Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

import cv2 + nvidia/pytorch:22.09-py3 + DistributedDataParallel. (FIND was unable to find an engine) #5291

Closed
myron opened this issue Oct 7, 2022 · 16 comments · Fixed by #5293
Assignees
Labels
bug Something isn't working

Comments

@myron
Copy link
Collaborator

myron commented Oct 7, 2022

EDIT: the bug is reproducable in the newest nvidia/pytorch:22.09-py3 docker container, but is not reproducible in older container (older pytorch/cudnn)

Something in MetaTensor makes DistributedDataParallel fail
(this is in addition to this bug #5283)

For example this code fails

import torch.distributed as dist
import torch

from monai.data import MetaTensor
#from monai.config.type_definitions import NdarrayTensor

from torch.cuda.amp import autocast  
torch.autograd.set_detect_anomaly(True)

def main():

    ngpus_per_node = torch.cuda.device_count()
    torch.multiprocessing.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node,))

def main_worker(rank, ngpus_per_node):

    print(f"rank {rank}")

    dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:23456', world_size=ngpus_per_node, rank=rank)
    torch.backends.cudnn.benchmark = True

    model = torch.nn.Conv3d(in_channels=1, out_channels=32, kernel_size=3, bias=True).to(rank)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)

    x = torch.ones(1, 1, 192, 192, 192).to(rank)
    with autocast(enabled=True):
        out = model(x)

    print("Done.", out.shape)

if __name__ == "__main__":
    main()

with error

