enhance rag retrieval + summary

This commit is contained in:
Adrien
2026-04-07 22:39:28 +02:00
parent 0cf318f0a7
commit aee6a9dfba
34 changed files with 2306 additions and 279 deletions
@@ -5,10 +5,11 @@ import com.aiteacher.book.BookStatus;
import com.aiteacher.book.NoKnowledgeSourceException;
import com.aiteacher.document.FigureEntity;
import com.aiteacher.document.SectionEntity;
import com.aiteacher.retrieval.CitationValidatorService;
import com.aiteacher.retrieval.LabelledContext;
import com.aiteacher.retrieval.NeurosurgeryRetriever;
import com.aiteacher.retrieval.QueryExpansionService;
import com.aiteacher.retrieval.RetrievalResult;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.stereotype.Service;
@@ -17,8 +18,6 @@ import java.util.*;
@Service
public class ChatService {
private static final Logger log = LoggerFactory.getLogger(ChatService.class);
private static final String SYSTEM_PROMPT = """
You are an expert neurosurgery educator assistant. Answer questions using the
medical textbook content provided to you as context.
@@ -29,8 +28,8 @@ public class ChatService {
- Build answers from what is present: procedures, conditions, techniques, and descriptions all contribute; combine them into a rich, structured response
- Use clear structure: headings, bullet points, or numbered steps where appropriate to maximize clarity
- Only say you cannot answer if the context is entirely unrelated to the question
- Cite sources for each major point (book title and page number from the context)
- When referencing diagrams or figures, cite them as [Fig. X, p.N]
- Cite sources for each major claim using the reference labels from the context (e.g. [S1], [F2]). Prefer these labels over inventing page numbers, but you may also describe the source naturally if needed.
- When referencing diagrams or figures, prefer their label from the context (e.g. [F1])
- Maintain continuity with the conversation history
- Never fabricate clinical information not present in the context
""";
@@ -40,17 +39,23 @@ public class ChatService {
private final ChatSessionRepository sessionRepository;
private final MessageRepository messageRepository;
private final NeurosurgeryRetriever retriever;
private final QueryExpansionService queryExpansionService;
private final CitationValidatorService citationValidatorService;
public ChatService(ChatClient chatClient,
BookRepository bookRepository,
ChatSessionRepository sessionRepository,
MessageRepository messageRepository,
NeurosurgeryRetriever retriever) {
NeurosurgeryRetriever retriever,
QueryExpansionService queryExpansionService,
CitationValidatorService citationValidatorService) {
this.chatClient = chatClient;
this.bookRepository = bookRepository;
this.sessionRepository = sessionRepository;
this.messageRepository = messageRepository;
this.retriever = retriever;
this.queryExpansionService = queryExpansionService;
this.citationValidatorService = citationValidatorService;
}
public ChatSession createSession(String topicId) {
@@ -85,25 +90,34 @@ public class ChatService {
List<Message> history = messageRepository.findBySessionIdOrderByCreatedAtAsc(sessionId);
String fullQuestion = buildQuestionWithHistory(history, userContent, session.getTopicId());
// Retrieve context from all ready books (aggregate across books)
// Expand only the current user question to clinical terminology for retrieval (US1).
// fullQuestion (which includes conversation history) is used for the LLM context prompt,
// but retrieval should be driven by a concise clinical rewrite of the actual question.
String retrievalQuery = queryExpansionService.expand(userContent).rewritten();
// Retrieve context from all ready books using the expanded query
List<SectionEntity> allSections = new ArrayList<>();
List<FigureEntity> allFigures = new ArrayList<>();
for (com.aiteacher.book.Book book : readyBooks) {
RetrievalResult result = retriever.retrieve(fullQuestion, book.getId());
RetrievalResult result = retriever.retrieve(retrievalQuery, book.getId());
allSections.addAll(result.parentSections());
allFigures.addAll(result.figures());
}
// Build LLM prompt with section full texts and figure references
String contextPrompt = buildContextPrompt(fullQuestion, allSections, allFigures);
// Build labelled context prompt (US2): assigns [S1]/[F1] labels to each source
LabelledContext ctx = buildContextPrompt(fullQuestion, allSections, allFigures);
String assistantContent = chatClient.prompt()
// Generate answer
String rawContent = chatClient.prompt()
.system(SYSTEM_PROMPT)
.user(contextPrompt)
.user(ctx.promptText())
.call()
.content();
// Build sources list with TEXT and FIGURE entries
// Strip any citation labels not present in the retrieved context (US2)
String assistantContent = citationValidatorService.validate(rawContent, ctx.allLabels());
// Attach sources with their ref-labels for frontend traceability
List<Map<String, Object>> sources = buildSources(allSections, allFigures);
Message assistantMessage = new Message(sessionId, MessageRole.ASSISTANT, assistantContent);
@@ -126,51 +140,71 @@ public class ChatService {
// Private helpers
// -------------------------------------------------------------------------
private String buildContextPrompt(String question,
List<SectionEntity> sections,
List<FigureEntity> figures) {
/**
* Builds the LLM context prompt, tagging each section as [S1], [S2]… and
* each figure as [F1], [F2]… so the model can cite only known sources.
*/
private LabelledContext buildContextPrompt(String question,
List<SectionEntity> sections,
List<FigureEntity> figures) {
Map<String, SectionEntity> sectionLabels = new LinkedHashMap<>();
Map<String, FigureEntity> figureLabels = new LinkedHashMap<>();
StringBuilder sb = new StringBuilder();
if (!sections.isEmpty()) {
sb.append("CONTEXT:\n\n");
for (SectionEntity section : sections) {
sb.append("[").append(section.getTitle())
.append(", p.").append(section.getPageStart()).append("]\n");
for (int i = 0; i < sections.size(); i++) {
SectionEntity section = sections.get(i);
String label = "S" + (i + 1);
sectionLabels.put(label, section);
sb.append("[").append(label).append("] ")
.append(section.getTitle())
.append(", p.").append(section.getPageStart()).append("\n");
sb.append(section.getFullText()).append("\n\n");
}
}
if (!figures.isEmpty()) {
sb.append("AVAILABLE FIGURES:\n");
for (FigureEntity figure : figures) {
sb.append("- ").append(figure.getLabel() != null ? figure.getLabel() : "Figure")
for (int i = 0; i < figures.size(); i++) {
FigureEntity figure = figures.get(i);
String label = "F" + (i + 1);
figureLabels.put(label, figure);
sb.append("[").append(label).append("] ")
.append(figure.getLabel() != null ? figure.getLabel() : "Figure")
.append(" (p.").append(figure.getPage()).append("): ")
.append(figure.getCaption() != null ? figure.getCaption() : "")
.append("\n");
}
sb.append("\nWhen referencing diagrams, cite them as [Fig. X, p.N].\n\n");
sb.append("\nWhen referencing diagrams, use their label from the context (e.g. [F1]).\n\n");
}
sb.append("QUESTION:\n").append(question);
return sb.toString();
return new LabelledContext(sectionLabels, figureLabels, sb.toString());
}
private List<Map<String, Object>> buildSources(List<SectionEntity> sections,
List<FigureEntity> figures) {
List<Map<String, Object>> sources = new ArrayList<>();
for (SectionEntity section : sections) {
for (int i = 0; i < sections.size(); i++) {
SectionEntity section = sections.get(i);
Map<String, Object> source = new LinkedHashMap<>();
source.put("type", "TEXT");
source.put("refLabel", "S" + (i + 1));
source.put("bookId", section.getBookId());
source.put("bookTitle", deriveTitleFromSection(section));
source.put("page", section.getPageStart());
source.put("chunkText", truncate(section.getFullText(), 500));
sources.add(source);
}
for (FigureEntity figure : figures) {
for (int i = 0; i < figures.size(); i++) {
FigureEntity figure = figures.get(i);
Map<String, Object> source = new LinkedHashMap<>();
source.put("type", "FIGURE");
source.put("refLabel", "F" + (i + 1));
source.put("bookId", figure.getBookId());
source.put("bookTitle", bookRepository.findById(figure.getBookId())
.map(com.aiteacher.book.Book::getTitle).orElse("Book"));
source.put("page", figure.getPage());
@@ -178,7 +212,6 @@ public class ChatService {
source.put("label", figure.getLabel() != null ? figure.getLabel() : "");
source.put("caption", figure.getCaption() != null ? figure.getCaption() : "");
source.put("figureType", figure.getFigureType().name());
// imageUrl assembled from relative path: figures/{bookId}/{filename}
String filename = figure.getImagePath().substring(
figure.getImagePath().lastIndexOf('/') + 1);
source.put("imageUrl", "/api/v1/figures/" + figure.getBookId() + "/" + filename);