Merge remote-tracking branch 'origin/main'

This commit is contained in:
2026-05-06 11:51:05 +08:00
3 changed files with 62 additions and 2 deletions

View File

@@ -16,6 +16,7 @@ import com.gxwebsoft.ai.factory.KnowledgeBaseClientFactory;
import com.gxwebsoft.ai.service.AiCloudFileService;
import com.gxwebsoft.ai.util.AiCloudKnowledgeBaseUtil;
import com.gxwebsoft.common.core.context.TenantContext;
import com.gxwebsoft.oa.service.OaCompanyService;
import cn.hutool.core.util.StrUtil;
import cn.hutool.http.HttpUtil;
@@ -41,6 +42,9 @@ public abstract class AbstractAuditContentService {
@Autowired
protected AiCloudFileService aiCloudFileService;
@Autowired
protected OaCompanyService oaCompanyService;
protected static final String DIFY_WORKFLOW_URL = "http://1.14.159.185:8180/v1/workflows/run";
// 用于同步的锁对象池
@@ -195,6 +199,9 @@ public abstract class AbstractAuditContentService {
* 查询知识库通用方法
*/
protected List<String> queryKnowledgeBase(String kbId, List<String> queries, int topK) {
// 递归获取所有父级单位的kbId
Set<String> allKbIds = TenantContext.callIgnoreTenant(() -> oaCompanyService.getAllParentKbIds(kbId));
Object lock = kbLocks.computeIfAbsent(kbId, k -> new Object());
synchronized (lock) {
@@ -206,8 +213,11 @@ public abstract class AbstractAuditContentService {
fileIds = requestFileIds.get();
}
// 1. 收集所有节点和文档ID
List<RetrieveResponseBodyDataNodes> allNodes = collectKnowledgeNodes(kbId, queries, topK, fileIds);
// 1. 收集所有节点和文档ID包含所有父级kbId
List<RetrieveResponseBodyDataNodes> allNodes = new ArrayList<>();
for (String currentKbId : allKbIds) {
allNodes.addAll(collectKnowledgeNodes(currentKbId, queries, topK, fileIds));
}
if (allNodes.isEmpty()) {
return new ArrayList<>();
}

View File

@@ -6,6 +6,7 @@ import com.gxwebsoft.oa.entity.OaCompany;
import com.gxwebsoft.oa.param.OaCompanyParam;
import java.util.List;
import java.util.Set;
/**
* 企业信息Service
@@ -61,4 +62,12 @@ public interface OaCompanyService extends IService<OaCompany> {
*/
boolean removeCompanyKnowledgeBase(Integer companyId);
/**
* 递归获取所有父级单位的kbId包含当前单位
*
* @param kbId 知识库ID
* @return 所有父级单位的kbId集合
*/
Set<String> getAllParentKbIds(String kbId);
}

View File

@@ -28,7 +28,9 @@ import org.springframework.stereotype.Service;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
/**
@@ -189,4 +191,43 @@ public class OaCompanyServiceImpl extends ServiceImpl<OaCompanyMapper, OaCompany
}
return ret;
}
@Override
public Set<String> getAllParentKbIds(String kbId) {
Set<String> kbIds = new LinkedHashSet<>();
kbIds.add(kbId);
// 根据kbId查找OaCompany
OaCompany company = getOne(new LambdaQueryWrapper<OaCompany>().eq(OaCompany::getKbId, kbId));
if (company != null && StrUtil.isNotBlank(company.getParentCompany())) {
// 递归查找父级
collectParentKbIds(company.getParentCompany(), kbIds);
}
return kbIds;
}
/**
* 递归收集父级kbId
*/
private void collectParentKbIds(String parentCompanyId, Set<String> kbIds) {
if (StrUtil.isBlank(parentCompanyId)) {
return;
}
// 查找父级公司
OaCompany parentCompany = getById(Integer.valueOf(parentCompanyId));
if (parentCompany != null) {
// 添加父级kbId
if (StrUtil.isNotBlank(parentCompany.getKbId())) {
kbIds.add(parentCompany.getKbId());
}
// 继续递归查找父级的父级
if (StrUtil.isNotBlank(parentCompany.getParentCompany())) {
collectParentKbIds(parentCompany.getParentCompany(), kbIds);
}
}
}
}