Skip to content

Commit

Permalink
update notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
amirhertz committed Nov 29, 2023
1 parent b88e32b commit 733572a
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 54 deletions.
18 changes: 13 additions & 5 deletions style_aligned_sdxl.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@
"cell_type": "code",
"execution_count": null,
"id": "c2f6f1e6-445f-47bc-b9db-0301caeb7490",
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# init models\n",
Expand All @@ -70,8 +74,8 @@
").to(\"cuda\")\n",
"\n",
"handler = sa_handler.Handler(pipeline)\n",
"sa_args = sa_handler.StyleAlignedArgs(share_group_norm=True,\n",
" share_layer_norm=True,\n",
"sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,\n",
" share_layer_norm=False,\n",
" share_attention=True,\n",
" adain_queries=True,\n",
" adain_keys=True,\n",
Expand All @@ -85,16 +89,20 @@
"cell_type": "code",
"execution_count": null,
"id": "5cca9256-0ce0-45c3-9cba-68c7eff1452f",
"metadata": {},
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"# run StyleAligned\n",
"\n",
"sets_of_prompts = [\n",
" \"a toy train. macro photo. 3d game asset\",\n",
" \"a toy airplane. macro photo. 3d game asset\",\n",
" \"a toy bicycle. macro photo. 3d game asset\",\n",
" \"a toy car. macro photo. 3d game asset\",\n",
" \"a toy boat. macro photo. 3d game asset\",\n",
"] \n",
"g_cuda = torch.Generator(device='cuda')\n",
"images = pipeline(sets_of_prompts,).images\n",
Expand Down
33 changes: 20 additions & 13 deletions style_aligned_w_controlnet.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@
").to(\"cuda\")\n",
"pipeline.enable_model_cpu_offload()\n",
"\n",
"sa_args = sa_handler.StyleAlignedArgs(share_group_norm=True,\n",
" share_layer_norm=True,\n",
"sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,\n",
" share_layer_norm=False,\n",
" share_attention=True,\n",
" adain_queries=True,\n",
" adain_keys=True,\n",
Expand All @@ -101,11 +101,12 @@
"metadata": {},
"outputs": [],
"source": [
"# get depth map\n",
"# get depth maps\n",
"\n",
"image = load_image(\"./example_image/train.png\")\n",
"depth_image = pipeline_calls.get_depth_map(image, feature_processor, depth_estimator)\n",
"mediapy.show_images([image, depth_image])"
"depth_image1 = pipeline_calls.get_depth_map(image, feature_processor, depth_estimator)\n",
"depth_image2 = load_image(\"./example_image/sun.png\").resize((1024, 1024))\n",
"mediapy.show_images([depth_image1, depth_image2])"
]
},
{
Expand All @@ -117,15 +118,21 @@
"source": [
"# run ControlNet depth with StyleAligned\n",
"\n",
"prompts = [\"a poster in flat design style\", \"a train in flat design style\"]\n",
"reference_prompt = \"a poster in flat design style\"\n",
"target_prompts = [\"a train in flat design style\", \"the sun in flat design style\"]\n",
"controlnet_conditioning_scale = 0.8\n",
"images = pipeline_calls.controlnet_call(pipeline, prompts,\n",
" image=depth_image,\n",
" num_inference_steps=50,\n",
" controlnet_conditioning_scale=controlnet_conditioning_scale,\n",
" num_images_per_prompt=3)\n",
"\n",
"mediapy.show_images([images[0], depth_image] + images[1:], titles=[\"reference\", \"depth\"] + [f'result {i}' for i in range(1, len(images))])\n"
"num_images_per_prompt = 3 # adjust according to VRAM size\n",
"# latents = torch.randn(1 + num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)\n",
"latents[1:] = torch.randn(num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)\n",
"for deph_map, target_prompt in zip((depth_image1, depth_image2), target_prompts):\n",
" images = pipeline_calls.controlnet_call(pipeline, [reference_prompt, target_prompt],\n",
" image=deph_map,\n",
" num_inference_steps=50,\n",
" controlnet_conditioning_scale=controlnet_conditioning_scale,\n",
" num_images_per_prompt=num_images_per_prompt,\n",
" latents=latents)\n",
" \n",
" mediapy.show_images([images[0], deph_map] + images[1:], titles=[\"reference\", \"depth\"] + [f'result {i}' for i in range(1, len(images))])\n"
]
},
{
Expand Down
44 changes: 8 additions & 36 deletions style_aligned_w_multidiffusion.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@
{
"cell_type": "markdown",
"id": "50fa980f-1bae-40c1-a1f3-f5f89bef60d3",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"metadata": {},
"source": [
"## Copyright 2023 Google LLC"
]
Expand All @@ -16,11 +12,7 @@
"cell_type": "code",
"execution_count": null,
"id": "5da5f038-057f-4475-a783-95660f98238c",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"metadata": {},
"outputs": [],
"source": [
"# Copyright 2023 Google LLC\n",
Expand All @@ -41,11 +33,7 @@
{
"cell_type": "markdown",
"id": "c3a7c069-c441-4204-a905-59cbd9edc13a",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"metadata": {},
"source": [
"# MultiDiffusion with StyleAligned over SD v2"
]
Expand All @@ -54,11 +42,7 @@
"cell_type": "code",
"execution_count": null,
"id": "14178de7-d4c8-4881-ac1d-ff84bae57c6f",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
Expand All @@ -72,11 +56,7 @@
"cell_type": "code",
"execution_count": null,
"id": "738cee0e-4d6e-4875-b4df-eadff6e27e7f",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"metadata": {},
"outputs": [],
"source": [
"# init models\n",
Expand All @@ -101,11 +81,7 @@
"cell_type": "code",
"execution_count": null,
"id": "ea61e789-2814-4820-8ae7-234c3c6640a0",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"metadata": {},
"outputs": [],
"source": [
"reference_prompt = \"a beautiful papercut art design\"\n",
Expand All @@ -121,11 +97,7 @@
"cell_type": "code",
"execution_count": null,
"id": "791a9b28-f0ce-4fd0-9f3c-594281c2ae56",
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"metadata": {},
"outputs": [],
"source": []
}
Expand All @@ -151,4 +123,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

0 comments on commit 733572a

Please sign in to comment.