-- Process 6 terminated with the following error:                                                                                                               
Traceback (most recent call last):                                                                                                                              
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap                                                               
    fn(i, *args)                                                                                                                                                
  File "/mnt/amproj/Code/automl/tasks/hecktor22/autoconfig_segresnet/test_monai.py", line 29, in main_worker                                                    
    out = model(x)                                                                                                                                              
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1185, in _call_impl                                                            
    return forward_call(*input, **kwargs)                                                                                                                       
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1015, in forward                                                         
    output = self._run_ddp_forward(*inputs, **kwargs)                                                                                                           
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 976, in _run_ddp_forward                                                 
    return module_to_run(*inputs[0], **kwargs[0])                                                                                                               
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1185, in _call_impl                                                            
    return forward_call(*input, **kwargs)                                                                                                                       
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 613, in forward                                                                  
    return self._conv_forward(input, self.weight, self.bias)                                                                                                    
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 608, in _conv_forward
    return F.conv3d(
RuntimeError: FIND was unable to find an engine to execute this computation

The MetaTensor is actually never used/initialized here, but something it it (or it's imports) makes the code fail. Since we import MetaTensor everywhere, any code with it fails. I've traced it down to this import (inside of MetaTensor.py)
from monai.config.type_definitions import NdarrayTensor

importing this line also makes the code fail.

Somehow it confuses conv3d operation, and possibly other operations

@myron myron added the bug Something isn't working label Oct 7, 2022
@myron myron added this to the Auto3D Seg framework [P0 v1.0] milestone Oct 7, 2022
@wyli
Copy link
Contributor

wyli commented Oct 7, 2022

this might be related to the cuda/cudnn versions. but I can't reproduce locally, with cuda 10.2 or 11.7, cudnn 7605 or 8500. could you try to reproduce with a fresh environment and if possible report back python -c 'import monai; monai.config.print_debug_info()'?

@wyli wyli removed their assignment Oct 7, 2022
@wyli wyli changed the title MetaTensor and DistributedDataParallel. bug2 MetaTensor and DistributedDataParallel. bug2 FIND was unable to find an engine Oct 7, 2022
@myron
Copy link
Collaborator Author

myron commented Oct 7, 2022

This is on NVIDIA V100 16gb x 8 ngc instance,
using NVIDIA pytorch contrainer (either nvidian/pytorch:22.09-py3 or nvidia/pytorch:22.0-py3)
and latest MONAI 1.0.0 (via pip install monai)

Printing MONAI config...

MONAI version: 1.0.0
Numpy version: 1.22.2
Pytorch version: 1.13.0a0+d0d6b1f
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 170093375ce29267e45681fcec09dfa856e1d7e7
MONAI __file__: /opt/conda/lib/python3.8/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 4.0.2
scikit-image version: 0.19.3
Pillow version: 9.0.1
Tensorboard version: 2.10.0
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.14.0a0
tqdm version: 4.64.1
lmdb version: 1.3.0
psutil version: 5.9.2
pandas version: 1.4.4
einops version: 0.5.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies



Printing system config...

System: Linux
Linux version: Ubuntu 20.04.5 LTS
Platform: Linux-5.4.0-109-generic-x86_64-with-glibc2.10
Processor: x86_64
Machine: x86_64
Python version: 3.8.13
Process name: python
Command: ['python', '-c', 'import monai; monai.config.print_debug_info()']
Open files: []
Num physical CPUs: 40
Num logical CPUs: 80
Num usable CPUs: 80
CPU usage (%): [0.6, 0.3, 0.6, 1.3, 0.0, 0.6, 0.3, 0.6, 2.9, 1.0, 1.6, 0.6, 0.6, 1.6, 0.3, 0.6, 1.9, 0.3, 0.3, 0.6, 0.3, 0.0, 0.0, 0.3, 0.3, 0.3, 0.6, 0.6, 1.3, 7.3, 0.3, 0.9, 0.0, 0.3, 0.6, 0.3, 0.3, 0.3, 0.0, 1.6, 0.0, 0.0, 0.3, 1.3, 0.6, 1.3, 1.9, 1.3, 0.6, 7.3, 6.6, 0.6, 0.0, 0.6, 0.3, 0.3, 0.3, 0.3, 2.5, 0.9, 0.3, 0.0, 0.0, 0.0, 0.0, 0.3, 0.3, 0.0, 0.3, 0.3, 0.3, 0.3, 0.0, 0.3, 0.6, 0.3, 0.3, 0.3, 0.9, 100.0]
CPU freq. (MHz): 2694
Load avg. in last 1, 5, 15 mins (%): [1.6, 1.8, 2.3]
Disk usage (%): 46.2
Avg. sensor temp. (Celsius): UNKNOWN for given OS
Total physical memory (GB): 503.8
Available memory (GB): 484.9
Used memory (GB): 16.2


Printing GPU config...

Num GPUs: 8
Has CUDA: True
CUDA version: 11.8
cuDNN enabled: True
cuDNN version: 8600
Current device: 0
Library compiled for CUDA architectures: ['sm_52', 'sm_60', 'sm_61', 'sm_70', 'sm_75', 'sm_80', 'sm_86', 'sm_90', 'compute_90']
GPU 0 Name: Tesla V100-SXM2-16GB-N
GPU 0 Is integrated: False
GPU 0 Is multi GPU board: False
GPU 0 Multi processor count: 80
GPU 0 Total memory (GB): 15.8
GPU 0 CUDA capability (maj.min): 7.0
GPU 7 Name: Tesla V100-SXM2-16GB-N
GPU 7 Is integrated: False
GPU 7 Is multi GPU board: False
GPU 7 Multi processor count: 80
GPU 7 Total memory (GB): 15.8
GPU 7 CUDA capability (maj.min): 7.0

@wyli
Copy link
Contributor

wyli commented Oct 7, 2022

This is on NVIDIA V100 16gb x 8 ngc instance, using NVIDIA pytorch contrainer (either nvidian/pytorch:22.09-py3 or nvidia/pytorch:22.0-py3) and latest MONAI 1.0.0 (via pip install monai)

thanks, 22.09 hasn't been tested yet #5269 cc @Nic-Ma

@myron
Copy link
Collaborator Author

myron commented Oct 7, 2022

yeah, you're right 22.08 pytorch container is working fine, which includes

CUDA version: 11.7
cuDNN version: 8500
Pytorch version: 1.13.0a0+d321be6

So it's related to newer cudnn or newer pytorch (in combination with monai==1.0.0)

But even with newest 22.09 container , monai==0.9.0 is working fine (only the 1.0.0 fails)

@Nic-Ma
Copy link
Contributor

Nic-Ma commented Oct 8, 2022

Hi @myron , @wyli ,

I tried to execute the test program on V100-32G with MONAI latest and 22.09 docker, got below output:

root@apt-sh-ai:/workspace/data/medical/MONAI# python test_ddp.py 
rank 0
rank 1
2022-10-08 10:13:58,995 - Added key: store_based_barrier_key:1 to store for rank: 0
2022-10-08 10:13:59,005 - Added key: store_based_barrier_key:1 to store for rank: 1
2022-10-08 10:13:59,005 - Rank 1: Completed store-based barrier for key:store_based_barrier_key:1 with 2 nodes.
2022-10-08 10:13:59,005 - Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 2 nodes.
is_namedtuple is deprecated, please use the python checks instead
is_namedtuple is deprecated, please use the python checks instead
Traceback (most recent call last):
  File "test_ddp.py", line 32, in <module>
    main()
  File "test_ddp.py", line 13, in main
    torch.multiprocessing.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node,))
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 240, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/workspace/data/medical/MONAI/test_ddp.py", line 27, in main_worker
    out = model(x)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1185, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1015, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 976, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1185, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 613, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 608, in _conv_forward
    return F.conv3d(
RuntimeError: FIND was unable to find an engine to execute this computation

Thanks.

@Nic-Ma
Copy link
Contributor

Nic-Ma commented Oct 8, 2022

After further analysis, here is my finding:

  1. Any MONAI import will cause the error, for example, changing from monai.data import MetaTensor to from monai.config.deviceconfig import print_config also shows the error.
  2. If moving the import into subprocessing (in function main_worker()), then everything is fine.
  3. If changing the nccl to gloo, then everything is fine.

As any MONAI import will trigger lots of importing, maybe some CUDA related thing is not shareable in spawn multi-processing.

Thanks.

@myron
Copy link
Collaborator Author

myron commented Oct 8, 2022

Thanks @Nic-Ma . In my analysis, monai=0.9.0 works fine in the newest pytorch container, so it's something specific for monai=1.0.0

it seems that this header import alone causes it
from monai.config.type_definitions import NdarrayTensor

@wyli
Copy link
Contributor

wyli commented Oct 8, 2022

it seems it's triggered by import cv2, on driver 470.82.01 and nvcr.io/nvidia/pytorch:22.09-py3 (the root cause is not really from monai...perhaps we report this to the framework team instead).

To reproduce, launch nvcr.io/nvidia/pytorch:22.09-py3, and run python test.py, where test.py has the following content:

import torch.distributed as dist
import torch

import cv2

from torch.cuda.amp import autocast
torch.autograd.set_detect_anomaly(True)

def main():

    ngpus_per_node = torch.cuda.device_count()
    torch.multiprocessing.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node,))

