优化审计内容生成接口-添加可选文件、目录

This commit is contained in:
2026-04-29 15:59:30 +08:00
parent d3833cfd0f
commit e583f8e352
2 changed files with 57 additions and 16 deletions

View File

@@ -16,6 +16,7 @@ import com.gxwebsoft.pwl.service.PwlProjectLibraryService;
import com.gxwebsoft.ai.service.AiCloudDocService;
import com.gxwebsoft.ai.service.AiCloudFileService;
import com.gxwebsoft.ai.service.KnowledgeBaseService;
import com.gxwebsoft.ai.service.impl.AbstractAuditContentService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.RequestBody;
@@ -70,21 +71,23 @@ public abstract class BaseAuditContentController extends BaseController {
}
request.setHistory(requestHistory);
String kbIdTmp = "";
String libraryKbIds = "";
try {
// 创建临时知识库(如果需要)
if (hasUploadedFiles(request)) {
kbIdTmp = createTempKnowledgeBase(request);
}
// 查询项目库信息
libraryKbIds = getLibraryKbIds(request.getLibraryIds());
// 生成数据
String knowledgeBaseId = getKnowledgeBaseId(kbIdTmp, request.getKbIds());
GenerateParams params = new GenerateParams(knowledgeBaseId, libraryKbIds, request.getProjectLibrary(), loginUser.getUsername(), request.getHistory(), request.getSuggestion());
// 如果有docList/fileList计算去重的fileIds并设置到ThreadLocal
if (hasUploadedFiles(request)) {
Set<Integer> docIds = request.getDocList().stream().flatMap(docId -> aiCloudDocService.getSelfAndChildren(docId).stream()).map(AiCloudDoc::getId).collect(Collectors.toSet());
List<AiCloudFile> relatedFiles = getRelatedFiles(docIds, request.getFileList());
List<String> fileIds = relatedFiles.stream().map(AiCloudFile::getFileId).distinct().collect(Collectors.toList());
Set<String> mainKbIds = Arrays.stream(request.getKbIds().split(",")).map(String::trim).filter(StrUtil::isNotBlank).collect(Collectors.toSet());
AbstractAuditContentService.setRequestFileIds(mainKbIds, fileIds);
}
// 生成数据(使用原来的默认知识库)
GenerateParams params = new GenerateParams(request.getKbIds(), libraryKbIds, request.getProjectLibrary(), loginUser.getUsername(), request.getHistory(), request.getSuggestion());
JSONObject result = generateFunction.apply(params);
@@ -100,7 +103,7 @@ public abstract class BaseAuditContentController extends BaseController {
log.error("生成表格数据失败,接口: {}", interfaceName, e);
return fail("生成表格数据失败: " + e.getMessage());
} finally {
cleanupTempKnowledgeBase(kbIdTmp);
AbstractAuditContentService.clearRequestFileIds();
}
}
@@ -199,7 +202,9 @@ public abstract class BaseAuditContentController extends BaseController {
* 检查是否有上传的文件
*/
protected boolean hasUploadedFiles(AuditContentRequest request) {
return !request.getDocList().isEmpty() || !request.getFileList().isEmpty();
List<Integer> docList = request.getDocList();
List<Integer> fileList = request.getFileList();
return (docList != null && !docList.isEmpty()) || (fileList != null && !fileList.isEmpty());
}
/**

View File

@@ -46,6 +46,27 @@ public abstract class AbstractAuditContentService {
// 用于同步的锁对象池
private static final Map<String, Object> kbLocks = new ConcurrentHashMap<>();
// 当前请求的主知识库ID集合用于判断是否需要按文件过滤
private static final ThreadLocal<Set<String>> requestMainKbIds = new ThreadLocal<>();
// 当前请求需要过滤的文件ID列表
private static final ThreadLocal<List<String>> requestFileIds = new ThreadLocal<>();
/**
* 设置当前请求的文件ID过滤条件
*/
public static void setRequestFileIds(Set<String> mainKbIds, List<String> fileIds) {
requestMainKbIds.set(mainKbIds);
requestFileIds.set(fileIds);
}
/**
* 清除当前请求的文件ID过滤条件
*/
public static void clearRequestFileIds() {
requestMainKbIds.remove();
requestFileIds.remove();
}
/**
* 调用工作流通用方法
*/
@@ -147,8 +168,6 @@ public abstract class AbstractAuditContentService {
return null;
}
/**
* 构建工作流请求通用方法
*/
@@ -180,8 +199,15 @@ public abstract class AbstractAuditContentService {
synchronized (lock) {
try {
// 获取当前请求的fileIds仅对主知识库生效
List<String> fileIds = null;
Set<String> mainKbIds = requestMainKbIds.get();
if (mainKbIds != null && mainKbIds.contains(kbId)) {
fileIds = requestFileIds.get();
}
// 1. 收集所有节点和文档ID
List<RetrieveResponseBodyDataNodes> allNodes = collectKnowledgeNodes(kbId, queries, topK);
List<RetrieveResponseBodyDataNodes> allNodes = collectKnowledgeNodes(kbId, queries, topK, fileIds);
if (allNodes.isEmpty()) {
return new ArrayList<>();
}
@@ -202,14 +228,24 @@ public abstract class AbstractAuditContentService {
/**
* 收集知识库节点
*/
private List<RetrieveResponseBodyDataNodes> collectKnowledgeNodes(String kbId, List<String> queries, int topK) {
private List<RetrieveResponseBodyDataNodes> collectKnowledgeNodes(String kbId, List<String> queries, int topK, List<String> fileIds) {
List<RetrieveResponseBodyDataNodes> allNodes = new ArrayList<>();
String workspaceId = config.getWorkspaceId();
try {
Client client = clientFactory.createClient();
for (String query : queries) {
try {
RetrieveResponse resp = AiCloudKnowledgeBaseUtil.retrieveIndex(client, workspaceId, kbId, query);
RetrieveResponse resp;
if (fileIds != null && !fileIds.isEmpty()) {
// fileId格式 file_xxxx_yyyy知识库tag保存的是xxxx部分
List<String> tags = fileIds.stream()
.map(id -> StrUtil.subBetween(id, "_", "_"))
.filter(StrUtil::isNotBlank)
.collect(Collectors.toList());
resp = AiCloudKnowledgeBaseUtil.retrieveIndex(client, workspaceId, kbId, query, tags);
} else {
resp = AiCloudKnowledgeBaseUtil.retrieveIndex(client, workspaceId, kbId, query);
}
List<RetrieveResponseBodyDataNodes> nodes = Optional.ofNullable(resp)
.map(RetrieveResponse::getBody)
.map(RetrieveResponseBody::getData)