60 lines
2.6 KiB
Python
60 lines
2.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import math
|
|
|
|
class MultiHeadAttention(nn.Module):
|
|
'''
|
|
Multi-Head Self Attention Block
|
|
|
|
Args:
|
|
dim (int): input & output dim
|
|
num_heads (int, default=8): number of heads
|
|
Inputs:
|
|
k: (b, seq, dim), it's not key value from anywhere, it's an embedding ready to get into W_k
|
|
q: (b, seq, dim), like k
|
|
v: (b, seq, dim), like v
|
|
mask (default None): BoolTensor, (b, seq, dim)
|
|
Outputs:
|
|
ans: (b, seq, dim)
|
|
score: (b, #heads, seq, seq) attention score which after softmax
|
|
'''
|
|
def __init__(self, dim, num_heads=8):
|
|
super(MultiHeadAttention, self).__init__()
|
|
|
|
self.num_heads = num_heads
|
|
self.head_dim = dim // num_heads
|
|
|
|
self.wk = nn.Linear(dim, dim) # b, seq, dim
|
|
self.wq = nn.Linear(dim, dim) # b, seq, dim
|
|
self.wv = nn.Linear(dim, dim) # b, seq, dim
|
|
self.fc = nn.Linear(dim, dim)
|
|
|
|
def forward(self, k, q, v, mask=None):
|
|
b, seq, dim = k.shape
|
|
k = self.wk(k) # b, seq, dim
|
|
q = self.wq(q) # b, seq, dim
|
|
v = self.wv(v) # b, seq, dim
|
|
|
|
k = k.view(b, -1, self.num_heads, self.head_dim).transpose(1, 2) # b, #heads, seq, #head_dim
|
|
q = q.view(b, -1, self.num_heads, self.head_dim).transpose(1, 2) # b, #heads, seq, #head_dim
|
|
v = v.view(b, -1, self.num_heads, self.head_dim).transpose(1, 2) # b, #heads, seq, #head_dim
|
|
|
|
k = k.transpose(2, 3) # b, #heads, #head_dim, seq
|
|
|
|
score = torch.matmul(q, k) / (math.sqrt(self.head_dim)) # b, #heads, seq, seq
|
|
if mask != None:
|
|
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
|
|
score = score.masked_fill(mask, value=torch.tensor(-(1e20)))
|
|
# print(score[0][0][2])
|
|
# for i in score[0][0]:
|
|
# print(i)
|
|
score = F.softmax(score, dim=-1)
|
|
|
|
ans = torch.matmul(score, v) # b, #heads, seq, head_dim
|
|
|
|
ans = ans.transpose(1, 2).reshape((b, -1, dim)) # b, seq, dim
|
|
ans = self.fc(ans) # b, seq, dim
|
|
|
|
return ans, score
|