From 609fe679fbd164508b3da1a0f6944ceba6d8184b Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Wed, 4 Sep 2024 10:25:45 -0500 Subject: [PATCH] Add initial ring flash attention support --- megatron/data/data_utils.py | 13 +- megatron/initialize.py | 20 ++- megatron/model/gpt2_model.py | 16 +- megatron/model/positional_embeddings.py | 23 ++- megatron/model/transformer.py | 103 ++++++++++++- megatron/model/utils.py | 10 ++ megatron/monkeypatcher.py | 109 ++++++++++++++ megatron/mpu/__init__.py | 7 + megatron/mpu/data.py | 37 ++++- megatron/mpu/initialize.py | 142 +++++++++++++++--- megatron/mpu/mappings.py | 157 +++++++++++++++++++- megatron/mpu/utils.py | 30 ++++ megatron/neox_arguments/arguments.py | 38 ++++- megatron/neox_arguments/neox_args.py | 20 ++- megatron/training.py | 15 +- megatron/utils.py | 27 ++-- requirements/requirements-ringattention.txt | 1 + 17 files changed, 707 insertions(+), 61 deletions(-) create mode 100644 megatron/monkeypatcher.py create mode 100644 requirements/requirements-ringattention.txt diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index bc5754cdb..d2e2bb0f4 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -311,8 +311,12 @@ def build_train_valid_test_data_iterators(neox_args): else: pipe_load = True - # Data loader only on rank 0 of each model parallel group. - if mpu.get_model_parallel_rank() == 0 and pipe_load: + # Data loader only on rank 0 of each model/sequence parallel group. + if ( + mpu.get_model_parallel_rank() == 0 + and pipe_load + and mpu.get_seq_parallel_rank() == 0 + ): # Number of train/valid/test samples. train_iters = neox_args.train_iters eval_iters = (train_iters // neox_args.eval_interval + 1) * neox_args.eval_iters @@ -441,6 +445,11 @@ def build_train_valid_test_data_iterators(neox_args): mpu.get_model_parallel_src_rank(), group=mpu.get_model_parallel_group(), ) + torch.distributed.broadcast( + flags, + mpu.get_seq_parallel_src_rank(), + group=mpu.get_seq_parallel_group(), + ) neox_args.do_train = flags[0].item() neox_args.do_valid = flags[1].item() neox_args.do_test = flags[2].item() diff --git a/megatron/initialize.py b/megatron/initialize.py index 29afe7f9a..8d6c2cd47 100644 --- a/megatron/initialize.py +++ b/megatron/initialize.py @@ -158,16 +158,18 @@ def _initialize_distributed(neox_args): # Setup 3D topology. pp = neox_args.pipe_parallel_size if neox_args.pipe_parallel_size >= 1 else 1 mp = neox_args.model_parallel_size if neox_args.model_parallel_size >= 1 else 1 + sp = ( + neox_args.sequence_parallel_size if neox_args.sequence_parallel_size >= 1 else 1 + ) assert ( - neox_args.world_size % (pp * mp) == 0 - ), f"world_size={neox_args.world_size}, pp={pp}, mp={mp}" - dp = neox_args.world_size // (pp * mp) - - from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology + neox_args.world_size % (pp * mp * sp) == 0 + ), f"world_size={neox_args.world_size}, pp={pp}, mp={mp}, sp={sp}" + dp = neox_args.world_size // (pp * mp * sp) + from deepspeed.runtime.pipe.topology import ProcessTopology - # this does pipe on the most outside, then data, then model. - # PipeModelDataParallelTopology is just a wrapper over ProcessTopology that predefines this order. - topo = PipeModelDataParallelTopology(num_pp=pp, num_mp=mp, num_dp=dp) + # With 4D parallelism, we have 4 dimensions: pipe, data, model, sequence + # So we need to define it manually... + topo = ProcessTopology(axes=["pipe", "data", "model", "seq"], dims=[pp, dp, mp, sp]) # Offset base seeds for the interior pipeline stages. # TODO: adjust last stage too once IO is improved. @@ -186,6 +188,8 @@ def _initialize_distributed(neox_args): else: mpu.initialize_model_parallel( neox_args.model_parallel_size, + neox_args.pipe_parallel_size, + neox_args.sequence_parallel_size, topology=topo, fp32_allreduce=neox_args.fp32_allreduce, ) diff --git a/megatron/model/gpt2_model.py b/megatron/model/gpt2_model.py index 9e643874a..1d26f97bc 100644 --- a/megatron/model/gpt2_model.py +++ b/megatron/model/gpt2_model.py @@ -74,7 +74,21 @@ def cross_entropy(output, labels, _fp16=False): else: losses = mpu.vocab_parallel_cross_entropy(output.float().contiguous(), labels) loss_mask = loss_mask.view(-1) - loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + loss_mask_sum = loss_mask.sum() + if mpu.get_seq_parallel_world_size() > 1: + torch.distributed.all_reduce( + loss_mask_sum, + op=torch.distributed.ReduceOp.SUM, + group=mpu.get_seq_parallel_group(), + ) + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask_sum + torch.distributed.all_reduce( + loss, + op=torch.distributed.ReduceOp.SUM, + group=mpu.get_seq_parallel_group(), + ) + else: + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask_sum return loss diff --git a/megatron/model/positional_embeddings.py b/megatron/model/positional_embeddings.py index fcded9e96..9335fc487 100644 --- a/megatron/model/positional_embeddings.py +++ b/megatron/model/positional_embeddings.py @@ -14,6 +14,7 @@ import torch import math +import megatron.mpu as mpu class SinusoidalPositionalEmbedding(torch.nn.Module): @@ -37,7 +38,13 @@ def forward(self, x, seq_dim=1): class RotaryEmbedding(torch.nn.Module): def __init__( - self, dim, max_seq_len, base=10000, precision=torch.half, save_inv_freqs=False + self, + dim, + max_seq_len, + base=10000, + precision=torch.half, + save_inv_freqs=False, + zigzag=True, ): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) @@ -49,6 +56,7 @@ def __init__( self.max_seq_len = max_seq_len self.base = base self.dim = dim + self.zigzag = zigzag # seq parallel zigzag # precompute cos_cached, sin_cached in fp32 cos_cached, sin_cached, inv_freq = self._prepare_cache( @@ -64,6 +72,19 @@ def _prepare_cache(self, seq_len, precision, base): inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float() / self.dim)) t = torch.arange(seq_len).type_as(inv_freq) + if mpu.get_seq_parallel_world_size() > 1: + if not self.zigzag: + t_chunks = torch.chunk(t, mpu.get_seq_parallel_world_size()) + t = t_chunks[mpu.get_seq_parallel_rank()].contiguous() + else: + t_chunks = torch.chunk(t, 2 * mpu.get_seq_parallel_world_size()) + t = torch.cat( + ( + t_chunks[mpu.get_seq_parallel_rank()], + t_chunks[-(mpu.get_seq_parallel_rank() + 1)], + ), + dim=0, + ).contiguous() freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index c154b09f4..c79e3bb7c 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -464,6 +464,7 @@ def __init__( self.rope_fusion = neox_args.rope_fusion self.attention_type = neox_args.attention_config[layer_number] self.use_flash_attention = self.attention_type == "flash" + self.use_ring_attention = self.attention_type == "ring" self.use_triton = ( self.use_flash_attention and self.pos_emb == "alibi" @@ -472,7 +473,7 @@ def __init__( >= packaging.version.Version("2.4.0.post1") ) ) - self.sparse = self.attention_type not in ("global", "flash") + self.sparse = self.attention_type not in ("global", "flash", "ring") if self.gqa: assert not self.sparse @@ -501,6 +502,12 @@ def __init__( self.flash_triton_fn = flash_attn_unpadded_unpacked_func_triton self.flash_qkv_fn = flash_attn_func self.flash_varlen_qkv_fn = flash_attn_varlen_func + elif self.use_ring_attention: + from ring_flash_attn.zigzag_ring_flash_attn import ( + zigzag_ring_flash_attn_func, + ) + + self.ring_attn_fn = zigzag_ring_flash_attn_func else: self.scale_mask_softmax = FusedScaleMaskSoftmax( input_in_fp16=self.fp16, @@ -748,6 +755,96 @@ def flash_attention(self, query_layer, key_layer, value_layer): return matmul_result + def ring_attention(self, query_layer, key_layer, value_layer): + # [b, np, sq, sk] + output_size = ( + query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0), + ) + + # [sk, b, np, hn] -> [b, sk, np, hn] -> [b * sk, 1, np, hn] + key_layer = key_layer.transpose(0, 1).reshape( + output_size[0], output_size[3], self.num_kv_heads_per_partition, -1 + ) + value_layer = value_layer.transpose(0, 1).reshape( + output_size[0], output_size[3], self.num_kv_heads_per_partition, -1 + ) + + # [sq, b, np, hn] -> [b, sq, np, hn] + query_layer = query_layer.transpose(0, 1).reshape( + output_size[0], output_size[2], output_size[1], -1 + ) + + # only pass in window_size or alibi_slopes kwarg + # if we use Sliding Window Attention / AliBi. + # Flash attn defaults to (-1,-1), or + # does not have this kwarg prior to v2.3.0 + extra_kwargs = ( + {"window_size": (self.sliding_window_width, -1)} + if self.sliding_window_width is not None + else {} + ) + if self.pos_emb == "alibi": + extra_kwargs["alibi_slopes"] = self.alibi_embed.slopes.to( + query_layer.device + ).to(torch.float32) + + if not self.training: + batch_size = output_size[0] + max_seqlen_q = output_size[2] + max_seqlen_k = output_size[3] + + cu_seqlens_q = torch.arange( + 0, + (batch_size + 1) * max_seqlen_q, + step=max_seqlen_q, + dtype=torch.int32, + device=query_layer.device, + ) + + cu_seqlens_k = torch.arange( + 0, + (batch_size + 1) * max_seqlen_k, + step=max_seqlen_k, + dtype=torch.int32, + device=key_layer.device, + ) + + q_shape = query_layer.shape + k_shape = key_layer.shape + v_shape = value_layer.shape + is_causal = max_seqlen_q == max_seqlen_k + output = self.ring_attn_fn( + query_layer, + key_layer, + value_layer, + 0.0, + softmax_scale=None, + causal=is_causal, + group=mpu.get_seq_parallel_group(), + **extra_kwargs, + ) + output = output.reshape(q_shape) + else: + output = self.ring_attn_fn( + query_layer, + key_layer, + value_layer, + self.dropout_p if self.training else 0.0, + softmax_scale=None, + causal=True, + group=mpu.get_seq_parallel_group(), + **extra_kwargs, + ) + + matmul_result = output + # [b, sq, np, hn] -> [b, np, sq, hn] + matmul_result = matmul_result.transpose(1, 2) + + return matmul_result + def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask): # TODO: sparse attn dropout? # TODO: pad to block size @@ -843,7 +940,7 @@ def gqa_project(self, hidden_states, attention_mask, layer_past=None): value_layer = value_layer.view(*new_kv_shape) # if not using Flash attention, we repeat K/V heads to match Q head counts - if not self.use_flash_attention: + if not (self.use_flash_attention or self.use_ring_attention): key_layer = torch.repeat_interleave( key_layer, repeats=int( @@ -957,6 +1054,8 @@ def forward(self, hidden_states, attention_mask, layer_past=None): if self.use_flash_attention: context_layer = self.flash_attention(query_layer, key_layer, value_layer) + elif self.use_ring_attention: + context_layer = self.ring_attention(query_layer, key_layer, value_layer) elif not self.sparse: context_layer = self.attention( query_layer, key_layer, value_layer, layer_past, attention_mask diff --git a/megatron/model/utils.py b/megatron/model/utils.py index c3da2ce8b..0946d041f 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -22,6 +22,16 @@ from megatron.model.fused_softmax import SoftmaxFusionTypes from types import GeneratorType import torch.distributed as dist +from megatron.mpu import ( + get_seq_parallel_group, + get_seq_parallel_src_rank, + get_seq_parallel_rank, + get_seq_parallel_world_size, +) +from megatron.mpu.mappings import ( + _GatherFromSeqParallelRegion, + _ScatterToSeqParallelRegion, +) def get_params_for_weight_decay_optimization(module, neox_args): diff --git a/megatron/monkeypatcher.py b/megatron/monkeypatcher.py new file mode 100644 index 000000000..4b36c5aba --- /dev/null +++ b/megatron/monkeypatcher.py @@ -0,0 +1,109 @@ +from deepspeed import comm as dist + +try: + from torch._six import inf +except ModuleNotFoundError: + from torch import inf +import torch +from collections.abc import Iterable +from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.utils import clip_tensors_by_global_norm + + +def get_global_norm_of_tensors(input_tensors, norm_type=2, mpu=None): + # Overwrite of https://github.com/EleutherAI/DeeperSpeed/blob/main/deepspeed/runtime/utils.py#L866-L901 + # To support sequence parallel + """Get norm of an iterable of tensors. + This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and + added functionality to handle model parallel parameters. Taken from Nvidia Megatron. + Arguments: + input_tensors (Iterable[Tensor]): an iterable of Tensors will have norm computed + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + Returns: + Total norm of the tensors (viewed as a single vector). + """ + + assert isinstance( + input_tensors, Iterable + ), f"expected Iterable type not {type(input_tensors)}" + assert all( + [torch.is_tensor(t) for t in input_tensors] + ), f"expected list of only tensors" + + norm_type = float(norm_type) + if norm_type == inf: + total_norm = max(t.data.abs().max() for t in input_tensors) + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) + if mpu is not None: + dist.all_reduce( + total_norm_cuda, + op=dist.ReduceOp.MAX, + group=mpu.get_model_parallel_group(), + ) + dist.all_reduce( + total_norm_cuda, + op=dist.ReduceOp.MAX, + group=mpu.get_seq_parallel_group(), + ) + total_norm = total_norm_cuda[0].item() + else: + total_norm = sum( + [t.data.float().norm(norm_type).item() ** norm_type for t in input_tensors] + ) + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) + if mpu is not None: + dist.all_reduce( + total_norm_cuda, + op=dist.ReduceOp.SUM, + group=mpu.get_model_parallel_group(), + ) + dist.all_reduce( + total_norm_cuda, + op=dist.ReduceOp.SUM, + group=mpu.get_seq_parallel_group(), + ) + total_norm = total_norm_cuda[0].item() ** (1.0 / norm_type) + + if ( + total_norm == float("inf") + or total_norm == -float("inf") + or total_norm != total_norm + ): + total_norm = -1 + + return total_norm + + +def replace_engine_get_global_norm(engine): + @torch.no_grad() + def step(closure=None): + # to monkeypatch https://github.com/EleutherAI/DeeperSpeed/blob/main/deepspeed/runtime/bf16_optimizer.py#L233-L253 + self = engine.optimizer + if closure is not None: + raise NotImplementedError(f"{self.__class__} does not support closure.") + + all_groups_norm = get_global_norm_of_tensors( + input_tensors=self.get_grads_for_norm(), + mpu=self.mpu, + norm_type=self.norm_type, + ) + self._global_grad_norm = all_groups_norm + + assert all_groups_norm > 0.0 + if self.clip_grad > 0.0: + clip_tensors_by_global_norm( + input_tensors=self.get_grads_for_norm(for_clipping=True), + max_norm=self.clip_grad, + global_norm=all_groups_norm, + mpu=self.mpu, + ) + + self.optimizer.step() + + self.update_lp_params() + + self.clear_hp_grads() + + engine.optimizer.step = step + return engine diff --git a/megatron/mpu/__init__.py b/megatron/mpu/__init__.py index 2365507d9..5fdc26f6e 100644 --- a/megatron/mpu/__init__.py +++ b/megatron/mpu/__init__.py @@ -54,3 +54,10 @@ from .utils import divide from .utils import split_tensor_along_last_dim +from .data import zigzag_data +from .initialize import ( + get_seq_parallel_group, + get_seq_parallel_rank, + get_seq_parallel_world_size, + get_seq_parallel_src_rank, +) diff --git a/megatron/mpu/data.py b/megatron/mpu/data.py index 87e2a9615..d8b1275ce 100644 --- a/megatron/mpu/data.py +++ b/megatron/mpu/data.py @@ -17,6 +17,10 @@ from .initialize import get_model_parallel_group from .initialize import get_model_parallel_rank from .initialize import get_model_parallel_src_rank +from .initialize import get_seq_parallel_src_rank +from .initialize import get_seq_parallel_group +from .initialize import get_seq_parallel_rank +from .initialize import get_seq_parallel_world_size _MAX_DATA_DIM = 4 @@ -38,7 +42,7 @@ def _build_key_size_numel_dictionaries(keys, data): sizes = [0 for _ in range(max_dim) for _ in keys] # Pack the sizes on rank zero. - if get_model_parallel_rank() == 0: + if (get_model_parallel_rank() == 0) and (get_seq_parallel_rank() == 0): offset = 0 for key in keys: assert data[key].dim() < max_dim, "you should increase MAX_DATA_DIM" @@ -52,6 +56,9 @@ def _build_key_size_numel_dictionaries(keys, data): torch.distributed.broadcast( sizes_cuda, get_model_parallel_src_rank(), group=get_model_parallel_group() ) + torch.distributed.broadcast( + sizes_cuda, get_seq_parallel_src_rank(), group=get_seq_parallel_group() + ) # Move back to cpu and unpack. sizes_cpu = sizes_cuda.cpu() @@ -76,7 +83,7 @@ def _build_key_size_numel_dictionaries(keys, data): return key_size, key_numel, total_numel -def broadcast_data(keys, data, datatype): +def broadcast_data(keys, data, datatype, zigzag=False): """Broadcast data from rank zero of each model parallel group to the members of the same model parallel group. @@ -91,7 +98,7 @@ def broadcast_data(keys, data, datatype): key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, data) # Pack on rank zero. - if get_model_parallel_rank() == 0: + if (get_model_parallel_rank() == 0) and (get_seq_parallel_rank() == 0): # Check that all keys have the same data type. _check_data_types(keys, data, datatype) # Flatten the data associated with the keys @@ -107,6 +114,9 @@ def broadcast_data(keys, data, datatype): torch.distributed.broadcast( flatten_data, get_model_parallel_src_rank(), group=get_model_parallel_group() ) + torch.distributed.broadcast( + flatten_data, get_seq_parallel_src_rank(), group=get_seq_parallel_group() + ) # Unpack output = {} @@ -117,4 +127,23 @@ def broadcast_data(keys, data, datatype): output[key] = flatten_data.narrow(0, offset, numel).view(size) offset += numel - return output + return output if not zigzag else {key: zigzag_data(output[key]) for key in keys} + + +def zigzag_data(data, seq_dim=1): + """Zigzag the data along the seq dimension. + Arguments: + data: data dictionary of string keys and cpu tensor values. + seq_dim: the sequence dimension to zigzag. + """ + worldsize = get_seq_parallel_world_size() + # first check if we can just skip it... + if worldsize == 1: + return data + # otherwise prepare for zigzagging + seq_chunks = torch.chunk(data, 2 * worldsize, dim=seq_dim) + data = [ + torch.cat((seq_chunks[i], seq_chunks[-(i + 1)]), dim=seq_dim) + for i in range(worldsize) + ] + return data[get_seq_parallel_rank()].contiguous() diff --git a/megatron/mpu/initialize.py b/megatron/mpu/initialize.py index 19d231524..946ab6e23 100644 --- a/megatron/mpu/initialize.py +++ b/megatron/mpu/initialize.py @@ -28,6 +28,8 @@ _DATA_PARALLEL_GROUP = None # Pipeline parallel group that the current rank belongs to. _PIPE_PARALLEL_GROUP = None +# Sequence parallel group that the current rank belongs to. +_SEQUENCE_PARALLEL_GROUP = None # A group used to sync during the IO process. Usually this is data_parallel_group(), # but with pipeline parallelism it must also involve the last stage (which is not in the @@ -38,7 +40,7 @@ _MPU_WORLD_SIZE = None _MPU_RANK = None -# Used to query 3D topology +# Used to query 4D topology _MPU_TOPOLOGY = None # Get fp32_allreduce flag @@ -50,7 +52,13 @@ def is_unitialized(): return _DATA_PARALLEL_GROUP is None -def initialize_model_parallel(model_parallel_size, topology=None, fp32_allreduce=False): +def initialize_model_parallel( + model_parallel_size, + pipe_parallel_size, + sequence_parallel_size, + topology=None, + fp32_allreduce=False, +): """ Initialize model data parallel groups. @@ -74,9 +82,11 @@ def initialize_model_parallel(model_parallel_size, topology=None, fp32_allreduce # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size = torch.distributed.get_world_size() - if world_size < model_parallel_size: - raise ValueError("world size cannot be smaller than model parallel size") - ensure_divisibility(world_size, model_parallel_size) + if world_size < model_parallel_size * sequence_parallel_size: + raise ValueError( + "world size cannot be smaller than (model parallel size) * (sequence parallel size)" + ) + ensure_divisibility(world_size, model_parallel_size * sequence_parallel_size) rank = torch.distributed.get_rank() global _MPU_TOPOLOGY @@ -125,7 +135,35 @@ def initialize_model_parallel(model_parallel_size, topology=None, fp32_allreduce else: _IO_PARALLEL_GROUP = get_data_parallel_group() - # Build the model parallel groups. + # Build the sequence parallel groups. + global _SEQUENCE_PARALLEL_GROUP + assert ( + _SEQUENCE_PARALLEL_GROUP is None + ), "sequence parallel group is already initialized" + if topology: + # short circuit case without sequence parallelism + if sequence_parallel_size == 1: + for group_rank in range(world_size): + group = torch.distributed.new_group(ranks=[group_rank]) + if rank == 0: + print(f"MPU SP:", [group_rank]) + if rank == group_rank: + _SEQUENCE_PARALLEL_GROUP = group + else: + for sp_group in topology.get_axis_comm_lists("seq"): + group = torch.distributed.new_group(ranks=sp_group) + if rank == 0: + print(f"MPU SP:", sp_group) + if rank in sp_group: + _SEQUENCE_PARALLEL_GROUP = group + else: + for i in range(world_size // sequence_parallel_size): + ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size) + group = torch.distributed.new_group(ranks) + if i == (rank // sequence_parallel_size): + _SEQUENCE_PARALLEL_GROUP = group + + # Build the model parallel groups, per sequence parallel group global _MODEL_PARALLEL_GROUP assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized" if topology: @@ -138,21 +176,36 @@ def initialize_model_parallel(model_parallel_size, topology=None, fp32_allreduce print(f"MPU MP:", [group_rank]) if rank == group_rank: _MODEL_PARALLEL_GROUP = group - return - - for mp_group in topology.get_axis_comm_lists("model"): - group = torch.distributed.new_group(ranks=mp_group) - if rank == 0: - print(f"MPU MP:", mp_group) - if rank in mp_group: - _MODEL_PARALLEL_GROUP = group - + else: + for mp_group in topology.get_axis_comm_lists("model"): + group = torch.distributed.new_group(ranks=mp_group) + if rank == 0: + print(f"MPU MP:", mp_group) + if rank in mp_group: + _MODEL_PARALLEL_GROUP = group else: - for i in range(world_size // model_parallel_size): - ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size) - group = torch.distributed.new_group(ranks) - if i == (rank // model_parallel_size): - _MODEL_PARALLEL_GROUP = group + if model_parallel_size == 1: + for i in range(world_size): + group = torch.distributed.new_group([i]) + if i == rank: + _MODEL_PARALLEL_GROUP = group + else: + for i in range( + world_size // (model_parallel_size * sequence_parallel_size) + ): + sp_mp_groups = [[] for _ in range(sequence_parallel_size)] + for j in range(model_parallel_size): + for k in range(sequence_parallel_size): + rank_ = ( + i * model_parallel_size * sequence_parallel_size + + j * sequence_parallel_size + + k + ) + sp_mp_groups[k].append(rank_) + for sp_mp_group in sp_mp_groups: + group = torch.distributed.new_group(ranks=sp_mp_group) + if i in sp_mp_group: + _MODEL_PARALLEL_GROUP = group global _FP32_ALLREDUCE assert _FP32_ALLREDUCE is None, "fp32_allreduce is already initialized" @@ -184,6 +237,14 @@ def get_io_parallel_group(): return _IO_PARALLEL_GROUP +def get_seq_parallel_group(): + """Get the sequence parallel group the caller rank belongs to.""" + assert ( + _SEQUENCE_PARALLEL_GROUP is not None + ), "sequence parallel group is not initialized" + return _SEQUENCE_PARALLEL_GROUP + + def set_model_parallel_world_size(world_size): """Set the model parallel size""" global _MPU_WORLD_SIZE @@ -216,7 +277,28 @@ def get_model_parallel_src_rank(): """Calculate the global rank corresponding to a local rank zero in the model parallel group.""" global_rank = torch.distributed.get_rank() - local_world_size = get_model_parallel_world_size() + local_world_size = get_model_parallel_world_size() * get_seq_parallel_world_size() + sp_rank = global_rank % get_seq_parallel_world_size() + return ( + global_rank // local_world_size + ) * local_world_size + sp_rank # src is per seq parallel group + + +def get_seq_parallel_world_size(): + """Return world size for the sequence parallel group.""" + return torch.distributed.get_world_size(group=get_seq_parallel_group()) + + +def get_seq_parallel_rank(): + """Return my rank for the sequence parallel group.""" + return torch.distributed.get_rank(group=get_seq_parallel_group()) + + +def get_seq_parallel_src_rank(): + """Calculate the global rank corresponding to a local rank zero + in the sequence parallel group.""" + global_rank = torch.distributed.get_rank() + local_world_size = get_seq_parallel_world_size() return (global_rank // local_world_size) * local_world_size @@ -316,9 +398,27 @@ def destroy_model_parallel(): _MPU_TOPOLOGY = None global _FP32_ALLREDUCE _FP32_ALLREDUCE = None + global _SEQUENCE_PARALLEL_GROUP + _SEQUENCE_PARALLEL_GROUP = None def get_fp32_allreduce(): """Get the fp32 allreduce flag""" assert _FP32_ALLREDUCE is not None, "fp32_allreduce is not Initialized" return _FP32_ALLREDUCE + + +def get_sequence_data_parallel_group(): + return get_seq_parallel_group() + + +def get_sequence_data_parallel_world_size(): + return get_seq_parallel_world_size() + + +def get_sequence_data_parallel_rank(): + return get_seq_parallel_rank() + + +def get_sequence_data_parallel_src_rank(): + return get_seq_parallel_src_rank() diff --git a/megatron/mpu/mappings.py b/megatron/mpu/mappings.py index 535fe6255..6ce2cb42e 100644 --- a/megatron/mpu/mappings.py +++ b/megatron/mpu/mappings.py @@ -22,8 +22,11 @@ get_model_parallel_world_size, get_model_parallel_rank, get_fp32_allreduce, + get_seq_parallel_rank, + get_seq_parallel_world_size, + get_seq_parallel_group, ) -from .utils import split_tensor_along_last_dim +from .utils import split_tensor_along_last_dim, split_tensor_along_seq_dim def _reduce(input_): @@ -48,6 +51,30 @@ def _reduce(input_): return input_ +def _sum_seq(input_): + """All-reduce the the input tensor across seq parallel group.""" + + # Bypass the function if we are using only 1 GPU. + if get_seq_parallel_world_size() == 1: + return input_ + + # Bf16 convert + dt = input_.dtype + if dt == torch.bfloat16 and get_fp32_allreduce(): + input_ = input_.float() + + # All-reduce. + torch.distributed.all_reduce( + input_, group=get_seq_parallel_group(), op=torch.distributed.ReduceOp.SUM + ) + + # Bf16 convert + if dt == torch.bfloat16 and get_fp32_allreduce(): + input_ = input_.bfloat16() + + return input_ + + def _split(input_): """Split the tensor along its last dimension and keep the corresponding slice.""" @@ -98,6 +125,68 @@ def _gather(input_): return output +def _split_seq(input_): + """Split the tensor along its seq dimension and keep the + corresponding slice.""" + + world_size = get_seq_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # Split along last dimension. + input_list = split_tensor_along_seq_dim(input_, world_size) + + # Note: torch.split does not create contiguous tensors by default. + rank = get_seq_parallel_rank() + output = input_list[rank].contiguous() + + return output + + +def _gather_seq(input_): + """Gather tensors and concatinate along the second dimension.""" + + world_size = get_seq_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + + # Bf16 convert + dt = input_.dtype + if dt == torch.bfloat16 and get_fp32_allreduce(): + input_ = input_.float() + + # Size and dimension. + seq_dim = 1 + rank = get_seq_parallel_rank() + + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + torch.distributed.all_gather(tensor_list, input_, group=get_seq_parallel_group()) + # unzigzag by concatenating the first half with the second half in reverse order + tensor_list = [torch.chunk(tensor_list[i], 2, seq_dim) for i in range(world_size)] + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat( + ( + torch.cat([item[0] for item in tensor_list], dim=seq_dim), # first half... + torch.cat( + [item[1] for item in tensor_list[::-1]], dim=seq_dim + ), # ...and second half + ), + dim=seq_dim, + ).contiguous() + + # # regular ring attention + # output = torch.cat(tensor_list, dim=seq_dim).contiguous() + + # Bf16 convert + if dt == torch.bfloat16 and get_fp32_allreduce(): + output = output.bfloat16() + + return output + + class _CopyToModelParallelRegion(torch.autograd.Function): """Pass the input to the model parallel region.""" @@ -146,6 +235,22 @@ def backward(ctx, grad_output): return _gather(grad_output) +class _ScatterToSeqParallelRegion(torch.autograd.Function): + """Split the input and keep only the corresponding chuck to the rank.""" + + @staticmethod + def symbolic(graph, input_): + return _split_seq(input_) + + @staticmethod + def forward(ctx, input_): + return _split_seq(input_) + + @staticmethod + def backward(ctx, grad_output): + return _gather_seq(grad_output) + + class _GatherFromModelParallelRegion(torch.autograd.Function): """Gather the input from model parallel region and concatinate.""" @@ -162,6 +267,39 @@ def backward(ctx, grad_output): return _split(grad_output) +class _GatherFromSeqParallelRegion(torch.autograd.Function): + """Gather the input from model parallel region and concatinate.""" + + @staticmethod + def symbolic(graph, input_): + return _gather_seq(input_) + + @staticmethod + def forward(ctx, input_): + return _gather_seq(input_) + + @staticmethod + def backward(ctx, grad_output): + return _split_seq(grad_output) + + +class _SumFromSeqParallelRegion(torch.autograd.Function): + """All-reduce the input from the model parallel region.""" + + @staticmethod + def symbolic(graph, input_): + return _sum_seq(input_) + + @staticmethod + def forward(ctx, input_): + return _sum_seq(input_) + + @staticmethod + def backward(ctx, grad_output): + print("sum_grad: ", grad_output, flush=True) + return grad_output + + # ----------------- # Helper functions. # ----------------- @@ -181,3 +319,20 @@ def scatter_to_model_parallel_region(input_): def gather_from_model_parallel_region(input_): return _GatherFromModelParallelRegion.apply(input_) + + +def gather_from_seq_parallel_region(input_): + return _GatherFromSeqParallelRegion.apply(input_) + + +def sum_from_seq_parallel_region(input_): + return _SumFromSeqParallelRegion.apply(input_) + + +def max_from_seq_parallel_region(input_): + if get_seq_parallel_world_size() > 1: + torch.distributed.all_reduce( + input_, + op=torch.distributed.ReduceOp.MAX, + group=get_seq_parallel_group(), + ) diff --git a/megatron/mpu/utils.py b/megatron/mpu/utils.py index 13941dc29..e171f6f7a 100644 --- a/megatron/mpu/utils.py +++ b/megatron/mpu/utils.py @@ -53,6 +53,36 @@ def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks= return tensor_list +def split_tensor_along_seq_dim( + tensor, num_partitions, contiguous_split_chunks=False, zigzag=True +): + """Split a tensor along its seq dimension. + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + """ + # Get the size and dimension. + seq_dim = 1 + + if zigzag: + # Split. + tensor_list = torch.chunk(tensor, 2 * num_partitions, dim=seq_dim) + # zigzag + tensor_list = [ + torch.cat((tensor_list[i], tensor_list[-(i + 1)]), dim=seq_dim) + for i in range(num_partitions) + ] + # Note: torch.split does not create contiguous tensors by default. + else: + tensor_list = torch.chunk(tensor, num_partitions, dim=seq_dim) + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + class VocabUtility: """Split the vocabulary into `world_size` chunks amd return the first and last index of the vocabulary belonging to the `rank` diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index 9cad02c43..9c664fd70 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -753,8 +753,11 @@ def configure_distributed_args(self): if self.rank == 0: print( self.__class__.__name__ - + ".configure_distributed_args() using world size: {} and model-parallel size: {} ".format( - self.world_size, self.model_parallel_size + + ".configure_distributed_args() using world size: {}, pipe-parallel size: {}, sequence-parallel size: {}, and model-parallel size: {} ".format( + self.world_size, + self.pipe_parallel_size, + self.sequence_parallel_size, + self.model_parallel_size, ), flush=True, ) @@ -794,7 +797,9 @@ def calculate_batch_parameters( # either none of the three parameters are provided or just gradient_accumulation_step is provided else: - assert False, "Either train_batch_size or train_micro_batch_size_per_gpu needs to be provided" + assert ( + False + ), "Either train_batch_size or train_micro_batch_size_per_gpu needs to be provided" return int(train_batch), int(micro_batch), int(grad_acc) @staticmethod @@ -855,10 +860,13 @@ def calculate_derived(self): pp_size = pp_size if pp_size >= 1 else 1 mp_size = self.model_parallel_size mp_size = mp_size if mp_size >= 1 else 1 + sp_size = self.sequence_parallel_size + sp_size = sp_size if sp_size >= 1 else 1 self.update_value("model_parallel_size", mp_size) + self.update_value("sequence_parallel_size", sp_size) - # pp_size and mp_size are only used here to compute dp world size and nowhere else. - dp_world_size = (global_num_gpus / pp_size) / mp_size + # pp_size, mp_size, and sp_size are only used here to compute dp world size and nowhere else. + dp_world_size = (global_num_gpus / pp_size) / (mp_size * sp_size) if not (dp_world_size % 1 == 0): error_message = ( self.__class__.__name__ @@ -1029,6 +1037,11 @@ def calculate_derived(self): # if we set pipe_parallel_size to 0 or 1, GPT2ModelPipe.to_sequential() is called, and we run training with # the sequential model without the PipelineModule wrapper to avoid the overhead it incurs self.update_value("is_pipe_parallel", self.pipe_parallel_size >= 1) + # update 'is sequence parallel' flag + self.update_value( + "is_sequence_parallel", + self.sequence_parallel_size > 1 and self.num_experts == 1, + ) if self.moe_num_experts > 1: assert not ( self.is_pipe_parallel or self.pipe_parallel_size > 1 @@ -1043,6 +1056,13 @@ def calculate_derived(self): "attention_config", expand_attention_types(self.attention_config, self.num_layers), ) + self.update_value( + "requires_attention_mask", + not all([item in ["ring", "flash"] for item in self.attention_config]), + ) + assert all([item == "ring" for item in self.attention_config]) or ( + not self.is_sequence_parallel + ), "Sequence parallel requires ring attention!" assert ( len(self.attention_config) == self.num_layers ), "Length of attention config list must equal num_layers" @@ -1088,7 +1108,9 @@ def calculate_derived(self): not self.sparsity_config ), "Sparse attention not compatible with GQA or MQA" assert all( - (attn_type == "flash") or (attn_type == "global") + (attn_type == "flash") + or (attn_type == "global") + or (attn_type == "ring") for attn_type in self.attention_config ), "GQA / MQA currently only compatible with Flash or standard global/sliding window Attention" assert ( @@ -1098,8 +1120,8 @@ def calculate_derived(self): if "flash" in self.attention_config: _flash_version = packaging.version.Version(version("flash-attn")) if self.sliding_window_width is not None: - assert ( - _flash_version >= packaging.version.Version("2.3.0") + assert _flash_version >= packaging.version.Version( + "2.3.0" ), f"Flash-Attention version ({str(_flash_version)}) must be >= 2.3.0 to support sliding window attention." if self.pos_emb == "alibi": if not _flash_version >= packaging.version.Version("2.4.0.post1"): diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index febefb3c2..e7ef7733c 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -38,6 +38,7 @@ "flash", "rwkv", "mamba", + "ring", ] @@ -67,6 +68,11 @@ class NeoXArgsParallelism(NeoXArgsTemplate): Size of the model parallelism. """ + sequence_parallel_size: int = 1 + """ + Size of the sequence parallelism. + """ + pipe_partition_method: str = "type:transformer|mlp" """ method used to distribute model layers across pipeline stages. Choose from "parameters", which balances the number @@ -85,6 +91,12 @@ class NeoXArgsParallelism(NeoXArgsTemplate): according to pipeline parallel size. """ + is_sequence_parallel: bool = False + """ + flag to determine whether sequence parallelism is on - shouldn't be set by user, is automatically determined + according to sequence parallel size. + """ + expert_interval: int = 2 """ Have one MoE layer every expert_interval layers @@ -217,7 +229,7 @@ class NeoXArgsModel(NeoXArgsTemplate): The first item in the list specifies the attention type(s), and should be a list of strings. The second item specifies the number of times to repeat those attention types in the full list. - attention type choices: [global, local, sparse_fixed, sparse_variable, bslongformer, bigbird, "gmlp", "amlp", "flash", "mamba", "rwkv"] + attention type choices: [global, local, sparse_fixed, sparse_variable, bslongformer, bigbird, "gmlp", "amlp", "flash", "mamba", "rwkv", "ring"] So a 12 layer network with only global attention could be specified like: [[[`global`], 12]] @@ -229,6 +241,12 @@ class NeoXArgsModel(NeoXArgsTemplate): [[[`global`], n_layers]] """ + requires_attention_mask: bool = True + """ + If true, the model requires an attention mask to be passed in. + Automatically configured based on attention type. + """ + sparsity_config: dict = None """ diff --git a/megatron/training.py b/megatron/training.py index 3265680c5..02d20ad33 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -21,6 +21,7 @@ """Pretrain utilities.""" from datetime import datetime from functools import partial +from megatron.monkeypatcher import replace_engine_get_global_norm import math import sys @@ -296,11 +297,20 @@ def _get_batch(neox_args, tokenizer, keys, data, datatype): eod_token=neox_args.tokenizer.eod, eod_mask_loss=neox_args.eod_mask_loss, sliding_window_width=neox_args.sliding_window_width, + requires_mask=neox_args.requires_attention_mask, ) # If `label` is present, any token < 0 (e.g., -100, the default for torch) skips the loss computation if "label" in data_b: loss_mask = (data_b["label"][:, 1:] >= 0).to(loss_mask.dtype) - return tokens, labels, loss_mask, attention_mask, position_ids + return ( + mpu.zigzag_data(tokens), + mpu.zigzag_data(labels), + mpu.zigzag_data(loss_mask), + mpu.zigzag_data(attention_mask, -2) + if neox_args.requires_attention_mask + else None, + mpu.zigzag_data(position_ids), + ) def get_batch(neox_args, data_iterator): @@ -361,6 +371,7 @@ def get_batch_sequential(forward_input, neox_args): data=forward_input[0], eod_token=neox_args.tokenizer.eod, eod_mask_loss=neox_args.eod_mask_loss, + requires_mask=neox_args.requires_attention_mask, ) return (forward_input[0], forward_input[1], attention_mask) @@ -970,7 +981,7 @@ def train( # to monitor if we've skipped many iterations in a row and trigger an early exit overflow_monitor = OverflowMonitor(optimizer) - + replace_engine_get_global_norm(model) if neox_args.profile: schedule = torch.profiler.schedule( wait=neox_args.profile_step_start, diff --git a/megatron/utils.py b/megatron/utils.py index 26b4439bd..275eb4725 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -81,22 +81,29 @@ def get_attn_mask(seq_length, device, sliding_window_width): def get_ltor_masks_and_position_ids( - data, - eod_token, - eod_mask_loss=False, - sliding_window_width=None, + data, eod_token, eod_mask_loss=False, sliding_window_width=None, requires_mask=True ): """Build masks and position id for left to right model.""" # Extract batch size and sequence length. batch_size, seq_length = data.size() - # Attention mask (lower triangular). - attention_mask = get_attn_mask( - seq_length=seq_length, - device=data.device, - sliding_window_width=sliding_window_width, - ) + if requires_mask: + # Attention mask (lower triangular). + attention_mask = get_attn_mask( + seq_length=seq_length, + device=data.device, + sliding_window_width=sliding_window_width, + ) + else: + # Need this to actually do long context, 128k**2 is v big. + # Give it a dummy value + # Surely there is a better way to do this... + attention_mask = get_attn_mask( + seq_length=64, + device=data.device, + sliding_window_width=sliding_window_width, + ) # Loss mask. loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) diff --git a/requirements/requirements-ringattention.txt b/requirements/requirements-ringattention.txt new file mode 100644 index 000000000..77fa631ba --- /dev/null +++ b/requirements/requirements-ringattention.txt @@ -0,0 +1 @@ +git+https://github.com/zhuzilin/ring-flash-attention.git@0.1