From 8ca9d0a18f6615115bcdf8b08590012a968c8372 Mon Sep 17 00:00:00 2001 From: yuance <182865460@qq.com> Date: Thu, 21 May 2026 18:15:57 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E5=AE=A1=E8=AE=A1=E5=86=85?= =?UTF-8?q?=E5=AE=B9=E7=94=9F=E6=88=90=E6=97=B6=E6=96=87=E4=BB=B6=E8=BF=87?= =?UTF-8?q?=E6=BB=A4=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../BaseAuditContentController.java | 81 +++++++++++++++++-- .../gxwebsoft/ai/dto/AuditContentRequest.java | 2 +- .../AuditContent11HistoryServiceImpl.java | 4 +- 3 files changed, 76 insertions(+), 11 deletions(-) diff --git a/src/main/java/com/gxwebsoft/ai/controller/BaseAuditContentController.java b/src/main/java/com/gxwebsoft/ai/controller/BaseAuditContentController.java index c5831cb..267d3a1 100644 --- a/src/main/java/com/gxwebsoft/ai/controller/BaseAuditContentController.java +++ b/src/main/java/com/gxwebsoft/ai/controller/BaseAuditContentController.java @@ -85,16 +85,68 @@ public abstract class BaseAuditContentController extends BaseController { String libraryKbIds = ""; try { - // 查询项目库信息 + // 查询公共库信息 libraryKbIds = getLibraryKbIds(request.getLibraryIds()); + // 如果有docList/fileList,计算去重的fileIds并设置到ThreadLocal +// if (hasUploadedFiles(request)) { +// Set docIds = request.getDocList().stream().flatMap(docId -> aiCloudDocService.getSelfAndChildren(docId).stream()).map(AiCloudDoc::getId).collect(Collectors.toSet()); +// List relatedFiles = getRelatedFiles(docIds, request.getFileList()); +// List fileIds = relatedFiles.stream().map(AiCloudFile::getFileId).distinct().collect(Collectors.toList()); +// Set mainKbIds = Arrays.stream(request.getKbIds().split(",")).map(String::trim).filter(StrUtil::isNotBlank).collect(Collectors.toSet()); +// AbstractAuditContentService.setRequestFileIds(mainKbIds, fileIds); +// } + // 如果有docList/fileList,计算去重的fileIds并设置到ThreadLocal if (hasUploadedFiles(request)) { Set docIds = request.getDocList().stream().flatMap(docId -> aiCloudDocService.getSelfAndChildren(docId).stream()).map(AiCloudDoc::getId).collect(Collectors.toSet()); List relatedFiles = getRelatedFiles(docIds, request.getFileList()); - List fileIds = relatedFiles.stream().map(AiCloudFile::getFileId).distinct().collect(Collectors.toList()); - Set mainKbIds = Arrays.stream(request.getKbIds().split(",")).map(String::trim).filter(StrUtil::isNotBlank).collect(Collectors.toSet()); - AbstractAuditContentService.setRequestFileIds(mainKbIds, fileIds); + + // 查询这些文件所属的目录类型 + Set fileDocIds = relatedFiles.stream().map(AiCloudFile::getDocId).collect(Collectors.toSet()); + Map docTypeMap = aiCloudDocService.list( + new LambdaQueryWrapper() + .select(AiCloudDoc::getId, AiCloudDoc::getDocType) + .in(AiCloudDoc::getId, fileDocIds) + ).stream().collect(Collectors.toMap(AiCloudDoc::getId, AiCloudDoc::getDocType)); + + // 区分项目库文件和公共库文件 (docType=3 为公共目录) + List projectFileIds = new ArrayList<>(); + List libraryFileIds = new ArrayList<>(); + for (AiCloudFile file : relatedFiles) { + Integer docType = docTypeMap.get(file.getDocId()); + if (docType != null && docType == 3) { + libraryFileIds.add(file.getFileId()); + } else { + projectFileIds.add(file.getFileId()); + } + } + projectFileIds = projectFileIds.stream().distinct().collect(Collectors.toList()); + libraryFileIds = libraryFileIds.stream().distinct().collect(Collectors.toList()); + + // 准备 kbIds 集合 + Set projectKbIds = Arrays.stream(request.getKbIds().split(",")) + .map(String::trim).filter(StrUtil::isNotBlank).collect(Collectors.toSet()); + Set libraryKbIdsSet = Arrays.stream(libraryKbIds.split(",")) + .map(String::trim).filter(StrUtil::isNotBlank).collect(Collectors.toSet()); + + // 确定需要过滤的知识库集合 + Set mainKbIds = new HashSet<>(); + List combinedFileIds = new ArrayList<>(); + + if (!projectFileIds.isEmpty()) { + mainKbIds.addAll(projectKbIds); + combinedFileIds.addAll(projectFileIds); + } + if (!libraryFileIds.isEmpty()) { + mainKbIds.addAll(libraryKbIdsSet); + combinedFileIds.addAll(libraryFileIds); + } + + // 设置到ThreadLocal(若mainKbIds为空,表示无需要过滤的文件,可传空集合或跳过) + if (!mainKbIds.isEmpty()) { + AbstractAuditContentService.setRequestFileIds(mainKbIds, combinedFileIds); + } } // 生成数据(使用原来的默认知识库) @@ -228,12 +280,25 @@ public abstract class BaseAuditContentController extends BaseController { /** * 获取项目库KB IDs */ +// protected String getLibraryKbIds(String libraryIds) { +// if (StrUtil.isBlank(libraryIds)) { +// return ""; +// } +// List idList = StrUtil.split(libraryIds, ','); +// List ret = pwlProjectLibraryService.list(new LambdaQueryWrapper().in(PwlProjectLibrary::getId, idList)); +// return ret.stream().map(PwlProjectLibrary::getKbId).filter(StrUtil::isNotBlank).collect(Collectors.joining(",")); +// } + + /** + * 获取公共库KB IDs + */ protected String getLibraryKbIds(String libraryIds) { - if (StrUtil.isBlank(libraryIds)) { - return ""; + LambdaQueryWrapper wrapper = new LambdaQueryWrapper<>(); + if (StrUtil.isNotBlank(libraryIds)) { + List idList = StrUtil.split(libraryIds, ','); + wrapper.in(PwlProjectLibrary::getId, idList); } - List idList = StrUtil.split(libraryIds, ','); - List ret = pwlProjectLibraryService.list(new LambdaQueryWrapper().in(PwlProjectLibrary::getId, idList)); + List ret = pwlProjectLibraryService.list(wrapper); return ret.stream().map(PwlProjectLibrary::getKbId).filter(StrUtil::isNotBlank).collect(Collectors.joining(",")); } diff --git a/src/main/java/com/gxwebsoft/ai/dto/AuditContentRequest.java b/src/main/java/com/gxwebsoft/ai/dto/AuditContentRequest.java index ef95bac..0c7b85a 100644 --- a/src/main/java/com/gxwebsoft/ai/dto/AuditContentRequest.java +++ b/src/main/java/com/gxwebsoft/ai/dto/AuditContentRequest.java @@ -16,7 +16,7 @@ public class AuditContentRequest { private Long projectId; /** - * 企业库 + * 项目库 */ private String kbIds; diff --git a/src/main/java/com/gxwebsoft/ai/service/impl/AuditContent11HistoryServiceImpl.java b/src/main/java/com/gxwebsoft/ai/service/impl/AuditContent11HistoryServiceImpl.java index f75466c..48108f8 100644 --- a/src/main/java/com/gxwebsoft/ai/service/impl/AuditContent11HistoryServiceImpl.java +++ b/src/main/java/com/gxwebsoft/ai/service/impl/AuditContent11HistoryServiceImpl.java @@ -79,7 +79,7 @@ public class AuditContent11HistoryServiceImpl extends AbstractAuditContentServic .addAll(queryKnowledgeBase(kbId, queries, 150))); } - // 审计报告库检索 + // 法律法规库检索 if (StrUtil.isNotBlank(libraryKbIds)) { Arrays.stream(libraryKbIds.split(",")) .map(String::trim) @@ -92,7 +92,7 @@ public class AuditContent11HistoryServiceImpl extends AbstractAuditContentServic }); } - // 法律法规库检索(从项目库) + // 审计报告库检索(从案例库) if (StrUtil.isNotBlank(projectLibrary)) { knowledgeSources.get("regulations").addAll( queryKnowledgeBase(projectLibrary,