升级 Spring AI 1.0.0-RC1、Chroma 1.0.0

This commit is contained in:
thinkgem
2025-05-14 21:48:29 +08:00
parent 723a088eed
commit 1b21375e1c
8 changed files with 106 additions and 74 deletions

View File

@@ -7,8 +7,11 @@ package com.jeesite.modules.cms.ai.config;
import com.jeesite.common.datasource.DataSourceHolder;
import com.jeesite.common.lang.StringUtils;
import com.jeesite.modules.cms.ai.properties.CmsAiProperties;
import com.jeesite.modules.cms.ai.service.CacheChatMemoryRepository;
import com.jeesite.modules.cms.ai.tools.CmsAiTools;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
@@ -39,6 +42,18 @@ public class CmsAiChatConfig {
return builder.build();
}
/**
* 聊天对话数据存储
* @author ThinkGem
*/
@Bean
public ChatMemory chatMemory(CacheChatMemoryRepository cacheChatMemoryRepository) {
return MessageWindowChatMemory.builder()
.chatMemoryRepository(cacheChatMemoryRepository)
.maxMessages(1024)
.build();
}
// @Bean
// public BatchingStrategy batchingStrategy() {
// return new TokenCountBatchingStrategy(EncodingType.CL100K_BASE, Integer.MAX_VALUE, 0.1);

View File

