diff --git a/src/main/java/com/gxwebsoft/ai/service/impl/AbstractAuditContentService.java b/src/main/java/com/gxwebsoft/ai/service/impl/AbstractAuditContentService.java index ccfbb21..cc2b2ff 100644 --- a/src/main/java/com/gxwebsoft/ai/service/impl/AbstractAuditContentService.java +++ b/src/main/java/com/gxwebsoft/ai/service/impl/AbstractAuditContentService.java @@ -16,6 +16,7 @@ import com.gxwebsoft.ai.factory.KnowledgeBaseClientFactory; import com.gxwebsoft.ai.service.AiCloudFileService; import com.gxwebsoft.ai.util.AiCloudKnowledgeBaseUtil; import com.gxwebsoft.common.core.context.TenantContext; +import com.gxwebsoft.oa.service.OaCompanyService; import cn.hutool.core.util.StrUtil; import cn.hutool.http.HttpUtil; @@ -41,6 +42,9 @@ public abstract class AbstractAuditContentService { @Autowired protected AiCloudFileService aiCloudFileService; + @Autowired + protected OaCompanyService oaCompanyService; + protected static final String DIFY_WORKFLOW_URL = "http://1.14.159.185:8180/v1/workflows/run"; // 用于同步的锁对象池 @@ -195,6 +199,9 @@ public abstract class AbstractAuditContentService { * 查询知识库通用方法 */ protected List queryKnowledgeBase(String kbId, List queries, int topK) { + // 递归获取所有父级单位的kbId + Set allKbIds = TenantContext.callIgnoreTenant(() -> oaCompanyService.getAllParentKbIds(kbId)); + Object lock = kbLocks.computeIfAbsent(kbId, k -> new Object()); synchronized (lock) { @@ -206,8 +213,11 @@ public abstract class AbstractAuditContentService { fileIds = requestFileIds.get(); } - // 1. 收集所有节点和文档ID - List allNodes = collectKnowledgeNodes(kbId, queries, topK, fileIds); + // 1. 收集所有节点和文档ID(包含所有父级kbId) + List allNodes = new ArrayList<>(); + for (String currentKbId : allKbIds) { + allNodes.addAll(collectKnowledgeNodes(currentKbId, queries, topK, fileIds)); + } if (allNodes.isEmpty()) { return new ArrayList<>(); } diff --git a/src/main/java/com/gxwebsoft/oa/service/OaCompanyService.java b/src/main/java/com/gxwebsoft/oa/service/OaCompanyService.java index dd91ee7..98fdc25 100644 --- a/src/main/java/com/gxwebsoft/oa/service/OaCompanyService.java +++ b/src/main/java/com/gxwebsoft/oa/service/OaCompanyService.java @@ -6,6 +6,7 @@ import com.gxwebsoft.oa.entity.OaCompany; import com.gxwebsoft.oa.param.OaCompanyParam; import java.util.List; +import java.util.Set; /** * 企业信息Service @@ -61,4 +62,12 @@ public interface OaCompanyService extends IService { */ boolean removeCompanyKnowledgeBase(Integer companyId); + /** + * 递归获取所有父级单位的kbId(包含当前单位) + * + * @param kbId 知识库ID + * @return 所有父级单位的kbId集合 + */ + Set getAllParentKbIds(String kbId); + } diff --git a/src/main/java/com/gxwebsoft/oa/service/impl/OaCompanyServiceImpl.java b/src/main/java/com/gxwebsoft/oa/service/impl/OaCompanyServiceImpl.java index f1b8937..be4b937 100644 --- a/src/main/java/com/gxwebsoft/oa/service/impl/OaCompanyServiceImpl.java +++ b/src/main/java/com/gxwebsoft/oa/service/impl/OaCompanyServiceImpl.java @@ -28,7 +28,9 @@ import org.springframework.stereotype.Service; import java.time.LocalDateTime; import java.util.ArrayList; import java.util.Arrays; +import java.util.LinkedHashSet; import java.util.List; +import java.util.Set; import java.util.stream.Collectors; /** @@ -189,4 +191,43 @@ public class OaCompanyServiceImpl extends ServiceImpl getAllParentKbIds(String kbId) { + Set kbIds = new LinkedHashSet<>(); + kbIds.add(kbId); + + // 根据kbId查找OaCompany + OaCompany company = getOne(new LambdaQueryWrapper().eq(OaCompany::getKbId, kbId)); + + if (company != null && StrUtil.isNotBlank(company.getParentCompany())) { + // 递归查找父级 + collectParentKbIds(company.getParentCompany(), kbIds); + } + + return kbIds; + } + + /** + * 递归收集父级kbId + */ + private void collectParentKbIds(String parentCompanyId, Set kbIds) { + if (StrUtil.isBlank(parentCompanyId)) { + return; + } + + // 查找父级公司 + OaCompany parentCompany = getById(Integer.valueOf(parentCompanyId)); + + if (parentCompany != null) { + // 添加父级kbId + if (StrUtil.isNotBlank(parentCompany.getKbId())) { + kbIds.add(parentCompany.getKbId()); + } + // 继续递归查找父级的父级 + if (StrUtil.isNotBlank(parentCompany.getParentCompany())) { + collectParentKbIds(parentCompany.getParentCompany(), kbIds); + } + } + } }