From a8a615abd202f2429fd0765466b5ebaa00a83f26 Mon Sep 17 00:00:00 2001 From: niu-junhao01 Date: Mon, 22 Sep 2025 21:06:33 +0800 Subject: [PATCH] refact build dataset process in base trainer. --- mindformers/trainer/base_trainer.py | 430 +++++++++++------- .../test_deepseek3/run_deepseek3.py | 1 + 2 files changed, 261 insertions(+), 170 deletions(-) diff --git a/mindformers/trainer/base_trainer.py b/mindformers/trainer/base_trainer.py index 91d7c02df..458349a15 100644 --- a/mindformers/trainer/base_trainer.py +++ b/mindformers/trainer/base_trainer.py @@ -171,7 +171,6 @@ class BaseTrainer: build_parallel_config(self.config) if os.environ.get("RUN_MODE") != "predict": self._check_grad_accumulation_steps() - self._check_global_batch_size_for_auto_parallel() self._reset_wrapper() return self.config @@ -189,105 +188,6 @@ class BaseTrainer: "you must define the required model, optimizer, and so on" "in the train or evaluate or predict attribute function.") - def _check_global_batch_size_for_auto_parallel(self): - """Check global batch size in auto parallel mode.""" - batch_size = self.config.runner_config.batch_size - self.config.runner_config.mini_batch_size = batch_size - gradient_accumulation_steps = self.config.runner_config.gradient_accumulation_steps - dp = self.config.parallel_config.data_parallel - micro_batch_num = self.config.parallel_config.micro_batch_num - micro_batch_interleave_num = self.config.micro_batch_interleave_num - parallel_mode = ms.get_auto_parallel_context("parallel_mode") - full_batch = ms.get_auto_parallel_context("full_batch") - ds_stra = ms.get_auto_parallel_context("dataset_strategy") - pp = self.get_pipeline_stages() - - if parallel_mode in ["semi_auto_parallel", "auto_parallel"]: - if pp == 1 and micro_batch_num > 1: - logger.warning("When pipeline parallel is not enabled, " - "config.parallel_config.micro_batch_num does not take effect. Reset it to 1.") - micro_batch_num = self.config.parallel_config.micro_batch_num = 1 - if full_batch: - if ds_stra != 'full_batch': - logger.warning(f"full_batch=True only supports dataset_strategy='full_batch', " - f"reset dataset_strategy {ds_stra} to 'full_batch'.") - ms.set_auto_parallel_context(dataset_strategy='full_batch') - - if pp > 1: - self.global_batch_size = batch_size * dp * micro_batch_num * micro_batch_interleave_num - logger.info("Pipeline parallel was opened: pipeline_stages = %s, full batch is True, " - "gradient_accumulation_steps will not take effect in pipeline parallel, " - "global batch size will be changed: " - "global_batch_size = " - "batch_size * data_parallel * micro_batch_num * micro_batch_interleave_num " - "= %s = %s * %s * %s * %s).", - pp, self.global_batch_size, batch_size, dp, micro_batch_num, - micro_batch_interleave_num) - self.config.runner_config.batch_size = self.global_batch_size - self._reset_wrapper_for_pipeline_parallel() - else: - self.global_batch_size = batch_size * dp * micro_batch_interleave_num * gradient_accumulation_steps - logger.info("The current parallel mode is %s, full batch is True," - "so global batch size will be changed: " - "global_batch_size = batch_size * data_parallel * micro_batch_interleave_num " - "* gradient_accumulation_steps = %s = %s * %s * %s * %s", - parallel_mode, self.global_batch_size, batch_size, dp, micro_batch_interleave_num, - gradient_accumulation_steps) - self.config.runner_config.batch_size = self.global_batch_size - else: # full_batch = False - if not isinstance(ds_stra, (tuple, list)): - raise ValueError("If set full_batch=False, dataset_strategy must be set as 'tuple', " - "such as ((dp, 1), ).") - ds_stra_dp = ds_stra[0][0] - if dp != ds_stra_dp: - raise ValueError(f"data_parallel {dp} should be equal to dataset_strategy[0][0] {ds_stra_dp}.") - - if pp > 1: - per_batch_size = batch_size * micro_batch_num * micro_batch_interleave_num - self.global_batch_size = per_batch_size * dp - logger.info("Pipeline parallel was opened: pipeline_stages = %s, full batch is False, " - "gradient_accumulation_steps will not take effect in pipeline parallel, " - "batch size per card will be changed: " - "per_batch_size = batch_size * micro_batch_num * micro_batch_interleave_num " - "= %s = %s * %s * %s).", - pp, per_batch_size, batch_size, micro_batch_num, - micro_batch_interleave_num) - logger.info("global_batch_size = per_batch_size * data_parallel = %s * %s = %s", - per_batch_size, dp, self.global_batch_size) - self.config.runner_config.batch_size = per_batch_size - self._reset_wrapper_for_pipeline_parallel() - else: - per_batch_size = batch_size * micro_batch_interleave_num * gradient_accumulation_steps - self.global_batch_size = per_batch_size * dp - logger.info("The current parallel mode is %s, full batch is False, " - "batch size per card will be changed: " - "per_batch_size = batch_size * micro_batch_interleave_num * " - "gradient_accumulation_steps = %s = %s * %s * %s).", - parallel_mode, per_batch_size, batch_size, micro_batch_interleave_num, - gradient_accumulation_steps) - logger.info("global_batch_size = per_batch_size * data_parallel = %s * %s = %s", - per_batch_size, dp, self.global_batch_size) - self.config.runner_config.batch_size = per_batch_size - else: - logger.info("The current parallel mode is %s, batch size per card will not be changed: " - "batch_size_per_card = %s", - parallel_mode, batch_size) - self.global_batch_size = batch_size * get_real_group_size() * gradient_accumulation_steps - logger.info( - "global_batch_size = batch_size_per_card * device_num * gradient_accumulation_steps " - "= %s = %s * %s * %s", - self.global_batch_size, batch_size, get_real_group_size(), gradient_accumulation_steps) - self.config.runner_config.batch_size = batch_size * gradient_accumulation_steps - self.config.parallel_config.data_parallel = 1 - self.config.parallel_config.model_parallel = 1 - self.config.parallel_config.context_parallel = 1 - self.config.parallel_config.expert_parallel = 1 - self.config.parallel_config.pipeline_stage = 1 - self.config.parallel_config.micro_batch_num = 1 - logger.info("parallel_config will be change to default config: %s.", - self.config.parallel_config) - self.config.runner_config.global_batch_size = self.global_batch_size - def _check_grad_accumulation_steps(self): """check the gradient accumulation steps.""" if self.config.runner_config.gradient_accumulation_steps is None: @@ -337,23 +237,6 @@ class BaseTrainer: if parallel_mode in ["semi_auto_parallel", "auto_parallel"] and gradient_accumulation_steps > 1: self._reset_wrapper_for_grad_accu(gradient_accumulation_steps) - def _reset_wrapper_for_pipeline_parallel(self): - """Reset wrapper when pipeline parallel.""" - if self.config.runner_wrapper is not None: - self.config.runner_wrapper.type = "MFPipelineWithLossScaleCell" \ - if self.config.runner_wrapper.type != "MFPipelineWithLossScaleCell" else self.config.runner_wrapper.type - self.config.runner_wrapper.micro_batch_num = self.config.parallel_config.micro_batch_num - self.config.runner_wrapper.calculate_per_token_loss = self.config.calculate_per_token_loss - logger.warning( - "When using the pipeline parallel mode, " - "the MFPipelineWithLossScaleCell class is used by default.") - else: - logger.info( - "When using the pipeline parallel mode, " - "because the wrapper class is not specified, " - "MindSpore's built-in PipelineCell is used by default") - logger.info("PipelineWrapper under evaluate or predict mode will not take effect.") - def _reset_wrapper_for_grad_accu(self, gradient_accumulation_steps): """Reset wrapper when using grad accumulation.""" if self.config.runner_wrapper is not None: @@ -796,22 +679,6 @@ class BaseTrainer: learning_rate = (base_learning_rate * device_num * per_device_batch_size) / scale_factor return learning_rate - def _process_megatron_dataset(self, dataset, config): - """Dataset processing for Megatron Dataset.""" - if ms.context.get_context("dataset_broadcast_opt_level") < 3: - raise ValueError("If using `BlendedMegatronDatasetDataLoader`, please set " - "`dataset_broadcast_opt_level: 3` in the `parallel_speed_up.json` file.") - - dataset_info = config.train_dataset.data_loader - # reset dataset size to remove redundant data - dataset = dataset.take(int(dataset_info.sizes[0]) // self.global_batch_size) - logger.info(f"Use BlendedMegatronDatasetDataLoader, reset dataset size to {dataset.get_dataset_size()}.") - - # Sync assign eod compression arguments - if self.config.train_dataset.data_loader.config.create_compressed_eod_mask: - self.config.model.model_config.use_eod_attn_mask_compression = True - return dataset, config - @staticmethod def _check_sink_mode_with_ds_broadcast(config): """Check sink_mode with dataset_broadcast_opt_level.""" @@ -825,43 +692,10 @@ class BaseTrainer: def _check_input_sliced_sig(config, usage_info='special'): """Check input_sliced_sig in model config.""" input_sliced_sig = config.model.model_config.get("input_sliced_sig") - if not input_sliced_sig: + if not input_sliced_sig and is_legacy_model(): raise ValueError( f"In this {usage_info} configuration, input_sliced_sig in model_config should be set 'True'") - def _train_dataset_postprocess(self, dataset, config): - """Dataset postprocess.""" - dataloader_info = config.train_dataset.get('data_loader') - if not dataloader_info: - return dataset, config - - # check sink_mode for dataset_broadcast_opt_level - self._check_sink_mode_with_ds_broadcast(config) - - dataloader_type = dataloader_info.get('type') - # postprocess for BlendedMegatronDatasetDataLoader - if dataloader_type == "BlendedMegatronDatasetDataLoader": - self._check_input_sliced_sig(config, dataloader_type) - return self._process_megatron_dataset(dataset, config) - - # postprocess for CommonDataLoader / HFDataLoader - if dataloader_type in ['HFDataLoader', 'CommonDataLoader']: - # check config for HF pack mode - handler = [] - if hasattr(dataloader_info, 'handler'): - handler = [sub_handler.get('type') for sub_handler in dataloader_info.handler] - if is_legacy_model() and 'PackingHandler' in handler: # CommonDataLoader legacy option - self._check_input_sliced_sig(config, f"{dataloader_type} with packing") - - # check config for use_broadcast_data - ds_broadcast_level = ms.context.get_context("dataset_broadcast_opt_level") - if dataloader_info.get('use_broadcast_data', True) and ds_broadcast_level < 3: - raise ValueError( - "If you are using `HFDataLoader` or `CommonDataLoader` and enable `use_broadcast_data`, " - "please set `dataset_broadcast_opt_level: 3` in the `parallel_speed_up.json` file." - ) - return dataset, config - @staticmethod def resume_ckpt_path_with_strategy(config): """Get resume checkpoint path with strategy. @@ -919,6 +753,260 @@ class BaseTrainer: return None + @staticmethod + def _reset_wrapper_for_pipeline_parallel(config): + """Reset wrapper when pipeline parallel.""" + if config.runner_wrapper is not None: + config.runner_wrapper.type = "MFPipelineWithLossScaleCell" + config.runner_wrapper.micro_batch_num = config.parallel_config.micro_batch_num + config.runner_wrapper.calculate_per_token_loss = config.calculate_per_token_loss + logger.warning("If pipeline_stages > 1 and config.runner_wrapper.type is None, " + "use `MFPipelineWithLossScaleCell` as config.runner_wrapper.type.") + logger.info("PipelineWrapper under evaluate or predict mode will not take effect.") + return config + + def _calculate_global_batch_size(self, config): + """ + Calculate the effective global batch size under different parallel modes. + Handles auto/semi-auto parallel as well as standalone/manual parallel cases. + """ + # Micro-batch size per device + micro_batch_size = config.runner_config.batch_size + config.runner_config.mini_batch_size = micro_batch_size + + gradient_accumulation_steps = config.runner_config.gradient_accumulation_steps + micro_batch_num = config.parallel_config.micro_batch_num + micro_batch_interleave_num = config.micro_batch_interleave_num + + # Parallel configurations + dp = config.parallel_config.data_parallel + pipeline_stages = ms.get_auto_parallel_context("pipeline_stages") + parallel_mode = ms.get_auto_parallel_context("parallel_mode") + dataset_strategy = ms.get_auto_parallel_context("dataset_strategy") + full_batch = dataset_strategy == 'full_batch' + + logger.info(f"The current parallel mode is {parallel_mode}.") + + # Case 1: Auto parallel or semi-auto parallel mode + if parallel_mode in ["semi_auto_parallel", "auto_parallel"]: + if pipeline_stages > 1: + # Pipeline parallelism > 1 → disable gradient accumulation + logger.warning("If pipeline_stages > 1, config.runner_config.gradient_accumulation_steps reset to 1.") + gradient_accumulation_steps = config.runner_config.gradient_accumulation_steps = 1 + else: + # No pipeline parallelism → disable micro-batch split + logger.warning("If pipeline_stages = 1, config.parallel_config.micro_batch_num reset to 1.") + micro_batch_num = config.parallel_config.micro_batch_num = 1 + + per_batch_size = ( + micro_batch_size + * micro_batch_num + * gradient_accumulation_steps + * micro_batch_interleave_num + ) + global_batch_size = dp * per_batch_size + + logger.info( + f"Calculate per_batch_size({per_batch_size}) = micro_batch_size({micro_batch_size}) " + f"* micro_batch_num({micro_batch_num}) " + f"* gradient_accumulation_steps({gradient_accumulation_steps}) " + f"* micro_batch_interleave_num({micro_batch_interleave_num})" + ) + logger.info( + f"Calculate global_batch_size({global_batch_size}) = " + f"data_parallel({dp}) * per_batch_size({per_batch_size})" + ) + + # Set effective batch size depending on full_batch flag + if full_batch: + config.runner_config.batch_size = global_batch_size + else: + config.runner_config.batch_size = per_batch_size + + # Reset wrapper for pipeline parallel case + if pipeline_stages > 1: + config = self._reset_wrapper_for_pipeline_parallel(config) + + # Case 2: Standalone or data parallel mode + else: + device_num = get_real_group_size() + per_batch_size = micro_batch_size * gradient_accumulation_steps + global_batch_size = device_num * per_batch_size + + logger.info( + f"Calculate global_batch_size({global_batch_size}) = device_num({device_num}) " + f"* micro_batch_size({micro_batch_size}) " + f"* gradient_accumulation_steps({gradient_accumulation_steps})" + ) + + # Reset batch size and parallel configs to single-device defaults + config.runner_config.batch_size = per_batch_size + config.parallel_config.data_parallel = 1 + config.parallel_config.model_parallel = 1 + config.parallel_config.context_parallel = 1 + config.parallel_config.expert_parallel = 1 + config.parallel_config.pipeline_stage = 1 + config.parallel_config.micro_batch_num = 1 + + config.runner_config.global_batch_size = global_batch_size + self.global_batch_size = global_batch_size + return config + + + def _train_dataset_postprocess(self, dataset, config): + """ + Postprocess the training dataset after construction. + Mainly used to adjust dataset size for special dataloaders. + """ + dataloader_info = config.train_dataset.get('data_loader') + if not dataloader_info: + return dataset + + dataloader_type = dataloader_info.get('type') + + # Special handling for BlendedMegatronDatasetDataLoader + if dataloader_type == "BlendedMegatronDatasetDataLoader": + dataset_info = config.train_dataset.data_loader + # Resize dataset to ensure it's divisible by global batch size (remove redundant data) + dataset = dataset.take(int(dataset_info.sizes[0]) // self.global_batch_size) + logger.info( + f"Use BlendedMegatronDatasetDataLoader, reset dataset size to {dataset.get_dataset_size()}." + ) + + return dataset + + @staticmethod + def _preprocess_attention_mask_config(config, dataset_config, dp): + """ + Preprocess attention mask related configurations in model and dataloader. + Ensures consistency between model's attention mask settings and dataloader's mask creation options. + """ + # Default sharding strategy and column names + dataset_strategy = [[dp, 1], [dp, 1], [dp, 1], [dp, 1]] + column_names = ['input_ids', 'labels', 'loss_mask', 'position_ids'] + + if dataset_config.get('create_compressed_eod_mask', False): + dataset_strategy.append([dp, 1]) + column_names.append('actual_seq_len') + config.model.model_config.use_eod_attn_mask_compression = True + elif dataset_config.get('create_attention_mask', False): + dataset_strategy.append([dp, 1, 1, 1]) + column_names.append('attention_mask') + + return config, dataset_strategy, column_names + + def _train_dataset_preprocess(self, config): + """ + Preprocess training dataset configuration before building the dataloader. + This includes validating broadcast options, setting dataset strategies, + and assigning dataset column names depending on the dataloader type. + """ + + # Check dataset sink mode with dataset broadcast optimization level + self._check_sink_mode_with_ds_broadcast(config) + + dataloader_config = config.train_dataset.get('data_loader', dict()) + dataloader_type = dataloader_config.get('type') + dp = config.parallel_config.data_parallel + ds_broadcast_level = ms.context.get_context("dataset_broadcast_opt_level") + + # If full_batch is explicitly set in parallel config, log a warning + full_batch = config.parallel.get('full_batch') + if full_batch is not None: + logger.warning("`full_batch` will be deprecated in future interfaces, use `dataset_strategy` instead.") + + # Default settings + dataset_strategy = 'full_batch' + column_names = None + + # Case 1: BlendedMegatronDatasetDataLoader + if dataloader_type == 'BlendedMegatronDatasetDataLoader': + # Validate that input slicing is enabled + self._check_input_sliced_sig(config, dataloader_type) + + # Must use broadcast opt level >= 3 for this loader + if ds_broadcast_level < 3: + raise ValueError( + "If using `BlendedMegatronDatasetDataLoader`, please set " + "`dataset_broadcast_opt_level: 3` in the `parallel_speed_up.json` file." + ) + + # Additional fields depending on sub_config + sub_config = dataloader_config.get('config') + config, dataset_strategy, column_names = self._preprocess_attention_mask_config( + config, sub_config, dp) + + # Case 2: HFDataLoader or CommonDataLoader + if dataloader_type in ['HFDataLoader', 'CommonDataLoader']: + # Collect handler types if present + handler = [] + if dataloader_config.handler is not None: + handler = [sub_handler.get('type') for sub_handler in dataloader_config.handler] + + # Packing or attention mask requires extended dataset strategy + if 'PackingHandler' in handler: + if is_legacy_model(): + self._check_input_sliced_sig(config, f"{dataloader_type} with packing") + + config.train_dataset.data_loader.create_attention_mask = True + config, dataset_strategy, column_names = self._preprocess_attention_mask_config( + config, config.train_dataset.data_loader, dp) + else: + dataset_strategy = [[dp, 1], [dp, 1]] + column_names = ['input_ids', 'labels'] + + # Must use broadcast opt level >= 3 if broadcast is enabled + if dataloader_config.get('use_broadcast_data', True) and ds_broadcast_level < 3: + raise ValueError( + "If you are using `HFDataLoader` or `CommonDataLoader` and enable `use_broadcast_data`, " + "please set `dataset_broadcast_opt_level: 3` in the `parallel_speed_up.json` file." + ) + + # Allow overriding defaults from config + dataset_strategy = config.parallel.get('dataset_strategy', dataset_strategy) + column_names = config.train_dataset.get('input_columns', column_names) + construct_args_key = config.train_dataset.get('construct_args_key', column_names) + + logger.info(f"Got dataset config: " + f"dataset_strategy={dataset_strategy}, " + f"column_names: {column_names}, " + f"construct_args_key:{construct_args_key}") + + # Convert strategy to correct format and set full_batch flag + if dataset_strategy != 'full_batch': + full_batch = False + if isinstance(dataset_strategy, list): + dataset_strategy = tuple(tuple(ds_item) for ds_item in dataset_strategy) + elif not isinstance(dataset_strategy, tuple): + raise ValueError('`dataset_strategy` should be list or tuple.') + else: + full_batch = True + + # Apply parallel context settings + ms.set_auto_parallel_context(**{ + 'full_batch': full_batch, + 'dataset_strategy': dataset_strategy + }) + + # Update config with resolved column names and construct keys + config.train_dataset.input_columns = column_names + config.train_dataset.construct_args_key = construct_args_key + config.train_dataset_task.dataset_config = config.train_dataset + + config.parallel.full_batch = full_batch + config.parallel.dataset_strategy = dataset_strategy + + return config + + def _preprocess_config(self, config): + """ + High-level preprocessing for training config. + Includes dataset preprocessing and global batch size calculation. + """ + config = self._train_dataset_preprocess(config) + config = self._calculate_global_batch_size(config) + return config + # pylint: disable=C0330 def training_process( self, @@ -931,18 +1019,19 @@ class BaseTrainer: **kwargs): """Train or Fine-tune for BaseTrainer in MindFormers.""" self.kwargs = kwargs + self.compute_metrics = compute_metrics if compute_metrics else self.compute_metrics self.train_dataset = dataset if dataset else self.train_dataset self.eval_dataset = kwargs.get('eval_dataset', None) - self.compute_metrics = compute_metrics if compute_metrics else self.compute_metrics - construct_args_key = config.train_dataset.pop("construct_args_key", None) is_full_config = kwargs.get("is_full_config", False) + config = self._preprocess_config(config) config = self.set_config(config, is_full_config) # build dataset logger.info(".........Build Dataset For Train..........") + construct_args_key = config.train_dataset.pop("construct_args_key", None) dataset = self.create_train_dataset() # postprocess and check dataset configuration - dataset, config = self._train_dataset_postprocess(dataset, config) + dataset = self._train_dataset_postprocess(dataset, config) logger.info("Create train dataset finish, dataset size:%d", dataset.get_dataset_size()) append_info = None @@ -1381,6 +1470,7 @@ class BaseTrainer: self.eval_dataset = dataset if dataset else self.eval_dataset metric_name = kwargs.get("metric_name") is_full_config = kwargs.get("is_full_config", False) + config = self._preprocess_config(config) config = self.set_config(config, is_full_config) construct_args_key = config.eval_dataset.pop("construct_args_key", None) diff --git a/tests/st/test_multi_cards_cases/test_model/test_deepseek3/run_deepseek3.py b/tests/st/test_multi_cards_cases/test_model/test_deepseek3/run_deepseek3.py index a5e7d8c13..4e5f0950f 100644 --- a/tests/st/test_multi_cards_cases/test_model/test_deepseek3/run_deepseek3.py +++ b/tests/st/test_multi_cards_cases/test_model/test_deepseek3/run_deepseek3.py @@ -41,6 +41,7 @@ def ds3_train(config, dataset, construct_args_key, checker_config): train_dataset=dataset, callbacks=callback) + task_trainer.config.train_dataset.input_columns = construct_args_key task_trainer.config.train_dataset.construct_args_key = construct_args_key def create_network(self, default_args): network = type(self).create_network(self, default_args) -- Gitee