-
Notifications
You must be signed in to change notification settings - Fork 23
/
data_utils.py
80 lines (64 loc) · 3.04 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import numpy as np
import joblib
from collections import Counter
from sklearn.preprocessing import MultiLabelBinarizer, normalize
from sklearn.datasets import load_svmlight_file
from gensim.models import KeyedVectors
from tqdm import tqdm
from typing import Union, Iterable
def build_vocab(texts: Iterable, w2v_model: Union[KeyedVectors, str], vocab_size=500000,
pad='<PAD>', unknown='<UNK>', sep='/SEP/', max_times=1, freq_times=1):
if isinstance(w2v_model, str):
w2v_model = KeyedVectors.load_word2vec_format(w2v_model, binary=False)
emb_size = w2v_model.vector_size
vocab, emb_init = [pad, unknown], [np.zeros(emb_size), np.random.uniform(-1.0, 1.0, emb_size)]
counter = Counter(token for t in texts for token in set(t.split()))
for word, freq in sorted(counter.items(), key=lambda x: (x[1], x[0] in w2v_model), reverse=True):
if word in w2v_model or freq >= freq_times:
vocab.append(word)
# We used embedding of '.' as embedding of '/SEP/' symbol.
word = '.' if word == sep else word
emb_init.append(w2v_model[word] if word in w2v_model else np.random.uniform(-1.0, 1.0, emb_size))
if freq < max_times or vocab_size == len(vocab):
break
return np.asarray(vocab), np.asarray(emb_init)
def get_word_emb(vec_path, vocab_path=None):
if vocab_path is not None:
with open(vocab_path) as fp:
vocab = {word: idx for idx, word in enumerate(fp)}
return np.load(vec_path), vocab
else:
return np.load(vec_path)
def get_data(text_file, label_file=None):
return np.load(text_file, allow_pickle=True), np.load(label_file, allow_pickle=True) if label_file is not None else None
def convert_to_binary(text_file, label_file=None, max_len=None, vocab=None, pad='<PAD>', unknown='<UNK>'):
with open(text_file) as fp:
texts = np.asarray([[vocab.get(word, vocab[unknown]) for word in line.split()]
for line in tqdm(fp, desc='Converting token to id', leave=False)])
labels = None
if label_file is not None:
with open(label_file) as fp:
labels = np.asarray([[label for label in line.split()]
for line in tqdm(fp, desc='Converting labels', leave=False)])
return truncate_text(texts, max_len, vocab[pad], vocab[unknown]), labels
def truncate_text(texts, max_len=500, padding_idx=0, unknown_idx=1):
if max_len is None:
return texts
texts = np.asarray([list(x[:max_len]) + [padding_idx] * (max_len - len(x)) for x in texts])
texts[(texts == padding_idx).all(axis=1), 0] = unknown_idx
return texts
def get_mlb(mlb_path, labels=None) -> MultiLabelBinarizer:
if os.path.exists(mlb_path):
return joblib.load(mlb_path)
mlb = MultiLabelBinarizer(sparse_output=True)
mlb.fit(labels)
joblib.dump(mlb, mlb_path)
return mlb
def get_sparse_feature(feature_file, label_file):
sparse_x, _ = load_svmlight_file(feature_file, multilabel=True)
return normalize(sparse_x), np.load(label_file) if label_file is not None else None
def output_res(output_path, name, scores, labels):
os.makedirs(output_path, exist_ok=True)
np.save(os.path.join(output_path, F'{name}-scores'), scores)
np.save(os.path.join(output_path, F'{name}-labels'), labels)