Files
java-10561/src/main/java/com/gxwebsoft/ai/service/impl/AbstractAuditContentService.java
2026-01-22 10:34:19 +08:00

333 lines
12 KiB
Java

package com.gxwebsoft.ai.service.impl;
import com.aliyun.bailian20231229.Client;
import com.aliyun.bailian20231229.models.RetrieveResponse;
import com.aliyun.bailian20231229.models.RetrieveResponseBody;
import com.aliyun.bailian20231229.models.RetrieveResponseBody.RetrieveResponseBodyData;
import com.aliyun.bailian20231229.models.RetrieveResponseBody.RetrieveResponseBodyDataNodes;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.gxwebsoft.ai.config.KnowledgeBaseConfig;
import com.gxwebsoft.ai.entity.AiCloudFile;
import com.gxwebsoft.ai.factory.KnowledgeBaseClientFactory;
import com.gxwebsoft.ai.service.AiCloudFileService;
import com.gxwebsoft.ai.util.AiCloudKnowledgeBaseUtil;
import com.gxwebsoft.common.core.context.TenantContext;
import cn.hutool.core.util.StrUtil;
import cn.hutool.http.HttpUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
@Slf4j
public abstract class AbstractAuditContentService {
@Autowired
protected KnowledgeBaseClientFactory clientFactory;
@Autowired
protected KnowledgeBaseConfig config;
@Autowired
protected AiCloudFileService aiCloudFileService;
protected static final String DIFY_WORKFLOW_URL = "http://1.14.159.185:8180/v1/workflows/run";
// 用于同步的锁对象池
private static final Map<String, Object> kbLocks = new ConcurrentHashMap<>();
/**
* 调用工作流通用方法
*/
protected JSONArray callWorkflow(String url, String token, JSONObject requestBody, String workflowName) {
try {
log.info("调用{}工作流,请求体长度: {}", workflowName, requestBody.toString().length());
String result = HttpUtil.createPost(url)
.header("Authorization", token)
.header("Content-Type", "application/json")
.body(requestBody.toString())
.timeout(10 * 60 * 1000)
.execute()
.body();
log.info("{}工作流返回结果长度: {}", workflowName, result.length());
ObjectMapper objectMapper = new ObjectMapper();
JsonNode rootNode = objectMapper.readTree(result);
String outputText = rootNode.path("data")
.path("outputs")
.path("result")
.asText();
if (StrUtil.isBlank(outputText)) {
log.warn("{}工作流返回结果为空", workflowName);
return new JSONArray();
}
JsonNode arrayNode = objectMapper.readTree(outputText);
JSONArray jsonArray = JSONArray.parseArray(arrayNode.toString());
log.info("成功解析{}工作流返回数据,记录数: {}", workflowName, jsonArray.size());
return jsonArray;
} catch (Exception e) {
log.error("调用{}工作流失败", workflowName, e);
throw new RuntimeException("调用" + workflowName + "工作流失败: " + e.getMessage(), e);
}
}
/**
* 构建工作流请求通用方法
*/
protected JSONObject buildWorkflowRequest(String knowledge, String userName) {
return buildWorkflowRequest(knowledge, userName, null);
}
protected JSONObject buildWorkflowRequest(String knowledge, String userName, Integer timeout) {
JSONObject requestBody = new JSONObject();
JSONObject inputs = new JSONObject();
inputs.put("knowledge", knowledge);
requestBody.put("inputs", inputs);
requestBody.put("response_mode", "blocking");
requestBody.put("user", userName);
if (timeout != null) {
requestBody.put("timeout", timeout);
}
return requestBody;
}
/**
* 查询知识库通用方法
*/
protected List<String> queryKnowledgeBase(String kbId, List<String> queries, int topK) {
Object lock = kbLocks.computeIfAbsent(kbId, k -> new Object());
synchronized (lock) {
try {
// 1. 收集所有节点和文档ID
List<RetrieveResponseBodyDataNodes> allNodes = collectKnowledgeNodes(kbId, queries, topK);
if (allNodes.isEmpty()) {
return new ArrayList<>();
}
// 2. 批量查询文件URL
Map<String, String> fileUrlMap = TenantContext.callIgnoreTenant(() -> batchQueryFileUrls(allNodes));
// 3. 处理节点生成结果
return processNodesToResults(allNodes, fileUrlMap);
} catch (Exception e) {
log.error("查询知识库失败 - kbId: {}", kbId, e);
return new ArrayList<>();
}
}
}
/**
* 收集知识库节点
*/
private List<RetrieveResponseBodyDataNodes> collectKnowledgeNodes(String kbId, List<String> queries, int topK) {
List<RetrieveResponseBodyDataNodes> allNodes = new ArrayList<>();
String workspaceId = config.getWorkspaceId();
try {
Client client = clientFactory.createClient();
for (String query : queries) {
try {
RetrieveResponse resp = AiCloudKnowledgeBaseUtil.retrieveIndex(client, workspaceId, kbId, query);
List<RetrieveResponseBodyDataNodes> nodes = Optional.ofNullable(resp)
.map(RetrieveResponse::getBody)
.map(RetrieveResponseBody::getData)
.map(RetrieveResponseBodyData::getNodes)
.orElse(Collections.emptyList())
.stream()
.limit(topK)
.collect(Collectors.toList());
allNodes.addAll(nodes);
} catch (Exception e) {
log.warn("查询知识库失败 - kbId: {}, query: {}", kbId, query, e);
}
}
} catch (Exception e) {
log.error("创建知识库客户端失败", e);
}
return allNodes;
}
/**
* 批量查询文件URL
*/
protected Map<String, String> batchQueryFileUrls(List<RetrieveResponseBodyDataNodes> nodes) {
// 收集所有文档ID
Set<String> docIds = nodes.stream().map(this::extractDocumentId).filter(StrUtil::isNotBlank).collect(Collectors.toSet());
if (docIds.isEmpty()) {
return Collections.emptyMap();
}
try {
// 批量查询
List<AiCloudFile> files = aiCloudFileService.list(new LambdaQueryWrapper<AiCloudFile>().in(AiCloudFile::getFileId, docIds));
// 构建映射表
return files.stream()
.filter(file -> file.getFileUrl() != null)
.collect(Collectors.toMap(
AiCloudFile::getFileId,
AiCloudFile::getFileUrl
));
} catch (Exception e) {
log.error("批量查询文件信息失败", e);
return Collections.emptyMap();
}
}
/**
* 处理节点生成结果
*/
private List<String> processNodesToResults(List<RetrieveResponseBodyDataNodes> nodes, Map<String, String> fileUrlMap) {
Set<String> results = new LinkedHashSet<>();
for (RetrieveResponseBodyDataNodes node : nodes) {
try {
// 检查文本有效性
String text = node.getText();
if (StrUtil.isBlank(text) || text.length() < 10) {
continue;
}
// 获取文档信息
String docName = extractDocumentName(node);
String docId = extractDocumentId(node);
String fileUrl = fileUrlMap.get(docId);
// 处理文件URL为空的情况
String url = StrUtil.isNotBlank(fileUrl) ? fileUrl : "";
// 格式化结果
String formattedText = String.format("《%s》【FileUrl:%s】%s", docName, url, text);
results.add(formattedText);
} catch (Exception e) {
log.warn("处理知识库节点失败", e);
}
}
return new ArrayList<>(results);
}
/**
* 提取文档名称通用方法
*/
protected String extractDocumentName(RetrieveResponseBodyDataNodes node) {
try {
Object metadataObj = node.getMetadata();
if (metadataObj instanceof Map) {
Map<?, ?> metadata = (Map<?, ?>) metadataObj;
Object docNameObj = metadata.get("doc_name");
if (docNameObj != null) {
return docNameObj.toString();
}
}
} catch (Exception e) {
log.debug("提取文档名称失败", e);
}
return "相关文档";
}
/**
* 提取文档Id通用方法
*/
protected String extractDocumentId(RetrieveResponseBodyDataNodes node) {
try {
Object metadataObj = node.getMetadata();
if (metadataObj instanceof Map) {
Map<?, ?> metadata = (Map<?, ?>) metadataObj;
Object docIdObj = metadata.get("doc_id");
if (docIdObj != null) {
return docIdObj.toString();
}
}
} catch (Exception e) {
log.debug("提取文档名称失败", e);
}
return "相关文档";
}
/**
* 构建成功响应通用方法
*/
protected JSONObject buildSuccessResponse(JSONArray data, long startTime) {
return buildSuccessResponse(data, startTime, null);
}
protected JSONObject buildSuccessResponse(JSONArray data, long startTime, String dataSource) {
JSONObject result = new JSONObject();
result.put("success", true);
result.put("data", data);
result.put("total_records", data.size());
result.put("generated_time", new Date().toString());
result.put("processing_time", (System.currentTimeMillis() - startTime) + "ms");
if (dataSource != null) {
result.put("data_source", dataSource);
}
return result;
}
/**
* 构建失败响应通用方法
*/
protected JSONObject buildErrorResponse(String errorMessage) {
JSONObject result = new JSONObject();
result.put("success", false);
result.put("error", errorMessage);
return result;
}
/**
* 异步处理分类数据通用方法
*/
protected Map<String, CompletableFuture<JSONArray>> processCategoriesAsync(
List<String> categories,
CategoryProcessor processor) {
Map<String, CompletableFuture<JSONArray>> futures = new LinkedHashMap<>();
for (String category : categories) {
CompletableFuture<JSONArray> future = processor.processCategory(category);
futures.put(category, future);
}
return futures;
}
/**
* 合并分类结果通用方法
*/
protected JSONArray mergeCategoryResults(List<String> categoryOrder,
Map<String, CompletableFuture<JSONArray>> futures) {
JSONArray allData = new JSONArray();
for (String category : categoryOrder) {
try {
JSONArray categoryData = futures.get(category).get();
if (categoryData != null && !categoryData.isEmpty()) {
allData.addAll(categoryData);
}
} catch (Exception e) {
log.error("获取分类 {} 数据失败", category, e);
}
}
return allData;
}
/**
* 分类处理器函数式接口
*/
@FunctionalInterface
public interface CategoryProcessor {
CompletableFuture<JSONArray> processCategory(String category);
}
}