diff --git a/model/bert_metric.py b/model/bert_metric.py index 8d12c01..f082f1a 100644 --- a/model/bert_metric.py +++ b/model/bert_metric.py @@ -34,8 +34,13 @@ def __init__(self, args): nn.Linear(mlp_hidden_size_2, 1), nn.Sigmoid()) - self.device = torch.device("cuda:{}".format(args.gpu)) - self.to(self.device) + if args.gpu: + self.device = torch.device("cuda:{}".format(args.gpu)) + self.to(self.device) + map_location = 'cuda:{}'.format(args.gpu) + else: + self.device = None + map_location = None if hasattr(args, 'checkpoint_file_name'): # loads checkpoint @@ -43,7 +48,7 @@ def __init__(self, args): args.checkpoint_dir_path, args.checkpoint_file_name) state_dict = torch.load( checkpoint_file_path, - map_location='cuda:{}'.format(args.gpu)) + map_location=map_location) self.load_state_dict(state_dict) print('loading checkpoint from: {}'.format(checkpoint_file_path)) @@ -54,7 +59,7 @@ def __init__(self, args): args.pretrain_checkpoint_file_name) state_dict = torch.load( checkpoint_file_path, - map_location='cuda:{}'.format(args.gpu)) + map_location=map_location) self.load_state_dict(state_dict) print('loading checkpoint from: {}'.format(checkpoint_file_path)) @@ -74,10 +79,10 @@ def forward(self, input_ids, token_type_ids, attention_mask): return output_dict, score @torch.no_grad() - def get_score(self, sample: dict): + def get_score(self, context: List[str], response: str): self.eval() input_ids, token_type_ids, attention_mask = self.encode_ctx_res_pair( - sample['context'], sample['hyp_response']) + context, response) _, score = self.forward(input_ids, token_type_ids, attention_mask) return score[0].item()