From 441f26ff3c0fdf5281b6dad8c670a5bf91c1116a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=8F=E5=A5=B6=E5=A5=B6=E8=8A=B1=E7=94=9F=E7=B1=B3?= <279094122@qq.com> Date: Mon, 21 Jul 2025 17:20:50 +0800 Subject: [PATCH] =?UTF-8?q?feat(ai):=20=E6=B7=BB=E5=8A=A0=20PostgreSQL=20p?= =?UTF-8?q?gvector=E7=9F=A2=E9=87=8F=E5=AD=98=E5=82=A8=E5=BA=93=E6=94=AF?= =?UTF-8?q?=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 实现了 PgVectorRepository 类,提供对 PostgreSQL pgvector 扩展的支持 - 添加了元数据字段定义和过滤表达式转换功能- 实现了文档的插入、删除、存在性检查和搜索功能 - 编写了详细的单元测试用例,验证了各种查询和过滤场景 --- .../solon-ai-repo-pgvector/README.md | 185 +++++++ .../solon-ai-repo-pgvector/docker-compose.yml | 32 ++ .../solon-ai-repo-pgvector/pom.xml | 26 + .../ai/rag/repository/PgVectorRepository.java | 495 ++++++++++++++++++ .../pgvector/FilterTransformer.java | 196 +++++++ .../repository/pgvector/MetadataField.java | 70 +++ .../repo/pgvector/PgVectorRepositoryTest.java | 395 ++++++++++++++ 7 files changed, 1399 insertions(+) create mode 100644 solon-ai-rag-repositorys/solon-ai-repo-pgvector/README.md create mode 100644 solon-ai-rag-repositorys/solon-ai-repo-pgvector/docker-compose.yml create mode 100644 solon-ai-rag-repositorys/solon-ai-repo-pgvector/src/main/java/org/noear/solon/ai/rag/repository/PgVectorRepository.java create mode 100644 solon-ai-rag-repositorys/solon-ai-repo-pgvector/src/main/java/org/noear/solon/ai/rag/repository/pgvector/FilterTransformer.java create mode 100644 solon-ai-rag-repositorys/solon-ai-repo-pgvector/src/main/java/org/noear/solon/ai/rag/repository/pgvector/MetadataField.java create mode 100644 solon-ai-rag-repositorys/solon-ai-repo-pgvector/src/test/java/features/ai/repo/pgvector/PgVectorRepositoryTest.java diff --git a/solon-ai-rag-repositorys/solon-ai-repo-pgvector/README.md b/solon-ai-rag-repositorys/solon-ai-repo-pgvector/README.md new file mode 100644 index 00000000..75a82bf8 --- /dev/null +++ b/solon-ai-rag-repositorys/solon-ai-repo-pgvector/README.md @@ -0,0 +1,185 @@ +# solon-ai-repo-pgvector + +基于 PostgreSQL pgvector 扩展的向量存储知识库实现。 + +## 功能特性 + +- 支持向量相似度搜索 +- 支持元数据字段索引和过滤 +- 支持批量操作 +- 支持连接池管理 +- 支持自定义表名和字段配置 +- 支持 UPSERT 操作 + +## 环境要求 + +- PostgreSQL 11+ +- pgvector 扩展 +- Java 8+ + +## 快速开始 + +### 1. 启动 PostgreSQL 和 pgvector + +使用 Docker Compose 快速启动: + +```bash +docker-compose up -d +``` + +或者手动安装 pgvector 扩展: + +```sql +-- 安装 pgvector 扩展 +CREATE EXTENSION IF NOT EXISTS vector; +``` + +### 2. 添加依赖 + +```xml + + org.noear + solon-ai-repo-pgvector + 3.4.0 + +``` + +### 3. 配置数据库连接 + +在 `app.yml` 中配置: + +```yaml +solon: + ai: + repo: + pgvector: + jdbcUrl: "jdbc:postgresql://localhost:5432/solon_ai_test" + username: "postgres" + password: "password" +``` + +### 4. 使用示例 + +```java +import org.noear.solon.ai.embedding.EmbeddingModel; +import org.noear.solon.ai.rag.repository.PgVectorRepository; +import org.noear.solon.ai.rag.repository.pgvector.MetadataField; +import org.noear.solon.ai.rag.Document; + +// 创建 EmbeddingModel +EmbeddingModel embeddingModel = new EmbeddingModel(properties); + +// 定义元数据字段 +List metadataFields = new ArrayList<>(); +metadataFields.add(MetadataField.text("title")); +metadataFields.add(MetadataField.text("category")); +metadataFields.add(MetadataField.numeric("price")); + +// 创建 Repository +PgVectorRepository repository = PgVectorRepository.builder( + embeddingModel, + "jdbc:postgresql://localhost:5432/solon_ai_test", + "postgres", + "password" +) +.tableName("my_documents") +.metadataFields(metadataFields) +.maxPoolSize(20) +.build(); + +// 插入文档 +Document doc = new Document("这是一个关于人工智能的文档"); +doc.getMetadata().put("title", "AI"); +doc.getMetadata().put("category", "technology"); +doc.getMetadata().put("price", 100); + +repository.insert(doc); + +// 搜索文档 +List results = repository.search("人工智能"); + +// 使用过滤条件搜索 +String filterExpression = "category == 'technology' AND price > 50"; +List filteredResults = repository.search( + new QueryCondition("人工智能") + .filterExpression(filterExpression) + .limit(10) +); +``` + +## 配置选项 + +### Builder 配置 + +| 参数 | 类型 | 默认值 | 说明 | +|------|------|--------|------| +| tableName | String | "solon_ai_documents" | 表名 | +| metadataFields | List | [] | 元数据字段定义 | +| maxPoolSize | int | 10 | 最大连接池大小 | +| minIdle | int | 2 | 最小空闲连接数 | +| connectionTimeout | long | 30000 | 连接超时时间(ms) | +| idleTimeout | long | 600000 | 空闲超时时间(ms) | +| maxLifetime | long | 1800000 | 最大生命周期(ms) | + +### 元数据字段类型 + +- `TEXT`: 文本字段,支持字符串比较和模糊匹配 +- `NUMERIC`: 数值字段,支持数值比较和范围查询 +- `JSON`: JSON 字段,支持复杂数据结构 + +## 过滤表达式语法 + +支持以下操作符: + +- 比较操作符:`==`, `!=`, `>`, `>=`, `<`, `<=` +- 逻辑操作符:`AND`, `OR`, `NOT` +- 集合操作符:`IN`, `NOT IN` + +示例: + +```java +// 简单比较 +"title == 'AI'" + +// 数值比较 +"price > 100" + +// 逻辑组合 +"category == 'technology' AND price > 50" + +// 集合操作 +"title IN ['AI', 'ML', 'DL']" + +// 复杂表达式 +"(category == 'technology' OR category == 'science') AND price > 50" +``` + +## 性能优化 + +1. **索引优化**:自动创建向量索引,支持余弦相似度搜索 +2. **连接池**:使用 HikariCP 连接池,提高连接效率 +3. **批量操作**:支持批量插入和删除 +4. **参数化查询**:使用 PreparedStatement 防止 SQL 注入 + +## 注意事项 + +1. 确保 PostgreSQL 已安装 pgvector 扩展 +2. 向量维度需要与 EmbeddingModel 的维度一致 +3. 大量数据插入时建议使用批量操作 +4. 定期维护数据库索引以保持查询性能 + +## 测试 + +运行测试前需要启动 PostgreSQL 服务: + +```bash +# 启动测试数据库 +docker-compose up -d + +# 运行测试 +mvn test +``` + +## 许可证 + +Apache License 2.0 \ No newline at end of file diff --git a/solon-ai-rag-repositorys/solon-ai-repo-pgvector/docker-compose.yml b/solon-ai-rag-repositorys/solon-ai-repo-pgvector/docker-compose.yml new file mode 100644 index 00000000..17ce78db --- /dev/null +++ b/solon-ai-rag-repositorys/solon-ai-repo-pgvector/docker-compose.yml @@ -0,0 +1,32 @@ +version: '3.8' + +services: + postgres: + image: pgvector/pgvector:pg16 + container_name: solon-ai-pgvector + environment: + POSTGRES_DB: solon_ai_test + POSTGRES_USER: postgres + POSTGRES_PASSWORD: password + ports: + - "5432:5432" + volumes: + - postgres_data:/var/lib/postgresql/data + command: > + postgres + -c shared_preload_libraries=vector + -c max_connections=100 + -c shared_buffers=256MB + -c effective_cache_size=1GB + -c maintenance_work_mem=64MB + -c checkpoint_completion_target=0.9 + -c wal_buffers=16MB + -c default_statistics_target=100 + -c random_page_cost=1.1 + -c effective_io_concurrency=200 + -c work_mem=4MB + -c min_wal_size=1GB + -c max_wal_size=4GB + +volumes: + postgres_data: \ No newline at end of file diff --git a/solon-ai-rag-repositorys/solon-ai-repo-pgvector/pom.xml b/solon-ai-rag-repositorys/solon-ai-repo-pgvector/pom.xml index 2ec9252b..9e2fcefa 100644 --- a/solon-ai-rag-repositorys/solon-ai-repo-pgvector/pom.xml +++ b/solon-ai-rag-repositorys/solon-ai-repo-pgvector/pom.xml @@ -21,6 +21,20 @@ solon-ai + + + org.postgresql + postgresql + 42.6.0 + + + + + com.zaxxer + HikariCP + 4.0.3 + + org.noear solon-test @@ -32,6 +46,18 @@ solon-logging-simple test + + + org.noear + solon-ai-load-markdown + test + + + + org.noear + solon-ai-load-html + test + diff --git a/solon-ai-rag-repositorys/solon-ai-repo-pgvector/src/main/java/org/noear/solon/ai/rag/repository/PgVectorRepository.java b/solon-ai-rag-repositorys/solon-ai-repo-pgvector/src/main/java/org/noear/solon/ai/rag/repository/PgVectorRepository.java new file mode 100644 index 00000000..31097390 --- /dev/null +++ b/solon-ai-rag-repositorys/solon-ai-repo-pgvector/src/main/java/org/noear/solon/ai/rag/repository/PgVectorRepository.java @@ -0,0 +1,495 @@ +/* + * Copyright 2017-2025 noear.org and authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.noear.solon.ai.rag.repository; + +import com.zaxxer.hikari.HikariConfig; +import com.zaxxer.hikari.HikariDataSource; +import org.noear.snack.ONode; +import org.noear.solon.Utils; +import org.noear.solon.ai.embedding.EmbeddingModel; +import org.noear.solon.ai.rag.Document; +import org.noear.solon.ai.rag.RepositoryLifecycle; +import org.noear.solon.ai.rag.RepositoryStorable; +import org.noear.solon.ai.rag.repository.pgvector.FilterTransformer; +import org.noear.solon.ai.rag.repository.pgvector.MetadataField; +import org.noear.solon.ai.rag.util.ListUtil; +import org.noear.solon.ai.rag.util.QueryCondition; +import org.noear.solon.ai.rag.util.SimilarityUtil; + +import javax.sql.DataSource; +import java.io.IOException; +import java.sql.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * PostgreSQL pgvector 矢量存储知识库 + * + * @author 小奶奶花生米 + */ +public class PgVectorRepository implements RepositoryStorable, RepositoryLifecycle { + private final Builder config; + private final DataSource dataSource; + + /** + * 私有构造函数,通过 Builder 模式创建 + */ + private PgVectorRepository(Builder config) { + this.config = config; + this.dataSource = createDataSource(); + try { + initRepository(); + } catch (Exception e) { + throw new RuntimeException("Failed to initialize pgvector repository", e); + } + } + + /** + * 创建数据源 + */ + private DataSource createDataSource() { + HikariConfig hikariConfig = new HikariConfig(); + hikariConfig.setJdbcUrl(config.jdbcUrl); + hikariConfig.setUsername(config.username); + hikariConfig.setPassword(config.password); + hikariConfig.setMaximumPoolSize(config.maxPoolSize); + hikariConfig.setMinimumIdle(config.minIdle); + hikariConfig.setConnectionTimeout(config.connectionTimeout); + hikariConfig.setIdleTimeout(config.idleTimeout); + hikariConfig.setMaxLifetime(config.maxLifetime); + + return new HikariDataSource(hikariConfig); + } + + /** + * 初始化仓库 + */ + @Override + public void initRepository() throws Exception { + try (Connection conn = dataSource.getConnection()) { + // 确保 pgvector 扩展已安装 + ensurePgVectorExtension(conn); + + // 检查表是否存在 + if (!tableExists(conn, config.tableName)) { + createTable(conn); + } + } catch (SQLException e) { + throw new Exception("Failed to initialize pgvector repository", e); + } catch (IOException e) { + throw new Exception("Failed to initialize pgvector repository", e); + } + } + + /** + * 确保 pgvector 扩展已安装 + */ + private void ensurePgVectorExtension(Connection conn) throws SQLException { + try (Statement stmt = conn.createStatement()) { + stmt.execute("CREATE EXTENSION IF NOT EXISTS vector"); + } + } + + /** + * 检查表是否存在 + */ + private boolean tableExists(Connection conn, String tableName) throws SQLException { + String sql = "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = ?)"; + try (PreparedStatement stmt = conn.prepareStatement(sql)) { + stmt.setString(1, tableName); + try (ResultSet rs = stmt.executeQuery()) { + return rs.next() && rs.getBoolean(1); + } + } + } + + /** + * 创建表 + */ + private void createTable(Connection conn) throws SQLException, IOException { + StringBuilder sql = new StringBuilder(); + sql.append("CREATE TABLE ").append(config.tableName).append(" ("); + sql.append("id VARCHAR(255) PRIMARY KEY,"); + sql.append("content TEXT NOT NULL,"); + sql.append("embedding VECTOR(").append(config.embeddingModel.dimensions()).append("),"); + sql.append("metadata JSONB,"); + sql.append("created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP"); + + // 添加元数据字段 + if (Utils.isNotEmpty(config.metadataFields)) { + for (MetadataField field : config.metadataFields) { + switch (field.getFieldType()) { + case TEXT: + sql.append(", \"").append(field.getName()).append("\" TEXT"); + break; + case NUMERIC: + sql.append(", \"").append(field.getName()).append("\" NUMERIC"); + break; + case JSON: + sql.append(", \"").append(field.getName()).append("\" JSONB"); + break; + } + } + } + + sql.append(")"); + + try (Statement stmt = conn.createStatement()) { + stmt.execute(sql.toString()); + } + + // 创建向量索引 + String indexSql = String.format( + "CREATE INDEX IF NOT EXISTS idx_%s_embedding ON %s USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100)", + config.tableName, config.tableName + ); + + try (Statement stmt = conn.createStatement()) { + stmt.execute(indexSql); + } + } + + /** + * 注销仓库 + */ + @Override + public void dropRepository() { + try (Connection conn = dataSource.getConnection(); + Statement stmt = conn.createStatement()) { + stmt.execute("DROP TABLE IF EXISTS " + config.tableName); + } catch (SQLException e) { + throw new RuntimeException("Failed to drop pgvector repository", e); + } + } + + /** + * 存储文档列表 + */ + @Override + public void insert(List documents) throws IOException { + if (documents == null || documents.isEmpty()) { + return; + } + + for (List batch : ListUtil.partition(documents, config.embeddingModel.batchSize())) { + config.embeddingModel.embed(batch); + + try (Connection conn = dataSource.getConnection()) { + insertBatch(conn, batch); + } catch (SQLException e) { + throw new IOException("Failed to insert documents", e); + } + } + } + + /** + * 批量插入文档 + */ + private void insertBatch(Connection conn, List documents) throws SQLException { + StringBuilder sql = new StringBuilder(); + sql.append("INSERT INTO ").append(config.tableName).append(" (id, content, embedding, metadata"); + + // 添加元数据字段 + if (Utils.isNotEmpty(config.metadataFields)) { + for (MetadataField field : config.metadataFields) { + sql.append(", \"").append(field.getName()).append("\""); + } + } + + sql.append(") VALUES (?, ?, ?, ?"); + + // 添加元数据字段占位符 + if (Utils.isNotEmpty(config.metadataFields)) { + for (int i = 0; i < config.metadataFields.size(); i++) { + sql.append(", ?"); + } + } + + sql.append(") ON CONFLICT (id) DO UPDATE SET "); + sql.append("content = EXCLUDED.content,"); + sql.append("embedding = EXCLUDED.embedding,"); + sql.append("metadata = EXCLUDED.metadata"); + + // 添加元数据字段更新 + if (Utils.isNotEmpty(config.metadataFields)) { + for (MetadataField field : config.metadataFields) { + sql.append(", \"").append(field.getName()).append("\" = EXCLUDED.\"").append(field.getName()).append("\""); + } + } + + try (PreparedStatement stmt = conn.prepareStatement(sql.toString())) { + for (Document doc : documents) { + if (doc.getId() == null) { + doc.id(Utils.uuid()); + } + + int paramIndex = 1; + stmt.setString(paramIndex++, doc.getId()); + stmt.setString(paramIndex++, doc.getContent()); + stmt.setArray(paramIndex++, conn.createArrayOf("float4", toFloatArray(doc.getEmbedding()))); + stmt.setObject(paramIndex++, ONode.stringify(doc.getMetadata()), Types.OTHER); + + // 设置元数据字段 + if (Utils.isNotEmpty(config.metadataFields)) { + for (MetadataField field : config.metadataFields) { + Object value = doc.getMetadata().get(field.getName()); + if (value != null) { + switch (field.getFieldType()) { + case TEXT: + stmt.setString(paramIndex++, value.toString()); + break; + case NUMERIC: + if (value instanceof Number) { + stmt.setBigDecimal(paramIndex++, new java.math.BigDecimal(value.toString())); + } else { + stmt.setNull(paramIndex++, Types.NUMERIC); + } + break; + case JSON: + stmt.setObject(paramIndex++, ONode.stringify(value), Types.OTHER); + break; + } + } else { + // 根据字段类型设置正确的 NULL 类型 + switch (field.getFieldType()) { + case TEXT: + stmt.setNull(paramIndex++, Types.VARCHAR); + break; + case NUMERIC: + stmt.setNull(paramIndex++, Types.NUMERIC); + break; + case JSON: + stmt.setNull(paramIndex++, Types.OTHER); + break; + } + } + } + } + + stmt.addBatch(); + } + stmt.executeBatch(); + }catch (Exception e){ + throw new SQLException("Failed to insert documents", e); + } + } + + /** + * 删除指定 ID 的文档 + */ + @Override + public void delete(String... ids) throws IOException { + if (ids == null || ids.length == 0) { + return; + } + + String sql = "DELETE FROM " + config.tableName + " WHERE id = ANY(?)"; + + try (Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement(sql)) { + stmt.setArray(1, conn.createArrayOf("varchar", ids)); + stmt.executeUpdate(); + } catch (SQLException e) { + throw new IOException("Failed to delete documents", e); + } + } + + /** + * 检查文档是否存在 + */ + @Override + public boolean exists(String id) throws IOException { + String sql = "SELECT COUNT(*) FROM " + config.tableName + " WHERE id = ?"; + + try (Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement(sql)) { + stmt.setString(1, id); + try (ResultSet rs = stmt.executeQuery()) { + return rs.next() && rs.getInt(1) > 0; + } + } catch (SQLException e) { + throw new IOException("Failed to check document existence", e); + } + } + + /** + * 搜索文档 + */ + @Override + public List search(QueryCondition condition) throws IOException { + float[] queryEmbedding = config.embeddingModel.embed(condition.getQuery()); + + StringBuilder sql = new StringBuilder(); + sql.append("SELECT id, content, metadata, 1 - (embedding <=> ?::vector) as similarity "); + sql.append("FROM ").append(config.tableName); + + // 添加过滤条件 + String filterClause = FilterTransformer.getInstance().transform(condition.getFilterExpression()); + if (Utils.isNotEmpty(filterClause)) { + sql.append(" WHERE ").append(filterClause); + } + + sql.append(" ORDER BY embedding <=> ?::vector"); + sql.append(" LIMIT ?"); + + try (Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement(sql.toString())) { + + int paramIndex = 1; + stmt.setArray(paramIndex++, conn.createArrayOf("float4", toFloatArray(queryEmbedding))); + stmt.setArray(paramIndex++, conn.createArrayOf("float4", toFloatArray(queryEmbedding))); + stmt.setInt(paramIndex++, condition.getLimit()); + + List results = new ArrayList<>(); + try (ResultSet rs = stmt.executeQuery()) { + while (rs.next()) { + String id = rs.getString("id"); + String content = rs.getString("content"); + String metadataJson = rs.getString("metadata"); + double similarity = rs.getDouble("similarity"); + + Map metadata = new HashMap<>(); + if (Utils.isNotEmpty(metadataJson)) { + metadata = ONode.deserialize(metadataJson, Map.class); + } + + Document doc = new Document(id, content, metadata, similarity); + results.add(doc); + } + } + + return SimilarityUtil.refilter(results.stream(), condition); + } catch (SQLException e) { + throw new IOException("Failed to search documents", e); + } + } + + /** + * 将 float 数组转换为 Float 数组 + */ + private Float[] toFloatArray(float[] array) { + if (array == null) { + return new Float[0]; + } + Float[] result = new Float[array.length]; + for (int i = 0; i < array.length; i++) { + result[i] = array[i]; + } + return result; + } + + /** + * 创建 PgVectorRepository 构建器 + */ + public static Builder builder(EmbeddingModel embeddingModel, String jdbcUrl, String username, String password) { + return new Builder(embeddingModel, jdbcUrl, username, password); + } + + /** + * Builder 类用于链式构建 PgVectorRepository + */ + public static class Builder { + // 必需参数 + private final EmbeddingModel embeddingModel; + private final String jdbcUrl; + private final String username; + private final String password; + + // 可选参数,设置默认值 + private String tableName = "solon_ai"; + private List metadataFields = new ArrayList<>(); + private int maxPoolSize = 10; + private int minIdle = 2; + private long connectionTimeout = 30000; + private long idleTimeout = 600000; + private long maxLifetime = 1800000; + + /** + * 构造器 + */ + public Builder(EmbeddingModel embeddingModel, String jdbcUrl, String username, String password) { + this.embeddingModel = embeddingModel; + this.jdbcUrl = jdbcUrl; + this.username = username; + this.password = password; + } + + /** + * 设置表名 + */ + public Builder tableName(String tableName) { + this.tableName = tableName; + return this; + } + + /** + * 设置元数据字段 + */ + public Builder metadataFields(List metadataFields) { + this.metadataFields = metadataFields; + return this; + } + + /** + * 设置最大连接池大小 + */ + public Builder maxPoolSize(int maxPoolSize) { + this.maxPoolSize = maxPoolSize; + return this; + } + + /** + * 设置最小空闲连接数 + */ + public Builder minIdle(int minIdle) { + this.minIdle = minIdle; + return this; + } + + /** + * 设置连接超时时间 + */ + public Builder connectionTimeout(long connectionTimeout) { + this.connectionTimeout = connectionTimeout; + return this; + } + + /** + * 设置空闲超时时间 + */ + public Builder idleTimeout(long idleTimeout) { + this.idleTimeout = idleTimeout; + return this; + } + + /** + * 设置最大生命周期 + */ + public Builder maxLifetime(long maxLifetime) { + this.maxLifetime = maxLifetime; + return this; + } + + /** + * 构建 PgVectorRepository 实例 + */ + public PgVectorRepository build() { + return new PgVectorRepository(this); + } + } +} \ No newline at end of file diff --git a/solon-ai-rag-repositorys/solon-ai-repo-pgvector/src/main/java/org/noear/solon/ai/rag/repository/pgvector/FilterTransformer.java b/solon-ai-rag-repositorys/solon-ai-repo-pgvector/src/main/java/org/noear/solon/ai/rag/repository/pgvector/FilterTransformer.java new file mode 100644 index 00000000..81da31d9 --- /dev/null +++ b/solon-ai-rag-repositorys/solon-ai-repo-pgvector/src/main/java/org/noear/solon/ai/rag/repository/pgvector/FilterTransformer.java @@ -0,0 +1,196 @@ +/* + * Copyright 2017-2025 noear.org and authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.noear.solon.ai.rag.repository.pgvector; + +import org.noear.solon.expression.Expression; +import org.noear.solon.expression.Transformer; +import org.noear.solon.expression.snel.*; + +import java.util.Collection; + +/** + * pgvector 过滤转换器,将表达式转换为 SQL WHERE 子句 + * + * @author noear + * @since 3.1 + */ +public class FilterTransformer implements Transformer { + private static FilterTransformer instance = new FilterTransformer(); + + public static FilterTransformer getInstance() { + return instance; + } + + @Override + public String transform(Expression filterExpression) { + if (filterExpression == null) { + return null; + } + + try { + StringBuilder buf = new StringBuilder(); + parseFilterExpression(filterExpression, buf); + + if (buf.length() == 0) { + return null; + } + + return buf.toString(); + } catch (Exception e) { + System.err.println("Error processing filter expression: " + e.getMessage()); + return null; + } + } + + /** + * 解析QueryCondition中的filterExpression,转换为SQL WHERE子句 + * + * @param filterExpression 过滤表达式 + * @param buf 字符串构建器 + */ + private void parseFilterExpression(Expression filterExpression, StringBuilder buf) { + if (filterExpression == null) { + return; + } + + if (filterExpression instanceof VariableNode) { + // 变量节点,获取字段名 + String name = ((VariableNode) filterExpression).getName(); + buf.append("\"").append(name).append("\""); + } else if (filterExpression instanceof ConstantNode) { + ConstantNode node = (ConstantNode) filterExpression; + // 常量节点,获取值 + Object value = node.getValue(); + + if (node.isCollection()) { + // 集合使用 IN 语法 + buf.append("("); + boolean first = true; + for (Object item : (Collection) value) { + if (!first) { + buf.append(", "); + } + if (item instanceof String) { + buf.append("'").append(item.toString().replace("'", "''")).append("'"); + } else { + buf.append(item); + } + first = false; + } + buf.append(")"); + } else if (value instanceof String) { + // 字符串值使用单引号 + buf.append("'").append(value.toString().replace("'", "''")).append("'"); + } else { + buf.append(value); + } + } else if (filterExpression instanceof ComparisonNode) { + ComparisonNode node = (ComparisonNode) filterExpression; + ComparisonOp operator = node.getOperator(); + Expression left = node.getLeft(); + Expression right = node.getRight(); + + // 比较节点 + switch (operator) { + case eq: + parseFilterExpression(left, buf); + buf.append(" = "); + parseFilterExpression(right, buf); + break; + case neq: + parseFilterExpression(left, buf); + buf.append(" != "); + parseFilterExpression(right, buf); + break; + case gt: + parseFilterExpression(left, buf); + buf.append(" > "); + parseFilterExpression(right, buf); + break; + case gte: + parseFilterExpression(left, buf); + buf.append(" >= "); + parseFilterExpression(right, buf); + break; + case lt: + parseFilterExpression(left, buf); + buf.append(" < "); + parseFilterExpression(right, buf); + break; + case lte: + parseFilterExpression(left, buf); + buf.append(" <= "); + parseFilterExpression(right, buf); + break; + case in: + parseFilterExpression(left, buf); + buf.append(" IN "); + parseFilterExpression(right, buf); + break; + case nin: + parseFilterExpression(left, buf); + buf.append(" NOT IN "); + parseFilterExpression(right, buf); + break; + default: + parseFilterExpression(left, buf); + buf.append(" = "); + parseFilterExpression(right, buf); + break; + } + } else if (filterExpression instanceof LogicalNode) { + LogicalNode node = (LogicalNode) filterExpression; + LogicalOp operator = node.getOperator(); + Expression left = node.getLeft(); + Expression right = node.getRight(); + + buf.append("("); + + if (right != null) { + // 二元操作符 (AND, OR) + parseFilterExpression(left, buf); + + switch (operator) { + case AND: + buf.append(" AND "); + break; + case OR: + buf.append(" OR "); + break; + default: + // 其他操作符,默认用 AND + buf.append(" AND "); + break; + } + + parseFilterExpression(right, buf); + } else { + // 一元操作符 (NOT) + switch (operator) { + case NOT: + buf.append("NOT "); + break; + default: + // 其他一元操作符,不添加前缀 + break; + } + parseFilterExpression(left, buf); + } + + buf.append(")"); + } + } +} \ No newline at end of file diff --git a/solon-ai-rag-repositorys/solon-ai-repo-pgvector/src/main/java/org/noear/solon/ai/rag/repository/pgvector/MetadataField.java b/solon-ai-rag-repositorys/solon-ai-repo-pgvector/src/main/java/org/noear/solon/ai/rag/repository/pgvector/MetadataField.java new file mode 100644 index 00000000..450d976a --- /dev/null +++ b/solon-ai-rag-repositorys/solon-ai-repo-pgvector/src/main/java/org/noear/solon/ai/rag/repository/pgvector/MetadataField.java @@ -0,0 +1,70 @@ +/* + * Copyright 2017-2025 noear.org and authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.noear.solon.ai.rag.repository.pgvector; + +/** + * pgvector 元数据字段定义 + * + * @author noear + * @since 3.1 + */ +public class MetadataField { + + private String name; + private FieldType fieldType; + + public static MetadataField text(String name) { + return new MetadataField(name, FieldType.TEXT); + } + + public static MetadataField numeric(String name) { + return new MetadataField(name, FieldType.NUMERIC); + } + + public static MetadataField json(String name) { + return new MetadataField(name, FieldType.JSON); + } + + public MetadataField(String name, FieldType fieldType) { + this.name = name; + this.fieldType = fieldType; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public FieldType getFieldType() { + return fieldType; + } + + public void setFieldType(FieldType fieldType) { + this.fieldType = fieldType; + } + + /** + * 字段类型枚举 + */ + public enum FieldType { + TEXT, + NUMERIC, + JSON + } +} \ No newline at end of file diff --git a/solon-ai-rag-repositorys/solon-ai-repo-pgvector/src/test/java/features/ai/repo/pgvector/PgVectorRepositoryTest.java b/solon-ai-rag-repositorys/solon-ai-repo-pgvector/src/test/java/features/ai/repo/pgvector/PgVectorRepositoryTest.java new file mode 100644 index 00000000..d4bc9979 --- /dev/null +++ b/solon-ai-rag-repositorys/solon-ai-repo-pgvector/src/test/java/features/ai/repo/pgvector/PgVectorRepositoryTest.java @@ -0,0 +1,395 @@ +package features.ai.repo.pgvector; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.noear.solon.ai.embedding.EmbeddingModel; +import org.noear.solon.ai.rag.Document; +import org.noear.solon.ai.rag.DocumentLoader; +import org.noear.solon.ai.rag.RepositoryStorable; +import org.noear.solon.ai.rag.loader.HtmlSimpleLoader; +import org.noear.solon.ai.rag.loader.MarkdownLoader; +import org.noear.solon.ai.rag.repository.PgVectorRepository; +import org.noear.solon.ai.rag.repository.pgvector.MetadataField; +import org.noear.solon.ai.rag.splitter.RegexTextSplitter; +import org.noear.solon.ai.rag.splitter.SplitterPipeline; +import org.noear.solon.ai.rag.splitter.TokenSizeTextSplitter; +import org.noear.solon.ai.rag.util.QueryCondition; +import org.noear.solon.net.http.HttpUtils; +import org.noear.solon.test.SolonTest; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * PgVectorRepository 测试类 + * + * @author noear + * @since 3.1 + */ +@SolonTest +public class PgVectorRepositoryTest { + private PgVectorRepository repository; + private EmbeddingModel embeddingModel; + final String apiUrl = "http://192.168.1.16:11434/api/embed"; + final String provider = "ollama"; + final String model = "bge-m3";// + + @BeforeEach + public void setup() throws Exception { + embeddingModel = EmbeddingModel.of(apiUrl).provider(provider).model(model).build(); + + // 从配置中获取数据库连接信息 + String jdbcUrl = "jdbc:postgresql://localhost:5432/solon_ai_test"; + String username = "postgres"; + String password = "test_123!"; + + // 创建元数据索引字段列表 + List metadataFields = new ArrayList<>(); + metadataFields.add(MetadataField.text("title")); + metadataFields.add(MetadataField.text("category")); + metadataFields.add(MetadataField.numeric("price")); + metadataFields.add(MetadataField.numeric("stock")); + + // 创建测试用的 Repository + repository = PgVectorRepository.builder(embeddingModel, jdbcUrl, username, password) + .tableName("test_documents") + .metadataFields(metadataFields) + .build(); + + // 清理并重新初始化 + repository.dropRepository(); + repository.initRepository(); + load(repository, "https://solon.noear.org/article/about?format=md"); + load(repository, "https://h5.noear.org/readme.htm"); + } + + @AfterEach + public void cleanup() { + if (repository != null) { + try { + repository.dropRepository(); + } catch (Exception e) { + // 忽略清理错误 + } + } + } + + @Test + public void testSearch() throws Exception { + + List list = repository.search("solon"); + assert list.size() >= 3;//可能3个(效果更好)或4个 + + list = repository.search("dubbo"); + assert list.isEmpty(); + + Document doc = new Document("Test content"); + repository.insert(Collections.singletonList(doc)); + String key = doc.getId(); + + Thread.sleep(1000); + assertTrue(repository.exists(key), "Document should exist after storing"); + + Thread.sleep(1000); + repository.delete(doc.getId()); + assertFalse(repository.exists(key), "Document should not exist after removal"); + } + + @Test + public void testRemove() { + // 准备并存储测试数据 + List documents = new ArrayList<>(); + Document doc = new Document("Document to be removed", new HashMap<>()); + documents.add(doc); + + try { + repository.insert(documents); + Thread.sleep(1000); + // 删除文档 + repository.delete(doc.getId()); + + Thread.sleep(1000); + // 验证文档已被删除 + assertFalse(repository.exists(doc.getId()), "文档应该已被删除"); + + } catch (Exception e) { + fail("测试过程中发生异常: " + e.getMessage()); + } + } + + @Test + public void testScoreOutput() throws IOException { + + try { + // 执行搜索查询 + QueryCondition condition = new QueryCondition("solon").disableRefilter(true); + List results = repository.search(condition); + + // 验证结果不为空 + assertFalse(results.isEmpty(), "搜索结果不应为空"); + + // 验证每个文档都有评分 + for (Document doc : results) { + assertTrue(doc.getScore() >= 0, "文档评分应该是非负数"); + + System.out.println("Document ID: " + doc.getId() + ", Score: " + doc.getScore()); + } + + // 验证评分排序(如果有多个结果) + if (results.size() > 1) { + double firstScore = results.get(0).getScore(); + double secondScore = results.get(1).getScore(); + + // 检查第一个结果的评分是否大于或等于第二个结果 + assertTrue(firstScore >= secondScore, "结果应该按评分降序排序"); + } + + // 打印所有结果的评分 + System.out.println("\n=== 评分测试结果 ==="); + for (Document doc : results) { + double score = doc.getScore(); + System.out.println("ID: " + doc.getId()); + System.out.println("Score: " + score); + System.out.println("Content: " + doc.getContent().substring(0, Math.min(50, doc.getContent().length())) + "..."); + System.out.println("---"); + } + + } catch (Exception e) { + fail("测试过程中发生异常: " + e.getMessage()); + } + } + + @Test + public void case1_search() throws Exception { + List list = repository.search("solon"); + assert list.size() == 4; + + List list2 = repository.search("temporal"); + assert list2.isEmpty(); + + /// ///////////////////////////// + + // 准备并存储文档,显式指定 ID + Document doc = new Document("Test content"); + repository.insert(Collections.singletonList(doc)); + String key = doc.getId(); + + // 验证存储成功 + assertTrue(repository.exists(key), "Document should exist after storing"); + + // 删除文档 + repository.delete(doc.getId()); + + // 验证删除成功 + assertFalse(repository.exists(key), "Document should not exist after removal"); + } + + @Test + public void case2_expression() throws Exception { + // 新增带有元数据的文档 + Document doc1 = new Document("Document about Solon framework"); + doc1.getMetadata().put("title", "solon"); + doc1.getMetadata().put("category", "framework"); + + Document doc2 = new Document("Document about Java settings"); + doc2.getMetadata().put("title", "设置"); + doc2.getMetadata().put("category", "tutorial"); + + Document doc3 = new Document("Document about Spring framework"); + doc3.getMetadata().put("title", "spring"); + doc3.getMetadata().put("category", "framework"); + + List documents = new ArrayList<>(); + documents.add(doc1); + documents.add(doc2); + documents.add(doc3); + repository.insert(documents); + + try { + // 1. 使用OR表达式过滤进行搜索 + String orExpression = "title == 'solon' OR title == '设置'"; + List orResults = repository.search(new QueryCondition("framework").filterExpression(orExpression).disableRefilter(true)); + + System.out.println("Found " + orResults.size() + " documents with OR filter expression: " + orExpression); + + // 验证结果包含2个文档 + assert orResults.size() == 2; + + // 2. 使用AND表达式过滤 + String andExpression = "title == 'solon' AND category == 'framework'"; + List andResults = repository.search(new QueryCondition("framework").filterExpression(andExpression).disableRefilter(true)); + + System.out.println("Found " + andResults.size() + " documents with AND filter expression: " + andExpression); + + // 验证结果只包含1个文档 + assertEquals(1, andResults.size()); + + // 3. 使用category过滤 + String categoryExpression = "category == 'framework'"; + List categoryResults = repository.search(new QueryCondition("framework").filterExpression(categoryExpression).disableRefilter(true)); + + System.out.println("Found " + categoryResults.size() + " documents with category filter: " + categoryExpression); + + // 验证结果包含2个framework类别的文档 + assertEquals(2, categoryResults.size()); + } finally { + // 清理测试数据 + repository.delete(doc1.getId(), doc2.getId(), doc3.getId()); + } + } + + @Test + public void testAdvancedExpressionFilter() throws IOException { + try { + // 创建测试文档 + List documents = new ArrayList<>(); + + Document doc1 = new Document("Document with numeric properties"); + doc1.metadata("price", 100); + doc1.metadata("stock", 50); + doc1.metadata("category", "electronics"); + + Document doc2 = new Document("Document with different price"); + doc2.metadata("price", 200); + doc2.metadata("stock", 10); + doc2.metadata("category", "electronics"); + + Document doc3 = new Document("Document with different category"); + doc3.metadata("price", 150); + doc3.metadata("stock", 25); + doc3.metadata("category", "books"); + + documents.add(doc1); + documents.add(doc2); + documents.add(doc3); + + // 插入测试文档 + repository.insert(documents); + + // 等待索引更新 + Thread.sleep(1000); + + // 1. 测试数值比较 (大于) + String gtExpression = "price > 120"; + QueryCondition gtCondition = new QueryCondition("document") + .filterExpression(gtExpression); + + List gtResults = repository.search(gtCondition); + System.out.println("找到 " + gtResults.size() + " 个文档,使用大于表达式: " + gtExpression); + + // 验证结果 - 应该找到两个价格大于120的文档 + assertTrue(gtResults.size() > 0, "大于表达式应该找到文档"); + int countGt120 = 0; + for (Document doc : gtResults) { + int price = ((Number) doc.getMetadata("price")).intValue(); + if (price > 120) { + countGt120++; + } + } + assertTrue(countGt120 > 0, "应该找到价格大于120的文档"); + + // 2. 测试数值比较 (小于等于) + String lteExpression = "stock <= 25"; + QueryCondition lteCondition = new QueryCondition("document") + .filterExpression(lteExpression); + + List lteResults = repository.search(lteCondition); + System.out.println("找到 " + lteResults.size() + " 个文档,使用小于等于表达式: " + lteExpression); + + // 验证结果 - 应该找到两个库存小于等于25的文档 + assertTrue(lteResults.size() > 0, "小于等于表达式应该找到文档"); + int countLte25 = 0; + for (Document doc : lteResults) { + int stock = ((Number) doc.getMetadata("stock")).intValue(); + if (stock <= 25) { + countLte25++; + } + } + assertTrue(countLte25 > 0, "应该找到库存小于等于25的文档"); + + // 3. 测试复合表达式 (价格区间和类别) + String complexExpression = "(price >= 100 AND price <= 180) AND category == 'electronics'"; + QueryCondition complexCondition = new QueryCondition("document") + .filterExpression(complexExpression); + + List complexResults = repository.search(complexCondition); + System.out.println("找到 " + complexResults.size() + " 个文档,使用复合表达式: " + complexExpression); + + // 验证结果 - 应该找到一个满足所有条件的文档 + assertTrue(complexResults.size() > 0, "复合表达式应该找到文档"); + boolean foundMatch = false; + for (Document doc : complexResults) { + int price = ((Number) doc.getMetadata("price")).intValue(); + String category = (String) doc.getMetadata("category"); + if (price >= 100 && price <= 180 && "electronics".equals(category)) { + foundMatch = true; + break; + } + } + assertTrue(foundMatch, "应该找到符合复合条件的文档"); + + // 4. 测试 IN 操作符 + String inExpression = "category IN ['electronics', 'books']"; + QueryCondition inCondition = new QueryCondition("document") + .filterExpression(inExpression); + + List inResults = repository.search(inCondition); + System.out.println("找到 " + inResults.size() + " 个文档,使用IN表达式: " + inExpression); + assertTrue(inResults.size() > 0, "IN表达式应该找到文档"); + + // 5. 测试 NOT 操作符 + String notExpression = "NOT (category == 'books')"; + QueryCondition notCondition = new QueryCondition("document") + .filterExpression(notExpression); + List notResults = repository.search(notCondition); + System.out.println("找到 " + notResults.size() + " 个文档,使用NOT表达式: " + notExpression); + assertTrue(notResults.size() > 0, "NOT表达式应该找到文档"); + boolean foundNonBooks = false; + for (Document doc : notResults) { + String category = (String) doc.getMetadata("category"); + if (!"books".equals(category)) { + foundNonBooks = true; + break; + } + } + assertTrue(foundNonBooks, "应该找到非books类别的文档"); + + // 打印结果 + System.out.println("\n=== 高级表达式过滤测试结果 ==="); + System.out.println("大于表达式结果数量: " + gtResults.size()); + System.out.println("小于等于表达式结果数量: " + lteResults.size()); + System.out.println("复合表达式结果数量: " + complexResults.size()); + System.out.println("IN表达式结果数量: " + inResults.size()); + System.out.println("NOT表达式结果数量: " + notResults.size()); + + } catch (Exception e) { + e.printStackTrace(); + fail("测试过程中发生异常: " + e.getMessage()); + } + } + + private void load(RepositoryStorable repository, String url) throws IOException { + String text = HttpUtils.http(url).get(); + + DocumentLoader loader = null; + if (text.contains("")) { + loader = new HtmlSimpleLoader(text.getBytes(StandardCharsets.UTF_8)); + } else { + loader = new MarkdownLoader(text.getBytes(StandardCharsets.UTF_8)); + } + + List documents = new SplitterPipeline() //2.分割文档(确保不超过 max-token-size) + .next(new RegexTextSplitter()) + .next(new TokenSizeTextSplitter(500)) + .split(loader.load()); + + repository.insert(documents); //(推入文档) + } +} \ No newline at end of file -- Gitee