Skip to content

Commit

Permalink
Revive SD downloads from shark_tank. (#1465)
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet authored May 25, 2023
1 parent 6d64b8e commit 54e57f7
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,10 @@ def load_clip(self):
self.text_encoder = self.sd_model.clip()
else:
try:
breakpoint()
self.text_encoder = get_clip()
except:
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
self.text_encoder = self.sd_model.clip()

Expand All @@ -104,7 +106,8 @@ def load_unet(self):
else:
try:
self.unet = get_unet()
except:
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
self.unet = self.sd_model.unet()

Expand All @@ -121,7 +124,8 @@ def load_vae(self):
else:
try:
self.vae = get_vae()
except:
except Exception as e:
print(e)
print("download pipeline failed, falling back to import_mlir")
self.vae = self.sd_model.vae()

Expand Down
1 change: 0 additions & 1 deletion apps/stable_diffusion/src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def get_shark_model(tank_url, model_name, extra_args=[]):

# Set local shark_tank cache directory.
shark_args.local_tank_cache = args.local_tank_cache

from shark.shark_downloader import download_model

if "cuda" in args.device:
Expand Down
8 changes: 8 additions & 0 deletions shark/shark_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def download_public_file(
continue

destination_filename = os.path.join(destination_folder_name, blob_name)
if os.path.isdir(destination_filename):
continue
with open(destination_filename, "wb") as f:
with tqdm.wrapattr(f, "write", total=blob.size) as file_obj:
storage_client.download_blob_to_file(blob, file_obj)
Expand Down Expand Up @@ -210,6 +212,9 @@ def download_model(
+ "_BS"
+ str(import_args["batch_size"])
)
elif any(model in model_name for model in ["clip", "unet", "vae"]):
# TODO(Ean Garvey): rework extended naming such that device is only included in model_name after .vmfb compilation.
model_dir_name = model_name
else:
model_dir_name = model_name + "_" + frontend
model_dir = os.path.join(WORKDIR, model_dir_name)
Expand Down Expand Up @@ -270,6 +275,9 @@ def download_model(
tuned_str = "" if tuned is None else "_" + tuned
suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
filename = os.path.join(model_dir, model_name + suffix)
print(
f"Verifying that model artifacts were downloaded successfully to {filename}..."
)
if not os.path.exists(filename):
from tank.generate_sharktank import gen_shark_files

Expand Down
2 changes: 1 addition & 1 deletion tank_version.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"version": "2023-03-31_02d52bb"
"version": "nightly"
}

0 comments on commit 54e57f7

Please sign in to comment.