Skip to content

Commit

Permalink
[{Up,Down}sample1d] explicit view kernel size as number elements in f…
Browse files Browse the repository at this point in the history
…lattened indices (huggingface#3479)

explicit view kernel size as number elements in flattened indices
  • Loading branch information
williamberman authored May 19, 2023
1 parent e589bdb commit 85eff63
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/diffusers/models/unet_1d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,8 @@ def forward(self, hidden_states):
hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
weight[indices, indices] = self.kernel.to(weight)
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
weight[indices, indices] = kernel
return F.conv1d(hidden_states, weight, stride=2)


Expand All @@ -316,7 +317,8 @@ def forward(self, hidden_states, temb=None):
hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode)
weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]])
indices = torch.arange(hidden_states.shape[1], device=hidden_states.device)
weight[indices, indices] = self.kernel.to(weight)
kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1)
weight[indices, indices] = kernel
return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1)


Expand Down

0 comments on commit 85eff63

Please sign in to comment.