Skip to content

Commit

Permalink
compatiable pt1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
jiesutd committed Jan 7, 2019
1 parent 44715fb commit 6017bdd
Show file tree
Hide file tree
Showing 13 changed files with 1,648 additions and 122 deletions.
53 changes: 53 additions & 0 deletions demo.clf.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
### use # to comment out the configure item

sentence_classification=True

### I/O ###
train_dir=sample_data/sentence_classification_train.txt
dev_dir=sample_data/sentence_classification_dev.txt
test_dir=sample_data/sentence_classification_test.txt
model_dir=sample_data/clf
word_emb_dir=sample_data/sample.word.emb

#raw_dir=
#decode_dir=
#dset_dir=
#load_model_dir=
#char_emb_dir=

norm_word_emb=False
norm_char_emb=False
number_normalized=True
seg=False
word_emb_dim=50
char_emb_dim=30

###NetworkConfiguration###
use_crf=True
use_char=True
word_seq_feature=CNN
char_seq_feature=CNN
#feature=[POS] emb_size=20
#feature=[Cap] emb_size=20
#nbest=1

###TrainingSetting###
status=train
optimizer=SGD
iteration=1
batch_size=10
ave_batch_loss=False

###Hyperparameters###
cnn_layer=4
char_hidden_dim=50
hidden_dim=200
dropout=0.5
lstm_layer=1
bilstm=True
learning_rate=0.015
lr_decay=0.05
momentum=0
l2=1e-8
#gpu
#clip=
162 changes: 134 additions & 28 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author: Jie
# @Date: 2017-06-15 14:11:08
# @Last Modified by: Jie Yang, Contact: [email protected]
# @Last Modified time: 2018-12-16 22:35:00
# @Last Modified time: 2019-01-01 23:58:38

from __future__ import print_function
import time
Expand All @@ -15,7 +15,8 @@
import torch.optim as optim
import numpy as np
from utils.metric import get_ner_fmeasure
from model.seqmodel import SeqModel
from model.seqlabel import SeqLabel
from model.sentclassifier import SentClassifier
from utils.data import Data

try:
Expand All @@ -38,7 +39,7 @@ def data_initialization(data):
data.fix_alphabet()


def predict_check(pred_variable, gold_variable, mask_variable):
def predict_check(pred_variable, gold_variable, mask_variable, sentence_classification=False):
"""
input:
pred_variable (batch_size, sent_len): pred tag result, in numpy format
Expand All @@ -49,37 +50,46 @@ def predict_check(pred_variable, gold_variable, mask_variable):
gold = gold_variable.cpu().data.numpy()
mask = mask_variable.cpu().data.numpy()
overlaped = (pred == gold)
right_token = np.sum(overlaped * mask)
total_token = mask.sum()
if sentence_classification:
right_token = np.sum(overlaped)
total_token = overlaped.shape[0] ## =batch_size
else:
right_token = np.sum(overlaped * mask)
total_token = mask.sum()
# print("right: %s, total: %s"%(right_token, total_token))
return right_token, total_token


def recover_label(pred_variable, gold_variable, mask_variable, label_alphabet, word_recover):
def recover_label(pred_variable, gold_variable, mask_variable, label_alphabet, word_recover, sentence_classification=False):
"""
input:
pred_variable (batch_size, sent_len): pred tag result
gold_variable (batch_size, sent_len): gold result variable
mask_variable (batch_size, sent_len): mask variable
"""

pred_variable = pred_variable[word_recover]
gold_variable = gold_variable[word_recover]
mask_variable = mask_variable[word_recover]
batch_size = gold_variable.size(0)
seq_len = gold_variable.size(1)
mask = mask_variable.cpu().data.numpy()
pred_tag = pred_variable.cpu().data.numpy()
gold_tag = gold_variable.cpu().data.numpy()
batch_size = mask.shape[0]
pred_label = []
gold_label = []
for idx in range(batch_size):
pred = [label_alphabet.get_instance(pred_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0]
gold = [label_alphabet.get_instance(gold_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0]
assert(len(pred)==len(gold))
pred_label.append(pred)
gold_label.append(gold)
if sentence_classification:
pred_tag = pred_variable.cpu().data.numpy().tolist()
gold_tag = gold_variable.cpu().data.numpy().tolist()
pred_label = [label_alphabet.get_instance(pred) for pred in pred_tag]
gold_label = [label_alphabet.get_instance(gold) for gold in gold_tag]
else:
seq_len = gold_variable.size(1)
mask = mask_variable.cpu().data.numpy()
pred_tag = pred_variable.cpu().data.numpy()
gold_tag = gold_variable.cpu().data.numpy()
batch_size = mask.shape[0]
pred_label = []
gold_label = []
for idx in range(batch_size):
pred = [label_alphabet.get_instance(pred_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0]
gold = [label_alphabet.get_instance(gold_tag[idx][idy]) for idy in range(seq_len) if mask[idx][idy] != 0]
assert(len(pred)==len(gold))
pred_label.append(pred)
gold_label.append(gold)
return pred_label, gold_label


Expand Down Expand Up @@ -178,7 +188,7 @@ def evaluate(data, model, name, nbest=None):
instance = instances[start:end]
if not instance:
continue
batch_word, batch_features, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, batch_label, mask = batchify_with_label(instance, data.HP_gpu, False)
batch_word, batch_features, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, batch_label, mask = batchify_with_label(instance, data.HP_gpu, False, data.sentence_classification)
if nbest:
scores, nbest_tag_seq = model.decode_nbest(batch_word,batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, mask, nbest)
nbest_pred_result = recover_nbest_label(nbest_tag_seq, mask, data.label_alphabet, batch_wordrecover)
Expand All @@ -189,7 +199,7 @@ def evaluate(data, model, name, nbest=None):
else:
tag_seq = model(batch_word, batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, mask)
# print("tag:",tag_seq)
pred_label, gold_label = recover_label(tag_seq, batch_label, mask, data.label_alphabet, batch_wordrecover)
pred_label, gold_label = recover_label(tag_seq, batch_label, mask, data.label_alphabet, batch_wordrecover, data.sentence_classification)
pred_results += pred_label
gold_results += gold_label
decode_time = time.time() - start_time
Expand All @@ -200,14 +210,25 @@ def evaluate(data, model, name, nbest=None):
return speed, acc, p, r, f, pred_results, pred_scores


def batchify_with_label(input_batch_list, gpu, if_train=True):
def batchify_with_label(input_batch_list, gpu, if_train=True, sentence_classification=False):
if sentence_classification:
return batchify_sentence_classification_with_label(input_batch_list, gpu, if_train)
else:
return batchify_sequence_labeling_with_label(input_batch_list, gpu, if_train)


def batchify_sequence_labeling_with_label(input_batch_list, gpu, if_train=True):
"""
input: list of words, chars and labels, various length. [[words,chars, labels],[words,chars,labels],...]
input: list of words, chars and labels, various length. [[words, features, chars, labels],[words, features, chars,labels],...]
words: word ids for one sentence. (batch_size, sent_len)
features: features ids for one sentence. (batch_size, sent_len, feature_num)
chars: char ids for on sentences, various length. (batch_size, sent_len, each_word_length)
labels: label ids for one sentence. (batch_size, sent_len)
output:
zero padding for word and char, with their batch length
word_seq_tensor: (batch_size, max_sent_len) Variable
feature_seq_tensors: [(batch_size, max_sent_len),...] list of Variable
word_seq_lengths: (batch_size,1) Tensor
char_seq_tensor: (batch_size*max_sent_len, max_word_len) Variable
char_seq_lengths: (batch_size*max_sent_len,1) Tensor
Expand Down Expand Up @@ -274,13 +295,97 @@ def batchify_with_label(input_batch_list, gpu, if_train=True):
return word_seq_tensor,feature_seq_tensors, word_seq_lengths, word_seq_recover, char_seq_tensor, char_seq_lengths, char_seq_recover, label_seq_tensor, mask


def batchify_sentence_classification_with_label(input_batch_list, gpu, if_train=True):
"""
input: list of words, chars and labels, various length. [[words, features, chars, labels],[words, features, chars,labels],...]
words: word ids for one sentence. (batch_size, sent_len)
features: features ids for one sentence. (batch_size, feature_num), each sentence has one set of feature
chars: char ids for on sentences, various length. (batch_size, sent_len, each_word_length)
labels: label ids for one sentence. (batch_size,), each sentence has one set of feature
output:
zero padding for word and char, with their batch length
word_seq_tensor: (batch_size, max_sent_len) Variable
feature_seq_tensors: [(batch_size,), ... ] list of Variable
word_seq_lengths: (batch_size,1) Tensor
char_seq_tensor: (batch_size*max_sent_len, max_word_len) Variable
char_seq_lengths: (batch_size*max_sent_len,1) Tensor
char_seq_recover: (batch_size*max_sent_len,1) recover char sequence order
label_seq_tensor: (batch_size, )
mask: (batch_size, max_sent_len)
"""

batch_size = len(input_batch_list)
words = [sent[0] for sent in input_batch_list]
features = [np.asarray(sent[1]) for sent in input_batch_list]
feature_num = len(features[0])
chars = [sent[2] for sent in input_batch_list]
labels = [sent[3] for sent in input_batch_list]
word_seq_lengths = torch.LongTensor(list(map(len, words)))
max_seq_len = word_seq_lengths.max().item()
word_seq_tensor = torch.zeros((batch_size, max_seq_len), requires_grad = if_train).long()
label_seq_tensor = torch.zeros((batch_size, ), requires_grad = if_train).long()
feature_seq_tensors = []
for idx in range(feature_num):
feature_seq_tensors.append(torch.zeros((batch_size, max_seq_len),requires_grad = if_train).long())
mask = torch.zeros((batch_size, max_seq_len), requires_grad = if_train).byte()
label_seq_tensor = torch.LongTensor(labels)
# exit(0)
for idx, (seq, seqlen) in enumerate(zip(words, word_seq_lengths)):
seqlen = seqlen.item()
word_seq_tensor[idx, :seqlen] = torch.LongTensor(seq)
mask[idx, :seqlen] = torch.Tensor([1]*seqlen)
for idy in range(feature_num):
feature_seq_tensors[idy][idx,:seqlen] = torch.LongTensor(features[idx][:,idy])
word_seq_lengths, word_perm_idx = word_seq_lengths.sort(0, descending=True)
word_seq_tensor = word_seq_tensor[word_perm_idx]
for idx in range(feature_num):
feature_seq_tensors[idx] = feature_seq_tensors[idx][word_perm_idx]
label_seq_tensor = label_seq_tensor[word_perm_idx]
mask = mask[word_perm_idx]
### deal with char
# pad_chars (batch_size, max_seq_len)
pad_chars = [chars[idx] + [[0]] * (max_seq_len-len(chars[idx])) for idx in range(len(chars))]
length_list = [list(map(len, pad_char)) for pad_char in pad_chars]
max_word_len = max(map(max, length_list))
char_seq_tensor = torch.zeros((batch_size, max_seq_len, max_word_len), requires_grad = if_train).long()
char_seq_lengths = torch.LongTensor(length_list)
for idx, (seq, seqlen) in enumerate(zip(pad_chars, char_seq_lengths)):
for idy, (word, wordlen) in enumerate(zip(seq, seqlen)):
# print len(word), wordlen
char_seq_tensor[idx, idy, :wordlen] = torch.LongTensor(word)

char_seq_tensor = char_seq_tensor[word_perm_idx].view(batch_size*max_seq_len,-1)
char_seq_lengths = char_seq_lengths[word_perm_idx].view(batch_size*max_seq_len,)
char_seq_lengths, char_perm_idx = char_seq_lengths.sort(0, descending=True)
char_seq_tensor = char_seq_tensor[char_perm_idx]
_, char_seq_recover = char_perm_idx.sort(0, descending=False)
_, word_seq_recover = word_perm_idx.sort(0, descending=False)
if gpu:
word_seq_tensor = word_seq_tensor.cuda()
for idx in range(feature_num):
feature_seq_tensors[idx] = feature_seq_tensors[idx].cuda()
word_seq_lengths = word_seq_lengths.cuda()
word_seq_recover = word_seq_recover.cuda()
label_seq_tensor = label_seq_tensor.cuda()
char_seq_tensor = char_seq_tensor.cuda()
char_seq_recover = char_seq_recover.cuda()
mask = mask.cuda()
return word_seq_tensor,feature_seq_tensors, word_seq_lengths, word_seq_recover, char_seq_tensor, char_seq_lengths, char_seq_recover, label_seq_tensor, mask




def train(data):
print("Training model...")
data.show_data_summary()
save_data_name = data.model_dir +".dset"
data.save(save_data_name)
model = SeqModel(data)
loss_function = nn.NLLLoss()
if data.sentence_classification:
model = SentClassifier(data)
else:
model = SeqLabel(data)
# loss_function = nn.NLLLoss()
if data.optimizer.lower() == "sgd":
optimizer = optim.SGD(model.parameters(), lr=data.HP_lr, momentum=data.HP_momentum,weight_decay=data.HP_l2)
elif data.optimizer.lower() == "adagrad":
Expand Down Expand Up @@ -325,10 +430,10 @@ def train(data):
instance = data.train_Ids[start:end]
if not instance:
continue
batch_word, batch_features, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, batch_label, mask = batchify_with_label(instance, data.HP_gpu)
batch_word, batch_features, batch_wordlen, batch_wordrecover, batch_char, batch_charlen, batch_charrecover, batch_label, mask = batchify_with_label(instance, data.HP_gpu, True, data.sentence_classification)
instance_count += 1
loss, tag_seq = model.neg_log_likelihood_loss(batch_word,batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, batch_label, mask)
right, whole = predict_check(tag_seq, batch_label, mask)
right, whole = predict_check(tag_seq, batch_label, mask, data.sentence_classification)
right_token += right
whole_token += whole
# print("loss:",loss.item())
Expand Down Expand Up @@ -426,6 +531,7 @@ def load_model_decode(data, name):
data = Data()
data.HP_gpu = torch.cuda.is_available()
data.read_config(args.config)
data.show_data_summary()
status = data.status.lower()
print("Seed num:",seed_num)

Expand Down
6 changes: 3 additions & 3 deletions main_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author: Jie
# @Date: 2017-06-15 14:11:08
# @Last Modified by: Jie Yang, Contact: [email protected]
# @Last Modified time: 2018-09-05 23:05:37
# @Last Modified time: 2019-01-01 21:09:38

from __future__ import print_function
import time
Expand All @@ -18,7 +18,7 @@
import torch.optim as optim
import numpy as np
from utils.metric import get_ner_fmeasure
from model.seqmodel import SeqModel
from model.seqlabel import SeqLabel
from utils.data import Data

try:
Expand Down Expand Up @@ -283,7 +283,7 @@ def train(data):
data.show_data_summary()
save_data_name = data.model_dir +".dset"
data.save(save_data_name)
model = SeqModel(data)
model = SeqLabel(data)
loss_function = nn.NLLLoss()
if data.optimizer.lower() == "sgd":
optimizer = optim.SGD(model.parameters(), lr=data.HP_lr, momentum=data.HP_momentum,weight_decay=data.HP_l2)
Expand Down
2 changes: 1 addition & 1 deletion model/charcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author: Jie Yang
# @Date: 2017-10-17 16:47:32
# @Last Modified by: Jie Yang, Contact: [email protected]
# @Last Modified time: 2018-04-26 13:21:40
# @Last Modified time: 2019-01-02 00:31:54
from __future__ import print_function
import torch
import torch.nn as nn
Expand Down
Loading

0 comments on commit 6017bdd

Please sign in to comment.