333 lines
12 KiB
Java
333 lines
12 KiB
Java
package com.gxwebsoft.ai.service.impl;
|
|
|
|
import com.aliyun.bailian20231229.Client;
|
|
import com.aliyun.bailian20231229.models.RetrieveResponse;
|
|
import com.aliyun.bailian20231229.models.RetrieveResponseBody;
|
|
import com.aliyun.bailian20231229.models.RetrieveResponseBody.RetrieveResponseBodyData;
|
|
import com.aliyun.bailian20231229.models.RetrieveResponseBody.RetrieveResponseBodyDataNodes;
|
|
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
|
|
import com.fasterxml.jackson.databind.JsonNode;
|
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
import com.alibaba.fastjson.JSONArray;
|
|
import com.alibaba.fastjson.JSONObject;
|
|
import com.gxwebsoft.ai.config.KnowledgeBaseConfig;
|
|
import com.gxwebsoft.ai.entity.AiCloudFile;
|
|
import com.gxwebsoft.ai.factory.KnowledgeBaseClientFactory;
|
|
import com.gxwebsoft.ai.service.AiCloudFileService;
|
|
import com.gxwebsoft.ai.util.AiCloudKnowledgeBaseUtil;
|
|
import com.gxwebsoft.common.core.context.TenantContext;
|
|
|
|
import cn.hutool.core.util.StrUtil;
|
|
import cn.hutool.http.HttpUtil;
|
|
import lombok.extern.slf4j.Slf4j;
|
|
import org.springframework.beans.factory.annotation.Autowired;
|
|
|
|
import java.util.*;
|
|
import java.util.concurrent.CompletableFuture;
|
|
import java.util.concurrent.ConcurrentHashMap;
|
|
import java.util.stream.Collectors;
|
|
|
|
@Slf4j
|
|
public abstract class AbstractAuditContentService {
|
|
|
|
@Autowired
|
|
protected KnowledgeBaseClientFactory clientFactory;
|
|
|
|
@Autowired
|
|
protected KnowledgeBaseConfig config;
|
|
|
|
@Autowired
|
|
protected AiCloudFileService aiCloudFileService;
|
|
|
|
protected static final String DIFY_WORKFLOW_URL = "http://1.14.159.185:8180/v1/workflows/run";
|
|
|
|
// 用于同步的锁对象池
|
|
private static final Map<String, Object> kbLocks = new ConcurrentHashMap<>();
|
|
|
|
/**
|
|
* 调用工作流通用方法
|
|
*/
|
|
protected JSONArray callWorkflow(String url, String token, JSONObject requestBody, String workflowName) {
|
|
try {
|
|
log.info("调用{}工作流,请求体长度: {}", workflowName, requestBody.toString().length());
|
|
|
|
String result = HttpUtil.createPost(url)
|
|
.header("Authorization", token)
|
|
.header("Content-Type", "application/json")
|
|
.body(requestBody.toString())
|
|
.timeout(10 * 60 * 1000)
|
|
.execute()
|
|
.body();
|
|
|
|
log.info("{}工作流返回结果长度: {}", workflowName, result.length());
|
|
|
|
ObjectMapper objectMapper = new ObjectMapper();
|
|
JsonNode rootNode = objectMapper.readTree(result);
|
|
|
|
String outputText = rootNode.path("data")
|
|
.path("outputs")
|
|
.path("result")
|
|
.asText();
|
|
|
|
if (StrUtil.isBlank(outputText)) {
|
|
log.warn("{}工作流返回结果为空", workflowName);
|
|
return new JSONArray();
|
|
}
|
|
|
|
JsonNode arrayNode = objectMapper.readTree(outputText);
|
|
JSONArray jsonArray = JSONArray.parseArray(arrayNode.toString());
|
|
|
|
log.info("成功解析{}工作流返回数据,记录数: {}", workflowName, jsonArray.size());
|
|
return jsonArray;
|
|
|
|
} catch (Exception e) {
|
|
log.error("调用{}工作流失败", workflowName, e);
|
|
throw new RuntimeException("调用" + workflowName + "工作流失败: " + e.getMessage(), e);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* 构建工作流请求通用方法
|
|
*/
|
|
protected JSONObject buildWorkflowRequest(String knowledge, String userName) {
|
|
return buildWorkflowRequest(knowledge, userName, null);
|
|
}
|
|
|
|
protected JSONObject buildWorkflowRequest(String knowledge, String userName, Integer timeout) {
|
|
JSONObject requestBody = new JSONObject();
|
|
JSONObject inputs = new JSONObject();
|
|
|
|
inputs.put("knowledge", knowledge);
|
|
|
|
requestBody.put("inputs", inputs);
|
|
requestBody.put("response_mode", "blocking");
|
|
requestBody.put("user", userName);
|
|
if (timeout != null) {
|
|
requestBody.put("timeout", timeout);
|
|
}
|
|
|
|
return requestBody;
|
|
}
|
|
|
|
/**
|
|
* 查询知识库通用方法
|
|
*/
|
|
protected List<String> queryKnowledgeBase(String kbId, List<String> queries, int topK) {
|
|
Object lock = kbLocks.computeIfAbsent(kbId, k -> new Object());
|
|
|
|
synchronized (lock) {
|
|
try {
|
|
// 1. 收集所有节点和文档ID
|
|
List<RetrieveResponseBodyDataNodes> allNodes = collectKnowledgeNodes(kbId, queries, topK);
|
|
if (allNodes.isEmpty()) {
|
|
return new ArrayList<>();
|
|
}
|
|
|
|
// 2. 批量查询文件URL
|
|
Map<String, String> fileUrlMap = TenantContext.callIgnoreTenant(() -> batchQueryFileUrls(allNodes));
|
|
|
|
// 3. 处理节点生成结果
|
|
return processNodesToResults(allNodes, fileUrlMap);
|
|
|
|
} catch (Exception e) {
|
|
log.error("查询知识库失败 - kbId: {}", kbId, e);
|
|
return new ArrayList<>();
|
|
}
|
|
}
|
|
}
|
|
|
|
/**
|
|
* 收集知识库节点
|
|
*/
|
|
private List<RetrieveResponseBodyDataNodes> collectKnowledgeNodes(String kbId, List<String> queries, int topK) {
|
|
List<RetrieveResponseBodyDataNodes> allNodes = new ArrayList<>();
|
|
String workspaceId = config.getWorkspaceId();
|
|
try {
|
|
Client client = clientFactory.createClient();
|
|
for (String query : queries) {
|
|
try {
|
|
RetrieveResponse resp = AiCloudKnowledgeBaseUtil.retrieveIndex(client, workspaceId, kbId, query);
|
|
List<RetrieveResponseBodyDataNodes> nodes = Optional.ofNullable(resp)
|
|
.map(RetrieveResponse::getBody)
|
|
.map(RetrieveResponseBody::getData)
|
|
.map(RetrieveResponseBodyData::getNodes)
|
|
.orElse(Collections.emptyList())
|
|
.stream()
|
|
.limit(topK)
|
|
.collect(Collectors.toList());
|
|
allNodes.addAll(nodes);
|
|
} catch (Exception e) {
|
|
log.warn("查询知识库失败 - kbId: {}, query: {}", kbId, query, e);
|
|
}
|
|
}
|
|
} catch (Exception e) {
|
|
log.error("创建知识库客户端失败", e);
|
|
}
|
|
return allNodes;
|
|
}
|
|
|
|
/**
|
|
* 批量查询文件URL
|
|
*/
|
|
protected Map<String, String> batchQueryFileUrls(List<RetrieveResponseBodyDataNodes> nodes) {
|
|
// 收集所有文档ID
|
|
Set<String> docIds = nodes.stream().map(this::extractDocumentId).filter(StrUtil::isNotBlank).collect(Collectors.toSet());
|
|
if (docIds.isEmpty()) {
|
|
return Collections.emptyMap();
|
|
}
|
|
try {
|
|
// 批量查询
|
|
List<AiCloudFile> files = aiCloudFileService.list(new LambdaQueryWrapper<AiCloudFile>().in(AiCloudFile::getFileId, docIds));
|
|
// 构建映射表
|
|
return files.stream()
|
|
.filter(file -> file.getFileUrl() != null)
|
|
.collect(Collectors.toMap(
|
|
AiCloudFile::getFileId,
|
|
AiCloudFile::getFileUrl
|
|
));
|
|
} catch (Exception e) {
|
|
log.error("批量查询文件信息失败", e);
|
|
return Collections.emptyMap();
|
|
}
|
|
}
|
|
|
|
/**
|
|
* 处理节点生成结果
|
|
*/
|
|
private List<String> processNodesToResults(List<RetrieveResponseBodyDataNodes> nodes, Map<String, String> fileUrlMap) {
|
|
Set<String> results = new LinkedHashSet<>();
|
|
for (RetrieveResponseBodyDataNodes node : nodes) {
|
|
try {
|
|
// 检查文本有效性
|
|
String text = node.getText();
|
|
if (StrUtil.isBlank(text) || text.length() < 10) {
|
|
continue;
|
|
}
|
|
|
|
// 获取文档信息
|
|
String docName = extractDocumentName(node);
|
|
String docId = extractDocumentId(node);
|
|
String fileUrl = fileUrlMap.get(docId);
|
|
|
|
// 处理文件URL为空的情况
|
|
String url = StrUtil.isNotBlank(fileUrl) ? fileUrl : "无";
|
|
// 格式化结果
|
|
String formattedText = String.format("《%s》【FileUrl:%s】%s", docName, url, text);
|
|
results.add(formattedText);
|
|
} catch (Exception e) {
|
|
log.warn("处理知识库节点失败", e);
|
|
}
|
|
}
|
|
return new ArrayList<>(results);
|
|
}
|
|
|
|
/**
|
|
* 提取文档名称通用方法
|
|
*/
|
|
protected String extractDocumentName(RetrieveResponseBodyDataNodes node) {
|
|
try {
|
|
Object metadataObj = node.getMetadata();
|
|
if (metadataObj instanceof Map) {
|
|
Map<?, ?> metadata = (Map<?, ?>) metadataObj;
|
|
Object docNameObj = metadata.get("doc_name");
|
|
if (docNameObj != null) {
|
|
return docNameObj.toString();
|
|
}
|
|
}
|
|
} catch (Exception e) {
|
|
log.debug("提取文档名称失败", e);
|
|
}
|
|
return "相关文档";
|
|
}
|
|
|
|
/**
|
|
* 提取文档Id通用方法
|
|
*/
|
|
protected String extractDocumentId(RetrieveResponseBodyDataNodes node) {
|
|
try {
|
|
Object metadataObj = node.getMetadata();
|
|
if (metadataObj instanceof Map) {
|
|
Map<?, ?> metadata = (Map<?, ?>) metadataObj;
|
|
Object docIdObj = metadata.get("doc_id");
|
|
if (docIdObj != null) {
|
|
return docIdObj.toString();
|
|
}
|
|
}
|
|
} catch (Exception e) {
|
|
log.debug("提取文档名称失败", e);
|
|
}
|
|
return "相关文档";
|
|
}
|
|
|
|
/**
|
|
* 构建成功响应通用方法
|
|
*/
|
|
protected JSONObject buildSuccessResponse(JSONArray data, long startTime) {
|
|
return buildSuccessResponse(data, startTime, null);
|
|
}
|
|
|
|
protected JSONObject buildSuccessResponse(JSONArray data, long startTime, String dataSource) {
|
|
JSONObject result = new JSONObject();
|
|
result.put("success", true);
|
|
result.put("data", data);
|
|
result.put("total_records", data.size());
|
|
result.put("generated_time", new Date().toString());
|
|
result.put("processing_time", (System.currentTimeMillis() - startTime) + "ms");
|
|
if (dataSource != null) {
|
|
result.put("data_source", dataSource);
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/**
|
|
* 构建失败响应通用方法
|
|
*/
|
|
protected JSONObject buildErrorResponse(String errorMessage) {
|
|
JSONObject result = new JSONObject();
|
|
result.put("success", false);
|
|
result.put("error", errorMessage);
|
|
return result;
|
|
}
|
|
|
|
/**
|
|
* 异步处理分类数据通用方法
|
|
*/
|
|
protected Map<String, CompletableFuture<JSONArray>> processCategoriesAsync(
|
|
List<String> categories,
|
|
CategoryProcessor processor) {
|
|
|
|
Map<String, CompletableFuture<JSONArray>> futures = new LinkedHashMap<>();
|
|
for (String category : categories) {
|
|
CompletableFuture<JSONArray> future = processor.processCategory(category);
|
|
futures.put(category, future);
|
|
}
|
|
return futures;
|
|
}
|
|
|
|
/**
|
|
* 合并分类结果通用方法
|
|
*/
|
|
protected JSONArray mergeCategoryResults(List<String> categoryOrder,
|
|
Map<String, CompletableFuture<JSONArray>> futures) {
|
|
JSONArray allData = new JSONArray();
|
|
for (String category : categoryOrder) {
|
|
try {
|
|
JSONArray categoryData = futures.get(category).get();
|
|
if (categoryData != null && !categoryData.isEmpty()) {
|
|
allData.addAll(categoryData);
|
|
}
|
|
} catch (Exception e) {
|
|
log.error("获取分类 {} 数据失败", category, e);
|
|
}
|
|
}
|
|
return allData;
|
|
}
|
|
|
|
/**
|
|
* 分类处理器函数式接口
|
|
*/
|
|
@FunctionalInterface
|
|
public interface CategoryProcessor {
|
|
CompletableFuture<JSONArray> processCategory(String category);
|
|
}
|
|
} |