Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/cyclegan v3 #2050

Draft
wants to merge 67 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
0b54b78
start work on cyclegan, WIP
mcgibbon Aug 17, 2022
3fe295b
still WIP, working on evaluation
mcgibbon Aug 18, 2022
e24e894
autoencoder can overfit, something wrong in evaluation, wip
mcgibbon Aug 19, 2022
abed41e
disable instance normalization, model now trains
mcgibbon Aug 19, 2022
c453017
restore instance normalization by using skip connection
mcgibbon Aug 19, 2022
9fa0c10
linting fixes, updated dumping type hints to use Dumpable instead of …
mcgibbon Aug 22, 2022
9f91b65
Merge branch 'feature/cyclegan' of github.com:ai2cm/fv3net into featu…
mcgibbon Aug 22, 2022
1b5ccd5
remove pdb call
mcgibbon Aug 22, 2022
a5552c3
fix IO test
mcgibbon Aug 22, 2022
8f8451c
move training configuration for pytorch models into the same file
mcgibbon Aug 22, 2022
3f22efe
deduplicate training loop logic
mcgibbon Aug 22, 2022
0aee33e
delete dead code, refactor to use ConvolutionFactoryFactory
mcgibbon Aug 22, 2022
a7f9563
fix failing test for IO
mcgibbon Aug 22, 2022
084a519
add autoencoder to special training types
mcgibbon Aug 22, 2022
f478eeb
fix logic to reset register between tests
mcgibbon Aug 22, 2022
b3fb256
wip cyclegan training code
mcgibbon Aug 24, 2022
c84d0d9
WIP, cyclegan is training but not converging
mcgibbon Aug 25, 2022
0653057
still wip, not converging
mcgibbon Aug 26, 2022
1f3efc1
ignore internal deprecation warnings during tests
mcgibbon Aug 26, 2022
a5948a6
Merge branch 'master' into feature/cyclegan_discriminate
mcgibbon Aug 26, 2022
b7992bf
cyclegan training code might be working, hard to test it
mcgibbon Aug 29, 2022
d1b2ec7
Merge branch 'master' into feature/cyclegan_discriminate
mcgibbon Aug 29, 2022
c4df7c4
fix test broken by merge
mcgibbon Aug 29, 2022
4f590b9
fix test_io.py by reverting to master
mcgibbon Aug 30, 2022
a6284ec
Merge branch 'master' into feature/cyclegan_discriminate
mcgibbon Sep 2, 2022
ed89ba7
working version of cyclegan manual test
mcgibbon Sep 2, 2022
e92aabd
remove non-functional overfitting test
mcgibbon Sep 2, 2022
d988f23
update cyclegan to work with mps acceleration
mcgibbon Sep 2, 2022
e6b2045
re-organize cyclegan code into more modules
mcgibbon Sep 3, 2022
2b17e88
cleanup merge leftover
mcgibbon Sep 3, 2022
5596903
add some docstrings to cyclegan_trainer.py
mcgibbon Sep 7, 2022
d8a3796
update n_convolutions in test to reflect new api behavior
mcgibbon Sep 7, 2022
8f18ccd
improve documentation for cyclegan model and training
mcgibbon Sep 7, 2022
c0c5d90
update shapes to their expected final form at each scope, document them
mcgibbon Sep 7, 2022
eb22eda
add test of StatsCollector
mcgibbon Sep 7, 2022
e88b482
remove unused output_scalers argument
mcgibbon Sep 7, 2022
c8f8d1c
add cyclegan symbols to fv3fit.pytorch namespace
mcgibbon Sep 7, 2022
233224c
add missing public symbols to fv3fit.data
mcgibbon Sep 7, 2022
188978a
further elaborate on the shape of data returned by synthetic loader
mcgibbon Sep 7, 2022
4221f55
fix bug in CycleGANLoader where batch_size option is ignored
mcgibbon Sep 7, 2022
01bc917
refactor get_Xy_dataset into more composeable get_Xy_map_fn
mcgibbon Sep 7, 2022
99e5290
update typing on PytorchAutoregressor to use str keys for scalers
mcgibbon Sep 7, 2022
2b121eb
add MPS support to fv3fit.pytorch.system for mac M1s
mcgibbon Sep 7, 2022
7c662a2
fix apparent bug in apply_to_tuple
mcgibbon Sep 7, 2022
19197ef
Merge branch 'feature/cyclegan_discriminate' into feature/cyclegan_v2
mcgibbon Sep 7, 2022
f260b7f
Merge branch 'master' into feature/cyclegan_v2
mcgibbon Sep 7, 2022
c049b85
add cyclegan to SPECIAL_TRAINING_TYPES
mcgibbon Sep 7, 2022
8a006e9
update autoencoder test and training function to channels first, fix …
mcgibbon Sep 7, 2022
fd02e10
updated training to closer match paper, model trained but predictor s…
mcgibbon Sep 12, 2022
37d7230
switch ReplayBuffer for ImagePool implementation by authors
mcgibbon Sep 12, 2022
97b4ab6
uncomment validation statistic outputs
mcgibbon Sep 12, 2022
bd9ae1e
add regtest coverage for cyclegan training
mcgibbon Sep 12, 2022
dc718fa
remove regtest from cyclegan test, non-deterministic
mcgibbon Sep 12, 2022
5d6e025
Merge branch 'master' into feature/cyclegan_v3
mcgibbon Sep 12, 2022
356c9c8
Merge branch 'feature/cyclegan_v2' into feature/cyclegan_v3
mcgibbon Sep 12, 2022
8be0697
state left wip at last session
mcgibbon Sep 16, 2022
ea3b13d
Merge branch 'master' into feature/cyclegan_v3
mcgibbon Sep 16, 2022
d10bed9
delete non-working test
mcgibbon Sep 16, 2022
4177e8c
Merge branch 'master' into feature/cyclegan_v3
mcgibbon Sep 22, 2022
954db8e
add halo updates to cyclegan
mcgibbon Sep 26, 2022
bf9330c
revert unneeded changes from main branch
mcgibbon Sep 26, 2022
f7cf3e9
add in_memory option for training CycleGAN
mcgibbon Sep 26, 2022
f86e2dc
continued attempts to get training working again on mps (not working)
mcgibbon Sep 26, 2022
5118281
updates to project directory, evaluation
mcgibbon Sep 26, 2022
e22e19f
Merge branch 'master' into feature/cyclegan_v3
mcgibbon Sep 26, 2022
682c353
use time-interpolated C384 data to match C48 time grid
mcgibbon Sep 27, 2022
38dae3c
Merge branch 'master' into feature/cyclegan_v3
mcgibbon Oct 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
deduplicate training loop logic
  • Loading branch information
