From 5302fdcc7ca8f9aedae22ebcda13cf977ac038af Mon Sep 17 00:00:00 2001 From: snsd0805 Date: Wed, 29 Mar 2023 02:35:44 +0800 Subject: [PATCH] fix: add eng to tokens function --- predict.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/predict.py b/predict.py index 4a08396..adbe680 100644 --- a/predict.py +++ b/predict.py @@ -22,6 +22,25 @@ SHOW_NUM = 5 NUM_HEADS = 8 DROPOUT_RATE = 0.5 +def en2tokens(en_sentence, en_vocab, for_model=False, en_seq=50): + ''' + English to tokens + + Args: + en_sentence (str) + en_vocab (torchtext.Vocab) + + for_model (bool, default=False): if `True`, it will add , , tokens + en_seq (int): for padding + Outputs: + tokens (LongTensor): (b,) + ''' + tokenizer = torchtext.data.utils.get_tokenizer("basic_english") + tokens = en_vocab( tokenizer(en_sentence.lower()) ) + if for_model: + tokens = [ en_vocab[''] ] + tokens + [ en_vocab[''] ] + tokens = tokens + [ en_vocab[''] for _ in range(en_seq - len(tokens)) ] + return torch.LongTensor(tokens) def predict(en_str, model, en_vocab, ch_vocab):