Skip to content

Commit

Permalink
batched sd (ROCm#8)
Browse files Browse the repository at this point in the history
Co-authored-by: Terry Chen <[email protected]>
  • Loading branch information
terrychenism and Terry Chen committed Oct 4, 2022
1 parent c2bdbb2 commit a3114e3
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 14 deletions.
12 changes: 9 additions & 3 deletions examples/05_stable_diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,18 @@ _OOM = Out of Memory_

## Batched Version

A batched version of AIT Stable Diffusion can be found at: https://github.com/terrychenism/AIT_StableDiffusion/tree/main/examples/05_stable_diffusion
### A100-40GB / CUDA 11.6

- Stable Diffusion with AIT batch inference, 50 steps

Some reference results are taken from the repo:
| Batch size | PT Latency (ms) | AIT Latency (ms) |
|--------------|------------------|------------------|
| 1 | 3058.27 | 1282.98 |
| 3 | 7334.46 | 3121.88 |
| 8 | 17944.60 | 7492.81 |
| 16 | OOM | 14931.95 |

### A100-40GB, 25 Steps
- AIT Faster rendering, 25 steps

| Batch size | AIT Latency (ms) | AVG im/s |
|------------|------------------|----------|
Expand Down
12 changes: 7 additions & 5 deletions examples/05_stable_diffusion/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def benchmark_clip(
attention_mask[-1, -mask_seq:] = 0
attention_mask = None

position_ids = torch.arange(seqlen).expand((1, -1)).cuda()
position_ids = torch.arange(seqlen).expand((batch_size, -1)).cuda()
pt_ys = pt_mod(input_ids, attention_mask, position_ids)
print("pt output:", pt_ys[0].shape)

Expand Down Expand Up @@ -274,9 +274,11 @@ def benchmark_vae(batch_size=1, height=64, width=64, benchmark_pt=False, verify=

@click.command()
@click.option("--token", default="", help="access token")
@click.option("--batch-size", default=1, help="batch size")
@click.option("--verify", type=bool, default=False, help="verify correctness")
@click.option("--benchmark-pt", type=bool, default=False, help="run pt benchmark")
def benchmark_diffusers(token, verify, benchmark_pt):
def benchmark_diffusers(token, batch_size, verify, benchmark_pt):
assert batch_size == 1, "batch size must be 1 for submodule verification"
logging.getLogger().setLevel(logging.INFO)
np.random.seed(0)
torch.manual_seed(4896)
Expand All @@ -293,11 +295,11 @@ def benchmark_diffusers(token, verify, benchmark_pt):
).to("cuda")

# CLIP
benchmark_clip(benchmark_pt=benchmark_pt, verify=verify)
benchmark_clip(batch_size=batch_size, benchmark_pt=benchmark_pt, verify=verify)
# UNet
benchmark_unet(batch_size=2, benchmark_pt=benchmark_pt, verify=verify)
benchmark_unet(batch_size=batch_size * 2, benchmark_pt=benchmark_pt, verify=verify)
# VAE
benchmark_vae(benchmark_pt=benchmark_pt, verify=verify)
benchmark_vae(batch_size=batch_size, benchmark_pt=benchmark_pt, verify=verify)


if __name__ == "__main__":
Expand Down
11 changes: 6 additions & 5 deletions examples/05_stable_diffusion/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def compile_unet(
latent_model_input_ait = Tensor(
[batch_size, hh, ww, 4], name="input0", is_input=True
)
timesteps_ait = Tensor([2], name="input1", is_input=True)
timesteps_ait = Tensor([batch_size], name="input1", is_input=True)
text_embeddings_pt_ait = Tensor([batch_size, 64, 768], name="input2", is_input=True)

Y = ait_mod(latent_model_input_ait, timesteps_ait, text_embeddings_pt_ait)
Expand Down Expand Up @@ -316,9 +316,10 @@ def compile_vae(

@click.command()
@click.option("--token", default="", help="access token")
@click.option("--batch-size", default=1, help="batch size")
@click.option("--use-fp16-acc", default=True, help="use fp16 accumulation")
@click.option("--convert-conv-to-gemm", default=True, help="convert 1x1 conv to gemm")
def compile_diffusers(token, use_fp16_acc=True, convert_conv_to_gemm=True):
def compile_diffusers(token, batch_size, use_fp16_acc=True, convert_conv_to_gemm=True):
logging.getLogger().setLevel(logging.INFO)
np.random.seed(0)
torch.manual_seed(4896)
Expand All @@ -338,15 +339,15 @@ def compile_diffusers(token, use_fp16_acc=True, convert_conv_to_gemm=True):
).to("cuda")

# CLIP
compile_clip(use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm)
compile_clip(batch_size=batch_size, use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm)
# UNet
compile_unet(
batch_size=2,
batch_size=batch_size * 2,
use_fp16_acc=use_fp16_acc,
convert_conv_to_gemm=convert_conv_to_gemm,
)
# VAE
compile_vae(use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm)
compile_vae(batch_size=batch_size, use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def unet_inference(self, latent_model_input, timesteps, encoder_hidden_states):

def clip_inference(self, input_ids, seqlen=64):
exe_module = self.clip_ait_exe
position_ids = torch.arange(seqlen).expand((1, -1)).cuda()
bs = input_ids.shape[0]
position_ids = torch.arange(seqlen).expand((bs, -1)).cuda()
inputs = {
"input0": input_ids,
"input1": position_ids,
Expand Down

0 comments on commit a3114e3

Please sign in to comment.