Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Val loss improvement #1903

Merged
merged 19 commits into from
Feb 26, 2025
Merged

Val loss improvement #1903

merged 19 commits into from
Feb 26, 2025

Conversation

kohya-ss
Copy link
Owner

@kohya-ss kohya-ss commented Jan 27, 2025

  • train/eval state for the network and the optimizer.
  • stable timesteps
  • stable noise
  • support block swap

@stepfunction83
Copy link

stepfunction83 commented Jan 27, 2025

I love the approach to holding the rng_state aside, setting the validation state using the validation seed, and then restoring the rng_state afterwards. It's much more elegant than tracking the state separately and has no overhead.

@stepfunction83
Copy link

I would also add that once this is put in place, there won't be a need for a moving average to track the validation loss. Using consistent timesteps and noise will make it almost entirely stable, so displaying the mean of the validation loss amounts for each validation run should be all that's needed.

Since the validation set is subject to change if the core dataset changes, I've found tracking the validation loss relative to the initial loss is also helpful to make progress across different training runs comparable.

@rockerBOO
Copy link
Contributor

This looks great!

What are you using for formatting the code? I've been manually formatting but might be easier to align the formatting if I use the same formatting tool.

@kohya-ss
Copy link
Owner Author

I would also add that once this is put in place, there won't be a need for a moving average to track the validation loss. Using consistent timesteps and noise will make it almost entirely stable, so displaying the mean of the validation loss amounts for each validation run should be all that's needed.

That makes sense. Currently, there is a problem viewing logs in TensorBoard, but I would like to at least get the mean of the validation loss to be displayed correctly.

What are you using for formatting the code? I've been manually formatting but might be easier to align the formatting if I use the same formatting tool.

For formatting, I use black with the --line-length=132 option. I would like to at least provide a guideline on this.

@gesen2egee
Copy link
Contributor

gesen2egee commented Jan 28, 2025

It seems that correction for timestep sampling works better (I previously used debiased 1/√SNR, which is similar in meaning).
Perhaps averaging wouldn’t be necessary in this case.

Additionally, I have some thoughts on the args.
For validation_split, how about making it an integer greater than 1 to automatically represent the number of validation samples?
This would be more convenient.

@gesen2egee
Copy link
Contributor

gesen2egee commented Jan 28, 2025

image

https://github.com/[spacepxl/demystifying-sd-finetuning](https://github.com/spacepxl/demystifying-sd-finetuning)
Here's a suggestion for a function that, while not perfect, can normalize the losses across different timesteps to the same magnitude. I believe this approach is more reliable

@kohya-ss
Copy link
Owner Author

Here's a suggestion for a function that, while not perfect, can normalize the losses across different timesteps to the same magnitude. I believe this approach is more reliable

This makes some sense.
However, I believe that users already apply timestep weighting if necessary. For example, min snr gamma or debiased estimation etc.
Also, the validation loss should be the same as the training loss, so I think no additional correction should be necessary.

For validation_split, how about making it an integer greater than 1 to automatically represent the number of validation samples?

Although it means giving multiple meanings to a single setting value, it is worth considering.

@spacepxl
Copy link

spacepxl commented Jan 28, 2025

@gesen2egee you would need a different fit equation for each new model, and it's not really relevant when you make validation fully deterministic. I've tried applying it to training loss and it was extremely harmful.

You can also visualize the raw training loss by plotting it like so:

individualImage

That was done by storing all loss and timestep values, and coloring them by training step. Not sure if there's a way to do that natively in tensorboard/wandb, I did this with matplotlib and just logged it as an image.

@rockerBOO
Copy link
Contributor

File "/mnt/900/builds/sd-scripts/library/train_util.py", line 5968, in get_timesteps
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
RuntimeError: random_ expects 'from' to be less than 'to', but got from=200 >= to=200

In get_timesteps maybe

if min_timestep < max_timestep:
    timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
else:
    timesteps = torch.ones(b_size, device="cpu") * min_timestep

I know this isn't completed but I tried it anyways.

@rockerBOO
Copy link
Contributor

rockerBOO commented Jan 29, 2025

Validation dataset: 20. Regularization images (1500 images) + training dataset (16 repeats)

Screenshot 2025-01-29 at 01-08-39 true-bee-59 cyberpunk-boo-kohya-lora – Weights   Biases

