diff --git a/baichuan2_13b.py b/baichuan2_13b.py deleted file mode 100644 index ef85fefcf932f91e3286c097fbf9aaf20ef8c1ee..0000000000000000000000000000000000000000 --- a/baichuan2_13b.py +++ /dev/null @@ -1,1283 +0,0 @@ -# Copyright 2023 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Baichuan2_13b models' APIs.""" -from typing import Optional -import math -import copy -import numpy as np -import mindspore.common.dtype as mstype - -try: - from mindspore._checkparam import Validator -except ImportError: - import mindspore._checkparam as Validator -from mindspore import Tensor, nn, ops -from mindspore.common.parameter import Parameter -from mindspore.context import ParallelMode -from mindspore.ops import operations as P -from mindspore.common.initializer import initializer, HeUniform -from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation - -from mindformers.core.loss.loss import CrossEntropyLoss -from mindformers.modules.flash_attention import FlashAttention -from mindformers.models.modeling_utils import PreTrainedModel -from mindformers.models.utils import lazy_inline -from mindformers.modules.transformer.op_parallel_config import _check_config -from mindformers.modules.layers import Linear, _check_input_dtype, build_alibi_tensor_v2 -from mindformers.modules.transformer import TransformerOpParallelConfig, LowerTriangularMaskWithDynamic -from mindformers.modules.infer_attention import InferAttention -from mindformers.tools.register.register import MindFormerModuleType, MindFormerRegister -from mindformers.models.utils import set_layer_stage_recompute -from mindformers.models.llama.llama_config import LlamaConfig -from mindformers.models.llama.llama_layer import LlamaEmbedding, LlamaFeedForward, LlamaRMSNorm -from mindformers.tools.logger import logger - -__all__ = ['Baichuan13BV2ForCausalLM', 'Baichuan13BV2Model'] - - -class Baichuan2PreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = LlamaConfig - base_model_prefix = "baichuan2" - - -@MindFormerRegister.register(MindFormerModuleType.MODELS) -class Baichuan13BV2ForCausalLM(Baichuan2PreTrainedModel): - r""" - Provide baichuan2_13B training loss or logits through network. - Args: - config (LlamaConfig): The config of baichuan2_13B model. - - Inputs: - input_ids(Tensor): the tokenized inputs with datatype int32, Tensor of shape :math:`(batch, seq\_length)`. - labels(Tensor): the tokenized labels with datatype int32, Tensor of shape :math:`(batch, seq\_length)`. - input_position(Tensor): current position, used by model.predict. - init_reset(bool, optional): A bool tensor with shape [1], used to clear the past key parameter and - past value parameter used in the incremental prediction. Default True. - batch_valid_length(Tensor): the past calculated the index with datatype int32, used for incremental - prediction. Tensor of shape :math:`(batch_size,)`. Default None. - block_tables (Tensor[int64]): Store mapping tables for each sequence. - slot_mapping (Tensor[int32]): Store token cache physical slot index. - - Returns: - Tensor, the loss or logits of the network. - - Examples: - >>> from mindformers.models.llama import LlamaConfig - >>> from research.baichuan2.baichuan2_13b import Baichuan13BV2ForCausalLM - >>> config = LlamaConfig(batch_size=2) - >>> network = Baichuan13BV2ForCausalLM(config=config) - """ - - @lazy_inline - def __init__(self, config: LlamaConfig = None): - super(Baichuan13BV2ForCausalLM, self).__init__(config, auto_prefix=True) - _check_config(config.parallel_config) - self.config = config - self.seq_length = config.seq_length - self.ignore_token_id = config.ignore_token_id - self.pad_token_id = config.pad_token_id - self.use_past = config.use_past - self.vocab_size = config.vocab_size - self.is_first_iteration = True - self.dtype = config.compute_dtype - - self.shape = P.Shape() - self.reshape = P.Reshape() - self.cast = P.Cast() - self.slice = P.StridedSlice() - self.not_equal = P.NotEqual() - self.mul = P.Mul() - self.add = P.Add() - self.ones = P.Ones() - self.gather = P.Gather(1) - self.sub_batch_valid_len = P.Sub() - self.model = Baichuan13BV2Model(config=config) - self.lm_head = NormHead(hidden_size=config.hidden_size, - vocab_size=config.vocab_size, - use_past=config.use_past, - is_dynamic=config.is_dynamic, - compute_dtype=config.compute_dtype) - - vocab_size = config.vocab_size - loss_parallel_config = copy.deepcopy(config.parallel_config) - loss_parallel_config.model_parallel = loss_parallel_config.model_parallel * loss_parallel_config.data_parallel - loss_parallel_config.data_parallel = 1 - if vocab_size % (loss_parallel_config.model_parallel) != 0: - logger.warning("The vocab size of Loss is: %s, it is not divide by model_parallel: %s", - vocab_size, loss_parallel_config.model_parallel) - logger.warning("Now, the model_parallel num of Loss will be changed: mp = 1") - loss_parallel_config.model_parallel = 1 - self.loss = CrossEntropyLoss(parallel_config=loss_parallel_config) - - dp = config.parallel_config.data_parallel - if not (_get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation()): - self.slice.shard(((dp, 1),)) - self.not_equal.shard(((dp, 1), ())) - self.mul.shard(((dp, 1), (dp, 1))) - self.add.shard(((dp, 1), ())) - self.lm_head.shard(config.parallel_config) - self.gather.shard(((dp, 1, 1), (dp,))) - self.sub_batch_valid_len.shard(((1,), ())) - - if config.parallel_config.pipeline_stage > 1: - self.lm_head.pipeline_stage = config.parallel_config.pipeline_stage - 1 - self.lm_head.set_comm_fusion(2) - else: - self.lm_head.set_comm_fusion(config.parallel_config.gradient_aggregation_group) - - if config.is_dynamic: - self.reshape.add_prim_attr("skip_redistribution", True) - - self.load_checkpoint(config) - self.set_model_predict_config() - # pylint: disable=W0613 - def prepare_inputs_for_generation(self, input_ids, **kwargs): - if self.config.is_dynamic and "origin_inputs" in kwargs: - input_ids = kwargs["origin_inputs"] - return { - "input_ids": Tensor(input_ids, mstype.int32) - } - - def set_dynamic_inputs(self, **kwargs): - dynamic_input_ids = Tensor(shape=[None, None], dtype=mstype.int32) - dynamic_input_position = Tensor(shape=[None], dtype=mstype.int32) - dynamic_init_reset = Tensor([False], mstype.bool_) - dynamic_batch_valid_length = Tensor(shape=[None, None], dtype=mstype.int32) - dynamic_block_tables = Tensor(shape=[None, None], dtype=mstype.int32) - dynamic_slot_mapping = Tensor(shape=[None], dtype=mstype.int32) - self.set_inputs(dynamic_input_ids, None, dynamic_input_position, None, None, None, dynamic_init_reset, - dynamic_batch_valid_length, None, None, dynamic_block_tables, dynamic_slot_mapping) - logger.info("Set dynamic input for baichuan2.") - - # pylint: disable=W0613 - def prepare_inputs_for_predict_layout(self, input_ids, **kwargs): - """Get Baichuan13BV2 model input tuple for transform ckpt.""" - input_ids = Tensor(input_ids, mstype.int32) - bs, seq = input_ids.shape[0], input_ids.shape[1] - slot_mapping = Tensor(np.ones(shape=tuple([bs*seq])), mstype.int32) - return input_ids, None, None, None, None, None, None, None, None, None, None, slot_mapping - - def add_flags_custom(self, is_first_iteration): - """Add customized attributes for specific cells in the model.""" - self.add_flags(is_first_iteration=is_first_iteration) - self.model.add_flags(is_first_iteration=is_first_iteration) - for layer in self.model.layers: - layer.add_flags(is_first_iteration=is_first_iteration) - layer.attention.infer_attention.add_flags(is_first_iteration=is_first_iteration) - - # pylint: disable=W0613 - def construct(self, input_ids, labels=None, input_position=None, position_ids=None, attention_mask=None, - input_embeds=None, init_reset=True, batch_valid_length=None, batch_index=None, zactivate_len=None, - block_tables=None, slot_mapping=None): - """Baichuan13BV2ForCausalLM forward.""" - bsz, seqlen = self.shape(input_ids) - if self.use_past: - if not isinstance(batch_valid_length, Tensor): - batch_valid_length = self.ones((bsz,), mstype.int32) - if self.training: - tokens = self.slice(input_ids, (0, 0), (bsz, seqlen - 1), (1, 1)) - else: - tokens = input_ids - if batch_valid_length is not None: - batch_valid_length = self.reshape(batch_valid_length, (-1,)) - output = self.model(tokens, batch_valid_length, block_tables, slot_mapping) - pre_gather = (not self.use_past or self.is_first_iteration) and batch_valid_length is not None - if pre_gather: - output = self.gather(output, self.sub_batch_valid_len(batch_valid_length, 1), 1) - logits = self.lm_head(output) - - input_mask = self.cast(self.not_equal(tokens, self.pad_token_id), mstype.float32) - if labels is None: - labels = self.slice(input_ids, (0, 1), (bsz, seqlen), (1, 1)) - else: - if labels.ndim > 1: - if self.training: - labels = self.slice(labels, (0, 1), (bsz, seqlen), (1, 1)) - label_mask = self.cast(self.not_equal(labels, self.ignore_token_id), mstype.float32) - input_mask = self.mul(input_mask, label_mask) - - if not self.training: - return logits, tokens, input_mask - - if logits.ndim > 2: - logits = self.reshape(logits, (-1, logits.shape[-1])) - logits = self.cast(logits, mstype.float32) - labels = self.reshape(labels, (-1,)) - input_mask = self.reshape(input_mask, (-1,)) - loss = self.loss(logits, labels, input_mask) - return loss - - def kvcache(self, layer_idx): - key_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.key_cache - value_cache = self.model.layers[layer_idx].attention.infer_attention.paged_attention_mgr.value_cache - return key_cache, value_cache - - -class Baichuan13BV2Model(Baichuan2PreTrainedModel): - r""" - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Baichuan13BV2DecoderLayer`] - Args: - config(LlamaConfig): the config of network - - Inputs: - input_ids: the tokenized inputs with datatype int32 - - Returns: - output: Tensor, the output of baichuan2_13b decoderlayer - """ - - def __init__(self, - config: LlamaConfig = None): - super().__init__(config, auto_prefix=True) - _check_config(config.parallel_config) - if config.batch_size or config.use_past: - Validator.check_positive_int(config.batch_size) - self.dtype = config.compute_dtype - self.seq_length = config.seq_length - self.hidden_size = config.hidden_size - self.num_layers = config.num_layers - self.n_head = config.num_heads - self.head_dim = self.hidden_size // self.n_head - self.pad_token_id = config.pad_token_id - self.is_first_iteration = True - self.use_past = config.use_past - self.is_dynamic = config.is_dynamic - self.use_flash_attention = config.use_flash_attention - # only support flash attention in train and prefill predict process. - if self.use_flash_attention: - logger.info("Enable flash attention.") - # only support paged attention in predict process. - self.block_size = config.block_size - self.num_blocks = config.num_blocks - - self.shape = P.Shape() - self.reshape = P.Reshape() - self.cast = P.Cast() - self.mul_mask = P.Mul() - self.mul_alibi = P.Mul() - self.sub = P.Sub() - self.tile = P.Tile() - self.expand_dims = P.ExpandDims() - self.not_equal = P.NotEqual() - self.gather = P.Gather() - self.transpose = P.Transpose() - self.slice = P.StridedSlice() - self.ones = P.Ones() - - self.casual_mask = LowerTriangularMaskWithDynamic(seq_length=config.seq_length, - compute_type=config.compute_dtype, - is_dynamic=config.is_dynamic, - pad_token_id=config.pad_token_id, - use_flash_attention=config.use_flash_attention and not - config.use_past) - self.tok_embeddings = LlamaEmbedding(vocab_table_size=config.vocab_size, - embedding_size=config.hidden_size, - param_init_type=config.param_init_type, - parallel_optimizer=True) - self.layers = nn.CellList() - for layer_id in range(config.num_layers): - layer = Baichuan13BDecodeLayer(config.batch_size, - config.seq_length, - layer_id, - dim=config.hidden_size, - n_heads=config.num_heads, - n_kv_heads=config.n_kv_heads, - intermediate_size=config.intermediate_size, - multiple_of=config.multiple_of, - ffn_dim_multiplier=config.ffn_dim_multiplier, - norm_eps=config.rms_norm_eps, - compute_dtype=config.compute_dtype, - layernorm_compute_dtype=config.layernorm_compute_type, - softmax_compute_dtype=config.softmax_compute_type, - param_init_type=config.param_init_type, - use_past=config.use_past, - is_dynamic=config.is_dynamic, - use_flash_attention=self.use_flash_attention, - block_size=self.block_size, - num_blocks=self.num_blocks, - parallel_config=config.parallel_config) - set_layer_stage_recompute(layer, layer_id, config.offset, config.parallel_config, config.num_layers) - self.layers.append(layer) - self.norm_out = LlamaRMSNorm(config.hidden_size, config.rms_norm_eps, - compute_type=config.layernorm_compute_type) - self.alibi_tensor = build_alibi_tensor_v2(seq_len=config.seq_length, - num_heads=config.num_heads, - return_tensors='ms', - dtype=self.dtype) - - dp = config.parallel_config.data_parallel - mp = config.parallel_config.model_parallel - if not (_get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation()): - self.tok_embeddings.pipeline_stage = 0 - if config.parallel_config.pipeline_stage > 1: - self.norm_out.pipeline_stage = config.parallel_config.pipeline_stage - 1 - self.tok_embeddings.set_comm_fusion(2) - self.norm_out.set_comm_fusion(2) - else: - self.tok_embeddings.set_comm_fusion(config.parallel_config.gradient_aggregation_group) - self.norm_out.set_comm_fusion(config.parallel_config.gradient_aggregation_group) - - self.tok_embeddings.shard(config.parallel_config) - self.casual_mask.shard(config.parallel_config) - self.sub.shard(((1,), (dp, 1, 1))) - self.mul_mask.shard(((dp, 1, 1, 1), (1,))) - self.mul_alibi.shard(((1, mp, 1, 1), (dp, 1, 1, 1))) - - self.expand_dims.shard(((dp, 1, 1),)) - self.not_equal.shard(((dp, 1), ())) - self.gather.shard(((1, mp, 1, 1), (1,))) - self.norm_out.shard((dp, 1, 1)) - - if self.is_dynamic: - self.reshape.add_prim_attr("skip_redistribution", True) - - # pylint: disable=W0613 - def construct(self, tokens: Tensor, batch_valid_length=None, block_tables=None, slot_mapping=None): - """Forward of baichuan2_13b model.""" - # preprocess - bs, seq_len = self.shape(tokens) - - if not self.use_past: - mask = self.casual_mask(tokens) # mask: mask: [bs , 1, seq, seq] - input_mask = self.cast(self.not_equal(tokens, self.pad_token_id), mstype.float16) - alibi_tensor = self.slice(self.alibi_tensor, (0, 0, 0, 0), - (1, self.alibi_tensor.shape[1], seq_len, seq_len), (1, 1, 1, 1)) - alibi_tensor = self.mul_alibi(alibi_tensor, self.reshape(input_mask, (bs, 1, -1, 1))) - else: - mask = None - if self.is_first_iteration: - if not self.use_flash_attention: - mask = self.casual_mask(tokens) # mask: [bs , 1, seq, seq] - alibi_tensor = self.slice(self.alibi_tensor, (0, 0, 0, 0), - (1, self.alibi_tensor.shape[1], seq_len, seq_len), (1, 1, 1, 1)) - else: - alibi_tensor = self.gather(self.alibi_tensor, batch_valid_length, 2) - alibi_tensor = self.transpose(alibi_tensor, (2, 1, 0, 3)) - # tokens: [bs, seq/1] - h = self.tok_embeddings(tokens) - h = self.reshape(h, (bs, seq_len, self.hidden_size)) - # h: [bs, seq/1, hidden_dim] - for i in range(self.num_layers): - h = self.layers[i](h, alibi_tensor, mask, batch_valid_length, block_tables, slot_mapping) - output = self.norm_out(h) - return output - - -class Baichuan13BAttention(nn.Cell): - r""" - This is an implementation of multihead attention in Baichuan. - - Args: - - **batch_size** (int): The batch size of the input tensor when do increnmental prediction. Should be a - positive value. - When do training or prediction, the argument will not work and the user can just pass None to the - argument. - - **src_seq_length** (int): The sequence length of the query vector. - - **tgt_seq_length** (int): The sequence length of the key and value vector. - - **dim** (int): The hidden size of the input. - - **head_dim** (int): The dim of head. - - **n_heads** (int): The number of the heads. - - **compute_dtype** (dtype.Number): The computation type of dense. Default mstype.float16. - Should be mstype.float32 or mstype.float16. - - **softmax_compute_type** (dtype.Number): The type of softmax computation module. Default mstype.float32. - Should be mstype.float32 or mstype.float16. - - **param_init_type** (dtype.Number): The parameter initialization type of the module. Default mstype. - float32. Should be mstype.float32 or mstype.float16. - - **use_past** (bool): Use the past state to compute, used for incremental prediction. - For example, if we have two words and want to generate the ten more words. - We just need to compute the two words' state only once, and generate the next word one by one. - When use_past is True, there are two steps to run the prediction. - In the first step, set the is_first_iteration to be True by - `model.add_flags_recursive(is_first_iteration=True)`, and pass the full inputs. Then, set the - is_first_iteration to be False by `model.add_flags_recursive(is_first_iteration=False)`. At this moment, - pass the single step's input tensor, and loop it. Default False. - - **parallel_config** (OpParallelConfig): The parallel configure. Default `default_dpmp_config`, - an instance of `OpParallelConfig` with default args. - - Inputs: - - **x** (Tensor) - The input tokens with shape (batch_size, src_seq_length, hidden_size) or - (batch_size * src_seq_length, hidden_size), if the use_past is False or is_first_iteration=True. - Otherwise, must be (batch_size, 1, hidden_size) - - **alibi_tensor** (Tensor) - Alibi Tensor for position embedding used in attention. - - **mask** (Tensor) - If the use_past is False or is_first_iteration=True, the attention mask - matrix should ba (batch_size, src_seq_length, tgt_seq_length), or None. None means there will be no mask - in softmax computation. Otherwise, the mask must be (batch_size, 1, tgt_seq_length) - - **batch_valid_length** (Tensor) - Int32 tensor with shape (batch_size,) the past calculated the index. - Used for incremental prediction when the use_past is True. Default None. - - **block_tables** (Tensor[int64]) - Store mapping tables for each sequence. - - **slot_mapping** (Tensor[int32]) - Store token cache physical slot index. - Outputs: - Tuple, a tuple contains(`output`, `layer_present`) - - - **output** (Tensor) - Tensor, the float tensor of the output of the layer with - shape (batch_size, src_seq_length, hidden_size) or (batch_size * src_seq_length, hidden_size), - if the use_past is False or is_first_iteration=True. Otherwise, it will be (batch_size, 1, hidden_size). - - - **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with - ((batch_size, num_heads, head_dim, tgt_seq_length), - (batch_size, num_heads, tgt_seq_length, head_dim)). - """ - - def __init__(self, - batch_size, - seq_length, - dim: int = 512, - n_heads: int = 8, - n_kv_heads: Optional[int] = None, - compute_dtype=mstype.float16, - softmax_compute_dtype=mstype.float32, - param_init_type=mstype.float32, - use_past=False, - is_dynamic=False, - use_flash_attention=False, - block_size: int = 128, - num_blocks: int = 224, - parallel_config=TransformerOpParallelConfig()): - super().__init__() - self.batch_size = batch_size - self.seq_length = seq_length - self.hidden_size = dim - self.n_head = n_heads - self.head_dim = dim // n_heads - self.n_kv_head = n_heads if n_kv_heads is None else n_kv_heads - self.n_rep = self.n_head // self.n_kv_head - - self.dtype = compute_dtype - self.softmax_dtype = softmax_compute_dtype - self.is_first_iteration = True - self.use_past = use_past - self.use_flash_attention = use_flash_attention - self.block_size = block_size - self.num_blocks = num_blocks - - if self.hidden_size % self.n_head != 0: - raise ValueError("For 'MultiHeadAttention', the class variable 'hidden_size' must be a multiple " - "of 'n_head', but got the hidden_size is {} and the n_head is {}." - .format(self.hidden_size, self.n_head)) - if self.n_kv_head % parallel_config.model_parallel != 0: - raise ValueError("For 'MultiHeadAttention', the class variable 'n_kv_head' must be a multiple of " - "'parallel_config.model_parallel', but got the n_kv_head is {} " - "and the parallel_config.model_parallel is {}." - .format(self.n_kv_head, parallel_config.model_parallel)) - - self.inv_norm_factor = Tensor(1.0 / math.sqrt(self.head_dim), dtype=compute_dtype) - - self.shape = P.Shape() - self.reshape = P.Reshape() - if is_dynamic: - self.reshape.add_prim_attr("skip_redistribution", True) - self.transpose = P.Transpose() - self.merger_head_transpose = P.Transpose() - self.batch_matmul = P.BatchMatMul() - self.batch_matmul_q_k = P.BatchMatMul(transpose_b=True) - self.mul = P.Mul() - self.add = P.Add() - self.add_alibi = P.Add() - self.softmax = P.Softmax() - self.cast = P.Cast() - self.cast_attn = P.Cast() - self.tile_kv = P.Tile() - - self.wo = Linear(in_channels=self.hidden_size, - out_channels=self.hidden_size, - has_bias=False, - compute_dtype=compute_dtype, - param_init_type=param_init_type, - skip_redistribution=is_dynamic) - self.wq = Linear(self.hidden_size, - self.hidden_size, - has_bias=False, - compute_dtype=compute_dtype, - param_init_type=param_init_type, - skip_redistribution=is_dynamic) - self.wk = Linear(self.hidden_size, - self.n_kv_head * self.head_dim, - has_bias=False, - compute_dtype=compute_dtype, - param_init_type=param_init_type, - skip_redistribution=is_dynamic) - self.wv = Linear(self.hidden_size, - self.n_kv_head * self.head_dim, - has_bias=False, - compute_dtype=compute_dtype, - param_init_type=param_init_type, - skip_redistribution=is_dynamic) - - dp = parallel_config.data_parallel - mp = parallel_config.model_parallel - if not (_get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation()): - self.transpose.shard(((dp, 1, mp, 1),)) - self.merger_head_transpose.shard(((dp, mp, 1, 1),)) - self.batch_matmul_q_k.shard(((dp, mp, 1, 1), (dp, mp, 1, 1))) - self.batch_matmul.shard(((dp, mp, 1, 1), (dp, mp, 1, 1))) - self.mul.shard(((dp, mp, 1, 1), ())) - self.add.shard(((dp, 1, 1, 1), (dp, mp, 1, 1))) - self.add_alibi.shard(((dp, mp, 1, 1), (dp, mp, 1, 1))) - self.softmax.shard(((dp, mp, 1, 1),)) - self.tile_kv.shard(((dp, mp, 1, 1),)) - - self.wq.shard(((dp, 1), (mp, 1))) - self.wk.shard(((dp, 1), (mp, 1))) - self.wv.shard(((dp, 1), (mp, 1))) - self.wo.shard(((dp, mp), (1, mp))) - - if parallel_config.use_seq_parallel and self.is_first_iteration: - self.wo.shard(((dp, mp), (1, mp)), out_strategy_matmul=((dp * mp, 1),)) - if parallel_config.recompute.select_recompute: - self.batch_matmul_q_k.recompute() - self.mul.recompute() - self.add_alibi.recompute() - self.softmax.recompute() - self.batch_matmul.recompute() - - if self.use_flash_attention: - self.flash_attention = FlashAttention(head_num=n_heads, - scale_value=1. / math.sqrt(self.head_dim), - input_layout='BNSD', - pre_tokens=65536, - next_tokens=0, - use_alibi_mask=True) - self.flash_attention.shard(parallel_config) - if self.use_past: - self.infer_attention = InferAttention(self.n_head, - self.head_dim, - self.n_kv_head, - scale_value=1. / math.sqrt(self.head_dim), - pre_tokens=65536, - next_tokens=65536, - block_size=self.block_size, - num_blocks=self.num_blocks, - use_alibi_mask=True, - use_rope_rotary_emb=False, - use_flash_attention=use_flash_attention, - compute_dtype=compute_dtype).shard(parallel_config) - self.infer_attention.shard(parallel_config) - - def construct(self, x: Tensor, alibi_tensor: Tensor, mask=None, batch_valid_length=None, block_tables=None, - slot_mapping=None): - """Forward process of the MultiHeadAttention""" - ori_dtype = x.dtype - # [bs, seq/1, hidden_dim] or [bs * seq/1, hidden_dim] - bs, seq_len, _ = self.shape(x) - - query = self.cast(self.wq(x), self.dtype) # dp, 1 -> dp, mp - key = self.cast(self.wk(x), self.dtype) # dp, 1 -> dp, mp - value = self.cast(self.wv(x), self.dtype) # dp, 1 -> dp, mp - - # key and value for current token(s) - if self.use_past: - attention = self.infer_attention(query, key, value, batch_valid_length, block_tables, slot_mapping, - None, mask, alibi_tensor) - else: - query = self.transpose(self.reshape(query, (bs, seq_len, self.n_head, self.head_dim)), (0, 2, 1, 3)) - key = self.transpose(self.reshape(key, (bs, seq_len, self.n_kv_head, self.head_dim)), (0, 2, 1, 3)) - value = self.transpose(self.reshape(value, (bs, seq_len, self.n_kv_head, self.head_dim)), (0, 2, 1, 3)) - if self.use_flash_attention: - attention = self.flash_attention(query, key, value, mask, alibi_tensor) - attention = self._merge_heads(attention) - else: - key = self._repeat_kv(key, self.n_rep) - value = self._repeat_kv(value, self.n_rep) - attention = self._attn(query, key, value, mask, alibi_tensor) - - # [bs, seq/1, hidden_dim] or [bs * seq/1, hidden_dim] - output = self.wo(attention) # dp, mp -> dp, 1 / dp * mp, 1 - output = self.cast(output, ori_dtype) - - return output - - def _repeat_kv(self, x, rep): - if rep == 1: - return x - bs, n_kv_head, seqlen, head_dim = self.shape(x) - x = self.reshape(x, (bs, n_kv_head, 1, seqlen * head_dim)) - x = self.tile_kv(x, (1, 1, rep, 1)) - x = self.reshape(x, (bs, n_kv_head * rep, seqlen, head_dim)) - return x - - def _merge_heads(self, x): - """ - convert a 4d input to a 2d or 3d output - - Inputs: - x: input tensor - - Output: - x_merge: the 2d output - """ - # [bs, n_head, seq/1, head_dim] - x = self.merger_head_transpose(x, (0, 2, 1, 3)) # dp,mp,1,1 -> dp,1,mp,1 - # [bs, seq/1, n_head, head_dim] - bs, seq_len, n_head, head_dim = self.shape(x) - # [bs, seq/1, hidden_dim] - new_shape = (bs, seq_len, n_head * head_dim) - x_merge = self.reshape(x, new_shape) - return x_merge - - def _attn(self, query, key, value, mask, alibi_tensor): - """ - Get the weighted score along the seq_length - - Inputs: - query: the query matrix - key: the key matrix - value: the value matrix - mask: the attention mask adder matrix with shape (batch_size, - 1, seq_length, seq_length) - Outputs: - weighted_values: Tensor, the weighted sum scores - """ - # q, k: [bs, n_head, seq/1, head_dim], [bs, n_head, seq, head_dim] - score = self.batch_matmul_q_k(query, key) - # score: [bs, n_head, seq/1, seq] - score = self.mul(score, self.inv_norm_factor) - score = self.add_alibi(score, alibi_tensor) - - score = self.add(mask, score) - - attention_probs = self.softmax(self.cast_attn(score, self.softmax_dtype)) - # score, v: [bs, n_head, seq/1, seq], [bs, n_head, seq, head_dim] - weighted_values = self.batch_matmul(self.cast(attention_probs, self.dtype), value) - # [bs, n_head, seq/1, head_dim] - attention_merge = self._merge_heads(weighted_values) - # [bs, seq/1, hidden_dim] or [bs * seq/1, hidden_dim] - return attention_merge - - -class Baichuan13BDecodeLayer(nn.Cell): - r""" - Transformer Layer. This is an implementation of the single layer of the transformer - encoder layer, including multihead attention and feedward layer. - - Args: - batch_size(int): The batch size of the input tensor when do increnmental prediction. Should be a positive - value. When do training or prediction, the argument will not work and the user can just pass None to - the argument. - seq_length(int): The input sequence length. - layer_id(int): The layer id of current transformer block layer. - dim(int): The hidden size of the input. - num_heads(int): The number of the heads. - multiple_of(int): The SwiGLU hidden layer size multiple of large power of 2. - norm_eps (float): The epsilon value of the denominator. Default 1e-5. - compute_dtype(dtype.Number): The computation type of the layer. - Should be mstype.float32 or mstype.float16. Default mstype.float32. - layernorm_compute_type(dtype.Number): The computation type of the norm. - Should be mstype.float32 or mstype.float16. Default mstype.float32. - softmax_compute_type(dtype.Number): The computation type of the softmax in the attention. - Should be mstype.float32 or mstype.float16. Default mstype.float32. - param_init_type(dtype.Number): The parameter initialization type of the module. - Should be mstype.float32 or mstype.float16. Default mstype.float32. - use_past(bool): Use the past state to compute, used for incremental prediction. For example, if we have two - words and want to generate the ten more words. We just need to compute the two words' state only once, - and generate the next word one by one. When use_past is True, there are two steps to run the prediction. - In the first step, set the is_first_iteration to be True by - `model.add_flags_recursive(is_first_iteration=True)`, and pass the full inputs. Then, set the - is_first_iteration to be False by `model.add_flags_recursive(is_first_iteration=False)`. - At this moment, pass the single step's input tensor, and loop it. Default False. - parallel_config(OpParallelConfig, MoEParallelConfig): The parallel configure. When MoE is applied, - MoEParallelConfig is effective, otherwise OpParallelConfig is effective. Default `default_dpmp_config`, - an instance of `OpParallelConfig` with default args. - - Inputs: - - **x** (Tensor) - Float Tensor, shape should be [batch_size, seq_length, hidden_size] or - [batch_size * seq_length, hidden_size], if the use_past is False or is_first_iteration=True. Otherwise, - should be [batch_size, 1, hidden_size] - - **alibi_tensor** (Tensor) - Alibi Tensor for position embedding used in attention. - - **mask** (Tensor) - Float Tensor, If the use_past is - False or is_first_iteration=True, - the attention mask matrix should ba [batch_size, seq_length, seq_length], or None. None means there will - be no mask in softmax computation. Otherwise, should be [batch_size, 1, hidden_size] - - **batch_valid_length** (Tensor) - Int32 tensor with shape [batch_size] the past calculated the index. - Used for incremental prediction when the use_past is True. Default None. - - **block_tables** (Tensor[int64]) - Store mapping tables for each sequence. - - **slot_mapping** (Tensor[int32]) - Store token cache physical slot index. - - Outputs: - Tuple, a tuple contains(`output`, `layer_present`). - - - **output** (Tensor) - The float tensor of the output of the layer with - shape (batch_size, seq_length, hidden_size) or (batch_size * seq_length, hidden_size), if the use_past is - False or is_first_iteration=True. Otherwise, it will be (batch_size, 1, hidden_size) - - - **layer_present** (Tuple) - A tuple of the Tensor of the projected key and value vector with - ((batch_size, num_heads, head_dim, seq_length), - (batch_size, num_heads, seq_length, head_dim)). - - """ - - def __init__(self, - batch_size, - seq_length, - layer_id, - dim: int = 512, - n_heads: int = 8, - n_kv_heads: Optional[int] = None, - intermediate_size: Optional[int] = None, - multiple_of: int = 256, - ffn_dim_multiplier: Optional[int] = None, - norm_eps: float = 1e-5, - compute_dtype=mstype.float16, - layernorm_compute_dtype=mstype.float32, - softmax_compute_dtype=mstype.float32, - param_init_type=mstype.float32, - use_past=False, - is_dynamic=False, - use_flash_attention=False, - block_size: int = 128, - num_blocks: int = 224, - parallel_config=TransformerOpParallelConfig()): - super().__init__() - if batch_size or use_past: - Validator.check_positive_int(batch_size) - self.batch_size = batch_size - self.seq_length = seq_length - self.layer_id = layer_id - self.hidden_size = dim - self.n_head = n_heads - self.head_dim = self.hidden_size // self.n_head - self.n_kv_head = n_heads if n_kv_heads is None else n_kv_heads - - self.dtype = compute_dtype - self.is_first_iteration = True - self.use_past = use_past - self.is_dynamic = is_dynamic - self.key_past = None - self.value_past = None - self.use_seq_parallel = parallel_config.use_seq_parallel - - self.shape = P.Shape() - self.reshape = P.Reshape() - if is_dynamic: - self.reshape.add_prim_attr("skip_redistribution", True) - self.add = P.Add() - self.attention_norm = LlamaRMSNorm(self.hidden_size, norm_eps, compute_type=layernorm_compute_dtype) - self.ffn_norm = LlamaRMSNorm(self.hidden_size, norm_eps, compute_type=layernorm_compute_dtype) - self.attention = Baichuan13BAttention(batch_size=batch_size, - seq_length=seq_length, - dim=dim, - n_heads=n_heads, - n_kv_heads=n_kv_heads, - compute_dtype=compute_dtype, - softmax_compute_dtype=softmax_compute_dtype, - param_init_type=param_init_type, - use_past=use_past, - is_dynamic=is_dynamic, - use_flash_attention=use_flash_attention, - block_size=block_size, - num_blocks=num_blocks, - parallel_config=parallel_config) - self.feed_forward = LlamaFeedForward(dim=self.hidden_size, - intermediate_size=intermediate_size, - hidden_dim=4 * self.hidden_size, - multiple_of=multiple_of, - ffn_dim_multiplier=ffn_dim_multiplier, - compute_dtype=compute_dtype, - param_init_type=param_init_type, - is_dynamic=is_dynamic) - - dp = parallel_config.data_parallel - mp = parallel_config.model_parallel - if not (_get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation()): - self.feed_forward.shard(parallel_config) - self.add.shard(((dp, 1, 1), (dp, 1, 1))) - self.attention_norm.shard((dp, 1, 1)) - self.ffn_norm.shard((dp, 1, 1)) - self.feed_forward.mul.shard(((dp, 1, mp), (dp, 1, mp))) - - if parallel_config.use_seq_parallel and self.is_first_iteration: - self.add.shard(((dp, mp, 1), (dp, mp, 1))) - self.attention_norm.shard((dp, mp, 1)) - self.ffn_norm.shard((dp, mp, 1)) - self.feed_forward.w2.shard(((dp, mp), (1, mp)), out_strategy_matmul=((dp * mp, 1),)) - - def construct(self, x, alibi_tensor, mask=None, batch_valid_length=None, block_tables=None, slot_mapping=None): - """ Forward of transformer block. """ - self._check_input(x, alibi_tensor, mask) - input_x = self.attention_norm(x) - # [bs, seq/1, hidden_dim] - h = self.attention(input_x, alibi_tensor, mask, batch_valid_length, block_tables, slot_mapping) - h = self.add(x, h) - ffn_norm = self.ffn_norm(h) - # [bs, seq/1, hidden_dim] or [bs * seq/1, hidden_dim] - ffn_out = self.feed_forward(ffn_norm) - # [bs, seq/1, hidden_dim] or [bs * seq/1, hidden_dim] - out = self.add(h, ffn_out) - return out - - def _check_input(self, x, alibi_tensor, mask): - r"""Check inputs""" - _check_input_dtype( - x.dtype, "x", [mstype.float32, mstype.float16], self.cls_name) - _check_input_dtype(alibi_tensor.dtype, "alibi_tensor", - [mstype.float32, mstype.float16], self.cls_name) - if mask is not None: - _check_input_dtype(mask.dtype, "input_mask", [mstype.float32, mstype.float16, mstype.uint8], self.cls_name) - return True - - -class NormHead(nn.Cell): - """ - NormHead Layer. - - Args: - hidden_size (int): The hidden size of the input. - vocab_size (int): Size of the dictionary of embeddings. - compute_type (dtype.Number): The compute type. - eps (number): A small positive value prevents division by zero. - - Inputs: - - hidden_states (Tensor) - Tensor of shape :math:`(batch, seq_length, hidden_size)`. - - Outputs: - Tensor of shape :math:`(batch, seq_length, vocab_size)`. - """ - - def __init__(self, - hidden_size, - vocab_size, - use_past, - is_dynamic=False, - compute_dtype=mstype.float32, - eps=1e-5): - super().__init__() - self.weight = Parameter( - initializer(HeUniform(negative_slope=math.sqrt(5)), - [vocab_size, hidden_size], - mstype.float16), - name='weight', - parallel_optimizer=True) - self.square = P.Square() - self.sqrt = P.Sqrt() - self.add = P.Add() - self.real_div = P.RealDiv() - self.reshape = P.Reshape() - self.sum = P.ReduceSum() - self.eps = Tensor([eps], mstype.float16) - self.is_first_iteration = True - self.use_past = use_past - - self.matmul = P.MatMul(transpose_b=True) - self.cast = P.Cast() - self.compute_dtype = compute_dtype - self.hidden_size = hidden_size - self.vocab_size = vocab_size - self.assign = P.Assign() - - if is_dynamic: - self.reshape.add_prim_attr("skip_redistribution", True) - - def construct(self, hidden_states): - """Forward process of the NormHead""" - out_shape = P.Shape()(hidden_states)[:-1] + (self.vocab_size,) - hidden_states = self.reshape(hidden_states, (-1, self.hidden_size)) - - if self.is_first_iteration: - variance = self.square(self.weight) - variance = self.sum(variance, 1) - variance = self.reshape(variance, (-1, 1)) - variance_eps = self.sqrt(self.add(variance, self.eps)) - norm_weight = self.real_div(self.weight, variance_eps) - if self.use_past: - norm_weight = ops.depend(norm_weight, norm_weight) - self.assign(self.weight, norm_weight) - else: - norm_weight = self.weight - self.assign(self.weight, norm_weight) - norm_weight = ops.depend(norm_weight, norm_weight) - - ori_type = hidden_states.dtype - out = self.matmul(hidden_states.astype(self.compute_dtype), - norm_weight.astype(self.compute_dtype)) - out = self.reshape(out, out_shape) - return self.cast(out, ori_type) - - def shard(self, parallel_config): - """sharding for norm head""" - self.square.shard(((parallel_config.model_parallel * parallel_config.data_parallel, 1),)) - self.sqrt.shard(((parallel_config.model_parallel * parallel_config.data_parallel, 1),)) - self.add.shard(((parallel_config.model_parallel * parallel_config.data_parallel, 1), (1,))) - self.real_div.shard(((parallel_config.model_parallel * parallel_config.data_parallel, 1), - (parallel_config.model_parallel * parallel_config.data_parallel, 1))) - self.sum.shard(((parallel_config.model_parallel * parallel_config.data_parallel, 1),)) - self.matmul.shard(((1, 1), - (parallel_config.model_parallel * parallel_config.data_parallel, 1))) - - -class DPOLoss(nn.Cell): - def __init__(self, config): - super(DPOLoss, self).__init__() - dp = config.parallel_config.data_parallel - mp = config.parallel_config.model_parallel - self.gatherd = P.GatherD() - self.log = P.Log() - self.reduce_sum = P.ReduceSum(keep_dims=False) - self.slice = P.StridedSlice().shard(((1, 1),)) # ? - self.slice_ind = P.StridedSlice().shard(((1,),)) # ? - self.mul = P.Mul().shard(((dp, mp), (dp, mp))) - self.sub = P.Sub().shard(((dp, mp), (dp, mp))) - self.log_softmax = P.LogSoftmax() - self.squeeze = P.Squeeze(-1).shard(((1, 1, 1),)) - self.expand = P.ExpandDims().shard(((1, 1),)) - self.label_pad_token_id = config.pad_token_id - self.average_log_prob = False - self.reference_free = False - self.log_sigmoid = nn.LogSigmoid() - self.reduce_mean = P.ReduceMean() - self.not_equal = P.NotEqual() - self.beta = 0.1 - self.enable_force_redistribute = True - if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL, ParallelMode.SEMI_AUTO_PARALLEL): - self.enable_force_redistribute = True - self.add = P.Add().shard(((dp, mp), ())).add_prim_attr("keep_alive", True) - self.add_label = P.Add().shard(((dp,), ())).add_prim_attr("keep_alive", True) - - def _get_batch_logps(self, logits, labels, loss_mask=None): - """Compute the log probabilities of the given labels under the given logits. - - Args: - logits: Logits of the model (unnormalized). Shape: (batch_size, seq_len, vocab_size) - labels: Labels for which to compute the log probabilities. Label tokens with value of label_pad_token_id are ignored. Shape: (batch_size, seq_len) - - Returns: - A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. - """ - if loss_mask is None: - loss_mask = self.not_equal(labels, self.label_pad_token_id) - # [bs, seq_len] -> [bs, seq_len] - labels = self.mul(labels, loss_mask) - # [bs, seq_len, vocab_size] - log_probs = self.log_softmax(logits) - # [bs, seq_len] -> [bs, seq_len, 1] - index = self.expand(labels, -1) - index = self.cast(index, mstype.int32) - # [bs, seq_len, 1] - per_token_logps = self.gatherd(log_probs, -1, index) - # [bs, seq_len, 1] -> [bs, seq_len] - per_token_logps = self.squeeze(per_token_logps) - if self.average_log_prob: - return self.reduce_sum(per_token_logps * loss_mask, -1) / self.reduce_sum(loss_mask, -1) - else: - return self.reduce_sum(per_token_logps * loss_mask, -1) - - def dpo_loss(self, policy_chosen_logps, policy_rejected_logps, chosen_ref_logps, rejected_ref_logps): - policy_log_ratios = policy_chosen_logps - policy_rejected_logps - ref_log_ratios = chosen_ref_logps - rejected_ref_logps - if self.reference_free: - ref_log_ratios = 0 - logits = policy_log_ratios - ref_log_ratios - losses = -self.log_sigmoid(self.beta * logits) - chosen_rewards = self.beta * (policy_chosen_logps - chosen_ref_logps) - rejected_rewards = self.beta * (policy_rejected_logps - rejected_ref_logps) - return losses, chosen_rewards, rejected_rewards - - def construct(self, policy_logits, policy_labels, loss_mask, chosen_ref_logps, rejected_ref_logps): - # policy_logits: [bs, seq_len, vocab_size] - # policy_labels: [bs, seq_len] - # loss_mask: [bs, seq_len] - # chosen_ref_logps: [bs,] - # rejected_ref_logps: [bs,] - # [bs,] - all_logps = self._get_batch_logps(policy_logits, policy_labels, loss_mask) - bs = all_logps.shape[0] // 2 # a sample has two bs responses (chosen and rejected) - policy_chosen_logps = self.slice_ind(all_logps, (0,), (bs,), (1,)) - policy_rejected_logps = self.slice_ind(all_logps, (bs,), (2 * bs,), (1,)) - losses, chosen_rewards, rejected_rewards = self.dpo_loss( - policy_chosen_logps, - policy_rejected_logps, - chosen_ref_logps, - rejected_ref_logps - ) - if self.phase == "train": - return losses - return losses, chosen_rewards, rejected_rewards - - -class DPOCrossEntropy(CrossEntropyLoss): - def __init__(self, parallel_config, **kwargs): - super().__init__(parallel_config, **kwargs) - dp = parallel_config.data_parallel - mp = parallel_config.model_parallel - self.slice_3d = P.StridedSlice().shard(((dp, mp, 1),)) - self.slice_2d = P.StridedSlice().shard(((dp, mp),)) - - def construct(self, logits, label, input_mask): - bs, seq_len, vocab_size = logits.shape # a sample has two bs responses (chosen and rejected) - policy_chosen_logps = self.slice_3d(logits, (0, 0, 0), (bs // 2, seq_len, vocab_size), (1, 1, 1)) - label = self.slice_2d(label, (0, 0), (bs // 2, seq_len), (1, 1)) - input_mask = self.slice_2d(input_mask, (0, 0), (bs // 2, seq_len), (1, 1)) - return super().construct(policy_chosen_logps.reshape((-1, policy_chosen_logps.shape[-1])), label.reshape((-1,)), input_mask.reshape((-1,))) - -@MindFormerRegister.register(MindFormerModuleType.LOSS) -class DPOLossV2(nn.Cell): - def __init__(self, config): - super(DPOLossV2, self).__init__() - dp = config.parallel_config.data_parallel - mp = config.parallel_config.model_parallel - self.gatherd = P.GatherD() - self.log = P.Log() - self.reduce_sum = P.ReduceSum(keep_dims=False) - self.slice = P.StridedSlice().shard(((1, 1),)) # ? - self.slice_ind = P.StridedSlice().shard(((1,),)) # ? - self.slice_mask = P.StridedSlice().shard(((1, 1),)) - self.mul = P.Mul().shard(((dp, mp), (dp, mp))) - self.sub = P.Sub().shard(((dp, mp), (dp, mp))) - self.log_softmax = P.LogSoftmax() - self.squeeze = P.Squeeze(-1).shard(((1, 1, 1),)) - self.expand = P.ExpandDims().shard(((1, 1),)) - self.label_pad_token_id = config.pad_token_id - self.average_log_prob = True - self.reference_free = False - self.log_sigmoid = nn.LogSigmoid() - self.not_equal = P.NotEqual() - self.beta = 0.2 - self.enable_force_redistribute = True - if _get_parallel_mode() in (ParallelMode.AUTO_PARALLEL, ParallelMode.SEMI_AUTO_PARALLEL): - self.enable_force_redistribute = True - self.add = P.Add().shard(((dp, mp), ())).add_prim_attr("keep_alive", True) - self.add_label = P.Add().shard(((dp,), ())).add_prim_attr("keep_alive", True) - - def _get_batch_logps(self, logits, labels, loss_mask=None): - """Compute the log probabilities of the given labels under the given logits. - - Args: - logits: Logits of the model (unnormalized). Shape: (batch_size, seq_len, vocab_size) - labels: Labels for which to compute the log probabilities. Label tokens with value of label_pad_token_id are ignored. Shape: (batch_size, seq_len) - - Returns: - A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits. - """ - # TODO: For glm2, the loss mask might be passed in - if loss_mask is None: - loss_mask = self.not_equal(labels, self.label_pad_token_id) - # [bs, seq_len] -> [bs, seq_len] - labels = self.mul(labels, loss_mask) - # [bs, seq_len, vocab_size] - print("logits", logits) - log_probs = self.log_softmax(logits) - print("log_probs", log_probs) - # [bs, seq_len] -> [bs, seq_len, 1] - index = self.expand(labels, -1) - index = self.cast(index, mstype.int32) - # [bs, seq_len, 1] - per_token_logps = self.gatherd(log_probs, -1, index) - # [bs, seq_len, 1] -> [bs, seq_len] - per_token_logps = self.squeeze(per_token_logps) - print("per_token_logps", per_token_logps) - if self.average_log_prob: - print("per_token_logps final", self.reduce_sum(per_token_logps * loss_mask, -1), self.reduce_sum(loss_mask, -1)) - return self.reduce_sum(per_token_logps * loss_mask, -1) / self.reduce_sum(loss_mask, -1) - else: - return self.reduce_sum(per_token_logps * loss_mask, -1) - - def dpo_loss(self, policy_chosen_logps, policy_rejected_logps, ref_chosen_logps, ref_rejected_logps, loss_mask): - bs, seq_len = loss_mask.shape - if self.average_log_prob: - policy_chosen_logps_avg = policy_chosen_logps - else: - chosen_loss_mask = self.slice_mask(loss_mask, (0, 0), (bs // 2, seq_len), (1, 1)) - chosen_valid_len = self.reduce_sum(chosen_loss_mask, -1) - policy_chosen_logps_avg = policy_chosen_logps / chosen_valid_len - # if self.average_log_prob: - # rejected_loss_mask = self.slice_mask(loss_mask, (bs // 2, 0), (bs, seq_len), (1, 1)) - # rejected_valid_len = self.reduce_sum(rejected_loss_mask, -1) - # policy_chosen_logps = policy_chosen_logps / chosen_valid_len - # ref_chosen_logps = ref_chosen_logps / chosen_valid_len - # policy_rejected_logps = policy_rejected_logps / rejected_valid_len - # ref_rejected_logps = ref_rejected_logps / rejected_valid_len - - policy_log_ratios = policy_chosen_logps - policy_rejected_logps - ref_log_ratios = ref_chosen_logps - ref_rejected_logps - if self.reference_free: - ref_log_ratios = 0 - logits = policy_log_ratios - ref_log_ratios - losses = -self.log_sigmoid(self.beta * logits) # if logits is very large, the losses would be nan - chosen_rewards = self.beta * (policy_chosen_logps - ref_chosen_logps) - rejected_rewards = self.beta * (policy_rejected_logps - ref_rejected_logps) - return losses, chosen_rewards, rejected_rewards, policy_chosen_logps_avg - - def construct(self, policy_logits, policy_labels, chosen_loss_mask, rejected_loss_mask, ref_chosen_logps, ref_rejected_logps): - # policy_logits: [bs, seq_len, vocab_size] - # policy_labels: [bs, seq_len] - # loss_mask: [bs, seq_len] - # ref_chosen_logps: [bs,] - # ref_rejected_logps: [bs,] - # [bs,] - loss_mask = ops.concat((chosen_loss_mask, rejected_loss_mask), axis=0) - all_logps = self._get_batch_logps(policy_logits, policy_labels, loss_mask) - bs = all_logps.shape[0] // 2 # a sample has two bs responses (chosen and rejected) - policy_chosen_logps = self.slice_ind(all_logps, (0,), (bs,), (1,)) - policy_rejected_logps = self.slice_ind(all_logps, (bs,), (2 * bs,), (1,)) - print("policy_chosen_logps", policy_chosen_logps) - print("policy_rejected_logps", policy_rejected_logps) - if self.average_log_prob: - ref_chosen_logps = ref_chosen_logps / self.reduce_sum(chosen_loss_mask, -1) - ref_rejected_logps = ref_rejected_logps / self.reduce_sum(rejected_loss_mask, -1) - print("ref_chosen_logps", ref_chosen_logps) - print("ref_rejected_logps", ref_rejected_logps) - dpo_loss, chosen_rewards, rejected_rewards, policy_chosen_logps_avg = self.dpo_loss( - policy_chosen_logps, - policy_rejected_logps, - ref_chosen_logps, - ref_rejected_logps, - loss_mask - ) - sft_loss = -policy_chosen_logps_avg - print("sft loss: ", sft_loss) - print("dpo loss: ", dpo_loss) - if self.phase == "train": - return dpo_loss, sft_loss - return dpo_loss, sft_loss, chosen_rewards, rejected_rewards - -@MindFormerRegister.register(MindFormerModuleType.MODELS) -class Baichuan13BDPO(Baichuan13BV2ForCausalLM): - r""" - Provide baichuan2_13B training loss or logits through network. - Args: - config (LlamaConfig): The config of baichuan2_13B model. - - Inputs: - input_ids(Tensor): the tokenized inputs with datatype int32, Tensor of shape :math:`(batch, seq\_length)`. - labels(Tensor): the tokenized labels with datatype int32, Tensor of shape :math:`(batch, seq\_length)`. - input_position(Tensor): current position, used by model.predict. - init_reset(bool, optional): A bool tensor with shape [1], used to clear the past key parameter and - past value parameter used in the incremental prediction. Default True. - batch_valid_length(Tensor): the past calculated the index with datatype int32, used for incremental - prediction. Tensor of shape :math:`(batch_size,)`. Default None. - block_tables (Tensor[int64]): Store mapping tables for each sequence. - slot_mapping (Tensor[int32]): Store token cache physical slot index. - - Returns: - Tensor, the loss or logits of the network. - - Examples: - >>> from mindformers.models.llama import LlamaConfig - >>> from research.baichuan2.baichuan2_13b import Baichuan13BV2ForCausalLM - >>> config = LlamaConfig(batch_size=2) - >>> network = Baichuan13BV2ForCausalLM(config=config) - """ - - @lazy_inline - def __init__(self, config: LlamaConfig = None): - super(Baichuan13BDPO, self).__init__(config) - _check_config(config.parallel_config) - self.config = config - self.seq_length = config.seq_length - self.ignore_token_id = config.ignore_token_id - self.pad_token_id = config.pad_token_id - self.use_past = config.use_past - self.vocab_size = config.vocab_size - self.is_first_iteration = True - self.dtype = config.compute_dtype - - self.shape = P.Shape() - self.reshape = P.Reshape() - self.cast = P.Cast() - self.slice = P.StridedSlice() - self.not_equal = P.NotEqual() - self.mul = P.Mul() - self.add = P.Add() - self.ones = P.Ones() - self.gather = P.Gather(1) - self.sub_batch_valid_len = P.Sub() - - vocab_size = config.vocab_size - dp = config.parallel_config.data_parallel - mp = config.parallel_config.model_parallel - if config.parallel_config.vocab_emb_dp or (config.vocab_size % mp != 0): - self.dpo_loss = DPOLossV2(config) - else: - loss_parallel_config = copy.deepcopy(config) - loss_parallel_config.parallel_config.model_parallel = dp * mp - loss_parallel_config.parallel_config.data_parallel = 1 - if dp >= 32 and dp % 8 == 0: # For large scale training - loss_parallel_config.parallel_config.model_parallel = 8 - loss_parallel_config.parallel_config.data_parallel = dp * mp // 8 - self.dpo_loss = DPOLossV2(loss_parallel_config) - - self.alpha = config.alpha - self.beta = config.beta - if config.parallel_config.vocab_emb_dp or (config.vocab_size % mp != 0): - self.sft_loss = DPOCrossEntropy(parallel_config=config.parallel_config) - else: - loss_parallel_config = copy.deepcopy(config.parallel_config) - loss_parallel_config.model_parallel = dp * mp - loss_parallel_config.data_parallel = 1 - if dp >= 32 and dp % 8 == 0: # For large scale training - loss_parallel_config.model_parallel = 8 - loss_parallel_config.data_parallel = dp * mp // 8 - self.sft_loss = DPOCrossEntropy(parallel_config=loss_parallel_config) - - # pylint: disable=W0613 - def construct(self, chosen_input_ids, chosen_labels=None, chosen_loss_mask=None, - chosen_ref_logps=None, rejected_input_ids=None, rejected_labels=None, - rejected_loss_mask=None, rejected_ref_logps=None, - input_position=None, position_ids=None, - input_embeds=None, init_reset=True, batch_valid_length=None, batch_index=None, zactivate_len=None, - block_tables=None, slot_mapping=None): - """Baichuan13BV2ForCausalLM forward.""" - if self.training: - input_ids = ops.concat((chosen_input_ids, rejected_input_ids), axis=0) - labels = ops.concat((chosen_labels, rejected_labels), axis=0) - # loss_mask = ops.concat((chosen_loss_mask, rejected_loss_mask), axis=0) - else: - input_ids = chosen_input_ids - labels = chosen_labels - # loss_mask = chosen_loss_mask - bsz, ori_seqlen = self.shape(input_ids) - if self.use_past: - if not isinstance(batch_valid_length, Tensor): - batch_valid_length = self.ones((bsz,), mstype.int32) - if self.training: - tokens = self.slice(input_ids, (0, 0), (bsz, ori_seqlen - 1), (1, 1)) - chosen_loss_mask = self.slice(chosen_loss_mask, (0, 1), (bsz, ori_seqlen), (1, 1)) - rejected_loss_mask = self.slice(rejected_loss_mask, (0, 1), (bsz, ori_seqlen), (1, 1)) - else: - tokens = input_ids - if batch_valid_length is not None: - batch_valid_length = self.reshape(batch_valid_length, (-1,)) - output = self.model(tokens, batch_valid_length, block_tables, slot_mapping) - pre_gather = (not self.use_past or self.is_first_iteration) and batch_valid_length is not None - if pre_gather: - output = self.gather(output, self.sub_batch_valid_len(batch_valid_length, 1), 1) - logits = self.lm_head(output) - - input_mask = self.cast(self.not_equal(tokens, self.pad_token_id), mstype.float32) - if labels is None: - labels = self.slice(input_ids, (0, 1), (bsz, ori_seqlen), (1, 1)) - else: - if labels.ndim > 1: - if self.training: - labels = self.slice(labels, (0, 1), (bsz, ori_seqlen), (1, 1)) - label_mask = self.cast(self.not_equal(labels, self.ignore_token_id), mstype.float32) - input_mask = self.mul(input_mask, label_mask) - - if not self.training: - return logits, tokens, input_mask - - if logits.ndim <= 2: - logits = self.reshape(logits, (bsz, tokens.shape[1], logits.shape[-1])) - # logits = self.cast(logits, mstype.float32) - # labels = self.reshape(labels, (-1,)) - # input_mask = self.reshape(input_mask, (-1,)) - # loss = self.loss(logits, labels, input_mask) - policy_logits = self.cast(logits, mstype.float32) - # sft_loss = self.sft_loss(policy_logits, labels, loss_mask) - dpo_loss, sft_loss = self.dpo_loss(policy_logits, labels, chosen_loss_mask, rejected_loss_mask, chosen_ref_logps.reshape((-1,)), rejected_ref_logps.reshape((-1,))) - return self.alpha * dpo_loss + self.beta * sft_loss diff --git a/examples/dpo/baichuan2/README.md b/examples/dpo/baichuan2/README.md index dcf1d0ebb4d415d757151042bb1dcdb342bea16d..44ec00d2561e000c222f4b65555b0b840a92ac3c 100644 --- a/examples/dpo/baichuan2/README.md +++ b/examples/dpo/baichuan2/README.md @@ -1,4 +1,4 @@ -# Baichuan2_13b DPO 训练教程 +# Baichuan2_13b-DPO 训练教程 ## 网络介绍 DPO训练中使用的网络和Mindformers中使用的结构一致。请参考[链接](https://portrait.gitee.com/huanglei_Sorry/mindformers/blob/dev/research/baichuan2/baichuan2.md)获得更详细的介绍内容。 @@ -17,9 +17,9 @@ bash scripts/msrun_launcher.sh \ "mindrlhf/tools/dpo_preprocess.py \ --src /path/to/input.jsonl \ --dst /path/to/output.mindrecord \ ---config model_configs/baichuan_config/process_baichuan2_13b.yaml \ +--config /path/to/model_configs/baichuan_config/process_baichuan2_13b.yaml \ --tokenizer /path/mindrlhf/tokenizers/baichuan/tokenizer.model \ ---seq_len 4097 \ +--seq_len 4096 \ --dataset_type cvalues \ --save_interval 2" \ 8 @@ -79,4 +79,4 @@ bash ./run_baichuan2_predict.sh single \ /path/mindrlhf/examples/dpo/baichuan2/checkpoint_network/rank_0/checkpoint_0.ckpt \ /path/mindrlhf/tokenizers/baichuan/tokenizer.model \ "如何制作毒品?" -``` +``` \ No newline at end of file diff --git a/examples/dpo/glm4/README.md b/examples/dpo/glm4/README.md index 8a9ecee71f5bb7feb21d60f570ae24bb3127eade..48ef09cef7dc567cdbf0d3f2d01accc638fac17f 100644 --- a/examples/dpo/glm4/README.md +++ b/examples/dpo/glm4/README.md @@ -1,4 +1,4 @@ -# GLM4-DPO训练 +# GLM4-DPO 训练教程 DPO训练中使用的网络和Mindformers中使用的结构一致。请参考[链接](https://gitee.com/mindspore/mindformers/blob/r1.3.0/docs/model_cards/glm4.md)获得更详细的介绍内容。 @@ -41,13 +41,14 @@ bash scripts/msrun_launcher.sh \ "mindrlhf/tools/dpo_preprocess.py \ --src /path/to/input.jsonl \ --dst /path/to/output.mindrecord \ ---config model_configs/glm_config/process_glm4_9b.yaml \ +--config /path/to/model_configs/glm_config/process_glm4_9b.yaml \ --tokenizer /path/to/tokenizer.model \ --load_checkpoint /path/to/glm4_9b.ckpt \ --auto_trans_ckpt True \ --seq_len 8192 \ ---dataset_type cvalues" \ -8 +--dataset_type cvalues \ +--save_interval 2" \ +8 # 参数说明 src: 原始数据集文件路径 @@ -77,7 +78,7 @@ dst: 输出数据集文件路径 # 请核对传入的checkpoint是否为分布式权重,如果不是将脚本中的auto_trans_ckpt设置为true,自动转换成分布式权重 bash ../../../scripts/msrun_launcher.sh \ "run_dpo.py \ - --config finetune_glm4_9b.yaml \ + --config /path/to/model_configs/glm_config/finetune_glm4_9b.yaml \ --train_dataset /path/to/input.mindrecord \ --vocab_file /path/to/tokenizer.model \ --load_checkpoint /path/to/glm4_9b.ckpt \ diff --git a/examples/dpo/qwen2/README.md b/examples/dpo/qwen2/README.md index 8c4bdfdc769d4db7051104236bc863968d0bef47..7a57f0cb8336a6c9c079b322425f12c9860ab4a6 100644 --- a/examples/dpo/qwen2/README.md +++ b/examples/dpo/qwen2/README.md @@ -1,4 +1,4 @@ -## DPO训练 +## QWEN2-DPO 训练教程 DPO训练中使用的网络和Mindformers中使用的结构一致。请参考[链接](https://gitee.com/mindspore/mindformers/blob/dev/research/qwen2/qwen2.md)获得更详细的介绍内容。 @@ -44,7 +44,7 @@ bash scripts/msrun_launcher.sh \ "mindrlhf/tools/dpo_preprocess.py \ --src /path/to/input.jsonl \ --dst /path/to/output.mindrecord \ ---config model_configs/qwen_config/process_qwen2_7b.yaml \ +--config /path/to/model_configs/qwen_config/process_qwen2_7b.yaml \ --tokenizer /path/to/vocab.json \ --merges_file /path/to/merges.txt \ --seq_len 4097 \ diff --git a/examples/dpo/qwen2_5/README.md b/examples/dpo/qwen2_5/README.md index f3921596f6e030e30013489b6889e7d0a9a91720..9129f04322f2a74e6d03e534c269778380cf2b2a 100644 --- a/examples/dpo/qwen2_5/README.md +++ b/examples/dpo/qwen2_5/README.md @@ -1,4 +1,4 @@ -## DPO训练 +## QWEN2_5-DPO 训练教程 DPO训练中使用的网络和Mindformers中使用的结构一致。请参考[链接](https://gitee.com/mindspore/mindformers/blob/dev/research/qwen2_5/qwen2_5.md)获得更详细的介绍内容。 @@ -42,7 +42,7 @@ bash scripts/msrun_launcher.sh \ "mindrlhf/tools/dpo_preprocess.py \ --src /path/to/input.jsonl \ --dst /path/to/output.mindrecord \ ---config model_configs/qwen_config/process_qwen2_5_7b.yaml \ +--config /path/to/model_configs/qwen_config/process_qwen2_5_7b.yaml \ --tokenizer /path/to/vocab.json \ --merges_file /path/to/merges.txt \ --seq_len 4097 \ diff --git a/mindrlhf/models/baichuan2/baichuan2_13b.py b/mindrlhf/models/baichuan2/baichuan2_13b.py index ea817b03940d3daf5cfb84437b4e75e4638dea2f..e17acca4425c8d9c8795fd0fab0b2bfd16625053 100644 --- a/mindrlhf/models/baichuan2/baichuan2_13b.py +++ b/mindrlhf/models/baichuan2/baichuan2_13b.py @@ -43,10 +43,12 @@ from mindformers.tools.register.register import MindFormerModuleType, MindFormer from mindformers.models.llama.llama_config import LlamaConfig from mindformers.models.llama.llama_layer import LlamaEmbedding, LlamaFeedForward, LlamaRMSNorm from mindformers.tools.logger import logger +from mindformers.tools.utils import get_predict_run_mode __all__ = ['Baichuan13BV2ForCausalLM', 'Baichuan13BV2Model', 'Baichuan13BDPO'] mindformers_version = mindformers.__version__ + class Baichuan2PreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -59,8 +61,7 @@ class Baichuan2PreTrainedModel(PreTrainedModel): @MindFormerRegister.register(MindFormerModuleType.MODELS) class Baichuan13BV2ForCausalLM(Baichuan2PreTrainedModel): - r""" - Provide baichuan2_13B training loss or logits through network. + r"""Provide baichuan2_13B training loss or logits through network. Args: config (LlamaConfig): The config of baichuan2_13B model. @@ -83,7 +84,7 @@ class Baichuan13BV2ForCausalLM(Baichuan2PreTrainedModel): >>> from research.baichuan2.baichuan2_13b import Baichuan13BV2ForCausalLM >>> config = LlamaConfig(batch_size=2) >>> network = Baichuan13BV2ForCausalLM(config=config) - """ + """ @lazy_inline def __init__(self, config: LlamaConfig = None): @@ -112,7 +113,6 @@ class Baichuan13BV2ForCausalLM(Baichuan2PreTrainedModel): self.lm_head = NormHead(hidden_size=config.hidden_size, vocab_size=config.vocab_size, use_past=config.use_past, - is_dynamic=config.is_dynamic, compute_dtype=config.compute_dtype) vocab_size = config.vocab_size @@ -124,7 +124,11 @@ class Baichuan13BV2ForCausalLM(Baichuan2PreTrainedModel): vocab_size, loss_parallel_config.model_parallel) logger.warning("Now, the model_parallel num of Loss will be changed: mp = 1") loss_parallel_config.model_parallel = 1 - self.loss = CrossEntropyLoss(parallel_config=loss_parallel_config) + check_for_nan_in_loss_and_grad = getattr(config, "check_for_nan_in_loss_and_grad", False) + calculate_per_token_loss = getattr(config, "calculate_per_token_loss", False) + self.loss = CrossEntropyLoss(parallel_config=loss_parallel_config, + check_for_nan_in_loss_and_grad=check_for_nan_in_loss_and_grad, + calculate_per_token_loss=calculate_per_token_loss) dp = config.parallel_config.data_parallel if not (_get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation()): @@ -142,11 +146,9 @@ class Baichuan13BV2ForCausalLM(Baichuan2PreTrainedModel): else: self.lm_head.set_comm_fusion(config.parallel_config.gradient_aggregation_group) - if config.is_dynamic: - self.reshape.add_prim_attr("skip_redistribution", True) - self.load_checkpoint(config) - self.set_model_predict_config() + self.predict_run_mode = get_predict_run_mode() + # pylint: disable=W0613 def prepare_inputs_for_generation(self, input_ids, **kwargs): if self.config.is_dynamic and "origin_inputs" in kwargs: @@ -157,12 +159,10 @@ class Baichuan13BV2ForCausalLM(Baichuan2PreTrainedModel): def set_dynamic_inputs(self, **kwargs): dynamic_input_ids = Tensor(shape=[None, None], dtype=mstype.int32) - dynamic_input_position = Tensor(shape=[None], dtype=mstype.int32) - dynamic_init_reset = Tensor([False], mstype.bool_) dynamic_batch_valid_length = Tensor(shape=[None, None], dtype=mstype.int32) dynamic_block_tables = Tensor(shape=[None, None], dtype=mstype.int32) dynamic_slot_mapping = Tensor(shape=[None], dtype=mstype.int32) - self.set_inputs(dynamic_input_ids, None, dynamic_input_position, None, None, None, dynamic_init_reset, + self.set_inputs(dynamic_input_ids, None, None, None, None, None, None, dynamic_batch_valid_length, None, None, dynamic_block_tables, dynamic_slot_mapping) logger.info("Set dynamic input for baichuan2.") @@ -171,7 +171,7 @@ class Baichuan13BV2ForCausalLM(Baichuan2PreTrainedModel): """Get Baichuan13BV2 model input tuple for transform ckpt.""" input_ids = Tensor(input_ids, mstype.int32) bs, seq = input_ids.shape[0], input_ids.shape[1] - slot_mapping = Tensor(np.ones(shape=tuple([bs*seq])), mstype.int32) + slot_mapping = Tensor(np.ones(shape=tuple([bs * seq])), mstype.int32) return input_ids, None, None, None, None, None, None, None, None, None, None, slot_mapping def add_flags_custom(self, is_first_iteration): @@ -184,7 +184,7 @@ class Baichuan13BV2ForCausalLM(Baichuan2PreTrainedModel): # pylint: disable=W0613 def construct(self, input_ids, labels=None, input_position=None, position_ids=None, attention_mask=None, - input_embeds=None, init_reset=True, batch_valid_length=None, batch_index=None, zactivate_len=None, + input_embeds=None, init_reset=None, batch_valid_length=None, batch_index=None, zactivate_len=None, block_tables=None, slot_mapping=None): """Baichuan13BV2ForCausalLM forward.""" bsz, seqlen = self.shape(input_ids) @@ -214,6 +214,10 @@ class Baichuan13BV2ForCausalLM(Baichuan2PreTrainedModel): input_mask = self.mul(input_mask, label_mask) if not self.training: + logits = self.cast(logits, mstype.float32) + if self.predict_run_mode: + logits = self.reshape(logits, (-1, logits.shape[-1])) + return logits return logits, tokens, input_mask if logits.ndim > 2: @@ -296,52 +300,55 @@ class Baichuan13BV2Model(Baichuan2PreTrainedModel): from mindformers.models.utils import set_layer_stage_recompute for layer_id in range(config.num_layers): layer = Baichuan13BDecodeLayer(config.batch_size, - config.seq_length, - layer_id, - dim=config.hidden_size, - n_heads=config.num_heads, - n_kv_heads=config.n_kv_heads, - intermediate_size=config.intermediate_size, - multiple_of=config.multiple_of, - ffn_dim_multiplier=config.ffn_dim_multiplier, - norm_eps=config.rms_norm_eps, - compute_dtype=config.compute_dtype, - layernorm_compute_dtype=config.layernorm_compute_type, - softmax_compute_dtype=config.softmax_compute_type, - param_init_type=config.param_init_type, - use_past=config.use_past, - is_dynamic=config.is_dynamic, - use_flash_attention=self.use_flash_attention, - block_size=self.block_size, - num_blocks=self.num_blocks, - parallel_config=config.parallel_config) + config.seq_length, + layer_id, + dim=config.hidden_size, + n_heads=config.num_heads, + n_kv_heads=config.n_kv_heads, + intermediate_size=config.intermediate_size, + multiple_of=config.multiple_of, + ffn_dim_multiplier=config.ffn_dim_multiplier, + norm_eps=config.rms_norm_eps, + compute_dtype=config.compute_dtype, + layernorm_compute_dtype=config.layernorm_compute_type, + softmax_compute_dtype=config.softmax_compute_type, + param_init_type=config.param_init_type, + use_past=config.use_past, + is_dynamic=config.is_dynamic, + use_flash_attention=self.use_flash_attention, + block_size=self.block_size, + num_blocks=self.num_blocks, + parallel_config=config.parallel_config) set_layer_stage_recompute(layer, layer_id, config.offset, config.parallel_config, config.num_layers) self.layers.append(layer) elif mindformers_version == "r1.3.0": from mindformers.models.utils import LayerSetting - self.layers_setting = LayerSetting(config.offset, config.parallel_config, config.num_layers) + self.layer_setting = LayerSetting(config.num_layers, + config.offset, + config.parallel_config, + config.pp_interleave_num) for layer_id in range(config.num_layers): layer = Baichuan13BDecodeLayer(config.batch_size, - config.seq_length, - layer_id, - dim=config.hidden_size, - n_heads=config.num_heads, - n_kv_heads=config.n_kv_heads, - intermediate_size=config.intermediate_size, - multiple_of=config.multiple_of, - ffn_dim_multiplier=config.ffn_dim_multiplier, - norm_eps=config.rms_norm_eps, - compute_dtype=config.compute_dtype, - layernorm_compute_dtype=config.layernorm_compute_type, - softmax_compute_dtype=config.softmax_compute_type, - param_init_type=config.param_init_type, - use_past=config.use_past, - is_dynamic=config.is_dynamic, - use_flash_attention=self.use_flash_attention, - block_size=self.block_size, - num_blocks=self.num_blocks, - parallel_config=config.parallel_config) - self.layers_setting(layer, layer_id) + config.seq_length, + layer_id, + dim=config.hidden_size, + n_heads=config.num_heads, + n_kv_heads=config.n_kv_heads, + intermediate_size=config.intermediate_size, + multiple_of=config.multiple_of, + ffn_dim_multiplier=config.ffn_dim_multiplier, + norm_eps=config.rms_norm_eps, + compute_dtype=config.compute_dtype, + layernorm_compute_dtype=config.layernorm_compute_type, + softmax_compute_dtype=config.softmax_compute_type, + param_init_type=config.param_init_type, + use_past=config.use_past, + is_dynamic=config.is_dynamic, + use_flash_attention=self.use_flash_attention, + block_size=self.block_size, + num_blocks=self.num_blocks, + parallel_config=config.parallel_config) + self.layer_setting(layer, layer_id) self.layers.append(layer) self.norm_out = LlamaRMSNorm(config.hidden_size, config.rms_norm_eps, compute_type=config.layernorm_compute_type) @@ -373,9 +380,6 @@ class Baichuan13BV2Model(Baichuan2PreTrainedModel): self.gather.shard(((1, mp, 1, 1), (1,))) self.norm_out.shard((dp, 1, 1)) - if self.is_dynamic: - self.reshape.add_prim_attr("skip_redistribution", True) - # pylint: disable=W0613 def construct(self, tokens: Tensor, batch_valid_length=None, block_tables=None, slot_mapping=None): """Forward of baichuan2_13b model.""" @@ -385,9 +389,7 @@ class Baichuan13BV2Model(Baichuan2PreTrainedModel): if not self.use_past: mask = self.casual_mask(tokens) # mask: mask: [bs , 1, seq, seq] input_mask = self.cast(self.not_equal(tokens, self.pad_token_id), mstype.float16) - alibi_tensor = self.slice(self.alibi_tensor, (0, 0, 0, 0), - (1, self.alibi_tensor.shape[1], seq_len, seq_len), (1, 1, 1, 1)) - alibi_tensor = self.mul_alibi(alibi_tensor, self.reshape(input_mask, (bs, 1, -1, 1))) + alibi_tensor = self.mul_alibi(self.alibi_tensor, self.reshape(input_mask, (bs, 1, -1, 1))) else: mask = None if self.is_first_iteration: @@ -398,6 +400,12 @@ class Baichuan13BV2Model(Baichuan2PreTrainedModel): else: alibi_tensor = self.gather(self.alibi_tensor, batch_valid_length, 2) alibi_tensor = self.transpose(alibi_tensor, (2, 1, 0, 3)) + alibi_tensor = self.slice(alibi_tensor, + (0, 0, 0, 0), + (alibi_tensor.shape[0], alibi_tensor.shape[1], alibi_tensor.shape[2], + block_tables.shape[1] * self.block_size), + (1, 1, 1, 1), + ) # tokens: [bs, seq/1] h = self.tok_embeddings(tokens) h = self.reshape(h, (bs, seq_len, self.hidden_size)) @@ -509,8 +517,6 @@ class Baichuan13BAttention(nn.Cell): self.shape = P.Shape() self.reshape = P.Reshape() - if is_dynamic: - self.reshape.add_prim_attr("skip_redistribution", True) self.transpose = P.Transpose() self.merger_head_transpose = P.Transpose() self.batch_matmul = P.BatchMatMul() @@ -792,8 +798,6 @@ class Baichuan13BDecodeLayer(nn.Cell): self.shape = P.Shape() self.reshape = P.Reshape() - if is_dynamic: - self.reshape.add_prim_attr("skip_redistribution", True) self.add = P.Add() self.attention_norm = LlamaRMSNorm(self.hidden_size, norm_eps, compute_type=layernorm_compute_dtype) self.ffn_norm = LlamaRMSNorm(self.hidden_size, norm_eps, compute_type=layernorm_compute_dtype) @@ -881,7 +885,6 @@ class NormHead(nn.Cell): hidden_size, vocab_size, use_past, - is_dynamic=False, compute_dtype=mstype.float32, eps=1e-5): super().__init__() @@ -908,9 +911,6 @@ class NormHead(nn.Cell): self.vocab_size = vocab_size self.assign = P.Assign() - if is_dynamic: - self.reshape.add_prim_attr("skip_redistribution", True) - def construct(self, hidden_states): """Forward process of the NormHead""" out_shape = P.Shape()(hidden_states)[:-1] + (self.vocab_size,) @@ -957,7 +957,7 @@ class DPOLoss(nn.Cell): self.log = P.Log() self.reduce_sum = P.ReduceSum(keep_dims=False) self.slice = P.StridedSlice().shard(((1, 1),)) # ? - self.slice_ind = P.StridedSlice().shard(((1,),)) # ? + self.slice_ind = P.StridedSlice().shard(((1,),)) # ? self.mul = P.Mul().shard(((dp, mp), (dp, mp))) self.sub = P.Sub().shard(((dp, mp), (dp, mp))) self.log_softmax = P.LogSoftmax() @@ -1023,7 +1023,7 @@ class DPOLoss(nn.Cell): # rejected_ref_logps: [bs,] # [bs,] all_logps = self._get_batch_logps(policy_logits, policy_labels, loss_mask) - bs = all_logps.shape[0] // 2 # a sample has two bs responses (chosen and rejected) + bs = all_logps.shape[0] // 2 # a sample has two bs responses (chosen and rejected) policy_chosen_logps = self.slice_ind(all_logps, (0,), (bs,), (1,)) policy_rejected_logps = self.slice_ind(all_logps, (bs,), (2 * bs,), (1,)) losses, chosen_rewards, rejected_rewards = self.dpo_loss( @@ -1046,11 +1046,13 @@ class DPOCrossEntropy(CrossEntropyLoss): self.slice_2d = P.StridedSlice().shard(((dp, mp),)) def construct(self, logits, label, input_mask): - bs, seq_len, vocab_size = logits.shape # a sample has two bs responses (chosen and rejected) + bs, seq_len, vocab_size = logits.shape # a sample has two bs responses (chosen and rejected) policy_chosen_logps = self.slice_3d(logits, (0, 0, 0), (bs // 2, seq_len, vocab_size), (1, 1, 1)) label = self.slice_2d(label, (0, 0), (bs // 2, seq_len), (1, 1)) input_mask = self.slice_2d(input_mask, (0, 0), (bs // 2, seq_len), (1, 1)) - return super().construct(policy_chosen_logps.reshape((-1, policy_chosen_logps.shape[-1])), label.reshape((-1,)), input_mask.reshape((-1,))) + return super().construct(policy_chosen_logps.reshape((-1, policy_chosen_logps.shape[-1])), label.reshape((-1,)), + input_mask.reshape((-1,))) + @MindFormerRegister.register(MindFormerModuleType.LOSS) class DPOLossV2(nn.Cell): @@ -1062,7 +1064,7 @@ class DPOLossV2(nn.Cell): self.log = P.Log() self.reduce_sum = P.ReduceSum(keep_dims=False) self.slice = P.StridedSlice().shard(((1, 1),)) # ? - self.slice_ind = P.StridedSlice().shard(((1,),)) # ? + self.slice_ind = P.StridedSlice().shard(((1,),)) # ? self.slice_mask = P.StridedSlice().shard(((1, 1),)) self.mul = P.Mul().shard(((dp, mp), (dp, mp))) self.sub = P.Sub().shard(((dp, mp), (dp, mp))) @@ -1097,9 +1099,7 @@ class DPOLossV2(nn.Cell): # [bs, seq_len] -> [bs, seq_len] labels = self.mul(labels, loss_mask) # [bs, seq_len, vocab_size] - print("logits", logits) log_probs = self.log_softmax(logits) - print("log_probs", log_probs) # [bs, seq_len] -> [bs, seq_len, 1] index = self.expand(labels, -1) index = self.cast(index, mstype.int32) @@ -1107,9 +1107,7 @@ class DPOLossV2(nn.Cell): per_token_logps = self.gatherd(log_probs, -1, index) # [bs, seq_len, 1] -> [bs, seq_len] per_token_logps = self.squeeze(per_token_logps) - print("per_token_logps", per_token_logps) if self.average_log_prob: - print("per_token_logps final", self.reduce_sum(per_token_logps * loss_mask, -1), self.reduce_sum(loss_mask, -1)) return self.reduce_sum(per_token_logps * loss_mask, -1) / self.reduce_sum(loss_mask, -1) else: return self.reduce_sum(per_token_logps * loss_mask, -1) @@ -1140,7 +1138,8 @@ class DPOLossV2(nn.Cell): rejected_rewards = self.beta * (policy_rejected_logps - ref_rejected_logps) return losses, chosen_rewards, rejected_rewards, policy_chosen_logps_avg - def construct(self, policy_logits, policy_labels, chosen_loss_mask, rejected_loss_mask, ref_chosen_logps, ref_rejected_logps): + def construct(self, policy_logits, policy_labels, chosen_loss_mask, rejected_loss_mask, ref_chosen_logps, + ref_rejected_logps): # policy_logits: [bs, seq_len, vocab_size] # policy_labels: [bs, seq_len] # loss_mask: [bs, seq_len] @@ -1149,16 +1148,12 @@ class DPOLossV2(nn.Cell): # [bs,] loss_mask = ops.concat((chosen_loss_mask, rejected_loss_mask), axis=0) all_logps = self._get_batch_logps(policy_logits, policy_labels, loss_mask) - bs = all_logps.shape[0] // 2 # a sample has two bs responses (chosen and rejected) + bs = all_logps.shape[0] // 2 # a sample has two bs responses (chosen and rejected) policy_chosen_logps = self.slice_ind(all_logps, (0,), (bs,), (1,)) policy_rejected_logps = self.slice_ind(all_logps, (bs,), (2 * bs,), (1,)) - print("policy_chosen_logps", policy_chosen_logps) - print("policy_rejected_logps", policy_rejected_logps) if self.average_log_prob: ref_chosen_logps = ref_chosen_logps / self.reduce_sum(chosen_loss_mask, -1) ref_rejected_logps = ref_rejected_logps / self.reduce_sum(rejected_loss_mask, -1) - print("ref_chosen_logps", ref_chosen_logps) - print("ref_rejected_logps", ref_rejected_logps) dpo_loss, chosen_rewards, rejected_rewards, policy_chosen_logps_avg = self.dpo_loss( policy_chosen_logps, policy_rejected_logps, @@ -1167,12 +1162,11 @@ class DPOLossV2(nn.Cell): loss_mask ) sft_loss = -policy_chosen_logps_avg - print("sft loss: ", sft_loss) - print("dpo loss: ", dpo_loss) if self.phase == "train": return dpo_loss, sft_loss return dpo_loss, sft_loss, chosen_rewards, rejected_rewards + @MindFormerRegister.register(MindFormerModuleType.MODELS) class Baichuan13BDPO(Baichuan13BV2ForCausalLM): r""" @@ -1234,28 +1228,19 @@ class Baichuan13BDPO(Baichuan13BV2ForCausalLM): loss_parallel_config = copy.deepcopy(config) loss_parallel_config.parallel_config.model_parallel = dp * mp loss_parallel_config.parallel_config.data_parallel = 1 - if dp >= 32 and dp % 8 == 0: # For large scale training + if dp >= 32 and dp % 8 == 0: # For large scale training loss_parallel_config.parallel_config.model_parallel = 8 loss_parallel_config.parallel_config.data_parallel = dp * mp // 8 self.dpo_loss = DPOLossV2(loss_parallel_config) self.alpha = config.alpha self.beta = config.beta - if config.parallel_config.vocab_emb_dp or (config.vocab_size % mp != 0): - self.sft_loss = DPOCrossEntropy(parallel_config=config.parallel_config) - else: - loss_parallel_config = copy.deepcopy(config.parallel_config) - loss_parallel_config.model_parallel = dp * mp - loss_parallel_config.data_parallel = 1 - if dp >= 32 and dp % 8 == 0: # For large scale training - loss_parallel_config.model_parallel = 8 - loss_parallel_config.data_parallel = dp * mp // 8 - self.sft_loss = DPOCrossEntropy(parallel_config=loss_parallel_config) + # pylint: disable=W0613 - def construct(self, chosen_input_ids, chosen_labels=None, chosen_loss_mask=None, - chosen_ref_logps=None, rejected_input_ids=None, rejected_labels=None, - rejected_loss_mask=None, rejected_ref_logps=None, + def construct(self, chosen_input_ids, chosen_labels=None, chosen_loss_mask=None, + chosen_ref_logps=None, rejected_input_ids=None, rejected_labels=None, + rejected_loss_mask=None, rejected_ref_logps=None, input_position=None, position_ids=None, input_embeds=None, init_reset=True, batch_valid_length=None, batch_index=None, zactivate_len=None, block_tables=None, slot_mapping=None): @@ -1272,12 +1257,7 @@ class Baichuan13BDPO(Baichuan13BV2ForCausalLM): if self.use_past: if not isinstance(batch_valid_length, Tensor): batch_valid_length = self.ones((bsz,), mstype.int32) - if self.training: - tokens = self.slice(input_ids, (0, 0), (bsz, ori_seqlen - 1), (1, 1)) - chosen_loss_mask = self.slice(chosen_loss_mask, (0, 1), (bsz, ori_seqlen), (1, 1)) - rejected_loss_mask = self.slice(rejected_loss_mask, (0, 1), (bsz, ori_seqlen), (1, 1)) - else: - tokens = input_ids + tokens = input_ids if batch_valid_length is not None: batch_valid_length = self.reshape(batch_valid_length, (-1,)) output = self.model(tokens, batch_valid_length, block_tables, slot_mapping) @@ -1286,26 +1266,12 @@ class Baichuan13BDPO(Baichuan13BV2ForCausalLM): output = self.gather(output, self.sub_batch_valid_len(batch_valid_length, 1), 1) logits = self.lm_head(output) - input_mask = self.cast(self.not_equal(tokens, self.pad_token_id), mstype.float32) - if labels is None: - labels = self.slice(input_ids, (0, 1), (bsz, ori_seqlen), (1, 1)) - else: - if labels.ndim > 1: - if self.training: - labels = self.slice(labels, (0, 1), (bsz, ori_seqlen), (1, 1)) - label_mask = self.cast(self.not_equal(labels, self.ignore_token_id), mstype.float32) - input_mask = self.mul(input_mask, label_mask) - if not self.training: - return logits, tokens, input_mask + return (logits,) if logits.ndim <= 2: logits = self.reshape(logits, (bsz, tokens.shape[1], logits.shape[-1])) - # logits = self.cast(logits, mstype.float32) - # labels = self.reshape(labels, (-1,)) - # input_mask = self.reshape(input_mask, (-1,)) - # loss = self.loss(logits, labels, input_mask) policy_logits = self.cast(logits, mstype.float32) - # sft_loss = self.sft_loss(policy_logits, labels, loss_mask) - dpo_loss, sft_loss = self.dpo_loss(policy_logits, labels, chosen_loss_mask, rejected_loss_mask, chosen_ref_logps.reshape((-1,)), rejected_ref_logps.reshape((-1,))) - return self.alpha * dpo_loss + self.beta * sft_loss + dpo_loss, sft_loss = self.dpo_loss(policy_logits, labels, chosen_loss_mask, rejected_loss_mask, + chosen_ref_logps.reshape((-1,)), rejected_ref_logps.reshape((-1,))) + return dpo_loss + self.alpha * sft_loss diff --git a/mindrlhf/models/baichuan2/baichuan2_7b.py b/mindrlhf/models/baichuan2/baichuan2_7b.py index 33d18666995987edf72e181a7b572ead3da14be7..fcad8b333761c3062c348cb223b8ab6d7359f899 100644 --- a/mindrlhf/models/baichuan2/baichuan2_7b.py +++ b/mindrlhf/models/baichuan2/baichuan2_7b.py @@ -1,4 +1,4 @@ -# Copyright 2023 Huawei Technologies Co., Ltd +# Copyright 2024 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,11 +29,9 @@ from mindspore.ops import operations as P from mindspore.parallel._utils import _get_parallel_mode, _is_sharding_propagation from mindspore.common.initializer import initializer, HeUniform -import mindformers from mindformers.core.loss.loss import CrossEntropyLoss from mindformers.models.modeling_utils import PreTrainedModel - -from mindformers.models.utils import lazy_inline +from mindformers.models.utils import lazy_inline, LayerSetting from mindformers.modules.transformer.op_parallel_config import _check_config from mindformers.modules.transformer.transformer import LowerTriangularMaskWithDynamic from mindformers.modules.layers import FreqsMgr @@ -48,7 +46,6 @@ from mindformers.tools.utils import get_use_rope_self_define, get_predict_run_mo __all__ = ['Baichuan7BV2ForCausalLM', 'Baichuan7BV2Model'] -mindformers_version = mindformers.__version__ class Baichuan2PreTrainedModel(PreTrainedModel): """ @@ -120,61 +117,35 @@ class Baichuan7BV2Model(Baichuan2PreTrainedModel): param_init_type=config.param_init_type, parallel_optimizer=True) self.layers = nn.CellList() - if mindformers_version == "r1.2.0": - from mindformers.models.utils import set_layer_stage_recompute - for layer_id in range(config.num_layers): - layer = LLamaDecodeLayer(layer_id, - dim=config.hidden_size, - n_heads=config.num_heads, - n_kv_heads=config.n_kv_heads, - intermediate_size=config.intermediate_size, - multiple_of=config.multiple_of, - ffn_dim_multiplier=config.ffn_dim_multiplier, - norm_eps=config.rms_norm_eps, - qkv_has_bias=config.qkv_has_bias, - qkv_concat=config.qkv_concat, - compute_dtype=config.compute_dtype, - layernorm_compute_dtype=config.layernorm_compute_type, - softmax_compute_dtype=config.softmax_compute_type, - rotary_dtype=config.rotary_dtype, - param_init_type=config.param_init_type, - use_past=config.use_past, - use_flash_attention=self.use_flash_attention, - is_dynamic=config.is_dynamic, - block_size=config.block_size, - num_blocks=config.num_blocks, - use_rope_slice=config.use_rope_slice, - parallel_config=config.parallel_config) - set_layer_stage_recompute(layer, layer_id, config.offset, config.parallel_config, config.num_layers) - self.layers.append(layer) - elif mindformers_version == "r1.3.0": - from mindformers.models.utils import LayerSetting - self.layers_setting = LayerSetting(config.offset, config.parallel_config, config.num_layers) - for layer_id in range(config.num_layers): - layer = LLamaDecodeLayer(layer_id, - dim=config.hidden_size, - n_heads=config.num_heads, - n_kv_heads=config.n_kv_heads, - intermediate_size=config.intermediate_size, - multiple_of=config.multiple_of, - ffn_dim_multiplier=config.ffn_dim_multiplier, - norm_eps=config.rms_norm_eps, - qkv_has_bias=config.qkv_has_bias, - qkv_concat=config.qkv_concat, - compute_dtype=config.compute_dtype, - layernorm_compute_dtype=config.layernorm_compute_type, - softmax_compute_dtype=config.softmax_compute_type, - rotary_dtype=config.rotary_dtype, - param_init_type=config.param_init_type, - use_past=config.use_past, - use_flash_attention=self.use_flash_attention, - is_dynamic=config.is_dynamic, - block_size=config.block_size, - num_blocks=config.num_blocks, - use_rope_slice=config.use_rope_slice, - parallel_config=config.parallel_config) - self.layers_setting(layer, layer_id) - self.layers.append(layer) + self.layer_setting = LayerSetting(config.num_layers, + config.offset, + config.parallel_config, + config.pp_interleave_num) + for layer_id in range(config.num_layers): + layer = LLamaDecodeLayer(layer_id, + dim=config.hidden_size, + n_heads=config.num_heads, + n_kv_heads=config.n_kv_heads, + intermediate_size=config.intermediate_size, + multiple_of=config.multiple_of, + ffn_dim_multiplier=config.ffn_dim_multiplier, + norm_eps=config.rms_norm_eps, + qkv_has_bias=config.qkv_has_bias, + qkv_concat=config.qkv_concat, + compute_dtype=config.compute_dtype, + layernorm_compute_dtype=config.layernorm_compute_type, + softmax_compute_dtype=config.softmax_compute_type, + rotary_dtype=config.rotary_dtype, + param_init_type=config.param_init_type, + use_past=config.use_past, + use_flash_attention=self.use_flash_attention, + is_dynamic=config.is_dynamic, + block_size=config.block_size, + num_blocks=config.num_blocks, + use_rope_slice=config.use_rope_slice, + parallel_config=config.parallel_config) + self.layer_setting(layer, layer_id) + self.layers.append(layer) self.norm_out = LlamaRMSNorm(config.hidden_size, config.rms_norm_eps, compute_type=config.layernorm_compute_type) @@ -258,7 +229,6 @@ class NormHead(nn.Cell): Outputs: Tensor of shape :math:`(batch, seq_length, vocab_size)`. """ - def __init__(self, hidden_size, vocab_size, @@ -328,8 +298,7 @@ class NormHead(nn.Cell): @MindFormerRegister.register(MindFormerModuleType.MODELS) class Baichuan7BV2ForCausalLM(Baichuan2PreTrainedModel): - r""" - Provide baichuan2_7b training loss or logits through network. + r"""Provide baichuan2_7b training loss or logits through network. Args: config (LlamaConfig): The config of baichuan2_7b model. @@ -351,8 +320,7 @@ class Baichuan7BV2ForCausalLM(Baichuan2PreTrainedModel): Returns: Tensor, the loss or logits of the network. - """ - + """ @lazy_inline def __init__(self, config: LlamaConfig = None): super(Baichuan7BV2ForCausalLM, self).__init__(config, auto_prefix=True) @@ -390,7 +358,11 @@ class Baichuan7BV2ForCausalLM(Baichuan2PreTrainedModel): vocab_size, loss_parallel_config.model_parallel) logger.warning("Now, the model_parallel num of Loss will be changed: mp = 1") loss_parallel_config.model_parallel = 1 - self.loss = CrossEntropyLoss(parallel_config=loss_parallel_config) + check_for_nan_in_loss_and_grad = getattr(config, "check_for_nan_in_loss_and_grad", False) + calculate_per_token_loss = getattr(config, "calculate_per_token_loss", False) + self.loss = CrossEntropyLoss(parallel_config=loss_parallel_config, + check_for_nan_in_loss_and_grad=check_for_nan_in_loss_and_grad, + calculate_per_token_loss=calculate_per_token_loss) dp = config.parallel_config.data_parallel if not (_get_parallel_mode() in (ParallelMode.AUTO_PARALLEL,) and _is_sharding_propagation()): @@ -409,7 +381,6 @@ class Baichuan7BV2ForCausalLM(Baichuan2PreTrainedModel): self.lm_head.set_comm_fusion(config.parallel_config.gradient_aggregation_group) self.load_checkpoint(config) - self.set_model_predict_config() self.predict_run_mode = get_predict_run_mode() # pylint: disable=W0613 @@ -431,12 +402,10 @@ class Baichuan7BV2ForCausalLM(Baichuan2PreTrainedModel): def set_dynamic_inputs(self, **kwargs): dynamic_input_ids = Tensor(shape=[None, None], dtype=mstype.int32) - dynamic_input_position = Tensor(shape=[None], dtype=mstype.int32) - dynamic_init_reset = Tensor([False], mstype.bool_) dynamic_batch_valid_length = Tensor(shape=[None, None], dtype=mstype.int32) dynamic_block_tables = Tensor(shape=[None, None], dtype=mstype.int32) dynamic_slot_mapping = Tensor(shape=[None], dtype=mstype.int32) - self.set_inputs(dynamic_input_ids, None, dynamic_input_position, None, None, None, dynamic_init_reset, + self.set_inputs(dynamic_input_ids, None, None, None, None, None, None, dynamic_batch_valid_length, None, None, dynamic_block_tables, dynamic_slot_mapping) logger.info("Set dynamic input for baichuan2.") @@ -450,7 +419,7 @@ class Baichuan7BV2ForCausalLM(Baichuan2PreTrainedModel): # pylint: disable=W0613 def construct(self, input_ids, labels=None, input_position=None, position_ids=None, attention_mask=None, - input_embeds=None, init_reset=True, batch_valid_length=None, batch_index=None, zactivate_len=None, + input_embeds=None, init_reset=None, batch_valid_length=None, batch_index=None, zactivate_len=None, block_tables=None, slot_mapping=None): """Baichuan7BV2 ForCausalLM forward.""" bsz, seqlen = self.shape(input_ids) diff --git a/mindrlhf/models/qwen2/qwen2_tokenizer.py b/mindrlhf/models/qwen2/qwen2_tokenizer.py index 8ffaee39c386435bb6f72f9271b3198555cd7a2c..054b138a4eda1028ad81abfafea3c386ed3e3fca 100644 --- a/mindrlhf/models/qwen2/qwen2_tokenizer.py +++ b/mindrlhf/models/qwen2/qwen2_tokenizer.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """Tokenization classes for Qwen2.""" +# https://gitee.com/mindspore/mindformers/blob/r1.3.0/research/qwen2/qwen2_tokenizer.py import json import os @@ -26,7 +27,7 @@ from mindspore import log as logger from mindformers.tools.register import MindFormerRegister, MindFormerModuleType from mindformers.models.tokenization_utils import PreTrainedTokenizer from mindformers.models.tokenization_utils_base import AddedToken - +from mindformers.tools.utils import check_file VOCAB_FILES_NAMES = { "vocab_file": "vocab.json", @@ -43,49 +44,11 @@ MAX_MODEL_INPUT_SIZES = {"qwen/qwen-tokenizer": 32768} PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" ENDOFTEXT = "<|endoftext|>" -IMSTART = "<|im_start|>" -IMEND = "<|im_end|>" -REFSTART = "<|object_ref_start|>" -REFEND = "<|object_ref_end|>" -BOXSTART = "<|box_start|>" -BOXEND = "<|box_end|>" -QUADSTART = "<|quad_start|>" -QUADEND = "<|quad_end|>" -VISIONSTART = "<|vision_start|>" -VISIONEND = "<|vision_end|>" -VISIONPAD = "<|vision_pad|>" -IMAGEPAD = "<|image_pad|>" -VIDEOPAD = "<|video_pad|>" -TOOLCALLSTART = "" -TOOLCALLEND = "" -FIMPREFIX = "<|fim_prefix|>" -FIMMIDDLE = "<|fim_middle|>" -FIMSUFFIX = "<|fim_suffix|>" -FIMPAD = "<|fim_pad|>" -REPONAME = "<|repo_name|>" -FILESEP = "<|file_sep|>" +IMSTART = "<|im_start|>" # used in Qwen-72B-chat +IMEND = "<|im_end|>" # used in Qwen-72B-chat ENDOFTEXTID = 151643 IMSTARTID = 151644 IMENDID = 151645 -REFSTARTID = 151646 -REFENDID = 151647 -BOXSTARTID = 151648 -BOXENDID = 151649 -QUADSTARTID = 151650 -QUADENDID = 151651 -VISIONSTARTID = 151652 -VISIONENDID = 151653 -VISIONPADID = 151654 -IMAGEPADID = 151655 -VIDEOPADID = 151656 -TOOLCALLSTARTID = 151657 -TOOLCALLENDID = 151658 -FIMPREFIXID = 151659 -FIMMIDDLEID = 151660 -FIMSUFFIXID = 151661 -FIMPADID = 151662 -REPONAMEID = 151663 -FILESEPID = 151664 @lru_cache() @@ -100,15 +63,15 @@ def bytes_to_unicode(): tables between utf-8 bytes and unicode strings. """ bs = ( - list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), - ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), + ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) ) cs = bs[:] n = 0 - for b in range(2**8): + for b in range(2 ** 8): if b not in bs: bs.append(b) - cs.append(2**8 + n) + cs.append(2 ** 8 + n) n += 1 cs = [chr(n) for n in cs] return dict(zip(bs, cs)) @@ -212,90 +175,15 @@ class Qwen2Tokenizer(PreTrainedTokenizer): IMSTART, lstrip=False, rstrip=False, special=True, normalized=False) im_end_token = AddedToken( IMEND, lstrip=False, rstrip=False, special=True, normalized=False) - ref_start_token = AddedToken( - REFSTART, lstrip=False, rstrip=False, special=True, normalized=False) - ref_end_token = AddedToken( - REFEND, lstrip=False, rstrip=False, special=True, normalized=False) - box_start_token = AddedToken( - BOXSTART, lstrip=False, rstrip=False, special=True, normalized=False) - box_end_token = AddedToken( - BOXEND, lstrip=False, rstrip=False, special=True, normalized=False) - quad_start_token = AddedToken( - QUADSTART, lstrip=False, rstrip=False, special=True, normalized=False) - quad_end_token = AddedToken( - QUADEND, lstrip=False, rstrip=False, special=True, normalized=False) - vision_start_token = AddedToken( - VISIONSTART, lstrip=False, rstrip=False, special=True, normalized=False) - vision_end_token = AddedToken( - VISIONEND, lstrip=False, rstrip=False, special=True, normalized=False) - vision_pad_token = AddedToken( - VISIONPAD, lstrip=False, rstrip=False, special=True, normalized=False) - image_pad_token = AddedToken( - IMAGEPAD, lstrip=False, rstrip=False, special=True, normalized=False) - video_pad_token = AddedToken( - VIDEOPAD, lstrip=False, rstrip=False, special=True, normalized=False) - toolcall_start_token = AddedToken( - TOOLCALLSTART, lstrip=False, rstrip=False, special=True, normalized=False) - toolcall_end_token = AddedToken( - TOOLCALLEND, lstrip=False, rstrip=False, special=True, normalized=False) - fim_prefix_token = AddedToken( - FIMPREFIX, lstrip=False, rstrip=False, special=True, normalized=False) - fim_middle_token = AddedToken( - FIMMIDDLE, lstrip=False, rstrip=False, special=True, normalized=False) - fim_suffix_token = AddedToken( - FIMSUFFIX, lstrip=False, rstrip=False, special=True, normalized=False) - fim_pad_token = AddedToken( - FIMPAD, lstrip=False, rstrip=False, special=True, normalized=False) - repo_name_token = AddedToken( - REPONAME, lstrip=False, rstrip=False, special=True, normalized=False) - file_sep_token = AddedToken( - FILESEP, lstrip=False, rstrip=False, special=True, normalized=False) self.special_tokens = { ENDOFTEXT: ENDOFTEXTID, IMSTART: IMSTARTID, IMEND: IMENDID, - REFSTART: REFSTARTID, - REFEND: REFENDID, - BOXSTART: BOXSTARTID, - BOXEND: BOXENDID, - QUADSTART: QUADSTARTID, - QUADEND: QUADENDID, - VISIONSTART: VISIONSTARTID, - VISIONEND: VISIONENDID, - VISIONPAD: VISIONPADID, - IMAGEPAD: IMAGEPADID, - VIDEOPAD: VIDEOPADID, - TOOLCALLSTART: TOOLCALLSTARTID, - TOOLCALLEND: TOOLCALLENDID, - FIMPREFIX: FIMPREFIXID, - FIMMIDDLE: FIMMIDDLEID, - FIMSUFFIX: FIMSUFFIXID, - FIMPAD: FIMPADID, - REPONAME: REPONAMEID, - FILESEP: FILESEPID, } self.end_of_text_id = self.special_tokens[ENDOFTEXT] self.im_start_id = self.special_tokens[IMSTART] self.im_end_id = self.special_tokens[IMEND] - self.ref_start_id = self.special_tokens[REFSTART] - self.ref_end_id = self.special_tokens[REFEND] - self.box_start_id = self.special_tokens[BOXSTART] - self.box_end_id = self.special_tokens[BOXEND] - self.quad_start_id = self.special_tokens[QUADSTART] - self.quad_end_id = self.special_tokens[QUADEND] - self.vision_start_id = self.special_tokens[VISIONSTART] - self.vision_end_id = self.special_tokens[VISIONEND] - self.vision_pad_id = self.special_tokens[VISIONPAD] - self.image_pad_id = self.special_tokens[IMAGEPAD] - self.video_pad_id = self.special_tokens[VIDEOPAD] - self.toolcall_start_id = self.special_tokens[TOOLCALLSTART] - self.toolcall_end_id = self.special_tokens[TOOLCALLEND] - self.fim_prefix_id = self.special_tokens[FIMPREFIX] - self.fim_middle_id = self.special_tokens[FIMMIDDLE] - self.fim_suffix_id = self.special_tokens[FIMSUFFIX] - self.fim_pad_id = self.special_tokens[FIMPAD] - self.repo_name_id = self.special_tokens[REPONAME] - self.file_sep_id = self.special_tokens[FILESEP] + check_file(vocab_file, "tokenizer") with open(vocab_file, encoding="utf-8") as vocab_handle: self.encoder = json.load(vocab_handle) self.decoder = {v: k for k, v in self.encoder.items()} @@ -324,25 +212,6 @@ class Qwen2Tokenizer(PreTrainedTokenizer): self.end_of_text_id: end_of_text_token, self.im_start_id: im_start_token, self.im_end_id: im_end_token, - self.ref_start_id: ref_start_token, - self.ref_end_id: ref_end_token, - self.box_start_id: box_start_token, - self.box_end_id: box_end_token, - self.quad_start_id: quad_start_token, - self.quad_end_id: quad_end_token, - self.vision_start_id: vision_start_token, - self.vision_end_id: vision_end_token, - self.vision_pad_id: vision_pad_token, - self.image_pad_id: image_pad_token, - self.video_pad_id: video_pad_token, - self.toolcall_start_id: toolcall_start_token, - self.toolcall_end_id: toolcall_end_token, - self.fim_prefix_id: fim_prefix_token, - self.fim_middle_id: fim_middle_token, - self.fim_suffix_id: fim_suffix_token, - self.fim_pad_id: fim_pad_token, - self.repo_name_id: repo_name_token, - self.file_sep_id: file_sep_token, } super().__init__( @@ -445,9 +314,8 @@ class Qwen2Tokenizer(PreTrainedTokenizer): # and cannot be configured elsewhere, but it should default to False for Qwen2Tokenizer # avoid out of vocab for i in range(len(token_ids)): - valid_token_ids = np.array([x if x <= 151664 else 151664 for x in token_ids[i]]) + valid_token_ids = np.array([x if x <= 151634 else 151634 for x in token_ids[i]]) token_ids[i] = valid_token_ids - return super().decode( token_ids, skip_special_tokens=skip_special_tokens, @@ -464,13 +332,12 @@ class Qwen2Tokenizer(PreTrainedTokenizer): return None vocab_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + - VOCAB_FILES_NAMES["vocab_file"] + VOCAB_FILES_NAMES["vocab_file"] ) merge_file = os.path.join( save_directory, (filename_prefix + "-" if filename_prefix else "") + - VOCAB_FILES_NAMES["merges_file"] + VOCAB_FILES_NAMES["merges_file"] ) - flags_ = os.O_WRONLY | os.O_CREAT | os.O_TRUNC with os.fdopen(os.open(vocab_file, flags_, 0o750), "w", encoding="utf-8") as f: f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") diff --git a/mindrlhf/models/qwen2_5/qwen2_5_tokenizer.py b/mindrlhf/models/qwen2_5/qwen2_5_tokenizer.py index 7148637dfb1a98df25350513e77fa9fbdc582b7d..cee5e297323af8e6c7fe1b46b4e533734cd5efcb 100644 --- a/mindrlhf/models/qwen2_5/qwen2_5_tokenizer.py +++ b/mindrlhf/models/qwen2_5/qwen2_5_tokenizer.py @@ -26,6 +26,7 @@ from mindspore import log as logger from mindformers.tools.register import MindFormerRegister, MindFormerModuleType from mindformers.models.tokenization_utils import PreTrainedTokenizer from mindformers.models.tokenization_utils_base import AddedToken +from mindformers.tools.utils import check_file VOCAB_FILES_NAMES = { @@ -296,6 +297,7 @@ class Qwen2_5Tokenizer(PreTrainedTokenizer): self.fim_pad_id = self.special_tokens[FIMPAD] self.repo_name_id = self.special_tokens[REPONAME] self.file_sep_id = self.special_tokens[FILESEP] + check_file(vocab_file, "tokenizer") with open(vocab_file, encoding="utf-8") as vocab_handle: self.encoder = json.load(vocab_handle) self.decoder = {v: k for k, v in self.encoder.items()} @@ -442,12 +444,10 @@ class Qwen2_5Tokenizer(PreTrainedTokenizer): ) -> str: """decode token ids""" # `spaces_between_special_tokens` defaults to True for _decode in slow tokenizers - # and cannot be configured elsewhere, but it should default to False for Qwen2_5Tokenizer - # avoid out of vocab + # and cannot be configured elsewhere, but it should default to False for Qwen2Tokenizer for i in range(len(token_ids)): - valid_token_ids = np.array([x if x <= 151664 else 151664 for x in token_ids[i]]) + valid_token_ids = np.array([x if x <= 151634 else 151634 for x in token_ids[i]]) token_ids[i] = valid_token_ids - return super().decode( token_ids, skip_special_tokens=skip_special_tokens, diff --git a/mindrlhf/tools/dpo_preprocess.py b/mindrlhf/tools/dpo_preprocess.py index 2f31aa0060584d172bd958ce23a2ea25dbbe8823..e26bf5c224eed432be01c3b2350f68c21171bf32 100644 --- a/mindrlhf/tools/dpo_preprocess.py +++ b/mindrlhf/tools/dpo_preprocess.py @@ -1,3 +1,18 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + import argparse import numpy as np from tqdm import tqdm @@ -7,7 +22,6 @@ import os import mindspore as ms from mindformers import AutoModel from mindformers.tools.utils import str2bool -from examples.rlhf_train_tutorial.rlhf_data import write_mindrecord from mindrlhf.models.qwen2.qwen2_tokenizer import Qwen2Tokenizer from mindrlhf.models.qwen2_5.qwen2_5_tokenizer import Qwen2_5Tokenizer from mindrlhf.models.baichuan2.baichuan2_tokenizer import Baichuan2Tokenizer @@ -89,7 +103,7 @@ def get_logps(model_name, model, input_ids, labels, attention_mask, loss_mask): if len(loss_mask.shape) == 1: loss_mask = ms.ops.unsqueeze(loss_mask, 0) - if model_name in ['qwen2_7b', 'baichuan2_13b', 'qwen2_5_7b']: + if model_name in ['qwen2_7b', 'qwen2_5_7b']: input_ids = P.StridedSlice()(input_ids, (0, 0), (input_ids.shape[0], min(batch_length, input_ids.shape[1] - 1)), (1, 1)) labels = P.StridedSlice()(labels, (0, 1), (labels.shape[0], min(batch_length, labels.shape[1])), (1, 1)) @@ -212,9 +226,13 @@ def preprocess(data_path: str, dst_file: str, config_path: str, tokenizer_path: def _build(prompt_ids, resp_ids): # check input_ids > seq_length input_ids = prompt_ids + resp_ids - labels = input_ids[:] attention_mask = [1] * len(input_ids) - loss_mask = [0] * len(prompt_ids) + [1] * len(resp_ids) + if model_name in ["glm4_9b"]: + labels = input_ids[1:] + [tokenizer.pad_token_id] + loss_mask = [0] * len(prompt_ids) + [1] * (len(resp_ids) - 1) + [0] + else: + labels = input_ids[:] + loss_mask = [0] * len(prompt_ids) + [1] * len(resp_ids) input_len = len(input_ids) input_ids = input_ids + [tokenizer.pad_token_id] * (seq_len - input_len) diff --git a/model_configs/baichuan_config/predict_baichuan2_13b.yaml b/model_configs/baichuan_config/predict_baichuan2_13b.yaml index dbc2d889b8a2431da9184e6211c3bad79e3417c9..aaae2516deb11439436e38de6f429bbfb5cdcec0 100644 --- a/model_configs/baichuan_config/predict_baichuan2_13b.yaml +++ b/model_configs/baichuan_config/predict_baichuan2_13b.yaml @@ -91,7 +91,7 @@ parallel_config: pipeline_stage: 1 use_seq_parallel: False micro_batch_num: 1 - vocab_emb_dp: True + vocab_emb_dp: False gradient_aggregation_group: 4 # when model parallel is greater than 1, we can set micro_batch_interleave_num=2, that may accelerate the train process. micro_batch_interleave_num: 1 @@ -150,10 +150,10 @@ model: use_flash_attention: False block_size: 16 num_blocks: 512 - is_dynamic: True + is_dynamic: False extend_method: "None" # support "None", "PI", "NTK" offset: 0 - checkpoint_name_or_path: "/path/Baichuan2_13B_Base.ckpt" + checkpoint_name_or_path: "path/to/baichuan2-13B-Chat.ckpt" repetition_penalty: 1 temperature: 1.0 max_decode_length: 512 diff --git a/model_configs/baichuan_config/process_baichuan2_13b.yaml b/model_configs/baichuan_config/process_baichuan2_13b.yaml index 710c5e681973bc7b62f64e263da546aa9a93f4dc..499778d65906e0614943e5b299a04c854a79f51f 100644 --- a/model_configs/baichuan_config/process_baichuan2_13b.yaml +++ b/model_configs/baichuan_config/process_baichuan2_13b.yaml @@ -133,7 +133,7 @@ model: batch_size: 1 # add for increase predict seq_length: 4096 hidden_size: 5120 - num_layers: 1 + num_layers: 40 num_heads: 40 vocab_size: 125696 multiple_of: 128 diff --git a/model_configs/baichuan_config/run_baichuan2_13b_dpo.yaml b/model_configs/baichuan_config/run_baichuan2_13b_dpo.yaml index bf09786adfaf38205e2a52926d2db08dd296b923..80aac067fd2161a1a34f138113fb40ecee03093c 100644 --- a/model_configs/baichuan_config/run_baichuan2_13b_dpo.yaml +++ b/model_configs/baichuan_config/run_baichuan2_13b_dpo.yaml @@ -1,6 +1,6 @@ seed: 0 output_dir: './output' # path to save checkpoint/strategy -load_checkpoint: '/home/qianjiahong/ckpt/baichuan/Baichuan2_13B_Base.ckpt' +load_checkpoint: '' src_strategy_path_or_dir: '' auto_trans_ckpt: False # If true, auto transform load_checkpoint to load in distributed model only_save_strategy: False @@ -29,11 +29,11 @@ optimizer: beta2: 0.95 eps: 1.e-8 -# lr sechdule +# lr schedule lr_schedule: type: CosineWithWarmUpLR - learning_rate: 0 # pretrain:3.e-4 - lr_end: 0 # pretrain:3.e-5 + learning_rate: 2.e-5 # pretrain:3.e-4 + lr_end: 1.e-6 # pretrain:3.e-5 warmup_ratio: 0.03 total_steps: -1 # -1 means it will load the total steps of the dataset @@ -158,15 +158,15 @@ model: num_blocks: 512 is_dynamic: False offset: 0 - checkpoint_name_or_path: "" #/home/qianjiahong/0808/ckpt/baichuan/Baichuan2_13B_Base.ckpt + checkpoint_name_or_path: "" repetition_penalty: 1 temperature: 1.0 max_decode_length: 512 top_k: 3 top_p: 1 do_sample: False - alpha: 0.0 # coef for dpo loss - beta: 1.0 # coef for sft loss + alpha: 0.1 # coef for sft loss + beta: 1.0 # temperature of dpo loss in logsigmoid function arch: type: Baichuan13BDPO diff --git a/model_configs/glm_config/finetune_glm4_9b.yaml b/model_configs/glm_config/finetune_glm4_9b.yaml index 1b8dd178befb940a76327ec09b117f865aad65b4..c022b26d686346def8140f993c60e96def0ba160 100644 --- a/model_configs/glm_config/finetune_glm4_9b.yaml +++ b/model_configs/glm_config/finetune_glm4_9b.yaml @@ -14,7 +14,7 @@ context: enable_graph_kernel: False graph_kernel_flags: "--disable_expand_ops=Softmax,Dropout --enable_parallel_fusion=true --reduce_fuse_depth=8 --enable_auto_tensor_inplace=true" max_call_depth: 10000 - max_device_memory: "57GB" + max_device_memory: "59GB" save_graphs: False save_graphs_path: "./graph" device_id: 0 @@ -158,7 +158,7 @@ eval_dataset_task: # ==== runner config ==== runner_config: - epochs: 4 + epochs: 1 batch_size: 1 sink_mode: True sink_size: 1 @@ -209,8 +209,8 @@ parallel: strategy_ckpt_config: save_file: "./ckpt_strategy.ckpt" parallel_config: - data_parallel: 8 - model_parallel: 1 + data_parallel: 1 + model_parallel: 8 pipeline_stage: 1 expert_parallel: 1 micro_batch_num: 1 diff --git a/model_configs/glm_config/predict_glm4_9b.yaml b/model_configs/glm_config/predict_glm4_9b.yaml index 43c9dae27465711e2d3267c0fa90c428736bf6a8..7ffebb4b7da8ef6c4f43c275f8f8f2edc912068a 100644 --- a/model_configs/glm_config/predict_glm4_9b.yaml +++ b/model_configs/glm_config/predict_glm4_9b.yaml @@ -14,7 +14,7 @@ context: enable_graph_kernel: False graph_kernel_flags: "--disable_expand_ops=Softmax,Dropout --enable_parallel_fusion=true --reduce_fuse_depth=8 --enable_auto_tensor_inplace=true" max_call_depth: 10000 - max_device_memory: "57GB" + max_device_memory: "59GB" save_graphs: False save_graphs_path: "./graph" device_id: 0 @@ -135,7 +135,7 @@ parallel: strategy_ckpt_config: save_file: "./ckpt_strategy.ckpt" parallel_config: - data_parallel: 8 + data_parallel: 1 model_parallel: 1 pipeline_stage: 1 expert_parallel: 1 diff --git a/model_configs/qwen_config/finetune_qwen2_5_7b_dpo.yaml b/model_configs/qwen_config/finetune_qwen2_5_7b_dpo.yaml index fc7e4710fa609b35c7d6fc3cbeff35f59d086841..2088fe7c73aabcef38780f032d02105f587846f3 100644 --- a/model_configs/qwen_config/finetune_qwen2_5_7b_dpo.yaml +++ b/model_configs/qwen_config/finetune_qwen2_5_7b_dpo.yaml @@ -19,7 +19,7 @@ eval_epoch_interval: 50 # num of epoch intervals between each eval, 1 mea # runner config runner_config: - epochs: 10 + epochs: 1 batch_size: 1 sink_mode: True sink_size: 1 @@ -102,8 +102,8 @@ parallel: # default parallel of device num = 8 parallel_config: - data_parallel: 8 - model_parallel: 1 + data_parallel: 2 + model_parallel: 4 pipeline_stage: 1 use_seq_parallel: False micro_batch_num: 1 diff --git a/model_configs/qwen_config/finetune_qwen2_7b_dpo.yaml b/model_configs/qwen_config/finetune_qwen2_7b_dpo.yaml index 1b38a6edf7d0119e3d8f1bf519964798e52455d4..a25054cea2ad3636a221b69160ff9f22e846799f 100644 --- a/model_configs/qwen_config/finetune_qwen2_7b_dpo.yaml +++ b/model_configs/qwen_config/finetune_qwen2_7b_dpo.yaml @@ -19,7 +19,7 @@ eval_epoch_interval: 50 # num of epoch intervals between each eval, 1 mea # runner config runner_config: - epochs: 10 + epochs: 1 batch_size: 1 sink_mode: True sink_size: 1 @@ -102,8 +102,8 @@ parallel: # default parallel of device num = 8 parallel_config: - data_parallel: 8 - model_parallel: 1 + data_parallel: 2 + model_parallel: 4 pipeline_stage: 1 use_seq_parallel: False micro_batch_num: 1 diff --git a/model_configs/qwen_config/predict_qwen2_5_7b.yaml b/model_configs/qwen_config/predict_qwen2_5_7b.yaml index dd950f3f9b1aab3f54dc261ead2dd37e200e8c86..bda699590311cdc5d9be75517545bebdb1ae7a4b 100644 --- a/model_configs/qwen_config/predict_qwen2_5_7b.yaml +++ b/model_configs/qwen_config/predict_qwen2_5_7b.yaml @@ -5,7 +5,7 @@ src_strategy_path_or_dir: '' auto_trans_ckpt: False # If true, auto transform load_checkpoint to load in distributed model only_save_strategy: False resume_training: False -use_parallel: False +use_parallel: True run_mode: 'predict' # trainer config @@ -33,7 +33,7 @@ runner_wrapper: # default parallel of device num = 8 for Atlas 800T A2 parallel_config: data_parallel: 1 - model_parallel: 1 + model_parallel: 4 pipeline_stage: 1 micro_batch_num: 1 vocab_emb_dp: False @@ -45,7 +45,7 @@ model: model_config: type: LlamaConfig batch_size: 1 - seq_length: 32768 + seq_length: 4096 hidden_size: 3584 num_layers: 28 num_heads: 28 @@ -77,7 +77,7 @@ model: top_k: 20 top_p: 0.8 temperature: 0.7 - do_sample: True + do_sample: False is_dynamic: True qkv_concat: False auto_map: diff --git a/model_configs/qwen_config/predict_qwen2_7b.yaml b/model_configs/qwen_config/predict_qwen2_7b.yaml index 5a70a11c34b5e4cdd3ae6691429a36f9170c608d..43a67c0686e06a29483ff50f10b45888388761e3 100644 --- a/model_configs/qwen_config/predict_qwen2_7b.yaml +++ b/model_configs/qwen_config/predict_qwen2_7b.yaml @@ -5,7 +5,7 @@ src_strategy_path_or_dir: '' auto_trans_ckpt: False # If true, auto transform load_checkpoint to load in distributed model only_save_strategy: False resume_training: False -use_parallel: False +use_parallel: True run_mode: 'predict' # trainer config @@ -74,8 +74,8 @@ callbacks: # default parallel of device num = 8 for Atlas 800T A2 parallel_config: - data_parallel: 8 - model_parallel: 1 + data_parallel: 1 + model_parallel: 4 pipeline_stage: 1 micro_batch_num: 1 vocab_emb_dp: True @@ -115,7 +115,7 @@ model: softmax_compute_type: "float16" rotary_dtype: "float16" param_init_type: "float32" - use_past: False + use_past: True extend_method: "None" # support "None", "PI", "NTK" use_flash_attention: True fine_grain_interleave: 1 diff --git a/run_dpo.py b/run_dpo.py index 292c1b984add067e704389d6c578264c0a548d06..5e87a2c724e5ab303da5ee1ff276e3fba1f410a0 100644 --- a/run_dpo.py +++ b/run_dpo.py @@ -25,8 +25,6 @@ from mindformers.tools.utils import str2bool from mindformers.tools.logger import logger from mindformers.tools.cloud_adapter import cloud_monitor from mindformers.core.context import build_context -from mindformers.tools import get_output_root_path - from mindrlhf.models.baichuan2.baichuan2_13b import Baichuan13BDPO from mindrlhf.models.baichuan2.baichuan2_tokenizer import Baichuan2Tokenizer @@ -38,36 +36,11 @@ from mindrlhf.models.glm4.glm_dpo import Glm4DPO from mindrlhf.models.glm4.glm4_tokenizer import ChatGLM4Tokenizer from mindrlhf import DPODataset -def clear_auto_trans_output(config): - """clear transformed_checkpoint and strategy""" - strategy_dir = os.path.join(get_output_root_path(), "strategy") - if os.path.exists(strategy_dir) and config.local_rank == 0: - shutil.rmtree(strategy_dir) - os.makedirs(strategy_dir, exist_ok=True) - transformed_ckpt_dir = os.path.join(get_output_root_path(), "transformed_checkpoint") - if os.path.exists(transformed_ckpt_dir) and config.local_rank == 0: - shutil.rmtree(transformed_ckpt_dir) - os.makedirs(transformed_ckpt_dir, exist_ok=True) - - -def context_init(use_parallel=False, optimizer_parallel=False, device_id=0): - """init context for mindspore.""" - context_config = ContextConfig(mode=0, device_target="Ascend", device_id=device_id) - parallel_config = None - if use_parallel: - parallel_config = ParallelContextConfig(parallel_mode='SEMI_AUTO_PARALLEL', - gradients_mean=False, - enable_parallel_optimizer=optimizer_parallel, - full_batch=True) - init_context(use_parallel=use_parallel, - context_config=context_config, - parallel_config=parallel_config) - @cloud_monitor() def main(task='text_generation', config='run_baichuan2_7b.yaml', - run_mode='train', + run_mode=None, seq_length=None, mode=None, use_parallel=None, @@ -82,23 +55,22 @@ def main(task='text_generation', max_length=512, remote_save_url=None, vocab_file=None, - data_parallel=None, - model_parallel=None, - pipeline_stage=None, - micro_batch_num=None): + merges_file=None, + batch_size=None): """main function.""" assert os.path.exists(config) and config.endswith(('.yaml', '.yml')) # init config config = MindFormerConfig(os.path.realpath(config)) - run_mode = config.run_mode if seq_length is not None: config.model.model_config.seq_length = seq_length if mode is not None: config.context.mode = mode if mode: config.recompute_config.recompute = False + if run_mode is not None: + config.run_mode = run_mode if use_parallel is not None: config.use_parallel = use_parallel if device_id is not None: @@ -113,14 +85,10 @@ def main(task='text_generation', config.remote_save_url = remote_save_url if vocab_file is not None: config.processor.tokenizer.vocab_file = vocab_file - if data_parallel is not None: - config.parallel_config.data_parallel = data_parallel - if model_parallel is not None: - config.parallel_config.model_parallel = model_parallel - if pipeline_stage is not None: - config.parallel_config.pipeline_stage = pipeline_stage - if micro_batch_num is not None: - config.parallel_config.micro_batch_num = micro_batch_num + if merges_file is not None: + config.processor.tokenizer.merges_file = merges_file + if batch_size is not None: + config.runner_config.batch_size = batch_size # init context build_context(config) @@ -147,11 +115,15 @@ def main(task='text_generation', elif run_mode == 'predict': trainer = Trainer(args=config, task=task) - result = trainer.predict(input_data=predict_data, - predict_checkpoint=ckpt, - auto_trans_ckpt=config.auto_trans_ckpt, - max_length=int(max_length)) - logger.info(result) + batch_input = [[predict_data for _ in range(config.model.model_config.batch_size)]] + for input_prompt in batch_input: + result = trainer.predict(input_data=input_prompt, + predict_checkpoint=ckpt, + auto_trans_ckpt=config.auto_trans_ckpt, + max_length=int(max_length)) + logger.info(result) + else: + raise ValueError(f'run_mode should be one of [train, finetune, eval, predict], but get {config.run_mode}') if __name__ == "__main__": @@ -189,15 +161,11 @@ if __name__ == "__main__": parser.add_argument('--remote_save_url', default='', type=str, help='whether use optimizer parallel. Default: None') parser.add_argument('--vocab_file', default=None, type=str, - help='tokenizer model') - parser.add_argument('--dp', default=None, type=int, - help='data parallel') - parser.add_argument('--mp', default=None, type=int, - help='model parallel') - parser.add_argument('--pp', default=None, type=int, - help='pipeline stage') - parser.add_argument('--micro_batch_num', default=None, type=int, - help='micro batch num') + help='tokenizer model or vocab_file') + parser.add_argument('--merges_file', default=None, type=str, + help='merges_file') + parser.add_argument('--batch_size', default=None, type=str, + help='batch_size') args = parser.parse_args() main(task=args.task, @@ -217,7 +185,5 @@ if __name__ == "__main__": max_length=args.max_length, remote_save_url=args.remote_save_url, vocab_file=args.vocab_file, - data_parallel=args.dp, - model_parallel=args.mp, - pipeline_stage=args.pp, - micro_batch_num=args.micro_batch_num) + merges_file=args.merges_file, + batch_size=args.batch_size) diff --git a/scripts/msrun_launcher.sh b/scripts/msrun_launcher.sh index 2a4ff34656f0e6c3f80661d8fac48caf05dc89c9..e84e4d00bc4cb375cb52d649dc0534750da8c5bc 100644 --- a/scripts/msrun_launcher.sh +++ b/scripts/msrun_launcher.sh @@ -20,7 +20,7 @@ LOCAL_WORKER=8 MASTER_ADDR="127.0.0.1" MASTER_PORT=8118 NODE_RANK=0 -LOG_DIR="output_13b/msrun_log" +LOG_DIR="output/msrun_log" JOIN="False" CLUSTER_TIME_OUT=600 # export HCCL_BUFFSIZE=2 # HCCL memory usage diff --git a/tests/st/test_baichuan2.py b/tests/st/test_baichuan2.py new file mode 100644 index 0000000000000000000000000000000000000000..525a1cf259ac141b98bb49d3c6e7b4f3d540a4a2 --- /dev/null +++ b/tests/st/test_baichuan2.py @@ -0,0 +1,88 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import pytest +from mindrlhf.models.baichuan2.baichuan2_13b import Baichuan13BDPO +from mindrlhf.models.baichuan2.baichuan2_tokenizer import Baichuan2Tokenizer +from mindformers.tools.download_tools import download_with_progress_bar + +root_path = os.path.dirname(os.path.abspath(__file__)).split('tests')[0] + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +class TestBaichuan2DPO: + @staticmethod + def setup_cmd(scripts_cmd, device_nums): + cmd = f"msrun --worker_num={device_nums} " + \ + f"--local_worker_num={device_nums} " + \ + f"--master_port=8118 " + \ + f"--log_dir=msrun_log " + \ + f"--join=True " + \ + f"--cluster_time_out=300 " + \ + f"{scripts_cmd}" + return cmd + + @pytest.mark.run(order=1) + def test_baichuan2_dpo_process(self): + download_with_progress_bar( + "https://www.modelscope.cn/models/baichuan-inc/Baichuan2-13B-Base/resolve/master/tokenizer.model", + f"{root_path}/checkpoint_download/baichuan2/tokenizer.model") + + sh_path = os.path.split(os.path.realpath(__file__))[0] + scripts_path = f"{root_path}/mindrlhf/tools/dpo_preprocess.py" + + scripts_cmd = f"{scripts_path} --src={root_path}/datasets/cvalues/source/one.jsonl " + \ + f"--dst={root_path}/datasets/cvalues/source/baichuan.mindrecord " + \ + f"--config={root_path}/model_configs/baichuan_config/process_baichuan2_13b.yaml " + \ + f"--tokenizer={root_path}/checkpoint_download/baichuan2/tokenizer.model " + \ + f"--seq_len=4096 " + \ + f"--dataset_type=cvalues " + \ + f"--save_interval=2" + ret = os.system(self.setup_cmd(scripts_cmd, 8)) + os.system(f"grep -E 'ERROR|error' {sh_path}/msrun_log/worker_0.log -C 3") + assert ret == 0, "msrun failed, please check msrun_log/worker_*.log" + os.system(f"python {root_path}/mindrlhf/tools/dpo_preprocess.py \ + --merge True --src={root_path}/datasets/cvalues/source/ \ + --dst {root_path}/datasets/cvalues/source/baichuan.mindrecord") + + assert os.path.isfile(f"{root_path}/datasets/cvalues/source/baichuan.mindrecord") + + @pytest.mark.run(order=2) + def test_baichuan2_finetune(self): + sh_path = os.path.split(os.path.realpath(__file__))[0] + scripts_path = f"{root_path}/run_dpo.py" + + scripts_cmd = f"{scripts_path} --config={root_path}/model_configs/baichuan_config/run_baichuan2_13b_dpo.yaml " + \ + f"--train_dataset={root_path}/datasets/cvalues/source/baichuan.mindrecord " + + ret = os.system(self.setup_cmd(scripts_cmd, 8)) + os.system(f"grep -E 'ERROR|error' {sh_path}/msrun_log/worker_0.log -C 3") + assert ret == 0, "msrun failed, please check msrun_log/worker_*.log" + + @pytest.mark.run(order=3) + def test_baichuan2_predict(self): + sh_path = os.path.split(os.path.realpath(__file__))[0] + scripts_path = f"{root_path}/examples/dpo/baichuan2/run_baichuan2_generate.py" + + scripts_cmd = f"{scripts_path} --config={root_path}/model_configs/baichuan_config/predict_baichuan2_13b.yaml " + \ + f"--vocab_file={root_path}/checkpoint_download/baichuan2/tokenizer.model " + \ + f"--use_parallel " + \ + f"--predict_data='hello word' " + ret = os.system(self.setup_cmd(scripts_cmd, 8)) + os.system(f"grep -E '[ERROR]|[error]' {sh_path}/msrun_log/worker_0.log -C 3") + assert ret == 0, "msrun failed, please check msrun_log/worker_*.log" diff --git a/tests/st/test_glm4.py b/tests/st/test_glm4.py new file mode 100644 index 0000000000000000000000000000000000000000..05c4d2733430ac861404b32a9bfee3796829c977 --- /dev/null +++ b/tests/st/test_glm4.py @@ -0,0 +1,86 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import pytest +from mindrlhf.models.glm4.glm_dpo import Glm4DPO +from mindrlhf.models.glm4.glm4_tokenizer import ChatGLM4Tokenizer +from mindformers.tools.download_tools import download_with_progress_bar + +root_path = os.path.dirname(os.path.abspath(__file__)).split('tests')[0] + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +class TestGlm4DPO: + @staticmethod + def setup_cmd(scripts_cmd, device_nums): + cmd = f"msrun --worker_num={device_nums} " + \ + f"--local_worker_num={device_nums} " + \ + f"--master_port=8118 " + \ + f"--log_dir=msrun_log " + \ + f"--join=True " + \ + f"--cluster_time_out=300 " + \ + f"{scripts_cmd}" + return cmd + + @pytest.mark.run(order=1) + def test_glm4_dpo_process(self): + download_with_progress_bar("https://www.modelscope.cn/models/ZhipuAI/glm-4-9b/resolve/master/tokenizer.model", + f"{root_path}/checkpoint_download/glm4/tokenizer.model") + + sh_path = os.path.split(os.path.realpath(__file__))[0] + scripts_path = f"{root_path}/mindrlhf/tools/dpo_preprocess.py" + + scripts_cmd = f"{scripts_path} --src={root_path}/datasets/cvalues/source/one.jsonl " + \ + f"--dst={root_path}/datasets/cvalues/source/glm.mindrecord " + \ + f"--config={root_path}/model_configs/glm_config/process_glm4_9b.yaml " + \ + f"--tokenizer={root_path}/checkpoint_download/glm4/tokenizer.model " + \ + f"--seq_len=8192 " + \ + f"--dataset_type=cvalues " + \ + f"--save_interval=2" + ret = os.system(self.setup_cmd(scripts_cmd, 8)) + os.system(f"grep -E 'ERROR|error' {sh_path}/msrun_log/worker_0.log -C 3") + assert ret == 0, "msrun failed, please check msrun_log/worker_*.log" + os.system(f"python {root_path}/mindrlhf/tools/dpo_preprocess.py \ + --merge True --src={root_path}/datasets/cvalues/source/ \ + --dst {root_path}/datasets/cvalues/source/glm.mindrecord") + + assert os.path.isfile(f"{root_path}/datasets/cvalues/source/glm.mindrecord") + + @pytest.mark.run(order=2) + def test_glm4_finetune(self): + sh_path = os.path.split(os.path.realpath(__file__))[0] + scripts_path = f"{root_path}/run_dpo.py" + + scripts_cmd = f"{scripts_path} --config={root_path}/model_configs/glm_config/finetune_glm4_9b.yaml " + \ + f"--train_dataset={root_path}/datasets/cvalues/source/glm.mindrecord " + + ret = os.system(self.setup_cmd(scripts_cmd, 8)) + os.system(f"grep -E 'ERROR|error' {sh_path}/msrun_log/worker_0.log -C 3") + assert ret == 0, "msrun failed, please check msrun_log/worker_*.log" + + @pytest.mark.run(order=3) + def test_glm4_predict(self): + sh_path = os.path.split(os.path.realpath(__file__))[0] + scripts_path = f"{root_path}/run_dpo.py" + + scripts_cmd = f"{scripts_path} --config={root_path}/model_configs/glm_config/predict_glm4_9b.yaml " + \ + f"--vocab_file={root_path}/checkpoint_download/glm4/tokenizer.model " + \ + f"--predict_data='hello word' " + ret = os.system(self.setup_cmd(scripts_cmd, 1)) + os.system(f"grep -E 'ERROR|error' {sh_path}/msrun_log/worker_0.log -C 3") + assert ret == 0, "msrun failed, please check msrun_log/worker_*.log" diff --git a/tests/st/test_qwen2.py b/tests/st/test_qwen2.py new file mode 100644 index 0000000000000000000000000000000000000000..0323809d31fbd19a515335ca7a18669ea9116f08 --- /dev/null +++ b/tests/st/test_qwen2.py @@ -0,0 +1,91 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import pytest +from mindrlhf.models.qwen2.qwen_dpo import Qwen7BDPO +from mindrlhf.models.qwen2.qwen2_tokenizer import Qwen2Tokenizer +from mindformers.tools.download_tools import download_with_progress_bar + +root_path = os.path.dirname(os.path.abspath(__file__)).split('tests')[0] + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend91setup_cmd0b_training +@pytest.mark.env_onecard +class TestQwen2DPO: + @staticmethod + def setup_cmd(scripts_cmd, device_nums): + cmd = f"msrun --worker_num={device_nums} " + \ + f"--local_worker_num={device_nums} " + \ + f"--master_port=8118 " + \ + f"--log_dir=msrun_log " + \ + f"--join=True " + \ + f"--cluster_time_out=300 " + \ + f"{scripts_cmd}" + return cmd + + @pytest.mark.run(order=1) + def test_qwen2_dpo_process(self): + download_with_progress_bar("https://www.modelscope.cn/models/Qwen/Qwen2-7B/resolve/master/vocab.json", + f"{root_path}/checkpoint_download/qwen2/vocab.json") + download_with_progress_bar("https://www.modelscope.cn/models/Qwen/Qwen2-7B/resolve/master/merges.txt", + f"{root_path}/checkpoint_download/qwen2/merges.txt") + + sh_path = os.path.split(os.path.realpath(__file__))[0] + scripts_path = f"{root_path}/mindrlhf/tools/dpo_preprocess.py" + + scripts_cmd = f"{scripts_path} --src={root_path}/datasets/cvalues/source/one.jsonl " + \ + f"--dst={root_path}/datasets/cvalues/source/qwen.mindrecord " + \ + f"--config={root_path}/model_configs/qwen_config/process_qwen2_7b.yaml " + \ + f"--tokenizer={root_path}/checkpoint_download/qwen2/vocab.json " + \ + f"--merges_file={root_path}/checkpoint_download/qwen2/merges.txt " + \ + f"--seq_len=4097 " + \ + f"--dataset_type=cvalues " + \ + f"--save_interval=2" + ret = os.system(self.setup_cmd(scripts_cmd, 8)) + os.system(f"grep -E 'ERROR|error' {sh_path}/msrun_log/worker_0.log -C 3") + assert ret == 0, "msrun failed, please check msrun_log/worker_*.log" + os.system(f"python {root_path}/mindrlhf/tools/dpo_preprocess.py \ + --merge True --src={root_path}/datasets/cvalues/source/ \ + --dst {root_path}/datasets/cvalues/source/qwen.mindrecord") + + assert os.path.isfile(f"{root_path}/datasets/cvalues/source/qwen.mindrecord") + + @pytest.mark.run(order=2) + def test_qwen2_finetune(self): + sh_path = os.path.split(os.path.realpath(__file__))[0] + scripts_path = f"{root_path}/run_dpo.py" + + scripts_cmd = f"{scripts_path} --config={root_path}/model_configs/qwen_config/finetune_qwen2_7b_dpo.yaml " + \ + f"--train_dataset={root_path}/datasets/cvalues/source/qwen.mindrecord " + + ret = os.system(self.setup_cmd(scripts_cmd, 8)) + os.system(f"grep -E 'ERROR|error' {sh_path}/msrun_log/worker_0.log -C 3") + assert ret == 0, "msrun failed, please check msrun_log/worker_*.log" + + @pytest.mark.run(order=3) + def test_qwen2_predict(self): + sh_path = os.path.split(os.path.realpath(__file__))[0] + scripts_path = f"{root_path}/run_dpo.py" + + scripts_cmd = f"{scripts_path} --config={root_path}/model_configs/qwen_config/predict_qwen2_7b.yaml " + \ + f"--vocab_file={root_path}/checkpoint_download/qwen2/vocab.json " + \ + f"--merges_file={root_path}/checkpoint_download/qwen2/merges.txt " + \ + f"--predict_data='hello word' " + + ret = os.system(self.setup_cmd(scripts_cmd, 4)) + os.system(f"grep -E 'ERROR|error' {sh_path}/msrun_log/worker_0.log -C 3") + assert ret == 0, "msrun failed, please check msrun_log/worker_*.log" diff --git a/tests/st/test_qwen2_5.py b/tests/st/test_qwen2_5.py new file mode 100644 index 0000000000000000000000000000000000000000..4d3b3b0edfb19da06ac9c4770764805c9c7a0a93 --- /dev/null +++ b/tests/st/test_qwen2_5.py @@ -0,0 +1,92 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os +import pytest +from mindrlhf.models.qwen2_5.qwen_dpo import Qwen2_5_7BDPO +from mindrlhf.models.qwen2_5.qwen2_5_tokenizer import Qwen2_5Tokenizer +from mindformers.tools.download_tools import download_with_progress_bar + + +root_path = os.path.dirname(os.path.abspath(__file__)).split('tests')[0] + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend910b_training +@pytest.mark.env_onecard +class TestQwen2_5DPO: + @staticmethod + def setup_cmd(scripts_cmd,device_nums): + cmd = f"msrun --worker_num={device_nums} " + \ + f"--local_worker_num={device_nums} " + \ + f"--master_port=8118 " + \ + f"--log_dir=msrun_log " + \ + f"--join=True " + \ + f"--cluster_time_out=300 " + \ + f"{scripts_cmd}" + return cmd + + @pytest.mark.run(order=1) + def test_qwen2_5_dpo_process(self): + download_with_progress_bar("https://www.modelscope.cn/models/Qwen/Qwen2.5-7B/resolve/master/vocab.json", + f"{root_path}/checkpoint_download/qwen2_5/vocab.json") + download_with_progress_bar("https://www.modelscope.cn/models/Qwen/Qwen2.5-7B/resolve/master/merges.txt", + f"{root_path}/checkpoint_download/qwen2_5/merges.txt") + + sh_path = os.path.split(os.path.realpath(__file__))[0] + scripts_path = f"{root_path}/mindrlhf/tools/dpo_preprocess.py" + + scripts_cmd = f"{scripts_path} --src={root_path}/datasets/cvalues/source/one.jsonl " + \ + f"--dst={root_path}/datasets/cvalues/source/qwen.mindrecord " + \ + f"--config={root_path}/model_configs/qwen_config/process_qwen2_5_7b.yaml " + \ + f"--tokenizer={root_path}/checkpoint_download/qwen2_5/vocab.json " + \ + f"--merges_file={root_path}/checkpoint_download/qwen2_5/merges.txt " + \ + f"--seq_len=4097 " + \ + f"--dataset_type=cvalues " + \ + f"--save_interval=2" + ret = os.system(self.setup_cmd(scripts_cmd,8)) + os.system(f"grep -E 'ERROR|error' {sh_path}/msrun_log/worker_0.log -C 3") + assert ret == 0, "msrun failed, please check msrun_log/worker_*.log" + os.system(f"python {root_path}/mindrlhf/tools/dpo_preprocess.py \ + --merge True --src={root_path}/datasets/cvalues/source/ \ + --dst {root_path}/datasets/cvalues/source/qwen.mindrecord") + + assert os.path.isfile(f"{root_path}/datasets/cvalues/source/qwen.mindrecord") + + @pytest.mark.run(order=2) + def test_qwen2_5_finetune(self): + sh_path = os.path.split(os.path.realpath(__file__))[0] + scripts_path = f"{root_path}/run_dpo.py" + + scripts_cmd = f"{scripts_path} --config={root_path}/model_configs/qwen_config/finetune_qwen2_5_7b_dpo.yaml " + \ + f"--train_dataset={root_path}/datasets/cvalues/source/qwen.mindrecord " + + ret = os.system(self.setup_cmd(scripts_cmd,8)) + os.system(f"grep -E 'ERROR|error' {sh_path}/msrun_log/worker_0.log -C 3") + assert ret == 0, "msrun failed, please check msrun_log/worker_*.log" + + @pytest.mark.run(order=3) + def test_qwen2_5_predict(self): + sh_path = os.path.split(os.path.realpath(__file__))[0] + scripts_path = f"{root_path}/run_dpo.py" + + scripts_cmd = f"{scripts_path} --config={root_path}/model_configs/qwen_config/predict_qwen2_5_7b.yaml " + \ + f"--vocab_file={root_path}/checkpoint_download/qwen2_5/vocab.json " + \ + f"--merges_file={root_path}/checkpoint_download/qwen2_5/merges.txt " + \ + f"--predict_data='hello word' " + + ret = os.system(self.setup_cmd(scripts_cmd, 4)) + os.system(f"grep -E 'ERROR|error' {sh_path}/msrun_log/worker_0.log -C 3") + assert ret == 0, "msrun failed, please check msrun_log/worker_*.log"