Files
ai-teacher/backend/src/main/java/com/aiteacher/retrieval/NeurosurgeryRetriever.java
T

112 lines
4.4 KiB
Java

package com.aiteacher.retrieval;
import com.aiteacher.document.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
import org.springframework.stereotype.Service;
import java.util.*;
/**
* Dual-modality retriever: searches text chunks and figure captions independently,
* then expands text hits to their parent sections and merges linked figures.
*/
@Service
public class NeurosurgeryRetriever {
private static final Logger log = LoggerFactory.getLogger(NeurosurgeryRetriever.class);
private static final int TEXT_TOP_K = 5;
private static final int FIGURE_TOP_K = 3;
private final VectorStore vectorStore;
private final SectionRepository sectionRepository;
private final FigureRepository figureRepository;
private final ChunkFigureRefRepository chunkFigureRefRepository;
public NeurosurgeryRetriever(VectorStore vectorStore,
SectionRepository sectionRepository,
FigureRepository figureRepository,
ChunkFigureRefRepository chunkFigureRefRepository) {
this.vectorStore = vectorStore;
this.sectionRepository = sectionRepository;
this.figureRepository = figureRepository;
this.chunkFigureRefRepository = chunkFigureRefRepository;
}
public RetrievalResult retrieve(String query, UUID bookId) {
FilterExpressionBuilder b = new FilterExpressionBuilder();
// 1. Text chunk search
List<Document> textHits = vectorStore.similaritySearch(
SearchRequest.builder()
.query(query)
.topK(TEXT_TOP_K)
.filterExpression(b.and(
b.eq("type", "TEXT"),
b.eq("book_id", bookId.toString())
).build())
.build()
);
// 2. Figure caption search (independent topK)
List<Document> figureHits = vectorStore.similaritySearch(
SearchRequest.builder()
.query(query)
.topK(FIGURE_TOP_K)
.filterExpression(b.and(
b.eq("type", "FIGURE"),
b.eq("book_id", bookId.toString())
).build())
.build()
);
// 3. Expand text chunks to parent sections from Postgres
List<String> sectionIds = textHits.stream()
.map(d -> (String) d.getMetadata().get("section_id"))
.filter(Objects::nonNull)
.distinct()
.toList();
List<SectionEntity> sections = sectionIds.isEmpty()
? List.of()
: sectionRepository.findAllById(sectionIds);
// 4. Fetch figures explicitly linked to retrieved chunks
List<UUID> chunkIds = textHits.stream()
.map(d -> {
try { return UUID.fromString(d.getId()); }
catch (Exception e) { return null; }
})
.filter(Objects::nonNull)
.toList();
List<String> linkedFigureIds = chunkIds.isEmpty()
? List.of()
: chunkFigureRefRepository.findByChunkIdIn(chunkIds)
.stream().map(ChunkFigureRefEntity::getFigureId).distinct().toList();
List<FigureEntity> linkedFigures = linkedFigureIds.isEmpty()
? List.of()
: figureRepository.findAllById(linkedFigureIds);
// 5. Collect figures from semantic figure search
List<String> semanticFigureIds = figureHits.stream()
.map(d -> (String) d.getMetadata().get("figure_id"))
.filter(Objects::nonNull)
.toList();
List<FigureEntity> semanticFigures = semanticFigureIds.isEmpty()
? List.of()
: figureRepository.findAllById(semanticFigureIds);
// 6. Merge and deduplicate figures by figureId (linked figures take precedence)
Map<String, FigureEntity> merged = new LinkedHashMap<>();
linkedFigures.forEach(f -> merged.put(f.getId(), f));
semanticFigures.forEach(f -> merged.putIfAbsent(f.getId(), f));
log.debug("Retrieved {} sections, {} figures for query", sections.size(), merged.size());
return new RetrievalResult(sections, new ArrayList<>(merged.values()));
}
}