@@ -12,6 +12,7 @@ import com.jeesite.common.lang.StringUtils;
import com.jeesite.common.lang.TimeUtils;
import com.jeesite.common.utils.PageUtils;
import com.jeesite.common.web.http.HttpClientUtils;
import com.jeesite.common.web.http.ServletUtils;
import com.jeesite.modules.cms.entity.Article;
import com.jeesite.modules.cms.service.ArticleVectorStore;
import com.jeesite.modules.cms.utils.CmsUtils;
@@ -21,11 +22,11 @@ import com.vladsch.flexmark.html2md.converter.FlexmarkHtmlConverter;
import com.vladsch.flexmark.html2md.converter.HtmlLinkResolver;
import com.vladsch.flexmark.html2md.converter.HtmlLinkResolverFactory;
import com.vladsch.flexmark.html2md.converter.HtmlNodeConverterContext;
import jakarta.servlet.http.HttpServletRequest;
import org.apache.tika.Tika;
import org.apache.tika.config.TikaConfig;
import org.apache.tika.exception.TikaException;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
@@ -75,13 +76,33 @@ public class ArticleVectorStoreImpl implements ArticleVectorStore {
metadata.put("updateBy", article.getUpdateBy());
metadata.put("updateDate", article.getUpdateDate());
List<String> attachmentList = ListUtils.newArrayList();
HtmlLinkResolverFactory linkResolverFactory = new HtmlLinkResolverFactory() {
String content = article.getTitle() + ", " + article.getKeywords() + ", "
+ article.getDescription() + ", " + FlexmarkHtmlConverter.builder()
.linkResolverFactory(getHtmlLinkResolverFactory(attachmentList)).build()
.convert(article.getArticleData().getContent())
+ ", attachment: " + attachmentList;
List<Document> documents = List.of(new Document(article.getId(), content, metadata));
List<Document> splitDocuments = new TokenTextSplitter().apply(documents);
this.delete(article); // 删除原数据
ListUtils.pageList(splitDocuments, 10, params -> {
vectorStore.add((List<Document>)params[0]); // 增加新数据
return null;
});
}
/**
* 解析文章中的连接并提取内容
* @author ThinkGem
*/
private @NotNull HtmlLinkResolverFactory getHtmlLinkResolverFactory(List<String> attachmentList) {
HttpServletRequest request = ServletUtils.getRequest();
return new HtmlLinkResolverFactory() {
@Override
public @Nullable Set<Class<?>> getAfterDependents() {
public @NotNull Set<Class<?>> getAfterDependents() {
return Set.of();
}
@Override
public @Nullable Set<Class<?>> getBeforeDependents() {
public @NotNull Set<Class<?>> getBeforeDependents() {
return Set.of();
}
@Override
@@ -94,11 +115,16 @@ public class ArticleVectorStoreImpl implements ArticleVectorStore {
if ("a".equalsIgnoreCase(node.nodeName())) {
String href = node.attributes().get("href"); String url = href;
if (StringUtils.contains(url, "://")) {
try (InputStream is = HttpClientUtils.getInputStream(url, null)) {
String text = getDocumentText(is);
attachmentList.add(url + text);
} catch (IOException | TikaException e) {
logger.error(e.getMessage(), e);
// 只提取系统允许跳转的附件内容外部网站内容不进行提取shiro.allowRedirects 参数设置范围
if (ServletUtils.isAllowRedirects(request, url)) {
try (InputStream is = HttpClientUtils.getInputStream(url, null)) {
if (is != null) {
String text = getDocumentText(is);
attachmentList.add(url + text);
}
} catch (IOException | TikaException e) {
logger.error(e.getMessage(), e);
}
}
} else {
String ctxPath = Global.getCtxPath();
@@ -106,8 +132,10 @@ public class ArticleVectorStoreImpl implements ArticleVectorStore {
url = url.substring(ctxPath.length());
}
try (InputStream is = IOUtils.getFileInputStream(Global.getUserfilesBaseDir(url))){
String text = getDocumentText(is);
attachmentList.add(url + text);
if (is != null) {
String text = getDocumentText(is);
attachmentList.add(url + text);
}
} catch (IOException | TikaException e) {
logger.error(e.getMessage(), e);
}
@@ -130,18 +158,6 @@ public class ArticleVectorStoreImpl implements ArticleVectorStore {
.orElse(StringUtils.EMPTY);
}
};
String content = article.getTitle() + ", " + article.getKeywords() + ", "
+ article.getDescription() + ", " + FlexmarkHtmlConverter.builder()
.linkResolverFactory(linkResolverFactory).build()
.convert(article.getArticleData().getContent())
+ ", attachment: " + attachmentList;
List<Document> documents = List.of(new Document(article.getId(), content, metadata));
List<Document> splitDocuments = new TokenTextSplitter().apply(documents);
this.delete(article); // 删除原数据
ListUtils.pageList(splitDocuments, 64, params -> {
vectorStore.add((List<Document>)params[0]); // 增加新数据
return null;
});
}
/**

View File

@@ -1,45 +0,0 @@
/**
* Copyright (c) 2013-Now http://jeesite.com All rights reserved.
* No deletion without permission, or be held responsible to law.
*/
package com.jeesite.modules.cms.ai.service;
import com.jeesite.common.cache.CacheUtils;
import com.jeesite.common.collect.ListUtils;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.stereotype.Service;
import java.util.List;
/**
* AI 对话消息存储
* @author ThinkGem
*/
@Service
public class CacheChatMemory implements ChatMemory {
private static final String CMS_CHAT_MSG_CACHE = "cmsChatMsgCache";
@Override
public void add(String conversationId, List<Message> messages) {
List<Message> conversationHistory = CacheUtils.get(CMS_CHAT_MSG_CACHE, conversationId);
if (conversationHistory == null) {
conversationHistory = ListUtils.newArrayList();
}
conversationHistory.addAll(messages);
CacheUtils.put(CMS_CHAT_MSG_CACHE, conversationId, conversationHistory);
}
@Override
public List<Message> get(String conversationId, int lastN) {
List<Message> all = CacheUtils.get(CMS_CHAT_MSG_CACHE, conversationId);
return all != null ? all.stream().skip(Math.max(0, all.size() - lastN)).toList() : List.of();
}
@Override
public void clear(String conversationId) {
CacheUtils.remove(CMS_CHAT_MSG_CACHE, conversationId);
}
}

View File

@@ -0,0 +1,44 @@
/**
* Copyright (c) 2013-Now http://jeesite.com All rights reserved.
* No deletion without permission, or be held responsible to law.
*/
package com.jeesite.modules.cms.ai.service;
import com.jeesite.common.cache.CacheUtils;
import org.jetbrains.annotations.NotNull;
import org.springframework.ai.chat.memory.ChatMemoryRepository;
import org.springframework.ai.chat.messages.Message;
import org.springframework.stereotype.Service;
import java.util.List;
/**
* AI 对话消息存储
* @author ThinkGem
*/
@Service
public class CacheChatMemoryRepository implements ChatMemoryRepository {
private static final String CMS_CHAT_MSG_CACHE = "cmsChatMsgCache";
@Override
public @NotNull List<String> findConversationIds() {
return CacheUtils.getCache(CMS_CHAT_MSG_CACHE).keys().stream().map(Object::toString).toList();
}
@Override
public @NotNull List<Message> findByConversationId(@NotNull String conversationId) {
List<Message> all = CacheUtils.get(CMS_CHAT_MSG_CACHE, conversationId);
return all != null ? all : List.of();
}
@Override
public void saveAll(@NotNull String conversationId, @NotNull List<Message> messages) {
CacheUtils.put(CMS_CHAT_MSG_CACHE, conversationId, messages);
}
@Override
public void deleteByConversationId(@NotNull String conversationId) {
CacheUtils.remove(CMS_CHAT_MSG_CACHE, conversationId);
}
}

View File

@@ -59,7 +59,7 @@ public class CmsAiChatService extends BaseService {
* @author ThinkGem
*/
public List<Message> getChatMessage(String conversationId) {
return chatMemory.get(conversationId, 100);
return chatMemory.get(conversationId);
}
private static String getChatCacheKey() {
@@ -119,7 +119,9 @@ public class CmsAiChatService extends BaseService {
new UserMessage(StringUtils.replaceEach(message, USER_MESSAGE_SEARCH, USER_MESSAGE_REPLACE))
)
.advisors(
new MessageChatMemoryAdvisor(chatMemory, conversationId, 1024),
MessageChatMemoryAdvisor.builder(chatMemory)
.conversationId(conversationId)
.build(),
QuestionAnswerAdvisor.builder(vectorStore)
.searchRequest(SearchRequest.builder().similarityThreshold(0.6F).topK(6).build())
.promptTemplate(new PromptTemplate(properties.getDefaultPromptTemplate()))

View File

@@ -61,8 +61,8 @@ spring:
host: http://testserver
port: 8000
initialize-schema: true
collection-name: vector_store
#collection-name: vector_store_1024
# collection-name: vector_store
collection-name: vector_store_1024
# Postgresql 向量数据库PG 连接配置,见下文,需要手动建表)【请在 pom.xml 中打开 pgvector 的注释,并注释上其它向量库】
pgvector: