Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[not for land] TE experiments, take 2 #614

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions parse_sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
Input: a subdirectory containing the logs from various experiments
Output: a csv file with loss values, peak memory usage, throughout from each experiment
"""

import csv
import os
import re

import fire

OUTPUT_FOLDER = '/home/vasiliy/local/tmp/torchtitan_outputs'

# example:
# [rank0]:[INFO | root ]: step: 10 loss: 7.8774 memory: 0.44GiB(0.47%) tps: 997,458 mfu: 1.50%
# note that number of spaces between terms can vary
regex = r"- step:[ ]+([\d]+).*loss:[ ]+([\d\.]+).*memory:[ ]+([\d\.]+)GiB.*tps: ([\d\,]+).*mfu.*"

def log_to_maybe_data(line):
res = re.search(regex, line)
if res is not None:
step, loss, memory_gib, wps = res.group(1), res.group(2), res.group(3), res.group(4)
return int(step), float(loss), float(memory_gib), int(wps.replace(',', ''))
else:
return None

def run(
subfolder_prefix: str,
results_filename: str,
):
subfolder_prefix = str(subfolder_prefix)

results = [['experiment', 'step', 'loss', 'memory_gib', 'tps']]

for entry in os.scandir(OUTPUT_FOLDER):
if entry.is_dir() and subfolder_prefix in entry.path:
print(entry)
log_fname = f"{entry.path}/logs.txt"
short_path = entry.path.replace(f"{OUTPUT_FOLDER}/", '')

with open(log_fname, 'r') as f:
lines = f.readlines()
for l in lines:
res = log_to_maybe_data(l)
if res is not None:
print(l.strip('\n'))
print(res)
results.append([short_path, *res])

with open(results_filename, 'w') as f:
writer = csv.writer(f)
writer.writerows(results)

print('done')

if __name__ == '__main__':
fire.Fire(run)
195 changes: 195 additions & 0 deletions test/test_te.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import copy

import torch
import torch.nn as nn

# path hack, TODO remove
import sys
sys.path.insert(0, '/home/vasiliy/local/torchtitan/')
import torchtitan.te_utils as te_utils
from torchtitan.models.norms import build_norm
from torchtitan.models.llama.model import FeedForward, Attention, ModelArgs, precompute_freqs_cis

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling

# torch.use_deterministic_algorithms(True)
torch.manual_seed(0)

fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
maybe_te_float8_ctx = te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe)

def test_linear_module_swap():
x = torch.randn(32, 32, device='cuda')

m = nn.Sequential(nn.Linear(32, 32)).cuda()
te_utils.swap_linear_to_te_linear(m)
print(m)
m = torch.compile(m)

with maybe_te_float8_ctx:
y = m(x)
y.sum().backward()

print('done')

# Subsection of TransformerBlock with only the ffn norm and the ffn
class NormFFNBlock(nn.Module):
def __init__(self, dim, hidden_dim, multiple_of):
super().__init__()
self.ffn_norm = build_norm("rmsnorm", dim, eps=1e-12)
self.feed_forward = FeedForward(dim, hidden_dim, multiple_of, None)

def forward(self, h):
out = h + self.feed_forward(self.ffn_norm(h))
return out

class NormAttnBlock(nn.Module):
def __init__(self, model_args):
super().__init__()
self.attention_norm = build_norm("rmsnorm", model_args.dim, eps=1e-12)
self.attention = Attention(model_args)
self.model_args = model_args
self.freqs_cis = precompute_freqs_cis(
self.model_args.dim // self.model_args.n_heads,
# Need to compute until at least the max token limit for generation
# TODO: explain in docs/composability.md why we removed the 2x
# relaxing in our CP enablement PR
self.model_args.max_seq_len,
self.model_args.rope_theta,
).cuda()

def forward(self, x):
x = self.attention_norm(x)
x = self.attention(x, self.freqs_cis)
return x

def SQNR(x, y):
return 20 * torch.log10(
torch.linalg.norm(x) / torch.linalg.norm(x - y)
)

def test_norm_attn_rewrite():
dim = 256
model_args = ModelArgs()
m = NormAttnBlock(model_args).cuda().bfloat16()
m_copy = copy.deepcopy(m)
te_utils.swap_norm_attn_to_te_friendly_norm_attn(m_copy)
print(m)

x = torch.randn(1, 128, model_args.dim).cuda().bfloat16()
x_copy = copy.deepcopy(x)

y = m(x)

y_copy = m_copy(x_copy)

print(torch.allclose(y, y_copy))
print(SQNR(y, y_copy))

te_utils.swap_te_friendly_norm_ffn_to_te_layernorm_linear(m_copy)
print(m)
y_copy2 = m_copy(x_copy)
print(torch.allclose(y_copy, y_copy2))
print(SQNR(y_copy, y_copy2))



def test_norm_ln_ffn_rewrite():
dim = 256
hidden_dim = 512
multiple_of = 1

x = torch.randn(1, 128, 256).cuda().bfloat16()
x_copy = copy.deepcopy(x)

m = NormFFNBlock(dim, hidden_dim, multiple_of).cuda().bfloat16()
m_copy = copy.deepcopy(m)
print(m)

y = m(x)
y.sum().backward()

te_utils.swap_norm_ffn_to_te_friendly_norm_ffn(m_copy)
print(m_copy)

y_copy = m_copy(x_copy)
y_copy.sum().backward()

# TODO: debug why not an exact match
print(torch.allclose(y, y_copy))
print(SQNR(y, y_copy))

# TODO test w13
# assert torch.allclose(m.ffn.w2.grad, m_copy.ffn.w2.grad, atol=0, rtol=0)

te_utils.swap_te_friendly_norm_ffn_to_te_layernorm_linear(m_copy)
print(m_copy)

y_copy2 = m_copy(x_copy)
print(torch.allclose(y_copy, y_copy2))
print(SQNR(y_copy, y_copy2))

def test_norm_mlp_ffn_rewrite():
dim = 256
hidden_dim = 512
multiple_of = 1

x = torch.randn(1, 128, 256).cuda().bfloat16()
x_copy = copy.deepcopy(x)

m = NormFFNBlock(dim, hidden_dim, multiple_of).cuda().bfloat16()
m_copy = copy.deepcopy(m)
print(m)

y = m(x)
y.sum().backward()

te_utils.swap_norm_ffn_to_te_friendly_norm_ffn(m_copy)
print(m_copy)

y_copy = m_copy(x_copy)
y_copy.sum().backward()

# TODO: debug why not an exact match
print(torch.allclose(y, y_copy))
print(SQNR(y, y_copy))

# TODO test w13
# assert torch.allclose(m.ffn.w2.grad, m_copy.ffn.w2.grad, atol=0, rtol=0)

te_utils.swap_te_friendly_norm_ffn_to_te_layernorm_mlp(m_copy)
print(123)
print(m_copy)

y_copy2 = m_copy(x_copy)
print(torch.allclose(y_copy, y_copy2))
print(SQNR(y_copy, y_copy2))

# works, so a bug in the swap above?
def test_split_linear():
M, K, N = 32, 64, 128
# M, K, N = 4, 6, 8

x = torch.randn(M, K)

fc1 = nn.Linear(K, N, bias=False)
fc2 = nn.Linear(K, N, bias=False)

fc3 = nn.Linear(K, N * 2, bias=False)
fc3.weight = torch.nn.Parameter(
torch.cat([copy.deepcopy(fc1.weight), copy.deepcopy(fc2.weight)], dim=0)
)

y1 = fc1(x)
y2 = fc2(x)
y3 = fc3(x)
y3_1, y3_2 = torch.split(y3, fc3.out_features // 2, dim=-1)

assert torch.allclose(y1, y3_1)
assert torch.allclose(y2, y3_2)


if __name__ == '__main__':
test()
71 changes: 71 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,77 @@ def __init__(self):
action="store_true",
help="Whether to compile the model",
)
self.parser.add_argument(
"--training.compile_ln_mlp",
action="store_true",
help="Whether to compile only the LNMLP blocks",
)
self.parser.add_argument(
"--training.compile_ln_linear",
action="store_true",
help="Whether to compile only the LNLinear blocks",
)
self.parser.add_argument(
"--training.compile_linear",
action="store_true",
help="Whether to compile only the LNLinear blocks",
)
self.parser.add_argument(
"--training.horizontally_fuse_fcs",
action="store_true",
help="""
If true, fuses ffn.fc1 and ffn.fc3 into ffn.fc13. Note that this is required
to use te.LayerNormLinear for FFNs.
TODO also implement this for attention.
""",
)
self.parser.add_argument(
"--training.te_swap_linear",
action="store_true",
help="""
If true, swaps torch.nn.Linear with te.Linear
(not for land)

Note:
* requires training.te_float8_autocast to use float8
""",
)
self.parser.add_argument(
"--training.te_swap_ln_linear",
action="store_true",
help="""
If true, swaps NormFeedForward.norm_w13 from
nn.Sequential(RMSNorm, nn.Linear) to te.LayerNormLinear
(not for land)

Note:
* requires training.horizontally_fuse_fcs to enable this swap
* this swap happens strictly before `training.te_swap_linear` if both are enabled
* requires training.te_float8_autocast to use float8
""",
)
self.parser.add_argument(
"--training.te_swap_ln_mlp",
action="store_true",
help="""
If true, swaps `NormFeedForward` to te.LayerNormMLP
(not for land)

Note:
* requires training.horizontally_fuse_fcs to enable this swap
* this swap happens strictly before `training.te_swap_linear` if both are enabled
* this swap happens strictly before `training.te_swap_ln_linear` if both are enabled
* requires training.te_float8_autocast to use float8
""",
)
self.parser.add_argument(
"--training.te_float8_autocast",
action="store_true",
help="""
If true, enables TE's float8 autocast context manager
(not for land)
""",
)
self.parser.add_argument(
"--training.gc_freq",
type=int,
Expand Down
15 changes: 14 additions & 1 deletion torchtitan/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _is_sm89_or_later():
class Float8Handler:
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
self.enabled = False
self.job_config = job_config

float8_config = job_config.float8
if not float8_config.enable_float8_linear:
Expand Down Expand Up @@ -92,16 +93,28 @@ def convert_to_float8_training(self, model: nn.Module):

from torchao.float8 import convert_to_float8_training

if self.job_config.training.compile_ln_linear:
# only convert compiled regions to float8
module_filter_fn=lambda mod, fqn: (fqn != "output" and "norm_" in fqn)
elif self.job_config.training.compile_ln_mlp:
# only convert compiled regions to float8
module_filter_fn=lambda mod, fqn: (fqn != "output" and "feed_forward" in fqn)
else:
module_filter_fn=lambda mod, fqn: fqn != "output"

# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
convert_to_float8_training(
model,
config=self.config,
module_filter_fn=lambda mod, fqn: fqn != "output",
module_filter_fn=module_filter_fn,
# module_filter_fn=lambda mod, fqn: fqn != "output",
# module_filter_fn=lambda mod, fqn: fqn != "output" and "norm_w13" in fqn,
)
logger.info(
"Swapped to Float8Linear layers with enable_fsdp_float8_all_gather="
f"{self.config.enable_fsdp_float8_all_gather}"
)
print(model)

def precompute_float8_dynamic_scale_for_fsdp(
self, model: Union[nn.Module, List[nn.Module]]
Expand Down
4 changes: 3 additions & 1 deletion torchtitan/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
}

llama3_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000),
# "debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000),
"debugmodel": ModelArgs(dim=256, n_layers=1, n_heads=16, rope_theta=500000),
"8B": ModelArgs(
dim=4096,
n_layers=32,
# n_layers=1,
n_heads=32,
n_kv_heads=8,
ffn_dim_multiplier=1.3,
Expand Down
7 changes: 6 additions & 1 deletion torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,12 @@ def init_weights(self):
for norm in (self.attention_norm, self.ffn_norm):
norm.reset_parameters()
self.attention.init_weights(self.weight_init_std)
self.feed_forward.init_weights(self.weight_init_std)
if 'LayerNormMLP' in str(type(self.feed_forward)):
torch.nn.init.ones_(self.feed_forward.layer_norm_weight)
torch.nn.init.trunc_normal_(self.feed_forward.fc1_weight, mean=0.0, std=self.weight_init_std)
torch.nn.init.trunc_normal_(self.feed_forward.fc2_weight, mean=0.0, std=self.weight_init_std)
else:
self.feed_forward.init_weights(self.weight_init_std)


class Transformer(nn.Module):
Expand Down
Loading
Loading