fix: set n in sample() default value to self.batch_size

This commit is contained in:
snsd0805 2023-03-14 22:57:44 +08:00
parent 37adcc7e97
commit a330666a17
Signed by: snsd0805
GPG Key ID: 569349933C77A854

View File

@ -53,14 +53,17 @@ class DDPM(nn.Module):
return mu + sigma * epsilon, epsilon # (b, c, w, h)
def sample(self, model, generate_iteration_pic=False, n=self.batch_size):
def sample(self, model, generate_iteration_pic=False, n=None):
'''
Inputs:
model (nn.Module): Unet instance
n (int): want to sample n pictures
generate_iteration_pic (bool): whether generate 10 pic on different denoising time
n (int, default=self.batch_size): want to sample n pictures
Outputs:
x_0 (nn.Tensor): (n, c, h, w)
'''
if n == None:
n = self.batch_size
c, h, w = 1, 28, 28
model.eval()
with torch.no_grad():