Merge remote-tracking branch 'origin/main'

This commit is contained in:
2026-05-06 11:51:05 +08:00
3 changed files with 62 additions and 2 deletions

View File

@@ -16,6 +16,7 @@ import com.gxwebsoft.ai.factory.KnowledgeBaseClientFactory;
import com.gxwebsoft.ai.service.AiCloudFileService; import com.gxwebsoft.ai.service.AiCloudFileService;
import com.gxwebsoft.ai.util.AiCloudKnowledgeBaseUtil; import com.gxwebsoft.ai.util.AiCloudKnowledgeBaseUtil;
import com.gxwebsoft.common.core.context.TenantContext; import com.gxwebsoft.common.core.context.TenantContext;
import com.gxwebsoft.oa.service.OaCompanyService;
import cn.hutool.core.util.StrUtil; import cn.hutool.core.util.StrUtil;
import cn.hutool.http.HttpUtil; import cn.hutool.http.HttpUtil;
@@ -41,6 +42,9 @@ public abstract class AbstractAuditContentService {
@Autowired @Autowired
protected AiCloudFileService aiCloudFileService; protected AiCloudFileService aiCloudFileService;
@Autowired
protected OaCompanyService oaCompanyService;
protected static final String DIFY_WORKFLOW_URL = "http://1.14.159.185:8180/v1/workflows/run"; protected static final String DIFY_WORKFLOW_URL = "http://1.14.159.185:8180/v1/workflows/run";
// 用于同步的锁对象池 // 用于同步的锁对象池
@@ -195,6 +199,9 @@ public abstract class AbstractAuditContentService {
* 查询知识库通用方法 * 查询知识库通用方法
*/ */
protected List<String> queryKnowledgeBase(String kbId, List<String> queries, int topK) { protected List<String> queryKnowledgeBase(String kbId, List<String> queries, int topK) {
// 递归获取所有父级单位的kbId
Set<String> allKbIds = TenantContext.callIgnoreTenant(() -> oaCompanyService.getAllParentKbIds(kbId));
Object lock = kbLocks.computeIfAbsent(kbId, k -> new Object()); Object lock = kbLocks.computeIfAbsent(kbId, k -> new Object());
synchronized (lock) { synchronized (lock) {
@@ -206,8 +213,11 @@ public abstract class AbstractAuditContentService {
fileIds = requestFileIds.get(); fileIds = requestFileIds.get();
} }
// 1. 收集所有节点和文档ID // 1. 收集所有节点和文档ID包含所有父级kbId
List<RetrieveResponseBodyDataNodes> allNodes = collectKnowledgeNodes(kbId, queries, topK, fileIds); List<RetrieveResponseBodyDataNodes> allNodes = new ArrayList<>();
for (String currentKbId : allKbIds) {
allNodes.addAll(collectKnowledgeNodes(currentKbId, queries, topK, fileIds));
}
if (allNodes.isEmpty()) { if (allNodes.isEmpty()) {
return new ArrayList<>(); return new ArrayList<>();
} }

View File

@@ -6,6 +6,7 @@ import com.gxwebsoft.oa.entity.OaCompany;
import com.gxwebsoft.oa.param.OaCompanyParam; import com.gxwebsoft.oa.param.OaCompanyParam;
import java.util.List; import java.util.List;
import java.util.Set;
/** /**
* 企业信息Service * 企业信息Service
@@ -61,4 +62,12 @@ public interface OaCompanyService extends IService<OaCompany> {
*/ */
boolean removeCompanyKnowledgeBase(Integer companyId); boolean removeCompanyKnowledgeBase(Integer companyId);
/**
* 递归获取所有父级单位的kbId包含当前单位
*
* @param kbId 知识库ID
* @return 所有父级单位的kbId集合
*/
Set<String> getAllParentKbIds(String kbId);
} }

View File

@@ -28,7 +28,9 @@ import org.springframework.stereotype.Service;
import java.time.LocalDateTime; import java.time.LocalDateTime;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
/** /**
@@ -189,4 +191,43 @@ public class OaCompanyServiceImpl extends ServiceImpl<OaCompanyMapper, OaCompany
} }
return ret; return ret;
} }
@Override
public Set<String> getAllParentKbIds(String kbId) {
Set<String> kbIds = new LinkedHashSet<>();
kbIds.add(kbId);
// 根据kbId查找OaCompany
OaCompany company = getOne(new LambdaQueryWrapper<OaCompany>().eq(OaCompany::getKbId, kbId));
if (company != null && StrUtil.isNotBlank(company.getParentCompany())) {
// 递归查找父级
collectParentKbIds(company.getParentCompany(), kbIds);
}
return kbIds;
}
/**
* 递归收集父级kbId
*/
private void collectParentKbIds(String parentCompanyId, Set<String> kbIds) {
if (StrUtil.isBlank(parentCompanyId)) {
return;
}
// 查找父级公司
OaCompany parentCompany = getById(Integer.valueOf(parentCompanyId));
if (parentCompany != null) {
// 添加父级kbId
if (StrUtil.isNotBlank(parentCompany.getKbId())) {
kbIds.add(parentCompany.getKbId());
}
// 继续递归查找父级的父级
if (StrUtil.isNotBlank(parentCompany.getParentCompany())) {
collectParentKbIds(parentCompany.getParentCompany(), kbIds);
}
}
}
} }