feat: PyTorch 模型 & Telegram Bot

This commit is contained in:
Ting-Jun Wang 2022-06-05 03:51:43 +08:00
parent 5b89104cfa
commit 75a3545e2d
Signed by: snsd0805
GPG Key ID: 8DB0D22BC1217D33
3 changed files with 151 additions and 0 deletions

41
Bot/PTTBotModel.py Normal file
View File

@ -0,0 +1,41 @@
import torch
from torch import nn
from torch.utils.data import DataLoader
class PTTBotModel(nn.Module):
def __init__(self, tokenizer):
super(PTTBotModel, self).__init__()
self.embedding = nn.Embedding(num_embeddings=len(tokenizer), embedding_dim=300, padding_idx=2)
self.encoder = nn.LSTM(input_size=300, hidden_size=1024, batch_first=True)
self.decoder = nn.LSTM(input_size=300, hidden_size=1024, batch_first=True)
self.dense = nn.Linear(in_features=1024, out_features=len(tokenizer))
self.dropout03 = nn.Dropout(0.3)
self.dropout05 = nn.Dropout(0.5)
def forward(self, q, a, lenq, lena):
# (BATCH, 40)
q = self.embedding(q)
# (BATCH, 40, 300)
# (BATCH, 30)
a = self.embedding(a)
# (BATCH, 30, 300)
q = self.dropout03(q)
a = self.dropout03(a)
q = nn.utils.rnn.pack_padded_sequence(q, lenq, batch_first=True, enforce_sorted=False)
out, (h, c) = self.encoder(q)
h = self.dropout03(h)
c = self.dropout03(c)
# (BATCH, 30, 512) for output
# (1, BATCH, 512) for h (最後一個 hidden state)
# (1, BATCH, 512) for c (最後一個 cell state)
out, (h, c) = self.decoder(a, (h, c))
# (BATCH, 30, 1024) for output
out = self.dense( self.dropout05(out) )
return out

81
Bot/bot.py Normal file
View File

@ -0,0 +1,81 @@
from urllib import response
from cairo import Filter
from telegram import Update
import torch
from PTTBotModel import PTTBotModel
from tokenizer import Tokenizer
from telegram.ext import Updater, MessageHandler, Filters, CallbackContext
TOKEN = ""
class PTTBot():
def __init__(self):
self.updater = Updater(token=TOKEN, use_context=True)
self.dispacher = self.updater.dispatcher
self.dispacher.add_handler(MessageHandler(Filters.text, self.response))
self.loadModel()
def start(self):
self.updater.start_polling()
print("start...")
def loadModel(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {self.device} device")
print("[INFO] 載入 tokenizer")
self.tokenizer = Tokenizer()
print("[INFO] 載入 model")
self.model = PTTBotModel(self.tokenizer).to(self.device)
self.model.load_state_dict(torch.load('./char_based_state.pt'))
def predict(self, q):
self.model.eval()
QTokens = self.tokenizer.chars_to_tokens(q)
QTokens = [ self.tokenizer.char_index['<sos>'] ] + QTokens + [ self.tokenizer.char_index['<end>'] ]
lenq = len(QTokens)
QTokens += [ self.tokenizer.char_index['<pad>'] ] * (40-len(QTokens))
QTokens = torch.tensor(QTokens).unsqueeze(0)
lenq = torch.tensor([lenq])
lena = torch.tensor([])
ans = []
with torch.no_grad():
QTokens = QTokens.to(self.device)
a = [0]*30
a[0] = 0 # <sos>
a = torch.tensor(a).unsqueeze(0).to(self.device)
for i in range(30):
out = self.model(QTokens, a, lenq, lena)
out = out.view(-1, out.shape[-1])
out_predict = torch.argmax(out, dim=1)
if i!=29:
a[0][i+1] = out_predict[i]
words = self.tokenizer.tokens_to_chars(a[0])
for word in words:
if word not in ['<sos>', '<pad>', '<end>']:
ans.append(word)
return ans
def response(self, update, context):
user = update.effective_chat
q = update.message.text
prediction = self.predict(q)
msg = "".join(prediction)
context.bot.send_message(
chat_id=user.id,
text=msg
)
if __name__ == '__main__':
bot = PTTBot()
bot.start()

29
Bot/tokenizer.py Normal file
View File

@ -0,0 +1,29 @@
import torch
import pickle
class Tokenizer():
def __init__(self):
with open('./char_index.pkl', 'rb') as fp:
self.char_index = pickle.load(fp)
with open('./index_char.pkl', 'rb') as fp:
self.index_char = pickle.load(fp)
def __len__(self):
return len(self.char_index)
def chars_to_tokens(self, chars):
ans = []
for ch in chars:
if ch in self.char_index:
ans.append( self.char_index[ch] )
else:
ans.append( self.char_index['<unk>'] )
return ans
def tokens_to_chars(self, tokens):
if type(tokens) is torch.Tensor:
tokens = tokens.tolist()
ans = []
for token in tokens:
ans.append( self.index_char[token] )
return ans