Skip to content

Commit

Permalink
support running QuantiDCE on CPU
Browse files Browse the repository at this point in the history
  • Loading branch information
James-Yip committed Oct 25, 2021
1 parent cf2e842 commit 433bea8
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions model/bert_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,21 @@ def __init__(self, args):
nn.Linear(mlp_hidden_size_2, 1),
nn.Sigmoid())

self.device = torch.device("cuda:{}".format(args.gpu))
self.to(self.device)
if args.gpu:
self.device = torch.device("cuda:{}".format(args.gpu))
self.to(self.device)
map_location = 'cuda:{}'.format(args.gpu)
else:
self.device = None
map_location = None

if hasattr(args, 'checkpoint_file_name'):
# loads checkpoint
checkpoint_file_path = os.path.join(
args.checkpoint_dir_path, args.checkpoint_file_name)
state_dict = torch.load(
checkpoint_file_path,
map_location='cuda:{}'.format(args.gpu))
map_location=map_location)
self.load_state_dict(state_dict)
print('loading checkpoint from: {}'.format(checkpoint_file_path))

Expand All @@ -54,7 +59,7 @@ def __init__(self, args):
args.pretrain_checkpoint_file_name)
state_dict = torch.load(
checkpoint_file_path,
map_location='cuda:{}'.format(args.gpu))
map_location=map_location)
self.load_state_dict(state_dict)
print('loading checkpoint from: {}'.format(checkpoint_file_path))

Expand All @@ -74,10 +79,10 @@ def forward(self, input_ids, token_type_ids, attention_mask):
return output_dict, score

@torch.no_grad()
def get_score(self, sample: dict):
def get_score(self, context: List[str], response: str):
self.eval()
input_ids, token_type_ids, attention_mask = self.encode_ctx_res_pair(
sample['context'], sample['hyp_response'])
context, response)
_, score = self.forward(input_ids, token_type_ids, attention_mask)
return score[0].item()

Expand Down

0 comments on commit 433bea8

Please sign in to comment.