Skip to content

Commit

Permalink
re-factor
Browse files Browse the repository at this point in the history
  • Loading branch information
Yu Zhang committed Feb 14, 2021
1 parent 99686fc commit d65c902
Show file tree
Hide file tree
Showing 32 changed files with 202 additions and 398 deletions.
14 changes: 8 additions & 6 deletions configure/datasets/MAG.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
name: MAG

train:
texts: /shared/data2/yuz9/MATCH/MATCH/data/MAG/train_texts.npy
labels: /shared/data2/yuz9/MATCH/MATCH/data/MAG/train_labels.npy
texts: MAG/train_texts.npy
labels: MAG/train_labels.npy

valid:
size: 70534

test:
texts: /shared/data2/yuz9/MATCH/MATCH/data/MAG/test_texts.npy
texts: MAG/test_texts.npy

embedding:
emb_init: /shared/data2/yuz9/MATCH/MATCH/data/MAG/emb_init.npy
emb_init: MAG/emb_init.npy

hierarchy: MAG/taxonomy.txt

output:
res: /shared/data2/yuz9/MATCH/MATCH/data/MAG/results
res: MAG/results

labels_binarizer: /shared/data2/yuz9/MATCH/MATCH/data/MAG/labels_binarizer
labels_binarizer: MAG/labels_binarizer

model:
emb_size: 100
14 changes: 8 additions & 6 deletions configure/datasets/MeSH.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
name: MeSH

train:
texts: /shared/data2/yuz9/MATCH/MATCH/data/MeSH/train_texts.npy
labels: /shared/data2/yuz9/MATCH/MATCH/data/MeSH/train_labels.npy
texts: MeSH/train_texts.npy
labels: MeSH/train_labels.npy

valid:
size: 89855

test:
texts: /shared/data2/yuz9/MATCH/MATCH/data/MeSH/test_texts.npy
texts: MeSH/test_texts.npy

embedding:
emb_init: /shared/data2/yuz9/MATCH/MATCH/data/MeSH/emb_init.npy
emb_init: MeSH/emb_init.npy

hierarchy: MeSH/taxonomy.txt

output:
res: /shared/data2/yuz9/MATCH/MATCH/data/MeSH/results
res: MeSH/results

labels_binarizer: /shared/data2/yuz9/MATCH/MATCH/data/MeSH/labels_binarizer
labels_binarizer: MeSH/labels_binarizer

model:
emb_size: 100
2 changes: 1 addition & 1 deletion configure/models/MATCH-MAG.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ valid:
predict:
batch_size: 256

path: /shared/data2/yuz9/MATCH/MATCH/data/MAG/models
path: MAG/models
2 changes: 1 addition & 1 deletion configure/models/MATCH-MeSH.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ valid:
predict:
batch_size: 256

path: /shared/data2/yuz9/MATCH/MATCH/data/MeSH/models
path: MeSH/models
25 changes: 0 additions & 25 deletions data/MAG/prepare.py

This file was deleted.

25 changes: 0 additions & 25 deletions data/MeSH/prepare.py

This file was deleted.

Empty file modified deepxml/__init__.py
100755 → 100644
Empty file.
Binary file modified deepxml/__pycache__/data_utils.cpython-36.pyc
Binary file not shown.
Binary file modified deepxml/__pycache__/models.cpython-36.pyc
Binary file not shown.
2 changes: 1 addition & 1 deletion deepxml/data_utils.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,4 @@ def get_sparse_feature(feature_file, label_file):
def output_res(output_path, name, scores, labels):
os.makedirs(output_path, exist_ok=True)
np.save(os.path.join(output_path, F'{name}-scores'), scores)
np.save(os.path.join(output_path, F'{name}-labels'), labels)
np.save(os.path.join(output_path, F'{name}-labels'), labels)
Empty file modified deepxml/dataset.py
100755 → 100644
Empty file.
Empty file modified deepxml/evaluation.py
100755 → 100644
Empty file.
32 changes: 23 additions & 9 deletions deepxml/models.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,42 @@
from logzero import logger
from typing import Optional, Mapping

from deepxml.evaluation import get_p_5, get_n_5
from deepxml.evaluation import get_p_1, get_p_3, get_p_5, get_n_1, get_n_3, get_n_5
from deepxml.optimizers import DenseSparseAdam


class Model(object):
def __init__(self, network, model_path, gradient_clip_value=5.0, device_ids=None, **kwargs):
def __init__(self, network, model_path, mode, reg=False, hierarchy=set(), gradient_clip_value=5.0, device_ids=None, **kwargs):
self.model = nn.DataParallel(network(**kwargs).cuda(), device_ids=device_ids)
self.loss_fn = nn.BCEWithLogitsLoss()
self.model_path, self.state = model_path, {}
os.makedirs(os.path.split(self.model_path)[0], exist_ok=True)
self.gradient_clip_value, self.gradient_norm_queue = gradient_clip_value, deque([np.inf], maxlen=5)
self.optimizer = None

self.reg = reg
if mode == 'train' and reg:
self.hierarchy = hierarchy
self.lambda1 = 1e-8

def train_step(self, train_x: torch.Tensor, train_y: torch.Tensor):
self.optimizer.zero_grad()
self.model.train()
scores = self.model(train_x)
loss = self.loss_fn(scores, train_y)

if self.reg:
probs = torch.sigmoid(scores)
regs = torch.zeros(len(probs), len(self.hierarchy))
for tup in self.hierarchy:
p = tup[0]
c = tup[1]
regs = probs[:,c] - probs[:,p]
loss += self.lambda1 * torch.sum(nn.functional.relu(regs)).item()

loss.backward()
self.clip_gradient()
self.optimizer.step(closure=None)
return loss.item()
return loss.item()

def predict_step(self, data_x: torch.Tensor, k: int):
self.model.eval()
Expand All @@ -44,14 +58,14 @@ def train(self, train_loader: DataLoader, valid_loader: DataLoader, opt_params:
nb_epoch=100, step=100, k=5, early=100, verbose=True, swa_warmup=None, **kwargs):
self.get_optimizer(**({} if opt_params is None else opt_params))
global_step, best_n5, e = 0, 0.0, 0
print_loss = 0.0#
print_loss = 0.0
for epoch_idx in range(nb_epoch):
if epoch_idx == swa_warmup:
self.swa_init()
for i, (train_x, train_y) in enumerate(train_loader, 1):
global_step += 1
loss = self.train_step(train_x, train_y.cuda())
print_loss += loss#
print_loss += loss
if global_step % step == 0:
self.swa_step()
self.swap_swa_params()
Expand All @@ -69,7 +83,7 @@ def train(self, train_loader: DataLoader, valid_loader: DataLoader, opt_params:
labels = np.concatenate(labels)

targets = valid_loader.dataset.data_y
p5, n5 = get_p_5(labels, targets), get_n_5(labels, targets)
p1, p3, p5, n3, n5 = get_p_1(labels, targets), get_p_3(labels, targets), get_p_5(labels, targets), get_n_3(labels, targets), get_n_5(labels, targets)
if n5 > best_n5:
self.save_model(True)
best_n5, e = n5, 0
Expand All @@ -79,8 +93,8 @@ def train(self, train_loader: DataLoader, valid_loader: DataLoader, opt_params:
return
self.swap_swa_params()
if verbose:
log_msg = '%d %d train loss: %.7f valid loss: %.7f P@5: %.5f N@5: %.5f early stop: %d' % \
(epoch_idx, i * train_loader.batch_size, print_loss / step, valid_loss, round(p5, 5), round(n5, 5), e)
log_msg = '%d %d train loss: %.7f valid loss: %.7f P@1: %.5f P@3: %.5f P@5: %.5f N@3: %.5f N@5: %.5f early stop: %d' % \
(epoch_idx, i * train_loader.batch_size, print_loss / step, valid_loss, round(p1, 5), round(p3, 5), round(p5, 5), round(n3, 5), round(n5, 5), e)
logger.info(log_msg)
print_loss = 0.0

Expand Down
Empty file modified deepxml/optimizers.py
100755 → 100644
Empty file.
52 changes: 29 additions & 23 deletions evaluation.py
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
import warnings
warnings.filterwarnings('ignore')

import click
import numpy as np
from sklearn.preprocessing import MultiLabelBinarizer

from deepxml.evaluation import get_p_1, get_p_3, get_p_5, get_n_1, get_n_3, get_n_5

@click.command()
@click.option('-r', '--results', type=click.Path(exists=True), help='Path of results.')
@click.option('-t', '--targets', type=click.Path(exists=True), help='Path of targets.')
@click.option('--train-labels', type=click.Path(exists=True), default=None, help='Path of labels for training set.')

def main(results, targets, train_labels):
res, targets = np.load(results, allow_pickle=True), np.load(targets, allow_pickle=True)
mlb = MultiLabelBinarizer(sparse_output=True)
targets = mlb.fit_transform(targets)
print('Precision@1,3,5:', get_p_1(res, targets, mlb), get_p_3(res, targets, mlb), get_p_5(res, targets, mlb))
print('nDCG@1,3,5:', get_n_1(res, targets, mlb), get_n_3(res, targets, mlb), get_n_5(res, targets, mlb))

if __name__ == '__main__':
main()
import warnings
warnings.filterwarnings('ignore')

import click
import numpy as np
from sklearn.preprocessing import MultiLabelBinarizer

from deepxml.evaluation import get_p_1, get_p_3, get_p_5, get_n_1, get_n_3, get_n_5

@click.command()
@click.option('-r', '--results', type=click.Path(exists=True), help='Path of results.')
@click.option('-t', '--targets', type=click.Path(exists=True), help='Path of targets.')
@click.option('--train-labels', type=click.Path(exists=True), default=None, help='Path of labels for training set.')

def main(results, targets, train_labels):
res, targets = np.load(results, allow_pickle=True), np.load(targets, allow_pickle=True)

topk = 5
with open('predictions.txt', 'w') as fout:
for labels in res:
fout.write(' '.join(labels[:topk])+'\n')

mlb = MultiLabelBinarizer(sparse_output=True)
targets = mlb.fit_transform(targets)
print('Precision@1,3,5:', get_p_1(res, targets, mlb), get_p_3(res, targets, mlb), get_p_5(res, targets, mlb))
print('nDCG@1,3,5:', get_n_1(res, targets, mlb), get_n_3(res, targets, mlb), get_n_5(res, targets, mlb))

if __name__ == '__main__':
main()
29 changes: 14 additions & 15 deletions joint/PreprocessMAG.py → joint/Preprocess.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,42 @@
import json
from collections import defaultdict
import argparse

folder = '/shared/data2/yuz9/MATCH/MAG_data/'
tot = 705425
thrs = 5
parser = argparse.ArgumentParser(description='main', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--dataset', default='MAG', choices=['MAG', 'MeSH'])

args = parser.parse_args()
folder = '../'+args.dataset+'/'

thrs = 5
left = set()
right = set()

node2cnt = defaultdict(int)
with open(folder+'MAG_CS.json') as fin:
with open(folder+'train.json') as fin:
for idx, line in enumerate(fin):
if idx % 10000 == 0:
print(idx)
if idx >= tot * 0.8:
continue

js = json.loads(line)

for W in js['text'].split():
node2cnt[W] += 1

for A0 in js['author']:
A = 'AUTHOR_' + A0
node2cnt[A] += 1

with open(folder+'MAG_CS.json') as fin, open('mag/network.dat', 'w') as fout:
with open(folder+'train.json') as fin, open('network.dat', 'w') as fout:
for idx, line in enumerate(fin):
if idx % 10000 == 0:
print(idx)
if idx >= tot * 0.8:
continue

js = json.loads(line)

P = 'PAPER_'+js['paper']
left.add(P)

# P-L
for L0 in js['fos']:
for L0 in js['label']:
L = 'LABEL_' + L0
fout.write(P+' '+L+' 0 1 \n')
right.add(L)
Expand Down Expand Up @@ -74,14 +73,14 @@
continue
for j in range(i-5, i+6):
if j < 0 or j >= len(words) or j == i:
break
continue
Wj = words[j]
if node2cnt[Wj] < thrs:
continue
fout.write(Wj+' '+Wi+' 5 1 \n')
left.add(Wj)

with open('mag/left.dat', 'w') as fou1, open('mag/right.dat', 'w') as fou2:
with open('left.dat', 'w') as fou1, open('right.dat', 'w') as fou2:
for x in left:
fou1.write(x+'\n')
for x in right:
Expand Down
Loading

0 comments on commit d65c902

Please sign in to comment.