diff --git a/python/paddle/distributed/auto_parallel/operators/__init__.py b/python/paddle/distributed/auto_parallel/operators/__init__.py index 5502cb3191a483bb21932375e3c54647495cbc95..c28b7930124dd6bec09716ea3a2c84ca6c4eff30 100644 --- a/python/paddle/distributed/auto_parallel/operators/__init__.py +++ b/python/paddle/distributed/auto_parallel/operators/__init__.py @@ -24,3 +24,4 @@ from . import dist_softmax from . import dist_transpose from . import dist_default from . import dist_check_finite_and_unscale +from . import dist_update_loss_scaling diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index 5685c40a3227b6bfba2d6b0f70395bd20bec4514..8f1ba33f544fb35e2935dcf0d178f6c7e86cdd48 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License -DISTRIBUTED_OPERATORS = {} +from ..dist_attribute import OperatorDistributedAttribute +_g_distributed_operator_impl_registries = {} +BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'} -class DistributedOperator: + +class DistributedOperatorImplContainer: def __init__(self): self._impls = [] self._name = None @@ -47,67 +50,63 @@ class DistributedOperatorImpl: def get_name(self): return self._name - def is_process_mesh_compatible(self, op_dist_attr): + def is_input_compatible(self, dist_op): raise NotImplementedError("Please Implement this method in Subclass.") - def is_input_compatible(self, op_dist_attr): + def is_output_compatible(self, dist_op): raise NotImplementedError("Please Implement this method in Subclass.") - def is_output_compatible(self, op_dist_attr): - raise NotImplementedError("Please Implement this method in Subclass.") + def is_compatible(self, dist_op): + return self.is_input_compatible(dist_op) and \ + self.is_output_compatible(dist_op) - def is_compatible(self, op_dist_attr): - return self.is_process_mesh_compatible(op_dist_attr) \ - and self.is_input_compatible(op_dist_attr) \ - and self.is_output_compatible(op_dist_attr) + def is_auto_compatible(self, dist_op): + raise NotImplementedError("Please Implement this method in Subclass.") - def update_dims_mapping(self, op_dist_attr): + def update_dims_mapping(self, dist_op): raise NotImplementedError("Please Implement this method in Subclass.") -def register_distributed_operator(name, dist_op): - global DISTRIBUTED_OPERATORS - DISTRIBUTED_OPERATORS[name] = dist_op +def register_distributed_operator_impl_container(name, dist_op_impl_container): + global _g_distributed_operator_impl_registries + _g_distributed_operator_impl_registries[name] = dist_op_impl_container -def get_distributed_operator(name): - global DISTRIBUTED_OPERATORS - return DISTRIBUTED_OPERATORS.get(name, None) +def get_distributed_operator_impl_container(name): + global _g_distributed_operator_impl_registries + return _g_distributed_operator_impl_registries.get(name, None) def register_distributed_operator_impl(name, dist_impl): - dist_op = get_distributed_operator(name) - if dist_op is not None: - dist_op.register_impl(dist_impl) + dist_op_impl_container = get_distributed_operator_impl_container(name) + if dist_op_impl_container is not None: + dist_op_impl_container.register_impl(dist_impl) else: - assert False, "Must register distributed operator first." + assert False, "Must register distributed operator registry first." def get_distributed_operator_impl(name, impl_idx): - global DISTRIBUTED_OPERATORS - return DISTRIBUTED_OPERATORS[name].get_impl(impl_idx) + global _g_distributed_operator_impl_registries + return _g_distributed_operator_impl_registries[name].get_impl(impl_idx) -def find_best_compatible_distributed_operator_impl(name, op_dist_attr, - fwd=True): +def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True): """ Here just return the first compatible implemention. This will be improved by cost model in the future. """ - dist_op = get_distributed_operator(name) - if dist_op is None: + dist_op_impl_container = get_distributed_operator_impl_container(name) + if dist_op_impl_container is None: return None, -1 compatible_impls = [] - impls = dist_op.get_impls() + impls = dist_op_impl_container.get_impls() if fwd: for idx, impl in enumerate(impls): - if impl.is_process_mesh_compatible(op_dist_attr) \ - and impl.is_input_compatible(op_dist_attr): + if impl.is_input_compatible(dist_op): compatible_impls.append((impl, idx)) else: for idx, impl in enumerate(impls): - if impl.is_process_mesh_compatible(op_dist_attr) \ - and impl.is_output_compatible(op_dist_attr): + if impl.is_output_compatible(dist_op): compatible_impls.append((impl, idx)) if compatible_impls: @@ -118,48 +117,78 @@ def find_best_compatible_distributed_operator_impl(name, op_dist_attr, return best_compatible_impl, idx -def copy_distributed_attr_for_var(src_op_dist_attr, var, src_var): - """ - copy src var's dist_attr to dst var - """ - import copy - - auto_paralle_context = src_op_dist_attr.get_owner_context() - dist_attr = copy.deepcopy( - auto_paralle_context.get_tensor_distributed_attr_for_program(src_var)) - dist_attr._owner_tensor = var - dist_attr._owner_context = auto_paralle_context.get_tensor_distributed_attr_for_program( - src_var)._owner_context - auto_paralle_context.set_tensor_distributed_attr_for_program(var, dist_attr) - - -def copy_distributed_attr_for_dist_op(dist_op, dst_block, src_op_dist_attr): - """ - copy src op's dist_attr to dst dist op - """ - from ..attribute import OperatorDistributedAttribute - - auto_paralle_context = src_op_dist_attr.get_owner_context() - op_dist_attr = OperatorDistributedAttribute(dist_op, auto_paralle_context) - auto_paralle_context._copy_distributed_attr_from_op_desc(dist_op.desc, - op_dist_attr) - auto_paralle_context.set_op_distributed_attr_for_program(dist_op, - op_dist_attr) - - op_dist_attr.set_process_mesh(src_op_dist_attr.get_process_mesh()) - op_dist_attr.set_impl_idx(src_op_dist_attr.get_impl_idx()) - - for input_varname in dist_op.desc.input_arg_names(): - input_var = dst_block.var(input_varname) - tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program( - input_var) - tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() - op_dist_attr.set_input_dims_mapping(input_varname, tensor_dims_mapping) - - for output_varname in dist_op.desc.output_arg_names(): - output_var = dst_block.var(output_varname) - tensor_dist_attr = auto_paralle_context.get_tensor_distributed_attr_for_program( - output_var) - tensor_dims_mapping = tensor_dist_attr.get_dims_mapping() - op_dist_attr.set_output_dims_mapping(output_varname, - tensor_dims_mapping) +def is_parameter_related(varname, block): + if ".cast_fp" in varname: + varname = varname[:varname.index(".cast_fp")] + assert block.has_var(varname) + var = block.var(varname) + return var.is_parameter + + +def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr): + var_shape = block.var(src_var.name).shape + var_topoloy = src_var_dist_attr.process_mesh.topology + var_dims_mapping = src_var_dist_attr.dims_mapping + + complete_shape = [] + for idx, shape in enumerate(var_shape): + if var_dims_mapping[idx] == -1: + complete_shape.append(shape) + else: + new_shape = shape * var_topoloy[var_dims_mapping[idx]] + complete_shape.append(new_shape) + + exact_shape = [] + input_topology = op_input_dist_attr.process_mesh.topology + input_dims_mapping = op_input_dist_attr.dims_mapping + for idx, shape in enumerate(complete_shape): + if input_dims_mapping[idx] == -1: + exact_shape.append(shape) + else: + new_shape = shape // input_topology[input_dims_mapping[idx]] + exact_shape.append(new_shape) + + return exact_shape + + +def set_comm_op_dist_attr_for_program(new_op, process_mesh, tensor_dist_attr, + ctx): + assert process_mesh is not None + assert tensor_dist_attr is not None + + new_op_dist_attr = OperatorDistributedAttribute() + new_op_dist_attr.process_mesh = process_mesh + for input_varname in new_op.desc.input_arg_names(): + new_op_dist_attr.set_input_dist_attr(input_varname, tensor_dist_attr) + for output_varname in new_op.desc.output_arg_names(): + new_op_dist_attr.set_output_dist_attr(output_varname, tensor_dist_attr) + ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr) + + +def naive_copy_op_dist_attr_for_program(new_op, ref_op, ctx): + + ref_dist_attr = ctx.get_op_dist_attr_for_program(ref_op) + new_op_dist_attr = OperatorDistributedAttribute() + new_op_dist_attr.process_mesh = ref_dist_attr.process_mesh + + for input_name in ref_op.input_names: + assert input_name in new_op.input_names + assert len(ref_op.input(input_name)) == 1 + assert len(new_op.input(input_name)) == 1 + + ref_tensor_dist_attr = ref_dist_attr.get_input_dist_attr( + ref_op.input(input_name)[0]) + new_op_dist_attr.set_input_dist_attr( + new_op.input(input_name)[0], ref_tensor_dist_attr) + + for output_name in ref_op.output_names: + assert output_name in new_op.output_names + assert len(ref_op.output(output_name)) == 1 + assert len(new_op.output(output_name)) == 1 + + ref_tensor_dist_attr = ref_dist_attr.get_output_dist_attr( + ref_op.output(output_name)[0]) + new_op_dist_attr.set_output_dist_attr( + new_op.output(output_name)[0], ref_tensor_dist_attr) + + ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py b/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py new file mode 100644 index 0000000000000000000000000000000000000000..56782bec0856a79e3971037974110d51c84e719f --- /dev/null +++ b/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py @@ -0,0 +1,134 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +from .common import DistributedOperatorImplContainer +from .common import DistributedOperatorImpl +from .common import register_distributed_operator_impl_container +from .common import register_distributed_operator_impl +from ..utils import set_dist_op_desc_original_id + + +class DistributedUpdateLossScaling(DistributedOperatorImplContainer): + def __init__(self, name): + super(DistributedUpdateLossScaling, self).__init__() + self._name = name + + +register_distributed_operator_impl_container( + "update_loss_scaling", DistributedUpdateLossScaling("update_loss_scaling")) + + +class DistributedUpdateLossScalingImpl(DistributedOperatorImpl): + def __init__(self, name): + super(DistributedUpdateLossScalingImpl, self).__init__() + self._name = name + self._forward_implemented = False + self._backward_implemented = True + + def is_input_compatible(self, dist_op): + raise RuntimeError( + "DistributedUpdateLossScalingImpl's is_input_compatible should not be called !" + ) + + def is_output_compatible(self, dist_op): + raise RuntimeError( + "DistributedUpdateLossScalingImpl's is_output_compatible should not be called !" + ) + + def update_dims_mapping(self, dist_op): + raise RuntimeError( + "DistributedUpdateLossScalingImpl's update_dims_mapping should not be called !" + ) + + @staticmethod + def forward(ctx, *args, **kwargs): + raise RuntimeError( + "DistributedUpdateLossScalingImpl's forward should not be called !") + + @staticmethod + def backward(ctx, *args, **kwargs): + + # the backward function only filte the gradient with current rank id + dist_op_context = ctx.dist_op_context + main_block = dist_op_context.get_dst_main_program().global_block() + backward_op = dist_op_context.get_cur_src_op() + rank_id = dist_op_context.get_rank_id() + dist_attr = ctx.get_op_dist_attr_for_program(backward_op) + assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( + str(backward_op)) + + assert rank_id in dist_attr.process_mesh.processes + + assert 'X' in kwargs, "input [{}] is not given".format('X') + assert 'FoundInfinite' in kwargs, "input [{}] is not given".format( + 'FoundInfinite') + assert 'PrevLossScaling' in kwargs, "input [{}] is not given".format( + 'PrevLossScaling') + assert 'InGoodSteps' in kwargs, "input [{}] is not given".format( + 'InGoodSteps') + assert 'InBadSteps' in kwargs, "input [{}] is not given".format( + 'InBadSteps') + + assert 'Out' in kwargs, "output [{}] is not given".format('Out') + assert 'LossScaling' in kwargs, "output [{}] is not given".format( + 'LossScaling') + assert 'OutGoodSteps' in kwargs, "input [{}] is not given".format( + 'OutGoodSteps') + assert 'OutBadSteps' in kwargs, "input [{}] is not given".format( + 'OutBadSteps') + + assert len(kwargs['FoundInfinite']) == 1, \ + "update_loss_scaling input FoundInfinite take 1 variable but got {}".format( + kwargs['FoundInfinite']) + assert len(kwargs['PrevLossScaling']) == 1, \ + "update_loss_scaling input PrevLossScaling take 1 variable but got {}".format( + kwargs['PrevLossScaling']) + assert len(kwargs['InGoodSteps']) == 1, \ + "update_loss_scaling input InGoodSteps take 1 variable but got {}".format( + kwargs['InGoodSteps']) + assert len(kwargs['InBadSteps']) == 1, \ + "update_loss_scaling input InBadSteps take 1 variable but got {}".format( + kwargs['InBadSteps']) + assert len(kwargs['LossScaling']) == 1, \ + "update_loss_scaling output LossScaling take 1 variable but got {}".format( + kwargs['LossScaling']) + assert len(kwargs['OutGoodSteps']) == 1, \ + "update_loss_scaling output OutGoodSteps take 1 variable but got {}".format( + kwargs['OutGoodSteps']) + assert len(kwargs['OutBadSteps']) == 1, \ + "update_loss_scaling output OutBadSteps take 1 variable but got {}".format( + kwargs['OutBadSteps']) + + assert len(kwargs['X']) == len(kwargs['Out']), \ + "update_loss_scaling got [{}] X and [{}] Out, which are supposed to be equal".format( + len(kwargs['X']), len(kwargs['Out'])) + + filter_vars = [] + for varname in kwargs['X']: + if rank_id in ctx.get_tensor_dist_attr_for_program( + main_block.var(varname)).process_mesh.processes: + filter_vars.append(varname) + + # replicate op in dist program + dist_op_desc = main_block.desc.append_op() + dist_op_desc.copy_from(backward_op.desc) + set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx) + dist_op_desc.set_input('X', filter_vars) + dist_op_desc.set_output('Out', filter_vars) + main_block._sync_with_cpp() + + +register_distributed_operator_impl( + "update_loss_scaling", + DistributedUpdateLossScalingImpl("update_loss_scaling"))