112 lines
4.4 KiB
Java
112 lines
4.4 KiB
Java
package com.aiteacher.retrieval;
|
|
|
|
import com.aiteacher.document.*;
|
|
import org.slf4j.Logger;
|
|
import org.slf4j.LoggerFactory;
|
|
import org.springframework.ai.document.Document;
|
|
import org.springframework.ai.vectorstore.SearchRequest;
|
|
import org.springframework.ai.vectorstore.VectorStore;
|
|
import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
|
|
import org.springframework.stereotype.Service;
|
|
|
|
import java.util.*;
|
|
|
|
/**
|
|
* Dual-modality retriever: searches text chunks and figure captions independently,
|
|
* then expands text hits to their parent sections and merges linked figures.
|
|
*/
|
|
@Service
|
|
public class NeurosurgeryRetriever {
|
|
|
|
private static final Logger log = LoggerFactory.getLogger(NeurosurgeryRetriever.class);
|
|
|
|
private static final int TEXT_TOP_K = 5;
|
|
private static final int FIGURE_TOP_K = 3;
|
|
|
|
private final VectorStore vectorStore;
|
|
private final SectionRepository sectionRepository;
|
|
private final FigureRepository figureRepository;
|
|
private final ChunkFigureRefRepository chunkFigureRefRepository;
|
|
|
|
public NeurosurgeryRetriever(VectorStore vectorStore,
|
|
SectionRepository sectionRepository,
|
|
FigureRepository figureRepository,
|
|
ChunkFigureRefRepository chunkFigureRefRepository) {
|
|
this.vectorStore = vectorStore;
|
|
this.sectionRepository = sectionRepository;
|
|
this.figureRepository = figureRepository;
|
|
this.chunkFigureRefRepository = chunkFigureRefRepository;
|
|
}
|
|
|
|
public RetrievalResult retrieve(String query, UUID bookId) {
|
|
FilterExpressionBuilder b = new FilterExpressionBuilder();
|
|
|
|
// 1. Text chunk search
|
|
List<Document> textHits = vectorStore.similaritySearch(
|
|
SearchRequest.builder()
|
|
.query(query)
|
|
.topK(TEXT_TOP_K)
|
|
.filterExpression(b.and(
|
|
b.eq("type", "TEXT"),
|
|
b.eq("book_id", bookId.toString())
|
|
).build())
|
|
.build()
|
|
);
|
|
|
|
// 2. Figure caption search (independent topK)
|
|
List<Document> figureHits = vectorStore.similaritySearch(
|
|
SearchRequest.builder()
|
|
.query(query)
|
|
.topK(FIGURE_TOP_K)
|
|
.filterExpression(b.and(
|
|
b.eq("type", "FIGURE"),
|
|
b.eq("book_id", bookId.toString())
|
|
).build())
|
|
.build()
|
|
);
|
|
|
|
// 3. Expand text chunks to parent sections from Postgres
|
|
List<String> sectionIds = textHits.stream()
|
|
.map(d -> (String) d.getMetadata().get("section_id"))
|
|
.filter(Objects::nonNull)
|
|
.distinct()
|
|
.toList();
|
|
List<SectionEntity> sections = sectionIds.isEmpty()
|
|
? List.of()
|
|
: sectionRepository.findAllById(sectionIds);
|
|
|
|
// 4. Fetch figures explicitly linked to retrieved chunks
|
|
List<UUID> chunkIds = textHits.stream()
|
|
.map(d -> {
|
|
try { return UUID.fromString(d.getId()); }
|
|
catch (Exception e) { return null; }
|
|
})
|
|
.filter(Objects::nonNull)
|
|
.toList();
|
|
List<String> linkedFigureIds = chunkIds.isEmpty()
|
|
? List.of()
|
|
: chunkFigureRefRepository.findByChunkIdIn(chunkIds)
|
|
.stream().map(ChunkFigureRefEntity::getFigureId).distinct().toList();
|
|
List<FigureEntity> linkedFigures = linkedFigureIds.isEmpty()
|
|
? List.of()
|
|
: figureRepository.findAllById(linkedFigureIds);
|
|
|
|
// 5. Collect figures from semantic figure search
|
|
List<String> semanticFigureIds = figureHits.stream()
|
|
.map(d -> (String) d.getMetadata().get("figure_id"))
|
|
.filter(Objects::nonNull)
|
|
.toList();
|
|
List<FigureEntity> semanticFigures = semanticFigureIds.isEmpty()
|
|
? List.of()
|
|
: figureRepository.findAllById(semanticFigureIds);
|
|
|
|
// 6. Merge and deduplicate figures by figureId (linked figures take precedence)
|
|
Map<String, FigureEntity> merged = new LinkedHashMap<>();
|
|
linkedFigures.forEach(f -> merged.put(f.getId(), f));
|
|
semanticFigures.forEach(f -> merged.putIfAbsent(f.getId(), f));
|
|
|
|
log.debug("Retrieved {} sections, {} figures for query", sections.size(), merged.size());
|
|
return new RetrievalResult(sections, new ArrayList<>(merged.values()));
|
|
}
|
|
}
|