Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Pytorch Memory Keeps Creeping up in Long running Jobs #79

Closed
VibhuJawa opened this issue Aug 27, 2024 · 8 comments
Closed

[BUG] Pytorch Memory Keeps Creeping up in Long running Jobs #79

VibhuJawa opened this issue Aug 27, 2024 · 8 comments

Comments

@VibhuJawa
Copy link
Member

While running classifiers like the quality classifier: https://github.com/NVIDIA/NeMo-Curator/blob/main/nemo_curator/classifiers/quality.py for long running jobs we seem to keep creeping up in Pytorch Memory Allocations.

@yury-tokpanov was kind enough to run a job for a long time using #78 .

At the end of first partition we have around 1639 Mib being allocated

Image

After a few hours we have the Pytorch memory which crept up to 36746 MiB.

Image

This is a problem for data annotation at foundation model scale as that can cause OOM in those jobs

@ayushdg
Copy link
Member

ayushdg commented Aug 29, 2024

cc: @arhamm1

@yury-tokpanov
Copy link

Thanks, @VibhuJawa for filing this. It seems like it also depends on the size of the Dask partitions - the bigger they are, the faster memory usage grows (for the same number of documents).

@VibhuJawa
Copy link
Member Author

VibhuJawa commented Sep 4, 2024

This also seems to be happening when using rmm together with pytorch allocator.

cluster = LocalCUDACluster(rmm_pool_size="25GB", rmm_async=False)
set_torch_to_use_rmm()
client = Client(cluster)

def enable_rmm_stats():
    import rmm
    rmm.statistics.enable_statistics()

client.run(set_torch_to_use_rmm)
client.run(enable_rmm_stats)
 from typing import Optional
 
 import cudf
+import rmm
 import cupy as cp
 import torch
 
@@ -54,11 +55,14 @@ class Predictor(Op):
 
     @torch.no_grad()
     def call(self, data, partition_info=None):
+        cleanup_torch_cache()
+        start_stats = rmm.statistics.get_statistics()
+        print("RMM MB allocations in the start:", start_stats.current_bytes/(1024 * 1024), flush=True)
         # Get the current CUDA device
-        current_device = torch.cuda.current_device()
+        # current_device = torch.cuda.current_device()
         # Print CUDA memory at the beginning of the method
-        print(f"CUDA memory at start (device {current_device}):")
-        print(torch.cuda.memory_summary(device=current_device))
+        # print(f"CUDA memory at start (device {current_device}):")
+        # print(torch.cuda.memory_summary(device=current_device))
 
         index = data.index.copy()
         if self.sorted_data_loader:
@@ -107,10 +111,11 @@ class Predictor(Op):
             raise RuntimeError(f"Unexpected output shape: {output.shape}")
         del outputs, _index
         cleanup_torch_cache()
-
+        end_stats = rmm.statistics.get_statistics()
+        print("RMM MB allocations in the end:", end_stats.current_bytes/(1024 * 1024), flush=True)        
         # Print CUDA memory at the end of the method
-        print(f"CUDA memory at end (device {current_device}):")
-        print(torch.cuda.memory_summary(device=current_device))
+        # print(f"CUDA memory at end (device {current_device}):")
+        # print(torch.cuda.memory_summary(device=current_device))
         return out
