forked from hiroki13/instance-based-ner
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_knn_models.py
141 lines (130 loc) · 4.8 KB
/
train_knn_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import os
import json
from models.knn_models import KnnModel
from utils.batchers.knn_batchers import BaseKnnBatcher
from utils.preprocessors.knn_preprocessors import KnnPreprocessor
def set_config(args, config):
if args.raw_path:
config["raw_path"] = args.raw_path
if args.save_path:
config["save_path"] = args.save_path
config["train_set"] = os.path.join(args.save_path, "train.json")
config["valid_set"] = os.path.join(args.save_path, "valid.json")
config["vocab"] = os.path.join(args.save_path, "vocab.json")
config["pretrained_emb"] = os.path.join(args.save_path, "glove_emb.npz")
if args.train_set:
config["train_set"] = args.train_set
if args.valid_set:
config["valid_set"] = args.valid_set
if args.pretrained_emb:
config["pretrained_emb"] = args.pretrained_emb
if args.vocab:
config["vocab"] = args.vocab
if args.checkpoint_path:
config["checkpoint_path"] = args.checkpoint_path
config["summary_path"] = os.path.join(args.checkpoint_path, "summary")
if args.summary_path:
config["summary_path"] = args.summary_path
if args.model_name:
config["model_name"] = args.model_name
if args.batch_size:
config["batch_size"] = args.batch_size
if args.data_size:
config["data_size"] = args.data_size
if args.bilstm_type:
config["bilstm_type"] = args.bilstm_type
if args.keep_prob:
config["keep_prob"] = args.keep_prob
if args.k:
config["k"] = args.k
if args.predict:
config["predict"] = args.predict
if args.max_span_len:
config["max_span_len"] = args.max_span_len
if args.max_n_spans:
config["max_n_spans"] = args.max_n_spans
if args.knn_sampling:
config["knn_sampling"] = args.knn_sampling
return config
def main(args):
config = json.load(open(args.config_file))
config = set_config(args, config)
preprocessor = KnnPreprocessor(config)
# create dataset from raw data files
if not os.path.exists(config["save_path"]):
preprocessor.preprocess()
model = KnnModel(config, BaseKnnBatcher(config))
model.train()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config_file',
required=True,
default='data/config/config.json',
help='Configuration file')
parser.add_argument('--raw_path',
default=None,
help='Raw data directory')
parser.add_argument('--save_path',
default=None,
help='Save directory')
parser.add_argument('--checkpoint_path',
default=None,
help='Checkpoint directory')
parser.add_argument('--summary_path',
default=None,
help='Summary directory')
parser.add_argument('--model_name',
default=None,
help='Model name')
parser.add_argument('--batch_size',
default=None,
type=int,
help='Batch size')
parser.add_argument('--train_set',
default=None,
help='path to training set')
parser.add_argument('--valid_set',
default=None,
help='path to training set')
parser.add_argument('--pretrained_emb',
default=None,
help='path to pretrained embeddings')
parser.add_argument('--vocab',
default=None,
help='path to vocabulary')
parser.add_argument('--data_size',
default=None,
type=int,
help='Data size')
parser.add_argument('--bilstm_type',
default=None,
help='standard/interleave')
parser.add_argument('--keep_prob',
default=None,
type=float,
help='Keep (dropout) probability')
parser.add_argument('--k',
default=None,
type=int,
help='k-NN sentences')
parser.add_argument('--predict',
default='max_margin',
help='prediction methods')
parser.add_argument('--max_span_len',
default=None,
type=int,
help='max span length')
parser.add_argument('--max_n_spans',
default=None,
type=int,
help='max num of spans')
parser.add_argument('--knn_sampling',
default=None,
help='k-NN sentence sampling')
main(parser.parse_args())