Validation loss (dashed) looking pretty smooth from a test.

Screenshot 2025-01-29 at 11-44-07 dashing-sun-66 cyberpunk-boo-kohya-lora – Weights   Biases

Another example with a different loss profile.

@TheDuckingDuck
Copy link

TheDuckingDuck commented Jan 29, 2025

@gesen2egee you would need a different fit equation for each new model, and it's not really relevant when you make validation fully deterministic. I've tried applying it to training loss and it was extremely harmful.

You can also visualize the raw training loss by plotting it like so:

That was done by storing all loss and timestep values, and coloring them by training step. Not sure if there's a way to do that natively in tensorboard/wandb, I did this with matplotlib and just logged it as an image.

To add to this, here's a rough example of how loss per timestep looks like on flux at 1024x1024:
image

@stepfunction83
Copy link

stepfunction83 commented Jan 30, 2025

I just realized that if there are any repeats in the dataset, validation error will not be calculated properly as the images used for validation can still exist in the training data. Guidance should likely be provided that repeats should not be used when validation error is being calculated, otherwise it will not be a useful metric for identifying overfitting.

Since the validation set is now going through a dataloader, it might also be easier to set up a separate directory of validation images to use.

@rockerBOO
Copy link
Contributor

Repeats are done after the split. I also made sure regularization images do not go into the validation dataset when splitting.

@stepfunction83
Copy link

Nice! That's good to hear!

@stepfunction83
Copy link

stepfunction83 commented Jan 30, 2025

I am not able to get a stable loss from this PR:

image

In this, I calculate the average loss on each validation run and graph them out. As you can see, the line is quite volatile, and has been on all of the attempts I've tried so far.

It could be user error. Here are the arguments I'm using:

--validation_seed 43 --validation_split .05 --validate_every_n_steps 100

To calculate the error per cycle:

val_losses = [] # Before training loop
...
current_val_losses = [] # Before validation loop
...
current_loss = loss.detach().item()
current_val_losses.append(current_loss)
...
val_losses.append(sum(current_val_losses)/len(current_val_losses)) #After validation loop

 if is_tracking:
                        logs = {
                            "loss/validation/val_loss": val_losses[-1]
                        }

I also tried printing the individual losses and based on their magnitudes, it looks like the same images are being run each cycle. The timesteps are being set manually, so that's not the variance, which means it must be the noise which is varying from one iteration to the next.

I'll do some additional testing tomorrow to identify if it is the noise and if so, why it would be varying in consecutive runs even though the same seed is used.

@kohya-ss
Copy link
Owner Author

In SDXL training, the random numbers are generated using the device, so this may be the cause.

noise = torch.randn_like(latents, device=latents.device)

Also, torch.manual_seed initializes the random seed for all devices, but rng_state seems to only work for CPU.

https://pytorch.org/docs/stable/generated/torch.manual_seed.html#torch.manual_seed
https://pytorch.org/docs/stable/generated/torch.get_rng_state.html#torch.get_rng_state

Although it would be a breaking change, it would be better to unify the way random numbers are generated.

@stepfunction83
Copy link

In SDXL training, the random numbers are generated using the device, so this may be the cause.

noise = torch.randn_like(latents, device=latents.device)

Also, torch.manual_seed initializes the random seed for all devices, but rng_state seems to only work for CPU.

https://pytorch.org/docs/stable/generated/torch.manual_seed.html#torch.manual_seed https://pytorch.org/docs/stable/generated/torch.get_rng_state.html#torch.get_rng_state

Although it would be a breaking change, it would be better to unify the way random numbers are generated.

I probably should have specified, but I am doing Flux LoRAs in this case, not SDXL. Also, if manual_seed is setting it for all devices, then it should be providing the desired consistency, so it's odd that it's not. I'll do some digging to confirm whether what I'm experiencing is even due to the noise, or whether it's something else.

There is always the possibility of caching the noise and using it in repeat runs, which doesn't interfere with any random state. Based on my initial testing, this didn't seem to have a significant overhead, but isn't quite as pretty as just setting the seed.

@stepfunction83
Copy link

stepfunction83 commented Jan 30, 2025

After some digging, the noise is absolutely consistent over time, but the noisy_model_input is not. I'm trying to understand why it is changing from one iteration to the next.

@stepfunction83
Copy link

stepfunction83 commented Jan 30, 2025

image

