diff --git a/debian/changelog b/debian/changelog index cbbe5d51dadd3c2fc99dac0bf8c2d574f2f295ba..29a2ac528a8d343b1dc7e48b4c66c10596ceb727 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +kylin-ai-business-framework-service (1.0.0.0-ok0.5) nile; urgency=medium + + * 解决建索引过程中无法搜索问题 + + -- wangweinan Wed, 22 May 2024 10:02:46 +0800 + kylin-ai-business-framework-service (1.0.0.0-ok0.4) nile; urgency=medium * 添加libpython3-dev依赖,解决编包失败的问题 diff --git a/src/datamanagement/datamanagementdatabase.cpp b/src/datamanagement/datamanagementdatabase.cpp index 0361eebe35235bdedc6ad462fb2342d8fdf7b5e9..153c8fbec593dc529d7421e94e807ba00405e8a3 100644 --- a/src/datamanagement/datamanagementdatabase.cpp +++ b/src/datamanagement/datamanagementdatabase.cpp @@ -49,28 +49,9 @@ std::vector convertFloatToDouble(const std::vector& floatVec) { } -DataManagementDatabase::~DataManagementDatabase() +DataManagementDatabase::DataManagementDatabase() + : database_(std::string(getVectorDatabasePath() + VECTOR_DATABASE_NAME)) { - if (connected_) - database_.destroyClient(); -} - -bool DataManagementDatabase::createClient() -{ - std::string databasePath = getVectorDatabasePath() + "/" + VECTOR_DATABASE_NAME; - - if (database_.createClient(databasePath) != VectorDBErrorCode::Success) { - return false; - } - - connected_ = true; - return true; -} - -void DataManagementDatabase::destroyClient() -{ - connected_ = false; - database_.destroyClient(); } void DataManagementDatabase::initCollections() @@ -93,7 +74,7 @@ SimilaritySearchResult DataManagementDatabase::textSearch(const std::vector& vector) { - if (vector.empty() || collection.empty() || !connected_) { + if (vector.empty() || collection.empty()) { return std::vector>(); } @@ -110,18 +91,29 @@ SimilaritySearchResult DataManagementDatabase::similaritySearch( std::vector> result; // 通过相似度搜到得到的sourceId,在另一个collection中查找对应的文件 + std::vector ids; for (const std::pair& searchResult : searchResults) { VectorData data = searchResult.first; double similarity = 1 - searchResult.second; // 目前向量库返回的是距离,我们要转换成相似度 std::string sourceId = std::get(data.metadata["sourceId"]); - std::vector ids = { sourceId }; - auto filedatas = database_.getData(FILE_INFO_COLLECTION_NAME, ids); - if (filedatas.empty()) - continue; - std::string filepath = std::get(filedatas[0].metadata["source"]); + if (std::find(ids.begin(), ids.end(), sourceId) == ids.end()) + ids.emplace_back(sourceId); + } + + auto filedatas = database_.getData(FILE_INFO_COLLECTION_NAME, ids); - result.emplace_back(std::move(filepath), std::move(similarity)); + for (const std::pair& searchResult : searchResults) { + VectorData data = searchResult.first; + double similarity = 1 - searchResult.second; // 目前向量库返回的是距离,我们要转换成相似度 + std::string sourceId = std::get(data.metadata["sourceId"]); + for (VectorData& filedata : filedatas) { + if (filedata.id == sourceId) { + std::string filepath = std::get(filedata.metadata["source"]); + result.emplace_back(std::move(filepath), std::move(similarity)); + break; + } + } } return result; @@ -139,11 +131,6 @@ void DataManagementDatabase::addTextDatas(const std::vector &fileinf void DataManagementDatabase::addDatas(const std::vector &fileinfos, const std::string& collection) { - if (!connected_) { - std::cerr << "Please create client first" << std::endl; - return; - } - std::vector fileDatas; std::vector vectorDatas; @@ -189,10 +176,6 @@ void DataManagementDatabase::addDatas(const std::vector &fileinfos, GetAllFileInfosResult DataManagementDatabase::getAllFileInfos() { GetAllFileInfosResult result; - if (!connected_) { - std::cerr << "Please create client first" << std::endl; - return result; - } if (!hasCollections({FILE_INFO_COLLECTION_NAME})) { std::cerr << "Don't have " << FILE_INFO_COLLECTION_NAME << " collection, " << "can't get file infos" << std::endl; @@ -210,10 +193,6 @@ GetAllFileInfosResult DataManagementDatabase::getAllFileInfos() void DataManagementDatabase::deleteFiles(const std::vector &files) { - if (!connected_) { - std::cerr << "Please create client first" << std::endl; - return; - } if (!hasCollections({FILE_INFO_COLLECTION_NAME})) { std::cerr << "Don't have " << FILE_INFO_COLLECTION_NAME << " collection, " << "can't get file infos" << std::endl; diff --git a/src/datamanagement/datamanagementdatabase.h b/src/datamanagement/datamanagementdatabase.h index 4529c3015840421ca63faee6a459b482553cb7b8..a9fd82c2148702e88462a2118cb5ddf0e35b4a82 100644 --- a/src/datamanagement/datamanagementdatabase.h +++ b/src/datamanagement/datamanagementdatabase.h @@ -44,10 +44,8 @@ struct DBFileInfo { class DataManagementDatabase { public: - DataManagementDatabase() = default; - ~DataManagementDatabase(); - bool createClient(); - void destroyClient(); + DataManagementDatabase(); + ~DataManagementDatabase() = default; void initCollections(); SimilaritySearchResult visionSearch(const std::vector &vector); @@ -68,7 +66,6 @@ private: private: VectorDB database_; - bool connected_ = false; }; #endif // DATAMANAGEMENTDATABASE_H diff --git a/src/datamanagement/datamanagementservice.cpp b/src/datamanagement/datamanagementservice.cpp index ea629e895978ec15595346c8f75722a09e946f88..fa5ffe317c9e457be20bca11745f4a71e35c5d65 100644 --- a/src/datamanagement/datamanagementservice.cpp +++ b/src/datamanagement/datamanagementservice.cpp @@ -42,14 +42,7 @@ typedef enum { DataManagementService::DataManagementService() { - //! \note chromadb 内部也有使用多线程,如果直接在其他线程里使用chromadb, - //! 可能会导致对 Python GIL 的操作冲突,经过验证,首次使用 chromadb 需要 - //! 在初始化 Python 环境的线程,这里选择调用数据库接口初始化 collection - //! 此外还存在一个问题,实际上每次destroyClient都没有能真正释放掉数据库资源, - //! 如果强行在 destroyClient 中清理所有的引用计数,会导致使用数据库接口时概率崩溃 - database_.createClient(); database_.initCollections(); - database_.destroyClient(); } void DataManagementService::similaritySearch(const std::string& searchConditions, SimilaritySearchCallback callback) @@ -305,13 +298,8 @@ std::string DataManagementService::doSimilaritySearch( return std::string(); } - if (!database_.createClient()) { - errorCode = DATA_MANAGEMENT_CANT_OPEN_DB; - return std::string(); - } auto visionSearchResult = database_.visionSearch(visionSideEmbedding); auto textSearchResult = database_.textSearch(textSideEmbedding); - database_.destroyClient(); if (visionSearchResult.empty() && textSearchResult.empty()) { errorCode = DATA_MANAGEMENT_CANT_OPEN_DB; @@ -349,12 +337,7 @@ void DataManagementService::doAddImageFiles( return; } - if (!database_.createClient()) { - errorCode = DATA_MANAGEMENT_CANT_OPEN_DB; - return; - } database_.addImageDatas(infos); - database_.destroyClient(); } void DataManagementService::doAddTextFiles( @@ -374,21 +357,11 @@ void DataManagementService::doAddTextFiles( return; } - if (!database_.createClient()) { - errorCode = DATA_MANAGEMENT_CANT_OPEN_DB; - return; - } database_.addTextDatas(infos); - database_.destroyClient(); } std::string DataManagementService::doGetAllFileInfos(int &errorCode) { - if (!database_.createClient()) { - errorCode = DATA_MANAGEMENT_CANT_OPEN_DB; - return std::string(); - } - GetAllFileInfosResult result = database_.getAllFileInfos(); return DataManagementJsonHelper::convertFileInfosResultToJson(result); } @@ -401,13 +374,8 @@ void DataManagementService::doDeleteFiles(const std::string &fileinfosJson, int return; } - if (!database_.createClient()) { - errorCode = DATA_MANAGEMENT_CANT_OPEN_DB; - return; - } database_.deleteFiles(files); - database_.destroyClient(); } void DataManagementService::doUpdateFilesContent(const std::string &fileinfosJson, int &errorCode) @@ -436,16 +404,10 @@ void DataManagementService::doUpdateFilesContent(const std::string &fileinfosJso auto imageDatas = makeDatasByImageFilePathAndFormat(std::move(imageFiles)); auto textDatas = makeDatasByTextFilePathAndFormat(std::move(textFiles)); - if (!database_.createClient()) { - errorCode = DATA_MANAGEMENT_CANT_OPEN_DB; - return; - } database_.deleteFiles(files); database_.addImageDatas(std::move(imageDatas)); database_.addTextDatas(std::move(textDatas)); - - database_.destroyClient(); } std::vector DataManagementService::textSideEmbeddingText(const std::string& text) diff --git a/src/utils/python/autotokenizer.py b/src/utils/python/autotokenizer.py index 0bcd07be37fdf05e6a11de9d941c01a5da46cddc..74a3198cde882467955f0607ffe830566451a758 100644 --- a/src/utils/python/autotokenizer.py +++ b/src/utils/python/autotokenizer.py @@ -14,7 +14,8 @@ from transformers import AutoTokenizer - +import os +os.environ["TOKENIZERS_PARALLELISM"] = "false" tokenizer = None def init_tokenizer(tokenizer_file_path): diff --git a/src/utils/python/chroma_db.py b/src/utils/python/chroma_db.py index c49df26a069f33abd65468c24ad2d6651bdabbfa..a6cef38dc6c9b2821479c4e49e2e5fe542fcc090 100644 --- a/src/utils/python/chroma_db.py +++ b/src/utils/python/chroma_db.py @@ -14,21 +14,96 @@ import chromadb from chromadb.config import Settings +import multiprocessing -class ChromaDB: - def __init__(self, pathname): - self.client = self.create_PersistentClient(pathname) +def do_list_collections(path): + client = chromadb.PersistentClient(path=path) + results = client.list_collections() + names = [collection.name for collection in results] + return names + +def do_create_collection(path, collection_name, collection_metadata="cosine"): + client = chromadb.PersistentClient(path=path) + client.create_collection(name=collection_name, metadata={"hnsw:space": collection_metadata}) + +def do_get_or_create_collection(path, collection_name, collection_metadata="cosine"): + client = chromadb.PersistentClient(path=path) + client.get_or_create_collection(name=collection_name, metadata={"hnsw:space": collection_metadata}) + +def do_add_data(path, collection_name, ids, metadatas, vectors = None, documents = None): + client = chromadb.PersistentClient(path = path) + collection = client.get_collection(name = collection_name) + if documents == [""]: + documents = None + if vectors is None: + vectors = [[1, 1, 1] for _ in range(len(ids))] + collection.add( + ids = ids, + embeddings = vectors, + documents = documents, + metadatas = metadatas + ) + +def do_get_data(path, collection_name, id=None, source=None): + client = chromadb.PersistentClient(path = path) + collection = client.get_collection(name = collection_name) + if source == "": + source = None + filter = None if source is None else {"$contains": source} + + data = collection.get( + ids=id, + # embeddings=vectors, + where_document=filter, + include=["embeddings", "metadatas", "documents"] + ) + results = {'ids': data.get('ids'), 'embeddings': data.get('embeddings'), 'metadatas': data.get('metadatas'), 'documents': data.get('documents')} + return results + +def do_delete_data(path, collection_name, sourceIds=None, ids=None): + client = chromadb.PersistentClient(path = path) + collection = client.get_collection(name = collection_name) + filter=None + if sourceIds is not None: + filter = {"sourceId": {"$in": [str(sourceId) for sourceId in sourceIds if sourceId is not None]}} + print(filter) + collection.delete( + ids=ids, + where=filter + ) - #创建本地存储数据的client链接,入参allow表示是否允许PersistentClient重置其状态或数据库,此处默认不允许; - def create_PersistentClient(self, pathname, allow=False): - chroma_client = chromadb.PersistentClient(path=pathname, settings=Settings(allow_reset=allow)) - return chroma_client +def do_query_data(path, collection_name, embuddings=None, results=10): + client = chromadb.PersistentClient(path = path) + collection = client.get_collection(name = collection_name) + data = collection.query( + query_embeddings=embuddings, + n_results=results + ) + results = {'ids': data.get('ids')[0],'metadatas': data.get('metadatas')[0], 'documents': data.get('documents')[0], 'distances': data.get('distances')[0]} + return results + +def do_update_data(path, collection_name, ids, documents = None, vectors = None, metadatas = None): + client = chromadb.PersistentClient(path = path) + collection = client.get_collection(name = collection_name) + collection.update( + ids = ids, + embeddings = vectors, + documents = documents, + metadatas = metadatas + ) + +class ChromaDB: + def __init__(self, path): + self.path = path def list_collections(self): - results = self.client.list_collections() - names = [collection.name for collection in results] - print(names) - return names + q = multiprocessing.Queue() + target = lambda q: q.put(do_list_collections(self.path)) + p = multiprocessing.Process(target=target, args=(q,)) + p.start() + p.join() + result = q.get() + return result #创建collection def create_collection( @@ -36,64 +111,39 @@ class ChromaDB: collection_name, collection_metadata= "cosine" ): - self.client.create_collection(name=collection_name, metadata={"hnsw:space": collection_metadata}) - - #获取collection - def get_collection( - self, - collection_name, - ): - collection = self.client.get_collection(name=collection_name) - return collection + p = multiprocessing.Process(target=do_create_collection, args=(self.path, collection_name)) + p.start() + p.join() def get_or_create_collection( self, collection_name, collection_metadata="cosine" ): - self.client.get_or_create_collection(name=collection_name, metadata={"hnsw:space": collection_metadata}) + p = multiprocessing.Process(target=do_get_or_create_collection, args=(self.path, collection_name)) + p.start() + p.join() #增 def add_data(self, collection_name, ids, metadatas, vectors = None, documents = None): - collection = self.get_collection(collection_name) - if documents == [""]: - documents = None - if vectors is None: - vectors = [[1, 1, 1] for _ in range(len(ids))] - collection.add( - ids = ids, - embeddings = vectors, - documents = documents, - metadatas = metadatas - ) + p = multiprocessing.Process(target=do_add_data, args=(self.path, collection_name, ids, metadatas, vectors, documents)) + p.start() + p.join() #获得所有数据 def get_data(self, collection_name, id=None, source=None): - collection = self.get_collection(collection_name) - if source == "": - source = None - filter = None if source is None else {"$contains": source} - - data = collection.get( - ids=id, - # embeddings=vectors, - where_document=filter, - include=["embeddings", "metadatas", "documents"] - ) - results = {'ids': data.get('ids'), 'embeddings': data.get('embeddings'), 'metadatas': data.get('metadatas'), 'documents': data.get('documents')} - return results + q = multiprocessing.Queue() + target = lambda q: q.put(do_get_data(self.path, collection_name, id, source)) + p = multiprocessing.Process(target=target, args=(q,)) + p.start() + p.join() + return q.get() #删 def delete_data(self, collection_name, sourceIds=None, ids=None): - collection = self.get_collection(collection_name) - filter=None - if sourceIds is not None: - filter = {"sourceId": {"$in": [str(sourceId) for sourceId in sourceIds if sourceId is not None]}} - print(filter) - collection.delete( - ids=ids, - where=filter - ) + p = multiprocessing.Process(target=do_delete_data, args=(self.path, collection_name, sourceIds, ids)) + p.start() + p.join() #查 def query_data( @@ -102,23 +152,18 @@ class ChromaDB: embuddings=None, results=10 ): - collection = self.get_collection(collection_name) - data=collection.query( - query_embeddings=embuddings, - n_results=results - ) - results = {'ids': data.get('ids')[0],'metadatas': data.get('metadatas')[0], 'documents': data.get('documents')[0], 'distances': data.get('distances')[0]} - return results + q = multiprocessing.Queue() + target = lambda q: q.put(do_query_data(self.path, collection_name, embuddings, results)) + p = multiprocessing.Process(target=target, args=(q,)) + p.start() + p.join() + return q.get() #更新跟add保持一致 def update_data(self, collection_name, ids, documents = None, vectors = None, metadatas = None): - collection = self.get_collection(collection_name) - collection.update( - ids = ids, - embeddings = vectors, - documents = documents, - metadatas = metadatas - ) + p = multiprocessing.Process(target=do_update_data, args=(self.path, collection_name, ids, documents, vectors, metadatas)) + p.start() + p.join() # if "__main__" == __name__: # db = ChromaDB("/home/wangyan/chroma/demo-chroma/database6") diff --git a/src/utils/vectordb/vectordb.cpp b/src/utils/vectordb/vectordb.cpp index 3d589b83edf3367dec4e2a9569358c5705826432..b87764a06bffb23372a7ed09d96e588a1eb8ae46 100644 --- a/src/utils/vectordb/vectordb.cpp +++ b/src/utils/vectordb/vectordb.cpp @@ -22,67 +22,63 @@ #include #include -VectorDB::VectorDB() +VectorDB::VectorDB(const std::string& databasePath) { std::cout << "Initializing VectorDB..." << std::endl; - initializeChromaDB(); + initializeChromaDB(databasePath); } -void VectorDB::initializeChromaDB() { +VectorDB::~VectorDB() { + PythonThreadLocker locker; + if (dbInstance_ != nullptr) { + Py_DECREF(dbInstance_); + } +} + +void VectorDB::initializeChromaDB(const std::string& databasePath) { std::cout << "Initializing Python..." << std::endl; // 初始化Python解释器 if (!Py_IsInitialized()) Py_Initialize(); PythonThreadLocker locker; + // 加载所需Python文件所在的路径 PyRun_SimpleString("import sys"); std::string command = "sys.path.append('" + std::string(DATA_MANAGEMENT_PYTHON_PATH) + "')"; PyRun_SimpleString(command.c_str()); -} -VectorDBErrorCode VectorDB::createClient(const std::string& databasePath) -{ - PythonThreadLocker locker; - std::cout << "Creating ChromaDB instance..." << std::endl; // 导入模块 + std::cout << "Creating ChromaDB instance..." << std::endl; PyObject* pModule = PyImport_ImportModule("chroma_db"); if (pModule == nullptr) { std::cerr << "Failed to load module: chroma_db" << std::endl; PyErr_Print(); - return VectorDBErrorCode::ArgumentError; + return; } // 获取ChromaDB类 PyObject* pChromaDBClass = PyObject_GetAttrString(pModule, "ChromaDB"); if (pChromaDBClass == nullptr) { std::cerr << "Failed to get ChromaDB class" << std::endl; - Py_DECREF(pModule); - return VectorDBErrorCode::ArgumentError; + Py_XDECREF(pModule); + return; } // 创建ChromaDB实例 PyObject* args = Py_BuildValue("(s)", databasePath.c_str()); if(args == nullptr) { std::cerr << "Error: Could not create argument tuple." << std::endl; + PyErr_Print(); + + Py_XDECREF(pModule); + Py_XDECREF(pChromaDBClass); + return; } dbInstance_ = PyObject_CallObject(pChromaDBClass, args); Py_DECREF(pModule); Py_DECREF(args); Py_DECREF(pChromaDBClass); - - return VectorDBErrorCode::Success; -} - -VectorDBErrorCode VectorDB::destroyClient() -{ - PythonThreadLocker locker; - if (dbInstance_ != nullptr) { - Py_DECREF(dbInstance_); - dbInstance_ = nullptr; // 设置为nullptr以避免悬空指针 - } - - return VectorDBErrorCode::Success; } VectorDBErrorCode VectorDB::createCollection(const std::string& collectionName, const std::string& searchAlgorithm) diff --git a/src/utils/vectordb/vectordb.h b/src/utils/vectordb/vectordb.h index 49624ada31dd1ba6218f1c50d9e79ffb5acc855f..a51c84dbb502631596080fd3ba1aa9caa1ecf31b 100644 --- a/src/utils/vectordb/vectordb.h +++ b/src/utils/vectordb/vectordb.h @@ -40,13 +40,10 @@ struct VectorData{ class VectorDB { public: - explicit VectorDB(); - // explicit VectorDB(const std::string& databasePath); - ~VectorDB() = default; + explicit VectorDB(const std::string& databasePath); + ~VectorDB(); std::string engineName() const {return "ChromaDB";} - VectorDBErrorCode createClient(const std::string& databasePath); - VectorDBErrorCode destroyClient(); VectorDBErrorCode createCollection(const std::string& collectionName, const std::string& searchAlgorithm); std::vector collections(); VectorDBErrorCode addData(const std::string& collectionName, const std::vector& data); @@ -62,7 +59,7 @@ public: std::vector> embeddings, int64_t numResults = 10); private: - void initializeChromaDB(); + void initializeChromaDB(const std::string& databasePath); std::vector convertPyObjectToVectorVectorData(PyObject* pyObject); PyObject* dbInstance_ = nullptr;