first implementation
This commit is contained in:
@@ -0,0 +1,75 @@
|
||||
package com.aiteacher.chat;
|
||||
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
|
||||
@RestController
|
||||
@RequestMapping("/api/v1/chat")
|
||||
public class ChatController {
|
||||
|
||||
private final ChatService chatService;
|
||||
|
||||
public ChatController(ChatService chatService) {
|
||||
this.chatService = chatService;
|
||||
}
|
||||
|
||||
@PostMapping("/sessions")
|
||||
public ResponseEntity<Map<String, Object>> createSession(
|
||||
@RequestBody(required = false) Map<String, String> body) {
|
||||
String topicId = body != null ? body.get("topicId") : null;
|
||||
ChatSession session = chatService.createSession(topicId);
|
||||
|
||||
Map<String, Object> response = new LinkedHashMap<>();
|
||||
response.put("sessionId", session.getId());
|
||||
response.put("topicId", session.getTopicId());
|
||||
response.put("createdAt", session.getCreatedAt());
|
||||
return ResponseEntity.status(HttpStatus.CREATED).body(response);
|
||||
}
|
||||
|
||||
@GetMapping("/sessions/{sessionId}/messages")
|
||||
public ResponseEntity<List<Map<String, Object>>> getMessages(@PathVariable UUID sessionId) {
|
||||
List<Message> messages = chatService.getMessages(sessionId);
|
||||
List<Map<String, Object>> response = messages.stream()
|
||||
.map(this::toMessageResponse)
|
||||
.toList();
|
||||
return ResponseEntity.ok(response);
|
||||
}
|
||||
|
||||
@PostMapping("/sessions/{sessionId}/messages")
|
||||
public ResponseEntity<Map<String, Object>> sendMessage(
|
||||
@PathVariable UUID sessionId,
|
||||
@RequestBody Map<String, String> body) {
|
||||
String content = body.get("content");
|
||||
if (content == null || content.isBlank()) {
|
||||
throw new IllegalArgumentException("Message content must not be empty.");
|
||||
}
|
||||
Message message = chatService.sendMessage(sessionId, content);
|
||||
return ResponseEntity.ok(toMessageResponse(message));
|
||||
}
|
||||
|
||||
@DeleteMapping("/sessions/{sessionId}")
|
||||
public ResponseEntity<Void> deleteSession(@PathVariable UUID sessionId) {
|
||||
chatService.deleteSession(sessionId);
|
||||
return ResponseEntity.noContent().build();
|
||||
}
|
||||
|
||||
private Map<String, Object> toMessageResponse(Message message) {
|
||||
Map<String, Object> map = new LinkedHashMap<>();
|
||||
map.put("id", message.getId());
|
||||
map.put("role", message.getRole().name());
|
||||
map.put("content", message.getContent());
|
||||
if (message.getSources() != null) {
|
||||
map.put("sources", message.getSources());
|
||||
} else {
|
||||
map.put("sources", List.of());
|
||||
}
|
||||
map.put("createdAt", message.getCreatedAt());
|
||||
return map;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,166 @@
|
||||
package com.aiteacher.chat;
|
||||
|
||||
import com.aiteacher.book.BookRepository;
|
||||
import com.aiteacher.book.BookStatus;
|
||||
import com.aiteacher.book.NoKnowledgeSourceException;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.client.advisor.vectorstore.QuestionAnswerAdvisor;
|
||||
import org.springframework.ai.chat.model.ChatResponse;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.vectorstore.SearchRequest;
|
||||
import org.springframework.ai.vectorstore.VectorStore;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.NoSuchElementException;
|
||||
import java.util.UUID;
|
||||
|
||||
@Service
|
||||
public class ChatService {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(ChatService.class);
|
||||
|
||||
private static final String SYSTEM_PROMPT = """
|
||||
You are an expert neurosurgery educator assistant. Your role is to answer
|
||||
questions based ONLY on the content from uploaded medical textbooks that has been
|
||||
retrieved for you as context.
|
||||
|
||||
Rules:
|
||||
- Answer only from the provided context chunks
|
||||
- If the context does not contain enough information, explicitly state:
|
||||
"I could not find relevant information about this topic in the uploaded books."
|
||||
- Cite sources when possible (book title and page number from the context metadata)
|
||||
- Maintain continuity with the conversation history
|
||||
- Never fabricate clinical information
|
||||
""";
|
||||
|
||||
private final ChatClient chatClient;
|
||||
private final VectorStore vectorStore;
|
||||
private final BookRepository bookRepository;
|
||||
private final ChatSessionRepository sessionRepository;
|
||||
private final MessageRepository messageRepository;
|
||||
|
||||
public ChatService(ChatClient chatClient, VectorStore vectorStore,
|
||||
BookRepository bookRepository,
|
||||
ChatSessionRepository sessionRepository,
|
||||
MessageRepository messageRepository) {
|
||||
this.chatClient = chatClient;
|
||||
this.vectorStore = vectorStore;
|
||||
this.bookRepository = bookRepository;
|
||||
this.sessionRepository = sessionRepository;
|
||||
this.messageRepository = messageRepository;
|
||||
}
|
||||
|
||||
public ChatSession createSession(String topicId) {
|
||||
ChatSession session = new ChatSession(topicId);
|
||||
return sessionRepository.save(session);
|
||||
}
|
||||
|
||||
public List<Message> getMessages(UUID sessionId) {
|
||||
if (!sessionRepository.existsById(sessionId)) {
|
||||
throw new NoSuchElementException("Session not found.");
|
||||
}
|
||||
return messageRepository.findBySessionIdOrderByCreatedAtAsc(sessionId);
|
||||
}
|
||||
|
||||
public Message sendMessage(UUID sessionId, String userContent) {
|
||||
ChatSession session = sessionRepository.findById(sessionId)
|
||||
.orElseThrow(() -> new NoSuchElementException("Session not found."));
|
||||
|
||||
if (!bookRepository.existsByStatus(BookStatus.READY)) {
|
||||
throw new NoKnowledgeSourceException("No books are available as knowledge sources.");
|
||||
}
|
||||
|
||||
// Persist user message
|
||||
Message userMessage = new Message(sessionId, MessageRole.USER, userContent);
|
||||
messageRepository.save(userMessage);
|
||||
|
||||
// Build conversation history for context
|
||||
List<Message> history = messageRepository.findBySessionIdOrderByCreatedAtAsc(sessionId);
|
||||
|
||||
// Build the prompt with full conversation history as context
|
||||
String fullQuestion = buildQuestionWithHistory(history, userContent, session.getTopicId());
|
||||
|
||||
var qaAdvisor = QuestionAnswerAdvisor.builder(vectorStore)
|
||||
.searchRequest(SearchRequest.builder().similarityThreshold(0.8d).topK(6).build())
|
||||
.build();
|
||||
|
||||
ChatResponse response = chatClient.prompt()
|
||||
.advisors(qaAdvisor)
|
||||
.system(SYSTEM_PROMPT)
|
||||
.user(fullQuestion)
|
||||
.call()
|
||||
.chatResponse();
|
||||
|
||||
String assistantContent = response.getResult().getOutput().getText();
|
||||
List<Map<String, Object>> sources = extractSources(response);
|
||||
|
||||
// Persist assistant message
|
||||
Message assistantMessage = new Message(sessionId, MessageRole.ASSISTANT, assistantContent);
|
||||
assistantMessage.setSources(sources);
|
||||
return messageRepository.save(assistantMessage);
|
||||
}
|
||||
|
||||
public void deleteSession(UUID sessionId) {
|
||||
if (!sessionRepository.existsById(sessionId)) {
|
||||
throw new NoSuchElementException("Session not found.");
|
||||
}
|
||||
sessionRepository.deleteById(sessionId);
|
||||
}
|
||||
|
||||
private String buildQuestionWithHistory(List<Message> history, String currentQuestion,
|
||||
String topicId) {
|
||||
if (history.size() <= 1) {
|
||||
// Only the current user message is in history; just ask the question
|
||||
return topicId != null
|
||||
? String.format("[Context: This is a question about the neurosurgery topic '%s']\n%s",
|
||||
topicId, currentQuestion)
|
||||
: currentQuestion;
|
||||
}
|
||||
|
||||
StringBuilder sb = new StringBuilder();
|
||||
if (topicId != null) {
|
||||
sb.append(String.format("[Context: This conversation is about the neurosurgery topic '%s']\n\n",
|
||||
topicId));
|
||||
}
|
||||
sb.append("Previous conversation:\n");
|
||||
// Include all messages except the last (which is the current user message just saved)
|
||||
for (int i = 0; i < history.size() - 1; i++) {
|
||||
Message msg = history.get(i);
|
||||
sb.append(msg.getRole().name()).append(": ").append(msg.getContent()).append("\n");
|
||||
}
|
||||
sb.append("\nCurrent question: ").append(currentQuestion);
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
private List<Map<String, Object>> extractSources(ChatResponse response) {
|
||||
List<Map<String, Object>> sources = new ArrayList<>();
|
||||
|
||||
if (response.getMetadata() != null) {
|
||||
Object retrieved = response.getMetadata().get(QuestionAnswerAdvisor.RETRIEVED_DOCUMENTS);
|
||||
if (retrieved instanceof List<?> docs) {
|
||||
for (Object docObj : docs) {
|
||||
if (docObj instanceof Document doc) {
|
||||
Map<String, Object> metadata = doc.getMetadata();
|
||||
String bookTitle = (String) metadata.get("book_title");
|
||||
Object pageObj = metadata.get("page_number");
|
||||
Integer page = pageObj instanceof Number n ? n.intValue() : null;
|
||||
if (bookTitle != null) {
|
||||
Map<String, Object> source = new HashMap<>();
|
||||
source.put("bookTitle", bookTitle);
|
||||
source.put("page", page);
|
||||
sources.add(source);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sources;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package com.aiteacher.chat;
|
||||
|
||||
import jakarta.persistence.*;
|
||||
import java.time.Instant;
|
||||
import java.util.UUID;
|
||||
|
||||
@Entity
|
||||
@Table(name = "chat_session")
|
||||
public class ChatSession {
|
||||
|
||||
@Id
|
||||
@GeneratedValue(strategy = GenerationType.UUID)
|
||||
private UUID id;
|
||||
|
||||
@Column(name = "topic_id", length = 100)
|
||||
private String topicId;
|
||||
|
||||
@Column(name = "created_at", nullable = false)
|
||||
private Instant createdAt;
|
||||
|
||||
public ChatSession() {
|
||||
}
|
||||
|
||||
public ChatSession(String topicId) {
|
||||
this.topicId = topicId;
|
||||
this.createdAt = Instant.now();
|
||||
}
|
||||
|
||||
public UUID getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
public String getTopicId() {
|
||||
return topicId;
|
||||
}
|
||||
|
||||
public void setTopicId(String topicId) {
|
||||
this.topicId = topicId;
|
||||
}
|
||||
|
||||
public Instant getCreatedAt() {
|
||||
return createdAt;
|
||||
}
|
||||
|
||||
public void setCreatedAt(Instant createdAt) {
|
||||
this.createdAt = createdAt;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package com.aiteacher.chat;
|
||||
|
||||
import org.springframework.data.jpa.repository.JpaRepository;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
import java.util.UUID;
|
||||
|
||||
@Repository
|
||||
public interface ChatSessionRepository extends JpaRepository<ChatSession, UUID> {
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
package com.aiteacher.chat;
|
||||
|
||||
import jakarta.persistence.*;
|
||||
import org.hibernate.annotations.JdbcTypeCode;
|
||||
import org.hibernate.type.SqlTypes;
|
||||
|
||||
import java.time.Instant;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
|
||||
@Entity
|
||||
@Table(name = "message")
|
||||
public class Message {
|
||||
|
||||
@Id
|
||||
@GeneratedValue(strategy = GenerationType.UUID)
|
||||
private UUID id;
|
||||
|
||||
@Column(name = "session_id", nullable = false)
|
||||
private UUID sessionId;
|
||||
|
||||
@Enumerated(EnumType.STRING)
|
||||
@Column(name = "role", nullable = false, length = 10)
|
||||
private MessageRole role;
|
||||
|
||||
@Column(name = "content", nullable = false, columnDefinition = "TEXT")
|
||||
private String content;
|
||||
|
||||
@JdbcTypeCode(SqlTypes.JSON)
|
||||
@Column(name = "sources", columnDefinition = "jsonb")
|
||||
private List<Map<String, Object>> sources;
|
||||
|
||||
@Column(name = "created_at", nullable = false)
|
||||
private Instant createdAt;
|
||||
|
||||
public Message() {
|
||||
}
|
||||
|
||||
public Message(UUID sessionId, MessageRole role, String content) {
|
||||
this.sessionId = sessionId;
|
||||
this.role = role;
|
||||
this.content = content;
|
||||
this.createdAt = Instant.now();
|
||||
}
|
||||
|
||||
public UUID getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
public UUID getSessionId() {
|
||||
return sessionId;
|
||||
}
|
||||
|
||||
public void setSessionId(UUID sessionId) {
|
||||
this.sessionId = sessionId;
|
||||
}
|
||||
|
||||
public MessageRole getRole() {
|
||||
return role;
|
||||
}
|
||||
|
||||
public void setRole(MessageRole role) {
|
||||
this.role = role;
|
||||
}
|
||||
|
||||
public String getContent() {
|
||||
return content;
|
||||
}
|
||||
|
||||
public void setContent(String content) {
|
||||
this.content = content;
|
||||
}
|
||||
|
||||
public List<Map<String, Object>> getSources() {
|
||||
return sources;
|
||||
}
|
||||
|
||||
public void setSources(List<Map<String, Object>> sources) {
|
||||
this.sources = sources;
|
||||
}
|
||||
|
||||
public Instant getCreatedAt() {
|
||||
return createdAt;
|
||||
}
|
||||
|
||||
public void setCreatedAt(Instant createdAt) {
|
||||
this.createdAt = createdAt;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.aiteacher.chat;
|
||||
|
||||
import org.springframework.data.jpa.repository.JpaRepository;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
@Repository
|
||||
public interface MessageRepository extends JpaRepository<Message, UUID> {
|
||||
|
||||
List<Message> findBySessionIdOrderByCreatedAtAsc(UUID sessionId);
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
package com.aiteacher.chat;
|
||||
|
||||
public enum MessageRole {
|
||||
USER,
|
||||
ASSISTANT
|
||||
}
|
||||
Reference in New Issue
Block a user