enhance rag retrieval + summary
This commit is contained in:
@@ -5,10 +5,11 @@ import com.aiteacher.book.BookStatus;
|
||||
import com.aiteacher.book.NoKnowledgeSourceException;
|
||||
import com.aiteacher.document.FigureEntity;
|
||||
import com.aiteacher.document.SectionEntity;
|
||||
import com.aiteacher.retrieval.CitationValidatorService;
|
||||
import com.aiteacher.retrieval.LabelledContext;
|
||||
import com.aiteacher.retrieval.NeurosurgeryRetriever;
|
||||
import com.aiteacher.retrieval.QueryExpansionService;
|
||||
import com.aiteacher.retrieval.RetrievalResult;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@@ -17,8 +18,6 @@ import java.util.*;
|
||||
@Service
|
||||
public class ChatService {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(ChatService.class);
|
||||
|
||||
private static final String SYSTEM_PROMPT = """
|
||||
You are an expert neurosurgery educator assistant. Answer questions using the
|
||||
medical textbook content provided to you as context.
|
||||
@@ -29,8 +28,8 @@ public class ChatService {
|
||||
- Build answers from what is present: procedures, conditions, techniques, and descriptions all contribute; combine them into a rich, structured response
|
||||
- Use clear structure: headings, bullet points, or numbered steps where appropriate to maximize clarity
|
||||
- Only say you cannot answer if the context is entirely unrelated to the question
|
||||
- Cite sources for each major point (book title and page number from the context)
|
||||
- When referencing diagrams or figures, cite them as [Fig. X, p.N]
|
||||
- Cite sources for each major claim using the reference labels from the context (e.g. [S1], [F2]). Prefer these labels over inventing page numbers, but you may also describe the source naturally if needed.
|
||||
- When referencing diagrams or figures, prefer their label from the context (e.g. [F1])
|
||||
- Maintain continuity with the conversation history
|
||||
- Never fabricate clinical information not present in the context
|
||||
""";
|
||||
@@ -40,17 +39,23 @@ public class ChatService {
|
||||
private final ChatSessionRepository sessionRepository;
|
||||
private final MessageRepository messageRepository;
|
||||
private final NeurosurgeryRetriever retriever;
|
||||
private final QueryExpansionService queryExpansionService;
|
||||
private final CitationValidatorService citationValidatorService;
|
||||
|
||||
public ChatService(ChatClient chatClient,
|
||||
BookRepository bookRepository,
|
||||
ChatSessionRepository sessionRepository,
|
||||
MessageRepository messageRepository,
|
||||
NeurosurgeryRetriever retriever) {
|
||||
NeurosurgeryRetriever retriever,
|
||||
QueryExpansionService queryExpansionService,
|
||||
CitationValidatorService citationValidatorService) {
|
||||
this.chatClient = chatClient;
|
||||
this.bookRepository = bookRepository;
|
||||
this.sessionRepository = sessionRepository;
|
||||
this.messageRepository = messageRepository;
|
||||
this.retriever = retriever;
|
||||
this.queryExpansionService = queryExpansionService;
|
||||
this.citationValidatorService = citationValidatorService;
|
||||
}
|
||||
|
||||
public ChatSession createSession(String topicId) {
|
||||
@@ -85,25 +90,34 @@ public class ChatService {
|
||||
List<Message> history = messageRepository.findBySessionIdOrderByCreatedAtAsc(sessionId);
|
||||
String fullQuestion = buildQuestionWithHistory(history, userContent, session.getTopicId());
|
||||
|
||||
// Retrieve context from all ready books (aggregate across books)
|
||||
// Expand only the current user question to clinical terminology for retrieval (US1).
|
||||
// fullQuestion (which includes conversation history) is used for the LLM context prompt,
|
||||
// but retrieval should be driven by a concise clinical rewrite of the actual question.
|
||||
String retrievalQuery = queryExpansionService.expand(userContent).rewritten();
|
||||
|
||||
// Retrieve context from all ready books using the expanded query
|
||||
List<SectionEntity> allSections = new ArrayList<>();
|
||||
List<FigureEntity> allFigures = new ArrayList<>();
|
||||
for (com.aiteacher.book.Book book : readyBooks) {
|
||||
RetrievalResult result = retriever.retrieve(fullQuestion, book.getId());
|
||||
RetrievalResult result = retriever.retrieve(retrievalQuery, book.getId());
|
||||
allSections.addAll(result.parentSections());
|
||||
allFigures.addAll(result.figures());
|
||||
}
|
||||
|
||||
// Build LLM prompt with section full texts and figure references
|
||||
String contextPrompt = buildContextPrompt(fullQuestion, allSections, allFigures);
|
||||
// Build labelled context prompt (US2): assigns [S1]/[F1] labels to each source
|
||||
LabelledContext ctx = buildContextPrompt(fullQuestion, allSections, allFigures);
|
||||
|
||||
String assistantContent = chatClient.prompt()
|
||||
// Generate answer
|
||||
String rawContent = chatClient.prompt()
|
||||
.system(SYSTEM_PROMPT)
|
||||
.user(contextPrompt)
|
||||
.user(ctx.promptText())
|
||||
.call()
|
||||
.content();
|
||||
|
||||
// Build sources list with TEXT and FIGURE entries
|
||||
// Strip any citation labels not present in the retrieved context (US2)
|
||||
String assistantContent = citationValidatorService.validate(rawContent, ctx.allLabels());
|
||||
|
||||
// Attach sources with their ref-labels for frontend traceability
|
||||
List<Map<String, Object>> sources = buildSources(allSections, allFigures);
|
||||
|
||||
Message assistantMessage = new Message(sessionId, MessageRole.ASSISTANT, assistantContent);
|
||||
@@ -126,51 +140,71 @@ public class ChatService {
|
||||
// Private helpers
|
||||
// -------------------------------------------------------------------------
|
||||
|
||||
private String buildContextPrompt(String question,
|
||||
List<SectionEntity> sections,
|
||||
List<FigureEntity> figures) {
|
||||
/**
|
||||
* Builds the LLM context prompt, tagging each section as [S1], [S2]… and
|
||||
* each figure as [F1], [F2]… so the model can cite only known sources.
|
||||
*/
|
||||
private LabelledContext buildContextPrompt(String question,
|
||||
List<SectionEntity> sections,
|
||||
List<FigureEntity> figures) {
|
||||
Map<String, SectionEntity> sectionLabels = new LinkedHashMap<>();
|
||||
Map<String, FigureEntity> figureLabels = new LinkedHashMap<>();
|
||||
StringBuilder sb = new StringBuilder();
|
||||
|
||||
if (!sections.isEmpty()) {
|
||||
sb.append("CONTEXT:\n\n");
|
||||
for (SectionEntity section : sections) {
|
||||
sb.append("[").append(section.getTitle())
|
||||
.append(", p.").append(section.getPageStart()).append("]\n");
|
||||
for (int i = 0; i < sections.size(); i++) {
|
||||
SectionEntity section = sections.get(i);
|
||||
String label = "S" + (i + 1);
|
||||
sectionLabels.put(label, section);
|
||||
sb.append("[").append(label).append("] ")
|
||||
.append(section.getTitle())
|
||||
.append(", p.").append(section.getPageStart()).append("\n");
|
||||
sb.append(section.getFullText()).append("\n\n");
|
||||
}
|
||||
}
|
||||
|
||||
if (!figures.isEmpty()) {
|
||||
sb.append("AVAILABLE FIGURES:\n");
|
||||
for (FigureEntity figure : figures) {
|
||||
sb.append("- ").append(figure.getLabel() != null ? figure.getLabel() : "Figure")
|
||||
for (int i = 0; i < figures.size(); i++) {
|
||||
FigureEntity figure = figures.get(i);
|
||||
String label = "F" + (i + 1);
|
||||
figureLabels.put(label, figure);
|
||||
sb.append("[").append(label).append("] ")
|
||||
.append(figure.getLabel() != null ? figure.getLabel() : "Figure")
|
||||
.append(" (p.").append(figure.getPage()).append("): ")
|
||||
.append(figure.getCaption() != null ? figure.getCaption() : "")
|
||||
.append("\n");
|
||||
}
|
||||
sb.append("\nWhen referencing diagrams, cite them as [Fig. X, p.N].\n\n");
|
||||
sb.append("\nWhen referencing diagrams, use their label from the context (e.g. [F1]).\n\n");
|
||||
}
|
||||
|
||||
sb.append("QUESTION:\n").append(question);
|
||||
return sb.toString();
|
||||
return new LabelledContext(sectionLabels, figureLabels, sb.toString());
|
||||
}
|
||||
|
||||
private List<Map<String, Object>> buildSources(List<SectionEntity> sections,
|
||||
List<FigureEntity> figures) {
|
||||
List<Map<String, Object>> sources = new ArrayList<>();
|
||||
|
||||
for (SectionEntity section : sections) {
|
||||
for (int i = 0; i < sections.size(); i++) {
|
||||
SectionEntity section = sections.get(i);
|
||||
Map<String, Object> source = new LinkedHashMap<>();
|
||||
source.put("type", "TEXT");
|
||||
source.put("refLabel", "S" + (i + 1));
|
||||
source.put("bookId", section.getBookId());
|
||||
source.put("bookTitle", deriveTitleFromSection(section));
|
||||
source.put("page", section.getPageStart());
|
||||
source.put("chunkText", truncate(section.getFullText(), 500));
|
||||
sources.add(source);
|
||||
}
|
||||
|
||||
for (FigureEntity figure : figures) {
|
||||
for (int i = 0; i < figures.size(); i++) {
|
||||
FigureEntity figure = figures.get(i);
|
||||
Map<String, Object> source = new LinkedHashMap<>();
|
||||
source.put("type", "FIGURE");
|
||||
source.put("refLabel", "F" + (i + 1));
|
||||
source.put("bookId", figure.getBookId());
|
||||
source.put("bookTitle", bookRepository.findById(figure.getBookId())
|
||||
.map(com.aiteacher.book.Book::getTitle).orElse("Book"));
|
||||
source.put("page", figure.getPage());
|
||||
@@ -178,7 +212,6 @@ public class ChatService {
|
||||
source.put("label", figure.getLabel() != null ? figure.getLabel() : "");
|
||||
source.put("caption", figure.getCaption() != null ? figure.getCaption() : "");
|
||||
source.put("figureType", figure.getFigureType().name());
|
||||
// imageUrl assembled from relative path: figures/{bookId}/{filename}
|
||||
String filename = figure.getImagePath().substring(
|
||||
figure.getImagePath().lastIndexOf('/') + 1);
|
||||
source.put("imageUrl", "/api/v1/figures/" + figure.getBookId() + "/" + filename);
|
||||
|
||||
Reference in New Issue
Block a user