增加聊天历史记录
This commit is contained in:
parent
999dd279dc
commit
fc0aa3d6da
|
|
@ -0,0 +1,29 @@
|
||||||
|
package cn.qihangerp.erp.config;
|
||||||
|
|
||||||
|
import org.springframework.context.annotation.Bean;
|
||||||
|
import org.springframework.context.annotation.Configuration;
|
||||||
|
import cn.qihangerp.erp.serviceImpl.ConversationHistoryManager;
|
||||||
|
import cn.qihangerp.erp.serviceImpl.SessionManager;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* AI相关配置类
|
||||||
|
*/
|
||||||
|
@Configuration
|
||||||
|
public class AiConfig {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 会话管理服务Bean
|
||||||
|
*/
|
||||||
|
@Bean
|
||||||
|
public SessionManager sessionManager() {
|
||||||
|
return new SessionManager();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 对话历史管理服务Bean
|
||||||
|
*/
|
||||||
|
@Bean
|
||||||
|
public ConversationHistoryManager conversationHistoryManager() {
|
||||||
|
return new ConversationHistoryManager();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -9,9 +9,15 @@ import org.springframework.web.bind.annotation.RequestParam;
|
||||||
import org.springframework.web.bind.annotation.RestController;
|
import org.springframework.web.bind.annotation.RestController;
|
||||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||||
|
|
||||||
|
import cn.qihangerp.security.LoginUser;
|
||||||
|
import cn.qihangerp.security.TokenService;
|
||||||
import cn.qihangerp.erp.serviceImpl.AiService;
|
import cn.qihangerp.erp.serviceImpl.AiService;
|
||||||
|
import cn.qihangerp.erp.serviceImpl.ConversationHistoryManager;
|
||||||
|
import cn.qihangerp.erp.serviceImpl.SessionManager;
|
||||||
|
|
||||||
|
import jakarta.servlet.http.HttpServletRequest;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
import java.util.concurrent.Executors;
|
import java.util.concurrent.Executors;
|
||||||
|
|
@ -24,19 +30,47 @@ import java.util.concurrent.TimeUnit;
|
||||||
public class SseController {
|
public class SseController {
|
||||||
|
|
||||||
private static final Map<String, SseEmitter> emitters = new ConcurrentHashMap<>();
|
private static final Map<String, SseEmitter> emitters = new ConcurrentHashMap<>();
|
||||||
|
private static final Map<String, Long> clientUserIdMap = new ConcurrentHashMap<>();
|
||||||
private final ScheduledExecutorService executorService = Executors.newScheduledThreadPool(1);
|
private final ScheduledExecutorService executorService = Executors.newScheduledThreadPool(1);
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private AiService aiService;
|
private AiService aiService;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private SessionManager sessionManager;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private ConversationHistoryManager conversationHistoryManager;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private TokenService tokenService;
|
||||||
|
|
||||||
@GetMapping(value = "/connect", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
@GetMapping(value = "/connect", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||||||
public SseEmitter connect(@RequestParam String clientId) {
|
public SseEmitter connect(@RequestParam String clientId, @RequestParam String token, HttpServletRequest request) {
|
||||||
SseEmitter emitter = new SseEmitter(Long.MAX_VALUE);
|
SseEmitter emitter = new SseEmitter(Long.MAX_VALUE);
|
||||||
emitters.put(clientId, emitter);
|
emitters.put(clientId, emitter);
|
||||||
|
|
||||||
|
// 从token中获取用户信息
|
||||||
|
try {
|
||||||
|
LoginUser loginUser = tokenService.getLoginUser(request);
|
||||||
|
if (loginUser != null) {
|
||||||
|
Long userId = loginUser.getUserId();
|
||||||
|
clientUserIdMap.put(clientId, userId);
|
||||||
|
log.info("用户 {} 连接成功,clientId: {}", userId, clientId);
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("获取用户信息失败: {}", e.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
// 设置超时处理
|
// 设置超时处理
|
||||||
emitter.onTimeout(() -> emitters.remove(clientId));
|
emitter.onTimeout(() -> {
|
||||||
emitter.onCompletion(() -> emitters.remove(clientId));
|
emitters.remove(clientId);
|
||||||
|
clientUserIdMap.remove(clientId);
|
||||||
|
});
|
||||||
|
emitter.onCompletion(() -> {
|
||||||
|
emitters.remove(clientId);
|
||||||
|
clientUserIdMap.remove(clientId);
|
||||||
|
});
|
||||||
|
|
||||||
// 发送连接成功消息
|
// 发送连接成功消息
|
||||||
try {
|
try {
|
||||||
|
|
@ -45,6 +79,7 @@ public class SseController {
|
||||||
.data("连接成功"));
|
.data("连接成功"));
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
emitters.remove(clientId);
|
emitters.remove(clientId);
|
||||||
|
clientUserIdMap.remove(clientId);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 定期发送心跳
|
// 定期发送心跳
|
||||||
|
|
@ -57,6 +92,7 @@ public class SseController {
|
||||||
}
|
}
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
emitters.remove(clientId);
|
emitters.remove(clientId);
|
||||||
|
clientUserIdMap.remove(clientId);
|
||||||
}
|
}
|
||||||
}, 30, 30, TimeUnit.SECONDS);
|
}, 30, 30, TimeUnit.SECONDS);
|
||||||
|
|
||||||
|
|
@ -64,14 +100,46 @@ public class SseController {
|
||||||
}
|
}
|
||||||
|
|
||||||
@GetMapping("/send")
|
@GetMapping("/send")
|
||||||
public String sendMessage(@RequestParam String clientId, @RequestParam String message, @RequestParam(required = false, defaultValue = "llama3") String model) {
|
public String sendMessage(@RequestParam String clientId, @RequestParam String message, @RequestParam(required = false, defaultValue = "llama3") String model, @RequestParam String token, HttpServletRequest request) {
|
||||||
log.info("=============来新消息了!");
|
log.info("=============来新消息了!");
|
||||||
SseEmitter emitter = emitters.get(clientId);
|
SseEmitter emitter = emitters.get(clientId);
|
||||||
if (emitter != null) {
|
if (emitter != null) {
|
||||||
try {
|
try {
|
||||||
// 使用AiService处理消息,传递模型参数
|
// 从token中获取用户信息
|
||||||
String response = aiService.processMessage(message, model);
|
LoginUser loginUser = tokenService.getLoginUser(request);
|
||||||
log.info("==========AI回复:{}",response);
|
Long userId = null;
|
||||||
|
if (loginUser != null) {
|
||||||
|
userId = loginUser.getUserId();
|
||||||
|
clientUserIdMap.put(clientId, userId);
|
||||||
|
} else {
|
||||||
|
// 尝试从映射中获取用户ID
|
||||||
|
userId = clientUserIdMap.get(clientId);
|
||||||
|
}
|
||||||
|
|
||||||
|
String sessionId = null;
|
||||||
|
if (userId != null) {
|
||||||
|
// 获取或创建会话ID
|
||||||
|
sessionId = sessionManager.getOrCreateSessionId(userId);
|
||||||
|
log.info("用户 {} 的会话ID: {}", userId, sessionId);
|
||||||
|
|
||||||
|
// 添加用户消息到对话历史
|
||||||
|
conversationHistoryManager.addMessage(sessionId, "user", message);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取对话历史
|
||||||
|
List<ConversationHistoryManager.Message> conversationHistory = null;
|
||||||
|
if (sessionId != null) {
|
||||||
|
conversationHistory = conversationHistoryManager.getRecentConversationHistory(sessionId, 10); // 只获取最近10条消息作为上下文
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用AiService处理消息,传递模型参数、会话ID和对话历史
|
||||||
|
String response = aiService.processMessage(message, model, sessionId, conversationHistory);
|
||||||
|
log.info("==========AI回复:{}", response);
|
||||||
|
|
||||||
|
// 如果有会话ID,添加AI回复到对话历史
|
||||||
|
if (sessionId != null) {
|
||||||
|
conversationHistoryManager.addMessage(sessionId, "assistant", response);
|
||||||
|
}
|
||||||
|
|
||||||
// 检查响应是否已经是JSON格式(以{开头)
|
// 检查响应是否已经是JSON格式(以{开头)
|
||||||
String jsonResponse;
|
String jsonResponse;
|
||||||
|
|
@ -80,7 +148,7 @@ public class SseController {
|
||||||
jsonResponse = response;
|
jsonResponse = response;
|
||||||
} else {
|
} else {
|
||||||
// 否则包装成JSON格式
|
// 否则包装成JSON格式
|
||||||
jsonResponse = String.format("{\"text\": \"%s\"}", response.replace("\"", "\\\"").replace("\n", "\\n"));
|
jsonResponse = String.format("{\"text\": \"%s\", \"sessionId\": \"%s\"}", response.replace("\"", "\\\"").replace("\n", "\\n"), sessionId != null ? sessionId : "");
|
||||||
}
|
}
|
||||||
|
|
||||||
// 发送JSON格式的消息
|
// 发送JSON格式的消息
|
||||||
|
|
@ -93,7 +161,8 @@ public class SseController {
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("消息处理失败: {}", e.getMessage());
|
log.error("消息处理失败: {}", e.getMessage());
|
||||||
emitters.remove(clientId);
|
emitters.remove(clientId);
|
||||||
return "消息发送失败";
|
clientUserIdMap.remove(clientId);
|
||||||
|
return "消息发送失败: " + e.getMessage();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return "客户端不存在";
|
return "客户端不存在";
|
||||||
|
|
@ -102,6 +171,7 @@ public class SseController {
|
||||||
@GetMapping("/disconnect")
|
@GetMapping("/disconnect")
|
||||||
public String disconnect(@RequestParam String clientId) {
|
public String disconnect(@RequestParam String clientId) {
|
||||||
SseEmitter emitter = emitters.remove(clientId);
|
SseEmitter emitter = emitters.remove(clientId);
|
||||||
|
clientUserIdMap.remove(clientId);
|
||||||
if (emitter != null) {
|
if (emitter != null) {
|
||||||
emitter.complete();
|
emitter.complete();
|
||||||
return "断开连接成功";
|
return "断开连接成功";
|
||||||
|
|
@ -111,6 +181,28 @@ public class SseController {
|
||||||
|
|
||||||
@GetMapping("/status")
|
@GetMapping("/status")
|
||||||
public String getStatus() {
|
public String getStatus() {
|
||||||
return "当前连接数: " + emitters.size();
|
return "当前连接数: " + emitters.size() + ", 活跃会话数: " + sessionManager.getSessionCount();
|
||||||
|
}
|
||||||
|
|
||||||
|
@GetMapping("/history")
|
||||||
|
public Object getConversationHistory(@RequestParam String token, HttpServletRequest request) {
|
||||||
|
try {
|
||||||
|
// 从token中获取用户信息
|
||||||
|
LoginUser loginUser = tokenService.getLoginUser(request);
|
||||||
|
if (loginUser != null) {
|
||||||
|
Long userId = loginUser.getUserId();
|
||||||
|
// 获取用户的会话ID
|
||||||
|
String sessionId = sessionManager.getSessionId(userId);
|
||||||
|
if (sessionId != null) {
|
||||||
|
// 获取对话历史
|
||||||
|
var history = conversationHistoryManager.getConversationHistory(sessionId);
|
||||||
|
return Map.of("success", true, "data", history, "sessionId", sessionId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Map.of("success", false, "message", "获取对话历史失败");
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("获取对话历史失败: {}", e.getMessage());
|
||||||
|
return Map.of("success", false, "message", "获取对话历史失败: " + e.getMessage());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -15,6 +15,7 @@ import java.time.format.DateTimeFormatter;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import cn.qihangerp.erp.service.OrderToolService;
|
import cn.qihangerp.erp.service.OrderToolService;
|
||||||
|
import cn.qihangerp.erp.serviceImpl.ConversationHistoryManager;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* AI服务类,使用LangChain4J调用Ollama模型处理聊天内容
|
* AI服务类,使用LangChain4J调用Ollama模型处理聊天内容
|
||||||
|
|
@ -129,9 +130,11 @@ public class AiService {
|
||||||
* 处理聊天消息
|
* 处理聊天消息
|
||||||
* @param message 用户消息
|
* @param message 用户消息
|
||||||
* @param model 模型名称
|
* @param model 模型名称
|
||||||
|
* @param sessionId 会话ID
|
||||||
|
* @param conversationHistory 对话历史
|
||||||
* @return AI回复
|
* @return AI回复
|
||||||
*/
|
*/
|
||||||
public String processMessage(String message, String model) {
|
public String processMessage(String message, String model, String sessionId, List<ConversationHistoryManager.Message> conversationHistory) {
|
||||||
try {
|
try {
|
||||||
// 优先检查页面跳转规则
|
// 优先检查页面跳转规则
|
||||||
String pageRuleResponse = checkPageRules(message);
|
String pageRuleResponse = checkPageRules(message);
|
||||||
|
|
@ -146,13 +149,39 @@ public class AiService {
|
||||||
// 替换消息中的"今天"为具体日期
|
// 替换消息中的"今天"为具体日期
|
||||||
message = message.replace("今天", today);
|
message = message.replace("今天", today);
|
||||||
|
|
||||||
// 创建订单工具服务
|
// 构建包含历史对话的提示词
|
||||||
OrderToolService orderToolService = new OrderToolService();
|
StringBuilder promptBuilder = new StringBuilder();
|
||||||
|
promptBuilder.append("今天的日期是:").append(today).append("\n");
|
||||||
|
|
||||||
// 执行AI服务,添加今天的日期信息
|
// 添加历史对话作为上下文
|
||||||
String enhancedMessage = "今天的日期是:" + today + "\n" + message;
|
if (conversationHistory != null && !conversationHistory.isEmpty()) {
|
||||||
|
promptBuilder.append("以下是之前的对话历史:\n");
|
||||||
|
for (ConversationHistoryManager.Message msg : conversationHistory) {
|
||||||
|
if (msg.getRole().equals("user")) {
|
||||||
|
promptBuilder.append("用户: " + msg.getContent()).append("\n");
|
||||||
|
} else {
|
||||||
|
promptBuilder.append("助手: " + msg.getContent()).append("\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
promptBuilder.append("\n当前用户消息:\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加当前消息
|
||||||
|
promptBuilder.append(message);
|
||||||
|
|
||||||
|
String enhancedMessage = promptBuilder.toString();
|
||||||
System.out.println("发送给AI的消息: " + enhancedMessage);
|
System.out.println("发送给AI的消息: " + enhancedMessage);
|
||||||
|
|
||||||
|
// 尝试创建订单工具服务
|
||||||
|
OrderToolService orderToolService = null;
|
||||||
|
try {
|
||||||
|
orderToolService = new OrderToolService();
|
||||||
|
System.out.println("成功创建OrderToolService");
|
||||||
|
} catch (Exception e) {
|
||||||
|
System.out.println("创建OrderToolService失败: " + e.getMessage());
|
||||||
|
// 工具创建失败,仍然继续执行,只是不使用工具
|
||||||
|
}
|
||||||
|
|
||||||
// 根据模型名称选择使用Ollama还是DeepSeek API
|
// 根据模型名称选择使用Ollama还是DeepSeek API
|
||||||
OrderAiService aiService;
|
OrderAiService aiService;
|
||||||
if (model.startsWith("deepseek")) {
|
if (model.startsWith("deepseek")) {
|
||||||
|
|
@ -171,10 +200,16 @@ public class AiService {
|
||||||
.timeout(Duration.ofSeconds(300))
|
.timeout(Duration.ofSeconds(300))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
aiService = AiServices.builder(OrderAiService.class)
|
if (orderToolService != null) {
|
||||||
.chatModel(deepSeekModelInstance)
|
aiService = AiServices.builder(OrderAiService.class)
|
||||||
.tools(orderToolService)
|
.chatModel(deepSeekModelInstance)
|
||||||
.build();
|
.tools(orderToolService)
|
||||||
|
.build();
|
||||||
|
} else {
|
||||||
|
aiService = AiServices.builder(OrderAiService.class)
|
||||||
|
.chatModel(deepSeekModelInstance)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
System.out.println("使用DeepSeek API处理消息");
|
System.out.println("使用DeepSeek API处理消息");
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
// 如果DeepSeek依赖不可用,返回错误消息
|
// 如果DeepSeek依赖不可用,返回错误消息
|
||||||
|
|
@ -182,35 +217,63 @@ public class AiService {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// 使用Ollama
|
// 使用Ollama
|
||||||
OllamaChatModel modelInstance = OllamaChatModel.builder()
|
try {
|
||||||
.baseUrl("http://localhost:11434") // Ollama默认端口
|
System.out.println("尝试连接Ollama,模型: " + model);
|
||||||
.modelName(model) // 使用指定的模型
|
OllamaChatModel modelInstance = OllamaChatModel.builder()
|
||||||
.temperature(0.7)
|
.baseUrl("http://localhost:11434") // Ollama默认端口
|
||||||
.timeout(Duration.ofSeconds(300)) // 超时时间设置为300秒(5分钟)
|
.modelName(model) // 使用指定的模型
|
||||||
.build();
|
.temperature(0.7)
|
||||||
|
.timeout(Duration.ofSeconds(300)) // 超时时间设置为300秒(5分钟)
|
||||||
|
.build();
|
||||||
|
|
||||||
aiService = AiServices.builder(OrderAiService.class)
|
if (orderToolService != null) {
|
||||||
.chatModel(modelInstance)
|
aiService = AiServices.builder(OrderAiService.class)
|
||||||
.tools(orderToolService)
|
.chatModel(modelInstance)
|
||||||
.build();
|
.tools(orderToolService)
|
||||||
System.out.println("使用Ollama处理消息,模型: " + model);
|
.build();
|
||||||
|
} else {
|
||||||
|
aiService = AiServices.builder(OrderAiService.class)
|
||||||
|
.chatModel(modelInstance)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
System.out.println("成功创建Ollama模型实例,模型: " + model);
|
||||||
|
} catch (Exception e) {
|
||||||
|
System.out.println("创建Ollama模型实例失败: " + e.getMessage());
|
||||||
|
return "错误: 无法连接到Ollama服务,请检查Ollama是否已启动,端口是否正确(默认11434)";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
String result = aiService.chat(enhancedMessage);
|
try {
|
||||||
System.out.println("AI返回的结果: " + result);
|
System.out.println("开始调用AI服务");
|
||||||
return result;
|
String result = aiService.chat(enhancedMessage);
|
||||||
|
System.out.println("AI返回的结果: " + result);
|
||||||
|
return result;
|
||||||
|
} catch (Exception e) {
|
||||||
|
System.out.println("调用AI服务失败: " + e.getMessage());
|
||||||
|
return "错误: 调用AI服务失败: " + e.getMessage();
|
||||||
|
}
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
return "错误: " + e.getMessage();
|
return "错误: " + e.getMessage();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 处理聊天消息
|
||||||
|
* @param message 用户消息
|
||||||
|
* @param model 模型名称
|
||||||
|
* @return AI回复
|
||||||
|
*/
|
||||||
|
public String processMessage(String message, String model) {
|
||||||
|
return processMessage(message, model, null, null);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 处理聊天消息(使用默认模型)
|
* 处理聊天消息(使用默认模型)
|
||||||
* @param message 用户消息
|
* @param message 用户消息
|
||||||
* @return AI回复
|
* @return AI回复
|
||||||
*/
|
*/
|
||||||
public String processMessage(String message) {
|
public String processMessage(String message) {
|
||||||
return processMessage(message, "llama3");
|
return processMessage(message, "llama3", null, null);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,112 @@
|
||||||
|
package cn.qihangerp.erp.serviceImpl;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
import java.util.concurrent.atomic.AtomicLong;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 对话历史管理服务,用于保存和管理用户的对话历史
|
||||||
|
*/
|
||||||
|
public class ConversationHistoryManager {
|
||||||
|
private static final Map<String, List<Message>> sessionHistoryMap = new ConcurrentHashMap<>();
|
||||||
|
private static final AtomicLong messageIdCounter = new AtomicLong(0);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 消息实体类
|
||||||
|
*/
|
||||||
|
public static class Message {
|
||||||
|
private long id;
|
||||||
|
private String role; // user 或 assistant
|
||||||
|
private String content;
|
||||||
|
private long timestamp;
|
||||||
|
|
||||||
|
public Message(String role, String content) {
|
||||||
|
this.id = messageIdCounter.incrementAndGet();
|
||||||
|
this.role = role;
|
||||||
|
this.content = content;
|
||||||
|
this.timestamp = System.currentTimeMillis();
|
||||||
|
}
|
||||||
|
|
||||||
|
public long getId() {
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getRole() {
|
||||||
|
return role;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getContent() {
|
||||||
|
return content;
|
||||||
|
}
|
||||||
|
|
||||||
|
public long getTimestamp() {
|
||||||
|
return timestamp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 添加消息到对话历史
|
||||||
|
* @param sessionId 会话ID
|
||||||
|
* @param role 角色
|
||||||
|
* @param content 消息内容
|
||||||
|
*/
|
||||||
|
public void addMessage(String sessionId, String role, String content) {
|
||||||
|
if (sessionId == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
sessionHistoryMap.computeIfAbsent(sessionId, k -> new ArrayList<>())
|
||||||
|
.add(new Message(role, content));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取会话的所有对话历史
|
||||||
|
* @param sessionId 会话ID
|
||||||
|
* @return 对话历史列表
|
||||||
|
*/
|
||||||
|
public List<Message> getConversationHistory(String sessionId) {
|
||||||
|
if (sessionId == null) {
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
return sessionHistoryMap.getOrDefault(sessionId, new ArrayList<>());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取会话的最近几条对话历史
|
||||||
|
* @param sessionId 会话ID
|
||||||
|
* @param limit 限制数量
|
||||||
|
* @return 最近的对话历史列表
|
||||||
|
*/
|
||||||
|
public List<Message> getRecentConversationHistory(String sessionId, int limit) {
|
||||||
|
if (sessionId == null) {
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
List<Message> history = sessionHistoryMap.getOrDefault(sessionId, new ArrayList<>());
|
||||||
|
int startIndex = Math.max(0, history.size() - limit);
|
||||||
|
return history.subList(startIndex, history.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 清空会话的对话历史
|
||||||
|
* @param sessionId 会话ID
|
||||||
|
*/
|
||||||
|
public void clearConversationHistory(String sessionId) {
|
||||||
|
if (sessionId != null) {
|
||||||
|
sessionHistoryMap.remove(sessionId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取会话的对话历史数量
|
||||||
|
* @param sessionId 会话ID
|
||||||
|
* @return 对话历史数量
|
||||||
|
*/
|
||||||
|
public int getMessageCount(String sessionId) {
|
||||||
|
if (sessionId == null) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
List<Message> history = sessionHistoryMap.get(sessionId);
|
||||||
|
return history != null ? history.size() : 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,75 @@
|
||||||
|
package cn.qihangerp.erp.serviceImpl;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
import java.util.UUID;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 会话管理服务,用于管理用户的会话ID
|
||||||
|
*/
|
||||||
|
public class SessionManager {
|
||||||
|
private static final Map<Long, String> userIdSessionMap = new ConcurrentHashMap<>();
|
||||||
|
private static final Map<String, Long> sessionUserIdMap = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取或创建用户的会话ID
|
||||||
|
* @param userId 用户ID
|
||||||
|
* @return 会话ID
|
||||||
|
*/
|
||||||
|
public String getOrCreateSessionId(Long userId) {
|
||||||
|
if (userId == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return userIdSessionMap.computeIfAbsent(userId, k -> {
|
||||||
|
String sessionId = UUID.randomUUID().toString();
|
||||||
|
sessionUserIdMap.put(sessionId, userId);
|
||||||
|
return sessionId;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 根据用户ID获取会话ID
|
||||||
|
* @param userId 用户ID
|
||||||
|
* @return 会话ID
|
||||||
|
*/
|
||||||
|
public String getSessionId(Long userId) {
|
||||||
|
if (userId == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return userIdSessionMap.get(userId);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 根据会话ID获取用户ID
|
||||||
|
* @param sessionId 会话ID
|
||||||
|
* @return 用户ID
|
||||||
|
*/
|
||||||
|
public Long getUserIdBySessionId(String sessionId) {
|
||||||
|
if (sessionId == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return sessionUserIdMap.get(sessionId);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 移除用户的会话
|
||||||
|
* @param userId 用户ID
|
||||||
|
*/
|
||||||
|
public void removeSession(Long userId) {
|
||||||
|
if (userId == null) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
String sessionId = userIdSessionMap.remove(userId);
|
||||||
|
if (sessionId != null) {
|
||||||
|
sessionUserIdMap.remove(sessionId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取当前活跃会话数
|
||||||
|
* @return 活跃会话数
|
||||||
|
*/
|
||||||
|
public int getSessionCount() {
|
||||||
|
return userIdSessionMap.size();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -7,3 +7,12 @@ export function getOllamaModels() {
|
||||||
method: 'get'
|
method: 'get'
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 获取对话历史
|
||||||
|
export function getConversationHistory(token) {
|
||||||
|
return request({
|
||||||
|
url: '/api/ai-agent/sse/history',
|
||||||
|
method: 'get',
|
||||||
|
params: { token }
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -128,7 +128,7 @@
|
||||||
|
|
||||||
<script>
|
<script>
|
||||||
import { todayDaily } from "@/api/report/report";
|
import { todayDaily } from "@/api/report/report";
|
||||||
import { getOllamaModels } from "@/api/ai/ollama";
|
import { getOllamaModels, getConversationHistory } from "@/api/ai/ollama";
|
||||||
import { getToken } from "@/utils/auth";
|
import { getToken } from "@/utils/auth";
|
||||||
import MarkdownIt from 'markdown-it';
|
import MarkdownIt from 'markdown-it';
|
||||||
|
|
||||||
|
|
@ -177,11 +177,13 @@ export default {
|
||||||
isSseConnected: false,
|
isSseConnected: false,
|
||||||
isLoading: false,
|
isLoading: false,
|
||||||
selectedModel: 'deepseek',
|
selectedModel: 'deepseek',
|
||||||
models: []
|
models: [],
|
||||||
|
sessionId: ''
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
mounted() {
|
mounted() {
|
||||||
this.loadSystemStats();
|
this.loadSystemStats();
|
||||||
|
this.loadConversationHistory();
|
||||||
this.initSse();
|
this.initSse();
|
||||||
this.loadOllamaModels();
|
this.loadOllamaModels();
|
||||||
},
|
},
|
||||||
|
|
@ -204,6 +206,31 @@ export default {
|
||||||
console.error('获取模型列表失败:', error);
|
console.error('获取模型列表失败:', error);
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
|
loadConversationHistory() {
|
||||||
|
const token = getToken();
|
||||||
|
if (token) {
|
||||||
|
getConversationHistory(token).then(response => {
|
||||||
|
if (response.success && response.data) {
|
||||||
|
// 清空当前消息列表
|
||||||
|
this.messages = [];
|
||||||
|
// 添加历史消息
|
||||||
|
response.data.forEach(msg => {
|
||||||
|
this.messages.push({
|
||||||
|
content: msg.content,
|
||||||
|
time: this.formatTime(new Date(msg.timestamp)),
|
||||||
|
isMe: msg.role === 'user',
|
||||||
|
avatar: ''
|
||||||
|
});
|
||||||
|
});
|
||||||
|
// 保存会话ID
|
||||||
|
this.sessionId = response.sessionId;
|
||||||
|
console.log('加载对话历史成功:', response.data.length, '条消息');
|
||||||
|
}
|
||||||
|
}).catch(error => {
|
||||||
|
console.error('获取对话历史失败:', error);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
},
|
||||||
initSse() {
|
initSse() {
|
||||||
// 生成唯一客户端ID
|
// 生成唯一客户端ID
|
||||||
this.clientId = 'client_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9);
|
this.clientId = 'client_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue