Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch TP transformer block benchmarking #2958

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions benchmarks/python/nanogpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

# NanoGPT model definition taken from https://github.com/Lightning-AI/lightning-thunder/blob/main/thunder/tests/nanogpt_model.py
# and modified for compatibility with PyTorch's Tensor Parallel API.

import math
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.nn import functional as F


class LayerNorm(nn.Module):
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""

def __init__(self, ndim, bias):
super().__init__()
self.weight = nn.Parameter(torch.ones(ndim))
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

def forward(self, input):
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)


class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.n_devices = config.n_devices
# Splitting key, query, and value projections
# Note: These were performed batched in the original implementation.
self.c_attn_key = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.c_attn_query = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.c_attn_value = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
# regularization
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.n_head = config.n_head
self.n_embd = config.n_embd
self.dropout = config.dropout
# flash attention only supported in PyTorch >= 2.0
self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")
# NOTE: The original Karpathy's script hides bias registration behind a flag
# but we don't do that here. We always register bias due to a now-fixed
# bug in thunder.
self.register_buffer(
"bias",
torch.tril(torch.ones(config.block_size, config.block_size)).view(
1, 1, config.block_size, config.block_size
),
)

def forward(self, x):
# batch size, sequence length, embedding dimensionality (n_embd)
(B, T, C) = x.size()

# calculate query, key, values for all heads separately and move head forward to be the batch dim
# Note: The original script calculated this batched but we cannot use the Tensor API on a 3D weight
q = self.c_attn_query(x)
k = self.c_attn_key(x)
v = self.c_attn_value(x)

# Note: It looks like view needs to take in the sharded size.
# Head dimension is parallelized
k = k.view(B, T, self.n_head // self.n_devices, C // self.n_head).transpose(
1, 2
) # (B, nh, T, hs)
q = q.view(B, T, self.n_head // self.n_devices, C // self.n_head).transpose(
1, 2
) # (B, nh, T, hs)
v = v.view(B, T, self.n_head // self.n_devices, C // self.n_head).transpose(
1, 2
) # (B, nh, T, hs)

# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
if self.flash:
# efficient attention using Flash Attention CUDA kernels
y = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=self.dropout if self.training else 0,
is_causal=True,
)
else:
# manual implementation of attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = (
y.transpose(1, 2).contiguous().view(B, T, C // self.n_devices)
) # re-assemble all head outputs side by side

# output projection
y = self.resid_dropout(self.c_proj(y))
return y


class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
# NOTE: The original Karpathy's script doesn't set the approximate flag,
# probably by mistake.
self.gelu = nn.GELU(approximate="tanh")
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)

def forward(self, x):
x = self.c_fc(x)
x = self.gelu(x)
x = self.c_proj(x)
x = self.dropout(x)
return x


class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
self.attn = CausalSelfAttention(config)
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
self.mlp = MLP(config)

def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x


@dataclass
class GPTConfig:
block_size: int = 1024
vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
n_layer: int = 12
n_head: int = 12
n_embd: int = 768
dropout: float = 0.1
bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
n_devices: int = 1
118 changes: 118 additions & 0 deletions benchmarks/python/test_pytorch_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

"""
Benchmarks Tensor parallel NanoGPT block using Pytorch TP API.

Usage: torchrun --nproc-per-node=<number of processes> test_pytorch_transformer.py
"""

import os
import time

from nanogpt import Block, GPTConfig

import torch
import torch.distributed as dist
from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
RowwiseParallel,
)
from torch.distributed._tensor.device_mesh import init_device_mesh

# Usage: torchrun --nproc-per-node=<number of processes> transformer.py

num_iters = 10
batch_size = 64
sequence_length = 2048
dtype = torch.bfloat16
world_size = int(os.environ["WORLD_SIZE"])
device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
rank = device_mesh.get_rank()

assert (
world_size % 2 == 0
), f"TP examples require even number of GPUs, but got {world_size} gpus"


def benchmark_loop(model, input):
forward_time = 0.0
backward_time = 0.0

# warm-up
output = model(input)
output.sum().backward()

for i in range(num_iters):
start = time.time()
output = model(input)
torch.cuda.synchronize()
end = time.time()
forward_time += end - start

start = time.time()
output.sum().backward()
torch.cuda.synchronize()
end = time.time()
backward_time += end - start

forward_time /= num_iters
backward_time /= num_iters
return forward_time, backward_time


def benchmark_model(use_torch_compile=False):
if rank != 0:
return
config = GPTConfig()
model = Block(config).to(dtype).to("cuda")
if use_torch_compile:
model = torch.compile(model)

input = torch.rand(
batch_size, sequence_length, config.n_embd, dtype=dtype, device="cuda"
)
forward_time, backward_time = benchmark_loop(model, input)
print(f"torch.compile {not use_torch_compile}, Average forward time {forward_time}s, backward time {backward_time}s")


def benchmark_tensor_parallel(use_torch_compile=False):
config = GPTConfig()
# TODO: used the world size to scale the sizes in view operations.
# this was necessary in eager mode, but not for torch.compile
if not use_torch_compile:
config.n_devices = world_size
tp_model = Block(config).to(dtype).to("cuda")
if use_torch_compile:
tp_model = torch.compile(tp_model)

# Parallelization plan. Tensor parallel
tp_model = parallelize_module(
module=tp_model,
device_mesh=device_mesh,
parallelize_plan={
"attn.c_attn_key": ColwiseParallel(),
"attn.c_attn_query": ColwiseParallel(),
"attn.c_attn_value": ColwiseParallel(),
"attn.c_proj": RowwiseParallel(),
"mlp.c_fc": ColwiseParallel(),
"mlp.c_proj": RowwiseParallel(),
},
)
input = torch.rand(
batch_size, sequence_length, config.n_embd, dtype=dtype, device="cuda"
)

forward_time, backward_time = benchmark_loop(tp_model, input)
print(
f"{rank}: torch.compile {not use_torch_compile}, Average tensor parallel forward time {forward_time}s, backward time {backward_time}s"
)


benchmark_model()
benchmark_model(True)
benchmark_tensor_parallel()
benchmark_tensor_parallel(True)
dist.destroy_process_group()
Loading