Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

why dit.safetensors is 40GB in size? #115

Open
david-beckham-315 opened this issue Dec 16, 2024 · 3 comments
Open

why dit.safetensors is 40GB in size? #115

david-beckham-315 opened this issue Dec 16, 2024 · 3 comments
Labels
question Further information is requested

Comments

@david-beckham-315
Copy link

Hi

The diffusion model has 10B parameter, but I found the dit.safetensors is 40GB in size? What's the dtype store in the model? In TF32?

Looking forward to the feedback, thanks.

@ajayjain
Copy link
Contributor

ajayjain commented Dec 16, 2024

The checkpoint is stored in float32, not TF32. It should be fine to convert it to bfloat16, but a few parameters should ideally stay in float32 (one pos_frequencies tensor, and q_norm_x.weight, q_norm_y.weight, k_norm_x.weight, k_norm_y.weight in each block).

@david-beckham-315
Copy link
Author

The checkpoint is stored in float32, not TF32. It should be fine to convert it to bfloat16, but a few parameters should ideally stay in float32 (one pos_frequencies tensor, and q_norm_x.weight, q_norm_y.weight, k_norm_x.weight, k_norm_y.weight in each block).

Thanks for your answer!

@david-beckham-315
Copy link
Author

The checkpoint is stored in float32, not TF32. It should be fine to convert it to bfloat16, but a few parameters should ideally stay in float32 (one pos_frequencies tensor, and q_norm_x.weight, q_norm_y.weight, k_norm_x.weight, k_norm_y.weight in each block).

Hi
I'd like to ask 2 questions

  1. The checkpoint is stored in float32, but why run in bfloat16?
  2. How to run in float32 dtype? I changed model_dtype="bf16" to model_dtype="fp32" in cli.py, but it prompts "assert self.kwargs["model_dtype"] == "bf16", "FP8 is not supported for multi-GPU inference""

@ajayjain ajayjain added the question Further information is requested label Dec 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants