Skip to content

Commit

Permalink
set fp16 for inference
Browse files Browse the repository at this point in the history
  • Loading branch information
yisol committed Apr 22, 2024
1 parent c6555b6 commit c053929
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
14 changes: 7 additions & 7 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,13 @@ def main():
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)

weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
args.mixed_precision = accelerator.mixed_precision
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
args.mixed_precision = accelerator.mixed_precision
weight_dtype = torch.float16
# if accelerator.mixed_precision == "fp16":
# weight_dtype = torch.float16
# args.mixed_precision = accelerator.mixed_precision
# elif accelerator.mixed_precision == "bf16":
# weight_dtype = torch.bfloat16
# args.mixed_precision = accelerator.mixed_precision

# Load scheduler, tokenizer and models.
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
Expand Down
14 changes: 7 additions & 7 deletions inference_dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,13 @@ def main():
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)

weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
args.mixed_precision = accelerator.mixed_precision
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
args.mixed_precision = accelerator.mixed_precision
weight_dtype = torch.float16
# if accelerator.mixed_precision == "fp16":
# weight_dtype = torch.float16
# args.mixed_precision = accelerator.mixed_precision
# elif accelerator.mixed_precision == "bf16":
# weight_dtype = torch.bfloat16
# args.mixed_precision = accelerator.mixed_precision

# Load scheduler, tokenizer and models.
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
Expand Down

0 comments on commit c053929

Please sign in to comment.