diff --git a/spkatt-2023.ipynb b/spkatt-2023.ipynb index 2b3ff00..968a8e8 100644 --- a/spkatt-2023.ipynb +++ b/spkatt-2023.ipynb @@ -44,6 +44,8 @@ "from tqdm import tqdm\n", "import json\n", "import warnings\n", + "import gc\n", + "import torch\n", "\n", "from datasets import (\n", " load_dataset,\n", @@ -1647,6 +1649,42 @@ "- `save_steps` and `max_steps`: set `max_steps` to control the length of training (`save_steps` determines when checkpoints are created)\n" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# choose config for cue model\n", + "cues_training_config = \"./configs/7b_cues.args\" # 7b model\n", + "# cues_training_config = \"./configs/70b_cues.args\" # 70b model\n", + "\n", + "train(cues_training_config)\n", + "\n", + "# free vram after training\n", + "gc.collect()\n", + "torch.cuda.empty_cache()\n", + "gc.collect()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# choose config for roles model\n", + "roles_training_config = \"./configs/7b_roles.args\" # 7b model\n", + "# roles_training_config = \"./configs/70b_roles.args\" # 70b model\n", + "\n", + "train(roles_training_config)\n", + "\n", + "# free vram after training\n", + "gc.collect()\n", + "torch.cuda.empty_cache()\n", + "gc.collect()" + ] + }, { "cell_type": "code", "execution_count": null,