diff --git a/CMakeLists.txt b/CMakeLists.txt index 91ffb3c6aca4e9e1b413a77eeebb44b0ddf4a107..18e3d76cbea5db91bc75a6dc8ffe9925a053d16b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,7 +38,15 @@ add_executable(kylin-ai-business-framework-service src/main.cpp src/embeddingtaskmanager/embeddingtaskmanager.h src/embeddingtaskmanager/embeddingtaskmanager.cpp src/embeddingtaskmanager/imageembeddingservice.h src/embeddingtaskmanager/imageembeddingservice.cpp src/embeddingtaskmanager/textembeddingservice.h src/embeddingtaskmanager/textembeddingservice.cpp - src/util/python/pythonthreadlocker.h) + src/util/python/pythonthreadlocker.h + src/datamanagement/datamanagementjsonhelper.cpp + src/datamanagement/datamanagementjsonhelper.h + src/datamanagement/datamanagementdatabase.cpp + src/datamanagement/datamanagementdatabase.h + src/datamanagement/datamanagementservice.cpp + src/datamanagement/datamanagementservice.h + src/datamanagement/segmenttokenizer.cpp + src/datamanagement/segmenttokenizer.h) find_package(PkgConfig REQUIRED) find_package(Python 3.8 COMPONENTS Development REQUIRED) @@ -47,6 +55,8 @@ pkg_check_modules(poppler REQUIRED IMPORTED_TARGET poppler-cpp) pkg_check_modules(opencv REQUIRED IMPORTED_TARGET opencv4) pkg_check_modules(onnxruntime REQUIRED IMPORTED_TARGET libonnxruntime) pkg_check_modules(GIO REQUIRED IMPORTED_TARGET gio-unix-2.0) +pkg_check_modules(vectordb REQUIRED IMPORTED_TARGET vectordb-engine) +pkg_check_modules(jsoncpp REQUIRED IMPORTED_TARGET jsoncpp) include_directories( ./src @@ -58,9 +68,11 @@ target_link_libraries(kylin-ai-business-framework-service ${Python_LIBRARIES} PkgConfig::GIO pthread + PkgConfig::jsoncpp PkgConfig::poppler PkgConfig::opencv PkgConfig::onnxruntime + PkgConfig::vectordb ) include(GNUInstallDirs) diff --git a/src/datamanagement/datamanagementdatabase.cpp b/src/datamanagement/datamanagementdatabase.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e5732d65006ebae2c0b868caa34f9bc92f34fcb6 --- /dev/null +++ b/src/datamanagement/datamanagementdatabase.cpp @@ -0,0 +1,290 @@ +#include "datamanagementdatabase.h" + +#include +#include +#include +#include + +static const char* VECTOR_DATABASE_PATH = "/usr/share/kylin-ai-runtime/datamanagement/database/"; +static const char* VECTOR_DATABASE_NAME = "search"; +static const char* FILE_INFO_COLLECTION_NAME = "files-info"; +static const char* VISION_FILE_CONTENT_COLLECTION_NAME = "vision-files-content-vector"; +static const char* TEXT_FILE_CONTENT_COLLECTION_NAME = "text-files-content-vector"; + +namespace { + +std::string generateUUID() { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, 15); + + std::stringstream ss; + ss << std::hex; + + for (int i = 0; i < 32; ++i) { + if (i == 8 || i == 12 || i == 16 || i == 20) { + ss << "-"; + } + ss << dis(gen); + } + + return ss.str(); +} + +std::vector convertFloatToDouble(const std::vector& floatVec) { + std::vector doubleVec; + doubleVec.reserve(floatVec.size()); // 预留空间以避免不必要的重新分配 + + for (float floatValue : floatVec) { + doubleVec.push_back(static_cast(floatValue)); + } + + return doubleVec; +} + +} + +DataManagementDatabase::~DataManagementDatabase() +{ + if (connected_) + database_.destroyClient(); +} + +bool DataManagementDatabase::createClient(uid_t uid) +{ + std::string databasePath = VECTOR_DATABASE_PATH + std::to_string(uid) + "/" + VECTOR_DATABASE_NAME; + + if (database_.createClient(databasePath) != VectorDBErrorCode::Success) { + return false; + } + + connected_ = true; + return true; +} + +void DataManagementDatabase::destroyClient() +{ + connected_ = false; + database_.destroyClient(); +} + +SimilaritySearchResult DataManagementDatabase::visionSearch(const std::vector &vector) +{ + return similaritySearch(VISION_FILE_CONTENT_COLLECTION_NAME, vector); +} + +SimilaritySearchResult DataManagementDatabase::textSearch(const std::vector &vector) +{ + return similaritySearch(TEXT_FILE_CONTENT_COLLECTION_NAME, vector); +} + +SimilaritySearchResult DataManagementDatabase::similaritySearch( + const std::string &collection, const std::vector& vector) +{ + if (vector.empty() || collection.empty() || !connected_) { + return std::vector>(); + } + + std::vector collections = {FILE_INFO_COLLECTION_NAME}; + collections.push_back(collection); + + if (!hasCollections(collections)) { + return std::vector>(); + } + std::vector> vectors; + vectors.emplace_back(convertFloatToDouble(vector)); + + auto searchResults = database_.queryData(collection, vectors, 50); + + std::vector> result; + // 通过相似度搜到得到的sourceId,在另一个collection中查找对应的文件 + for (const std::pair& searchResult : searchResults) { + VectorData data = searchResult.first; + double similarity = 1 - searchResult.second; // 目前向量库返回的是距离,我们要转换成相似度 + std::string sourceId = std::get(data.metadata["sourceId"]); + std::vector ids = { sourceId }; + + auto filedatas = database_.getData(FILE_INFO_COLLECTION_NAME, ids); + if (filedatas.empty()) + continue; + std::string filepath = std::get(filedatas[0].metadata["source"]); + + result.emplace_back(std::move(filepath), std::move(similarity)); + } + + return result; +} + +void DataManagementDatabase::addImageDatas(const std::vector &fileinfos) +{ + addDatas(fileinfos, VISION_FILE_CONTENT_COLLECTION_NAME); +} + +void DataManagementDatabase::addTextDatas(const std::vector &fileinfos) +{ + addDatas(fileinfos, TEXT_FILE_CONTENT_COLLECTION_NAME); +} + +void DataManagementDatabase::addDatas(const std::vector &fileinfos, const std::string& collection) +{ + if (!connected_) { + std::cerr << "Please create client first" << std::endl; + return; + } + + std::vector fileDatas; + std::vector vectorDatas; + + for (const auto& fileinfo : fileinfos) { + if (fileinfo.empty()) { + std::cerr << "Add file " << fileinfo.filepath << " to database failed," + << "file vector is empty" << std::endl; + continue; + } + + VectorData fileData = getFileDataFromFileInfo(fileinfo); + std::vector vectorData = getVectorDatasFromFileInfo(fileinfo, fileData.id); + std::cout << "Add file: " << fileinfo.filepath + << " file id: " << fileData.id << std::endl; + + fileDatas.emplace_back(std::move(fileData)); + vectorDatas.insert(vectorDatas.end(), + std::make_move_iterator(vectorData.begin()), + std::make_move_iterator(vectorData.end())); + } + + if (fileDatas.empty() || vectorDatas.empty()) { + return; + } + + const std::vector collections = + {FILE_INFO_COLLECTION_NAME, collection}; + checkAndCreateCollections(collections); + + // 分别向两个collection添加数据,若有一方添加失败则撤回操作 + auto filesResult = database_.addData(FILE_INFO_COLLECTION_NAME, fileDatas); + if (filesResult != VectorDBErrorCode::Success) { + std::cerr << "Add file data failed." << std::endl; + return; + } + auto vectorResult = database_.addData(collection, vectorDatas); + if (vectorResult != VectorDBErrorCode::Success) { + std::cerr << "Add vector data failed." << std::endl; + std::vector ids; + for (const auto& data : fileDatas) { + ids.push_back(data.id); + }; + database_.deleteData(FILE_INFO_COLLECTION_NAME, {}, ids); + } +} + +GetAllFileInfosResult DataManagementDatabase::getAllFileInfos() +{ + GetAllFileInfosResult result; + if (!connected_) { + std::cerr << "Please create client first" << std::endl; + return result; + } + if (!hasCollections({FILE_INFO_COLLECTION_NAME})) { + std::cerr << "don't have " << FILE_INFO_COLLECTION_NAME << " collection, " + << "can't get file infos" << std::endl; + return result; + } + + std::vector datas = database_.getData(FILE_INFO_COLLECTION_NAME); + for (auto& data : datas) { + result.emplace_back(std::move(std::get(data.metadata["source"])), + std::move(std::get(data.metadata["modifytime"]))); + } + + return result; +} + +void DataManagementDatabase::deleteFiles(const std::vector &files) +{ + if (!connected_) { + std::cerr << "Please create client first" << std::endl; + return; + } + if (!hasCollections({FILE_INFO_COLLECTION_NAME})) { + std::cerr << "don't have " << FILE_INFO_COLLECTION_NAME << " collection, " + << "can't get file infos" << std::endl; + return; + } + + for (const auto& file : files) { + std::cout << "delete file: " << file << std::endl; + } + + std::vector deletedSourceIds; + for (const auto& file : files) { + std::vector res = database_.getData(FILE_INFO_COLLECTION_NAME, {}, file); + if (res.empty()) { + std::cerr << "Get file: " << file << " data, failed." << std::endl; + continue; + } + for (const auto& filedata : res) { + deletedSourceIds.emplace_back(filedata.id); + } + } + if (deletedSourceIds.empty()) { + std::cerr << "Can't get any data from files, cancel delete." << std::endl; + return; + } + + database_.deleteData(FILE_INFO_COLLECTION_NAME, {}, deletedSourceIds); + database_.deleteData(VISION_FILE_CONTENT_COLLECTION_NAME, deletedSourceIds); + database_.deleteData(TEXT_FILE_CONTENT_COLLECTION_NAME, deletedSourceIds); +} + +bool DataManagementDatabase::hasCollections(const std::vector& collections) +{ + if (collections.empty()) + return false; + + std::vector collectionList = database_.collections(); + + bool hasCollections = std::all_of(collections.begin(), collections.end(), + [&collectionList](const std::string& str) { + auto it = std::find(collectionList.begin(), collectionList.end(), str); + bool hasCollection = it != collectionList.end(); + return hasCollection; + }); + + return hasCollections; +} + +void DataManagementDatabase::checkAndCreateCollections(const std::vector &collections) +{ + std::vector collectionList = database_.collections(); + for (const auto& collection : collections) { + auto it = std::find(collectionList.begin(), collectionList.end(), collection); + if (it == collectionList.end()) { + database_.createCollection(collection, "cosine"); + } + } +} + +VectorData DataManagementDatabase::getFileDataFromFileInfo(const DBFileInfo &fileinfo) const +{ + VectorData fileData; + fileData.id = generateUUID(); + fileData.metadata.insert({"source", std::string(fileinfo.filepath)}); + fileData.metadata.insert({"modifytime", int(fileinfo.modifyTime)}); + + return fileData; +} + +std::vector DataManagementDatabase::getVectorDatasFromFileInfo(const DBFileInfo &fileinfo, const std::string& sourceId) const +{ + std::vector datas; + for (auto const& vector : fileinfo.vectors) { + VectorData vectorData; + vectorData.id = generateUUID(); + vectorData.embedding = convertFloatToDouble(vector); + vectorData.metadata.insert({"sourceId", sourceId}); + datas.emplace_back(std::move(vectorData)); + } + + return datas; +} diff --git a/src/datamanagement/datamanagementdatabase.h b/src/datamanagement/datamanagementdatabase.h new file mode 100644 index 0000000000000000000000000000000000000000..ba424a305fd192d6fd16b6333563b19256ae1142 --- /dev/null +++ b/src/datamanagement/datamanagementdatabase.h @@ -0,0 +1,58 @@ +#ifndef DATAMANAGEMENTDATABASE_H +#define DATAMANAGEMENTDATABASE_H + +#include + +// pair.first: filepath, pair.second: similarity +using SimilaritySearchResult = std::vector>; + +// pair.first: filepath, pair.second: modify time +using GetAllFileInfosResult = std::vector>; + +struct DBFileInfo { + DBFileInfo(const std::string& filepath, + time_t modifyTime, + const std::vector>& vectors) + : filepath(filepath) + , modifyTime(modifyTime) + , vectors(vectors) {} + + bool empty() const { + return (vectors.empty()); + } + + std::string filepath; + time_t modifyTime; + std::vector> vectors; +}; + +class DataManagementDatabase +{ +public: + DataManagementDatabase() = default; + ~DataManagementDatabase(); + bool createClient(uid_t uid); + void destroyClient(); + + SimilaritySearchResult visionSearch(const std::vector &vector); + SimilaritySearchResult textSearch(const std::vector &vector); + void addImageDatas(const std::vector &fileinfos); + void addTextDatas(const std::vector& fileinfos); + GetAllFileInfosResult getAllFileInfos(); + void deleteFiles(const std::vector &files); + +private: + SimilaritySearchResult similaritySearch(const std::string& collection, const std::vector &vector); + void addDatas(const std::vector& fileinfos, const std::string& collection); + + bool hasCollections(const std::vector& collections); + void checkAndCreateCollections(const std::vector& collections); + VectorData getFileDataFromFileInfo(const DBFileInfo& fileinfo) const; + std::vector getVectorDatasFromFileInfo(const DBFileInfo& fileinfo, const std::string &sourceId) const; + +private: + VectorDB database_; + bool connected_ = false; +}; + +#endif // DATAMANAGEMENTDATABASE_H diff --git a/src/datamanagement/datamanagementjsonhelper.cpp b/src/datamanagement/datamanagementjsonhelper.cpp new file mode 100644 index 0000000000000000000000000000000000000000..92e64ebba0c56b01971d24ff3bc3edcdb10a6be2 --- /dev/null +++ b/src/datamanagement/datamanagementjsonhelper.cpp @@ -0,0 +1,142 @@ +#include "datamanagementjsonhelper.h" + +#include + +#include +#include +#include + +static double DEFAULT_SIMILARITY_THRESHOLD = 0.3; + +std::pair DataManagementJsonHelper::parserSearchConditions(const std::string &searchConditions) +{ + std::pair result; + + Json::CharReaderBuilder builder; + Json::CharReader* reader = builder.newCharReader(); + Json::Value root; + std::string errors; + + bool parsingSuccessful = reader->parse( + searchConditions.c_str(), searchConditions.c_str() + searchConditions.size(), + &root, &errors); + delete reader; + + if (!parsingSuccessful) { + std::cerr << "Failed to parse JSON: " << errors << std::endl; + return result; + } + + if (root.isMember("text")) { + result.first = root["text"].asString(); + } + + if (root.isMember("similarity-threshold")) { + result.second = root["similarity-threshold"].asDouble(); + } else { + result.second = DEFAULT_SIMILARITY_THRESHOLD; + } + + return result; +} + +std::string DataManagementJsonHelper::convertSearchResultToJson( + const std::vector> &searchResult) +{ + Json::Value root; + + // 遍历结果列表并转换为 JSON 格式 + for (const auto& result : searchResult) { + Json::Value item; + item["filepath"] = result.first; + item["similarity"] = result.second; + + root.append(item); + } + + Json::StreamWriterBuilder builder; + builder["indentation"] = ""; + std::ostringstream oss; + std::unique_ptr writer(builder.newStreamWriter()); + writer->write(root, &oss); + + std::string jsonString = oss.str(); + + return jsonString; +} + +std::vector> DataManagementJsonHelper::parserFilePathAndFormat(const std::string &fileinfos) +{ + std::vector> result; + + // 解析 JSON 字符串 + Json::CharReaderBuilder builder; + Json::Value root; + JSONCPP_STRING err; + std::istringstream jsonStream(fileinfos); + + if (!Json::parseFromStream(builder, jsonStream, &root, &err)) { + std::cerr << "Error parsing JSON: " << err << std::endl; + return result; + } + + if (root.isArray()) { + for (const auto& item : root) { + if (item.isObject() && item.isMember("filepath") && item.isMember("fileformat")) { + std::string filepath = item["filepath"].asString(); + std::string fileformat = item["fileformat"].asString(); + result.emplace_back(filepath, fileformat); + } + } + } + + return result; +} + +std::vector DataManagementJsonHelper::parserDeletedFileInfos(const std::string &fileinfos) +{ + std::vector result; + // 解析 JSON 字符串 + Json::CharReaderBuilder builder; + Json::Value root; + JSONCPP_STRING err; + std::istringstream jsonStream(fileinfos); + + if (!Json::parseFromStream(builder, jsonStream, &root, &err)) { + std::cerr << "Error parsing JSON: " << err << std::endl; + return result; + } + + if (!root.isArray()) { + std::cerr << "Delete file input json " << fileinfos + << " is not array, can't parser" << std::endl; + return result; + } + + for (const auto& item : root) { + if (item.isObject() && item.isMember("filepath")) { + std::string filepath = item["filepath"].asString(); + result.emplace_back(std::move(filepath)); + } + } + + return result; +} + +std::string DataManagementJsonHelper::convertFileInfosResultToJson(const std::vector> &fileinfos) +{ + Json::Value root(Json::arrayValue); // 直接初始化一个数组类型的根节点 + + for (const auto &info : fileinfos) { + Json::Value file; + file["filepath"] = info.first; + file["timestamp"] = info.second; + root.append(file); // 直接向根节点添加文件信息对象 + } + + Json::StreamWriterBuilder builder; + builder["indentation"] = ""; + std::string jsonString = Json::writeString(builder, root); + + return jsonString; +} diff --git a/src/datamanagement/datamanagementjsonhelper.h b/src/datamanagement/datamanagementjsonhelper.h new file mode 100644 index 0000000000000000000000000000000000000000..8b5765e093179cf0471c8443ac2ce94b8fd17350 --- /dev/null +++ b/src/datamanagement/datamanagementjsonhelper.h @@ -0,0 +1,17 @@ +#ifndef DATAMANAGEMENTJSONHELPER_H +#define DATAMANAGEMENTJSONHELPER_H + +#include +#include + +class DataManagementJsonHelper +{ +public: + static std::pair parserSearchConditions(const std::string& searchConditionsJson); + static std::vector> parserFilePathAndFormat(const std::string& fileinfos); + static std::vector parserDeletedFileInfos(const std::string& fileinfos); + static std::string convertSearchResultToJson(const std::vector>& searchResult); + static std::string convertFileInfosResultToJson(const std::vector>& fileinfos); +}; + +#endif // DATAMANAGEMENTJSONHELPER_H diff --git a/src/datamanagement/datamanagementprocessor.cpp b/src/datamanagement/datamanagementprocessor.cpp index 7591d9def2441995159446305ae3dc9965162fb4..fdcd7b7cbc34cca0e011ac670d465755afd969e3 100644 --- a/src/datamanagement/datamanagementprocessor.cpp +++ b/src/datamanagement/datamanagementprocessor.cpp @@ -1,5 +1,6 @@ // #include "datamanagement/datamanagementtaskmanager.h" #include "datamanagement/datamanagementprocessor.h" +#include "datamanagement/datamanagementservice.h" #include @@ -83,8 +84,7 @@ bool DataManagementProcessor::handleSimilaritySearch(AisdkDataManagementProcesso delegate, invocation, searchResult.c_str(), errorCode); }; - // DataManagementTaskManager::getInstance()->similaritySearch( - // searchConditions, uid, callback); + DataManagementService::getInstance()->similaritySearch(searchConditions, uid, callback); return true; } @@ -100,8 +100,7 @@ bool DataManagementProcessor::handleAddTextFiles(AisdkDataManagementProcessor *d delegate, invocation, errorCode); }; - // DataManagementTaskManager::getInstance()->addTextFiles( - // fileinfos, uid, callback); + DataManagementService::getInstance()->addTextFiles(fileinfos, uid, callback); return true; } @@ -117,8 +116,7 @@ bool DataManagementProcessor::handleAddImageFiles(AisdkDataManagementProcessor * delegate, invocation, errorCode); }; - // DataManagementTaskManager::getInstance()->addImageFiles( - // fileinfos, uid, callback); + DataManagementService::getInstance()->addImageFiles(fileinfos, uid, callback); return true; } @@ -133,8 +131,7 @@ bool DataManagementProcessor::handleDeleteFiles(AisdkDataManagementProcessor *de aisdk_data_management_processor_complete_delete_files( delegate, invocation, errorCode); }; - // DataManagementTaskManager::getInstance()->deleteFiles( - // fileinfos, uid, callback); + DataManagementService::getInstance()->deleteFiles(fileinfos, uid, callback); return true; } @@ -150,8 +147,7 @@ bool DataManagementProcessor::handleUpdateFilesName(AisdkDataManagementProcessor delegate, invocation, errorCode); }; - // DataManagementTaskManager::getInstance()->updateFilesName( - // fileinfos, uid, callback); + DataManagementService::getInstance()->updateFilesName(fileinfos, uid, callback); return true; } @@ -167,8 +163,7 @@ bool DataManagementProcessor::handleUpdateFilesContent(AisdkDataManagementProces delegate, invocation, errorCode); }; - // DataManagementTaskManager::getInstance()->updateFilesContent( - // fileinfos, uid, callback); + DataManagementService::getInstance()->updateFilesContent(fileinfos, uid, callback); return true; } @@ -184,7 +179,7 @@ bool DataManagementProcessor::handleGetAllFileinfos(AisdkDataManagementProcessor delegate, invocation, fileinfos.c_str(), errorCode); }; - // DataManagementTaskManager::getInstance()->getAllFileInfos(uid,callback); + DataManagementService::getInstance()->getAllFileInfos(uid,callback); return true; } diff --git a/src/datamanagement/datamanagementprocessor.h b/src/datamanagement/datamanagementprocessor.h index 1c2541ba08294f24702593f50d2689345e91b46c..1750bcad4aec1939ec803100053b433f3001f7c1 100644 --- a/src/datamanagement/datamanagementprocessor.h +++ b/src/datamanagement/datamanagementprocessor.h @@ -57,7 +57,6 @@ private: bool isExported_ = false; GDBusConnection &connection_; static const std::string objectPath_; - }; #endif // DATAMANAGEMENTPROCESSOR_H diff --git a/src/datamanagement/datamanagementservice.cpp b/src/datamanagement/datamanagementservice.cpp new file mode 100644 index 0000000000000000000000000000000000000000..019292d048653d5c92e3f524cd29e733fda36760 --- /dev/null +++ b/src/datamanagement/datamanagementservice.cpp @@ -0,0 +1,461 @@ +#include "datamanagementservice.h" +#include "thirdparty/threadpool/async.h" +#include "datamanagement/datamanagementjsonhelper.h" +#include "../util/parser/parser.h" +#include "../util/parser/fileparserfactory.h" +#include "datamanagement/segmenttokenizer.h" +#include "../embeddingtaskmanager/embeddingtaskmanager.h" +#include "../embeddingtaskmanager/embeddingtask.h" + +#include + +#include +#include + +static const double TEXT_SEARCH_THRESHOLD = 0.8; +static const double VISION_SEARCH_THRESHOLD = 0.4; + +DataManagementService::DataManagementService() + : embeddingTaskManager_(EmbeddingTaskManager::getInstance()) + {} + +void DataManagementService::similaritySearch(const std::string& searchConditions, uid_t uid, SimilaritySearchCallback callback) +{ + if (searchConditions.empty()) { + callback(std::string(), DATA_MANAGEMENT_PARAM_ERROR); + return; + } + // cpr::async([searchConditions, uid, callback, this]() { + int errorCode = 0; + + std::string result = doSimilaritySearch(searchConditions, uid, errorCode); + callback(result, errorCode); + // }); +} + +void DataManagementService::addTextFiles(const std::string& fileinfos, uid_t uid, AddTextFilesCallback callback) +{ + if (fileinfos.empty()) { + callback(DATA_MANAGEMENT_PARAM_ERROR); + return; + } + // cpr::async([fileinfos, uid, callback, this]() { + int errorCode = 0; + + doAddTextFiles(fileinfos, uid, errorCode); + callback(errorCode); + // }); +} + +void DataManagementService::addImageFiles(const std::string& fileinfos, uid_t uid, AddImageFilesCallback callback) +{ + if (fileinfos.empty()) { + callback(DATA_MANAGEMENT_PARAM_ERROR); + return; + } + + // cpr::async([fileinfos, uid, callback, this]() { + int errorCode = 0; + + doAddImageFiles(fileinfos, uid, errorCode); + callback(errorCode); + // }); +} + +void DataManagementService::deleteFiles(const std::string& fileinfos, uid_t uid, DeleteFilesCallback callback) +{ + if (fileinfos.empty()) { + callback(DATA_MANAGEMENT_PARAM_ERROR); + return; + } + + // cpr::async([fileinfos, uid, callback, this]() { + int errorCode = 0; + doDeleteFiles(fileinfos, uid, errorCode); + callback(errorCode); + // }); +} + +void DataManagementService::updateFilesName(const std::string& fileinfos, uid_t uid, UpdateFilesNameCallback callback) +{ + +} + +void DataManagementService::updateFilesContent(const std::string& fileinfos, uid_t uid, UpdateFilesContentCallback callback) +{ + if (fileinfos.empty()) { + callback(DATA_MANAGEMENT_PARAM_ERROR); + return; + } + + // cpr::async([fileinfos, uid, callback, this]() { + int errorCode = 0; + doUpdateFilesContent(fileinfos, uid, errorCode); + callback(errorCode); + // }); +} + +void DataManagementService::getAllFileInfos(uid_t uid, GetAllFileinfosCallback callback) +{ + int errorCode = 0; + + std::string result = doGetAllFileInfos(uid, errorCode); + callback(result, errorCode); +} + +SimilaritySearchResult DataManagementService::mergeSearchResults( + const SimilaritySearchResult& visionSearchResult, const SimilaritySearchResult& textSearchResult) const { + + // 合并两个搜索结果列表, 数据库返回的搜索结果是排好序的不需要重排 + std::vector> mergedResult; + mergedResult.insert(mergedResult.end(), visionSearchResult.begin(), visionSearchResult.end()); + mergedResult.insert(mergedResult.end(), textSearchResult.begin(), textSearchResult.end()); + + return mergedResult; +} + +void DataManagementService::filtSearchResults(SimilaritySearchResult &searchResult, double threshold) const +{ + // 丢弃文件相似度小于给定阈值的结果 + searchResult.erase( + std::remove_if(searchResult.begin(), searchResult.end(), + [threshold](const std::pair& result) { + return result.second < threshold; + }), + searchResult.end()); +} + +SimilaritySearchResult DataManagementService::deduplicatedSearchResults(const SimilaritySearchResult &textSearchResult) const +{ + std::unordered_map maxValues; + + for (const auto& pair : textSearchResult) { + auto it = maxValues.find(pair.first); + if (it != maxValues.end()) { + it->second = std::max(it->second, pair.second); + } else { + maxValues.insert(pair); + } + } + + std::vector> result; + for (const auto& pair : maxValues) { + result.push_back(pair); + } + std::sort(result.begin(), result.end(), [](const auto& a, const auto& b) { + return a.second > b.second; + }); + + return result; +} + +FileInfo DataManagementService::parserTextFileInfo( + const std::string &filepath, const std::string &format) const +{ + std::unique_ptr parser = FileParserFactory::createParser(format); + if (!parser) { + std::cerr << "can't create parser for " << format << std::endl; + return FileInfo(); + } + + return parser->parse(filepath); +} + +void DataManagementService::truncateTexts(std::vector &texts, int maxLength) const +{ + if (texts.size() > maxLength) { + int start = (texts.size() - maxLength) / 2; + texts.erase(texts.begin(), texts.begin() + start); // 删除前面多余的部分 + texts.erase(texts.begin() + maxLength, texts.end()); // 删除后面多余的部分 + } +} + +bool DataManagementService::isSupportedImageFormat(const std::string& format) const +{ + const std::vector supportFormat = {"png", "jpg", "jpeg", "jpe", "bmp", "dib"}; + bool isSupported = std::find(supportFormat.begin(), supportFormat.end(), format) != supportFormat.end(); + + return isSupported; +} + +bool DataManagementService::isSupportedTextFormat(const std::string& format) const +{ + const std::vector supportFormat = {"txt", "pdf", "docx", "pptx"}; + bool isSupported = std::find(supportFormat.begin(), supportFormat.end(), format) != supportFormat.end(); + + return isSupported; +} + +std::vector DataManagementService::makeDatasByImageFilePathAndFormat(const std::vector> &fileinfos) +{ + std::vector infos; + for (const std::pair& fileinfo : fileinfos) { + // 跳过不支持格式的文件 + std::string filepath = fileinfo.first; + std::string fileformat = fileinfo.second; + std::cout << "start handle file: " << filepath << " format: " << fileformat << std::endl; + + if (!isSupportedImageFormat(fileformat)) { + std::cerr << "format: " << fileformat << " not support" << std::endl; + continue; + } + if (!std::filesystem::exists(filepath)) { + std::cerr << "file: " << filepath << " not exit" << std::endl; + continue; + } + + std::vector vector = imageEmbedderImage(filepath); + if (vector.empty()) { + std::cerr << "file: " << filepath << " embedding error" << std::endl; + continue; + } + time_t modifyTime = Parser::modifyTime(filepath); + DBFileInfo info(filepath, modifyTime, {vector}); + infos.emplace_back(std::move(info)); + } + return infos; +} + +std::vector DataManagementService::makeDatasByTextFilePathAndFormat(const std::vector > &fileinfos) +{ + std::vector infos; + for (const auto& fileinfo : fileinfos) { + std::string path = fileinfo.first; + std::string format = fileinfo.second; + std::cout << "start handle file: " << path << " format: " << format << std::endl; + + if (!isSupportedTextFormat(format)) { + std::cerr << "format: " << format << " not support" << std::endl; + continue; + } + if (!std::filesystem::exists(path)) { + std::cerr << "file: " << path << " not exit" << std::endl; + continue; + } + FileInfo file = parserTextFileInfo(path, format); + if (file.contents.empty()) { + std::cerr << "parser file " << path << " to text failed." << std::endl; + continue; + } + + SegmentTokenizer::segmentTokenize(file.contents, 100); + std::cout << "start encode text file " << path << std::endl; + std::vector> vectors = textEmbedderTexts(file.contents); + std::cout << "finish encode text file " << path << std::endl; + + if (vectors.empty()) { + std::cerr << "file: " << path << " embedding error" << std::endl; + continue; + } + DBFileInfo info(path, file.modifyTime, vectors); + infos.emplace_back(std::move(info)); + } + + return infos; +} + +std::string DataManagementService::doSimilaritySearch( + const std::string &searchConditions, uid_t uid, int &errorCode) +{ + errorCode = 0; + std::pair condition = DataManagementJsonHelper::parserSearchConditions(searchConditions); + if (condition.first.empty()) { + errorCode = DATA_MANAGEMENT_PARAM_ERROR; + return std::string(); + } + + std::vector visionSideEmbedding = imageEmbedderText(condition.first); + std::vector textSideEmbedding = textEmbedderText(condition.first); + if (visionSideEmbedding.empty() && textSideEmbedding.empty()) { + errorCode = DATA_MANAGEMENT_INVAILD_PATH; + return std::string(); + } + + if (!database_.createClient(uid)) { + errorCode = DATA_MANAGEMENT_CANT_OPEN_DB; + return std::string(); + } + auto visionSearchResult = database_.visionSearch(visionSideEmbedding); + auto textSearchResult = database_.textSearch(textSideEmbedding); + database_.destroyClient(); + + if (visionSearchResult.empty() && textSearchResult.empty()) { + errorCode = DATA_MANAGEMENT_CANT_OPEN_DB; + return std::string(); + } + + textSearchResult = deduplicatedSearchResults(std::move(textSearchResult)); + + filtSearchResults(textSearchResult, TEXT_SEARCH_THRESHOLD); + filtSearchResults(visionSearchResult, VISION_SEARCH_THRESHOLD); + + std::vector> mergedResult = + mergeSearchResults(std::move(visionSearchResult), std::move(textSearchResult)); + + std::string result = DataManagementJsonHelper::convertSearchResultToJson(mergedResult); + + return result; +} + +void DataManagementService::doAddImageFiles( + const std::string &fileinfosJson, uid_t uid, int &errorCode) +{ + auto fileinfos = DataManagementJsonHelper::parserFilePathAndFormat(fileinfosJson); + if (fileinfos.empty()) { + errorCode = DATA_MANAGEMENT_PARAM_ERROR; + std::cerr << "parser add image files input falied" << std::endl; + return; + } + + std::vector infos = makeDatasByImageFilePathAndFormat(fileinfos); + + if (infos.empty()) { + std::cerr << "get file info failed" << std::endl; + errorCode = DATA_MANAGEMENT_UNKNOWN_ERROR; + return; + } + + if (!database_.createClient(uid)) { + errorCode = DATA_MANAGEMENT_CANT_OPEN_DB; + return; + } + database_.addImageDatas(infos); + database_.destroyClient(); +} + +void DataManagementService::doAddTextFiles( + const std::string &fileinfosJson, uid_t uid, int &errorCode) +{ + auto fileinfos = DataManagementJsonHelper::parserFilePathAndFormat(fileinfosJson); + if (fileinfos.empty()) { + errorCode = DATA_MANAGEMENT_PARAM_ERROR; + std::cerr << "parser add text files input falied" << std::endl; + return; + } + + std::vector infos = makeDatasByTextFilePathAndFormat(fileinfos); + if (infos.empty()) { + std::cerr << "can't get any file info, ignore request" << std::endl; + errorCode = DATA_MANAGEMENT_UNKNOWN_ERROR; + return; + } + + if (!database_.createClient(uid)) { + errorCode = DATA_MANAGEMENT_CANT_OPEN_DB; + return; + } + database_.addTextDatas(infos); + database_.destroyClient(); +} + +std::string DataManagementService::doGetAllFileInfos(uid_t uid, int &errorCode) +{ + if (!database_.createClient(uid)) { + errorCode = DATA_MANAGEMENT_CANT_OPEN_DB; + return std::string(); + } + + GetAllFileInfosResult result = database_.getAllFileInfos(); + return DataManagementJsonHelper::convertFileInfosResultToJson(result); +} + +void DataManagementService::doDeleteFiles(const std::string &fileinfosJson, uid_t uid, int &errorCode) +{ + std::vector files = DataManagementJsonHelper::parserDeletedFileInfos(fileinfosJson); + if (files.empty()) { + errorCode = DATA_MANAGEMENT_PARAM_ERROR; + return; + } + + if (!database_.createClient(uid)) { + errorCode = DATA_MANAGEMENT_CANT_OPEN_DB; + return; + } + database_.deleteFiles(files); + + database_.destroyClient(); +} + +void DataManagementService::doUpdateFilesContent(const std::string &fileinfosJson, uid_t uid, int &errorCode) +{ + auto fileinfos = DataManagementJsonHelper::parserFilePathAndFormat(fileinfosJson); + if (fileinfos.empty()) { + errorCode = DATA_MANAGEMENT_PARAM_ERROR; + std::cerr << "parser updata files content input falied" << std::endl; + return; + } + + std::vector files; + std::vector> imageFiles; + std::vector> textFiles; + + for (auto begin = fileinfos.begin(); begin != fileinfos.end();) { + files.emplace_back(begin->first); + if (isSupportedTextFormat(begin->second)) { + textFiles.emplace_back(*begin); + } else if (isSupportedImageFormat(begin->second)) { + imageFiles.emplace_back(*begin); + } + begin = fileinfos.erase(begin); + } + + auto imageDatas = makeDatasByImageFilePathAndFormat(std::move(imageFiles)); + auto textDatas = makeDatasByTextFilePathAndFormat(std::move(textFiles)); + + if (!database_.createClient(uid)) { + errorCode = DATA_MANAGEMENT_CANT_OPEN_DB; + return; + } + database_.deleteFiles(files); + + database_.addImageDatas(std::move(imageDatas)); + database_.addTextDatas(std::move(textDatas)); + + database_.destroyClient(); +} + +std::vector DataManagementService::textEmbedderText(const std::string& text) +{ + std::vector textEmbedding; + EmbeddingTask task(text, EmbeddingTask::EmbeddingDataType::Text); + embeddingTaskManager_->addTextSideTask(task); + textEmbedding = embeddingTaskManager_->getTextSideResult(task.id()); + return textEmbedding; +} + +std::vector> DataManagementService::textEmbedderTexts(const std::vector& texts) +{ + std::vector> textEmbeddings; + std::vector textEmbedding; + std::vector tasks; + for (int i = 0; i < texts.size(); i++) { + EmbeddingTask task(texts.at(i), EmbeddingTask::EmbeddingDataType::Text); + embeddingTaskManager_->addTextSideTask(task); + tasks.push_back(task); + } + + for (int i = 0; i < tasks.size(); i++) { + textEmbedding = embeddingTaskManager_->getTextSideResult(tasks.at(i).id()); + textEmbeddings.push_back(textEmbedding); + } + + return textEmbeddings; +} + +std::vector DataManagementService::imageEmbedderText(const std::string& text) +{ + std::vector textEmbedding; + EmbeddingTask task(text, EmbeddingTask::EmbeddingDataType::Text); + embeddingTaskManager_->addImageSideTask(task); + textEmbedding = embeddingTaskManager_->getImageSideResult(task.id()); + return textEmbedding; +} + +std::vector DataManagementService::imageEmbedderImage(const std::string& filePath) +{ + std::vector imageEmbedding; + EmbeddingTask task(filePath, EmbeddingTask::EmbeddingDataType::FilePath); + embeddingTaskManager_->addImageSideTask(task); + imageEmbedding = embeddingTaskManager_->getImageSideResult(task.id()); + return imageEmbedding; +} diff --git a/src/datamanagement/datamanagementservice.h b/src/datamanagement/datamanagementservice.h new file mode 100644 index 0000000000000000000000000000000000000000..1c8226f2d17fab161730a630f78d2be8b5215a63 --- /dev/null +++ b/src/datamanagement/datamanagementservice.h @@ -0,0 +1,72 @@ +#ifndef DATAMANAGEMENTSERVICE_H +#define DATAMANAGEMENTSERVICE_H + +#include "datamanagement/datamanagementdatabase.h" + +#include +#include +#include + +class FileInfo; +class Parser; +class EmbeddingTaskManager; +class DataManagementService +{ +public: + using SimilaritySearchCallback = std::function; + using AddTextFilesCallback = std::function; + using AddImageFilesCallback = std::function; + using DeleteFilesCallback = std::function; + using UpdateFilesNameCallback = std::function; + using UpdateFilesContentCallback = std::function; + using GetAllFileinfosCallback = std::function; + + static DataManagementService* getInstance() { + static DataManagementService service; + return &service; + } + + void similaritySearch(const std::string& searchConditions, uid_t uid, SimilaritySearchCallback callback); + void addTextFiles(const std::string& fileinfos, uid_t uid, AddTextFilesCallback callback); + void addImageFiles(const std::string& fileinfos, uid_t uid, AddImageFilesCallback callback); + void deleteFiles(const std::string& fileinfos, uid_t uid, DeleteFilesCallback callback); + void updateFilesName(const std::string& fileinfos, uid_t uid, UpdateFilesNameCallback callback); + void updateFilesContent(const std::string& fileinfos, uid_t uid, UpdateFilesContentCallback callback); + void getAllFileInfos(uid_t uid, GetAllFileinfosCallback callback); + + +private: + DataManagementService(); + + std::string doSimilaritySearch(const std::string& searchConditions, uid_t uid, int& errorCode); + void doAddImageFiles(const std::string& fileinfosJson, uid_t uid, int& errorCode); + std::vector makeDatasByImageFilePathAndFormat(const std::vector>& fileinfos); + void doAddTextFiles(const std::string& fileinfosJson, uid_t uid, int& errorCode); + std::vector makeDatasByTextFilePathAndFormat(const std::vector>& fileinfos); + std::string doGetAllFileInfos(uid_t uid, int& errorCode); + void doDeleteFiles(const std::string& fileinfosJson, uid_t uid, int& errorCode); + void doUpdateFilesContent(const std::string& fileinfosJson, uid_t uid, int& errorCode); + + SimilaritySearchResult mergeSearchResults(const SimilaritySearchResult& visionSearchResult, + const SimilaritySearchResult& textSearchResult) const; + void filtSearchResults(SimilaritySearchResult& searchResult, double threshold) const; + SimilaritySearchResult deduplicatedSearchResults(const SimilaritySearchResult &textSearchResult) const; + FileInfo parserTextFileInfo(const std::string& filepath, const std::string& format) const; + void truncateTexts(std::vector& texts, int maxLength) const; + bool isSupportedImageFormat(const std::string& format) const; + bool isSupportedTextFormat(const std::string& format) const; + + std::vector textEmbedderText(const std::string& text); + std::vector> textEmbedderTexts(const std::vector& texts); + std::vector imageEmbedderText(const std::string& text); + std::vector imageEmbedderImage(const std::string& filePath); + +private: + DataManagementDatabase database_; + EmbeddingTaskManager* embeddingTaskManager_ = nullptr; + + + std::mutex taskMutex_; +}; + +#endif // DATAMANAGEMENTSERVICE_H diff --git a/src/datamanagement/segmenttokenizer.cpp b/src/datamanagement/segmenttokenizer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..241b0b542db669785f7e95e03bfc3078309c654f --- /dev/null +++ b/src/datamanagement/segmenttokenizer.cpp @@ -0,0 +1,144 @@ +#include "segmenttokenizer.h" + +namespace SegmentTokenizer { + +std::vector splitString( + const std::string& str, const std::string& delimiters) { + std::vector result; + + size_t startPos = 0; + size_t foundPos = str.find(delimiters); + + while(foundPos != std::string::npos) { + if (delimiters == "." && !isEnglishPeriod(str, foundPos)) { + foundPos = str.find(delimiters, foundPos + delimiters.size()); + continue; + } + result.emplace_back(str.substr(startPos, foundPos - startPos + delimiters.size())); + startPos = foundPos + delimiters.size(); + foundPos = str.find(delimiters, startPos); + } + std::string lastSegment = str.substr(startPos); + if (!lastSegment.empty()) { + result.emplace_back(lastSegment); + } + + return result; +} + +// 按照句号分割句子,如果按照句号分割后仍然过长,则按照逗号进一步分割 +std::vector splitChineseString(const std::string& str, int maxLength) +{ + std::vector sentences = splitString(str, "。"); + for (auto it = sentences.begin(); it != sentences.end();) { + if (it->size() > maxLength) { + std::vector phrases = splitString(*it, ","); + it = sentences.erase(it); + it = sentences.insert(it, phrases.begin(), phrases.end()); + } + ++it; + } + return sentences; +} + +std::vector splitEnglishString(const std::string& str, int maxLength) +{ + std::vector sentences = splitString(str, "."); + for (auto it = sentences.begin(); it != sentences.end();) { + if (it->size() > maxLength) { + std::vector phrases = splitString(*it, ","); + it = sentences.erase(it); + it = sentences.insert(it, phrases.begin(), phrases.end()); + } + ++it; + } + return sentences; +} + +bool isEnglishPeriod(const std::string& text, size_t index) { + if (index == 0) // 确保不越界 + return false; + if (index == text.size() - 1) // 文本结尾的'.'认为是英文句号 + return true; + + // 观察`.`周围的内容 + char prevChar = text[index - 1]; + char nextChar = text[index + 1]; + + // 判断`.`是否是英文句号 + if (std::isspace(prevChar) && std::isalpha(nextChar)) // 如果`.`前是空格且`.`后是字母 + return true; + else if (std::isalpha(prevChar) && std::isspace(nextChar)) // 如果`.`前是字母且`.`后是空格 + return true; + else if (std::ispunct(prevChar) && std::isspace(nextChar)) // 如果`.`前是标点符号且`.`后是空格 + return true; + + return false; +} + +bool containsEnglishPeriod(const std::string& text) +{ + for (size_t i = 0; i < text.size(); ++i) { + if (isEnglishPeriod(text, i)) + return true; + } + return false; +} + +std::vector fixedSizeChunks(const std::string& text, int chunkSize) +{ + std::vector chunks; + + size_t startPos = 0; + while (startPos < text.size()) { + size_t endPos = startPos + chunkSize; + // 确保不要在一个 UTF-8 多字节字符的中间分割 + if (endPos < text.size()) { + while ((text[endPos] & 0xC0) == 0x80) { + // 如果 endPos 指向的字节是多字节字符的非起始字节 + // 向前移动,直到找到多字节字符的起始字节 + --endPos; + } + } + + size_t length = endPos - startPos; + chunks.push_back(text.substr(startPos, length)); + startPos += length; + } + + return chunks; +} + +// 文本分割,首先按照句号分割,如果按照句号分割完过长, +// 则按逗号进行进一步分割,如果按照逗号分割完仍然过长,则按照固定长度进行分割 +void segmentTokenize(std::vector &texts, int maxLength) +{ + for (auto begin = texts.begin(); begin != texts.end();) { + if (begin->size() < maxLength) { + ++begin; + continue; + } + + if (begin->find("。") != std::string::npos) { + std::vector sentences = splitChineseString(*begin, maxLength); + begin = texts.erase(begin); + begin = texts.insert(begin, sentences.begin(), sentences.end()); + } else if (containsEnglishPeriod(*begin)) { + std::vector sentences = splitEnglishString(*begin, maxLength); + begin = texts.erase(begin); + begin = texts.insert(begin, sentences.begin(), sentences.end()); + } + + if (begin->size() < maxLength) { + ++begin; + continue; + } + + std::vector chunks = fixedSizeChunks(*begin, maxLength); + begin = texts.erase(begin); + begin = texts.insert(begin, chunks.begin(), chunks.end()); + std::advance(begin, chunks.size()); + } +} + +} diff --git a/src/datamanagement/segmenttokenizer.h b/src/datamanagement/segmenttokenizer.h new file mode 100644 index 0000000000000000000000000000000000000000..9b70f79bf3f023d374ea1026ab4534d790c01d5a --- /dev/null +++ b/src/datamanagement/segmenttokenizer.h @@ -0,0 +1,19 @@ +#ifndef SEGMENTTOKENIZER_H +#define SEGMENTTOKENIZER_H + +#include +#include + +namespace SegmentTokenizer { + +std::vector splitString(const std::string& str, const std::string& delimiters); +std::vector splitChineseString(const std::string& str, int maxLength); +std::vector splitEnglishString(const std::string& str, int maxLength); +bool isEnglishPeriod(const std::string& text, size_t index); +bool containsEnglishPeriod(const std::string& text); +std::vector fixedSizeChunks(const std::string& text, int chunkSize); +void segmentTokenize(std::vector& texts, int maxLength); + +} // namespace SegmentTokenizer + +#endif // SEGMENTTOKENIZER_H diff --git a/src/servicemanager.cpp b/src/servicemanager.cpp index f0ec017874e2f763ac3d029b57311f0b0b74b866..5e650cf07cae7ae2259b402e8c47fefe7e15284a 100644 --- a/src/servicemanager.cpp +++ b/src/servicemanager.cpp @@ -24,5 +24,5 @@ ServiceManager::ServiceManager(GDBusConnection &connection) void ServiceManager::initProcessors(GDBusConnection &connection) { - // dataManagementProcessor_ = std::make_unique(connection); + dataManagementProcessor_ = std::make_unique(connection); } diff --git a/src/servicemanager.h b/src/servicemanager.h index a2b569224c99f416c64619bef2ce8705ca39d58e..421fbfaea5924d78e12b03641bcfae3dde38fe85 100644 --- a/src/servicemanager.h +++ b/src/servicemanager.h @@ -21,7 +21,7 @@ #include -// #include "datamanagementprocessor.h" +#include "datamanagement/datamanagementprocessor.h" class ServiceManager { public: @@ -31,7 +31,7 @@ private: void initProcessors(GDBusConnection &connection); private: - // std::unique_ptr dataManagementProcessor_; + std::unique_ptr dataManagementProcessor_; }; #endif diff --git a/src/util/parser/officepyparserwrapper.cpp b/src/util/parser/officepyparserwrapper.cpp index 43e87b559892b1a5104c442652cf5b0b6c64ef90..ea7416f396d2794c3434776aaa02a7db7c563660 100644 --- a/src/util/parser/officepyparserwrapper.cpp +++ b/src/util/parser/officepyparserwrapper.cpp @@ -1,11 +1,14 @@ #include "officepyparserwrapper.h" #include "../python/pythonutil.h" +#include "../python/pythonthreadlocker.h" OfficePyParserWrapper::OfficePyParserWrapper() { if (!Py_IsInitialized()) Py_Initialize(); + + PythonThreadLocker locker; PyRun_SimpleString("import sys"); std::string command = "sys.path.append('" + std::string(DATA_MANAGEMENT_PYTHON_PATH) + "')"; PyRun_SimpleString(command.c_str()); @@ -13,6 +16,7 @@ OfficePyParserWrapper::OfficePyParserWrapper() FileInfo OfficePyParserWrapper::parse(const std::string &filePath, const std::string &module, const std::string &className) { + PythonThreadLocker locker; FileInfo fileinfo; PyObject* parserModule = PyImport_ImportModule(module.c_str());