Skip to content

Commit

Permalink
incorporated change from PR real-stanford#10
Browse files Browse the repository at this point in the history
  • Loading branch information
cheng-chi committed Sep 7, 2023
1 parent 0d00e02 commit dd2cbac
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions diffusion_policy/model/diffusion/conditional_unet1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,10 @@ def forward(self,
for idx, (resnet, resnet2, upsample) in enumerate(self.up_modules):
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, global_feature)
# The correct condition should be:
# if idx == (len(self.up_modules)-1) and len(h_local) > 0:
# However this change will break compatibility with published checkpoints.
# Therefore it is left as a comment.
if idx == len(self.up_modules) and len(h_local) > 0:
x = x + h_local[1]
x = resnet2(x, global_feature)
Expand Down

0 comments on commit dd2cbac

Please sign in to comment.