diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/README.md b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..989d18d969ee0936256a7ea796aecdd42b537945
--- /dev/null
+++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/README.md
@@ -0,0 +1,108 @@
+# 测试hstu npu, gpu性能对比
+npu tag: https://gitee.com/ascend/RecSDK.git branch_v7.1.0-RC1
+gpu tag: https://github.com/NVIDIA/recsys-examples.git v25.05
+
+## 1. 准备一台GPU, 一台NPU 服务器
+## 2. GPU服务器下载编译 recsys-example
+
+```
+git clone https://github.com/NVIDIA/recsys-examples.git
+cd recsys-examples/corelib/hstu
+git checkout v25.05
+make install
+```
+配置GPU服务器端的sftp server服务
+
+## 3. 配置NFS共享路径
+
+## 4. NPU算子编译
+git clone https://gitee.com/ascend/RecSDK.git
+cd RecSDK/mxrec_add_ons/rec_for_torch/operators/
+去掉hstu_dense_forward/op_host/tiling_policy.cpp 函数GeneralShapeCheck 的校验,直接返回true. 编译算子安装
+去掉hstu_dense_backward/op_host/hstu_dense_backward_tiling_common.cpp 函数BasicShapeCheck 的校验,直接返回true. 编译算子安装
+
+
+## 5. 配置当前工程config.py 填写所有x的位置
+```
+# Absolute directory path of the recsys-example project recsys-examples/corelib/hstu (replace with actual path)
+RECSYS_DIR="x"
+
+# Network File System (NFS) mount directory path (replace with actual path)
+NFS_DIR="x"
+
+# IP address of the GPU server (replace with actual IP)
+GPU_IP="x"
+
+# Login username for the GPU server (replace with actual username)
+GPU_USER="x"
+
+# Absolute path to Python3 interpreter (replace with actual path, e.g. /usr/bin/python3)
+PYTHON3="x"
+```
+## 6. 上传test_npu_performance 到NPU 共享目录NFS_DIR
+
+## 7. 登录NPU服务器,执行性能测试命令
+输入参数在benchmark.csv, 结果保存在result.csv中。
+```
+cd test_npu_performance
+pip install -r requirements.txt
+python3 test_benchmark.py # 需交互输入GPU服务器密码
+```
+
+
+### Benchmark文件头说明
+
+| 参数名 | 类型 | 说明 |
+|--------|------|------|
+| index | int | 测试用例的唯一标识符 |
+| shape_info | str | 张量形状描述信息 |
+| total_len | int | 输入序列总长度 |
+| batch_size | int | 批量大小 |
+| heads | int | 注意力头数 |
+| heads_rab | int | 使用RAB时的头数 |
+| max_seq_len_q | int | 查询序列最大长度 |
+| max_seq_len_k | int | 键序列最大长度 |
+| max_context_len | int | 上下文序列最大长度 |
+| max_target_len | int | 目标序列最大长度 |
+| target_group_size | int | 目标序列分组大小 |
+| attn_dim | int | 注意力机制维度 |
+| hidden_dim | int | 隐藏层维度 |
+| alpha | float | 缩放因子 |
+| has_rab | bool | 是否使用RAB |
+| has_drab | bool | 是否使用DRAB |
+| window_size | int | 滑动窗口大小 |
+| run_benchmark | bool | 是否运行基准测试 |
+| dtype | str | 数据类型 |
+| full_batch | bool | 是否使用完整批次 |
+| is_delta_q | bool | 是否使用增量查询 |
+| format | str | 数据格式 |
+| npu_fw_time | float | NPU前向时间(ms) |
+| npu_bw_time | float | NPU反向时间(ms) |
+| gpu_fw_time | float | GPU前向时间(ms) |
+| gpu_bw_time | float | GPU反向时间(ms) |
+| precision | bool | 计算精度是否对其 |
+| npu_fw/gpu_fw | float | NPU/GPU前向时间比 |
+| npu_bw/gpu_bw | float | NPU/GPU反向时间比 |
+| npu_fw+bw/gpu_fw+bw | float | NPU/GPU总时间比 |
+
+### Result文件头说明
+
+| 参数名 | 类型 | 说明 |
+|--------|------|------|
+| precision | bool | 计算精度是否对其 |
+| npu_fw/gpu_fw | float | NPU/GPU前向时间比 |
+| npu_bw/gpu_bw | float | NPU/GPU反向时间比 |
+| npu_fw+bw/gpu_fw+bw | float | NPU/GPU总时间比 |
+| npu_fw/benchmark | float | NPU/基准前向时间比 |
+| npu_bw/benchmark | float | NPU/基准反向时间比 |
+| npu_fw+bw/benchmark | float | NPU/基准总时间比 |
+
+### 使用说明
+
+1. **性能对比**:
+ - 比值>1表示NPU更快
+ - 比值<1表示GPU更快
+
+2. **基准验证**:
+ - 比值≈1表示符合预期
+ - 显著偏离1需检查
\ No newline at end of file
diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/benchmark.csv b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/benchmark.csv
new file mode 100644
index 0000000000000000000000000000000000000000..c15c1ac824b0039ef163f7dc05ee6fa42f815f66
--- /dev/null
+++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/benchmark.csv
@@ -0,0 +1,30 @@
+index,shape_info,total_len,batch_size,heads,heads_rab,max_seq_len_q,max_seq_len_k,max_context_len,max_target_len,target_group_size,attn_dim,hidden_dim,alpha,has_rab,has_drab,window_size,run_benchmark,dtype,full_batch,is_delta_q,format,npu_fw_time,npu_bw_time,gpu_fw_time,gpu_bw_time,precision,npu_fw/gpu_fw,npu_bw/gpu_bw,npu_fw+bw/gpu_fw+bw
+1,01shape,64570,96,3,3,0,0,1991,417,1,256,256,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,1.727,3.729,0.991,3.276,TRUE,0.574,0.879,0.782
+2,02shape,88696,8,4,4,16384,16384,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,20.419,65.717,6.475,18.819,TRUE,0.317,0.286,0.294
+3,03shape,21536,32,4,4,1024,1024,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,0.42,1.165,0.197,0.552,TRUE,0.469,0.474,0.473
+4,blockless256,400,8,4,4,100,100,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,0.069,0.098,0.049,0.236,TRUE,0.71,2.408,1.707
+5,bound,40960,4,2,2,20480,20480,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,5.871,19.499,2.058,5.167,TRUE,0.351,0.265,0.285
+6,Random,65536,32,8,8,4096,4096,0,0,1,256,256,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,8.105,18.877,3.097,11.97,TRUE,0.382,0.634,0.558
+7,Random,65536,32,8,8,4096,4096,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,5.797,15.801,1.829,5.533,TRUE,0.316,0.35,0.341
+8,Random,65536,32,2,2,4096,4096,0,0,1,256,256,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,2.12,5.043,1.099,3.231,TRUE,0.518,0.641,0.604
+9,Random,65536,32,2,2,4096,4096,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,1.427,4.169,0.659,1.603,TRUE,0.462,0.385,0.404
+10,Random,54400,32,8,8,3400,3400,0,0,1,256,256,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,8.082,17.313,2.594,8.767,TRUE,0.321,0.506,0.447
+11,Random,54400,32,8,8,3400,3400,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,6.057,14.197,1.405,4.285,TRUE,0.232,0.302,0.281
+12,Random,54400,32,2,2,3400,3400,0,0,1,256,256,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,2.049,4.408,0.704,2.39,TRUE,0.344,0.542,0.479
+13,Random,54400,32,2,2,3400,3400,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,1.511,3.486,0.468,1.166,TRUE,0.31,0.334,0.327
+14,Random,19248,32,8,8,1203,1203,0,0,1,256,256,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,1.324,2.915,0.482,1.71,TRUE,0.364,0.587,0.517
+15,Random,19248,32,8,8,1203,1203,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,0.99,2.347,0.325,0.884,TRUE,0.328,0.377,0.362
+16,Random,19248,32,2,2,1203,1203,0,0,1,256,256,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,0.438,0.861,0.175,0.813,TRUE,0.4,0.944,0.761
+17,Random,19248,32,2,2,1203,1203,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,0.333,0.674,0.113,0.422,TRUE,0.339,0.626,0.531
+18,Random,16384,8,8,8,4096,4096,0,0,1,256,256,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,2.456,5.708,1.056,3.704,TRUE,0.43,0.649,0.583
+19,Random,16384,8,8,8,4096,4096,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,1.624,4.827,0.624,1.684,TRUE,0.384,0.349,0.358
+20,Random,16384,8,2,2,4096,4096,0,0,1,256,256,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,0.687,1.607,0.375,1.144,TRUE,0.546,0.712,0.662
+21,Random,16384,8,2,2,4096,4096,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,0.47,1.402,0.29,0.524,TRUE,0.617,0.374,0.435
+22,Random,13600,8,8,8,3400,3400,0,0,1,256,256,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,2.395,5.084,0.944,2.727,TRUE,0.394,0.536,0.491
+23,Random,13600,8,8,8,3400,3400,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,1.78,4.091,0.544,1.703,TRUE,0.306,0.416,0.383
+24,Random,13600,8,2,2,3400,3400,0,0,1,256,256,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,0.798,1.563,0.307,0.968,TRUE,0.385,0.619,0.54
+25,Random,13600,8,2,2,3400,3400,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,0.592,1.194,0.231,0.409,TRUE,0.39,0.343,0.358
+26,Random,4812,8,8,8,1203,1203,0,0,1,256,256,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,0.473,0.922,0.158,0.523,TRUE,0.334,0.567,0.488
+27,Random,4812,8,8,8,1203,1203,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,0.351,0.712,0.135,0.294,TRUE,0.385,0.413,0.404
+28,Random,4812,8,2,2,1203,1203,0,0,1,256,256,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,0.173,0.328,0.091,0.23,TRUE,0.526,0.701,0.641
+29,Random,4812,8,2,2,1203,1203,0,0,1,128,128,1,FALSE,FALSE,"(-1, 0)",0b01,torch.bfloat16,FALSE,FALSE,jagged,0.131,0.284,0.074,0.137,TRUE,0.565,0.482,0.508
diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/requirements.txt b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..7b73200af8ca162f91f4948a1c467f7c52ee52d8
--- /dev/null
+++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/requirements.txt
@@ -0,0 +1,5 @@
+pandas
+numpy
+matplotlib
+paramiko==3.5.1
+torch==2.6.0
diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_benchmark.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d0f4f47905d0e5393cb579385c975a63a5024b0
--- /dev/null
+++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_benchmark.py
@@ -0,0 +1,368 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+# Copyright 2025. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+import argparse
+import time
+import getpass
+import os
+import subprocess
+from typing import Optional
+import socket
+from select import select
+
+import torch
+import paramiko
+import pandas as pd
+
+import config
+from test_read_benchmark import (
+ benchmark_csv,
+ result_csv,
+ logger,
+ DATASETS,
+ init_result_csv_index,
+ create_and_save_params,
+)
+from test_msprof_npu_hstu import msprof_main
+
+
+INDEX_STR = "index"
+GPU_FW_TIME = "gpu_fw_time"
+NPU_FW_TIME = "npu_fw_time"
+NPU_BW_TIME = "npu_bw_time"
+GPU_BW_TIME = "gpu_bw_time"
+
+
+_remote_password_cache = None
+
+
+def get_remote_password() -> Optional[str]:
+ """Securely retrieves and caches the GPU server password (single input, reusable for the session).
+
+ Returns:
+ Optional[str]: The password if successfully retrieved, None otherwise.
+ """
+ global _remote_password_cache
+
+ if _remote_password_cache is not None:
+ return _remote_password_cache
+
+ try:
+ # Interactive input (hidden)
+ logger.info("\n[Security Notice] Password input will not display characters.")
+ password = getpass.getpass(
+ prompt="Enter GPU server password (valid for this session): "
+ )
+ if password.strip():
+ _remote_password_cache = password
+ return _remote_password_cache
+ else:
+ logger.error("Password cannot be empty.")
+ return None
+
+ except Exception as e:
+ logger.error(f"Password retrieval failed: {str(e)}")
+ return None
+
+
+def execute_remote_linux_cmd(client, cmd, timeout=600):
+ """Secure execution of Linux remote commands (supports timeout control and complete output capture)"""
+ stdin, stdout, stderr = None, None, None
+ exit_status = -1
+ output_buffer = []
+ error_buffer = []
+
+ try:
+ # Create execution channel (set timeout to prevent network blocking)
+ transport = client.get_transport()
+ channel = transport.open_session(timeout=timeout)
+ channel.settimeout(timeout)
+
+ # Execute command (non-blocking mode)
+ channel.exec_command(cmd)
+ stdin = channel.makefile_stdin("wb")
+ stdout = channel.makefile("r")
+ stderr = channel.makefile_stderr("r")
+
+ # Use select for efficient I/O multiplexing
+ start_time = time.time()
+ while not channel.exit_status_ready():
+ # Timeout check
+ if time.time() - start_time > timeout:
+ raise socket.timeout(f"Command timeout after {timeout}s: {cmd}")
+
+ # Wait for readable events (0.5s polling interval)
+ rlist, _, _ = select([channel], [], [], 0.5)
+ if not rlist:
+ continue
+
+ # Prioritize reading error stream (avoid buffer overflow)
+ while channel.recv_stderr_ready():
+ error_buffer.append(stderr.read(4096))
+
+ while channel.recv_ready():
+ output_buffer.append(stdout.read(4096))
+
+ # Get final exit status
+ exit_status = channel.recv_exit_status()
+
+ # Read all remaining output
+ output_buffer.append(stdout.read())
+ error_buffer.append(stderr.read())
+
+ except socket.timeout as te:
+ logger.error(f"Command execution timed out: {te}")
+ # Attempt to terminate remote process (send SIGTERM)
+ if "channel" in locals():
+ channel.close()
+ return False
+ except Exception as e:
+ logger.error(f"SSH command execution failed: {e}", exc_info=True)
+ return False
+ finally:
+ # Safely close channels (without closing underlying client connection)
+ for stream in [stdin, stdout, stderr]:
+ try:
+ if stream:
+ stream.close()
+ except Exception as e:
+ logger.info(e)
+
+ try:
+ if "channel" in locals():
+ channel.close()
+ except Exception as e:
+ logger.info(e)
+
+ # Build output results
+ decoded_output = [
+ item.decode("utf-8") if isinstance(item, bytes) else item
+ for item in output_buffer
+ ]
+ decoded_error = [
+ item.decode("utf-8") if isinstance(item, bytes) else item
+ for item in error_buffer
+ ]
+ full_output = "".join(decoded_output)
+ full_error = "".join(decoded_error)
+
+ # Log results (limit log length)
+ logger.info(f"Command [{cmd}] exited with status: {exit_status}")
+ if full_output:
+ logger.debug(
+ f"stdout: {full_output[:2000]}{'...' if len(full_output)>2000 else ''}"
+ )
+ if full_error:
+ logger.warning(
+ f"stderr: {full_error[:2000]}{'...' if len(full_error)>2000 else ''}"
+ )
+
+ return exit_status == 0
+
+
+def transfer_and_execute(bx):
+ if not isinstance(bx, int):
+ raise ValueError(f"input must be an integer but got {bx}")
+ remote_host = config.GPU_IP
+ remote_user = config.GPU_USER
+ remote_password = get_remote_password()
+ remote_dir = os.path.realpath(config.RECSYS_DIR)
+ client = paramiko.SSHClient()
+ client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+ cmd = f"cd {remote_dir} && {os.path.realpath(config.PYTHON3)} test_gpu_hstu.py --index={bx}"
+
+ try:
+ client.connect(
+ remote_host, port=22, username=remote_user, password=remote_password
+ )
+ sftp = client.open_sftp()
+ files_to_transfer = ["test_gpu_hstu.py", "test_read_benchmark.py", "config.py"]
+ for file in files_to_transfer:
+ sftp.put(file, os.path.join(remote_dir, file))
+ sftp.close()
+
+ logger.info(f"Executing remote script: {cmd}")
+ return execute_remote_linux_cmd(client, cmd)
+ except Exception as e:
+ logger.error(f"Failed to execute remote script: {cmd}")
+ logger.error(e)
+ return False
+ finally:
+ # 安全关闭连接
+ try:
+ if client:
+ client.close()
+ except Exception as close_error:
+ logger.error(f"SSH close error: {close_error}")
+ logger.info(f"Exit remote script: {cmd}")
+
+
+def compare_npu_gpu_precision(save_dir=DATASETS, device="cpu"):
+ logger.info(f"Compating npu and gpu results of {save_dir}")
+ try:
+ data_type = torch.load(os.path.join(save_dir, "dtype.pth"), map_location=device)
+ npu_out = torch.load(
+ os.path.join(save_dir, "npu_out.pth"), map_location=device
+ ).to(dtype=data_type)
+ gpu_out = torch.load(
+ os.path.join(save_dir, "gpu_out.pth"), map_location=device
+ ).to(dtype=data_type)
+ gpu_out = gpu_out.view(gpu_out.shape[0], -1)
+
+ npu_q = torch.load(os.path.join(save_dir, "npu_q.pth"), map_location=device).to(
+ dtype=data_type
+ )
+ gpu_q = torch.load(os.path.join(save_dir, "gpu_q.pth"), map_location=device).to(
+ dtype=data_type
+ )
+
+ npu_k = torch.load(os.path.join(save_dir, "npu_k.pth"), map_location=device).to(
+ dtype=data_type
+ )
+ gpu_k = torch.load(os.path.join(save_dir, "gpu_k.pth"), map_location=device).to(
+ dtype=data_type
+ )
+
+ npu_v = torch.load(os.path.join(save_dir, "npu_v.pth"), map_location=device).to(
+ dtype=data_type
+ )
+ gpu_v = torch.load(os.path.join(save_dir, "gpu_v.pth"), map_location=device).to(
+ dtype=data_type
+ )
+
+ except Exception as e:
+ logger.error(f"error : {e}")
+ return False
+
+ if data_type == torch.bfloat16:
+ eps = 1e-2
+ elif data_type == torch.float16:
+ eps = 1e-3
+ else:
+ logger.error(f"error type : {data_type}")
+ return False
+
+ try:
+ out_close = torch.allclose(npu_out, gpu_out, eps, eps)
+ out_q = torch.allclose(npu_q, gpu_q, eps, eps)
+ out_k = torch.allclose(npu_k, gpu_k, eps, eps)
+ out_v = torch.allclose(npu_v, gpu_v, eps, eps)
+ except Exception as e:
+ logger.error(f"error : {e}")
+ return False
+
+ logger.info(f"npu_out vs gpu_out: {out_close}")
+ logger.info(f"npu_q vs gpu_q: {out_q}")
+ logger.info(f"npu_k vs gpu_k: {out_k}")
+ logger.info(f"npu_v vs gpu_v: {out_v}")
+
+ ret = out_close and out_q and out_k and out_v
+ return ret
+
+
+def update_index_csv(df_res, benchmark_df, bx, precision):
+ mask_res = df_res[INDEX_STR] == bx
+ mask_benchmark = benchmark_df[INDEX_STR] == bx
+
+ df_res.loc[mask_res, "precision"] = precision
+ df_res.loc[mask_res, "npu_fw/gpu_fw"] = (
+ df_res.loc[mask_res, GPU_FW_TIME] / df_res.loc[mask_res, NPU_FW_TIME]
+ )
+ df_res.loc[mask_res, "npu_bw/gpu_bw"] = (
+ df_res.loc[mask_res, GPU_BW_TIME] / df_res.loc[mask_res, NPU_BW_TIME]
+ )
+
+ df_res.loc[mask_res, "npu_fw+bw/gpu_fw+bw"] = df_res.loc[
+ mask_res, [GPU_FW_TIME, GPU_BW_TIME]
+ ].sum(axis=1) / df_res.loc[mask_res, [NPU_FW_TIME, NPU_BW_TIME]].sum(axis=1)
+ df_res.loc[mask_res, "npu_fw/benchmark"] = (
+ benchmark_df.loc[mask_benchmark, NPU_FW_TIME].item()
+ / df_res.loc[mask_res, NPU_FW_TIME].item()
+ )
+ df_res.loc[mask_res, "npu_bw/benchmark"] = (
+ benchmark_df.loc[mask_benchmark, NPU_BW_TIME].item()
+ / df_res.loc[mask_res, NPU_BW_TIME].item()
+ )
+ df_res.loc[mask_res, "npu_fw+bw/benchmark"] = benchmark_df.loc[
+ mask_benchmark, [NPU_FW_TIME, NPU_BW_TIME]
+ ].sum(axis=1).item() / df_res.loc[mask_res, [NPU_FW_TIME, NPU_BW_TIME]].sum(axis=1).item()
+
+
+
+def retry_operation(operation, operation_name, bx, max_retries=2):
+ for attempt in range(max_retries):
+ ret = operation(bx)
+ if ret:
+ return True
+ logger.warning(
+ f"{operation_name} failed for benchmark {bx}, attempt {attempt + 1}/{max_retries}"
+ )
+ logger.error(
+ f"{operation_name} failed for benchmark {bx} after {max_retries} retries"
+ )
+ return False
+
+
+def main(index=None):
+ benchmark_df = pd.read_csv(benchmark_csv)
+ benchmark_df[INDEX_STR] = benchmark_df[INDEX_STR].astype(int)
+ all_indices = benchmark_df[INDEX_STR].tolist()
+
+ if index is not None:
+ all_indices = [index]
+ failed = []
+ for bx in all_indices:
+ logger.info(f"benchmark {bx} testing")
+ init_result_csv_index(bx)
+
+ df_res = pd.read_csv(result_csv)
+
+ if (
+ bx in df_res[INDEX_STR].values
+ and df_res[df_res[INDEX_STR] == bx].notna().all().all()
+ ):
+ logger.info(f"benchmark {bx} already, pass")
+ continue
+
+ create_and_save_params(bx)
+ remote_success = retry_operation(transfer_and_execute, "Remote execution", bx)
+ if not remote_success:
+ failed.append(bx)
+ continue
+
+ local_success = retry_operation(msprof_main, "Local execution", bx)
+ if not local_success:
+ failed.append(bx)
+ continue
+
+ df_res = pd.read_csv(result_csv)
+ precision = compare_npu_gpu_precision()
+ update_index_csv(df_res, benchmark_df, bx, precision)
+ df_res.to_csv(result_csv, index=False)
+ logger.info(f"benchmark {bx} tested")
+
+ if len(failed) != 0:
+ logger.error("index %s failed!", ", ".join(str(x) for x in failed))
+ else:
+ logger.info("All successed!")
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Run benchmark tests.")
+ parser.add_argument("--index", type=int, help="benchmark to run. Default run all.")
+ args = parser.parse_args()
+ main(args.index)
diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_gpu_hstu.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_gpu_hstu.py
new file mode 100644
index 0000000000000000000000000000000000000000..586a7f5c7d995652faa0859f41249bb961a2da4b
--- /dev/null
+++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_gpu_hstu.py
@@ -0,0 +1,552 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+# Copyright 2025. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+import argparse
+import math
+import os
+import traceback
+from statistics import mean
+from typing import Optional, Tuple
+
+import matplotlib.colors as mcolors
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from pynvml import (
+ nvmlDeviceGetCount,
+ nvmlDeviceGetHandleByIndex,
+ nvmlDeviceGetMemoryInfo,
+ nvmlInit,
+ nvmlShutdown,
+)
+
+from test_read_benchmark import (
+ DATASETS,
+ init_result_csv_index,
+ logger,
+ read_and_validate_parameters,
+ result_csv,
+ load_params,
+)
+
+sm_major_version = torch.cuda.get_device_properties(0).major
+sm_minor_version = torch.cuda.get_device_properties(0).minor
+if sm_major_version == 9 and sm_minor_version == 0:
+ from hstu_attn_interface import hstu_attn_varlen_func
+elif sm_major_version == 8:
+ from hstu_attn import hstu_attn_varlen_func
+
+
+PERFORMANCE = True
+if PERFORMANCE:
+ g_iterations = 100
+ g_profiler_step_start = 20
+ loop = 3
+else:
+ g_iterations = 1
+ g_profiler_step_start = 0
+ loop = 3
+
+
+def get_gpu_memory_info():
+ nvmlInit()
+ device_count = nvmlDeviceGetCount()
+ memory_info = []
+ for i in range(device_count):
+ handle = nvmlDeviceGetHandleByIndex(i)
+ info = nvmlDeviceGetMemoryInfo(handle)
+ memory_info.append((i, info.total, info.used))
+ nvmlShutdown()
+ return memory_info
+
+
+def auto_select_gpu():
+ memory_info = get_gpu_memory_info()
+ min_memory_used = float("inf")
+ best_gpu_index = None
+ for i, total, used in memory_info:
+ logger.info(
+ f"GPU {i}: Total Memory = {total / 1024**2} MiB, Used Memory = {used / 1024**2} MiB"
+ )
+ if used < min_memory_used:
+ min_memory_used = used
+ best_gpu_index = i
+ return best_gpu_index
+
+
+gdevice = auto_select_gpu()
+logger.info(f"Selected GPU device index: {gdevice}")
+torch.cuda.set_device(gdevice)
+
+
+def pad_input(unpadded_input, cu_seqlen, batch, seqlen):
+ indices = []
+ for i in range(batch):
+ indices.append(
+ torch.arange(seqlen * i, seqlen * i + cu_seqlen[i + 1] - cu_seqlen[i])
+ )
+ indices = torch.cat(indices)
+ output = torch.zeros(
+ (batch * seqlen),
+ *unpadded_input.shape[1:],
+ device=unpadded_input.device,
+ dtype=unpadded_input.dtype,
+ )
+ output[indices] = unpadded_input
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)
+
+
+def pad_input_delta_q(unpadded_input, cu_seqlen_q, cu_seqlen_k, batch, seqlen):
+ indices = []
+ for i in range(batch):
+ act_seqlen_q = (cu_seqlen_q[i + 1] - cu_seqlen_q[i]).item()
+ act_seqlen_k = (cu_seqlen_k[i + 1] - cu_seqlen_k[i]).item()
+ indices.append(
+ torch.arange(
+ seqlen * i + act_seqlen_k - act_seqlen_q, seqlen * i + act_seqlen_k
+ )
+ )
+ indices = torch.cat(indices)
+ output = torch.zeros(
+ (batch * seqlen),
+ *unpadded_input.shape[1:],
+ device=unpadded_input.device,
+ dtype=unpadded_input.dtype,
+ )
+ output[indices] = unpadded_input
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)
+
+
+def unpad_input(padded_input, cu_seqlen):
+ padded_input.reshape(padded_input.size(0), padded_input.size(1), -1)
+ output = []
+ for i in range(len(cu_seqlen) - 1):
+ output.append(padded_input[i, : (cu_seqlen[i + 1] - cu_seqlen[i]), :])
+ return torch.cat(output, dim=0)
+
+
+def unpad_input_delta_q(padded_input, cu_seqlen_q, cu_seqlen_k, batch, seqlen):
+ padded_input.reshape(padded_input.size(0), padded_input.size(1), -1)
+ output = []
+ for i in range(batch):
+ act_seqlen_q = (cu_seqlen_q[i + 1] - cu_seqlen_q[i]).item()
+ act_seqlen_k = (cu_seqlen_k[i + 1] - cu_seqlen_k[i]).item()
+ output.append(padded_input[i, act_seqlen_k - act_seqlen_q: act_seqlen_k, :])
+ return torch.cat(output, dim=0)
+
+
+def _hstu_attention_maybe_from_cache(
+ num_heads: int,
+ attention_dim: int,
+ linear_dim: int,
+ seqlen_q: int,
+ seqlen_k: int,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ q_offsets: torch.Tensor,
+ k_offsets: torch.Tensor,
+ rab: Optional[torch.Tensor],
+ invalid_attn_mask: torch.Tensor,
+ alpha: float,
+ upcast: bool = True,
+ reorder_op: bool = False,
+ is_delta_q: bool = False,
+):
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
+ batch: int = q_offsets.size(0) - 1
+ dtype_out = q.dtype
+ if is_delta_q:
+ padded_q = pad_input_delta_q(q, q_offsets, k_offsets, batch, seqlen_k)
+ else:
+ padded_q = pad_input(q, q_offsets, batch, seqlen_q)
+ padded_k = pad_input(k, k_offsets, batch, seqlen_k)
+ padded_v = pad_input(v, k_offsets, batch, seqlen_k)
+
+ padded_q = padded_q.view(batch, seqlen_k, num_heads, attention_dim)
+ padded_k = padded_k.view(batch, seqlen_k, num_heads, attention_dim)
+ padded_v = padded_v.view(batch, seqlen_k, num_heads, linear_dim)
+ if upcast:
+ padded_q, padded_k, padded_v = (
+ padded_q.float(),
+ padded_k.float(),
+ padded_v.float(),
+ )
+ if rab is not None:
+ rab = rab.float()
+ qk_attn = torch.einsum(
+ "bnhd,bmhd->bhnm",
+ padded_q,
+ padded_k,
+ )
+
+ if rab is not None:
+ padding = (
+ 0,
+ qk_attn.shape[-1] - rab.shape[-1],
+ 0,
+ qk_attn.shape[-2] - rab.shape[-2],
+ )
+ rab = F.pad(rab, padding, value=0)
+ masked_qk_attn = qk_attn + rab
+ else:
+ masked_qk_attn = qk_attn
+ masked_qk_attn = masked_qk_attn * alpha
+ masked_qk_attn = F.silu(masked_qk_attn)
+ masked_qk_attn = masked_qk_attn / seqlen_q
+ if invalid_attn_mask is not None:
+ if invalid_attn_mask.ndim == 2:
+ invalid_attn_mask = invalid_attn_mask.unsqueeze(0).unsqueeze(0)
+ masked_qk_attn = (
+ masked_qk_attn * invalid_attn_mask.type(masked_qk_attn.dtype)[:, :, :, :]
+ )
+
+ attn_output = torch.einsum(
+ "bhnm,bmhd->bnhd",
+ masked_qk_attn,
+ padded_v,
+ )
+
+ attn_output = attn_output.reshape(batch, seqlen_k, num_heads * linear_dim)
+ if is_delta_q:
+ attn_output = unpad_input_delta_q(
+ attn_output, q_offsets, k_offsets, batch, seqlen_k
+ )
+ else:
+ attn_output = unpad_input(attn_output, q_offsets)
+ attn_output = attn_output.reshape(-1, num_heads * linear_dim)
+
+ return attn_output.to(dtype_out)
+
+
+def test_fused_attn(
+ total_len: int,
+ batch_size: int,
+ heads: int,
+ heads_rab: Optional[int],
+ max_seq_len_q: int,
+ max_seq_len_k: int,
+ max_context_len: int,
+ max_target_len: int,
+ target_group_size: int,
+ attn_dim: int,
+ hidden_dim: int,
+ alpha: float,
+ has_rab: bool,
+ has_drab: bool,
+ window_size: Tuple[int, int],
+ run_benchmark: Optional[int],
+ dtype: torch.dtype,
+ full_batch: bool,
+ is_delta_q: bool,
+) -> Tuple[Optional[float], Optional[float]]:
+ has_context = max_context_len > 0
+ has_target = max_target_len > 0
+ group_target = target_group_size > 1
+ is_causal = window_size[0] == -1 and window_size[1] == 0
+ if dtype == torch.float8_e4m3fn:
+ raise ValueError(
+ "float8_e4m3fn is not supported, please use test_fused_attn_fp8 instead"
+ )
+ if has_drab and not has_rab:
+ raise ValueError("has_drab is True but has_rab is False")
+ cond = has_target and (window_size[0] > 0 or window_size[1] > 0)
+ if (has_target and not is_causal) or cond:
+ raise ValueError(
+ "has_target is True but is_causal is False or window_size is not (-1, -1)"
+ )
+ if (max_seq_len_q != max_seq_len_k) and not is_delta_q:
+ raise ValueError("max_seq_len_q != max_seq_len_k but is_delta_q is False")
+ if is_delta_q and max_seq_len_q > max_seq_len_k:
+ raise ValueError("is_delta_q is True but max_seq_len_q > max_seq_len_k")
+ if group_target and not has_target:
+ raise ValueError("group_target is True but has_target is False")
+
+ if is_delta_q and has_target:
+ return None, None
+ if is_delta_q and has_context:
+ return None, None
+ if not is_causal and has_context:
+ return None, None
+ if (window_size[0] > 0 or window_size[1] > 0) and has_context:
+ return None, None
+
+ torch.cuda.synchronize()
+
+ if run_benchmark not in [
+ 0b01,
+ 0b10,
+ 0b11,
+ ]: # 0b01 is run hstu benchmark and 0b10 is run torch benchmark, 0b11 is run both and compare precision
+ raise ValueError("run_benchmark should be in [0b01, 0b10, 0b11]")
+
+ iterations = g_iterations
+ profiler_step_start = g_profiler_step_start
+
+ (
+ lq,
+ lk,
+ num_contexts,
+ seq_offsets_q,
+ seq_offsets_k,
+ num_targets,
+ q,
+ k,
+ v,
+ rab,
+ attn_mask,
+ grad,
+ ) = read_param()
+
+ logger.info(
+ f"max_context_len: {max_context_len}, max_seq_len_q: {max_seq_len_q}, "
+ f"max_target_len: {max_target_len}"
+ )
+ logger.info(f"q.shape: {q.shape}")
+ logger.info(f"k.shape: {k.shape}")
+ logger.info(f"v.shape: {v.shape}")
+ logger.info(f"grad.shape: {grad.shape}")
+ logger.info(f"attn_mask.shape: {attn_mask.shape}")
+ logger.info(f"seq_offsets_q.shape: {seq_offsets_q.shape}")
+ logger.info(f"seq_offsets_k.shape: {seq_offsets_k.shape}")
+ logger.info(
+ f"total_max_seq_len_q: {max_seq_len_q + max_context_len + max_target_len}"
+ )
+ logger.info(
+ f"total_max_seq_len_k: {max_seq_len_k + max_context_len + max_target_len}"
+ )
+ logger.info(
+ f"num_contexts.shape: {num_contexts.shape if (has_context and run_benchmark & 0b01) else None}"
+ )
+ logger.info(
+ f"num_targets.shape: {num_targets.shape if (has_target and run_benchmark & 0b01) else None}"
+ )
+ logger.info(f"target_group_size: {target_group_size}")
+ logger.info(f"window_size: {window_size}")
+ logger.info(f"alpha: {alpha}")
+ logger.info(f"rab.shape: {rab.shape if has_rab else None}")
+ logger.info(f"has_drab: {has_drab}")
+ logger.info(f"is_delta_q: {is_delta_q}")
+
+ fwd_event_start = torch.cuda.Event(enable_timing=True)
+ fwd_event_stop = torch.cuda.Event(enable_timing=True)
+ torch.cuda.synchronize()
+ for i in range(iterations):
+ if i == profiler_step_start:
+ fwd_event_start.record()
+
+ if run_benchmark & 0b01:
+ out_hstu = hstu_attn_varlen_func(
+ q=q,
+ k=k,
+ v=v,
+ seq_offsets_q=seq_offsets_q,
+ seq_offsets_k=seq_offsets_k,
+ max_seqlen_q=max_context_len + max_seq_len_q + max_target_len,
+ max_seqlen_k=max_context_len + max_seq_len_k + max_target_len,
+ num_contexts=num_contexts if has_context else None,
+ num_targets=num_targets if has_target else None,
+ target_group_size=target_group_size,
+ window_size=window_size,
+ alpha=alpha,
+ rab=rab if has_rab else None,
+ has_drab=has_drab,
+ is_delta_q=is_delta_q,
+ )
+ if run_benchmark & 0b10:
+ out_torch = _hstu_attention_maybe_from_cache(
+ num_heads=heads,
+ attention_dim=attn_dim,
+ linear_dim=hidden_dim,
+ seqlen_q=max_context_len + max_seq_len_q + max_target_len,
+ seqlen_k=max_context_len + max_seq_len_k + max_target_len,
+ q=q.view(lq, -1),
+ k=k.view(lk, -1),
+ v=v.view(lk, -1),
+ q_offsets=seq_offsets_q,
+ k_offsets=seq_offsets_k,
+ rab=rab if has_rab else None,
+ invalid_attn_mask=(
+ attn_mask.to(torch.float32) if attn_mask is not None else None
+ ),
+ alpha=alpha,
+ upcast=False,
+ reorder_op=True,
+ is_delta_q=is_delta_q,
+ )
+ out_torch = out_torch.reshape(-1, heads, attn_dim)
+ fwd_event_stop.record()
+ torch.cuda.synchronize()
+ fwd_time = fwd_event_start.elapsed_time(fwd_event_stop) / (
+ iterations - profiler_step_start
+ )
+
+ bwd_event_start = torch.cuda.Event(enable_timing=True)
+ bwd_event_stop = torch.cuda.Event(enable_timing=True)
+ torch.cuda.synchronize()
+ for i in range(iterations):
+ if i == profiler_step_start:
+ bwd_event_start.record()
+
+ autograd_input = (q, k, v, rab) if has_rab else (q, k, v)
+ if run_benchmark & 0b01:
+ dq_hstu, dk_hstu, dv_hstu = torch.autograd.grad(
+ out_hstu, autograd_input, grad, retain_graph=True
+ )
+ if run_benchmark & 0b10:
+ dq_torch, dk_torch, dv_torch = torch.autograd.grad(
+ out_torch, autograd_input, grad, retain_graph=True
+ )
+ bwd_event_stop.record()
+ torch.cuda.synchronize()
+ bwd_time = bwd_event_start.elapsed_time(bwd_event_stop) / (
+ iterations - profiler_step_start
+ )
+
+ if run_benchmark & 0b11 == 0b11:
+ if dtype == torch.bfloat16:
+ eps = 1e-2
+ elif dtype == torch.float16:
+ eps = 1e-3
+ else:
+ eps = 1e-4
+
+ out_close = torch.allclose(out_hstu, out_torch, eps, eps)
+ q_close = torch.allclose(dq_hstu, dq_torch, eps, eps)
+ k_close = torch.allclose(dk_hstu, dk_torch, eps, eps)
+ v_close = torch.allclose(dv_hstu, dv_torch, eps, eps)
+
+ logger.info(f"cpu_out vs gpu_out: {out_close}")
+ logger.info(f"cpu_q vs gpu_q: {q_close}")
+ logger.info(f"cpu_k vs gpu_k: {k_close}")
+ logger.info(f"cpu_v vs gpu_v: {v_close}")
+ logger.info(f"all pass: {(out_close and q_close and k_close and v_close)}")
+
+ save_dir = DATASETS
+ prefix = "gpu_"
+ logger.info(f"prefix: {prefix}")
+
+ cpu = "cpu"
+ torch.save(out_hstu.to(cpu), os.path.join(save_dir, f"{prefix}out.pth"))
+ torch.save(dq_hstu.to(cpu), os.path.join(save_dir, f"{prefix}q.pth"))
+ torch.save(dk_hstu.to(cpu), os.path.join(save_dir, f"{prefix}k.pth"))
+ torch.save(dv_hstu.to(cpu), os.path.join(save_dir, f"{prefix}v.pth"))
+
+ return fwd_time, bwd_time
+
+
+def read_param():
+ """Load parameters and automatically convert to CUDA tensors if needed."""
+ # Parameter field names in original order
+ params_def = [
+ "l_q", # length_q
+ "l_k", # length_k
+ "num_contexts", # num_contexts
+ "seq_offsets_q_wt", # seq_offsets_q
+ "seq_offsets_k_wt", # seq_offsets_k
+ "num_targets", # num_targets
+ "q", # query
+ "k", # key
+ "v", # value
+ "rab", # relative_attention_bias
+ "attn_mask", # attention_mask
+ "grad", # gradient
+ ]
+
+ # Error messages
+ _error_key = "key_error"
+ _error_messages = {
+ _error_key: "Missing parameter: {}",
+ "cuda_transfer_failed": "CUDA transfer failed",
+ }
+
+ param = load_params(DATASETS)
+ result = []
+
+ try:
+ for field_name in params_def:
+ param_value = param[field_name]
+ result.append(
+ param_value.cuda()
+ if isinstance(param_value, torch.Tensor)
+ else param_value
+ )
+ return tuple(result)
+
+ except KeyError as e:
+ error_msg = _error_messages[_error_key].format(e)
+ logger.error(error_msg, exc_info=True)
+ raise
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Read CSV file and run a specific index benchmark"
+ )
+ parser.add_argument(
+ "--index", type=int, required=True, help="index of the benchmark to run"
+ )
+ args = parser.parse_args()
+
+ _, params = read_and_validate_parameters(args.index)
+
+ init_result_csv_index(args.index)
+ df_res = pd.read_csv(result_csv)
+ select_mask = df_res["index"] == args.index
+ if (
+ df_res.loc[select_mask, "gpu_fw_time"].notna().all()
+ and df_res.loc[select_mask, "gpu_bw_time"].notna().all()
+ ):
+ logger.info(f"bemchmark {args.index} already")
+ exit(0)
+
+ if params is None:
+ logger.error("Invalid data, benchmark failed")
+ return
+
+ fwd_time_list = []
+ bwd_time_list = []
+
+ for i in range(loop):
+ try:
+ fwd_time, bwd_time = test_fused_attn(**params)
+ if fwd_time is None and bwd_time is None:
+ continue
+ logger.info(f"iter: {i}, fwd time: {fwd_time}, bwd time: {bwd_time}")
+ fwd_time_list.append(fwd_time)
+ bwd_time_list.append(bwd_time)
+ except Exception as e:
+ logger.error(e)
+ logger.error(traceback.format_exc())
+ return
+
+ if len(fwd_time_list) < loop:
+ raise Exception(f"Failed: {len(fwd_time_list)}/{loop} iterations finished")
+
+ df_res.loc[select_mask, "gpu_fw_time"] = mean(fwd_time_list[1:])
+ df_res.loc[select_mask, "gpu_bw_time"] = mean(bwd_time_list[1:])
+ df_res.to_csv(result_csv, index=False)
+
+ logger.info(f"Forward gpu time = {mean(fwd_time_list[1:])} ms")
+ logger.info(f"Backward gpu time = {mean(bwd_time_list[1:])} ms")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_msprof_npu_hstu.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_msprof_npu_hstu.py
new file mode 100644
index 0000000000000000000000000000000000000000..05aa9a0df79bfb0bc5b2cc5e6c3f2d24b301ae75
--- /dev/null
+++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_msprof_npu_hstu.py
@@ -0,0 +1,131 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+# Copyright 2025. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+import os
+import argparse
+import glob
+import subprocess
+import pandas as pd
+
+import config
+from test_read_benchmark import (
+ logger,
+ read_and_validate_parameters,
+ result_csv,
+ init_result_csv_index,
+)
+
+
+def msprof_main(index):
+ index_str = "index"
+ try:
+ # Initialize the result CSV file
+ init_result_csv_index(index)
+ df_res = pd.read_csv(result_csv)
+
+ # Check if the current index has already been processed
+ if (
+ df_res.loc[df_res[index_str] == index, "npu_fw_time"].notna().all()
+ and df_res.loc[df_res[index_str] == index, "npu_bw_time"].notna().all()
+ ):
+ logger.info(f"Benchmark with index {index} is already done. Exiting.")
+ return True
+
+ # Read and validate parameters
+ _, params = read_and_validate_parameters(index)
+
+ # Execute the msprof command
+ cmd = f'rm -rf profnpu/ ; msprof --application="python3 test_npu_hstu.py --index={index}" --output=profnpu'
+ ret = os.system(cmd)
+ if ret != 0:
+ logger.error(f"Command execution failed (ret={ret}): {cmd}")
+ return False
+
+ # Locate the generated CSV file
+ search_dir = os.path.join(os.path.realpath(config.NFS_DIR), "profnpu")
+ csv_files = glob.glob(
+ f"{search_dir}/PROF_*/mindstudio_profiler_output/op_stati*.csv"
+ )
+ if len(csv_files) == 0:
+ logger.error(f"MSProf run failed. No CSV file found in {search_dir}.")
+ return False
+
+ csv_file = csv_files[0]
+ logger.info(f"Profile located at: {csv_file}")
+
+ # Read the CSV file
+ df_op_stati = pd.read_csv(csv_file)
+
+ # Extract Forward and Backward data
+ forward_row = df_op_stati[df_op_stati["OP Type"] == "HstuDenseForward"]
+ backward_row = df_op_stati[df_op_stati["OP Type"] == "HstuDenseBackward"]
+
+ # Check if the data exists
+ if forward_row.empty or backward_row.empty:
+ missing_ops = []
+ if forward_row.empty:
+ missing_ops.append("HstuDenseForward")
+ if backward_row.empty:
+ missing_ops.append("HstuDenseBackward")
+ logger.error(f"Missing OP types in CSV: {', '.join(missing_ops)}")
+ return False
+
+ # Update the result DataFrame
+ df_res.loc[df_res[index_str] == index, "npu_fw_time"] = (
+ forward_row["Avg Time(us)"].squeeze() / 1000
+ )
+ df_res.loc[df_res[index_str] == index, "npu_bw_time"] = (
+ backward_row["Avg Time(us)"].squeeze() / 1000
+ )
+
+ # Save the results
+ df_res.to_csv(result_csv, index=False)
+
+ # Log the results
+ logger.info(
+ f"Forward time: {df_res.loc[df_res[index_str] == index, 'npu_fw_time'].values[0]} ms"
+ )
+ logger.info(
+ f"Backward time: {df_res.loc[df_res[index_str] == index, 'npu_bw_time'].values[0]} ms"
+ )
+
+ return True
+
+ except FileNotFoundError as e:
+ logger.error(f"File not found: {e}")
+ return False
+ except pd.errors.EmptyDataError as e:
+ logger.error(f"Empty CSV file or invalid data: {e}")
+ return False
+ except KeyError as e:
+ logger.error(f"Missing required column in DataFrame: {e}")
+ return False
+ except Exception as e:
+ logger.error(f"Unexpected error: {e}")
+ return False
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Read CSV file and run a specific index benchmark"
+ )
+ parser.add_argument(
+ "--index", type=int, required=True, help="index of the benchmark to run"
+ )
+ args = parser.parse_args()
+ if msprof_main(args.index):
+ logger.info("success")
diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_npu_hstu.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_npu_hstu.py
new file mode 100644
index 0000000000000000000000000000000000000000..c63d8dc606f5bb508a977f770dced5f943eed7fe
--- /dev/null
+++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_npu_hstu.py
@@ -0,0 +1,158 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+# Copyright 2025. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+import argparse
+import os
+import sysconfig
+
+import torch
+
+import config
+from test_read_benchmark import (
+ logger,
+ DATASETS,
+ load_params,
+ read_and_validate_parameters,
+)
+
+torch.npu.config.allow_internal_format = False
+torch.ops.load_library(f"{sysconfig.get_path('purelib')}/libfbgemm_npu_api.so")
+
+
+def _hstu_attention_maybe_from_cache(
+ num_heads: int,
+ attention_dim: int,
+ linear_dim: int,
+ silu_value: float,
+ grad: torch.Tensor,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ invalid_attn_mask: torch.Tensor,
+ seq_offset: torch.Tensor,
+ data_type: torch.dtype,
+ device: str,
+):
+ n: int = invalid_attn_mask.size(-1)
+ torch.npu.set_device(device)
+
+ q_ = q.reshape(-1, num_heads, attention_dim).to(device=device).to(data_type)
+ k_ = k.reshape(-1, num_heads, attention_dim).to(device=device).to(data_type)
+ v_ = v.reshape(-1, num_heads, attention_dim).to(device=device).to(data_type)
+ grad = grad.to(device=device).to(data_type)
+
+ seq_offset = seq_offset.to(device=device).tolist()
+
+ if len(invalid_attn_mask.shape) == 2:
+ invalid_attn_mask = invalid_attn_mask.repeat(
+ len(seq_offset) - 1, num_heads, 1, 1
+ )
+ if len(invalid_attn_mask.shape) == 4 and invalid_attn_mask.shape[1] == 1:
+ invalid_attn_mask = invalid_attn_mask.repeat(1, num_heads, 1, 1)
+
+ logger.info(f"invalid_attn_mask shape: {invalid_attn_mask.shape}")
+
+ invalid_attn_mask = invalid_attn_mask.to(device=device).to(data_type)
+ mask_type = 3
+ silu_value = silu_value / n
+ local_cycle_nums = 100
+ for _ in range(local_cycle_nums):
+ grad_output = torch.ops.mxrec.hstu_dense(
+ q_,
+ k_,
+ v_,
+ invalid_attn_mask,
+ None,
+ mask_type,
+ n,
+ silu_value,
+ "jagged",
+ seq_offset,
+ )
+ q_grad, k_grad, v_grad, _ = torch.ops.mxrec.hstu_dense_backward(
+ grad,
+ q_,
+ k_,
+ v_,
+ invalid_attn_mask,
+ None,
+ "jagged",
+ mask_type,
+ n,
+ silu_value,
+ seq_offset,
+ )
+
+ torch.npu.synchronize()
+ grad_output = grad_output.reshape(-1, num_heads * linear_dim)
+
+ save_dir = DATASETS
+
+ torch.save(grad_output, os.path.join(save_dir, "npu_out.pth"))
+ torch.save(q_grad, os.path.join(save_dir, "npu_q.pth"))
+ torch.save(k_grad, os.path.join(save_dir, "npu_k.pth"))
+ torch.save(v_grad, os.path.join(save_dir, "npu_v.pth"))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Read CSV file and run a specific index benchmark"
+ )
+ parser.add_argument(
+ "--index", type=int, required=True, help="index of the benchmark to run"
+ )
+ args = parser.parse_args()
+
+ devicex = 0
+ deviceg = f"npu:{devicex}"
+ logger.info(f"device: {deviceg}")
+ read_dir = os.path.join(os.path.realpath(config.NFS_DIR), DATASETS)
+
+ _, param_ben = read_and_validate_parameters(args.index)
+ param = load_params(DATASETS)
+ try:
+ grad_data = param["grad"]
+ q_data = param["q"]
+ k_data = param["k"]
+ v_data = param["v"]
+ bias_data = param["rab"]
+ mask_data = param["attn_mask"]
+ max_seq_len_data = param["num_contexts"]
+ seq_offset_data = param["seq_offsets_q_wt"]
+ num_heads_data = v_data.shape[1]
+ data_type_data = param_ben["dtype"]
+ alpha_data = param["alpha"]
+ attention_dim_data = q_data.shape[2]
+ linear_dim_data = v_data.shape[2]
+ except KeyError as e:
+ logger.error(e)
+ exit(1)
+
+ _hstu_attention_maybe_from_cache(
+ num_heads=num_heads_data,
+ attention_dim=attention_dim_data,
+ linear_dim=linear_dim_data,
+ silu_value=alpha_data,
+ grad=grad_data,
+ q=q_data,
+ k=k_data,
+ v=v_data,
+ invalid_attn_mask=mask_data,
+ seq_offset=seq_offset_data,
+ data_type=data_type_data,
+ device=deviceg,
+ )
diff --git a/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_read_benchmark.py b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_read_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..d837991e8502bb2b61e775a333d3ebb7c88e196b
--- /dev/null
+++ b/mxrec_add_ons/rec_for_torch/torch_plugin/torch_demo/hstu_dense/test_hstu_performance/test_read_benchmark.py
@@ -0,0 +1,755 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+# Copyright 2025. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+import argparse
+import ast
+import logging
+import math
+import os
+from typing import Dict, Optional, Tuple
+
+import pandas as pd
+import numpy as np
+import torch
+import matplotlib.colors as mcolors
+import matplotlib.pyplot as plt
+import matplotlib
+
+import config
+
+logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s %(filename)s %(lineno)d [%(levelname)s] %(message)s",
+ handlers=[logging.FileHandler("test_benchmark.log"), logging.StreamHandler()],
+)
+logger = logging.getLogger(__name__)
+
+matplotlib.use("Agg")
+INDEX_STR = "index"
+DATASETS = os.path.join(os.path.realpath(config.NFS_DIR), "datasets")
+benchmark_csv = os.path.join(os.path.realpath(config.NFS_DIR), "benchmark.csv")
+result_csv = os.path.join(os.path.realpath(config.NFS_DIR), "result.csv")
+
+column_names = [
+ "index",
+ "shape_info",
+ "total_len",
+ "batch_size",
+ "heads",
+ "heads_rab",
+ "max_seq_len_q",
+ "max_seq_len_k",
+ "max_context_len",
+ "max_target_len",
+ "target_group_size",
+ "attn_dim",
+ "hidden_dim",
+ "alpha",
+ "has_rab",
+ "has_drab",
+ "window_size",
+ "run_benchmark",
+ "dtype",
+ "full_batch",
+ "is_delta_q",
+ "format",
+ "npu_fw_time",
+ "npu_bw_time",
+ "gpu_fw_time",
+ "gpu_bw_time",
+ "precision",
+ "npu_fw/gpu_fw",
+ "npu_bw/gpu_bw",
+ "npu_fw+bw/gpu_fw+bw",
+ "npu_fw/benchmark",
+ "npu_bw/benchmark",
+ "npu_fw+bw/benchmark",
+]
+
+column_left = column_names[: column_names.index("format") + 1]
+
+hstu_required_params = {
+ "total_len": int,
+ "batch_size": int,
+ "heads": int,
+ "heads_rab": int,
+ "max_seq_len_q": int,
+ "max_seq_len_k": int,
+ "max_context_len": int,
+ "max_target_len": int,
+ "target_group_size": int,
+ "attn_dim": int,
+ "hidden_dim": int,
+ "alpha": float,
+ "has_rab": bool,
+ "has_drab": bool,
+ "window_size": tuple,
+ "run_benchmark": int,
+ "dtype": torch.dtype,
+ "full_batch": bool,
+ "is_delta_q": bool,
+}
+
+generate_params = [
+ "total_len",
+ "batch_size",
+ "heads",
+ "heads_rab",
+ "max_seq_len_q",
+ "max_seq_len_k",
+ "max_context_len",
+ "max_target_len",
+ "target_group_size",
+ "attn_dim",
+ "hidden_dim",
+ "window_size",
+ "dtype",
+ "full_batch",
+ "has_drab",
+ "is_delta_q",
+]
+
+
+def convert_value(value, required_type):
+ if isinstance(value, str):
+ if value == "torch.float16":
+ return torch.float16
+ elif value == "torch.bfloat16":
+ return torch.bfloat16
+ else:
+ try:
+ value = ast.literal_eval(value)
+ except ValueError:
+ pass
+ try:
+ return required_type(value)
+ except (ValueError, TypeError):
+ return value
+
+
+def read_and_validate_parameters(index, csv_file_path=benchmark_csv):
+ try:
+ # Try to read CSV file
+ df_benchmark = pd.read_csv(csv_file_path, encoding="utf-8")
+
+ # Check if index column exists
+ if INDEX_STR not in df_benchmark.columns:
+ logger.error("Missing index column in CSV: %s", INDEX_STR)
+ return None, None
+
+ # Filter rows by index
+ df_benchmark = df_benchmark.loc[df_benchmark[INDEX_STR] == index]
+ if df_benchmark.empty:
+ logger.info("No data found for index %d", index)
+ return None, None
+
+ # Extract required parameters
+ try:
+ row = df_benchmark[list(hstu_required_params.keys())]
+ except KeyError as e:
+ logger.error("Missing required columns in CSV: %s", e)
+ return None, None
+
+ # Convert to dict and validate types
+ params = row.iloc[0].to_dict()
+ for key, required_type in hstu_required_params.items():
+ try:
+ params[key] = convert_value(params[key], required_type)
+ except (ValueError, TypeError) as e:
+ logger.error("Type conversion failed for parameter %s: %s", key, e)
+ return None, None
+
+ logger.info("Parameters for index %d: %s", index, params)
+ return df_benchmark, params
+
+ except Exception as e:
+ logger.error("Error: %s", e, exc_info=True) # Added exc_info for stack trace
+ return None, None
+
+
+def init_result_csv_index(index):
+ try:
+ # Read benchmark CSV and ensure index column is integer type
+ ben_df = pd.read_csv(benchmark_csv)
+ ben_df[INDEX_STR] = ben_df[INDEX_STR].astype(int)
+
+ # Create empty result file if not exists
+ if not os.path.exists(result_csv):
+ pd.DataFrame(columns=column_names).to_csv(result_csv, index=False)
+
+ # Read result file
+ df_res = pd.read_csv(result_csv)
+
+ # Check if index exists in benchmark data
+ benchmark_row = ben_df.loc[ben_df[INDEX_STR] == index, column_left]
+ if benchmark_row.empty: # Explicit empty case handling
+ logger.warning("Index %d not found in benchmark file, skipping", index)
+ raise ValueError("Row %d not found in %s" % (index, benchmark_csv))
+
+ # Check if index already exists in result file
+ if index in df_res[INDEX_STR].values:
+ return
+
+ # Add new row and save
+ new_row = pd.DataFrame([benchmark_row.iloc[0]], columns=column_left)
+ df_res = pd.concat([df_res, new_row], ignore_index=True)
+ df_res.to_csv(result_csv, index=False)
+ logger.info("Successfully added index %d to result file %s", index, result_csv)
+
+ except Exception as e:
+ logger.error("Error initializing result file: %s", str(e), exc_info=True)
+ raise
+
+
+def construct_mask(
+ seqlen_c,
+ seqlen,
+ seqlen_t=0,
+ target_group_size=1,
+ window_size=(-1, -1), # -1 means infinite window size
+ seq_offsets=None,
+ num_contexts=None,
+ device=None,
+):
+ seqlen = seqlen_c + seqlen + seqlen_t
+ bs = seq_offsets.size(0) - 1
+
+ mask = torch.zeros((seqlen, seqlen), device=device, dtype=torch.bool)
+ if window_size[0] < 0 and window_size[1] == 0:
+ # causal mask
+ for i in range(seqlen):
+ mask[i, : i + 1] = True
+
+ # context mask
+ if seqlen_c != 0:
+ mask = mask.unsqueeze(0).unsqueeze(0).repeat(bs, 1, 1, 1)
+ for i in range(bs):
+ target_start = (
+ num_contexts[i] + seq_offsets[i + 1] - seq_offsets[i]
+ ).item()
+ mask[i, 0, : num_contexts[i], :target_start] = True
+
+ # target mask
+ if seqlen_t != 0:
+ mask = (
+ mask.unsqueeze(0).unsqueeze(0).repeat(bs, 1, 1, 1)
+ if mask.ndim == 2
+ else mask
+ )
+ for i in range(bs):
+ target_start = (
+ num_contexts[i] + seq_offsets[i + 1] - seq_offsets[i]
+ ).item()
+ # target group mask
+ if target_group_size > 1:
+ group_num = math.ceil((seqlen - target_start) / target_group_size)
+ for j in range(group_num):
+ for k in range(
+ min(
+ target_group_size,
+ seqlen - target_start - j * target_group_size,
+ )
+ ):
+ mask[
+ i,
+ 0,
+ target_start + j * target_group_size + k,
+ target_start: target_start + j * target_group_size,
+ ] = False
+ else:
+ for j in range(target_start, seqlen):
+ mask[i, 0, j, target_start:j] = False
+
+ # local mask
+ else:
+ window_size_0 = window_size[0] if window_size[0] > 0 else seqlen
+ window_size_1 = window_size[1] if window_size[1] > 0 else seqlen
+ for i in range(seqlen):
+ mask[i, max(0, i - window_size_0): min(seqlen, i + window_size_1 + 1)] = (
+ True
+ )
+ return mask
+
+
+def gen_seq(length, max_value, total_sum):
+ if max_value * length < total_sum:
+ raise ValueError("total_sum error %d" % total_sum)
+ logger.info("gen_seq with total_sum %d", total_sum)
+ if length == 1:
+ return np.array([total_sum])
+ if length == 2:
+ return np.array([max_value, total_sum - max_value])
+ remaining_sum = total_sum - max_value
+ mean_value = remaining_sum // (length - 2)
+ min_val = remaining_sum - mean_value * (length - 2)
+ sequence = [mean_value] * (length - 2)
+ if min_val == 0 and sequence[-1] > 1:
+ min_val += 1
+ sequence[-1] -= 1
+ sequence.extend([max_value, min_val])
+ return np.array(sequence)
+
+
+def adjust_ratio(total_sum, max_context_len, max_seq_len_k, max_target_len):
+ """Adjust ratio distribution based on given parameters."""
+ # Initialize result variables
+ total_content, total_k, total_target = 0, 0, 0
+
+ # Check denominator
+ denominator = max_context_len + max_seq_len_k + max_target_len
+ if denominator == 0:
+ logger.debug("All max lengths are 0, returning zeros")
+ return 0, 0, 0 # Return zeros if all inputs are zero
+
+ # Handle zero cases
+ if max_context_len == 0:
+ total_content = 0
+ if max_seq_len_k == 0:
+ total_k = 0
+ if max_target_len == 0:
+ total_target = 0
+
+ # Calculate remaining sum to distribute
+ remaining_sum = total_sum - (total_content + total_k + total_target)
+ if remaining_sum == 0:
+ logger.debug("No remaining sum to distribute")
+ return total_content, total_k, total_target
+
+ # Calculate valid denominator (excluding zero terms)
+ valid_denominator = 0
+ if max_context_len > 0:
+ valid_denominator += max_context_len
+ if max_seq_len_k > 0:
+ valid_denominator += max_seq_len_k
+ if max_target_len > 0:
+ valid_denominator += max_target_len
+
+ if valid_denominator == 0:
+ raise ValueError("valid_denominator cannot be 0")
+
+ logger.debug("Distributing remaining sum %d with valid denominator %d", remaining_sum, valid_denominator)
+
+ try:
+ # Distribute remaining sum proportionally
+ if max_context_len > 0 and valid_denominator != 0:
+ total_content += int(round(remaining_sum * max_context_len / valid_denominator))
+
+ if max_seq_len_k > 0 and valid_denominator != 0:
+ total_k += int(round(remaining_sum * max_seq_len_k / valid_denominator))
+
+ if max_target_len > 0 and valid_denominator != 0:
+ total_target += int(round(remaining_sum * max_target_len / valid_denominator))
+ except ZeroDivisionError as e:
+ logger.info(e)
+ raise e
+
+ # Handle rounding errors
+ diff = total_sum - (total_content + total_k + total_target)
+ if diff != 0:
+ logger.debug("Adjusting for rounding difference of %d", diff)
+ total_target += diff # Default adjustment to target
+
+ logger.info("Final distribution: total_k=%d, total_content=%d, total_target=%d", \
+ total_k, total_content, total_target)
+ return total_k, total_content, total_target
+
+
+def generate_input(
+ total_len: int,
+ batch_size: int,
+ heads: int,
+ heads_rab: Optional[int],
+ max_seq_len_q: int,
+ max_seq_len_k: int,
+ max_context_len: int,
+ max_target_len: int,
+ target_group_size: int,
+ attn_dim: int,
+ hidden_dim: int,
+ window_size: Tuple[int, int],
+ dtype: torch.dtype,
+ full_batch: bool,
+ has_drab: bool,
+ is_delta_q: bool,
+):
+ device_str = "cpu"
+ has_context = max_context_len > 0
+ has_target = max_target_len > 0
+ target_group_size > 1
+
+ # Modification Note: Original random length generation caused uncontrolled total length,
+ # leading to unstable computational load. Now using proportional allocation based on total_len.
+ logger.info("generate with total_len %d", total_len)
+
+ # Allocate total length proportionally according to max_seq_len_k/max_context_len/max_target_len
+ total_k, total_content, total_target = adjust_ratio(
+ total_len, max_context_len, max_seq_len_k, max_target_len
+ )
+
+ # Generate key sequence lengths (proportionally allocated)
+ lengths_k = (
+ torch.from_numpy(gen_seq(batch_size, max_seq_len_k, total_k))
+ .to(device_str)
+ .int()
+ )
+
+ # Generate context sequence lengths (proportionally allocated)
+ num_contexts = (
+ torch.from_numpy(gen_seq(batch_size, max_context_len, total_content))
+ .to(device_str)
+ .int()
+ )
+
+ # Generate target sequence lengths (proportionally allocated)
+ num_targets = (
+ torch.from_numpy(gen_seq(batch_size, max_target_len, total_target))
+ .to(device_str)
+ .int()
+ )
+
+ # Generate lengths for context
+
+ seq_offsets_c = torch.zeros(
+ (batch_size + 1,), dtype=torch.int32, device=torch.device(device_str)
+ )
+ seq_offsets_c[1:] = torch.cumsum(num_contexts, dim=0)
+
+ # Generate lengths for historial qkv
+
+ seq_offsets_k = torch.zeros(
+ (batch_size + 1,), dtype=torch.int32, device=torch.device(device_str)
+ )
+ seq_offsets_k[1:] = torch.cumsum(lengths_k, dim=0)
+
+ # Generate lengths for target qkv
+ seq_offsets_t = torch.zeros(
+ (batch_size + 1,), dtype=torch.int32, device=torch.device(device_str)
+ )
+ seq_offsets_t[1:] = torch.cumsum(num_targets, dim=0)
+
+ # Generate lengths for delta q
+ if is_delta_q:
+ if full_batch:
+ lengths_q = (
+ torch.ones(
+ (batch_size,), device=torch.device(device_str), dtype=torch.int32
+ )
+ * max_seq_len_q
+ )
+ else:
+ # lengths_q[i] is an integer between 1 and min(max_seq_len_q, lengths_k[i])
+ lengths_q = torch.zeros(
+ (batch_size,), device=torch.device(device_str), dtype=torch.int32
+ )
+ for i in range(batch_size):
+ lengths_q[i] = torch.randint(
+ 1,
+ min(max_seq_len_q, lengths_k[i]) + 1,
+ size=(1,),
+ device=torch.device(device_str),
+ )
+ seq_offsets_q = torch.zeros(
+ (batch_size + 1,), dtype=torch.int32, device=torch.device(device_str)
+ )
+ seq_offsets_q[1:] = torch.cumsum(lengths_q, dim=0)
+ else:
+ seq_offsets_q = seq_offsets_k
+
+ # Lengths for whole q, kv
+ seq_offsets_q_wt = torch.zeros(
+ (batch_size + 1,), dtype=torch.int32, device=torch.device(device_str)
+ )
+ seq_offsets_q_wt = seq_offsets_c + seq_offsets_q + seq_offsets_t
+ seq_offsets_k_wt = torch.zeros(
+ (batch_size + 1,), dtype=torch.int32, device=torch.device(device_str)
+ )
+ seq_offsets_k_wt = seq_offsets_c + seq_offsets_k + seq_offsets_t
+
+ l_q = int(seq_offsets_q_wt[-1].item())
+ l_k = int(seq_offsets_k_wt[-1].item())
+ if dtype == torch.float8_e4m3fn:
+ dtype_init = torch.float16
+ else:
+ dtype_init = dtype
+
+ # Generate q, k, v for history + target
+ q = (
+ torch.empty(
+ (l_q, heads, attn_dim), dtype=dtype_init, device=torch.device(device_str)
+ )
+ .uniform_(-1, 1)
+ .requires_grad_()
+ ).to(dtype)
+ k = (
+ torch.empty(
+ (l_k, heads, attn_dim), dtype=dtype_init, device=torch.device(device_str)
+ )
+ .uniform_(-1, 1)
+ .requires_grad_()
+ ).to(dtype)
+ v = (
+ torch.empty(
+ (l_k, heads, hidden_dim), dtype=dtype_init, device=torch.device(device_str)
+ )
+ .uniform_(-1, 1)
+ .requires_grad_()
+ ).to(dtype)
+ rab = None
+ if has_drab:
+ rab = torch.empty(
+ (
+ batch_size,
+ heads if heads_rab is None else heads_rab,
+ max_context_len + max_seq_len_k + max_target_len,
+ max_context_len + max_seq_len_k + max_target_len,
+ ),
+ dtype=dtype_init,
+ device=torch.device(device_str),
+ ).uniform_(-1, 1)
+ rab = rab.requires_grad_()
+
+ if window_size[0] == -1 and window_size[1] == -1:
+ attn_mask = None
+ else:
+ attn_mask = (
+ construct_mask(
+ seqlen_c=max_context_len,
+ seqlen=max_seq_len_k,
+ seqlen_t=max_target_len,
+ target_group_size=target_group_size,
+ window_size=window_size,
+ num_contexts=num_contexts,
+ seq_offsets=seq_offsets_k,
+ )
+ .cpu()
+ .to(torch.float32)
+ )
+ grad = torch.rand_like(v)
+ params = {
+ "l_q": l_q,
+ "l_k": l_k,
+ "num_contexts": num_contexts if has_context else None,
+ "seq_offsets_q_wt": seq_offsets_q_wt,
+ "seq_offsets_k_wt": seq_offsets_k_wt,
+ "num_targets": num_targets if has_target else None,
+ "q": q,
+ "k": k,
+ "v": v,
+ "rab": rab,
+ "attn_mask": attn_mask,
+ "grad": grad,
+ }
+ return params
+
+
+def save_mask(matrix, title="matrix"):
+ cmap = mcolors.LinearSegmentedColormap.from_list(
+ "CustomMap", [(1, 1, 1), (0.8, 0.902, 0.8)], N=256
+ )
+
+ max_pixels = 65536
+ dpi = 100
+
+ width_inches = min(12, max(6, matrix.shape[1] * 0.05))
+ height_inches = min(12, max(6, matrix.shape[0] * 0.05))
+
+ width_pixels = width_inches * dpi
+ height_pixels = height_inches * dpi
+
+ if width_pixels > max_pixels or height_pixels > max_pixels:
+ scale_factor = min(max_pixels / width_pixels, max_pixels / height_pixels)
+ width_inches *= scale_factor
+ height_inches *= scale_factor
+
+ fig, ax = plt.subplots(figsize=(width_inches, height_inches))
+
+ img = ax.imshow(
+ matrix, cmap=cmap, origin="lower", vmin=0, vmax=1, interpolation="nearest"
+ )
+
+ ax.xaxis.set_ticks_position("top")
+ ax.yaxis.set_ticks_position("left")
+ ax.invert_yaxis()
+
+ plt.colorbar(img)
+ plt.title(title)
+ plt.xlabel("seq_len")
+ plt.ylabel("seq_len")
+ plt.savefig(f"{title}.png", bbox_inches="tight", dpi=dpi)
+ plt.close(fig)
+ logger.info(f"save {title}.png")
+
+
+PARAM_META = {
+ # Tensor parameters (name: (type, required, description))
+ "l_q": ("Tensor", True, "Query sequence length"),
+ "l_k": ("Tensor", True, "Key sequence length"),
+ "num_contexts": ("Tensor", True, "Number of contexts"),
+ "seq_offsets_q_wt": ("Tensor", True, "Query sequence weight offsets"),
+ "seq_offsets_k_wt": ("Tensor", True, "Key sequence weight offsets"),
+ "num_targets": ("Tensor", True, "Number of targets"),
+ "q": ("Tensor", True, "Query tensor"),
+ "k": ("Tensor", True, "Key tensor"),
+ "v": ("Tensor", True, "Value tensor"),
+ "rab": ("Tensor", True, "Relative attention bias"),
+ "attn_mask": ("Tensor", True, "Attention mask"),
+ "grad": ("Tensor", True, "Gradient tensor"),
+ # Configuration parameters
+ "dtype": ("Config", True, "Data type"),
+ "max_context_len": ("Config", True, "Max context length"),
+ "max_seq_len_q": ("Config", True, "Max query sequence length"),
+ "max_target_len": ("Config", True, "Max target length"),
+ "alpha": ("Config", True, "silu scale tensor"),
+}
+
+
+
+def _get_save_paths(save_dir: str) -> Dict[str, str]:
+ return {
+ name: os.path.normpath(os.path.join(save_dir, f"{name}.pth"))
+ for name in PARAM_META
+ if not name.startswith("_")
+ }
+
+
+def save_params(save_dir=DATASETS, **kwargs):
+ # Normalize path
+ save_dir = os.path.normpath(save_dir)
+ os.makedirs(save_dir, exist_ok=True)
+
+ # Validate required parameters
+ missing = [
+ name
+ for name, (_, required, _) in PARAM_META.items()
+ if required and name not in kwargs
+ ]
+ if missing:
+ logger.error("Missing required parameters: %s", ", ".join(missing))
+ raise ValueError("Required parameters missing: %s" % ", ".join(missing))
+
+ # Get all save paths
+ paths = _get_save_paths(save_dir)
+ image_name = "image_name"
+ # Save parameters with logging
+ logger.info("\n%s", "=" * 50)
+ logger.info("[SAVE] Target directory: %s", save_dir)
+ for name, path in paths.items():
+ if name in kwargs:
+ torch.save(kwargs[name], path)
+ param = kwargs[name]
+ log_msg = "%s: " % name.ljust(15)
+ if torch.is_tensor(param):
+ log_msg += "shape=%s | dtype=%s" % (
+ str(param.shape).ljust(18),
+ param.dtype,
+ )
+ else:
+ log_msg += str(param)
+ logger.info("%s -> %s", log_msg, path)
+
+ # Save attention matrix image if provided
+ if "attn_mask" in kwargs and image_name in kwargs:
+ mask = kwargs["attn_mask"]
+ if mask.dim() == 4:
+ save_matrix = mask[0, 0].cpu()
+ elif mask.dim() == 3:
+ save_matrix = mask[0].cpu()
+ else:
+ save_matrix = mask.cpu()
+
+ img_path = os.path.join(save_dir, kwargs[image_name])
+ save_mask(save_matrix, img_path)
+
+ # Create completion flag
+ flag_path = os.path.join(save_dir, "complete.flag")
+ torch.save(torch.tensor(1), flag_path)
+ logger.info("%s\n[SUCCESS] All parameters saved\n%s", "=" * 50, "=" * 50)
+
+
+def load_params(save_dir=DATASETS, device="cpu") -> Dict[str, object]:
+ save_dir = os.path.normpath(save_dir)
+ paths = _get_save_paths(save_dir)
+
+ # Validate directory integrity
+ if not os.path.exists(paths["l_q"]):
+ logger.error("Invalid parameter directory: %s", save_dir)
+ raise FileNotFoundError("Invalid parameter directory: %s" % save_dir)
+
+ # Load parameters with logging (lazy interpolation)
+ logger.info("\n%s", "=" * 50)
+ logger.info("[LOAD] Source directory: %s", save_dir)
+ params = {}
+ for name, path in paths.items():
+ if os.path.exists(path):
+ params[name] = torch.load(path, map_location=device)
+ log_msg = "%s: " % name.ljust(15)
+ if torch.is_tensor(params[name]):
+ log_msg += "shape=%s | dtype=%s" % (
+ str(params[name].shape).ljust(18),
+ params[name].dtype,
+ )
+ else:
+ log_msg += str(params[name])
+ logger.info("%s <- %s", log_msg, path)
+
+ logger.info("%s\n[SUCCESS] All parameters loaded\n%s", "=" * 50, "=" * 50)
+ return params
+
+
+def create_and_save_params(bx):
+ df, bparams = read_and_validate_parameters(bx)
+ image_name = "image_name"
+ bparams[image_name] = f"{bx}_{df.shape_info.item()}"
+ init_result_csv_index(bx)
+ df_resg = pd.read_csv(result_csv)
+ logger.info(df_resg[df_resg[INDEX_STR] == bx])
+ gene_param_dict = {k: bparams[k] for k in generate_params}
+ iparams = generate_input(**gene_param_dict)
+ iparams[image_name] = bparams[image_name]
+
+ extra_params = [
+ "dtype",
+ "max_context_len",
+ "max_seq_len_q",
+ "max_target_len",
+ "alpha",
+ ]
+ for param_name in extra_params:
+ if param_name not in iparams:
+ iparams[param_name] = bparams[param_name]
+
+ save_params(**iparams)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Read CSV file and run a specific index benchmark"
+ )
+ parser.add_argument(
+ "--index", type=int, required=True, help="index of the benchmark to run"
+ )
+ args = parser.parse_args()
+ benchmark_df = pd.read_csv(benchmark_csv)
+ benchmark_df[INDEX_STR] = benchmark_df[INDEX_STR].astype(int)
+ all_indices = benchmark_df[INDEX_STR].tolist()
+ if args.index:
+ all_indices = [args.index]
+ for bxx in all_indices:
+ create_and_save_params(bxx)
+ # Load example
+ loaded = load_params()