增加聊天历史记录

This commit is contained in:
启航 2026-03-10 14:41:10 +08:00
parent 999dd279dc
commit fc0aa3d6da
7 changed files with 444 additions and 37 deletions

View File

@ -0,0 +1,29 @@
package cn.qihangerp.erp.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import cn.qihangerp.erp.serviceImpl.ConversationHistoryManager;
import cn.qihangerp.erp.serviceImpl.SessionManager;
/**
* AI相关配置类
*/
@Configuration
public class AiConfig {
/**
* 会话管理服务Bean
*/
@Bean
public SessionManager sessionManager() {
return new SessionManager();
}
/**
* 对话历史管理服务Bean
*/
@Bean
public ConversationHistoryManager conversationHistoryManager() {
return new ConversationHistoryManager();
}
}

View File

@ -9,9 +9,15 @@ import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import cn.qihangerp.security.LoginUser;
import cn.qihangerp.security.TokenService;
import cn.qihangerp.erp.serviceImpl.AiService;
import cn.qihangerp.erp.serviceImpl.ConversationHistoryManager;
import cn.qihangerp.erp.serviceImpl.SessionManager;
import jakarta.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
@ -24,19 +30,47 @@ import java.util.concurrent.TimeUnit;
public class SseController {
private static final Map<String, SseEmitter> emitters = new ConcurrentHashMap<>();
private static final Map<String, Long> clientUserIdMap = new ConcurrentHashMap<>();
private final ScheduledExecutorService executorService = Executors.newScheduledThreadPool(1);
@Autowired
private AiService aiService;
@Autowired
private SessionManager sessionManager;
@Autowired
private ConversationHistoryManager conversationHistoryManager;
@Autowired
private TokenService tokenService;
@GetMapping(value = "/connect", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public SseEmitter connect(@RequestParam String clientId) {
public SseEmitter connect(@RequestParam String clientId, @RequestParam String token, HttpServletRequest request) {
SseEmitter emitter = new SseEmitter(Long.MAX_VALUE);
emitters.put(clientId, emitter);
// 从token中获取用户信息
try {
LoginUser loginUser = tokenService.getLoginUser(request);
if (loginUser != null) {
Long userId = loginUser.getUserId();
clientUserIdMap.put(clientId, userId);
log.info("用户 {} 连接成功clientId: {}", userId, clientId);
}
} catch (Exception e) {
log.error("获取用户信息失败: {}", e.getMessage());
}
// 设置超时处理
emitter.onTimeout(() -> emitters.remove(clientId));
emitter.onCompletion(() -> emitters.remove(clientId));
emitter.onTimeout(() -> {
emitters.remove(clientId);
clientUserIdMap.remove(clientId);
});
emitter.onCompletion(() -> {
emitters.remove(clientId);
clientUserIdMap.remove(clientId);
});
// 发送连接成功消息
try {
@ -45,6 +79,7 @@ public class SseController {
.data("连接成功"));
} catch (IOException e) {
emitters.remove(clientId);
clientUserIdMap.remove(clientId);
}
// 定期发送心跳
@ -57,6 +92,7 @@ public class SseController {
}
} catch (IOException e) {
emitters.remove(clientId);
clientUserIdMap.remove(clientId);
}
}, 30, 30, TimeUnit.SECONDS);
@ -64,15 +100,47 @@ public class SseController {
}
@GetMapping("/send")
public String sendMessage(@RequestParam String clientId, @RequestParam String message, @RequestParam(required = false, defaultValue = "llama3") String model) {
public String sendMessage(@RequestParam String clientId, @RequestParam String message, @RequestParam(required = false, defaultValue = "llama3") String model, @RequestParam String token, HttpServletRequest request) {
log.info("=============来新消息了!");
SseEmitter emitter = emitters.get(clientId);
if (emitter != null) {
try {
// 使用AiService处理消息传递模型参数
String response = aiService.processMessage(message, model);
// 从token中获取用户信息
LoginUser loginUser = tokenService.getLoginUser(request);
Long userId = null;
if (loginUser != null) {
userId = loginUser.getUserId();
clientUserIdMap.put(clientId, userId);
} else {
// 尝试从映射中获取用户ID
userId = clientUserIdMap.get(clientId);
}
String sessionId = null;
if (userId != null) {
// 获取或创建会话ID
sessionId = sessionManager.getOrCreateSessionId(userId);
log.info("用户 {} 的会话ID: {}", userId, sessionId);
// 添加用户消息到对话历史
conversationHistoryManager.addMessage(sessionId, "user", message);
}
// 获取对话历史
List<ConversationHistoryManager.Message> conversationHistory = null;
if (sessionId != null) {
conversationHistory = conversationHistoryManager.getRecentConversationHistory(sessionId, 10); // 只获取最近10条消息作为上下文
}
// 使用AiService处理消息传递模型参数会话ID和对话历史
String response = aiService.processMessage(message, model, sessionId, conversationHistory);
log.info("==========AI回复{}", response);
// 如果有会话ID添加AI回复到对话历史
if (sessionId != null) {
conversationHistoryManager.addMessage(sessionId, "assistant", response);
}
// 检查响应是否已经是JSON格式{开头
String jsonResponse;
if (response.trim().startsWith("{")) {
@ -80,7 +148,7 @@ public class SseController {
jsonResponse = response;
} else {
// 否则包装成JSON格式
jsonResponse = String.format("{\"text\": \"%s\"}", response.replace("\"", "\\\"").replace("\n", "\\n"));
jsonResponse = String.format("{\"text\": \"%s\", \"sessionId\": \"%s\"}", response.replace("\"", "\\\"").replace("\n", "\\n"), sessionId != null ? sessionId : "");
}
// 发送JSON格式的消息
@ -93,7 +161,8 @@ public class SseController {
} catch (Exception e) {
log.error("消息处理失败: {}", e.getMessage());
emitters.remove(clientId);
return "消息发送失败";
clientUserIdMap.remove(clientId);
return "消息发送失败: " + e.getMessage();
}
}
return "客户端不存在";
@ -102,6 +171,7 @@ public class SseController {
@GetMapping("/disconnect")
public String disconnect(@RequestParam String clientId) {
SseEmitter emitter = emitters.remove(clientId);
clientUserIdMap.remove(clientId);
if (emitter != null) {
emitter.complete();
return "断开连接成功";
@ -111,6 +181,28 @@ public class SseController {
@GetMapping("/status")
public String getStatus() {
return "当前连接数: " + emitters.size();
return "当前连接数: " + emitters.size() + ", 活跃会话数: " + sessionManager.getSessionCount();
}
@GetMapping("/history")
public Object getConversationHistory(@RequestParam String token, HttpServletRequest request) {
try {
// 从token中获取用户信息
LoginUser loginUser = tokenService.getLoginUser(request);
if (loginUser != null) {
Long userId = loginUser.getUserId();
// 获取用户的会话ID
String sessionId = sessionManager.getSessionId(userId);
if (sessionId != null) {
// 获取对话历史
var history = conversationHistoryManager.getConversationHistory(sessionId);
return Map.of("success", true, "data", history, "sessionId", sessionId);
}
}
return Map.of("success", false, "message", "获取对话历史失败");
} catch (Exception e) {
log.error("获取对话历史失败: {}", e.getMessage());
return Map.of("success", false, "message", "获取对话历史失败: " + e.getMessage());
}
}
}

View File

@ -15,6 +15,7 @@ import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.List;
import cn.qihangerp.erp.service.OrderToolService;
import cn.qihangerp.erp.serviceImpl.ConversationHistoryManager;
/**
* AI服务类使用LangChain4J调用Ollama模型处理聊天内容
@ -129,9 +130,11 @@ public class AiService {
* 处理聊天消息
* @param message 用户消息
* @param model 模型名称
* @param sessionId 会话ID
* @param conversationHistory 对话历史
* @return AI回复
*/
public String processMessage(String message, String model) {
public String processMessage(String message, String model, String sessionId, List<ConversationHistoryManager.Message> conversationHistory) {
try {
// 优先检查页面跳转规则
String pageRuleResponse = checkPageRules(message);
@ -146,13 +149,39 @@ public class AiService {
// 替换消息中的"今天"为具体日期
message = message.replace("今天", today);
// 创建订单工具服务
OrderToolService orderToolService = new OrderToolService();
// 构建包含历史对话的提示词
StringBuilder promptBuilder = new StringBuilder();
promptBuilder.append("今天的日期是:").append(today).append("\n");
// 执行AI服务添加今天的日期信息
String enhancedMessage = "今天的日期是:" + today + "\n" + message;
// 添加历史对话作为上下文
if (conversationHistory != null && !conversationHistory.isEmpty()) {
promptBuilder.append("以下是之前的对话历史:\n");
for (ConversationHistoryManager.Message msg : conversationHistory) {
if (msg.getRole().equals("user")) {
promptBuilder.append("用户: " + msg.getContent()).append("\n");
} else {
promptBuilder.append("助手: " + msg.getContent()).append("\n");
}
}
promptBuilder.append("\n当前用户消息\n");
}
// 添加当前消息
promptBuilder.append(message);
String enhancedMessage = promptBuilder.toString();
System.out.println("发送给AI的消息: " + enhancedMessage);
// 尝试创建订单工具服务
OrderToolService orderToolService = null;
try {
orderToolService = new OrderToolService();
System.out.println("成功创建OrderToolService");
} catch (Exception e) {
System.out.println("创建OrderToolService失败: " + e.getMessage());
// 工具创建失败仍然继续执行只是不使用工具
}
// 根据模型名称选择使用Ollama还是DeepSeek API
OrderAiService aiService;
if (model.startsWith("deepseek")) {
@ -171,10 +200,16 @@ public class AiService {
.timeout(Duration.ofSeconds(300))
.build();
if (orderToolService != null) {
aiService = AiServices.builder(OrderAiService.class)
.chatModel(deepSeekModelInstance)
.tools(orderToolService)
.build();
} else {
aiService = AiServices.builder(OrderAiService.class)
.chatModel(deepSeekModelInstance)
.build();
}
System.out.println("使用DeepSeek API处理消息");
} catch (Exception e) {
// 如果DeepSeek依赖不可用返回错误消息
@ -182,6 +217,8 @@ public class AiService {
}
} else {
// 使用Ollama
try {
System.out.println("尝试连接Ollama模型: " + model);
OllamaChatModel modelInstance = OllamaChatModel.builder()
.baseUrl("http://localhost:11434") // Ollama默认端口
.modelName(model) // 使用指定的模型
@ -189,28 +226,54 @@ public class AiService {
.timeout(Duration.ofSeconds(300)) // 超时时间设置为300秒5分钟
.build();
if (orderToolService != null) {
aiService = AiServices.builder(OrderAiService.class)
.chatModel(modelInstance)
.tools(orderToolService)
.build();
System.out.println("使用Ollama处理消息模型: " + model);
} else {
aiService = AiServices.builder(OrderAiService.class)
.chatModel(modelInstance)
.build();
}
System.out.println("成功创建Ollama模型实例模型: " + model);
} catch (Exception e) {
System.out.println("创建Ollama模型实例失败: " + e.getMessage());
return "错误: 无法连接到Ollama服务请检查Ollama是否已启动端口是否正确默认11434";
}
}
try {
System.out.println("开始调用AI服务");
String result = aiService.chat(enhancedMessage);
System.out.println("AI返回的结果: " + result);
return result;
} catch (Exception e) {
System.out.println("调用AI服务失败: " + e.getMessage());
return "错误: 调用AI服务失败: " + e.getMessage();
}
} catch (Exception e) {
e.printStackTrace();
return "错误: " + e.getMessage();
}
}
/**
* 处理聊天消息
* @param message 用户消息
* @param model 模型名称
* @return AI回复
*/
public String processMessage(String message, String model) {
return processMessage(message, model, null, null);
}
/**
* 处理聊天消息使用默认模型
* @param message 用户消息
* @return AI回复
*/
public String processMessage(String message) {
return processMessage(message, "llama3");
return processMessage(message, "llama3", null, null);
}
}

View File

@ -0,0 +1,112 @@
package cn.qihangerp.erp.serviceImpl;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
/**
* 对话历史管理服务用于保存和管理用户的对话历史
*/
public class ConversationHistoryManager {
private static final Map<String, List<Message>> sessionHistoryMap = new ConcurrentHashMap<>();
private static final AtomicLong messageIdCounter = new AtomicLong(0);
/**
* 消息实体类
*/
public static class Message {
private long id;
private String role; // user assistant
private String content;
private long timestamp;
public Message(String role, String content) {
this.id = messageIdCounter.incrementAndGet();
this.role = role;
this.content = content;
this.timestamp = System.currentTimeMillis();
}
public long getId() {
return id;
}
public String getRole() {
return role;
}
public String getContent() {
return content;
}
public long getTimestamp() {
return timestamp;
}
}
/**
* 添加消息到对话历史
* @param sessionId 会话ID
* @param role 角色
* @param content 消息内容
*/
public void addMessage(String sessionId, String role, String content) {
if (sessionId == null) {
return;
}
sessionHistoryMap.computeIfAbsent(sessionId, k -> new ArrayList<>())
.add(new Message(role, content));
}
/**
* 获取会话的所有对话历史
* @param sessionId 会话ID
* @return 对话历史列表
*/
public List<Message> getConversationHistory(String sessionId) {
if (sessionId == null) {
return new ArrayList<>();
}
return sessionHistoryMap.getOrDefault(sessionId, new ArrayList<>());
}
/**
* 获取会话的最近几条对话历史
* @param sessionId 会话ID
* @param limit 限制数量
* @return 最近的对话历史列表
*/
public List<Message> getRecentConversationHistory(String sessionId, int limit) {
if (sessionId == null) {
return new ArrayList<>();
}
List<Message> history = sessionHistoryMap.getOrDefault(sessionId, new ArrayList<>());
int startIndex = Math.max(0, history.size() - limit);
return history.subList(startIndex, history.size());
}
/**
* 清空会话的对话历史
* @param sessionId 会话ID
*/
public void clearConversationHistory(String sessionId) {
if (sessionId != null) {
sessionHistoryMap.remove(sessionId);
}
}
/**
* 获取会话的对话历史数量
* @param sessionId 会话ID
* @return 对话历史数量
*/
public int getMessageCount(String sessionId) {
if (sessionId == null) {
return 0;
}
List<Message> history = sessionHistoryMap.get(sessionId);
return history != null ? history.size() : 0;
}
}

View File

@ -0,0 +1,75 @@
package cn.qihangerp.erp.serviceImpl;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.UUID;
/**
* 会话管理服务用于管理用户的会话ID
*/
public class SessionManager {
private static final Map<Long, String> userIdSessionMap = new ConcurrentHashMap<>();
private static final Map<String, Long> sessionUserIdMap = new ConcurrentHashMap<>();
/**
* 获取或创建用户的会话ID
* @param userId 用户ID
* @return 会话ID
*/
public String getOrCreateSessionId(Long userId) {
if (userId == null) {
return null;
}
return userIdSessionMap.computeIfAbsent(userId, k -> {
String sessionId = UUID.randomUUID().toString();
sessionUserIdMap.put(sessionId, userId);
return sessionId;
});
}
/**
* 根据用户ID获取会话ID
* @param userId 用户ID
* @return 会话ID
*/
public String getSessionId(Long userId) {
if (userId == null) {
return null;
}
return userIdSessionMap.get(userId);
}
/**
* 根据会话ID获取用户ID
* @param sessionId 会话ID
* @return 用户ID
*/
public Long getUserIdBySessionId(String sessionId) {
if (sessionId == null) {
return null;
}
return sessionUserIdMap.get(sessionId);
}
/**
* 移除用户的会话
* @param userId 用户ID
*/
public void removeSession(Long userId) {
if (userId == null) {
return;
}
String sessionId = userIdSessionMap.remove(userId);
if (sessionId != null) {
sessionUserIdMap.remove(sessionId);
}
}
/**
* 获取当前活跃会话数
* @return 活跃会话数
*/
public int getSessionCount() {
return userIdSessionMap.size();
}
}

View File

@ -7,3 +7,12 @@ export function getOllamaModels() {
method: 'get'
})
}
// 获取对话历史
export function getConversationHistory(token) {
return request({
url: '/api/ai-agent/sse/history',
method: 'get',
params: { token }
})
}

View File

@ -128,7 +128,7 @@
<script>
import { todayDaily } from "@/api/report/report";
import { getOllamaModels } from "@/api/ai/ollama";
import { getOllamaModels, getConversationHistory } from "@/api/ai/ollama";
import { getToken } from "@/utils/auth";
import MarkdownIt from 'markdown-it';
@ -177,11 +177,13 @@ export default {
isSseConnected: false,
isLoading: false,
selectedModel: 'deepseek',
models: []
models: [],
sessionId: ''
}
},
mounted() {
this.loadSystemStats();
this.loadConversationHistory();
this.initSse();
this.loadOllamaModels();
},
@ -204,6 +206,31 @@ export default {
console.error('获取模型列表失败:', error);
});
},
loadConversationHistory() {
const token = getToken();
if (token) {
getConversationHistory(token).then(response => {
if (response.success && response.data) {
//
this.messages = [];
//
response.data.forEach(msg => {
this.messages.push({
content: msg.content,
time: this.formatTime(new Date(msg.timestamp)),
isMe: msg.role === 'user',
avatar: ''
});
});
// ID
this.sessionId = response.sessionId;
console.log('加载对话历史成功:', response.data.length, '条消息');
}
}).catch(error => {
console.error('获取对话历史失败:', error);
});
}
},
initSse() {
// ID
this.clientId = 'client_' + Date.now() + '_' + Math.random().toString(36).substr(2, 9);