That's the culprit. It wasn't the noise changing, it was the latents changing!

Digging deeper:

def crop_target(...):
    p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range  # -range ~ +range までのいい感じの乱数

random.randint(0, range)

It's due to the random cropping coming from random and not torch, so setting the manual_seed doesn't fix the cropping in place.

I'm going to turn it off and do a longer run to confirm.

If you want to test it yourself, you'll need the fix from 8dfb42f to allow flux to run with validation and without latents cached.

Turning it off leads to a perfectly smooth and beautiful loss curve:

image

@stepfunction83
Copy link

stepfunction83 commented Jan 30, 2025

I was able to resolve this by setting random.seed() in addition to setting torch.manual_seed():

rng_state = torch.get_rng_state()
random_seed_state = random.randint(0,100000)
torch.manual_seed(args.validation_seed if args.validation_seed is not None else args.seed)
random.seed(args.validation_seed if args.validation_seed is not None else args.seed)

and then

torch.set_rng_state(rng_state)
random.seed(random_seed_state)

It's a bit hacky, but storing the state for random as random_seed_state = random.randint(0,100000) does get the job done since on repeat runs, it will always play out consistently, though you would want to move this step outside of the validation so that the state for random increments the same way independently of validation running.

Everything's enabled (and I am also training clip as well):
image

And a nice, smooth curve (Note: With an aggressive LR of 0.001 and no graph smoothing or averaging applied)!
image

With a more normal LR of 0.0001, the curve is perfectly smooth:
image

@kohya-ss
Copy link
Owner Author

kohya-ss commented Feb 9, 2025

Fix Python random seed is not set.
Support block swap with validation.

I think it's probably ready to be merged, so please let me know if you have any suggestions.

@rockerBOO
Copy link
Contributor

For the main progress_bar if you do progress_bar.unpause() after validation steps it will reset the timer for main training iterations.

https://tqdm.github.io/docs/tqdm/#unpause

Otherwise has been working for me and block swap is working as well.

@stepfunction83
Copy link

stepfunction83 commented Feb 9, 2025

Not at my computer today to test, but were the fixes from #1900 implemented to fix the issue with running validation when latents are not cached and to fix bar positioning?

@kohya-ss
Copy link
Owner Author

Added unpause, and should work if latents are not cached.

The TensorBoard logs are still broken...

@rockerBOO
Copy link
Contributor

What is the expectation for tensorboard logs? I was mentioning that we may need a different approach for different "experiment trackers". Like tensorboard may be handled differently than wandb, and not have everything through the generic accelerate.log. We do this already for wandb specific features like sample images. Though I want to consider what seems appropriate for the other trackers they support.

For Validation, the logging of the individual validation steps isn't being represented. This is due to them all happening on the same "step", which we are using global step. Iterating through step for each accelerate.log (and not setting steps) would work better to record them all. Meaning accelerate.log(logs) but do not pass the steps. Downside is the graphs will have gaps in their representation.

Example from another trainer. Notice the flat parts.
Screenshot 2025-02-12 at 12-32-19 noble-jazz-84 t5-distill – Weights   Biases
Screenshot 2025-02-12 at 12-32-06 noble-jazz-84 t5-distill – Weights   Biases

I do not know what is best here in the end. I think compromise and making things work for the trackers how we want would be the best. Make tensorboard work how you'd expect (via the specific experiment tracker) and then we can make it work for the other trackers maybe more generically.

@stepfunction83
Copy link

To add on, I'd also like to recommend the metric of relative validation loss compared to the initial loss calculated.

Since the validation set is not constant, it is important to calculate the loss relative to where it started to more accurately track progress on different training runs.

rel_val_loss=curr_val_loss / val_losses[0]

@kohya-ss
Copy link
Owner Author

I do not know what is best here in the end. I think compromise and making things work for the trackers how we want would be the best. Make tensorboard work how you'd expect (via the specific experiment tracker) and then we can make it work for the other trackers maybe more generically.

I think this makes sense. Since I mainly use TensorBoard, I think it might be a good idea to implement it so that it outputs the desired logs for that, and then ask the community for help to see if it outputs correctly for other trackers.

It would be great if accelerate abstracted the log output sufficiently, but if it didn't, we might have to deal with each tracker using if then, which would be a headache.

To add on, I'd also like to recommend the metric of relative validation loss compared to the initial loss calculated.

