diff --git a/torchbenchmark/models/cm3leon_generate/__init__.py b/torchbenchmark/models/cm3leon_generate/__init__.py index 4bd3ee8ed2..d36d16c523 100644 --- a/torchbenchmark/models/cm3leon_generate/__init__.py +++ b/torchbenchmark/models/cm3leon_generate/__init__.py @@ -7,6 +7,7 @@ class Model(BenchmarkModel): task = NLP.LANGUAGE_MODELING + DEFAULT_TRAIN_BSIZE = 1 DEFAULT_EVAL_BSIZE = 1 def __init__(self, test, device, batch_size=None, extra_args=[]):