Skip to content

Commit

Permalink
change pkl to pt
Browse files Browse the repository at this point in the history
  • Loading branch information
ines-chami committed May 4, 2020
1 parent d9b3d00 commit f846be3
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def train(args):
counter = 0
best_epoch = step
logging.info("\t Saving model at epoch {} in {}".format(step, save_dir))
torch.save(model.cpu().state_dict(), os.path.join(save_dir, "model.pkl"))
torch.save(model.cpu().state_dict(), os.path.join(save_dir, "model.pt"))
model.cuda()
else:
counter += 1
Expand All @@ -171,10 +171,10 @@ def train(args):

logging.info("\t Optimization finished")
if not best_mrr:
torch.save(model.cpu().state_dict(), os.path.join(save_dir, "model.pkl"))
torch.save(model.cpu().state_dict(), os.path.join(save_dir, "model.pt"))
else:
logging.info("\t Loading best model saved at epoch {}".format(best_epoch))
model.load_state_dict(torch.load(os.path.join(save_dir, "model.pkl")))
model.load_state_dict(torch.load(os.path.join(save_dir, "model.pt")))
model.cuda()
model.eval()

Expand Down
2 changes: 1 addition & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test(model_dir):
model = getattr(models, args.model)(args)
device = 'cuda'
model.to(device)
model.load_state_dict(torch.load(os.path.join(model_dir, 'model.pkl')))
model.load_state_dict(torch.load(os.path.join(model_dir, 'model.pt')))

# eval
test_metrics = avg_both(*model.compute_metrics(test_examples, filters))
Expand Down

0 comments on commit f846be3

Please sign in to comment.