Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train loop takes exponentially longer as number of layers increase #8824

Open
rplsbo opened this issue Mar 12, 2025 · 3 comments
Open

Train loop takes exponentially longer as number of layers increase #8824

rplsbo opened this issue Mar 12, 2025 · 3 comments

Comments

@rplsbo
Copy link

rplsbo commented Mar 12, 2025

🐛 Bug

Training time increases exponentially with increase in number of layers.

To Reproduce

The following code takes around 16 seconds per step at layers=1000, where it takes 0.38 seconds per step at layers=500.

import os

os.environ["XLA_REGISTER_INSTALLED_PLUGINS"] = "1"
os.environ["XLA_DISABLE_FUNCTIONALIZATION"] = "1"

import torch
import torch_xla.core.xla_model as xm
import time

class SimpleModel(torch.nn.Module):
    def __init__(self, layers=1000):
        super().__init__()
        self.input_layer = torch.nn.Linear(10, 100)
        self.input_activation = torch.nn.ReLU()
        
        self.hidden_layers = torch.nn.ModuleList()
        for _ in range(layers):
            self.hidden_layers.append(torch.nn.Linear(100, 100))
            self.hidden_layers.append(torch.nn.ReLU())
        
        self.output_layer = torch.nn.Linear(100, 1)
    
    def forward(self, x):
        x = self.input_activation(self.input_layer(x))
        for layer in self.hidden_layers:
            x = layer(x)
        x = self.output_layer(x)
        return x

def train():
    device = xm.xla_device()
    for layers in [500, 1000, 1500, 2000, 2500, 3000]:
        model = SimpleModel(layers).to(device)

        optimizer = torch.optim.Adam(model.parameters())
        loss_fn = torch.nn.MSELoss()

        step_times = []

        for i in range(10):
            start_time = time.time()
            input = torch.randn(10).to(device)
            y = torch.randn(1).to(device)
            output = model(input)
            total_loss = loss_fn(output, y)
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            xm.mark_step()
            xm.wait_device_ops()
            step_time = time.time() - start_time
            step_times.append(step_time)
        
        median_time = sorted(step_times)[len(step_times)//2]
        print(f"Median step time: {median_time:.4f}s")
        print(step_times)
    

if __name__ == "__main__":
    train()

Expected behavior

Wondering if it is expected that the training step time increases exponentially (rather than linearly) with more layers.

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: CUDA
  • torch_xla version: 2.5.1

Additional context

@ysiraichi
Copy link
Collaborator

Thank you for filing this issue.
Could you try using nightly PyTorch/XLA? Even better if you could also compare it with PyTorch CUDA device.

@rplsbo
Copy link
Author

rplsbo commented Mar 12, 2025

Issue does NOT happen if I use pytorch CUDA device. Code below to demonstrate issue not happening if not using XLA:

import torch
import time

class SimpleModel(torch.nn.Module):
    def __init__(self, layers=1000):
        super().__init__()
        self.input_layer = torch.nn.Linear(10, 100)
        self.input_activation = torch.nn.ReLU()
        
        self.hidden_layers = torch.nn.ModuleList()
        for _ in range(layers):
            self.hidden_layers.append(torch.nn.Linear(100, 100))
            self.hidden_layers.append(torch.nn.ReLU())
        
        self.output_layer = torch.nn.Linear(100, 1)
    
    def forward(self, x):
        x = self.input_activation(self.input_layer(x))
        for layer in self.hidden_layers:
            x = layer(x)
        x = self.output_layer(x)
        return x

def train():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    for layers in [500, 1000, 1500, 2000, 2500, 3000]:
        model = SimpleModel(layers).to(device)

        optimizer = torch.optim.Adam(model.parameters())
        loss_fn = torch.nn.MSELoss()

        step_times = []

        for i in range(10):
            start_time = time.time()
            input = torch.randn(10).to(device)
            y = torch.randn(1).to(device)
            output = model(input)
            total_loss = loss_fn(output, y)
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            torch.cuda.synchronize()
            step_time = time.time() - start_time
            step_times.append(step_time)
        
        median_time = sorted(step_times)[len(step_times)//2]
        print(f"Layers: {layers}, Median step time: {median_time:.4f}s")
        print(f"All times: {step_times}")
    

if __name__ == "__main__":
    train()

The issue is slightly different when running on a TPU v6e. With the nightly build, I see compile times getting pretty large in the initial steps, but the execution time only increases exponentially in the stable build (2.6.0), but not the nightly build (2.7.0).

That seems to suggest this is fixed in a later build. Any idea what the fix could be?

@ysiraichi
Copy link
Collaborator

If I understood it correctly, you said that the issue seems to be solved in the nightly build, using TPU v6e, correct? Is this also true for CUDA (accelerator used in the original post)?

Any idea what the fix could be?

Unfortunately, I don't know.
@miladm @tengyifei @bhavya01 @ManfeiBai Do you know what could have caused this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants