Skip to content

Commit

Permalink
add translate_batch
Browse files Browse the repository at this point in the history
  • Loading branch information
CjangCjengh committed Sep 6, 2023
1 parent 0d2f1c2 commit 2f682ff
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,13 @@ def terminate(self):
self._is_terminated = True

def translate(self, text, beam_size=3, device='cpu', input_cleaner=None, output_cleaner=None):
return self.translate_batch([text], beam_size, device, input_cleaner, output_cleaner)[0]

def translate_batch(self, text, beam_size=3, device='cpu', input_cleaner=None, output_cleaner=None):
bos_idx = self.config['bos_idx']
eos_idx = self.config['eos_idx']
pad_idx = self.config['pad_idx']

if isinstance(text, str):
text = [text]

if self.input_cleaners is not None:
for c in self.input_cleaners:
text = [c(text_single) for text_single in text]
Expand Down Expand Up @@ -77,8 +78,6 @@ def translate(self, text, beam_size=3, device='cpu', input_cleaner=None, output_
text = getattr(cleaner, output_cleaner)(text)
texts.append(text)
texts_last.append(texts)
if len(texts_last) == 1:
return texts_last[0]
return texts_last

def translate_txt(self, file, output, beam_size=3, device='cpu', input_cleaner=None, output_cleaner=None):
Expand Down

0 comments on commit 2f682ff

Please sign in to comment.