mcgibbon committed Aug 22, 2022
commit 3f22efe5b2619672d06bb5b3d78295138fb6f7b2
4 changes: 2 additions & 2 deletions external/fv3fit/fv3fit/pytorch/graph/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,11 @@ def train_graph_model(
)

if validation_batches is not None:
val_state = get_state(data=validation_batches)
val_state = get_state(data=validation_batches).unbatch()
else:
val_state = None

train_state = get_state(data=train_batches)
train_state = get_state(data=train_batches).unbatch()

train_model = build_model(
hyperparameters.graph_network, n_state=next(iter(train_state)).shape[-1]
Expand Down
179 changes: 91 additions & 88 deletions external/fv3fit/fv3fit/pytorch/training_loop.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import dataclasses
import logging
from typing import Callable, Optional
from typing import Any, Callable, Optional
import numpy as np
import torch
import tensorflow_datasets as tfds
Expand Down Expand Up @@ -52,52 +52,24 @@ def fit_loop(
optimizer: type of optimizer for the model
loss_config: configuration of loss function
"""
train_data = train_data.shuffle(buffer_size=self.shuffle_buffer_size).batch(
self.samples_per_batch

def evaluate_on_batch(batch_state, model):
batch_input = torch.as_tensor(batch_state[0]).float().to(DEVICE)
batch_output = torch.as_tensor(batch_state[1]).float().to(DEVICE)
loss: torch.Tensor = loss_config.loss(model(batch_input), batch_output)
return loss

return _train_loop(
model=train_model,
train_data=train_data,
validation_data=validation_data,
evaluate_on_batch=evaluate_on_batch,
optimizer=optimizer,
n_epoch=self.n_epoch,
shuffle_buffer_size=self.shuffle_buffer_size,
samples_per_batch=self.samples_per_batch,
validation_batch_size=self.validation_batch_size,
)
train_data = tfds.as_numpy(train_data)
if validation_data is not None:
if self.validation_batch_size is None:
validation_batch_size = sequence_size(validation_data)
else:
validation_batch_size = self.validation_batch_size
validation_data = validation_data.batch(validation_batch_size)
validation_data = tfds.as_numpy(validation_data)
min_val_loss = np.inf
best_weights = None
for i in range(1, self.n_epoch + 1): # loop over the dataset multiple times
logger.info("starting epoch %d", i)
train_model = train_model.train()
train_losses = []
for batch_state in train_data:
batch_input = torch.as_tensor(batch_state[0]).float().to(DEVICE)
batch_output = torch.as_tensor(batch_state[1]).float().to(DEVICE)
optimizer.zero_grad()
loss: torch.Tensor = loss_config.loss(
train_model(batch_input), batch_output
)
loss.backward()
train_losses.append(loss)
optimizer.step()
train_loss = torch.mean(torch.stack(train_losses))
logger.info("train loss: %f", train_loss)
if validation_data is not None:
val_model = train_model.eval()
val_losses = []
for batch_state in validation_data:
batch_input = torch.as_tensor(batch_state[0]).float().to(DEVICE)
batch_output = torch.as_tensor(batch_state[1]).float().to(DEVICE)
with torch.no_grad():
val_losses.append(
loss_config.loss(val_model(batch_input), batch_output)
)
val_loss = torch.mean(torch.stack(val_losses))
logger.info("val_loss %f", val_loss)
if val_loss < min_val_loss:
min_val_loss = val_loss
best_weights = train_model.state_dict()
if validation_data is not None:
train_model.load_state_dict(best_weights)


@dataclasses.dataclass
Expand All @@ -113,10 +85,11 @@ class AutoregressiveTrainingConfig:
"""

n_epoch: int = 20
buffer_size: int = 50_000
shuffle_buffer_size: int = 50_000
samples_per_batch: int = 1
save_path: str = "weight.pt"
multistep: int = 1
validation_batch_size: Optional[int] = None

def fit_loop(
self,
Expand All @@ -136,48 +109,28 @@ def fit_loop(
optimizer: type of optimizer for the model
loss_config: configuration of loss function
"""
train_data = (
train_data.unbatch()
.shuffle(buffer_size=self.buffer_size)
.batch(self.samples_per_batch)
)
train_data = tfds.as_numpy(train_data)
if validation_data is not None:
validation_data = validation_data.unbatch()
n_validation = sequence_size(validation_data)
validation_state = (
torch.as_tensor(next(iter(validation_data.batch(n_validation))).numpy())
.float()
.to(DEVICE)

def evaluate_on_batch(batch_state, model):
batch_state = torch.as_tensor(batch_state).float().to(DEVICE)
loss: torch.Tensor = evaluate_model(
batch_state=batch_state,
model=train_model,
multistep=self.multistep,
loss=loss_config.loss,
)
min_val_loss = np.inf
else:
validation_state = None
for _ in range(1, self.n_epoch + 1): # loop over the dataset multiple times
train_model = train_model.train()
for batch_state in train_data:
batch_state = torch.as_tensor(batch_state).float().to(DEVICE)
optimizer.zero_grad()
loss = evaluate_model(
batch_state=batch_state,
model=train_model,
multistep=self.multistep,
loss=loss_config.loss,
)
loss.backward()
optimizer.step()
if validation_state is not None:
val_model = train_model.eval()
with torch.no_grad():
val_loss = evaluate_model(
validation_state,
model=val_model,
multistep=self.multistep,
loss=loss_config.loss,
)
if val_loss < min_val_loss:
min_val_loss = val_loss
torch.save(train_model.state_dict(), self.save_path)
return loss

return _train_loop(
model=train_model,
train_data=train_data,
validation_data=validation_data,
evaluate_on_batch=evaluate_on_batch,
optimizer=optimizer,
n_epoch=self.n_epoch,
shuffle_buffer_size=self.shuffle_buffer_size,
samples_per_batch=self.samples_per_batch,
validation_batch_size=self.validation_batch_size,
)


def evaluate_model(
Expand All @@ -193,3 +146,53 @@ def evaluate_model(
target_state = batch_state[:, step + 1, :]
total_loss += loss(state_snapshot, target_state)
return total_loss / multistep


def _train_loop(
model: torch.nn.Module,
train_data: tf.data.Dataset,
validation_data: tf.data.Dataset,
evaluate_on_batch: Callable[[Any, torch.nn.Module], torch.Tensor],
optimizer: torch.optim.Optimizer,
n_epoch: int,
shuffle_buffer_size: int,
samples_per_batch: int,
validation_batch_size: Optional[int] = None,
):

train_data = train_data.shuffle(buffer_size=shuffle_buffer_size).batch(
samples_per_batch
)
train_data = tfds.as_numpy(train_data)
if validation_data is not None:
if validation_batch_size is None:
validation_batch_size = sequence_size(validation_data)
validation_data = validation_data.batch(validation_batch_size)
validation_data = tfds.as_numpy(validation_data)
min_val_loss = np.inf
best_weights = None
for i in range(1, n_epoch + 1): # loop over the dataset multiple times
logger.info("starting epoch %d", i)
train_model = model.train()
train_losses = []
for batch_state in train_data:
optimizer.zero_grad()
loss = evaluate_on_batch(batch_state, train_model)
loss.backward()
train_losses.append(loss)
optimizer.step()
train_loss = torch.mean(torch.stack(train_losses))
logger.info("train loss: %f", train_loss)
if validation_data is not None:
val_model = model.eval()
val_losses = []
for batch_state in validation_data:
with torch.no_grad():
val_losses.append(evaluate_on_batch(batch_state, val_model))
val_loss = torch.mean(torch.stack(val_losses))
logger.info("val_loss %f", val_loss)
if val_loss < min_val_loss:
min_val_loss = val_loss
best_weights = train_model.state_dict()
if validation_data is not None:
train_model.load_state_dict(best_weights)