From a9bd1a0215a7130230b0e87446fb4f127959ce5f Mon Sep 17 00:00:00 2001 From: mcowan Date: Tue, 17 Sep 2024 09:02:21 -0700 Subject: [PATCH 1/6] init --- benchmarks/python/test_pytorch_transformer.py | 208 ++++++++++++++++++ 1 file changed, 208 insertions(+) create mode 100644 benchmarks/python/test_pytorch_transformer.py diff --git a/benchmarks/python/test_pytorch_transformer.py b/benchmarks/python/test_pytorch_transformer.py new file mode 100644 index 00000000000..d5ba0f05599 --- /dev/null +++ b/benchmarks/python/test_pytorch_transformer.py @@ -0,0 +1,208 @@ +import os +import math +from dataclasses import dataclass +from collections.abc import Sequence + +import torch +import torch.nn as nn +from torch.nn import functional as F +import logging + +from torch.distributed.tensor.parallel import ( + parallelize_module, + ColwiseParallel, + RowwiseParallel, +) +from torch.distributed._tensor.device_mesh import init_device_mesh + +logging.basicConfig( + format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO +) + +# Usage: torchrun --nproc-per-node= transformer.py + +def get_logger(): + return logging.getLogger(__name__) + + +def rank_log(_rank, logger, msg): + """helper function to log only on global rank 0""" + if _rank == 0: + logger.info(f" {msg}") + + +class LayerNorm(nn.Module): + """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" + + def __init__(self, ndim, bias): + super().__init__() + self.weight = nn.Parameter(torch.ones(ndim)) + self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None + + def forward(self, input): + return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) + + +class CausalSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + # key, query, value projections for all heads, but in a batch + # self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) + # Splitting key, query, and value projections + self.c_attn_key = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + self.c_attn_query = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + self.c_attn_value = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + # regularization + self.attn_dropout = nn.Dropout(config.dropout) + self.resid_dropout = nn.Dropout(config.dropout) + self.n_head = config.n_head + self.n_embd = config.n_embd + self.dropout = config.dropout + # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 + self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") + # NOTE: The original Karpathy's script hides bias registration behind a flag + # but we don't do that here. We always register bias due to a now-fixed + # bug in thunder. + # TODO: Move the bias registration to be happening `if not self.flash` once the bug is fixed. + # if not self.flash: + # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "bias", + torch.tril(torch.ones(config.block_size, config.block_size)).view( + 1, 1, config.block_size, config.block_size + ), + ) + + def forward(self, x): + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + # q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + q = self.c_attn_query(x) + k = self.c_attn_key(x) + v = self.c_attn_value(x) + + # TODO: It looks like view needs to take in the sharded size. + k = k.view(B, T, self.n_head, C // self.n_head // 2).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head // 2).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head // 2).transpose(1, 2) # (B, nh, T, hs) + + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + if self.flash: + # efficient attention using Flash Attention CUDA kernels + y = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True + ) + else: + # manual implementation of attention + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C // 2) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + + +class MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) + # NOTE: The original Karpathy's script doesn't set the approximate flag, + # probably by mistake. + self.gelu = nn.GELU(approximate="tanh") + self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) + self.dropout = nn.Dropout(config.dropout) + + def forward(self, x): + x = self.c_fc(x) + x = self.gelu(x) + x = self.c_proj(x) + x = self.dropout(x) + return x + + +class Block(nn.Module): + def __init__(self, config): + super().__init__() + self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) + self.attn = CausalSelfAttention(config) + self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) + self.mlp = MLP(config) + + def forward(self, x): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +@dataclass +class GPTConfig: + block_size: int = 1024 + vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: int = 12 + n_head: int = 12 + n_embd: int = 768 + dropout: float = 0.0 + bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + + +logger = get_logger() + +# DeviceMesh creation +_world_size = int(os.environ["WORLD_SIZE"]) +device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,)) +_rank = device_mesh.get_rank() + +print(f"Starting PyTorch TP example on rank {_rank}.") +assert ( + _world_size % 2 == 0 +), f"TP examples require even number of GPUs, but got {_world_size} gpus" + +rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}") + +dtype = torch.bfloat16 +config = GPTConfig() +tp_model = Block(config).to(dtype).to("cuda") + +# Create an optimizer for the parallelized module. +lr = 0.25 +optimizer = torch.optim.AdamW(tp_model.parameters(), lr=lr, foreach=True) + +# Parallelization plan +tp_model = parallelize_module( + module=tp_model, + device_mesh=device_mesh, + parallelize_plan={ + "attn.c_attn_key": ColwiseParallel(), + "attn.c_attn_query": ColwiseParallel(), + "attn.c_attn_value": ColwiseParallel(), + "attn.c_proj": RowwiseParallel(), + "mlp.c_fc": ColwiseParallel(), + "mlp.c_proj": RowwiseParallel(), + } +) +# Perform a num of iterations of forward/backward +# and optimizations for the sharded module. +num_iters = 10 +rank_log(_rank, logger, "Tensor Parallel training starting...") + +for i in range(num_iters): + # For TP, input needs to be same across all TP ranks. + # Setting the random seed is to mimic the behavior of dataloader. + torch.manual_seed(i) + inp = torch.rand(64, 2048, config.n_embd, dtype=dtype, device="cuda") + output = tp_model(inp) + output.sum().backward() + optimizer.step() + rank_log(_rank, logger, f"Tensor Parallel iter {i} completed") + +rank_log(_rank, logger, "Tensor Parallel training completed!") \ No newline at end of file From bcbc71ad05e7721a07c7ef932848eabbe250c96d Mon Sep 17 00:00:00 2001 From: mcowan Date: Tue, 17 Sep 2024 16:40:40 -0700 Subject: [PATCH 2/6] fix hard coded devices --- benchmarks/python/test_pytorch_transformer.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/benchmarks/python/test_pytorch_transformer.py b/benchmarks/python/test_pytorch_transformer.py index d5ba0f05599..6ecc488bef3 100644 --- a/benchmarks/python/test_pytorch_transformer.py +++ b/benchmarks/python/test_pytorch_transformer.py @@ -1,3 +1,13 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +''' +Benchmarks Tensor parallel NanoGPT with Pytorch. + +Usage: torchrun --nproc-per-node= transformer.py +''' + import os import math from dataclasses import dataclass @@ -87,9 +97,9 @@ def forward(self, x): v = self.c_attn_value(x) # TODO: It looks like view needs to take in the sharded size. - k = k.view(B, T, self.n_head, C // self.n_head // 2).transpose(1, 2) # (B, nh, T, hs) - q = q.view(B, T, self.n_head, C // self.n_head // 2).transpose(1, 2) # (B, nh, T, hs) - v = v.view(B, T, self.n_head, C // self.n_head // 2).transpose(1, 2) # (B, nh, T, hs) + k = k.view(B, T, self.n_head, C // self.n_head // _world_size).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head // _world_size).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head // _world_size).transpose(1, 2) # (B, nh, T, hs) # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) @@ -105,7 +115,7 @@ def forward(self, x): att = F.softmax(att, dim=-1) att = self.attn_dropout(att) y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = y.transpose(1, 2).contiguous().view(B, T, C // 2) # re-assemble all head outputs side by side + y = y.transpose(1, 2).contiguous().view(B, T, C // _world_size) # re-assemble all head outputs side by side # output projection y = self.resid_dropout(self.c_proj(y)) From 7a76b5a4c08fa95b5c6928d8b9b8fd44f2449008 Mon Sep 17 00:00:00 2001 From: mcowan Date: Wed, 18 Sep 2024 13:47:41 -0700 Subject: [PATCH 3/6] tp tests --- benchmarks/python/nanogpt.py | 138 +++++++++ benchmarks/python/test_pytorch_transformer.py | 287 ++++++------------ 2 files changed, 225 insertions(+), 200 deletions(-) create mode 100644 benchmarks/python/nanogpt.py diff --git a/benchmarks/python/nanogpt.py b/benchmarks/python/nanogpt.py new file mode 100644 index 00000000000..91586feffc7 --- /dev/null +++ b/benchmarks/python/nanogpt.py @@ -0,0 +1,138 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# NanoGPT model definition taken from https://github.com/Lightning-AI/lightning-thunder/blob/main/thunder/tests/nanogpt_model.py +# and modified for compatibility with PyTorch's Tensor Parallel API. + +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn +from torch.nn import functional as F + + +class LayerNorm(nn.Module): + """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" + + def __init__(self, ndim, bias): + super().__init__() + self.weight = nn.Parameter(torch.ones(ndim)) + self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None + + def forward(self, input): + return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) + + +class CausalSelfAttention(nn.Module): + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + self.n_devices = config.n_devices + # key, query, value projections for all heads, but in a batch + # self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) + # Splitting key, query, and value projections + self.c_attn_key = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + self.c_attn_query = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + self.c_attn_value = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + # regularization + self.attn_dropout = nn.Dropout(config.dropout) + self.resid_dropout = nn.Dropout(config.dropout) + self.n_head = config.n_head + self.n_embd = config.n_embd + self.dropout = config.dropout + # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 + self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") + # NOTE: The original Karpathy's script hides bias registration behind a flag + # but we don't do that here. We always register bias due to a now-fixed + # bug in thunder. + # TODO: Move the bias registration to be happening `if not self.flash` once the bug is fixed. + # if not self.flash: + # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer( + "bias", + torch.tril(torch.ones(config.block_size, config.block_size)).view( + 1, 1, config.block_size, config.block_size + ), + ) + + def forward(self, x): + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads separately and move head forward to be the batch dim + q = self.c_attn_query(x) + k = self.c_attn_key(x) + v = self.c_attn_value(x) + + # TODO: It looks like view needs to take in the sharded size. + k = k.view(B, T, self.n_head, C // self.n_head // self.n_devices).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head // self.n_devices).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head // self.n_devices).transpose(1, 2) # (B, nh, T, hs) + + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + if self.flash: + # efficient attention using Flash Attention CUDA kernels + y = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True + ) + else: + # manual implementation of attention + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C // self.n_devices) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + + +class MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) + # NOTE: The original Karpathy's script doesn't set the approximate flag, + # probably by mistake. + self.gelu = nn.GELU(approximate="tanh") + self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) + self.dropout = nn.Dropout(config.dropout) + + def forward(self, x): + x = self.c_fc(x) + x = self.gelu(x) + x = self.c_proj(x) + x = self.dropout(x) + return x + + +class Block(nn.Module): + def __init__(self, config): + super().__init__() + self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) + self.attn = CausalSelfAttention(config) + self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) + self.mlp = MLP(config) + + def forward(self, x): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +@dataclass +class GPTConfig: + block_size: int = 1024 + vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: int = 12 + n_head: int = 12 + n_embd: int = 768 + dropout: float = 0.1 + bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + n_devices: int = 1 \ No newline at end of file diff --git a/benchmarks/python/test_pytorch_transformer.py b/benchmarks/python/test_pytorch_transformer.py index 6ecc488bef3..ac69b43ad8b 100644 --- a/benchmarks/python/test_pytorch_transformer.py +++ b/benchmarks/python/test_pytorch_transformer.py @@ -2,22 +2,19 @@ # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -''' -Benchmarks Tensor parallel NanoGPT with Pytorch. +""" +Benchmarks Tensor parallel NanoGPT block using Pytorch TP API. -Usage: torchrun --nproc-per-node= transformer.py -''' +Usage: torchrun --nproc-per-node= test_pytorch_transformer.py +""" import os -import math -from dataclasses import dataclass -from collections.abc import Sequence +import time -import torch -import torch.nn as nn -from torch.nn import functional as F -import logging +from nanogpt import * +import torch +import torch.distributed as dist from torch.distributed.tensor.parallel import ( parallelize_module, ColwiseParallel, @@ -25,194 +22,84 @@ ) from torch.distributed._tensor.device_mesh import init_device_mesh -logging.basicConfig( - format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO -) - -# Usage: torchrun --nproc-per-node= transformer.py - -def get_logger(): - return logging.getLogger(__name__) - - -def rank_log(_rank, logger, msg): - """helper function to log only on global rank 0""" - if _rank == 0: - logger.info(f" {msg}") - - -class LayerNorm(nn.Module): - """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" - - def __init__(self, ndim, bias): - super().__init__() - self.weight = nn.Parameter(torch.ones(ndim)) - self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None - - def forward(self, input): - return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) - - -class CausalSelfAttention(nn.Module): - def __init__(self, config): - super().__init__() - assert config.n_embd % config.n_head == 0 - # key, query, value projections for all heads, but in a batch - # self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) - # Splitting key, query, and value projections - self.c_attn_key = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) - self.c_attn_query = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) - self.c_attn_value = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) - # output projection - self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) - # regularization - self.attn_dropout = nn.Dropout(config.dropout) - self.resid_dropout = nn.Dropout(config.dropout) - self.n_head = config.n_head - self.n_embd = config.n_embd - self.dropout = config.dropout - # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 - self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") - # NOTE: The original Karpathy's script hides bias registration behind a flag - # but we don't do that here. We always register bias due to a now-fixed - # bug in thunder. - # TODO: Move the bias registration to be happening `if not self.flash` once the bug is fixed. - # if not self.flash: - # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") - # causal mask to ensure that attention is only applied to the left in the input sequence - self.register_buffer( - "bias", - torch.tril(torch.ones(config.block_size, config.block_size)).view( - 1, 1, config.block_size, config.block_size - ), - ) - - def forward(self, x): - B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) - - # calculate query, key, values for all heads in batch and move head forward to be the batch dim - # q, k, v = self.c_attn(x).split(self.n_embd, dim=2) - q = self.c_attn_query(x) - k = self.c_attn_key(x) - v = self.c_attn_value(x) - - # TODO: It looks like view needs to take in the sharded size. - k = k.view(B, T, self.n_head, C // self.n_head // _world_size).transpose(1, 2) # (B, nh, T, hs) - q = q.view(B, T, self.n_head, C // self.n_head // _world_size).transpose(1, 2) # (B, nh, T, hs) - v = v.view(B, T, self.n_head, C // self.n_head // _world_size).transpose(1, 2) # (B, nh, T, hs) - - - # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) - if self.flash: - # efficient attention using Flash Attention CUDA kernels - y = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True - ) - else: - # manual implementation of attention - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) - att = F.softmax(att, dim=-1) - att = self.attn_dropout(att) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = y.transpose(1, 2).contiguous().view(B, T, C // _world_size) # re-assemble all head outputs side by side - - # output projection - y = self.resid_dropout(self.c_proj(y)) - return y - - -class MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) - # NOTE: The original Karpathy's script doesn't set the approximate flag, - # probably by mistake. - self.gelu = nn.GELU(approximate="tanh") - self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) - self.dropout = nn.Dropout(config.dropout) - - def forward(self, x): - x = self.c_fc(x) - x = self.gelu(x) - x = self.c_proj(x) - x = self.dropout(x) - return x - - -class Block(nn.Module): - def __init__(self, config): - super().__init__() - self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) - self.attn = CausalSelfAttention(config) - self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) - self.mlp = MLP(config) - - def forward(self, x): - x = x + self.attn(self.ln_1(x)) - x = x + self.mlp(self.ln_2(x)) - return x - - -@dataclass -class GPTConfig: - block_size: int = 1024 - vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency - n_layer: int = 12 - n_head: int = 12 - n_embd: int = 768 - dropout: float = 0.0 - bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster - - -logger = get_logger() - -# DeviceMesh creation -_world_size = int(os.environ["WORLD_SIZE"]) -device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,)) -_rank = device_mesh.get_rank() - -print(f"Starting PyTorch TP example on rank {_rank}.") -assert ( - _world_size % 2 == 0 -), f"TP examples require even number of GPUs, but got {_world_size} gpus" - -rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}") +# Usage: torchrun --nproc-per-node= transformer.py -dtype = torch.bfloat16 -config = GPTConfig() -tp_model = Block(config).to(dtype).to("cuda") - -# Create an optimizer for the parallelized module. -lr = 0.25 -optimizer = torch.optim.AdamW(tp_model.parameters(), lr=lr, foreach=True) - -# Parallelization plan -tp_model = parallelize_module( - module=tp_model, - device_mesh=device_mesh, - parallelize_plan={ - "attn.c_attn_key": ColwiseParallel(), - "attn.c_attn_query": ColwiseParallel(), - "attn.c_attn_value": ColwiseParallel(), - "attn.c_proj": RowwiseParallel(), - "mlp.c_fc": ColwiseParallel(), - "mlp.c_proj": RowwiseParallel(), - } -) -# Perform a num of iterations of forward/backward -# and optimizations for the sharded module. num_iters = 10 -rank_log(_rank, logger, "Tensor Parallel training starting...") - -for i in range(num_iters): - # For TP, input needs to be same across all TP ranks. - # Setting the random seed is to mimic the behavior of dataloader. - torch.manual_seed(i) - inp = torch.rand(64, 2048, config.n_embd, dtype=dtype, device="cuda") - output = tp_model(inp) - output.sum().backward() - optimizer.step() - rank_log(_rank, logger, f"Tensor Parallel iter {i} completed") - -rank_log(_rank, logger, "Tensor Parallel training completed!") \ No newline at end of file +batch_size = 64 +sequence_length = 2048 +dtype = torch.bfloat16 +world_size = int(os.environ["WORLD_SIZE"]) +device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,)) +rank = device_mesh.get_rank() + +assert ( + world_size % 2 == 0 +), f"TP examples require even number of GPUs, but got {world_size} gpus" + + +def benchmark_loop(model, input): + forward_time = 0.0 + backward_time = 0.0 + + for i in range(num_iters): + start = time.time() + output = model(input) + torch.cuda.synchronize() + end = time.time() + forward_time += end - start + + start = time.time() + output.sum().backward() + torch.cuda.synchronize() + end = time.time() + backward_time += end - start + + forward_time /= num_iters + backward_time /= num_iters + return forward_time, backward_time + + +def benchmark_model(): + if rank != 0: + return + config = GPTConfig() + model = Block(config).to(dtype).to("cuda") + + input = torch.rand( + batch_size, sequence_length, config.n_embd, dtype=dtype, device="cuda" + ) + forward_time, backward_time = benchmark_loop(model, input) + print(f"Average forward time {forward_time}s, backward time {backward_time}s") + + +def benchmark_tensor_parallel(): + config = GPTConfig() + config.n_devices = world_size + tp_model = Block(config).to(dtype).to("cuda") + + # Parallelization plan. Tensor parallel + tp_model = parallelize_module( + module=tp_model, + device_mesh=device_mesh, + parallelize_plan={ + "attn.c_attn_key": ColwiseParallel(), + "attn.c_attn_query": ColwiseParallel(), + "attn.c_attn_value": ColwiseParallel(), + "attn.c_proj": RowwiseParallel(), + "mlp.c_fc": ColwiseParallel(), + "mlp.c_proj": RowwiseParallel(), + }, + ) + input = torch.rand( + batch_size, sequence_length, config.n_embd, dtype=dtype, device="cuda" + ) + + forward_time, backward_time = benchmark_loop(tp_model, input) + print( + f"{rank}: Average tensor parallel forward time {forward_time}s, backward time {backward_time}s" + ) + + +benchmark_model() +benchmark_tensor_parallel() +dist.destroy_process_group() From f7a9c41a5d212fc4b5cb71ba53c59d3270314785 Mon Sep 17 00:00:00 2001 From: mcowan Date: Wed, 18 Sep 2024 16:24:02 -0700 Subject: [PATCH 4/6] fix bug, lint --- benchmarks/python/nanogpt.py | 44 ++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/benchmarks/python/nanogpt.py b/benchmarks/python/nanogpt.py index 91586feffc7..3f75f31c123 100644 --- a/benchmarks/python/nanogpt.py +++ b/benchmarks/python/nanogpt.py @@ -30,9 +30,8 @@ def __init__(self, config): super().__init__() assert config.n_embd % config.n_head == 0 self.n_devices = config.n_devices - # key, query, value projections for all heads, but in a batch - # self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) - # Splitting key, query, and value projections + # Splitting key, query, and value projections + # Note: These were performed batched in the original implementation. self.c_attn_key = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) self.c_attn_query = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) self.c_attn_value = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) @@ -44,15 +43,11 @@ def __init__(self, config): self.n_head = config.n_head self.n_embd = config.n_embd self.dropout = config.dropout - # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 + # flash attention only supported in PyTorch >= 2.0 self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") # NOTE: The original Karpathy's script hides bias registration behind a flag # but we don't do that here. We always register bias due to a now-fixed # bug in thunder. - # TODO: Move the bias registration to be happening `if not self.flash` once the bug is fixed. - # if not self.flash: - # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") - # causal mask to ensure that attention is only applied to the left in the input sequence self.register_buffer( "bias", torch.tril(torch.ones(config.block_size, config.block_size)).view( @@ -61,24 +56,37 @@ def __init__(self, config): ) def forward(self, x): - B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + # batch size, sequence length, embedding dimensionality (n_embd) + (B, T, C) = x.size() # calculate query, key, values for all heads separately and move head forward to be the batch dim + # Note: The original script calculated this batched but we cannot use the Tensor API on a 3D weight q = self.c_attn_query(x) k = self.c_attn_key(x) v = self.c_attn_value(x) - # TODO: It looks like view needs to take in the sharded size. - k = k.view(B, T, self.n_head, C // self.n_head // self.n_devices).transpose(1, 2) # (B, nh, T, hs) - q = q.view(B, T, self.n_head, C // self.n_head // self.n_devices).transpose(1, 2) # (B, nh, T, hs) - v = v.view(B, T, self.n_head, C // self.n_head // self.n_devices).transpose(1, 2) # (B, nh, T, hs) - + # Note: It looks like view needs to take in the sharded size. + # Head dimension is parallelized + k = k.view(B, T, self.n_head // self.n_devices, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + q = q.view(B, T, self.n_head // self.n_devices, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + v = v.view(B, T, self.n_head // self.n_devices, C // self.n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) if self.flash: # efficient attention using Flash Attention CUDA kernels y = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True + q, + k, + v, + attn_mask=None, + dropout_p=self.dropout if self.training else 0, + is_causal=True, ) else: # manual implementation of attention @@ -87,7 +95,9 @@ def forward(self, x): att = F.softmax(att, dim=-1) att = self.attn_dropout(att) y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = y.transpose(1, 2).contiguous().view(B, T, C // self.n_devices) # re-assemble all head outputs side by side + y = ( + y.transpose(1, 2).contiguous().view(B, T, C // self.n_devices) + ) # re-assemble all head outputs side by side # output projection y = self.resid_dropout(self.c_proj(y)) @@ -135,4 +145,4 @@ class GPTConfig: n_embd: int = 768 dropout: float = 0.1 bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster - n_devices: int = 1 \ No newline at end of file + n_devices: int = 1 From 1655e4a4d2c1a70cf6852cc516fcdcd86ad69fdf Mon Sep 17 00:00:00 2001 From: mcowan Date: Wed, 18 Sep 2024 16:47:51 -0700 Subject: [PATCH 5/6] lint --- benchmarks/python/test_pytorch_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/python/test_pytorch_transformer.py b/benchmarks/python/test_pytorch_transformer.py index ac69b43ad8b..29ba7e5f3f8 100644 --- a/benchmarks/python/test_pytorch_transformer.py +++ b/benchmarks/python/test_pytorch_transformer.py @@ -11,7 +11,7 @@ import os import time -from nanogpt import * +from nanogpt import Block, GPTConfig import torch import torch.distributed as dist From 44acd6d408084e022b882c0dbef0ec94bba3be39 Mon Sep 17 00:00:00 2001 From: mcowan Date: Wed, 18 Sep 2024 17:49:42 -0700 Subject: [PATCH 6/6] with torch compile --- benchmarks/python/test_pytorch_transformer.py | 23 +++++++++++++++---- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/benchmarks/python/test_pytorch_transformer.py b/benchmarks/python/test_pytorch_transformer.py index 29ba7e5f3f8..6c139b001a0 100644 --- a/benchmarks/python/test_pytorch_transformer.py +++ b/benchmarks/python/test_pytorch_transformer.py @@ -41,6 +41,10 @@ def benchmark_loop(model, input): forward_time = 0.0 backward_time = 0.0 + # warm-up + output = model(input) + output.sum().backward() + for i in range(num_iters): start = time.time() output = model(input) @@ -59,23 +63,30 @@ def benchmark_loop(model, input): return forward_time, backward_time -def benchmark_model(): +def benchmark_model(use_torch_compile=False): if rank != 0: return config = GPTConfig() model = Block(config).to(dtype).to("cuda") + if use_torch_compile: + model = torch.compile(model) input = torch.rand( batch_size, sequence_length, config.n_embd, dtype=dtype, device="cuda" ) forward_time, backward_time = benchmark_loop(model, input) - print(f"Average forward time {forward_time}s, backward time {backward_time}s") + print(f"torch.compile {not use_torch_compile}, Average forward time {forward_time}s, backward time {backward_time}s") -def benchmark_tensor_parallel(): +def benchmark_tensor_parallel(use_torch_compile=False): config = GPTConfig() - config.n_devices = world_size + # TODO: used the world size to scale the sizes in view operations. + # this was necessary in eager mode, but not for torch.compile + if not use_torch_compile: + config.n_devices = world_size tp_model = Block(config).to(dtype).to("cuda") + if use_torch_compile: + tp_model = torch.compile(tp_model) # Parallelization plan. Tensor parallel tp_model = parallelize_module( @@ -96,10 +107,12 @@ def benchmark_tensor_parallel(): forward_time, backward_time = benchmark_loop(tp_model, input) print( - f"{rank}: Average tensor parallel forward time {forward_time}s, backward time {backward_time}s" + f"{rank}: torch.compile {not use_torch_compile}, Average tensor parallel forward time {forward_time}s, backward time {backward_time}s" ) benchmark_model() +benchmark_model(True) benchmark_tensor_parallel() +benchmark_tensor_parallel(True) dist.destroy_process_group()