Skip to content

Commit

Permalink
Update pipeline.py
Browse files Browse the repository at this point in the history
Specify the device for the `input_image_tensor` to avoid encountering the following error code when employing GPU for image-to-image operations.

```
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
```
  • Loading branch information
nick8592 authored Feb 22, 2024
1 parent c3fb755 commit e2f7e66
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions sd/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def generate(
# (Height, Width, Channel)
input_image_tensor = np.array(input_image_tensor)
# (Height, Width, Channel) -> (Height, Width, Channel)
input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32)
input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32, device=device)
# (Height, Width, Channel) -> (Height, Width, Channel)
input_image_tensor = rescale(input_image_tensor, (0, 255), (-1, 1))
# (Height, Width, Channel) -> (Batch_Size, Height, Width, Channel)
Expand Down Expand Up @@ -167,4 +167,4 @@ def get_time_embedding(timestep):
# Shape: (1, 160)
x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
# Shape: (1, 160 * 2)
return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)
return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)

0 comments on commit e2f7e66

Please sign in to comment.