diff --git a/src/main/java/com/gxwebsoft/ai/controller/AuditReportController.java b/src/main/java/com/gxwebsoft/ai/controller/AuditReportController.java index 3f6370d..30f7b63 100644 --- a/src/main/java/com/gxwebsoft/ai/controller/AuditReportController.java +++ b/src/main/java/com/gxwebsoft/ai/controller/AuditReportController.java @@ -1,6 +1,5 @@ package com.gxwebsoft.ai.controller; -import java.io.FileOutputStream; import java.io.OutputStream; import java.util.HashMap; import java.util.Map; @@ -20,13 +19,13 @@ import com.gxwebsoft.ai.config.TemplateConfig; import com.gxwebsoft.ai.dto.AuditReportRequest; import com.gxwebsoft.ai.dto.KnowledgeBaseRequest; import com.gxwebsoft.ai.enums.AuditReportEnum; +import com.gxwebsoft.ai.service.KnowledgeBaseService; import com.gxwebsoft.ai.util.AuditReportUtil; import com.gxwebsoft.common.core.web.ApiResult; import com.gxwebsoft.common.core.web.BaseController; import com.gxwebsoft.common.system.entity.User; import cn.afterturn.easypoi.word.WordExportUtil; -import cn.hutool.core.util.StrUtil; import cn.hutool.http.HttpUtil; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.tags.Tag; @@ -45,7 +44,7 @@ public class AuditReportController extends BaseController { private TemplateConfig templateConfig; @Autowired - private KnowledgeBaseController knowledgeBaseController; + private KnowledgeBaseService knowledgeBaseService; private String invok(String query, String knowledge, String history, String suggestion, String title, String userName) { // 构建请求体 @@ -91,7 +90,7 @@ public class AuditReportController extends BaseController { KnowledgeBaseRequest knowledgeBaseRequest = new KnowledgeBaseRequest(); knowledgeBaseRequest.setKbId(req.getKbId()); knowledgeBaseRequest.setFormCommit((req.getFormCommit() >= 10) ? req.getFormCommit() / 10 : req.getFormCommit()); - String knowledge = knowledgeBaseController.query(knowledgeBaseRequest).getData().toString(); + String knowledge = knowledgeBaseService.queryKnowledgeBase(knowledgeBaseRequest).toString(); String query = AuditReportEnum.getByCode(req.getFormCommit()).getDesc(); // String ret = this.invok(query, knowledge, AuditReportUtil.generateReportContent(req), req.getSuggestion(), loginUser.getUsername()); diff --git a/src/main/java/com/gxwebsoft/ai/controller/KnowledgeBaseController.java b/src/main/java/com/gxwebsoft/ai/controller/KnowledgeBaseController.java index ed819e4..0f7f73f 100644 --- a/src/main/java/com/gxwebsoft/ai/controller/KnowledgeBaseController.java +++ b/src/main/java/com/gxwebsoft/ai/controller/KnowledgeBaseController.java @@ -1,25 +1,15 @@ package com.gxwebsoft.ai.controller; -import com.aliyun.bailian20231229.Client; -import com.aliyun.bailian20231229.models.RetrieveResponse; -import com.aliyun.bailian20231229.models.RetrieveResponseBody.RetrieveResponseBodyDataNodes; -import com.gxwebsoft.ai.config.KnowledgeBaseConfig; -import com.gxwebsoft.ai.constants.KnowledgeBaseConstants; -import com.gxwebsoft.ai.factory.KnowledgeBaseClientFactory; -import com.gxwebsoft.ai.util.KnowledgeBaseRetrieve; import com.gxwebsoft.ai.dto.KnowledgeBaseRequest; +import com.gxwebsoft.ai.service.KnowledgeBaseService; import com.gxwebsoft.common.core.web.ApiResult; import com.gxwebsoft.common.core.web.BaseController; -import cn.hutool.core.util.StrUtil; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.tags.Tag; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.*; - -import java.util.Arrays; -import java.util.LinkedHashSet; -import java.util.List; -import java.util.Set; +import org.springframework.web.multipart.MultipartFile; @Tag(name = "知识库") @RestController @@ -27,33 +17,80 @@ import java.util.Set; public class KnowledgeBaseController extends BaseController { @Autowired - private KnowledgeBaseConfig config; - - @Autowired - private KnowledgeBaseClientFactory clientFactory; + private KnowledgeBaseService knowledgeBaseService; @Operation(summary = "查询知识库") @GetMapping("/query") public ApiResult query(KnowledgeBaseRequest req) { - Set ret = new LinkedHashSet<>(); - String workspaceId = config.getWorkspaceId(); - List keyWords = Arrays.asList(KnowledgeBaseConstants.KEY_WORDS); - String indexId = req.getKbId(); - String query = StrUtil.isEmpty(req.getQuery()) ? keyWords.get(req.getFormCommit()) : req.getQuery(); - Integer topK = req.getTopK() == null ? 10 : req.getTopK(); - try { - Client client = clientFactory.createClient(); - RetrieveResponse resp = KnowledgeBaseRetrieve.retrieveIndex(client, workspaceId, indexId, query); - for (RetrieveResponseBodyDataNodes node : resp.getBody().getData().getNodes()) { - ret.add(node.getText()); - if (ret.size() >= topK) { - break; - } - } + return success(knowledgeBaseService.queryKnowledgeBase(req)); } catch (Exception e) { return fail("查询失败:" + e.getMessage()); } - return success(ret); + } + + @Operation(summary = "创建知识库") + @PostMapping("/create") + public ApiResult create(@RequestParam String companyName, @RequestParam String companyCode) { + try { + String indexId = knowledgeBaseService.createKnowledgeBase(companyName, companyCode); + return success(indexId); + } catch (Exception e) { + return fail("创建失败:" + e.getMessage()); + } + } + + @Operation(summary = "查询知识库下的文档列表") + @GetMapping("/documents") + public ApiResult listDocuments(String kbId) { + try { + return success(knowledgeBaseService.listDocuments(kbId)); + } catch (Exception e) { + return fail("查询文档列表失败:" + e.getMessage()); + } + } + + @Operation(summary = "删除知识库下的文档") + @DeleteMapping("/document") + public ApiResult deleteDocument(String kbId, String fileIds) { + try { + boolean result = knowledgeBaseService.deleteIndexDocument(kbId, fileIds); + if (result) { + return success("删除成功"); + } else { + return fail("删除失败"); + } + } catch (Exception e) { + return fail("删除文档失败: " + e.getMessage()); + } + } + + @Operation(summary = "上传文档到知识库") + @PostMapping("/upload") + public ApiResult uploadDocuments(@RequestParam String kbId, @RequestParam("files") MultipartFile[] files) { + try { + if (files == null || files.length == 0) { + return fail("请选择要上传的文件"); + } + + for (MultipartFile file : files) { + if (file.isEmpty()) { + return fail("文件不能为空: " + file.getOriginalFilename()); + } + if (file.getSize() > 100 * 1024 * 1024) { + return fail("文件大小不能超过100MB: " + file.getOriginalFilename()); + } + } + + boolean result = knowledgeBaseService.uploadDocuments(kbId, files); + + if (result) { + return success("成功上传 " + files.length + " 个文件"); + } else { + return success("部分文件上传成功"); + } + } catch (Exception e) { + return fail("上传失败: " + e.getMessage()); + } } } \ No newline at end of file diff --git a/src/main/java/com/gxwebsoft/ai/service/KnowledgeBaseService.java b/src/main/java/com/gxwebsoft/ai/service/KnowledgeBaseService.java new file mode 100644 index 0000000..320fe41 --- /dev/null +++ b/src/main/java/com/gxwebsoft/ai/service/KnowledgeBaseService.java @@ -0,0 +1,51 @@ +package com.gxwebsoft.ai.service; + +import com.gxwebsoft.ai.dto.KnowledgeBaseRequest; + +import java.util.List; +import java.util.Set; + +import org.springframework.web.multipart.MultipartFile; + +public interface KnowledgeBaseService { + + /** + * 查询知识库 + */ + Set queryKnowledgeBase(KnowledgeBaseRequest req); + + /** + * 查询知识库(带参数) + */ + Set queryKnowledgeBase(String kbId, String query, Integer topK, Integer formCommit); + + /** + * 创建知识库 + */ + String createKnowledgeBase(String companyName, String companyCode); + + /** + * 检查知识库是否存 + */ + boolean existsKnowledgeBase(String companyCode); + + /** + * 查找知识库ID + */ + String getKnowledgeBaseId(String companyCode); + + /** + * 查询知识库下的文档列表 + */ + List listDocuments(String kbId); + + /** + * 删除知识库下的文档 + */ + boolean deleteIndexDocument(String kbId, String fileIds); + + /** + * 上传知识库文档 + */ + boolean uploadDocuments(String kbId, MultipartFile[] files); +} \ No newline at end of file diff --git a/src/main/java/com/gxwebsoft/ai/service/impl/KnowledgeBaseServiceImpl.java b/src/main/java/com/gxwebsoft/ai/service/impl/KnowledgeBaseServiceImpl.java new file mode 100644 index 0000000..ad253b9 --- /dev/null +++ b/src/main/java/com/gxwebsoft/ai/service/impl/KnowledgeBaseServiceImpl.java @@ -0,0 +1,153 @@ +package com.gxwebsoft.ai.service.impl; + +import com.aliyun.bailian20231229.Client; +import com.aliyun.bailian20231229.models.CreateIndexResponse; +import com.aliyun.bailian20231229.models.DeleteIndexDocumentResponse; +import com.aliyun.bailian20231229.models.ListIndexDocumentsResponse; +import com.aliyun.bailian20231229.models.ListIndicesResponse; +import com.aliyun.bailian20231229.models.RetrieveResponse; +import com.aliyun.bailian20231229.models.RetrieveResponseBody.RetrieveResponseBodyDataNodes; +import com.gxwebsoft.ai.config.KnowledgeBaseConfig; +import com.gxwebsoft.ai.constants.KnowledgeBaseConstants; +import com.gxwebsoft.ai.dto.KnowledgeBaseRequest; +import com.gxwebsoft.ai.factory.KnowledgeBaseClientFactory; +import com.gxwebsoft.ai.service.KnowledgeBaseService; +import com.gxwebsoft.ai.util.KnowledgeBaseUploader; +import com.gxwebsoft.ai.util.KnowledgeBaseUtil; +import cn.hutool.core.util.StrUtil; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; +import org.springframework.web.multipart.MultipartFile; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Set; + +@Service +public class KnowledgeBaseServiceImpl implements KnowledgeBaseService { + + @Autowired + private KnowledgeBaseConfig config; + + @Autowired + private KnowledgeBaseClientFactory clientFactory; + + @Override + public Set queryKnowledgeBase(KnowledgeBaseRequest req) { + return queryKnowledgeBase(req.getKbId(), req.getQuery(), req.getTopK(), req.getFormCommit()); + } + + @Override + public Set queryKnowledgeBase(String kbId, String query, Integer topK, Integer formCommit) { + Set result = new LinkedHashSet<>(); + String workspaceId = config.getWorkspaceId(); + List keyWords = Arrays.asList(KnowledgeBaseConstants.KEY_WORDS); + String indexId = kbId; + String searchQuery = StrUtil.isEmpty(query) ? keyWords.get(formCommit) : query; + Integer searchTopK = topK == null ? 10 : topK; + + try { + Client client = clientFactory.createClient(); + RetrieveResponse resp = KnowledgeBaseUtil.retrieveIndex(client, workspaceId, indexId, searchQuery); + for (RetrieveResponseBodyDataNodes node : resp.getBody().getData().getNodes()) { + result.add(node.getText()); + if (result.size() >= searchTopK) { + break; + } + } + } catch (Exception e) { + throw new RuntimeException("查询知识库失败: " + e.getMessage(), e); + } + return result; + } + + @Override + public String createKnowledgeBase(String companyName, String companyCode) { + String workspaceId = config.getWorkspaceId(); + try { + Client client = clientFactory.createClient(); + CreateIndexResponse indexResponse = KnowledgeBaseUtil.createIndex(client, workspaceId, companyCode, companyName); + return indexResponse.getBody().getData().getId(); + } catch (Exception e) { + throw new RuntimeException("创建知识库失败: " + e.getMessage(), e); + } + } + + @Override + public boolean existsKnowledgeBase(String companyCode) { + String workspaceId = config.getWorkspaceId(); + try { + Client client = clientFactory.createClient(); + ListIndicesResponse indicesResponse = KnowledgeBaseUtil.listIndices(client, workspaceId); + + return indicesResponse.getBody().getData().getIndices().stream() + .anyMatch(index -> companyCode.equals(index.getName())); + } catch (Exception e) { + throw new RuntimeException("检查知识库是否存在失败: " + e.getMessage(), e); + } + } + + @Override + public String getKnowledgeBaseId(String companyCode) { + String workspaceId = config.getWorkspaceId(); + try { + Client client = clientFactory.createClient(); + ListIndicesResponse indicesResponse = KnowledgeBaseUtil.listIndices(client, workspaceId); + + return indicesResponse.getBody().getData().getIndices().stream() + .filter(index -> companyCode.equals(index.getName())) + .findFirst() + .map(index -> index.getId()) + .orElse(""); + } catch (Exception e) { + throw new RuntimeException("查找知识库ID失败: " + e.getMessage(), e); + } + } + + @Override + public List listDocuments(String kbId) { + List ret = new ArrayList<>(); + String workspaceId = config.getWorkspaceId(); + try { + Client client = clientFactory.createClient(); + ListIndexDocumentsResponse indexDocumentsResponse = KnowledgeBaseUtil.listIndexDocuments(client, workspaceId, kbId); + ret.addAll(indexDocumentsResponse.getBody().getData().getDocuments()); + } catch (Exception e) { + throw new RuntimeException("查询知识库下的文档列表失败: " + e.getMessage(), e); + } + return ret; + } + + @Override + public boolean deleteIndexDocument(String kbId, String fileIds) { + String workspaceId = config.getWorkspaceId(); + List ids = StrUtil.splitTrim(fileIds, ","); + try { + Client client = clientFactory.createClient(); + DeleteIndexDocumentResponse indexDocumentResponse = KnowledgeBaseUtil.deleteIndexDocument(client, workspaceId, kbId, ids); + return indexDocumentResponse.getBody().getSuccess(); + } catch (Exception e) { + throw new RuntimeException("删除知识库下的文档失败: " + e.getMessage(), e); + } + } + + @Override + public boolean uploadDocuments(String kbId, MultipartFile[] files) { + String workspaceId = config.getWorkspaceId(); + int count = files.length; + try { + Client client = clientFactory.createClient(); + List fileIds = KnowledgeBaseUploader.uploadDocuments(client, workspaceId, kbId, files); + //上传切片完成后删除原文档(释放云空间) + for(String fileId : fileIds) { + KnowledgeBaseUtil.deleteAppDocument(client, workspaceId, fileId); + } + return !fileIds.isEmpty() && fileIds.size() == count; + } catch (Exception e) { + throw new RuntimeException("上传文档到知识库失败: " + e.getMessage(), e); + } + } + +} \ No newline at end of file diff --git a/src/main/java/com/gxwebsoft/ai/util/KnowledgeBaseUploader.java b/src/main/java/com/gxwebsoft/ai/util/KnowledgeBaseUploader.java new file mode 100644 index 0000000..eede89c --- /dev/null +++ b/src/main/java/com/gxwebsoft/ai/util/KnowledgeBaseUploader.java @@ -0,0 +1,303 @@ +package com.gxwebsoft.ai.util; + +import com.aliyun.bailian20231229.Client; +import com.aliyun.bailian20231229.models.*; +import com.aliyun.teautil.models.RuntimeOptions; +import com.fasterxml.jackson.databind.ObjectMapper; + +import java.io.File; +import java.io.FileInputStream; +import java.io.InputStream; +import java.net.HttpURLConnection; +import java.net.URL; +import java.nio.file.Paths; +import java.security.MessageDigest; +import java.util.*; + +import org.springframework.web.multipart.MultipartFile; + +/** + * 知识库上传工具类 + * @author GIIT-YC + * + */ +public class KnowledgeBaseUploader { + + /** + * 上传文档到知识库(直接处理MultipartFile) + * + * @param client 阿里云客户端 + * @param workspaceId 业务空间ID + * @param indexId 知识库ID + * @param file 上传的文件 + * @return 新文档的FileID,失败返回null + */ + public static String uploadDocument(com.aliyun.bailian20231229.Client client, String workspaceId, String indexId, MultipartFile file) { + try { + // 准备文档信息 + String fileName = file.getOriginalFilename(); + String fileMd5 = calculateMD5(file.getInputStream()); + String fileSize = String.valueOf(file.getSize()); + + // 申请上传租约 + ApplyFileUploadLeaseRequest leaseRequest = new ApplyFileUploadLeaseRequest() + .setFileName(fileName) + .setMd5(fileMd5) + .setSizeInBytes(fileSize); + + ApplyFileUploadLeaseResponse leaseResponse = client.applyFileUploadLeaseWithOptions( + "default", workspaceId, leaseRequest, new HashMap<>(), new RuntimeOptions()); + + String leaseId = leaseResponse.getBody().getData().getFileUploadLeaseId(); + String uploadUrl = leaseResponse.getBody().getData().getParam().getUrl(); + + // 上传文件 + ObjectMapper mapper = new ObjectMapper(); + Map headers = mapper.readValue(mapper.writeValueAsString(leaseResponse.getBody().getData().getParam().getHeaders()), Map.class); + + uploadFile(uploadUrl, headers, file); + + // 添加文件到类目 + AddFileRequest addRequest = new AddFileRequest() + .setLeaseId(leaseId) + .setParser("DASHSCOPE_DOCMIND") + .setCategoryId("default"); + + AddFileResponse addResponse = client.addFileWithOptions(workspaceId, addRequest, new HashMap<>(), new RuntimeOptions()); + + String fileId = addResponse.getBody().getData().getFileId(); + + // 等待文件解析完成 + waitForFileParsing(client, workspaceId, fileId); + + // 添加到知识库 + SubmitIndexAddDocumentsJobRequest indexRequest = new SubmitIndexAddDocumentsJobRequest() + .setIndexId(indexId) + .setDocumentIds(Collections.singletonList(fileId)) + .setSourceType("DATA_CENTER_FILE"); + + SubmitIndexAddDocumentsJobResponse indexResponse = client.submitIndexAddDocumentsJobWithOptions(workspaceId, indexRequest, new HashMap<>(), new RuntimeOptions()); + + // 等待索引完成 + waitForIndexJob(client, workspaceId, indexResponse.getBody().getData().getId(), indexId); + + return fileId; + + } catch (Exception e) { + e.printStackTrace(); + return null; + } + } + + /** + * 批量上传文档到知识库 + */ + public static List uploadDocuments(com.aliyun.bailian20231229.Client client, String workspaceId, String indexId, MultipartFile[] files) { + List fileIds = new ArrayList<>(); + for (MultipartFile file : files) { + String fileId = uploadDocument(client, workspaceId, indexId, file); + if (fileId != null) { + fileIds.add(fileId); + } + } + return fileIds; + } + + /** + * 上传文档到知识库 + * + * @param client 阿里云客户端 + * @param workspaceId 业务空间ID + * @param indexId 知识库ID + * @param filePath 文档本地路径 + * @return 新文档的FileID,失败返回null + */ + public static String uploadDocument(com.aliyun.bailian20231229.Client client, String workspaceId, String indexId, String filePath) { + try { + // 准备文档信息 + String fileName = Paths.get(filePath).getFileName().toString(); + String fileMd5 = calculateMD5(filePath); + String fileSize = String.valueOf(new File(filePath).length()); + + // 申请上传租约 + ApplyFileUploadLeaseRequest leaseRequest = new ApplyFileUploadLeaseRequest() + .setFileName(fileName) + .setMd5(fileMd5) + .setSizeInBytes(fileSize); + + ApplyFileUploadLeaseResponse leaseResponse = client.applyFileUploadLeaseWithOptions( + "default", workspaceId, leaseRequest, new HashMap<>(), new RuntimeOptions()); + + String leaseId = leaseResponse.getBody().getData().getFileUploadLeaseId(); + String uploadUrl = leaseResponse.getBody().getData().getParam().getUrl(); + + // 上传文件 + ObjectMapper mapper = new ObjectMapper(); + Map headers = mapper.readValue( + mapper.writeValueAsString(leaseResponse.getBody().getData().getParam().getHeaders()), + Map.class); + + uploadFile(uploadUrl, headers, filePath); + + // 添加文件到类目 + AddFileRequest addRequest = new AddFileRequest() + .setLeaseId(leaseId) + .setParser("DASHSCOPE_DOCMIND") + .setCategoryId("default"); + + AddFileResponse addResponse = client.addFileWithOptions( + workspaceId, addRequest, new HashMap<>(), new RuntimeOptions()); + + String fileId = addResponse.getBody().getData().getFileId(); + + // 等待文件解析完成 + waitForFileParsing(client, workspaceId, fileId); + + // 添加到知识库 + SubmitIndexAddDocumentsJobRequest indexRequest = new SubmitIndexAddDocumentsJobRequest() + .setIndexId(indexId) + .setDocumentIds(Collections.singletonList(fileId)) + .setSourceType("DATA_CENTER_FILE"); + + SubmitIndexAddDocumentsJobResponse indexResponse = client.submitIndexAddDocumentsJobWithOptions( + workspaceId, indexRequest, new HashMap<>(), new RuntimeOptions()); + + // 等待索引完成 + waitForIndexJob(client, workspaceId, indexResponse.getBody().getData().getId(), indexId); + + return fileId; + + } catch (Exception e) { + e.printStackTrace(); + return null; + } + } + + private static String calculateMD5(String filePath) throws Exception { + MessageDigest md = MessageDigest.getInstance("MD5"); + try (FileInputStream fis = new FileInputStream(filePath)) { + byte[] buffer = new byte[4096]; + int bytesRead; + while ((bytesRead = fis.read(buffer)) != -1) { + md.update(buffer, 0, bytesRead); + } + } + StringBuilder sb = new StringBuilder(); + for (byte b : md.digest()) { + sb.append(String.format("%02x", b & 0xff)); + } + return sb.toString(); + } + + private static void uploadFile(String preSignedUrl, Map headers, + String filePath) throws Exception { + try (FileInputStream fis = new FileInputStream(filePath)) { + HttpURLConnection conn = (HttpURLConnection) new URL(preSignedUrl).openConnection(); + conn.setRequestMethod("PUT"); + conn.setDoOutput(true); + conn.setRequestProperty("X-bailian-extra", headers.get("X-bailian-extra")); + conn.setRequestProperty("Content-Type", headers.get("Content-Type")); + + byte[] buffer = new byte[4096]; + int bytesRead; + while ((bytesRead = fis.read(buffer)) != -1) { + conn.getOutputStream().write(buffer, 0, bytesRead); + } + + if (conn.getResponseCode() != 200) { + throw new RuntimeException("上传失败: " + conn.getResponseCode()); + } + } + } + + private static void waitForFileParsing(com.aliyun.bailian20231229.Client client, + String workspaceId, String fileId) throws Exception { + while (true) { + DescribeFileResponse response = client.describeFileWithOptions( + workspaceId, fileId, new HashMap<>(), new RuntimeOptions()); + + String status = response.getBody().getData().getStatus(); + if ("PARSE_SUCCESS".equals(status)) break; + if ("PARSE_FAILED".equals(status)) throw new RuntimeException("文档解析失败"); + Thread.sleep(5000); + } + } + + private static void waitForIndexJob(com.aliyun.bailian20231229.Client client, + String workspaceId, String jobId, String indexId) throws Exception { + while (true) { + GetIndexJobStatusRequest request = new GetIndexJobStatusRequest() + .setIndexId(indexId) + .setJobId(jobId); + + GetIndexJobStatusResponse response = client.getIndexJobStatusWithOptions( + workspaceId, request, new HashMap<>(), new RuntimeOptions()); + + String status = response.getBody().getData().getStatus(); + if ("COMPLETED".equals(status)) break; + if ("FAILED".equals(status)) throw new RuntimeException("索引任务失败"); + Thread.sleep(5000); + } + } + + private static String calculateMD5(InputStream inputStream) throws Exception { + MessageDigest md = MessageDigest.getInstance("MD5"); + byte[] buffer = new byte[4096]; + int bytesRead; + while ((bytesRead = inputStream.read(buffer)) != -1) { + md.update(buffer, 0, bytesRead); + } + StringBuilder sb = new StringBuilder(); + for (byte b : md.digest()) { + sb.append(String.format("%02x", b & 0xff)); + } + return sb.toString(); + } + + private static void uploadFile(String preSignedUrl, Map headers, + MultipartFile file) throws Exception { + HttpURLConnection conn = (HttpURLConnection) new URL(preSignedUrl).openConnection(); + conn.setRequestMethod("PUT"); + conn.setDoOutput(true); + conn.setRequestProperty("X-bailian-extra", headers.get("X-bailian-extra")); + conn.setRequestProperty("Content-Type", headers.get("Content-Type")); + + try (InputStream inputStream = file.getInputStream()) { + byte[] buffer = new byte[4096]; + int bytesRead; + while ((bytesRead = inputStream.read(buffer)) != -1) { + conn.getOutputStream().write(buffer, 0, bytesRead); + } + } + + if (conn.getResponseCode() != 200) { + throw new RuntimeException("上传失败: " + conn.getResponseCode()); + } + } + + /** + * 初始化客户端(Client)。 + * + * @return 配置好的客户端对象 + */ + public static com.aliyun.bailian20231229.Client createClient(String accessKeyId, String accessKeySecret) throws Exception { + com.aliyun.teaopenapi.models.Config config = new com.aliyun.teaopenapi.models.Config() + .setAccessKeyId(accessKeyId) + .setAccessKeySecret(accessKeySecret); + // 下方接入地址以公有云的公网接入地址为例,可按需更换接入地址。 + config.endpoint = "bailian.cn-beijing.aliyuncs.com"; + return new com.aliyun.bailian20231229.Client(config); + } + + public static void main(String[] args) throws Exception { + String ALIBABA_CLOUD_ACCESS_KEY_ID = "LTAI5tD5YRKuxWz6Eg7qrM4P"; + String ALIBABA_CLOUD_ACCESS_KEY_SECRET = "bO8TBDXflOwbtSKimPpG8XrJnyzgTk"; + String WORKSPACE_ID = "llm-4pf5auwewoz34zqu"; + String indexId = "b9pvwfqp3d"; + String filePath = "D:\\公司经济责任审计方案模板.docx"; + + Client client = createClient(ALIBABA_CLOUD_ACCESS_KEY_ID, ALIBABA_CLOUD_ACCESS_KEY_SECRET); + + uploadDocument(client, WORKSPACE_ID, indexId, filePath); + } +} \ No newline at end of file diff --git a/src/main/java/com/gxwebsoft/ai/util/KnowledgeBaseUtil.java b/src/main/java/com/gxwebsoft/ai/util/KnowledgeBaseUtil.java new file mode 100644 index 0000000..76a028d --- /dev/null +++ b/src/main/java/com/gxwebsoft/ai/util/KnowledgeBaseUtil.java @@ -0,0 +1,134 @@ +package com.gxwebsoft.ai.util; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import com.aliyun.bailian20231229.models.CreateIndexResponse; +import com.aliyun.bailian20231229.models.DeleteFileResponse; +import com.aliyun.bailian20231229.models.DeleteIndexDocumentResponse; +import com.aliyun.bailian20231229.models.DeleteIndexResponse; +import com.aliyun.bailian20231229.models.ListIndexDocumentsResponse; +import com.aliyun.bailian20231229.models.ListIndicesResponse; +import com.aliyun.bailian20231229.models.RetrieveRequest; +import com.aliyun.bailian20231229.models.RetrieveResponse; +import com.aliyun.teautil.models.RuntimeOptions; + +/** + * 知识库工具类 + * @author GIIT-YC + * + */ +public class KnowledgeBaseUtil { + + /** + * 在指定的知识库中检索信息。 + * + * @param client 客户端对象(bailian20231229Client) + * @param workspaceId 业务空间ID + * @param indexId 知识库ID + * @param query 检索查询语句 + * @return 阿里云百炼服务的响应 + */ + public static RetrieveResponse retrieveIndex(com.aliyun.bailian20231229.Client client, String workspaceId, String indexId, String query) throws Exception { + RetrieveRequest retrieveRequest = new RetrieveRequest(); + retrieveRequest.setIndexId(indexId); + retrieveRequest.setQuery(query); + retrieveRequest.setDenseSimilarityTopK(null); + RuntimeOptions runtime = new RuntimeOptions(); + return client.retrieveWithOptions(workspaceId, retrieveRequest, null, runtime); + } + + /** + * 在阿里云百炼服务中创建知识库(初始化)。 + * + * @param client 客户端对象 + * @param workspaceId 业务空间ID + * @param name 知识库名称 + * @param desc 知识库描述 + * @return 阿里云百炼服务的响应对象 + */ + public static CreateIndexResponse createIndex(com.aliyun.bailian20231229.Client client, String workspaceId, String name, String desc) throws Exception { + Map headers = new HashMap<>(); + com.aliyun.bailian20231229.models.CreateIndexRequest createIndexRequest = new com.aliyun.bailian20231229.models.CreateIndexRequest(); + createIndexRequest.setStructureType("unstructured"); + createIndexRequest.setName(name); + createIndexRequest.setDescription(desc); + createIndexRequest.setSinkType("DEFAULT"); + createIndexRequest.setEmbeddingModelName("text-embedding-v4"); + com.aliyun.teautil.models.RuntimeOptions runtime = new com.aliyun.teautil.models.RuntimeOptions(); + return client.createIndexWithOptions(workspaceId, createIndexRequest, headers, runtime); + } + + /** + * 获取指定业务空间下一个或多个知识库的详细信息 + * + * @param client 客户端(Client) + * @param workspaceId 业务空间ID + * @return 阿里云百炼服务的响应 + */ + public static ListIndicesResponse listIndices(com.aliyun.bailian20231229.Client client, String workspaceId) throws Exception { + Map headers = new HashMap<>(); + com.aliyun.bailian20231229.models.ListIndicesRequest listIndicesRequest = new com.aliyun.bailian20231229.models.ListIndicesRequest(); + com.aliyun.teautil.models.RuntimeOptions runtime = new com.aliyun.teautil.models.RuntimeOptions(); + return client.listIndicesWithOptions(workspaceId, listIndicesRequest, headers, runtime); + } + + /** + * 永久性删除指定的知识库 + * + * @param client 客户端(Client) + * @param workspaceId 业务空间ID + * @param indexId 知识库ID + * @return 阿里云百炼服务的响应 + */ + public static DeleteIndexResponse deleteIndex(com.aliyun.bailian20231229.Client client, String workspaceId, String indexId) throws Exception { + Map headers = new HashMap<>(); + com.aliyun.bailian20231229.models.DeleteIndexRequest deleteIndexRequest = new com.aliyun.bailian20231229.models.DeleteIndexRequest(); + deleteIndexRequest.setIndexId(indexId); + com.aliyun.teautil.models.RuntimeOptions runtime = new com.aliyun.teautil.models.RuntimeOptions(); + return client.deleteIndexWithOptions(workspaceId, deleteIndexRequest, headers, runtime); + } + + /** + * 查询知识库下的文档列表 + * + * @param client 客户端(Client) + * @param workspaceId 业务空间ID + * @param indexId 知识库ID + * @return 阿里云百炼服务的响应 + */ + public static ListIndexDocumentsResponse listIndexDocuments(com.aliyun.bailian20231229.Client client, String workspaceId, String indexId) throws Exception { + com.aliyun.bailian20231229.models.ListIndexDocumentsRequest listIndexDocumentsRequest = new com.aliyun.bailian20231229.models.ListIndexDocumentsRequest(); + listIndexDocumentsRequest.setIndexId(indexId); + return client.listIndexDocuments(workspaceId, listIndexDocumentsRequest); + } + + /** + * 删除知识库下的文档 + * + * @param client 客户端(Client) + * @param workspaceId 业务空间ID + * @param indexId 知识库ID + * @param ids 删除文件ID列表 + * @return 阿里云百炼服务的响应 + */ + public static DeleteIndexDocumentResponse deleteIndexDocument(com.aliyun.bailian20231229.Client client, String workspaceId, String indexId, List ids) throws Exception { + com.aliyun.bailian20231229.models.DeleteIndexDocumentRequest deleteIndexDocumentRequest = new com.aliyun.bailian20231229.models.DeleteIndexDocumentRequest(); + deleteIndexDocumentRequest.setIndexId(indexId); + deleteIndexDocumentRequest.setDocumentIds(ids); + return client.deleteIndexDocument(workspaceId, deleteIndexDocumentRequest); + } + + /** + * 删除阿里云应用数据文档 + * + * @param client 客户端(Client) + * @param workspaceId 业务空间ID + * @param fileId 删除文件ID + * @return 阿里云百炼服务的响应 + */ + public static DeleteFileResponse deleteAppDocument(com.aliyun.bailian20231229.Client client, String workspaceId, String fileId) throws Exception { + return client.deleteFile(fileId, workspaceId); + } +} diff --git a/src/main/java/com/gxwebsoft/oa/controller/OaCompanyController.java b/src/main/java/com/gxwebsoft/oa/controller/OaCompanyController.java index acfa344..d37c879 100644 --- a/src/main/java/com/gxwebsoft/oa/controller/OaCompanyController.java +++ b/src/main/java/com/gxwebsoft/oa/controller/OaCompanyController.java @@ -11,9 +11,12 @@ import com.gxwebsoft.common.core.web.ApiResult; import com.gxwebsoft.common.core.web.PageResult; import com.gxwebsoft.common.core.web.PageParam; import com.gxwebsoft.common.core.web.BatchParam; +import com.gxwebsoft.ai.service.KnowledgeBaseService; import com.gxwebsoft.common.core.annotation.OperationLog; import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.Operation; + +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.access.prepost.PreAuthorize; import org.springframework.web.bind.annotation.*; @@ -32,6 +35,9 @@ import java.util.List; public class OaCompanyController extends BaseController { @Resource private OaCompanyService oaCompanyService; + + @Autowired + private KnowledgeBaseService knowledgeBaseService; @Operation(summary = "分页查询企业信息") @GetMapping("/page") @@ -63,12 +69,16 @@ public class OaCompanyController extends BaseController { @Operation(summary = "添加企业信息") @PostMapping() public ApiResult save(@RequestBody OaCompany oaCompany) { + if(StrUtil.isEmpty(oaCompany.getCompanyCode())) { + return fail("单位唯一标识不能为空"); + } if (oaCompanyService.save(oaCompany)) { - //TODO 查询知识库(kb_name=enterprise.getCreditCode) - - //TODO 新建知识库 - String kbId = "pggi9mpair"; - + //查询知识库 + String kbId = knowledgeBaseService.getKnowledgeBaseId(oaCompany.getCompanyCode()); + //新建知识库 + if(StrUtil.isEmpty(kbId)) { + kbId = knowledgeBaseService.createKnowledgeBase(oaCompany.getCompanyName(), oaCompany.getCompanyCode()); + } //绑定知识库 oaCompany.setKbId(kbId); oaCompanyService.updateById(oaCompany); @@ -81,12 +91,16 @@ public class OaCompanyController extends BaseController { @Operation(summary = "修改企业信息") @PutMapping() public ApiResult update(@RequestBody OaCompany oaCompany) { + if(StrUtil.isEmpty(oaCompany.getCompanyCode())) { + return fail("单位唯一标识不能为空"); + } if(StrUtil.isEmpty(oaCompany.getKbId())) { - //TODO 查询知识库 - - //TODO 新建知识库 - String kbId = "pggi9mpair"; - + //查询知识库 + String kbId = knowledgeBaseService.getKnowledgeBaseId(oaCompany.getCompanyCode()); + //新建知识库 + if(StrUtil.isEmpty(kbId)) { + kbId = knowledgeBaseService.createKnowledgeBase(oaCompany.getCompanyName(), oaCompany.getCompanyCode()); + } //绑定知识库 oaCompany.setKbId(kbId); }