Skip to content

Commit

Permalink
[bitsandbbytes] follow-ups (huggingface#9730)
Browse files Browse the repository at this point in the history
* bnb follow ups.

* add a warning when dtypes mismatch.

* fx-copies

* clear cache.

* check_if_quantized_param

* add a check on shape.

* updates

* docs

* improve readability.

* resources.

* fix
  • Loading branch information
sayakpaul authored Oct 22, 2024
1 parent 0f079b9 commit 60ffa84
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 65 deletions.
23 changes: 8 additions & 15 deletions docs/source/en/quantization/bitsandbytes.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,7 @@ model_8bit = FluxTransformer2DModel.from_pretrained(
model_8bit.transformer_blocks.layers[-1].norm2.weight.dtype
```

Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights.

```py
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_8bit=True)

model_8bit = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/FLUX.1-dev",
subfolder="transformer",
quantization_config=quantization_config
)
```
Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`].

</hfoption>
<hfoption id="4-bit">
Expand Down Expand Up @@ -131,7 +119,7 @@ from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_4bit=True)

model_4bit = FluxTransformer2DModel.from_pretrained(
"sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer"
"hf-internal-testing/flux.1-dev-nf4-pkg", subfolder="transformer"
)
```

Expand Down Expand Up @@ -264,4 +252,9 @@ double_quant_model = SD3Transformer2DModel.from_pretrained(
quantization_config=double_quant_config,
)
model.dequantize()
```
```

## Resources

* [End-to-end notebook showing Flux.1 Dev inference in a free-tier Colab](https://gist.github.com/sayakpaul/c76bd845b48759e11687ac550b99d8b4)
* [Training](https://gist.github.com/sayakpaul/05afd428bc089b47af7c016e42004527)
25 changes: 16 additions & 9 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,21 +211,28 @@ def load_model_dict_into_meta(
set_module_kwargs["dtype"] = dtype

# bnb params are flattened.
if not is_quant_method_bnb and empty_state_dict[param_name].shape != param.shape:
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
raise ValueError(
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
)
if empty_state_dict[param_name].shape != param.shape:
if (
is_quant_method_bnb
and hf_quantizer.pre_quantized
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
):
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape)
elif not is_quant_method_bnb:
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
raise ValueError(
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
)

if not is_quantized or (
not hf_quantizer.check_quantized_param(model, param, param_name, state_dict, param_device=device)
if is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
):
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
else:
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
else:
set_module_tensor_to_device(model, param_name, device, value=param)
else:
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)

return unexpected_keys

Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .auto import DiffusersAutoQuantizationConfig, DiffusersAutoQuantizer
from .auto import DiffusersAutoQuantizer
from .base import DiffusersQuantizer
41 changes: 15 additions & 26 deletions src/diffusers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@
}


class DiffusersAutoQuantizationConfig:
class DiffusersAutoQuantizer:
"""
The auto diffusers quantization config class that takes care of automatically dispatching to the correct
quantization config given a quantization config stored in a dictionary.
The auto diffusers quantizer class that takes care of automatically instantiating to the correct
`DiffusersQuantizer` given the `QuantizationConfig`.
"""

@classmethod
Expand All @@ -60,31 +60,11 @@ def from_dict(cls, quantization_config_dict: Dict):
target_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_method]
return target_cls.from_dict(quantization_config_dict)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
model_config = cls.load_config(pretrained_model_name_or_path, **kwargs)
if getattr(model_config, "quantization_config", None) is None:
raise ValueError(
f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized."
)
quantization_config_dict = model_config.quantization_config
quantization_config = cls.from_dict(quantization_config_dict)
# Update with potential kwargs that are passed through from_pretrained.
quantization_config.update(kwargs)
return quantization_config


class DiffusersAutoQuantizer:
"""
The auto diffusers quantizer class that takes care of automatically instantiating to the correct
`DiffusersQuantizer` given the `QuantizationConfig`.
"""

@classmethod
def from_config(cls, quantization_config: Union[QuantizationConfigMixin, Dict], **kwargs):
# Convert it to a QuantizationConfig if the q_config is a dict
if isinstance(quantization_config, dict):
quantization_config = DiffusersAutoQuantizationConfig.from_dict(quantization_config)
quantization_config = cls.from_dict(quantization_config)

quant_method = quantization_config.quant_method

Expand All @@ -107,7 +87,16 @@ def from_config(cls, quantization_config: Union[QuantizationConfigMixin, Dict],

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
quantization_config = DiffusersAutoQuantizationConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
model_config = cls.load_config(pretrained_model_name_or_path, **kwargs)
if getattr(model_config, "quantization_config", None) is None:
raise ValueError(
f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized."
)
quantization_config_dict = model_config.quantization_config
quantization_config = cls.from_dict(quantization_config_dict)
# Update with potential kwargs that are passed through from_pretrained.
quantization_config.update(kwargs)

return cls.from_config(quantization_config)

@classmethod
Expand All @@ -129,7 +118,7 @@ def merge_quantization_configs(
warning_msg = ""

if isinstance(quantization_config, dict):
quantization_config = DiffusersAutoQuantizationConfig.from_dict(quantization_config)
quantization_config = cls.from_dict(quantization_config)

if warning_msg != "":
warnings.warn(warning_msg)
Expand Down
13 changes: 8 additions & 5 deletions src/diffusers/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str,
"""adjust max_memory argument for infer_auto_device_map() if extra memory is needed for quantization"""
return max_memory

def check_quantized_param(
def check_if_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
Expand All @@ -152,10 +152,13 @@ def create_quantized_param(self, *args, **kwargs) -> "torch.nn.Parameter":
"""
takes needed components from state_dict and creates quantized param.
"""
if not hasattr(self, "check_quantized_param"):
raise AttributeError(
f"`.create_quantized_param()` method is not supported by quantizer class {self.__class__.__name__}."
)
return

def check_quantized_param_shape(self, *args, **kwargs):
"""
checks if the quantized param has expected shape.
"""
return True

def validate_environment(self, *args, **kwargs):
"""
Expand Down
15 changes: 12 additions & 3 deletions src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
else:
raise ValueError(f"Wrong `target_dtype` ({target_dtype}) provided.")

def check_quantized_param(
def check_if_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
Expand Down Expand Up @@ -204,6 +204,16 @@ def create_quantized_param(

module._parameters[tensor_name] = new_value

def check_quantized_param_shape(self, param_name, current_param_shape, loaded_param_shape):
n = current_param_shape.numel()
inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1)
if loaded_param_shape != inferred_shape:
raise ValueError(
f"Expected the flattened shape of the current param ({param_name}) to be {loaded_param_shape} but is {inferred_shape}."
)
else:
return True

def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
# need more space for buffers that are created during quantization
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
Expand Down Expand Up @@ -330,7 +340,6 @@ def __init__(self, quantization_config, **kwargs):
if self.quantization_config.llm_int8_skip_modules is not None:
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules

# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.validate_environment with 4-bit->8-bit
def validate_environment(self, *args, **kwargs):
if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
Expand Down Expand Up @@ -404,7 +413,7 @@ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
logger.info("target_dtype {target_dtype} is replaced by `torch.int8` for 8-bit BnB quantization")
return torch.int8

def check_quantized_param(
def check_if_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
Expand Down
49 changes: 46 additions & 3 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import os
import tempfile
import unittest

import numpy as np
import safetensors.torch

from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
from diffusers.utils import logging
Expand Down Expand Up @@ -118,6 +120,9 @@ def get_dummy_inputs(self):

class BnB4BitBasicTests(Base4bitTests):
def setUp(self):
gc.collect()
torch.cuda.empty_cache()

# Models
self.model_fp16 = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", torch_dtype=torch.float16
Expand Down Expand Up @@ -232,7 +237,7 @@ def test_linear_are_4bit(self):

def test_config_from_pretrained(self):
transformer_4bit = FluxTransformer2DModel.from_pretrained(
"sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer"
"hf-internal-testing/flux.1-dev-nf4-pkg", subfolder="transformer"
)
linear = get_some_linear_layer(transformer_4bit)
self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit)
Expand Down Expand Up @@ -312,9 +317,42 @@ def test_bnb_4bit_wrong_config(self):
with self.assertRaises(ValueError):
_ = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_storage="add")

def test_bnb_4bit_errors_loading_incorrect_state_dict(self):
r"""
Test if loading with an incorrect state dict raises an error.
"""
with tempfile.TemporaryDirectory() as tmpdirname:
nf4_config = BitsAndBytesConfig(load_in_4bit=True)
model_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=nf4_config
)
model_4bit.save_pretrained(tmpdirname)
del model_4bit

with self.assertRaises(ValueError) as err_context:
state_dict = safetensors.torch.load_file(
os.path.join(tmpdirname, "diffusion_pytorch_model.safetensors")
)

# corrupt the state dict
key_to_target = "context_embedder.weight" # can be other keys too.
compatible_param = state_dict[key_to_target]
corrupted_param = torch.randn(compatible_param.shape[0] - 1, 1)
state_dict[key_to_target] = bnb.nn.Params4bit(corrupted_param, requires_grad=False)
safetensors.torch.save_file(
state_dict, os.path.join(tmpdirname, "diffusion_pytorch_model.safetensors")
)

_ = SD3Transformer2DModel.from_pretrained(tmpdirname)

assert key_to_target in str(err_context.exception)


class BnB4BitTrainingTests(Base4bitTests):
def setUp(self):
gc.collect()
torch.cuda.empty_cache()

nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
Expand Down Expand Up @@ -360,6 +398,9 @@ def test_training(self):
@require_transformers_version_greater("4.44.0")
class SlowBnb4BitTests(Base4bitTests):
def setUp(self) -> None:
gc.collect()
torch.cuda.empty_cache()

nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
Expand Down Expand Up @@ -447,8 +488,10 @@ def test_moving_to_cpu_throws_warning(self):
@require_transformers_version_greater("4.44.0")
class SlowBnb4BitFluxTests(Base4bitTests):
def setUp(self) -> None:
# TODO: Copy sayakpaul/flux.1-dev-nf4-pkg to testing repo.
model_id = "sayakpaul/flux.1-dev-nf4-pkg"
gc.collect()
torch.cuda.empty_cache()

model_id = "hf-internal-testing/flux.1-dev-nf4-pkg"
t5_4bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
transformer_4bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer")
self.pipeline_4bit = DiffusionPipeline.from_pretrained(
Expand Down
20 changes: 17 additions & 3 deletions tests/quantization/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ def get_dummy_inputs(self):

class BnB8bitBasicTests(Base8bitTests):
def setUp(self):
gc.collect()
torch.cuda.empty_cache()

# Models
self.model_fp16 = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", torch_dtype=torch.float16
Expand Down Expand Up @@ -238,7 +241,7 @@ def test_llm_skip(self):

def test_config_from_pretrained(self):
transformer_8bit = FluxTransformer2DModel.from_pretrained(
"sayakpaul/flux.1-dev-int8-pkg", subfolder="transformer"
"hf-internal-testing/flux.1-dev-int8-pkg", subfolder="transformer"
)
linear = get_some_linear_layer(transformer_8bit)
self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
Expand Down Expand Up @@ -296,6 +299,9 @@ def test_device_and_dtype_assignment(self):

class BnB8bitTrainingTests(Base8bitTests):
def setUp(self):
gc.collect()
torch.cuda.empty_cache()

mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
self.model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
Expand Down Expand Up @@ -337,6 +343,9 @@ def test_training(self):
@require_transformers_version_greater("4.44.0")
class SlowBnb8bitTests(Base8bitTests):
def setUp(self) -> None:
gc.collect()
torch.cuda.empty_cache()

mixed_int8_config = BitsAndBytesConfig(load_in_8bit=True)
model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=mixed_int8_config
Expand Down Expand Up @@ -427,8 +436,10 @@ def test_generate_quality_dequantize(self):
@require_transformers_version_greater("4.44.0")
class SlowBnb8bitFluxTests(Base8bitTests):
def setUp(self) -> None:
# TODO: Copy sayakpaul/flux.1-dev-int8-pkg to testing repo.
model_id = "sayakpaul/flux.1-dev-int8-pkg"
gc.collect()
torch.cuda.empty_cache()

model_id = "hf-internal-testing/flux.1-dev-int8-pkg"
t5_8bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
transformer_8bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer")
self.pipeline_8bit = DiffusionPipeline.from_pretrained(
Expand Down Expand Up @@ -466,6 +477,9 @@ def test_quality(self):
@slow
class BaseBnb8bitSerializationTests(Base8bitTests):
def setUp(self):
gc.collect()
torch.cuda.empty_cache()

quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
)
Expand Down

0 comments on commit 60ffa84

Please sign in to comment.