Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
CjangCjengh committed Aug 12, 2023
1 parent 574d563 commit 4112c2f
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 8 deletions.
9 changes: 2 additions & 7 deletions ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ def __init__(self, settings):
super().__init__()
self.translate_func = []
self.max_text_length = 0
self.beam_size = 3
self.device = "cpu"
self.beam_size = int(settings.value('beam_size')) if settings.contains('beam_size') else 3
self.device = settings.value('device') if settings.contains('device') else 'cpu'
self.settings = settings
self.translator = None
self.init_ui()
Expand Down Expand Up @@ -370,11 +370,6 @@ def init_font(widget, label):
font.setFamily(self.settings.value(f'{label}_font'))
widget.setFont(font)

if self.settings.contains('device'):
self.device = self.settings.value('device')
if self.settings.contains('beam_size'):
self.beam_size = int(self.settings.value('beam_size'))

init_font(QApplication, 'global')
init_font(self.original_text_edit, 'original')
init_font(self.translated_text_edit, 'translated')
Expand Down
2 changes: 1 addition & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def translate(self, text, beam_size=3, device='cpu'):
text = self.cleaner(text)
src_tokens = torch.LongTensor([[bos_idx] + self.encode(text) + [eos_idx]])
src_mask = (src_tokens != pad_idx).unsqueeze(-2)
results, _ = beam_search(self.model.to(device), src_tokens, src_mask, self.config['max_len'],
results, _ = beam_search(self.model.to(device), src_tokens.to(device), src_mask.to(device), self.config['max_len'],
pad_idx, bos_idx, eos_idx, beam_size, device, self.is_terminated)
if results is None:
return None
Expand Down

0 comments on commit 4112c2f

Please sign in to comment.