Skip to content

Commit

Permalink
Merge pull request CjangCjengh#9 from Kern-steins/main
Browse files Browse the repository at this point in the history
add batch function
  • Loading branch information
CjangCjengh authored Sep 5, 2023
2 parents 3f69c8d + 3153b30 commit 252b7cf
Showing 1 changed file with 26 additions and 15 deletions.
41 changes: 26 additions & 15 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import glob, os, shutil
from model import init_model
from beam_decoder import beam_search
import torch.nn.utils.rnn as rnn_utils

class Translator:
def __init__(self, model_dir, device='cpu'):
Expand Down Expand Up @@ -45,28 +46,38 @@ def translate(self, text, beam_size=3, device='cpu', input_cleaner=None, output_
bos_idx = self.config['bos_idx']
eos_idx = self.config['eos_idx']
pad_idx = self.config['pad_idx']

if isinstance(text, str):
text = [text]
if self.input_cleaners is not None:
for c in self.input_cleaners:
text = c(text)
text = [c(text_single) for text_single in text]

if input_cleaner:
text = getattr(cleaner, input_cleaner)(text)
src_tokens = torch.LongTensor([[bos_idx] + self.encode(text) + [eos_idx]])
src_mask = (src_tokens != pad_idx).unsqueeze(-2)
results, _ = beam_search(self.model.to(device), src_tokens.to(device), src_mask.to(device), self.config['max_len'][1],

src_tokens = rnn_utils.pad_sequence((torch.LongTensor([bos_idx] + self.encode(t) + [eos_idx]) for t in text),
batch_first=True, padding_value=pad_idx).to(device)
src_mask = (src_tokens != pad_idx).unsqueeze(-2).to(device)

results, _ = beam_search(self.model.to(device), src_tokens, src_mask, self.config['max_len'][1],
pad_idx, bos_idx, eos_idx, beam_size, device, self.is_terminated)
if results is None:
return None
texts = []
for result in results[0]:
index_of_eos = result.index(2) if 2 in result else len(result)
result = result[:index_of_eos + 1]
text = self.decode(result)
for c in self.output_cleaners:
text = c(text)
if output_cleaner:
text = getattr(cleaner, output_cleaner)(text)
texts.append(text)
return texts
texts_last = []
for result_idx in results:
texts = []
for result in result_idx:
index_of_eos = result.index(2) if 2 in result else len(result)
result = result[:index_of_eos + 1]
text = self.decode(result)
for c in self.output_cleaners:
text = c(text)
if output_cleaner:
text = getattr(cleaner, output_cleaner)(text)
texts.append(text)
texts_last.append(texts)
return texts_last

def translate_txt(self, file, output, beam_size=3, device='cpu', input_cleaner=None, output_cleaner=None):
def translate_and_write(text):
Expand Down

0 comments on commit 252b7cf

Please sign in to comment.