Sign in
Sign up
Explore
Enterprise
Education
Search
Help
Terms of use
About Us
Explore
Enterprise
Education
Gitee Premium
Gitee AI
Sign in
Sign up
Fetch the repository succeeded.
Open Source
>
AI/ML
>
Artificial Intelligence
&&
Donate
Please sign in before you donate.
Cancel
Sign in
Scan WeChat QR to Pay
Cancel
Complete
Prompt
Switch to Alipay.
OK
Cancel
Watch
Unwatch
Watching
Releases Only
Ignoring
100
Star
1.3K
Fork
908
GVP
MindSpore
/
mindformers
Code
Issues
100
Pull Requests
144
Wiki
Insights
Pipelines
Service
Quality Analysis
Jenkins for Gitee
Tencent CloudBase
Tencent Cloud Serverless
悬镜安全
Aliyun SAE
Codeblitz
SBOM
Don’t show this again
Update failed. Please try again later!
Remove this flag
Content Risk Flag
This task is identified by
as the content contains sensitive information such as code security bugs, privacy leaks, etc., so it is only accessible to contributors of this repository.
【动静一致】MindSpore Transformers支持动静一致的大模型训推需求
TODO
#ICY7S6
RFC
suhaibo
member
Opened this issue
2025-09-16 14:29
### 背景信息 目前业界主流框架,例如torch2.x版本,在持续演进torch.compile的图模式,提供高性能的训推方案,并保持100%的前向兼容;torch.distribute的原生分布式能力也逐步从DDP发到FSDP2,将分布式训练能力逐步往框架内嵌,且动静图模式均可使用。从AI框架发展来看,提供动静一致的分布式训推能力成为趋势。 基于上述能力,目前业界也在发展新的模型训推开发工具库(pytorch官方的torchtitan、字节的Veomni等),其实现逻辑是将模型结构实现和并行算法进行分离,其大大简化了模型构建的成本。而并行优化可以不侵入式修改模型结构,设计上实现了各个组件之间的解耦。与Megatron-LM相比,其扩展性大大增强,Megatron-LM是将并行算法与模型实现耦合,并行调度强感知模型结构,拓展开发的难度高。 MindSpore Transformers从开源到当前版本,都是基于静态图模式构建的大模型训推解决方案,目前在图模式上遇到了如下关键问题: - 调试困难,整网整图,无法分段调试 - 并行算法黑盒(tensor parallel、pipeline parallel、zero optimizer等),例如DeepSeekV3引入的dualpp特性,需要依赖MindSpore开发完成,套件和用户无法自主完成 - 动态Shape、控制流支持度不够,无法满足多模态模型乃至后面的全模态模型训练诉求 综上,为解决MindSpore Transformers遇到的关键问题,结合业界的最新动态,计划联合MindSpore框架构建动静一致的分布式技术,目标如下: - 保持模型实现编程和分布式并行算法解耦,提升算法的可读性、可扩展性 - 支持pynative执行(动态图),并可以通过jit入图,转成图模式执行 - 支持控制流及动态shape - 并行算法实现上移至Python层,预置tensor parallel、pipeline parallel、zero optimizer等分布式算法;并支持用户扩展 ### 需求范围 目前在第一阶段,主要集中在基础能力的改造和补齐,完成动静态统一接口改造,本阶段MindSpore Transformers主要需求规划如下: - 适配动静一致编程范式的并行算法:tensor parallel、pipeline parallel、zero optimizer以及Swap&recompute特性 - 适配动静一致编程范式,逐步替换为函数式算子调用,对齐torch、jax等主流框架的调用方式 - 多模态基础接口:补齐多模态理解、全模态等场景接口 ### 设计方案 #### MindSpore动静一致的设计思路 当前MindSpore计划构建一套动静统一的分布式并行技术,目标如下: · 保持单卡逻辑编程,非侵入式修改单卡算法实现分布式并行,提升算法可读性; · 支持pynative执行,通过图编译接口转成整图执行; · 提供三层编程范式,手动、dtensor、shard; · 支持自定义并行与框架原生并行混合编程; · 支持控制流及动态shape; · 支持MPMD图切分及调度; 整体架构图如下:  #### MindSpore Transformers的parallel core改造 1. tensor parallel:声明式编程范式,按需配置需要重排布的张量,无需全局所有算子进行切分配置 张量并行的核心在于切分策略从模型的输入进行整网的推导,站在每一个算子的视角,即是依据算子输入张量的切分策略结合算子的数学逻辑给出输出张量的切分策略。因此,张量并行的核心是针对算子进行分类与逐类型的建模。 为了对显存进行优化,在特定的位置(特定的Cell)用户可以指定期望的切分策略,如若与推导得到的切分策略不一致,那么将进行重排布的通信操作,如短序列并行场景对norm处进行切分并且以AllGather汇聚给到Linear。 动态图、静态图整体实现有较大差异,因此对动态图的并行流程与静态图的并行流程分别进行设计。但是总体设计思路相同,用户配置切分策略仅对tensor进行配置,不对算子进行配置。对Cell的配置本质上也是配置在Cell的输入与输出的tensor上,当对于模型中间的某个tensor,自上游传播下来的切分策略与用户指定的切分策略不一致时,将触发重排布。整体设计流程图如下。  1、`RowParallelLinear`行切的linear为例,其实现的数学公式是`Y = XA + b`,其中`A = transpose([A_1 .. A_p]) X = [X_1, ..., X_p]`,编程的变化如下: - 纯静态图模式下,是针对每个算子进行配置,具体如下: ``` python def shard(self, config: TransformerConfig) -> None: """Shard the operators in RowParallelLinear. Args: config (TransformerConfig): The config of the transformer model. """ dp = config.data_parallel_size if config.data_parallel_size is not None else 1 cp = config.context_parallel_size if config.context_parallel_size is not None else 1 tp = config.tensor_model_parallel_size if config.tensor_model_parallel_size is not None else 1 matmul_in_strategy = ((dp * cp, tp), (tp, 1)) self.matmul.shard(in_strategy=matmul_in_strategy) if self.transpose_b: self.transpose.shard(((1, tp),)) if not self.skip_bias_add: add_in_strategy = ((dp * cp, 1), (1,)) self.add.shard(in_strategy=add_in_strategy) ``` - 动静一致新范式下,所有的切分是站Tensor维度去进行传播,只需要配置输入输出以及权重的layout信息即可,具体如下: ```python def shard(self, config: TransformerConfig) -> None: input_layout = layout("cp_dp", "tp") output_layout = layout("cp_dp", "None") weight_layout = layout("tp", "None") return input_layout, output_layout, weight_layout ``` 2、`ColumnParallelLinear`列切的linear为例,其实现的数学公式是`Y = XA + b`,其中权重A = [A_1, ..., A_p],其本质对于每一个token的计算都是独立无需插入通信重排布;编程的变化如下: - 纯静态图模式下,仍需要针对每个算子切分进行配置: ```python def shard(self, config: TransformerConfig) -> None: """Shard the operators in ColumnParallelLinear. Args: config (TransformerConfig): The config of the transformer model. """ dp = config.data_parallel_size if config.data_parallel_size is not None else 1 tp = config.tensor_model_parallel_size if config.tensor_model_parallel_size is not None else 1 cp = config.context_parallel_size if config.context_parallel_size is not None else 1 matmul_in_strategy = ((dp * cp, 1), (1, tp)) self.matmul.shard(in_strategy=matmul_in_strategy) if not self.skip_bias_add: add_in_strategy = ((dp * cp, tp), (tp,)) self.add.shard(in_strategy=add_in_strategy) ``` - 动静一致模式下,针对不需要进行重排布的Cell层,则可以依赖框架的策略传播模式,进行自动推导tensor的layout切分,具体如下: ```python def shard(self, config: TransformerConfig) -> None: weight_layout = layout("None", "tp") return None, None, weight_layout ``` 综上,对于动静一致,简化了配置化编程的难度,与具体的并行切分算法对其,只需要在插入通信算子处进行layout切分的配置,并利用框架的自动传播机制进行切分广播实现整网切分。 2. pipeline parallel: 将pipeline stage切分与pipeline scheduler实现开放给用户: - pp 切分 模型切分这部分需要在用户侧脚本执行完成,具体地,用户需要考虑每个 stage 需要计算的子模型,接着对这部分子模型进行实例化。参考脚本可如下所示: ``` python class Transformer(nn.Cell): def init (self, num layers) super(). init () self.embedding = nn.Embedding(...) self.layers = nn.CellDict() for layer id in range(num layers): self.layers[str(layer id)] = TransformerBlock(...) self.output = nn.Linear(...) def construct(self, input ids): h = self.embedding(input ids) if self.embedding else input id for layer in self.layers.values(): h= layer(h) output = self.output(h) if self.output else ha return output def model split manual(model, stage index): ""假设model共有8层,共2个stage,每个stage有4层om if stage index == O: for i in range(4, 8): del model.layers[str(i)] model.output = None else: for i range(4): del model.layers[str(i)] model.embedding = None ``` 此处对于MindSpore Transformers来说,只需要在模型实例化时,完成对应层的实例化控制即可。具体形式会与Megatron-LM进行参考对齐,方便大家理解与修改: ```python def build_layer(layer_spec, layer_number): global_layer_number = layer_number + get_transformer_layer_offset( self.config, self.vp_stage ) # 1-based index if self.config.heterogeneous_block_specs: layer_config = self.config.get_config_for_layer(global_layer_number) else: layer_config = self.config fp8_init_context = get_fp8_context(layer_config, global_layer_number - 1, is_init=True) with fp8_init_context: module = build_module( layer_spec, config=layer_config, layer_number=layer_number, model_comm_pgs=self.model_comm_pgs, vp_stage=self.vp_stage, ) return module ``` - pp 调度 流水线并行调度通常包括MicroBatch的切分,MicroBatch的执行,以及输出output的处理。MindSpore提供了PipelineScheduleBase基类,其中内置好了MicroBatch的切分,并可指定输入的BatchDim,如果MicroBatch的切分不满足用户的需求,用户可进行重载。此外,PipelineScheduleBase的子类需要完成run_microbatches方法以及run方法。 为了方便用户自定义执行序调度,PipelineSchedule提供了一套自定义调度的接口,主要包括MetaStep以及MetaStepType,其中MetaStep是流水线并行调度的元操作,MetaStepType是枚举类型,指定了当前MetaStep的操作类型,包含有:FWD(正向)、BWD(反向)、FWD_RECV(正向recv)等, 具体Gpipe样例如下: ``` python class ScheduleGPipe(PipelineScheduleSingle): """ The Gpipe schedule. It first executes all forward micro batches and then execute all backward micro batches. """ def construct_exec_order(self): for stage_index in range(self.stage.stage_num): order_list = [] for mb_index in range(self.micro_batch_num): if stage_index != 0: order_list.append(MetaStep(mb_index, MetaStepType.FWD_RECV, stage_index)) order_list.append(MetaStep(mb_index, MetaStepType.FWD, stage_index)) if stage_index != self.stage.stage_num - 1: order_list.append(MetaStep(mb_index, MetaStepType.FWD_SEND, stage_index)) for mb_index in range(self.micro_batch_num): if stage_index != self.stage.stage_num - 1: order_list.append(MetaStep(mb_index, MetaStepType.BWD_RECV, stage_index)) order_list.append(MetaStep(mb_index, MetaStepType.BWD, stage_index)) if stage_index != 0: order_list.append(MetaStep(mb_index, MetaStepType.BWD_SEND, stage_index)) self.exec_order[stage_index] = order_list ``` MindSpore框架会内置GPipe,1F1B, VPP等常用调度,也支持用户利用调度接口进行自定义调度。 MindSpore Transformers第一阶段会提供不同的PP Scheduler的配置供用户选择,以此来实现调度算法的选择。后续会针对ZBV等算法进行实现支持。 3. zero optimizer(hsdp) 与原有的配置化方式的变化,当前MindSpore会提供类似torch fsdp的api方式提供一个对外的hsdp api,供用户实现zero optimizer的切分功能。具体接口设计如下: ```python hsdp( cell, # 网络模块 shard_size=1, # 参数dp域切分份数,1为不切分,即纯数据并行 threshold=64, # 参数切分阈值,单位KB,默认不超过64KB的不切分 optimizer_level="level1", # 切分方式,level1、level2、level3对应zero1、zero2、zero3 accumulate_grad_step=1 # 梯度累积次数,用来求梯度均值 ) ``` 其中cell就是我们在MindFormers中定义的模型类。 4. swap & recompute 动静一致的设计中,逐步是用的是函数式接口调用,函数式与以前的primitive的类调用相比,缺少了类属性的概念,无法记录哪些算子的输出结果需要被保存或者异步拷贝到cpu 内存中? 以重计算为例:在Megatron-LM中的设计是在forward接口中去在每个需要重算的位置调用checkpoint接口进行保存,或者针对一个结构的每一种重算策略都要开发一个重算的策略。其维护和开发工作量较大,灵活度低。而在新兴的torchtitan中,meta给出的是基于一类算子或者同一类算子的特定shape的算子可以设置细粒度的重算,仍然无法解决用户可以任意自定义设置算子级重算的属性。 为了解决这个问题,在MindFormers中我们将在每一个Cell中定义一个算子申明,用于匹配任意算子信息(用户可以按照字符串设置重算算子信息),算子变量名信息和用户设定信息匹配的则在MindFormers侧自动封装一个重算Cell进行重算。具体伪代码如下: 
### 背景信息 目前业界主流框架,例如torch2.x版本,在持续演进torch.compile的图模式,提供高性能的训推方案,并保持100%的前向兼容;torch.distribute的原生分布式能力也逐步从DDP发到FSDP2,将分布式训练能力逐步往框架内嵌,且动静图模式均可使用。从AI框架发展来看,提供动静一致的分布式训推能力成为趋势。 基于上述能力,目前业界也在发展新的模型训推开发工具库(pytorch官方的torchtitan、字节的Veomni等),其实现逻辑是将模型结构实现和并行算法进行分离,其大大简化了模型构建的成本。而并行优化可以不侵入式修改模型结构,设计上实现了各个组件之间的解耦。与Megatron-LM相比,其扩展性大大增强,Megatron-LM是将并行算法与模型实现耦合,并行调度强感知模型结构,拓展开发的难度高。 MindSpore Transformers从开源到当前版本,都是基于静态图模式构建的大模型训推解决方案,目前在图模式上遇到了如下关键问题: - 调试困难,整网整图,无法分段调试 - 并行算法黑盒(tensor parallel、pipeline parallel、zero optimizer等),例如DeepSeekV3引入的dualpp特性,需要依赖MindSpore开发完成,套件和用户无法自主完成 - 动态Shape、控制流支持度不够,无法满足多模态模型乃至后面的全模态模型训练诉求 综上,为解决MindSpore Transformers遇到的关键问题,结合业界的最新动态,计划联合MindSpore框架构建动静一致的分布式技术,目标如下: - 保持模型实现编程和分布式并行算法解耦,提升算法的可读性、可扩展性 - 支持pynative执行(动态图),并可以通过jit入图,转成图模式执行 - 支持控制流及动态shape - 并行算法实现上移至Python层,预置tensor parallel、pipeline parallel、zero optimizer等分布式算法;并支持用户扩展 ### 需求范围 目前在第一阶段,主要集中在基础能力的改造和补齐,完成动静态统一接口改造,本阶段MindSpore Transformers主要需求规划如下: - 适配动静一致编程范式的并行算法:tensor parallel、pipeline parallel、zero optimizer以及Swap&recompute特性 - 适配动静一致编程范式,逐步替换为函数式算子调用,对齐torch、jax等主流框架的调用方式 - 多模态基础接口:补齐多模态理解、全模态等场景接口 ### 设计方案 #### MindSpore动静一致的设计思路 当前MindSpore计划构建一套动静统一的分布式并行技术,目标如下: · 保持单卡逻辑编程,非侵入式修改单卡算法实现分布式并行,提升算法可读性; · 支持pynative执行,通过图编译接口转成整图执行; · 提供三层编程范式,手动、dtensor、shard; · 支持自定义并行与框架原生并行混合编程; · 支持控制流及动态shape; · 支持MPMD图切分及调度; 整体架构图如下:  #### MindSpore Transformers的parallel core改造 1. tensor parallel:声明式编程范式,按需配置需要重排布的张量,无需全局所有算子进行切分配置 张量并行的核心在于切分策略从模型的输入进行整网的推导,站在每一个算子的视角,即是依据算子输入张量的切分策略结合算子的数学逻辑给出输出张量的切分策略。因此,张量并行的核心是针对算子进行分类与逐类型的建模。 为了对显存进行优化,在特定的位置(特定的Cell)用户可以指定期望的切分策略,如若与推导得到的切分策略不一致,那么将进行重排布的通信操作,如短序列并行场景对norm处进行切分并且以AllGather汇聚给到Linear。 动态图、静态图整体实现有较大差异,因此对动态图的并行流程与静态图的并行流程分别进行设计。但是总体设计思路相同,用户配置切分策略仅对tensor进行配置,不对算子进行配置。对Cell的配置本质上也是配置在Cell的输入与输出的tensor上,当对于模型中间的某个tensor,自上游传播下来的切分策略与用户指定的切分策略不一致时,将触发重排布。整体设计流程图如下。  1、`RowParallelLinear`行切的linear为例,其实现的数学公式是`Y = XA + b`,其中`A = transpose([A_1 .. A_p]) X = [X_1, ..., X_p]`,编程的变化如下: - 纯静态图模式下,是针对每个算子进行配置,具体如下: ``` python def shard(self, config: TransformerConfig) -> None: """Shard the operators in RowParallelLinear. Args: config (TransformerConfig): The config of the transformer model. """ dp = config.data_parallel_size if config.data_parallel_size is not None else 1 cp = config.context_parallel_size if config.context_parallel_size is not None else 1 tp = config.tensor_model_parallel_size if config.tensor_model_parallel_size is not None else 1 matmul_in_strategy = ((dp * cp, tp), (tp, 1)) self.matmul.shard(in_strategy=matmul_in_strategy) if self.transpose_b: self.transpose.shard(((1, tp),)) if not self.skip_bias_add: add_in_strategy = ((dp * cp, 1), (1,)) self.add.shard(in_strategy=add_in_strategy) ``` - 动静一致新范式下,所有的切分是站Tensor维度去进行传播,只需要配置输入输出以及权重的layout信息即可,具体如下: ```python def shard(self, config: TransformerConfig) -> None: input_layout = layout("cp_dp", "tp") output_layout = layout("cp_dp", "None") weight_layout = layout("tp", "None") return input_layout, output_layout, weight_layout ``` 2、`ColumnParallelLinear`列切的linear为例,其实现的数学公式是`Y = XA + b`,其中权重A = [A_1, ..., A_p],其本质对于每一个token的计算都是独立无需插入通信重排布;编程的变化如下: - 纯静态图模式下,仍需要针对每个算子切分进行配置: ```python def shard(self, config: TransformerConfig) -> None: """Shard the operators in ColumnParallelLinear. Args: config (TransformerConfig): The config of the transformer model. """ dp = config.data_parallel_size if config.data_parallel_size is not None else 1 tp = config.tensor_model_parallel_size if config.tensor_model_parallel_size is not None else 1 cp = config.context_parallel_size if config.context_parallel_size is not None else 1 matmul_in_strategy = ((dp * cp, 1), (1, tp)) self.matmul.shard(in_strategy=matmul_in_strategy) if not self.skip_bias_add: add_in_strategy = ((dp * cp, tp), (tp,)) self.add.shard(in_strategy=add_in_strategy) ``` - 动静一致模式下,针对不需要进行重排布的Cell层,则可以依赖框架的策略传播模式,进行自动推导tensor的layout切分,具体如下: ```python def shard(self, config: TransformerConfig) -> None: weight_layout = layout("None", "tp") return None, None, weight_layout ``` 综上,对于动静一致,简化了配置化编程的难度,与具体的并行切分算法对其,只需要在插入通信算子处进行layout切分的配置,并利用框架的自动传播机制进行切分广播实现整网切分。 2. pipeline parallel: 将pipeline stage切分与pipeline scheduler实现开放给用户: - pp 切分 模型切分这部分需要在用户侧脚本执行完成,具体地,用户需要考虑每个 stage 需要计算的子模型,接着对这部分子模型进行实例化。参考脚本可如下所示: ``` python class Transformer(nn.Cell): def init (self, num layers) super(). init () self.embedding = nn.Embedding(...) self.layers = nn.CellDict() for layer id in range(num layers): self.layers[str(layer id)] = TransformerBlock(...) self.output = nn.Linear(...) def construct(self, input ids): h = self.embedding(input ids) if self.embedding else input id for layer in self.layers.values(): h= layer(h) output = self.output(h) if self.output else ha return output def model split manual(model, stage index): ""假设model共有8层,共2个stage,每个stage有4层om if stage index == O: for i in range(4, 8): del model.layers[str(i)] model.output = None else: for i range(4): del model.layers[str(i)] model.embedding = None ``` 此处对于MindSpore Transformers来说,只需要在模型实例化时,完成对应层的实例化控制即可。具体形式会与Megatron-LM进行参考对齐,方便大家理解与修改: ```python def build_layer(layer_spec, layer_number): global_layer_number = layer_number + get_transformer_layer_offset( self.config, self.vp_stage ) # 1-based index if self.config.heterogeneous_block_specs: layer_config = self.config.get_config_for_layer(global_layer_number) else: layer_config = self.config fp8_init_context = get_fp8_context(layer_config, global_layer_number - 1, is_init=True) with fp8_init_context: module = build_module( layer_spec, config=layer_config, layer_number=layer_number, model_comm_pgs=self.model_comm_pgs, vp_stage=self.vp_stage, ) return module ``` - pp 调度 流水线并行调度通常包括MicroBatch的切分,MicroBatch的执行,以及输出output的处理。MindSpore提供了PipelineScheduleBase基类,其中内置好了MicroBatch的切分,并可指定输入的BatchDim,如果MicroBatch的切分不满足用户的需求,用户可进行重载。此外,PipelineScheduleBase的子类需要完成run_microbatches方法以及run方法。 为了方便用户自定义执行序调度,PipelineSchedule提供了一套自定义调度的接口,主要包括MetaStep以及MetaStepType,其中MetaStep是流水线并行调度的元操作,MetaStepType是枚举类型,指定了当前MetaStep的操作类型,包含有:FWD(正向)、BWD(反向)、FWD_RECV(正向recv)等, 具体Gpipe样例如下: ``` python class ScheduleGPipe(PipelineScheduleSingle): """ The Gpipe schedule. It first executes all forward micro batches and then execute all backward micro batches. """ def construct_exec_order(self): for stage_index in range(self.stage.stage_num): order_list = [] for mb_index in range(self.micro_batch_num): if stage_index != 0: order_list.append(MetaStep(mb_index, MetaStepType.FWD_RECV, stage_index)) order_list.append(MetaStep(mb_index, MetaStepType.FWD, stage_index)) if stage_index != self.stage.stage_num - 1: order_list.append(MetaStep(mb_index, MetaStepType.FWD_SEND, stage_index)) for mb_index in range(self.micro_batch_num): if stage_index != self.stage.stage_num - 1: order_list.append(MetaStep(mb_index, MetaStepType.BWD_RECV, stage_index)) order_list.append(MetaStep(mb_index, MetaStepType.BWD, stage_index)) if stage_index != 0: order_list.append(MetaStep(mb_index, MetaStepType.BWD_SEND, stage_index)) self.exec_order[stage_index] = order_list ``` MindSpore框架会内置GPipe,1F1B, VPP等常用调度,也支持用户利用调度接口进行自定义调度。 MindSpore Transformers第一阶段会提供不同的PP Scheduler的配置供用户选择,以此来实现调度算法的选择。后续会针对ZBV等算法进行实现支持。 3. zero optimizer(hsdp) 与原有的配置化方式的变化,当前MindSpore会提供类似torch fsdp的api方式提供一个对外的hsdp api,供用户实现zero optimizer的切分功能。具体接口设计如下: ```python hsdp( cell, # 网络模块 shard_size=1, # 参数dp域切分份数,1为不切分,即纯数据并行 threshold=64, # 参数切分阈值,单位KB,默认不超过64KB的不切分 optimizer_level="level1", # 切分方式,level1、level2、level3对应zero1、zero2、zero3 accumulate_grad_step=1 # 梯度累积次数,用来求梯度均值 ) ``` 其中cell就是我们在MindFormers中定义的模型类。 4. swap & recompute 动静一致的设计中,逐步是用的是函数式接口调用,函数式与以前的primitive的类调用相比,缺少了类属性的概念,无法记录哪些算子的输出结果需要被保存或者异步拷贝到cpu 内存中? 以重计算为例:在Megatron-LM中的设计是在forward接口中去在每个需要重算的位置调用checkpoint接口进行保存,或者针对一个结构的每一种重算策略都要开发一个重算的策略。其维护和开发工作量较大,灵活度低。而在新兴的torchtitan中,meta给出的是基于一类算子或者同一类算子的特定shape的算子可以设置细粒度的重算,仍然无法解决用户可以任意自定义设置算子级重算的属性。 为了解决这个问题,在MindFormers中我们将在每一个Cell中定义一个算子申明,用于匹配任意算子信息(用户可以按照字符串设置重算算子信息),算子变量名信息和用户设定信息匹配的则在MindFormers侧自动封装一个重算Cell进行重算。具体伪代码如下: 
Comments (
0
)
Sign in
to comment
Status
TODO
VALIDATION
CLOSED
TODO
ACCEPTED
WIP
DONE
REJECTED
Assignees
Not set
Labels
Not set
Projects
Unprojected
Unprojected
Milestones
No related milestones
No related milestones
Pull Requests
None yet
None yet
Successfully merging a pull request will close this issue.
Branches
No related branch
Branches (38)
Tags (22)
master
r1.7.0-beta3
r1.6.0
br_feature_infer
r1.7.0-beta1
br_infer_boom
revert-3cfdd0a
dev
br_infer_deepseek_os
r1.5.0
br_feature_checkpoint
br_feature_infer_300iduo
br_feature_mcore
r1.6.0-beta1
br_infer_deepseek_ep
br_feature_rl_dpo
r1.3.0
r1.3.1
r1.4.0-beta2
r1.4.0-beta1
r1.5.0-beta1
r1.2.0
r1.1.0
r1.1.0-infer
r1.1.rc1
r1.0
kbk-infer
r1.0.a
r0.8
r0.7
r0.6.1_demo
r0.6
0.6rc1
r0.3
r0.2
v0.1.2
v0.1.1
v0.1.0
v1.7.0-beta3
v1.7.0-beta2
v1.6.0
v1.6.0-beta1
v1.5.0
v1.5.0-beta2
v1.5.0-beta1
v1.4.0-beta2
v1.3.2
v1.3.1-beta1
v1.4.0-beta1
v1.3.0
v1.2.0
v1.1.0
v1.0.2
v1.0.1
v1.0.0
v0.6.0
v0.3
v0.2_rc
v0.1.1
v0.1.0
Planed to start   -   Planed to end
-
Top level
Not Top
Top Level: High
Top Level: Medium
Top Level: Low
Priority
Not specified
Serious
Main
Secondary
Unimportant
Duration
(hours)
参与者(1)
Python
1
https://gitee.com/mindspore/mindformers.git
[email protected]
:mindspore/mindformers.git
mindspore
mindformers
mindformers
Going to Help Center
Search
Git 命令在线学习
如何在 Gitee 导入 GitHub 仓库
Git 仓库基础操作
企业版和社区版功能对比
SSH 公钥设置
如何处理代码冲突
仓库体积过大,如何减小?
如何找回被删除的仓库数据
Gitee 产品配额说明
GitHub仓库快速导入Gitee及同步更新
什么是 Release(发行版)
将 PHP 项目自动发布到 packagist.org
Comment
Repository Report
Back to the top
Login prompt
This operation requires login to the code cloud account. Please log in before operating.
Go to login
No account. Register