优化审计内容生成接口-添加可选文件、目录
This commit is contained in:
@@ -16,6 +16,7 @@ import com.gxwebsoft.pwl.service.PwlProjectLibraryService;
|
||||
import com.gxwebsoft.ai.service.AiCloudDocService;
|
||||
import com.gxwebsoft.ai.service.AiCloudFileService;
|
||||
import com.gxwebsoft.ai.service.KnowledgeBaseService;
|
||||
import com.gxwebsoft.ai.service.impl.AbstractAuditContentService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
@@ -70,21 +71,23 @@ public abstract class BaseAuditContentController extends BaseController {
|
||||
}
|
||||
request.setHistory(requestHistory);
|
||||
|
||||
String kbIdTmp = "";
|
||||
String libraryKbIds = "";
|
||||
|
||||
try {
|
||||
// 创建临时知识库(如果需要)
|
||||
if (hasUploadedFiles(request)) {
|
||||
kbIdTmp = createTempKnowledgeBase(request);
|
||||
}
|
||||
|
||||
// 查询项目库信息
|
||||
libraryKbIds = getLibraryKbIds(request.getLibraryIds());
|
||||
|
||||
// 生成数据
|
||||
String knowledgeBaseId = getKnowledgeBaseId(kbIdTmp, request.getKbIds());
|
||||
GenerateParams params = new GenerateParams(knowledgeBaseId, libraryKbIds, request.getProjectLibrary(), loginUser.getUsername(), request.getHistory(), request.getSuggestion());
|
||||
// 如果有docList/fileList,计算去重的fileIds并设置到ThreadLocal
|
||||
if (hasUploadedFiles(request)) {
|
||||
Set<Integer> docIds = request.getDocList().stream().flatMap(docId -> aiCloudDocService.getSelfAndChildren(docId).stream()).map(AiCloudDoc::getId).collect(Collectors.toSet());
|
||||
List<AiCloudFile> relatedFiles = getRelatedFiles(docIds, request.getFileList());
|
||||
List<String> fileIds = relatedFiles.stream().map(AiCloudFile::getFileId).distinct().collect(Collectors.toList());
|
||||
Set<String> mainKbIds = Arrays.stream(request.getKbIds().split(",")).map(String::trim).filter(StrUtil::isNotBlank).collect(Collectors.toSet());
|
||||
AbstractAuditContentService.setRequestFileIds(mainKbIds, fileIds);
|
||||
}
|
||||
|
||||
// 生成数据(使用原来的默认知识库)
|
||||
GenerateParams params = new GenerateParams(request.getKbIds(), libraryKbIds, request.getProjectLibrary(), loginUser.getUsername(), request.getHistory(), request.getSuggestion());
|
||||
|
||||
JSONObject result = generateFunction.apply(params);
|
||||
|
||||
@@ -100,7 +103,7 @@ public abstract class BaseAuditContentController extends BaseController {
|
||||
log.error("生成表格数据失败,接口: {}", interfaceName, e);
|
||||
return fail("生成表格数据失败: " + e.getMessage());
|
||||
} finally {
|
||||
cleanupTempKnowledgeBase(kbIdTmp);
|
||||
AbstractAuditContentService.clearRequestFileIds();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -199,7 +202,9 @@ public abstract class BaseAuditContentController extends BaseController {
|
||||
* 检查是否有上传的文件
|
||||
*/
|
||||
protected boolean hasUploadedFiles(AuditContentRequest request) {
|
||||
return !request.getDocList().isEmpty() || !request.getFileList().isEmpty();
|
||||
List<Integer> docList = request.getDocList();
|
||||
List<Integer> fileList = request.getFileList();
|
||||
return (docList != null && !docList.isEmpty()) || (fileList != null && !fileList.isEmpty());
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -46,6 +46,27 @@ public abstract class AbstractAuditContentService {
|
||||
// 用于同步的锁对象池
|
||||
private static final Map<String, Object> kbLocks = new ConcurrentHashMap<>();
|
||||
|
||||
// 当前请求的主知识库ID集合(用于判断是否需要按文件过滤)
|
||||
private static final ThreadLocal<Set<String>> requestMainKbIds = new ThreadLocal<>();
|
||||
// 当前请求需要过滤的文件ID列表
|
||||
private static final ThreadLocal<List<String>> requestFileIds = new ThreadLocal<>();
|
||||
|
||||
/**
|
||||
* 设置当前请求的文件ID过滤条件
|
||||
*/
|
||||
public static void setRequestFileIds(Set<String> mainKbIds, List<String> fileIds) {
|
||||
requestMainKbIds.set(mainKbIds);
|
||||
requestFileIds.set(fileIds);
|
||||
}
|
||||
|
||||
/**
|
||||
* 清除当前请求的文件ID过滤条件
|
||||
*/
|
||||
public static void clearRequestFileIds() {
|
||||
requestMainKbIds.remove();
|
||||
requestFileIds.remove();
|
||||
}
|
||||
|
||||
/**
|
||||
* 调用工作流通用方法
|
||||
*/
|
||||
@@ -147,8 +168,6 @@ public abstract class AbstractAuditContentService {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* 构建工作流请求通用方法
|
||||
*/
|
||||
@@ -180,8 +199,15 @@ public abstract class AbstractAuditContentService {
|
||||
|
||||
synchronized (lock) {
|
||||
try {
|
||||
// 获取当前请求的fileIds(仅对主知识库生效)
|
||||
List<String> fileIds = null;
|
||||
Set<String> mainKbIds = requestMainKbIds.get();
|
||||
if (mainKbIds != null && mainKbIds.contains(kbId)) {
|
||||
fileIds = requestFileIds.get();
|
||||
}
|
||||
|
||||
// 1. 收集所有节点和文档ID
|
||||
List<RetrieveResponseBodyDataNodes> allNodes = collectKnowledgeNodes(kbId, queries, topK);
|
||||
List<RetrieveResponseBodyDataNodes> allNodes = collectKnowledgeNodes(kbId, queries, topK, fileIds);
|
||||
if (allNodes.isEmpty()) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
@@ -202,14 +228,24 @@ public abstract class AbstractAuditContentService {
|
||||
/**
|
||||
* 收集知识库节点
|
||||
*/
|
||||
private List<RetrieveResponseBodyDataNodes> collectKnowledgeNodes(String kbId, List<String> queries, int topK) {
|
||||
private List<RetrieveResponseBodyDataNodes> collectKnowledgeNodes(String kbId, List<String> queries, int topK, List<String> fileIds) {
|
||||
List<RetrieveResponseBodyDataNodes> allNodes = new ArrayList<>();
|
||||
String workspaceId = config.getWorkspaceId();
|
||||
try {
|
||||
Client client = clientFactory.createClient();
|
||||
for (String query : queries) {
|
||||
try {
|
||||
RetrieveResponse resp = AiCloudKnowledgeBaseUtil.retrieveIndex(client, workspaceId, kbId, query);
|
||||
RetrieveResponse resp;
|
||||
if (fileIds != null && !fileIds.isEmpty()) {
|
||||
// fileId格式 file_xxxx_yyyy,知识库tag保存的是xxxx部分
|
||||
List<String> tags = fileIds.stream()
|
||||
.map(id -> StrUtil.subBetween(id, "_", "_"))
|
||||
.filter(StrUtil::isNotBlank)
|
||||
.collect(Collectors.toList());
|
||||
resp = AiCloudKnowledgeBaseUtil.retrieveIndex(client, workspaceId, kbId, query, tags);
|
||||
} else {
|
||||
resp = AiCloudKnowledgeBaseUtil.retrieveIndex(client, workspaceId, kbId, query);
|
||||
}
|
||||
List<RetrieveResponseBodyDataNodes> nodes = Optional.ofNullable(resp)
|
||||
.map(RetrieveResponse::getBody)
|
||||
.map(RetrieveResponseBody::getData)
|
||||
|
||||
Reference in New Issue
Block a user