diff --git a/modules/cms-ai/src/main/java/com/jeesite/modules/cms/ai/config/CmsAiChatConfig.java b/modules/cms-ai/src/main/java/com/jeesite/modules/cms/ai/config/CmsAiChatConfig.java index dfd3617a..916eaabe 100644 --- a/modules/cms-ai/src/main/java/com/jeesite/modules/cms/ai/config/CmsAiChatConfig.java +++ b/modules/cms-ai/src/main/java/com/jeesite/modules/cms/ai/config/CmsAiChatConfig.java @@ -5,9 +5,11 @@ package com.jeesite.modules.cms.ai.config; import com.jeesite.common.datasource.DataSourceHolder; +import com.jeesite.modules.cms.ai.properties.CmsAiProperties; import com.jeesite.modules.cms.ai.tools.CmsAiTools; import org.springframework.ai.chat.client.ChatClient; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Primary; @@ -18,6 +20,7 @@ import org.springframework.jdbc.core.JdbcTemplate; * @author ThinkGem */ @Configuration +@EnableConfigurationProperties(CmsAiProperties.class) public class CmsAiChatConfig { /** @@ -25,19 +28,21 @@ public class CmsAiChatConfig { * @author ThinkGem */ @Bean - public ChatClient chatClient(ChatClient.Builder builder) { - return builder - .defaultSystem(""" - ## 人物设定 - 你是我的知识库AI助手,你把我当作朋友,耐心真诚地回复我提出的相关问题。 - 你需要遵循以下原则,与关注者进行友善而有价值的沟通。 - ## 表达方式: - 1. 使用简体中文回答我的问题。 - 2. 使用幽默有趣的方式与我沟通。 - 3. 增加互动,如 “您的看法如何?” - """) - .defaultTools(new CmsAiTools()) - .build(); + public ChatClient chatClient(ChatClient.Builder builder, CmsAiProperties properties) { + builder.defaultSystem(""" + ## 人物设定 + 你是我的知识库AI助手,你把我当作朋友,耐心真诚地回复我提出的相关问题。 + 你需要遵循以下原则,与关注者进行友善而有价值的沟通。 + ## 表达方式: + 1. 使用简体中文回答我的问题。 + 2. 使用幽默有趣的方式与我沟通。 + 3. 增加互动,如 “您的看法如何?” + 4. 可以用表情,避免过多表情。 + """); + if (properties.getToolCalls()) { + builder.defaultTools(new CmsAiTools()); + } + return builder.build(); } // @Bean diff --git a/modules/cms-ai/src/main/java/com/jeesite/modules/cms/ai/config/WebClientThinkConfig.java b/modules/cms-ai/src/main/java/com/jeesite/modules/cms/ai/config/WebClientThinkConfig.java new file mode 100644 index 00000000..70885aa0 --- /dev/null +++ b/modules/cms-ai/src/main/java/com/jeesite/modules/cms/ai/config/WebClientThinkConfig.java @@ -0,0 +1,129 @@ +/** + * Copyright (c) 2013-Now http://jeesite.com All rights reserved. + * No deletion without permission, or be held responsible to law. + */ +package com.jeesite.modules.cms.ai.config; + +import com.jeesite.common.mapper.JsonMapper; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.web.reactive.function.client.WebClientCustomizer; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.http.HttpHeaders; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.ExchangeFilterFunction; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * 推理模型OpenAI兼容处理 + * @author ThinkGem + */ +@Configuration +public class WebClientThinkConfig { + + private final Logger logger = LoggerFactory.getLogger(WebClientThinkConfig.class); + + @Bean + @ConditionalOnMissingBean + public WebClientCustomizer webClientCustomizerThink() { + return webClientBuilder -> { + ExchangeFilterFunction requestFilter = ExchangeFilterFunction.ofRequestProcessor(clientRequest -> { + logger.trace("Request url: {}: {}", clientRequest.method(), clientRequest.url()); + return Mono.just(clientRequest); + }); + ExchangeFilterFunction responseFilter = ExchangeFilterFunction.ofResponseProcessor(clientResponse -> { + logger.trace("Response status: {}", clientResponse.statusCode()); + AtomicBoolean thinkingFlag = new AtomicBoolean(false); + Flux modifiedBody = clientResponse.bodyToFlux(DataBuffer.class) + .map(buf -> { + byte[] bytes = new byte[buf.readableByteCount()]; + buf.read(bytes); + DataBufferUtils.release(buf); + return new String(bytes, StandardCharsets.UTF_8); + }) + .flatMap(eventString -> { + logger.trace("Original response: ==> {}", eventString); + List lines = new ArrayList<>(); + String[] list = eventString.split("\\n", -1); + for (String line : list) { + if (!line.startsWith("data: ")) { + lines.add(line); + continue; + } + String jsonPart = line.substring("data: ".length()).trim(); + if (!(StringUtils.startsWith(jsonPart, "{") + && StringUtils.endsWith(jsonPart, "}") + && !"data: [DONE]".equals(line))) { + lines.add(line); + continue; + } + Map map = JsonMapper.fromJson(jsonPart, Map.class); + if (map == null) { + lines.add(line); + continue; + } + // 修改内容字段 + List choices = (List)map.get("choices"); + if (choices == null) { + lines.add(line); + continue; + } + for (Object o : choices) { + Map choice = (Map) o; + if (choice == null) { + continue; + } + Map delta = (Map) choice.get("delta"); + if (delta == null) { + continue; + } + String reasoningContent = (String) delta.get("reasoning_content"); + String content = (String) delta.get("content"); + if (reasoningContent != null) { + if (!thinkingFlag.get()) { + thinkingFlag.set(true); + delta.put("content", "\n" + reasoningContent); + } else { + delta.put("content", reasoningContent); + } + } else { + if (thinkingFlag.get()) { + thinkingFlag.set(false); + delta.put("content", "" + (content == null ? "" : content)); + } + } + } + // 重新生成事件字符串 + lines.add("data: " + JsonMapper.toJson(map)); + } + String finalLine = StringUtils.join(lines, "\n"); + logger.trace("Modified response: ==> {}", finalLine); + return Mono.just(finalLine); + }) + .map(str -> { + byte[] bytes = str.getBytes(StandardCharsets.UTF_8); + return new DefaultDataBufferFactory().wrap(bytes); + }); + ClientResponse modifiedResponse = ClientResponse.from(clientResponse) + .headers(headers -> headers.remove(HttpHeaders.CONTENT_LENGTH)) + .body(modifiedBody) + .build(); + return Mono.just(modifiedResponse); + }); + webClientBuilder.filter(requestFilter).filter(responseFilter); + }; + } +} diff --git a/modules/cms-ai/src/main/java/com/jeesite/modules/cms/ai/properties/CmsAiProperties.java b/modules/cms-ai/src/main/java/com/jeesite/modules/cms/ai/properties/CmsAiProperties.java new file mode 100644 index 00000000..c8ca33e8 --- /dev/null +++ b/modules/cms-ai/src/main/java/com/jeesite/modules/cms/ai/properties/CmsAiProperties.java @@ -0,0 +1,17 @@ +package com.jeesite.modules.cms.ai.properties; + +import org.springframework.boot.context.properties.ConfigurationProperties; + +@ConfigurationProperties("spring.ai") +public class CmsAiProperties { + + private Boolean toolCalls = false; + + public Boolean getToolCalls() { + return toolCalls; + } + + public void setToolCalls(Boolean toolCalls) { + this.toolCalls = toolCalls; + } +} diff --git a/modules/cms-ai/src/main/java/com/jeesite/modules/cms/ai/service/CmsAiChatService.java b/modules/cms-ai/src/main/java/com/jeesite/modules/cms/ai/service/CmsAiChatService.java index b2aad2ba..7b2f41ec 100644 --- a/modules/cms-ai/src/main/java/com/jeesite/modules/cms/ai/service/CmsAiChatService.java +++ b/modules/cms-ai/src/main/java/com/jeesite/modules/cms/ai/service/CmsAiChatService.java @@ -11,18 +11,23 @@ import com.jeesite.common.lang.DateUtils; import com.jeesite.common.lang.StringUtils; import com.jeesite.common.service.BaseService; import com.jeesite.modules.sys.utils.UserUtils; +import jakarta.servlet.http.HttpServletRequest; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor; -import org.springframework.ai.chat.client.advisor.QuestionAnswerAdvisor; +import org.springframework.ai.chat.client.advisor.vectorstore.QuestionAnswerAdvisor; import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; +import org.springframework.web.reactive.function.client.WebClientResponseException; import reactor.core.publisher.Flux; +import reactor.core.publisher.SignalType; import java.util.List; import java.util.Map; @@ -35,6 +40,8 @@ import java.util.Map; public class CmsAiChatService extends BaseService { private static final String CMS_CHAT_CACHE = "cmsChatCache"; + private static final String[] USER_MESSAGE_SEARCH = new String[]{"{", "}", "$", "%"}; + private static final String[] USER_MESSAGE_REPLACE = new String[]{"\\{", "\\}", "\\$", "\\%"}; @Autowired private ChatClient chatClient; @@ -102,14 +109,52 @@ public class CmsAiChatService extends BaseService { * 聊天对话,流输出 * @author ThinkGem */ - public Flux chatStream(String conversationId, String message) { + public Flux chatStream(String conversationId, String message, HttpServletRequest request) { return chatClient.prompt() - .messages(new UserMessage(message)) + .messages( + new UserMessage(StringUtils.replaceEach(message, USER_MESSAGE_SEARCH, USER_MESSAGE_REPLACE)) + ) .advisors( new MessageChatMemoryAdvisor(chatMemory, conversationId, 1024), - new QuestionAnswerAdvisor(vectorStore, SearchRequest.builder().similarityThreshold(0.6F).topK(6).build())) + new QuestionAnswerAdvisor(vectorStore, SearchRequest.builder().similarityThreshold(0.6F).topK(6).build()) + ) .stream() - .chatResponse(); + .chatResponse() + .doOnNext(response -> { + if (response.getResult() != null && StringUtils.isNotBlank(response.getResult().getOutput().getText())) { + AssistantMessage assistantMessage = (AssistantMessage)request.getAttribute("assistantMessage"); + AssistantMessage currAssistantMessage = response.getResult().getOutput(); + if (assistantMessage == null) { + request.setAttribute("assistantMessage", currAssistantMessage); + } else { + request.setAttribute("assistantMessage", new AssistantMessage( + assistantMessage.getText() + currAssistantMessage.getText(), + currAssistantMessage.getMetadata())); + } + } + }) + .doFinally((signalType) -> { + if (signalType != SignalType.ON_COMPLETE) { + AssistantMessage assistantMessage = (AssistantMessage)request.getAttribute("assistantMessage"); + if (assistantMessage != null) { + chatMemory.add(conversationId, assistantMessage); + } else if (signalType == SignalType.CANCEL) { + chatMemory.add(conversationId, new AssistantMessage(text("暂无消息,你已主动停止响应。"))); + } + } + }) + .onErrorResume(error -> { + String errorMessage = error.getMessage(); + if (error instanceof WebClientResponseException webClientError) { + errorMessage = webClientError.getResponseBodyAsString(); + } + AssistantMessage assistantMessage = new AssistantMessage(errorMessage); + chatMemory.add(conversationId, assistantMessage); + logger.error("Error message: {}", errorMessage); + return Flux.just(ChatResponse.builder() + .generations(List.of(new Generation(assistantMessage))) + .build()); + }); } } diff --git a/modules/cms-ai/src/main/java/com/jeesite/modules/cms/ai/web/CmsAiChatController.java b/modules/cms-ai/src/main/java/com/jeesite/modules/cms/ai/web/CmsAiChatController.java index cab9e514..14314611 100644 --- a/modules/cms-ai/src/main/java/com/jeesite/modules/cms/ai/web/CmsAiChatController.java +++ b/modules/cms-ai/src/main/java/com/jeesite/modules/cms/ai/web/CmsAiChatController.java @@ -7,6 +7,7 @@ package com.jeesite.modules.cms.ai.web; import com.jeesite.common.config.Global; import com.jeesite.common.web.BaseController; import com.jeesite.modules.cms.ai.service.CmsAiChatService; +import jakarta.servlet.http.HttpServletRequest; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.model.ChatResponse; import org.springframework.beans.factory.annotation.Autowired; @@ -77,8 +78,8 @@ public class CmsAiChatController extends BaseController { * @author ThinkGem */ @RequestMapping(value = "/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) - public Flux stream(String id, String message) { - return cmsAiChatService.chatStream(id, message); + public Flux stream(String id, String message, HttpServletRequest request) { + return cmsAiChatService.chatStream(id, message, request); } } diff --git a/modules/cms-ai/src/main/resources/config/jeesite-cms-ai.yml b/modules/cms-ai/src/main/resources/config/jeesite-cms-ai.yml index e673cd46..350283e1 100644 --- a/modules/cms-ai/src/main/resources/config/jeesite-cms-ai.yml +++ b/modules/cms-ai/src/main/resources/config/jeesite-cms-ai.yml @@ -13,19 +13,17 @@ spring: #api-key: ${BAILIAN_APP_KEY} # 聊天对话模型 chat: - enabled: false options: model: deepseek-ai/DeepSeek-R1-Distill-Qwen-7B #model: DeepSeek-R1-Distill-Qwen-14B #model: deepseek-r1-distill-llama-8b max-tokens: 1024 temperature: 0.6 - top-p: 0.7 + top-p: 0.9 frequency-penalty: 0 - logprobs: true + #logprobs: true # 向量库知识库模型(注意:不同的模型维度不同) embedding: - enabled: false options: model: BAAI/bge-m3 #model: bge-large-zh-v1.5 @@ -33,12 +31,14 @@ spring: #model: text-embedding-v3 #dimensions: 1024 + # 是否启用工具调用 + tool-calls: false + # 本地大模型配置(使用该模型,请开启 enabled 参数) ollama: base-url: http://localhost:11434 # 聊天对话模型 chat: - enabled: true options: model: qwen2.5 #model: deepseek-r1:7b @@ -48,7 +48,6 @@ spring: frequency-penalty: 0 # 向量库知识库模型(注意:不同的模型维度不同) embedding: - enabled: true # 维度 dimensions 设置为 384 #model: all-minilm:33m # 维度 dimensions 设置为 768 @@ -68,67 +67,64 @@ spring: #collection-name: vector_store collection-name: vector_store_1024 -# # Postgresql 向量数据库(PG 连接配置,见下文,需要手动建表) -# pgvector: -# id-type: TEXT -# index-type: HNSW -# distance-type: COSINE_DISTANCE -# initialize-schema: false -# #table-name: vector_store_384 -# #dimensions: 384 -# #table-name: vector_store_786 -# #dimensions: 768 -# table-name: vector_store_1024 -# dimensions: 1024 -# batching-strategy: TOKEN_COUNT -# max-document-batch-size: 10000 + # Postgresql 向量数据库(PG 连接配置,见下文,需要手动建表) + pgvector: + id-type: TEXT + index-type: HNSW + distance-type: COSINE_DISTANCE + initialize-schema: false + #table-name: vector_store_384 + #dimensions: 384 + #table-name: vector_store_786 + #dimensions: 768 + table-name: vector_store_1024 + dimensions: 1024 + max-document-batch-size: 10000 -# # ES 向量数据库(ES 连接配置,见下文) -# elasticsearch: -# index-name: vector-index -# initialize-schema: true -# dimensions: 1024 -# similarity: cosine -# batching-strategy: TOKEN_COUNT + # ES 向量数据库(ES 连接配置,见下文) + elasticsearch: + index-name: vector-index + initialize-schema: true + dimensions: 1024 + similarity: cosine -# # Milvus 向量数据库(字符串长度不超过65535) -# milvus: -# client: -# host: "localhost" -# port: 19530 -# username: "root" -# password: "milvus" -# initialize-schema: true -# database-name: "default2" -# collection-name: "vector_store2" -# embedding-dimension: 384 -# index-type: HNSW -# metric-type: COSINE + # Milvus 向量数据库 + milvus: + client: + host: "localhost" + port: 19530 + username: "root" + password: "milvus" + initialize-schema: true + database-name: "default" + collection-name: "vector_store" + embedding-dimension: 384 + index-type: HNSW + metric-type: COSINE # ========= Postgresql 向量数据库数据源 ========= -#jdbc: -# ds_pgvector: -# type: postgresql -# driver: org.postgresql.Driver -# url: jdbc:postgresql://127.0.0.1:5433/jeesite-ai -# username: postgres -# password: postgres -# testSql: SELECT 1 -# pool: -# init: 0 -# minIdle: 0 -# breakAfterAcquireFailure: true +jdbc: + ds_pgvector: + type: postgresql + driver: org.postgresql.Driver + url: jdbc:postgresql://127.0.0.1:5433/jeesite-ai + username: postgres + password: postgres + testSql: SELECT 1 + pool: + init: 0 + minIdle: 0 + breakAfterAcquireFailure: true # ========= ES 向量数据库连接配置 ========= -#spring.elasticsearch: -# enabled: true -# socket-timeout: 120s -# connection-timeout: 120s -# uris: http://127.0.0.1:9200 -# username: elastic -# password: elastic +spring.elasticsearch: + socket-timeout: 120s + connection-timeout: 120s + uris: http://127.0.0.1:9200 + username: elastic + password: elastic # 对话消息存缓存,可自定义存数据库 j2cache: