Skip to content

Commit

Permalink
Merge pull request jiesutd#32 from abbottLane/master
Browse files Browse the repository at this point in the history
Changes for python3+ compatibility
  • Loading branch information
jiesutd committed Jun 19, 2018
2 parents 94e704b + f13b2d6 commit a42dcd9
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 52 deletions.
50 changes: 26 additions & 24 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import copy
import torch
import gc
import cPickle as pickle
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -22,6 +21,12 @@
from model.seqmodel import SeqModel
from utils.data import Data

try:
import cPickle as pickle
except ModuleNotFoundError:
import pickle as pickle


seed_num = 42
random.seed(seed_num)
torch.manual_seed(seed_num)
Expand Down Expand Up @@ -60,7 +65,7 @@ def recover_label(pred_variable, gold_variable, mask_variable, label_alphabet, w
gold_variable (batch_size, sent_len): gold result variable
mask_variable (batch_size, sent_len): mask variable
"""

pred_variable = pred_variable[word_recover]
gold_variable = gold_variable[word_recover]
mask_variable = mask_variable[word_recover]
Expand Down Expand Up @@ -169,7 +174,7 @@ def evaluate(data, model, name, nbest=None):
total_batch = train_num//batch_size+1
for batch_id in range(total_batch):
start = batch_id*batch_size
end = (batch_id+1)*batch_size
end = (batch_id+1)*batch_size
if end > train_num:
end = train_num
instance = instances[start:end]
Expand All @@ -179,7 +184,7 @@ def evaluate(data, model, name, nbest=None):
if nbest:
scores, nbest_tag_seq = model.decode_nbest(batch_word,batch_features, batch_wordlen, batch_char, batch_charlen, batch_charrecover, mask, nbest)
nbest_pred_result = recover_nbest_label(nbest_tag_seq, mask, data.label_alphabet, batch_wordrecover)
nbest_pred_results += nbest_pred_result
nbest_pred_results += nbest_pred_result
pred_scores += scores[batch_wordrecover].cpu().data.numpy().tolist()
## select the best sequence to evalurate
tag_seq = nbest_tag_seq[:,:,0]
Expand All @@ -200,25 +205,25 @@ def evaluate(data, model, name, nbest=None):
def batchify_with_label(input_batch_list, gpu, volatile_flag=False):
"""
input: list of words, chars and labels, various length. [[words,chars, labels],[words,chars,labels],...]
words: word ids for one sentence. (batch_size, sent_len)
words: word ids for one sentence. (batch_size, sent_len)
chars: char ids for on sentences, various length. (batch_size, sent_len, each_word_length)
output:
zero padding for word and char, with their batch length
word_seq_tensor: (batch_size, max_sent_len) Variable
word_seq_lengths: (batch_size,1) Tensor
char_seq_tensor: (batch_size*max_sent_len, max_word_len) Variable
char_seq_lengths: (batch_size*max_sent_len,1) Tensor
char_seq_recover: (batch_size*max_sent_len,1) recover char sequence order
char_seq_recover: (batch_size*max_sent_len,1) recover char sequence order
label_seq_tensor: (batch_size, max_sent_len)
mask: (batch_size, max_sent_len)
mask: (batch_size, max_sent_len)
"""
batch_size = len(input_batch_list)
words = [sent[0] for sent in input_batch_list]
features = [np.asarray(sent[1]) for sent in input_batch_list]
feature_num = len(features[0][0])
chars = [sent[2] for sent in input_batch_list]
labels = [sent[3] for sent in input_batch_list]
word_seq_lengths = torch.LongTensor(map(len, words))
word_seq_lengths = torch.LongTensor(list(map(len, words)))
max_seq_len = word_seq_lengths.max()
word_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len)), volatile = volatile_flag).long()
label_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len)),volatile = volatile_flag).long()
Expand All @@ -242,15 +247,15 @@ def batchify_with_label(input_batch_list, gpu, volatile_flag=False):
### deal with char
# pad_chars (batch_size, max_seq_len)
pad_chars = [chars[idx] + [[0]] * (max_seq_len-len(chars[idx])) for idx in range(len(chars))]
length_list = [map(len, pad_char) for pad_char in pad_chars]
length_list = [list(map(len, pad_char)) for pad_char in pad_chars]
max_word_len = max(map(max, length_list))
char_seq_tensor = autograd.Variable(torch.zeros((batch_size, max_seq_len, max_word_len)), volatile = volatile_flag).long()
char_seq_lengths = torch.LongTensor(length_list)
for idx, (seq, seqlen) in enumerate(zip(pad_chars, char_seq_lengths)):
for idy, (word, wordlen) in enumerate(zip(seq, seqlen)):
# print len(word), wordlen
char_seq_tensor[idx, idy, :wordlen] = torch.LongTensor(word)

char_seq_tensor = char_seq_tensor[word_perm_idx].view(batch_size*max_seq_len,-1)
char_seq_lengths = char_seq_lengths[word_perm_idx].view(batch_size*max_seq_len,)
char_seq_lengths, char_perm_idx = char_seq_lengths.sort(0, descending=True)
Expand Down Expand Up @@ -315,7 +320,7 @@ def train(data):
total_batch = train_num//batch_size+1
for batch_id in range(total_batch):
start = batch_id*batch_size
end = (batch_id+1)*batch_size
end = (batch_id+1)*batch_size
if end >train_num:
end = train_num
instance = data.train_Ids[start:end]
Expand Down Expand Up @@ -344,8 +349,8 @@ def train(data):
model.zero_grad()
temp_time = time.time()
temp_cost = temp_time - temp_start
print(" Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f"%(end, temp_cost, sample_loss, right_token, whole_token,(right_token+0.)/whole_token))
print(" Instance: %s; Time: %.2fs; loss: %.4f; acc: %s/%s=%.4f"%(end, temp_cost, sample_loss, right_token, whole_token,(right_token+0.)/whole_token))

epoch_finish = time.time()
epoch_cost = epoch_finish - epoch_start
print("Epoch: %s training finished. Time: %.2fs, speed: %.2fst/s, total loss: %s"%(idx, epoch_cost, train_num/epoch_cost, total_loss))
Expand Down Expand Up @@ -373,7 +378,7 @@ def train(data):
model_name = data.model_dir +'.'+ str(idx) + ".model"
print("Save current best model in file:", model_name)
torch.save(model.state_dict(), model_name)
best_dev = current_score
best_dev = current_score
# ## decode test
speed, acc, p, r, f, _,_ = evaluate(data, model, "test")
test_finish = time.time()
Expand All @@ -382,7 +387,7 @@ def train(data):
print("Test: time: %.2fs, speed: %.2fst/s; acc: %.4f, p: %.4f, r: %.4f, f: %.4f"%(test_cost, speed, acc, p, r, f))
else:
print("Test: time: %.2fs, speed: %.2fst/s; acc: %.4f"%(test_cost, speed, acc))
gc.collect()
gc.collect()


def load_model_decode(data, name):
Expand Down Expand Up @@ -416,14 +421,14 @@ def load_model_decode(data, name):
parser = argparse.ArgumentParser(description='Tuning with NCRF++')
# parser.add_argument('--status', choices=['train', 'decode'], help='update algorithm', default='train')
parser.add_argument('--config', help='Configuration File' )

args = parser.parse_args()
data = Data()
data.read_config(args.config)
status = data.status.lower()
data.HP_gpu = torch.cuda.is_available()
print("Seed num:",seed_num)

if status == 'train':
print("MODEL: train")
data_initialization(data)
Expand All @@ -432,12 +437,12 @@ def load_model_decode(data, name):
data.generate_instance('test')
data.build_pretrain_emb()
train(data)
elif status == 'decode':
elif status == 'decode':
print("MODEL: decode")
data.load(data.dset_dir)
data.read_config(args.config)
data.load(data.dset_dir)
data.read_config(args.config)
print(data.raw_dir)
# exit(0)
# exit(0)
data.show_data_summary()
data.generate_instance('raw')
print("nbest: %s"%(data.nbest))
Expand All @@ -449,6 +454,3 @@ def load_model_decode(data, name):
else:
print("Invalid argument! Please use valid arguments! (train/test/decode)")




8 changes: 6 additions & 2 deletions utils/alphabet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from __future__ import print_function
import json
import os
import sys


class Alphabet:
Expand All @@ -36,7 +37,7 @@ def clear(self, keep_growing=True):
# Index 0 is occupied by default, all else following.
self.default_index = 0
self.next_index = 1

def add(self, instance):
if instance not in self.instance2index:
self.instances.append(instance)
Expand Down Expand Up @@ -73,7 +74,10 @@ def size(self):
return len(self.instances) + 1

def iteritems(self):
return self.instance2index.iteritems()
if sys.version_info[0] < 3: # If using python3, dict item access uses different syntax
return self.instance2index.iteritems()
else:
return self.instance2index.items()

def enumerate_items(self, start=1):
if start < 1 or start >= self.size():
Expand Down
51 changes: 29 additions & 22 deletions utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
import numpy as np
from .alphabet import Alphabet
from .functions import *
import cPickle as pickle

try:
import cPickle as pickle
except ModuleNotFoundError:
import pickle as pickle


START = "</s>"
Expand All @@ -34,21 +38,21 @@ def __init__(self):

self.label_alphabet = Alphabet('label',True)
self.tagScheme = "NoSeg" ## BMES/BIO

self.seg = True

### I/O
self.train_dir = None
self.dev_dir = None
self.test_dir = None
self.train_dir = None
self.dev_dir = None
self.test_dir = None
self.raw_dir = None

self.decode_dir = None
self.dset_dir = None ## data vocabulary related file
self.model_dir = None ## model save file
self.load_model_dir = None ## model load file

self.word_emb_dir = None
self.word_emb_dir = None
self.char_emb_dir = None
self.feature_emb_dirs = []

Expand Down Expand Up @@ -82,7 +86,7 @@ def __init__(self):
self.char_feature_extractor = "CNN" ## "LSTM"/"CNN"/"GRU"/None
self.use_crf = True
self.nbest = None

## Training
self.average_batch_loss = False
self.optimizer = "SGD" ## "SGD"/"AdaGrad"/"AdaDelta"/"RMSProp"/"Adam"
Expand All @@ -96,14 +100,14 @@ def __init__(self):
self.HP_dropout = 0.5
self.HP_lstm_layer = 1
self.HP_bilstm = True

self.HP_gpu = False
self.HP_lr = 0.015
self.HP_lr_decay = 0.05
self.HP_clip = None
self.HP_momentum = 0
self.HP_l2 = 1e-8

def show_data_summary(self):
print("++"*50)
print("DATA SUMMARY START:")
Expand Down Expand Up @@ -156,7 +160,7 @@ def show_data_summary(self):

print(" "+"++"*20)
print(" Hyperparameters:")

print(" Hyper lr: %s"%(self.HP_lr))
print(" Hyper lr_decay: %s"%(self.HP_lr_decay))
print(" Hyper HP_clip: %s"%(self.HP_clip))
Expand All @@ -166,7 +170,7 @@ def show_data_summary(self):
print(" Hyper dropout: %s"%(self.HP_dropout))
print(" Hyper lstm_layer: %s"%(self.HP_lstm_layer))
print(" Hyper bilstm: %s"%(self.HP_bilstm))
print(" Hyper GPU: %s"%(self.HP_gpu))
print(" Hyper GPU: %s"%(self.HP_gpu))
print("DATA SUMMARY END.")
print("++"*50)
sys.stdout.flush()
Expand All @@ -180,11 +184,11 @@ def initial_feature_alphabets(self):
feature_prefix = items[idx].split(']',1)[0]+"]"
self.feature_alphabets.append(Alphabet(feature_prefix))
self.feature_name.append(feature_prefix)
print("Find feature: ", feature_prefix)
print("Find feature: ", feature_prefix)
self.feature_num = len(self.feature_alphabets)
self.pretrain_feature_embeddings = [None]*self.feature_num
self.feature_emb_dims = [20]*self.feature_num
self.feature_emb_dirs = [None]*self.feature_num
self.feature_emb_dirs = [None]*self.feature_num
self.norm_feature_embs = [False]*self.feature_num
self.feature_alphabet_sizes = [0]*self.feature_num
if self.feat_config:
Expand All @@ -201,13 +205,13 @@ def build_alphabet(self, input_file):
for line in in_lines:
if len(line) > 2:
pairs = line.strip().split()
word = pairs[0].decode('utf-8')
word = pairs[0]
if self.number_normalized:
word = normalize_word(word)
label = pairs[-1]
self.label_alphabet.add(label)
self.word_alphabet.add(word)
## build feature alphabet
## build feature alphabet
for idx in range(self.feature_num):
feat_idx = pairs[idx+1].split(']',1)[-1]
self.feature_alphabets[idx].add(feat_idx)
Expand Down Expand Up @@ -235,9 +239,9 @@ def build_alphabet(self, input_file):
def fix_alphabet(self):
self.word_alphabet.close()
self.char_alphabet.close()
self.label_alphabet.close()
self.label_alphabet.close()
for idx in range(self.feature_num):
self.feature_alphabets[idx].close()
self.feature_alphabets[idx].close()


def build_pretrain_emb(self):
Expand Down Expand Up @@ -332,7 +336,10 @@ def write_nbest_decoded_results(self, predict_results, pred_scores, name):
fout.write(score_string.strip() + "\n")

for idy in range(sent_length):
label_string = content_list[idx][0][idy].encode('utf-8') + " "
try: # Will fail with python3
label_string = content_list[idx][0][idy].encode('utf-8') + " "
except:
label_string = content_list[idx][0][idy] + " "
for idz in range(nbest):
label_string += predict_results[idx][idz][idy]+" "
label_string = label_string.strip() + "\n"
Expand Down Expand Up @@ -425,7 +432,7 @@ def read_config(self,config_file):

the_item = 'feature'
if the_item in config:
self.feat_config = config[the_item] ## feat_config is a dict
self.feat_config = config[the_item] ## feat_config is a dict



Expand Down Expand Up @@ -503,7 +510,7 @@ def config_file_to_dict(input_file):
if item=="feature":
if item not in config:
feat_dict = {}
config[item]= feat_dict
config[item]= feat_dict
feat_dict = config[item]
new_pair = pair[-1].split()
feat_name = new_pair[0]
Expand All @@ -525,12 +532,12 @@ def config_file_to_dict(input_file):
else:
if item in config:
print("Warning: duplicated config item found: %s, updated."%(pair[0]))
config[item] = pair[-1]
config[item] = pair[-1]
return config


def str2bool(string):
if string == "True" or string == "true" or string == "TRUE":
return True
return True
else:
return False
Loading

0 comments on commit a42dcd9

Please sign in to comment.