This is a difficult question. I think that relative values ​​can also be evaluated using absolute value graphs, so I personally think that it would be better to display the raw values.

@kohya-ss
Copy link
Owner Author

Now the log on TensorBoard seems to be more consistent.

image
image

step_current fluctuates greatly depending on changes in timestep, so it may not be very meaningful.

@kohya-ss
Copy link
Owner Author

I believe this can be merge.

@rockerBOO
Copy link
Contributor

Only issue with your approach for using the step in that way is it blocks those logs from being recorded on Wandb at all because they are not incrementing or they increment past the end preventing other logs using accelerate.log(*, step=global_step) from being logged as well. It's pretty annoying that Wandb works in this way but would probably require a pass to get them both working like how you have it presented. Otherwise if you use validation, normally logging would fail after the validation pass because global_step + val step would move the step increment past global_step.

@stepfunction83
Copy link

I would agree with that. Logging the individual validation steps is almost entirely a waste of time. The only metric of interest is the average validation loss.

@kohya-ss
Copy link
Owner Author

Thanks for the details. The wandb specifications are a bit annoying...

In my understanding, by not recording the val loss per step, we can make sure that the step always increases. The log display in TensorBoard will be less intuitive, but I think it will be still understandable.

It's late, but I'll be able to resolve it tomorrow.

@rockerBOO
Copy link
Contributor

rockerBOO commented Feb 20, 2025

The secondary issue was setting the epoch, since the step would be lower than the global step it would not record the epochs as well.

I think they (wandb) generally want you to not set the steps at all, and set the "global_step" or "epoch" in as actual values. Then you can set them as "metric" and then you can use them in the graphs to organize them based on global_step, epoch, and/or val_step.

accelerator.log({"global_step": global_step, "epoch": epoch + 1})

So for tensorboard would need to be like:

    if "tensorboard" in [tracker.name for tracker in accelerator.trackers]:
        tensorboard_tracker = accelerator.get_tracker("tensorboard")
        with tensorboard_tracker.as_default(step=global_step + val_step):
            tf.summary.scalar("loss/validation/step_current": current_loss)

And wandb:

    if "wandb" in [tracker.name for tracker in accelerator.trackers]:
        wandb_tracker = accelerator.get_tracker("wandb")
        wandb_tracker.log(
            {
               "loss/validation/step_current": current_loss,
               "global_step": global_step,
               "epoch": epoch + 1,
               "val_step": val_step
            }
        )

Which then we would add global_step to the metric for wandb

        if "wandb" in [tracker.name for tracker in accelerator.trackers]:
            import wandb 
            wandb_tracker = accelerator.get_tracker("wandb", unwrap=True)

            # Define specific metrics to handle validation and epochs "steps"
            wandb_tracker.define_metric("epoch", hidden=True)
            wandb_tracker.define_metric("val_step", hidden=True)

            wandb_tracker.define_metric("global_step", hidden=True)

I think we could make this happen for both but would need to drop accelerator.log() to be able to do it. Make a new function that handles these distinctions so we can utilize the trackers to their most efficient methods and not work around the lack of functionality of accelerator.log.

@kohya-ss
Copy link
Owner Author

kohya-ss commented Feb 21, 2025

Thank you for the detailed explanation. I was a little surprised that accelerate doesn't wrap anything regarding logs.

I removed the output of val logs for each step.

I believe now the logs will be displayed correctly on TensorBoard and WandB (not tested though.)

It may be necessary to take similar measures for other training scripts in the future.

@rockerBOO
Copy link
Contributor

Looks good to me! I'll try to test it in a little bit. Thanks for doing this!

@heinrichI
Copy link

Val loss not compatible with masked_loss?

@rockerBOO
Copy link
Contributor

Screenshot 2025-02-24 at 14-08-18 easy-mountain-428 women-kohya-lora – Weights   Biases

It is working for me with the 3 loss types, and then I can set the global_step and works. Only 4 epochs and every 32 steps with only 80 steps in the test. Looks good to me to merge.

@kohya-ss kohya-ss merged commit 4965189 into sd3 Feb 26, 2025
2 checks passed
@kohya-ss kohya-ss deleted the val-loss-improvement branch February 26, 2025 12:15
@kohya-ss
Copy link
Owner Author

Sorry for the late merge. Thank you for your detailed reviews!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants