Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
nesasio93 committed Oct 7, 2023
1 parent 3a9c5e3 commit 2194560
Showing 1 changed file with 30 additions and 49 deletions.
79 changes: 30 additions & 49 deletions spkatt-2023.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"import os\n",
"\n",
"os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0,1,2,3\""
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0,1,2,3\"\n"
]
},
{
Expand Down Expand Up @@ -59,7 +59,7 @@
"from transformers import AutoTokenizer\n",
"from transformers import logging as trans_logging\n",
"\n",
"from qlora import train"
"from qlora import train\n"
]
},
{
Expand All @@ -78,7 +78,7 @@
"ds_logging.set_verbosity_error()\n",
"ds_logging.disable_progress_bar()\n",
"trans_logging.set_verbosity_error()\n",
"warnings.filterwarnings(\"ignore\")"
"warnings.filterwarnings(\"ignore\")\n"
]
},
{
Expand Down Expand Up @@ -134,7 +134,7 @@
" features=features,\n",
" )\n",
" ds = ds.add_column(\"FileName\", [file] * len(ds))\n",
" return ds"
" return ds\n"
]
},
{
Expand All @@ -149,7 +149,7 @@
" )\n",
" ds = ds.add_column(\"FileName\", [file] * len(ds))\n",
" ds = ds.add_column(\"Sentence\", [\" \".join(t) for t in ds[\"Tokens\"]])\n",
" return ds"
" return ds\n"
]
},
{
Expand All @@ -169,7 +169,7 @@
" [dataset, read_annotations_from_file(path, file)]\n",
" )\n",
"\n",
" return dataset"
" return dataset\n"
]
},
{
Expand All @@ -190,7 +190,7 @@
" )\n",
"\n",
" dataset = dataset.add_column(\"id\", range(len(dataset)))\n",
" return dataset"
" return dataset\n"
]
},
{
Expand All @@ -214,7 +214,7 @@
" os.makedirs(path_to_dataset, exist_ok=True)\n",
" result.save_to_disk(path_to_dataset)\n",
"\n",
" return result"
" return result\n"
]
},
{
Expand All @@ -237,7 +237,7 @@
" )\n",
" os.makedirs(path_to_dataset, exist_ok=True)\n",
" result.save_to_disk(path_to_dataset)\n",
" return result"
" return result\n"
]
},
{
Expand All @@ -248,7 +248,7 @@
"source": [
"train_sentences_dataset = read_sentences_dataset(\"train\")\n",
"val_sentences_dataset = read_sentences_dataset(\"dev\")\n",
"test_sentences_dataset = read_sentences_dataset(\"eval\")"
"test_sentences_dataset = read_sentences_dataset(\"eval\")\n"
]
},
{
Expand All @@ -258,7 +258,7 @@
"outputs": [],
"source": [
"train_annotations_dataset = read_annotations_dataset(\"train\")\n",
"val_annotations_dataset = read_annotations_dataset(\"dev\")"
"val_annotations_dataset = read_annotations_dataset(\"dev\")\n"
]
},
{
Expand All @@ -285,7 +285,7 @@
" and r[\"SentenceId\"] == int(anno.split(\":\")[0])\n",
" )[0]\n",
" tokens.append(temp_row[\"Tokens\"][int(anno.split(\":\")[1])])\n",
" return tokens"
" return tokens\n"
]
},
{
Expand Down Expand Up @@ -466,7 +466,7 @@
" os.makedirs(path_to_dataset, exist_ok=True)\n",
" res.save_to_disk(path_to_dataset)\n",
"\n",
" return res"
" return res\n"
]
},
{
Expand All @@ -479,7 +479,7 @@
" train_sentences_dataset, train_annotations_dataset, \"train\"\n",
")\n",
"val_ds = build_complete_dataset(val_sentences_dataset, val_annotations_dataset, \"dev\")\n",
"test_ds = build_complete_dataset(test_sentences_dataset, None, \"eval\")"
"test_ds = build_complete_dataset(test_sentences_dataset, None, \"eval\")\n"
]
},
{
Expand All @@ -488,7 +488,7 @@
"metadata": {},
"outputs": [],
"source": [
"inputs = test_sentences_dataset.rename_column(\"Sentence\", \"Satz\")"
"inputs = test_sentences_dataset.rename_column(\"Sentence\", \"Satz\")\n"
]
},
{
Expand Down Expand Up @@ -562,7 +562,7 @@
}
],
"source": [
"train_ds[52]"
"train_ds[52]\n"
]
},
{
Expand Down Expand Up @@ -1003,7 +1003,7 @@
}
],
"source": [
"train_ds[15]"
"train_ds[15]\n"
]
},
{
Expand All @@ -1022,7 +1022,7 @@
"def map_cues_to_string(mapped):\n",
" if mapped == []:\n",
" return \"#UNK#\"\n",
" return \", \".join([\"[\" + \", \".join(val) + \"]\" for val in mapped])"
" return \", \".join([\"[\" + \", \".join(val) + \"]\" for val in mapped])\n"
]
},
{
Expand All @@ -1034,7 +1034,7 @@
"def map_roles_to_string(mapped):\n",
" if mapped == []:\n",
" return \"#UNK#\"\n",
" return \", \".join(mapped)"
" return \", \".join(mapped)\n"
]
},
{
Expand Down Expand Up @@ -1113,7 +1113,7 @@
" result.append(element)\n",
"\n",
" with open(lmsys_data_path, \"w\", encoding=\"utf8\") as outfile:\n",
" json.dump(result, outfile, indent=3, ensure_ascii=False)"
" json.dump(result, outfile, indent=3, ensure_ascii=False)\n"
]
},
{
Expand All @@ -1122,7 +1122,7 @@
"metadata": {},
"outputs": [],
"source": [
"build_lmsys_format(train_ds, val_ds)"
"build_lmsys_format(train_ds, val_ds)\n"
]
},
{
Expand Down Expand Up @@ -1201,7 +1201,7 @@
" f.write(\"\\n\".join(all_prompts_cues))\n",
"\n",
"with open(parsed_roles_file, \"w\") as f:\n",
" f.write(\"\\n\".join(all_prompts_roles))"
" f.write(\"\\n\".join(all_prompts_roles))\n"
]
},
{
Expand Down Expand Up @@ -1302,7 +1302,7 @@
"for l in lines[:5]:\n",
" print(\"=== in: ===\\n\" + json.loads(l)[\"input\"] + \"\\n\")\n",
" print(\"=== out: ===\\n\" + json.loads(l)[\"output\"] + \"\\n\")\n",
" print()"
" print()\n"
]
},
{
Expand Down Expand Up @@ -1463,7 +1463,7 @@
"for l in lines[:5]:\n",
" print(\"=== in: ===\\n\" + json.loads(l)[\"input\"] + \"\\n\")\n",
" print(\"=== out: ===\\n\" + json.loads(l)[\"output\"] + \"\\n\")\n",
" print()"
" print()\n"
]
},
{
Expand Down Expand Up @@ -1568,7 +1568,7 @@
" enc_in = tokenizer.encode(json.loads(l)[\"input\"])\n",
" encoded_inputs_roles.append(enc_in)\n",
" enc_out = tokenizer.encode(json.loads(l)[\"output\"])\n",
" encoded_outputs_roles.append(enc_out)"
" encoded_outputs_roles.append(enc_out)\n"
]
},
{
Expand Down Expand Up @@ -1598,7 +1598,7 @@
"print(f\"mean length: {np.mean(len_enc)}\")\n",
"print(\n",
" f\"number of samples longer than {max_length_source_roles}: {sum(np.array(len_enc) > max_length_source_roles)}\"\n",
")"
")\n"
]
},
{
Expand Down Expand Up @@ -1628,7 +1628,7 @@
"print(f\"mean length: {np.mean(len_enc)}\")\n",
"print(\n",
" f\"number of samples longer than {max_length_target_roles}: {sum(np.array(len_enc) > max_length_target_roles)}\"\n",
")"
")\n"
]
},
{
Expand Down Expand Up @@ -1664,7 +1664,7 @@
"# free vram after training\n",
"gc.collect()\n",
"torch.cuda.empty_cache()\n",
"gc.collect()\n"
"gc.collect()"
]
},
{
Expand All @@ -1682,26 +1682,7 @@
"# free vram after training\n",
"gc.collect()\n",
"torch.cuda.empty_cache()\n",
"gc.collect()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# define config files for training\n",
"# 7B models\n",
"cues_training_config = \"./configs/7b_cues.args\"\n",
"roles_training_config = \"./configs/7b_roles.args\"\n",
"\n",
"# 70B models\n",
"# cues_training_config = \"./configs/70b_cues.args\"\n",
"# roles_training_config = \"./configs/70b_roles.args\"\n",
"\n",
"train(cues_training_config)\n",
"# train(roles_training_config)"
"gc.collect()\n"
]
},
{
Expand Down Expand Up @@ -1744,7 +1725,7 @@
"# )\n",
"# from langchain import HuggingFacePipeline\n",
"\n",
"# llm = HuggingFacePipeline(pipeline=pipe)\n"
"# llm = HuggingFacePipeline(pipeline=pipe)"
]
}
],
Expand Down

0 comments on commit 2194560

Please sign in to comment.