diff --git a/src/main/java/com/gxwebsoft/ai/controller/BaseAuditContentController.java b/src/main/java/com/gxwebsoft/ai/controller/BaseAuditContentController.java index 9c8eda5..c58bd9a 100644 --- a/src/main/java/com/gxwebsoft/ai/controller/BaseAuditContentController.java +++ b/src/main/java/com/gxwebsoft/ai/controller/BaseAuditContentController.java @@ -16,6 +16,7 @@ import com.gxwebsoft.pwl.service.PwlProjectLibraryService; import com.gxwebsoft.ai.service.AiCloudDocService; import com.gxwebsoft.ai.service.AiCloudFileService; import com.gxwebsoft.ai.service.KnowledgeBaseService; +import com.gxwebsoft.ai.service.impl.AbstractAuditContentService; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.RequestBody; @@ -70,21 +71,23 @@ public abstract class BaseAuditContentController extends BaseController { } request.setHistory(requestHistory); - String kbIdTmp = ""; String libraryKbIds = ""; try { - // 创建临时知识库(如果需要) - if (hasUploadedFiles(request)) { - kbIdTmp = createTempKnowledgeBase(request); - } - // 查询项目库信息 libraryKbIds = getLibraryKbIds(request.getLibraryIds()); - // 生成数据 - String knowledgeBaseId = getKnowledgeBaseId(kbIdTmp, request.getKbIds()); - GenerateParams params = new GenerateParams(knowledgeBaseId, libraryKbIds, request.getProjectLibrary(), loginUser.getUsername(), request.getHistory(), request.getSuggestion()); + // 如果有docList/fileList,计算去重的fileIds并设置到ThreadLocal + if (hasUploadedFiles(request)) { + Set docIds = request.getDocList().stream().flatMap(docId -> aiCloudDocService.getSelfAndChildren(docId).stream()).map(AiCloudDoc::getId).collect(Collectors.toSet()); + List relatedFiles = getRelatedFiles(docIds, request.getFileList()); + List fileIds = relatedFiles.stream().map(AiCloudFile::getFileId).distinct().collect(Collectors.toList()); + Set mainKbIds = Arrays.stream(request.getKbIds().split(",")).map(String::trim).filter(StrUtil::isNotBlank).collect(Collectors.toSet()); + AbstractAuditContentService.setRequestFileIds(mainKbIds, fileIds); + } + + // 生成数据(使用原来的默认知识库) + GenerateParams params = new GenerateParams(request.getKbIds(), libraryKbIds, request.getProjectLibrary(), loginUser.getUsername(), request.getHistory(), request.getSuggestion()); JSONObject result = generateFunction.apply(params); @@ -100,7 +103,7 @@ public abstract class BaseAuditContentController extends BaseController { log.error("生成表格数据失败,接口: {}", interfaceName, e); return fail("生成表格数据失败: " + e.getMessage()); } finally { - cleanupTempKnowledgeBase(kbIdTmp); + AbstractAuditContentService.clearRequestFileIds(); } } @@ -199,7 +202,9 @@ public abstract class BaseAuditContentController extends BaseController { * 检查是否有上传的文件 */ protected boolean hasUploadedFiles(AuditContentRequest request) { - return !request.getDocList().isEmpty() || !request.getFileList().isEmpty(); + List docList = request.getDocList(); + List fileList = request.getFileList(); + return (docList != null && !docList.isEmpty()) || (fileList != null && !fileList.isEmpty()); } /** diff --git a/src/main/java/com/gxwebsoft/ai/service/impl/AbstractAuditContentService.java b/src/main/java/com/gxwebsoft/ai/service/impl/AbstractAuditContentService.java index 52e2e30..ccfbb21 100644 --- a/src/main/java/com/gxwebsoft/ai/service/impl/AbstractAuditContentService.java +++ b/src/main/java/com/gxwebsoft/ai/service/impl/AbstractAuditContentService.java @@ -46,6 +46,27 @@ public abstract class AbstractAuditContentService { // 用于同步的锁对象池 private static final Map kbLocks = new ConcurrentHashMap<>(); + // 当前请求的主知识库ID集合(用于判断是否需要按文件过滤) + private static final ThreadLocal> requestMainKbIds = new ThreadLocal<>(); + // 当前请求需要过滤的文件ID列表 + private static final ThreadLocal> requestFileIds = new ThreadLocal<>(); + + /** + * 设置当前请求的文件ID过滤条件 + */ + public static void setRequestFileIds(Set mainKbIds, List fileIds) { + requestMainKbIds.set(mainKbIds); + requestFileIds.set(fileIds); + } + + /** + * 清除当前请求的文件ID过滤条件 + */ + public static void clearRequestFileIds() { + requestMainKbIds.remove(); + requestFileIds.remove(); + } + /** * 调用工作流通用方法 */ @@ -147,8 +168,6 @@ public abstract class AbstractAuditContentService { return null; } - - /** * 构建工作流请求通用方法 */ @@ -180,8 +199,15 @@ public abstract class AbstractAuditContentService { synchronized (lock) { try { + // 获取当前请求的fileIds(仅对主知识库生效) + List fileIds = null; + Set mainKbIds = requestMainKbIds.get(); + if (mainKbIds != null && mainKbIds.contains(kbId)) { + fileIds = requestFileIds.get(); + } + // 1. 收集所有节点和文档ID - List allNodes = collectKnowledgeNodes(kbId, queries, topK); + List allNodes = collectKnowledgeNodes(kbId, queries, topK, fileIds); if (allNodes.isEmpty()) { return new ArrayList<>(); } @@ -202,14 +228,24 @@ public abstract class AbstractAuditContentService { /** * 收集知识库节点 */ - private List collectKnowledgeNodes(String kbId, List queries, int topK) { + private List collectKnowledgeNodes(String kbId, List queries, int topK, List fileIds) { List allNodes = new ArrayList<>(); String workspaceId = config.getWorkspaceId(); try { Client client = clientFactory.createClient(); for (String query : queries) { try { - RetrieveResponse resp = AiCloudKnowledgeBaseUtil.retrieveIndex(client, workspaceId, kbId, query); + RetrieveResponse resp; + if (fileIds != null && !fileIds.isEmpty()) { + // fileId格式 file_xxxx_yyyy,知识库tag保存的是xxxx部分 + List tags = fileIds.stream() + .map(id -> StrUtil.subBetween(id, "_", "_")) + .filter(StrUtil::isNotBlank) + .collect(Collectors.toList()); + resp = AiCloudKnowledgeBaseUtil.retrieveIndex(client, workspaceId, kbId, query, tags); + } else { + resp = AiCloudKnowledgeBaseUtil.retrieveIndex(client, workspaceId, kbId, query); + } List nodes = Optional.ofNullable(resp) .map(RetrieveResponse::getBody) .map(RetrieveResponseBody::getData)