Skip to content

Commit

Permalink
Merge pull request #417 from bghira/main
Browse files Browse the repository at this point in the history
prompt library: rewrite all prompts, focusing on concept diversity and density, reducing 'sameness' complaints of prompt library | logging: reduce logspam in `INFO` log level | aspect bucketing: ability to randomise aspect buckets without distorting the images (experimental) | validations: ability to disable uncond generation for a slight speed-up on slow hardware when not necessary | aspect bucketing: ability to customise the aspect resolution mappings and enforce the resolutions you wish to train on | captioning toolkit: new scripts for gemini-pro-vision, paligemma 3B and BLIP3 | bugfix: dataloader metadata retrieval would occasionally return the wrong values if filenames match across multiple datasets
  • Loading branch information
bghira authored May 23, 2024
2 parents d8ba270 + ef7fd80 commit 67dc2a8
Show file tree
Hide file tree
Showing 18 changed files with 779 additions and 279 deletions.
46 changes: 38 additions & 8 deletions OPTIONS.md
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ usage: train_sdxl.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr]
[--lora_type {Standard}]
[--lora_init_type {default,gaussian,loftq}]
[--lora_rank LORA_RANK] [--lora_alpha LORA_ALPHA]
[--lora_dropout LORA_DROPOUT]
[--lora_dropout LORA_DROPOUT] [--controlnet]
[--controlnet_model_name_or_path]
--pretrained_model_name_or_path
PRETRAINED_MODEL_NAME_OR_PATH
[--pretrained_vae_model_name_or_path PRETRAINED_VAE_MODEL_NAME_OR_PATH]
Expand All @@ -201,8 +202,9 @@ usage: train_sdxl.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr]
[--vae_cache_scan_behaviour {recreate,sync}]
[--vae_cache_preprocess] [--keep_vae_loaded]
[--skip_file_discovery SKIP_FILE_DISCOVERY]
[--revision REVISION] [--preserve_data_backend_cache]
[--use_dora] [--override_dataset_config]
[--revision REVISION] [--variant VARIANT]
[--preserve_data_backend_cache] [--use_dora]
[--override_dataset_config]
[--cache_dir_text CACHE_DIR_TEXT]
[--cache_dir_vae CACHE_DIR_VAE] --data_backend_config
DATA_BACKEND_CONFIG [--write_batch_size WRITE_BATCH_SIZE]
Expand Down Expand Up @@ -256,12 +258,12 @@ usage: train_sdxl.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr]
[--adam_weight_decay ADAM_WEIGHT_DECAY]
[--adam_epsilon ADAM_EPSILON] [--adam_bfloat16]
[--max_grad_norm MAX_GRAD_NORM] [--push_to_hub]
[--push_checkpoints_to_hub]
[--hub_model_id HUB_MODEL_ID] [--logging_dir LOGGING_DIR]
[--push_checkpoints_to_hub] [--hub_model_id HUB_MODEL_ID]
[--logging_dir LOGGING_DIR]
[--validation_torch_compile VALIDATION_TORCH_COMPILE]
[--validation_torch_compile_mode {max-autotune,reduce-overhead,default}]
[--allow_tf32] [--webhook_config WEBHOOK_CONFIG]
[--report_to REPORT_TO]
[--allow_tf32] [--validation_using_datasets]
[--webhook_config WEBHOOK_CONFIG] [--report_to REPORT_TO]
[--tracker_run_name TRACKER_RUN_NAME]
[--tracker_project_name TRACKER_PROJECT_NAME]
[--validation_prompt VALIDATION_PROMPT]
Expand Down Expand Up @@ -296,6 +298,8 @@ usage: train_sdxl.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr]
[--freeze_encoder FREEZE_ENCODER] [--save_text_encoder]
[--text_encoder_limit TEXT_ENCODER_LIMIT]
[--prepend_instance_prompt] [--only_instance_prompt]
[--data_aesthetic_score DATA_AESTHETIC_SCORE]
[--sdxl_refiner_uses_full_range]
[--caption_dropout_probability CAPTION_DROPOUT_PROBABILITY]
[--input_perturbation INPUT_PERTURBATION]
[--input_perturbation_probability INPUT_PERTURBATION_PROBABILITY]
Expand Down Expand Up @@ -345,6 +349,13 @@ options:
--lora_dropout LORA_DROPOUT
LoRA dropout randomly ignores neurons during training.
This can help prevent overfitting.
--controlnet If set, ControlNet style training will be used, where
a conditioning input image is required alongside the
training data.
--controlnet_model_name_or_path
When provided alongside --controlnet, this will
specify ControlNet model weights to preload from the
hub.
--pretrained_model_name_or_path PRETRAINED_MODEL_NAME_OR_PATH
Path to pretrained model or model identifier from
huggingface.co/models.
Expand Down Expand Up @@ -455,7 +466,10 @@ options:
aspect, vae, text, metadata.
--revision REVISION Revision of pretrained model identifier from
huggingface.co/models. Trainable model components
should be float32 precision.
should be at least bfloat16 precision.
--variant VARIANT Variant of pretrained model identifier from
huggingface.co/models. Trainable model components
should be at least bfloat16 precision.
--preserve_data_backend_cache
For very large cloud storage buckets that will never
change, enabling this option will prevent the trainer
Expand Down Expand Up @@ -753,6 +767,11 @@ options:
used to speed up training. For more information, see h
ttps://pytorch.org/docs/stable/notes/cuda.html#tensorf
loat-32-tf32-on-ampere-devices
--validation_using_datasets
When set, validation will use images sampled randomly
from each dataset for validation. Be mindful of
privacy issues when publishing training data to the
internet.
--webhook_config WEBHOOK_CONFIG
The path to the webhook configuration file. This file
should be a JSON file with the following format:
Expand Down Expand Up @@ -930,6 +949,17 @@ options:
--only_instance_prompt
Use the instance prompt instead of the caption from
filename.
--data_aesthetic_score DATA_AESTHETIC_SCORE
Since currently we do not calculate aesthetic scores
for data, we will statically set it to one value. This
is only used by the SDXL Refiner.
--sdxl_refiner_uses_full_range
If set, the SDXL Refiner will use the full range of
the model, rather than the design value of 20 percent.
This is useful for training models that will be used
for inference from end-to-end of the noise schedule.
You may use this for example, to turn the SDXL refiner
into a full text-to-image model.
--caption_dropout_probability CAPTION_DROPOUT_PROBABILITY
Caption dropout will randomly drop captions and, for
SDXL, size conditioning inputs based on this
Expand Down
63 changes: 59 additions & 4 deletions documentation/DATALOADER.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ Here is an example dataloader configuration file, as `multidatabackend.example.j
"instance_data_dir": "/path/to/data/tree",
"crop": false,
"crop_style": "random|center|corner|face",
"crop_aspect": "square|preserve",
"crop_aspect": "square|preserve|random",
"crop_aspect_buckets": [
0.33, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75
],
"resolution": 1.0,
"resolution_type": "area|pixel",
"minimum_image_size": 1.0,
Expand Down Expand Up @@ -88,7 +91,8 @@ Here is an example dataloader configuration file, as `multidatabackend.example.j
### Cropping Options
- `crop`: Enables or disables image cropping.
- `crop_style`: Selects the cropping style (`random`, `center`, `corner`, `face`).
- `crop_aspect`: Chooses the cropping aspect (`square` or `preserve`).
- `crop_aspect`: Chooses the cropping aspect (`random`, `square` or `preserve`).
- `crop_aspect_buckets`: When `crop_aspect` is set to `random`, a bucket from this list will be selected, so long as the resulting image size would not result more than 20% upscaling.

