Transformer-Translator/predict.py

115 lines
3.8 KiB
Python

import torch
import torch.nn as nn
import torchtext
import numpy
from utils.Vocab import get_vocabs
from utils.Dataset import TranslateDataset
from models.Transformer import Transformer
from models.Translator import TranslatorModel
import seaborn as sns
import matplotlib.pyplot as plt
BATCH_SIZE = 128
EPOCH_NUM = 100
LEARNING_RATE = 1e-4
ENGLISH_SEQ = 50
CHINESE_SEQ = 40
WORD_EMB_DIM = 100
DIM_IN_TRANSFORMER = 256
FFN_HIDDEN_DIM = 512
DEVICE = torch.device('cuda')
SHOW_NUM = 5
NUM_HEADS = 8
DROPOUT_RATE = 0.5
def en2tokens(en_sentence, en_vocab, for_model=False, en_seq=50):
'''
English to tokens
Args:
en_sentence (str)
en_vocab (torchtext.Vocab)
for_model (bool, default=False): if `True`, it will add <SOS>, <END>, <PAD> tokens
en_seq (int): for padding <PAD>
Outputs:
tokens (LongTensor): (b,)
'''
tokenizer = torchtext.data.utils.get_tokenizer("basic_english")
tokens = en_vocab( tokenizer(en_sentence.lower()) )
if for_model:
tokens = [ en_vocab['<SOS>'] ] + tokens + [ en_vocab['<END>'] ]
tokens = tokens + [ en_vocab['<PAD>'] for _ in range(en_seq - len(tokens)) ]
return torch.LongTensor(tokens)
def predict(en_str, model, en_vocab, ch_vocab):
en_tokens = en2tokens(en_str, en_vocab, for_model=True, en_seq=ENGLISH_SEQ)
en_tokens = en_tokens.unsqueeze(0).to(DEVICE)
ch_tokens = torch.LongTensor([ ch_vocab['<PAD>'] for _ in range(CHINESE_SEQ) ]).unsqueeze(0).to(DEVICE)
ch_tokens[0][0] = torch.tensor(ch_vocab['<SOS>'])
model.eval()
att = []
with torch.no_grad():
for index in range(0, CHINESE_SEQ):
predict, score = model(en_tokens, ch_tokens) # b, seq, dim
predict = torch.argmax(predict, dim=2) # b, seq
att.append(score[0, :, index, :].unsqueeze(0))
if index != (CHINESE_SEQ-1):
ch_tokens[0][index+1] = predict[0][index]
att = torch.cat(att, dim=0) # seq, #head, ENGLISH_SEQ
english_words = en_vocab.lookup_tokens(en_tokens[0].tolist())
chinese_words = ch_vocab.lookup_tokens(ch_tokens[0].tolist())
english_len, chinese_len = 0, 0
for i in english_words:
english_len += 1
if i == '<END>':
break
for i in chinese_words:
chinese_len += 1
if i == '<END>':
break
return chinese_words, english_words, english_len, chinese_len, att
if __name__ == '__main__':
# load tokenizer & vocabs
tokenizer = torchtext.data.utils.get_tokenizer("basic_english")
en_vocab, ch_vocab = get_vocabs()
print("English Vocab size: {}\nChinese Vocab size: {}\n".format(len(en_vocab), len(ch_vocab)))
# load model
model = TranslatorModel(
num_emb_en=len(en_vocab),
num_emb_ch=len(ch_vocab),
emb_dim=WORD_EMB_DIM,
en_vocab=en_vocab,
dim_in_transformer=DIM_IN_TRANSFORMER,
ffn_hidden_dim=FFN_HIDDEN_DIM,
en_seq=ENGLISH_SEQ,
ch_seq=CHINESE_SEQ,
device=DEVICE,
num_heads=NUM_HEADS,
dropout_rate=DROPOUT_RATE
).to(DEVICE)
model.load_state_dict(torch.load('translate_model.pth'))
while(1):
s = input("English: ")
chinese_words, english_words, english_len, chinese_len, att = predict(s, model, en_vocab, ch_vocab)
att = att.to('cpu').numpy()
print("Chinese: {}".format(chinese_words))
for i in range(8):
mapping = att[:chinese_len, i, :english_len]
plt.rcParams['font.sans-serif'] = ['Taipei Sans TC Beta']
plt.subplot(421+i)
sns.heatmap(mapping, xticklabels=english_words[:english_len], yticklabels=chinese_words[:chinese_len])
plt.show()