93 lines
3.3 KiB
Python
93 lines
3.3 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from models.Transformer import Transformer
|
|
import os
|
|
|
|
class TranslatorModel(nn.Module):
|
|
'''
|
|
Args:
|
|
num_emb_en (int): length of vocab (how many words in vocab)
|
|
num_emb_ch (int): length of vocab (how many words in vocab)
|
|
emb_dim (int): word embedding dim (English's dim and Chineses' dim are same)
|
|
en_vocab (torchtext.Vocab): for load Glove pretrained embedding
|
|
dim_in_transformer (int)
|
|
ffn_hidden_dim (int): for transformer's FFN module
|
|
en_seq (int): English token' length
|
|
ch_seq (int): Chinese token' length
|
|
device (nn.Device)
|
|
|
|
num_heads (int, default: 8)
|
|
dropout_rate (float, default: 0.1)
|
|
Inputs:
|
|
en_tokens: (b, seq) LongTensor
|
|
ch_tokens: (b, seq) LongTensor
|
|
Outputs:
|
|
x: (b, seq, num_emb_ch) probability
|
|
'''
|
|
def __init__(self, num_emb_en, num_emb_ch, emb_dim, en_vocab, dim_in_transformer, ffn_hidden_dim, en_seq, ch_seq, device, num_heads=8, dropout_rate=0.1):
|
|
super(TranslatorModel, self).__init__()
|
|
|
|
# load glove word embedding
|
|
weight = self.get_glove_weight(en_vocab, num_emb_en, emb_dim)
|
|
self.en_word_embedding = nn.Embedding(num_emb_en, emb_dim)
|
|
self.en_word_embedding = self.en_word_embedding.from_pretrained(weight, freeze=True)
|
|
|
|
# chinese word embedding
|
|
self.ch_word_embedding = nn.Embedding(num_emb_ch, emb_dim)
|
|
|
|
# transformer
|
|
self.transformer = Transformer(
|
|
emb_dim=emb_dim,
|
|
dim=dim_in_transformer,
|
|
ffn_hidden_dim=ffn_hidden_dim,
|
|
encoder_seq=en_seq,
|
|
decoder_seq=ch_seq,
|
|
device=device,
|
|
num_heads=8,
|
|
dropout_rate=0.1
|
|
)
|
|
|
|
self.fc1 = nn.Linear(dim_in_transformer, 512)
|
|
self.fc2 = nn.Linear(512, 1024)
|
|
self.fc3 = nn.Linear(1024, num_emb_ch)
|
|
self.dropout = nn.Dropout(dropout_rate)
|
|
self.relu = nn.ReLU()
|
|
|
|
|
|
def forward(self, en_tokens, ch_tokens):
|
|
en_tokens = self.en_word_embedding(en_tokens)
|
|
ch_tokens = self.ch_word_embedding(ch_tokens)
|
|
x, score = self.transformer(en_tokens, ch_tokens)
|
|
x = self.relu(self.fc1(x))
|
|
x = self.dropout(x)
|
|
x = self.relu(self.fc2(x))
|
|
x = self.dropout(x)
|
|
x = self.fc3(x)
|
|
return x, score
|
|
|
|
def get_glove_weight(self, en_vocab, num_emb_en, emb_dim):
|
|
'''
|
|
Load embedding from GLOVE
|
|
|
|
Args:
|
|
en_vocab (torch.Vocab)
|
|
num_emb_en (int): (how many word in vocab)
|
|
emb_dim (int): word embedding's dim
|
|
'''
|
|
if os.path.isfile("data/word_embedding.pth"):
|
|
weight = torch.load("data/word_embedding.pth")
|
|
else:
|
|
weight = torch.randn((num_emb_en, emb_dim))
|
|
with open('data/glove.6B.100d.txt') as fp:
|
|
lines = fp.readlines()
|
|
for line in lines:
|
|
l = line.split(" ")
|
|
word = l[0]
|
|
emb = l[1:]
|
|
emb = torch.tensor([ float(i) for i in emb ])
|
|
if word in en_vocab:
|
|
weight[en_vocab[word]] = emb
|
|
torch.save(weight, "data/word_embedding.pth")
|
|
return weight
|
|
|