优化审计内容生成接口-添加可选文件、目录
This commit is contained in:
@@ -16,6 +16,7 @@ import com.gxwebsoft.pwl.service.PwlProjectLibraryService;
|
|||||||
import com.gxwebsoft.ai.service.AiCloudDocService;
|
import com.gxwebsoft.ai.service.AiCloudDocService;
|
||||||
import com.gxwebsoft.ai.service.AiCloudFileService;
|
import com.gxwebsoft.ai.service.AiCloudFileService;
|
||||||
import com.gxwebsoft.ai.service.KnowledgeBaseService;
|
import com.gxwebsoft.ai.service.KnowledgeBaseService;
|
||||||
|
import com.gxwebsoft.ai.service.impl.AbstractAuditContentService;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.web.bind.annotation.RequestBody;
|
import org.springframework.web.bind.annotation.RequestBody;
|
||||||
@@ -70,21 +71,23 @@ public abstract class BaseAuditContentController extends BaseController {
|
|||||||
}
|
}
|
||||||
request.setHistory(requestHistory);
|
request.setHistory(requestHistory);
|
||||||
|
|
||||||
String kbIdTmp = "";
|
|
||||||
String libraryKbIds = "";
|
String libraryKbIds = "";
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// 创建临时知识库(如果需要)
|
|
||||||
if (hasUploadedFiles(request)) {
|
|
||||||
kbIdTmp = createTempKnowledgeBase(request);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查询项目库信息
|
// 查询项目库信息
|
||||||
libraryKbIds = getLibraryKbIds(request.getLibraryIds());
|
libraryKbIds = getLibraryKbIds(request.getLibraryIds());
|
||||||
|
|
||||||
// 生成数据
|
// 如果有docList/fileList,计算去重的fileIds并设置到ThreadLocal
|
||||||
String knowledgeBaseId = getKnowledgeBaseId(kbIdTmp, request.getKbIds());
|
if (hasUploadedFiles(request)) {
|
||||||
GenerateParams params = new GenerateParams(knowledgeBaseId, libraryKbIds, request.getProjectLibrary(), loginUser.getUsername(), request.getHistory(), request.getSuggestion());
|
Set<Integer> docIds = request.getDocList().stream().flatMap(docId -> aiCloudDocService.getSelfAndChildren(docId).stream()).map(AiCloudDoc::getId).collect(Collectors.toSet());
|
||||||
|
List<AiCloudFile> relatedFiles = getRelatedFiles(docIds, request.getFileList());
|
||||||
|
List<String> fileIds = relatedFiles.stream().map(AiCloudFile::getFileId).distinct().collect(Collectors.toList());
|
||||||
|
Set<String> mainKbIds = Arrays.stream(request.getKbIds().split(",")).map(String::trim).filter(StrUtil::isNotBlank).collect(Collectors.toSet());
|
||||||
|
AbstractAuditContentService.setRequestFileIds(mainKbIds, fileIds);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 生成数据(使用原来的默认知识库)
|
||||||
|
GenerateParams params = new GenerateParams(request.getKbIds(), libraryKbIds, request.getProjectLibrary(), loginUser.getUsername(), request.getHistory(), request.getSuggestion());
|
||||||
|
|
||||||
JSONObject result = generateFunction.apply(params);
|
JSONObject result = generateFunction.apply(params);
|
||||||
|
|
||||||
@@ -100,7 +103,7 @@ public abstract class BaseAuditContentController extends BaseController {
|
|||||||
log.error("生成表格数据失败,接口: {}", interfaceName, e);
|
log.error("生成表格数据失败,接口: {}", interfaceName, e);
|
||||||
return fail("生成表格数据失败: " + e.getMessage());
|
return fail("生成表格数据失败: " + e.getMessage());
|
||||||
} finally {
|
} finally {
|
||||||
cleanupTempKnowledgeBase(kbIdTmp);
|
AbstractAuditContentService.clearRequestFileIds();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -199,7 +202,9 @@ public abstract class BaseAuditContentController extends BaseController {
|
|||||||
* 检查是否有上传的文件
|
* 检查是否有上传的文件
|
||||||
*/
|
*/
|
||||||
protected boolean hasUploadedFiles(AuditContentRequest request) {
|
protected boolean hasUploadedFiles(AuditContentRequest request) {
|
||||||
return !request.getDocList().isEmpty() || !request.getFileList().isEmpty();
|
List<Integer> docList = request.getDocList();
|
||||||
|
List<Integer> fileList = request.getFileList();
|
||||||
|
return (docList != null && !docList.isEmpty()) || (fileList != null && !fileList.isEmpty());
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -46,6 +46,27 @@ public abstract class AbstractAuditContentService {
|
|||||||
// 用于同步的锁对象池
|
// 用于同步的锁对象池
|
||||||
private static final Map<String, Object> kbLocks = new ConcurrentHashMap<>();
|
private static final Map<String, Object> kbLocks = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
|
// 当前请求的主知识库ID集合(用于判断是否需要按文件过滤)
|
||||||
|
private static final ThreadLocal<Set<String>> requestMainKbIds = new ThreadLocal<>();
|
||||||
|
// 当前请求需要过滤的文件ID列表
|
||||||
|
private static final ThreadLocal<List<String>> requestFileIds = new ThreadLocal<>();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置当前请求的文件ID过滤条件
|
||||||
|
*/
|
||||||
|
public static void setRequestFileIds(Set<String> mainKbIds, List<String> fileIds) {
|
||||||
|
requestMainKbIds.set(mainKbIds);
|
||||||
|
requestFileIds.set(fileIds);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 清除当前请求的文件ID过滤条件
|
||||||
|
*/
|
||||||
|
public static void clearRequestFileIds() {
|
||||||
|
requestMainKbIds.remove();
|
||||||
|
requestFileIds.remove();
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 调用工作流通用方法
|
* 调用工作流通用方法
|
||||||
*/
|
*/
|
||||||
@@ -147,8 +168,6 @@ public abstract class AbstractAuditContentService {
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 构建工作流请求通用方法
|
* 构建工作流请求通用方法
|
||||||
*/
|
*/
|
||||||
@@ -180,8 +199,15 @@ public abstract class AbstractAuditContentService {
|
|||||||
|
|
||||||
synchronized (lock) {
|
synchronized (lock) {
|
||||||
try {
|
try {
|
||||||
|
// 获取当前请求的fileIds(仅对主知识库生效)
|
||||||
|
List<String> fileIds = null;
|
||||||
|
Set<String> mainKbIds = requestMainKbIds.get();
|
||||||
|
if (mainKbIds != null && mainKbIds.contains(kbId)) {
|
||||||
|
fileIds = requestFileIds.get();
|
||||||
|
}
|
||||||
|
|
||||||
// 1. 收集所有节点和文档ID
|
// 1. 收集所有节点和文档ID
|
||||||
List<RetrieveResponseBodyDataNodes> allNodes = collectKnowledgeNodes(kbId, queries, topK);
|
List<RetrieveResponseBodyDataNodes> allNodes = collectKnowledgeNodes(kbId, queries, topK, fileIds);
|
||||||
if (allNodes.isEmpty()) {
|
if (allNodes.isEmpty()) {
|
||||||
return new ArrayList<>();
|
return new ArrayList<>();
|
||||||
}
|
}
|
||||||
@@ -202,14 +228,24 @@ public abstract class AbstractAuditContentService {
|
|||||||
/**
|
/**
|
||||||
* 收集知识库节点
|
* 收集知识库节点
|
||||||
*/
|
*/
|
||||||
private List<RetrieveResponseBodyDataNodes> collectKnowledgeNodes(String kbId, List<String> queries, int topK) {
|
private List<RetrieveResponseBodyDataNodes> collectKnowledgeNodes(String kbId, List<String> queries, int topK, List<String> fileIds) {
|
||||||
List<RetrieveResponseBodyDataNodes> allNodes = new ArrayList<>();
|
List<RetrieveResponseBodyDataNodes> allNodes = new ArrayList<>();
|
||||||
String workspaceId = config.getWorkspaceId();
|
String workspaceId = config.getWorkspaceId();
|
||||||
try {
|
try {
|
||||||
Client client = clientFactory.createClient();
|
Client client = clientFactory.createClient();
|
||||||
for (String query : queries) {
|
for (String query : queries) {
|
||||||
try {
|
try {
|
||||||
RetrieveResponse resp = AiCloudKnowledgeBaseUtil.retrieveIndex(client, workspaceId, kbId, query);
|
RetrieveResponse resp;
|
||||||
|
if (fileIds != null && !fileIds.isEmpty()) {
|
||||||
|
// fileId格式 file_xxxx_yyyy,知识库tag保存的是xxxx部分
|
||||||
|
List<String> tags = fileIds.stream()
|
||||||
|
.map(id -> StrUtil.subBetween(id, "_", "_"))
|
||||||
|
.filter(StrUtil::isNotBlank)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
resp = AiCloudKnowledgeBaseUtil.retrieveIndex(client, workspaceId, kbId, query, tags);
|
||||||
|
} else {
|
||||||
|
resp = AiCloudKnowledgeBaseUtil.retrieveIndex(client, workspaceId, kbId, query);
|
||||||
|
}
|
||||||
List<RetrieveResponseBodyDataNodes> nodes = Optional.ofNullable(resp)
|
List<RetrieveResponseBodyDataNodes> nodes = Optional.ofNullable(resp)
|
||||||
.map(RetrieveResponse::getBody)
|
.map(RetrieveResponse::getBody)
|
||||||
.map(RetrieveResponseBody::getData)
|
.map(RetrieveResponseBody::getData)
|
||||||
|
|||||||
Reference in New Issue
Block a user