diff --git a/CMakeLists.txt b/CMakeLists.txt index 4a2ddfa6aa87661c74a4a0597e62283cb401a612..da00545f01b4861d9fcff0ef9569eca51bc1ffd7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -46,7 +46,10 @@ add_executable(kylin-ai-business-framework-service src/main.cpp src/datamanagement/datamanagementservice.cpp src/datamanagement/datamanagementservice.h src/datamanagement/segmenttokenizer.cpp - src/datamanagement/segmenttokenizer.h) + src/datamanagement/segmenttokenizer.h + src/utils/vectordb/vectordb.cpp + src/utils/vectordb/vectordb.h) + find_package(PkgConfig REQUIRED) find_package(Python 3.8 COMPONENTS Development REQUIRED) @@ -55,7 +58,6 @@ pkg_check_modules(poppler REQUIRED IMPORTED_TARGET poppler-cpp) pkg_check_modules(opencv REQUIRED IMPORTED_TARGET opencv4) pkg_check_modules(onnxruntime REQUIRED IMPORTED_TARGET libonnxruntime) pkg_check_modules(GIO REQUIRED IMPORTED_TARGET gio-unix-2.0) -pkg_check_modules(vectordb REQUIRED IMPORTED_TARGET vectordb-engine) pkg_check_modules(jsoncpp REQUIRED IMPORTED_TARGET jsoncpp) include_directories( @@ -72,7 +74,6 @@ target_link_libraries(kylin-ai-business-framework-service PkgConfig::poppler PkgConfig::opencv PkgConfig::onnxruntime - PkgConfig::vectordb ) set(UTILS_RESOURCE_PATH /usr/share/kylin-ai-business-framework/utils) diff --git a/debian/control b/debian/control index e8954c0c58fdb6833bab6c6bef89288c584c30b1..6e3eab38ae29ab276aa9f9672d87e2ad6f2fe67e 100644 --- a/debian/control +++ b/debian/control @@ -7,9 +7,7 @@ Build-Depends: debhelper-compat (= 12), libjsoncpp-dev, libpoppler-cpp-dev, libopencv-dev, - libglib2.0-dev, - libvectordb-engine, - libvectordb-engine-dev + libglib2.0-dev Standards-Version: 4.4.1 Homepage: https://www.ukui.org/ diff --git a/src/datamanagement/datamanagementdatabase.cpp b/src/datamanagement/datamanagementdatabase.cpp index d8d9a6578cdd2ca294a13d2c3b191a3114921d1f..3cb70ddab7c709c0d87b6701c549e31669b838f0 100644 --- a/src/datamanagement/datamanagementdatabase.cpp +++ b/src/datamanagement/datamanagementdatabase.cpp @@ -1,4 +1,5 @@ #include "datamanagementdatabase.h" +#include "utils/vectordb/vectordb.h" #include #include diff --git a/src/datamanagement/datamanagementdatabase.h b/src/datamanagement/datamanagementdatabase.h index ba424a305fd192d6fd16b6333563b19256ae1142..c71125b21cda781d9d7b2ca3c49cf7b0642459c6 100644 --- a/src/datamanagement/datamanagementdatabase.h +++ b/src/datamanagement/datamanagementdatabase.h @@ -1,7 +1,7 @@ #ifndef DATAMANAGEMENTDATABASE_H #define DATAMANAGEMENTDATABASE_H -#include +#include "utils/vectordb/vectordb.h" // pair.first: filepath, pair.second: similarity using SimilaritySearchResult = std::vector>; @@ -25,7 +25,6 @@ struct DBFileInfo { time_t modifyTime; std::vector> vectors; }; - class DataManagementDatabase { public: diff --git a/src/utils/python/chroma_db.py b/src/utils/python/chroma_db.py new file mode 100644 index 0000000000000000000000000000000000000000..affdd51bd12b32d275934249c8c77bbca6613792 --- /dev/null +++ b/src/utils/python/chroma_db.py @@ -0,0 +1,121 @@ +import chromadb +from chromadb.config import Settings + +class ChromaDB: + def __init__(self, pathname): + self.client = self.create_PersistentClient(pathname) + + #创建本地存储数据的client链接,入参allow表示是否允许PersistentClient重置其状态或数据库,此处默认不允许; + def create_PersistentClient(self, pathname, allow=False): + chroma_client = chromadb.PersistentClient(path=pathname, settings=Settings(allow_reset=allow)) + return chroma_client + + def list_collections(self): + results = self.client.list_collections() + names = [collection.name for collection in results] + print(names) + return names + + #创建collection + def create_collection( + self, + collection_name, + collection_metadata= "cosine" + ): + self.client.create_collection(name=collection_name, metadata={"hnsw:space": collection_metadata}) + + #获取collection + def get_collection( + self, + collection_name, + ): + collection = self.client.get_collection(name=collection_name) + return collection + + def get_or_create_collection( + self, + collection_name, + collection_metadata="cosine" + ): + self.client.get_or_create_collection(name=collection_name, metadata={"hnsw:space": collection_metadata}) + + #增 + def add_data(self, collection_name, ids, metadatas, vectors = None, documents = None): + collection = self.get_collection(collection_name) + if documents == [""]: + documents = None + if vectors is None: + vectors = [[1, 1, 1] for _ in range(len(ids))] + collection.add( + ids = ids, + embeddings = vectors, + documents = documents, + metadatas = metadatas + ) + + #获得所有数据 + def get_data(self, collection_name, id=None, source=None): + collection = self.get_collection(collection_name) + if source == "": + source = None + filter = None if source is None else {"$contains": source} + + data = collection.get( + ids=id, + # embeddings=vectors, + where_document=filter, + include=["embeddings", "metadatas", "documents"] + ) + results = {'ids': data.get('ids'), 'embeddings': data.get('embeddings'), 'metadatas': data.get('metadatas'), 'documents': data.get('documents')} + return results + + #删 + def delete_data(self, collection_name, sourceIds=None, ids=None): + collection = self.get_collection(collection_name) + filter=None + if sourceIds is not None: + filter = {"sourceId": {"$in": [str(sourceId) for sourceId in sourceIds if sourceId is not None]}} + print(filter) + collection.delete( + ids=ids, + where=filter + ) + + #查 + def query_data( + self, + collection_name, + embuddings=None, + results=10 + ): + collection = self.get_collection(collection_name) + data=collection.query( + query_embeddings=embuddings, + n_results=results + ) + results = {'ids': data.get('ids')[0],'metadatas': data.get('metadatas')[0], 'documents': data.get('documents')[0], 'distances': data.get('distances')[0]} + return results + + #更新跟add保持一致 + def update_data(self, collection_name, ids, documents = None, vectors = None, metadatas = None): + collection = self.get_collection(collection_name) + collection.update( + ids = ids, + embeddings = vectors, + documents = documents, + metadatas = metadatas + ) + +# if "__main__" == __name__: +# db = ChromaDB("/home/wangyan/chroma/demo-chroma/database6") +# db.create_collection("collection_demo1") +# # db.add_data("collection_demo", ["id1","id2"], ["/a", "/b"], None, [{"soure":"/a", "time": 1},{"soure":"/b", "time": 2}]) +# # db.delete_data("collection_demo1", None, ["idd1","idd2"] ) +# db.add_data("collection_demo1", ["idd1","idd2"],[{"soureId":"id1"}, {"soureId":"id2"}], [[1,2,3], [2,3,4]]) +# # results=db.get_data("collection_demo") +# results= db.get_data("collection_demo1", ["idd1","idd2"]) +# print(results) + +# ids = uuid +# collection1 = {metadate: path ,time. id:ids} +# colletion2 = {emdiing: sss , metadata: ids, id:dddd} \ No newline at end of file diff --git a/src/utils/pythonutil.cpp b/src/utils/pythonutil.cpp index e65fd97e39a3aad484a066def66a55cdc0d9ad59..dde164bb6f1abd23dbca701ce2913cdc1291745c 100644 --- a/src/utils/pythonutil.cpp +++ b/src/utils/pythonutil.cpp @@ -70,4 +70,257 @@ int convertPythonIntToStd(PyObject* intObject) { return static_cast(longValue); } +PyObject* convertVectorStringToPyObject(const std::vector& vec) +{ + if (vec.empty()) { + return Py_None; + } + PyObject* list = PyList_New(vec.size()); + if(!list) { + std::cout << "convertVectorStringToPyObject error!!!" << std::endl; + } + + for (size_t i = 0; i < vec.size(); ++i) { + PyObject* py_str = PyUnicode_DecodeFSDefault(vec[i].c_str()); + if (!py_str) { + Py_DECREF(list); + return nullptr; + } + + PyList_SET_ITEM(list, i, py_str); + } + return list; +} + +PyObject* convertVectorDoubleToPyObject(const std::vector& vec) { + if (vec.empty()) { + std::cout << "vec is empty" << std::endl; + Py_RETURN_NONE; + } + PyObject* list = PyList_New(vec.size()); + if(!list) { + std::cout << "convertVectorDoubleToPyObject error!!!" << std::endl; + } + + for (size_t i = 0; i < vec.size(); ++i) { + PyList_SET_ITEM(list, i, PyFloat_FromDouble(vec[i])); + } + return list; +} + +PyObject* convertVectorVectorDoubleToPyObject(const std::vector>& vec_vec) +{ + if (vec_vec.empty()) { + Py_RETURN_NONE; + } + PyObject* list = PyList_New(vec_vec.size()); + if(!list) { + std::cout << "convertVectorVectorDoubleToPyObject error!!!" << std::endl; + } + + for (size_t i = 0; i < vec_vec.size(); ++i) { + if (vec_vec[i].empty()) { + Py_RETURN_NONE ; + } else { + PyList_SET_ITEM(list, i, convertVectorDoubleToPyObject(vec_vec[i])); + } + } + + return list; +} + +// 将std::variant转换为Python对象 +PyObject* convertVariantToPyObject(const std::variant& var) +{ + if (std::holds_alternative(var)) { + return PyUnicode_FromString(std::get(var).c_str()); + } else if (std::holds_alternative(var)) { + return PyLong_FromLongLong(std::get(var)); + } + // 如果variant中不是预期的类型,返回None或抛出异常 + Py_RETURN_NONE; +} + +// 将std::map>转换为Python字典对象 +PyObject* convertMapVariantToPyObject(const std::map>& map_var) { + PyObject* dict = PyDict_New(); + for (const auto& entry : map_var) { + PyObject* key = PyUnicode_FromString(entry.first.c_str()); + PyObject* value = convertVariantToPyObject(entry.second); + if (key && value) { + PyDict_SetItem(dict, key, value); + Py_DECREF(key); + Py_DECREF(value); + } else { + // 如果有错误发生,释放已分配的对象并返回NULL + Py_XDECREF(key); + Py_XDECREF(value); + Py_DECREF(dict); + return nullptr; + } + } + return dict; +} + +// 将std::vector>>转换为Python列表对象 +PyObject* convertVectorMapToPyObject(const std::vector>>& vec_map_var) { + PyObject* list = PyList_New(vec_map_var.size()); + for (size_t i = 0; i < vec_map_var.size(); ++i) { + PyList_SET_ITEM(list, i, convertMapVariantToPyObject(vec_map_var[i])); + } + return list; +} + +std::vector convertPythonListToStringVector(PyObject* ids_obj) +{ + std::vector ids={}; + if (!PyList_Check(ids_obj)) { + return ids; + } + + Py_ssize_t size = PyList_Size(ids_obj); + for (Py_ssize_t i = 0; i < size; ++i) { + PyObject* item = PyList_GetItem(ids_obj, i); + + if (!PyUnicode_Check(item)) { + continue; + } + + PyObject* unicode_str = PyUnicode_AsUTF8String(item); + const char* c_str = PyBytes_AsString(unicode_str); + + ids.push_back(std::string(c_str)); + Py_DECREF(unicode_str); + } + + return ids; +} + +std::vector> convertPythonListToVectorOfVectors(PyObject* embeddings_obj) { + std::vector> embeddings; + if (PyList_Check(embeddings_obj)) { + embeddings = extractInnerVectors(embeddings_obj); + } else { + std::cerr << "The embeddings object is not a list." << std::endl; + } + return embeddings; +} + +std::vector> extractInnerVectors(PyObject* outer_list) { + std::vector> embeddings; + Py_ssize_t outer_size = PyList_Size(outer_list); + for (Py_ssize_t i = 0; i < outer_size; ++i) { + PyObject* inner_list = PyList_GetItem(outer_list, i); + if (PyList_Check(inner_list)) { + std::vector inner_vec = extractDoubles(inner_list); + embeddings.push_back(inner_vec); + } else { + std::cerr << "Non-list value found in the embeddings list." << std::endl; + } + } + return embeddings; +} + +std::vector extractDoubles(PyObject* inner_list) { + std::vector inner_vec; + Py_ssize_t inner_size = PyList_Size(inner_list); + for (Py_ssize_t j = 0; j < inner_size; ++j) { + PyObject* item = PyList_GetItem(inner_list, j); + if (PyFloat_Check(item)) { + double value = PyFloat_AsDouble(item); + inner_vec.push_back(value); + } else { + std::cerr << "Non-float value found in the embeddings list." << std::endl; + } + } + return inner_vec; +} + +std::vector>> convertPythonListToMapOfVariant(PyObject* metadatas_obj) { + std::vector>> metadatas; + if (!PyList_Check(metadatas_obj)) { + return metadatas; + } + Py_ssize_t size = PyList_Size(metadatas_obj); + for (Py_ssize_t i = 0; i < size; ++i) { + PyObject* item = PyList_GetItem(metadatas_obj, i); + if (PyDict_Check(item)) { + std::map> metadata_map = extractMapFromDict(item); + metadatas.push_back(metadata_map); + } + } + + return metadatas; +} + +std::map> extractMapFromDict(PyObject* dict_obj) { + std::map> metadata_map; + PyObject* keys = PyDict_Keys(dict_obj); + Py_ssize_t key_count = PyList_Size(keys); + for (Py_ssize_t j = 0; j < key_count; ++j) { + PyObject* key_obj = PyList_GetItem(keys, j); + if (PyUnicode_Check(key_obj)) { + std::string key = extractStringFromUnicode(key_obj); + PyObject* value_obj = PyDict_GetItem(dict_obj, key_obj); + std::variant value = extractVariantFromObj(value_obj); + metadata_map[key] = value; + } + } + Py_DECREF(keys); + return metadata_map; +} + +std::string extractStringFromUnicode(PyObject* unicode_obj) { + PyObject* unicode_as_utf8 = PyUnicode_AsUTF8String(unicode_obj); + const char* c_str = PyBytes_AsString(unicode_as_utf8); + std::string str(c_str); + Py_DECREF(unicode_as_utf8); + return str; +} + +std::variant extractVariantFromObj(PyObject* value_obj) { + std::variant value; + if (PyUnicode_Check(value_obj)) { + value = extractStringFromUnicode(value_obj); + } else if (PyLong_Check(value_obj)) { + int64_t int_value = PyLong_AsLongLong(value_obj); + value = int_value; + } + return value; +} + +std::vector convertPythonListToVectorFloat(PyObject* obj) +{ + std::vector result; + if (!PySequence_Check(obj)) { + PyErr_SetString(PyExc_TypeError, "Input object is not a sequence"); + return result; + } + Py_ssize_t length = PySequence_Size(obj); + result.reserve(length); + for (Py_ssize_t i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(obj, i); + if (item == NULL) { + PyErr_Print(); + Py_XDECREF(item); + return result; + } + if (!PyFloat_Check(item)) { + PyErr_Format(PyExc_TypeError, "Sequence item %zd is not a float", i); + Py_XDECREF(item); + return result; + } + double value = PyFloat_AsDouble(item); + if (value == -1.0 && PyErr_Occurred()) { + PyErr_Print(); + Py_XDECREF(item); + return result; + } + result.push_back(static_cast(value)); + + Py_XDECREF(item); + } + return result; +} + } diff --git a/src/utils/pythonutil.h b/src/utils/pythonutil.h index cf1fbdf5eddc2be02eef01bbc2cb982ceb665907..5a226a9e32ee7ba3b332435f4993bea60fe04ade 100644 --- a/src/utils/pythonutil.h +++ b/src/utils/pythonutil.h @@ -6,6 +6,7 @@ #include #include #include +#include #define DATA_MANAGEMENT_PYTHON_PATH "/usr/share/kylin-ai-business-framework/utils/python/" @@ -16,6 +17,28 @@ std::vector convertPythonStringListToStd(PyObject* listObject); std::string convertPythonStringToStd(PyObject* strObject); int convertPythonIntToStd(PyObject* intObject); +PyObject* convertVectorStringToPyObject(const std::vector& vec); + +PyObject* convertVectorDoubleToPyObject(const std::vector& vec); +PyObject* convertVectorVectorDoubleToPyObject(const std::vector>& vec_vec); + +PyObject* convertVariantToPyObject(const std::variant& var); +PyObject* convertMapVariantToPyObject(const std::map>& map_var); +PyObject* convertVectorMapToPyObject(const std::vector>>& vec_map_var); + +std::vector convertPythonListToStringVector(PyObject* ids_obj); + +std::vector> convertPythonListToVectorOfVectors(PyObject* embeddings_obj); +std::vector> extractInnerVectors(PyObject* outer_list); +std::vector extractDoubles(PyObject* inner_list); + +std::vector>> convertPythonListToMapOfVariant(PyObject* metadatas_obj); +std::map> extractMapFromDict(PyObject* dict_obj); +std::string extractStringFromUnicode(PyObject* unicode_obj); +std::variant extractVariantFromObj(PyObject* value_obj); + +std::vector convertPythonListToVectorFloat(PyObject* obj); + } diff --git a/src/utils/vectordb/vectordb.cpp b/src/utils/vectordb/vectordb.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f1a48c4feb52689949437575fdff4161629065b1 --- /dev/null +++ b/src/utils/vectordb/vectordb.cpp @@ -0,0 +1,395 @@ +#include "vectordb.h" +#include "utils/pythonthreadlocker.h" +#include "utils/pythonutil.h" + +#include +#include +#include + +VectorDB::VectorDB() +{ + std::cout << "Initializing VectorDB..." << std::endl; + initializeChromaDB(); +} + +void VectorDB::initializeChromaDB() { + std::cout << "Initializing Python..." << std::endl; + // 初始化Python解释器 + if (!Py_IsInitialized()) + Py_Initialize(); + + PythonThreadLocker locker; + PyRun_SimpleString("import sys"); + std::string command = "sys.path.append('" + std::string(DATA_MANAGEMENT_PYTHON_PATH) + "')"; + PyRun_SimpleString(command.c_str()); +} + +VectorDBErrorCode VectorDB::createClient(const std::string& databasePath) +{ + PythonThreadLocker locker; + std::cout << "Creating ChromaDB instance..." << std::endl; + // 导入模块 + PyObject* pModule = PyImport_ImportModule("chroma_db"); + if (pModule == nullptr) { + std::cerr << "Failed to load module: chroma_db" << std::endl; + PyErr_Print(); + return VectorDBErrorCode::ArgumentError; + } + + // 获取ChromaDB类 + PyObject* pChromaDBClass = PyObject_GetAttrString(pModule, "ChromaDB"); + if (pChromaDBClass == nullptr) { + std::cerr << "Failed to get ChromaDB class" << std::endl; + Py_DECREF(pModule); + return VectorDBErrorCode::ArgumentError; + } + + // 创建ChromaDB实例 + PyObject* args = Py_BuildValue("(s)", databasePath.c_str()); + if(args == nullptr) { + std::cerr << "Error: Could not create argument tuple." << std::endl; + } + dbInstance_ = PyObject_CallObject(pChromaDBClass, args); + + Py_DECREF(pModule); + Py_DECREF(args); + Py_DECREF(pChromaDBClass); + + return VectorDBErrorCode::Success; +} + +VectorDBErrorCode VectorDB::destroyClient() +{ + PythonThreadLocker locker; + if (dbInstance_ != nullptr) { + Py_DECREF(dbInstance_); + dbInstance_ = nullptr; // 设置为nullptr以避免悬空指针 + } + + return VectorDBErrorCode::Success; +} + +VectorDBErrorCode VectorDB::createCollection(const std::string& collectionName, const std::string& searchAlgorithm) +{ + PythonThreadLocker locker; + if (dbInstance_ == nullptr) { + return VectorDBErrorCode::ValidDataBase; + } + + PyObject* args = Py_BuildValue("(ss)", collectionName.c_str(), searchAlgorithm.c_str()); + if (!args) { + std::cerr << "Error: Could not create argument tuple." << std::endl; + Py_XDECREF(args); + return VectorDBErrorCode::ArgumentError; + } + + PyObject* result = PyObject_CallMethod(dbInstance_, "get_or_create_collection", "O", args); + if (result == nullptr) { + // 处理错误 + PyErr_Print(); + Py_XDECREF(args); + Py_XDECREF(result); + return VectorDBErrorCode::PyObjectCallMethodError; + } + + Py_XDECREF(args); + Py_XDECREF(result); + std::cout << "createCollection success" << std::endl; + return VectorDBErrorCode::Success; +} + +std::vector VectorDB::collections() +{ + PythonThreadLocker locker; + std::vector collectionLists={}; + if (dbInstance_ == nullptr) { + return collectionLists; + } + PyObject* func = PyObject_GetAttrString(dbInstance_, "list_collections"); + if (func == NULL) { + std::cerr << "Could not find function 'list_collections'" << std::endl; + return collectionLists; + } + + PyObject* result = PyObject_CallObject(func, NULL); // 假设list_collections不接受任何参数 + Py_DECREF(func); // 减少对函数的引用计数 + collectionLists = pythonutil::convertPythonListToStringVector(result); + return collectionLists; +} + +VectorDBErrorCode VectorDB::addData(const std::string& collectionName, const std::vector& data) +{ + PythonThreadLocker locker; + std::cout << "Adding data to ChromaDB..." << std::endl; + if (dbInstance_ == nullptr) { + return VectorDBErrorCode::ValidDataBase; + } + + if (data.empty()) { + std::cerr << "Warning: Attempted to addData with an empty data vector." << std::endl; + return VectorDBErrorCode::EmptyDataVector; + } + + std::vector ids; + std::vector> embeddings; + std::vector>> metadatas; + std::vector documents; + + for (int i = 0; i < data.size(); i++) { + ids.push_back(data[i].id); + embeddings.push_back(data[i].embedding); + metadatas.push_back(data[i].metadata); + auto it = data[i].metadata.find("source"); + if (it != data[i].metadata.end()) { + documents.push_back(std::get(it->second)); + } else { + documents.emplace_back(); + } + } + + PyObject* pyIds = pythonutil::convertVectorStringToPyObject(ids); + PyObject* pyEmbeddings = pythonutil::convertVectorVectorDoubleToPyObject(embeddings); + PyObject* pyMetadatas = pythonutil::convertVectorMapToPyObject(metadatas); + PyObject* pyDocuments = pythonutil::convertVectorStringToPyObject(documents); + + // 创建包含所有参数的元组 + PyObject* args_tuple = PyTuple_New(5); + PyTuple_SET_ITEM(args_tuple, 0, PyUnicode_FromString(collectionName.c_str())); + PyTuple_SET_ITEM(args_tuple, 1, pyIds); + PyTuple_SET_ITEM(args_tuple, 2, pyMetadatas); + PyTuple_SET_ITEM(args_tuple, 3, pyEmbeddings); + PyTuple_SET_ITEM(args_tuple, 4, pyDocuments); + + PyObject* result = PyObject_CallObject(PyObject_GetAttrString(dbInstance_, "add_data"), args_tuple); + if (result == nullptr) { + PyErr_Print(); + + Py_XDECREF(pyIds); + Py_XDECREF(pyEmbeddings); + Py_XDECREF(pyMetadatas); + Py_XDECREF(pyDocuments); + Py_XDECREF(args_tuple); + Py_XDECREF(result); + + return VectorDBErrorCode::PyObjectCallMethodError; + } + + Py_XDECREF(pyIds); + Py_XDECREF(pyEmbeddings); + Py_XDECREF(pyMetadatas); + Py_XDECREF(pyDocuments); + Py_XDECREF(args_tuple); + Py_XDECREF(result); + + return VectorDBErrorCode::Success; +} + +VectorDBErrorCode VectorDB::updateData(const std::string& collectionName, const std::vector& data) +{ + PythonThreadLocker locker; + if (dbInstance_ == nullptr) { + return VectorDBErrorCode::ValidDataBase; + } + + if (data.empty()) { + std::cerr << "Warning: Attempted to updateData with an empty data vector." << std::endl; + return VectorDBErrorCode::EmptyDataVector; + } + std::vector ids; + std::vector> embeddings; + std::vector>> metadatas; + std::vector documents; + + for (int i = 0; i < data.size(); i++) { + ids.push_back(data[i].id); + embeddings.push_back(data[i].embedding); + metadatas.push_back(data[i].metadata); + auto it = data[i].metadata.find("source"); + if (it != data[i].metadata.end()) { + documents.push_back(std::get(it->second)); + } else { + documents.emplace_back(); + } + } + + PyObject* pyIds = pythonutil::convertVectorStringToPyObject(ids); + PyObject* pyEmbeddings = pythonutil::convertVectorVectorDoubleToPyObject(embeddings); + PyObject* pyMetadatas = pythonutil::convertVectorMapToPyObject(metadatas); + PyObject* pyDocuments = pythonutil::convertVectorStringToPyObject(documents); + + // 创建包含所有参数的元组 + PyObject* args_tuple = PyTuple_New(5); + PyTuple_SET_ITEM(args_tuple, 0, PyUnicode_FromString(collectionName.c_str())); + PyTuple_SET_ITEM(args_tuple, 1, pyIds); + PyTuple_SET_ITEM(args_tuple, 2, pyMetadatas); + PyTuple_SET_ITEM(args_tuple, 3, pyEmbeddings); + PyTuple_SET_ITEM(args_tuple, 4, pyDocuments); + + PyObject* result = PyObject_CallObject(PyObject_GetAttrString(dbInstance_, "update_data"), args_tuple); + if (result == nullptr) { + PyErr_Print(); + + Py_XDECREF(pyIds); + Py_XDECREF(pyEmbeddings); + Py_XDECREF(pyMetadatas); + Py_XDECREF(pyDocuments); + Py_XDECREF(args_tuple); + Py_XDECREF(result); + + return VectorDBErrorCode::PyObjectCallMethodError; + } + + Py_XDECREF(pyIds); + Py_XDECREF(pyEmbeddings); + Py_XDECREF(pyMetadatas); + Py_XDECREF(pyDocuments); + Py_XDECREF(args_tuple); + Py_XDECREF(result); + + return VectorDBErrorCode::Success; +} + +std::vector VectorDB::getData(const std::string& collectionName, const std::vector& ids, + const std::string& source) +{ + PythonThreadLocker locker; + std::vector vectorDates={}; + if (dbInstance_ == nullptr) { + return vectorDates; + } + PyObject* pyIds = pythonutil::convertVectorStringToPyObject(ids); + + PyObject* args_tuple = PyTuple_New(3); + PyTuple_SET_ITEM(args_tuple, 0, PyUnicode_FromString(collectionName.c_str())); + PyTuple_SET_ITEM(args_tuple, 1, pyIds); + PyTuple_SET_ITEM(args_tuple, 2, PyUnicode_FromString(source.c_str())); + + PyObject* result = PyObject_CallObject(PyObject_GetAttrString(dbInstance_, "get_data"), args_tuple); + if (result == nullptr) { + // 处理错误 + PyErr_Print(); + Py_XDECREF(pyIds); + Py_XDECREF(args_tuple); + return vectorDates; + } + + Py_XDECREF(pyIds); + Py_XDECREF(args_tuple); + vectorDates = convertPyObjectToVectorVectorData(result); + return vectorDates; +} + +VectorDBErrorCode VectorDB::deleteData(const std::string& collectionName, const std::vector& sourceIds, const std::vector& ids) +{ + PythonThreadLocker locker; + if (dbInstance_ == nullptr) { + return VectorDBErrorCode::ValidDataBase; + } + PyObject* pySourceIds = pythonutil::convertVectorStringToPyObject(sourceIds); + PyObject* pyIds = pythonutil::convertVectorStringToPyObject(ids); + PyObject* args_tuple = PyTuple_New(3); + PyTuple_SET_ITEM(args_tuple, 0, PyUnicode_FromString(collectionName.c_str())); + PyTuple_SET_ITEM(args_tuple, 1, pySourceIds); + PyTuple_SET_ITEM(args_tuple, 2, pyIds); + + if (PyObject_CallObject(PyObject_GetAttrString(dbInstance_, "delete_data"), args_tuple) == nullptr) { + PyErr_Print(); + + Py_XDECREF(pySourceIds); + Py_XDECREF(pyIds); + Py_XDECREF(args_tuple); + return VectorDBErrorCode::PyObjectCallMethodError; + } + + Py_XDECREF(pySourceIds); + Py_XDECREF(pyIds); + Py_XDECREF(args_tuple); + + return VectorDBErrorCode::Success; +} + +std::vector> VectorDB::queryData(const std::string& collectionName, std::vector> vectors, + int64_t numResults) +{ + PythonThreadLocker locker; + std::vector> queryDatas={}; + if (dbInstance_ == nullptr) { + return queryDatas; + } + std::vector ids; + std::vector> embeddings; + std::vector>> metadatas; + std::vector distances; + + PyObject* pyEmbeddings = pythonutil::convertVectorVectorDoubleToPyObject(vectors); + + PyObject* args_tuple = PyTuple_New(3); + PyTuple_SET_ITEM(args_tuple, 0, PyUnicode_FromString(collectionName.c_str())); + PyTuple_SET_ITEM(args_tuple, 1, pyEmbeddings); + PyTuple_SET_ITEM(args_tuple, 2, PyLong_FromLongLong(numResults)); + + PyObject* result = PyObject_CallObject(PyObject_GetAttrString(dbInstance_, "query_data"), args_tuple); + if (result == nullptr) { + // 处理错误 + PyErr_Print(); + Py_XDECREF(pyEmbeddings); + Py_XDECREF(args_tuple); + return queryDatas; + } + + PyObject* ids_obj = PyDict_GetItemString(result, "ids"); + PyObject* metadatas_obj = PyDict_GetItemString(result, "metadatas"); + PyObject* distances_obj = PyDict_GetItemString(result, "distances"); + + ids = pythonutil::convertPythonListToStringVector(ids_obj); + metadatas = pythonutil::convertPythonListToMapOfVariant(metadatas_obj); + distances = pythonutil::convertPythonListToVectorFloat(distances_obj); + + for (size_t i = 0; i < ids.size(); ++i) { + std::pair queryData; + VectorData data; + float distance; + + data.id = ids[i]; + data.metadata = metadatas[i]; + distance = distances[i]; + queryData.first = data; + queryData.second = distance; + queryDatas.push_back(queryData); + } + + Py_XDECREF(pyEmbeddings); + Py_XDECREF(args_tuple); + return queryDatas; +} + +std::vector VectorDB::convertPyObjectToVectorVectorData(PyObject* pyObject) +{ + std::vector vectorDats; + std::vector ids; + std::vector> embeddings; + std::vector>> metadatas; + + PyObject* ids_obj = PyDict_GetItemString(pyObject, "ids"); + PyObject* embeddings_obj = PyDict_GetItemString(pyObject, "embeddings"); + PyObject* metadatas_obj = PyDict_GetItemString(pyObject, "metadatas"); + + ids = pythonutil::convertPythonListToStringVector(ids_obj); + embeddings = pythonutil::convertPythonListToVectorOfVectors(embeddings_obj); + metadatas = pythonutil::convertPythonListToMapOfVariant(metadatas_obj); + + if (ids.size() == embeddings.size() && ids.size() == metadatas.size()) { + for (size_t i = 0; i < ids.size(); ++i) { + VectorData data; + data.id = ids[i]; + data.embedding = embeddings[i]; + data.metadata = metadatas[i]; + + vectorDats.push_back(data); + } + } else { + std::cout << "ids, embeddings, 和 metadatas的大小不匹配 " << std::endl; + } + + return vectorDats; +} \ No newline at end of file diff --git a/src/utils/vectordb/vectordb.h b/src/utils/vectordb/vectordb.h new file mode 100644 index 0000000000000000000000000000000000000000..457af1dad23de367f2a0d53f254eb537232c14e8 --- /dev/null +++ b/src/utils/vectordb/vectordb.h @@ -0,0 +1,55 @@ +#ifndef VECTORDB_H +#define VECTORDB_H + +#include +#include +#include +#include +#include + +enum class VectorDBErrorCode +{ + Success, + ArgumentError, + PyObjectCallMethodError, + ValidDataBase, + EmptyDataVector +}; + +struct VectorData{ + std::string id; + std::vector embedding = {}; + std::map> metadata; +}; +class VectorDB +{ +public: + explicit VectorDB(); + // explicit VectorDB(const std::string& databasePath); + ~VectorDB() = default; + std::string engineName() const {return "ChromaDB";} + + VectorDBErrorCode createClient(const std::string& databasePath); + VectorDBErrorCode destroyClient(); + VectorDBErrorCode createCollection(const std::string& collectionName, const std::string& searchAlgorithm); + std::vector collections(); + VectorDBErrorCode addData(const std::string& collectionName, const std::vector& data); + VectorDBErrorCode updateData(const std::string& collectionName, const std::vector& data); + + std::vector getData(const std::string& collectionName, + const std::vector& ids={}, const std::string& source=""); + + VectorDBErrorCode deleteData(const std::string& collectionName, + const std::vector& sourceIds={}, const std::vector& ids={}); + + std::vector> queryData(const std::string& collectionName, + std::vector> embeddings, + int64_t numResults = 10); +private: + void initializeChromaDB(); + std::vector convertPyObjectToVectorVectorData(PyObject* pyObject); + + PyObject* dbInstance_ = nullptr; +}; + +#endif // VECTORDB_H