first implementation - image/drawing integration

This commit is contained in:
Adrien
2026-04-04 12:56:56 +02:00
parent fc5b22fba1
commit 5acfdd33c1
42 changed files with 2854 additions and 151 deletions
@@ -0,0 +1,47 @@
package com.aiteacher.document;
import jakarta.persistence.*;
import java.time.Instant;
import java.util.UUID;
@Entity
@Table(name = "chapter")
public class ChapterEntity {
@Id
@Column(name = "id", length = 200)
private String id;
@Column(name = "book_id", nullable = false)
private UUID bookId;
@Column(name = "number", nullable = false)
private int number;
@Column(name = "title", length = 500)
private String title;
@Column(name = "page_start")
private Integer pageStart;
@Column(name = "created_at", nullable = false)
private Instant createdAt;
public ChapterEntity() {}
public ChapterEntity(String id, UUID bookId, int number, String title, Integer pageStart) {
this.id = id;
this.bookId = bookId;
this.number = number;
this.title = title;
this.pageStart = pageStart;
this.createdAt = Instant.now();
}
public String getId() { return id; }
public UUID getBookId() { return bookId; }
public int getNumber() { return number; }
public String getTitle() { return title; }
public Integer getPageStart() { return pageStart; }
public Instant getCreatedAt() { return createdAt; }
}
@@ -0,0 +1,9 @@
package com.aiteacher.document;
import org.springframework.data.jpa.repository.JpaRepository;
import java.util.UUID;
public interface ChapterRepository extends JpaRepository<ChapterEntity, String> {
void deleteAllByBookId(UUID bookId);
}
@@ -0,0 +1,58 @@
package com.aiteacher.document;
import jakarta.persistence.*;
import java.io.Serializable;
import java.util.Objects;
import java.util.UUID;
@Entity
@Table(name = "chunk_figure_ref")
@IdClass(ChunkFigureRefEntity.PK.class)
public class ChunkFigureRefEntity {
@Id
@Column(name = "chunk_id", nullable = false)
private UUID chunkId;
@Id
@Column(name = "figure_id", nullable = false, length = 200)
private String figureId;
@Column(name = "mention_page")
private Integer mentionPage;
public ChunkFigureRefEntity() {}
public ChunkFigureRefEntity(UUID chunkId, String figureId, Integer mentionPage) {
this.chunkId = chunkId;
this.figureId = figureId;
this.mentionPage = mentionPage;
}
public UUID getChunkId() { return chunkId; }
public String getFigureId() { return figureId; }
public Integer getMentionPage() { return mentionPage; }
public static class PK implements Serializable {
private UUID chunkId;
private String figureId;
public PK() {}
public PK(UUID chunkId, String figureId) {
this.chunkId = chunkId;
this.figureId = figureId;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof PK pk)) return false;
return Objects.equals(chunkId, pk.chunkId) && Objects.equals(figureId, pk.figureId);
}
@Override
public int hashCode() {
return Objects.hash(chunkId, figureId);
}
}
}
@@ -0,0 +1,18 @@
package com.aiteacher.document;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.query.Param;
import java.util.List;
import java.util.UUID;
public interface ChunkFigureRefRepository extends JpaRepository<ChunkFigureRefEntity, ChunkFigureRefEntity.PK> {
@Query("SELECT r FROM ChunkFigureRefEntity r WHERE r.chunkId IN :chunkIds")
List<ChunkFigureRefEntity> findByChunkIdIn(@Param("chunkIds") List<UUID> chunkIds);
@Query("DELETE FROM ChunkFigureRefEntity r WHERE r.figureId IN :figureIds")
@org.springframework.data.jpa.repository.Modifying
void deleteByFigureIdIn(@Param("figureIds") List<String> figureIds);
}
@@ -0,0 +1,62 @@
package com.aiteacher.document;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.stereotype.Service;
import java.util.List;
import java.util.UUID;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* Scans chunk text for "Fig. X" and "Figure X" references and persists
* ChunkFigureRefEntity rows linking that chunk to its referenced figures.
*/
@Service
public class ChunkFigureRefService {
private static final Logger log = LoggerFactory.getLogger(ChunkFigureRefService.class);
// Matches: "Fig. 12-4", "Fig. 12.4", "Fig 12", "Figure 12-4", etc.
private static final Pattern REF_PATTERN =
Pattern.compile("(?i)\\b(Fig\\.?|Figure)\\s+(\\d+[\\-.\\d]*)");
private final ChunkFigureRefRepository refRepository;
public ChunkFigureRefService(ChunkFigureRefRepository refRepository) {
this.refRepository = refRepository;
}
/**
* For each text chunk, finds figure references and persists ChunkFigureRefEntity rows.
*/
public void linkChunksToFigures(List<Document> chunks, List<FigureEntity> bookFigures,
int pageNum) {
if (bookFigures.isEmpty()) return;
for (Document chunk : chunks) {
String chunkIdStr = chunk.getId();
UUID chunkId;
try {
chunkId = UUID.fromString(chunkIdStr);
} catch (IllegalArgumentException ex) {
log.warn("Chunk has non-UUID id: {}", chunkIdStr);
continue;
}
Matcher m = REF_PATTERN.matcher(chunk.getText());
while (m.find()) {
String refNum = m.group(2).trim();
// Find matching figure by label suffix
for (FigureEntity figure : bookFigures) {
if (figure.getLabel() != null && figure.getLabel().endsWith(refNum)) {
refRepository.save(new ChunkFigureRefEntity(chunkId, figure.getId(), pageNum));
break;
}
}
}
}
}
}
@@ -0,0 +1,82 @@
package com.aiteacher.document;
import jakarta.persistence.*;
import java.time.Instant;
import java.util.UUID;
@Entity
@Table(name = "figure")
public class FigureEntity {
@Id
@Column(name = "id", length = 200)
private String id;
@Column(name = "book_id", nullable = false)
private UUID bookId;
@Column(name = "section_id", length = 200)
private String sectionId;
@Column(name = "chapter_id", length = 200)
private String chapterId;
@Column(name = "label", length = 100)
private String label;
@Column(name = "caption", columnDefinition = "TEXT")
private String caption;
@Enumerated(EnumType.STRING)
@Column(name = "figure_type", nullable = false, length = 50)
private FigureType figureType;
@Column(name = "page", nullable = false)
private int page;
@Column(name = "image_path", nullable = false, length = 1000)
private String imagePath;
@Column(name = "caption_embedding_id")
private UUID captionEmbeddingId;
@Column(name = "created_at", nullable = false)
private Instant createdAt;
public FigureEntity() {}
public FigureEntity(String id, UUID bookId, String sectionId, String chapterId,
String label, String caption, FigureType figureType,
int page, String imagePath) {
this.id = id;
this.bookId = bookId;
this.sectionId = sectionId;
this.chapterId = chapterId;
this.label = label;
this.caption = caption;
this.figureType = figureType;
this.page = page;
this.imagePath = imagePath;
this.createdAt = Instant.now();
}
public String getId() { return id; }
public UUID getBookId() { return bookId; }
public String getSectionId() { return sectionId; }
public String getChapterId() { return chapterId; }
public String getLabel() { return label; }
public String getCaption() { return caption; }
public FigureType getFigureType() { return figureType; }
public int getPage() { return page; }
public String getImagePath() { return imagePath; }
public UUID getCaptionEmbeddingId() { return captionEmbeddingId; }
public Instant getCreatedAt() { return createdAt; }
public void setCaptionEmbeddingId(UUID captionEmbeddingId) {
this.captionEmbeddingId = captionEmbeddingId;
}
public void setCaption(String caption) {
this.caption = caption;
}
}
@@ -0,0 +1,135 @@
package com.aiteacher.document;
import com.aiteacher.figure.FigureStorageService;
import org.apache.pdfbox.Loader;
import org.apache.pdfbox.cos.COSName;
import org.apache.pdfbox.pdmodel.PDDocument;
import org.apache.pdfbox.pdmodel.PDPage;
import org.apache.pdfbox.pdmodel.graphics.PDXObject;
import org.apache.pdfbox.pdmodel.graphics.image.PDImageXObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* Extracts images from each PDF page using PDFBox.
* Images below the configured minimum size are skipped.
* Caption is detected by the "Fig." pattern in page text.
*/
@Service
public class FigureExtractionService {
private static final Logger log = LoggerFactory.getLogger(FigureExtractionService.class);
// Caption: line starting with "Fig." or "Figure" followed by a number
private static final Pattern CAPTION_PATTERN =
Pattern.compile("(?m)^(Fig\\.?\\s*\\d+[\\-.]?\\d*[^\\n]*)", Pattern.CASE_INSENSITIVE);
// Figure label: "Fig. 12-4" or "Fig. 12.4"
private static final Pattern LABEL_PATTERN =
Pattern.compile("(?i)Fig\\.?\\s*(\\d+[\\-.\\d]*)");
private final FigureStorageService storageService;
private final FigureRepository figureRepository;
private final int minImageSizePx;
public FigureExtractionService(
FigureStorageService storageService,
FigureRepository figureRepository,
@Value("${app.figure-storage.min-image-size-px:100}") int minImageSizePx) {
this.storageService = storageService;
this.figureRepository = figureRepository;
this.minImageSizePx = minImageSizePx;
}
/**
* Extracts all qualifying images from the PDF for the given book.
* Returns persisted FigureEntity list (without vision descriptions — set later).
*/
public List<FigureEntity> extract(UUID bookId, String chapterId,
List<SectionEntity> sections, Path pdfPath) {
List<FigureEntity> figures = new ArrayList<>();
int figureCounter = 0;
try (PDDocument doc = Loader.loadPDF(pdfPath.toFile())) {
for (SectionEntity section : sections) {
int pageIndex = section.getPageStart() - 1; // 0-based
if (pageIndex < 0 || pageIndex >= doc.getNumberOfPages()) continue;
PDPage page = doc.getPage(pageIndex);
String pageText = section.getFullText();
try {
for (COSName name : page.getResources().getXObjectNames()) {
PDXObject xObject = page.getResources().getXObject(name);
if (!(xObject instanceof PDImageXObject image)) continue;
BufferedImage bufferedImage = image.getImage();
if (bufferedImage.getWidth() < minImageSizePx
|| bufferedImage.getHeight() < minImageSizePx) {
continue; // skip decorative images
}
figureCounter++;
String figureId = bookId + "-fig-" + pageIndex + "-" + figureCounter;
String caption = detectCaption(pageText);
String label = detectLabel(caption, figureCounter);
FigureType type = classifyType(caption, pageText);
String imagePath = storageService.save(bookId, figureId, bufferedImage);
FigureEntity figure = new FigureEntity(
figureId, bookId, section.getId(), chapterId,
label, caption, type, section.getPageStart(), imagePath
);
figures.add(figureRepository.save(figure));
}
} catch (IOException ex) {
log.warn("Failed to extract images from page {} of book {}: {}",
section.getPageStart(), bookId, ex.getMessage());
}
}
} catch (IOException ex) {
log.error("Could not open PDF for image extraction, book {}", bookId, ex);
}
log.info("Extracted {} figures for book {}", figures.size(), bookId);
return figures;
}
private String detectCaption(String pageText) {
if (pageText == null) return null;
Matcher m = CAPTION_PATTERN.matcher(pageText);
return m.find() ? m.group(1).trim() : null;
}
private String detectLabel(String caption, int counter) {
if (caption != null) {
Matcher m = LABEL_PATTERN.matcher(caption);
if (m.find()) return "Fig. " + m.group(1).trim();
}
return "Fig. " + counter;
}
private FigureType classifyType(String caption, String pageText) {
String combined = ((caption != null ? caption : "") + " " + (pageText != null ? pageText : "")).toLowerCase();
if (combined.contains("mri") || combined.contains("ct ") || combined.contains("magnetic")
|| combined.contains("tomography")) return FigureType.MRI_CT_SCAN;
if (combined.contains("intraoperative") || combined.contains("intra-op")) return FigureType.INTRAOPERATIVE_IMAGE;
if (caption != null && caption.toLowerCase().startsWith("table")) return FigureType.TABLE;
if (combined.contains("chart") || combined.contains("histogram") || combined.contains("graph"))
return FigureType.CHART;
if (combined.contains("photograph") || combined.contains("photo")) return FigureType.SURGICAL_PHOTOGRAPH;
return FigureType.ANATOMICAL_DIAGRAM;
}
}
@@ -0,0 +1,11 @@
package com.aiteacher.document;
import org.springframework.data.jpa.repository.JpaRepository;
import java.util.List;
import java.util.UUID;
public interface FigureRepository extends JpaRepository<FigureEntity, String> {
List<FigureEntity> findAllByBookId(UUID bookId);
void deleteAllByBookId(UUID bookId);
}
@@ -0,0 +1,10 @@
package com.aiteacher.document;
public enum FigureType {
ANATOMICAL_DIAGRAM,
SURGICAL_PHOTOGRAPH,
MRI_CT_SCAN,
TABLE,
CHART,
INTRAOPERATIVE_IMAGE
}
@@ -0,0 +1,71 @@
package com.aiteacher.document;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.reader.pdf.PagePdfDocumentReader;
import org.springframework.ai.reader.pdf.config.PdfDocumentReaderConfig;
import org.springframework.core.io.FileSystemResource;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
/**
* Parses a PDF into page-level SectionEntity records stored in Postgres.
* Each page becomes one section, grouped under a single chapter per book.
*/
@Service
public class PdfStructureParser {
private static final Logger log = LoggerFactory.getLogger(PdfStructureParser.class);
private final ChapterRepository chapterRepository;
private final SectionRepository sectionRepository;
public PdfStructureParser(ChapterRepository chapterRepository,
SectionRepository sectionRepository) {
this.chapterRepository = chapterRepository;
this.sectionRepository = sectionRepository;
}
@Transactional
public List<SectionEntity> parse(UUID bookId, String bookTitle, Path pdfPath) {
log.info("Parsing PDF structure for book {}", bookId);
// One chapter per book
String chapterId = bookId + "-ch1";
ChapterEntity chapter = new ChapterEntity(chapterId, bookId, 1, bookTitle, 1);
chapterRepository.save(chapter);
// One section per page
PagePdfDocumentReader reader = new PagePdfDocumentReader(
new FileSystemResource(pdfPath.toFile()),
PdfDocumentReaderConfig.builder().withPagesPerDocument(1).build()
);
List<org.springframework.ai.document.Document> pages = reader.get();
List<SectionEntity> sections = new ArrayList<>();
for (int i = 0; i < pages.size(); i++) {
int pageNum = i + 1;
String text = pages.get(i).getText();
if (text == null || text.isBlank()) continue;
String sectionId = bookId + "-p" + pageNum;
SectionEntity section = new SectionEntity(
sectionId, chapterId, bookId,
String.valueOf(pageNum),
"Page " + pageNum,
pageNum, pageNum,
text
);
sections.add(sectionRepository.save(section));
}
log.info("Parsed {} sections for book {}", sections.size(), bookId);
return sections;
}
}
@@ -0,0 +1,63 @@
package com.aiteacher.document;
import jakarta.persistence.*;
import java.time.Instant;
import java.util.UUID;
@Entity
@Table(name = "section")
public class SectionEntity {
@Id
@Column(name = "id", length = 200)
private String id;
@Column(name = "chapter_id", nullable = false, length = 200)
private String chapterId;
@Column(name = "book_id", nullable = false)
private UUID bookId;
@Column(name = "number", length = 50)
private String number;
@Column(name = "title", length = 500)
private String title;
@Column(name = "page_start", nullable = false)
private int pageStart;
@Column(name = "page_end", nullable = false)
private int pageEnd;
@Column(name = "full_text", nullable = false, columnDefinition = "TEXT")
private String fullText;
@Column(name = "created_at", nullable = false)
private Instant createdAt;
public SectionEntity() {}
public SectionEntity(String id, String chapterId, UUID bookId, String number,
String title, int pageStart, int pageEnd, String fullText) {
this.id = id;
this.chapterId = chapterId;
this.bookId = bookId;
this.number = number;
this.title = title;
this.pageStart = pageStart;
this.pageEnd = pageEnd;
this.fullText = fullText;
this.createdAt = Instant.now();
}
public String getId() { return id; }
public String getChapterId() { return chapterId; }
public UUID getBookId() { return bookId; }
public String getNumber() { return number; }
public String getTitle() { return title; }
public int getPageStart() { return pageStart; }
public int getPageEnd() { return pageEnd; }
public String getFullText() { return fullText; }
public Instant getCreatedAt() { return createdAt; }
}
@@ -0,0 +1,11 @@
package com.aiteacher.document;
import org.springframework.data.jpa.repository.JpaRepository;
import java.util.List;
import java.util.UUID;
public interface SectionRepository extends JpaRepository<SectionEntity, String> {
List<SectionEntity> findAllByBookId(UUID bookId);
void deleteAllByBookId(UUID bookId);
}
@@ -0,0 +1,65 @@
package com.aiteacher.document;
import org.springframework.ai.document.Document;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
/**
* Splits a SectionEntity's full text into overlapping chunks for vector embedding.
* Target size: ~1800 characters (~450 tokens); overlap: 200 characters.
*/
@Service
public class TextChunkingService {
private static final int TARGET_CHARS = 1800;
private static final int OVERLAP_CHARS = 200;
public List<Document> chunk(SectionEntity section, String bookTitle) {
String text = section.getFullText();
if (text == null || text.isBlank()) return List.of();
List<String> windows = split(text);
List<Document> documents = new ArrayList<>();
for (int i = 0; i < windows.size(); i++) {
String chunkId = UUID.randomUUID().toString();
Map<String, Object> metadata = buildMetadata(section, bookTitle, i, windows.size(), chunkId);
documents.add(new Document(chunkId, windows.get(i), metadata));
}
return documents;
}
private List<String> split(String text) {
List<String> windows = new ArrayList<>();
int start = 0;
while (start < text.length()) {
int end = Math.min(start + TARGET_CHARS, text.length());
windows.add(text.substring(start, end));
if (end == text.length()) break;
start = end - OVERLAP_CHARS;
}
return windows;
}
private Map<String, Object> buildMetadata(SectionEntity section, String bookTitle,
int index, int total, String chunkId) {
Map<String, Object> m = new HashMap<>();
m.put("type", "TEXT");
m.put("book_id", section.getBookId().toString());
m.put("book_title", bookTitle);
m.put("chapter_id", section.getChapterId());
m.put("section_id", section.getId());
m.put("section_title", section.getTitle() != null ? section.getTitle() : "");
m.put("page_start", section.getPageStart());
m.put("page_end", section.getPageEnd());
m.put("chunk_index", index);
m.put("total_chunks", total);
m.put("chunk_id", chunkId);
return m;
}
}
@@ -0,0 +1,49 @@
package com.aiteacher.document;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.core.io.FileSystemResource;
import org.springframework.stereotype.Service;
import org.springframework.util.MimeTypeUtils;
import java.nio.file.Path;
/**
* Generates a clinical text description for an extracted figure image
* using the OpenAI vision model via Spring AI ChatClient.
*/
@Service
public class VisionDescriptionService {
private static final Logger log = LoggerFactory.getLogger(VisionDescriptionService.class);
private static final String PROMPT =
"You are a neurosurgery educator. Provide a brief 2-3 sentence clinical description of " +
"this image. Focus on anatomical structures, surgical landmarks, labels, and clinical " +
"significance. If text or labels are visible, include them verbatim.";
private final ChatClient chatClient;
public VisionDescriptionService(ChatClient chatClient) {
this.chatClient = chatClient;
}
/**
* Returns a description string. Falls back to the provided caption if vision fails.
*/
public String describe(Path imagePath, String captionFallback) {
try {
return chatClient.prompt()
.user(u -> u
.text(PROMPT)
.media(MimeTypeUtils.IMAGE_PNG, new FileSystemResource(imagePath.toFile())))
.call()
.content();
} catch (Exception ex) {
log.warn("Vision description failed for {}: {} — using caption as fallback",
imagePath.getFileName(), ex.getMessage());
return captionFallback != null ? captionFallback : "Figure";
}
}
}