diff --git a/MPC/middleware/CMakeLists.txt b/MPC/middleware/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..4dc998808e69486dc7e4a5db14cd072a92f66a7b --- /dev/null +++ b/MPC/middleware/CMakeLists.txt @@ -0,0 +1,16 @@ +cmake_minimum_required(VERSION 3.16) + +project(kcal_middleware LANGUAGES C CXX) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +find_package(libkcal) + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) + +add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/kcal) + +install(DIRECTORY kcal + DESTINATION include + FILES_MATCHING PATTERN "*.h" +) diff --git a/MPC/middleware/kcal/CMakeLists.txt b/MPC/middleware/kcal/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a66297233941522b6e1d0f3d03506959b4890c89 --- /dev/null +++ b/MPC/middleware/kcal/CMakeLists.txt @@ -0,0 +1,9 @@ +file(GLOB_RECURSE COMMON_SRCS ${CMAKE_CURRENT_LIST_DIR}/*.cc) + +add_library(kcal_middle ${COMMON_SRCS}) +target_link_libraries(kcal_middle PUBLIC lib_kcal) + +install(TARGETS kcal_middle + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib +) diff --git a/MPC/middleware/kcal/api/kcal_api.h b/MPC/middleware/kcal/api/kcal_api.h new file mode 100644 index 0000000000000000000000000000000000000000..f125f9d809c583afd4089e7497118561b8dbe94f --- /dev/null +++ b/MPC/middleware/kcal/api/kcal_api.h @@ -0,0 +1,27 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * virtCCA_sdk is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR + * PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef KCAL_API_H +#define KCAL_API_H + +#include "data_guard_anonymization.h" +#include "data_guard_callback.h" +#include "data_guard_cfs.h" +#include "data_guard_config.h" +#include "data_guard_di.h" +#include "data_guard_dp.h" +#include "data_guard_enum.h" +#include "data_guard_error_code.h" +#include "data_guard_mpc.h" +#include "data_guard_struct.h" + +#endif // KCAL_API_H diff --git a/MPC/middleware/kcal/core/context.cc b/MPC/middleware/kcal/core/context.cc new file mode 100644 index 0000000000000000000000000000000000000000..23ce3dcf1643d5541aa11bbc26252ea99b7e703b --- /dev/null +++ b/MPC/middleware/kcal/core/context.cc @@ -0,0 +1,61 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * virtCCA_sdk is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR + * PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include "kcal/core/context.h" + +namespace kcal { + +Context::Context(KCAL_Config config, KCAL_AlgorithmsType type) : config_(config), type_(type) {} + +Context::~Context() +{ + DG_ReleaseConfig(&teeCfg_); + DG_ReleaseConfigOpts(&cfgOpts_); +} + +int Context::Init() +{ + int rv = DG_InitConfigOpts(DG_BusinessType::MPC, &cfgOpts_); + if (rv != DG_SUCCESS) { + return rv; + } + rv = cfgOpts_->init(&teeCfg_); + if (rv != DG_SUCCESS) { + return rv; + } + + cfgOpts_->setIntValue(teeCfg_, DG_CON_MPC_TEE_INT_NODEID, config_.nodeId); + if (type_ == KCAL_AlgorithmsType::PSI || type_ == KCAL_AlgorithmsType::PIR) { + cfgOpts_->setIntValue(teeCfg_, DG_CON_MPC_TEE_INT_FXP_BITS, 0); + } else { + cfgOpts_->setIntValue(teeCfg_, DG_CON_MPC_TEE_INT_FXP_BITS, config_.fixBits); + } + cfgOpts_->setIntValue(teeCfg_, DG_CON_MPC_TEE_INT_THREAD_COUNT, config_.threadCount); + + return DG_SUCCESS; +} + +std::shared_ptr Context::Create(KCAL_Config config, TEE_NET_RES *netRes, KCAL_AlgorithmsType type) +{ + auto context = std::make_shared(config, type); + context->Init(); + context->SetNetRes(netRes); + return context; +} + +void Context::SetNetRes(TEE_NET_RES *teeNetRes) +{ + DG_Void netFunc = {.data = teeNetRes, .size = sizeof(TEE_NET_RES)}; + cfgOpts_->setVoidValue(teeCfg_, DG_CON_MPC_TEE_VOID_NET_API, &netFunc); +} + +} diff --git a/MPC/middleware/kcal/core/context.h b/MPC/middleware/kcal/core/context.h new file mode 100644 index 0000000000000000000000000000000000000000..a30ad80f96f51c90627527924a7b2138c63ccd2c --- /dev/null +++ b/MPC/middleware/kcal/core/context.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * virtCCA_sdk is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR + * PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef CONTEXT_H +#define CONTEXT_H + +#include +#include "kcal/api/kcal_api.h" +#include "kcal/enumeration/kcal_enum.h" + +namespace kcal { + +struct KCAL_Config { + int nodeId; + int fixBits; + int threadCount; + int worldSize; +}; + +class Context { +public: + Context() = default; + + explicit Context(KCAL_Config config, KCAL_AlgorithmsType type); + + Context(const Context &) = delete; + Context &operator=(const Context &) = delete; + + ~Context(); + + static std::shared_ptr Create(KCAL_Config config, TEE_NET_RES *netRes, KCAL_AlgorithmsType type); + + int Init(); + + [[nodiscard]] int GetWorldSize() const { return config_.worldSize; } + + void *GetTeeConfig() { return teeCfg_; } + + void SetNetRes(TEE_NET_RES *teeNetRes); + +private: + KCAL_Config config_; + KCAL_AlgorithmsType type_; + + void *teeCfg_ = nullptr; + DG_ConfigOpts *cfgOpts_ = nullptr; +}; + +} + +#endif // CONTEXT_H diff --git a/MPC/middleware/kcal/enumeration/kcal_enum.h b/MPC/middleware/kcal/enumeration/kcal_enum.h new file mode 100644 index 0000000000000000000000000000000000000000..e7597c480b55082c2eb2ddf541495901ff07df3a --- /dev/null +++ b/MPC/middleware/kcal/enumeration/kcal_enum.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * virtCCA_sdk is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR + * PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef KCAL_ENUM_H +#define KCAL_ENUM_H + +namespace kcal { + +enum class KCAL_AlgorithmsType { + PSI, + PIR +}; + +} + +#endif // KCAL_ENUM_H diff --git a/MPC/middleware/kcal/operator/kcal_pir.cc b/MPC/middleware/kcal/operator/kcal_pir.cc new file mode 100644 index 0000000000000000000000000000000000000000..76137772d1aea0ffb34d60d1a9267ef54fa6b42d --- /dev/null +++ b/MPC/middleware/kcal/operator/kcal_pir.cc @@ -0,0 +1,64 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * virtCCA_sdk is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR + * PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include "kcal/operator/kcal_pir.h" +#include "kcal/utils/node_info_helper.h" + +namespace kcal { + +Pir::Pir() +{ + opts_ = std::make_unique(); + *opts_ = DG_InitPirOpts(); +} + +Pir::~Pir() +{ + opts_->releaseBucketMap(&bucketMap_); + opts_->releaseTeeCtx(&dgTeeCtx_); +} + +int Pir::Init(std::shared_ptr ctx) +{ + utils::NodeInfoHelper nodeInfoHelper(ctx->GetWorldSize()); + + int rv = opts_->initTeeCtx(ctx->GetTeeConfig(), &dgTeeCtx_); + if (rv != 0) { + return rv; + } + + rv = opts_->setTeeNodeInfos(dgTeeCtx_, nodeInfoHelper.Get()); + if (rv != 0) { + return rv; + } + + baseCtx_ = std::move(ctx); + + return DG_SUCCESS; +} + +int Pir::ServerPreProcess(DG_PairList *pairList) +{ + return opts_->offlineCalculate(dgTeeCtx_, pairList, &bucketMap_); +} + +int Pir::ClientQuery(DG_TeeInput *input, DG_TeeOutput **output, DG_DummyMode dummyMode) +{ + return opts_->clientCalculate(dgTeeCtx_, dummyMode, input, output); +} + +int Pir::ServerAnswer() +{ + return opts_->serverCalculate(dgTeeCtx_, bucketMap_); +} + +} diff --git a/MPC/middleware/kcal/operator/kcal_pir.h b/MPC/middleware/kcal/operator/kcal_pir.h new file mode 100644 index 0000000000000000000000000000000000000000..96a1bc8439feb7d5d490cfb5af87bb77741f037f --- /dev/null +++ b/MPC/middleware/kcal/operator/kcal_pir.h @@ -0,0 +1,43 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * virtCCA_sdk is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR + * PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef KCAL_PIR_H +#define KCAL_PIR_H + +#include +#include "kcal/core/context.h" + +namespace kcal { + +class Pir { +public: + Pir(); + ~Pir(); + + Pir(const Pir &) = delete; + Pir &operator=(const Pir &) = delete; + + int Init(std::shared_ptr ctx); + int ServerPreProcess(DG_PairList *pairList); + int ClientQuery(DG_TeeInput *input, DG_TeeOutput **output, DG_DummyMode dummyMode); + int ServerAnswer(); + +private: + DG_TeeCtx *dgTeeCtx_ = nullptr; + std::shared_ptr baseCtx_; + std::unique_ptr opts_; + DG_BucketMap *bucketMap_ = nullptr; +}; + +} + +#endif // KCAL_PIR_H diff --git a/MPC/middleware/kcal/operator/kcal_psi.cc b/MPC/middleware/kcal/operator/kcal_psi.cc new file mode 100644 index 0000000000000000000000000000000000000000..2bd5763c7dc47d30abedd7fe3a7ea64a0cc7169a --- /dev/null +++ b/MPC/middleware/kcal/operator/kcal_psi.cc @@ -0,0 +1,50 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * virtCCA_sdk is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR + * PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include "kcal/operator/kcal_psi.h" +#include "kcal/utils/node_info_helper.h" + +namespace kcal { + +Psi::Psi() +{ + opts_ = std::make_unique(); + *opts_ = DG_InitPsiOpts(); +} + +Psi::~Psi() { opts_->releaseTeeCtx(&dgTeeCtx_); } + +int Psi::Init(std::shared_ptr ctx) +{ + utils::NodeInfoHelper nodeInfoHelper(ctx->GetWorldSize()); + + int rv = opts_->initTeeCtx(ctx->GetTeeConfig(), &dgTeeCtx_); + if (rv != 0) { + return rv; + } + + rv = opts_->setTeeNodeInfos(dgTeeCtx_, nodeInfoHelper.Get()); + if (rv != 0) { + return rv; + } + + baseCtx_ = std::move(ctx); + + return DG_SUCCESS; +} + +int Psi::Run(DG_TeeInput *input, DG_TeeOutput **output, DG_TeeMode outputMode) +{ + return opts_->calculate(dgTeeCtx_, PSI, input, output, outputMode); +} + +} diff --git a/MPC/middleware/kcal/operator/kcal_psi.h b/MPC/middleware/kcal/operator/kcal_psi.h new file mode 100644 index 0000000000000000000000000000000000000000..de32d492fc5ac1574056c66b7dabbb5ea7e76d5e --- /dev/null +++ b/MPC/middleware/kcal/operator/kcal_psi.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * virtCCA_sdk is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR + * PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef KCAL_PSI_H +#define KCAL_PSI_H + +#include +#include "kcal/api/kcal_api.h" +#include "kcal/core/context.h" + +namespace kcal { + +class Psi { +public: + Psi(); + ~Psi(); + + Psi(const Psi &) = delete; + Psi &operator=(const Psi &) = delete; + + int Init(std::shared_ptr ctx); + int Run(DG_TeeInput *input, DG_TeeOutput **output, DG_TeeMode outputMode); + +private: + DG_TeeCtx *dgTeeCtx_ = nullptr; + std::shared_ptr baseCtx_ = nullptr; + std::unique_ptr opts_; +}; + +} + +#endif // KCAL_PSI_H diff --git a/MPC/middleware/kcal/utils/node_info_helper.cc b/MPC/middleware/kcal/utils/node_info_helper.cc new file mode 100644 index 0000000000000000000000000000000000000000..99742b1eaaa43846795e3404e5d720944807c03e --- /dev/null +++ b/MPC/middleware/kcal/utils/node_info_helper.cc @@ -0,0 +1,28 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * virtCCA_sdk is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR + * PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#include "node_info_helper.h" + +namespace kcal::utils { + +NodeInfoHelper::NodeInfoHelper(int worldSize) +{ + infos_.resize(worldSize); + for (int i = 0; i < worldSize; ++i) { + infos_[i].nodeId = i; + } + + nodeInfos_.nodeInfo = infos_.data(); + nodeInfos_.size = infos_.size(); +} + +} diff --git a/MPC/middleware/kcal/utils/node_info_helper.h b/MPC/middleware/kcal/utils/node_info_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..953d890a11959fb7f3bd3fe3c2446704209177e9 --- /dev/null +++ b/MPC/middleware/kcal/utils/node_info_helper.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * virtCCA_sdk is licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR + * PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +#ifndef NODE_INFO_HELPER_H +#define NODE_INFO_HELPER_H + +#include +#include "kcal/api/kcal_api.h" + +namespace kcal::utils { + +class NodeInfoHelper { +public: + explicit NodeInfoHelper(int worldSize); + + ~NodeInfoHelper() = default; + + NodeInfoHelper(const NodeInfoHelper &) = delete; + NodeInfoHelper &operator=(const NodeInfoHelper &) = delete; + + TeeNodeInfos *Get() { return &nodeInfos_; } + +private: + TeeNodeInfos nodeInfos_; + std::vector infos_; +}; + +} + +#endif // NODE_INFO_HELPER_H diff --git a/MPC/third_party_adaptor/secrerflow/psi/README.md b/MPC/third_party_adaptor/secrerflow/psi/README.md new file mode 100644 index 0000000000000000000000000000000000000000..06e98c97d23ada9189c8487e5a0b895c5b92f862 --- /dev/null +++ b/MPC/third_party_adaptor/secrerflow/psi/README.md @@ -0,0 +1,293 @@ +# 前置条件 + +1. 需获取 `kcal` 包,含 `include` 和 `lib` 目录 +2. 需要 bazel 编译构建工具,编译环境依赖参考:[devtools/dockerfiles/release-ci-aarch64.DockerFile at main · secretflow/devtools](https://github.com/secretflow/devtools/blob/main/dockerfiles/release-ci-aarch64.DockerFile) +3. 运行环境为`virtCCA cvm`​ + +# 中间件在蚂蚁 psi 库的编译 + +## 创建工作目录 + +为方便进行演示,以下操作均以 `/home/admin/dev` 目录作为工作主目录 + +## clone virtCCA\_sdk 仓 + +```shell +cd /home/admin/dev + +git clone https://gitee.com/openeuler/virtCCA_sdk.git +``` + +## clone 蚂蚁 psi 仓,应用 patch + +注:目前已适配蚂蚁 `tag: psi-v0.6.0.dev250507` 版本的 `psi`库 + +```shell +cd /home/admin/dev + +# clone 仓库,并创建一个本地分支 +git clone --branch "v0.6.0.dev250507" https://github.com/secretflow/psi.git +git switch -c kcal-on-v0.6.0 + +# 应用 virtCCA_sdk 下面的 patch +git apply /home/admin/dev/virtCCA_sdk/MPC/third_party_adaptor/secrerflow/psi/patches/kcal.patch +``` + +## 引入中间件和 kcal 库 + +### 引入 kcal 库 + +```shell +# 假设 kcal 库已下载在 /opt/kcal 目录下 +cp -r /opt/kcal/include /home/admin/dev/psi/third_party/kcal/ +cp -r /opt/kcal/lib /home/admin/dev/psi/third_party/kcal/ +``` + +### 引入中间件 + +```shell +cp -r /home/admin/dev/virtCCA_sdk/MPC/middleware/* /home/admin/dev/psi/third_party/kcal_middleware/ +``` + +## 执行编译 + +```shell +cd /home/admin/dev/psi + +bazel build //... -c opt +``` + +编译完成后,在`bazel-bin/psi/apps/psi_launcher`目录下生成`main`可执行文件,后续可通过执行`./bazel-bin/psi/apps/psi_launcher/main --config xxx.json`进行验证 + +## 将二进制部署至 virtCCA 内 + +可让`/home/admin/dev/psi`的目录结构与`virtCCA`内部保持一致,方便进行测试 + +# kcal psi 在蚂蚁 psi 库上的测试 + +编译完成后,进入`/home/admin/dev/psi`目录,步骤可参考:`examples/psi/README.md`​ + +## 生成 psi 测试数据 + +```bash +cd /home/admin/dev/psi + +python examples/psi/generate_psi_data.py \ + --receiver_item_cnt 1e6 \ + --sender_item_cnt 1e6 \ + --intersection_cnt 8e4 \ + --id_cnt 2 \ + --receiver_path /tmp/receiver_input.csv \ + --sender_path /tmp/sender_input.csv \ + --intersection_path /tmp/intersection.csv +``` + +> 说明: +> +> --receiver_item_cnt:receiver 方拥有的数据总量 +> +> --sender_item_cnt:sender 方拥有的数据总量 +> +> --intersection_cnt:约定两方产生交集部分的数据总量 +> +> --id_cnt:每个参与方的输入数据包含几个字段 +> +> --receiver_path:receiver 方输入数据的文件位置 +> +> --sender_path:sender 方输入数据的文件位置 +> +> --intersection_path:生成的交集数据的文件位置 + +## 配置文件说明 + +配置文件已在`patch`中提供,只需修改下列说明的部分进行测试 + +下面以`kcal_receiver.json`为例 + +```json +{ + "psi_config": { + "protocol_config": { + "protocol": "PROTOCOL_KCAL", + "kcal_config": { + "thread_count": 16 // 线程数按需修改,目前固定 16 线程 + }, + "role": "ROLE_SENDER", + "broadcast_result": true + }, + "input_config": { + "type": "IO_TYPE_FILE_CSV", + "path": "/tmp/sender_input.csv" // 当前参与方运行 psi 算法的数据输入文件位置,按需修改 + }, + "output_config": { + "type": "IO_TYPE_FILE_CSV", // 文件类型不需修改 + "path": "/tmp/kcal_sender_output.csv" // 两方运行完 psi 算法后,最终交集文件的输出位置,按需修改 + }, + "keys": ["id_0", "id_1"], // 有几个字段 + "debug_options": { // 无需修改 + "trace_path": "/tmp/kcal_sender.trace" + }, + "disable_alignment": true, + "recovery_config": { // 这个配置不需要修改,kcal 目前无 recovery 模式 + "enabled": false, + "folder": "/tmp/kcal_sender_cache" + } + }, + "link_config": { + "parties": [ // 两个参与方的通信 ip 和 端口,按需修改 + { + "id": "receiver", + "host": "127.0.0.1:5300" + }, + { + "id": "sender", + "host": "127.0.0.1:5400" + } + ] + }, + "self_link_party": "sender" // 当前参与方的标识 +} + +``` + +### kcal 两个配置文件 + +- examples/psi/config/kcal_receiver.json +- examples/psi/config/kcal_sender.json + +### 蚂蚁对比配置文件 + +- examples/psi/config/rr22_receiver_recovery.json +- examples/psi/config/rr22_sender_recovery.json + +## 测试 + +进入两个 cvm 分别执行以下命令 + +```bash +cd /home/admin/dev/psi + +# 参与方 receiver +./bazel-bin/psi/apps/psi_launcher/main --config $(pwd)/examples/psi/config/kcal_receiver.json +# 参与方 sender +./bazel-bin/psi/apps/psi_launcher/main --config $(pwd)/examples/psi/config/kcal_sender.json +``` + +运行完以上命令后,每个`cvm`内会在配置文件中指明的`output_config.path`路径中生成交集文件 + +## 结果对比 + +将`/tmp/kcal_sender_output.csv`、`/tmp/kcal_receiver_output.csv`的内容与一开始生成的交集文件`/tmp/intersection.csv`内容进行比对,结果保持一致 + +# kcal pir 在蚂蚁 psi 库上的测试 + +编译完成后,进入`/home/admin/dev/psi`目录,步骤可参考:`examples/pir/README.md`​ + +## 生成 pir 测试数据 + +```bash +cd /home/admin/dev/psi + +python examples/pir/apsi/test_data_creator.py \ + --sender_size=10000000 \ + --receiver_size=1000 \ + --intersection_size=100 \ + --label_byte_count=100 \ + --item_byte_count=16 + +mv ground_truth.csv /tmp/ground_truth.csv +mv db.csv /tmp/db.csv +mv query.csv /tmp/query.csv +``` + +> 说明: +> +> --sender_size:服务端数据库总体数据行数 +> +> --receiver_size:客户端要查询的 key 的数量 +> +> --intersection_size:实际上生成的数据里面,服务端只有 intersection_size 个包含客户端能够查到的键值 +> +> --label_byte_count:服务端数据库每个 value 所占的字节数 +> +> --item_byte_count:服务端数据库每个 key 所占的字节数 + +## 配置文件说明 + +配置文件已在`patch`中提供,只需修改下列说明的部分进行测试 + +```json +// 客户端配置文件 +{ + "kcal_pir_receiver_config": { + "threads": 16, // 多线程处理,按需修改 + "query_file": "/tmp/query.csv", // 要查询的 key 的集合文件位置,按需修改 + "output_file": "/tmp/result.csv", // 查询结果 value 的保存位置,按需修改 + "is_dummy_mode": true // 查询的 key 是否进行 dummy,按需修改 + }, + "link_config": { // 两个参与方的通信 ip 和 端口,按需修改 + "parties": [ + { + "id": "sender", + "host": "127.0.0.1:5300" + }, + { + "id": "receiver", + "host": "127.0.0.1:5400" + } + ] + }, + "self_link_party": "receiver" +} + +// 服务端配置文件 +{ + "kcal_pir_sender_config": { + "threads": 16, // 多线程处理,按需修改 + "db_file": "/tmp/db.csv" // 数据库文件位置 + }, + "link_config": { // 两个参与方的通信 ip 和 端口,按需修改 + "parties": [ + { + "id": "sender", + "host": "127.0.0.1:5300" + }, + { + "id": "receiver", + "host": "127.0.0.1:5400" + } + ] + }, + "self_link_party": "sender" +} +``` + +### kcal 两个配置文件 + +- examples/pir/config/kcal_pir_receiver.json +- examples/pir/config/kcal_pir_sender.json + +### 蚂蚁对比配置文件 + +- examples/pir/config/apsi_sender_setup.json +- examples/pir/config/apsi_sender_online.json +- examples/pir/config/apsi_receiver.json + +## 测试 + +进入两个 cvm 分别执行以下命令 + +```bash +cd /home/admin/dev/psi + +# 服务端 +./bazel-bin/psi/apps/psi_launcher/main --config $(pwd)/examples/pir/config/kcal_pir_sender.json +# 客户端 +./bazel-bin/psi/apps/psi_launcher/main --config $(pwd)/examples/pir/config/kcal_pir_receiver.json +``` + +运行完以上命令后,客户端`cvm`内会在配置文件中指明的`kcal_pir_receiver_config.output_file`路径中生成查询结果文件 + +## 结果对比 + +将`/tmp/result.csv`的内容与一开始生成的交集文件`/tmp/ground_truth.csv`内容进行比对,结果保持一致 diff --git a/MPC/third_party_adaptor/secrerflow/psi/patches/kcal.patch b/MPC/third_party_adaptor/secrerflow/psi/patches/kcal.patch new file mode 100644 index 0000000000000000000000000000000000000000..b837a6eba37bde1961e8be306423f6b43cf2827b --- /dev/null +++ b/MPC/third_party_adaptor/secrerflow/psi/patches/kcal.patch @@ -0,0 +1,1646 @@ +diff --git a/.bazelrc b/.bazelrc +index 6406143..ca5d119 100644 +--- a/.bazelrc ++++ b/.bazelrc +@@ -26,6 +26,7 @@ build --incompatible_disallow_empty_glob=false + + build --cxxopt=-std=c++17 + build --host_cxxopt=-std=c++17 ++build --linkopt=-pthread + + # Binary safety flags + build --copt=-fPIC +diff --git a/MODULE.bazel b/MODULE.bazel +index 7170c4e..4292630 100644 +--- a/MODULE.bazel ++++ b/MODULE.bazel +@@ -66,6 +66,12 @@ new_local_repository( + path = "/opt/homebrew/opt/libomp/", + ) + ++new_local_repository( ++ name = "kunpeng_kcal_middleware", ++ build_file = "//bazel:kunpengkcal.BUILD", ++ path = "third_party/kcal_middleware" ++) ++ + bazel_dep(name = "kuscia", version = "0.14.0b0") + + # test +diff --git a/bazel/kunpengkcal.BUILD b/bazel/kunpengkcal.BUILD +new file mode 100644 +index 0000000..8416212 +--- /dev/null ++++ b/bazel/kunpengkcal.BUILD +@@ -0,0 +1,17 @@ ++load("@rules_foreign_cc//foreign_cc:defs.bzl", "cmake") ++ ++package(default_visibility = ["//visibility:public"]) ++ ++filegroup( ++ name = "kcal_all", ++ srcs = glob(["**"]), ++) ++ ++cmake( ++ name = "kcal_middleware", ++ lib_source = ":kcal_all", ++ out_static_libs = ["libkcal_middle.a"], ++ deps = [ ++ "@//third_party:libkcal" ++ ] ++) +diff --git a/examples/pir/README.md b/examples/pir/README.md +index b4bb342..3b0eddd 100644 +--- a/examples/pir/README.md ++++ b/examples/pir/README.md +@@ -6,11 +6,13 @@ + + ```bash + python examples/pir/apsi/test_data_creator.py --sender_size=100000 --receiver_size=1 --intersection_size=1 --label_byte_count=16 ++python examples/pir/apsi/test_data_creator.py --sender_size=10000000 --receiver_size=1000 --intersection_size=100 --label_byte_count=100 --item_byte_count=16 + + mv db.csv /tmp/db.csv + mv query.csv /tmp/query.csv + + cp examples/pir/apsi/parameters/100K-1-16.json /tmp/100K-1-16.json ++cp examples/pir/apsi/parameters/1M-1024-cmp.json /tmp/1M-1024-cmp.json + ``` + + ### NOTE +@@ -39,12 +41,14 @@ At sender terminal, run + + ```bash + ./bazel-bin/psi/apps/psi_launcher/main --config $(pwd)/examples/pir/config/apsi_sender_online.json ++./bazel-bin/psi/apps/psi_launcher/main --config $(pwd)/examples/pir/config/kcal_pir_sender.json + ``` + + At receiver terminal, run + + ```bash + ./bazel-bin/psi/apps/psi_launcher/main --config $(pwd)/examples/pir/config/apsi_receiver.json ++./bazel-bin/psi/apps/psi_launcher/main --config $(pwd)/examples/pir/config/kcal_pir_receiver.json + ``` + + ## Run Server with Full Mode (No Seperate Setup Stage) +diff --git a/examples/pir/config/apsi_receiver.json b/examples/pir/config/apsi_receiver.json +index 2c7f56b..f62177a 100644 +--- a/examples/pir/config/apsi_receiver.json ++++ b/examples/pir/config/apsi_receiver.json +@@ -3,7 +3,7 @@ + "threads": 1, + "query_file": "/tmp/query.csv", + "output_file": "/tmp/result.csv", +- "params_file": "/tmp/100K-1-16.json", ++ "params_file": "/tmp/1M-1024-cmp.json", + "log_level": "info" + }, + "link_config": { +diff --git a/examples/pir/config/apsi_sender_online.json b/examples/pir/config/apsi_sender_online.json +index b72e2e8..7444e5c 100644 +--- a/examples/pir/config/apsi_sender_online.json ++++ b/examples/pir/config/apsi_sender_online.json +@@ -1,6 +1,6 @@ + { + "apsi_sender_config": { +- "threads": 1, ++ "threads": 16, + "db_file": "/tmp/sdb", + "log_level": "info" + }, +diff --git a/examples/pir/config/apsi_sender_setup.json b/examples/pir/config/apsi_sender_setup.json +index e3e6991..c35b8da 100644 +--- a/examples/pir/config/apsi_sender_setup.json ++++ b/examples/pir/config/apsi_sender_setup.json +@@ -1,7 +1,7 @@ + { + "apsi_sender_config": { + "source_file": "/tmp/db.csv", +- "params_file": "/tmp/100K-1-16.json", ++ "params_file": "/tmp/1M-1024-cmp.json", + "sdb_out_file": "/tmp/sdb", + "save_db_only": true + } +diff --git a/examples/pir/config/kcal_pir_receiver.json b/examples/pir/config/kcal_pir_receiver.json +new file mode 100644 +index 0000000..14e97e6 +--- /dev/null ++++ b/examples/pir/config/kcal_pir_receiver.json +@@ -0,0 +1,21 @@ ++{ ++ "kcal_pir_receiver_config": { ++ "threads": 16, ++ "query_file": "/tmp/query.csv", ++ "output_file": "/tmp/result.csv", ++ "is_dummy_mode": true ++ }, ++ "link_config": { ++ "parties": [ ++ { ++ "id": "sender", ++ "host": "127.0.0.1:5300" ++ }, ++ { ++ "id": "receiver", ++ "host": "127.0.0.1:5400" ++ } ++ ] ++ }, ++ "self_link_party": "receiver" ++} +\ No newline at end of file +diff --git a/examples/pir/config/kcal_pir_sender.json b/examples/pir/config/kcal_pir_sender.json +new file mode 100644 +index 0000000..c2579cb +--- /dev/null ++++ b/examples/pir/config/kcal_pir_sender.json +@@ -0,0 +1,19 @@ ++{ ++ "kcal_pir_sender_config": { ++ "threads": 16, ++ "db_file": "/tmp/db.csv" ++ }, ++ "link_config": { ++ "parties": [ ++ { ++ "id": "sender", ++ "host": "127.0.0.1:5300" ++ }, ++ { ++ "id": "receiver", ++ "host": "127.0.0.1:5400" ++ } ++ ] ++ }, ++ "self_link_party": "sender" ++} +\ No newline at end of file +diff --git a/examples/psi/README.md b/examples/psi/README.md +index 9327ec7..42f2a83 100644 +--- a/examples/psi/README.md ++++ b/examples/psi/README.md +@@ -7,7 +7,7 @@ + 1. Compile the binary + + ```bash +- bazel build //psi:main -c opt ++ bazel build //... -c opt + ``` + + 2. Generate test data +@@ -36,12 +36,14 @@ + + ```bash + ./bazel-bin/psi/apps/psi_launcher/main --config $(pwd)/examples/psi/config/ecdh_receiver_recovery.json ++ ./bazel-bin/psi/apps/psi_launcher/main --config $(pwd)/examples/psi/config/kcal_receiver.json + ``` + + For **sender** terminal, + + ```bash + ./bazel-bin/psi/apps/psi_launcher/main --config $(pwd)/examples/psi/config/ecdh_sender_recovery.json ++ ./bazel-bin/psi/apps/psi_launcher/main --config $(pwd)/examples/psi/config/kcal_sender.json + ``` + + ## 2P UB PSI +diff --git a/examples/psi/config/kcal_receiver.json b/examples/psi/config/kcal_receiver.json +new file mode 100644 +index 0000000..72e2694 +--- /dev/null ++++ b/examples/psi/config/kcal_receiver.json +@@ -0,0 +1,42 @@ ++{ ++ "psi_config": { ++ "protocol_config": { ++ "protocol": "PROTOCOL_KCAL", ++ "kcal_config": { ++ "thread_count": 16 ++ }, ++ "role": "ROLE_RECEIVER", ++ "broadcast_result": true ++ }, ++ "input_config": { ++ "type": "IO_TYPE_FILE_CSV", ++ "path": "/tmp/receiver_input.csv" ++ }, ++ "output_config": { ++ "type": "IO_TYPE_FILE_CSV", ++ "path": "/tmp/kcal_receiver_output.csv" ++ }, ++ "keys": ["id_0", "id_1"], ++ "debug_options": { ++ "trace_path": "/tmp/kcal_receiver.trace" ++ }, ++ "disable_alignment": true, ++ "recovery_config": { ++ "enabled": false, ++ "folder": "/tmp/kcal_receiver_cache" ++ } ++ }, ++ "link_config": { ++ "parties": [ ++ { ++ "id": "receiver", ++ "host": "127.0.0.1:5300" ++ }, ++ { ++ "id": "sender", ++ "host": "127.0.0.1:5400" ++ } ++ ] ++ }, ++ "self_link_party": "receiver" ++} +diff --git a/examples/psi/config/kcal_sender.json b/examples/psi/config/kcal_sender.json +new file mode 100644 +index 0000000..34abca9 +--- /dev/null ++++ b/examples/psi/config/kcal_sender.json +@@ -0,0 +1,42 @@ ++{ ++ "psi_config": { ++ "protocol_config": { ++ "protocol": "PROTOCOL_KCAL", ++ "kcal_config": { ++ "thread_count": 16 ++ }, ++ "role": "ROLE_SENDER", ++ "broadcast_result": true ++ }, ++ "input_config": { ++ "type": "IO_TYPE_FILE_CSV", ++ "path": "/tmp/sender_input.csv" ++ }, ++ "output_config": { ++ "type": "IO_TYPE_FILE_CSV", ++ "path": "/tmp/kcal_sender_output.csv" ++ }, ++ "keys": ["id_0", "id_1"], ++ "debug_options": { ++ "trace_path": "/tmp/kcal_sender.trace" ++ }, ++ "disable_alignment": true, ++ "recovery_config": { ++ "enabled": false, ++ "folder": "/tmp/kcal_sender_cache" ++ } ++ }, ++ "link_config": { ++ "parties": [ ++ { ++ "id": "receiver", ++ "host": "127.0.0.1:5300" ++ }, ++ { ++ "id": "sender", ++ "host": "127.0.0.1:5400" ++ } ++ ] ++ }, ++ "self_link_party": "sender" ++} +diff --git a/examples/psi/config/rr22_receiver_recovery.json b/examples/psi/config/rr22_receiver_recovery.json +index 9df3437..24acf48 100644 +--- a/examples/psi/config/rr22_receiver_recovery.json ++++ b/examples/psi/config/rr22_receiver_recovery.json +@@ -3,7 +3,10 @@ + "protocol_config": { + "protocol": "PROTOCOL_RR22", + "role": "ROLE_RECEIVER", +- "broadcast_result": true ++ "broadcast_result": true, ++ "rr22_config": { ++ "low_comm_mode": true ++ } + }, + "input_config": { + "type": "IO_TYPE_FILE_CSV", +diff --git a/examples/psi/config/rr22_sender_recovery.json b/examples/psi/config/rr22_sender_recovery.json +index d840dfd..0033330 100644 +--- a/examples/psi/config/rr22_sender_recovery.json ++++ b/examples/psi/config/rr22_sender_recovery.json +@@ -3,7 +3,10 @@ + "protocol_config": { + "protocol": "PROTOCOL_RR22", + "role": "ROLE_SENDER", +- "broadcast_result": true ++ "broadcast_result": true, ++ "rr22_config": { ++ "low_comm_mode": true ++ } + }, + "input_config": { + "type": "IO_TYPE_FILE_CSV", +diff --git a/psi/algorithm/kcal_pir/BUILD.bazel b/psi/algorithm/kcal_pir/BUILD.bazel +new file mode 100644 +index 0000000..72f2f14 +--- /dev/null ++++ b/psi/algorithm/kcal_pir/BUILD.bazel +@@ -0,0 +1,28 @@ ++load("//bazel:psi.bzl", "psi_cc_binary", "psi_cc_library", "psi_cc_test") ++ ++package(default_visibility = ["//visibility:public"]) ++ ++psi_cc_library( ++ name = "pir_receiver", ++ srcs = ["pir_receiver.cc"], ++ hdrs = ["pir_receiver.h"], ++ deps = [ ++ "//psi/kcal_adaptor:data_helper", ++ "//psi/kcal_adaptor:traffic_wrapper", ++ "//psi/kcal_adaptor:global_manager", ++ "//psi/utils:arrow_csv_batch_provider", ++ "@kunpeng_kcal_middleware//:kcal_middleware", ++ ] ++) ++ ++psi_cc_library( ++ name = "pir_sender", ++ srcs = ["pir_sender.cc"], ++ hdrs = ["pir_sender.h"], ++ deps = [ ++ "//psi/kcal_adaptor:traffic_wrapper", ++ "//psi/kcal_adaptor:global_manager", ++ "//psi/utils:arrow_csv_batch_provider", ++ "@kunpeng_kcal_middleware//:kcal_middleware", ++ ] ++) +diff --git a/psi/algorithm/kcal_pir/pir_receiver.cc b/psi/algorithm/kcal_pir/pir_receiver.cc +new file mode 100644 +index 0000000..3761e18 +--- /dev/null ++++ b/psi/algorithm/kcal_pir/pir_receiver.cc +@@ -0,0 +1,129 @@ ++#include "psi/algorithm/kcal_pir/pir_receiver.h" ++ ++#include ++ ++#include "psi/kcal_adaptor/data_helper.h" ++#include "psi/kcal_adaptor/global_manager.h" ++#include "psi/kcal_adaptor/traffic_wrapper.h" ++#include "psi/utils/arrow_csv_batch_provider.h" ++ ++namespace psi::kcal::pir { ++ ++KcalPirReceiver::KcalPirReceiver(const ReceiverOptions& options, ++ std::shared_ptr lctx) ++ : lctx_(std::move(lctx)) { ++ options_ = std::make_shared(); ++ *options_ = options; ++ pir_ = std::make_unique<::kcal::Pir>(); ++} ++ ++void KcalPirReceiver::Init() { ++ SPDLOG_INFO("Kcal pir receiver init."); ++ lctx_->ConnectToMesh(); ++ GlobalManager::GetInstance().InitializeCtx(lctx_); ++ ::kcal::KCAL_Config kcal_cfg; ++ kcal_cfg.fixBits = 1; ++ kcal_cfg.nodeId = lctx_->Rank(); ++ kcal_cfg.threadCount = options_->thread_count; ++ kcal_cfg.worldSize = lctx_->WorldSize(); ++ ++ TEE_NET_RES net_res; ++ net_res.funcRecvData = GlobalRecvWrapper; ++ net_res.funcSendData = GlobalSendWrapper; ++ ++ kcal_ctx_ = ::kcal::Context::Create(kcal_cfg, &net_res, ++ ::kcal::KCAL_AlgorithmsType::PIR); ++ ++ pir_->Init(kcal_ctx_); ++} ++ ++int KcalPirReceiver::PreProcess() { ++ SPDLOG_INFO("Kcal pir receiver preprocess."); ++ std::vector keys{"key"}; ++ size_t batch_size = 1 << 20; ++ auto key_data = ReadPirQuery(options_->query_file, keys, batch_size); ++ ++ input_ = std::make_unique(); ++ input_->Fill(key_data); ++ SPDLOG_INFO("Kcal pir receiver input size = {}.", input_->Size()); ++ return 0; ++} ++ ++int KcalPirReceiver::Connect() { ++ yacl::link::Barrier(lctx_, "pir preprocess"); ++ return 0; ++} ++ ++int KcalPirReceiver::Online(DG_DummyMode dummyMode) { ++ SPDLOG_INFO("Kcal pir receiver online query."); ++ output_ = std::make_unique(); ++ int ret = pir_->ClientQuery(input_->Get(), output_->GetSecondaryPointer(), ++ dummyMode); ++ SPDLOG_INFO("Kcal pir receiver query ret = {}.", ret); ++ return ret; ++} ++ ++int KcalPirReceiver::PostProcess(int* match_cnt) { ++ SPDLOG_INFO("Kcal pir receiver postprocess."); ++ *match_cnt = output_->Size(); ++ SPDLOG_INFO("Kcal pir result size = {}.", *match_cnt); ++ ++ std::ofstream outFile(options_->result_file); ++ if (!outFile.is_open()) { ++ SPDLOG_INFO("Cannot open result file {}.", options_->result_file); ++ return -1; ++ } ++ outFile << "key,value\n"; ++ size_t str_nums = output_->Size(); ++ for (size_t i = 0; i < str_nums; ++i) { ++ outFile << input_->Get()->data.strings[i].str << ','; ++ outFile << output_->Get()->data.strings[i].str << '\n'; ++ } ++ outFile.close(); ++ return 0; ++} ++ ++std::vector KcalPirReceiver::ReadPirQuery( ++ const std::string& file_path, const std::vector& keys, ++ size_t batch_size, char delimiter) { ++ const std::vector labels = {}; ++ psi::ArrowCsvBatchProvider csv_batch_provider(file_path, keys, batch_size, ++ labels, delimiter); ++ std::vector key_data; ++ bool is_reach_end = false; ++ while (!is_reach_end) { ++ auto batch = csv_batch_provider.ReadNextBatch(); ++ if (batch.size() < batch_size) { ++ is_reach_end = true; ++ } ++ ++ if (batch.size() > 0) { ++ key_data.insert(key_data.end(), batch.begin(), batch.end()); ++ } ++ } ++ return key_data; ++} ++ ++int RunKcalReceiver(const ReceiverOptions& options, ++ const std::shared_ptr& lctx, ++ int* match_cnt) { ++ KcalPirReceiver pir_receiver(options, lctx); ++ pir_receiver.Init(); ++ pir_receiver.PreProcess(); ++ pir_receiver.Connect(); ++ auto start_time = std::chrono::steady_clock::now(); ++ DG_DummyMode dummy_mode = NORMAL; ++ if (options.is_dummy_mode) { ++ dummy_mode = DUMMY; ++ } ++ pir_receiver.Online(dummy_mode); ++ auto end_time = std::chrono::steady_clock::now(); ++ auto dur = std::chrono::duration_cast(end_time - ++ start_time) ++ .count(); ++ SPDLOG_INFO("Online query time {} ms", dur); ++ pir_receiver.PostProcess(match_cnt); ++ return 0; ++} ++ ++} // namespace psi::kcal::pir +diff --git a/psi/algorithm/kcal_pir/pir_receiver.h b/psi/algorithm/kcal_pir/pir_receiver.h +new file mode 100644 +index 0000000..98de787 +--- /dev/null ++++ b/psi/algorithm/kcal_pir/pir_receiver.h +@@ -0,0 +1,48 @@ ++#pragma once ++ ++#include "kcal/core/context.h" ++#include "kcal/operator/kcal_pir.h" ++#include "yacl/link/link.h" ++ ++#include "psi/kcal_adaptor/data_helper.h" ++ ++namespace psi::kcal::pir { ++ ++struct ReceiverOptions { ++ int thread_count = 1; ++ std::string query_file; ++ std::string result_file; ++ bool is_dummy_mode = false; ++}; ++ ++class KcalPirReceiver { ++ public: ++ explicit KcalPirReceiver(const ReceiverOptions& options, ++ std::shared_ptr lctx = nullptr); ++ ++ ~KcalPirReceiver() = default; ++ ++ void Init(); ++ int PreProcess(); ++ int Connect(); ++ int Online(DG_DummyMode dummyMode = NORMAL); ++ int PostProcess(int* match_cnt); ++ ++ private: ++ static std::vector ReadPirQuery( ++ const std::string& file_path, const std::vector& keys, ++ size_t batch_size, char delimiter = ','); ++ ++ std::shared_ptr lctx_; ++ std::shared_ptr<::kcal::Context> kcal_ctx_; ++ std::unique_ptr input_; ++ std::unique_ptr output_; ++ std::shared_ptr options_; ++ std::unique_ptr<::kcal::Pir> pir_; ++}; ++ ++int RunKcalReceiver(const ReceiverOptions& options, ++ const std::shared_ptr& lctx = nullptr, ++ int* match_cnt = nullptr); ++ ++} // namespace psi::kcal::pir +diff --git a/psi/algorithm/kcal_pir/pir_sender.cc b/psi/algorithm/kcal_pir/pir_sender.cc +new file mode 100644 +index 0000000..3bdd810 +--- /dev/null ++++ b/psi/algorithm/kcal_pir/pir_sender.cc +@@ -0,0 +1,156 @@ ++#include "psi/algorithm/kcal_pir/pir_sender.h" ++ ++#include "psi/kcal_adaptor/global_manager.h" ++#include "psi/kcal_adaptor/traffic_wrapper.h" ++#include "psi/utils/arrow_csv_batch_provider.h" ++ ++namespace psi::kcal::pir { ++ ++KcalPirSender::KcalPirSender(const SenderOptions& options, ++ std::shared_ptr lctx) ++ : lctx_(std::move(lctx)) { ++ options_ = std::make_shared(); ++ *options_ = options; ++ pir_ = std::make_unique<::kcal::Pir>(); ++} ++ ++void KcalPirSender::Init() { ++ SPDLOG_INFO("Kcal pir sender init."); ++ lctx_->ConnectToMesh(); ++ GlobalManager::GetInstance().InitializeCtx(lctx_); ++ ++ ::kcal::KCAL_Config kcal_cfg{}; ++ kcal_cfg.fixBits = 1; ++ kcal_cfg.nodeId = lctx_->Rank(); ++ kcal_cfg.threadCount = options_->thread_count; ++ kcal_cfg.worldSize = lctx_->WorldSize(); ++ ++ TEE_NET_RES net_res; ++ net_res.funcRecvData = GlobalRecvWrapper; ++ net_res.funcSendData = GlobalSendWrapper; ++ ++ kcal_ctx_ = ::kcal::Context::Create(kcal_cfg, &net_res, ++ ::kcal::KCAL_AlgorithmsType::PIR); ++ pir_->Init(kcal_ctx_); ++} ++ ++int KcalPirSender::PreProcess() { ++ SPDLOG_INFO("Kcal pir sender preprocess."); ++ std::vector keys{"key"}; ++ std::vector labels{"value"}; ++ size_t batch_size = 1 << 20; ++ auto key_label = ReadPirDataSet(options_->db_file, keys, batch_size, labels); ++ ++ DG_PairList* pair_list = nullptr; ++ BuildDgPairList(key_label, &pair_list); ++ SPDLOG_INFO("Kcal pir sender db size = {}.", pair_list->size); ++ int ret = pir_->ServerPreProcess(pair_list); ++ ReleaseDgPairList(pair_list); ++ SPDLOG_INFO("Kcal pir sender preprocess ret = {}.", ret); ++ return ret; ++} ++ ++int KcalPirSender::Connect() { ++ yacl::link::Barrier(lctx_, "pir preprocess"); ++ return 0; ++} ++ ++int KcalPirSender::Online() { ++ SPDLOG_INFO("Kcal pir sender answer."); ++ int ret = pir_->ServerAnswer(); ++ SPDLOG_INFO("Kcal pir sender answer ret = {}.", ret); ++ return ret; ++} ++ ++int KcalPirSender::PostProcess() { ++ SPDLOG_INFO("Kcal pir sender postprocess"); ++ return 0; ++} ++ ++std::pair, std::vector> ++KcalPirSender::ReadPirDataSet(const std::string& file_path, ++ const std::vector& keys, ++ size_t batch_size, ++ const std::vector& labels, ++ char delimiter) { ++ psi::ArrowCsvBatchProvider csv_batch_provider(file_path, keys, batch_size, ++ labels, delimiter); ++ std::vector key_data; ++ std::vector label_data; ++ bool is_reach_end = false; ++ while (!is_reach_end) { ++ auto batch = csv_batch_provider.ReadNextLabeledBatch(); ++ if (batch.first.size() < batch_size) { ++ is_reach_end = true; ++ } ++ ++ if (batch.first.size() > 0) { ++ key_data.insert(key_data.end(), batch.first.begin(), batch.first.end()); ++ label_data.insert(label_data.end(), batch.second.begin(), ++ batch.second.end()); ++ } ++ } ++ return {key_data, label_data}; ++} ++ ++void KcalPirSender::ReleasePairString(DG_String* dgString) { ++ if (dgString != nullptr) { ++ delete[] dgString->str; ++ delete dgString; ++ } ++} ++ ++void KcalPirSender::ReleaseDgPairList(DG_PairList* pairList) { ++ if (pairList != nullptr) { ++ for (unsigned long i = 0; i < pairList->size; ++i) { ++ ReleasePairString(pairList->dgPair[i].key); ++ ReleasePairString(pairList->dgPair[i].value); ++ } ++ } ++ delete[] pairList->dgPair; ++ delete pairList; ++} ++ ++DG_String* KcalPirSender::BuildPairString(const std::string& s) { ++ auto* dgString = new DG_String(); ++ dgString->size = s.size() + 1; ++ dgString->str = new char[dgString->size + 1]; ++ dgString->str[dgString->size] = '\0'; ++ strcpy(dgString->str, s.c_str()); ++ return dgString; ++} ++ ++void KcalPirSender::BuildDgPairList( ++ const std::pair, std::vector>& ++ key_label, ++ DG_PairList** pair_list) { ++ *pair_list = new DG_PairList(); ++ size_t size = key_label.first.size(); ++ (*pair_list)->dgPair = new DG_Pair[size]; ++ ++ for (unsigned long i = 0; i < size; ++i) { ++ (*pair_list)->dgPair[i].key = BuildPairString(key_label.first[i]); ++ (*pair_list)->dgPair[i].value = BuildPairString(key_label.second[i]); ++ } ++ (*pair_list)->size = size; ++} ++ ++int RunKcalSender(const SenderOptions& options, ++ const std::shared_ptr& lctx) { ++ KcalPirSender pir_sender(options, lctx); ++ ++ pir_sender.Init(); ++ pir_sender.PreProcess(); ++ pir_sender.Connect(); ++ auto start_time = std::chrono::steady_clock::now(); ++ pir_sender.Online(); ++ auto end_time = std::chrono::steady_clock::now(); ++ auto dur = std::chrono::duration_cast(end_time - ++ start_time) ++ .count(); ++ SPDLOG_INFO("Online query time {} ms", dur); ++ pir_sender.PostProcess(); ++ return 0; ++} ++ ++} // namespace psi::kcal::pir +diff --git a/psi/algorithm/kcal_pir/pir_sender.h b/psi/algorithm/kcal_pir/pir_sender.h +new file mode 100644 +index 0000000..f44bd39 +--- /dev/null ++++ b/psi/algorithm/kcal_pir/pir_sender.h +@@ -0,0 +1,55 @@ ++#pragma once ++ ++#include "kcal/core/context.h" ++#include "kcal/operator/kcal_pir.h" ++#include "yacl/link/link.h" ++ ++namespace psi::kcal::pir { ++ ++struct SenderOptions { ++ int thread_count = 1; ++ std::string db_file; ++}; ++ ++class KcalPirSender { ++ public: ++ explicit KcalPirSender(const SenderOptions& options, ++ std::shared_ptr lctx = nullptr); ++ ++ ~KcalPirSender() = default; ++ ++ void Init(); ++ int PreProcess(); ++ int Connect(); ++ int Online(); ++ int PostProcess(); ++ ++ private: ++ static std::pair, std::vector> ++ ReadPirDataSet(const std::string& file_path, ++ const std::vector& keys, ++ size_t batch_size = 1 << 20, ++ const std::vector& labels = {}, ++ char delimiter = ','); ++ ++ static void ReleasePairString(DG_String* dgString); ++ ++ static void ReleaseDgPairList(DG_PairList* pairList); ++ ++ static DG_String* BuildPairString(const std::string& s); ++ ++ static void BuildDgPairList( ++ const std::pair, std::vector>& ++ key_label, ++ DG_PairList** pair_list); ++ ++ std::shared_ptr lctx_; ++ std::shared_ptr<::kcal::Context> kcal_ctx_; ++ std::shared_ptr options_; ++ std::unique_ptr<::kcal::Pir> pir_; ++}; ++ ++int RunKcalSender(const SenderOptions& options, ++ const std::shared_ptr& lctx = nullptr); ++ ++} // namespace psi::kcal::pir +diff --git a/psi/algorithm/kcal_psi/BUILD.bazel b/psi/algorithm/kcal_psi/BUILD.bazel +new file mode 100644 +index 0000000..14dd8df +--- /dev/null ++++ b/psi/algorithm/kcal_psi/BUILD.bazel +@@ -0,0 +1,27 @@ ++load("//bazel:psi.bzl", "psi_cc_binary", "psi_cc_library", "psi_cc_test") ++ ++package(default_visibility = ["//visibility:public"]) ++ ++psi_cc_library( ++ name = "receiver", ++ srcs = ["receiver.cc"], ++ hdrs = ["receiver.h"], ++ deps = [ ++ "//psi/kcal_adaptor:data_helper", ++ "//psi/kcal_adaptor:traffic_wrapper", ++ "//psi:interface", ++ "//psi/utils:arrow_csv_batch_provider", ++ ] ++) ++ ++psi_cc_library( ++ name = "sender", ++ srcs = ["sender.cc"], ++ hdrs = ["sender.h"], ++ deps = [ ++ "//psi/kcal_adaptor:data_helper", ++ "//psi/kcal_adaptor:traffic_wrapper", ++ "//psi:interface", ++ "//psi/utils:arrow_csv_batch_provider", ++ ] ++) +diff --git a/psi/algorithm/kcal_psi/receiver.cc b/psi/algorithm/kcal_psi/receiver.cc +new file mode 100644 +index 0000000..5a6bd69 +--- /dev/null ++++ b/psi/algorithm/kcal_psi/receiver.cc +@@ -0,0 +1,62 @@ ++#include "psi/algorithm/kcal_psi/receiver.h" ++ ++#include ++ ++#include "kcal/operator/kcal_psi.h" ++ ++#include "psi/kcal_adaptor/data_helper.h" ++#include "psi/kcal_adaptor/global_manager.h" ++#include "psi/kcal_adaptor/traffic_wrapper.h" ++ ++namespace psi::kcal { ++ ++KcalPsiReceiver::KcalPsiReceiver(const v2::PsiConfig& config, ++ std::shared_ptr lctx) ++ : AbstractPsiReceiver(config, std::move(lctx)) {} ++ ++void KcalPsiReceiver::Init() { ++ AbstractPsiReceiver::Init(); ++ // 全局网络资源管理 ++ GlobalManager::GetInstance().InitializeCtx(lctx_); ++ // 配置 ++ ::kcal::KCAL_Config kcal_cfg; ++ kcal_cfg.fixBits = 1; ++ kcal_cfg.nodeId = lctx_->Rank(); ++ kcal_cfg.threadCount = config_.protocol_config().kcal_config().thread_count(); ++ kcal_cfg.worldSize = lctx_->WorldSize(); ++ // 网络接口 ++ TEE_NET_RES net_res; ++ net_res.funcRecvData = GlobalRecvWrapper; ++ net_res.funcSendData = GlobalSendWrapper; ++ // 计算上下文 ++ kcal_ctx_ = ::kcal::Context::Create(kcal_cfg, &net_res, ++ ::kcal::KCAL_AlgorithmsType::PSI); ++} ++ ++void KcalPsiReceiver::PreProcess() { ++ input_ = std::make_unique(); ++ input_->Fill(batch_provider_); ++} ++ ++void KcalPsiReceiver::Online() { ++ auto start_time = std::chrono::high_resolution_clock::now(); ++ ::kcal::Psi psi; ++ psi.Init(kcal_ctx_); ++ output_ = std::make_unique(); ++ psi.Run(input_->Get(), output_->GetSecondaryPointer(), TEE_OUTPUT_INDEX); ++ auto end_time = std::chrono::high_resolution_clock::now(); ++ auto duration = std::chrono::duration_cast( ++ end_time - start_time) ++ .count(); ++ SPDLOG_INFO("Online cost: {} ms", duration); ++} ++ ++void KcalPsiReceiver::PostProcess() { ++ for (int i = 0; i < output_->Size(); ++i) { ++ intersection_indices_writer_->WriteCache( ++ output_->Get()->data.u64Numbers[i]); ++ } ++ intersection_indices_writer_->Commit(); ++} ++ ++} // namespace psi::kcal +\ No newline at end of file +diff --git a/psi/algorithm/kcal_psi/receiver.h b/psi/algorithm/kcal_psi/receiver.h +new file mode 100644 +index 0000000..247e2cb +--- /dev/null ++++ b/psi/algorithm/kcal_psi/receiver.h +@@ -0,0 +1,34 @@ ++#pragma once ++ ++#include "kcal/core/context.h" ++ ++#include "psi/interface.h" ++#include "psi/kcal_adaptor/data_helper.h" ++#include "psi/utils/arrow_csv_batch_provider.h" ++ ++#include "psi/proto/psi_v2.pb.h" ++ ++namespace psi::kcal { ++ ++class KcalPsiReceiver final : public AbstractPsiReceiver { ++ public: ++ explicit KcalPsiReceiver(const v2::PsiConfig& config, ++ std::shared_ptr lctx = nullptr); ++ ++ ~KcalPsiReceiver() override = default; ++ ++ private: ++ void Init() override; ++ ++ void PreProcess() override; ++ ++ void Online() override; ++ ++ void PostProcess() override; ++ ++ std::shared_ptr<::kcal::Context> kcal_ctx_; ++ std::unique_ptr input_; ++ std::unique_ptr output_; ++}; ++ ++} // namespace psi::kcal +diff --git a/psi/algorithm/kcal_psi/sender.cc b/psi/algorithm/kcal_psi/sender.cc +new file mode 100644 +index 0000000..551b0ff +--- /dev/null ++++ b/psi/algorithm/kcal_psi/sender.cc +@@ -0,0 +1,60 @@ ++#include "psi/algorithm/kcal_psi/sender.h" ++ ++#include "kcal/operator/kcal_psi.h" ++ ++#include "psi/kcal_adaptor/data_helper.h" ++#include "psi/kcal_adaptor/global_manager.h" ++#include "psi/kcal_adaptor/traffic_wrapper.h" ++ ++namespace psi::kcal { ++ ++KcalPsiSender::KcalPsiSender(const v2::PsiConfig& config, ++ std::shared_ptr lctx) ++ : AbstractPsiSender(config, std::move(lctx)) {} ++ ++void KcalPsiSender::Init() { ++ AbstractPsiSender::Init(); ++ // 全局网络资源管理 ++ GlobalManager::GetInstance().InitializeCtx(lctx_); ++ // 配置 ++ ::kcal::KCAL_Config kcal_cfg; ++ kcal_cfg.fixBits = 1; ++ kcal_cfg.nodeId = lctx_->Rank(); ++ kcal_cfg.threadCount = config_.protocol_config().kcal_config().thread_count(); ++ kcal_cfg.worldSize = lctx_->WorldSize(); ++ // 网络接口 ++ TEE_NET_RES net_res; ++ net_res.funcRecvData = GlobalRecvWrapper; ++ net_res.funcSendData = GlobalSendWrapper; ++ // 计算上下文 ++ kcal_ctx_ = ::kcal::Context::Create(kcal_cfg, &net_res, ++ ::kcal::KCAL_AlgorithmsType::PSI); ++} ++ ++void KcalPsiSender::PreProcess() { ++ input_ = std::make_unique(); ++ input_->Fill(batch_provider_); ++} ++ ++void KcalPsiSender::Online() { ++ auto start_time = std::chrono::high_resolution_clock::now(); ++ ::kcal::Psi psi; ++ psi.Init(kcal_ctx_); ++ output_ = std::make_unique(); ++ psi.Run(input_->Get(), output_->GetSecondaryPointer(), TEE_OUTPUT_INDEX); ++ auto end_time = std::chrono::high_resolution_clock::now(); ++ auto duration = std::chrono::duration_cast( ++ end_time - start_time) ++ .count(); ++ SPDLOG_INFO("Online cost: {} ms", duration); ++} ++ ++void KcalPsiSender::PostProcess() { ++ for (int i = 0; i < output_->Size(); ++i) { ++ intersection_indices_writer_->WriteCache( ++ output_->Get()->data.u64Numbers[i]); ++ } ++ intersection_indices_writer_->Commit(); ++} ++ ++} // namespace psi::kcal +\ No newline at end of file +diff --git a/psi/algorithm/kcal_psi/sender.h b/psi/algorithm/kcal_psi/sender.h +new file mode 100644 +index 0000000..b571632 +--- /dev/null ++++ b/psi/algorithm/kcal_psi/sender.h +@@ -0,0 +1,34 @@ ++#pragma once ++ ++#include "kcal/core/context.h" ++ ++#include "psi/interface.h" ++#include "psi/kcal_adaptor/data_helper.h" ++#include "psi/utils/arrow_csv_batch_provider.h" ++ ++#include "psi/proto/psi_v2.pb.h" ++ ++namespace psi::kcal { ++ ++class KcalPsiSender final : public AbstractPsiSender { ++ public: ++ explicit KcalPsiSender(const v2::PsiConfig& config, ++ std::shared_ptr lctx = nullptr); ++ ++ ~KcalPsiSender() override = default; ++ ++ private: ++ void Init() override; ++ ++ void PreProcess() override; ++ ++ void Online() override; ++ ++ void PostProcess() override; ++ ++ std::shared_ptr<::kcal::Context> kcal_ctx_; ++ std::unique_ptr input_; ++ std::unique_ptr output_; ++}; ++ ++} // namespace psi::kcal +diff --git a/psi/apps/psi_launcher/BUILD.bazel b/psi/apps/psi_launcher/BUILD.bazel +index ffa6a3f..15752ff 100644 +--- a/psi/apps/psi_launcher/BUILD.bazel ++++ b/psi/apps/psi_launcher/BUILD.bazel +@@ -29,6 +29,8 @@ psi_cc_library( + "//psi/algorithm/kkrt:sender", + "//psi/algorithm/rr22:receiver", + "//psi/algorithm/rr22:sender", ++ "//psi/algorithm/kcal_psi:receiver", ++ "//psi/algorithm/kcal_psi:sender", + "@yacl//yacl/base:exception", + ], + ) +@@ -52,6 +54,8 @@ psi_cc_library( + "//psi/legacy:bucket_psi", + "//psi/proto:pir_cc_proto", + "//psi/wrapper/apsi/cli:entry", ++ "//psi/algorithm/kcal_pir:pir_sender", ++ "//psi/algorithm/kcal_pir:pir_receiver", + "@boost.algorithm//:boost.algorithm", + ], + ) +@@ -102,6 +106,7 @@ psi_cc_binary( + "//psi:version", + "//psi/proto:entry_cc_proto", + "//psi/utils:resource_manager", ++ "@kunpeng_kcal_middleware//:kcal_middleware", + "@gflags", + ], + ) +diff --git a/psi/apps/psi_launcher/factory.cc b/psi/apps/psi_launcher/factory.cc +index f57f1eb..a96a8ea 100644 +--- a/psi/apps/psi_launcher/factory.cc ++++ b/psi/apps/psi_launcher/factory.cc +@@ -22,6 +22,8 @@ + #include "psi/algorithm/ecdh/sender.h" + #include "psi/algorithm/ecdh/ub_psi/client.h" + #include "psi/algorithm/ecdh/ub_psi/server.h" ++#include "psi/algorithm/kcal_psi/receiver.h" ++#include "psi/algorithm/kcal_psi/sender.h" + #include "psi/algorithm/kkrt/receiver.h" + #include "psi/algorithm/kkrt/sender.h" + #include "psi/algorithm/rr22/receiver.h" +@@ -62,6 +64,16 @@ std::unique_ptr createPsiParty( + YACL_THROW("Role is invalid."); + } + } ++ case v2::Protocol::PROTOCOL_KCAL: { ++ switch (config.protocol_config().role()) { ++ case v2::Role::ROLE_RECEIVER: ++ return std::make_unique(config, lctx); ++ case v2::Role::ROLE_SENDER: ++ return std::make_unique(config, lctx); ++ default: ++ YACL_THROW("Role is invalid."); ++ } ++ } + default: + YACL_THROW("Protocol is unspecified."); + } +diff --git a/psi/apps/psi_launcher/launch.cc b/psi/apps/psi_launcher/launch.cc +index 31deab7..6c6780f 100644 +--- a/psi/apps/psi_launcher/launch.cc ++++ b/psi/apps/psi_launcher/launch.cc +@@ -320,6 +320,30 @@ PirResultReport RunDkPir(const DkPirSenderConfig& dk_pir_sender_config, + return PirResultReport(); + } + ++PirResultReport RunKcalPir( ++ const KcalPirReceiverConfig& kcal_pir_receiver_config, ++ const std::shared_ptr& lctx) { ++ kcal::pir::ReceiverOptions options; ++ options.thread_count = kcal_pir_receiver_config.threads(); ++ options.is_dummy_mode = kcal_pir_receiver_config.is_dummy_mode(); ++ options.query_file = kcal_pir_receiver_config.query_file(); ++ options.result_file = kcal_pir_receiver_config.output_file(); ++ int match_cnt = 0; ++ kcal::pir::RunKcalReceiver(options, lctx, &match_cnt); ++ PirResultReport report; ++ report.set_match_cnt(match_cnt); ++ return report; ++} ++ ++PirResultReport RunKcalPir(const KcalPirSenderConfig& kcal_pir_sender_config, ++ const std::shared_ptr& lctx) { ++ kcal::pir::SenderOptions options; ++ options.thread_count = kcal_pir_sender_config.threads(); ++ options.db_file = kcal_pir_sender_config.db_file(); ++ kcal::pir::RunKcalSender(options, lctx); ++ return {}; ++} ++ + namespace api { + + const std::map kV2ProtoclMap = { +diff --git a/psi/apps/psi_launcher/launch.h b/psi/apps/psi_launcher/launch.h +index 080f550..3e377ee 100644 +--- a/psi/apps/psi_launcher/launch.h ++++ b/psi/apps/psi_launcher/launch.h +@@ -19,6 +19,8 @@ + + #include "yacl/link/context.h" + ++#include "psi/algorithm/kcal_pir/pir_receiver.h" ++#include "psi/algorithm/kcal_pir/pir_sender.h" + #include "psi/apps/psi_launcher/report.h" + #include "psi/config/psi.h" + #include "psi/config/ub_psi.h" +@@ -53,6 +55,13 @@ PirResultReport RunDkPir(const DkPirReceiverConfig& dk_pir_receiver_config, + PirResultReport RunDkPir(const DkPirSenderConfig& dk_pir_sender_config, + const std::shared_ptr& lctx); + ++PirResultReport RunKcalPir( ++ const KcalPirReceiverConfig& kcal_pir_receiver_config, ++ const std::shared_ptr& lctx); ++ ++PirResultReport RunKcalPir(const KcalPirSenderConfig& kcal_pir_sender_config, ++ const std::shared_ptr& lctx); ++ + namespace api { + + namespace internal { +diff --git a/psi/apps/psi_launcher/main.cc b/psi/apps/psi_launcher/main.cc +index 3965006..c39c05d 100644 +--- a/psi/apps/psi_launcher/main.cc ++++ b/psi/apps/psi_launcher/main.cc +@@ -133,6 +133,18 @@ int main(int argc, char* argv[]) { + YACL_ENFORCE(google::protobuf::util::MessageToJsonString( + report, &report_json, json_print_options) + .ok()); ++ } else if (launch_config.has_kcal_pir_sender_config()) { ++ psi::PirResultReport report = ++ psi::RunKcalPir(launch_config.kcal_pir_sender_config(), lctx); ++ YACL_ENFORCE(google::protobuf::util::MessageToJsonString( ++ report, &report_json, json_print_options) ++ .ok()); ++ } else if (launch_config.has_kcal_pir_receiver_config()) { ++ psi::PirResultReport report = ++ psi::RunKcalPir(launch_config.kcal_pir_receiver_config(), lctx); ++ YACL_ENFORCE(google::protobuf::util::MessageToJsonString( ++ report, &report_json, json_print_options) ++ .ok()); + } else { + SPDLOG_WARN("No runtime config is provided."); + } +diff --git a/psi/kcal_adaptor/BUILD.bazel b/psi/kcal_adaptor/BUILD.bazel +new file mode 100644 +index 0000000..e20c134 +--- /dev/null ++++ b/psi/kcal_adaptor/BUILD.bazel +@@ -0,0 +1,32 @@ ++load("//bazel:psi.bzl", "psi_cc_binary", "psi_cc_library", "psi_cc_test") ++ ++package(default_visibility = ["//visibility:public"]) ++ ++psi_cc_library( ++ name = "data_helper", ++ srcs = ["data_helper.cc"], ++ hdrs = ["data_helper.h"], ++ deps = [ ++ "//psi/utils:batch_provider", ++ "@kunpeng_kcal_middleware//:kcal_middleware", ++ ] ++) ++ ++psi_cc_library( ++ name = "global_manager", ++ hdrs = ["global_manager.h"], ++ srcs = ["global_manager.cc"], ++ deps = [ ++ "@yacl//yacl/link", ++ ] ++) ++ ++psi_cc_library( ++ name = "traffic_wrapper", ++ hdrs = ["traffic_wrapper.h"], ++ srcs = ["traffic_wrapper.cc"], ++ deps = [ ++ ":global_manager", ++ "@kunpeng_kcal_middleware//:kcal_middleware", ++ ] ++) +diff --git a/psi/kcal_adaptor/data_helper.cc b/psi/kcal_adaptor/data_helper.cc +new file mode 100644 +index 0000000..ed9f5b9 +--- /dev/null ++++ b/psi/kcal_adaptor/data_helper.cc +@@ -0,0 +1,62 @@ ++#include "psi/kcal_adaptor/data_helper.h" ++ ++#include ++ ++namespace psi::kcal { ++ ++void DataHelper::BuildDgString(const std::vector& strings, ++ DG_String** dg) { ++ auto* dgString = new DG_String[strings.size()]; ++ for (size_t i = 0; i < strings.size(); ++i) { ++ dgString[i].str = strdup(strings[i].c_str()); ++ dgString[i].size = strings[i].size() + 1; ++ } ++ *dg = dgString; ++} ++ ++void DataHelper::ReleaseOutput(DG_TeeOutput** output) { ++ if (output == nullptr || *output == nullptr) { ++ return; ++ } ++ if ((*output)->dataType == MPC_STRING && (*output)->data.strings != nullptr) { ++ for (int i = 0; i < (*output)->size; ++i) { ++ if ((*output)->data.strings[i].str != nullptr) { ++ delete[](*output)->data.strings[i].str; ++ } ++ } ++ delete[](*output)->data.strings; ++ } else if ((*output)->dataType == MPC_DOUBLE && ++ (*output)->data.doubleNumbers != nullptr) { ++ delete[](*output)->data.doubleNumbers; ++ } else if ((*output)->dataType == MPC_INT && ++ (*output)->data.u64Numbers != nullptr) { ++ delete[](*output)->data.u64Numbers; ++ } ++ delete *output; ++ *output = nullptr; ++} ++ ++void KcalInput::Fill(const std::vector& data) { ++ DG_String* strings = nullptr; ++ DataHelper::BuildDgString(data, &strings); ++ ++ input_ = new DG_TeeInput(); ++ input_->data.strings = strings; ++ input_->size = static_cast(data.size()); ++ input_->dataType = MPC_STRING; ++} ++ ++void KcalInput::Fill(std::shared_ptr provider) { ++ std::vector rawData; ++ while (true) { ++ auto data = provider->ReadNextBatch(); ++ if (data.empty()) { ++ break; ++ } ++ rawData.insert(rawData.end(), data.begin(), data.end()); ++ } ++ ++ Fill(rawData); ++} ++ ++} // namespace psi::kcal +diff --git a/psi/kcal_adaptor/data_helper.h b/psi/kcal_adaptor/data_helper.h +new file mode 100644 +index 0000000..ef643c9 +--- /dev/null ++++ b/psi/kcal_adaptor/data_helper.h +@@ -0,0 +1,44 @@ ++#pragma once ++ ++#include ++#include ++#include ++ ++#include "kcal/api/kcal_api.h" ++ ++#include "psi/utils/batch_provider.h" ++ ++namespace psi::kcal { ++ ++class DataHelper { ++ public: ++ static void BuildDgString(const std::vector& strings, ++ DG_String** dg); ++ ++ static void ReleaseOutput(DG_TeeOutput** output); ++}; ++ ++class KcalInput { ++ public: ++ KcalInput() = default; ++ ++ ~KcalInput() { DataHelper::ReleaseOutput(&input_); } ++ ++ KcalInput(const KcalInput&) = delete; ++ KcalInput& operator=(const KcalInput&) = delete; ++ ++ void Fill(const std::vector& data); ++ void Fill(std::shared_ptr provider); ++ ++ DG_TeeInput* Get() { return input_; } ++ DG_TeeInput** GetSecondaryPointer() { return &input_; } ++ ++ int Size() { return input_->size; } ++ ++ private: ++ DG_TeeInput* input_ = nullptr; ++}; ++ ++using KcalOutput = KcalInput; ++ ++} // namespace psi::kcal +\ No newline at end of file +diff --git a/psi/kcal_adaptor/global_manager.cc b/psi/kcal_adaptor/global_manager.cc +new file mode 100644 +index 0000000..92ca35c +--- /dev/null ++++ b/psi/kcal_adaptor/global_manager.cc +@@ -0,0 +1,19 @@ ++#include "psi/kcal_adaptor/global_manager.h" ++ ++namespace psi::kcal { ++ ++GlobalManager& GlobalManager::GetInstance() { ++ static GlobalManager instance; ++ return instance; ++} ++ ++void GlobalManager::InitializeCtx( ++ const std::shared_ptr& lctx) { ++ lctx_ = lctx; ++} ++ ++std::shared_ptr& GlobalManager::GetLink() { return lctx_; } ++ ++void GlobalManager::CleanLink() { lctx_.reset(); } ++ ++} // namespace psi::kcal +\ No newline at end of file +diff --git a/psi/kcal_adaptor/global_manager.h b/psi/kcal_adaptor/global_manager.h +new file mode 100644 +index 0000000..c596b0f +--- /dev/null ++++ b/psi/kcal_adaptor/global_manager.h +@@ -0,0 +1,25 @@ ++#pragma once ++ ++#include "yacl/link/link.h" ++ ++namespace psi::kcal { ++ ++class GlobalManager { ++ public: ++ static GlobalManager& GetInstance(); ++ ++ ~GlobalManager() = default; ++ ++ void InitializeCtx(const std::shared_ptr& lctx); ++ ++ std::shared_ptr& GetLink(); ++ ++ void CleanLink(); ++ ++ private: ++ GlobalManager() = default; ++ ++ std::shared_ptr lctx_; ++}; ++ ++} // namespace psi::kcal +diff --git a/psi/kcal_adaptor/traffic_wrapper.cc b/psi/kcal_adaptor/traffic_wrapper.cc +new file mode 100644 +index 0000000..4ea541f +--- /dev/null ++++ b/psi/kcal_adaptor/traffic_wrapper.cc +@@ -0,0 +1,26 @@ ++#include "psi/kcal_adaptor/traffic_wrapper.h" ++ ++#include "psi/kcal_adaptor/global_manager.h" ++ ++namespace psi::kcal { ++ ++int GlobalSendWrapper(struct TeeNodeInfo* nodeInfo, unsigned char* buf, ++ unsigned long long len) { ++ yacl::Buffer msg(reinterpret_cast(buf), len); ++ std::string tag = "kcal_send"; ++ auto lctx = GlobalManager::GetInstance().GetLink(); ++ lctx->SendAsync(nodeInfo->nodeId, msg, tag); ++ return 0; ++} ++ ++int GlobalRecvWrapper(struct TeeNodeInfo* nodeInfo, unsigned char* buf, ++ unsigned long long* len) { ++ std::string tag = "kcal_recv"; ++ auto lctx = GlobalManager::GetInstance().GetLink(); ++ auto msg = lctx->Recv(nodeInfo->nodeId, tag); ++ memcpy(buf, msg.data(), msg.size()); ++ *len = msg.size(); ++ return 0; ++} ++ ++} // namespace psi::kcal +diff --git a/psi/kcal_adaptor/traffic_wrapper.h b/psi/kcal_adaptor/traffic_wrapper.h +new file mode 100644 +index 0000000..fc00fcc +--- /dev/null ++++ b/psi/kcal_adaptor/traffic_wrapper.h +@@ -0,0 +1,13 @@ ++#pragma once ++ ++#include "kcal/api/kcal_api.h" ++ ++namespace psi::kcal { ++ ++int GlobalSendWrapper(struct TeeNodeInfo* nodeInfo, unsigned char* buf, ++ unsigned long long len); ++ ++int GlobalRecvWrapper(struct TeeNodeInfo* nodeInfo, unsigned char* buf, ++ unsigned long long* len); ++ ++} // namespace psi::kcal +\ No newline at end of file +diff --git a/psi/proto/entry.proto b/psi/proto/entry.proto +index 265ebaa..76b2881 100644 +--- a/psi/proto/entry.proto ++++ b/psi/proto/entry.proto +@@ -47,5 +47,9 @@ message LaunchConfig { + DkPirSenderConfig dk_pir_sender_config = 8; + + DkPirReceiverConfig dk_pir_receiver_config = 9; ++ ++ KcalPirSenderConfig kcal_pir_sender_config = 10; ++ ++ KcalPirReceiverConfig kcal_pir_receiver_config = 11; + } + } +diff --git a/psi/proto/pir.proto b/psi/proto/pir.proto +index 094ecba..4fc74c8 100644 +--- a/psi/proto/pir.proto ++++ b/psi/proto/pir.proto +@@ -127,6 +127,33 @@ message ApsiReceiverConfig { + uint32 query_batch_size = 10; + } + ++message KcalPirSenderConfig { ++ // Number of threads to use ++ uint32 threads = 1; ++ ++ // Path to a CSV file describing the sender's dataset (an item-label pair on ++ // each row) or a file containing a serialized SenderDB; the CLI will first ++ // attempt to load the data as a serialized SenderDB, and - upon failure - ++ // will proceed to attempt to read it as a CSV file ++ // For CSV File: ++ // 1. the first col is processed as item while the second col as label. OTHER ++ // COLS ARE IGNORED. ++ // 2. NO HEADERS ARE ALLOWED ++ string db_file = 2; ++} ++ ++message KcalPirReceiverConfig { ++ // Number of threads to use ++ uint32 threads = 1; ++ ++ string query_file = 2; ++ ++ // Path to a file where intersection result will be written. ++ string output_file = 3; ++ ++ bool is_dummy_mode = 4; ++} ++ + message DkPirSenderConfig { + enum Mode { + MODE_UNSPECIFIED = 0; +diff --git a/psi/proto/psi_v2.proto b/psi/proto/psi_v2.proto +index 5682ef7..c56b575 100644 +--- a/psi/proto/psi_v2.proto ++++ b/psi/proto/psi_v2.proto +@@ -60,6 +60,8 @@ enum Protocol { + + // Blazing Fast PSI https://eprint.iacr.org/2022/320.pdf + PROTOCOL_RR22 = 3; ++ ++ PROTOCOL_KCAL = 4; + } + + // Configs for ECDH protocol. +@@ -92,6 +94,10 @@ message Rr22Config { + bool low_comm_mode = 2; + } + ++message KcalConfig { ++ uint32 thread_count = 1; ++} ++ + // Any items related to PSI protocols. + message ProtocolConfig { + Protocol protocol = 1; +@@ -109,6 +115,8 @@ message ProtocolConfig { + + // For RR22 protocol. + Rr22Config rr22_config = 6; ++ ++ KcalConfig kcal_config = 7; + } + + // TODO(junfeng): support more io types including oss, sql, etc. +diff --git a/third_party/BUILD.bazel b/third_party/BUILD.bazel +new file mode 100644 +index 0000000..a980827 +--- /dev/null ++++ b/third_party/BUILD.bazel +@@ -0,0 +1,17 @@ ++load("@rules_foreign_cc//foreign_cc:defs.bzl", "cmake") ++ ++package(default_visibility = ["//visibility:public"]) ++ ++cmake( ++ name = "libkcal", ++ generate_args = ["-GNinja"], ++ lib_source = "//third_party/kcal:kcal_srcs", ++ out_shared_libs = [ ++ "libdata_guard_common.so", ++ "libdata_guard.so", ++ "libhitls_bsl.so", ++ "libhitls_crypto.so", ++ "libmpc_tee.so", ++ "libsecurec.so", ++ ] ++) +diff --git a/third_party/kcal/BUILD.bazel b/third_party/kcal/BUILD.bazel +new file mode 100644 +index 0000000..88895f4 +--- /dev/null ++++ b/third_party/kcal/BUILD.bazel +@@ -0,0 +1,5 @@ ++filegroup( ++ name = "kcal_srcs", ++ srcs = glob(["**"]), ++ visibility = ["//visibility:public"], ++) +\ No newline at end of file +diff --git a/third_party/kcal/CMakeLists.txt b/third_party/kcal/CMakeLists.txt +new file mode 100644 +index 0000000..972c809 +--- /dev/null ++++ b/third_party/kcal/CMakeLists.txt +@@ -0,0 +1,69 @@ ++cmake_minimum_required(VERSION 3.16) ++ ++project(lib_kcal) ++ ++add_library(lib_data_guard_common SHARED IMPORTED GLOBAL) ++set_target_properties(lib_data_guard_common PROPERTIES ++ IMPORTED_LOCATION "${CMAKE_CURRENT_LIST_DIR}/lib/libdata_guard_common.so" ++ INTERFACE_INCLUDE_DIRECTORIES "${CMAKE_CURRENT_LIST_DIR}/include" ++) ++ ++add_library(lib_data_guard SHARED IMPORTED GLOBAL) ++set_target_properties(lib_data_guard PROPERTIES ++ IMPORTED_LOCATION "${CMAKE_CURRENT_LIST_DIR}/lib/libdata_guard.so" ++ INTERFACE_INCLUDE_DIRECTORIES "${CMAKE_CURRENT_LIST_DIR}/include" ++) ++ ++add_library(lib_hitls_bsl SHARED IMPORTED GLOBAL) ++set_target_properties(lib_hitls_bsl PROPERTIES ++ IMPORTED_LOCATION "${CMAKE_CURRENT_LIST_DIR}/lib/libhitls_bsl.so" ++ INTERFACE_INCLUDE_DIRECTORIES "${CMAKE_CURRENT_LIST_DIR}/include" ++) ++ ++add_library(lib_hitls_crypto SHARED IMPORTED GLOBAL) ++set_target_properties(lib_hitls_crypto PROPERTIES ++ IMPORTED_LOCATION "${CMAKE_CURRENT_LIST_DIR}/lib/libhitls_crypto.so" ++ INTERFACE_INCLUDE_DIRECTORIES "${CMAKE_CURRENT_LIST_DIR}/include" ++) ++ ++add_library(lib_mpc_tee SHARED IMPORTED GLOBAL) ++set_target_properties(lib_mpc_tee PROPERTIES ++ IMPORTED_LOCATION "${CMAKE_CURRENT_LIST_DIR}/lib/libmpc_tee.so" ++ INTERFACE_INCLUDE_DIRECTORIES "${CMAKE_CURRENT_LIST_DIR}/include" ++) ++ ++add_library(lib_securec SHARED IMPORTED GLOBAL) ++set_target_properties(lib_securec PROPERTIES ++ IMPORTED_LOCATION "${CMAKE_CURRENT_LIST_DIR}/lib/libsecurec.so" ++ INTERFACE_INCLUDE_DIRECTORIES "${CMAKE_CURRENT_LIST_DIR}/include" ++) ++ ++add_library(lib_kcal INTERFACE) ++target_include_directories(lib_kcal INTERFACE ++ $ ++ $ ++) ++target_link_libraries(lib_kcal INTERFACE ++ lib_data_guard_common ++ lib_data_guard ++ lib_hitls_bsl ++ lib_hitls_crypto ++ lib_mpc_tee ++ lib_securec ++) ++ ++install(TARGETS lib_kcal ++ EXPORT lib_kcal_targets ++) ++ ++install(DIRECTORY include/ DESTINATION include ++ FILES_MATCHING PATTERN "*.h" ++) ++ ++file(GLOB SO_FILES "${CMAKE_CURRENT_LIST_DIR}/lib/*.so") ++install(FILES ${SO_FILES} DESTINATION lib) ++ ++install(EXPORT lib_kcal_targets ++ DESTINATION lib/cmake/libkcal ++ FILE libkcal-config.cmake ++) +diff --git a/third_party/kcal_middleware/BUILD.bazel b/third_party/kcal_middleware/BUILD.bazel +new file mode 100644 +index 0000000..69e75e2 +--- /dev/null ++++ b/third_party/kcal_middleware/BUILD.bazel +@@ -0,0 +1,5 @@ ++filegroup( ++ name = "kcal_middleware_srcs", ++ srcs = glob(["**"]), ++ visibility = ["//visibility:public"], ++) +\ No newline at end of file