新增知识库相关接口

This commit is contained in:
2025-09-24 11:23:25 +08:00
parent 8efaa62d95
commit 1cf0246b31
7 changed files with 738 additions and 47 deletions

View File

@@ -1,6 +1,5 @@
package com.gxwebsoft.ai.controller; package com.gxwebsoft.ai.controller;
import java.io.FileOutputStream;
import java.io.OutputStream; import java.io.OutputStream;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
@@ -20,13 +19,13 @@ import com.gxwebsoft.ai.config.TemplateConfig;
import com.gxwebsoft.ai.dto.AuditReportRequest; import com.gxwebsoft.ai.dto.AuditReportRequest;
import com.gxwebsoft.ai.dto.KnowledgeBaseRequest; import com.gxwebsoft.ai.dto.KnowledgeBaseRequest;
import com.gxwebsoft.ai.enums.AuditReportEnum; import com.gxwebsoft.ai.enums.AuditReportEnum;
import com.gxwebsoft.ai.service.KnowledgeBaseService;
import com.gxwebsoft.ai.util.AuditReportUtil; import com.gxwebsoft.ai.util.AuditReportUtil;
import com.gxwebsoft.common.core.web.ApiResult; import com.gxwebsoft.common.core.web.ApiResult;
import com.gxwebsoft.common.core.web.BaseController; import com.gxwebsoft.common.core.web.BaseController;
import com.gxwebsoft.common.system.entity.User; import com.gxwebsoft.common.system.entity.User;
import cn.afterturn.easypoi.word.WordExportUtil; import cn.afterturn.easypoi.word.WordExportUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.http.HttpUtil; import cn.hutool.http.HttpUtil;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
@@ -45,7 +44,7 @@ public class AuditReportController extends BaseController {
private TemplateConfig templateConfig; private TemplateConfig templateConfig;
@Autowired @Autowired
private KnowledgeBaseController knowledgeBaseController; private KnowledgeBaseService knowledgeBaseService;
private String invok(String query, String knowledge, String history, String suggestion, String title, String userName) { private String invok(String query, String knowledge, String history, String suggestion, String title, String userName) {
// 构建请求体 // 构建请求体
@@ -91,7 +90,7 @@ public class AuditReportController extends BaseController {
KnowledgeBaseRequest knowledgeBaseRequest = new KnowledgeBaseRequest(); KnowledgeBaseRequest knowledgeBaseRequest = new KnowledgeBaseRequest();
knowledgeBaseRequest.setKbId(req.getKbId()); knowledgeBaseRequest.setKbId(req.getKbId());
knowledgeBaseRequest.setFormCommit((req.getFormCommit() >= 10) ? req.getFormCommit() / 10 : req.getFormCommit()); knowledgeBaseRequest.setFormCommit((req.getFormCommit() >= 10) ? req.getFormCommit() / 10 : req.getFormCommit());
String knowledge = knowledgeBaseController.query(knowledgeBaseRequest).getData().toString(); String knowledge = knowledgeBaseService.queryKnowledgeBase(knowledgeBaseRequest).toString();
String query = AuditReportEnum.getByCode(req.getFormCommit()).getDesc(); String query = AuditReportEnum.getByCode(req.getFormCommit()).getDesc();
// String ret = this.invok(query, knowledge, AuditReportUtil.generateReportContent(req), req.getSuggestion(), loginUser.getUsername()); // String ret = this.invok(query, knowledge, AuditReportUtil.generateReportContent(req), req.getSuggestion(), loginUser.getUsername());

View File

@@ -1,25 +1,15 @@
package com.gxwebsoft.ai.controller; package com.gxwebsoft.ai.controller;
import com.aliyun.bailian20231229.Client;
import com.aliyun.bailian20231229.models.RetrieveResponse;
import com.aliyun.bailian20231229.models.RetrieveResponseBody.RetrieveResponseBodyDataNodes;
import com.gxwebsoft.ai.config.KnowledgeBaseConfig;
import com.gxwebsoft.ai.constants.KnowledgeBaseConstants;
import com.gxwebsoft.ai.factory.KnowledgeBaseClientFactory;
import com.gxwebsoft.ai.util.KnowledgeBaseRetrieve;
import com.gxwebsoft.ai.dto.KnowledgeBaseRequest; import com.gxwebsoft.ai.dto.KnowledgeBaseRequest;
import com.gxwebsoft.ai.service.KnowledgeBaseService;
import com.gxwebsoft.common.core.web.ApiResult; import com.gxwebsoft.common.core.web.ApiResult;
import com.gxwebsoft.common.core.web.BaseController; import com.gxwebsoft.common.core.web.BaseController;
import cn.hutool.core.util.StrUtil;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
@Tag(name = "知识库") @Tag(name = "知识库")
@RestController @RestController
@@ -27,33 +17,80 @@ import java.util.Set;
public class KnowledgeBaseController extends BaseController { public class KnowledgeBaseController extends BaseController {
@Autowired @Autowired
private KnowledgeBaseConfig config; private KnowledgeBaseService knowledgeBaseService;
@Autowired
private KnowledgeBaseClientFactory clientFactory;
@Operation(summary = "查询知识库") @Operation(summary = "查询知识库")
@GetMapping("/query") @GetMapping("/query")
public ApiResult<?> query(KnowledgeBaseRequest req) { public ApiResult<?> query(KnowledgeBaseRequest req) {
Set<String> ret = new LinkedHashSet<>();
String workspaceId = config.getWorkspaceId();
List<String> keyWords = Arrays.asList(KnowledgeBaseConstants.KEY_WORDS);
String indexId = req.getKbId();
String query = StrUtil.isEmpty(req.getQuery()) ? keyWords.get(req.getFormCommit()) : req.getQuery();
Integer topK = req.getTopK() == null ? 10 : req.getTopK();
try { try {
Client client = clientFactory.createClient(); return success(knowledgeBaseService.queryKnowledgeBase(req));
RetrieveResponse resp = KnowledgeBaseRetrieve.retrieveIndex(client, workspaceId, indexId, query);
for (RetrieveResponseBodyDataNodes node : resp.getBody().getData().getNodes()) {
ret.add(node.getText());
if (ret.size() >= topK) {
break;
}
}
} catch (Exception e) { } catch (Exception e) {
return fail("查询失败:" + e.getMessage()); return fail("查询失败:" + e.getMessage());
} }
return success(ret); }
@Operation(summary = "创建知识库")
@PostMapping("/create")
public ApiResult<?> create(@RequestParam String companyName, @RequestParam String companyCode) {
try {
String indexId = knowledgeBaseService.createKnowledgeBase(companyName, companyCode);
return success(indexId);
} catch (Exception e) {
return fail("创建失败:" + e.getMessage());
}
}
@Operation(summary = "查询知识库下的文档列表")
@GetMapping("/documents")
public ApiResult<?> listDocuments(String kbId) {
try {
return success(knowledgeBaseService.listDocuments(kbId));
} catch (Exception e) {
return fail("查询文档列表失败:" + e.getMessage());
}
}
@Operation(summary = "删除知识库下的文档")
@DeleteMapping("/document")
public ApiResult<?> deleteDocument(String kbId, String fileIds) {
try {
boolean result = knowledgeBaseService.deleteIndexDocument(kbId, fileIds);
if (result) {
return success("删除成功");
} else {
return fail("删除失败");
}
} catch (Exception e) {
return fail("删除文档失败: " + e.getMessage());
}
}
@Operation(summary = "上传文档到知识库")
@PostMapping("/upload")
public ApiResult<?> uploadDocuments(@RequestParam String kbId, @RequestParam("files") MultipartFile[] files) {
try {
if (files == null || files.length == 0) {
return fail("请选择要上传的文件");
}
for (MultipartFile file : files) {
if (file.isEmpty()) {
return fail("文件不能为空: " + file.getOriginalFilename());
}
if (file.getSize() > 100 * 1024 * 1024) {
return fail("文件大小不能超过100MB: " + file.getOriginalFilename());
}
}
boolean result = knowledgeBaseService.uploadDocuments(kbId, files);
if (result) {
return success("成功上传 " + files.length + " 个文件");
} else {
return success("部分文件上传成功");
}
} catch (Exception e) {
return fail("上传失败: " + e.getMessage());
}
} }
} }

View File

@@ -0,0 +1,51 @@
package com.gxwebsoft.ai.service;
import com.gxwebsoft.ai.dto.KnowledgeBaseRequest;
import java.util.List;
import java.util.Set;
import org.springframework.web.multipart.MultipartFile;
public interface KnowledgeBaseService {
/**
* 查询知识库
*/
Set<String> queryKnowledgeBase(KnowledgeBaseRequest req);
/**
* 查询知识库(带参数)
*/
Set<String> queryKnowledgeBase(String kbId, String query, Integer topK, Integer formCommit);
/**
* 创建知识库
*/
String createKnowledgeBase(String companyName, String companyCode);
/**
* 检查知识库是否存
*/
boolean existsKnowledgeBase(String companyCode);
/**
* 查找知识库ID
*/
String getKnowledgeBaseId(String companyCode);
/**
* 查询知识库下的文档列表
*/
List<Object> listDocuments(String kbId);
/**
* 删除知识库下的文档
*/
boolean deleteIndexDocument(String kbId, String fileIds);
/**
* 上传知识库文档
*/
boolean uploadDocuments(String kbId, MultipartFile[] files);
}

View File

@@ -0,0 +1,153 @@
package com.gxwebsoft.ai.service.impl;
import com.aliyun.bailian20231229.Client;
import com.aliyun.bailian20231229.models.CreateIndexResponse;
import com.aliyun.bailian20231229.models.DeleteIndexDocumentResponse;
import com.aliyun.bailian20231229.models.ListIndexDocumentsResponse;
import com.aliyun.bailian20231229.models.ListIndicesResponse;
import com.aliyun.bailian20231229.models.RetrieveResponse;
import com.aliyun.bailian20231229.models.RetrieveResponseBody.RetrieveResponseBodyDataNodes;
import com.gxwebsoft.ai.config.KnowledgeBaseConfig;
import com.gxwebsoft.ai.constants.KnowledgeBaseConstants;
import com.gxwebsoft.ai.dto.KnowledgeBaseRequest;
import com.gxwebsoft.ai.factory.KnowledgeBaseClientFactory;
import com.gxwebsoft.ai.service.KnowledgeBaseService;
import com.gxwebsoft.ai.util.KnowledgeBaseUploader;
import com.gxwebsoft.ai.util.KnowledgeBaseUtil;
import cn.hutool.core.util.StrUtil;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
@Service
public class KnowledgeBaseServiceImpl implements KnowledgeBaseService {
@Autowired
private KnowledgeBaseConfig config;
@Autowired
private KnowledgeBaseClientFactory clientFactory;
@Override
public Set<String> queryKnowledgeBase(KnowledgeBaseRequest req) {
return queryKnowledgeBase(req.getKbId(), req.getQuery(), req.getTopK(), req.getFormCommit());
}
@Override
public Set<String> queryKnowledgeBase(String kbId, String query, Integer topK, Integer formCommit) {
Set<String> result = new LinkedHashSet<>();
String workspaceId = config.getWorkspaceId();
List<String> keyWords = Arrays.asList(KnowledgeBaseConstants.KEY_WORDS);
String indexId = kbId;
String searchQuery = StrUtil.isEmpty(query) ? keyWords.get(formCommit) : query;
Integer searchTopK = topK == null ? 10 : topK;
try {
Client client = clientFactory.createClient();
RetrieveResponse resp = KnowledgeBaseUtil.retrieveIndex(client, workspaceId, indexId, searchQuery);
for (RetrieveResponseBodyDataNodes node : resp.getBody().getData().getNodes()) {
result.add(node.getText());
if (result.size() >= searchTopK) {
break;
}
}
} catch (Exception e) {
throw new RuntimeException("查询知识库失败: " + e.getMessage(), e);
}
return result;
}
@Override
public String createKnowledgeBase(String companyName, String companyCode) {
String workspaceId = config.getWorkspaceId();
try {
Client client = clientFactory.createClient();
CreateIndexResponse indexResponse = KnowledgeBaseUtil.createIndex(client, workspaceId, companyCode, companyName);
return indexResponse.getBody().getData().getId();
} catch (Exception e) {
throw new RuntimeException("创建知识库失败: " + e.getMessage(), e);
}
}
@Override
public boolean existsKnowledgeBase(String companyCode) {
String workspaceId = config.getWorkspaceId();
try {
Client client = clientFactory.createClient();
ListIndicesResponse indicesResponse = KnowledgeBaseUtil.listIndices(client, workspaceId);
return indicesResponse.getBody().getData().getIndices().stream()
.anyMatch(index -> companyCode.equals(index.getName()));
} catch (Exception e) {
throw new RuntimeException("检查知识库是否存在失败: " + e.getMessage(), e);
}
}
@Override
public String getKnowledgeBaseId(String companyCode) {
String workspaceId = config.getWorkspaceId();
try {
Client client = clientFactory.createClient();
ListIndicesResponse indicesResponse = KnowledgeBaseUtil.listIndices(client, workspaceId);
return indicesResponse.getBody().getData().getIndices().stream()
.filter(index -> companyCode.equals(index.getName()))
.findFirst()
.map(index -> index.getId())
.orElse("");
} catch (Exception e) {
throw new RuntimeException("查找知识库ID失败: " + e.getMessage(), e);
}
}
@Override
public List<Object> listDocuments(String kbId) {
List<Object> ret = new ArrayList<>();
String workspaceId = config.getWorkspaceId();
try {
Client client = clientFactory.createClient();
ListIndexDocumentsResponse indexDocumentsResponse = KnowledgeBaseUtil.listIndexDocuments(client, workspaceId, kbId);
ret.addAll(indexDocumentsResponse.getBody().getData().getDocuments());
} catch (Exception e) {
throw new RuntimeException("查询知识库下的文档列表失败: " + e.getMessage(), e);
}
return ret;
}
@Override
public boolean deleteIndexDocument(String kbId, String fileIds) {
String workspaceId = config.getWorkspaceId();
List<String> ids = StrUtil.splitTrim(fileIds, ",");
try {
Client client = clientFactory.createClient();
DeleteIndexDocumentResponse indexDocumentResponse = KnowledgeBaseUtil.deleteIndexDocument(client, workspaceId, kbId, ids);
return indexDocumentResponse.getBody().getSuccess();
} catch (Exception e) {
throw new RuntimeException("删除知识库下的文档失败: " + e.getMessage(), e);
}
}
@Override
public boolean uploadDocuments(String kbId, MultipartFile[] files) {
String workspaceId = config.getWorkspaceId();
int count = files.length;
try {
Client client = clientFactory.createClient();
List<String> fileIds = KnowledgeBaseUploader.uploadDocuments(client, workspaceId, kbId, files);
//上传切片完成后删除原文档(释放云空间)
for(String fileId : fileIds) {
KnowledgeBaseUtil.deleteAppDocument(client, workspaceId, fileId);
}
return !fileIds.isEmpty() && fileIds.size() == count;
} catch (Exception e) {
throw new RuntimeException("上传文档到知识库失败: " + e.getMessage(), e);
}
}
}

View File

@@ -0,0 +1,303 @@
package com.gxwebsoft.ai.util;
import com.aliyun.bailian20231229.Client;
import com.aliyun.bailian20231229.models.*;
import com.aliyun.teautil.models.RuntimeOptions;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.net.HttpURLConnection;
import java.net.URL;
import java.nio.file.Paths;
import java.security.MessageDigest;
import java.util.*;
import org.springframework.web.multipart.MultipartFile;
/**
* 知识库上传工具类
* @author GIIT-YC
*
*/
public class KnowledgeBaseUploader {
/**
* 上传文档到知识库直接处理MultipartFile
*
* @param client 阿里云客户端
* @param workspaceId 业务空间ID
* @param indexId 知识库ID
* @param file 上传的文件
* @return 新文档的FileID失败返回null
*/
public static String uploadDocument(com.aliyun.bailian20231229.Client client, String workspaceId, String indexId, MultipartFile file) {
try {
// 准备文档信息
String fileName = file.getOriginalFilename();
String fileMd5 = calculateMD5(file.getInputStream());
String fileSize = String.valueOf(file.getSize());
// 申请上传租约
ApplyFileUploadLeaseRequest leaseRequest = new ApplyFileUploadLeaseRequest()
.setFileName(fileName)
.setMd5(fileMd5)
.setSizeInBytes(fileSize);
ApplyFileUploadLeaseResponse leaseResponse = client.applyFileUploadLeaseWithOptions(
"default", workspaceId, leaseRequest, new HashMap<>(), new RuntimeOptions());
String leaseId = leaseResponse.getBody().getData().getFileUploadLeaseId();
String uploadUrl = leaseResponse.getBody().getData().getParam().getUrl();
// 上传文件
ObjectMapper mapper = new ObjectMapper();
Map<String, String> headers = mapper.readValue(mapper.writeValueAsString(leaseResponse.getBody().getData().getParam().getHeaders()), Map.class);
uploadFile(uploadUrl, headers, file);
// 添加文件到类目
AddFileRequest addRequest = new AddFileRequest()
.setLeaseId(leaseId)
.setParser("DASHSCOPE_DOCMIND")
.setCategoryId("default");
AddFileResponse addResponse = client.addFileWithOptions(workspaceId, addRequest, new HashMap<>(), new RuntimeOptions());
String fileId = addResponse.getBody().getData().getFileId();
// 等待文件解析完成
waitForFileParsing(client, workspaceId, fileId);
// 添加到知识库
SubmitIndexAddDocumentsJobRequest indexRequest = new SubmitIndexAddDocumentsJobRequest()
.setIndexId(indexId)
.setDocumentIds(Collections.singletonList(fileId))
.setSourceType("DATA_CENTER_FILE");
SubmitIndexAddDocumentsJobResponse indexResponse = client.submitIndexAddDocumentsJobWithOptions(workspaceId, indexRequest, new HashMap<>(), new RuntimeOptions());
// 等待索引完成
waitForIndexJob(client, workspaceId, indexResponse.getBody().getData().getId(), indexId);
return fileId;
} catch (Exception e) {
e.printStackTrace();
return null;
}
}
/**
* 批量上传文档到知识库
*/
public static List<String> uploadDocuments(com.aliyun.bailian20231229.Client client, String workspaceId, String indexId, MultipartFile[] files) {
List<String> fileIds = new ArrayList<>();
for (MultipartFile file : files) {
String fileId = uploadDocument(client, workspaceId, indexId, file);
if (fileId != null) {
fileIds.add(fileId);
}
}
return fileIds;
}
/**
* 上传文档到知识库
*
* @param client 阿里云客户端
* @param workspaceId 业务空间ID
* @param indexId 知识库ID
* @param filePath 文档本地路径
* @return 新文档的FileID失败返回null
*/
public static String uploadDocument(com.aliyun.bailian20231229.Client client, String workspaceId, String indexId, String filePath) {
try {
// 准备文档信息
String fileName = Paths.get(filePath).getFileName().toString();
String fileMd5 = calculateMD5(filePath);
String fileSize = String.valueOf(new File(filePath).length());
// 申请上传租约
ApplyFileUploadLeaseRequest leaseRequest = new ApplyFileUploadLeaseRequest()
.setFileName(fileName)
.setMd5(fileMd5)
.setSizeInBytes(fileSize);
ApplyFileUploadLeaseResponse leaseResponse = client.applyFileUploadLeaseWithOptions(
"default", workspaceId, leaseRequest, new HashMap<>(), new RuntimeOptions());
String leaseId = leaseResponse.getBody().getData().getFileUploadLeaseId();
String uploadUrl = leaseResponse.getBody().getData().getParam().getUrl();
// 上传文件
ObjectMapper mapper = new ObjectMapper();
Map<String, String> headers = mapper.readValue(
mapper.writeValueAsString(leaseResponse.getBody().getData().getParam().getHeaders()),
Map.class);
uploadFile(uploadUrl, headers, filePath);
// 添加文件到类目
AddFileRequest addRequest = new AddFileRequest()
.setLeaseId(leaseId)
.setParser("DASHSCOPE_DOCMIND")
.setCategoryId("default");
AddFileResponse addResponse = client.addFileWithOptions(
workspaceId, addRequest, new HashMap<>(), new RuntimeOptions());
String fileId = addResponse.getBody().getData().getFileId();
// 等待文件解析完成
waitForFileParsing(client, workspaceId, fileId);
// 添加到知识库
SubmitIndexAddDocumentsJobRequest indexRequest = new SubmitIndexAddDocumentsJobRequest()
.setIndexId(indexId)
.setDocumentIds(Collections.singletonList(fileId))
.setSourceType("DATA_CENTER_FILE");
SubmitIndexAddDocumentsJobResponse indexResponse = client.submitIndexAddDocumentsJobWithOptions(
workspaceId, indexRequest, new HashMap<>(), new RuntimeOptions());
// 等待索引完成
waitForIndexJob(client, workspaceId, indexResponse.getBody().getData().getId(), indexId);
return fileId;
} catch (Exception e) {
e.printStackTrace();
return null;
}
}
private static String calculateMD5(String filePath) throws Exception {
MessageDigest md = MessageDigest.getInstance("MD5");
try (FileInputStream fis = new FileInputStream(filePath)) {
byte[] buffer = new byte[4096];
int bytesRead;
while ((bytesRead = fis.read(buffer)) != -1) {
md.update(buffer, 0, bytesRead);
}
}
StringBuilder sb = new StringBuilder();
for (byte b : md.digest()) {
sb.append(String.format("%02x", b & 0xff));
}
return sb.toString();
}
private static void uploadFile(String preSignedUrl, Map<String, String> headers,
String filePath) throws Exception {
try (FileInputStream fis = new FileInputStream(filePath)) {
HttpURLConnection conn = (HttpURLConnection) new URL(preSignedUrl).openConnection();
conn.setRequestMethod("PUT");
conn.setDoOutput(true);
conn.setRequestProperty("X-bailian-extra", headers.get("X-bailian-extra"));
conn.setRequestProperty("Content-Type", headers.get("Content-Type"));
byte[] buffer = new byte[4096];
int bytesRead;
while ((bytesRead = fis.read(buffer)) != -1) {
conn.getOutputStream().write(buffer, 0, bytesRead);
}
if (conn.getResponseCode() != 200) {
throw new RuntimeException("上传失败: " + conn.getResponseCode());
}
}
}
private static void waitForFileParsing(com.aliyun.bailian20231229.Client client,
String workspaceId, String fileId) throws Exception {
while (true) {
DescribeFileResponse response = client.describeFileWithOptions(
workspaceId, fileId, new HashMap<>(), new RuntimeOptions());
String status = response.getBody().getData().getStatus();
if ("PARSE_SUCCESS".equals(status)) break;
if ("PARSE_FAILED".equals(status)) throw new RuntimeException("文档解析失败");
Thread.sleep(5000);
}
}
private static void waitForIndexJob(com.aliyun.bailian20231229.Client client,
String workspaceId, String jobId, String indexId) throws Exception {
while (true) {
GetIndexJobStatusRequest request = new GetIndexJobStatusRequest()
.setIndexId(indexId)
.setJobId(jobId);
GetIndexJobStatusResponse response = client.getIndexJobStatusWithOptions(
workspaceId, request, new HashMap<>(), new RuntimeOptions());
String status = response.getBody().getData().getStatus();
if ("COMPLETED".equals(status)) break;
if ("FAILED".equals(status)) throw new RuntimeException("索引任务失败");
Thread.sleep(5000);
}
}
private static String calculateMD5(InputStream inputStream) throws Exception {
MessageDigest md = MessageDigest.getInstance("MD5");
byte[] buffer = new byte[4096];
int bytesRead;
while ((bytesRead = inputStream.read(buffer)) != -1) {
md.update(buffer, 0, bytesRead);
}
StringBuilder sb = new StringBuilder();
for (byte b : md.digest()) {
sb.append(String.format("%02x", b & 0xff));
}
return sb.toString();
}
private static void uploadFile(String preSignedUrl, Map<String, String> headers,
MultipartFile file) throws Exception {
HttpURLConnection conn = (HttpURLConnection) new URL(preSignedUrl).openConnection();
conn.setRequestMethod("PUT");
conn.setDoOutput(true);
conn.setRequestProperty("X-bailian-extra", headers.get("X-bailian-extra"));
conn.setRequestProperty("Content-Type", headers.get("Content-Type"));
try (InputStream inputStream = file.getInputStream()) {
byte[] buffer = new byte[4096];
int bytesRead;
while ((bytesRead = inputStream.read(buffer)) != -1) {
conn.getOutputStream().write(buffer, 0, bytesRead);
}
}
if (conn.getResponseCode() != 200) {
throw new RuntimeException("上传失败: " + conn.getResponseCode());
}
}
/**
* 初始化客户端Client
*
* @return 配置好的客户端对象
*/
public static com.aliyun.bailian20231229.Client createClient(String accessKeyId, String accessKeySecret) throws Exception {
com.aliyun.teaopenapi.models.Config config = new com.aliyun.teaopenapi.models.Config()
.setAccessKeyId(accessKeyId)
.setAccessKeySecret(accessKeySecret);
// 下方接入地址以公有云的公网接入地址为例,可按需更换接入地址。
config.endpoint = "bailian.cn-beijing.aliyuncs.com";
return new com.aliyun.bailian20231229.Client(config);
}
public static void main(String[] args) throws Exception {
String ALIBABA_CLOUD_ACCESS_KEY_ID = "LTAI5tD5YRKuxWz6Eg7qrM4P";
String ALIBABA_CLOUD_ACCESS_KEY_SECRET = "bO8TBDXflOwbtSKimPpG8XrJnyzgTk";
String WORKSPACE_ID = "llm-4pf5auwewoz34zqu";
String indexId = "b9pvwfqp3d";
String filePath = "D:\\公司经济责任审计方案模板.docx";
Client client = createClient(ALIBABA_CLOUD_ACCESS_KEY_ID, ALIBABA_CLOUD_ACCESS_KEY_SECRET);
uploadDocument(client, WORKSPACE_ID, indexId, filePath);
}
}

View File

@@ -0,0 +1,134 @@
package com.gxwebsoft.ai.util;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import com.aliyun.bailian20231229.models.CreateIndexResponse;
import com.aliyun.bailian20231229.models.DeleteFileResponse;
import com.aliyun.bailian20231229.models.DeleteIndexDocumentResponse;
import com.aliyun.bailian20231229.models.DeleteIndexResponse;
import com.aliyun.bailian20231229.models.ListIndexDocumentsResponse;
import com.aliyun.bailian20231229.models.ListIndicesResponse;
import com.aliyun.bailian20231229.models.RetrieveRequest;
import com.aliyun.bailian20231229.models.RetrieveResponse;
import com.aliyun.teautil.models.RuntimeOptions;
/**
* 知识库工具类
* @author GIIT-YC
*
*/
public class KnowledgeBaseUtil {
/**
* 在指定的知识库中检索信息。
*
* @param client 客户端对象bailian20231229Client
* @param workspaceId 业务空间ID
* @param indexId 知识库ID
* @param query 检索查询语句
* @return 阿里云百炼服务的响应
*/
public static RetrieveResponse retrieveIndex(com.aliyun.bailian20231229.Client client, String workspaceId, String indexId, String query) throws Exception {
RetrieveRequest retrieveRequest = new RetrieveRequest();
retrieveRequest.setIndexId(indexId);
retrieveRequest.setQuery(query);
retrieveRequest.setDenseSimilarityTopK(null);
RuntimeOptions runtime = new RuntimeOptions();
return client.retrieveWithOptions(workspaceId, retrieveRequest, null, runtime);
}
/**
* 在阿里云百炼服务中创建知识库(初始化)。
*
* @param client 客户端对象
* @param workspaceId 业务空间ID
* @param name 知识库名称
* @param desc 知识库描述
* @return 阿里云百炼服务的响应对象
*/
public static CreateIndexResponse createIndex(com.aliyun.bailian20231229.Client client, String workspaceId, String name, String desc) throws Exception {
Map<String, String> headers = new HashMap<>();
com.aliyun.bailian20231229.models.CreateIndexRequest createIndexRequest = new com.aliyun.bailian20231229.models.CreateIndexRequest();
createIndexRequest.setStructureType("unstructured");
createIndexRequest.setName(name);
createIndexRequest.setDescription(desc);
createIndexRequest.setSinkType("DEFAULT");
createIndexRequest.setEmbeddingModelName("text-embedding-v4");
com.aliyun.teautil.models.RuntimeOptions runtime = new com.aliyun.teautil.models.RuntimeOptions();
return client.createIndexWithOptions(workspaceId, createIndexRequest, headers, runtime);
}
/**
* 获取指定业务空间下一个或多个知识库的详细信息
*
* @param client 客户端Client
* @param workspaceId 业务空间ID
* @return 阿里云百炼服务的响应
*/
public static ListIndicesResponse listIndices(com.aliyun.bailian20231229.Client client, String workspaceId) throws Exception {
Map<String, String> headers = new HashMap<>();
com.aliyun.bailian20231229.models.ListIndicesRequest listIndicesRequest = new com.aliyun.bailian20231229.models.ListIndicesRequest();
com.aliyun.teautil.models.RuntimeOptions runtime = new com.aliyun.teautil.models.RuntimeOptions();
return client.listIndicesWithOptions(workspaceId, listIndicesRequest, headers, runtime);
}
/**
* 永久性删除指定的知识库
*
* @param client 客户端Client
* @param workspaceId 业务空间ID
* @param indexId 知识库ID
* @return 阿里云百炼服务的响应
*/
public static DeleteIndexResponse deleteIndex(com.aliyun.bailian20231229.Client client, String workspaceId, String indexId) throws Exception {
Map<String, String> headers = new HashMap<>();
com.aliyun.bailian20231229.models.DeleteIndexRequest deleteIndexRequest = new com.aliyun.bailian20231229.models.DeleteIndexRequest();
deleteIndexRequest.setIndexId(indexId);
com.aliyun.teautil.models.RuntimeOptions runtime = new com.aliyun.teautil.models.RuntimeOptions();
return client.deleteIndexWithOptions(workspaceId, deleteIndexRequest, headers, runtime);
}
/**
* 查询知识库下的文档列表
*
* @param client 客户端Client
* @param workspaceId 业务空间ID
* @param indexId 知识库ID
* @return 阿里云百炼服务的响应
*/
public static ListIndexDocumentsResponse listIndexDocuments(com.aliyun.bailian20231229.Client client, String workspaceId, String indexId) throws Exception {
com.aliyun.bailian20231229.models.ListIndexDocumentsRequest listIndexDocumentsRequest = new com.aliyun.bailian20231229.models.ListIndexDocumentsRequest();
listIndexDocumentsRequest.setIndexId(indexId);
return client.listIndexDocuments(workspaceId, listIndexDocumentsRequest);
}
/**
* 删除知识库下的文档
*
* @param client 客户端Client
* @param workspaceId 业务空间ID
* @param indexId 知识库ID
* @param ids 删除文件ID列表
* @return 阿里云百炼服务的响应
*/
public static DeleteIndexDocumentResponse deleteIndexDocument(com.aliyun.bailian20231229.Client client, String workspaceId, String indexId, List<String> ids) throws Exception {
com.aliyun.bailian20231229.models.DeleteIndexDocumentRequest deleteIndexDocumentRequest = new com.aliyun.bailian20231229.models.DeleteIndexDocumentRequest();
deleteIndexDocumentRequest.setIndexId(indexId);
deleteIndexDocumentRequest.setDocumentIds(ids);
return client.deleteIndexDocument(workspaceId, deleteIndexDocumentRequest);
}
/**
* 删除阿里云应用数据文档
*
* @param client 客户端Client
* @param workspaceId 业务空间ID
* @param fileId 删除文件ID
* @return 阿里云百炼服务的响应
*/
public static DeleteFileResponse deleteAppDocument(com.aliyun.bailian20231229.Client client, String workspaceId, String fileId) throws Exception {
return client.deleteFile(fileId, workspaceId);
}
}

View File

@@ -11,9 +11,12 @@ import com.gxwebsoft.common.core.web.ApiResult;
import com.gxwebsoft.common.core.web.PageResult; import com.gxwebsoft.common.core.web.PageResult;
import com.gxwebsoft.common.core.web.PageParam; import com.gxwebsoft.common.core.web.PageParam;
import com.gxwebsoft.common.core.web.BatchParam; import com.gxwebsoft.common.core.web.BatchParam;
import com.gxwebsoft.ai.service.KnowledgeBaseService;
import com.gxwebsoft.common.core.annotation.OperationLog; import com.gxwebsoft.common.core.annotation.OperationLog;
import io.swagger.v3.oas.annotations.tags.Tag; import io.swagger.v3.oas.annotations.tags.Tag;
import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Operation;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.access.prepost.PreAuthorize; import org.springframework.security.access.prepost.PreAuthorize;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
@@ -32,6 +35,9 @@ import java.util.List;
public class OaCompanyController extends BaseController { public class OaCompanyController extends BaseController {
@Resource @Resource
private OaCompanyService oaCompanyService; private OaCompanyService oaCompanyService;
@Autowired
private KnowledgeBaseService knowledgeBaseService;
@Operation(summary = "分页查询企业信息") @Operation(summary = "分页查询企业信息")
@GetMapping("/page") @GetMapping("/page")
@@ -63,12 +69,16 @@ public class OaCompanyController extends BaseController {
@Operation(summary = "添加企业信息") @Operation(summary = "添加企业信息")
@PostMapping() @PostMapping()
public ApiResult<?> save(@RequestBody OaCompany oaCompany) { public ApiResult<?> save(@RequestBody OaCompany oaCompany) {
if(StrUtil.isEmpty(oaCompany.getCompanyCode())) {
return fail("单位唯一标识不能为空");
}
if (oaCompanyService.save(oaCompany)) { if (oaCompanyService.save(oaCompany)) {
//TODO 查询知识库(kb_name=enterprise.getCreditCode) //查询知识库
String kbId = knowledgeBaseService.getKnowledgeBaseId(oaCompany.getCompanyCode());
//TODO 新建知识库 //新建知识库
String kbId = "pggi9mpair"; if(StrUtil.isEmpty(kbId)) {
kbId = knowledgeBaseService.createKnowledgeBase(oaCompany.getCompanyName(), oaCompany.getCompanyCode());
}
//绑定知识库 //绑定知识库
oaCompany.setKbId(kbId); oaCompany.setKbId(kbId);
oaCompanyService.updateById(oaCompany); oaCompanyService.updateById(oaCompany);
@@ -81,12 +91,16 @@ public class OaCompanyController extends BaseController {
@Operation(summary = "修改企业信息") @Operation(summary = "修改企业信息")
@PutMapping() @PutMapping()
public ApiResult<?> update(@RequestBody OaCompany oaCompany) { public ApiResult<?> update(@RequestBody OaCompany oaCompany) {
if(StrUtil.isEmpty(oaCompany.getCompanyCode())) {
return fail("单位唯一标识不能为空");
}
if(StrUtil.isEmpty(oaCompany.getKbId())) { if(StrUtil.isEmpty(oaCompany.getKbId())) {
//TODO 查询知识库 //查询知识库
String kbId = knowledgeBaseService.getKnowledgeBaseId(oaCompany.getCompanyCode());
//TODO 新建知识库 //新建知识库
String kbId = "pggi9mpair"; if(StrUtil.isEmpty(kbId)) {
kbId = knowledgeBaseService.createKnowledgeBase(oaCompany.getCompanyName(), oaCompany.getCompanyCode());
}
//绑定知识库 //绑定知识库
oaCompany.setKbId(kbId); oaCompany.setKbId(kbId);
} }