32 lines
1.0 KiB
Python
32 lines
1.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import math
|
|
|
|
class PositionEncode(nn.Module):
|
|
'''
|
|
Args:
|
|
emb_dim (int): position embedding dim
|
|
device (nn.device)
|
|
Inputs:
|
|
time_seq: LongTensor (b, )
|
|
'''
|
|
def __init__(self, emb_dim, device):
|
|
super(PositionEncode, self).__init__()
|
|
seq = torch.tensor([ i//2 for i in range(emb_dim) ]) / emb_dim
|
|
self.base = 1/torch.pow(10000, seq).to(device) # (dim, )
|
|
self.emb_dim = emb_dim
|
|
|
|
def forward(self, time_seq):
|
|
b = time_seq.shape[0]
|
|
base = self.base[:, None].reshape(1, -1).repeat(b, 1) # (b, dim)
|
|
time_seq = time_seq[:, None]
|
|
# .repeat(1, self.emb_dim) # (b, dim)
|
|
|
|
|
|
ans = base * time_seq # (b, dim)
|
|
ans[:, 0::2] = torch.sin(ans[:, 0::2])
|
|
ans[:, 1::2] = torch.cos(ans[:, 1::2])
|
|
|
|
return ans
|