diff --git a/src/main/java/com/gxwebsoft/ai/service/impl/KnowledgeBaseServiceImpl.java b/src/main/java/com/gxwebsoft/ai/service/impl/KnowledgeBaseServiceImpl.java index 76d10bb..d3b0b72 100644 --- a/src/main/java/com/gxwebsoft/ai/service/impl/KnowledgeBaseServiceImpl.java +++ b/src/main/java/com/gxwebsoft/ai/service/impl/KnowledgeBaseServiceImpl.java @@ -1,207 +1,215 @@ -//package com.gxwebsoft.ai.service.impl; -// -//import com.aliyun.bailian20231229.Client; -//import com.aliyun.bailian20231229.models.CreateIndexResponse; -//import com.aliyun.bailian20231229.models.DeleteIndexDocumentResponse; -//import com.aliyun.bailian20231229.models.DeleteIndexResponse; -//import com.aliyun.bailian20231229.models.ListIndexDocumentsResponse; -//import com.aliyun.bailian20231229.models.ListIndicesResponse; -//import com.aliyun.bailian20231229.models.RetrieveResponse; -//import com.aliyun.bailian20231229.models.RetrieveResponseBody.RetrieveResponseBodyDataNodes; -//import com.gxwebsoft.ai.config.KnowledgeBaseConfig; -//import com.gxwebsoft.ai.constants.KnowledgeBaseConstants; -//import com.gxwebsoft.ai.dto.KnowledgeBaseRequest; -//import com.gxwebsoft.ai.factory.KnowledgeBaseClientFactory; -//import com.gxwebsoft.ai.service.KnowledgeBaseService; -//import com.gxwebsoft.ai.util.AiCloudKnowledgeBaseUtil; -//import com.gxwebsoft.ai.util.KnowledgeBaseUploader; -//import com.gxwebsoft.ai.util.KnowledgeBaseUtil; -//import cn.hutool.core.util.StrUtil; -//import org.springframework.beans.factory.annotation.Autowired; -//import org.springframework.scheduling.annotation.Async; -//import org.springframework.stereotype.Service; -//import org.springframework.web.multipart.MultipartFile; -// -//import java.time.LocalDateTime; -//import java.time.format.DateTimeFormatter; -//import java.util.Arrays; -//import java.util.HashMap; -//import java.util.LinkedHashSet; -//import java.util.List; -//import java.util.Map; -//import java.util.Set; -// -//@Service -//public class KnowledgeBaseServiceImpl implements KnowledgeBaseService { -// -// @Autowired -// private KnowledgeBaseConfig config; -// -// @Autowired -// private KnowledgeBaseClientFactory clientFactory; -// -// @Override -// public Set queryKnowledgeBase(KnowledgeBaseRequest req) { -// return queryKnowledgeBase(req.getKbId(), req.getQuery(), req.getTopK(), req.getFormCommit()); -// } -// -// @Override -// public Set queryKnowledgeBase(String kbId, String query, Integer topK, Integer formCommit) { -// Set result = new LinkedHashSet<>(); -// String workspaceId = config.getWorkspaceId(); -// List keyWords = Arrays.asList(KnowledgeBaseConstants.KEY_WORDS); -// String indexId = kbId; -// String searchQuery = StrUtil.isEmpty(query) ? keyWords.get(formCommit) : query; -// Integer searchTopK = topK == null ? 10 : topK; -// -// try { -// Client client = clientFactory.createClient(); -// RetrieveResponse resp = KnowledgeBaseUtil.retrieveIndex(client, workspaceId, indexId, searchQuery); -// for (RetrieveResponseBodyDataNodes node : resp.getBody().getData().getNodes()) { -// result.add(node.getText()); -// if (result.size() >= searchTopK) { -// break; -// } -// } -// } catch (Exception e) { -// throw new RuntimeException("查询知识库失败: " + e.getMessage(), e); -// } -// return result; -// } -// -// @Override -// public String createKnowledgeBase(String companyName, String companyCode) { -// String workspaceId = config.getWorkspaceId(); -// try { -// String kbId = getKnowledgeBaseIdByName(companyCode); -// if(StrUtil.isNotEmpty(kbId)) { -// return kbId; -// } -// -// Client client = clientFactory.createClient(); -// CreateIndexResponse indexResponse = KnowledgeBaseUtil.createIndex(client, workspaceId, companyCode, companyName); -// return indexResponse.getBody().getData().getId(); -// } catch (Exception e) { -// throw new RuntimeException("创建知识库失败: " + e.getMessage(), e); -// } -// } -// -// @Override -// public String createKnowledgeBaseTemp() { -// String code = "Temp_" + LocalDateTime.now().format(DateTimeFormatter.ofPattern("MMddHHmmssSSS")); -// return createKnowledgeBase(code, code); -// } -// -// @Override -// public boolean existsKnowledgeBase(String companyCode) { -// String workspaceId = config.getWorkspaceId(); -// try { -// Client client = clientFactory.createClient(); -// ListIndicesResponse indicesResponse = KnowledgeBaseUtil.listIndices(client, workspaceId); -// -// return indicesResponse.getBody().getData().getIndices().stream() -// .anyMatch(index -> companyCode.equals(index.getName())); -// } catch (Exception e) { -// throw new RuntimeException("检查知识库是否存在失败: " + e.getMessage(), e); -// } -// } -// -// @Override -// public String getKnowledgeBaseIdByName(String companyCode) { -// String workspaceId = config.getWorkspaceId(); -// try { -// Client client = clientFactory.createClient(); -// ListIndicesResponse indicesResponse = KnowledgeBaseUtil.listIndices(client, workspaceId); -// -// return indicesResponse.getBody().getData().getIndices().stream() -// .filter(index -> companyCode.equals(index.getName())) -// .findFirst() -// .map(index -> index.getId()) -// .orElse(""); -// } catch (Exception e) { -// throw new RuntimeException("查找知识库ID失败: " + e.getMessage(), e); -// } -// } -// -// @Override -// public Map listDocuments(String kbId, Integer pageSize, Integer pageNumber) { -// Map ret = new HashMap<>(); -// String workspaceId = config.getWorkspaceId(); -// try { -// Client client = clientFactory.createClient(); -// ListIndexDocumentsResponse indexDocumentsResponse = KnowledgeBaseUtil.listIndexDocuments(client, workspaceId, kbId, pageSize, pageNumber); -// ret.put("data", indexDocumentsResponse.getBody().getData().getDocuments()); -// ret.put("total", indexDocumentsResponse.getBody().getData().getTotalCount()); -// } catch (Exception e) { -// throw new RuntimeException("查询知识库下的文档列表失败: " + e.getMessage(), e); -// } -// return ret; -// } -// -// @Override -// public boolean deleteIndex(String kbId) { -// String workspaceId = config.getWorkspaceId(); -// try { -// Client client = clientFactory.createClient(); -// DeleteIndexResponse indexDocumentResponse = KnowledgeBaseUtil.deleteIndex(client, workspaceId, kbId); -// return indexDocumentResponse.getBody().getSuccess(); -// } catch (Exception e) { -// throw new RuntimeException("删除知识库失败: " + e.getMessage(), e); -// } -// } -// -// @Override -// public boolean deleteIndexDocument(String kbId, String fileIds) { -// String workspaceId = config.getWorkspaceId(); -// List ids = StrUtil.splitTrim(fileIds, ","); -// try { -// Client client = clientFactory.createClient(); -// DeleteIndexDocumentResponse indexDocumentResponse = KnowledgeBaseUtil.deleteIndexDocument(client, workspaceId, kbId, ids); -// return indexDocumentResponse.getBody().getSuccess(); -// } catch (Exception e) { -// throw new RuntimeException("删除知识库下的文档失败: " + e.getMessage(), e); -// } -// } -// -// @Override -// public boolean uploadDocuments(String kbId, MultipartFile[] files) { -// String workspaceId = config.getWorkspaceId(); -// int count = files.length; -// try { -// Client client = clientFactory.createClient(); -// List fileIds = KnowledgeBaseUploader.uploadDocuments(client, workspaceId, kbId, files); -// //上传切片完成后删除原文档(释放云空间) -// for(String fileId : fileIds) { -// KnowledgeBaseUtil.deleteAppDocument(client, workspaceId, fileId); -// } -// return !fileIds.isEmpty() && fileIds.size() == count; -// } catch (Exception e) { -// throw new RuntimeException("上传文档到知识库失败: " + e.getMessage(), e); -// } -// } -// -// @Async -// @Override -// public void submitDocuments(String kbId, String fileId) { -// String workspaceId = config.getWorkspaceId(); -// try { -// Client client = clientFactory.createClient(); -// AiCloudKnowledgeBaseUtil.submitIndexAddDocumentsJob(client, workspaceId, kbId, fileId); -// } catch (Exception e) { -// throw new RuntimeException("添加文档到知识库失败: " + e.getMessage(), e); -// } -// } -// -// @Override -// public void submitDocuments(String kbId, List fileIds) { -// String workspaceId = config.getWorkspaceId(); -// try { -// Client client = clientFactory.createClient(); -// AiCloudKnowledgeBaseUtil.submitIndexAddDocumentsJob(client, workspaceId, kbId, fileIds); -// } catch (Exception e) { -// throw new RuntimeException("添加文档到知识库失败: " + e.getMessage(), e); -// } -// } -// -// -//} \ No newline at end of file +package com.gxwebsoft.ai.service.impl; + +import com.aliyun.bailian20231229.Client; +import com.aliyun.bailian20231229.models.AddFileResponse; +import com.aliyun.bailian20231229.models.CreateIndexResponse; +import com.aliyun.bailian20231229.models.DeleteIndexDocumentResponse; +import com.aliyun.bailian20231229.models.DeleteIndexResponse; +import com.aliyun.bailian20231229.models.ListIndexDocumentsResponse; +import com.aliyun.bailian20231229.models.ListIndicesResponse; +import com.aliyun.bailian20231229.models.RetrieveResponse; +import com.aliyun.bailian20231229.models.RetrieveResponseBody.RetrieveResponseBodyDataNodes; +import com.gxwebsoft.ai.config.KnowledgeBaseConfig; +import com.gxwebsoft.ai.constants.KnowledgeBaseConstants; +import com.gxwebsoft.ai.dto.KnowledgeBaseRequest; +import com.gxwebsoft.ai.factory.KnowledgeBaseClientFactory; +import com.gxwebsoft.ai.service.KnowledgeBaseService; +import com.gxwebsoft.ai.util.AiCloudDataCenterUtil; +import com.gxwebsoft.ai.util.AiCloudKnowledgeBaseUtil; +import cn.hutool.core.util.StrUtil; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.scheduling.annotation.Async; +import org.springframework.stereotype.Service; +import org.springframework.web.multipart.MultipartFile; + +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +@Service +public class KnowledgeBaseServiceImpl implements KnowledgeBaseService { + + @Autowired + private KnowledgeBaseConfig config; + + @Autowired + private KnowledgeBaseClientFactory clientFactory; + + @Override + public Set queryKnowledgeBase(KnowledgeBaseRequest req) { + return queryKnowledgeBase(req.getKbId(), req.getQuery(), req.getTopK(), req.getFormCommit()); + } + + @Override + public Set queryKnowledgeBase(String kbId, String query, Integer topK, Integer formCommit) { + Set result = new LinkedHashSet<>(); + String workspaceId = config.getWorkspaceId(); + List keyWords = Arrays.asList(KnowledgeBaseConstants.KEY_WORDS); + String indexId = kbId; + String searchQuery = StrUtil.isEmpty(query) ? keyWords.get(formCommit) : query; + Integer searchTopK = topK == null ? 10 : topK; + + try { + Client client = clientFactory.createClient(); + RetrieveResponse resp = AiCloudKnowledgeBaseUtil.retrieveIndex(client, workspaceId, indexId, searchQuery); + for (RetrieveResponseBodyDataNodes node : resp.getBody().getData().getNodes()) { + result.add(node.getText()); + if (result.size() >= searchTopK) { + break; + } + } + } catch (Exception e) { + throw new RuntimeException("查询知识库失败: " + e.getMessage(), e); + } + return result; + } + + @Override + public String createKnowledgeBase(String companyName, String companyCode) { + String workspaceId = config.getWorkspaceId(); + try { + String kbId = getKnowledgeBaseIdByName(companyCode); + if(StrUtil.isNotEmpty(kbId)) { + return kbId; + } + + Client client = clientFactory.createClient(); + CreateIndexResponse indexResponse = AiCloudKnowledgeBaseUtil.createIndex(client, workspaceId, companyCode, companyName); + return indexResponse.getBody().getData().getId(); + } catch (Exception e) { + throw new RuntimeException("创建知识库失败: " + e.getMessage(), e); + } + } + + @Override + public String createKnowledgeBaseTemp() { + String code = "Temp_" + LocalDateTime.now().format(DateTimeFormatter.ofPattern("MMddHHmmssSSS")); + return createKnowledgeBase(code, code); + } + + @Override + public boolean existsKnowledgeBase(String companyCode) { + String workspaceId = config.getWorkspaceId(); + try { + Client client = clientFactory.createClient(); + ListIndicesResponse indicesResponse = AiCloudKnowledgeBaseUtil.listIndices(client, workspaceId); + + return indicesResponse.getBody().getData().getIndices().stream() + .anyMatch(index -> companyCode.equals(index.getName())); + } catch (Exception e) { + throw new RuntimeException("检查知识库是否存在失败: " + e.getMessage(), e); + } + } + + @Override + public String getKnowledgeBaseIdByName(String companyCode) { + String workspaceId = config.getWorkspaceId(); + try { + Client client = clientFactory.createClient(); + ListIndicesResponse indicesResponse = AiCloudKnowledgeBaseUtil.listIndices(client, workspaceId); + + return indicesResponse.getBody().getData().getIndices().stream() + .filter(index -> companyCode.equals(index.getName())) + .findFirst() + .map(index -> index.getId()) + .orElse(""); + } catch (Exception e) { + throw new RuntimeException("查找知识库ID失败: " + e.getMessage(), e); + } + } + + @Override + public Map listDocuments(String kbId, Integer pageSize, Integer pageNumber) { + Map ret = new HashMap<>(); + String workspaceId = config.getWorkspaceId(); + try { + Client client = clientFactory.createClient(); + ListIndexDocumentsResponse indexDocumentsResponse = AiCloudKnowledgeBaseUtil.listIndexDocuments(client, workspaceId, kbId, pageSize, pageNumber); + ret.put("data", indexDocumentsResponse.getBody().getData().getDocuments()); + ret.put("total", indexDocumentsResponse.getBody().getData().getTotalCount()); + } catch (Exception e) { + throw new RuntimeException("查询知识库下的文档列表失败: " + e.getMessage(), e); + } + return ret; + } + + @Override + public boolean deleteIndex(String kbId) { + String workspaceId = config.getWorkspaceId(); + try { + Client client = clientFactory.createClient(); + DeleteIndexResponse indexDocumentResponse = AiCloudKnowledgeBaseUtil.deleteIndex(client, workspaceId, kbId); + return indexDocumentResponse.getBody().getSuccess(); + } catch (Exception e) { + throw new RuntimeException("删除知识库失败: " + e.getMessage(), e); + } + } + + @Override + public boolean deleteIndexDocument(String kbId, String fileIds) { + String workspaceId = config.getWorkspaceId(); + List ids = StrUtil.splitTrim(fileIds, ","); + try { + Client client = clientFactory.createClient(); + DeleteIndexDocumentResponse indexDocumentResponse = AiCloudKnowledgeBaseUtil.deleteIndexDocument(client, workspaceId, kbId, ids); + return indexDocumentResponse.getBody().getSuccess(); + } catch (Exception e) { + throw new RuntimeException("删除知识库下的文档失败: " + e.getMessage(), e); + } + } + + @Override + public boolean uploadDocuments(String kbId, MultipartFile[] files) { + String workspaceId = config.getWorkspaceId(); + int count = files.length; + try { + Client client = clientFactory.createClient(); + + List fileIds = new ArrayList<>(); + for(MultipartFile file : files) { + AddFileResponse addFileResponse = AiCloudDataCenterUtil.uploadFile(client, workspaceId, "", file); + String fileId = addFileResponse.getBody().getData().getFileId(); + fileIds.add(fileId); + } +// List fileIds = AiCloudKnowledgeBaseUtil.uploadDocuments(client, workspaceId, kbId, files); + //上传切片完成后删除原文档(释放云空间) + for(String fileId : fileIds) { + AiCloudKnowledgeBaseUtil.deleteAppDocument(client, workspaceId, fileId); + } + return !fileIds.isEmpty() && fileIds.size() == count; + } catch (Exception e) { + throw new RuntimeException("上传文档到知识库失败: " + e.getMessage(), e); + } + } + + @Async + @Override + public void submitDocuments(String kbId, String fileId) { + String workspaceId = config.getWorkspaceId(); + try { + Client client = clientFactory.createClient(); + AiCloudKnowledgeBaseUtil.submitIndexAddDocumentsJob(client, workspaceId, kbId, fileId); + } catch (Exception e) { + throw new RuntimeException("添加文档到知识库失败: " + e.getMessage(), e); + } + } + + @Override + public void submitDocuments(String kbId, List fileIds) { + String workspaceId = config.getWorkspaceId(); + try { + Client client = clientFactory.createClient(); + AiCloudKnowledgeBaseUtil.submitIndexAddDocumentsJob(client, workspaceId, kbId, fileIds); + } catch (Exception e) { + throw new RuntimeException("添加文档到知识库失败: " + e.getMessage(), e); + } + } + + +} \ No newline at end of file diff --git a/src/main/java/com/gxwebsoft/ai/util/AiCloudDataCenterUtil.java b/src/main/java/com/gxwebsoft/ai/util/AiCloudDataCenterUtil.java index 8917330..c942cc2 100644 --- a/src/main/java/com/gxwebsoft/ai/util/AiCloudDataCenterUtil.java +++ b/src/main/java/com/gxwebsoft/ai/util/AiCloudDataCenterUtil.java @@ -15,7 +15,7 @@ import java.io.InputStream; import java.net.HttpURLConnection; import java.net.URL; import java.security.MessageDigest; -import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -261,6 +261,11 @@ public class AiCloudDataCenterUtil { String fileId = addFileResponse.getBody().getData().getFileId(); waitForFileParsing(client, workspaceId, fileId); + + // fileId的UUID部分做为标签更新 file_df12ed21b7384353bd75868444c516ae_10377381 -> df12ed21b7384353bd75868444c516ae + String tag = StrUtil.subBetween(fileId, "_", "_").substring(0, 32); + updateFileTag(client, workspaceId, fileId, Arrays.asList(tag)); + return addFileResponse; } diff --git a/src/main/java/com/gxwebsoft/ai/util/AiCloudKnowledgeBaseUtil.java b/src/main/java/com/gxwebsoft/ai/util/AiCloudKnowledgeBaseUtil.java index 39f4bbd..eac84e3 100644 --- a/src/main/java/com/gxwebsoft/ai/util/AiCloudKnowledgeBaseUtil.java +++ b/src/main/java/com/gxwebsoft/ai/util/AiCloudKnowledgeBaseUtil.java @@ -68,6 +68,7 @@ public class AiCloudKnowledgeBaseUtil { searchFiltersTags.put("tags", JSON.toJSONString(filesIds)); searchFilters.add(searchFiltersTags); retrieveRequest.setSearchFilters(searchFilters); +// retrieveRequest.setRerankMinScore(null); RuntimeOptions runtime = new RuntimeOptions(); return client.retrieveWithOptions(workspaceId, retrieveRequest, null, runtime); }