Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add HF Auth mixin to Stable Diffusion (#1763)
Summary: Right now stale diffusion and lit-llama are not actually running in CI because they get rate limited by huggingface. since we've now added an auth token as a github secret we can move stable diffusion out of canary and do things like include it in blueberries dashboard We also added some nice errors so people running in torchbench locally know they will need to have a token to run these models Anyways auth is a mixin which seems like the right abstraction # Some relevant details about the model Torchbench has a function `get_module()` that has the intent of testing a `nn.Module` on an actual `torch.Tensor` Unfortunately a `StableDiffusionPipeline` is not an `nn.Module` it's a composition of a tokenizer and 3 seperate `nn.Modules` an encoder, vae and unet. ## text_encoder ```python def get_module(self): batch_size = 1 sequence_length = 10 vocab_size = 32000 # Generate random indices within the valid range input_tensor = torch.randint(low=0, high=vocab_size, size=(batch_size, sequence_length)) # Make sure the tensor has the correct data type input_tensor = input_tensor.long() print(self.pipe.text_encoder(input_tensor)) return self.pipe.text_encoder, input_tensor ``` Text encoder outputs a `BaseModelOutputWithPooling` which has multiple nn modules https://gist.github.com/msaroufim/51f0038863c5cce4cc3045e4d9f9c399 ``` ====================================================================== FAIL: test_stable_diffusion_example_cuda (__main__.TestBenchmark) ---------------------------------------------------------------------- components._impl.workers.subprocess_rpc.ChildTraceException: Traceback (most recent call last): File "/home/ubuntu/benchmark/components/_impl/workers/subprocess_rpc.py", line 482, in _run_block exec( # noqa: P204 File "<subprocess-worker>", line 35, in <module> File "<subprocess-worker>", line 12, in _run_in_worker_f File "/home/ubuntu/benchmark/torchbenchmark/util/model.py", line 26, in __call__ obj.__post__init__() File "/home/ubuntu/benchmark/torchbenchmark/util/model.py", line 126, in __post__init__ self.accuracy = check_accuracy(self) File "/home/ubuntu/benchmark/torchbenchmark/util/env_check.py", line 469, in check_accuracy model, example_inputs = maybe_cast(tbmodel, model, example_inputs) File "/home/ubuntu/benchmark/torchbenchmark/util/env_check.py", line 424, in maybe_cast example_inputs = clone_inputs(example_inputs) File "/home/ubuntu/benchmark/torchbenchmark/util/env_check.py", line 297, in clone_inputs assert isinstance(value, torch.Tensor) AssertionError ``` ## vae ```python def get_module(self): print(self.pipe.vae(torch.randn(9,3,9,9))) ``` Same problem for vae https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/vae.py#L27 ## unet ```python def get_module(self): # This will only benchmark the unet since that's the biggest layer # Stable diffusion is a composition of a text encoder, unet and vae encoder_hidden_states = torch.randn(320, 1024) sample = torch.randn(4, 4, 4, 32) timestep = 5 inputs_to_pipe = {'timestep': timestep, 'encoder_hidden_states': encoder_hidden_states, 'sample': sample} result = self.pipe.unet(**inputs_to_pipe) return self.pipe, inputs_to_pipe ``` Unet unfortunately does not have a tensor input For VAE and encoder the test failure is particularly helpful ``` (sam) ubuntu@ip-172-31-9-217:~/benchmark$ python test.py -k "test_stable_diffusion_example_cuda" F ====================================================================== FAIL: test_stable_diffusion_example_cuda (__main__.TestBenchmark) ---------------------------------------------------------------------- Traceback (most recent call last): File "/home/ubuntu/benchmark/test.py", line 75, in example_fn assert accuracy == "pass" or accuracy == "eager_1st_run_OOM", f"Expected accuracy pass, get {accuracy}" AssertionError: Expected accuracy pass, get eager_1st_run_fail ---------------------------------------------------------------------- Ran 1 test in 7.402s FAILED (failures=1) ``` Pull Request resolved: #1763 Reviewed By: xuzhao9 Differential Revision: D47565523 Pulled By: msaroufim fbshipit-source-id: c949ce8a31c0a4706658937fc6603a22a4bc3ec6
- Loading branch information