-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
checkpointing/PGTransport: add NCCL/Gloo transport for checkpoints #110
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is really cool work!
|
||
def _prepare_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, _TensorMeta]: | ||
return ( | ||
_cast_tensor(tensor, torch.uint8), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, we're casting to uint8 to reduce memory pressure / speed up the transfer, but should we be concerned about any precision loss?
I see that transport_test.py
verifies closeness/correctness through run_multi_recovery_test
, but it isn't making sense to me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This just reinterprets the tensor as a bunch of bytes (hence the uint8) backed by the same UntypedStorage range of bytes. No bytes modified, so no loss of precision:
https://pytorch.org/docs/stable/storage.html#untyped-storage-api
@d4l3k I presume if you don't do this cast and just pass the original tensor object with its original dtype, then you do lose precision? Or is it just inconvenient on the recv side to interpret the tensor with its original dtype right away?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In most cases it doesn't matter and should result in a byte identical output.
There's a few trade-offs:
- Pro: not all tensor types are supported by NCCL
torch.uint16
for instance can't be sent via nccl so doing this cast to uint8 allows us to support any dtype - Con: arguably it's better to use the non-storage option to avoid sending duplicate/extra bytes for strided/offset tensors. If you have two tensors sharing the same underlying storage or a tensor that's strided in this implementation we end up sending twice as much data
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we just do tensor.view(torch.uint8)
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.view
doesn't work for strided tensors -- I'm not sure we need to support those but I think I'll leave it as is for now
def metadata(self) -> str: | ||
return "<n/a>" | ||
|
||
def disallow_checkpoint(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be implemented?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't support any async/out of band calls via PG so nothing needs to be done here
For HTTP we need to this to avoid serving a checkpoint during optimizer step but since send is synchronous in PG we don't need any additional synchronization
work.append(self._pg.send([t], dst_rank, tag=3 + i)) | ||
|
||
# allow 3 concurrent transfers at a time | ||
while len(work) > (3 * len(dst_ranks)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this used to avoid OOM?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's to avoid OOM when transferring between devices
i += 1 | ||
|
||
# TODO: allow in place receives to avoid having to copy to cpu to | ||
# avoid OOMs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should be able to avoid transferring TensorMeta and DTensorMeta and avoid to(cpu)
if we can first call state_dict()
to get the state_dict
structure and traverse the state_dict
and send/recv the tensor directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah we totally can do that in most cases -- I wanted to make that refactor in a follow up PR since I still need to figure out how to do that and this may be a decent fallback if we have some weird state dict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks pretty good! I had some comments and questions mostly about around readability and using existing infra, but otherwise looks solid enough to land
|
||
def _prepare_tensor(tensor: torch.Tensor) -> Tuple[torch.Tensor, _TensorMeta]: | ||
return ( | ||
_cast_tensor(tensor, torch.uint8), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we just do tensor.view(torch.uint8)
instead?
dtype: torch.dtype | ||
storage_offset: int | ||
stride: Tuple[int, ...] | ||
nbytes: int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not important now, but I wonder if we need to store quantization information and also wondering how thats handled in dtensor if you know
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@H-Huang do you know how quantized information is stored? Is it a different tensor subclass or just packed into the storage?
|
||
work = [] | ||
|
||
with _timeit("send pickle"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so we have send_object_list
and recv_object_list
https://github.com/pytorch/pytorch/blob/8b5ee275fb455156a944445fb92c43731369ace3/torch/distributed/distributed_c10d.py#L3181 which is what we use in PP to exchange shape metadata between stages to preallocate the recv buffers.
It is pretty similar since it pickles then sends object sizes, then the object data. I think yours may be more efficient since there are only 2 additional sends of metadata and the rest are the actual data. But wanted to flag in case we wanted to somehow consolidate some logic!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Yeah, I missed that before -- this is good to know though for right now I think this implementation is a bit more performant since it sends the same data to multiple receivers
If we wrap the underlying PG we also should be able to use the broadcast_object_list variant which should give us best of both worlds
I'm planning a follow up PR since to make a subworld we need to do some underlying improvements in how we calculate the recovering workers
|
||
|
||
@dataclass | ||
class _StateDictMeta: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking the _StateDictMeta
and dataclasses could probably use some more comments since that's pretty important in determining how we serialize / deserialize and being able to update them. I am kinda curious how DCP handles this metadata when it transfers and if we have existing structure we can use? @fegin @LucasLLC do you know?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added comments and updated field names to make it clearer
a8ee556
to
05202f3
Compare
This adds a ProcessGroup based CheckpointTransport to allow for transferring a state_dict via the backend network.
This supports NCCL with cuda devices and Gloo on CPU devices. DTensor is supported but other tensor subclasses will likely error.
This is cleaned up code from #104
Additional changes:
ProcessGroupBaby
: pass thetimeout
parameter to the subprocess so we can catch NCCL timeoutsCore algorithm:
Sender
These sends will send to each of the receiving peers simultaneously.
Receiver
Receiving is largely the same. Notably when receiving each tensor, a buffer is allocated on the device, received and then transferred to CPU to prevent CUDA ooms. Allocating and transferring is significantly slower than doing and inplace receive and is something we should fix in a follow up PR.
Test plan:
added a new shared multi rank recovery test and enabled it for PGTransport w/ NCCL+Gloo and the existing HTTPTransport