package com.aiteacher.retrieval; import com.aiteacher.document.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.document.Document; import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder; import org.springframework.stereotype.Service; import java.util.*; /** * Dual-modality retriever: searches text chunks and figure captions independently, * then expands text hits to their parent sections and merges linked figures. */ @Service public class NeurosurgeryRetriever { private static final Logger log = LoggerFactory.getLogger(NeurosurgeryRetriever.class); private static final int TEXT_TOP_K = 5; private static final int FIGURE_TOP_K = 3; private final VectorStore vectorStore; private final SectionRepository sectionRepository; private final FigureRepository figureRepository; private final ChunkFigureRefRepository chunkFigureRefRepository; public NeurosurgeryRetriever(VectorStore vectorStore, SectionRepository sectionRepository, FigureRepository figureRepository, ChunkFigureRefRepository chunkFigureRefRepository) { this.vectorStore = vectorStore; this.sectionRepository = sectionRepository; this.figureRepository = figureRepository; this.chunkFigureRefRepository = chunkFigureRefRepository; } public RetrievalResult retrieve(String query, UUID bookId) { FilterExpressionBuilder b = new FilterExpressionBuilder(); // 1. Text chunk search List textHits = vectorStore.similaritySearch( SearchRequest.builder() .query(query) .topK(TEXT_TOP_K) .filterExpression(b.and( b.eq("type", "TEXT"), b.eq("book_id", bookId.toString()) ).build()) .build() ); // 2. Figure caption search (independent topK) List figureHits = vectorStore.similaritySearch( SearchRequest.builder() .query(query) .topK(FIGURE_TOP_K) .filterExpression(b.and( b.eq("type", "FIGURE"), b.eq("book_id", bookId.toString()) ).build()) .build() ); // 3. Expand text chunks to parent sections from Postgres List sectionIds = textHits.stream() .map(d -> (String) d.getMetadata().get("section_id")) .filter(Objects::nonNull) .distinct() .toList(); List sections = sectionIds.isEmpty() ? List.of() : sectionRepository.findAllById(sectionIds); // 4. Fetch figures explicitly linked to retrieved chunks List chunkIds = textHits.stream() .map(d -> { try { return UUID.fromString(d.getId()); } catch (Exception e) { return null; } }) .filter(Objects::nonNull) .toList(); List linkedFigureIds = chunkIds.isEmpty() ? List.of() : chunkFigureRefRepository.findByChunkIdIn(chunkIds) .stream().map(ChunkFigureRefEntity::getFigureId).distinct().toList(); List linkedFigures = linkedFigureIds.isEmpty() ? List.of() : figureRepository.findAllById(linkedFigureIds); // 5. Collect figures from semantic figure search List semanticFigureIds = figureHits.stream() .map(d -> (String) d.getMetadata().get("figure_id")) .filter(Objects::nonNull) .toList(); List semanticFigures = semanticFigureIds.isEmpty() ? List.of() : figureRepository.findAllById(semanticFigureIds); // 6. Merge and deduplicate figures by figureId (linked figures take precedence) Map merged = new LinkedHashMap<>(); linkedFigures.forEach(f -> merged.put(f.getId(), f)); semanticFigures.forEach(f -> merged.putIfAbsent(f.getId(), f)); log.debug("Retrieved {} sections, {} figures for query", sections.size(), merged.size()); return new RetrievalResult(sections, new ArrayList<>(merged.values())); } }