DDPM_Mnist/positionEncoding.py

32 lines
1.2 KiB
Python

import torch
import torch.nn as nn
import math
class PositionEncode(nn.Module):
'''
Input a LongTensor time sequence, return position embedding
Input:
time_seq: shape of LongTensor (b, )
Output:
dim: shape of tensor (b, time_emb_dim)
Args:
time_emb_dim (int): output's dimension
device (nn.device): data on what device
'''
def __init__(self, time_emb_dim, device):
super(PositionEncode, self).__init__()
self.time_emb_dim = time_emb_dim
self.base = torch.Tensor([ 1/math.pow(10000, (i//2)/self.time_emb_dim) for i in range(self.time_emb_dim) ]) # (d)
self.base = self.base.to(device)
def forward(self, time_seq):
seq_len = len(time_seq)
dim = self.base.reshape(1, self.time_emb_dim).repeat(seq_len, 1) # (b, time_emb_dim)
time_seq = time_seq[:, None].repeat(1, self.time_emb_dim) # (b, time_emb_dim)
ans = dim * time_seq # (b, time_emb_dim)
ans[:, 0::2] = torch.sin(ans[:, 0::2])
ans[:, 1::2] = torch.cos(ans[:, 1::2])
return ans