fix: set n in sample() default value to self.batch_size
This commit is contained in:
parent
37adcc7e97
commit
a330666a17
7
ddpm.py
7
ddpm.py
@ -53,14 +53,17 @@ class DDPM(nn.Module):
|
||||
|
||||
return mu + sigma * epsilon, epsilon # (b, c, w, h)
|
||||
|
||||
def sample(self, model, generate_iteration_pic=False, n=self.batch_size):
|
||||
def sample(self, model, generate_iteration_pic=False, n=None):
|
||||
'''
|
||||
Inputs:
|
||||
model (nn.Module): Unet instance
|
||||
n (int): want to sample n pictures
|
||||
generate_iteration_pic (bool): whether generate 10 pic on different denoising time
|
||||
n (int, default=self.batch_size): want to sample n pictures
|
||||
Outputs:
|
||||
x_0 (nn.Tensor): (n, c, h, w)
|
||||
'''
|
||||
if n == None:
|
||||
n = self.batch_size
|
||||
c, h, w = 1, 28, 28
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user