### `resolution`
- **Area-Based:** Cropping/sizing is done by megapixel count.
Expand All @@ -113,7 +117,7 @@ Here is an example dataloader configuration file, as `multidatabackend.example.j
- Specifies the number of times all samples in the dataset are seen during an epoch. Useful for giving more impact to smaller datasets or maximizing the usage of VAE cache objects.

### `vae_cache_clear_each_epoch`
- When enabled, all VAE cache objects are deleted from the filesystem at the end of each dataset repeat cycle. This can be resource-intensive for large datasets.
- When enabled, all VAE cache objects are deleted from the filesystem at the end of each dataset repeat cycle. This can be resource-intensive for large datasets, but combined with `crop_style=random` and/or `crop_aspect=random` you'll want this enabled to ensure you sample a full range of crops from each image.

### `ignore_epochs`
- When enabled, this dataset will not hold up the rest of the datasets from completing an epoch. This will inherently make the value for the current epoch inaccurate, as it reflects only the number of times any datasets *without* this flag have completed all of their repeats. The state of the ignored dataset isn't reset upon the next epoch, it is simply ignored. It will eventually run out of samples as a dataset typically does. At that time it will be removed from consideration until the next natural epoch completes.
Expand Down Expand Up @@ -151,6 +155,8 @@ In order, the lines behave as follows:

> ❗Use [regex 101](https://regex101.com) for help debugging and testing regular expressions.
# Advanced techniques

## Parquet caption strategy

> ⚠️ This is an advanced feature, and will not be necessary for most users.
Expand Down Expand Up @@ -203,4 +209,53 @@ In this configuration:
As with other dataloader configurations:

- `prepend_instance_prompt` and `instance_prompt` behave as normal.
- Updating a sample's caption in between training runs will cache the new embed, but not remove the old (orphaned) unit.
- Updating a sample's caption in between training runs will cache the new embed, but not remove the old (orphaned) unit.

## Custom aspect ratio-to-resolution mapping

When SimpleTuner first launches, it generates resolution-specific aspect mapping lists that link a decimal aspect-ratio value to its target pixel size.

It's possible to create a custom mapping that forces the trainer to adjust to your chosen target resolution instead of its own calculations. This functionality is provided at your own risk, as it can obviously cause great harm if configured incorrectly.

To create the custom mapping:

- Create a file that follows the example (below)
- Name the file using the format `aspect_ratio_map-{resolution}.json`
- For a configuration value of `resolution=1.0` / `resolution_type=area`, the mapping filename will be `aspect_resolution_map-1.0.json`
- Place this file in the location specified as `--output_dir`
- This is the same location where your checkpoints and validation images will be found.
- No additional configuration flags or options are required. It will be automatically discovered and used, as long as the name and location are correct.

### Example mapping configuration

This is an example aspect ratio mapping generated by SimpleTuner.

- The dataset had more than 1 million images
- The dataloader `resolution` was set to `1.0`
- The dataloader `resolution_type` was set to `area`

This is the most common configuration, and list of aspect buckets trainable for a 1 megapixel model.

```json
{
"0.07": [320, 4544], "0.38": [640, 1664], "0.88": [960, 1088], "1.92": [1472, 768], "3.11": [1792, 576], "5.71": [2560, 448],
"0.08": [320, 3968], "0.4": [640, 1600], "0.89": [1024, 1152], "2.09": [1472, 704], "3.22": [1856, 576], "6.83": [2624, 384],
"0.1": [320, 3328], "0.41": [704, 1728], "0.94": [1024, 1088], "2.18": [1536, 704], "3.33": [1920, 576], "7.0": [2688, 384],
"0.11": [384, 3520], "0.42": [704, 1664], "1.06": [1088, 1024], "2.27": [1600, 704], "3.44": [1984, 576], "8.0": [3072, 384],
"0.12": [384, 3200], "0.44": [704, 1600], "1.12": [1152, 1024], "2.5": [1600, 640], "3.88": [1984, 512],
"0.14": [384, 2688], "0.46": [704, 1536], "1.13": [1088, 960], "2.6": [1664, 640], "4.0": [2048, 512],
"0.15": [448, 3008], "0.48": [704, 1472], "1.2": [1152, 960], "2.7": [1728, 640], "4.12": [2112, 512],
"0.16": [448, 2816], "0.5": [768, 1536], "1.36": [1216, 896], "2.8": [1792, 640], "4.25": [2176, 512],
"0.19": [448, 2304], "0.52": [768, 1472], "1.46": [1216, 832], "3.11": [1792, 576], "4.38": [2240, 512],
"0.24": [512, 2112], "0.55": [768, 1408], "1.54": [1280, 832], "3.22": [1856, 576], "5.0": [2240, 448],
"0.26": [512, 1984], "0.59": [832, 1408], "1.83": [1408, 768], "3.33": [1920, 576], "5.14": [2304, 448],
"0.29": [576, 1984], "0.62": [832, 1344], "1.92": [1472, 768], "3.44": [1984, 576], "5.71": [2560, 448],
"0.31": [576, 1856], "0.65": [832, 1280], "2.09": [1472, 704], "3.88": [1984, 512], "6.83": [2624, 384],
"0.34": [640, 1856], "0.68": [832, 1216], "2.18": [1536, 704], "4.0": [2048, 512], "7.0": [2688, 384],
"0.38": [640, 1664], "0.74": [896, 1216], "2.27": [1600, 704], "4.12": [2112, 512], "8.0": [3072, 384],
"0.4": [640, 1600], "0.83": [960, 1152], "2.5": [1600, 640], "4.25": [2176, 512],
"0.41": [704, 1728], "0.88": [960, 1088], "2.6": [1664, 640], "4.38": [2240, 512],
"0.42": [704, 1664], "0.89": [1024, 1152], "2.7": [1728, 640], "5.0": [2240, 448],
"0.44": [704, 1600], "0.94": [1024, 1088], "2.8": [1792, 640], "5.14": [2304, 448]
}
```
12 changes: 11 additions & 1 deletion helpers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,12 +1012,22 @@ def parse_args(input_args=None):
" Default: ddim"
),
)
parser.add_argument(
"--validation_disable_unconditional",
action="store_true",
help=(
"When set, the validation pipeline will not generate unconditional samples."
" This is useful to speed up validations with a single prompt on slower systems, or if you are not"
" interested in unconditional space generations."
),
)
parser.add_argument(
"--disable_compel",
action="store_true",
help=(
"If provided, prompts will be handled using the typical prompt encoding strategy."
"If provided, validation pipeline prompts will be handled using the typical prompt encoding strategy."
" Otherwise, the default behaviour is to use Compel for prompt embed generation."
" Note that the training input text embeds are not generated using Compel, and will be truncated to 77 tokens."
),
)
parser.add_argument(
Expand Down
4 changes: 2 additions & 2 deletions helpers/caching/sdxl_embeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def create_hash(self, caption):
# Reuse the hash object
md5_hash = hashlib.md5()
md5_hash.update(caption.encode())
logger.debug(f"Hashing caption: {caption}")
# logger.debug(f"Hashing caption: {caption}")
result = md5_hash.hexdigest() + hash_format
logger.debug(f"-> {result}")
# logger.debug(f"-> {result}")
return result

def hash_prompt(self, caption):
Expand Down
20 changes: 14 additions & 6 deletions helpers/caching/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,20 @@


def prepare_sample(image: Image.Image, data_backend_id: str, filepath: str):
metadata = StateTracker.get_metadata_by_filepath(filepath)
metadata = StateTracker.get_metadata_by_filepath(
filepath, data_backend_id=data_backend_id
)
logger.debug(
f"Preparing sample {image} from {filepath} with data backend {data_backend_id}. Metadata: {metadata}"
)
training_sample = TrainingSample(
image=image,
data_backend_id=data_backend_id,
image_metadata=metadata,
image_path=filepath,
)
prepared_sample = training_sample.prepare()
logger.debug(f"Prepared: {prepared_sample.to_dict()}")
return (
prepared_sample.image,
prepared_sample.crop_coordinates,
Expand Down Expand Up @@ -443,11 +449,13 @@ def encode_images(self, images, filepaths, load_from_cache=True):
if len(missing_images) > 0 and not self.vae_cache_preprocess:
missing_image_paths = [filepaths[i] for i in missing_images]
logger.debug(f"Missing image paths: {missing_image_paths}")
missing_image_data_generator = self._read_from_storage_concurrently(
missing_image_paths, hide_errors=True
)
# extract images from generator:
missing_image_data = [
self._read_from_storage_concurrently(
missing_image_paths, hide_errors=True
).__next__()[1]
for i in missing_images
retrieved_image_data[1]
for retrieved_image_data in missing_image_data_generator
]
logger.debug(f"Missing image data: {missing_image_data}")
missing_image_pixel_values = self._process_images_in_batch(
Expand Down Expand Up @@ -629,7 +637,7 @@ def _process_images_in_batch(
f"Skipping {filepath} because it does not meet the minimum image size requirement of {self.minimum_image_size}"
)
continue

# image.save(f"test_{os.path.basename(filepath)}.png")
initial_data.append((filepath, image, aspect_bucket))

# Process Pool Execution
Expand Down
23 changes: 22 additions & 1 deletion helpers/data_backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,22 @@ def init_backend_config(backend: dict, args: dict, accelerator) -> dict:
else:
output["config"]["crop"] = args.crop
if "crop_aspect" in backend:
choices = ["square", "preserve", "random"]
if backend.get("crop_aspect", None) not in choices:
raise ValueError(
f"(id={backend['id']}) crop_aspect must be one of {choices}."
)
output["config"]["crop_aspect"] = backend["crop_aspect"]
if (
output["config"]["crop_aspect"] == "random"
and "crop_aspect_buckets" not in backend
):
raise ValueError(
f"(id={backend['id']}) crop_aspect_buckets must be provided when crop_aspect is set to 'random'."
" This should be a list of float values or a list of dictionaries following the format: {'aspect_bucket': float, 'weight': float}."
" The weight represents how likely this bucket is to be chosen, and all weights should add up to 1.0 collectively."
)
output["config"]["crop_aspect_buckets"] = backend.get("crop_aspect_buckets")
else:
output["config"]["crop_aspect"] = "square"
if "crop_style" in backend:
Expand Down Expand Up @@ -391,6 +406,9 @@ def configure_multi_databackend(
data_backend_id=init_backend["id"],
preserve_data_backend_cache=preserve_data_backend_cache,
)
StateTracker.load_aspect_resolution_map(
dataloader_resolution=init_backend["config"]["resolution"],
)

if backend["type"] == "local":
init_backend["data_backend"] = get_local_backend(
Expand Down Expand Up @@ -419,7 +437,10 @@ def configure_multi_databackend(
raise ValueError(f"Unknown data backend type: {backend['type']}")

# Assign a TextEmbeddingCache to this dataset. it might be undefined.
text_embed_id = backend.get("text_embeds", default_text_embed_backend_id)
text_embed_id = backend.get(
"text_embeds",
backend.get("text_embed_cache", default_text_embed_backend_id),
)
if text_embed_id not in text_embed_backends:
raise ValueError(
f"Text embed backend {text_embed_id} not found in data backend config file."
Expand Down
Loading

0 comments on commit 67dc2a8

Please sign in to comment.