Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

checkpointing/PGTransport: add NCCL/Gloo transport for checkpoints #110

Merged
merged 1 commit into from
Feb 14, 2025

Conversation

d4l3k
Copy link
Member

@d4l3k d4l3k commented Feb 13, 2025

This adds a ProcessGroup based CheckpointTransport to allow for transferring a state_dict via the backend network.

This supports NCCL with cuda devices and Gloo on CPU devices. DTensor is supported but other tensor subclasses will likely error.

This is cleaned up code from #104

Additional changes:

  • ProcessGroupBaby: pass the timeout parameter to the subprocess so we can catch NCCL timeouts
  • refactored http_transport_test

Core algorithm:

Sender

  1. preprocess the state_dict using pytree to separate the tensors from python objects
  2. serialize the tensor metadata and the non-python objects using pickle
  3. send the metadata size via send/recv
  4. send the pickled buffer via send/recv
  5. send each tensor via send/recv

These sends will send to each of the receiving peers simultaneously.

Receiver

Receiving is largely the same. Notably when receiving each tensor, a buffer is allocated on the device, received and then transferred to CPU to prevent CUDA ooms. Allocating and transferring is significantly slower than doing and inplace receive and is something we should fix in a follow up PR.

Test plan:

added a new shared multi rank recovery test and enabled it for PGTransport w/ NCCL+Gloo and the existing HTTPTransport

pytest torchft/checkpointing/

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 13, 2025
Copy link
Contributor

@allenwang28 allenwang28 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is really cool work!

torchft/checkpointing/transport_test.py Show resolved Hide resolved

def _prepare_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, _TensorMeta]:
return (
_cast_tensor(tensor, torch.uint8),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, we're casting to uint8 to reduce memory pressure / speed up the transfer, but should we be concerned about any precision loss?

I see that transport_test.py verifies closeness/correctness through run_multi_recovery_test, but it isn't making sense to me!

Copy link

@daulet-askarov daulet-askarov Feb 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just reinterprets the tensor as a bunch of bytes (hence the uint8) backed by the same UntypedStorage range of bytes. No bytes modified, so no loss of precision:
https://pytorch.org/docs/stable/storage.html#untyped-storage-api
@d4l3k I presume if you don't do this cast and just pass the original tensor object with its original dtype, then you do lose precision? Or is it just inconvenient on the recv side to interpret the tensor with its original dtype right away?

Copy link
Member Author

@d4l3k d4l3k Feb 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In most cases it doesn't matter and should result in a byte identical output.

There's a few trade-offs:

  • Pro: not all tensor types are supported by NCCL torch.uint16 for instance can't be sent via nccl so doing this cast to uint8 allows us to support any dtype
  • Con: arguably it's better to use the non-storage option to avoid sending duplicate/extra bytes for strided/offset tensors. If you have two tensors sharing the same underlying storage or a tensor that's strided in this implementation we end up sending twice as much data

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just do tensor.view(torch.uint8) instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.view doesn't work for strided tensors -- I'm not sure we need to support those but I think I'll leave it as is for now

def metadata(self) -> str:
return "<n/a>"

def disallow_checkpoint(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be implemented?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't support any async/out of band calls via PG so nothing needs to be done here

For HTTP we need to this to avoid serving a checkpoint during optimizer step but since send is synchronous in PG we don't need any additional synchronization

work.append(self._pg.send([t], dst_rank, tag=3 + i))

# allow 3 concurrent transfers at a time
while len(work) > (3 * len(dst_ranks)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used to avoid OOM?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's to avoid OOM when transferring between devices

i += 1

# TODO: allow in place receives to avoid having to copy to cpu to
# avoid OOMs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to avoid transferring TensorMeta and DTensorMeta and avoid to(cpu) if we can first call state_dict() to get the state_dict structure and traverse the state_dict and send/recv the tensor directly.

Copy link
Member Author

@d4l3k d4l3k Feb 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we totally can do that in most cases -- I wanted to make that refactor in a follow up PR since I still need to figure out how to do that and this may be a decent fallback if we have some weird state dict

Copy link
Member

@H-Huang H-Huang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks pretty good! I had some comments and questions mostly about around readability and using existing infra, but otherwise looks solid enough to land

torchft/checkpointing/pg_transport.py Outdated Show resolved Hide resolved
torchft/checkpointing/pg_transport.py Outdated Show resolved Hide resolved

def _prepare_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, _TensorMeta]:
return (
_cast_tensor(tensor, torch.uint8),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just do tensor.view(torch.uint8) instead?

dtype: torch.dtype
storage_offset: int
stride: Tuple[int, ...]
nbytes: int
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not important now, but I wonder if we need to store quantization information and also wondering how thats handled in dtensor if you know

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@H-Huang do you know how quantized information is stored? Is it a different tensor subclass or just packed into the storage?


work = []

with _timeit("send pickle"):
Copy link
Member

@H-Huang H-Huang Feb 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we have send_object_list and recv_object_list https://github.com/pytorch/pytorch/blob/8b5ee275fb455156a944445fb92c43731369ace3/torch/distributed/distributed_c10d.py#L3181 which is what we use in PP to exchange shape metadata between stages to preallocate the recv buffers.

It is pretty similar since it pickles then sends object sizes, then the object data. I think yours may be more efficient since there are only 2 additional sends of metadata and the rest are the actual data. But wanted to flag in case we wanted to somehow consolidate some logic!

Copy link
Member Author

@d4l3k d4l3k Feb 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Yeah, I missed that before -- this is good to know though for right now I think this implementation is a bit more performant since it sends the same data to multiple receivers

If we wrap the underlying PG we also should be able to use the broadcast_object_list variant which should give us best of both worlds

I'm planning a follow up PR since to make a subworld we need to do some underlying improvements in how we calculate the recovering workers



@dataclass
class _StateDictMeta:
Copy link
Member

@H-Huang H-Huang Feb 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking the _StateDictMeta and dataclasses could probably use some more comments since that's pretty important in determining how we serialize / deserialize and being able to update them. I am kinda curious how DCP handles this metadata when it transfers and if we have existing structure we can use? @fegin @LucasLLC do you know?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comments and updated field names to make it clearer

@d4l3k d4l3k force-pushed the d4l3k/pg_transport branch from a8ee556 to 05202f3 Compare February 14, 2025 21:16
@d4l3k d4l3k merged commit 8628a3f into main Feb 14, 2025
6 checks passed
@d4l3k d4l3k deleted the d4l3k/pg_transport branch February 14, 2025 21:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants