-
Notifications
You must be signed in to change notification settings - Fork 2
/
10 - Training the AmazonCat-13k Dataset.py
58 lines (45 loc) · 1.42 KB
/
10 - Training the AmazonCat-13k Dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# %%
import json
import utils.storage as st
import utils.dataset as ds
from utils.models import create_classifiers
from itertools import chain
from timeit import default_timer as timer
# %% [markdown]
# # Training on the AmazonCat-13k Dataset
# %%
INPUT_LENGTH = 10
CLASS_COUNT = 13330
# %% [markdown]
# Import the dataset and the embedding layer
# %%
X_train, y_train = ds.import_amazoncat13k('trn.processed', INPUT_LENGTH)
X_test, y_test = ds.import_amazoncat13k('tst', INPUT_LENGTH)
# %% [markdown]
# Actually train the classifiers.
# %%
# Get unique labels
labels = set(ds.amazoncat13k_top10_labels + ds.amazoncat13k_threshold_labels)
def classifiers():
return chain(
create_classifiers(8842, ['50%positive', '20%positive', '10%positive', 'unbalanced']),
chain(*[create_classifiers(label, ['50%positive', 'unbalanced']) for label in labels]),
)
# %%
durations = {}
try:
for c in classifiers():
print(f"Training classifier '{c.name}'.")
start = timer()
c.train(X_train, y_train)
end = timer()
duration = end - start
durations[c.name] = duration
print(f"Training took {duration} seconds.")
Xi_test, yi_test = ds.get_dataset(X_test, y_test, c.id)
yi_predict = c.get_prediction(Xi_test)
st.save_prediction(c.id, c.type_name, yi_predict)
finally:
with open("results/durations.json", 'w') as fp:
json.dump(durations, fp)
# %%