升级 Spring AI 1.0.0-RC1、Chroma 1.0.0
This commit is contained in:
@@ -7,8 +7,11 @@ package com.jeesite.modules.cms.ai.config;
|
||||
import com.jeesite.common.datasource.DataSourceHolder;
|
||||
import com.jeesite.common.lang.StringUtils;
|
||||
import com.jeesite.modules.cms.ai.properties.CmsAiProperties;
|
||||
import com.jeesite.modules.cms.ai.service.CacheChatMemoryRepository;
|
||||
import com.jeesite.modules.cms.ai.tools.CmsAiTools;
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.memory.ChatMemory;
|
||||
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
@@ -39,6 +42,18 @@ public class CmsAiChatConfig {
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
/**
|
||||
* 聊天对话数据存储
|
||||
* @author ThinkGem
|
||||
*/
|
||||
@Bean
|
||||
public ChatMemory chatMemory(CacheChatMemoryRepository cacheChatMemoryRepository) {
|
||||
return MessageWindowChatMemory.builder()
|
||||
.chatMemoryRepository(cacheChatMemoryRepository)
|
||||
.maxMessages(1024)
|
||||
.build();
|
||||
}
|
||||
|
||||
// @Bean
|
||||
// public BatchingStrategy batchingStrategy() {
|
||||
// return new TokenCountBatchingStrategy(EncodingType.CL100K_BASE, Integer.MAX_VALUE, 0.1);
|
||||
|
||||
@@ -12,6 +12,7 @@ import com.jeesite.common.lang.StringUtils;
|
||||
import com.jeesite.common.lang.TimeUtils;
|
||||
import com.jeesite.common.utils.PageUtils;
|
||||
import com.jeesite.common.web.http.HttpClientUtils;
|
||||
import com.jeesite.common.web.http.ServletUtils;
|
||||
import com.jeesite.modules.cms.entity.Article;
|
||||
import com.jeesite.modules.cms.service.ArticleVectorStore;
|
||||
import com.jeesite.modules.cms.utils.CmsUtils;
|
||||
@@ -21,11 +22,11 @@ import com.vladsch.flexmark.html2md.converter.FlexmarkHtmlConverter;
|
||||
import com.vladsch.flexmark.html2md.converter.HtmlLinkResolver;
|
||||
import com.vladsch.flexmark.html2md.converter.HtmlLinkResolverFactory;
|
||||
import com.vladsch.flexmark.html2md.converter.HtmlNodeConverterContext;
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
import org.apache.tika.Tika;
|
||||
import org.apache.tika.config.TikaConfig;
|
||||
import org.apache.tika.exception.TikaException;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.document.Document;
|
||||
@@ -75,13 +76,33 @@ public class ArticleVectorStoreImpl implements ArticleVectorStore {
|
||||
metadata.put("updateBy", article.getUpdateBy());
|
||||
metadata.put("updateDate", article.getUpdateDate());
|
||||
List<String> attachmentList = ListUtils.newArrayList();
|
||||
HtmlLinkResolverFactory linkResolverFactory = new HtmlLinkResolverFactory() {
|
||||
String content = article.getTitle() + ", " + article.getKeywords() + ", "
|
||||
+ article.getDescription() + ", " + FlexmarkHtmlConverter.builder()
|
||||
.linkResolverFactory(getHtmlLinkResolverFactory(attachmentList)).build()
|
||||
.convert(article.getArticleData().getContent())
|
||||
+ ", attachment: " + attachmentList;
|
||||
List<Document> documents = List.of(new Document(article.getId(), content, metadata));
|
||||
List<Document> splitDocuments = new TokenTextSplitter().apply(documents);
|
||||
this.delete(article); // 删除原数据
|
||||
ListUtils.pageList(splitDocuments, 10, params -> {
|
||||
vectorStore.add((List<Document>)params[0]); // 增加新数据
|
||||
return null;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析文章中的连接并提取内容
|
||||
* @author ThinkGem
|
||||
*/
|
||||
private @NotNull HtmlLinkResolverFactory getHtmlLinkResolverFactory(List<String> attachmentList) {
|
||||
HttpServletRequest request = ServletUtils.getRequest();
|
||||
return new HtmlLinkResolverFactory() {
|
||||
@Override
|
||||
public @Nullable Set<Class<?>> getAfterDependents() {
|
||||
public @NotNull Set<Class<?>> getAfterDependents() {
|
||||
return Set.of();
|
||||
}
|
||||
@Override
|
||||
public @Nullable Set<Class<?>> getBeforeDependents() {
|
||||
public @NotNull Set<Class<?>> getBeforeDependents() {
|
||||
return Set.of();
|
||||
}
|
||||
@Override
|
||||
@@ -94,11 +115,16 @@ public class ArticleVectorStoreImpl implements ArticleVectorStore {
|
||||
if ("a".equalsIgnoreCase(node.nodeName())) {
|
||||
String href = node.attributes().get("href"); String url = href;
|
||||
if (StringUtils.contains(url, "://")) {
|
||||
try (InputStream is = HttpClientUtils.getInputStream(url, null)) {
|
||||
String text = getDocumentText(is);
|
||||
attachmentList.add(url + text);
|
||||
} catch (IOException | TikaException e) {
|
||||
logger.error(e.getMessage(), e);
|
||||
// 只提取系统允许跳转的附件内容,外部网站内容不进行提取,shiro.allowRedirects 参数设置范围
|
||||
if (ServletUtils.isAllowRedirects(request, url)) {
|
||||
try (InputStream is = HttpClientUtils.getInputStream(url, null)) {
|
||||
if (is != null) {
|
||||
String text = getDocumentText(is);
|
||||
attachmentList.add(url + text);
|
||||
}
|
||||
} catch (IOException | TikaException e) {
|
||||
logger.error(e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
String ctxPath = Global.getCtxPath();
|
||||
@@ -106,8 +132,10 @@ public class ArticleVectorStoreImpl implements ArticleVectorStore {
|
||||
url = url.substring(ctxPath.length());
|
||||
}
|
||||
try (InputStream is = IOUtils.getFileInputStream(Global.getUserfilesBaseDir(url))){
|
||||
String text = getDocumentText(is);
|
||||
attachmentList.add(url + text);
|
||||
if (is != null) {
|
||||
String text = getDocumentText(is);
|
||||
attachmentList.add(url + text);
|
||||
}
|
||||
} catch (IOException | TikaException e) {
|
||||
logger.error(e.getMessage(), e);
|
||||
}
|
||||
@@ -130,18 +158,6 @@ public class ArticleVectorStoreImpl implements ArticleVectorStore {
|
||||
.orElse(StringUtils.EMPTY);
|
||||
}
|
||||
};
|
||||
String content = article.getTitle() + ", " + article.getKeywords() + ", "
|
||||
+ article.getDescription() + ", " + FlexmarkHtmlConverter.builder()
|
||||
.linkResolverFactory(linkResolverFactory).build()
|
||||
.convert(article.getArticleData().getContent())
|
||||
+ ", attachment: " + attachmentList;
|
||||
List<Document> documents = List.of(new Document(article.getId(), content, metadata));
|
||||
List<Document> splitDocuments = new TokenTextSplitter().apply(documents);
|
||||
this.delete(article); // 删除原数据
|
||||
ListUtils.pageList(splitDocuments, 64, params -> {
|
||||
vectorStore.add((List<Document>)params[0]); // 增加新数据
|
||||
return null;
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
/**
|
||||
* Copyright (c) 2013-Now http://jeesite.com All rights reserved.
|
||||
* No deletion without permission, or be held responsible to law.
|
||||
*/
|
||||
package com.jeesite.modules.cms.ai.service;
|
||||
|
||||
import com.jeesite.common.cache.CacheUtils;
|
||||
import com.jeesite.common.collect.ListUtils;
|
||||
import org.springframework.ai.chat.memory.ChatMemory;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* AI 对话消息存储
|
||||
* @author ThinkGem
|
||||
*/
|
||||
@Service
|
||||
public class CacheChatMemory implements ChatMemory {
|
||||
|
||||
private static final String CMS_CHAT_MSG_CACHE = "cmsChatMsgCache";
|
||||
|
||||
@Override
|
||||
public void add(String conversationId, List<Message> messages) {
|
||||
List<Message> conversationHistory = CacheUtils.get(CMS_CHAT_MSG_CACHE, conversationId);
|
||||
if (conversationHistory == null) {
|
||||
conversationHistory = ListUtils.newArrayList();
|
||||
}
|
||||
conversationHistory.addAll(messages);
|
||||
CacheUtils.put(CMS_CHAT_MSG_CACHE, conversationId, conversationHistory);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Message> get(String conversationId, int lastN) {
|
||||
List<Message> all = CacheUtils.get(CMS_CHAT_MSG_CACHE, conversationId);
|
||||
return all != null ? all.stream().skip(Math.max(0, all.size() - lastN)).toList() : List.of();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void clear(String conversationId) {
|
||||
CacheUtils.remove(CMS_CHAT_MSG_CACHE, conversationId);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
/**
|
||||
* Copyright (c) 2013-Now http://jeesite.com All rights reserved.
|
||||
* No deletion without permission, or be held responsible to law.
|
||||
*/
|
||||
package com.jeesite.modules.cms.ai.service;
|
||||
|
||||
import com.jeesite.common.cache.CacheUtils;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.springframework.ai.chat.memory.ChatMemoryRepository;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* AI 对话消息存储
|
||||
* @author ThinkGem
|
||||
*/
|
||||
@Service
|
||||
public class CacheChatMemoryRepository implements ChatMemoryRepository {
|
||||
|
||||
private static final String CMS_CHAT_MSG_CACHE = "cmsChatMsgCache";
|
||||
|
||||
@Override
|
||||
public @NotNull List<String> findConversationIds() {
|
||||
return CacheUtils.getCache(CMS_CHAT_MSG_CACHE).keys().stream().map(Object::toString).toList();
|
||||
}
|
||||
|
||||
@Override
|
||||
public @NotNull List<Message> findByConversationId(@NotNull String conversationId) {
|
||||
List<Message> all = CacheUtils.get(CMS_CHAT_MSG_CACHE, conversationId);
|
||||
return all != null ? all : List.of();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void saveAll(@NotNull String conversationId, @NotNull List<Message> messages) {
|
||||
CacheUtils.put(CMS_CHAT_MSG_CACHE, conversationId, messages);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteByConversationId(@NotNull String conversationId) {
|
||||
CacheUtils.remove(CMS_CHAT_MSG_CACHE, conversationId);
|
||||
}
|
||||
}
|
||||
@@ -59,7 +59,7 @@ public class CmsAiChatService extends BaseService {
|
||||
* @author ThinkGem
|
||||
*/
|
||||
public List<Message> getChatMessage(String conversationId) {
|
||||
return chatMemory.get(conversationId, 100);
|
||||
return chatMemory.get(conversationId);
|
||||
}
|
||||
|
||||
private static String getChatCacheKey() {
|
||||
@@ -119,7 +119,9 @@ public class CmsAiChatService extends BaseService {
|
||||
new UserMessage(StringUtils.replaceEach(message, USER_MESSAGE_SEARCH, USER_MESSAGE_REPLACE))
|
||||
)
|
||||
.advisors(
|
||||
new MessageChatMemoryAdvisor(chatMemory, conversationId, 1024),
|
||||
MessageChatMemoryAdvisor.builder(chatMemory)
|
||||
.conversationId(conversationId)
|
||||
.build(),
|
||||
QuestionAnswerAdvisor.builder(vectorStore)
|
||||
.searchRequest(SearchRequest.builder().similarityThreshold(0.6F).topK(6).build())
|
||||
.promptTemplate(new PromptTemplate(properties.getDefaultPromptTemplate()))
|
||||
|
||||
@@ -61,8 +61,8 @@ spring:
|
||||
host: http://testserver
|
||||
port: 8000
|
||||
initialize-schema: true
|
||||
collection-name: vector_store
|
||||
#collection-name: vector_store_1024
|
||||
# collection-name: vector_store
|
||||
collection-name: vector_store_1024
|
||||
|
||||
# Postgresql 向量数据库(PG 连接配置,见下文,需要手动建表)【请在 pom.xml 中打开 pgvector 的注释,并注释上其它向量库】
|
||||
pgvector:
|
||||
|
||||
Reference in New Issue
Block a user