-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
import cv2 + nvidia/pytorch:22.09-py3 + DistributedDataParallel. (FIND was unable to find an engine) #5291
Comments
this might be related to the cuda/cudnn versions. but I can't reproduce locally, with cuda 10.2 or 11.7, cudnn 7605 or 8500. could you try to reproduce with a fresh environment and if possible report back |
This is on NVIDIA V100 16gb x 8 ngc instance,
|
yeah, you're right 22.08 pytorch container is working fine, which includes
So it's related to newer cudnn or newer pytorch (in combination with monai==1.0.0) But even with newest 22.09 container , monai==0.9.0 is working fine (only the 1.0.0 fails) |
I tried to execute the test program on V100-32G with MONAI latest and 22.09 docker, got below output:
Thanks. |
After further analysis, here is my finding:
As any MONAI import will trigger lots of importing, maybe some CUDA related thing is not shareable in spawn multi-processing. Thanks. |
Thanks @Nic-Ma . In my analysis, monai=0.9.0 works fine in the newest pytorch container, so it's something specific for monai=1.0.0 it seems that this header import alone causes it |
it seems it's triggered by To reproduce, launch import torch.distributed as dist
import torch
import cv2
from torch.cuda.amp import autocast
torch.autograd.set_detect_anomaly(True)
def main():
ngpus_per_node = torch.cuda.device_count()
torch.multiprocessing.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node,))
def main_worker(rank, ngpus_per_node):
dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:23456', world_size=ngpus_per_node, rank=rank)
torch.backends.cudnn.benchmark = True
model = torch.nn.Conv3d(in_channels=1, out_channels=32, kernel_size=3, bias=True).to(rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)
x = torch.ones(1, 1, 192, 192, 192).to(rank)
with autocast(enabled=True):
out = model(x)
if __name__ == "__main__":
main() output:
|
I get the same error if I import instead of and that import doesn't import cv2, |
Hi @myron , The MONAI import logic is different, we import all the things even you only import one component: Thanks. |
@Nic-Ma thanks for the reply. I see.. We should reconsider this logic. If someone wants to import only a small component, why do we need to import Everything. This seems slow, and can lead to bugs, which is "hard-to-debug" - like this bug, in the future. |
The current import is not lazy for the first run, but it always walks through the modules in the same import ordering and easily avoids circular imports. I tried to make it optional but dont have an idea about dealing with the circular imports. |
Fixes #5269 #5291 . ### Description This PR updated the PyTorch base docker to 22.09. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Nic Ma <[email protected]> Signed-off-by: monai-bot <[email protected]> Signed-off-by: Wenqi Li <[email protected]> Co-authored-by: monai-bot <[email protected]> Co-authored-by: Wenqi Li <[email protected]>
Fixes #5269 #5291 . ### Description This PR updated the PyTorch base docker to 22.09. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Nic Ma <[email protected]> Signed-off-by: monai-bot <[email protected]> Signed-off-by: Wenqi Li <[email protected]> Co-authored-by: monai-bot <[email protected]> Co-authored-by: Wenqi Li <[email protected]>
thank you for the reply, cudnn.benchmark selects the best kernel variant , if we don't use it we may have 20% performance drop. I don't think it's an acceptable long term solution. I can use 22.08 container, until we find a solution (it's acceptable to me for now) But we need a solution, that doesn't compromise the efficiency. And once again in monai=0.9.0 it's working fine with 22.09 with all these options, so it's something new in monai=1.0.0 that triggers it. I don't think we can close this issue yet |
It's already been addressed by #5293 (by not importing |
I see, very good, thank you guys |
Fixes #5269 #5291 . ### Description This PR updated the PyTorch base docker to 22.09. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Nic Ma <[email protected]> Signed-off-by: monai-bot <[email protected]> Signed-off-by: Wenqi Li <[email protected]> Co-authored-by: monai-bot <[email protected]> Co-authored-by: Wenqi Li <[email protected]> Signed-off-by: KumoLiu <[email protected]>
EDIT: the bug is reproducable in the newest nvidia/pytorch:22.09-py3 docker container, but is not reproducible in older container (older pytorch/cudnn)
Something in MetaTensor makes DistributedDataParallel fail
(this is in addition to this bug #5283)
For example this code fails
with error
The MetaTensor is actually never used/initialized here, but something it it (or it's imports) makes the code fail. Since we import MetaTensor everywhere, any code with it fails. I've traced it down to this import (inside of MetaTensor.py)
from monai.config.type_definitions import NdarrayTensor
importing this line also makes the code fail.
Somehow it confuses conv3d operation, and possibly other operations
The text was updated successfully, but these errors were encountered: