Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pipeline has no attribute '_execution_device' #9180

Open
choidaedae opened this issue Aug 14, 2024 · 33 comments
Open

Pipeline has no attribute '_execution_device' #9180

choidaedae opened this issue Aug 14, 2024 · 33 comments
Assignees
Labels
bug Something isn't working stale Issues that haven't received updates

Comments

@choidaedae
Copy link

choidaedae commented Aug 14, 2024

Describe the bug

Hello, I implemented my own custom pipeline referring StableDiffusionPipeline (RepDiffusionPipeline), but there are some issues
I called "accelerator.prepare" properly, and mapped the models on device (with "to.(accelerator.device)")
But when I call pipeline and the 'call' function is called, sometimes I met the error
It is not only problem in using multi-gpu, it occurs when I use single gpu.
For example, I defined my pipeline for my validation in training code like this:

val_pipe = RepDiffusionPipeline.from_pretrained(
                        "runwayml/stable-diffusion-v1-5",
                        unet=accelerator.unwrap_model(unet),
                        rep_encoder=accelerator.unwrap_model(rep_encoder),
                        vae=accelerator.unwrap_model(vae),
                        revision=None, variant=None, torch_dtype=weight_dtype, safety_checker=None
                    ).to(accelerator.device)

then, when I called 'val_pipe' like this:

model_pred = val_pipe(
                                        image = condition_original_image if args.val_mask_op else data["original_images"],
                                        representation = representation,
                                        prompt = "",
                                        num_inference_steps = 20,
                                        image_guidance_scale = 1.5,
                                        guidance_scale = scale,
                                        generator = generator
                                    ).images[0]

At that time, the error "RepDiffusionPipeline has no attribute '_execution_device'" occurs. (Not always, just randomly)
How can I solve this issue, or what part of my code can be doubted and fixed?
Thank you for reading:)

Reproduction

It occurs randomly, so there is no option to reproduce...

But when I call the defined pipeline, it occurs randomly.

Logs

RepDiffusionPipeline has no attribute '_execution_device'

System Info

I tried to test in various diffusers & python versions, but the problem still occurs.
In now, I am running my code in diffusers 0.27.2, python 3.10.14.

WARNING[XFORMERS]: xFormers can't load C++/CUDA extensions. xFormers was built for:
PyTorch 2.2.2+cu121 with CUDA 1201 (you have 2.2.2+cu118)
Python 3.10.14 (you have 3.10.14)
Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
Memory-efficient attention, SwiGLU, sparse and more won't be available.
Set XFORMERS_MORE_DETAILS=1 for more details

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • diffusers version: 0.27.2
  • Platform: Linux-5.4.0-132-generic-x86_64-with-glibc2.31
  • Python version: 3.10.14
  • PyTorch version (GPU?): 2.2.2+cu118 (True)
  • Huggingface_hub version: 0.24.3
  • Transformers version: 4.43.3
  • Accelerate version: 0.33.0
  • xFormers version: 0.0.25.post1
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@sayakpaul @yiyixuxu

@choidaedae choidaedae added the bug Something isn't working label Aug 14, 2024
@yiyixuxu
Copy link
Collaborator

hi @choidaedae

would you be able to provide a minimum reproducible code example? i.e. a script I can run on my end to get the same error that you're getting?

it would really help us understand the problem :)

@choidaedae
Copy link
Author

hi @yiyixuxu
Thanks for replying.
It occurs when I validate with my customized pipeline, in training code of my private repos.
Can I invite you to my private repository and can you run the code?
If it can, I'll clean the code and data to run it easily and invite you.
Thanks :)

@LinB203
Copy link

LinB203 commented Sep 6, 2024

same question.

@MuQY1818
Copy link

MuQY1818 commented Sep 7, 2024

Same question

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Sep 9, 2024

@LinB203 @MuQY1818
a minimum reproducible code example to help me understand the problem will be appreciated!

@luyvlei
Copy link

luyvlei commented Sep 30, 2024

I had a similar problem, and it happened randomly so it was hard to reproduce, and I have now abandoned this attribute

@asomoza
Copy link
Member

asomoza commented Sep 30, 2024

so can we assume this issue is resolved? OP also didn't provide any code to reproduce the error.

@choidaedae
Copy link
Author

I am sorry for not providing any code, but it occurs randomly and actually I'm not working with this pipeline now, so I can't provide it.

If anyone still has this problem, providing the code can be userful for huggingface.
I will let you know if a similar error occurs again later.

Thank you for your hard work : )

@asomoza asomoza closed this as completed Sep 30, 2024
@yiyixuxu
Copy link
Collaborator

think we can try to catch the exception for that property - so maybe in the future it happens we will have some visibility into the problem

@yiyixuxu yiyixuxu self-assigned this Sep 30, 2024
@Lum1104
Copy link

Lum1104 commented Nov 8, 2024

Same problem. This bug will disappear if I run it again.

@myendless1
Copy link

Same problem, occurs randomly.

@Arlaz
Copy link
Contributor

Arlaz commented Nov 19, 2024

Same problem, occurs about half of the time with the same script

@sayakpaul
Copy link
Member

Again, having a minimal reproducible snippet would be really amazing!

@a-r-r-o-w
Copy link
Member

I am re-opening this due to 3 comments in two weeks because it seems like a legitimate problem now.

Could any of you please provide a minimal script for what you're trying that causes this error? I don't see enough information here to be able to start debugging. If it happens on a single GPU as well, as mentioned above, I think it is safe to rule out race conditions, but maybe not since there is mentions that it happens randomly. It would also be nice if you could check with latest diffusers release

@a-r-r-o-w a-r-r-o-w reopened this Nov 19, 2024
@yiyixuxu
Copy link
Collaborator

@a-r-r-o-w, see ,my comment here #9180 (comment)
let's catch the exception so that they can report back

@zer0int
Copy link

zer0int commented Nov 25, 2024

Yes, me too, but always and reproducible:

Traceback (most recent call last):
...
    out = pipe(
  File "C:\Users\zer0int\AppData\Roaming\Python\Python310\site-packages\torch\utils\_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "C:\Users\zer0int\AppData\Roaming\Python\Python310\site-packages\diffusers\pipelines\flux\pipeline_flux.py", line 660, in __call__
    device = self._execution_device
  File "C:\Users\zer0int\AppData\Roaming\Python\Python310\site-packages\diffusers\configuration_utils.py", line 143, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'FluxPipeline' object has no attribute '_execution_device'

So I tried to make you a dummy model + dummy embeds minimal example -- but found myself flummoxed as I suddenly didn't get that error anymore - but got the error I actually expected, informing me about what is expected from me (as I had hoped!):

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x768 and 4096x3072)

So now I am saving the dummies and loading them from file, and I get the error maybe 20% of the time (primary drive).

However, if I save + load the dummies from HDD, and not from fast NVMe SSD, I can get the
'FluxPipeline' object has no attribute '_execution_device' error back! :)

I should also note that I am loading Flux from a secondary not-as-fast internal SSD, so this may help in erroring out.

I remember very occasionally getting AttributeError: 'FluxPipeline' object has no attribute 'device' (with SDXL and stuff, too), and that too seemed to be speed-of-access related / drive caching related. Just run it again, and it disappears.

So yeah, as it appears to me, this new issue is again related to speed of loading the weights for some weird reason. Maybe the two different errors are even related? I've had that happen since "many versions" of PyTorch and even a version of Python ago (the one about 'device' attribute).

Still, current:

Win11
Python 3.10.11, no venv

Name: torch
Version: 2.5.1+cu124

Name: diffusers
Version: 0.31.0

Name: transformers
Version: 4.46.3

Name: accelerate
Version: 1.1.1

Here's the code:

import os
import torch
from diffusers import FluxPipeline
from transformers import CLIPModel, CLIPProcessor, CLIPConfig
from transformers import T5EncoderModel, T5Config, AutoTokenizer
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SparseAutoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, lambda_val):
        super(SparseAutoencoder, self).__init__()
        self.lambda_val = lambda_val

        # Encoder
        self.encoder_weights = nn.Parameter(torch.randn(hidden_dim, input_dim, dtype=torch.float32))
        self.encoder_bias = nn.Parameter(torch.zeros(hidden_dim, dtype=torch.float32))

        # Decoder
        self.decoder_weights = nn.Parameter(torch.randn(input_dim, hidden_dim, dtype=torch.float32))
        self.decoder_bias = nn.Parameter(torch.zeros(input_dim, dtype=torch.float32))

        # Symmetric initialization
        nn.init.xavier_uniform_(self.decoder_weights)
        self.decoder_weights.data = self.encoder_weights.t()

    def forward(self, x):
        # Encoder activations
        activations = torch.relu(torch.matmul(x, self.encoder_weights.t()) + self.encoder_bias)

        # Decoder reconstruction
        reconstruction = torch.matmul(activations, self.decoder_weights.t()) + self.decoder_bias
        return reconstruction, activations

    def loss_function(self, x, reconstruction, activations):
        # L2 reconstruction loss
        reconstruction_loss = torch.mean(torch.sum((x - reconstruction) ** 2, dim=1))

        # L1 sparsity penalty
        sparsity_penalty = torch.sum(
            torch.norm(self.decoder_weights, dim=0) * torch.sum(activations, dim=0)
        )
        total_loss = reconstruction_loss + self.lambda_val * sparsity_penalty
        return total_loss


def load_model(model, load_path, device):
    model.load_state_dict(torch.load(load_path, map_location=device))
    model.to(device)
    print(model)
    print(f"Model loaded from {load_path}")
    return model

def reconstruct_embeddings(model, embeddings, device):
    model.eval()  # Set model to evaluation mode
    with torch.no_grad():
        embeddings = embeddings.to(device).float()
        reconstructed, _ = model(embeddings)
    return reconstructed.cpu()


torch_dtype = torch.bfloat16
local_files_only = False

seed = 4255330358390669
guidance_scale=3.5
num_inference_steps=20

model_id = "zer0int/CLIP-GmP-ViT-L-14"
maxtokens = 77


# Load CLIP model and processor
config = CLIPConfig.from_pretrained(model_id)
config.text_config.max_position_embeddings = maxtokens

clip_model = CLIPModel.from_pretrained(model_id, torch_dtype=torch_dtype, config=config, local_files_only=local_files_only).to(device)
clip_processor = CLIPProcessor.from_pretrained(model_id, padding="max_length", max_length=maxtokens, return_tensors="pt", truncation=True, local_files_only=local_files_only)

# Load FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch_dtype, local_files_only=local_files_only)

pipe.tokenizer = clip_processor.tokenizer
pipe.tokenizer_2 = clip_processor.tokenizer
pipe.text_encoder = clip_model.text_model
pipe.text_encoder_2 = clip_model.text_model
pipe.tokenizer_max_length = maxtokens
pipe.text_encoder.dtype = torch_dtype
pipe.text_encoder_2.dtype = torch_dtype

pipe.to(device)

pipe.vae.enable_slicing()
pipe.vae.enable_tiling()


if __name__ == "__main__":
    lambda_val = 5.0
    hidden_dim = 8192

    text_embeddings = torch.randn(13, 768, dtype=torch.float32)
    torch.save(text_embeddings, "dummyembeds.pt")
    dummymodel = SparseAutoencoder(input_dim=768, hidden_dim=hidden_dim, lambda_val=lambda_val)  
    torch.save(dummymodel.state_dict(), "dummymodel.pth")
    dummysae = "dummymodel.pth"

    text_embeddings = torch.load("dummyembeds.pt").to(device)

    sae = SparseAutoencoder(input_dim=768, hidden_dim=hidden_dim, lambda_val=lambda_val)
    sae = load_model(sae, dummysae, device)

    reconstructed_embeddings = reconstruct_embeddings(sae, text_embeddings, device)
    text_embeddings = text_embeddings.to(device)
    reconstructed_embeddings = reconstructed_embeddings.to(device)

    num_embeddings = text_embeddings.size(0)
    for selected_embedding_idx in range(num_embeddings):
        print(f"Processing embedding index: {selected_embedding_idx}")
        selected_embedding = text_embeddings[selected_embedding_idx:selected_embedding_idx + 1].float()

        with torch.no_grad():
            reconstruction, activations = sae(selected_embedding)

        _, sae_activations = sae(text_embeddings.float())


        normalized_embeddings = reconstruction / reconstruction.norm(dim=-1, keepdim=True)
        normalized_embeddings = normalized_embeddings.to(device).bfloat16()

        # Please ignore that I am doing nonsense here. WIP, and there's no prompt.
        pooled_prompt_embeds = reconstructed_embeddings / reconstructed_embeddings.norm(dim=-1, keepdim=True)
        pooled_prompt_embeds = pooled_prompt_embeds.bfloat16()      

        generator = torch.manual_seed(seed)
        out = pipe(
            prompt_embeds=normalized_embeddings,
            pooled_prompt_embeds=pooled_prompt_embeds,
            guidance_scale=guidance_scale,
            height=1024,
            width=1024,
            num_inference_steps=num_inference_steps,
            generator=generator,
        ).images[0]

        out.save(f"dummyimage.png")

PS: If you have a tip for constructing pooled_prompt_embeds, I'll happily take it (can also be a zero-weight T5 if need be). :)
PPS: I have no prompt. The original embeddings were made my CLIP in gradient ascent for a given image, and there's no way I can get accurate tokens / words out of that.

Either way, I hope this helps your quest in squashing this bug - thanks in advance for your work! 👍

@sayakpaul
Copy link
Member

Thank you so much!

PS: If you have a tip for constructing pooled_prompt_embeds, I'll happily take it (can also be a zero-weight T5 if need be). :)

You can compute the text embeddings needed to run the FluxPipeline like so:

from diffusers import FluxPipeline
import torch 

ckpt_id = "black-forest-labs/FLUX.1-dev"
prompt = "a photo of a dog with cat-like look"

pipeline = FluxPipeline.from_pretrained(
    ckpt_id,
    vae=None,
    torch_type=torch.bfloat16
).to("cuda")

with torch.no_grad():
    print("Encoding prompts.")
    prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
        prompt=prompt, prompt_2=None, max_sequence_length=512
    )

Does this help?

@HilaManor
Copy link

for me it happens as well, but even simpler,

pipe = StableDiffusionPipeline.from_pretrained(..)
print(pipe.dtype)

sometimes work, and sometimes AttributeError: 'StableDiffusionPipeline' object has no attribute 'dtype'.

Only as part of a script and not in iPython, for example. I think it's related maybe to accelerate and offloading?

Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Jan 10, 2025
@asomoza asomoza removed the stale Issues that haven't received updates label Jan 10, 2025
@asomoza
Copy link
Member

asomoza commented Jan 10, 2025

@HilaManor Hi, that example code is too simple, I can execute that 100 times and all 100 times it will work for me, probably there's something in your env that it's making this to happen.

IMO this should be treated as an issue or bug if it happens with our pipelines and code, if someone is using part of it or doing a custom pipeline, this error probably means that there's something wrong in the code or the environment, but it would be ideal to also properly catch it and give a better explanation of the error.

@HilaManor
Copy link

It occurs randomly, so I think it's a race condition. When I captured that and tried to debug this it disappeared, that's part of the reason I think it's a race thing..

@asomoza
Copy link
Member

asomoza commented Jan 10, 2025

if that would be the case it should happen at least once for me in a 100 times experiment no? What I mean is that we probably need more information about what you're doing, for example, your environment and hardware.

@HilaManor
Copy link

@asomoza it happened on 5 different computers, which have 3 different types of GPUs (an A6000 server, an 2080Ti server and on a personal computer, a 3090Ti on a personal computer), with different CPUs and generally different hardware,

The env is roughly the same (up to the fact that the nvidia drivers is different on each server) (below). Please let me know what more info could help.

Package                   Version              Editable project location
------------------------- -------------------- --------------------------------------
absl-py                   2.1.0
accelerate                1.1.1
addict                    2.4.0
aiohappyeyeballs          2.4.3
aiohttp                   3.10.10
aiosignal                 1.3.1
anyio                     4.6.2.post1
argon2-cffi               23.1.0
argon2-cffi-bindings      21.2.0
arrow                     1.3.0
asttokens                 2.4.1
async-lru                 2.0.4
attrs                     24.2.0
av                        13.1.0
babel                     2.16.0
basicsr                   1.4.2
bcrypt                    4.2.1
beautifulsoup4            4.12.3
bitsandbytes              0.45.0
bleach                    6.2.0
boto3                     1.35.82
botocore                  1.35.82
certifi                   2024.8.30
cffi                      1.17.1
cfgv                      3.4.0
charset-normalizer        3.4.0
click                     8.1.7
cmake                     3.31.2
colorama                  0.4.6
comm                      0.2.2
compressai                1.2.6
contourpy                 1.3.0
cryptography              44.0.0
cycler                    0.12.1
dctorch                   0.1.2
debugpy                   1.8.8
decorator                 5.1.1
defusedxml                0.7.1
diffusers                 0.31.0
distlib                   0.3.9
DISTS-pytorch             0.1
docker-pycreds            0.4.0
einops                    0.8.0
executing                 2.1.0
facexlib                  0.3.0
fastcore                  1.7.19
fastjsonschema            2.21.1
filelock                  3.16.1
filterpy                  1.4.5
fonttools                 4.54.1
fqdn                      1.5.1
frozenlist                1.5.0
fsspec                    2024.10.0
ftfy                      6.3.1
future                    1.0.0
fvcore                    0.1.5.post20221221
gitdb                     4.0.11
GitPython                 3.1.43
grpcio                    1.67.1
h11                       0.14.0
httpcore                  1.0.7
httpx                     0.28.0
huggingface-hub           0.26.2
icecream                  2.1.3
identify                  2.6.4
idna                      3.10
imageio                   2.36.0
importlib_metadata        8.5.0
iniconfig                 2.0.0
inquirerpy                0.3.4
iopath                    0.1.10
ipykernel                 6.29.5
ipython                   8.29.0
isoduration               20.11.0
jedi                      0.19.2
Jinja2                    3.1.4
jmespath                  1.0.1
json5                     0.10.0
jsonpointer               3.0.0
jsonschema                4.23.0
jsonschema-specifications 2024.10.1
jupyter_client            8.6.3
jupyter_core              5.7.2
jupyter-events            0.10.0
jupyter-lsp               2.2.5
jupyter_server            2.14.2
jupyter_server_terminals  0.5.3
jupyterlab                4.3.2
jupyterlab_pygments       0.3.0
jupyterlab_server         2.27.3
keyboard                  0.13.5
kiwisolver                1.4.7
lazy_loader               0.4
lightning                 2.4.0
lightning-utilities       0.11.8
llvmlite                  0.43.0
lmdb                      1.5.1
lovely-numpy              0.2.13
lovely-tensors            0.1.17
lpips                     0.1.4
Markdown                  3.7
MarkupSafe                3.0.2
matplotlib                3.9.2
matplotlib-inline         0.1.7
mistune                   3.0.2
mpmath                    1.3.0
multidict                 6.1.0
natten                    0.17.3+torch250cu121
nbclient                  0.10.1
nbconvert                 7.16.4
nbformat                  5.10.4
nest-asyncio              1.6.0
networkx                  3.4.2
neuralcompression         0.3.0
nodeenv                   1.9.1
notebook                  7.3.1
notebook_shim             0.2.4
numba                     0.60.0
numpy                     1.26.4
nvidia-cublas-cu12        12.4.5.8
nvidia-cuda-cupti-cu12    12.4.127
nvidia-cuda-nvrtc-cu12    12.4.127
nvidia-cuda-runtime-cu12  12.4.127
nvidia-cudnn-cu12         9.1.0.70
nvidia-cufft-cu12         11.2.1.3
nvidia-curand-cu12        10.3.5.147
nvidia-cusolver-cu12      11.6.1.9
nvidia-cusparse-cu12      12.3.1.170
nvidia-nccl-cu12          2.21.5
nvidia-nvjitlink-cu12     12.4.127
nvidia-nvtx-cu12          12.4.127
openai-clip               1.0.1
opencv-python             4.10.0.84
opencv-python-headless    4.10.0.84
overrides                 7.7.0
packaging                 24.2
pandas                    2.2.3
pandocfilters             1.5.1
parameterized             0.9.0
paramiko                  3.5.0
parso                     0.8.4
pexpect                   4.9.0
pfzy                      0.3.4
pillow                    11.0.0
pip                       24.0
piq                       0.8.0
platformdirs              4.3.6
pluggy                    1.5.0
portalocker               3.0.0
pre_commit                4.0.1
prometheus_client         0.21.1
prompt_toolkit            3.0.48
propcache                 0.2.0
protobuf                  5.28.3
psutil                    6.1.0
ptyprocess                0.7.0
pure_eval                 0.2.3
pycparser                 2.22
Pygments                  2.18.0
pyiqa                     0.1.13
PyNaCl                    1.5.0
pyparsing                 3.2.0
pytest                    8.3.4
python-dateutil           2.9.0.post0
python-json-logger        2.0.7
pytorch-lightning         2.4.0
pytorch-msssim            1.0.0
pytorchvideo              0.1.5
pytz                      2024.2
PyYAML                    6.0.2
pyzmq                     26.2.0
referencing               0.35.1
regex                     2024.11.6
requests                  2.32.3
rfc3339-validator         0.1.4
rfc3986-validator         0.1.1
rpds-py                   0.22.3
s3transfer                0.10.4
safetensors               0.4.5
scikit-image              0.24.0
scipy                     1.14.1
seaborn                   0.13.2
Send2Trash                1.8.3
sentencepiece             0.2.0
sentry-sdk                2.18.0
setproctitle              1.3.3
setuptools                75.3.0
six                       1.16.0
smmap                     5.0.1
sniffio                   1.3.1
soupsieve                 2.6
stack-data                0.6.3
sympy                     1.13.1
tabulate                  0.9.0
tb-nightly                2.19.0a20241112
tensorboard               2.18.0
tensorboard-data-server   0.7.2
termcolor                 2.5.0
terminado                 0.18.1
tifffile                  2024.9.20
timm                      1.0.11
tinycss2                  1.4.0
tokenizers                0.15.2
tomli                     2.1.0
torch                     2.5.1
torch-ema                 0.3
torch_fidelity            0.4.0b0              
torch-geometric           2.6.1
torchmetrics              1.5.2
torchvision               0.20.1
tornado                   6.4.1
tqdm                      4.67.0
traitlets                 5.14.3
transformers              4.37.2
triton                    3.1.0
types-python-dateutil     2.9.0.20241003
typing_extensions         4.12.2
tzdata                    2024.2
uri-template              1.3.0
urllib3                   2.2.3
virtualenv                20.28.1
wandb                     0.18.6
wcwidth                   0.2.13
webcolors                 24.11.1
webencodings              0.5.1
websocket-client          1.8.0
Werkzeug                  3.1.3
yacs                      0.1.8
yapf                      0.40.2
yarl                      1.17.1
zipp                      3.21.0

@asomoza
Copy link
Member

asomoza commented Jan 31, 2025

thanks for the detailed answer, just to be clear, you are getting Pipeline has no attribute '_execution_device' with just this code:

pipe = StableDiffusionPipeline.from_pretrained(..)
print(pipe.dtype)

Ideally we should have a fully reproducible code snippet that gives the error all the time, since you're saying this is random, I'll try to use an AWS instance to test this for a longer time.

As I told you, I constantly use diffusers (hundreds of times a day) with linux and windows, with a 3090, A100 and a mobile 4090 and never, not once, I had this error, so it seems this is going to be a really hard bug to catch.

You can guess this also because this issue has a low number of active users, if this was a recurrent problem, we would have a lot of people/issues reporting this.

@HilaManor
Copy link

Clarifying - for the code you wrote I get AttributeError: 'StableDiffusionPipeline' object has no attribute 'dtype'., and not an execution device bug.
I found that instead, pipe.model.dtype always works for me.
So I think what happens is that for both attributes, there's some race condition from the loading of the model and when the model's attributes are loaded into the pipeline object. I have no idea if this is easily fixable.

@asomoza
Copy link
Member

asomoza commented Jan 31, 2025

oh, I look at a lot of issues so I usually go by the title of the issue, that's why is not a good idea to use another issue to add yours. Probably this should go in a new one since it's totally unrelated to the one from this discussion.

What happens to you it's even more weird, since the dtype is just a property of the pipeline and it returns torch.float32 by default, so there's something else happening here. There's no race condition here.

You're using just plain unmodified diffusers for this right?

You seem to have solved it anyway but if you want to keep digging into this, please open a new issue with your problem and also we don't need the full library list of your environment, you can just use diffusers-cli env and post the output, that gives us all the information we need.

Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Feb 25, 2025
@asomoza
Copy link
Member

asomoza commented Feb 25, 2025

closing this since there's no more updates from the main issue, assuming it was resolved or was a problem in the env.

@asomoza asomoza closed this as completed Feb 25, 2025
@mirix
Copy link

mirix commented Mar 11, 2025

Intriguing. The bug is still present today but, for me, only with LoRAs. And it happens half of the times with exactly the same script.

@hosjiu1702
Copy link

hosjiu1702 commented Mar 21, 2025

confirm @asomoza

@asomoza
Copy link
Member

asomoza commented Mar 21, 2025

Hi, in case this is not that clear yet, we don't need a confirmation or "I have the same problem", what we need to be able to help you is a snippet of code that reproduces the error, also we can't support custom pipelines or custom code because we don't have the bandwidth, so this needs to be with just diffusers code.

Sadly this is something than never happens to us or the majority of the people using diffusers, so we need your help to be able to help you.

@mashrurmorshed
Copy link

This happened to me when I was trying to trick diffuser's DDIMPipeline into working with a non-diffusers diffusion model (a bit unrelated to OP's original use case, but it may still be useful to someone).

The findings is that the _execution_device is available when if whatever model you are passing has a .device attribute (the device is set when register_modules is called in __init__ in the pipeline). Minimal reproducible example:

import torch
import diffusers

class DummyNoDevice(torch.nn.Module):
    def __init__(self):
        super().__init__()
   
    def forward(self, sample, timestep):
        return (sample,)

class DummyWithDevice(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
   
    def forward(self, sample, timestep):
        return (sample,)

noise_scheduler = diffusers.DDPMScheduler(1000)
pipe_no_device = diffusers.DDIMPipeline(unet=DummyNoDevice(), scheduler=noise_scheduler)
pipe_with_device = diffusers.DDIMPipeline(unet=DummyWithDevice(), scheduler=noise_scheduler)

try:
    print(f"Success: W/ device: {pipe_with_device._execution_device}")
except Exception as e:
    print(f"Failed: W/ device: {e}")

try:
    print(f"Success: no device: {pipe_no_device._execution_device}")
except Exception as e:
    print(f"Failed: no device: {e}")

Result: (diffusers 0.32.2)

Success: W/ device: cuda:0
Failed: no device: 'DDIMPipeline' object has no attribute '_execution_device'

@asomoza asomoza reopened this Mar 31, 2025
@asomoza
Copy link
Member

asomoza commented Mar 31, 2025

thanks a lot @mashrurmorshed

@yiyixuxu wee have a reproducible snippet of code now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working stale Issues that haven't received updates
Projects
None yet
Development

No branches or pull requests