Skip to content
This repository has been archived by the owner on Nov 21, 2022. It is now read-only.

TransformerDataModule.setup() run more than once unnecessarily #299

Open
RR-28023 opened this issue Oct 27, 2022 · 0 comments
Open

TransformerDataModule.setup() run more than once unnecessarily #299

RR-28023 opened this issue Oct 27, 2022 · 0 comments
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@RR-28023
Copy link
Contributor

RR-28023 commented Oct 27, 2022

🐛 Bug

TransformerDataModule.setup() is run more than once unnecessarily. For example, when running the code included below, it runs setup() when calling dm.num_classes and then when calling trainer.fit(model, dm).

setup() then calls self.load_dataset(), self.split_dataset(dataset) and self.process_data(dataset, stage=stage). Calling self.load_dataset() several times is not a big deal because it will load it from the cache, but the other two methods are expensive and I think it does not make sense to run them again (since they just override whatever self.ds was there before.

To Reproduce

Take the below example from the docs and just check the console output or run it in debug mode with a breakpoint. It can be seen that TransformerDataModule.setup() and the subsequent methods load_dataset(), split_dataset() and are run more than once.

import pytorch_lightning as pl
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.text_classification import (
    TextClassificationDataModule,
    TextClassificationTransformer,
)

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="bert-base-uncased")
dm = TextClassificationDataModule(
    batch_size=1,
    dataset_name="glue",
    dataset_config_name="sst2",
    max_length=512,
    tokenizer=tokenizer,
)
model = TextClassificationTransformer(pretrained_model_name_or_path="bert-base-uncased", num_labels=dm.num_classes)
trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1)

trainer.fit(model, dm)

Expected behavior

Given that TransformerDataModule.setup() currently does the following:

def setup(self, stage: Optional[str] = None): 
  dataset = self.load_dataset()
  dataset = self.split_dataset(dataset)
  dataset = self.process_data(dataset, stage=stage)
  self.ds = dataset

Perhaps a way to avoid running it again would be creating the class attribute self.setup_stages_run = [] when the class is initialized and then defining the setup method as:

    def setup(self, stage: Optional[str] = None): 
        # Load and split dataset only if setup has not been run before
        if len(self.setup_stages_run) == 0: 
            dataset = self.load_dataset()
            dataset = self.split_dataset(dataset)
        else:
            dataset = self.ds

        # Process dataset only if setup has not been run before for this stage    
        if stage not in self.setup_stages_run:            
            self.ds = self.process_data(dataset, stage=stage)
            self.setup_stages_run.append(stage)

Can create a PR if you think this makes sense.
Thanks!

@RR-28023 RR-28023 added bug / fix Something isn't working help wanted Extra attention is needed labels Oct 27, 2022
@RR-28023 RR-28023 changed the title TransformerDataModule.setup() run more than once unnecessarily `TransformerDataModule.setup() run more than once unnecessarily Oct 27, 2022
@RR-28023 RR-28023 changed the title `TransformerDataModule.setup() run more than once unnecessarily TransformerDataModule.setup() run more than once unnecessarily Oct 27, 2022
@Borda Borda added enhancement New feature or request and removed bug / fix Something isn't working labels Nov 21, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants