diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..b194bcc3de6b5bcbf8b26876e7c1d8c3b4b0b2b6 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -0,0 +1,476 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import copy +from collections import defaultdict +from paddle.fluid import framework +from paddle.fluid import core +from .dist_attribute import TensorDistributedAttribute +from .dist_attribute import OperatorDistributedAttribute +from .dist_tensor import DistributedTensor +from .dist_op import DistributedOperator +from .process_mesh import ProcessMesh + +# There always exists a default context for user. And user can set it to another one. +_g_default_distributed_context = None + + +def get_default_distributed_context(): + global _g_default_distributed_context + if _g_default_distributed_context is None: + dist_context = DistributedContext() + set_default_distributed_context(dist_context) + return _g_default_distributed_context + + +def set_default_distributed_context(dist_context): + global _g_default_distributed_context + _g_default_distributed_context = dist_context + + +class DistributedContext: + """ + DistributedContext is used to collect related distributed information for program and graph. + One auto-parallel run should use its own DistributedContext to avoid interfering other run. + """ + + def __init__(self, program=None): + # Program related data members + self._serial_program = program + self._is_initialized_for_program = False + self._dist_tensors_for_program = {} + self._dist_ops_for_program = {} + # Graph related data members + self._is_initialized_for_graph = False + self._serial_graph = None + self._dist_tensors_for_graph = {} + self._dist_ops_for_graph = {} + self._node_id_to_tensor_id = {} + self._node_id_to_op_id = {} + # Other data members + self._dist_op_context = DistributedOperatorContext() + self._process_meshes = [] + + # Distributed programs + self._dist_main_programs = {} + self._dist_startup_programs = {} + + @property + def serial_program(self): + return self._serial_program + + @property + def serial_graph(self): + return self._serial_graph + + @serial_program.setter + def serial_program(self, program): + assert self._serial_program is None, \ + "This distributed context has already been realted to a serial program" + self._serial_program = program + + @property + def process_meshes(self): + return self._process_meshes + + @property + def dist_op_context(self): + return self._dist_op_context + + @property + def dist_main_programs(self): + return self._dist_main_programs + + @property + def dist_startup_programs(self): + return self._dist_startup_programs + + def add_process_mesh(self, process_mesh): + assert isinstance(process_mesh, ProcessMesh), \ + 'The type of dim_mapping must be ProcessMesh.' + if process_mesh not in self.process_meshes: + self._process_meshes.append(process_mesh) + + def add_dist_tensor_for_program(self, dist_tensor): + inner_serial_tensor = dist_tensor.serial_tensor + inner_serial_tensor_id = inner_serial_tensor.desc.id() + self._dist_tensors_for_program[inner_serial_tensor_id] = dist_tensor + + def add_dist_op_for_program(self, dist_op): + inner_serial_op = dist_op.serial_op + inner_serial_op_id = inner_serial_op.desc.id() + self._dist_ops_for_program[inner_serial_op_id] = dist_op + + def get_dist_tensor_for_program(self, serial_tensor): + serial_tensor_id = serial_tensor.desc.id() + dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None) + if dist_tensor: + return dist_tensor + else: + serial_tensor_id = serial_tensor.desc.original_id() + dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, + None) + if dist_tensor: + return dist_tensor + else: + return None + + def get_dist_tensor_for_graph(self, serial_tensor_node): + serial_tensor_node_id = serial_tensor_node.id() + return self._dist_tensors_for_graph.get(serial_tensor_node_id, None) + + def get_dist_op_for_program(self, serial_op): + serial_op_id = serial_op.desc.id() + dist_op = self._dist_ops_for_program.get(serial_op_id, None) + if dist_op: + return dist_op + else: + serial_op_id = serial_op.desc.original_id() + dist_op = self._dist_ops_for_program.get(serial_op_id, None) + if dist_op: + return dist_op + else: + return None + + def del_dist_op_for_program(self, serial_tensor): + serial_tensor_id = serial_tensor.desc.id() + if self._dist_ops_for_program.get(serial_tensor_id, None): + del self._dist_ops_for_program[serial_tensor_id] + + def get_dist_op_for_graph(self, serial_op_node): + serial_op_node_id = serial_op_node.id() + return self._dist_ops_for_graph.get(serial_op_node_id, None) + + def get_tensor_dist_attr_for_program(self, serial_tensor): + serial_tensor_id = serial_tensor.desc.id() + dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None) + if dist_tensor: + return dist_tensor.dist_attr + else: + serial_tensor_id = serial_tensor.desc.original_id() + dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, + None) + if dist_tensor: + return dist_tensor.dist_attr + else: + return None + + def set_tensor_dist_attr_for_program(self, serial_tensor, dist_attr): + dist_tensor = DistributedTensor(serial_tensor, dist_attr) + self.add_dist_tensor_for_program(dist_tensor) + + def get_tensor_dist_attr_for_graph(self, serial_tensor_node): + serial_tensor_node_id = serial_tensor_node.id() + dist_tensor = self._dist_tensors_for_graph.get(serial_tensor_node_id, + None) + if dist_tensor: + return dist_tensor.dist_attr + else: + return None + + def get_op_dist_attr_for_program(self, serial_op): + serial_op_id = serial_op.desc.id() + dist_op = self._dist_ops_for_program.get(serial_op_id, None) + if dist_op: + return dist_op.dist_attr + else: + serial_op_id = serial_op.desc.original_id() + dist_op = self._dist_ops_for_program.get(serial_op_id, None) + if dist_op: + return dist_op.dist_attr + else: + return None + + def set_op_dist_attr_for_program(self, serial_op, dist_attr): + dist_op = DistributedOperator(serial_op, dist_attr) + self.add_dist_op_for_program(dist_op) + + def get_op_dist_attr_for_graph(self, serial_op_node): + serial_op_node_id = serial_op_node.id() + dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None) + if dist_op: + return dist_op.dist_attr + else: + return None + + def init_dist_attr_for_program(self): + assert self._serial_program, \ + "Please set the program of this context before initializing its distribute attributes." + if self._is_initialized_for_program: + return + # Copy the dist tensors and dist ops annotated by users from the default context + default_ctx = get_default_distributed_context() + self._process_meshes = copy.deepcopy(default_ctx.process_meshes) + for block in self._serial_program.blocks: + for tensor in block.vars.values(): + # Copy the distributed tensors in the default context + default_dist_tensor = default_ctx.get_dist_tensor_for_program( + tensor) + if default_dist_tensor and default_ctx is not self: + self.add_dist_tensor_for_program(default_dist_tensor) + current_dist_tensor = self.get_dist_tensor_for_program(tensor) + if current_dist_tensor is None: + dist_tensor = DistributedTensor(tensor) + self.add_dist_tensor_for_program(dist_tensor) + for op in block.ops: + # Copy the distributed operators in the default context + default_dist_op = default_ctx.get_dist_op_for_program(op) + if default_dist_op and default_ctx is not self: + self.add_dist_op_for_program(default_dist_op) + current_dist_op = self.get_dist_op_for_program(op) + if current_dist_op is None: + dist_op = DistributedOperator(op) + self.add_dist_op_for_program(dist_op) + self._is_initialized_for_program = True + + def init_dist_attr_for_graph(self): + assert self._is_initialized_for_program, \ + "The program must be initialized before initializing the distributed attributes for its graph." + if self._is_initialized_for_graph: + return + # Convert program to graph + self._serial_graph = framework.IrGraph( + core.Graph(self._serial_program.desc)) + all_nodes = self._serial_graph.all_nodes() + for node in all_nodes: + if node.is_var() and node.var() is not None: + dist_tensor = None + tensor_id = node.node.original_desc_id() + for cur_tensor_id, cur_dist_tensor in self._dist_tensors_for_program.items( + ): + if tensor_id == cur_tensor_id \ + or tensor_id == cur_dist_tensor.serial_tensor.desc.original_id(): + dist_tensor = cur_dist_tensor + self._node_id_to_tensor_id[node.id()] = cur_tensor_id + assert dist_tensor is not None, \ + "Tensor must have a distributed tensor after the initialization for program." + serial_tensor_node_id = node.id() + new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor, + dist_tensor.dist_attr) + self._dist_tensors_for_graph[ + serial_tensor_node_id] = new_dist_tensor + if node.is_op() and node.op() is not None: + dist_op = None + op_id = node.node.original_desc_id() + for cur_op_id, cur_dist_op in self._dist_ops_for_program.items( + ): + if op_id == cur_op_id \ + or op_id == cur_dist_op.serial_op.desc.original_id(): + dist_op = cur_dist_op + self._node_id_to_op_id[node.id()] = cur_op_id + assert dist_op is not None, \ + "Operator must have a distributed operator after the initialization for program." + serial_op_node_id = node.id() + new_dist_op = DistributedOperator(dist_op.serial_op, + dist_op.dist_attr) + self._dist_ops_for_graph[serial_op_node_id] = new_dist_op + self._is_initialized_for_graph = True + + def clear_dist_info_for_program(self): + self._dist_tensors_for_program.clear() + self._dist_ops_for_program.clear() + + def clear_dist_info_for_graph(self): + self._dist_tensors_for_graph.clear() + self._dist_ops_for_graph.clear() + + def copy_dist_attr_from_graph_to_program(self): + assert self._is_initialized_for_program and self._is_initialized_for_graph, \ + "Both program and graph must be initialized." + updated_tensors = {} + all_nodes = self._serial_graph.all_nodes() + for node in all_nodes: + if node.is_var() and node.var() is not None: + tensor_id = self._node_id_to_tensor_id[node.id()] + updated = updated_tensors.get(tensor_id, False) + # If a var has multiples var nodes in graph, only use the first one for now + if not updated: + tensor_dist_attr_for_graph = self.get_tensor_dist_attr_for_graph( + node) + dist_tensor_for_program = self._dist_tensors_for_program[ + tensor_id] + dist_tensor_for_program.dist_attr = tensor_dist_attr_for_graph + updated_tensors[tensor_id] = True + if node.is_op() and node.op() is not None: + op_id = self._node_id_to_op_id[node.id()] + op_dist_attr_for_graph = self.get_op_dist_attr_for_graph(node) + dist_op_for_program = self._dist_ops_for_program[op_id] + dist_op_for_program.dist_attr = op_dist_attr_for_graph + + def amend_dist_attr_for_program(self): + for dist_tensor in self._dist_tensors_for_program.values(): + serial_tensor = dist_tensor.serial_tensor + dist_attr = dist_tensor.dist_attr + if serial_tensor.type == core.VarDesc.VarType.READER: + tensor_shape = [] + else: + tensor_shape = serial_tensor.shape + dims_mapping = dist_attr.dims_mapping + process_mesh_shape = dist_attr.process_mesh.topology + # If the dimension of tensor is less than the sharding dimension of process mesh, + # we just amend the dimension mapping to -1. (Is this really OK?) + for i in range(len(tensor_shape)): + if dims_mapping[i] != -1 and tensor_shape[i] > 0 \ + and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]: + dims_mapping[i] = -1 + + for dist_op in self._dist_ops_for_program.values(): + serial_op = dist_op.serial_op + dist_attr = dist_op.dist_attr + for arg_name in serial_op.input_arg_names: + if dist_op.get_serial_input(arg_name) is None: + tensor_shape = [] + else: + if dist_op.get_serial_input(arg_name).type == core.VarDesc.VarType.READER \ + or dist_op.serial_op.type == "create_py_reader": + tensor_shape = [] + else: + tensor_shape = dist_op.get_serial_input(arg_name).shape + dims_mapping = dist_attr.get_input_dims_mapping(arg_name) + process_mesh_shape = dist_attr.process_mesh.topology + # If the dimension of tensor is less than the sharding dimension of process mesh, + # we just amend the dimension mapping to -1. (Is this really OK?) + for i in range(len(tensor_shape)): + if dims_mapping[i] != -1 and tensor_shape[i] > 0 \ + and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]: + dims_mapping[i] = -1 + for arg_name in serial_op.output_arg_names: + if dist_op.get_serial_output( + arg_name).type == core.VarDesc.VarType.READER: + tensor_shape = [] + else: + tensor_shape = dist_op.get_serial_output(arg_name).shape + dims_mapping = dist_attr.get_output_dims_mapping(arg_name) + process_mesh_shape = dist_attr.process_mesh.topology + # If the dimension of tensor is less than the sharding dimension of process mesh, + # we just amend the dimension mapping to -1. (Is this really OK?) + for i in range(len(tensor_shape)): + if dims_mapping[i] != -1 and tensor_shape[i] > 0 \ + and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]: + dims_mapping[i] = -1 + + def validate_dist_attr_for_program(self): + if not self._is_initialized_for_program: + assert False, \ + "Program must be initialized before validating its distributed attributes" + for block in self.serial_program.blocks: + for tensor in block.vars.values(): + dist_tensor = self.get_dist_tensor_for_program(tensor) + if (dist_tensor is not None) and ( + not dist_tensor.validate_dist_attr()): + assert False, "Tensor {} has a wrong distributed attributes {}.".format( + dist_tensor.serial_tensor.name, dist_tensor.dist_attr) + for op in block.ops: + dist_op = self.get_dist_op_for_program(op) + if (dist_op is not None) and (not dist_op.validate_dist_attr()): + assert False, "Operator {} has a wrong distributed attributes {}.".format( + dist_op.serial_op.type, dist_tensor.dist_attr) + return True + + def __deepcopy__(self, memo): + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + if k == "_serial_program" or k == "_serial_graph" or k == "_dist_main_programs" or k == "_dist_startup_programs": + setattr(result, k, v) + else: + setattr(result, k, copy.deepcopy(v, memo)) + + # update dist tensor's dist_context + for key in result._dist_tensors_for_program.keys(): + result._dist_tensors_for_program[key]._dist_context = result + return result + + +class DistributedOperatorContext: + """ + DistributedOperatorContext is used to create a dist op desc in Program. + Every time to create a new dist op, the context should be updated for it accordingly. + """ + + def __init__(self): + self._dst_main_program = None + self._dst_startup_program = None + self._varname_mapping = None + self._rank_id = None + self._cur_src_op = None + self._cur_dist_attr = None + self.grad_op_id_to_op_id = {} + self.already_init_sync_vars = set() + + def __deepcopy__(self, memo): + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + if k == "_dst_main_program" or k == "_dst_startup_program" or k == "_cur_src_op": + setattr(result, k, v) + else: + setattr(result, k, copy.deepcopy(v, memo)) + return result + + def set_dst_main_program(self, prog): + self._dst_main_program = prog + + def get_dst_main_program(self): + return self._dst_main_program + + def set_dst_startup_program(self, prog): + self._dst_startup_program = prog + + def get_dst_startup_program(self): + return self._dst_startup_program + + def set_varname_mapping(self, mapping): + self._varname_mapping = mapping + + def get_varname_mapping(self): + return self._varname_mapping + + def set_rank_id(self, rank_id): + self._rank_id = rank_id + + def get_rank_id(self): + return self._rank_id + + def set_cur_src_op(self, cur_src_op): + self._cur_src_op = cur_src_op + + def get_cur_src_op(self): + return self._cur_src_op + + def prepare_context(self, src_op): + + self.set_cur_src_op(src_op) + + # build input varname mapping + kinputs = {} + for input_name in src_op.desc.input_names(): + varnames = [] + for varname in src_op.desc.input(input_name): + assert varname in self._varname_mapping + varnames.append(self._varname_mapping[varname]) + kinputs[input_name] = varnames + + # build output varname mapping + koutputs = {} + for output_name in src_op.desc.output_names(): + varnames = [] + for varname in src_op.desc.output(output_name): + assert varname in self._varname_mapping + varnames.append(self._varname_mapping[varname]) + koutputs[output_name] = varnames + + return kinputs, koutputs diff --git a/python/paddle/distributed/auto_parallel/dist_tensor.py b/python/paddle/distributed/auto_parallel/dist_tensor.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..5e3c852699ab6f8dcb92b386989338e5ca3d2c1f 100644 --- a/python/paddle/distributed/auto_parallel/dist_tensor.py +++ b/python/paddle/distributed/auto_parallel/dist_tensor.py @@ -0,0 +1,393 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import copy +import inspect + +import paddle +from paddle.fluid import core +from paddle.fluid.framework import Parameter, Block, Variable +from .dist_attribute import TensorDistributedAttribute +from .dist_attribute import get_tensor_dist_attr_field_keys +from .utils import _linear_idx2coordinate + + +class DistributedTensor: + """ + DistributedTensor represents the distribution of tensor on the process group and + local tensors can be created by DistributedTensor. + Only support even sharding now and uneven sharding will be supported in the future. + Local tensor information can be obtained from the DistributedTensor instance object, + or obtained by the static methods provided by DistributedTensor, + including shard (i.e. the index in the serial tensor), offsets, and sizes. + """ + + @staticmethod + def _validate_sizes_and_dist_attr(sizes, + dims_mapping, + topology, + processes, + rank=None, + shard_sizes=None): + if not (isinstance(sizes, (list, tuple)) and + all(map(lambda x: isinstance(x, int) and x > 0, sizes))): + raise ValueError( + "The sizes must be list or tuple and item in sizes must be non-negative integer, but got {}". + format(sizes)) + if not (isinstance(dims_mapping, (list, tuple)) and all( + map(lambda x: isinstance(x, int) and x >= -1, dims_mapping))): + raise ValueError( + "The dims_mapping must be list or tuple and item in dims_mapping must >= -1, but got {}". + format(dims_mapping)) + if not (isinstance(processes, (list, tuple)) and + all(map(lambda x: isinstance(x, int) and x >= 0, processes))): + raise ValueError( + "The processes must be list or tuple and item in processes must be integer, but got {}". + format(processes)) + if not (isinstance(topology, (list, tuple)) and + all(map(lambda x: isinstance(x, int) and x > 0, topology))): + raise ValueError( + "The topology must be list or tuple and item in topology must be non-negative integer, but got {}". + format(topology)) + if rank is not None and not (isinstance(rank, int) and rank >= 0): + raise ValueError("The rank must >= 0, but got {}".format(rank)) + + # NOTE: Only support even sharding now + if shard_sizes is not None: + raise ValueError("Only support even sharding now.") + + @staticmethod + def get_local_sizes(global_sizes, + dims_mapping, + topology, + processes, + rank=None, + shard_sizes=None): + DistributedTensor._validate_sizes_and_dist_attr( + global_sizes, dims_mapping, topology, processes, rank, shard_sizes) + + local_sizes = [] + # for even sharding, the local sizes of every rank are equal + for idx, item in enumerate(global_sizes): + if dims_mapping[idx] == -1: + local_sizes.append(item) + else: + local_sizes.append(item // topology[dims_mapping[idx]]) + + return local_sizes + + @staticmethod + def get_local_offsets(global_sizes, + dims_mapping, + topology, + processes, + rank, + shard_sizes=None): + local_sizes = DistributedTensor.get_local_sizes( + global_sizes, dims_mapping, topology, processes, rank, shard_sizes) + local_offsets = [] + rank_relatvie = processes.index(rank) + coordinate = _linear_idx2coordinate(topology, rank_relatvie) + + for i in range(len(global_sizes)): + if dims_mapping[i] == -1: + local_offsets.append(0) + else: + local_offsets.append(coordinate[dims_mapping[i]] * + local_sizes[i]) + return local_offsets + + @staticmethod + def get_global_sizes(local_sizes, + dims_mapping, + topology, + processes, + rank=None, + shard_sizes=None): + DistributedTensor._validate_sizes_and_dist_attr( + local_sizes, dims_mapping, topology, processes, rank, shard_sizes) + global_sizes = [] + for idx, item in enumerate(local_sizes): + if dims_mapping[idx] == -1: + global_sizes.append(item) + else: + global_sizes.append(item * topology[dims_mapping[idx]]) + return global_sizes + + @staticmethod + def get_local_shard(global_sizes, + dims_mapping, + topology, + processes, + rank, + shard_sizes=None): + local_offsets = DistributedTensor.get_local_offsets( + global_sizes, dims_mapping, topology, processes, rank, shard_sizes) + local_sizes = DistributedTensor.get_local_sizes( + global_sizes, dims_mapping, topology, processes, rank, shard_sizes) + assert len(local_sizes) == len( + local_offsets + ), "The length of local_sizes must be equal to local_offsets, but got {} and {}.".format( + len(local_sizes), len(local_offsets)) + + local_end_offsets = list( + map(lambda x: x[0] + x[1], zip(local_offsets, local_sizes))) + local_shard = list(zip(local_offsets, local_end_offsets)) + return local_shard + + def __init__(self, serial_tensor, dist_attr=None, dist_context=None): + self._serial_tensor = serial_tensor + self._dist_attr = None + self._batch_dim = 0 + # Reuse the dist_attr setter to initialize _dist_attr + self.dist_attr = dist_attr + self._local_sizes_map = {} + self._local_offsets_map = {} + self._local_shard_map = {} + self._local_tensor_map = {} + + from .dist_context import get_default_distributed_context + self._dist_context = dist_context if dist_context is not None else get_default_distributed_context( + ) + # TODO: Add Automatically to dist_context after initialized and it will be adapted in the future. + # self._dist_context.add_dist_tensor_for_program(self) + + @property + def serial_tensor(self): + return self._serial_tensor + + @property + def dist_attr(self): + return self._dist_attr + + @property + def dist_context(self): + return self._dist_context + + @dist_attr.setter + def dist_attr(self, dist_attr): + if self._dist_attr is None: + self._dist_attr = TensorDistributedAttribute() + self._dist_attr.init(dist_attr) + self._init_default_dist_attr() + + def _init_default_dist_attr(self): + if self._dist_attr.dims_mapping is None: + if self.serial_tensor.type == core.VarDesc.VarType.READER: + tensor_shape = [] + else: + tensor_shape = self._serial_tensor.shape + tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))] + self._dist_attr.dims_mapping = tensor_dims_mapping + + def validate_dist_attr(self): + if self.serial_tensor.type == core.VarDesc.VarType.READER: + return True + tensor_shape = self.serial_tensor.shape + if len(tensor_shape) != len(self.dist_attr.dims_mapping): + return False + for i in range(len(self.dist_attr.dims_mapping)): + if self.dist_attr.dims_mapping[ + i] < -1 or self.dist_attr.dims_mapping[i] >= len( + self.dist_attr.process_mesh.topology): + return False + for i in range(len(self.dist_attr.process_mesh.topology)): + if self.dist_attr.dims_mapping.count(i) > 1: + return False + return True + + def local_sizes(self, rank=None): + rank = paddle.distributed.get_rank() if rank is None else rank + local_sizes = None + if rank in self._local_sizes_map.keys(): + local_sizes = self._local_sizes_map[rank] + else: + global_sizes = self.serial_tensor.shape + dims_mapping = self.dist_attr.dims_mapping + shard_sizes = self.dist_attr.shard_sizes + processes = self.dist_attr.process_mesh.processes + topology = self.dist_attr.process_mesh.topology + local_sizes = DistributedTensor.get_local_sizes( + global_sizes, dims_mapping, topology, processes, rank, + shard_sizes) + self._local_sizes_map[rank] = local_sizes + + return local_sizes + + def local_offsets(self, rank=None): + rank = paddle.distributed.get_rank() if rank is None else rank + local_offsets = None + if rank in self._local_offsets_map.keys(): + local_offsets = self._local_offsets_map[rank] + else: + global_sizes = self.serial_tensor.shape + dims_mapping = self.dist_attr.dims_mapping + shard_sizes = self.dist_attr.shard_sizes + processes = self.dist_attr.process_mesh.processes + topology = self.dist_attr.process_mesh.topology + local_offsets = DistributedTensor.get_local_offsets( + global_sizes, dims_mapping, topology, processes, rank, + shard_sizes) + self._local_offsets_map[rank] = local_offsets + + return local_offsets + + def global_sizes(self): + return self.serial_tensor.shape + + def local_shard(self, rank=None): + rank = paddle.distributed.get_rank() if rank is None else rank + local_shard = None + if rank in self._local_shard_map.keys(): + local_shard = self._local_shard_map[rank] + else: + global_sizes = self.serial_tensor.shape + dims_mapping = self.dist_attr.dims_mapping + shard_sizes = self.dist_attr.shard_sizes + processes = self.dist_attr.process_mesh.processes + topology = self.dist_attr.process_mesh.topology + local_shard = DistributedTensor.get_local_shard( + global_sizes, dims_mapping, topology, processes, rank, + shard_sizes) + self._local_shard_map[rank] = local_shard + + return local_shard + + def new_local_tensor(self, block=None, rank=None, name=None): + """ + Create a new local tensor of serial tensor corresponding to rank. + + Args: + block (Block): The block contains the new tensor. Default value is recommend and it will be created in the block of dist main program corresponding to the serial tensor block id. Default: None. + rank (int): The rank id. Default value is recommend and it will be the current rank. Default: None. + """ + + def _copy_kwargs(serial_tensor): + kwargs = {} + no_need_copy_args = ["self", "block", "shape", "name"] + arg_spec = inspect.getargspec(Variable.__init__) + + for key in arg_spec.args: + # TODO: Check the copied attribute from serial tensor whether valid + if key in no_need_copy_args: + continue + elif key not in kwargs: + if key == "type": + kwargs[key] = serial_tensor.desc.type() + elif key == "dtype": + kwargs[key] = serial_tensor.desc.dtype() + elif key == "lod_level": + kwargs[key] = serial_tensor.desc.lod_level() + elif key == "persistable": + kwargs[key] = serial_tensor.desc.persistable() + elif key == "stop_gradient": + kwargs[key] = serial_tensor.desc.stop_gradient() + elif key == "need_check_feed": + kwargs[key] = serial_tensor.desc.need_check_feed() + # TODO: Get capacity by framework + elif key == "capacity": + continue + else: + kwargs[key] = self.serial_tensor.__dict__[key] + + if isinstance(serial_tensor, Parameter): + kwargs["trainable"] = serial_tensor.trainable + kwargs["optimize_attr"] = serial_tensor.trainable + kwargs["regularizer"] = serial_tensor.regularizer + kwargs["do_model_average"] = serial_tensor.do_model_average + kwargs["need_clip"] = serial_tensor.need_clip + kwargs["is_distributed"] = serial_tensor.is_distributed + kwargs["is_parameter"] = serial_tensor.is_parameter + + return kwargs + + if rank is not None and not (isinstance(rank, int) and rank >= 0): + raise ValueError("The rank must >= 0, but got {}".format(rank)) + if block is not None and not isinstance(block, Block): + raise TypeError("The block must be Block, but got {}.".format( + type(block))) + rank = paddle.distributed.get_rank() if rank is None else rank + + if block is None: + block_id = self.serial_tensor.block.idx + block = self.dist_context.dist_main_programs[rank].block(block_id) + + # copy serial tensor attribute + kwargs = _copy_kwargs(self.serial_tensor) + kwargs["name"] = name + kwargs["shape"] = self.local_sizes(rank) + + if isinstance(self.serial_tensor, Parameter): + kwargs.pop("persistable") + local_tensor = Parameter(block=block, **kwargs) + else: + local_tensor = block.create_var(**kwargs) + + # TODO: Set original id when set original_id is approved + local_tensor.desc.set_original_id(self.serial_tensor.desc.id()) + self._local_tensor_map[rank] = local_tensor + return local_tensor + + def local_tensor(self, rank=None): + rank = paddle.distributed.get_rank() if rank is None else rank + assert rank in self._local_tensor_map, "The rank {} local tensor has not been created.".format( + rank) + return self._local_tensor_map[rank] + + def __deepcopy__(self, memo): + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + if k == "_serial_tensor" or k == "_local_tensor_map": + setattr(result, k, v) + else: + setattr(result, k, copy.deepcopy(v, memo)) + return result + + def __str__(self): + str = "{{tensor name: {}, tensor id: {}".format( + self.serial_tensor.desc.name(), self.serial_tensor.desc.id()) + + # str += ", {}".format(self.dist_attr) + # return str + + if self.dist_attr.is_annotated("process_mesh"): + annotated_str = "annotated" + else: + annotated_str = "non-annotated" + str += ", process_mesh ({}): {}".format(annotated_str, + self.dist_attr.process_mesh) + + str += ", is_parameter: {}".format(self.serial_tensor.is_parameter) + + if self.dist_attr.is_annotated("dims_mapping"): + annotated_str = "annotated" + else: + annotated_str = "non-annotated" + str += ", dims_mapping ({}): {}".format(annotated_str, + self.dist_attr.dims_mapping) + + if self.dist_attr.is_annotated("shard_mask"): + annotated_str = "annotated" + else: + annotated_str = "non-annotated" + str += ", shard_mask ({}): {}".format(annotated_str, None) + + if self.dist_attr.is_annotated("offload_device"): + annotated_str = "annotated" + else: + annotated_str = "non-annotated" + str += ", offload_device ({}): {} }}".format(annotated_str, None) + return str diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 64c247e56d1d31253a7abb710ceb99b630f4bb8f..b46a10c8c79d895235fcc24fd10aab64fcd90241 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -94,6 +94,7 @@ list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_partitioner_gpt) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_searcher) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard) +list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_dist_tensor) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_serial) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_mppp) list(APPEND MIXED_DIST_TEST_OPS test_auto_parallel_reshard_dpmppp) @@ -262,6 +263,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_partitioner_gpt) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_searcher) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard) + LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_dist_tensor) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_serial) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_mppp) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_reshard_dpmppp) @@ -649,6 +651,7 @@ if(WITH_DISTRIBUTE) py_test_modules(test_auto_parallel_partitioner_gpt MODULES test_auto_parallel_partitioner_gpt ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_searcher MODULES test_auto_parallel_searcher ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_reshard MODULES test_auto_parallel_reshard ENVS ${dist_ENVS}) + py_test_modules(test_auto_parallel_dist_tensor MODULES test_auto_parallel_dist_tensor ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_reshard_serial MODULES test_auto_parallel_reshard_serial ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_reshard_mppp MODULES test_auto_parallel_reshard_mppp ENVS ${dist_ENVS}) py_test_modules(test_auto_parallel_reshard_dpmppp MODULES test_auto_parallel_reshard_dpmppp ENVS ${dist_ENVS}) diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_dist_tensor.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_dist_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..b21cbb5ae78bc5eb6e7d9e668ff45bbbbe6e5eed --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_dist_tensor.py @@ -0,0 +1,222 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import unittest + +import paddle +from paddle.fluid import core +import paddle.distributed.auto_parallel as auto +from paddle.distributed import fleet +from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer +from paddle.distributed.auto_parallel.partitioner import Partitioner +from paddle.distributed.auto_parallel.dist_context import DistributedContext +from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor +from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute +import test_auto_parallel_reshard +from test_auto_parallel_reshard import mlp_forward + + +def get_dist_prog(train_program, + startup_program, + dist_context, + rank_id, + complete_train_program=None): + loss, train_program, startup_program = mlp_forward(train_program, + startup_program) + + fleet._user_defined_strategy = fleet.DistributedStrategy() + fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer() + parallelizer = AutoParallelizer(fleet) + parallelizer._dist_context = dist_context + + # serial forward & backward completion + complete_train_program = auto.complete_annotation( + train_program, dist_context + ) if complete_train_program is None else complete_train_program + + # parallelizer._apply_serial_forward_pass(complete_train_program, + # startup_program) + + params_grads = parallelizer._generate_backward( + complete_train_program, + startup_program, + loss, + parameter_list=None, + no_grad_set=None, + callbacks=None) + + # logical partition + partitioner = Partitioner(dist_context, rank_id) + auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition( + complete_train_program, startup_program, params_grads) + + partitioned_optimize_ops = parallelizer._apply_optimize( + auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads) + + return auto_parallel_main_prog, auto_parallel_startup_prog, complete_train_program + + +class TestDistributedTensor(unittest.TestCase): + def test_new_local_tensor(self): + test_auto_parallel_reshard._global_process_mesh = auto.ProcessMesh( + mesh=[0, 1]) + test_auto_parallel_reshard._global_parallel_strategy = "dp" + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + dist_context = DistributedContext() + rank_id = 0 + dist_main_prog, dist_startup_prog, complete_train_program = get_dist_prog( + train_program, startup_program, dist_context, rank_id) + dist_context.dist_main_programs[rank_id] = dist_main_prog + dist_context.dist_startup_programs[rank_id] = dist_startup_prog + name = "layer_norm_1.tmp_2" + dist_tensor = dist_context.get_dist_tensor_for_program( + complete_train_program.global_block().vars[name]) + dist_tensor._dist_context = dist_context + intermediate_var_0 = dist_tensor.new_local_tensor( + name="intermediate_var_0") + self.assertEqual(intermediate_var_0.shape, (2, 1024)) + self.assertEqual(intermediate_var_0.name, "intermediate_var_0") + + rank_id = 1 + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + dist_main_prog, dist_startup_prog, _ = get_dist_prog( + train_program, startup_program, dist_context, rank_id, + complete_train_program) + dist_context.dist_main_programs[rank_id] = dist_main_prog + dist_context.dist_startup_programs[rank_id] = dist_startup_prog + name = "layer_norm_1.tmp_2" + dist_tensor = dist_context.get_dist_tensor_for_program( + complete_train_program.global_block().vars[name]) + dist_tensor._dist_context = dist_context + intermediate_var_1 = dist_tensor.new_local_tensor( + rank=rank_id, name="intermediate_var_1") + self.assertEqual(intermediate_var_0.shape, (2, 1024)) + self.assertEqual(intermediate_var_1.name, "intermediate_var_1") + + name = "linear_0.w_0" + dist_tensor = dist_context.get_dist_tensor_for_program( + complete_train_program.global_block().vars[name]) + dist_tensor._dist_context = dist_context + intermediate_var_1 = dist_tensor.new_local_tensor( + rank=rank_id, name="linear_0.w_0_intermediate") + self.assertEqual(intermediate_var_1.shape, (1024, 4096)) + self.assertEqual(intermediate_var_1.name, "linear_0.w_0_intermediate") + + copied_dist_context = copy.deepcopy(dist_context) + self.assertIsNotNone(copied_dist_context) + self.assertEqual( + id(copied_dist_context), + id( + copied_dist_context.get_dist_tensor_for_program( + dist_tensor.serial_tensor).dist_context)) + + def test_static_method(self): + dims_mapping = [1, 0] + processes = [0, 1, 2, 3, 4, 5, 6] + topology = [2, 3] + global_sizes = [6, 6] + + # rank 0 [(0, 2), (0, 3)] + # rank 1 [(2, 4), (0, 3)] + # rank 4 [(2, 4), (3, 6)] + rank = 0 + local_sizes = DistributedTensor.get_local_sizes( + global_sizes, dims_mapping, topology, processes) + self.assertEqual(local_sizes, [2, 3]) + local_offsets = DistributedTensor.get_local_offsets( + global_sizes, dims_mapping, topology, processes, rank) + self.assertEqual(local_offsets, [0, 0]) + local_shard = DistributedTensor.get_local_shard( + global_sizes, dims_mapping, topology, processes, rank) + self.assertEqual(local_shard, [(0, 2), (0, 3)]) + + rank = 1 + local_sizes = DistributedTensor.get_local_sizes( + global_sizes, dims_mapping, topology, processes) + self.assertEqual(local_sizes, [2, 3]) + local_offsets = DistributedTensor.get_local_offsets( + global_sizes, dims_mapping, topology, processes, rank) + self.assertEqual(local_offsets, [2, 0]) + local_shard = DistributedTensor.get_local_shard( + global_sizes, dims_mapping, topology, processes, rank) + self.assertEqual(local_shard, [(2, 4), (0, 3)]) + + rank = 4 + local_sizes = DistributedTensor.get_local_sizes( + global_sizes, dims_mapping, topology, processes) + self.assertEqual(local_sizes, [2, 3]) + local_offsets = DistributedTensor.get_local_offsets( + global_sizes, dims_mapping, topology, processes, rank) + self.assertEqual(local_offsets, [2, 3]) + local_shard = DistributedTensor.get_local_shard( + global_sizes, dims_mapping, topology, processes, rank) + self.assertEqual(local_shard, [(2, 4), (3, 6)]) + + # global sizes + local_sizes = [2, 3] + global_sizes = DistributedTensor.get_global_sizes( + local_sizes, dims_mapping, topology, processes) + self.assertEqual(global_sizes, [6, 6]) + + def test_instance_method(self): + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.dims_mapping = [1, 0] + tensor_dist_attr.process_mesh = auto.ProcessMesh( + mesh=[[0, 1, 2], [3, 4, 5]]) + serial_tensor = paddle.static.data( + name="data", shape=[6, 6], dtype='float32') + dist_tensor = DistributedTensor(serial_tensor, tensor_dist_attr) + + # rank 0 [(0, 2), (0, 3)] + # rank 1 [(2, 4), (0, 3)] + # rank 4 [(2, 4), (3, 6)] + rank = 0 + local_sizes = dist_tensor.local_sizes(rank) + self.assertEqual(local_sizes, [2, 3]) + local_offsets = dist_tensor.local_offsets(rank) + self.assertEqual(local_offsets, [0, 0]) + local_shard = dist_tensor.local_shard(rank) + self.assertEqual(local_shard, [(0, 2), (0, 3)]) + self.assertEqual(local_sizes, dist_tensor.local_sizes(rank)) + self.assertEqual(local_offsets, dist_tensor.local_offsets(rank)) + self.assertEqual(local_shard, dist_tensor.local_shard(rank)) + self.assertEqual(local_sizes, dist_tensor.local_sizes()) + self.assertEqual(local_offsets, dist_tensor.local_offsets()) + self.assertEqual(local_shard, dist_tensor.local_shard()) + + rank = 1 + local_sizes = dist_tensor.local_sizes(rank) + self.assertEqual(local_sizes, [2, 3]) + local_offsets = dist_tensor.local_offsets(rank) + self.assertEqual(local_offsets, [2, 0]) + local_shard = dist_tensor.local_shard(rank) + self.assertEqual(local_shard, [(2, 4), (0, 3)]) + + rank = 4 + local_sizes = dist_tensor.local_sizes(rank) + self.assertEqual(local_sizes, [2, 3]) + local_offsets = dist_tensor.local_offsets(rank) + self.assertEqual(local_offsets, [2, 3]) + local_shard = dist_tensor.local_shard(rank) + self.assertEqual(local_shard, [(2, 4), (3, 6)]) + + global_sizes = dist_tensor.global_sizes() + self.assertEqual(global_sizes, (6, 6)) + + +if __name__ == "__main__": + unittest.main()