def main_worker(rank, ngpus_per_node):

    dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:23456', world_size=ngpus_per_node, rank=rank)
    torch.backends.cudnn.benchmark = True

    model = torch.nn.Conv3d(in_channels=1, out_channels=32, kernel_size=3, bias=True).to(rank)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)

    x = torch.ones(1, 1, 192, 192, 192).to(rank)
    with autocast(enabled=True):
        out = model(x)

if __name__ == "__main__":
    main()

output:

root@3512928:/workspace# python test.py
/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/scatter_gather.py:9: UserWarning: is_namedtuple is deprecated, please use the python checks instead
  warnings.warn("is_namedtuple is deprecated, please use the python checks instead")
Traceback (most recent call last):
  File "test.py", line 27, in <module>
    main()
  File "test.py", line 12, in main
    torch.multiprocessing.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node,))
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 240, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/workspace/test.py", line 24, in main_worker
    out = model(x)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1185, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1015, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 976, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1185, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 613, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 608, in _conv_forward
    return F.conv3d(
RuntimeError: FIND was unable to find an engine to execute this computation

@wyli wyli changed the title MetaTensor and DistributedDataParallel. bug2 FIND was unable to find an engine import cv2 and DistributedDataParallel. (FIND was unable to find an engine) Oct 8, 2022
@myron
Copy link
Collaborator Author

myron commented Oct 9, 2022

I get the same error if I import
from monai.config.type_definitions import NdarrayTensor

instead of import cv2

and that import doesn't import cv2,
so it seems there are several ways to trigger this error

@Nic-Ma
Copy link
Contributor

Nic-Ma commented Oct 9, 2022

Hi @myron ,

The MONAI import logic is different, we import all the things even you only import one component:
https://github.com/Project-MONAI/MONAI/blob/dev/monai/__init__.py#L48
So it may call the import cv2 somewhere in the codebase, for example:
https://github.com/Project-MONAI/MONAI/blob/dev/monai/data/video_dataset.py#L28

Thanks.

@myron
Copy link
Collaborator Author

myron commented Oct 9, 2022

@Nic-Ma thanks for the reply. I see..

We should reconsider this logic. If someone wants to import only a small component, why do we need to import Everything. This seems slow, and can lead to bugs, which is "hard-to-debug" - like this bug, in the future.

@wyli
Copy link
Contributor

wyli commented Oct 9, 2022

The current import is not lazy for the first run, but it always walks through the modules in the same import ordering and easily avoids circular imports. I tried to make it optional but dont have an idea about dealing with the circular imports.

@Nic-Ma
Copy link
Contributor

Nic-Ma commented Oct 11, 2022

Hi @myron @wyli ,

After more analysis, I found that this issue only occurs when you set:
torch.backends.cudnn.benchmark = True
To unblock your work, I think you can remove this line or set it to False so far.

Thanks.

Nic-Ma added a commit that referenced this issue Oct 11, 2022
Fixes #5269 #5291 .

### Description

This PR updated the PyTorch base docker to 22.09.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: Nic Ma <[email protected]>
Signed-off-by: monai-bot <[email protected]>
Signed-off-by: Wenqi Li <[email protected]>
Co-authored-by: monai-bot <[email protected]>
Co-authored-by: Wenqi Li <[email protected]>
wyli added a commit that referenced this issue Oct 11, 2022
Fixes #5269 #5291 .

### Description

This PR updated the PyTorch base docker to 22.09.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: Nic Ma <[email protected]>
Signed-off-by: monai-bot <[email protected]>
Signed-off-by: Wenqi Li <[email protected]>
Co-authored-by: monai-bot <[email protected]>
Co-authored-by: Wenqi Li <[email protected]>
@myron
Copy link
Collaborator Author

myron commented Oct 11, 2022

thank you for the reply,

cudnn.benchmark selects the best kernel variant , if we don't use it we may have 20% performance drop. I don't think it's an acceptable long term solution.

I can use 22.08 container, until we find a solution (it's acceptable to me for now)

But we need a solution, that doesn't compromise the efficiency. And once again in monai=0.9.0 it's working fine with 22.09 with all these options, so it's something new in monai=1.0.0 that triggers it.

I don't think we can close this issue yet

@myron myron reopened this Oct 11, 2022
@myron myron changed the title import cv2 and DistributedDataParallel. (FIND was unable to find an engine) monai==1.0.0 + nvidia/pytorch:22.09-py3 + DistributedDataParallel. (FIND was unable to find an engine) Oct 11, 2022
@wyli
Copy link
Contributor

wyli commented Oct 11, 2022

It's already been addressed by #5293 (by not importing cv2, https://github.com/Project-MONAI/MONAI/blob/dev/monai/__init__.py#L50), with a test case included. What Nic mentions is a possible alternative solution in case cv2 is imported for some other purposes.

@wyli wyli closed this as completed Oct 11, 2022
@wyli wyli changed the title monai==1.0.0 + nvidia/pytorch:22.09-py3 + DistributedDataParallel. (FIND was unable to find an engine) import cv2 + nvidia/pytorch:22.09-py3 + DistributedDataParallel. (FIND was unable to find an engine) Oct 11, 2022
@myron
Copy link
Collaborator Author

myron commented Oct 11, 2022

I see, very good, thank you guys

KumoLiu pushed a commit that referenced this issue Nov 2, 2022
Fixes #5269 #5291 .

### Description

This PR updated the PyTorch base docker to 22.09.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: Nic Ma <[email protected]>
Signed-off-by: monai-bot <[email protected]>
Signed-off-by: Wenqi Li <[email protected]>
Co-authored-by: monai-bot <[email protected]>
Co-authored-by: Wenqi Li <[email protected]>
Signed-off-by: KumoLiu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants