Compare commits
2 Commits
25cfa08b27
...
dfdc370b8e
| Author | SHA1 | Date | |
|---|---|---|---|
| dfdc370b8e | |||
| 5302fdcc7c |
5
README.md
Normal file
5
README.md
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# Transformer-based Translator
|
||||||
|
|
||||||
|
Simple Transformer
|
||||||
|
|
||||||
|
! [](image/Screenshot_20230329_023305.png)
|
||||||
BIN
image/Screenshot_20230329_023305.png
Normal file
BIN
image/Screenshot_20230329_023305.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 67 KiB |
19
predict.py
19
predict.py
@ -22,6 +22,25 @@ SHOW_NUM = 5
|
|||||||
NUM_HEADS = 8
|
NUM_HEADS = 8
|
||||||
DROPOUT_RATE = 0.5
|
DROPOUT_RATE = 0.5
|
||||||
|
|
||||||
|
def en2tokens(en_sentence, en_vocab, for_model=False, en_seq=50):
|
||||||
|
'''
|
||||||
|
English to tokens
|
||||||
|
|
||||||
|
Args:
|
||||||
|
en_sentence (str)
|
||||||
|
en_vocab (torchtext.Vocab)
|
||||||
|
|
||||||
|
for_model (bool, default=False): if `True`, it will add <SOS>, <END>, <PAD> tokens
|
||||||
|
en_seq (int): for padding <PAD>
|
||||||
|
Outputs:
|
||||||
|
tokens (LongTensor): (b,)
|
||||||
|
'''
|
||||||
|
tokenizer = torchtext.data.utils.get_tokenizer("basic_english")
|
||||||
|
tokens = en_vocab( tokenizer(en_sentence.lower()) )
|
||||||
|
if for_model:
|
||||||
|
tokens = [ en_vocab['<SOS>'] ] + tokens + [ en_vocab['<END>'] ]
|
||||||
|
tokens = tokens + [ en_vocab['<PAD>'] for _ in range(en_seq - len(tokens)) ]
|
||||||
|
return torch.LongTensor(tokens)
|
||||||
|
|
||||||
def predict(en_str, model, en_vocab, ch_vocab):
|
def predict(en_str, model, en_vocab, ch_vocab):
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user