Skip to content

Commit

Permalink
Workaround for SD alt compilation/demo for T4 SM75
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov committed Jun 21, 2023
1 parent f5896d0 commit 3040a50
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,15 @@ def compile_clip(

pt_mod = pt_mod.eval()
params_ait = map_clip_params(pt_mod, batch_size, seqlen, depth)
batch_size = IntVar(values=list(batch_size), name="batch_size")
# batch higher dim should be 8+
# otherwise output image will be messy on T4 GPU (SM75)
batch_size_d = IntVar(values=[batch_size[0], max(8, batch_size[1])], name="batch_size")

input_ids_ait = Tensor(
[batch_size, seqlen], name="input0", dtype="int64", is_input=True
[batch_size_d, seqlen], name="input0", dtype="int64", is_input=True
)
position_ids_ait = Tensor(
[batch_size, seqlen], name="input1", dtype="int64", is_input=True
[batch_size_d, seqlen], name="input1", dtype="int64", is_input=True
)
Y = ait_mod(input_ids=input_ids_ait, position_ids=position_ids_ait)
mark_output(Y)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def compile_unet(
height_d = height
width_d = width
else:
batch_size = (batch_size[0], batch_size[1] * 2) # double batch size for unet
# double batch size for unet.
# Both lower and upper dims should be doubled, otherviwe output image will be messy on T4 GPU (SM75)
batch_size = (batch_size[0] * 2, batch_size[1] * 2)
batch_size = IntVar(values=list(batch_size), name="batch_size")
height = height[0] // 8, height[1] // 8
width = width[0] // 8, width[1] // 8
Expand Down

0 comments on commit 3040a50

Please sign in to comment.