diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/README.md b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/README.md new file mode 100644 index 0000000000000000000000000000000000000000..989d18d969ee0936256a7ea796aecdd42b537945 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/README.md @@ -0,0 +1,108 @@ +# 测试hstu npu, gpu性能对比 +npu tag: https://gitee.com/ascend/RecSDK.git branch_v7.1.0-RC1
+gpu tag: https://github.com/NVIDIA/recsys-examples.git v25.05 + +## 1. 准备一台GPU, 一台NPU 服务器 +## 2. GPU服务器下载编译 recsys-example + +``` +git clone https://github.com/NVIDIA/recsys-examples.git +cd recsys-examples/corelib/hstu +git checkout v25.05 +make install +``` +配置GPU服务器端的sftp server服务 + +## 3. 配置NFS共享路径 + +## 4. NPU算子编译 +git clone https://gitee.com/ascend/RecSDK.git
+cd RecSDK/mxrec_add_ons/rec_for_torch/operators/
+去掉hstu_dense_forward/op_host/tiling_policy.cpp 函数GeneralShapeCheck 的校验,直接返回true. 编译算子安装
+去掉hstu_dense_backward/op_host/hstu_dense_backward_tiling_common.cpp 函数BasicShapeCheck 的校验,直接返回true. 编译算子安装
+ + +## 5. 配置当前工程config.py 填写所有x的位置 +``` +# Absolute directory path of the recsys-example project recsys-examples/corelib/hstu (replace with actual path) +RECSYS_DIR="x" + +# Network File System (NFS) mount directory path (replace with actual path) +NFS_DIR="x" + +# IP address of the GPU server (replace with actual IP) +GPU_IP="x" + +# Login username for the GPU server (replace with actual username) +GPU_USER="x" + +# Absolute path to Python3 interpreter (replace with actual path, e.g. /usr/bin/python3) +PYTHON3="x" +``` +## 6. 上传test_npu_performance 到NPU 共享目录NFS_DIR + +## 7. 登录NPU服务器,执行性能测试命令 +输入参数在benchmark.csv, 结果保存在result.csv中。 +``` +cd test_npu_performance +pip install -r requirements.txt +python3 test_benchmark.py # 需交互输入GPU服务器密码 +``` + + +### Benchmark文件头说明 + +| 参数名 | 类型 | 说明 | +|--------|------|------| +| index | int | 测试用例的唯一标识符 | +| shape_info | str | 张量形状描述信息 | +| total_len | int | 输入序列总长度 | +| batch_size | int | 批量大小 | +| heads | int | 注意力头数 | +| heads_rab | int | 使用RAB时的头数 | +| max_seq_len_q | int | 查询序列最大长度 | +| max_seq_len_k | int | 键序列最大长度 | +| max_context_len | int | 上下文序列最大长度 | +| max_target_len | int | 目标序列最大长度 | +| target_group_size | int | 目标序列分组大小 | +| attn_dim | int | 注意力机制维度 | +| hidden_dim | int | 隐藏层维度 | +| alpha | float | 缩放因子 | +| has_rab | bool | 是否使用RAB | +| has_drab | bool | 是否使用DRAB | +| window_size | int | 滑动窗口大小 | +| run_benchmark | bool | 是否运行基准测试 | +| dtype | str | 数据类型 | +| full_batch | bool | 是否使用完整批次 | +| is_delta_q | bool | 是否使用增量查询 | +| format | str | 数据格式 | +| npu_fw_time | float | NPU前向时间(ms) | +| npu_bw_time | float | NPU反向时间(ms) | +| gpu_fw_time | float | GPU前向时间(ms) | +| gpu_bw_time | float | GPU反向时间(ms) | +| precision | bool | 计算精度是否对其 | +| npu_fw/gpu_fw | float | NPU/GPU前向时间比 | +| npu_bw/gpu_bw | float | NPU/GPU反向时间比 | +| npu_fw+bw/gpu_fw+bw | float | NPU/GPU总时间比 | + +### Result文件头说明 + +| 参数名 | 类型 | 说明 | +|--------|------|------| +| precision | bool | 计算精度是否对其 | +| npu_fw/gpu_fw | float | NPU/GPU前向时间比 | +| npu_bw/gpu_bw | float | NPU/GPU反向时间比 | +| npu_fw+bw/gpu_fw+bw | float | NPU/GPU总时间比 | +| npu_fw/benchmark | float | NPU/基准前向时间比 | +| npu_bw/benchmark | float | NPU/基准反向时间比 | +| npu_fw+bw/benchmark | float | NPU/基准总时间比 | + +### 使用说明 + +1. **性能对比**: + - 比值>1表示NPU更快 + - 比值<1表示GPU更快 + +2. **基准验证**: + - 比值≈1表示符合预期 + - 显著偏离1需检查 \ No newline at end of file diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/benchmark.csv b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/benchmark.csv new file mode 100644 index 0000000000000000000000000000000000000000..c15c1ac824b0039ef163f7dc05ee6fa42f815f66 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/benchmark.csv @@ -0,0 +1,30 @@ +index,shape_info,total_len,batch_size,heads,heads_rab,max_seq_len_q,max_seq_len_k,max_context_len,max_target_len,target_group_size,attn_dim,hidden_dim,alpha,has_rab,has_drab,window_size,run_benchmark,dtype,full_batch,is_delta_q,format,npu_fw_time,npu_bw_time,gpu_fw_time,gpu_bw_time,precision,npu_fw/gpu_fw,npu_bw/gpu_bw,npu_fw+bw/gpu_fw+bw +1,01shape,64570,96,3,3,0,0,1991,417,1,256,256,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,1.727,3.729,0.991,3.276,TRUE,0.574,0.879,0.782 +2,02shape,88696,8,4,4,16384,16384,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,20.419,65.717,6.475,18.819,TRUE,0.317,0.286,0.294 +3,03shape,21536,32,4,4,1024,1024,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,0.42,1.165,0.197,0.552,TRUE,0.469,0.474,0.473 +4,blockless256,400,8,4,4,100,100,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,0.069,0.098,0.049,0.236,TRUE,0.71,2.408,1.707 +5,bound,40960,4,2,2,20480,20480,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,5.871,19.499,2.058,5.167,TRUE,0.351,0.265,0.285 +6,Random,65536,32,8,8,4096,4096,0,0,1,256,256,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,8.105,18.877,3.097,11.97,TRUE,0.382,0.634,0.558 +7,Random,65536,32,8,8,4096,4096,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,5.797,15.801,1.829,5.533,TRUE,0.316,0.35,0.341 +8,Random,65536,32,2,2,4096,4096,0,0,1,256,256,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,2.12,5.043,1.099,3.231,TRUE,0.518,0.641,0.604 +9,Random,65536,32,2,2,4096,4096,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,1.427,4.169,0.659,1.603,TRUE,0.462,0.385,0.404 +10,Random,54400,32,8,8,3400,3400,0,0,1,256,256,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,8.082,17.313,2.594,8.767,TRUE,0.321,0.506,0.447 +11,Random,54400,32,8,8,3400,3400,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,6.057,14.197,1.405,4.285,TRUE,0.232,0.302,0.281 +12,Random,54400,32,2,2,3400,3400,0,0,1,256,256,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,2.049,4.408,0.704,2.39,TRUE,0.344,0.542,0.479 +13,Random,54400,32,2,2,3400,3400,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,1.511,3.486,0.468,1.166,TRUE,0.31,0.334,0.327 +14,Random,19248,32,8,8,1203,1203,0,0,1,256,256,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,1.324,2.915,0.482,1.71,TRUE,0.364,0.587,0.517 +15,Random,19248,32,8,8,1203,1203,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,0.99,2.347,0.325,0.884,TRUE,0.328,0.377,0.362 +16,Random,19248,32,2,2,1203,1203,0,0,1,256,256,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,0.438,0.861,0.175,0.813,TRUE,0.4,0.944,0.761 +17,Random,19248,32,2,2,1203,1203,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,0.333,0.674,0.113,0.422,TRUE,0.339,0.626,0.531 +18,Random,16384,8,8,8,4096,4096,0,0,1,256,256,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,2.456,5.708,1.056,3.704,TRUE,0.43,0.649,0.583 +19,Random,16384,8,8,8,4096,4096,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,1.624,4.827,0.624,1.684,TRUE,0.384,0.349,0.358 +20,Random,16384,8,2,2,4096,4096,0,0,1,256,256,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,0.687,1.607,0.375,1.144,TRUE,0.546,0.712,0.662 +21,Random,16384,8,2,2,4096,4096,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,0.47,1.402,0.29,0.524,TRUE,0.617,0.374,0.435 +22,Random,13600,8,8,8,3400,3400,0,0,1,256,256,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,2.395,5.084,0.944,2.727,TRUE,0.394,0.536,0.491 +23,Random,13600,8,8,8,3400,3400,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,1.78,4.091,0.544,1.703,TRUE,0.306,0.416,0.383 +24,Random,13600,8,2,2,3400,3400,0,0,1,256,256,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,0.798,1.563,0.307,0.968,TRUE,0.385,0.619,0.54 +25,Random,13600,8,2,2,3400,3400,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,0.592,1.194,0.231,0.409,TRUE,0.39,0.343,0.358 +26,Random,4812,8,8,8,1203,1203,0,0,1,256,256,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,0.473,0.922,0.158,0.523,TRUE,0.334,0.567,0.488 +27,Random,4812,8,8,8,1203,1203,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,0.351,0.712,0.135,0.294,TRUE,0.385,0.413,0.404 +28,Random,4812,8,2,2,1203,1203,0,0,1,256,256,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,0.173,0.328,0.091,0.23,TRUE,0.526,0.701,0.641 +29,Random,4812,8,2,2,1203,1203,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,0.131,0.284,0.074,0.137,TRUE,0.565,0.482,0.508 diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/requirements.txt b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..7b73200af8ca162f91f4948a1c467f7c52ee52d8 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/requirements.txt @@ -0,0 +1,5 @@ +pandas +numpy +matplotlib +paramiko==3.5.1 +torch==2.6.0 diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_benchmark.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..5d0f4f47905d0e5393cb579385c975a63a5024b0 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_benchmark.py @@ -0,0 +1,368 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2025. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import argparse +import time +import getpass +import os +import subprocess +from typing import Optional +import socket +from select import select + +import torch +import paramiko +import pandas as pd + +import config +from test_read_benchmark import ( + benchmark_csv, + result_csv, + logger, + DATASETS, + init_result_csv_index, + create_and_save_params, +) +from test_msprof_npu_hstu import msprof_main + + +INDEX_STR = "index" +GPU_FW_TIME = "gpu_fw_time" +NPU_FW_TIME = "npu_fw_time" +NPU_BW_TIME = "npu_bw_time" +GPU_BW_TIME = "gpu_bw_time" + + +_remote_password_cache = None + + +def get_remote_password() -> Optional[str]: + """Securely retrieves and caches the GPU server password (single input, reusable for the session). + + Returns: + Optional[str]: The password if successfully retrieved, None otherwise. + """ + global _remote_password_cache + + if _remote_password_cache is not None: + return _remote_password_cache + + try: + # Interactive input (hidden) + logger.info("\n[Security Notice] Password input will not display characters.") + password = getpass.getpass( + prompt="Enter GPU server password (valid for this session): " + ) + if password.strip(): + _remote_password_cache = password + return _remote_password_cache + else: + logger.error("Password cannot be empty.") + return None + + except Exception as e: + logger.error(f"Password retrieval failed: {str(e)}") + return None + + +def execute_remote_linux_cmd(client, cmd, timeout=600): + """Secure execution of Linux remote commands (supports timeout control and complete output capture)""" + stdin, stdout, stderr = None, None, None + exit_status = -1 + output_buffer = [] + error_buffer = [] + + try: + # Create execution channel (set timeout to prevent network blocking) + transport = client.get_transport() + channel = transport.open_session(timeout=timeout) + channel.settimeout(timeout) + + # Execute command (non-blocking mode) + channel.exec_command(cmd) + stdin = channel.makefile_stdin("wb") + stdout = channel.makefile("r") + stderr = channel.makefile_stderr("r") + + # Use select for efficient I/O multiplexing + start_time = time.time() + while not channel.exit_status_ready(): + # Timeout check + if time.time() - start_time > timeout: + raise socket.timeout(f"Command timeout after {timeout}s: {cmd}") + + # Wait for readable events (0.5s polling interval) + rlist, _, _ = select([channel], [], [], 0.5) + if not rlist: + continue + + # Prioritize reading error stream (avoid buffer overflow) + while channel.recv_stderr_ready(): + error_buffer.append(stderr.read(4096)) + + while channel.recv_ready(): + output_buffer.append(stdout.read(4096)) + + # Get final exit status + exit_status = channel.recv_exit_status() + + # Read all remaining output + output_buffer.append(stdout.read()) + error_buffer.append(stderr.read()) + + except socket.timeout as te: + logger.error(f"Command execution timed out: {te}") + # Attempt to terminate remote process (send SIGTERM) + if "channel" in locals(): + channel.close() + return False + except Exception as e: + logger.error(f"SSH command execution failed: {e}", exc_info=True) + return False + finally: + # Safely close channels (without closing underlying client connection) + for stream in [stdin, stdout, stderr]: + try: + if stream: + stream.close() + except Exception as e: + logger.info(e) + + try: + if "channel" in locals(): + channel.close() + except Exception as e: + logger.info(e) + + # Build output results + decoded_output = [ + item.decode("utf-8") if isinstance(item, bytes) else item + for item in output_buffer + ] + decoded_error = [ + item.decode("utf-8") if isinstance(item, bytes) else item + for item in error_buffer + ] + full_output = "".join(decoded_output) + full_error = "".join(decoded_error) + + # Log results (limit log length) + logger.info(f"Command [{cmd}] exited with status: {exit_status}") + if full_output: + logger.debug( + f"stdout: {full_output[:2000]}{'...' if len(full_output)>2000 else ''}" + ) + if full_error: + logger.warning( + f"stderr: {full_error[:2000]}{'...' if len(full_error)>2000 else ''}" + ) + + return exit_status == 0 + + +def transfer_and_execute(bx): + if not isinstance(bx, int): + raise ValueError(f"input must be an integer but got {bx}") + remote_host = config.GPU_IP + remote_user = config.GPU_USER + remote_password = get_remote_password() + remote_dir = os.path.realpath(config.RECSYS_DIR) + client = paramiko.SSHClient() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + cmd = f"cd {remote_dir} && {os.path.realpath(config.PYTHON3)} test_gpu_hstu.py --index={bx}" + + try: + client.connect( + remote_host, port=22, username=remote_user, password=remote_password + ) + sftp = client.open_sftp() + files_to_transfer = ["test_gpu_hstu.py", "test_read_benchmark.py", "config.py"] + for file in files_to_transfer: + sftp.put(file, os.path.join(remote_dir, file)) + sftp.close() + + logger.info(f"Executing remote script: {cmd}") + return execute_remote_linux_cmd(client, cmd) + except Exception as e: + logger.error(f"Failed to execute remote script: {cmd}") + logger.error(e) + return False + finally: + # 安全关闭连接 + try: + if client: + client.close() + except Exception as close_error: + logger.error(f"SSH close error: {close_error}") + logger.info(f"Exit remote script: {cmd}") + + +def compare_npu_gpu_precision(save_dir=DATASETS, device="cpu"): + logger.info(f"Compating npu and gpu results of {save_dir}") + try: + data_type = torch.load(os.path.join(save_dir, "dtype.pth"), map_location=device) + npu_out = torch.load( + os.path.join(save_dir, "npu_out.pth"), map_location=device + ).to(dtype=data_type) + gpu_out = torch.load( + os.path.join(save_dir, "gpu_out.pth"), map_location=device + ).to(dtype=data_type) + gpu_out = gpu_out.view(gpu_out.shape[0], -1) + + npu_q = torch.load(os.path.join(save_dir, "npu_q.pth"), map_location=device).to( + dtype=data_type + ) + gpu_q = torch.load(os.path.join(save_dir, "gpu_q.pth"), map_location=device).to( + dtype=data_type + ) + + npu_k = torch.load(os.path.join(save_dir, "npu_k.pth"), map_location=device).to( + dtype=data_type + ) + gpu_k = torch.load(os.path.join(save_dir, "gpu_k.pth"), map_location=device).to( + dtype=data_type + ) + + npu_v = torch.load(os.path.join(save_dir, "npu_v.pth"), map_location=device).to( + dtype=data_type + ) + gpu_v = torch.load(os.path.join(save_dir, "gpu_v.pth"), map_location=device).to( + dtype=data_type + ) + + except Exception as e: + logger.error(f"error : {e}") + return False + + if data_type == torch.bfloat16: + eps = 1e-2 + elif data_type == torch.float16: + eps = 1e-3 + else: + logger.error(f"error type : {data_type}") + return False + + try: + out_close = torch.allclose(npu_out, gpu_out, eps, eps) + out_q = torch.allclose(npu_q, gpu_q, eps, eps) + out_k = torch.allclose(npu_k, gpu_k, eps, eps) + out_v = torch.allclose(npu_v, gpu_v, eps, eps) + except Exception as e: + logger.error(f"error : {e}") + return False + + logger.info(f"npu_out vs gpu_out: {out_close}") + logger.info(f"npu_q vs gpu_q: {out_q}") + logger.info(f"npu_k vs gpu_k: {out_k}") + logger.info(f"npu_v vs gpu_v: {out_v}") + + ret = out_close and out_q and out_k and out_v + return ret + + +def update_index_csv(df_res, benchmark_df, bx, precision): + mask_res = df_res[INDEX_STR] == bx + mask_benchmark = benchmark_df[INDEX_STR] == bx + + df_res.loc[mask_res, "precision"] = precision + df_res.loc[mask_res, "npu_fw/gpu_fw"] = ( + df_res.loc[mask_res, GPU_FW_TIME] / df_res.loc[mask_res, NPU_FW_TIME] + ) + df_res.loc[mask_res, "npu_bw/gpu_bw"] = ( + df_res.loc[mask_res, GPU_BW_TIME] / df_res.loc[mask_res, NPU_BW_TIME] + ) + + df_res.loc[mask_res, "npu_fw+bw/gpu_fw+bw"] = df_res.loc[ + mask_res, [GPU_FW_TIME, GPU_BW_TIME] + ].sum(axis=1) / df_res.loc[mask_res, [NPU_FW_TIME, NPU_BW_TIME]].sum(axis=1) + df_res.loc[mask_res, "npu_fw/benchmark"] = ( + benchmark_df.loc[mask_benchmark, NPU_FW_TIME].item() + / df_res.loc[mask_res, NPU_FW_TIME].item() + ) + df_res.loc[mask_res, "npu_bw/benchmark"] = ( + benchmark_df.loc[mask_benchmark, NPU_BW_TIME].item() + / df_res.loc[mask_res, NPU_BW_TIME].item() + ) + df_res.loc[mask_res, "npu_fw+bw/benchmark"] = benchmark_df.loc[ + mask_benchmark, [NPU_FW_TIME, NPU_BW_TIME] + ].sum(axis=1).item() / df_res.loc[mask_res, [NPU_FW_TIME, NPU_BW_TIME]].sum(axis=1).item() + + + +def retry_operation(operation, operation_name, bx, max_retries=2): + for attempt in range(max_retries): + ret = operation(bx) + if ret: + return True + logger.warning( + f"{operation_name} failed for benchmark {bx}, attempt {attempt + 1}/{max_retries}" + ) + logger.error( + f"{operation_name} failed for benchmark {bx} after {max_retries} retries" + ) + return False + + +def main(index=None): + benchmark_df = pd.read_csv(benchmark_csv) + benchmark_df[INDEX_STR] = benchmark_df[INDEX_STR].astype(int) + all_indices = benchmark_df[INDEX_STR].tolist() + + if index is not None: + all_indices = [index] + failed = [] + for bx in all_indices: + logger.info(f"benchmark {bx} testing") + init_result_csv_index(bx) + + df_res = pd.read_csv(result_csv) + + if ( + bx in df_res[INDEX_STR].values + and df_res[df_res[INDEX_STR] == bx].notna().all().all() + ): + logger.info(f"benchmark {bx} already, pass") + continue + + create_and_save_params(bx) + remote_success = retry_operation(transfer_and_execute, "Remote execution", bx) + if not remote_success: + failed.append(bx) + continue + + local_success = retry_operation(msprof_main, "Local execution", bx) + if not local_success: + failed.append(bx) + continue + + df_res = pd.read_csv(result_csv) + precision = compare_npu_gpu_precision() + update_index_csv(df_res, benchmark_df, bx, precision) + df_res.to_csv(result_csv, index=False) + logger.info(f"benchmark {bx} tested") + + if len(failed) != 0: + logger.error("index %s failed!", ", ".join(str(x) for x in failed)) + else: + logger.info("All successed!") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run benchmark tests.") + parser.add_argument("--index", type=int, help="benchmark to run. Default run all.") + args = parser.parse_args() + main(args.index) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_gpu_hstu.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_gpu_hstu.py new file mode 100644 index 0000000000000000000000000000000000000000..586a7f5c7d995652faa0859f41249bb961a2da4b --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_gpu_hstu.py @@ -0,0 +1,552 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2025. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import argparse +import math +import os +import traceback +from statistics import mean +from typing import Optional, Tuple + +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F +from einops import rearrange +from pynvml import ( + nvmlDeviceGetCount, + nvmlDeviceGetHandleByIndex, + nvmlDeviceGetMemoryInfo, + nvmlInit, + nvmlShutdown, +) + +from test_read_benchmark import ( + DATASETS, + init_result_csv_index, + logger, + read_and_validate_parameters, + result_csv, + load_params, +) + +sm_major_version = torch.cuda.get_device_properties(0).major +sm_minor_version = torch.cuda.get_device_properties(0).minor +if sm_major_version == 9 and sm_minor_version == 0: + from hstu_attn_interface import hstu_attn_varlen_func +elif sm_major_version == 8: + from hstu_attn import hstu_attn_varlen_func + + +PERFORMANCE = True +if PERFORMANCE: + g_iterations = 100 + g_profiler_step_start = 20 + loop = 3 +else: + g_iterations = 1 + g_profiler_step_start = 0 + loop = 3 + + +def get_gpu_memory_info(): + nvmlInit() + device_count = nvmlDeviceGetCount() + memory_info = [] + for i in range(device_count): + handle = nvmlDeviceGetHandleByIndex(i) + info = nvmlDeviceGetMemoryInfo(handle) + memory_info.append((i, info.total, info.used)) + nvmlShutdown() + return memory_info + + +def auto_select_gpu(): + memory_info = get_gpu_memory_info() + min_memory_used = float("inf") + best_gpu_index = None + for i, total, used in memory_info: + logger.info( + f"GPU {i}: Total Memory = {total / 1024**2} MiB, Used Memory = {used / 1024**2} MiB" + ) + if used < min_memory_used: + min_memory_used = used + best_gpu_index = i + return best_gpu_index + + +gdevice = auto_select_gpu() +logger.info(f"Selected GPU device index: {gdevice}") +torch.cuda.set_device(gdevice) + + +def pad_input(unpadded_input, cu_seqlen, batch, seqlen): + indices = [] + for i in range(batch): + indices.append( + torch.arange(seqlen * i, seqlen * i + cu_seqlen[i + 1] - cu_seqlen[i]) + ) + indices = torch.cat(indices) + output = torch.zeros( + (batch * seqlen), + *unpadded_input.shape[1:], + device=unpadded_input.device, + dtype=unpadded_input.dtype, + ) + output[indices] = unpadded_input + return rearrange(output, "(b s) ... -> b s ...", b=batch) + + +def pad_input_delta_q(unpadded_input, cu_seqlen_q, cu_seqlen_k, batch, seqlen): + indices = [] + for i in range(batch): + act_seqlen_q = (cu_seqlen_q[i + 1] - cu_seqlen_q[i]).item() + act_seqlen_k = (cu_seqlen_k[i + 1] - cu_seqlen_k[i]).item() + indices.append( + torch.arange( + seqlen * i + act_seqlen_k - act_seqlen_q, seqlen * i + act_seqlen_k + ) + ) + indices = torch.cat(indices) + output = torch.zeros( + (batch * seqlen), + *unpadded_input.shape[1:], + device=unpadded_input.device, + dtype=unpadded_input.dtype, + ) + output[indices] = unpadded_input + return rearrange(output, "(b s) ... -> b s ...", b=batch) + + +def unpad_input(padded_input, cu_seqlen): + padded_input.reshape(padded_input.size(0), padded_input.size(1), -1) + output = [] + for i in range(len(cu_seqlen) - 1): + output.append(padded_input[i, : (cu_seqlen[i + 1] - cu_seqlen[i]), :]) + return torch.cat(output, dim=0) + + +def unpad_input_delta_q(padded_input, cu_seqlen_q, cu_seqlen_k, batch, seqlen): + padded_input.reshape(padded_input.size(0), padded_input.size(1), -1) + output = [] + for i in range(batch): + act_seqlen_q = (cu_seqlen_q[i + 1] - cu_seqlen_q[i]).item() + act_seqlen_k = (cu_seqlen_k[i + 1] - cu_seqlen_k[i]).item() + output.append(padded_input[i, act_seqlen_k - act_seqlen_q: act_seqlen_k, :]) + return torch.cat(output, dim=0) + + +def _hstu_attention_maybe_from_cache( + num_heads: int, + attention_dim: int, + linear_dim: int, + seqlen_q: int, + seqlen_k: int, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q_offsets: torch.Tensor, + k_offsets: torch.Tensor, + rab: Optional[torch.Tensor], + invalid_attn_mask: torch.Tensor, + alpha: float, + upcast: bool = True, + reorder_op: bool = False, + is_delta_q: bool = False, +): + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + batch: int = q_offsets.size(0) - 1 + dtype_out = q.dtype + if is_delta_q: + padded_q = pad_input_delta_q(q, q_offsets, k_offsets, batch, seqlen_k) + else: + padded_q = pad_input(q, q_offsets, batch, seqlen_q) + padded_k = pad_input(k, k_offsets, batch, seqlen_k) + padded_v = pad_input(v, k_offsets, batch, seqlen_k) + + padded_q = padded_q.view(batch, seqlen_k, num_heads, attention_dim) + padded_k = padded_k.view(batch, seqlen_k, num_heads, attention_dim) + padded_v = padded_v.view(batch, seqlen_k, num_heads, linear_dim) + if upcast: + padded_q, padded_k, padded_v = ( + padded_q.float(), + padded_k.float(), + padded_v.float(), + ) + if rab is not None: + rab = rab.float() + qk_attn = torch.einsum( + "bnhd,bmhd->bhnm", + padded_q, + padded_k, + ) + + if rab is not None: + padding = ( + 0, + qk_attn.shape[-1] - rab.shape[-1], + 0, + qk_attn.shape[-2] - rab.shape[-2], + ) + rab = F.pad(rab, padding, value=0) + masked_qk_attn = qk_attn + rab + else: + masked_qk_attn = qk_attn + masked_qk_attn = masked_qk_attn * alpha + masked_qk_attn = F.silu(masked_qk_attn) + masked_qk_attn = masked_qk_attn / seqlen_q + if invalid_attn_mask is not None: + if invalid_attn_mask.ndim == 2: + invalid_attn_mask = invalid_attn_mask.unsqueeze(0).unsqueeze(0) + masked_qk_attn = ( + masked_qk_attn * invalid_attn_mask.type(masked_qk_attn.dtype)[:, :, :, :] + ) + + attn_output = torch.einsum( + "bhnm,bmhd->bnhd", + masked_qk_attn, + padded_v, + ) + + attn_output = attn_output.reshape(batch, seqlen_k, num_heads * linear_dim) + if is_delta_q: + attn_output = unpad_input_delta_q( + attn_output, q_offsets, k_offsets, batch, seqlen_k + ) + else: + attn_output = unpad_input(attn_output, q_offsets) + attn_output = attn_output.reshape(-1, num_heads * linear_dim) + + return attn_output.to(dtype_out) + + +def test_fused_attn( + total_len: int, + batch_size: int, + heads: int, + heads_rab: Optional[int], + max_seq_len_q: int, + max_seq_len_k: int, + max_context_len: int, + max_target_len: int, + target_group_size: int, + attn_dim: int, + hidden_dim: int, + alpha: float, + has_rab: bool, + has_drab: bool, + window_size: Tuple[int, int], + run_benchmark: Optional[int], + dtype: torch.dtype, + full_batch: bool, + is_delta_q: bool, +) -> Tuple[Optional[float], Optional[float]]: + has_context = max_context_len > 0 + has_target = max_target_len > 0 + group_target = target_group_size > 1 + is_causal = window_size[0] == -1 and window_size[1] == 0 + if dtype == torch.float8_e4m3fn: + raise ValueError( + "float8_e4m3fn is not supported, please use test_fused_attn_fp8 instead" + ) + if has_drab and not has_rab: + raise ValueError("has_drab is True but has_rab is False") + cond = has_target and (window_size[0] > 0 or window_size[1] > 0) + if (has_target and not is_causal) or cond: + raise ValueError( + "has_target is True but is_causal is False or window_size is not (-1, -1)" + ) + if (max_seq_len_q != max_seq_len_k) and not is_delta_q: + raise ValueError("max_seq_len_q != max_seq_len_k but is_delta_q is False") + if is_delta_q and max_seq_len_q > max_seq_len_k: + raise ValueError("is_delta_q is True but max_seq_len_q > max_seq_len_k") + if group_target and not has_target: + raise ValueError("group_target is True but has_target is False") + + if is_delta_q and has_target: + return None, None + if is_delta_q and has_context: + return None, None + if not is_causal and has_context: + return None, None + if (window_size[0] > 0 or window_size[1] > 0) and has_context: + return None, None + + torch.cuda.synchronize() + + if run_benchmark not in [ + 0b01, + 0b10, + 0b11, + ]: # 0b01 is run hstu benchmark and 0b10 is run torch benchmark, 0b11 is run both and compare precision + raise ValueError("run_benchmark should be in [0b01, 0b10, 0b11]") + + iterations = g_iterations + profiler_step_start = g_profiler_step_start + + ( + lq, + lk, + num_contexts, + seq_offsets_q, + seq_offsets_k, + num_targets, + q, + k, + v, + rab, + attn_mask, + grad, + ) = read_param() + + logger.info( + f"max_context_len: {max_context_len}, max_seq_len_q: {max_seq_len_q}, " + f"max_target_len: {max_target_len}" + ) + logger.info(f"q.shape: {q.shape}") + logger.info(f"k.shape: {k.shape}") + logger.info(f"v.shape: {v.shape}") + logger.info(f"grad.shape: {grad.shape}") + logger.info(f"attn_mask.shape: {attn_mask.shape}") + logger.info(f"seq_offsets_q.shape: {seq_offsets_q.shape}") + logger.info(f"seq_offsets_k.shape: {seq_offsets_k.shape}") + logger.info( + f"total_max_seq_len_q: {max_seq_len_q + max_context_len + max_target_len}" + ) + logger.info( + f"total_max_seq_len_k: {max_seq_len_k + max_context_len + max_target_len}" + ) + logger.info( + f"num_contexts.shape: {num_contexts.shape if (has_context and run_benchmark & 0b01) else None}" + ) + logger.info( + f"num_targets.shape: {num_targets.shape if (has_target and run_benchmark & 0b01) else None}" + ) + logger.info(f"target_group_size: {target_group_size}") + logger.info(f"window_size: {window_size}") + logger.info(f"alpha: {alpha}") + logger.info(f"rab.shape: {rab.shape if has_rab else None}") + logger.info(f"has_drab: {has_drab}") + logger.info(f"is_delta_q: {is_delta_q}") + + fwd_event_start = torch.cuda.Event(enable_timing=True) + fwd_event_stop = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + for i in range(iterations): + if i == profiler_step_start: + fwd_event_start.record() + + if run_benchmark & 0b01: + out_hstu = hstu_attn_varlen_func( + q=q, + k=k, + v=v, + seq_offsets_q=seq_offsets_q, + seq_offsets_k=seq_offsets_k, + max_seqlen_q=max_context_len + max_seq_len_q + max_target_len, + max_seqlen_k=max_context_len + max_seq_len_k + max_target_len, + num_contexts=num_contexts if has_context else None, + num_targets=num_targets if has_target else None, + target_group_size=target_group_size, + window_size=window_size, + alpha=alpha, + rab=rab if has_rab else None, + has_drab=has_drab, + is_delta_q=is_delta_q, + ) + if run_benchmark & 0b10: + out_torch = _hstu_attention_maybe_from_cache( + num_heads=heads, + attention_dim=attn_dim, + linear_dim=hidden_dim, + seqlen_q=max_context_len + max_seq_len_q + max_target_len, + seqlen_k=max_context_len + max_seq_len_k + max_target_len, + q=q.view(lq, -1), + k=k.view(lk, -1), + v=v.view(lk, -1), + q_offsets=seq_offsets_q, + k_offsets=seq_offsets_k, + rab=rab if has_rab else None, + invalid_attn_mask=( + attn_mask.to(torch.float32) if attn_mask is not None else None + ), + alpha=alpha, + upcast=False, + reorder_op=True, + is_delta_q=is_delta_q, + ) + out_torch = out_torch.reshape(-1, heads, attn_dim) + fwd_event_stop.record() + torch.cuda.synchronize() + fwd_time = fwd_event_start.elapsed_time(fwd_event_stop) / ( + iterations - profiler_step_start + ) + + bwd_event_start = torch.cuda.Event(enable_timing=True) + bwd_event_stop = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + for i in range(iterations): + if i == profiler_step_start: + bwd_event_start.record() + + autograd_input = (q, k, v, rab) if has_rab else (q, k, v) + if run_benchmark & 0b01: + dq_hstu, dk_hstu, dv_hstu = torch.autograd.grad( + out_hstu, autograd_input, grad, retain_graph=True + ) + if run_benchmark & 0b10: + dq_torch, dk_torch, dv_torch = torch.autograd.grad( + out_torch, autograd_input, grad, retain_graph=True + ) + bwd_event_stop.record() + torch.cuda.synchronize() + bwd_time = bwd_event_start.elapsed_time(bwd_event_stop) / ( + iterations - profiler_step_start + ) + + if run_benchmark & 0b11 == 0b11: + if dtype == torch.bfloat16: + eps = 1e-2 + elif dtype == torch.float16: + eps = 1e-3 + else: + eps = 1e-4 + + out_close = torch.allclose(out_hstu, out_torch, eps, eps) + q_close = torch.allclose(dq_hstu, dq_torch, eps, eps) + k_close = torch.allclose(dk_hstu, dk_torch, eps, eps) + v_close = torch.allclose(dv_hstu, dv_torch, eps, eps) + + logger.info(f"cpu_out vs gpu_out: {out_close}") + logger.info(f"cpu_q vs gpu_q: {q_close}") + logger.info(f"cpu_k vs gpu_k: {k_close}") + logger.info(f"cpu_v vs gpu_v: {v_close}") + logger.info(f"all pass: {(out_close and q_close and k_close and v_close)}") + + save_dir = DATASETS + prefix = "gpu_" + logger.info(f"prefix: {prefix}") + + cpu = "cpu" + torch.save(out_hstu.to(cpu), os.path.join(save_dir, f"{prefix}out.pth")) + torch.save(dq_hstu.to(cpu), os.path.join(save_dir, f"{prefix}q.pth")) + torch.save(dk_hstu.to(cpu), os.path.join(save_dir, f"{prefix}k.pth")) + torch.save(dv_hstu.to(cpu), os.path.join(save_dir, f"{prefix}v.pth")) + + return fwd_time, bwd_time + + +def read_param(): + """Load parameters and automatically convert to CUDA tensors if needed.""" + # Parameter field names in original order + params_def = [ + "l_q", # length_q + "l_k", # length_k + "num_contexts", # num_contexts + "seq_offsets_q_wt", # seq_offsets_q + "seq_offsets_k_wt", # seq_offsets_k + "num_targets", # num_targets + "q", # query + "k", # key + "v", # value + "rab", # relative_attention_bias + "attn_mask", # attention_mask + "grad", # gradient + ] + + # Error messages + _error_key = "key_error" + _error_messages = { + _error_key: "Missing parameter: {}", + "cuda_transfer_failed": "CUDA transfer failed", + } + + param = load_params(DATASETS) + result = [] + + try: + for field_name in params_def: + param_value = param[field_name] + result.append( + param_value.cuda() + if isinstance(param_value, torch.Tensor) + else param_value + ) + return tuple(result) + + except KeyError as e: + error_msg = _error_messages[_error_key].format(e) + logger.error(error_msg, exc_info=True) + raise + + +def main(): + parser = argparse.ArgumentParser( + description="Read CSV file and run a specific index benchmark" + ) + parser.add_argument( + "--index", type=int, required=True, help="index of the benchmark to run" + ) + args = parser.parse_args() + + _, params = read_and_validate_parameters(args.index) + + init_result_csv_index(args.index) + df_res = pd.read_csv(result_csv) + select_mask = df_res["index"] == args.index + if ( + df_res.loc[select_mask, "gpu_fw_time"].notna().all() + and df_res.loc[select_mask, "gpu_bw_time"].notna().all() + ): + logger.info(f"bemchmark {args.index} already") + exit(0) + + if params is None: + logger.error("Invalid data, benchmark failed") + return + + fwd_time_list = [] + bwd_time_list = [] + + for i in range(loop): + try: + fwd_time, bwd_time = test_fused_attn(**params) + if fwd_time is None and bwd_time is None: + continue + logger.info(f"iter: {i}, fwd time: {fwd_time}, bwd time: {bwd_time}") + fwd_time_list.append(fwd_time) + bwd_time_list.append(bwd_time) + except Exception as e: + logger.error(e) + logger.error(traceback.format_exc()) + return + + if len(fwd_time_list) < loop: + raise Exception(f"Failed: {len(fwd_time_list)}/{loop} iterations finished") + + df_res.loc[select_mask, "gpu_fw_time"] = mean(fwd_time_list[1:]) + df_res.loc[select_mask, "gpu_bw_time"] = mean(bwd_time_list[1:]) + df_res.to_csv(result_csv, index=False) + + logger.info(f"Forward gpu time = {mean(fwd_time_list[1:])} ms") + logger.info(f"Backward gpu time = {mean(bwd_time_list[1:])} ms") + + +if __name__ == "__main__": + main() diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_msprof_npu_hstu.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_msprof_npu_hstu.py new file mode 100644 index 0000000000000000000000000000000000000000..05aa9a0df79bfb0bc5b2cc5e6c3f2d24b301ae75 --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_msprof_npu_hstu.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2025. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os +import argparse +import glob +import subprocess +import pandas as pd + +import config +from test_read_benchmark import ( + logger, + read_and_validate_parameters, + result_csv, + init_result_csv_index, +) + + +def msprof_main(index): + index_str = "index" + try: + # Initialize the result CSV file + init_result_csv_index(index) + df_res = pd.read_csv(result_csv) + + # Check if the current index has already been processed + if ( + df_res.loc[df_res[index_str] == index, "npu_fw_time"].notna().all() + and df_res.loc[df_res[index_str] == index, "npu_bw_time"].notna().all() + ): + logger.info(f"Benchmark with index {index} is already done. Exiting.") + return True + + # Read and validate parameters + _, params = read_and_validate_parameters(index) + + # Execute the msprof command + cmd = f'rm -rf profnpu/ ; msprof --application="python3 test_npu_hstu.py --index={index}" --output=profnpu' + ret = os.system(cmd) + if ret != 0: + logger.error(f"Command execution failed (ret={ret}): {cmd}") + return False + + # Locate the generated CSV file + search_dir = os.path.join(os.path.realpath(config.NFS_DIR), "profnpu") + csv_files = glob.glob( + f"{search_dir}/PROF_*/mindstudio_profiler_output/op_stati*.csv" + ) + if len(csv_files) == 0: + logger.error(f"MSProf run failed. No CSV file found in {search_dir}.") + return False + + csv_file = csv_files[0] + logger.info(f"Profile located at: {csv_file}") + + # Read the CSV file + df_op_stati = pd.read_csv(csv_file) + + # Extract Forward and Backward data + forward_row = df_op_stati[df_op_stati["OP Type"] == "HstuDenseForward"] + backward_row = df_op_stati[df_op_stati["OP Type"] == "HstuDenseBackward"] + + # Check if the data exists + if forward_row.empty or backward_row.empty: + missing_ops = [] + if forward_row.empty: + missing_ops.append("HstuDenseForward") + if backward_row.empty: + missing_ops.append("HstuDenseBackward") + logger.error(f"Missing OP types in CSV: {', '.join(missing_ops)}") + return False + + # Update the result DataFrame + df_res.loc[df_res[index_str] == index, "npu_fw_time"] = ( + forward_row["Avg Time(us)"].squeeze() / 1000 + ) + df_res.loc[df_res[index_str] == index, "npu_bw_time"] = ( + backward_row["Avg Time(us)"].squeeze() / 1000 + ) + + # Save the results + df_res.to_csv(result_csv, index=False) + + # Log the results + logger.info( + f"Forward time: {df_res.loc[df_res[index_str] == index, 'npu_fw_time'].values[0]} ms" + ) + logger.info( + f"Backward time: {df_res.loc[df_res[index_str] == index, 'npu_bw_time'].values[0]} ms" + ) + + return True + + except FileNotFoundError as e: + logger.error(f"File not found: {e}") + return False + except pd.errors.EmptyDataError as e: + logger.error(f"Empty CSV file or invalid data: {e}") + return False + except KeyError as e: + logger.error(f"Missing required column in DataFrame: {e}") + return False + except Exception as e: + logger.error(f"Unexpected error: {e}") + return False + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Read CSV file and run a specific index benchmark" + ) + parser.add_argument( + "--index", type=int, required=True, help="index of the benchmark to run" + ) + args = parser.parse_args() + if msprof_main(args.index): + logger.info("success") diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_npu_hstu.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_npu_hstu.py new file mode 100644 index 0000000000000000000000000000000000000000..c63d8dc606f5bb508a977f770dced5f943eed7fe --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_npu_hstu.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2025. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import argparse +import os +import sysconfig + +import torch + +import config +from test_read_benchmark import ( + logger, + DATASETS, + load_params, + read_and_validate_parameters, +) + +torch.npu.config.allow_internal_format = False +torch.ops.load_library(f"{sysconfig.get_path('purelib')}/libfbgemm_npu_api.so") + + +def _hstu_attention_maybe_from_cache( + num_heads: int, + attention_dim: int, + linear_dim: int, + silu_value: float, + grad: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + invalid_attn_mask: torch.Tensor, + seq_offset: torch.Tensor, + data_type: torch.dtype, + device: str, +): + n: int = invalid_attn_mask.size(-1) + torch.npu.set_device(device) + + q_ = q.reshape(-1, num_heads, attention_dim).to(device=device).to(data_type) + k_ = k.reshape(-1, num_heads, attention_dim).to(device=device).to(data_type) + v_ = v.reshape(-1, num_heads, attention_dim).to(device=device).to(data_type) + grad = grad.to(device=device).to(data_type) + + seq_offset = seq_offset.to(device=device).tolist() + + if len(invalid_attn_mask.shape) == 2: + invalid_attn_mask = invalid_attn_mask.repeat( + len(seq_offset) - 1, num_heads, 1, 1 + ) + if len(invalid_attn_mask.shape) == 4 and invalid_attn_mask.shape[1] == 1: + invalid_attn_mask = invalid_attn_mask.repeat(1, num_heads, 1, 1) + + logger.info(f"invalid_attn_mask shape: {invalid_attn_mask.shape}") + + invalid_attn_mask = invalid_attn_mask.to(device=device).to(data_type) + mask_type = 3 + silu_value = silu_value / n + local_cycle_nums = 100 + for _ in range(local_cycle_nums): + grad_output = torch.ops.mxrec.hstu_dense( + q_, + k_, + v_, + invalid_attn_mask, + None, + mask_type, + n, + silu_value, + "jagged", + seq_offset, + ) + q_grad, k_grad, v_grad, _ = torch.ops.mxrec.hstu_dense_backward( + grad, + q_, + k_, + v_, + invalid_attn_mask, + None, + "jagged", + mask_type, + n, + silu_value, + seq_offset, + ) + + torch.npu.synchronize() + grad_output = grad_output.reshape(-1, num_heads * linear_dim) + + save_dir = DATASETS + + torch.save(grad_output, os.path.join(save_dir, "npu_out.pth")) + torch.save(q_grad, os.path.join(save_dir, "npu_q.pth")) + torch.save(k_grad, os.path.join(save_dir, "npu_k.pth")) + torch.save(v_grad, os.path.join(save_dir, "npu_v.pth")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Read CSV file and run a specific index benchmark" + ) + parser.add_argument( + "--index", type=int, required=True, help="index of the benchmark to run" + ) + args = parser.parse_args() + + devicex = 0 + deviceg = f"npu:{devicex}" + logger.info(f"device: {deviceg}") + read_dir = os.path.join(os.path.realpath(config.NFS_DIR), DATASETS) + + _, param_ben = read_and_validate_parameters(args.index) + param = load_params(DATASETS) + try: + grad_data = param["grad"] + q_data = param["q"] + k_data = param["k"] + v_data = param["v"] + bias_data = param["rab"] + mask_data = param["attn_mask"] + max_seq_len_data = param["num_contexts"] + seq_offset_data = param["seq_offsets_q_wt"] + num_heads_data = v_data.shape[1] + data_type_data = param_ben["dtype"] + alpha_data = param["alpha"] + attention_dim_data = q_data.shape[2] + linear_dim_data = v_data.shape[2] + except KeyError as e: + logger.error(e) + exit(1) + + _hstu_attention_maybe_from_cache( + num_heads=num_heads_data, + attention_dim=attention_dim_data, + linear_dim=linear_dim_data, + silu_value=alpha_data, + grad=grad_data, + q=q_data, + k=k_data, + v=v_data, + invalid_attn_mask=mask_data, + seq_offset=seq_offset_data, + data_type=data_type_data, + device=deviceg, + ) diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_read_benchmark.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_read_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..d837991e8502bb2b61e775a333d3ebb7c88e196b --- /dev/null +++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_read_benchmark.py @@ -0,0 +1,755 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2025. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import argparse +import ast +import logging +import math +import os +from typing import Dict, Optional, Tuple + +import pandas as pd +import numpy as np +import torch +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt +import matplotlib + +import config + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(filename)s %(lineno)d [%(levelname)s] %(message)s", + handlers=[logging.FileHandler("test_benchmark.log"), logging.StreamHandler()], +) +logger = logging.getLogger(__name__) + +matplotlib.use("Agg") +INDEX_STR = "index" +DATASETS = os.path.join(os.path.realpath(config.NFS_DIR), "datasets") +benchmark_csv = os.path.join(os.path.realpath(config.NFS_DIR), "benchmark.csv") +result_csv = os.path.join(os.path.realpath(config.NFS_DIR), "result.csv") + +column_names = [ + "index", + "shape_info", + "total_len", + "batch_size", + "heads", + "heads_rab", + "max_seq_len_q", + "max_seq_len_k", + "max_context_len", + "max_target_len", + "target_group_size", + "attn_dim", + "hidden_dim", + "alpha", + "has_rab", + "has_drab", + "window_size", + "run_benchmark", + "dtype", + "full_batch", + "is_delta_q", + "format", + "npu_fw_time", + "npu_bw_time", + "gpu_fw_time", + "gpu_bw_time", + "precision", + "npu_fw/gpu_fw", + "npu_bw/gpu_bw", + "npu_fw+bw/gpu_fw+bw", + "npu_fw/benchmark", + "npu_bw/benchmark", + "npu_fw+bw/benchmark", +] + +column_left = column_names[: column_names.index("format") + 1] + +hstu_required_params = { + "total_len": int, + "batch_size": int, + "heads": int, + "heads_rab": int, + "max_seq_len_q": int, + "max_seq_len_k": int, + "max_context_len": int, + "max_target_len": int, + "target_group_size": int, + "attn_dim": int, + "hidden_dim": int, + "alpha": float, + "has_rab": bool, + "has_drab": bool, + "window_size": tuple, + "run_benchmark": int, + "dtype": torch.dtype, + "full_batch": bool, + "is_delta_q": bool, +} + +generate_params = [ + "total_len", + "batch_size", + "heads", + "heads_rab", + "max_seq_len_q", + "max_seq_len_k", + "max_context_len", + "max_target_len", + "target_group_size", + "attn_dim", + "hidden_dim", + "window_size", + "dtype", + "full_batch", + "has_drab", + "is_delta_q", +] + + +def convert_value(value, required_type): + if isinstance(value, str): + if value == "torch.float16": + return torch.float16 + elif value == "torch.bfloat16": + return torch.bfloat16 + else: + try: + value = ast.literal_eval(value) + except ValueError: + pass + try: + return required_type(value) + except (ValueError, TypeError): + return value + + +def read_and_validate_parameters(index, csv_file_path=benchmark_csv): + try: + # Try to read CSV file + df_benchmark = pd.read_csv(csv_file_path, encoding="utf-8") + + # Check if index column exists + if INDEX_STR not in df_benchmark.columns: + logger.error("Missing index column in CSV: %s", INDEX_STR) + return None, None + + # Filter rows by index + df_benchmark = df_benchmark.loc[df_benchmark[INDEX_STR] == index] + if df_benchmark.empty: + logger.info("No data found for index %d", index) + return None, None + + # Extract required parameters + try: + row = df_benchmark[list(hstu_required_params.keys())] + except KeyError as e: + logger.error("Missing required columns in CSV: %s", e) + return None, None + + # Convert to dict and validate types + params = row.iloc[0].to_dict() + for key, required_type in hstu_required_params.items(): + try: + params[key] = convert_value(params[key], required_type) + except (ValueError, TypeError) as e: + logger.error("Type conversion failed for parameter %s: %s", key, e) + return None, None + + logger.info("Parameters for index %d: %s", index, params) + return df_benchmark, params + + except Exception as e: + logger.error("Error: %s", e, exc_info=True) # Added exc_info for stack trace + return None, None + + +def init_result_csv_index(index): + try: + # Read benchmark CSV and ensure index column is integer type + ben_df = pd.read_csv(benchmark_csv) + ben_df[INDEX_STR] = ben_df[INDEX_STR].astype(int) + + # Create empty result file if not exists + if not os.path.exists(result_csv): + pd.DataFrame(columns=column_names).to_csv(result_csv, index=False) + + # Read result file + df_res = pd.read_csv(result_csv) + + # Check if index exists in benchmark data + benchmark_row = ben_df.loc[ben_df[INDEX_STR] == index, column_left] + if benchmark_row.empty: # Explicit empty case handling + logger.warning("Index %d not found in benchmark file, skipping", index) + raise ValueError("Row %d not found in %s" % (index, benchmark_csv)) + + # Check if index already exists in result file + if index in df_res[INDEX_STR].values: + return + + # Add new row and save + new_row = pd.DataFrame([benchmark_row.iloc[0]], columns=column_left) + df_res = pd.concat([df_res, new_row], ignore_index=True) + df_res.to_csv(result_csv, index=False) + logger.info("Successfully added index %d to result file %s", index, result_csv) + + except Exception as e: + logger.error("Error initializing result file: %s", str(e), exc_info=True) + raise + + +def construct_mask( + seqlen_c, + seqlen, + seqlen_t=0, + target_group_size=1, + window_size=(-1, -1), # -1 means infinite window size + seq_offsets=None, + num_contexts=None, + device=None, +): + seqlen = seqlen_c + seqlen + seqlen_t + bs = seq_offsets.size(0) - 1 + + mask = torch.zeros((seqlen, seqlen), device=device, dtype=torch.bool) + if window_size[0] < 0 and window_size[1] == 0: + # causal mask + for i in range(seqlen): + mask[i, : i + 1] = True + + # context mask + if seqlen_c != 0: + mask = mask.unsqueeze(0).unsqueeze(0).repeat(bs, 1, 1, 1) + for i in range(bs): + target_start = ( + num_contexts[i] + seq_offsets[i + 1] - seq_offsets[i] + ).item() + mask[i, 0, : num_contexts[i], :target_start] = True + + # target mask + if seqlen_t != 0: + mask = ( + mask.unsqueeze(0).unsqueeze(0).repeat(bs, 1, 1, 1) + if mask.ndim == 2 + else mask + ) + for i in range(bs): + target_start = ( + num_contexts[i] + seq_offsets[i + 1] - seq_offsets[i] + ).item() + # target group mask + if target_group_size > 1: + group_num = math.ceil((seqlen - target_start) / target_group_size) + for j in range(group_num): + for k in range( + min( + target_group_size, + seqlen - target_start - j * target_group_size, + ) + ): + mask[ + i, + 0, + target_start + j * target_group_size + k, + target_start: target_start + j * target_group_size, + ] = False + else: + for j in range(target_start, seqlen): + mask[i, 0, j, target_start:j] = False + + # local mask + else: + window_size_0 = window_size[0] if window_size[0] > 0 else seqlen + window_size_1 = window_size[1] if window_size[1] > 0 else seqlen + for i in range(seqlen): + mask[i, max(0, i - window_size_0): min(seqlen, i + window_size_1 + 1)] = ( + True + ) + return mask + + +def gen_seq(length, max_value, total_sum): + if max_value * length < total_sum: + raise ValueError("total_sum error %d" % total_sum) + logger.info("gen_seq with total_sum %d", total_sum) + if length == 1: + return np.array([total_sum]) + if length == 2: + return np.array([max_value, total_sum - max_value]) + remaining_sum = total_sum - max_value + mean_value = remaining_sum // (length - 2) + min_val = remaining_sum - mean_value * (length - 2) + sequence = [mean_value] * (length - 2) + if min_val == 0 and sequence[-1] > 1: + min_val += 1 + sequence[-1] -= 1 + sequence.extend([max_value, min_val]) + return np.array(sequence) + + +def adjust_ratio(total_sum, max_context_len, max_seq_len_k, max_target_len): + """Adjust ratio distribution based on given parameters.""" + # Initialize result variables + total_content, total_k, total_target = 0, 0, 0 + + # Check denominator + denominator = max_context_len + max_seq_len_k + max_target_len + if denominator == 0: + logger.debug("All max lengths are 0, returning zeros") + return 0, 0, 0 # Return zeros if all inputs are zero + + # Handle zero cases + if max_context_len == 0: + total_content = 0 + if max_seq_len_k == 0: + total_k = 0 + if max_target_len == 0: + total_target = 0 + + # Calculate remaining sum to distribute + remaining_sum = total_sum - (total_content + total_k + total_target) + if remaining_sum == 0: + logger.debug("No remaining sum to distribute") + return total_content, total_k, total_target + + # Calculate valid denominator (excluding zero terms) + valid_denominator = 0 + if max_context_len > 0: + valid_denominator += max_context_len + if max_seq_len_k > 0: + valid_denominator += max_seq_len_k + if max_target_len > 0: + valid_denominator += max_target_len + + if valid_denominator == 0: + raise ValueError("valid_denominator cannot be 0") + + logger.debug("Distributing remaining sum %d with valid denominator %d", remaining_sum, valid_denominator) + + try: + # Distribute remaining sum proportionally + if max_context_len > 0 and valid_denominator != 0: + total_content += int(round(remaining_sum * max_context_len / valid_denominator)) + + if max_seq_len_k > 0 and valid_denominator != 0: + total_k += int(round(remaining_sum * max_seq_len_k / valid_denominator)) + + if max_target_len > 0 and valid_denominator != 0: + total_target += int(round(remaining_sum * max_target_len / valid_denominator)) + except ZeroDivisionError as e: + logger.info(e) + raise e + + # Handle rounding errors + diff = total_sum - (total_content + total_k + total_target) + if diff != 0: + logger.debug("Adjusting for rounding difference of %d", diff) + total_target += diff # Default adjustment to target + + logger.info("Final distribution: total_k=%d, total_content=%d, total_target=%d", \ + total_k, total_content, total_target) + return total_k, total_content, total_target + + +def generate_input( + total_len: int, + batch_size: int, + heads: int, + heads_rab: Optional[int], + max_seq_len_q: int, + max_seq_len_k: int, + max_context_len: int, + max_target_len: int, + target_group_size: int, + attn_dim: int, + hidden_dim: int, + window_size: Tuple[int, int], + dtype: torch.dtype, + full_batch: bool, + has_drab: bool, + is_delta_q: bool, +): + device_str = "cpu" + has_context = max_context_len > 0 + has_target = max_target_len > 0 + target_group_size > 1 + + # Modification Note: Original random length generation caused uncontrolled total length, + # leading to unstable computational load. Now using proportional allocation based on total_len. + logger.info("generate with total_len %d", total_len) + + # Allocate total length proportionally according to max_seq_len_k/max_context_len/max_target_len + total_k, total_content, total_target = adjust_ratio( + total_len, max_context_len, max_seq_len_k, max_target_len + ) + + # Generate key sequence lengths (proportionally allocated) + lengths_k = ( + torch.from_numpy(gen_seq(batch_size, max_seq_len_k, total_k)) + .to(device_str) + .int() + ) + + # Generate context sequence lengths (proportionally allocated) + num_contexts = ( + torch.from_numpy(gen_seq(batch_size, max_context_len, total_content)) + .to(device_str) + .int() + ) + + # Generate target sequence lengths (proportionally allocated) + num_targets = ( + torch.from_numpy(gen_seq(batch_size, max_target_len, total_target)) + .to(device_str) + .int() + ) + + # Generate lengths for context + + seq_offsets_c = torch.zeros( + (batch_size + 1,), dtype=torch.int32, device=torch.device(device_str) + ) + seq_offsets_c[1:] = torch.cumsum(num_contexts, dim=0) + + # Generate lengths for historial qkv + + seq_offsets_k = torch.zeros( + (batch_size + 1,), dtype=torch.int32, device=torch.device(device_str) + ) + seq_offsets_k[1:] = torch.cumsum(lengths_k, dim=0) + + # Generate lengths for target qkv + seq_offsets_t = torch.zeros( + (batch_size + 1,), dtype=torch.int32, device=torch.device(device_str) + ) + seq_offsets_t[1:] = torch.cumsum(num_targets, dim=0) + + # Generate lengths for delta q + if is_delta_q: + if full_batch: + lengths_q = ( + torch.ones( + (batch_size,), device=torch.device(device_str), dtype=torch.int32 + ) + * max_seq_len_q + ) + else: + # lengths_q[i] is an integer between 1 and min(max_seq_len_q, lengths_k[i]) + lengths_q = torch.zeros( + (batch_size,), device=torch.device(device_str), dtype=torch.int32 + ) + for i in range(batch_size): + lengths_q[i] = torch.randint( + 1, + min(max_seq_len_q, lengths_k[i]) + 1, + size=(1,), + device=torch.device(device_str), + ) + seq_offsets_q = torch.zeros( + (batch_size + 1,), dtype=torch.int32, device=torch.device(device_str) + ) + seq_offsets_q[1:] = torch.cumsum(lengths_q, dim=0) + else: + seq_offsets_q = seq_offsets_k + + # Lengths for whole q, kv + seq_offsets_q_wt = torch.zeros( + (batch_size + 1,), dtype=torch.int32, device=torch.device(device_str) + ) + seq_offsets_q_wt = seq_offsets_c + seq_offsets_q + seq_offsets_t + seq_offsets_k_wt = torch.zeros( + (batch_size + 1,), dtype=torch.int32, device=torch.device(device_str) + ) + seq_offsets_k_wt = seq_offsets_c + seq_offsets_k + seq_offsets_t + + l_q = int(seq_offsets_q_wt[-1].item()) + l_k = int(seq_offsets_k_wt[-1].item()) + if dtype == torch.float8_e4m3fn: + dtype_init = torch.float16 + else: + dtype_init = dtype + + # Generate q, k, v for history + target + q = ( + torch.empty( + (l_q, heads, attn_dim), dtype=dtype_init, device=torch.device(device_str) + ) + .uniform_(-1, 1) + .requires_grad_() + ).to(dtype) + k = ( + torch.empty( + (l_k, heads, attn_dim), dtype=dtype_init, device=torch.device(device_str) + ) + .uniform_(-1, 1) + .requires_grad_() + ).to(dtype) + v = ( + torch.empty( + (l_k, heads, hidden_dim), dtype=dtype_init, device=torch.device(device_str) + ) + .uniform_(-1, 1) + .requires_grad_() + ).to(dtype) + rab = None + if has_drab: + rab = torch.empty( + ( + batch_size, + heads if heads_rab is None else heads_rab, + max_context_len + max_seq_len_k + max_target_len, + max_context_len + max_seq_len_k + max_target_len, + ), + dtype=dtype_init, + device=torch.device(device_str), + ).uniform_(-1, 1) + rab = rab.requires_grad_() + + if window_size[0] == -1 and window_size[1] == -1: + attn_mask = None + else: + attn_mask = ( + construct_mask( + seqlen_c=max_context_len, + seqlen=max_seq_len_k, + seqlen_t=max_target_len, + target_group_size=target_group_size, + window_size=window_size, + num_contexts=num_contexts, + seq_offsets=seq_offsets_k, + ) + .cpu() + .to(torch.float32) + ) + grad = torch.rand_like(v) + params = { + "l_q": l_q, + "l_k": l_k, + "num_contexts": num_contexts if has_context else None, + "seq_offsets_q_wt": seq_offsets_q_wt, + "seq_offsets_k_wt": seq_offsets_k_wt, + "num_targets": num_targets if has_target else None, + "q": q, + "k": k, + "v": v, + "rab": rab, + "attn_mask": attn_mask, + "grad": grad, + } + return params + + +def save_mask(matrix, title="matrix"): + cmap = mcolors.LinearSegmentedColormap.from_list( + "CustomMap", [(1, 1, 1), (0.8, 0.902, 0.8)], N=256 + ) + + max_pixels = 65536 + dpi = 100 + + width_inches = min(12, max(6, matrix.shape[1] * 0.05)) + height_inches = min(12, max(6, matrix.shape[0] * 0.05)) + + width_pixels = width_inches * dpi + height_pixels = height_inches * dpi + + if width_pixels > max_pixels or height_pixels > max_pixels: + scale_factor = min(max_pixels / width_pixels, max_pixels / height_pixels) + width_inches *= scale_factor + height_inches *= scale_factor + + fig, ax = plt.subplots(figsize=(width_inches, height_inches)) + + img = ax.imshow( + matrix, cmap=cmap, origin="lower", vmin=0, vmax=1, interpolation="nearest" + ) + + ax.xaxis.set_ticks_position("top") + ax.yaxis.set_ticks_position("left") + ax.invert_yaxis() + + plt.colorbar(img) + plt.title(title) + plt.xlabel("seq_len") + plt.ylabel("seq_len") + plt.savefig(f"{title}.png", bbox_inches="tight", dpi=dpi) + plt.close(fig) + logger.info(f"save {title}.png") + + +PARAM_META = { + # Tensor parameters (name: (type, required, description)) + "l_q": ("Tensor", True, "Query sequence length"), + "l_k": ("Tensor", True, "Key sequence length"), + "num_contexts": ("Tensor", True, "Number of contexts"), + "seq_offsets_q_wt": ("Tensor", True, "Query sequence weight offsets"), + "seq_offsets_k_wt": ("Tensor", True, "Key sequence weight offsets"), + "num_targets": ("Tensor", True, "Number of targets"), + "q": ("Tensor", True, "Query tensor"), + "k": ("Tensor", True, "Key tensor"), + "v": ("Tensor", True, "Value tensor"), + "rab": ("Tensor", True, "Relative attention bias"), + "attn_mask": ("Tensor", True, "Attention mask"), + "grad": ("Tensor", True, "Gradient tensor"), + # Configuration parameters + "dtype": ("Config", True, "Data type"), + "max_context_len": ("Config", True, "Max context length"), + "max_seq_len_q": ("Config", True, "Max query sequence length"), + "max_target_len": ("Config", True, "Max target length"), + "alpha": ("Config", True, "silu scale tensor"), +} + + + +def _get_save_paths(save_dir: str) -> Dict[str, str]: + return { + name: os.path.normpath(os.path.join(save_dir, f"{name}.pth")) + for name in PARAM_META + if not name.startswith("_") + } + + +def save_params(save_dir=DATASETS, **kwargs): + # Normalize path + save_dir = os.path.normpath(save_dir) + os.makedirs(save_dir, exist_ok=True) + + # Validate required parameters + missing = [ + name + for name, (_, required, _) in PARAM_META.items() + if required and name not in kwargs + ] + if missing: + logger.error("Missing required parameters: %s", ", ".join(missing)) + raise ValueError("Required parameters missing: %s" % ", ".join(missing)) + + # Get all save paths + paths = _get_save_paths(save_dir) + image_name = "image_name" + # Save parameters with logging + logger.info("\n%s", "=" * 50) + logger.info("[SAVE] Target directory: %s", save_dir) + for name, path in paths.items(): + if name in kwargs: + torch.save(kwargs[name], path) + param = kwargs[name] + log_msg = "%s: " % name.ljust(15) + if torch.is_tensor(param): + log_msg += "shape=%s | dtype=%s" % ( + str(param.shape).ljust(18), + param.dtype, + ) + else: + log_msg += str(param) + logger.info("%s -> %s", log_msg, path) + + # Save attention matrix image if provided + if "attn_mask" in kwargs and image_name in kwargs: + mask = kwargs["attn_mask"] + if mask.dim() == 4: + save_matrix = mask[0, 0].cpu() + elif mask.dim() == 3: + save_matrix = mask[0].cpu() + else: + save_matrix = mask.cpu() + + img_path = os.path.join(save_dir, kwargs[image_name]) + save_mask(save_matrix, img_path) + + # Create completion flag + flag_path = os.path.join(save_dir, "complete.flag") + torch.save(torch.tensor(1), flag_path) + logger.info("%s\n[SUCCESS] All parameters saved\n%s", "=" * 50, "=" * 50) + + +def load_params(save_dir=DATASETS, device="cpu") -> Dict[str, object]: + save_dir = os.path.normpath(save_dir) + paths = _get_save_paths(save_dir) + + # Validate directory integrity + if not os.path.exists(paths["l_q"]): + logger.error("Invalid parameter directory: %s", save_dir) + raise FileNotFoundError("Invalid parameter directory: %s" % save_dir) + + # Load parameters with logging (lazy interpolation) + logger.info("\n%s", "=" * 50) + logger.info("[LOAD] Source directory: %s", save_dir) + params = {} + for name, path in paths.items(): + if os.path.exists(path): + params[name] = torch.load(path, map_location=device) + log_msg = "%s: " % name.ljust(15) + if torch.is_tensor(params[name]): + log_msg += "shape=%s | dtype=%s" % ( + str(params[name].shape).ljust(18), + params[name].dtype, + ) + else: + log_msg += str(params[name]) + logger.info("%s <- %s", log_msg, path) + + logger.info("%s\n[SUCCESS] All parameters loaded\n%s", "=" * 50, "=" * 50) + return params + + +def create_and_save_params(bx): + df, bparams = read_and_validate_parameters(bx) + image_name = "image_name" + bparams[image_name] = f"{bx}_{df.shape_info.item()}" + init_result_csv_index(bx) + df_resg = pd.read_csv(result_csv) + logger.info(df_resg[df_resg[INDEX_STR] == bx]) + gene_param_dict = {k: bparams[k] for k in generate_params} + iparams = generate_input(**gene_param_dict) + iparams[image_name] = bparams[image_name] + + extra_params = [ + "dtype", + "max_context_len", + "max_seq_len_q", + "max_target_len", + "alpha", + ] + for param_name in extra_params: + if param_name not in iparams: + iparams[param_name] = bparams[param_name] + + save_params(**iparams) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Read CSV file and run a specific index benchmark" + ) + parser.add_argument( + "--index", type=int, required=True, help="index of the benchmark to run" + ) + args = parser.parse_args() + benchmark_df = pd.read_csv(benchmark_csv) + benchmark_df[INDEX_STR] = benchmark_df[INDEX_STR].astype(int) + all_indices = benchmark_df[INDEX_STR].tolist() + if args.index: + all_indices = [args.index] + for bxx in all_indices: + create_and_save_params(bxx) + # Load example + loaded = load_params()