From fc0aa3d6da665b2dc642951013ef85c211818d99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=90=AF=E8=88=AA?= <280645618@qq.com> Date: Tue, 10 Mar 2026 14:41:10 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E8=81=8A=E5=A4=A9=E5=8E=86?= =?UTF-8?q?=E5=8F=B2=E8=AE=B0=E5=BD=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../cn/qihangerp/erp/config/AiConfig.java | 29 +++++ .../erp/controller/SseController.java | 112 +++++++++++++++-- .../qihangerp/erp/serviceImpl/AiService.java | 113 ++++++++++++++---- .../ConversationHistoryManager.java | 112 +++++++++++++++++ .../erp/serviceImpl/SessionManager.java | 75 ++++++++++++ vue/src/api/ai/ollama.js | 9 ++ vue/src/views/index.vue | 31 ++++- 7 files changed, 444 insertions(+), 37 deletions(-) create mode 100644 api/ai-agent/src/main/java/cn/qihangerp/erp/config/AiConfig.java create mode 100644 api/ai-agent/src/main/java/cn/qihangerp/erp/serviceImpl/ConversationHistoryManager.java create mode 100644 api/ai-agent/src/main/java/cn/qihangerp/erp/serviceImpl/SessionManager.java diff --git a/api/ai-agent/src/main/java/cn/qihangerp/erp/config/AiConfig.java b/api/ai-agent/src/main/java/cn/qihangerp/erp/config/AiConfig.java new file mode 100644 index 00000000..8d5a4f47 --- /dev/null +++ b/api/ai-agent/src/main/java/cn/qihangerp/erp/config/AiConfig.java @@ -0,0 +1,29 @@ +package cn.qihangerp.erp.config; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import cn.qihangerp.erp.serviceImpl.ConversationHistoryManager; +import cn.qihangerp.erp.serviceImpl.SessionManager; + +/** + * AI相关配置类 + */ +@Configuration +public class AiConfig { + + /** + * 会话管理服务Bean + */ + @Bean + public SessionManager sessionManager() { + return new SessionManager(); + } + + /** + * 对话历史管理服务Bean + */ + @Bean + public ConversationHistoryManager conversationHistoryManager() { + return new ConversationHistoryManager(); + } +} \ No newline at end of file diff --git a/api/ai-agent/src/main/java/cn/qihangerp/erp/controller/SseController.java b/api/ai-agent/src/main/java/cn/qihangerp/erp/controller/SseController.java index 5523c787..bc464b1f 100644 --- a/api/ai-agent/src/main/java/cn/qihangerp/erp/controller/SseController.java +++ b/api/ai-agent/src/main/java/cn/qihangerp/erp/controller/SseController.java @@ -9,9 +9,15 @@ import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; +import cn.qihangerp.security.LoginUser; +import cn.qihangerp.security.TokenService; import cn.qihangerp.erp.serviceImpl.AiService; +import cn.qihangerp.erp.serviceImpl.ConversationHistoryManager; +import cn.qihangerp.erp.serviceImpl.SessionManager; +import jakarta.servlet.http.HttpServletRequest; import java.io.IOException; +import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executors; @@ -24,19 +30,47 @@ import java.util.concurrent.TimeUnit; public class SseController { private static final Map emitters = new ConcurrentHashMap<>(); + private static final Map clientUserIdMap = new ConcurrentHashMap<>(); private final ScheduledExecutorService executorService = Executors.newScheduledThreadPool(1); @Autowired private AiService aiService; + + @Autowired + private SessionManager sessionManager; + + @Autowired + private ConversationHistoryManager conversationHistoryManager; + + @Autowired + private TokenService tokenService; @GetMapping(value = "/connect", produces = MediaType.TEXT_EVENT_STREAM_VALUE) - public SseEmitter connect(@RequestParam String clientId) { + public SseEmitter connect(@RequestParam String clientId, @RequestParam String token, HttpServletRequest request) { SseEmitter emitter = new SseEmitter(Long.MAX_VALUE); emitters.put(clientId, emitter); + // 从token中获取用户信息 + try { + LoginUser loginUser = tokenService.getLoginUser(request); + if (loginUser != null) { + Long userId = loginUser.getUserId(); + clientUserIdMap.put(clientId, userId); + log.info("用户 {} 连接成功,clientId: {}", userId, clientId); + } + } catch (Exception e) { + log.error("获取用户信息失败: {}", e.getMessage()); + } + // 设置超时处理 - emitter.onTimeout(() -> emitters.remove(clientId)); - emitter.onCompletion(() -> emitters.remove(clientId)); + emitter.onTimeout(() -> { + emitters.remove(clientId); + clientUserIdMap.remove(clientId); + }); + emitter.onCompletion(() -> { + emitters.remove(clientId); + clientUserIdMap.remove(clientId); + }); // 发送连接成功消息 try { @@ -45,6 +79,7 @@ public class SseController { .data("连接成功")); } catch (IOException e) { emitters.remove(clientId); + clientUserIdMap.remove(clientId); } // 定期发送心跳 @@ -57,6 +92,7 @@ public class SseController { } } catch (IOException e) { emitters.remove(clientId); + clientUserIdMap.remove(clientId); } }, 30, 30, TimeUnit.SECONDS); @@ -64,14 +100,46 @@ public class SseController { } @GetMapping("/send") - public String sendMessage(@RequestParam String clientId, @RequestParam String message, @RequestParam(required = false, defaultValue = "llama3") String model) { + public String sendMessage(@RequestParam String clientId, @RequestParam String message, @RequestParam(required = false, defaultValue = "llama3") String model, @RequestParam String token, HttpServletRequest request) { log.info("=============来新消息了!"); SseEmitter emitter = emitters.get(clientId); if (emitter != null) { try { - // 使用AiService处理消息,传递模型参数 - String response = aiService.processMessage(message, model); - log.info("==========AI回复:{}",response); + // 从token中获取用户信息 + LoginUser loginUser = tokenService.getLoginUser(request); + Long userId = null; + if (loginUser != null) { + userId = loginUser.getUserId(); + clientUserIdMap.put(clientId, userId); + } else { + // 尝试从映射中获取用户ID + userId = clientUserIdMap.get(clientId); + } + + String sessionId = null; + if (userId != null) { + // 获取或创建会话ID + sessionId = sessionManager.getOrCreateSessionId(userId); + log.info("用户 {} 的会话ID: {}", userId, sessionId); + + // 添加用户消息到对话历史 + conversationHistoryManager.addMessage(sessionId, "user", message); + } + + // 获取对话历史 + List conversationHistory = null; + if (sessionId != null) { + conversationHistory = conversationHistoryManager.getRecentConversationHistory(sessionId, 10); // 只获取最近10条消息作为上下文 + } + + // 使用AiService处理消息,传递模型参数、会话ID和对话历史 + String response = aiService.processMessage(message, model, sessionId, conversationHistory); + log.info("==========AI回复:{}", response); + + // 如果有会话ID,添加AI回复到对话历史 + if (sessionId != null) { + conversationHistoryManager.addMessage(sessionId, "assistant", response); + } // 检查响应是否已经是JSON格式(以{开头) String jsonResponse; @@ -80,7 +148,7 @@ public class SseController { jsonResponse = response; } else { // 否则包装成JSON格式 - jsonResponse = String.format("{\"text\": \"%s\"}", response.replace("\"", "\\\"").replace("\n", "\\n")); + jsonResponse = String.format("{\"text\": \"%s\", \"sessionId\": \"%s\"}", response.replace("\"", "\\\"").replace("\n", "\\n"), sessionId != null ? sessionId : ""); } // 发送JSON格式的消息 @@ -93,7 +161,8 @@ public class SseController { } catch (Exception e) { log.error("消息处理失败: {}", e.getMessage()); emitters.remove(clientId); - return "消息发送失败"; + clientUserIdMap.remove(clientId); + return "消息发送失败: " + e.getMessage(); } } return "客户端不存在"; @@ -102,6 +171,7 @@ public class SseController { @GetMapping("/disconnect") public String disconnect(@RequestParam String clientId) { SseEmitter emitter = emitters.remove(clientId); + clientUserIdMap.remove(clientId); if (emitter != null) { emitter.complete(); return "断开连接成功"; @@ -111,6 +181,28 @@ public class SseController { @GetMapping("/status") public String getStatus() { - return "当前连接数: " + emitters.size(); + return "当前连接数: " + emitters.size() + ", 活跃会话数: " + sessionManager.getSessionCount(); + } + + @GetMapping("/history") + public Object getConversationHistory(@RequestParam String token, HttpServletRequest request) { + try { + // 从token中获取用户信息 + LoginUser loginUser = tokenService.getLoginUser(request); + if (loginUser != null) { + Long userId = loginUser.getUserId(); + // 获取用户的会话ID + String sessionId = sessionManager.getSessionId(userId); + if (sessionId != null) { + // 获取对话历史 + var history = conversationHistoryManager.getConversationHistory(sessionId); + return Map.of("success", true, "data", history, "sessionId", sessionId); + } + } + return Map.of("success", false, "message", "获取对话历史失败"); + } catch (Exception e) { + log.error("获取对话历史失败: {}", e.getMessage()); + return Map.of("success", false, "message", "获取对话历史失败: " + e.getMessage()); + } } } \ No newline at end of file diff --git a/api/ai-agent/src/main/java/cn/qihangerp/erp/serviceImpl/AiService.java b/api/ai-agent/src/main/java/cn/qihangerp/erp/serviceImpl/AiService.java index 43d9de83..145227d1 100644 --- a/api/ai-agent/src/main/java/cn/qihangerp/erp/serviceImpl/AiService.java +++ b/api/ai-agent/src/main/java/cn/qihangerp/erp/serviceImpl/AiService.java @@ -15,6 +15,7 @@ import java.time.format.DateTimeFormatter; import java.util.ArrayList; import java.util.List; import cn.qihangerp.erp.service.OrderToolService; +import cn.qihangerp.erp.serviceImpl.ConversationHistoryManager; /** * AI服务类,使用LangChain4J调用Ollama模型处理聊天内容 @@ -129,9 +130,11 @@ public class AiService { * 处理聊天消息 * @param message 用户消息 * @param model 模型名称 + * @param sessionId 会话ID + * @param conversationHistory 对话历史 * @return AI回复 */ - public String processMessage(String message, String model) { + public String processMessage(String message, String model, String sessionId, List conversationHistory) { try { // 优先检查页面跳转规则 String pageRuleResponse = checkPageRules(message); @@ -146,13 +149,39 @@ public class AiService { // 替换消息中的"今天"为具体日期 message = message.replace("今天", today); - // 创建订单工具服务 - OrderToolService orderToolService = new OrderToolService(); + // 构建包含历史对话的提示词 + StringBuilder promptBuilder = new StringBuilder(); + promptBuilder.append("今天的日期是:").append(today).append("\n"); - // 执行AI服务,添加今天的日期信息 - String enhancedMessage = "今天的日期是:" + today + "\n" + message; + // 添加历史对话作为上下文 + if (conversationHistory != null && !conversationHistory.isEmpty()) { + promptBuilder.append("以下是之前的对话历史:\n"); + for (ConversationHistoryManager.Message msg : conversationHistory) { + if (msg.getRole().equals("user")) { + promptBuilder.append("用户: " + msg.getContent()).append("\n"); + } else { + promptBuilder.append("助手: " + msg.getContent()).append("\n"); + } + } + promptBuilder.append("\n当前用户消息:\n"); + } + + // 添加当前消息 + promptBuilder.append(message); + + String enhancedMessage = promptBuilder.toString(); System.out.println("发送给AI的消息: " + enhancedMessage); + // 尝试创建订单工具服务 + OrderToolService orderToolService = null; + try { + orderToolService = new OrderToolService(); + System.out.println("成功创建OrderToolService"); + } catch (Exception e) { + System.out.println("创建OrderToolService失败: " + e.getMessage()); + // 工具创建失败,仍然继续执行,只是不使用工具 + } + // 根据模型名称选择使用Ollama还是DeepSeek API OrderAiService aiService; if (model.startsWith("deepseek")) { @@ -171,10 +200,16 @@ public class AiService { .timeout(Duration.ofSeconds(300)) .build(); - aiService = AiServices.builder(OrderAiService.class) - .chatModel(deepSeekModelInstance) - .tools(orderToolService) - .build(); + if (orderToolService != null) { + aiService = AiServices.builder(OrderAiService.class) + .chatModel(deepSeekModelInstance) + .tools(orderToolService) + .build(); + } else { + aiService = AiServices.builder(OrderAiService.class) + .chatModel(deepSeekModelInstance) + .build(); + } System.out.println("使用DeepSeek API处理消息"); } catch (Exception e) { // 如果DeepSeek依赖不可用,返回错误消息 @@ -182,35 +217,63 @@ public class AiService { } } else { // 使用Ollama - OllamaChatModel modelInstance = OllamaChatModel.builder() - .baseUrl("http://localhost:11434") // Ollama默认端口 - .modelName(model) // 使用指定的模型 - .temperature(0.7) - .timeout(Duration.ofSeconds(300)) // 超时时间设置为300秒(5分钟) - .build(); - - aiService = AiServices.builder(OrderAiService.class) - .chatModel(modelInstance) - .tools(orderToolService) - .build(); - System.out.println("使用Ollama处理消息,模型: " + model); + try { + System.out.println("尝试连接Ollama,模型: " + model); + OllamaChatModel modelInstance = OllamaChatModel.builder() + .baseUrl("http://localhost:11434") // Ollama默认端口 + .modelName(model) // 使用指定的模型 + .temperature(0.7) + .timeout(Duration.ofSeconds(300)) // 超时时间设置为300秒(5分钟) + .build(); + + if (orderToolService != null) { + aiService = AiServices.builder(OrderAiService.class) + .chatModel(modelInstance) + .tools(orderToolService) + .build(); + } else { + aiService = AiServices.builder(OrderAiService.class) + .chatModel(modelInstance) + .build(); + } + System.out.println("成功创建Ollama模型实例,模型: " + model); + } catch (Exception e) { + System.out.println("创建Ollama模型实例失败: " + e.getMessage()); + return "错误: 无法连接到Ollama服务,请检查Ollama是否已启动,端口是否正确(默认11434)"; + } } - String result = aiService.chat(enhancedMessage); - System.out.println("AI返回的结果: " + result); - return result; + try { + System.out.println("开始调用AI服务"); + String result = aiService.chat(enhancedMessage); + System.out.println("AI返回的结果: " + result); + return result; + } catch (Exception e) { + System.out.println("调用AI服务失败: " + e.getMessage()); + return "错误: 调用AI服务失败: " + e.getMessage(); + } } catch (Exception e) { e.printStackTrace(); return "错误: " + e.getMessage(); } } + /** + * 处理聊天消息 + * @param message 用户消息 + * @param model 模型名称 + * @return AI回复 + */ + public String processMessage(String message, String model) { + return processMessage(message, model, null, null); + } + /** * 处理聊天消息(使用默认模型) * @param message 用户消息 * @return AI回复 */ public String processMessage(String message) { - return processMessage(message, "llama3"); + return processMessage(message, "llama3", null, null); } } diff --git a/api/ai-agent/src/main/java/cn/qihangerp/erp/serviceImpl/ConversationHistoryManager.java b/api/ai-agent/src/main/java/cn/qihangerp/erp/serviceImpl/ConversationHistoryManager.java new file mode 100644 index 00000000..8908fb6b --- /dev/null +++ b/api/ai-agent/src/main/java/cn/qihangerp/erp/serviceImpl/ConversationHistoryManager.java @@ -0,0 +1,112 @@ +package cn.qihangerp.erp.serviceImpl; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; + +/** + * 对话历史管理服务,用于保存和管理用户的对话历史 + */ +public class ConversationHistoryManager { + private static final Map> sessionHistoryMap = new ConcurrentHashMap<>(); + private static final AtomicLong messageIdCounter = new AtomicLong(0); + + /** + * 消息实体类 + */ + public static class Message { + private long id; + private String role; // user 或 assistant + private String content; + private long timestamp; + + public Message(String role, String content) { + this.id = messageIdCounter.incrementAndGet(); + this.role = role; + this.content = content; + this.timestamp = System.currentTimeMillis(); + } + + public long getId() { + return id; + } + + public String getRole() { + return role; + } + + public String getContent() { + return content; + } + + public long getTimestamp() { + return timestamp; + } + } + + /** + * 添加消息到对话历史 + * @param sessionId 会话ID + * @param role 角色 + * @param content 消息内容 + */ + public void addMessage(String sessionId, String role, String content) { + if (sessionId == null) { + return; + } + sessionHistoryMap.computeIfAbsent(sessionId, k -> new ArrayList<>()) + .add(new Message(role, content)); + } + + /** + * 获取会话的所有对话历史 + * @param sessionId 会话ID + * @return 对话历史列表 + */ + public List getConversationHistory(String sessionId) { + if (sessionId == null) { + return new ArrayList<>(); + } + return sessionHistoryMap.getOrDefault(sessionId, new ArrayList<>()); + } + + /** + * 获取会话的最近几条对话历史 + * @param sessionId 会话ID + * @param limit 限制数量 + * @return 最近的对话历史列表 + */ + public List getRecentConversationHistory(String sessionId, int limit) { + if (sessionId == null) { + return new ArrayList<>(); + } + List history = sessionHistoryMap.getOrDefault(sessionId, new ArrayList<>()); + int startIndex = Math.max(0, history.size() - limit); + return history.subList(startIndex, history.size()); + } + + /** + * 清空会话的对话历史 + * @param sessionId 会话ID + */ + public void clearConversationHistory(String sessionId) { + if (sessionId != null) { + sessionHistoryMap.remove(sessionId); + } + } + + /** + * 获取会话的对话历史数量 + * @param sessionId 会话ID + * @return 对话历史数量 + */ + public int getMessageCount(String sessionId) { + if (sessionId == null) { + return 0; + } + List history = sessionHistoryMap.get(sessionId); + return history != null ? history.size() : 0; + } +} \ No newline at end of file diff --git a/api/ai-agent/src/main/java/cn/qihangerp/erp/serviceImpl/SessionManager.java b/api/ai-agent/src/main/java/cn/qihangerp/erp/serviceImpl/SessionManager.java new file mode 100644 index 00000000..a90857a9 --- /dev/null +++ b/api/ai-agent/src/main/java/cn/qihangerp/erp/serviceImpl/SessionManager.java @@ -0,0 +1,75 @@ +package cn.qihangerp.erp.serviceImpl; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.UUID; + +/** + * 会话管理服务,用于管理用户的会话ID + */ +public class SessionManager { + private static final Map userIdSessionMap = new ConcurrentHashMap<>(); + private static final Map sessionUserIdMap = new ConcurrentHashMap<>(); + + /** + * 获取或创建用户的会话ID + * @param userId 用户ID + * @return 会话ID + */ + public String getOrCreateSessionId(Long userId) { + if (userId == null) { + return null; + } + return userIdSessionMap.computeIfAbsent(userId, k -> { + String sessionId = UUID.randomUUID().toString(); + sessionUserIdMap.put(sessionId, userId); + return sessionId; + }); + } + + /** + * 根据用户ID获取会话ID + * @param userId 用户ID + * @return 会话ID + */ + public String getSessionId(Long userId) { + if (userId == null) { + return null; + } + return userIdSessionMap.get(userId); + } + + /** + * 根据会话ID获取用户ID + * @param sessionId 会话ID + * @return 用户ID + */ + public Long getUserIdBySessionId(String sessionId) { + if (sessionId == null) { + return null; + } + return sessionUserIdMap.get(sessionId); + } + + /** + * 移除用户的会话 + * @param userId 用户ID + */ + public void removeSession(Long userId) { + if (userId == null) { + return; + } + String sessionId = userIdSessionMap.remove(userId); + if (sessionId != null) { + sessionUserIdMap.remove(sessionId); + } + } + + /** + * 获取当前活跃会话数 + * @return 活跃会话数 + */ + public int getSessionCount() { + return userIdSessionMap.size(); + } +} \ No newline at end of file diff --git a/vue/src/api/ai/ollama.js b/vue/src/api/ai/ollama.js index f5d99b8e..c2a1d80d 100644 --- a/vue/src/api/ai/ollama.js +++ b/vue/src/api/ai/ollama.js @@ -7,3 +7,12 @@ export function getOllamaModels() { method: 'get' }) } + +// 获取对话历史 +export function getConversationHistory(token) { + return request({ + url: '/api/ai-agent/sse/history', + method: 'get', + params: { token } + }) +} diff --git a/vue/src/views/index.vue b/vue/src/views/index.vue index 2b7ba653..b68a453b 100644 --- a/vue/src/views/index.vue +++ b/vue/src/views/index.vue @@ -128,7 +128,7 @@