Reading 4000 files
Starting domain classifier inference
RMM MB allocations in the start: 48.897552490234375
GPU: 0, Part: 0: 100%|██████████| 5908[/5908](http://10.33.227.161:8888/5908) [00:28<00:00, 207.99it[/s](http://10.33.227.161:8888/s)]
RMM MB allocations in the end: 782.1744155883789
Writing to disk complete for 1 partitions
---------------------------------------------------------------------------------------------------
Reading 4000 files
Starting domain classifier inference
RMM MB allocations in the start: 781.4426956176758
GPU: 0, Part: 0: 100%|██████████| 5908[/5908](http://10.33.227.161:8888/5908) [00:25<00:00, 227.37it[/s](http://10.33.227.161:8888/s)]
RMM MB allocations in the end: 805.252571105957
Writing to disk complete for 1 partitions
---------------------------------------------------------------------------------------------------
Reading 4000 files
Starting domain classifier inference
RMM MB allocations in the start: 804.5208206176758
GPU: 0, Part: 0: 100%|██████████| 5908[/5908](http://10.33.227.161:8888/5908) [00:26<00:00, 226.74it[/s](http://10.33.227.161:8888/s)]
RMM MB allocations in the end: 828.330696105957
Writing to disk complete for 1 partitions
---------------------------------------------------------------------------------------------------
Reading 4000 files
Starting domain classifier inference
RMM MB allocations in the start: 827.5989456176758
GPU: 0, Part: 0: 100%|██████████| 5908[/5908](http://10.33.227.161:8888/5908) [00:28<00:00, 204.00it[/s](http://10.33.227.161:8888/s)]
RMM MB allocations in the end: 851.408821105957
Writing to disk complete for 1 partitions
---------------------------------------------------------------------------------------------------
Reading 4000 files
Starting domain classifier inference
RMM MB allocations in the start: 850.6770706176758
GPU: 0, Part: 0: 100%|██████████| 5908[/5908](http://10.33.227.161:8888/5908) [00:26<00:00, 226.56it[/s](http://10.33.227.161:8888/s)]
RMM MB allocations in the end: 874.486946105957
Writing to disk complete for 1 partitions
---------------------------------------------------------------------------------------------------
Reading 4000 files
Starting domain classifier inference
RMM MB allocations in the start: 873.7551956176758
GPU: 0, Part: 0:  52%|█████▏    | 3072[/5908](http://10.33.227.161:8888/5908) [00:15<00:14, 190.38it[/s](http://10.33.227.161:8888/s)]

@VibhuJawa
Copy link
Member Author

I think this is the potential issue MRE:

import gc
import cudf
import rmm
from rmm.allocators.torch import rmm_torch_allocator
import torch
from crossfit.backend.torch.loader import SortedSeqLoader
from nemo_curator.classifiers import DomainClassifier

torch.cuda.memory.change_current_allocator(rmm_torch_allocator)
rmm.statistics.enable_statistics()
classifier = DomainClassifier(batch_size=1024)
model = classifier.model
data = cudf.read_parquet("loader_helper.parquet")

for _ in range(0,10):
    stats = rmm.statistics.get_statistics()
    print("RMM MB allocations before creating loader:", stats.current_bytes/(1024 * 1024), flush=True)
    loader = SortedSeqLoader(
        data[["input_ids", "attention_mask"]],
        model=model,
        initial_batch_size=1024
    )
   
    del loader
    gc.collect()
    stats = rmm.statistics.get_statistics()
    print("RMM MB allocations after creating loader:", stats.current_bytes/(1024 * 1024), flush=True)
    print("---"*33)
RMM MB allocations before creating loader: 48.895599365234375
RMM MB allocations after creating loader: 71.97381591796875
---------------------------------------------------------------------------------------------------
RMM MB allocations before creating loader: 71.97381591796875
RMM MB allocations after creating loader: 95.05194091796875
---------------------------------------------------------------------------------------------------
RMM MB allocations before creating loader: 95.05194091796875
RMM MB allocations after creating loader: 118.13006591796875
---------------------------------------------------------------------------------------------------
RMM MB allocations before creating loader: 118.13006591796875
RMM MB allocations after creating loader: 141.20819091796875
---------------------------------------------------------------------------------------------------
RMM MB allocations before creating loader: 141.20819091796875
RMM MB allocations after creating loader: 164.28631591796875
---------------------------------------------------------------------------------------------------
RMM MB allocations before creating loader: 164.28631591796875
RMM MB allocations after creating loader: 187.36444091796875
---------------------------------------------------------------------------------------------------
RMM MB allocations before creating loader: 187.36444091796875
RMM MB allocations after creating loader: 210.44256591796875
---------------------------------------------------------------------------------------------------
RMM MB allocations before creating loader: 210.44256591796875
RMM MB allocations after creating loader: 233.52069091796875
---------------------------------------------------------------------------------------------------

@VibhuJawa
Copy link
Member Author

import gc
import cudf
import rmm
from rmm.allocators.torch import rmm_torch_allocator
import torch


torch.cuda.memory.change_current_allocator(rmm_torch_allocator)
rmm.statistics.enable_statistics()


from crossfit.data.dataframe.dispatch import CrossFrame


data = cudf.read_parquet("loader_helper.parquet")
for _ in range(0,10):
    stats = rmm.statistics.get_statistics()
    print("RMM MB allocations before assign", stats.current_bytes/(1024 * 1024), flush=True)
    frame = CrossFrame(data[["input_ids", "attention_mask"]]).cast(torch.Tensor)
    seq_length = (frame["input_ids"] != 1).sum(axis=1)
    sorted_indices = seq_length.argsort(descending=True)
    frame = frame.apply(lambda x: x[sorted_indices])
    frame = frame.assign(seq_length=seq_length[sorted_indices])
    stats = rmm.statistics.get_statistics()
    print("RMM MB allocations after assign", stats.current_bytes/(1024 * 1024), flush=True)
RMM MB allocations before assign 48.895599365234375
RMM MB allocations after assign 72.10903930664062
RMM MB allocations before assign 72.10903930664062
RMM MB allocations after assign 95.18716430664062
RMM MB allocations before assign 95.18716430664062
RMM MB allocations after assign 118.26528930664062
RMM MB allocations before assign 118.26528930664062
RMM MB allocations after assign 141.34341430664062
RMM MB allocations before assign 141.34341430664062
RMM MB allocations after assign 164.42153930664062
RMM MB allocations before assign 164.42153930664062
RMM MB allocations after assign 187.49966430664062
RMM MB allocations before assign 187.49966430664062
RMM MB allocations after assign 210.57778930664062
RMM MB allocations before assign 210.57778930664062
RMM MB allocations after assign 233.65591430664062
RMM MB allocations before assign 233.65591430664062
RMM MB allocations after assign 256.7340393066406
RMM MB allocations before assign 256.7340393066406
RMM MB allocations after assign 279.8121643066406

@VibhuJawa
Copy link
Member Author

A much smaller MRE:

import gc
import cudf
import rmm
from rmm.allocators.torch import rmm_torch_allocator
import torch

torch.cuda.memory.change_current_allocator(rmm_torch_allocator)
rmm.statistics.enable_statistics()

from crossfit.data.dataframe.dispatch import CrossFrame
data = cudf.read_parquet("loader_helper.parquet")
frame = CrossFrame(data[["input_ids", "attention_mask"]]).cast(torch.Tensor)

def assign(frame, **kwargs):
    data = frame.data.copy()
    # Uncommenting below produces memory leak
    for k, v in kwargs.items():
        if frame.columns and len(v) != len(frame):
            raise ValueError(f"Column {k} was length {len(v)}, but expected length {len(frame)}")
    data.update(**kwargs)
    return frame.__class__(data)

for _ in range(0,10):    
    stats = rmm.statistics.get_statistics()    
    print("RMM MB allocations before assign", stats.current_bytes/(1024 * 1024), flush=True)    
    seq_length = (frame["input_ids"] != 1).sum(axis=1)
    sorted_indices = seq_length.argsort(descending=True)
    frame = frame.apply(lambda x: x[sorted_indices])
    frame = assign(frame, seq_length=seq_length[sorted_indices])
    stats = rmm.statistics.get_statistics()
    print("RMM MB allocations after assign", stats.current_bytes/(1024 * 1024), flush=True)
RMM MB allocations before assign 48.89569091796875
RMM MB allocations after assign 72.10903930664062
RMM MB allocations before assign 72.10903930664062
RMM MB allocations after assign 95.23223876953125
RMM MB allocations before assign 95.23223876953125
RMM MB allocations after assign 118.35543823242188
RMM MB allocations before assign 118.35543823242188
RMM MB allocations after assign 141.4786376953125
RMM MB allocations before assign 141.4786376953125
RMM MB allocations after assign 164.60183715820312
RMM MB allocations before assign 164.60183715820312
RMM MB allocations after assign 187.72503662109375
RMM MB allocations before assign 187.72503662109375
RMM MB allocations after assign 210.84823608398438
RMM MB allocations before assign 210.84823608398438
RMM MB allocations after assign 233.971435546875
RMM MB allocations before assign 233.971435546875
RMM MB allocations after assign 257.0946350097656
RMM MB allocations before assign 257.0946350097656
RMM MB allocations after assign 280.21783447265625

Somehow below produces a memory leak:

    for k, v in kwargs.items():
        if frame.columns and len(v) != len(frame):
            raise ValueError(f"Column {k} was length {len(v)}, but expected length {len(frame)}")

@VibhuJawa
Copy link
Member Author

With Fix: https://github.com/rapidsai/crossfit/pull/80/files , we seem to not grow in memory.

After 10 iterations, the memory remains consistent around 727434 KiB while previously it grew linearly for me .

After 1st iteration:

Starting domain classifier inference
CUDA memory at end (device 0):
|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 4            |        cudaMalloc retries: 5         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      | 727515 KiB |  27916 MiB |   4121 GiB |   4120 GiB |
|       from large pool | 726872 KiB |  27915 MiB |   4119 GiB |   4118 GiB |
|       from small pool |    643 KiB |      5 MiB |      2 GiB |      2 GiB |
|---------------------------------------------------------------------------|
| Active memory         | 727515 KiB |  27916 MiB |   4121 GiB |   4120 GiB |
|       from large pool | 726872 KiB |  27915 MiB |   4119 GiB |   4118 GiB |
|       from small pool |    643 KiB |      5 MiB |      2 GiB |      2 GiB |
|---------------------------------------------------------------------------|
| Requested memory      | 726574 KiB |  27902 MiB |   4121 GiB |   4120 GiB |
|       from large pool | 725932 KiB |  27901 MiB |   4119 GiB |   4118 GiB |
|       from small pool |    642 KiB |      5 MiB |      2 GiB |      2 GiB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |    782 MiB |  30386 MiB |  39616 MiB |  38834 MiB |
|       from large pool |    780 MiB |  30378 MiB |  39594 MiB |  38814 MiB |
|       from small pool |      2 MiB |      8 MiB |     22 MiB |     20 MiB |
|---------------------------------------------------------------------------|
| Non-releasable memory |  73253 KiB |  19011 MiB |   4528 GiB |   4528 GiB |
|       from large pool |  71848 KiB |  19010 MiB |   4525 GiB |   4525 GiB |
|       from small pool |   1405 KiB |      5 MiB |      2 GiB |      2 GiB |
|---------------------------------------------------------------------------|
| Allocations           |     203    |     382    |    9622    |    9419    |
|       from large pool |      75    |     174    |    5913    |    5838    |
|       from small pool |     128    |     214    |    3709    |    3581    |
|---------------------------------------------------------------------------|
| Active allocs         |     203    |     382    |    9622    |    9419    |
|       from large pool |      75    |     174    |    5913    |    5838    |
|       from small pool |     128    |     214    |    3709    |    3581    |
|---------------------------------------------------------------------------|
| GPU reserved segments |      22    |      34    |      43    |      21    |
|       from large pool |      21    |      30    |      32    |      11    |
|       from small pool |       1    |       4    |      11    |      10    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |      25    |      59    |    4342    |    4317    |
|       from large pool |      22    |      34    |    2978    |    2956    |
|       from small pool |       3    |      28    |    1364    |    1361    |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|

After 10 Iterations:

CUDA memory at start (device 0):
|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 36           |        cudaMalloc retries: 45        |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      | 727434 KiB |  27916 MiB |  37087 GiB |  37086 GiB |
|       from large pool | 726872 KiB |  27915 MiB |  37067 GiB |  37067 GiB |
|       from small pool |    562 KiB |      5 MiB |     19 GiB |     19 GiB |
|---------------------------------------------------------------------------|
| Active memory         | 727434 KiB |  27916 MiB |  37087 GiB |  37086 GiB |
|       from large pool | 726872 KiB |  27915 MiB |  37067 GiB |  37067 GiB |
|       from small pool |    562 KiB |      5 MiB |     19 GiB |     19 GiB |
|---------------------------------------------------------------------------|
| Requested memory      | 726494 KiB |  27902 MiB |  37084 GiB |  37084 GiB |
|       from large pool | 725932 KiB |  27901 MiB |  37065 GiB |  37064 GiB |
|       from small pool |    562 KiB |      5 MiB |     19 GiB |     19 GiB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |    782 MiB |  30386 MiB | 350480 MiB | 349698 MiB |
|       from large pool |    780 MiB |  30378 MiB | 350298 MiB | 349518 MiB |
|       from small pool |      2 MiB |      8 MiB |    182 MiB |    180 MiB |
|---------------------------------------------------------------------------|
| Non-releasable memory |  73333 KiB |  19011 MiB |  40751 GiB |  40751 GiB |
|       from large pool |  71848 KiB |  19010 MiB |  40730 GiB |  40730 GiB |
|       from small pool |   1485 KiB |      5 MiB |     20 GiB |     20 GiB |
|---------------------------------------------------------------------------|
| Allocations           |     202    |     382    |   84982    |   84780    |
|       from large pool |      75    |     174    |   52617    |   52542    |
|       from small pool |     127    |     214    |   32365    |   32238    |
|---------------------------------------------------------------------------|
| Active allocs         |     202    |     382    |   84982    |   84780    |
|       from large pool |      75    |     174    |   52617    |   52542    |
|       from small pool |     127    |     214    |   32365    |   32238    |
|---------------------------------------------------------------------------|
| GPU reserved segments |      22    |      34    |     219    |     197    |
|       from large pool |      21    |      30    |     128    |     107    |
|       from small pool |       1    |       4    |      91    |      90    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |      24    |      60    |   38950    |   38926    |
|       from large pool |      22    |      34    |   26634    |   26612    |
|       from small pool |       2    |      28    |   12316    |   12314    |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|

I still need to understand why though before merging in the fix.

@sarahyurick
Copy link
Collaborator

Closed by #80.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants