Skip to content

Commit

Permalink
add: fsdp module
Browse files Browse the repository at this point in the history
  • Loading branch information
Secbone committed Apr 18, 2024
1 parent 92e94a3 commit 19b2bbc
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 9 deletions.
Empty file added toad/nn/distributed/__init__.py
Empty file.
67 changes: 60 additions & 7 deletions toad/nn/distributed/fsdp.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import torch

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
CPUOffload,
Expand All @@ -11,16 +13,67 @@


class FSDPModule(FSDP):
"""distributed module class
"""FSDP module class
"""
def fit(self, *args, **kwargs):
return self.module.fit(*args, **kwargs)
def __init__(self, module, policy = None, *args, **kwargs):
import functools
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
ModuleWrapPolicy,
enable_wrap,
wrap,
)
my_auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=10,
)

super().__init__(
module,
auto_wrap_policy = my_auto_wrap_policy,
# auto_wrap_policy = policy,
*args,
**kwargs
)


def fit(self, loader, trainer = None, optimizer = None, loss = None, early_stopping = None, **kwargs):
"""train model
Args:
loader (DataLoader): loader for training model
trainer (Trainer): trainer for training model
optimizer (torch.Optimier): the default optimizer is `Adam(lr = 1e-3)`
loss (Callable): could be called as 'loss(y_hat, y)'
early_stopping (earlystopping): the default value is `loss_earlystopping`,
you can set it to `False` to disable early stopping
epoch (int): number of epoch for training loop
callback (callable): callable function will be called every epoch
"""
if trainer is None:
from ..trainer import Trainer
trainer = Trainer(self, loader, optimizer = optimizer, loss = loss, early_stopping = early_stopping)
trainer.fit_step(self.module.__class__.fit_step)

trainer.train(**kwargs)

def save(self, path):
"""save shards state dict
"""
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType

with FSDP.state_dict_type(self, StateDictType.SHARDED_STATE_DICT):
torch.save(self.state_dict(), path)

def save(self, *args, **kwargs):
return self.module.save(*args, **kwargs)
def load(self, path, *args, **kwargs):
"""load shards state dict
"""
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType

def load(self, *args, **kwargs):
return self.module.load(*args, **kwargs)
with FSDP.state_dict_type(self, StateDictType.SHARDED_STATE_DICT):
self.load_state_dict(torch.load(path))

return self


def log(self, *args, **kwargs):
return self.module.log(*args, **kwargs)
106 changes: 106 additions & 0 deletions toad/nn/distributed/fsdp_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from ..module import Module
from .fsdp import FSDPModule

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

from toad.utils.progress import Progress



class TestModel(Module):
def __init__(self, in_feats, out_feats):
super().__init__()

self.linear = nn.Linear(in_feats, out_feats)

def forward(self, x):
x = self.linear(x)
return F.relu(x)

def fit_step(self, batch):
x, y = batch
y_hat = self(x)
# return F.cross_entropy(y_hat, y)
return F.mse_loss(y_hat, y)

def worker(rank, world):
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
ModuleWrapPolicy,
)

torch.manual_seed(0)

NUM_FEATS = 4096
NUM_CLASSES = 1024
DATASET_SIZE = 10000


X = torch.rand(DATASET_SIZE, NUM_FEATS, dtype = torch.float)
# y = torch.randint(NUM_CLASSES, size = (DATASET_SIZE,), dtype = torch.long)

NUM_CLASSES = 1
y = torch.sum(X, dim = 1)

loader = DataLoader(
TensorDataset(X, y),
batch_size = 128,
shuffle = True,
)

model = TestModel(NUM_FEATS, NUM_CLASSES)
# print(next(model.linear.parameters()).shape)

model.distributed(backend = "gloo", rank = rank, world_size = world)

fdsp_model = FSDPModule(
model,
# sync_module_states = True,
# auto_wrap_policy = my_auto_wrap_policy,
# policy = ModuleWrapPolicy([nn.Linear,]),
device_id=torch.device("cpu"),
)

optimizer = optim.Adam(fdsp_model.parameters(), lr = 1e-3)

state_path = f"data/fsdp_model_{rank}.pkl"

fdsp_model.load(state_path)

print('before fit:', fdsp_model(X[0]).sum())

# inputs = torch.rand(10, features_dim)
fdsp_model.fit(loader, epoch = 20, early_stopping = False)

print('after fit:', fdsp_model(X[0]).sum())

print(fdsp_model)
# print(fdsp_model.flatten_sharded_optim_state_dict())

# out = fdsp_model(inputs).sum()

# out.backward()

# print("~~~~~", out)

print(model)
model.save(f"data/origin_model_{rank}.pkl")

fdsp_model.save(state_path)




def test_fsdp_model():
import torch.multiprocessing as mp

import os
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'

mp.spawn(worker, args=(2,), nprocs=2, join=True)

5 changes: 3 additions & 2 deletions toad/nn/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def __init__(self, model, loader = None, optimizer = None, loss = None, keep_his
from .earlystop import loss_stopping
early_stopping = loss_stopping()

# self.early_stop = early_stopping
self.register("earlystop:check", early_stopping)
if early_stopping is not False:
self.register("earlystop:check", early_stopping)

from collections import deque
self.history = deque(maxlen = keep_history)
Expand Down Expand Up @@ -283,6 +283,7 @@ def train_loop(trainer, model, loader, epoch = 10, start = 0, backward_rounds =

# log loss
history.log('loss', l)

backward_loss = l + backward_loss
if i % backward_rounds == 0 or i == len(p):
trainer.optimizer.zero_grad()
Expand Down

0 comments on commit 19b2bbc

Please sign in to comment.