fix: set n in sample() default value to self.batch_size
This commit is contained in:
parent
37adcc7e97
commit
a330666a17
7
ddpm.py
7
ddpm.py
@ -53,14 +53,17 @@ class DDPM(nn.Module):
|
|||||||
|
|
||||||
return mu + sigma * epsilon, epsilon # (b, c, w, h)
|
return mu + sigma * epsilon, epsilon # (b, c, w, h)
|
||||||
|
|
||||||
def sample(self, model, generate_iteration_pic=False, n=self.batch_size):
|
def sample(self, model, generate_iteration_pic=False, n=None):
|
||||||
'''
|
'''
|
||||||
Inputs:
|
Inputs:
|
||||||
model (nn.Module): Unet instance
|
model (nn.Module): Unet instance
|
||||||
n (int): want to sample n pictures
|
generate_iteration_pic (bool): whether generate 10 pic on different denoising time
|
||||||
|
n (int, default=self.batch_size): want to sample n pictures
|
||||||
Outputs:
|
Outputs:
|
||||||
x_0 (nn.Tensor): (n, c, h, w)
|
x_0 (nn.Tensor): (n, c, h, w)
|
||||||
'''
|
'''
|
||||||
|
if n == None:
|
||||||
|
n = self.batch_size
|
||||||
c, h, w = 1, 28, 28
|
c, h, w = 1, 28, 28
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user