Skip to content

Commit

Permalink
Merge pull request #1 from ChoYongchae/main
Browse files Browse the repository at this point in the history
Add 'low_vram' flag to optimize VRAM usage in rb-modulation.ipynb and fix typo in README.md
  • Loading branch information
LituRout committed Aug 27, 2024
2 parents 4658c64 + 94fa48f commit 51b7f55
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ cd ..
# Install dependencies following the original [StableCascade](https://github.com/Stability-AI/StableCascade/blob/master/inference/readme.md)
conda create -n rbm python==3.9
pip install -r requirements.txt
pip install jupyter notebook opencv-python matplotlib ffty
pip install jupyter notebook opencv-python matplotlib ftfy
# Download [pre-trained CSD weights](https://drive.google.com/file/d/1FX0xs8p-C7Ob-h5Y4cUhTeOepHzXv_46/view) and put it under `third_party/CSD/checkpoint.pth`.
Expand Down
51 changes: 50 additions & 1 deletion rb-modulation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,36 @@
"from gdf.targets import EpsilonTarget\n",
"\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"print(device)"
"print(device)\n",
"\n",
"# Turn on this flag if you don't have enough memory.\n",
"low_vram = False"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c24086d9",
"metadata": {},
"outputs": [],
"source": [
"if low_vram:\n",
" def models_to(model, device=\"cpu\", excepts=None):\n",
" # Change the device of nn.Modules within a class.\n",
" for attr_name in dir(model):\n",
" if attr_name.startswith('__') and attr_name.endswith('__'):\n",
" continue # skip special attributes\n",
"\n",
" attr_value = getattr(model, attr_name, None)\n",
"\n",
" if isinstance(attr_value, torch.nn.Module):\n",
" if excepts and attr_name in excepts:\n",
" print(f\"Except '{attr_name}'\")\n",
" continue\n",
" print(f\"Change device of '{attr_name}' to {device}\")\n",
" attr_value.to(device)\n",
" \n",
" torch.cuda.empty_cache()"
]
},
{
Expand Down Expand Up @@ -4256,6 +4285,11 @@
}
],
"source": [
"if low_vram:\n",
" # Off-load old generator (which is not used in models_rbm)\n",
" models.generator.to(\"cpu\")\n",
" torch.cuda.empty_cache()\n",
"\n",
"generator_rbm = StageCRBM()\n",
"for param_name, param in load_or_fail(core.config.generator_checkpoint_path).items():\n",
" set_module_tensor_to_device(generator_rbm, param_name, \"cpu\", value=param)\n",
Expand Down Expand Up @@ -4428,6 +4462,10 @@
}
],
"source": [
"if low_vram:\n",
" # The sampling process uses more vram, so we offload everything except two modules to the cpu.\n",
" models_to(models_rbm, device=\"cpu\", excepts=[\"generator\", \"previewer\"])\n",
"\n",
"# Stage C reverse process.\n",
"sampling_c = extras.gdf.sample(\n",
" models_rbm.generator, conditions, stage_c_latent_shape,\n",
Expand Down Expand Up @@ -4566,6 +4604,11 @@
}
],
"source": [
"if low_vram:\n",
" # Revert the devices of the modules back to their original state\n",
" # (Assume some modules have already offloaded in the above sampling block.)\n",
" models_to(models_rbm, device)\n",
"\n",
"batch_size = 1\n",
"height, width = 1024, 1024\n",
"stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)\n",
Expand Down Expand Up @@ -4645,6 +4688,12 @@
}
],
"source": [
"if low_vram:\n",
" # The sampling process uses more vram, so we offload everything except two modules to the cpu.\n",
" models_to(models_rbm, device=\"cpu\", excepts=[\"generator\", \"previewer\"])\n",
" models_to(sam_model, device=\"cpu\")\n",
" models_to(sam_model.sam, device=\"cpu\")\n",
"\n",
"# Stage C reverse process.\n",
"sampling_c = extras.gdf.sample(\n",
" models_rbm.generator, conditions, stage_c_latent_shape,\n",
Expand Down

0 comments on commit 51b7f55

Please sign in to comment.