From 0b54b78666acc980d93d2a1a19fe08a00d94a5b0 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Wed, 17 Aug 2022 18:08:21 +0000 Subject: [PATCH 01/55] start work on cyclegan, WIP --- external/fv3fit/fv3fit/data/base.py | 1 + external/fv3fit/fv3fit/data/tfdataset.py | 47 +++++++- .../fv3fit/fv3fit/pytorch/cyclegan/network.py | 69 +++++++++++ .../fv3fit/fv3fit/pytorch/cyclegan/train.py | 107 ++++++++++++++++++ 4 files changed, 221 insertions(+), 3 deletions(-) create mode 100644 external/fv3fit/fv3fit/pytorch/cyclegan/network.py create mode 100644 external/fv3fit/fv3fit/pytorch/cyclegan/train.py diff --git a/external/fv3fit/fv3fit/data/base.py b/external/fv3fit/fv3fit/data/base.py index 8085efaa52..d9f31c3770 100644 --- a/external/fv3fit/fv3fit/data/base.py +++ b/external/fv3fit/fv3fit/data/base.py @@ -51,6 +51,7 @@ def tfdataset_loader_from_dict(d: dict) -> TFDatasetLoader: TypeError, ValueError, AttributeError, + RecursionError, dacite.exceptions.MissingValueError, dacite.exceptions.UnexpectedDataError, ): diff --git a/external/fv3fit/fv3fit/data/tfdataset.py b/external/fv3fit/fv3fit/data/tfdataset.py index ca72e32264..7fea830a31 100644 --- a/external/fv3fit/fv3fit/data/tfdataset.py +++ b/external/fv3fit/fv3fit/data/tfdataset.py @@ -1,7 +1,8 @@ +import contextlib import dataclasses -from typing import Mapping, Sequence, Optional +from typing import List, Mapping, Sequence, Optional import tensorflow as tf -from .base import TFDatasetLoader, register_tfdataset_loader +from .base import TFDatasetLoader, register_tfdataset_loader, tfdataset_loader_from_dict import dacite from ..tfdataset import iterable_to_tfdataset import tempfile @@ -56,6 +57,45 @@ def get_n_windows(n_times: int, window_size: int) -> int: return (n_times - 1) // (window_size - 1) +@register_tfdataset_loader +@dataclasses.dataclass +class CycleGANLoader(TFDatasetLoader): + + domain_configs: List[TFDatasetLoader] = dataclasses.field(default_factory=list) + batch_size: int = 1 + + def open_tfdataset( + self, local_download_path: Optional[str], variable_names: Sequence[str], + ) -> tf.data.Dataset: + datasets = [] + for config in self.domain_configs: + datasets.append(config.open_tfdataset(local_download_path, variable_names)) + return tf.data.Dataset.zip(tuple(datasets)) + + @classmethod + def from_dict(cls, d: dict) -> "CycleGANLoader": + with prevent_recursion(): + domain_configs = [ + tfdataset_loader_from_dict(domain_config) + for domain_config in d["domain_configs"] + ] + return CycleGANLoader(domain_configs=domain_configs) + + +RECURSING = False + + +@contextlib.contextmanager +def prevent_recursion(): + global RECURSING + if RECURSING: + raise RecursionError("recursion detected") + else: + RECURSING = True + yield + RECURSING = False + + @register_tfdataset_loader @dataclasses.dataclass class WindowedZarrLoader(TFDatasetLoader): @@ -86,6 +126,7 @@ class WindowedZarrLoader(TFDatasetLoader): variable_configs: Mapping[str, VariableConfig] = dataclasses.field( default_factory=dict ) + batch_size: int = 1 n_windows: Optional[int] = None def open_tfdataset( @@ -105,7 +146,7 @@ def open_tfdataset( # if local_download_path is given, cache on disk if local_download_path is not None: tfdataset = tfdataset.cache(local_download_path) - return tfdataset + return tfdataset.batch(self.batch_size) def _convert_to_tfdataset( self, ds: xr.Dataset, variable_names: Sequence[str], diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/network.py b/external/fv3fit/fv3fit/pytorch/cyclegan/network.py new file mode 100644 index 0000000000..2bf1c2c29a --- /dev/null +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/network.py @@ -0,0 +1,69 @@ +import dataclasses +from typing import Callable +import torch.nn.functional as F +import torch +import torch.nn as nn +from dgl.nn.pytorch import SAGEConv +from ..graph import build_dgl_graph, CubedSphereGraphOperation + + +def relu_activation(): + return nn.ReLU() + + +class ResnetBlock(nn.Module): + def __init__( + self, + n_filters: int, + convolution_factory: Callable[[int], nn.Module], + activation_factory: Callable[[], nn.Module] = relu_activation, + ): + super(ResnetBlock, self).__init__() + self.conv_block = nn.Sequential( + convolution_factory(n_filters), + nn.InstanceNorm2d(n_filters), + activation_factory(), + convolution_factory(n_filters), + nn.InstanceNorm2d(n_filters), + ) + + def forward(self, inputs): + g = self.conv_block(inputs) + # skip-connection + g = torch.concat([g, inputs], dim=-1) + return g + + +class ConvBlock(nn.Module): + def __init__(self): + pass + + def forward(self, inputs): + pass + + +class CycleGenerator(nn.Module): + def __init__(self, config, n_features_in: int, n_features_out: int): + super(CycleGenerator, self).__init__() + self.conv1 = CubedSphereGraphOperation( + SAGEConv(n_features_in, config.n_hidden, config.aggregator) + ) + self.config = config + + def forward(self, inputs): + for _ in range(self.config.num_blocks): + h1 = self.conv1(inputs) + out = self.config.activation(h1) + return out + + +def define_generator(): + pass + + +def define_discriminator(): + pass + + +def define_composite_model(): + pass diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train.py new file mode 100644 index 0000000000..f9775bffb9 --- /dev/null +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train.py @@ -0,0 +1,107 @@ +import tensorflow as tf +import numpy as np +import dataclasses +from fv3fit._shared.training_config import Hyperparameters +from toolz.functoolz import curry +from fv3fit.pytorch.predict import PytorchModel +from fv3fit.pytorch.graph.network import GraphNetwork, GraphNetworkConfig +from fv3fit.pytorch.loss import LossConfig +from fv3fit.pytorch.optimizer import OptimizerConfig +from fv3fit.pytorch.training_loop import TrainingLoopConfig +from fv3fit._shared.scaler import StandardScaler +from ..system import DEVICE + +from fv3fit._shared import register_training_function +from typing import ( + Callable, + List, + Optional, + Sequence, + Set, + Mapping, +) +from fv3fit.tfdataset import select_keys, ensure_nd, apply_to_mapping +from .network import define_generator, define_discriminator, define_composite_model + + +@dataclasses.dataclass +class CycleGANHyperparameters(Hyperparameters): + """ + Args: + state_variables: names of variables to evolve forward in time + optimizer_config: selection of algorithm to be used in gradient descent + graph_network: configuration of graph network + training_loop: configuration of training loop + loss: configuration of loss functions, will be applied separately to + each output variable + """ + + state_variables: List[str] + normalization_fit_samples: int = 50_000 + optimizer_config: OptimizerConfig = dataclasses.field( + default_factory=lambda: OptimizerConfig("AdamW") + ) + graph_network: GraphNetworkConfig = dataclasses.field( + default_factory=lambda: GraphNetworkConfig() + ) + training_loop: TrainingLoopConfig = dataclasses.field( + default_factory=lambda: TrainingLoopConfig() + ) + loss: LossConfig = LossConfig(loss_type="mse") + + @property + def variables(self) -> Set[str]: + return set(self.state_variables) + + +def train( + d_model_A, + d_model_B, + g_model_AtoB, + g_model_BtoA, + c_model_AtoB, + c_model_BtoA, + dataset, + n_batch: int, + n_epochs: int, +): + pass + + +@register_training_function("cyclegan", CycleGANHyperparameters) +def train_cyclegan( + hyperparameters: CycleGANHyperparameters, + train_batches: tf.data.Dataset, + validation_batches: Optional[tf.data.Dataset], +) -> PytorchModel: + # define input shape based on the loaded dataset + image_shape = dataset[0].shape[1:] + # generator: A -> B + g_model_AtoB = define_generator(image_shape) + # generator: B -> A + g_model_BtoA = define_generator(image_shape) + # discriminator: A -> [real/fake] + d_model_A = define_discriminator(image_shape) + # discriminator: B -> [real/fake] + d_model_B = define_discriminator(image_shape) + # composite: A -> B -> [real/fake, A] + c_model_AtoB = define_composite_model( + g_model_AtoB, d_model_B, g_model_BtoA, image_shape + ) + # composite: B -> A -> [real/fake, B] + c_model_BtoA = define_composite_model( + g_model_BtoA, d_model_A, g_model_AtoB, image_shape + ) + + # train models + train( + d_model_A, + d_model_B, + g_model_AtoB, + g_model_BtoA, + c_model_AtoB, + c_model_BtoA, + dataset, + n_batch=1, + n_epochs=n_epochs, + ) From 3fe295b6914c9327a43ac89a671ac5119bd6cda3 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Thu, 18 Aug 2022 20:18:56 +0000 Subject: [PATCH 02/55] still WIP, working on evaluation --- external/fv3fit/fv3fit/data/base.py | 1 + external/fv3fit/fv3fit/data/tfdataset.py | 43 ++- external/fv3fit/fv3fit/pytorch/__init__.py | 4 +- .../fv3fit/pytorch/cyclegan/__init__.py | 1 + .../fv3fit/fv3fit/pytorch/cyclegan/network.py | 248 +++++++++++++--- .../fv3fit/fv3fit/pytorch/cyclegan/train.py | 276 +++++++++++++----- external/fv3fit/fv3fit/pytorch/graph/train.py | 6 +- external/fv3fit/fv3fit/pytorch/loss.py | 8 +- external/fv3fit/fv3fit/pytorch/predict.py | 226 +++++++++++--- .../fv3fit/fv3fit/pytorch/training_loop.py | 46 +-- external/fv3fit/fv3fit/tfdataset.py | 39 +++ external/fv3fit/fv3fit/train.py | 2 +- external/fv3fit/tests/pytorch/test_model.py | 6 +- .../fv3fit/tests/training/test_autoencoder.py | 125 ++++++++ 14 files changed, 832 insertions(+), 199 deletions(-) create mode 100644 external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py create mode 100644 external/fv3fit/tests/training/test_autoencoder.py diff --git a/external/fv3fit/fv3fit/data/base.py b/external/fv3fit/fv3fit/data/base.py index d9f31c3770..0dd0477863 100644 --- a/external/fv3fit/fv3fit/data/base.py +++ b/external/fv3fit/fv3fit/data/base.py @@ -52,6 +52,7 @@ def tfdataset_loader_from_dict(d: dict) -> TFDatasetLoader: ValueError, AttributeError, RecursionError, + KeyError, dacite.exceptions.MissingValueError, dacite.exceptions.UnexpectedDataError, ): diff --git a/external/fv3fit/fv3fit/data/tfdataset.py b/external/fv3fit/fv3fit/data/tfdataset.py index 7fea830a31..86403a7b16 100644 --- a/external/fv3fit/fv3fit/data/tfdataset.py +++ b/external/fv3fit/fv3fit/data/tfdataset.py @@ -4,11 +4,12 @@ import tensorflow as tf from .base import TFDatasetLoader, register_tfdataset_loader, tfdataset_loader_from_dict import dacite -from ..tfdataset import iterable_to_tfdataset +from ..tfdataset import generator_to_tfdataset import tempfile import xarray as xr import numpy as np -from fv3fit._shared.stacking import stack +from fv3fit._shared.stacking import stack, SAMPLE_DIM_NAME +from toolz import curry @dataclasses.dataclass @@ -28,16 +29,22 @@ def __post_init__(self): raise TypeError("times must be one of 'window' or 'start'") def get_record(self, name: str, ds: xr.Dataset, unstacked_dims: Sequence[str]): + for dim in unstacked_dims[:-1]: + if dim not in ds[name].dims: + raise ValueError("variable {} has no dimension {}".format(name, dim)) if self.times == "start": ds = ds.isel(time=0) - data = stack(ds[name], unstacked_dims=unstacked_dims).values + dims = [d for d in unstacked_dims if d in ds[name].dims] + data = ds[name].transpose(*dims).values return data def open_zarr_using_filecache(url: str): cachedir = tempfile.mkdtemp() return xr.open_zarr( - "filecache::" + url, storage_options={"filecache": {"cache_storage": cachedir}} + "filecache::" + url, + storage_options={"filecache": {"cache_storage": cachedir}}, + decode_times=False, ) @@ -160,7 +167,7 @@ def _convert_to_tfdataset( variable name to variable value, and each value is a tensor whose first dimension is the batch dimension """ - tfdataset = iterable_to_tfdataset( + tfdataset = generator_to_tfdataset( records( n_windows=self.n_windows, window_size=self.window_size, @@ -189,14 +196,18 @@ def records( variable_configs: Mapping[str, VariableConfig], unstacked_dims: Sequence[str], ): - n_times = ds.dims["time"] - if n_windows is None: - n_windows = get_n_windows(n_times, window_size) - starts = np.random.randint(0, n_times - window_size, n_windows) - for i_start in starts: - record = {} - window_ds = ds.isel(time=range(i_start, i_start + window_size)) - for name in variable_names: - config = variable_configs.get(name, default_variable_config) - record[name] = config.get_record(name, window_ds, unstacked_dims) - yield record + def generator(): + nonlocal n_windows + n_times = ds.dims["time"] + if n_windows is None: + n_windows = get_n_windows(n_times, window_size) + starts = np.random.randint(0, n_times - window_size, n_windows) + for i_start in starts: + record = {} + window_ds = ds.isel(time=range(i_start, i_start + window_size)) + for name in variable_names: + config = variable_configs.get(name, default_variable_config) + record[name] = config.get_record(name, window_ds, unstacked_dims) + yield record + + return generator diff --git a/external/fv3fit/fv3fit/pytorch/__init__.py b/external/fv3fit/fv3fit/pytorch/__init__.py index aeca5579c5..0cf16f3b1b 100644 --- a/external/fv3fit/fv3fit/pytorch/__init__.py +++ b/external/fv3fit/fv3fit/pytorch/__init__.py @@ -1,3 +1,5 @@ from .graph import GraphHyperparameters, train_graph_model, GraphNetworkConfig from .system import DEVICE -from .predict import PytorchModel +from .predict import PytorchAutoregressor, PytorchPredictor +from .cyclegan import train_autoencoder, AutoencoderHyperparameters, GeneratorConfig +from .optimizer import OptimizerConfig diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py b/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py new file mode 100644 index 0000000000..4f2f47a491 --- /dev/null +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py @@ -0,0 +1 @@ +from .train import train_autoencoder, AutoencoderHyperparameters, GeneratorConfig diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/network.py b/external/fv3fit/fv3fit/pytorch/cyclegan/network.py index 2bf1c2c29a..6bca76382a 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/network.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/network.py @@ -1,69 +1,249 @@ import dataclasses -from typing import Callable +from typing import Callable, Literal, Optional, Protocol import torch.nn.functional as F import torch import torch.nn as nn from dgl.nn.pytorch import SAGEConv from ..graph import build_dgl_graph, CubedSphereGraphOperation +from toolz import curry def relu_activation(): return nn.ReLU() +def tanh_activation(): + return nn.Tanh() + + +def leakyrelu_activation(**kwargs): + def factory(): + return nn.LeakyReLU(**kwargs) + + return factory + + +def no_activation(): + return nn.Identity() + + +class ConvolutionFactory(Protocol): + def __call__(self, in_channels: int, out_channels: int) -> nn.Module: + ... + + +class ConvolutionFactoryFactory(Protocol): + def __call__( + self, + kernel_size: int, + padding: int, + stride: int = 1, + stride_type: Literal["regular", "transpose"] = "regular", + bias: bool = True, + ) -> ConvolutionFactory: + ... + + +@curry +def strided_convolution( + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + padding: int, + bias: bool = True, +): + return nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ) + + +@curry +def transpose_convolution( + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int, + padding: int, + output_padding: int, + bias: bool = True, +): + return nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=(padding, padding), + output_padding=output_padding, + bias=bias, + ) + + +@curry +def flat_convolution(in_channels: int, out_channels: int, kernel_size: int, bias=True): + return nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding="same", + bias=bias, + ) + + class ResnetBlock(nn.Module): def __init__( self, n_filters: int, - convolution_factory: Callable[[int], nn.Module], + convolution_factory: ConvolutionFactory, activation_factory: Callable[[], nn.Module] = relu_activation, ): super(ResnetBlock, self).__init__() self.conv_block = nn.Sequential( - convolution_factory(n_filters), - nn.InstanceNorm2d(n_filters), - activation_factory(), - convolution_factory(n_filters), - nn.InstanceNorm2d(n_filters), + ConvBlock( + in_channels=n_filters, + out_channels=n_filters, + convolution_factory=convolution_factory, + activation_factory=activation_factory, + ), + ConvBlock( + in_channels=n_filters, + out_channels=n_filters, + convolution_factory=convolution_factory, + activation_factory=no_activation, + ), ) + self.identity = nn.Identity() def forward(self, inputs): g = self.conv_block(inputs) - # skip-connection - g = torch.concat([g, inputs], dim=-1) - return g + return g + self.identity(inputs) class ConvBlock(nn.Module): - def __init__(self): - pass + def __init__( + self, + in_channels: int, + out_channels: int, + convolution_factory: ConvolutionFactory, + activation_factory: Callable[[], nn.Module] = relu_activation, + ): + super(ConvBlock, self).__init__() + self.conv_block = nn.Sequential( + convolution_factory(in_channels=in_channels, out_channels=out_channels), + nn.InstanceNorm2d(out_channels), + activation_factory(), + ) def forward(self, inputs): - pass - - -class CycleGenerator(nn.Module): - def __init__(self, config, n_features_in: int, n_features_out: int): - super(CycleGenerator, self).__init__() - self.conv1 = CubedSphereGraphOperation( - SAGEConv(n_features_in, config.n_hidden, config.aggregator) + return self.conv_block(inputs) + + +class Discriminator(nn.Module): + def __init__(self, in_channels: int, n_convolutions: int, max_filters: int): + super(Discriminator, self).__init__() + # max_filters = min_filters * 2 ** (n_convolutions - 1), therefore + min_filters = int(max_filters / 2 ** (n_convolutions - 1)) + convs = [ + ConvBlock( + in_channels=in_channels, + out_channels=min_filters, + convolution_factory=strided_convolution( + kernel_size=3, stride=2, padding=1 + ), + activation_factory=leakyrelu_activation(alpha=0.2), + ) + ] + for i in range(1, n_convolutions): + convs.append( + ConvBlock( + in_channels=min_filters * 2 ** (i - 1), + out_channels=min_filters * 2 ** i, + convolution_factory=strided_convolution( + kernel_size=3, stride=2, padding=1 + ), + activation_factory=leakyrelu_activation(alpha=0.2), + ) + ) + final_conv = ConvBlock( + in_channels=max_filters, + out_channels=max_filters, + convolution_factory=flat_convolution(kernel_size=3), + activation_factory=leakyrelu_activation(alpha=0.2), + ) + patch_output = ConvBlock( + in_channels=max_filters, + out_channels=1, + convolution_factory=flat_convolution(kernel_size=3), + activation_factory=leakyrelu_activation(alpha=0.2), ) - self.config = config + self._sequential = nn.Sequential(*convs, final_conv, patch_output) def forward(self, inputs): - for _ in range(self.config.num_blocks): - h1 = self.conv1(inputs) - out = self.config.activation(h1) - return out + return self._sequential(inputs) -def define_generator(): - pass - - -def define_discriminator(): - pass - +class Generator(nn.Module): + def __init__( + self, channels: int, n_convolutions: int, n_resnet: int, max_filters: int, + ): + super(Generator, self).__init__() + min_filters = int(max_filters / 2 ** (n_convolutions - 1)) + convs = [ + ConvBlock( + in_channels=channels, + out_channels=min_filters, + convolution_factory=flat_convolution(kernel_size=7), + activation_factory=relu_activation, + ) + ] + for i in range(1, n_convolutions): + convs.append( + ConvBlock( + in_channels=min_filters * 2 ** (i - 1), + out_channels=min_filters * 2 ** i, + convolution_factory=strided_convolution( + kernel_size=3, stride=2, padding=1 + ), + activation_factory=relu_activation, + ) + ) + resnet_blocks = [ + ResnetBlock( + n_filters=max_filters, + convolution_factory=flat_convolution(kernel_size=3), + activation_factory=relu_activation, + ) + for i in range(n_resnet) + ] + transpose_convs = [] + for i in range(1, n_convolutions): + transpose_convs.append( + ConvBlock( + in_channels=max_filters // (2 ** (i - 1)), + out_channels=max_filters // (2 ** i), + convolution_factory=transpose_convolution( + kernel_size=3, stride=2, padding=1, output_padding=1 + ), + activation_factory=relu_activation, + ) + ) + out_conv = ConvBlock( + in_channels=min_filters, + out_channels=channels, + convolution_factory=flat_convolution(kernel_size=7), + activation_factory=tanh_activation, + ) + self._sequential = nn.Sequential( + *convs, *resnet_blocks, *transpose_convs, out_conv + ) -def define_composite_model(): - pass + def forward(self, inputs: torch.Tensor): + # data will have channels last, model requires channels first + inputs = inputs.permute(0, 3, 1, 2) + outputs: torch.Tensor = self._sequential(inputs) + return outputs.permute(0, 2, 3, 1) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train.py index f9775bffb9..b08cb6019c 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train.py @@ -1,107 +1,233 @@ +from fv3fit._shared.hyperparameters import Hyperparameters +import dataclasses import tensorflow as tf -import numpy as np import dataclasses -from fv3fit._shared.training_config import Hyperparameters -from toolz.functoolz import curry -from fv3fit.pytorch.predict import PytorchModel -from fv3fit.pytorch.graph.network import GraphNetwork, GraphNetworkConfig +from fv3fit.pytorch.predict import PytorchPredictor from fv3fit.pytorch.loss import LossConfig from fv3fit.pytorch.optimizer import OptimizerConfig -from fv3fit.pytorch.training_loop import TrainingLoopConfig -from fv3fit._shared.scaler import StandardScaler +import tensorflow_datasets as tfds +from fv3fit.tfdataset import sequence_size +import torch +import numpy as np from ..system import DEVICE from fv3fit._shared import register_training_function from typing import ( - Callable, List, Optional, - Sequence, - Set, - Mapping, + Tuple, ) -from fv3fit.tfdataset import select_keys, ensure_nd, apply_to_mapping -from .network import define_generator, define_discriminator, define_composite_model +from fv3fit.tfdataset import ensure_nd, apply_to_mapping +from .network import Generator +from fv3fit.pytorch.graph.train import ( + get_scalers, + get_mapping_scale_func, + get_Xy_dataset, +) +from toolz import curry +import logging + +logger = logging.getLogger(__name__) @dataclasses.dataclass -class CycleGANHyperparameters(Hyperparameters): - """ - Args: - state_variables: names of variables to evolve forward in time - optimizer_config: selection of algorithm to be used in gradient descent - graph_network: configuration of graph network - training_loop: configuration of training loop - loss: configuration of loss functions, will be applied separately to - each output variable - """ +class GeneratorConfig: + n_convolutions: int = 3 + n_resnet: int = 3 + max_filters: int = 256 + + +@dataclasses.dataclass +class AutoencoderHyperparameters(Hyperparameters): state_variables: List[str] normalization_fit_samples: int = 50_000 optimizer_config: OptimizerConfig = dataclasses.field( default_factory=lambda: OptimizerConfig("AdamW") ) - graph_network: GraphNetworkConfig = dataclasses.field( - default_factory=lambda: GraphNetworkConfig() + generator: GeneratorConfig = dataclasses.field( + default_factory=lambda: GeneratorConfig() ) - training_loop: TrainingLoopConfig = dataclasses.field( + training_loop: "TrainingLoopConfig" = dataclasses.field( default_factory=lambda: TrainingLoopConfig() ) loss: LossConfig = LossConfig(loss_type="mse") @property - def variables(self) -> Set[str]: - return set(self.state_variables) - - -def train( - d_model_A, - d_model_B, - g_model_AtoB, - g_model_BtoA, - c_model_AtoB, - c_model_BtoA, - dataset, - n_batch: int, - n_epochs: int, -): - pass - - -@register_training_function("cyclegan", CycleGANHyperparameters) -def train_cyclegan( - hyperparameters: CycleGANHyperparameters, + def variables(self): + return tuple(self.state_variables) + + +@dataclasses.dataclass +class TrainingLoopConfig: + """ + Attributes: + epochs: number of times to run through the batches when training + shuffle_buffer_size: size of buffer to use when shuffling samples + save_path: name of the file to save the best weights + do_multistep: if True, use multistep loss calculation + multistep: number of steps in multistep loss calculation + validation_batch_size: if given, process validation data in batches + of this size, otherwise process it all at once + """ + + n_epoch: int = 20 + shuffle_buffer_size: int = 10 + samples_per_batch: int = 1 + save_path: str = "weight.pt" + validation_batch_size: Optional[int] = None + + def fit_loop( + self, + train_model: torch.nn.Module, + train_data: tf.data.Dataset, + validation_data: tf.data.Dataset, + optimizer: torch.optim.Optimizer, + loss_config: LossConfig, + ) -> None: + """ + Args: + train_model: pytorch model to train + train_data: training dataset containing samples to be passed to the model, + samples should be tuples with two tensors of shape [sample, time, tile, x, y, z] + validation_data: validation dataset containing samples to be passed + to the model, samples should be tuples with two tensors + of shape [sample, time, tile, x, y, z] + optimizer: type of optimizer for the model + loss_config: configuration of loss function + """ + train_data = ( + flatten_dims(train_data) + .shuffle(buffer_size=self.shuffle_buffer_size) + .batch(self.samples_per_batch) + ) + train_data = tfds.as_numpy(train_data) + if validation_data is not None: + if self.validation_batch_size is None: + validation_batch_size = sequence_size(validation_data) + else: + validation_batch_size = self.validation_batch_size + validation_data = flatten_dims(validation_data).batch(validation_batch_size) + validation_data = tfds.as_numpy(validation_data) + min_val_loss = np.inf + best_weights = None + for i in range(1, self.n_epoch + 1): # loop over the dataset multiple times + logger.info("starting epoch %d", i) + train_model = train_model.train() + train_losses = [] + for batch_state in train_data: + batch_input = torch.as_tensor(batch_state[0]).float().to(DEVICE) + batch_output = torch.as_tensor(batch_state[1]).float().to(DEVICE) + optimizer.zero_grad() + loss: torch.Tensor = loss_config.loss( + train_model(batch_input), batch_output + ) + loss.backward() + train_losses.append(loss) + optimizer.step() + train_loss = torch.mean(torch.stack(train_losses)) + logger.info("train loss: %f", train_loss) + if validation_data is not None: + val_model = train_model.eval() + val_losses = [] + for batch_state in validation_data: + batch_input = torch.as_tensor(batch_state[0]).float().to(DEVICE) + batch_output = torch.as_tensor(batch_state[1]).float().to(DEVICE) + with torch.no_grad(): + val_losses.append( + loss_config.loss(val_model(batch_input), batch_output) + ) + val_loss = torch.mean(torch.stack(val_losses)) + logger.info("val_loss %f", val_loss) + if val_loss < min_val_loss: + min_val_loss = val_loss + best_weights = train_model.state_dict() + if validation_data is not None: + train_model.load_state_dict(best_weights) + + +@curry +def define_noisy_input(data: tf.Tensor, stdev=0.1) -> Tuple[tf.Tensor, tf.Tensor]: + """ + Given data, return a tuple with a noisy version of the data and the original data. + """ + noisy = data + tf.random.normal(shape=tf.shape(data), stddev=stdev) + return (noisy, data) + + +def flatten_dims(dataset: tf.data.Dataset) -> tf.data.Dataset: + """Transform [batch, time, tile, x, y, z] to [sample, x, y, z]""" + return dataset.unbatch().unbatch().unbatch() + + +@register_training_function("autoencoder", AutoencoderHyperparameters) +def train_autoencoder( + hyperparameters: AutoencoderHyperparameters, train_batches: tf.data.Dataset, validation_batches: Optional[tf.data.Dataset], -) -> PytorchModel: - # define input shape based on the loaded dataset - image_shape = dataset[0].shape[1:] - # generator: A -> B - g_model_AtoB = define_generator(image_shape) - # generator: B -> A - g_model_BtoA = define_generator(image_shape) - # discriminator: A -> [real/fake] - d_model_A = define_discriminator(image_shape) - # discriminator: B -> [real/fake] - d_model_B = define_discriminator(image_shape) - # composite: A -> B -> [real/fake, A] - c_model_AtoB = define_composite_model( - g_model_AtoB, d_model_B, g_model_BtoA, image_shape +) -> PytorchPredictor: + """ + Train a denoising autoencoder for cubed sphere data. + + Args: + hyperparameters: configuration for training + train_batches: training data, as a dataset of Mapping[str, tf.Tensor] + where each tensor has dimensions [sample, time, tile, x, y(, z)] + validation_batches: validation data, as a dataset of Mapping[str, tf.Tensor] + where each tensor has dimensions [sample, time, tile, x, y(, z)] + """ + train_batches = train_batches.map(apply_to_mapping(ensure_nd(6))) + sample_batch = next( + iter(train_batches.unbatch().batch(hyperparameters.normalization_fit_samples)) ) - # composite: B -> A -> [real/fake, B] - c_model_BtoA = define_composite_model( - g_model_BtoA, d_model_A, g_model_AtoB, image_shape + + scalers = get_scalers(sample_batch) + mapping_scale_func = get_mapping_scale_func(scalers) + + get_state = curry(get_Xy_dataset)( + state_variables=hyperparameters.state_variables, + n_dims=6, # [batch, time, tile, x, y, z] + mapping_scale_func=mapping_scale_func, ) - # train models - train( - d_model_A, - d_model_B, - g_model_AtoB, - g_model_BtoA, - c_model_AtoB, - c_model_BtoA, - dataset, - n_batch=1, - n_epochs=n_epochs, + if validation_batches is not None: + val_state = get_state(data=validation_batches) + else: + val_state = None + + train_state = get_state(data=train_batches) + + train_model = build_model( + hyperparameters.generator, n_state=next(iter(train_state)).shape[-1] + ) + print(train_model) + optimizer = hyperparameters.optimizer_config + + train_state = train_state.map(define_noisy_input(stdev=0.5)) + if validation_batches is not None: + val_state = val_state.map(define_noisy_input(stdev=0.5)) + + hyperparameters.training_loop.fit_loop( + train_model=train_model, + train_data=train_state, + validation_data=val_state, + optimizer=optimizer.instance(train_model.parameters()), + loss_config=hyperparameters.loss, + ) + + predictor = PytorchPredictor( + input_variables=hyperparameters.state_variables, + output_variables=hyperparameters.state_variables, + model=train_model, + scalers=scalers, + ) + return predictor + + +def build_model(config: GeneratorConfig, n_state: int) -> Generator: + return Generator( + channels=n_state, + n_convolutions=config.n_convolutions, + n_resnet=config.n_resnet, + max_filters=config.max_filters, ) diff --git a/external/fv3fit/fv3fit/pytorch/graph/train.py b/external/fv3fit/fv3fit/pytorch/graph/train.py index 7cb29358df..ca02982776 100644 --- a/external/fv3fit/fv3fit/pytorch/graph/train.py +++ b/external/fv3fit/fv3fit/pytorch/graph/train.py @@ -3,7 +3,7 @@ import dataclasses from fv3fit._shared.training_config import Hyperparameters from toolz.functoolz import curry -from fv3fit.pytorch.predict import PytorchModel +from fv3fit.pytorch.predict import PytorchAutoregressor from fv3fit.pytorch.graph.network import GraphNetwork, GraphNetworkConfig from fv3fit.pytorch.loss import LossConfig from fv3fit.pytorch.optimizer import OptimizerConfig @@ -82,7 +82,7 @@ def train_graph_model( hyperparameters: GraphHyperparameters, train_batches: tf.data.Dataset, validation_batches: Optional[tf.data.Dataset], -) -> PytorchModel: +) -> PytorchAutoregressor: """ Train a graph network. @@ -127,7 +127,7 @@ def train_graph_model( loss_config=hyperparameters.loss, ) - predictor = PytorchModel( + predictor = PytorchAutoregressor( state_variables=hyperparameters.state_variables, model=train_model, scalers=scalers, diff --git a/external/fv3fit/fv3fit/pytorch/loss.py b/external/fv3fit/fv3fit/pytorch/loss.py index 94dfba9025..8c7618ce2e 100644 --- a/external/fv3fit/fv3fit/pytorch/loss.py +++ b/external/fv3fit/fv3fit/pytorch/loss.py @@ -16,6 +16,9 @@ def __post_init__(self): raise ValueError( f"loss_type must be 'mse' or 'mae', got '{self.loss_type}'" ) + + @property + def loss(self) -> torch.nn.Module: """ Returns the loss function described by the configuration. @@ -25,8 +28,9 @@ def __post_init__(self): loss: pytorch loss function """ if self.loss_type == "mse": - self.loss = torch.nn.MSELoss() + loss = torch.nn.MSELoss() elif self.loss_type == "mae": - self.loss = torch.nn.L1Loss() + loss = torch.nn.L1Loss() else: raise NotImplementedError(f"loss_type {self.loss_type} is not implemented") + return loss diff --git a/external/fv3fit/fv3fit/pytorch/predict.py b/external/fv3fit/fv3fit/pytorch/predict.py index f126384537..3c0461fac1 100644 --- a/external/fv3fit/fv3fit/pytorch/predict.py +++ b/external/fv3fit/fv3fit/pytorch/predict.py @@ -1,15 +1,27 @@ -from fv3fit._shared.predictor import Dumpable, Loadable +from fv3fit._shared.predictor import Dumpable, Loadable, Predictor from .._shared.scaler import StandardScaler import numpy as np import torch import torch.nn as nn import xarray as xr -from typing import Hashable, Iterable, Mapping, Tuple, TypeVar, Type, IO, Protocol +from typing import ( + Any, + Hashable, + Iterable, + Mapping, + Sequence, + Tuple, + TypeVar, + Type, + IO, + Protocol, +) import zipfile from fv3fit.pytorch.system import DEVICE import os import yaml import vcm +from fv3fit._shared import io L = TypeVar("L", bound="BinaryLoadable") @@ -43,7 +55,94 @@ def load_mapping(cls: Type[L], f: IO[bytes]) -> Mapping[Hashable, L]: return {name: cls.load(archive.open(name, "r")) for name in archive.namelist()} -class PytorchModel(Dumpable, Loadable): +@io.register("pytorch_predictor") +class PytorchPredictor(Predictor): + + _MODEL_FILENAME = "weight.pt" + _CONFIG_FILENAME = "config.yaml" + _SCALERS_FILENAME = "scalers.zip" + + def __init__( + self, + input_variables: Iterable[Hashable], + output_variables: Iterable[Hashable], + model: nn.Module, + scalers: Mapping[Hashable, StandardScaler], + ): + """Initialize the predictor + Args: + state_variables: names of state variables + model: pytorch model to wrap + scalers: normalization data for each of the state variables + """ + self.input_variables = input_variables + self.output_variables = output_variables + self.model = model + self.scalers = scalers + + def predict(self, X: xr.Dataset) -> xr.Dataset: + """ + Predict an output xarray dataset from an input xarray dataset. + + Note that returned datasets include the initial state of the prediction, + where by definition the model will have perfect skill. + + Args: + X: input dataset + timesteps: number of timesteps to predict + + Returns: + predicted: predicted timeseries data + reference: true timeseries data from the input dataset + """ + tensor = self.pack_to_tensor(X) + with torch.no_grad(): + outputs = self.model(tensor) + predicted = self.unpack_tensor(outputs) + return predicted + + def pack_to_tensor(self, X: xr.Dataset) -> torch.Tensor: + timeseries_packed = _pack_to_tensor( + ds=X, + timesteps=1, + state_variables=self.input_variables, + scalers=self.scalers, + ) + # dimensions are [window, timestep, tile, x, y, z], + # we must select first timestep and squash all but (x, y, z) into a sample dim + first_times = timeseries_packed[:, 0, :] + return torch.reshape( + first_times, + (first_times.shape[0] * first_times.shape[1],) + + tuple(first_times.shape[2:]), + ) + + def unpack_tensor(self, data: torch.Tensor) -> xr.Dataset: + data = torch.reshape(data, (-1, 6) + tuple(data.shape[1:])) + return _unpack_tensor( + data, + varnames=self.output_variables, + scalers=self.scalers, + dims=["time", "tile", "x", "y", "z"], + ) + + @classmethod + def load(cls, path: str) -> "PytorchAutoregressor": + """Load a serialized model from a directory.""" + return _load_pytorch(cls, path) + + def dump(self, path: str) -> None: + _dump_pytorch(self, path) + + def get_config(self): + return { + "input_variables": self.input_variables, + "output_variables": self.output_variables, + } + + +@io.register("pytorch_autoregressor") +class PytorchAutoregressor(Dumpable, Loadable): _MODEL_FILENAME = "weight.pt" _CONFIG_FILENAME = "config.yaml" @@ -99,25 +198,12 @@ def unpack_tensor(self, data: torch.Tensor) -> xr.Dataset: Returns: xarray dataset with values of shape [window, time, tile, x, y, feature] """ - i_feature = 0 - data_vars = {} - all_dims = ["window", "time", "tile", "x", "y", "z"] - for varname in self.state_variables: - mean_value = self.scalers[varname].mean - if mean_value is None: - raise RuntimeError(f"scaler for {varname} has not been fit") - else: - if len(mean_value.shape) > 0 and mean_value.shape[0] > 1: - n_features = mean_value.shape[0] - var_data = data[..., i_feature : i_feature + n_features] - else: - n_features = 1 - var_data = data[..., i_feature] - data_vars[varname] = xr.DataArray( - data=var_data, dims=all_dims[: len(var_data.shape)] - ) - i_feature += n_features - return xr.Dataset(data_vars=data_vars) + return _unpack_tensor( + data, + varnames=self.state_variables, + scalers=self.scalers, + dims=["window", "time", "tile", "x", "y", "z"], + ) def step_model(self, state: torch.Tensor, timesteps: int): """ @@ -160,30 +246,58 @@ def predict(self, X: xr.Dataset, timesteps: int) -> Tuple[xr.Dataset, xr.Dataset return predicted, reference @classmethod - def load(cls, path: str) -> "PytorchModel": + def load(cls, path: str) -> "PytorchAutoregressor": """Load a serialized model from a directory.""" - fs = vcm.get_fs(path) - model_filename = os.path.join(path, cls._MODEL_FILENAME) - with fs.open(model_filename, "rb") as f: - model = torch.load(f) - with fs.open(os.path.join(path, cls._SCALERS_FILENAME), "rb") as f: - scalers = load_mapping(StandardScaler, f) - with open(os.path.join(path, cls._CONFIG_FILENAME), "r") as f: - config = yaml.load(f, Loader=yaml.Loader) - obj = cls( - state_variables=config["state_variables"], model=model, scalers=scalers, - ) - return obj + return _load_pytorch(cls, path) + + def dump(self, path: str) -> None: + _dump_pytorch(self, path) + + def get_config(self) -> Mapping[str, Any]: + return {"state_variables": self.state_variables} + + +class PytorchDumpable(Protocol): + _MODEL_FILENAME: str + _SCALERS_FILENAME: str + _CONFIG_FILENAME: str + state_variables: Iterable[Hashable] + scalers: Mapping[Hashable, StandardScaler] + model: torch.nn.Module def dump(self, path: str) -> None: - fs = vcm.get_fs(path) - model_filename = os.path.join(path, self._MODEL_FILENAME) - with fs.open(model_filename, "wb") as f: - torch.save(self.model, model_filename) - with fs.open(os.path.join(path, self._SCALERS_FILENAME), "wb") as f: - dump_mapping(self.scalers, f) - with fs.open(os.path.join(path, self._CONFIG_FILENAME), "w") as f: - f.write(yaml.dump({"state_variables": self.state_variables})) + ... + + def get_config(self) -> Mapping[str, Any]: + """ + Returns additional keyword arguments needed to initialize this object. + """ + ... + + +def _load_pytorch(cls: Type[PytorchDumpable], path: str): + """Load a serialized model from a directory.""" + fs = vcm.get_fs(path) + model_filename = os.path.join(path, cls._MODEL_FILENAME) + with fs.open(model_filename, "rb") as f: + model = torch.load(f) + with fs.open(os.path.join(path, cls._SCALERS_FILENAME), "rb") as f: + scalers = load_mapping(StandardScaler, f) + with open(os.path.join(path, cls._CONFIG_FILENAME), "r") as f: + config = yaml.load(f, Loader=yaml.Loader) + obj = cls(model=model, scalers=scalers, **config) + return obj + + +def _dump_pytorch(obj: PytorchDumpable, path: str) -> None: + fs = vcm.get_fs(path) + model_filename = os.path.join(path, obj._MODEL_FILENAME) + with fs.open(model_filename, "wb") as f: + torch.save(obj.model, model_filename) + with fs.open(os.path.join(path, obj._SCALERS_FILENAME), "wb") as f: + dump_mapping(obj.scalers, f) + with fs.open(os.path.join(path, obj._CONFIG_FILENAME), "w") as f: + f.write(yaml.dump(obj.get_config())) def _pack_to_tensor( @@ -237,3 +351,29 @@ def _pack_to_tensor( all_data.append(data) concatenated_data = np.concatenate(all_data, axis=-1) return torch.as_tensor(concatenated_data).float().to(DEVICE) + + +def _unpack_tensor( + data: torch.Tensor, + varnames: Iterable[Hashable], + scalers: Mapping[Hashable, StandardScaler], + dims: Sequence[Hashable], +) -> xr.Dataset: + i_feature = 0 + data_vars = {} + for varname in varnames: + mean_value = scalers[varname].mean + if mean_value is None: + raise RuntimeError(f"scaler for {varname} has not been fit") + else: + if len(mean_value.shape) > 0 and mean_value.shape[0] > 1: + n_features = mean_value.shape[0] + var_data = data[..., i_feature : i_feature + n_features] + else: + n_features = 1 + var_data = data[..., i_feature] + data_vars[varname] = xr.DataArray( + data=var_data, dims=dims[: len(var_data.shape)] + ) + i_feature += n_features + return xr.Dataset(data_vars=data_vars) diff --git a/external/fv3fit/fv3fit/pytorch/training_loop.py b/external/fv3fit/fv3fit/pytorch/training_loop.py index fca9482c41..f164825cbd 100644 --- a/external/fv3fit/fv3fit/pytorch/training_loop.py +++ b/external/fv3fit/fv3fit/pytorch/training_loop.py @@ -1,6 +1,6 @@ import dataclasses import logging -from typing import Callable +from typing import Callable, Optional import numpy as np import torch import tensorflow_datasets as tfds @@ -32,7 +32,7 @@ def fit_loop( self, train_model: torch.nn.Module, train_data: tf.data.Dataset, - validation_data: tf.data.Dataset, + validation_data: Optional[tf.data.Dataset], optimizer: torch.optim.Optimizer, loss_config: LossConfig, ) -> None: @@ -52,14 +52,17 @@ def fit_loop( .batch(self.samples_per_batch) ) train_data = tfds.as_numpy(train_data) - validation_data = validation_data.unbatch() - n_validation = sequence_size(validation_data) - validation_state = ( - torch.as_tensor(next(iter(validation_data.batch(n_validation))).numpy()) - .float() - .to(DEVICE) - ) - min_val_loss = np.inf + if validation_data is not None: + validation_data = validation_data.unbatch() + n_validation = sequence_size(validation_data) + validation_state = ( + torch.as_tensor(next(iter(validation_data.batch(n_validation))).numpy()) + .float() + .to(DEVICE) + ) + min_val_loss = np.inf + else: + validation_state = None for _ in range(1, self.n_epoch + 1): # loop over the dataset multiple times train_model = train_model.train() for batch_state in train_data: @@ -73,17 +76,18 @@ def fit_loop( ) loss.backward() optimizer.step() - val_model = train_model.eval() - with torch.no_grad(): - val_loss = evaluate_model( - validation_state, - model=val_model, - multistep=self.multistep, - loss=loss_config.loss, - ) - if val_loss < min_val_loss: - min_val_loss = val_loss - torch.save(train_model.state_dict(), self.save_path) + if validation_state is not None: + val_model = train_model.eval() + with torch.no_grad(): + val_loss = evaluate_model( + validation_state, + model=val_model, + multistep=self.multistep, + loss=loss_config.loss, + ) + if val_loss < min_val_loss: + min_val_loss = val_loss + torch.save(train_model.state_dict(), self.save_path) def evaluate_model( diff --git a/external/fv3fit/fv3fit/tfdataset.py b/external/fv3fit/fv3fit/tfdataset.py index 11e2eaad90..1b51590eb8 100644 --- a/external/fv3fit/fv3fit/tfdataset.py +++ b/external/fv3fit/fv3fit/tfdataset.py @@ -2,6 +2,7 @@ from fv3fit._shared.packer import clip_sample import tensorflow as tf from typing import ( + Generator, Hashable, Iterable, Mapping, @@ -141,6 +142,44 @@ def process_shape(shape): ) +def generator_to_tfdataset( + source: Generator, varying_first_dim: bool = False, +) -> tf.data.Dataset: + """ + A general function to convert from a generator into a tensorflow dataset. + + Args: + source: data items to be included in the dataset + varying_first_dim: if True, the first dimension of the produced tensors + can be of varying length + """ + + try: + sample = next(iter(source())) + except StopIteration: + raise NotImplementedError("can only make tfdataset from non-empty batches") + + # if batches have different numbers of samples, we need to set the dimension size + # to None to indicate the size can be different across generated tensors + if varying_first_dim: + + def process_shape(shape): + return (None,) + shape[1:] + + else: + + def process_shape(shape): + return shape + + return tf.data.Dataset.from_generator( + source, + output_signature={ + key: tf.TensorSpec(process_shape(val.shape), dtype=val.dtype) + for key, val in sample.items() + }, + ) + + def dataset_to_tensor_dict(ds): return {key: tf.convert_to_tensor(val) for key, val in ds.items()} diff --git a/external/fv3fit/fv3fit/train.py b/external/fv3fit/fv3fit/train.py index cf139484fe..11cc193b67 100644 --- a/external/fv3fit/fv3fit/train.py +++ b/external/fv3fit/fv3fit/train.py @@ -181,7 +181,7 @@ def main(args, unknown_args=None): os.makedirs("artifacts", exist_ok=True) logging.basicConfig( level=logging.INFO, - format="%(asctime)s [%(levelname)s] %(message)s", + format="%(asctime)s [%(levelname)s] %(filename)s::L%(lineno)d : %(message)s", handlers=[ logging.FileHandler("artifacts/training.log"), logging.StreamHandler(), diff --git a/external/fv3fit/tests/pytorch/test_model.py b/external/fv3fit/tests/pytorch/test_model.py index af4f805f28..71d5d68898 100644 --- a/external/fv3fit/tests/pytorch/test_model.py +++ b/external/fv3fit/tests/pytorch/test_model.py @@ -1,4 +1,4 @@ -from fv3fit.pytorch import PytorchModel +from fv3fit.pytorch import PytorchAutoregressor from fv3fit.pytorch.predict import _pack_to_tensor from torch import nn import fv3fit @@ -28,13 +28,13 @@ def test_pytorch_model_dump_load(tmpdir): state_variables = ["u"] data = np.random.uniform(low=-1, high=1, size=(n_samples, n_features)) scaler.fit(data) - model = PytorchModel( + model = PytorchAutoregressor( state_variables=state_variables, model=nn.Linear(n_features, n_features), scalers={"u": scaler}, ) model.dump(str(tmpdir)) - reloaded_model = PytorchModel.load(str(tmpdir)) + reloaded_model = PytorchAutoregressor.load(str(tmpdir)) assert model.state_variables == reloaded_model.state_variables assert same_state(model.model, reloaded_model.model) assert model.scalers == reloaded_model.scalers diff --git a/external/fv3fit/tests/training/test_autoencoder.py b/external/fv3fit/tests/training/test_autoencoder.py new file mode 100644 index 0000000000..542a894249 --- /dev/null +++ b/external/fv3fit/tests/training/test_autoencoder.py @@ -0,0 +1,125 @@ +import matplotlib + +matplotlib.use("TkAgg") +import numpy as np +import xarray as xr +from typing import Sequence +from fv3fit.pytorch.cyclegan import AutoencoderHyperparameters, train_autoencoder +from fv3fit.pytorch.cyclegan.train import TrainingLoopConfig +from fv3fit.tfdataset import iterable_to_tfdataset +import collections +import os +import fv3fit.pytorch +import matplotlib.pyplot as plt + + +def get_tfdataset(nsamples, nbatch, ntime, nx, ny, nz): + ntile = 6 + + grid_x = np.arange(0, nx, dtype=np.float32) + grid_y = np.arange(0, ny, dtype=np.float32) + grid_x, grid_y = np.broadcast_arrays(grid_x[:, None], grid_y[None, :]) + grid_x = grid_x[None, None, None, :, :, None] + grid_y = grid_y[None, None, None, :, :, None] + + def sample_iterator(): + # creates a timeseries where each time is the negation of time before it + for _ in range(nsamples): + ax = np.random.uniform(0.5, 1.5, size=(nbatch, 1, ntile, nz))[ + :, :, :, None, None, : + ] + bx = np.random.uniform(6, 8, size=(nbatch, 1, ntile, nz))[ + :, :, :, None, None, : + ] + cx = np.random.uniform(0.0, 2 * np.pi, size=(nbatch, 1, ntile, nz))[ + :, :, :, None, None, : + ] + ay = np.random.uniform(0.5, 1.5, size=(nbatch, 1, ntile, nz))[ + :, :, :, None, None, : + ] + by = np.random.uniform(6, 8, size=(nbatch, 1, ntile, nz))[ + :, :, :, None, None, : + ] + cy = np.random.uniform(0.0, 2 * np.pi, size=(nbatch, 1, ntile, nz))[ + :, :, :, None, None, : + ] + a = ( + ax + * np.sin(grid_x / (2 * np.pi * bx) + cx) + * ay + * np.sin(grid_y / (2 * np.pi * by) + cy) + ) + start = { + "a": a.astype(np.float32), + "b": -a[..., 0].astype(np.float32), + } + out = {key: [value] for key, value in start.items()} + for _ in range(ntime - 1): + for varname in start.keys(): + out[varname].append(out[varname][-1] * -1.0) + for varname in out: + out[varname] = np.concatenate(out[varname], axis=1) + yield out + + return iterable_to_tfdataset(list(sample_iterator())) + + +def tfdataset_to_xr_dataset(tfdataset, dims: Sequence[str]): + """ + Returns a [time, tile, x, y, z] dataset needed for evaluation. + + Assumes input samples have shape [sample, time, tile, x, y(, z)], will + concatenate samples along the time axis before returning. + """ + data_sequences = collections.defaultdict(list) + for sample in tfdataset: + for name, value in sample.items(): + data_sequences[name].append( + value.numpy().reshape( + [value.shape[0] * value.shape[1]] + list(value.shape[2:]) + ) + ) + data_vars = {} + for name in data_sequences: + data = np.concatenate(data_sequences[name]) + data_vars[name] = xr.DataArray(data, dims=dims[: len(data.shape)]) + return xr.Dataset(data_vars) + + +def test_autoencoder(tmpdir): + matplotlib.use("TkAgg") + # run the test in a temporary directory to delete artifacts when done + os.chdir(tmpdir) + # need a larger nx, ny for the sample data here since we're training + # on whether we can autoencode sin waves, and need to resolve full cycles + sizes = {"nbatch": 2, "ntime": 2, "nx": 32, "ny": 32, "nz": 2} + state_variables = ["a", "b"] + train_tfdataset = get_tfdataset(nsamples=20, **sizes) + val_tfdataset = get_tfdataset(nsamples=3, **sizes) + hyperparameters = AutoencoderHyperparameters( + state_variables=state_variables, + generator=fv3fit.pytorch.GeneratorConfig( + n_convolutions=2, n_resnet=3, max_filters=32 + ), + training_loop=TrainingLoopConfig(n_epoch=1, samples_per_batch=2), + optimizer_config=fv3fit.pytorch.OptimizerConfig(name="Adam",), + ) + predictor = train_autoencoder(hyperparameters, train_tfdataset, val_tfdataset) + # for test, need one continuous series so we consistently flip sign + test_sizes = {"nbatch": 1, "ntime": 100, "nx": 8, "ny": 8, "nz": 2} + test_xrdataset = tfdataset_to_xr_dataset( + get_tfdataset(nsamples=10, **test_sizes), dims=["time", "tile", "x", "y", "z"] + ) + predicted = predictor.predict(test_xrdataset) + reference = test_xrdataset + fig, ax = plt.subplots(1, 2) + ax[0].imshow(reference["a"][0, 0, :, :, 0].values) + ax[1].imshow(predicted["a"][0, 0, :, :, 0].values) + plt.tight_layout() + plt.show() + bias = predicted.isel(time=1) - reference.isel(time=1) + mean_bias: xr.Dataset = bias.mean() + rmse: xr.Dataset = (bias ** 2).mean() ** 0.5 + for varname in state_variables: + assert np.abs(mean_bias[varname]) < 0.1 + assert rmse[varname] < 0.1 From e24e894039c86c47265394b262e2fd2420e3f338 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Thu, 18 Aug 2022 17:37:43 -0700 Subject: [PATCH 03/55] autoencoder can overfit, something wrong in evaluation, wip --- .../fv3fit/fv3fit/pytorch/cyclegan/network.py | 141 ++++++++++++++++-- .../fv3fit/fv3fit/pytorch/cyclegan/train.py | 9 +- external/fv3fit/fv3fit/pytorch/predict.py | 43 +++--- .../fv3fit/tests/training/test_autoencoder.py | 76 ++++++++-- 4 files changed, 216 insertions(+), 53 deletions(-) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/network.py b/external/fv3fit/fv3fit/pytorch/cyclegan/network.py index 6bca76382a..a6cf5b45ec 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/network.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/network.py @@ -8,8 +8,11 @@ from toolz import curry -def relu_activation(): - return nn.ReLU() +def relu_activation(**kwargs): + def relu_factory(): + return nn.ReLU(**kwargs) + + return relu_factory def tanh_activation(): @@ -17,10 +20,10 @@ def tanh_activation(): def leakyrelu_activation(**kwargs): - def factory(): + def leakyrelu_factory(): return nn.LeakyReLU(**kwargs) - return factory + return leakyrelu_factory def no_activation(): @@ -100,7 +103,7 @@ def __init__( self, n_filters: int, convolution_factory: ConvolutionFactory, - activation_factory: Callable[[], nn.Module] = relu_activation, + activation_factory: Callable[[], nn.Module] = relu_activation(), ): super(ResnetBlock, self).__init__() self.conv_block = nn.Sequential( @@ -130,7 +133,7 @@ def __init__( in_channels: int, out_channels: int, convolution_factory: ConvolutionFactory, - activation_factory: Callable[[], nn.Module] = relu_activation, + activation_factory: Callable[[], nn.Module] = relu_activation(), ): super(ConvBlock, self).__init__() self.conv_block = nn.Sequential( @@ -155,7 +158,7 @@ def __init__(self, in_channels: int, n_convolutions: int, max_filters: int): convolution_factory=strided_convolution( kernel_size=3, stride=2, padding=1 ), - activation_factory=leakyrelu_activation(alpha=0.2), + activation_factory=leakyrelu_activation(alpha=0.2,), ) ] for i in range(1, n_convolutions): @@ -166,20 +169,20 @@ def __init__(self, in_channels: int, n_convolutions: int, max_filters: int): convolution_factory=strided_convolution( kernel_size=3, stride=2, padding=1 ), - activation_factory=leakyrelu_activation(alpha=0.2), + activation_factory=leakyrelu_activation(alpha=0.2,), ) ) final_conv = ConvBlock( in_channels=max_filters, out_channels=max_filters, convolution_factory=flat_convolution(kernel_size=3), - activation_factory=leakyrelu_activation(alpha=0.2), + activation_factory=leakyrelu_activation(alpha=0.2,), ) patch_output = ConvBlock( in_channels=max_filters, out_channels=1, convolution_factory=flat_convolution(kernel_size=3), - activation_factory=leakyrelu_activation(alpha=0.2), + activation_factory=leakyrelu_activation(alpha=0.2,), ) self._sequential = nn.Sequential(*convs, final_conv, patch_output) @@ -187,7 +190,7 @@ def forward(self, inputs): return self._sequential(inputs) -class Generator(nn.Module): +class SequentialGenerator(nn.Module): def __init__( self, channels: int, n_convolutions: int, n_resnet: int, max_filters: int, ): @@ -198,7 +201,7 @@ def __init__( in_channels=channels, out_channels=min_filters, convolution_factory=flat_convolution(kernel_size=7), - activation_factory=relu_activation, + activation_factory=relu_activation(), ) ] for i in range(1, n_convolutions): @@ -209,14 +212,14 @@ def __init__( convolution_factory=strided_convolution( kernel_size=3, stride=2, padding=1 ), - activation_factory=relu_activation, + activation_factory=relu_activation(), ) ) resnet_blocks = [ ResnetBlock( n_filters=max_filters, convolution_factory=flat_convolution(kernel_size=3), - activation_factory=relu_activation, + activation_factory=relu_activation(), ) for i in range(n_resnet) ] @@ -229,21 +232,127 @@ def __init__( convolution_factory=transpose_convolution( kernel_size=3, stride=2, padding=1, output_padding=1 ), - activation_factory=relu_activation, + activation_factory=relu_activation(), ) ) out_conv = ConvBlock( in_channels=min_filters, out_channels=channels, convolution_factory=flat_convolution(kernel_size=7), - activation_factory=tanh_activation, + activation_factory=no_activation, ) self._sequential = nn.Sequential( *convs, *resnet_blocks, *transpose_convs, out_conv ) + self._identity = nn.Identity() def forward(self, inputs: torch.Tensor): # data will have channels last, model requires channels first + # return self._identity(inputs) inputs = inputs.permute(0, 3, 1, 2) outputs: torch.Tensor = self._sequential(inputs) return outputs.permute(0, 2, 3, 1) + + +class Generator(nn.Module): + def __init__( + self, channels: int, n_convolutions: int, n_resnet: int, max_filters: int, + ): + super(Generator, self).__init__() + + def resnet(in_channels: int): + resnet_blocks = [ + ResnetBlock( + n_filters=in_channels, + convolution_factory=flat_convolution(kernel_size=3), + activation_factory=relu_activation(), + ) + for _ in range(n_resnet) + ] + return nn.Sequential(*resnet_blocks) + + def down(in_channels: int, out_channels: int): + return ConvBlock( + in_channels=in_channels, + out_channels=out_channels, + convolution_factory=strided_convolution( + kernel_size=3, stride=2, padding=1 + ), + activation_factory=relu_activation(), + ) + + def up(in_channels: int, out_channels: int): + return ConvBlock( + in_channels=in_channels, + out_channels=out_channels, + convolution_factory=transpose_convolution( + kernel_size=3, stride=2, padding=1, output_padding=1 + ), + activation_factory=relu_activation(), + ) + + min_filters = int(max_filters / 2 ** (n_convolutions - 1)) + self._first_conv = ConvBlock( + in_channels=channels, + out_channels=min_filters, + convolution_factory=flat_convolution(kernel_size=3), + activation_factory=relu_activation(), + ) + + self._unet = UNet( + down_factory=down, + up_factory=up, + bottom_factory=resnet, + depth=n_convolutions - 1, + in_channels=min_filters, + ) + + # self._out_conv = ConvBlock( + # in_channels=2 *min_filters, + # out_channels=channels, + # convolution_factory=flat_convolution(kernel_size=3), + # activation_factory=no_activation, + # ) + + self._out_conv = flat_convolution(kernel_size=3)( + in_channels=2 * min_filters, out_channels=channels + ) + + def forward(self, inputs): + # data will have channels last, model requires channels first + # return self._identity(inputs) + inputs = inputs.permute(0, 3, 1, 2) + x = self._first_conv(inputs) + x = self._unet(x) + outputs = self._out_conv(x) + return outputs.permute(0, 2, 3, 1) + + +class UNet(nn.Module): + def __init__( + self, down_factory, up_factory, bottom_factory, depth: int, in_channels: int, + ): + super(UNet, self).__init__() + lower_channels = 2 * in_channels + self._down = down_factory(in_channels=in_channels, out_channels=lower_channels) + self._up = up_factory(in_channels=lower_channels, out_channels=in_channels) + if depth == 1: + self._lower = bottom_factory(in_channels=lower_channels) + elif depth <= 0: + raise ValueError(f"depth must be at least 1, got {depth}") + else: + self._lower = UNet( + down_factory, + up_factory, + bottom_factory, + depth=depth - 1, + in_channels=lower_channels, + ) + + def forward(self, inputs): + x = self._down(inputs) + x = self._lower(x) + x = self._up(x) + # skip connection + x = torch.concat([x, inputs], dim=1) + return x diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train.py index b08cb6019c..a5bfb69c83 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train.py @@ -52,6 +52,7 @@ class AutoencoderHyperparameters(Hyperparameters): default_factory=lambda: TrainingLoopConfig() ) loss: LossConfig = LossConfig(loss_type="mse") + noise_amount: float = 0.5 @property def variables(self): @@ -203,9 +204,13 @@ def train_autoencoder( print(train_model) optimizer = hyperparameters.optimizer_config - train_state = train_state.map(define_noisy_input(stdev=0.5)) + train_state = train_state.map( + define_noisy_input(stdev=hyperparameters.noise_amount) + ) if validation_batches is not None: - val_state = val_state.map(define_noisy_input(stdev=0.5)) + val_state = val_state.map( + define_noisy_input(stdev=hyperparameters.noise_amount) + ) hyperparameters.training_loop.fit_loop( train_model=train_model, diff --git a/external/fv3fit/fv3fit/pytorch/predict.py b/external/fv3fit/fv3fit/pytorch/predict.py index 3c0461fac1..0dc876dfa3 100644 --- a/external/fv3fit/fv3fit/pytorch/predict.py +++ b/external/fv3fit/fv3fit/pytorch/predict.py @@ -99,22 +99,22 @@ def predict(self, X: xr.Dataset) -> xr.Dataset: with torch.no_grad(): outputs = self.model(tensor) predicted = self.unpack_tensor(outputs) + import pdb + + pdb.set_trace() return predicted def pack_to_tensor(self, X: xr.Dataset) -> torch.Tensor: - timeseries_packed = _pack_to_tensor( + packed = _pack_to_tensor( ds=X, - timesteps=1, + timesteps=0, state_variables=self.input_variables, scalers=self.scalers, ) - # dimensions are [window, timestep, tile, x, y, z], - # we must select first timestep and squash all but (x, y, z) into a sample dim - first_times = timeseries_packed[:, 0, :] + # dimensions are [time, tile, x, y, z], + # we must combine [time, tile] into one sample dimension return torch.reshape( - first_times, - (first_times.shape[0] * first_times.shape[1],) - + tuple(first_times.shape[2:]), + packed, (packed.shape[0] * packed.shape[1],) + tuple(packed.shape[2:]), ) def unpack_tensor(self, data: torch.Tensor) -> xr.Dataset: @@ -325,10 +325,11 @@ def _pack_to_tensor( expected_dims = ("time", "tile", "x", "y", "z") ds = ds.transpose(*expected_dims) - n_times = ds.time.size - n_windows = int((n_times - 1) // timesteps) - # times need to be evenly divisible into windows - ds = ds.isel(time=slice(None, n_windows * timesteps + 1)) + if timesteps > 0: + n_times = ds.time.size + n_windows = int((n_times - 1) // timesteps) + # times need to be evenly divisible into windows + ds = ds.isel(time=slice(None, n_windows * timesteps + 1)) all_data = [] for varname in state_variables: var_dims = ds[varname].dims @@ -338,13 +339,16 @@ def _pack_to_tensor( ) data = ds[varname].values normalized_data = scalers[varname].normalize(data) - # segment time axis into windows, excluding last time of each window - data = normalized_data[:-1, :].reshape(n_windows, timesteps, *data.shape[1:]) - # append first time of next window to end of each window - end_data = np.concatenate( - [data[1:, :1, :], normalized_data[None, -1:, :]], axis=0 - ) - data = np.concatenate([data, end_data], axis=1) + if timesteps > 0: + # segment time axis into windows, excluding last time of each window + data = normalized_data[:-1, :].reshape( + n_windows, timesteps, *data.shape[1:] + ) + # append first time of next window to end of each window + end_data = np.concatenate( + [data[1:, :1, :], normalized_data[None, -1:, :]], axis=0 + ) + data = np.concatenate([data, end_data], axis=1) if "z" not in var_dims: # need a z-axis for concatenation into feature axis data = data[..., np.newaxis] @@ -372,6 +376,7 @@ def _unpack_tensor( else: n_features = 1 var_data = data[..., i_feature] + var_data = scalers[varname].denormalize(var_data) data_vars[varname] = xr.DataArray( data=var_data, dims=dims[: len(var_data.shape)] ) diff --git a/external/fv3fit/tests/training/test_autoencoder.py b/external/fv3fit/tests/training/test_autoencoder.py index 542a894249..9d6209a78c 100644 --- a/external/fv3fit/tests/training/test_autoencoder.py +++ b/external/fv3fit/tests/training/test_autoencoder.py @@ -1,6 +1,3 @@ -import matplotlib - -matplotlib.use("TkAgg") import numpy as np import xarray as xr from typing import Sequence @@ -28,7 +25,7 @@ def sample_iterator(): ax = np.random.uniform(0.5, 1.5, size=(nbatch, 1, ntile, nz))[ :, :, :, None, None, : ] - bx = np.random.uniform(6, 8, size=(nbatch, 1, ntile, nz))[ + bx = np.random.uniform(8, 16, size=(nbatch, 1, ntile, nz))[ :, :, :, None, None, : ] cx = np.random.uniform(0.0, 2 * np.pi, size=(nbatch, 1, ntile, nz))[ @@ -37,7 +34,7 @@ def sample_iterator(): ay = np.random.uniform(0.5, 1.5, size=(nbatch, 1, ntile, nz))[ :, :, :, None, None, : ] - by = np.random.uniform(6, 8, size=(nbatch, 1, ntile, nz))[ + by = np.random.uniform(8, 16, size=(nbatch, 1, ntile, nz))[ :, :, :, None, None, : ] cy = np.random.uniform(0.0, 2 * np.pi, size=(nbatch, 1, ntile, nz))[ @@ -45,9 +42,9 @@ def sample_iterator(): ] a = ( ax - * np.sin(grid_x / (2 * np.pi * bx) + cx) + * np.sin(2 * np.pi * grid_x / bx + cx) * ay - * np.sin(grid_y / (2 * np.pi * by) + cy) + * np.sin(2 * np.pi * grid_y / by + cy) ) start = { "a": a.astype(np.float32), @@ -87,12 +84,12 @@ def tfdataset_to_xr_dataset(tfdataset, dims: Sequence[str]): def test_autoencoder(tmpdir): - matplotlib.use("TkAgg") # run the test in a temporary directory to delete artifacts when done os.chdir(tmpdir) # need a larger nx, ny for the sample data here since we're training # on whether we can autoencode sin waves, and need to resolve full cycles - sizes = {"nbatch": 2, "ntime": 2, "nx": 32, "ny": 32, "nz": 2} + nx, ny = 32, 32 + sizes = {"nbatch": 2, "ntime": 2, "nx": nx, "ny": ny, "nz": 2} state_variables = ["a", "b"] train_tfdataset = get_tfdataset(nsamples=20, **sizes) val_tfdataset = get_tfdataset(nsamples=3, **sizes) @@ -101,25 +98,72 @@ def test_autoencoder(tmpdir): generator=fv3fit.pytorch.GeneratorConfig( n_convolutions=2, n_resnet=3, max_filters=32 ), - training_loop=TrainingLoopConfig(n_epoch=1, samples_per_batch=2), + training_loop=TrainingLoopConfig(n_epoch=5, samples_per_batch=2), optimizer_config=fv3fit.pytorch.OptimizerConfig(name="Adam",), ) predictor = train_autoencoder(hyperparameters, train_tfdataset, val_tfdataset) # for test, need one continuous series so we consistently flip sign - test_sizes = {"nbatch": 1, "ntime": 100, "nx": 8, "ny": 8, "nz": 2} + test_sizes = {"nbatch": 1, "ntime": 100, "nx": nx, "ny": ny, "nz": 2} test_xrdataset = tfdataset_to_xr_dataset( get_tfdataset(nsamples=10, **test_sizes), dims=["time", "tile", "x", "y", "z"] ) predicted = predictor.predict(test_xrdataset) reference = test_xrdataset - fig, ax = plt.subplots(1, 2) - ax[0].imshow(reference["a"][0, 0, :, :, 0].values) - ax[1].imshow(predicted["a"][0, 0, :, :, 0].values) - plt.tight_layout() - plt.show() + for i in range(6): + fig, ax = plt.subplots(1, 2) + vmin = reference["a"][0, i, :, :, 0].values.min() + vmax = reference["a"][0, i, :, :, 0].values.max() + ax[0].imshow(reference["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) + ax[1].imshow(predicted["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) + plt.tight_layout() + plt.show() bias = predicted.isel(time=1) - reference.isel(time=1) mean_bias: xr.Dataset = bias.mean() rmse: xr.Dataset = (bias ** 2).mean() ** 0.5 for varname in state_variables: assert np.abs(mean_bias[varname]) < 0.1 assert rmse[varname] < 0.1 + + +def test_autoencoder_overfit(tmpdir): + # run the test in a temporary directory to delete artifacts when done + os.chdir(tmpdir) + # need a larger nx, ny for the sample data here since we're training + # on whether we can autoencode sin waves, and need to resolve full cycles + nx, ny = 32, 32 + sizes = {"nbatch": 1, "ntime": 1, "nx": nx, "ny": ny, "nz": 2} + state_variables = ["a", "b"] + train_tfdataset = get_tfdataset(nsamples=1, **sizes) + train_tfdataset = train_tfdataset.cache() # needed to keep sample identical + hyperparameters = AutoencoderHyperparameters( + state_variables=state_variables, + generator=fv3fit.pytorch.GeneratorConfig( + n_convolutions=2, n_resnet=1, max_filters=32 + ), + training_loop=TrainingLoopConfig(n_epoch=1000, samples_per_batch=6), + optimizer_config=fv3fit.pytorch.OptimizerConfig(name="Adam",), + noise_amount=0.0, + ) + predictor = train_autoencoder( + hyperparameters, train_tfdataset, validation_batches=None + ) + # for test, need one continuous series so we consistently flip sign + test_xrdataset = tfdataset_to_xr_dataset( + train_tfdataset, dims=["time", "tile", "x", "y", "z"] + ) + predicted = predictor.predict(test_xrdataset) + reference = test_xrdataset + for i in range(6): + fig, ax = plt.subplots(1, 2) + vmin = reference["a"][0, i, :, :, 0].values.min() + vmax = reference["a"][0, i, :, :, 0].values.max() + ax[0].imshow(reference["a"][0, i, :, :, 0].values) # , vmin=vmin, vmax=vmax) + ax[1].imshow(predicted["a"][0, i, :, :, 0].values) # , vmin=vmin, vmax=vmax) + plt.tight_layout() + plt.show() + bias = predicted - reference + mean_bias: xr.Dataset = bias.mean() + rmse: xr.Dataset = (bias ** 2).mean() ** 0.5 + for varname in state_variables: + assert np.abs(mean_bias[varname]) < 0.1 + assert rmse[varname] < 0.1 From abed41e6d01c3bd75e8170512cbdd6021b55f74b Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Thu, 18 Aug 2022 18:06:03 -0700 Subject: [PATCH 04/55] disable instance normalization, model now trains --- .../fv3fit/fv3fit/_shared/training_config.py | 2 ++ .../fv3fit/fv3fit/pytorch/cyclegan/network.py | 10 +++---- .../fv3fit/fv3fit/pytorch/cyclegan/train.py | 4 +-- external/fv3fit/fv3fit/pytorch/predict.py | 4 --- external/fv3fit/tests/pytorch/test_model.py | 20 ++++++++++++- .../fv3fit/tests/training/test_autoencoder.py | 30 +++++++++++-------- 6 files changed, 45 insertions(+), 25 deletions(-) diff --git a/external/fv3fit/fv3fit/_shared/training_config.py b/external/fv3fit/fv3fit/_shared/training_config.py index edf02e2397..02683dd47d 100644 --- a/external/fv3fit/fv3fit/_shared/training_config.py +++ b/external/fv3fit/fv3fit/_shared/training_config.py @@ -29,6 +29,7 @@ import random import warnings import vcm +import torch # TODO: move all keras configs under fv3fit.keras import tensorflow as tf @@ -47,6 +48,7 @@ def set_random_seed(seed: Union[float, int] = 0): np.random.seed(seed + 1) random.seed(seed + 2) tf.random.set_seed(seed + 3) + torch.manual_seed(seed + 4) # TODO: delete this routine by refactoring the tests to no longer depend on it diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/network.py b/external/fv3fit/fv3fit/pytorch/cyclegan/network.py index a6cf5b45ec..43237ee3d9 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/network.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/network.py @@ -138,7 +138,7 @@ def __init__( super(ConvBlock, self).__init__() self.conv_block = nn.Sequential( convolution_factory(in_channels=in_channels, out_channels=out_channels), - nn.InstanceNorm2d(out_channels), + # nn.InstanceNorm2d(out_channels), activation_factory(), ) @@ -158,7 +158,7 @@ def __init__(self, in_channels: int, n_convolutions: int, max_filters: int): convolution_factory=strided_convolution( kernel_size=3, stride=2, padding=1 ), - activation_factory=leakyrelu_activation(alpha=0.2,), + activation_factory=leakyrelu_activation(alpha=0.2), ) ] for i in range(1, n_convolutions): @@ -169,20 +169,20 @@ def __init__(self, in_channels: int, n_convolutions: int, max_filters: int): convolution_factory=strided_convolution( kernel_size=3, stride=2, padding=1 ), - activation_factory=leakyrelu_activation(alpha=0.2,), + activation_factory=leakyrelu_activation(alpha=0.2), ) ) final_conv = ConvBlock( in_channels=max_filters, out_channels=max_filters, convolution_factory=flat_convolution(kernel_size=3), - activation_factory=leakyrelu_activation(alpha=0.2,), + activation_factory=leakyrelu_activation(alpha=0.2), ) patch_output = ConvBlock( in_channels=max_filters, out_channels=1, convolution_factory=flat_convolution(kernel_size=3), - activation_factory=leakyrelu_activation(alpha=0.2,), + activation_factory=leakyrelu_activation(alpha=0.2), ) self._sequential = nn.Sequential(*convs, final_conv, patch_output) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train.py index a5bfb69c83..5d722beb04 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train.py @@ -143,8 +143,8 @@ def fit_loop( if val_loss < min_val_loss: min_val_loss = val_loss best_weights = train_model.state_dict() - if validation_data is not None: - train_model.load_state_dict(best_weights) + # if validation_data is not None: + # train_model.load_state_dict(best_weights) @curry diff --git a/external/fv3fit/fv3fit/pytorch/predict.py b/external/fv3fit/fv3fit/pytorch/predict.py index 0dc876dfa3..1c2162c2aa 100644 --- a/external/fv3fit/fv3fit/pytorch/predict.py +++ b/external/fv3fit/fv3fit/pytorch/predict.py @@ -99,9 +99,6 @@ def predict(self, X: xr.Dataset) -> xr.Dataset: with torch.no_grad(): outputs = self.model(tensor) predicted = self.unpack_tensor(outputs) - import pdb - - pdb.set_trace() return predicted def pack_to_tensor(self, X: xr.Dataset) -> torch.Tensor: @@ -376,7 +373,6 @@ def _unpack_tensor( else: n_features = 1 var_data = data[..., i_feature] - var_data = scalers[varname].denormalize(var_data) data_vars[varname] = xr.DataArray( data=var_data, dims=dims[: len(var_data.shape)] ) diff --git a/external/fv3fit/tests/pytorch/test_model.py b/external/fv3fit/tests/pytorch/test_model.py index 71d5d68898..66adc3a79f 100644 --- a/external/fv3fit/tests/pytorch/test_model.py +++ b/external/fv3fit/tests/pytorch/test_model.py @@ -1,4 +1,4 @@ -from fv3fit.pytorch import PytorchAutoregressor +from fv3fit.pytorch import PytorchAutoregressor, PytorchPredictor from fv3fit.pytorch.predict import _pack_to_tensor from torch import nn import fv3fit @@ -72,3 +72,21 @@ def _helper_test_pack_to_tensor_one_var(data): np.testing.assert_almost_equal(tensor[2, -1, :], data[6, :]) # check a full window np.testing.assert_almost_equal(tensor[2, :], data[4:7, :]) + + +def test_predictor_identity(): + ntime, ntiles, nx, ny, nz = 11, 6, 8, 8, 3 + data = np.random.uniform(low=10, high=20, size=(ntime, ntiles, nx, ny, nz)) + ds = xr.Dataset( + data_vars={"u": xr.DataArray(data, dims=["time", "tile", "x", "y", "z"])} + ) + scaler = fv3fit.StandardScaler() + scaler.fit(data) + predictor = PytorchPredictor( + input_variables=["u"], + output_variables=["u"], + model=nn.Identity(), + scalers={"u": scaler}, + ) + prediction = predictor.predict(ds) + np.testing.assert_almost_equal(prediction.u.values, data, decimal=5) diff --git a/external/fv3fit/tests/training/test_autoencoder.py b/external/fv3fit/tests/training/test_autoencoder.py index 9d6209a78c..e8521f7ee3 100644 --- a/external/fv3fit/tests/training/test_autoencoder.py +++ b/external/fv3fit/tests/training/test_autoencoder.py @@ -8,6 +8,7 @@ import os import fv3fit.pytorch import matplotlib.pyplot as plt +import fv3fit def get_tfdataset(nsamples, nbatch, ntime, nx, ny, nz): @@ -84,6 +85,7 @@ def tfdataset_to_xr_dataset(tfdataset, dims: Sequence[str]): def test_autoencoder(tmpdir): + fv3fit.set_random_seed(0) # run the test in a temporary directory to delete artifacts when done os.chdir(tmpdir) # need a larger nx, ny for the sample data here since we're training @@ -100,12 +102,13 @@ def test_autoencoder(tmpdir): ), training_loop=TrainingLoopConfig(n_epoch=5, samples_per_batch=2), optimizer_config=fv3fit.pytorch.OptimizerConfig(name="Adam",), + noise_amount=0.5, ) predictor = train_autoencoder(hyperparameters, train_tfdataset, val_tfdataset) # for test, need one continuous series so we consistently flip sign test_sizes = {"nbatch": 1, "ntime": 100, "nx": nx, "ny": ny, "nz": 2} test_xrdataset = tfdataset_to_xr_dataset( - get_tfdataset(nsamples=10, **test_sizes), dims=["time", "tile", "x", "y", "z"] + get_tfdataset(nsamples=1, **test_sizes), dims=["time", "tile", "x", "y", "z"] ) predicted = predictor.predict(test_xrdataset) reference = test_xrdataset @@ -119,13 +122,14 @@ def test_autoencoder(tmpdir): plt.show() bias = predicted.isel(time=1) - reference.isel(time=1) mean_bias: xr.Dataset = bias.mean() - rmse: xr.Dataset = (bias ** 2).mean() ** 0.5 + mse: xr.Dataset = (bias ** 2).mean() ** 0.5 for varname in state_variables: assert np.abs(mean_bias[varname]) < 0.1 - assert rmse[varname] < 0.1 + assert mse[varname] < 0.1 def test_autoencoder_overfit(tmpdir): + fv3fit.set_random_seed(0) # run the test in a temporary directory to delete artifacts when done os.chdir(tmpdir) # need a larger nx, ny for the sample data here since we're training @@ -140,7 +144,7 @@ def test_autoencoder_overfit(tmpdir): generator=fv3fit.pytorch.GeneratorConfig( n_convolutions=2, n_resnet=1, max_filters=32 ), - training_loop=TrainingLoopConfig(n_epoch=1000, samples_per_batch=6), + training_loop=TrainingLoopConfig(n_epoch=100, samples_per_batch=6), optimizer_config=fv3fit.pytorch.OptimizerConfig(name="Adam",), noise_amount=0.0, ) @@ -153,17 +157,17 @@ def test_autoencoder_overfit(tmpdir): ) predicted = predictor.predict(test_xrdataset) reference = test_xrdataset - for i in range(6): - fig, ax = plt.subplots(1, 2) - vmin = reference["a"][0, i, :, :, 0].values.min() - vmax = reference["a"][0, i, :, :, 0].values.max() - ax[0].imshow(reference["a"][0, i, :, :, 0].values) # , vmin=vmin, vmax=vmax) - ax[1].imshow(predicted["a"][0, i, :, :, 0].values) # , vmin=vmin, vmax=vmax) - plt.tight_layout() - plt.show() + # for i in range(6): + # fig, ax = plt.subplots(1, 2) + # vmin = reference["a"][0, i, :, :, 0].values.min() + # vmax = reference["a"][0, i, :, :, 0].values.max() + # ax[0].imshow(reference["a"][0, i, :, :, 0].values) # , vmin=vmin, vmax=vmax) + # ax[1].imshow(predicted["a"][0, i, :, :, 0].values) # , vmin=vmin, vmax=vmax) + # plt.tight_layout() + # plt.show() bias = predicted - reference mean_bias: xr.Dataset = bias.mean() - rmse: xr.Dataset = (bias ** 2).mean() ** 0.5 + rmse: xr.Dataset = (bias ** 2).mean() for varname in state_variables: assert np.abs(mean_bias[varname]) < 0.1 assert rmse[varname] < 0.1 From c453017efa8e32cb6a2b5d1fe3129ed842a4e7c9 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Thu, 18 Aug 2022 18:28:40 -0700 Subject: [PATCH 05/55] restore instance normalization by using skip connection --- .../fv3fit/fv3fit/pytorch/cyclegan/network.py | 20 +++++++++++++------ external/fv3fit/fv3fit/pytorch/predict.py | 6 ++++++ .../fv3fit/tests/training/test_autoencoder.py | 4 ++-- 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/network.py b/external/fv3fit/fv3fit/pytorch/cyclegan/network.py index 43237ee3d9..edda00fe7e 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/network.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/network.py @@ -138,7 +138,7 @@ def __init__( super(ConvBlock, self).__init__() self.conv_block = nn.Sequential( convolution_factory(in_channels=in_channels, out_channels=out_channels), - # nn.InstanceNorm2d(out_channels), + nn.InstanceNorm2d(out_channels), activation_factory(), ) @@ -292,11 +292,19 @@ def up(in_channels: int, out_channels: int): ) min_filters = int(max_filters / 2 ** (n_convolutions - 1)) - self._first_conv = ConvBlock( - in_channels=channels, - out_channels=min_filters, - convolution_factory=flat_convolution(kernel_size=3), - activation_factory=relu_activation(), + + # self._first_conv = ConvBlock( + # in_channels=channels, + # out_channels=min_filters, + # convolution_factory=flat_convolution(kernel_size=3), + # activation_factory=relu_activation(), + # ) + + self._first_conv = nn.Sequential( + flat_convolution(kernel_size=3)( + in_channels=channels, out_channels=min_filters + ), + relu_activation()(), ) self._unet = UNet( diff --git a/external/fv3fit/fv3fit/pytorch/predict.py b/external/fv3fit/fv3fit/pytorch/predict.py index 1c2162c2aa..8c4542c73d 100644 --- a/external/fv3fit/fv3fit/pytorch/predict.py +++ b/external/fv3fit/fv3fit/pytorch/predict.py @@ -99,6 +99,9 @@ def predict(self, X: xr.Dataset) -> xr.Dataset: with torch.no_grad(): outputs = self.model(tensor) predicted = self.unpack_tensor(outputs) + import pdb + + pdb.set_trace() return predicted def pack_to_tensor(self, X: xr.Dataset) -> torch.Tensor: @@ -346,6 +349,8 @@ def _pack_to_tensor( [data[1:, :1, :], normalized_data[None, -1:, :]], axis=0 ) data = np.concatenate([data, end_data], axis=1) + else: + data = normalized_data if "z" not in var_dims: # need a z-axis for concatenation into feature axis data = data[..., np.newaxis] @@ -373,6 +378,7 @@ def _unpack_tensor( else: n_features = 1 var_data = data[..., i_feature] + var_data = scalers[varname].denormalize(var_data) data_vars[varname] = xr.DataArray( data=var_data, dims=dims[: len(var_data.shape)] ) diff --git a/external/fv3fit/tests/training/test_autoencoder.py b/external/fv3fit/tests/training/test_autoencoder.py index e8521f7ee3..9f15f6ede7 100644 --- a/external/fv3fit/tests/training/test_autoencoder.py +++ b/external/fv3fit/tests/training/test_autoencoder.py @@ -23,7 +23,7 @@ def get_tfdataset(nsamples, nbatch, ntime, nx, ny, nz): def sample_iterator(): # creates a timeseries where each time is the negation of time before it for _ in range(nsamples): - ax = np.random.uniform(0.5, 1.5, size=(nbatch, 1, ntile, nz))[ + ax = np.random.uniform(0.1, 1.5, size=(nbatch, 1, ntile, nz))[ :, :, :, None, None, : ] bx = np.random.uniform(8, 16, size=(nbatch, 1, ntile, nz))[ @@ -32,7 +32,7 @@ def sample_iterator(): cx = np.random.uniform(0.0, 2 * np.pi, size=(nbatch, 1, ntile, nz))[ :, :, :, None, None, : ] - ay = np.random.uniform(0.5, 1.5, size=(nbatch, 1, ntile, nz))[ + ay = np.random.uniform(0.1, 1.5, size=(nbatch, 1, ntile, nz))[ :, :, :, None, None, : ] by = np.random.uniform(8, 16, size=(nbatch, 1, ntile, nz))[ From 9fa0c1044e3a2ecfa7fce9a2c1234be567dbcb52 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 22 Aug 2022 18:01:40 +0000 Subject: [PATCH 06/55] linting fixes, updated dumping type hints to use Dumpable instead of Predictor --- external/fv3fit/fv3fit/_shared/io.py | 25 +++++++++++-------- external/fv3fit/fv3fit/data/tfdataset.py | 2 -- .../fv3fit/fv3fit/pytorch/cyclegan/network.py | 6 +---- .../fv3fit/fv3fit/pytorch/cyclegan/train.py | 4 +-- external/fv3fit/fv3fit/pytorch/predict.py | 11 ++++++-- external/fv3fit/fv3fit/tfdataset.py | 4 +-- external/fv3fit/tests/test_io.py | 4 +-- .../fv3fit/tests/training/test_autoencoder.py | 16 +++++------- 8 files changed, 36 insertions(+), 36 deletions(-) diff --git a/external/fv3fit/fv3fit/_shared/io.py b/external/fv3fit/fv3fit/_shared/io.py index 2e30d3f9e9..4e5734c52d 100644 --- a/external/fv3fit/fv3fit/_shared/io.py +++ b/external/fv3fit/fv3fit/_shared/io.py @@ -3,7 +3,7 @@ import fsspec import warnings -from .predictor import Predictor +from .predictor import Predictor, Dumpable from functools import partial _NAME_PATH = "name" @@ -18,8 +18,9 @@ class _Register: def __init__(self) -> None: self._model_types: MutableMapping[str, Type[Predictor]] = {} + self._dump_types: MutableMapping[str, Type[Dumpable]] = {} - def __call__(self, name: str) -> Callable[[Type[Predictor]], Type[Predictor]]: + def __call__(self, name: str) -> Callable[[Type[Dumpable]], Type[Dumpable]]: if name in self._model_types: raise ValueError( f"{name} is already registered by {self._model_types[name]}." @@ -27,8 +28,10 @@ def __call__(self, name: str) -> Callable[[Type[Predictor]], Type[Predictor]]: else: return partial(self._register_class, name=name) - def _register_class(self, cls: Type[Predictor], name: str) -> Type[Predictor]: - self._model_types[name] = cls + def _register_class(self, cls: Type[Dumpable], name: str) -> Type[Dumpable]: + if isinstance(cls, Predictor): + self._model_types[name] = cls + self._dump_types[name] = cls return cls def _load_by_name(self, name: str, path: str) -> Predictor: @@ -40,10 +43,10 @@ def _load_by_name(self, name: str, path: str) -> Predictor: ) return self._model_types[name].load(path) - def get_name(self, obj: Predictor) -> str: + def get_dumpable_name(self, obj: Dumpable) -> str: return_name = None name_cls = None - for name, cls in self._model_types.items(): + for name, cls in self._dump_types.items(): if isinstance(obj, cls): # always return the most specific class name / deepest subclass if name_cls is None or issubclass(cls, name_cls): @@ -61,9 +64,9 @@ def get_name(self, obj: Predictor) -> str: def _get_predictor_name(path: str) -> str: return fsspec.get_mapper(path)[_NAME_PATH].decode(_NAME_ENCODING).strip() - def _dump_predictor_name(self, obj: Predictor, path: str): + def _dump_dumpable_name(self, obj: Dumpable, path: str): mapper = fsspec.get_mapper(path) - name = self.get_name(obj) + name = self.get_dumpable_name(obj) mapper[_NAME_PATH] = name.encode(_NAME_ENCODING) def load(self, path: str) -> Predictor: @@ -87,9 +90,9 @@ def load(self, path: str) -> Predictor: else: return self._load_by_name(name, path) - def dump(self, obj: Predictor, path: str): - """Dump a Predictor to a path""" - self._dump_predictor_name(obj, path) + def dump(self, obj: Dumpable, path: str): + """Dump a Dumpable to a path""" + self._dump_dumpable_name(obj, path) obj.dump(path) diff --git a/external/fv3fit/fv3fit/data/tfdataset.py b/external/fv3fit/fv3fit/data/tfdataset.py index 86403a7b16..dc9715520d 100644 --- a/external/fv3fit/fv3fit/data/tfdataset.py +++ b/external/fv3fit/fv3fit/data/tfdataset.py @@ -8,8 +8,6 @@ import tempfile import xarray as xr import numpy as np -from fv3fit._shared.stacking import stack, SAMPLE_DIM_NAME -from toolz import curry @dataclasses.dataclass diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/network.py b/external/fv3fit/fv3fit/pytorch/cyclegan/network.py index 6bca76382a..6079a9736a 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/network.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/network.py @@ -1,10 +1,6 @@ -import dataclasses -from typing import Callable, Literal, Optional, Protocol -import torch.nn.functional as F +from typing import Callable, Literal, Protocol import torch import torch.nn as nn -from dgl.nn.pytorch import SAGEConv -from ..graph import build_dgl_graph, CubedSphereGraphOperation from toolz import curry diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train.py index b08cb6019c..23984ce665 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train.py @@ -1,7 +1,6 @@ from fv3fit._shared.hyperparameters import Hyperparameters import dataclasses import tensorflow as tf -import dataclasses from fv3fit.pytorch.predict import PytorchPredictor from fv3fit.pytorch.loss import LossConfig from fv3fit.pytorch.optimizer import OptimizerConfig @@ -89,7 +88,8 @@ def fit_loop( Args: train_model: pytorch model to train train_data: training dataset containing samples to be passed to the model, - samples should be tuples with two tensors of shape [sample, time, tile, x, y, z] + samples should be tuples with two tensors of shape + [sample, time, tile, x, y, z] validation_data: validation dataset containing samples to be passed to the model, samples should be tuples with two tensors of shape [sample, time, tile, x, y, z] diff --git a/external/fv3fit/fv3fit/pytorch/predict.py b/external/fv3fit/fv3fit/pytorch/predict.py index 3c0461fac1..a3bd31f808 100644 --- a/external/fv3fit/fv3fit/pytorch/predict.py +++ b/external/fv3fit/fv3fit/pytorch/predict.py @@ -127,7 +127,7 @@ def unpack_tensor(self, data: torch.Tensor) -> xr.Dataset: ) @classmethod - def load(cls, path: str) -> "PytorchAutoregressor": + def load(cls, path: str) -> "PytorchPredictor": """Load a serialized model from a directory.""" return _load_pytorch(cls, path) @@ -261,10 +261,17 @@ class PytorchDumpable(Protocol): _MODEL_FILENAME: str _SCALERS_FILENAME: str _CONFIG_FILENAME: str - state_variables: Iterable[Hashable] scalers: Mapping[Hashable, StandardScaler] model: torch.nn.Module + def __init__( + self, + model: torch.nn.Module, + scalers: Mapping[Hashable, StandardScaler], + **kwargs, + ): + ... + def dump(self, path: str) -> None: ... diff --git a/external/fv3fit/fv3fit/tfdataset.py b/external/fv3fit/fv3fit/tfdataset.py index 1b51590eb8..5fe20f0bdd 100644 --- a/external/fv3fit/fv3fit/tfdataset.py +++ b/external/fv3fit/fv3fit/tfdataset.py @@ -143,13 +143,13 @@ def process_shape(shape): def generator_to_tfdataset( - source: Generator, varying_first_dim: bool = False, + source: Callable[[], Generator], varying_first_dim: bool = False, ) -> tf.data.Dataset: """ A general function to convert from a generator into a tensorflow dataset. Args: - source: data items to be included in the dataset + source: function which provides data items to be included in the dataset varying_first_dim: if True, the first dimension of the produced tensors can be of varying length """ diff --git a/external/fv3fit/tests/test_io.py b/external/fv3fit/tests/test_io.py index ed3374643f..07b9da577b 100644 --- a/external/fv3fit/tests/test_io.py +++ b/external/fv3fit/tests/test_io.py @@ -12,7 +12,7 @@ class Mock: pass mock = Mock() - assert register.get_name(mock) == "mock" + assert register.get_dumpable_name(mock) == "mock" def test_registering_twice_fails(): @@ -41,7 +41,7 @@ class MockSubclass(Mock): pass mock = MockSubclass() - assert register.get_name(mock) == "mock-subclass" + assert register.get_dumpable_name(mock) == "mock-subclass" def test_register_dump_load(tmpdir): diff --git a/external/fv3fit/tests/training/test_autoencoder.py b/external/fv3fit/tests/training/test_autoencoder.py index 542a894249..3c3a557046 100644 --- a/external/fv3fit/tests/training/test_autoencoder.py +++ b/external/fv3fit/tests/training/test_autoencoder.py @@ -1,6 +1,3 @@ -import matplotlib - -matplotlib.use("TkAgg") import numpy as np import xarray as xr from typing import Sequence @@ -10,7 +7,6 @@ import collections import os import fv3fit.pytorch -import matplotlib.pyplot as plt def get_tfdataset(nsamples, nbatch, ntime, nx, ny, nz): @@ -87,7 +83,6 @@ def tfdataset_to_xr_dataset(tfdataset, dims: Sequence[str]): def test_autoencoder(tmpdir): - matplotlib.use("TkAgg") # run the test in a temporary directory to delete artifacts when done os.chdir(tmpdir) # need a larger nx, ny for the sample data here since we're training @@ -112,11 +107,12 @@ def test_autoencoder(tmpdir): ) predicted = predictor.predict(test_xrdataset) reference = test_xrdataset - fig, ax = plt.subplots(1, 2) - ax[0].imshow(reference["a"][0, 0, :, :, 0].values) - ax[1].imshow(predicted["a"][0, 0, :, :, 0].values) - plt.tight_layout() - plt.show() + # plotting code to uncomment if you'd like to manually check the results: + # fig, ax = plt.subplots(1, 2) + # ax[0].imshow(reference["a"][0, 0, :, :, 0].values) + # ax[1].imshow(predicted["a"][0, 0, :, :, 0].values) + # plt.tight_layout() + # plt.show() bias = predicted.isel(time=1) - reference.isel(time=1) mean_bias: xr.Dataset = bias.mean() rmse: xr.Dataset = (bias ** 2).mean() ** 0.5 From 1b5ccd5d189f49d439948d7ac1e8cebd73a1af19 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 22 Aug 2022 18:18:52 +0000 Subject: [PATCH 07/55] remove pdb call --- external/fv3fit/fv3fit/pytorch/predict.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/external/fv3fit/fv3fit/pytorch/predict.py b/external/fv3fit/fv3fit/pytorch/predict.py index 30c6c6b6ca..f6717f3560 100644 --- a/external/fv3fit/fv3fit/pytorch/predict.py +++ b/external/fv3fit/fv3fit/pytorch/predict.py @@ -99,9 +99,6 @@ def predict(self, X: xr.Dataset) -> xr.Dataset: with torch.no_grad(): outputs = self.model(tensor) predicted = self.unpack_tensor(outputs) - import pdb - - pdb.set_trace() return predicted def pack_to_tensor(self, X: xr.Dataset) -> torch.Tensor: From a5552c367e70ef5c4d91d3aedf74c5322b0b1832 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 22 Aug 2022 18:33:08 +0000 Subject: [PATCH 08/55] fix IO test --- external/fv3fit/tests/test_io.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/external/fv3fit/tests/test_io.py b/external/fv3fit/tests/test_io.py index 07b9da577b..e1c004232e 100644 --- a/external/fv3fit/tests/test_io.py +++ b/external/fv3fit/tests/test_io.py @@ -2,6 +2,7 @@ import pytest from fv3fit._shared.io import dump, load, register, _Register +import fv3fit def test_Register_get_name(): @@ -51,10 +52,13 @@ def test_register_dump_load(tmpdir): relative_path = "some_path" @register("mock1") - class Mock1: + class Mock1(fv3fit.Predictor): def __init__(self, data): self.data = data + def predict(self, X): + pass + @staticmethod def load(path: str): with open(os.path.join(path, relative_path)) as f: From 8f8451ce34d9ba96055ee9fd1595ea2864e38f88 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 22 Aug 2022 18:34:21 +0000 Subject: [PATCH 09/55] move training configuration for pytorch models into the same file --- external/fv3fit/fv3fit/_shared/io.py | 2 +- .../fv3fit/fv3fit/pytorch/cyclegan/train.py | 107 ++---------------- .../fv3fit/fv3fit/pytorch/graph/__init__.py | 2 +- external/fv3fit/fv3fit/pytorch/graph/train.py | 6 +- .../fv3fit/fv3fit/pytorch/training_loop.py | 92 ++++++++++++++- .../fv3fit/tests/training/test_autoencoder.py | 6 +- 6 files changed, 106 insertions(+), 109 deletions(-) diff --git a/external/fv3fit/fv3fit/_shared/io.py b/external/fv3fit/fv3fit/_shared/io.py index 4e5734c52d..0742af3a23 100644 --- a/external/fv3fit/fv3fit/_shared/io.py +++ b/external/fv3fit/fv3fit/_shared/io.py @@ -29,7 +29,7 @@ def __call__(self, name: str) -> Callable[[Type[Dumpable]], Type[Dumpable]]: return partial(self._register_class, name=name) def _register_class(self, cls: Type[Dumpable], name: str) -> Type[Dumpable]: - if isinstance(cls, Predictor): + if issubclass(cls, Predictor): self._model_types[name] = cls self._dump_types[name] = cls return cls diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train.py index 08a3adf66a..0dcf56e153 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train.py @@ -4,11 +4,7 @@ from fv3fit.pytorch.predict import PytorchPredictor from fv3fit.pytorch.loss import LossConfig from fv3fit.pytorch.optimizer import OptimizerConfig -import tensorflow_datasets as tfds -from fv3fit.tfdataset import sequence_size -import torch -import numpy as np -from ..system import DEVICE +from fv3fit.pytorch.training_loop import TrainingConfig from fv3fit._shared import register_training_function from typing import ( @@ -47,8 +43,8 @@ class AutoencoderHyperparameters(Hyperparameters): generator: GeneratorConfig = dataclasses.field( default_factory=lambda: GeneratorConfig() ) - training_loop: "TrainingLoopConfig" = dataclasses.field( - default_factory=lambda: TrainingLoopConfig() + training_loop: "TrainingConfig" = dataclasses.field( + default_factory=lambda: TrainingConfig() ) loss: LossConfig = LossConfig(loss_type="mse") noise_amount: float = 0.5 @@ -58,95 +54,6 @@ def variables(self): return tuple(self.state_variables) -@dataclasses.dataclass -class TrainingLoopConfig: - """ - Attributes: - epochs: number of times to run through the batches when training - shuffle_buffer_size: size of buffer to use when shuffling samples - save_path: name of the file to save the best weights - do_multistep: if True, use multistep loss calculation - multistep: number of steps in multistep loss calculation - validation_batch_size: if given, process validation data in batches - of this size, otherwise process it all at once - """ - - n_epoch: int = 20 - shuffle_buffer_size: int = 10 - samples_per_batch: int = 1 - save_path: str = "weight.pt" - validation_batch_size: Optional[int] = None - - def fit_loop( - self, - train_model: torch.nn.Module, - train_data: tf.data.Dataset, - validation_data: tf.data.Dataset, - optimizer: torch.optim.Optimizer, - loss_config: LossConfig, - ) -> None: - """ - Args: - train_model: pytorch model to train - train_data: training dataset containing samples to be passed to the model, - samples should be tuples with two tensors of shape - [sample, time, tile, x, y, z] - validation_data: validation dataset containing samples to be passed - to the model, samples should be tuples with two tensors - of shape [sample, time, tile, x, y, z] - optimizer: type of optimizer for the model - loss_config: configuration of loss function - """ - train_data = ( - flatten_dims(train_data) - .shuffle(buffer_size=self.shuffle_buffer_size) - .batch(self.samples_per_batch) - ) - train_data = tfds.as_numpy(train_data) - if validation_data is not None: - if self.validation_batch_size is None: - validation_batch_size = sequence_size(validation_data) - else: - validation_batch_size = self.validation_batch_size - validation_data = flatten_dims(validation_data).batch(validation_batch_size) - validation_data = tfds.as_numpy(validation_data) - min_val_loss = np.inf - best_weights = None - for i in range(1, self.n_epoch + 1): # loop over the dataset multiple times - logger.info("starting epoch %d", i) - train_model = train_model.train() - train_losses = [] - for batch_state in train_data: - batch_input = torch.as_tensor(batch_state[0]).float().to(DEVICE) - batch_output = torch.as_tensor(batch_state[1]).float().to(DEVICE) - optimizer.zero_grad() - loss: torch.Tensor = loss_config.loss( - train_model(batch_input), batch_output - ) - loss.backward() - train_losses.append(loss) - optimizer.step() - train_loss = torch.mean(torch.stack(train_losses)) - logger.info("train loss: %f", train_loss) - if validation_data is not None: - val_model = train_model.eval() - val_losses = [] - for batch_state in validation_data: - batch_input = torch.as_tensor(batch_state[0]).float().to(DEVICE) - batch_output = torch.as_tensor(batch_state[1]).float().to(DEVICE) - with torch.no_grad(): - val_losses.append( - loss_config.loss(val_model(batch_input), batch_output) - ) - val_loss = torch.mean(torch.stack(val_losses)) - logger.info("val_loss %f", val_loss) - if val_loss < min_val_loss: - min_val_loss = val_loss - best_weights = train_model.state_dict() - if validation_data is not None: - train_model.load_state_dict(best_weights) - - @curry def define_noisy_input(data: tf.Tensor, stdev=0.1) -> Tuple[tf.Tensor, tf.Tensor]: """ @@ -204,12 +111,12 @@ def train_autoencoder( print(train_model) optimizer = hyperparameters.optimizer_config - train_state = train_state.map( - define_noisy_input(stdev=hyperparameters.noise_amount) + train_state = flatten_dims( + train_state.map(define_noisy_input(stdev=hyperparameters.noise_amount)) ) if validation_batches is not None: - val_state = val_state.map( - define_noisy_input(stdev=hyperparameters.noise_amount) + val_state = flatten_dims( + val_state.map(define_noisy_input(stdev=hyperparameters.noise_amount)) ) hyperparameters.training_loop.fit_loop( diff --git a/external/fv3fit/fv3fit/pytorch/graph/__init__.py b/external/fv3fit/fv3fit/pytorch/graph/__init__.py index 8604808da2..c5abae99e7 100644 --- a/external/fv3fit/fv3fit/pytorch/graph/__init__.py +++ b/external/fv3fit/fv3fit/pytorch/graph/__init__.py @@ -1,3 +1,3 @@ from .network import GraphNetwork, GraphNetworkConfig, CubedSphereGraphOperation -from .train import train_graph_model, TrainingLoopConfig, GraphHyperparameters +from .train import train_graph_model, AutoregressiveTrainingConfig, GraphHyperparameters from .graph_builder import build_dgl_graph, build_graph diff --git a/external/fv3fit/fv3fit/pytorch/graph/train.py b/external/fv3fit/fv3fit/pytorch/graph/train.py index ca02982776..fc72bcc082 100644 --- a/external/fv3fit/fv3fit/pytorch/graph/train.py +++ b/external/fv3fit/fv3fit/pytorch/graph/train.py @@ -7,7 +7,7 @@ from fv3fit.pytorch.graph.network import GraphNetwork, GraphNetworkConfig from fv3fit.pytorch.loss import LossConfig from fv3fit.pytorch.optimizer import OptimizerConfig -from fv3fit.pytorch.training_loop import TrainingLoopConfig +from fv3fit.pytorch.training_loop import AutoregressiveTrainingConfig from fv3fit._shared.scaler import StandardScaler from ..system import DEVICE @@ -43,8 +43,8 @@ class GraphHyperparameters(Hyperparameters): graph_network: GraphNetworkConfig = dataclasses.field( default_factory=lambda: GraphNetworkConfig() ) - training_loop: TrainingLoopConfig = dataclasses.field( - default_factory=lambda: TrainingLoopConfig() + training_loop: AutoregressiveTrainingConfig = dataclasses.field( + default_factory=lambda: AutoregressiveTrainingConfig() ) loss: LossConfig = LossConfig(loss_type="mse") diff --git a/external/fv3fit/fv3fit/pytorch/training_loop.py b/external/fv3fit/fv3fit/pytorch/training_loop.py index f164825cbd..48f605a7f1 100644 --- a/external/fv3fit/fv3fit/pytorch/training_loop.py +++ b/external/fv3fit/fv3fit/pytorch/training_loop.py @@ -13,8 +13,98 @@ @dataclasses.dataclass -class TrainingLoopConfig: +class TrainingConfig: """ + Training configuration. + + Attributes: + epochs: number of times to run through the batches when training + shuffle_buffer_size: size of buffer to use when shuffling samples + samples_per_batch: number of samples to use in each training batch + save_path: name of the file to save the best weights + validation_batch_size: if given, process validation data in batches + of this size, otherwise process it all at once + """ + + n_epoch: int = 20 + shuffle_buffer_size: int = 10 + samples_per_batch: int = 1 + save_path: str = "weight.pt" + validation_batch_size: Optional[int] = None + + def fit_loop( + self, + train_model: torch.nn.Module, + train_data: tf.data.Dataset, + validation_data: tf.data.Dataset, + optimizer: torch.optim.Optimizer, + loss_config: LossConfig, + ) -> None: + """ + Args: + train_model: pytorch model to train + train_data: training dataset containing samples to be passed to the model, + samples should be tuples with two tensors corresponding to the model + input and output + validation_data: validation dataset containing samples to be passed + to the model, samples should be tuples with two tensors + corresponding to the model input and output + optimizer: type of optimizer for the model + loss_config: configuration of loss function + """ + train_data = train_data.shuffle(buffer_size=self.shuffle_buffer_size).batch( + self.samples_per_batch + ) + train_data = tfds.as_numpy(train_data) + if validation_data is not None: + if self.validation_batch_size is None: + validation_batch_size = sequence_size(validation_data) + else: + validation_batch_size = self.validation_batch_size + validation_data = validation_data.batch(validation_batch_size) + validation_data = tfds.as_numpy(validation_data) + min_val_loss = np.inf + best_weights = None + for i in range(1, self.n_epoch + 1): # loop over the dataset multiple times + logger.info("starting epoch %d", i) + train_model = train_model.train() + train_losses = [] + for batch_state in train_data: + batch_input = torch.as_tensor(batch_state[0]).float().to(DEVICE) + batch_output = torch.as_tensor(batch_state[1]).float().to(DEVICE) + optimizer.zero_grad() + loss: torch.Tensor = loss_config.loss( + train_model(batch_input), batch_output + ) + loss.backward() + train_losses.append(loss) + optimizer.step() + train_loss = torch.mean(torch.stack(train_losses)) + logger.info("train loss: %f", train_loss) + if validation_data is not None: + val_model = train_model.eval() + val_losses = [] + for batch_state in validation_data: + batch_input = torch.as_tensor(batch_state[0]).float().to(DEVICE) + batch_output = torch.as_tensor(batch_state[1]).float().to(DEVICE) + with torch.no_grad(): + val_losses.append( + loss_config.loss(val_model(batch_input), batch_output) + ) + val_loss = torch.mean(torch.stack(val_losses)) + logger.info("val_loss %f", val_loss) + if val_loss < min_val_loss: + min_val_loss = val_loss + best_weights = train_model.state_dict() + if validation_data is not None: + train_model.load_state_dict(best_weights) + + +@dataclasses.dataclass +class AutoregressiveTrainingConfig: + """ + Training configuration for autoregressive models. + Attributes: epochs: number of times to run through the batches when training shuffle_buffer_size: size of buffer to use when shuffling samples diff --git a/external/fv3fit/tests/training/test_autoencoder.py b/external/fv3fit/tests/training/test_autoencoder.py index e1a16ccd90..cc58b2cfa7 100644 --- a/external/fv3fit/tests/training/test_autoencoder.py +++ b/external/fv3fit/tests/training/test_autoencoder.py @@ -2,7 +2,7 @@ import xarray as xr from typing import Sequence from fv3fit.pytorch.cyclegan import AutoencoderHyperparameters, train_autoencoder -from fv3fit.pytorch.cyclegan.train import TrainingLoopConfig +from fv3fit.pytorch.cyclegan.train import TrainingConfig from fv3fit.tfdataset import iterable_to_tfdataset import collections import os @@ -99,7 +99,7 @@ def test_autoencoder(tmpdir): generator=fv3fit.pytorch.GeneratorConfig( n_convolutions=2, n_resnet=3, max_filters=32 ), - training_loop=TrainingLoopConfig(n_epoch=5, samples_per_batch=2), + training_loop=TrainingConfig(n_epoch=5, samples_per_batch=2), optimizer_config=fv3fit.pytorch.OptimizerConfig(name="Adam",), noise_amount=0.5, ) @@ -144,7 +144,7 @@ def test_autoencoder_overfit(tmpdir): generator=fv3fit.pytorch.GeneratorConfig( n_convolutions=2, n_resnet=1, max_filters=32 ), - training_loop=TrainingLoopConfig(n_epoch=100, samples_per_batch=6), + training_loop=TrainingConfig(n_epoch=100, samples_per_batch=6), optimizer_config=fv3fit.pytorch.OptimizerConfig(name="Adam",), noise_amount=0.0, ) From 3f22efe5b2619672d06bb5b3d78295138fb6f7b2 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 22 Aug 2022 19:10:52 +0000 Subject: [PATCH 10/55] deduplicate training loop logic --- external/fv3fit/fv3fit/pytorch/graph/train.py | 4 +- .../fv3fit/fv3fit/pytorch/training_loop.py | 179 +++++++++--------- 2 files changed, 93 insertions(+), 90 deletions(-) diff --git a/external/fv3fit/fv3fit/pytorch/graph/train.py b/external/fv3fit/fv3fit/pytorch/graph/train.py index fc72bcc082..fc76d8c7ff 100644 --- a/external/fv3fit/fv3fit/pytorch/graph/train.py +++ b/external/fv3fit/fv3fit/pytorch/graph/train.py @@ -108,11 +108,11 @@ def train_graph_model( ) if validation_batches is not None: - val_state = get_state(data=validation_batches) + val_state = get_state(data=validation_batches).unbatch() else: val_state = None - train_state = get_state(data=train_batches) + train_state = get_state(data=train_batches).unbatch() train_model = build_model( hyperparameters.graph_network, n_state=next(iter(train_state)).shape[-1] diff --git a/external/fv3fit/fv3fit/pytorch/training_loop.py b/external/fv3fit/fv3fit/pytorch/training_loop.py index 48f605a7f1..a3b84aadc0 100644 --- a/external/fv3fit/fv3fit/pytorch/training_loop.py +++ b/external/fv3fit/fv3fit/pytorch/training_loop.py @@ -1,6 +1,6 @@ import dataclasses import logging -from typing import Callable, Optional +from typing import Any, Callable, Optional import numpy as np import torch import tensorflow_datasets as tfds @@ -52,52 +52,24 @@ def fit_loop( optimizer: type of optimizer for the model loss_config: configuration of loss function """ - train_data = train_data.shuffle(buffer_size=self.shuffle_buffer_size).batch( - self.samples_per_batch + + def evaluate_on_batch(batch_state, model): + batch_input = torch.as_tensor(batch_state[0]).float().to(DEVICE) + batch_output = torch.as_tensor(batch_state[1]).float().to(DEVICE) + loss: torch.Tensor = loss_config.loss(model(batch_input), batch_output) + return loss + + return _train_loop( + model=train_model, + train_data=train_data, + validation_data=validation_data, + evaluate_on_batch=evaluate_on_batch, + optimizer=optimizer, + n_epoch=self.n_epoch, + shuffle_buffer_size=self.shuffle_buffer_size, + samples_per_batch=self.samples_per_batch, + validation_batch_size=self.validation_batch_size, ) - train_data = tfds.as_numpy(train_data) - if validation_data is not None: - if self.validation_batch_size is None: - validation_batch_size = sequence_size(validation_data) - else: - validation_batch_size = self.validation_batch_size - validation_data = validation_data.batch(validation_batch_size) - validation_data = tfds.as_numpy(validation_data) - min_val_loss = np.inf - best_weights = None - for i in range(1, self.n_epoch + 1): # loop over the dataset multiple times - logger.info("starting epoch %d", i) - train_model = train_model.train() - train_losses = [] - for batch_state in train_data: - batch_input = torch.as_tensor(batch_state[0]).float().to(DEVICE) - batch_output = torch.as_tensor(batch_state[1]).float().to(DEVICE) - optimizer.zero_grad() - loss: torch.Tensor = loss_config.loss( - train_model(batch_input), batch_output - ) - loss.backward() - train_losses.append(loss) - optimizer.step() - train_loss = torch.mean(torch.stack(train_losses)) - logger.info("train loss: %f", train_loss) - if validation_data is not None: - val_model = train_model.eval() - val_losses = [] - for batch_state in validation_data: - batch_input = torch.as_tensor(batch_state[0]).float().to(DEVICE) - batch_output = torch.as_tensor(batch_state[1]).float().to(DEVICE) - with torch.no_grad(): - val_losses.append( - loss_config.loss(val_model(batch_input), batch_output) - ) - val_loss = torch.mean(torch.stack(val_losses)) - logger.info("val_loss %f", val_loss) - if val_loss < min_val_loss: - min_val_loss = val_loss - best_weights = train_model.state_dict() - if validation_data is not None: - train_model.load_state_dict(best_weights) @dataclasses.dataclass @@ -113,10 +85,11 @@ class AutoregressiveTrainingConfig: """ n_epoch: int = 20 - buffer_size: int = 50_000 + shuffle_buffer_size: int = 50_000 samples_per_batch: int = 1 save_path: str = "weight.pt" multistep: int = 1 + validation_batch_size: Optional[int] = None def fit_loop( self, @@ -136,48 +109,28 @@ def fit_loop( optimizer: type of optimizer for the model loss_config: configuration of loss function """ - train_data = ( - train_data.unbatch() - .shuffle(buffer_size=self.buffer_size) - .batch(self.samples_per_batch) - ) - train_data = tfds.as_numpy(train_data) - if validation_data is not None: - validation_data = validation_data.unbatch() - n_validation = sequence_size(validation_data) - validation_state = ( - torch.as_tensor(next(iter(validation_data.batch(n_validation))).numpy()) - .float() - .to(DEVICE) + + def evaluate_on_batch(batch_state, model): + batch_state = torch.as_tensor(batch_state).float().to(DEVICE) + loss: torch.Tensor = evaluate_model( + batch_state=batch_state, + model=train_model, + multistep=self.multistep, + loss=loss_config.loss, ) - min_val_loss = np.inf - else: - validation_state = None - for _ in range(1, self.n_epoch + 1): # loop over the dataset multiple times - train_model = train_model.train() - for batch_state in train_data: - batch_state = torch.as_tensor(batch_state).float().to(DEVICE) - optimizer.zero_grad() - loss = evaluate_model( - batch_state=batch_state, - model=train_model, - multistep=self.multistep, - loss=loss_config.loss, - ) - loss.backward() - optimizer.step() - if validation_state is not None: - val_model = train_model.eval() - with torch.no_grad(): - val_loss = evaluate_model( - validation_state, - model=val_model, - multistep=self.multistep, - loss=loss_config.loss, - ) - if val_loss < min_val_loss: - min_val_loss = val_loss - torch.save(train_model.state_dict(), self.save_path) + return loss + + return _train_loop( + model=train_model, + train_data=train_data, + validation_data=validation_data, + evaluate_on_batch=evaluate_on_batch, + optimizer=optimizer, + n_epoch=self.n_epoch, + shuffle_buffer_size=self.shuffle_buffer_size, + samples_per_batch=self.samples_per_batch, + validation_batch_size=self.validation_batch_size, + ) def evaluate_model( @@ -193,3 +146,53 @@ def evaluate_model( target_state = batch_state[:, step + 1, :] total_loss += loss(state_snapshot, target_state) return total_loss / multistep + + +def _train_loop( + model: torch.nn.Module, + train_data: tf.data.Dataset, + validation_data: tf.data.Dataset, + evaluate_on_batch: Callable[[Any, torch.nn.Module], torch.Tensor], + optimizer: torch.optim.Optimizer, + n_epoch: int, + shuffle_buffer_size: int, + samples_per_batch: int, + validation_batch_size: Optional[int] = None, +): + + train_data = train_data.shuffle(buffer_size=shuffle_buffer_size).batch( + samples_per_batch + ) + train_data = tfds.as_numpy(train_data) + if validation_data is not None: + if validation_batch_size is None: + validation_batch_size = sequence_size(validation_data) + validation_data = validation_data.batch(validation_batch_size) + validation_data = tfds.as_numpy(validation_data) + min_val_loss = np.inf + best_weights = None + for i in range(1, n_epoch + 1): # loop over the dataset multiple times + logger.info("starting epoch %d", i) + train_model = model.train() + train_losses = [] + for batch_state in train_data: + optimizer.zero_grad() + loss = evaluate_on_batch(batch_state, train_model) + loss.backward() + train_losses.append(loss) + optimizer.step() + train_loss = torch.mean(torch.stack(train_losses)) + logger.info("train loss: %f", train_loss) + if validation_data is not None: + val_model = model.eval() + val_losses = [] + for batch_state in validation_data: + with torch.no_grad(): + val_losses.append(evaluate_on_batch(batch_state, val_model)) + val_loss = torch.mean(torch.stack(val_losses)) + logger.info("val_loss %f", val_loss) + if val_loss < min_val_loss: + min_val_loss = val_loss + best_weights = train_model.state_dict() + if validation_data is not None: + train_model.load_state_dict(best_weights) From 0aee33ea4a0c6d6a3e887929de2976b913d0d8eb Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 22 Aug 2022 19:21:20 +0000 Subject: [PATCH 11/55] delete dead code, refactor to use ConvolutionFactoryFactory --- .../fv3fit/fv3fit/pytorch/cyclegan/network.py | 169 ++++++++---------- 1 file changed, 73 insertions(+), 96 deletions(-) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/network.py b/external/fv3fit/fv3fit/pytorch/cyclegan/network.py index ed26fd910f..a620e90297 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/network.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/network.py @@ -35,14 +35,61 @@ class ConvolutionFactoryFactory(Protocol): def __call__( self, kernel_size: int, - padding: int, + padding: int = 0, + output_padding: int = 0, stride: int = 1, stride_type: Literal["regular", "transpose"] = "regular", bias: bool = True, ) -> ConvolutionFactory: + """ + Create a factory for creating convolution layers. + + Args: + kernel_size: size of the convolution kernel + padding: padding to apply to the input + output_padding: argument used for transpose convolution + stride: stride of the convolution + stride_type: type of stride, one of "regular" or "transpose" + bias: whether to include a bias vector in the produced layers + """ ... +def regular_convolution( + kernel_size: int, + padding: int = 0, + output_padding: int = 0, + stride: int = 1, + stride_type: Literal["regular", "transpose"] = "regular", + bias: bool = True, +) -> ConvolutionFactory: + """ + Produces convolution factories for regular (image) data. + + Args: + kernel_size: size of the convolution kernel + padding: padding to apply to the input + output_padding: argument used for transpose convolution + stride: stride of the convolution + stride_type: type of stride, one of "regular" or "transpose" + bias: whether to include a bias vector in the produced layers + """ + if stride == 1: + return flat_convolution(kernel_size=kernel_size, bias=bias) + elif stride_type == "regular": + return strided_convolution( + kernel_size=kernel_size, stride=stride, padding=padding, bias=bias + ) + elif stride_type == "transpose": + return transpose_convolution( + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + bias=bias, + ) + + @curry def strided_convolution( in_channels: int, @@ -143,7 +190,13 @@ def forward(self, inputs): class Discriminator(nn.Module): - def __init__(self, in_channels: int, n_convolutions: int, max_filters: int): + def __init__( + self, + in_channels: int, + n_convolutions: int, + max_filters: int, + convolution: ConvolutionFactoryFactory = regular_convolution, + ): super(Discriminator, self).__init__() # max_filters = min_filters * 2 ** (n_convolutions - 1), therefore min_filters = int(max_filters / 2 ** (n_convolutions - 1)) @@ -151,9 +204,7 @@ def __init__(self, in_channels: int, n_convolutions: int, max_filters: int): ConvBlock( in_channels=in_channels, out_channels=min_filters, - convolution_factory=strided_convolution( - kernel_size=3, stride=2, padding=1 - ), + convolution_factory=convolution(kernel_size=3, stride=2, padding=1), activation_factory=leakyrelu_activation(alpha=0.2), ) ] @@ -162,22 +213,20 @@ def __init__(self, in_channels: int, n_convolutions: int, max_filters: int): ConvBlock( in_channels=min_filters * 2 ** (i - 1), out_channels=min_filters * 2 ** i, - convolution_factory=strided_convolution( - kernel_size=3, stride=2, padding=1 - ), + convolution_factory=convolution(kernel_size=3, stride=2, padding=1), activation_factory=leakyrelu_activation(alpha=0.2), ) ) final_conv = ConvBlock( in_channels=max_filters, out_channels=max_filters, - convolution_factory=flat_convolution(kernel_size=3), + convolution_factory=convolution(kernel_size=3), activation_factory=leakyrelu_activation(alpha=0.2), ) patch_output = ConvBlock( in_channels=max_filters, out_channels=1, - convolution_factory=flat_convolution(kernel_size=3), + convolution_factory=convolution(kernel_size=3), activation_factory=leakyrelu_activation(alpha=0.2), ) self._sequential = nn.Sequential(*convs, final_conv, patch_output) @@ -186,73 +235,14 @@ def forward(self, inputs): return self._sequential(inputs) -class SequentialGenerator(nn.Module): - def __init__( - self, channels: int, n_convolutions: int, n_resnet: int, max_filters: int, - ): - super(SequentialGenerator, self).__init__() - min_filters = int(max_filters / 2 ** (n_convolutions - 1)) - convs = [ - ConvBlock( - in_channels=channels, - out_channels=min_filters, - convolution_factory=flat_convolution(kernel_size=7), - activation_factory=relu_activation(), - ) - ] - for i in range(1, n_convolutions): - convs.append( - ConvBlock( - in_channels=min_filters * 2 ** (i - 1), - out_channels=min_filters * 2 ** i, - convolution_factory=strided_convolution( - kernel_size=3, stride=2, padding=1 - ), - activation_factory=relu_activation(), - ) - ) - resnet_blocks = [ - ResnetBlock( - n_filters=max_filters, - convolution_factory=flat_convolution(kernel_size=3), - activation_factory=relu_activation(), - ) - for i in range(n_resnet) - ] - transpose_convs = [] - for i in range(1, n_convolutions): - transpose_convs.append( - ConvBlock( - in_channels=max_filters // (2 ** (i - 1)), - out_channels=max_filters // (2 ** i), - convolution_factory=transpose_convolution( - kernel_size=3, stride=2, padding=1, output_padding=1 - ), - activation_factory=relu_activation(), - ) - ) - out_conv = ConvBlock( - in_channels=min_filters, - out_channels=channels, - convolution_factory=flat_convolution(kernel_size=7), - activation_factory=no_activation, - ) - self._sequential = nn.Sequential( - *convs, *resnet_blocks, *transpose_convs, out_conv - ) - self._identity = nn.Identity() - - def forward(self, inputs: torch.Tensor): - # data will have channels last, model requires channels first - # return self._identity(inputs) - inputs = inputs.permute(0, 3, 1, 2) - outputs: torch.Tensor = self._sequential(inputs) - return outputs.permute(0, 2, 3, 1) - - class Generator(nn.Module): def __init__( - self, channels: int, n_convolutions: int, n_resnet: int, max_filters: int, + self, + channels: int, + n_convolutions: int, + n_resnet: int, + max_filters: int, + convolution: ConvolutionFactoryFactory = regular_convolution, ): super(Generator, self).__init__() @@ -260,7 +250,7 @@ def resnet(in_channels: int): resnet_blocks = [ ResnetBlock( n_filters=in_channels, - convolution_factory=flat_convolution(kernel_size=3), + convolution_factory=convolution(kernel_size=3), activation_factory=relu_activation(), ) for _ in range(n_resnet) @@ -271,9 +261,7 @@ def down(in_channels: int, out_channels: int): return ConvBlock( in_channels=in_channels, out_channels=out_channels, - convolution_factory=strided_convolution( - kernel_size=3, stride=2, padding=1 - ), + convolution_factory=convolution(kernel_size=3, stride=2, padding=1), activation_factory=relu_activation(), ) @@ -281,21 +269,18 @@ def up(in_channels: int, out_channels: int): return ConvBlock( in_channels=in_channels, out_channels=out_channels, - convolution_factory=transpose_convolution( - kernel_size=3, stride=2, padding=1, output_padding=1 + convolution_factory=convolution( + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + stride_type="transpose", ), activation_factory=relu_activation(), ) min_filters = int(max_filters / 2 ** (n_convolutions - 1)) - # self._first_conv = ConvBlock( - # in_channels=channels, - # out_channels=min_filters, - # convolution_factory=flat_convolution(kernel_size=3), - # activation_factory=relu_activation(), - # ) - self._first_conv = nn.Sequential( flat_convolution(kernel_size=3)( in_channels=channels, out_channels=min_filters @@ -311,20 +296,12 @@ def up(in_channels: int, out_channels: int): in_channels=min_filters, ) - # self._out_conv = ConvBlock( - # in_channels=2 *min_filters, - # out_channels=channels, - # convolution_factory=flat_convolution(kernel_size=3), - # activation_factory=no_activation, - # ) - self._out_conv = flat_convolution(kernel_size=3)( in_channels=2 * min_filters, out_channels=channels ) def forward(self, inputs): # data will have channels last, model requires channels first - # return self._identity(inputs) inputs = inputs.permute(0, 3, 1, 2) x = self._first_conv(inputs) x = self._unet(x) From a7f95631b5be01559ab2195304807e8165b0f7d2 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 22 Aug 2022 19:29:30 +0000 Subject: [PATCH 12/55] fix failing test for IO --- external/fv3fit/fv3fit/_shared/io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/external/fv3fit/fv3fit/_shared/io.py b/external/fv3fit/fv3fit/_shared/io.py index 0742af3a23..ffcd500ecd 100644 --- a/external/fv3fit/fv3fit/_shared/io.py +++ b/external/fv3fit/fv3fit/_shared/io.py @@ -21,9 +21,9 @@ def __init__(self) -> None: self._dump_types: MutableMapping[str, Type[Dumpable]] = {} def __call__(self, name: str) -> Callable[[Type[Dumpable]], Type[Dumpable]]: - if name in self._model_types: + if name in self._dump_types: raise ValueError( - f"{name} is already registered by {self._model_types[name]}." + f"{name} is already registered by {self._dump_types[name]}." ) else: return partial(self._register_class, name=name) From 084a5196d79780f8dc6625964242ac321087ecac Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 22 Aug 2022 20:07:10 +0000 Subject: [PATCH 13/55] add autoencoder to special training types --- external/fv3fit/tests/training/test_train.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/external/fv3fit/tests/training/test_train.py b/external/fv3fit/tests/training/test_train.py index c4ba078972..65c84653c0 100644 --- a/external/fv3fit/tests/training/test_train.py +++ b/external/fv3fit/tests/training/test_train.py @@ -38,7 +38,12 @@ # cannot be used in generic tests below # you must write a separate file that specializes each of the tests # for models in this list -SPECIAL_TRAINING_TYPES = ["graph", "min_max_novelty_detector", "ocsvm_novelty_detector"] +SPECIAL_TRAINING_TYPES = [ + "graph", + "min_max_novelty_detector", + "ocsvm_novelty_detector", + "autoencoder", +] # automatically test on every registered training class From f478eebc07e7955727a0036dddc547049b32d2ec Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 22 Aug 2022 22:57:47 +0000 Subject: [PATCH 14/55] fix logic to reset register between tests --- external/fv3fit/tests/training/test_main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/external/fv3fit/tests/training/test_main.py b/external/fv3fit/tests/training/test_main.py index 9b8c9a6465..27fdfe3830 100644 --- a/external/fv3fit/tests/training/test_main.py +++ b/external/fv3fit/tests/training/test_main.py @@ -100,6 +100,7 @@ def mock_train_dense_model(): "dense", fv3fit.DenseHyperparameters )(original_func) register._model_types.pop("mock") + register._dump_types.pop("mock") @pytest.fixture From b3fb2563cac2226b7a08d92f67f1e2a11cbdc983 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Wed, 24 Aug 2022 00:10:08 +0000 Subject: [PATCH 15/55] wip cyclegan training code --- external/fv3fit/fv3fit/data/__init__.py | 3 +- external/fv3fit/fv3fit/data/synthetic.py | 126 ++++++ external/fv3fit/fv3fit/data/tfdataset.py | 1 - .../fv3fit/pytorch/cyclegan/__init__.py | 1 + .../fv3fit/fv3fit/pytorch/cyclegan/network.py | 26 ++ .../fv3fit/fv3fit/pytorch/cyclegan/train.py | 16 +- .../fv3fit/pytorch/cyclegan/train_cyclegan.py | 397 ++++++++++++++++++ .../fv3fit/tests/training/test_autoencoder.py | 76 +--- .../fv3fit/tests/training/test_cyclegan.py | 156 +++++++ 9 files changed, 731 insertions(+), 71 deletions(-) create mode 100644 external/fv3fit/fv3fit/data/synthetic.py create mode 100644 external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py create mode 100644 external/fv3fit/tests/training/test_cyclegan.py diff --git a/external/fv3fit/fv3fit/data/__init__.py b/external/fv3fit/fv3fit/data/__init__.py index c9d742838e..0d147ca4c6 100644 --- a/external/fv3fit/fv3fit/data/__init__.py +++ b/external/fv3fit/fv3fit/data/__init__.py @@ -1,3 +1,4 @@ from .base import TFDatasetLoader, tfdataset_loader_from_dict, register_tfdataset_loader from .batches import FromBatches -from .tfdataset import WindowedZarrLoader, VariableConfig +from .tfdataset import WindowedZarrLoader, VariableConfig, CycleGANLoader +from .synthetic import SyntheticWaves diff --git a/external/fv3fit/fv3fit/data/synthetic.py b/external/fv3fit/fv3fit/data/synthetic.py new file mode 100644 index 0000000000..d30a4c8f59 --- /dev/null +++ b/external/fv3fit/fv3fit/data/synthetic.py @@ -0,0 +1,126 @@ +from .base import TFDatasetLoader, register_tfdataset_loader +import dataclasses +from typing import Optional, Sequence, List +import tensorflow as tf +import numpy as np +from ..tfdataset import iterable_to_tfdataset +import dacite + + +@register_tfdataset_loader +@dataclasses.dataclass +class SyntheticWaves(TFDatasetLoader): + + nsamples: int + nbatch: int + ntime: int + nx: int + nz: int + scalar_names: List[str] = dataclasses.field(default_factory=list) + scale_min: float = 0.0 + scale_max: float = 1.0 + period_min: float = 8.0 + period_max: float = 16.0 + + def open_tfdataset( + self, local_download_path: Optional[str], variable_names: Sequence[str], + ) -> tf.data.Dataset: + """ + Args: + local_download_path: if provided, cache data locally at this path + variable_names: names of variables to include when loading data + Returns: + dataset containing requested variables, each record is a mapping from + variable name to variable value, and each value is a tensor whose + first dimension is the batch dimension + """ + dataset = get_tfdataset( + variable_names, + scalar_names=self.scalar_names, + nsamples=self.nsamples, + nbatch=self.nbatch, + ntime=self.ntime, + nx=self.nx, + ny=self.nx, + nz=self.nz, + scale_min=self.scale_min, + scale_max=self.scale_max, + period_min=self.period_min, + period_max=self.period_max, + ) + if local_download_path is not None: + dataset = dataset.cache(local_download_path) + return dataset + + @classmethod + def from_dict(cls, d: dict) -> "TFDatasetLoader": + return dacite.from_dict( + data_class=cls, data=d, config=dacite.Config(strict=True) + ) + + +def get_tfdataset( + variable_names, + *, + scalar_names, + nsamples: int, + nbatch: int, + ntime: int, + nx: int, + ny: int, + nz: int, + scale_min: float, + scale_max: float, + period_min: float, + period_max: float +): + ntile = 6 + + grid_x = np.arange(0, nx, dtype=np.float32) + grid_y = np.arange(0, ny, dtype=np.float32) + grid_x, grid_y = np.broadcast_arrays(grid_x[:, None], grid_y[None, :]) + grid_x = grid_x[None, None, None, :, :, None] + grid_y = grid_y[None, None, None, :, :, None] + + def sample_iterator(): + # creates a timeseries where each time is the negation of time before it + for _ in range(nsamples): + ax = np.random.uniform(scale_min, scale_max, size=(nbatch, 1, ntile, nz))[ + :, :, :, None, None, : + ] + bx = np.random.uniform(period_min, period_max, size=(nbatch, 1, ntile, nz))[ + :, :, :, None, None, : + ] + cx = np.random.uniform(0.0, 2 * np.pi, size=(nbatch, 1, ntile, nz))[ + :, :, :, None, None, : + ] + ay = np.random.uniform(scale_min, scale_max, size=(nbatch, 1, ntile, nz))[ + :, :, :, None, None, : + ] + by = np.random.uniform(period_min, period_max, size=(nbatch, 1, ntile, nz))[ + :, :, :, None, None, : + ] + cy = np.random.uniform(0.0, 2 * np.pi, size=(nbatch, 1, ntile, nz))[ + :, :, :, None, None, : + ] + data = ( + ax + * np.sin(2 * np.pi * grid_x / bx + cx) + * ay + * np.sin(2 * np.pi * grid_y / by + cy) + ) + start = {} + for varname in variable_names: + if varname in scalar_names: + start[varname] = data[..., 0].astype(np.float32) + else: + start[varname] = data.astype(np.float32) + out = {key: [value] for key, value in start.items()} + for _ in range(ntime - 1): + for varname in start.keys(): + out[varname].append(out[varname][-1] * -1.0) + for varname in out: + out[varname] = np.concatenate(out[varname], axis=1) + yield out + + return iterable_to_tfdataset(list(sample_iterator())) diff --git a/external/fv3fit/fv3fit/data/tfdataset.py b/external/fv3fit/fv3fit/data/tfdataset.py index dc9715520d..e95f98d414 100644 --- a/external/fv3fit/fv3fit/data/tfdataset.py +++ b/external/fv3fit/fv3fit/data/tfdataset.py @@ -67,7 +67,6 @@ def get_n_windows(n_times: int, window_size: int) -> int: class CycleGANLoader(TFDatasetLoader): domain_configs: List[TFDatasetLoader] = dataclasses.field(default_factory=list) - batch_size: int = 1 def open_tfdataset( self, local_download_path: Optional[str], variable_names: Sequence[str], diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py b/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py index 4f2f47a491..88f8be5d0e 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py @@ -1 +1,2 @@ from .train import train_autoencoder, AutoencoderHyperparameters, GeneratorConfig +from .train_cyclegan import train_cyclegan, CycleGANHyperparameters diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/network.py b/external/fv3fit/fv3fit/pytorch/cyclegan/network.py index a620e90297..010117ff05 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/network.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/network.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn from toolz import curry +import dataclasses def relu_activation(**kwargs): @@ -141,6 +142,31 @@ def flat_convolution(in_channels: int, out_channels: int, kernel_size: int, bias ) +@dataclasses.dataclass +class GeneratorConfig: + n_convolutions: int = 3 + n_resnet: int = 3 + max_filters: int = 256 + + def instance( + self, + channels: int, + convolution: ConvolutionFactoryFactory = regular_convolution, + ): + return Generator( + channels=channels, + n_convolutions=self.n_convolutions, + n_resnet=self.n_resnet, + max_filters=self.max_filters, + convolution=convolution, + ) + + +@dataclasses.dataclass +class DiscriminatorConfig: + pass + + class ResnetBlock(nn.Module): def __init__( self, diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train.py index 0dcf56e153..882d0943ce 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train.py @@ -13,7 +13,7 @@ Tuple, ) from fv3fit.tfdataset import ensure_nd, apply_to_mapping -from .network import Generator +from .network import Generator, GeneratorConfig from fv3fit.pytorch.graph.train import ( get_scalers, get_mapping_scale_func, @@ -25,13 +25,6 @@ logger = logging.getLogger(__name__) -@dataclasses.dataclass -class GeneratorConfig: - n_convolutions: int = 3 - n_resnet: int = 3 - max_filters: int = 256 - - @dataclasses.dataclass class AutoencoderHyperparameters(Hyperparameters): @@ -137,9 +130,4 @@ def train_autoencoder( def build_model(config: GeneratorConfig, n_state: int) -> Generator: - return Generator( - channels=n_state, - n_convolutions=config.n_convolutions, - n_resnet=config.n_resnet, - max_filters=config.max_filters, - ) + return config.instance(channels=n_state) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py new file mode 100644 index 0000000000..fd267aeedf --- /dev/null +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py @@ -0,0 +1,397 @@ +from fv3fit._shared.hyperparameters import Hyperparameters +import random +import dataclasses +import tensorflow as tf +from fv3fit.pytorch.predict import PytorchPredictor +from fv3fit.pytorch.loss import LossConfig +from fv3fit.pytorch.optimizer import OptimizerConfig +import torch +from fv3fit.pytorch.system import DEVICE +import tensorflow_datasets as tfds +from fv3fit.tfdataset import sequence_size + +from fv3fit._shared import register_training_function +from typing import ( + Dict, + List, + Optional, +) +from fv3fit.tfdataset import ensure_nd, apply_to_mapping +from .network import Discriminator, Generator, GeneratorConfig, DiscriminatorConfig +from fv3fit.pytorch.graph.train import ( + get_scalers, + get_mapping_scale_func, + get_Xy_dataset, +) +from toolz import curry +import logging +import numpy as np + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class CycleGANHyperparameters(Hyperparameters): + + state_variables: List[str] + normalization_fit_samples: int = 50_000 + optimizer_config: OptimizerConfig = dataclasses.field( + default_factory=lambda: OptimizerConfig("AdamW") + ) + generator: GeneratorConfig = dataclasses.field( + default_factory=lambda: GeneratorConfig() + ) + discriminator: DiscriminatorConfig = dataclasses.field( + default_factory=lambda: DiscriminatorConfig() + ) + training_loop: "CycleGANTrainingConfig" = dataclasses.field( + default_factory=lambda: CycleGANTrainingConfig() + ) + loss: LossConfig = LossConfig(loss_type="mse") + + @property + def variables(self): + return tuple(self.state_variables) + + +def flatten_dims(dataset: tf.data.Dataset) -> tf.data.Dataset: + """Transform [batch, time, tile, x, y, z] to [sample, x, y, z]""" + return dataset.unbatch().unbatch().unbatch() + + +class CycleGANTrainingConfig: + + n_epoch: int = 20 + shuffle_buffer_size: int = 10 + samples_per_batch: int = 1 + validation_batch_size: Optional[int] = None + + def fit_loop( + self, + train_model: "CycleGAN", + train_data: tf.data.Dataset, + validation_data: Optional[tf.data.Dataset], + ) -> None: + """ + Args: + train_model: cycle-GAN to train + train_data: training dataset containing samples to be passed to the model, + should have dimensions [sample, time, tile, x, y, z] + validation_data: validation dataset containing samples to be passed + to the model, should have dimensions [sample, time, tile, x, y, z] + """ + + train_data = train_data.shuffle(buffer_size=self.shuffle_buffer_size).batch( + self.samples_per_batch + ) + train_data = tfds.as_numpy(train_data) + if validation_data is not None: + if self.validation_batch_size is None: + validation_batch_size = sequence_size(validation_data) + else: + validation_batch_size = self.validation_batch_size + validation_data = validation_data.batch(validation_batch_size) + validation_data = tfds.as_numpy(validation_data) + for i in range(1, self.n_epoch + 1): + logger.info("starting epoch %d", i) + train_losses = [] + for batch_state in train_data: + train_losses.append(train_model.train_on_batch(*batch_state)) + train_loss = torch.mean(torch.stack(train_losses)) + logger.info("train_loss: %f", train_loss) + if validation_data is not None: + val_loss = train_model.evaluate_on_dataset(validation_data) + logger.info("val_loss %f", val_loss) + + +@register_training_function("cyclegan", CycleGANHyperparameters) +def train_cyclegan( + hyperparameters: CycleGANHyperparameters, + train_batches: tf.data.Dataset, + validation_batches: Optional[tf.data.Dataset], +) -> PytorchPredictor: + """ + Train a denoising autoencoder for cubed sphere data. + + Args: + hyperparameters: configuration for training + train_batches: training data, as a dataset of Mapping[str, tf.Tensor] + where each tensor has dimensions [sample, time, tile, x, y(, z)] + validation_batches: validation data, as a dataset of Mapping[str, tf.Tensor] + where each tensor has dimensions [sample, time, tile, x, y(, z)] + """ + train_batches = train_batches.map(apply_to_mapping(ensure_nd(6))) + sample_batch = next( + iter(train_batches.unbatch().batch(hyperparameters.normalization_fit_samples)) + ) + + scalers = get_scalers(sample_batch) + mapping_scale_func = get_mapping_scale_func(scalers) + + get_state = curry(get_Xy_dataset)( + state_variables=hyperparameters.state_variables, + n_dims=6, # [batch, time, tile, x, y, z] + mapping_scale_func=mapping_scale_func, + ) + + if validation_batches is not None: + val_state = get_state(data=validation_batches) + else: + val_state = None + + train_state = get_state(data=train_batches) + + train_model = build_model( + hyperparameters.generator, n_state=next(iter(train_state)).shape[-1] + ) + + train_state = flatten_dims(train_state) + if validation_batches is not None: + val_state = flatten_dims(val_state) + + hyperparameters.training_loop.fit_loop( + train_model=train_model, train_data=train_state, validation_data=val_state, + ) + + predictor = PytorchPredictor( + input_variables=hyperparameters.state_variables, + output_variables=hyperparameters.state_variables, + model=train_model, + scalers=scalers, + ) + return predictor + + +class ReplayBuffer: + + # To reduce model oscillation during training, we update the discriminator + # using a history of generated data instead of the most recently generated data + # according to Shrivastava et al. (2017). + + def __init__(self, max_size=50): + if max_size <= 0: + raise ValueError("max_size must be positive") + self.max_size = max_size + self.data = [] + + def push_and_pop(self, data: torch.Tensor) -> torch.autograd.Variable: + to_return = [] + for element in data.data: + element = torch.unsqueeze(element, 0) + if len(self.data) < self.max_size: + self.data.append(element) + to_return.append(element) + else: + if random.uniform(0, 1) > 0.5: + i = random.randint(0, self.max_size - 1) + to_return.append(self.data[i].clone()) + self.data[i] = element + else: + to_return.append(element) + return torch.autograd.Variable(torch.cat(to_return)) + + +class StatsCollector: + def __init__(self, n_dims_keep: int): + self.n_dims_keep = n_dims_keep + self._sum = np.asarray(0.0, dtype=np.float64) + self._sum_squared = np.asarray(0.0, dtype=np.float64) + self._count = 0 + + def observe(self, data: np.ndarray): + mean_dims = tuple(range(0, len(data.shape) - self.n_dims_keep)) + data = data.astype(np.float64) + self._sum += data.mean(dims=mean_dims) + self._sum_squared += (data ** 2).mean(dims=mean_dims) + self._count += 1 + + @property + def mean(self) -> np.ndarray: + return self._sum / self._count + + @property + def std(self) -> np.ndarray: + return np.sqrt(self._sum_squared / self._count - self.mean() ** 2) + + +def get_r2(predicted, target) -> float: + """ + Compute the R^2 statistic for the predicted and target data. + """ + return ( + 1.0 + - ((target - predicted) ** 2).mean() / ((target - target.mean()) ** 2).mean() + ) + + +class CycleGAN: + + # This class based loosely on + # https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/cycle_gan_model.py + + # Copyright Facebook, BSD license + # https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/c99ce7c4e781712e0252c6127ad1a4e8021cc489/LICENSE + + generator_a_to_b: Generator + generator_b_to_a: Generator + discriminator_a: Discriminator + discriminator_b: Discriminator + optimizer_generator: torch.optim.Optimizer + optimizer_discriminator: torch.optim.Optimizer + identity_loss: torch.nn.Module + cycle_loss: torch.nn.Module + gan_loss: torch.nn.Module + batch_size: int + identity_weight: float = 0.5 + cycle_weight: float = 1.0 + gan_weight: float = 1.0 + + def __post_init__(self): + self.target_real = torch.autograd.Variable( + torch.Tensor(self.batch_size).fill_(1.0), requires_grad=False + ) + self.target_fake = torch.autograd.Variable( + torch.Tensor(self.batch_size).fill_(0.0), requires_grad=False + ) + self.fake_a_buffer = ReplayBuffer() + self.fake_b_buffer = ReplayBuffer() + + def evaluate_on_dataset( + self, dataset: tf.data.Dataset, n_dims_keep: int = 4 + ) -> Dict[str, float]: + stats_real_a = StatsCollector(n_dims_keep) + stats_real_b = StatsCollector(n_dims_keep) + stats_gen_a = StatsCollector(n_dims_keep) + stats_gen_b = StatsCollector(n_dims_keep) + real_a: np.ndarray + real_b: np.ndarray + for real_a, real_b in dataset: + stats_real_a.observe(real_a) + stats_real_b.observe(real_b) + gen_a: torch.Tensor = self.generator_a_to_b( + torch.as_tensor(real_a).float().to(DEVICE) + ) + gen_b: torch.Tensor = self.generator_b_to_a( + torch.as_tensor(real_b).float().to(DEVICE) + ) + stats_gen_a.observe(gen_a.detach().cpu().numpy()) + stats_gen_b.observe(gen_b.detach().cpu().numpy()) + metrics = { + "r2_mean_a": get_r2(stats_real_a.mean, stats_gen_a.mean), + "bias_mean_a": np.mean(stats_real_a.mean - stats_gen_a.mean), + "r2_mean_b": get_r2(stats_real_b.mean, stats_gen_b.mean), + "bias_mean_b": np.mean(stats_real_b.mean - stats_gen_b.mean), + "r2_std_a": get_r2(stats_real_a.std, stats_gen_a.std), + "bias_std_a": np.mean(stats_real_a.std - stats_gen_a.std), + "r2_std_b": get_r2(stats_real_b.std, stats_gen_b.std), + "bias_std_b": np.mean(stats_real_b.std - stats_gen_b.std), + } + return metrics + + def train_on_batch(self, real_a: torch.Tensor, real_b: torch.Tensor) -> float: + fake_b = self.generator_a_to_b(real_a) + fake_a = self.generator_b_to_a(real_b) + reconstructed_a = self.generator_b_to_a(fake_b) + reconstructed_b = self.generator_a_to_b(fake_a) + + # Generators A2B and B2A ###### + + # don't update discriminators when training generators to fool them + set_requires_grad( + [self.discriminator_a, self.discriminator_b], requires_grad=False + ) + + self.optimizer_generator.zero_grad() + + # Identity loss + # G_A2B(B) should equal B if real B is fed + same_b = self.generator_a_to_b(real_b) + loss_identity_b = self.identity_loss(same_b, real_b) * self.identity_weight + # G_B2A(A) should equal A if real A is fed + same_a = self.generator_b_to_a(real_b) + loss_identity_a = self.identity_loss(same_a, real_a) * self.identity_weight + + # GAN loss + fake_b = self.generator_a_to_b(real_a) + pred_fake = self.discriminator_b(fake_b) + loss_gan_a_to_b = self.gan_loss(pred_fake, self.target_real) + + fake_A = self.generator_b_to_a(real_b) + pred_fake = self.discriminator_a(fake_A) + loss_gan_b_to_a = self.gan_loss(pred_fake, self.target_real) + + # Cycle loss + loss_cycle_a_b_a = self.cycle_loss(reconstructed_a, real_a) * self.cycle_weight + loss_cycle_b_a_b = self.cycle_loss(reconstructed_b, real_b) * self.cycle_weight + + # Total loss + loss_g: torch.Tensor = ( + loss_identity_a + + loss_identity_b + + loss_gan_a_to_b + + loss_gan_b_to_a + + loss_cycle_a_b_a + + loss_cycle_b_a_b + ) + loss_g.backward() + + self.optimizer_generator.step() + + # Discriminators A and B ###### + + # do update discriminators when training them to identify samples + set_requires_grad( + [self.discriminator_a, self.discriminator_b], requires_grad=True + ) + + self.optimizer_discriminator.zero_grad() + + # Real loss + pred_real = self.discriminator_a(real_a) + loss_d_a_real = self.gan_loss(pred_real, self.target_real) + + # Fake loss + fake_a = self.fake_a_buffer.push_and_pop(fake_a) + pred_a_fake = self.discriminator_a(fake_a.detach()) + loss_d_a_fake = self.gan_loss(pred_a_fake, self.target_fake) + + # Real loss + pred_real = self.discriminator_b(real_b) + loss_d_b_real = self.gan_loss(pred_real, self.target_real) + + # Fake loss + fake_b = self.fake_b_buffer.push_and_pop(fake_b) + pred_b_fake = self.discriminator_b(fake_b.detach()) + loss_d_b_fake = self.gan_loss(pred_b_fake, self.target_fake) + + # Total loss + loss_d: torch.Tensor = ( + loss_d_b_real + loss_d_b_fake + loss_d_a_real + loss_d_a_fake + ) * 0.5 + loss_d.backward() + + self.optimizer_discriminator.step() + return float(loss_g + loss_d) + + +def build_model(config: GeneratorConfig, n_state: int) -> CycleGAN: + return Generator( + channels=n_state, + n_convolutions=config.n_convolutions, + n_resnet=config.n_resnet, + max_filters=config.max_filters, + ) + + +def set_requires_grad(nets: List[torch.nn.Module], requires_grad=False): + """Set requies_grad=Fasle for all the networks to avoid unnecessary computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad diff --git a/external/fv3fit/tests/training/test_autoencoder.py b/external/fv3fit/tests/training/test_autoencoder.py index cc58b2cfa7..f1c74306d7 100644 --- a/external/fv3fit/tests/training/test_autoencoder.py +++ b/external/fv3fit/tests/training/test_autoencoder.py @@ -3,62 +3,28 @@ from typing import Sequence from fv3fit.pytorch.cyclegan import AutoencoderHyperparameters, train_autoencoder from fv3fit.pytorch.cyclegan.train import TrainingConfig -from fv3fit.tfdataset import iterable_to_tfdataset +from fv3fit.data.synthetic import SyntheticWaves import collections import os import fv3fit.pytorch import fv3fit -def get_tfdataset(nsamples, nbatch, ntime, nx, ny, nz): - ntile = 6 - - grid_x = np.arange(0, nx, dtype=np.float32) - grid_y = np.arange(0, ny, dtype=np.float32) - grid_x, grid_y = np.broadcast_arrays(grid_x[:, None], grid_y[None, :]) - grid_x = grid_x[None, None, None, :, :, None] - grid_y = grid_y[None, None, None, :, :, None] - - def sample_iterator(): - # creates a timeseries where each time is the negation of time before it - for _ in range(nsamples): - ax = np.random.uniform(0.1, 1.5, size=(nbatch, 1, ntile, nz))[ - :, :, :, None, None, : - ] - bx = np.random.uniform(8, 16, size=(nbatch, 1, ntile, nz))[ - :, :, :, None, None, : - ] - cx = np.random.uniform(0.0, 2 * np.pi, size=(nbatch, 1, ntile, nz))[ - :, :, :, None, None, : - ] - ay = np.random.uniform(0.1, 1.5, size=(nbatch, 1, ntile, nz))[ - :, :, :, None, None, : - ] - by = np.random.uniform(8, 16, size=(nbatch, 1, ntile, nz))[ - :, :, :, None, None, : - ] - cy = np.random.uniform(0.0, 2 * np.pi, size=(nbatch, 1, ntile, nz))[ - :, :, :, None, None, : - ] - a = ( - ax - * np.sin(2 * np.pi * grid_x / bx + cx) - * ay - * np.sin(2 * np.pi * grid_y / by + cy) - ) - start = { - "a": a.astype(np.float32), - "b": -a[..., 0].astype(np.float32), - } - out = {key: [value] for key, value in start.items()} - for _ in range(ntime - 1): - for varname in start.keys(): - out[varname].append(out[varname][-1] * -1.0) - for varname in out: - out[varname] = np.concatenate(out[varname], axis=1) - yield out - - return iterable_to_tfdataset(list(sample_iterator())) +def get_tfdataset(nsamples, nbatch, ntime, nx, nz): + config = SyntheticWaves( + nsamples=nsamples, + nbatch=nbatch, + ntime=ntime, + nx=nx, + nz=nz, + scalar_names=["b"], + scale_min=0.5, + scale_max=1.5, + period_min=8, + period_max=16, + ) + dataset = config.open_tfdataset(local_download_path=None, variable_names=["a", "b"]) + return dataset def tfdataset_to_xr_dataset(tfdataset, dims: Sequence[str]): @@ -89,8 +55,8 @@ def test_autoencoder(tmpdir): os.chdir(tmpdir) # need a larger nx, ny for the sample data here since we're training # on whether we can autoencode sin waves, and need to resolve full cycles - nx, ny = 32, 32 - sizes = {"nbatch": 2, "ntime": 2, "nx": nx, "ny": ny, "nz": 2} + nx = 32 + sizes = {"nbatch": 2, "ntime": 2, "nx": nx, "nz": 2} state_variables = ["a", "b"] train_tfdataset = get_tfdataset(nsamples=20, **sizes) val_tfdataset = get_tfdataset(nsamples=3, **sizes) @@ -105,7 +71,7 @@ def test_autoencoder(tmpdir): ) predictor = train_autoencoder(hyperparameters, train_tfdataset, val_tfdataset) # for test, need one continuous series so we consistently flip sign - test_sizes = {"nbatch": 1, "ntime": 100, "nx": nx, "ny": ny, "nz": 2} + test_sizes = {"nbatch": 1, "ntime": 100, "nx": nx, "nz": 2} test_xrdataset = tfdataset_to_xr_dataset( get_tfdataset(nsamples=1, **test_sizes), dims=["time", "tile", "x", "y", "z"] ) @@ -134,8 +100,8 @@ def test_autoencoder_overfit(tmpdir): os.chdir(tmpdir) # need a larger nx, ny for the sample data here since we're training # on whether we can autoencode sin waves, and need to resolve full cycles - nx, ny = 32, 32 - sizes = {"nbatch": 1, "ntime": 1, "nx": nx, "ny": ny, "nz": 2} + nx = 32 + sizes = {"nbatch": 1, "ntime": 1, "nx": nx, "nz": 2} state_variables = ["a", "b"] train_tfdataset = get_tfdataset(nsamples=1, **sizes) train_tfdataset = train_tfdataset.cache() # needed to keep sample identical diff --git a/external/fv3fit/tests/training/test_cyclegan.py b/external/fv3fit/tests/training/test_cyclegan.py new file mode 100644 index 0000000000..823f3b6495 --- /dev/null +++ b/external/fv3fit/tests/training/test_cyclegan.py @@ -0,0 +1,156 @@ +import numpy as np +import xarray as xr +from typing import Sequence +from fv3fit.pytorch.cyclegan import CycleGANHyperparameters, train_cyclegan +from fv3fit.pytorch.cyclegan.train import TrainingConfig +from fv3fit.data import CycleGANLoader, SyntheticWaves +import collections +import os +import fv3fit.pytorch +import fv3fit + + +def get_tfdataset(nsamples, nbatch, ntime, nx, nz): + config = CycleGANLoader( + domain_configs=[ + SyntheticWaves( + nsamples=nsamples, + nbatch=nbatch, + ntime=ntime, + nx=nx, + nz=nz, + scalar_names=["b"], + scale_min=0.1, + scale_max=1.0, + period_min=4, + period_max=7, + ), + SyntheticWaves( + nsamples=nsamples, + nbatch=nbatch, + ntime=ntime, + nx=nx, + nz=nz, + scalar_names=["b"], + scale_min=0.5, + scale_max=1.5, + period_min=8, + period_max=16, + ), + ] + ) + dataset = config.open_tfdataset(local_download_path=None, variable_names=["a", "b"]) + return dataset + + +def tfdataset_to_xr_dataset(tfdataset, dims: Sequence[str]): + """ + Returns a [time, tile, x, y, z] dataset needed for evaluation. + + Assumes input samples have shape [sample, time, tile, x, y(, z)], will + concatenate samples along the time axis before returning. + """ + data_sequences = collections.defaultdict(list) + for sample in tfdataset: + for name, value in sample.items(): + data_sequences[name].append( + value.numpy().reshape( + [value.shape[0] * value.shape[1]] + list(value.shape[2:]) + ) + ) + data_vars = {} + for name in data_sequences: + data = np.concatenate(data_sequences[name]) + data_vars[name] = xr.DataArray(data, dims=dims[: len(data.shape)]) + return xr.Dataset(data_vars) + + +def test_cyclegan(tmpdir): + fv3fit.set_random_seed(0) + # run the test in a temporary directory to delete artifacts when done + os.chdir(tmpdir) + # need a larger nx, ny for the sample data here since we're training + # on whether we can autoencode sin waves, and need to resolve full cycles + nx, ny = 32, 32 + sizes = {"nbatch": 2, "ntime": 2, "nx": nx, "ny": ny, "nz": 2} + state_variables = ["a", "b"] + train_tfdataset = get_tfdataset(nsamples=20, **sizes) + val_tfdataset = get_tfdataset(nsamples=3, **sizes) + hyperparameters = CycleGANHyperparameters( + state_variables=state_variables, + generator=fv3fit.pytorch.GeneratorConfig( + n_convolutions=2, n_resnet=3, max_filters=32 + ), + training_loop=TrainingConfig(n_epoch=5, samples_per_batch=2), + optimizer_config=fv3fit.pytorch.OptimizerConfig(name="Adam",), + noise_amount=0.5, + ) + predictor = train_cyclegan(hyperparameters, train_tfdataset, val_tfdataset) + # for test, need one continuous series so we consistently flip sign + test_sizes = {"nbatch": 1, "ntime": 100, "nx": nx, "ny": ny, "nz": 2} + test_xrdataset = tfdataset_to_xr_dataset( + get_tfdataset(nsamples=1, **test_sizes), dims=["time", "tile", "x", "y", "z"] + ) + predicted = predictor.predict(test_xrdataset) + reference = test_xrdataset + # plotting code to uncomment if you'd like to manually check the results: + # for i in range(6): + # fig, ax = plt.subplots(1, 2) + # vmin = reference["a"][0, i, :, :, 0].values.min() + # vmax = reference["a"][0, i, :, :, 0].values.max() + # ax[0].imshow(reference["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) + # ax[1].imshow(predicted["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) + # plt.tight_layout() + # plt.show() + bias = predicted.isel(time=1) - reference.isel(time=1) + mean_bias: xr.Dataset = bias.mean() + mse: xr.Dataset = (bias ** 2).mean() ** 0.5 + for varname in state_variables: + assert np.abs(mean_bias[varname]) < 0.1 + assert mse[varname] < 0.1 + + +def test_cyclegan_overfit(tmpdir): + fv3fit.set_random_seed(0) + # run the test in a temporary directory to delete artifacts when done + os.chdir(tmpdir) + # need a larger nx, ny for the sample data here since we're training + # on whether we can autoencode sin waves, and need to resolve full cycles + nx, ny = 32, 32 + sizes = {"nbatch": 1, "ntime": 1, "nx": nx, "ny": ny, "nz": 2} + state_variables = ["a", "b"] + train_tfdataset = get_tfdataset(nsamples=1, **sizes) + train_tfdataset = train_tfdataset.cache() # needed to keep sample identical + hyperparameters = CycleGANHyperparameters( + state_variables=state_variables, + generator=fv3fit.pytorch.GeneratorConfig( + n_convolutions=2, n_resnet=1, max_filters=32 + ), + training_loop=TrainingConfig(n_epoch=100, samples_per_batch=6), + optimizer_config=fv3fit.pytorch.OptimizerConfig(name="Adam",), + noise_amount=0.0, + ) + predictor = train_cyclegan( + hyperparameters, train_tfdataset, validation_batches=None + ) + # for test, need one continuous series so we consistently flip sign + test_xrdataset = tfdataset_to_xr_dataset( + train_tfdataset, dims=["time", "tile", "x", "y", "z"] + ) + predicted = predictor.predict(test_xrdataset) + reference = test_xrdataset + # plotting code to uncomment if you'd like to manually check the results: + # for i in range(6): + # fig, ax = plt.subplots(1, 2) + # vmin = reference["a"][0, i, :, :, 0].values.min() + # vmax = reference["a"][0, i, :, :, 0].values.max() + # ax[0].imshow(reference["a"][0, i, :, :, 0].values) # , vmin=vmin, vmax=vmax) + # ax[1].imshow(predicted["a"][0, i, :, :, 0].values) # , vmin=vmin, vmax=vmax) + # plt.tight_layout() + # plt.show() + bias = predicted - reference + mean_bias: xr.Dataset = bias.mean() + rmse: xr.Dataset = (bias ** 2).mean() + for varname in state_variables: + assert np.abs(mean_bias[varname]) < 0.1 + assert rmse[varname] < 0.1 From c84d0d94d5737956d64ec5cbf568540d703db1c5 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Thu, 25 Aug 2022 21:33:59 +0000 Subject: [PATCH 16/55] WIP, cyclegan is training but not converging --- external/fv3fit/fv3fit/data/synthetic.py | 33 +++- .../fv3fit/pytorch/cyclegan/__init__.py | 7 +- .../fv3fit/fv3fit/pytorch/cyclegan/network.py | 31 +++- .../fv3fit/fv3fit/pytorch/cyclegan/train.py | 2 +- .../fv3fit/pytorch/cyclegan/train_cyclegan.py | 169 +++++++++++++----- external/fv3fit/fv3fit/pytorch/graph/train.py | 40 ++++- external/fv3fit/fv3fit/pytorch/loss.py | 2 +- external/fv3fit/fv3fit/pytorch/predict.py | 6 + .../fv3fit/fv3fit/pytorch/training_loop.py | 4 +- external/fv3fit/fv3fit/tfdataset.py | 12 +- .../fv3fit/tests/training/test_cyclegan.py | 128 +++++++------ 11 files changed, 301 insertions(+), 133 deletions(-) diff --git a/external/fv3fit/fv3fit/data/synthetic.py b/external/fv3fit/fv3fit/data/synthetic.py index d30a4c8f59..597000f000 100644 --- a/external/fv3fit/fv3fit/data/synthetic.py +++ b/external/fv3fit/fv3fit/data/synthetic.py @@ -10,6 +10,22 @@ @register_tfdataset_loader @dataclasses.dataclass class SyntheticWaves(TFDatasetLoader): + """ + Attributes: + nsamples: number of samples to generate per batch + nbatch: number of batches to generate + nx: length of x- and y-dimensions to generate + nz: length of z-dimension to generate + scalar_names: names to generate as scalars instead of + vertically-resolved variables + scale_min: minimum amplitude of waves + scale_max: maximum amplitude of waves + period_min: minimum period of waves + period_max: maximum period of waves + phase_range: fraction of 2*pi to use for possible range of + random phase, should be a value between 0 and 1. + + """ nsamples: int nbatch: int @@ -21,6 +37,7 @@ class SyntheticWaves(TFDatasetLoader): scale_max: float = 1.0 period_min: float = 8.0 period_max: float = 16.0 + phase_range: float = 1.0 def open_tfdataset( self, local_download_path: Optional[str], variable_names: Sequence[str], @@ -47,6 +64,7 @@ def open_tfdataset( scale_max=self.scale_max, period_min=self.period_min, period_max=self.period_max, + phase_range=self.phase_range, ) if local_download_path is not None: dataset = dataset.cache(local_download_path) @@ -72,7 +90,8 @@ def get_tfdataset( scale_min: float, scale_max: float, period_min: float, - period_max: float + period_max: float, + phase_range: float, ): ntile = 6 @@ -91,18 +110,18 @@ def sample_iterator(): bx = np.random.uniform(period_min, period_max, size=(nbatch, 1, ntile, nz))[ :, :, :, None, None, : ] - cx = np.random.uniform(0.0, 2 * np.pi, size=(nbatch, 1, ntile, nz))[ - :, :, :, None, None, : - ] + cx = np.random.uniform( + 0.0, 2 * np.pi * phase_range, size=(nbatch, 1, ntile, nz) + )[:, :, :, None, None, :] ay = np.random.uniform(scale_min, scale_max, size=(nbatch, 1, ntile, nz))[ :, :, :, None, None, : ] by = np.random.uniform(period_min, period_max, size=(nbatch, 1, ntile, nz))[ :, :, :, None, None, : ] - cy = np.random.uniform(0.0, 2 * np.pi, size=(nbatch, 1, ntile, nz))[ - :, :, :, None, None, : - ] + cy = np.random.uniform( + 0.0, 2 * np.pi * phase_range, size=(nbatch, 1, ntile, nz) + )[:, :, :, None, None, :] data = ( ax * np.sin(2 * np.pi * grid_x / bx + cx) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py b/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py index 88f8be5d0e..ea0eb6db5c 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py @@ -1,2 +1,7 @@ from .train import train_autoencoder, AutoencoderHyperparameters, GeneratorConfig -from .train_cyclegan import train_cyclegan, CycleGANHyperparameters +from .train_cyclegan import ( + train_cyclegan, + CycleGANHyperparameters, + CycleGANNetworkConfig, + CycleGANTrainingConfig, +) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/network.py b/external/fv3fit/fv3fit/pytorch/cyclegan/network.py index 010117ff05..d40927a1bf 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/network.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/network.py @@ -3,6 +3,7 @@ import torch.nn as nn from toolz import curry import dataclasses +from fv3fit.pytorch.optimizer import OptimizerConfig def relu_activation(**kwargs): @@ -148,7 +149,7 @@ class GeneratorConfig: n_resnet: int = 3 max_filters: int = 256 - def instance( + def build( self, channels: int, convolution: ConvolutionFactoryFactory = regular_convolution, @@ -164,7 +165,21 @@ def instance( @dataclasses.dataclass class DiscriminatorConfig: - pass + + n_convolutions: int = 3 + max_filters: int = 256 + + def build( + self, + channels: int, + convolution: ConvolutionFactoryFactory = regular_convolution, + ): + return Discriminator( + in_channels=channels, + n_convolutions=self.n_convolutions, + max_filters=self.max_filters, + convolution=convolution, + ) class ResnetBlock(nn.Module): @@ -231,7 +246,7 @@ def __init__( in_channels=in_channels, out_channels=min_filters, convolution_factory=convolution(kernel_size=3, stride=2, padding=1), - activation_factory=leakyrelu_activation(alpha=0.2), + activation_factory=leakyrelu_activation(negative_slope=0.2), ) ] for i in range(1, n_convolutions): @@ -240,25 +255,27 @@ def __init__( in_channels=min_filters * 2 ** (i - 1), out_channels=min_filters * 2 ** i, convolution_factory=convolution(kernel_size=3, stride=2, padding=1), - activation_factory=leakyrelu_activation(alpha=0.2), + activation_factory=leakyrelu_activation(negative_slope=0.2), ) ) final_conv = ConvBlock( in_channels=max_filters, out_channels=max_filters, convolution_factory=convolution(kernel_size=3), - activation_factory=leakyrelu_activation(alpha=0.2), + activation_factory=leakyrelu_activation(negative_slope=0.2), ) patch_output = ConvBlock( in_channels=max_filters, out_channels=1, convolution_factory=convolution(kernel_size=3), - activation_factory=leakyrelu_activation(alpha=0.2), + activation_factory=leakyrelu_activation(negative_slope=0.2), ) self._sequential = nn.Sequential(*convs, final_conv, patch_output) def forward(self, inputs): - return self._sequential(inputs) + inputs = inputs.permute(0, 3, 1, 2) + outputs = self._sequential(inputs) + return outputs.permute(0, 2, 3, 1) class Generator(nn.Module): diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train.py index 882d0943ce..f660ab0d03 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train.py @@ -130,4 +130,4 @@ def train_autoencoder( def build_model(config: GeneratorConfig, n_state: int) -> Generator: - return config.instance(channels=n_state) + return config.build(channels=n_state) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py index fd267aeedf..b736cb08d9 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py @@ -1,3 +1,4 @@ +import itertools from fv3fit._shared.hyperparameters import Hyperparameters import random import dataclasses @@ -12,18 +13,21 @@ from fv3fit._shared import register_training_function from typing import ( + Callable, Dict, List, + Mapping, Optional, + Sequence, + Tuple, ) -from fv3fit.tfdataset import ensure_nd, apply_to_mapping +from fv3fit.tfdataset import ensure_nd, apply_to_mapping, apply_to_tuple from .network import Discriminator, Generator, GeneratorConfig, DiscriminatorConfig from fv3fit.pytorch.graph.train import ( get_scalers, get_mapping_scale_func, - get_Xy_dataset, + get_Xy_map_fn as get_Xy_map_fn_single_domain, ) -from toolz import curry import logging import numpy as np @@ -35,14 +39,8 @@ class CycleGANHyperparameters(Hyperparameters): state_variables: List[str] normalization_fit_samples: int = 50_000 - optimizer_config: OptimizerConfig = dataclasses.field( - default_factory=lambda: OptimizerConfig("AdamW") - ) - generator: GeneratorConfig = dataclasses.field( - default_factory=lambda: GeneratorConfig() - ) - discriminator: DiscriminatorConfig = dataclasses.field( - default_factory=lambda: DiscriminatorConfig() + network: "CycleGANNetworkConfig" = dataclasses.field( + default_factory=lambda: CycleGANNetworkConfig() ) training_loop: "CycleGANTrainingConfig" = dataclasses.field( default_factory=lambda: CycleGANTrainingConfig() @@ -54,11 +52,7 @@ def variables(self): return tuple(self.state_variables) -def flatten_dims(dataset: tf.data.Dataset) -> tf.data.Dataset: - """Transform [batch, time, tile, x, y, z] to [sample, x, y, z]""" - return dataset.unbatch().unbatch().unbatch() - - +@dataclasses.dataclass class CycleGANTrainingConfig: n_epoch: int = 20 @@ -96,12 +90,46 @@ def fit_loop( logger.info("starting epoch %d", i) train_losses = [] for batch_state in train_data: - train_losses.append(train_model.train_on_batch(*batch_state)) - train_loss = torch.mean(torch.stack(train_losses)) + state_a = torch.as_tensor(batch_state[0]).float().to(DEVICE) + state_b = torch.as_tensor(batch_state[1]).float().to(DEVICE) + train_losses.append(train_model.train_on_batch(state_a, state_b)) + train_loss = np.mean(train_losses) logger.info("train_loss: %f", train_loss) if validation_data is not None: val_loss = train_model.evaluate_on_dataset(validation_data) - logger.info("val_loss %f", val_loss) + logger.info("val_loss %s", val_loss) + + +def apply_to_tuple_mapping(func): + # not sure why, but tensorflow doesn't like parsing + # apply_to_tuple(apply_to_maping(func)), so we do it manually + def wrapped(*tuple_of_mapping): + return tuple( + {name: func(value) for name, value in mapping.items()} + for mapping in tuple_of_mapping + ) + + return wrapped + + +def get_Xy_map_fn( + state_variables: Sequence[str], + n_dims: int, # [batch, time, tile, x, y, z] + mapping_scale_funcs: Tuple[ + Callable[[Mapping[str, np.ndarray]], Mapping[str, np.ndarray]], ... + ], +): + funcs = tuple( + get_Xy_map_fn_single_domain( + state_variables=state_variables, n_dims=n_dims, mapping_scale_func=func + ) + for func in mapping_scale_funcs + ) + + def Xy_map_fn(*data: Mapping[str, np.ndarray]): + return tuple(func(entry) for func, entry in zip(funcs, data)) + + return Xy_map_fn @register_training_function("cyclegan", CycleGANHyperparameters) @@ -120,34 +148,36 @@ def train_cyclegan( validation_batches: validation data, as a dataset of Mapping[str, tf.Tensor] where each tensor has dimensions [sample, time, tile, x, y(, z)] """ - train_batches = train_batches.map(apply_to_mapping(ensure_nd(6))) + train_batches = train_batches.map(apply_to_tuple_mapping(ensure_nd(6))) sample_batch = next( iter(train_batches.unbatch().batch(hyperparameters.normalization_fit_samples)) ) - scalers = get_scalers(sample_batch) - mapping_scale_func = get_mapping_scale_func(scalers) + scalers = tuple(get_scalers(entry) for entry in sample_batch) + mapping_scale_funcs = tuple(get_mapping_scale_func(scaler) for scaler in scalers) - get_state = curry(get_Xy_dataset)( + get_Xy = get_Xy_map_fn( state_variables=hyperparameters.state_variables, n_dims=6, # [batch, time, tile, x, y, z] - mapping_scale_func=mapping_scale_func, + mapping_scale_funcs=mapping_scale_funcs, ) if validation_batches is not None: - val_state = get_state(data=validation_batches) + val_state = validation_batches.map(get_Xy).unbatch() else: val_state = None - train_state = get_state(data=train_batches) + train_state = train_batches.map(get_Xy).unbatch() - train_model = build_model( - hyperparameters.generator, n_state=next(iter(train_state)).shape[-1] + train_model = hyperparameters.network.build( + n_state=next(iter(train_state))[0].shape[-1], + n_batch=hyperparameters.training_loop.samples_per_batch, ) - train_state = flatten_dims(train_state) + # remove time and tile dimensions, while we're using regular convolution + train_state = train_state.unbatch().unbatch() if validation_batches is not None: - val_state = flatten_dims(val_state) + val_state = val_state.unbatch().unbatch() hyperparameters.training_loop.fit_loop( train_model=train_model, train_data=train_state, validation_data=val_state, @@ -156,8 +186,9 @@ def train_cyclegan( predictor = PytorchPredictor( input_variables=hyperparameters.state_variables, output_variables=hyperparameters.state_variables, - model=train_model, - scalers=scalers, + model=train_model.generator_a_to_b, + scalers=scalers[0], + output_scalers=scalers[1], ) return predictor @@ -194,15 +225,15 @@ def push_and_pop(self, data: torch.Tensor) -> torch.autograd.Variable: class StatsCollector: def __init__(self, n_dims_keep: int): self.n_dims_keep = n_dims_keep - self._sum = np.asarray(0.0, dtype=np.float64) - self._sum_squared = np.asarray(0.0, dtype=np.float64) + self._sum = 0.0 + self._sum_squared = 0.0 self._count = 0 def observe(self, data: np.ndarray): mean_dims = tuple(range(0, len(data.shape) - self.n_dims_keep)) data = data.astype(np.float64) - self._sum += data.mean(dims=mean_dims) - self._sum_squared += (data ** 2).mean(dims=mean_dims) + self._sum += data.mean(axis=mean_dims) + self._sum_squared += (data ** 2).mean(axis=mean_dims) self._count += 1 @property @@ -211,19 +242,68 @@ def mean(self) -> np.ndarray: @property def std(self) -> np.ndarray: - return np.sqrt(self._sum_squared / self._count - self.mean() ** 2) + return np.sqrt(self._sum_squared / self._count - self.mean ** 2) def get_r2(predicted, target) -> float: """ Compute the R^2 statistic for the predicted and target data. """ - return ( - 1.0 - - ((target - predicted) ** 2).mean() / ((target - target.mean()) ** 2).mean() + return 1.0 - np.var(predicted - target) / np.var(target) + + +@dataclasses.dataclass +class CycleGANNetworkConfig: + generator_optimizer: OptimizerConfig = dataclasses.field( + default_factory=lambda: OptimizerConfig("Adam") + ) + discriminator_optimizer: OptimizerConfig = dataclasses.field( + default_factory=lambda: OptimizerConfig("Adam") ) + generator: "GeneratorConfig" = dataclasses.field( + default_factory=lambda: GeneratorConfig() + ) + discriminator: "DiscriminatorConfig" = dataclasses.field( + default_factory=lambda: DiscriminatorConfig() + ) + identity_loss: LossConfig = dataclasses.field(default_factory=LossConfig) + cycle_loss: LossConfig = dataclasses.field(default_factory=LossConfig) + gan_loss: LossConfig = dataclasses.field(default_factory=LossConfig) + identity_weight: float = 0.5 + cycle_weight: float = 1.0 + gan_weight: float = 1.0 + + def build(self, n_state: int, n_batch: int) -> "CycleGAN": + generator_a_to_b = self.generator.build(n_state) + generator_b_to_a = self.generator.build(n_state) + discriminator_a = self.discriminator.build(n_state) + discriminator_b = self.discriminator.build(n_state) + optimizer_generator = self.generator_optimizer.instance( + itertools.chain( + generator_a_to_b.parameters(), generator_b_to_a.parameters() + ) + ) + optimizer_discriminator = self.discriminator_optimizer.instance( + itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()) + ) + return CycleGAN( + generator_a_to_b=generator_a_to_b, + generator_b_to_a=generator_b_to_a, + discriminator_a=discriminator_a, + discriminator_b=discriminator_b, + optimizer_generator=optimizer_generator, + optimizer_discriminator=optimizer_discriminator, + identity_loss=self.identity_loss.instance, + cycle_loss=self.cycle_loss.instance, + gan_loss=self.gan_loss.instance, + batch_size=n_batch, + identity_weight=self.identity_weight, + cycle_weight=self.cycle_weight, + gan_weight=self.gan_weight, + ) +@dataclasses.dataclass class CycleGAN: # This class based loosely on @@ -374,15 +454,6 @@ def train_on_batch(self, real_a: torch.Tensor, real_b: torch.Tensor) -> float: return float(loss_g + loss_d) -def build_model(config: GeneratorConfig, n_state: int) -> CycleGAN: - return Generator( - channels=n_state, - n_convolutions=config.n_convolutions, - n_resnet=config.n_resnet, - max_filters=config.max_filters, - ) - - def set_requires_grad(nets: List[torch.nn.Module], requires_grad=False): """Set requies_grad=Fasle for all the networks to avoid unnecessary computations Parameters: diff --git a/external/fv3fit/fv3fit/pytorch/graph/train.py b/external/fv3fit/fv3fit/pytorch/graph/train.py index fc76d8c7ff..6fcb276f3b 100644 --- a/external/fv3fit/fv3fit/pytorch/graph/train.py +++ b/external/fv3fit/fv3fit/pytorch/graph/train.py @@ -101,18 +101,18 @@ def train_graph_model( scalers = get_scalers(sample) mapping_scale_func = get_mapping_scale_func(scalers) - get_state = curry(get_Xy_dataset)( + get_Xy = get_Xy_map_fn( state_variables=hyperparameters.state_variables, n_dims=6, # [batch, time, tile, x, y, z] mapping_scale_func=mapping_scale_func, ) if validation_batches is not None: - val_state = get_state(data=validation_batches).unbatch() + val_state = validation_batches.map(get_Xy).unbatch() else: val_state = None - train_state = get_state(data=train_batches).unbatch() + train_state = train_batches.map(get_Xy).unbatch() train_model = build_model( hyperparameters.graph_network, n_state=next(iter(train_state)).shape[-1] @@ -180,3 +180,37 @@ def map_fn(data): return data return data.map(map_fn) + + +def get_Xy_map_fn( + state_variables: Sequence[str], + n_dims: int, + mapping_scale_func: Callable[[Mapping[str, np.ndarray]], Mapping[str, np.ndarray]], +): + """ + Given a tf.data.Dataset with mappings from variable name to samples + return a tf.data.Dataset whose entries are tensors of the requested + state variables concatenated along the feature dimension. + + Args: + state_variables: names of variables to include in returned tensor + n_dims: number of dimensions of each sample, including feature dimension + mapping_scale_func: function which scales data stored as a mapping + from variable name to array + data: tf.data.Dataset with mappings from variable name + to sample tensors + + Returns: + tf.data.Dataset where each sample is a single tensor + containing normalized and concatenated state variables + """ + ensure_dims = apply_to_mapping(ensure_nd(n_dims)) + + def map_fn(data): + data = mapping_scale_func(data) + data = ensure_dims(data) + data = select_keys(state_variables, data) + data = tf.concat(data, axis=-1) + return data + + return map_fn diff --git a/external/fv3fit/fv3fit/pytorch/loss.py b/external/fv3fit/fv3fit/pytorch/loss.py index 8c7618ce2e..1823551fcf 100644 --- a/external/fv3fit/fv3fit/pytorch/loss.py +++ b/external/fv3fit/fv3fit/pytorch/loss.py @@ -18,7 +18,7 @@ def __post_init__(self): ) @property - def loss(self) -> torch.nn.Module: + def instance(self) -> torch.nn.Module: """ Returns the loss function described by the configuration. diff --git a/external/fv3fit/fv3fit/pytorch/predict.py b/external/fv3fit/fv3fit/pytorch/predict.py index f6717f3560..65e76e7b82 100644 --- a/external/fv3fit/fv3fit/pytorch/predict.py +++ b/external/fv3fit/fv3fit/pytorch/predict.py @@ -9,6 +9,7 @@ Hashable, Iterable, Mapping, + Optional, Sequence, Tuple, TypeVar, @@ -68,6 +69,7 @@ def __init__( output_variables: Iterable[Hashable], model: nn.Module, scalers: Mapping[Hashable, StandardScaler], + output_scalers: Optional[Mapping[Hashable, StandardScaler]] = None, ): """Initialize the predictor Args: @@ -79,6 +81,10 @@ def __init__( self.output_variables = output_variables self.model = model self.scalers = scalers + if output_scalers is None: + self.output_scalers = output_scalers + else: + self.output_scalers = scalers def predict(self, X: xr.Dataset) -> xr.Dataset: """ diff --git a/external/fv3fit/fv3fit/pytorch/training_loop.py b/external/fv3fit/fv3fit/pytorch/training_loop.py index a3b84aadc0..8bfade3d28 100644 --- a/external/fv3fit/fv3fit/pytorch/training_loop.py +++ b/external/fv3fit/fv3fit/pytorch/training_loop.py @@ -56,7 +56,7 @@ def fit_loop( def evaluate_on_batch(batch_state, model): batch_input = torch.as_tensor(batch_state[0]).float().to(DEVICE) batch_output = torch.as_tensor(batch_state[1]).float().to(DEVICE) - loss: torch.Tensor = loss_config.loss(model(batch_input), batch_output) + loss: torch.Tensor = loss_config.instance(model(batch_input), batch_output) return loss return _train_loop( @@ -116,7 +116,7 @@ def evaluate_on_batch(batch_state, model): batch_state=batch_state, model=train_model, multistep=self.multistep, - loss=loss_config.loss, + loss=loss_config.instance, ) return loss diff --git a/external/fv3fit/fv3fit/tfdataset.py b/external/fv3fit/fv3fit/tfdataset.py index 5fe20f0bdd..c160b2d331 100644 --- a/external/fv3fit/fv3fit/tfdataset.py +++ b/external/fv3fit/fv3fit/tfdataset.py @@ -11,22 +11,26 @@ Optional, Sequence, Tuple, + TypeVar, ) from toolz.functoolz import curry import loaders.typing +T_in = TypeVar("T_in") +T_out = TypeVar("T_out") + @curry def apply_to_mapping( - tensor_func: Callable[[tf.Tensor], tf.Tensor], data: Mapping[str, tf.Tensor] -) -> Dict[str, tf.Tensor]: + tensor_func: Callable[[T_in], T_out], data: Mapping[str, T_in] +) -> Dict[str, T_out]: return {name: tensor_func(tensor) for name, tensor in data.items()} @curry def apply_to_tuple( - tensor_func: Callable[[tf.Tensor], tf.Tensor], data: Tuple[tf.Tensor, ...] -) -> Tuple[tf.Tensor, ...]: + tensor_func: Callable[[T_in], T_out], data: Tuple[T_in, ...] +) -> Tuple[T_out, ...]: return tuple(tensor_func(tensor) for tensor in data) diff --git a/external/fv3fit/tests/training/test_cyclegan.py b/external/fv3fit/tests/training/test_cyclegan.py index 823f3b6495..5d1e6331f2 100644 --- a/external/fv3fit/tests/training/test_cyclegan.py +++ b/external/fv3fit/tests/training/test_cyclegan.py @@ -1,8 +1,12 @@ import numpy as np import xarray as xr from typing import Sequence -from fv3fit.pytorch.cyclegan import CycleGANHyperparameters, train_cyclegan -from fv3fit.pytorch.cyclegan.train import TrainingConfig +from fv3fit.pytorch.cyclegan import ( + CycleGANHyperparameters, + CycleGANNetworkConfig, + CycleGANTrainingConfig, + train_cyclegan, +) from fv3fit.data import CycleGANLoader, SyntheticWaves import collections import os @@ -24,6 +28,7 @@ def get_tfdataset(nsamples, nbatch, ntime, nx, nz): scale_max=1.0, period_min=4, period_max=7, + phase_range=0.1, ), SyntheticWaves( nsamples=nsamples, @@ -36,6 +41,7 @@ def get_tfdataset(nsamples, nbatch, ntime, nx, nz): scale_max=1.5, period_min=8, period_max=16, + phase_range=0.1, ), ] ) @@ -65,80 +71,83 @@ def tfdataset_to_xr_dataset(tfdataset, dims: Sequence[str]): return xr.Dataset(data_vars) -def test_cyclegan(tmpdir): - fv3fit.set_random_seed(0) - # run the test in a temporary directory to delete artifacts when done - os.chdir(tmpdir) - # need a larger nx, ny for the sample data here since we're training - # on whether we can autoencode sin waves, and need to resolve full cycles - nx, ny = 32, 32 - sizes = {"nbatch": 2, "ntime": 2, "nx": nx, "ny": ny, "nz": 2} - state_variables = ["a", "b"] - train_tfdataset = get_tfdataset(nsamples=20, **sizes) - val_tfdataset = get_tfdataset(nsamples=3, **sizes) - hyperparameters = CycleGANHyperparameters( - state_variables=state_variables, - generator=fv3fit.pytorch.GeneratorConfig( - n_convolutions=2, n_resnet=3, max_filters=32 - ), - training_loop=TrainingConfig(n_epoch=5, samples_per_batch=2), - optimizer_config=fv3fit.pytorch.OptimizerConfig(name="Adam",), - noise_amount=0.5, - ) - predictor = train_cyclegan(hyperparameters, train_tfdataset, val_tfdataset) - # for test, need one continuous series so we consistently flip sign - test_sizes = {"nbatch": 1, "ntime": 100, "nx": nx, "ny": ny, "nz": 2} - test_xrdataset = tfdataset_to_xr_dataset( - get_tfdataset(nsamples=1, **test_sizes), dims=["time", "tile", "x", "y", "z"] - ) - predicted = predictor.predict(test_xrdataset) - reference = test_xrdataset - # plotting code to uncomment if you'd like to manually check the results: - # for i in range(6): - # fig, ax = plt.subplots(1, 2) - # vmin = reference["a"][0, i, :, :, 0].values.min() - # vmax = reference["a"][0, i, :, :, 0].values.max() - # ax[0].imshow(reference["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) - # ax[1].imshow(predicted["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) - # plt.tight_layout() - # plt.show() - bias = predicted.isel(time=1) - reference.isel(time=1) - mean_bias: xr.Dataset = bias.mean() - mse: xr.Dataset = (bias ** 2).mean() ** 0.5 - for varname in state_variables: - assert np.abs(mean_bias[varname]) < 0.1 - assert mse[varname] < 0.1 +# def test_cyclegan(tmpdir): +# fv3fit.set_random_seed(0) +# # run the test in a temporary directory to delete artifacts when done +# os.chdir(tmpdir) +# # need a larger nx, ny for the sample data here since we're training +# # on whether we can autoencode sin waves, and need to resolve full cycles +# nx, ny = 32, 32 +# sizes = {"nbatch": 2, "ntime": 2, "nx": nx, "ny": ny, "nz": 2} +# state_variables = ["a", "b"] +# train_tfdataset = get_tfdataset(nsamples=20, **sizes) +# val_tfdataset = get_tfdataset(nsamples=3, **sizes) +# hyperparameters = CycleGANHyperparameters( +# state_variables=state_variables, +# generator=fv3fit.pytorch.GeneratorConfig( +# n_convolutions=2, n_resnet=3, max_filters=32 +# ), +# training_loop=TrainingConfig(n_epoch=5, samples_per_batch=2), +# optimizer_config=fv3fit.pytorch.OptimizerConfig(name="Adam",), +# noise_amount=0.5, +# ) +# predictor = train_cyclegan(hyperparameters, train_tfdataset, val_tfdataset) +# # for test, need one continuous series so we consistently flip sign +# test_sizes = {"nbatch": 1, "ntime": 100, "nx": nx, "ny": ny, "nz": 2} +# test_xrdataset = tfdataset_to_xr_dataset( +# get_tfdataset(nsamples=1, **test_sizes), dims=["time", "tile", "x", "y", "z"] +# ) +# predicted = predictor.predict(test_xrdataset) +# reference = test_xrdataset +# # plotting code to uncomment if you'd like to manually check the results: +# # for i in range(6): +# # fig, ax = plt.subplots(1, 2) +# # vmin = reference["a"][0, i, :, :, 0].values.min() +# # vmax = reference["a"][0, i, :, :, 0].values.max() +# # ax[0].imshow(reference["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) +# # ax[1].imshow(predicted["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) +# # plt.tight_layout() +# # plt.show() +# bias = predicted.isel(time=1) - reference.isel(time=1) +# mean_bias: xr.Dataset = bias.mean() +# mse: xr.Dataset = (bias ** 2).mean() ** 0.5 +# for varname in state_variables: +# assert np.abs(mean_bias[varname]) < 0.1 +# assert mse[varname] < 0.1 def test_cyclegan_overfit(tmpdir): fv3fit.set_random_seed(0) # run the test in a temporary directory to delete artifacts when done os.chdir(tmpdir) - # need a larger nx, ny for the sample data here since we're training + # need a larger nx for the sample data here since we're training # on whether we can autoencode sin waves, and need to resolve full cycles - nx, ny = 32, 32 - sizes = {"nbatch": 1, "ntime": 1, "nx": nx, "ny": ny, "nz": 2} + nx = 32 + sizes = {"nbatch": 1, "ntime": 1, "nx": nx, "nz": 2} state_variables = ["a", "b"] train_tfdataset = get_tfdataset(nsamples=1, **sizes) train_tfdataset = train_tfdataset.cache() # needed to keep sample identical hyperparameters = CycleGANHyperparameters( state_variables=state_variables, - generator=fv3fit.pytorch.GeneratorConfig( - n_convolutions=2, n_resnet=1, max_filters=32 + network=CycleGANNetworkConfig( + generator=fv3fit.pytorch.GeneratorConfig( + n_convolutions=2, n_resnet=1, max_filters=32 + ), ), - training_loop=TrainingConfig(n_epoch=100, samples_per_batch=6), - optimizer_config=fv3fit.pytorch.OptimizerConfig(name="Adam",), - noise_amount=0.0, + training_loop=CycleGANTrainingConfig(n_epoch=1000, samples_per_batch=6), ) predictor = train_cyclegan( - hyperparameters, train_tfdataset, validation_batches=None + hyperparameters, train_tfdataset, validation_batches=train_tfdataset ) # for test, need one continuous series so we consistently flip sign - test_xrdataset = tfdataset_to_xr_dataset( - train_tfdataset, dims=["time", "tile", "x", "y", "z"] + test_xrdataset_in = tfdataset_to_xr_dataset( + train_tfdataset.map(lambda a, b: a), dims=["time", "tile", "x", "y", "z"] + ) + test_xrdataset_out = tfdataset_to_xr_dataset( + train_tfdataset.map(lambda a, b: b), dims=["time", "tile", "x", "y", "z"] ) - predicted = predictor.predict(test_xrdataset) - reference = test_xrdataset + predicted = predictor.predict(test_xrdataset_in) + reference = test_xrdataset_out # plotting code to uncomment if you'd like to manually check the results: # for i in range(6): # fig, ax = plt.subplots(1, 2) @@ -154,3 +163,6 @@ def test_cyclegan_overfit(tmpdir): for varname in state_variables: assert np.abs(mean_bias[varname]) < 0.1 assert rmse[varname] < 0.1 + import pdb + + pdb.set_trace() From 0653057f93e86a3124c7b4993df2595445579606 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Thu, 25 Aug 2022 17:06:13 -0700 Subject: [PATCH 17/55] still wip, not converging --- external/fv3fit/fv3fit/data/__init__.py | 2 +- external/fv3fit/fv3fit/data/synthetic.py | 84 +++++- .../fv3fit/pytorch/cyclegan/__init__.py | 1 + .../fv3fit/pytorch/cyclegan/train_cyclegan.py | 274 ++++++++++++++---- .../fv3fit/tests/training/test_cyclegan.py | 92 ++++-- 5 files changed, 373 insertions(+), 80 deletions(-) diff --git a/external/fv3fit/fv3fit/data/__init__.py b/external/fv3fit/fv3fit/data/__init__.py index 0d147ca4c6..7f2e10beb8 100644 --- a/external/fv3fit/fv3fit/data/__init__.py +++ b/external/fv3fit/fv3fit/data/__init__.py @@ -1,4 +1,4 @@ from .base import TFDatasetLoader, tfdataset_loader_from_dict, register_tfdataset_loader from .batches import FromBatches from .tfdataset import WindowedZarrLoader, VariableConfig, CycleGANLoader -from .synthetic import SyntheticWaves +from .synthetic import SyntheticWaves, SyntheticNoise diff --git a/external/fv3fit/fv3fit/data/synthetic.py b/external/fv3fit/fv3fit/data/synthetic.py index 597000f000..6b7dd8a666 100644 --- a/external/fv3fit/fv3fit/data/synthetic.py +++ b/external/fv3fit/fv3fit/data/synthetic.py @@ -7,6 +7,51 @@ import dacite +@register_tfdataset_loader +@dataclasses.dataclass +class SyntheticNoise(TFDatasetLoader): + nsamples: int + nbatch: int + ntime: int + nx: int + nz: int + scalar_names: List[str] = dataclasses.field(default_factory=list) + noise_amplitude: float = 1.0 + + def open_tfdataset( + self, local_download_path: Optional[str], variable_names: Sequence[str], + ) -> tf.data.Dataset: + """ + Args: + local_download_path: if provided, cache data locally at this path + variable_names: names of variables to include when loading data + Returns: + dataset containing requested variables, each record is a mapping from + variable name to variable value, and each value is a tensor whose + first dimension is the batch dimension + """ + dataset = get_noise_tfdataset( + variable_names, + scalar_names=self.scalar_names, + nsamples=self.nsamples, + nbatch=self.nbatch, + ntime=self.ntime, + nx=self.nx, + ny=self.nx, + nz=self.nz, + noise_amplitude=self.noise_amplitude, + ) + if local_download_path is not None: + dataset = dataset.cache(local_download_path) + return dataset + + @classmethod + def from_dict(cls, d: dict) -> "TFDatasetLoader": + return dacite.from_dict( + data_class=cls, data=d, config=dacite.Config(strict=True) + ) + + @register_tfdataset_loader @dataclasses.dataclass class SyntheticWaves(TFDatasetLoader): @@ -51,7 +96,7 @@ def open_tfdataset( variable name to variable value, and each value is a tensor whose first dimension is the batch dimension """ - dataset = get_tfdataset( + dataset = get_waves_tfdataset( variable_names, scalar_names=self.scalar_names, nsamples=self.nsamples, @@ -77,7 +122,7 @@ def from_dict(cls, d: dict) -> "TFDatasetLoader": ) -def get_tfdataset( +def get_waves_tfdataset( variable_names, *, scalar_names, @@ -143,3 +188,38 @@ def sample_iterator(): yield out return iterable_to_tfdataset(list(sample_iterator())) + + +def get_noise_tfdataset( + variable_names, + *, + scalar_names, + nsamples: int, + nbatch: int, + ntime: int, + nx: int, + ny: int, + nz: int, + noise_amplitude: float, +): + ntile = 6 + + def sample_iterator(): + # creates a timeseries where each time is the negation of time before it + for _ in range(nsamples): + data = noise_amplitude * np.random.randn(nbatch, 1, ntile, nx, ny, nz) + start = {} + for varname in variable_names: + if varname in scalar_names: + start[varname] = data[..., 0].astype(np.float32) + else: + start[varname] = data.astype(np.float32) + out = {key: [value] for key, value in start.items()} + for _ in range(ntime - 1): + for varname in start.keys(): + out[varname].append(out[varname][-1] * -1.0) + for varname in out: + out[varname] = np.concatenate(out[varname], axis=1) + yield out + + return iterable_to_tfdataset(list(sample_iterator())) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py b/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py index ea0eb6db5c..fa8964d198 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py @@ -4,4 +4,5 @@ CycleGANHyperparameters, CycleGANNetworkConfig, CycleGANTrainingConfig, + CycleGAN, ) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py index b736cb08d9..d93b2cbb4a 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py @@ -6,15 +6,24 @@ from fv3fit.pytorch.predict import PytorchPredictor from fv3fit.pytorch.loss import LossConfig from fv3fit.pytorch.optimizer import OptimizerConfig +import xarray as xr import torch from fv3fit.pytorch.system import DEVICE import tensorflow_datasets as tfds from fv3fit.tfdataset import sequence_size +from fv3fit.pytorch.predict import ( + _load_pytorch, + _dump_pytorch, + _pack_to_tensor, + _unpack_tensor, +) -from fv3fit._shared import register_training_function +from fv3fit._shared import register_training_function, io, StandardScaler from typing import ( Callable, Dict, + Hashable, + Iterable, List, Mapping, Optional, @@ -62,7 +71,7 @@ class CycleGANTrainingConfig: def fit_loop( self, - train_model: "CycleGAN", + train_model: "CycleGANTrainer", train_data: tf.data.Dataset, validation_data: Optional[tf.data.Dataset], ) -> None: @@ -93,8 +102,11 @@ def fit_loop( state_a = torch.as_tensor(batch_state[0]).float().to(DEVICE) state_b = torch.as_tensor(batch_state[1]).float().to(DEVICE) train_losses.append(train_model.train_on_batch(state_a, state_b)) - train_loss = np.mean(train_losses) - logger.info("train_loss: %f", train_loss) + train_loss = { + name: np.mean([data[name] for data in train_losses]) + for name in train_losses[0] + } + logger.info("train_loss: %s", train_loss) if validation_data is not None: val_loss = train_model.evaluate_on_dataset(validation_data) logger.info("val_loss %s", val_loss) @@ -137,7 +149,7 @@ def train_cyclegan( hyperparameters: CycleGANHyperparameters, train_batches: tf.data.Dataset, validation_batches: Optional[tf.data.Dataset], -) -> PytorchPredictor: +) -> "CycleGAN": """ Train a denoising autoencoder for cubed sphere data. @@ -172,6 +184,8 @@ def train_cyclegan( train_model = hyperparameters.network.build( n_state=next(iter(train_state))[0].shape[-1], n_batch=hyperparameters.training_loop.samples_per_batch, + state_variables=hyperparameters.state_variables, + scalers=scalers, ) # remove time and tile dimensions, while we're using regular convolution @@ -182,15 +196,7 @@ def train_cyclegan( hyperparameters.training_loop.fit_loop( train_model=train_model, train_data=train_state, validation_data=val_state, ) - - predictor = PytorchPredictor( - input_variables=hyperparameters.state_variables, - output_variables=hyperparameters.state_variables, - model=train_model.generator_a_to_b, - scalers=scalers[0], - output_scalers=scalers[1], - ) - return predictor + return train_model.cycle_gan class ReplayBuffer: @@ -273,7 +279,9 @@ class CycleGANNetworkConfig: cycle_weight: float = 1.0 gan_weight: float = 1.0 - def build(self, n_state: int, n_batch: int) -> "CycleGAN": + def build( + self, n_state: int, n_batch: int, state_variables, scalers + ) -> "CycleGANTrainer": generator_a_to_b = self.generator.build(n_state) generator_b_to_a = self.generator.build(n_state) discriminator_a = self.discriminator.build(n_state) @@ -286,11 +294,17 @@ def build(self, n_state: int, n_batch: int) -> "CycleGAN": optimizer_discriminator = self.discriminator_optimizer.instance( itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()) ) - return CycleGAN( - generator_a_to_b=generator_a_to_b, - generator_b_to_a=generator_b_to_a, - discriminator_a=discriminator_a, - discriminator_b=discriminator_b, + return CycleGANTrainer( + cycle_gan=CycleGAN( + model=CycleGANModule( + generator_a_to_b=generator_a_to_b, + generator_b_to_a=generator_b_to_a, + discriminator_a=discriminator_a, + discriminator_b=discriminator_b, + ), + state_variables=state_variables, + scalers=_merge_scaler_mappings(scalers), + ), optimizer_generator=optimizer_generator, optimizer_discriminator=optimizer_discriminator, identity_loss=self.identity_loss.instance, @@ -303,19 +317,173 @@ def build(self, n_state: int, n_batch: int) -> "CycleGAN": ) -@dataclasses.dataclass +def _merge_scaler_mappings( + scaler_tuple: Tuple[Mapping[str, StandardScaler], Mapping[str, StandardScaler]] +) -> Mapping[str, StandardScaler]: + scalers = {} + for prefix, scaler_map in zip(("a_", "b_"), scaler_tuple): + for key, scaler in scaler_map.items(): + scalers[prefix + key] = scaler + return scalers + + +class CycleGANModule(torch.nn.Module): + def __init__( + self, + generator_a_to_b: Generator, + generator_b_to_a: Generator, + discriminator_a: Discriminator, + discriminator_b: Discriminator, + ): + super(CycleGANModule, self).__init__() + self.generator_a_to_b = generator_a_to_b + self.generator_b_to_a = generator_b_to_a + self.discriminator_a = discriminator_a + self.discriminator_b = discriminator_b + + +@io.register("cycle_gan") class CycleGAN: + _MODEL_FILENAME = "weight.pt" + _CONFIG_FILENAME = "config.yaml" + _SCALERS_FILENAME = "scalers.zip" + + def __init__( + self, + model: CycleGANModule, + scalers: Mapping[Hashable, StandardScaler], + state_variables: Iterable[Hashable], + ): + """ + Args: + model: pytorch model + scalers: scalers for the state variables, keys are prepended with "a_" + or "b_" to denote the domain of the scaler, followed by the name of + the state variable it scales + state_variables: name of variables to be used as state variables in + the order expected by the model + """ + self.model = model + self.scalers = scalers + self.state_variables = state_variables + + @property + def generator_a_to_b(self) -> torch.nn.Module: + return self.model.generator_a_to_b + + @property + def generator_b_to_a(self) -> torch.nn.Module: + return self.model.generator_b_to_a + + @property + def discriminator_a(self) -> torch.nn.Module: + return self.model.discriminator_a + + @property + def discriminator_b(self) -> torch.nn.Module: + return self.model.discriminator_b + + @classmethod + def load(cls, path: str) -> "CycleGAN": + """Load a serialized model from a directory.""" + return _load_pytorch(cls, path) + + def dump(self, path: str) -> None: + _dump_pytorch(self, path) + + def get_config(self): + return {} + + def pack_to_tensor(self, ds: xr.Dataset, domain: str = "a") -> torch.Tensor: + """ + Packs the dataset into a tensor to be used by the pytorch model. + + Subdivides the dataset evenly into windows + of size (timesteps + 1) with overlapping start and end points. + Overlapping the window start and ends is necessary so that every + timestep (evolution from one time to the next) is included within + one of the windows. + + Args: + ds: dataset containing values to pack + domain: one of "a" or "b" + + Returns: + tensor of shape [window, time, tile, x, y, feature] + """ + scalers = { + name[2:]: scaler + for name, scaler in self.scalers.items() + if name.startswith(f"{domain}_") + } + return _pack_to_tensor( + ds=ds, timesteps=0, state_variables=self.state_variables, scalers=scalers, + ) + + def unpack_tensor(self, data: torch.Tensor, domain: str = "b") -> xr.Dataset: + """ + Unpacks the tensor into a dataset. + + Args: + data: tensor of shape [window, time, tile, x, y, feature] + domain: one of "a" or "b" + + Returns: + xarray dataset with values of shape [window, time, tile, x, y, feature] + """ + scalers = { + name[2:]: scaler + for name, scaler in self.scalers.items() + if name.startswith(f"{domain}_") + } + return _unpack_tensor( + data, + varnames=self.state_variables, + scalers=scalers, + dims=["time", "tile", "x", "y", "z"], + ) + + def predict(self, X: xr.Dataset, reverse: bool = False) -> xr.Dataset: + """ + Predict a state in the output domain from a state in the input domain. + + Args: + X: input dataset + reverse: if True, transform from the output domain to the input domain + + Returns: + predicted: predicted dataset + """ + if reverse: + input_domain, output_domain = "b", "a" + else: + input_domain, output_domain = "a", "b" + + tensor = self.pack_to_tensor(X, domain=input_domain) + reshaped_tensor = tensor.reshape( + [tensor.shape[0] * tensor.shape[1]] + list(tensor.shape[2:]) + ) + with torch.no_grad(): + if reverse: + outputs = self.generator_b_to_a(reshaped_tensor) + else: + outputs = self.generator_a_to_b(reshaped_tensor) + outputs = outputs.reshape(tensor.shape) + predicted = self.unpack_tensor(outputs, domain=output_domain) + return predicted + + +@dataclasses.dataclass +class CycleGANTrainer: + # This class based loosely on # https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/cycle_gan_model.py # Copyright Facebook, BSD license # https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/c99ce7c4e781712e0252c6127ad1a4e8021cc489/LICENSE - generator_a_to_b: Generator - generator_b_to_a: Generator - discriminator_a: Discriminator - discriminator_b: Discriminator + cycle_gan: CycleGAN optimizer_generator: torch.optim.Optimizer optimizer_discriminator: torch.optim.Optimizer identity_loss: torch.nn.Module @@ -335,6 +503,10 @@ def __post_init__(self): ) self.fake_a_buffer = ReplayBuffer() self.fake_b_buffer = ReplayBuffer() + self.generator_a_to_b = self.cycle_gan.generator_a_to_b + self.generator_b_to_a = self.cycle_gan.generator_b_to_a + self.discriminator_a = self.cycle_gan.discriminator_a + self.discriminator_b = self.cycle_gan.discriminator_b def evaluate_on_dataset( self, dataset: tf.data.Dataset, n_dims_keep: int = 4 @@ -348,10 +520,10 @@ def evaluate_on_dataset( for real_a, real_b in dataset: stats_real_a.observe(real_a) stats_real_b.observe(real_b) - gen_a: torch.Tensor = self.generator_a_to_b( + gen_b: torch.Tensor = self.generator_a_to_b( torch.as_tensor(real_a).float().to(DEVICE) ) - gen_b: torch.Tensor = self.generator_b_to_a( + gen_a: torch.Tensor = self.generator_b_to_a( torch.as_tensor(real_b).float().to(DEVICE) ) stats_gen_a.observe(gen_a.detach().cpu().numpy()) @@ -368,7 +540,9 @@ def evaluate_on_dataset( } return metrics - def train_on_batch(self, real_a: torch.Tensor, real_b: torch.Tensor) -> float: + def train_on_batch( + self, real_a: torch.Tensor, real_b: torch.Tensor + ) -> Mapping[str, float]: fake_b = self.generator_a_to_b(real_a) fake_a = self.generator_b_to_a(real_b) reconstructed_a = self.generator_b_to_a(fake_b) @@ -381,8 +555,6 @@ def train_on_batch(self, real_a: torch.Tensor, real_b: torch.Tensor) -> float: [self.discriminator_a, self.discriminator_b], requires_grad=False ) - self.optimizer_generator.zero_grad() - # Identity loss # G_A2B(B) should equal B if real B is fed same_b = self.generator_a_to_b(real_b) @@ -390,31 +562,25 @@ def train_on_batch(self, real_a: torch.Tensor, real_b: torch.Tensor) -> float: # G_B2A(A) should equal A if real A is fed same_a = self.generator_b_to_a(real_b) loss_identity_a = self.identity_loss(same_a, real_a) * self.identity_weight + loss_identity = loss_identity_a + loss_identity_b # GAN loss - fake_b = self.generator_a_to_b(real_a) pred_fake = self.discriminator_b(fake_b) - loss_gan_a_to_b = self.gan_loss(pred_fake, self.target_real) + loss_gan_a_to_b = self.gan_loss(pred_fake, self.target_real) * self.gan_weight - fake_A = self.generator_b_to_a(real_b) - pred_fake = self.discriminator_a(fake_A) - loss_gan_b_to_a = self.gan_loss(pred_fake, self.target_real) + pred_fake = self.discriminator_a(fake_a) + loss_gan_b_to_a = self.gan_loss(pred_fake, self.target_real) * self.gan_weight + loss_gan = loss_gan_a_to_b + loss_gan_b_to_a # Cycle loss loss_cycle_a_b_a = self.cycle_loss(reconstructed_a, real_a) * self.cycle_weight loss_cycle_b_a_b = self.cycle_loss(reconstructed_b, real_b) * self.cycle_weight + loss_cycle = loss_cycle_a_b_a + loss_cycle_b_a_b # Total loss - loss_g: torch.Tensor = ( - loss_identity_a - + loss_identity_b - + loss_gan_a_to_b - + loss_gan_b_to_a - + loss_cycle_a_b_a - + loss_cycle_b_a_b - ) + loss_g: torch.Tensor = (loss_identity + loss_gan + loss_cycle) + self.optimizer_generator.zero_grad() loss_g.backward() - self.optimizer_generator.step() # Discriminators A and B ###### @@ -424,8 +590,6 @@ def train_on_batch(self, real_a: torch.Tensor, real_b: torch.Tensor) -> float: [self.discriminator_a, self.discriminator_b], requires_grad=True ) - self.optimizer_discriminator.zero_grad() - # Real loss pred_real = self.discriminator_a(real_a) loss_d_a_real = self.gan_loss(pred_real, self.target_real) @@ -447,11 +611,23 @@ def train_on_batch(self, real_a: torch.Tensor, real_b: torch.Tensor) -> float: # Total loss loss_d: torch.Tensor = ( loss_d_b_real + loss_d_b_fake + loss_d_a_real + loss_d_a_fake - ) * 0.5 + ) + self.optimizer_discriminator.zero_grad() loss_d.backward() - self.optimizer_discriminator.step() - return float(loss_g + loss_d) + + return { + # "gan_loss": float(loss_gan), + "b_to_a_gan_loss": float(loss_gan_b_to_a), + "a_to_b_gan_loss": float(loss_gan_a_to_b), + "discriminator_a_loss": float(loss_d_a_fake + loss_d_a_real), + "discriminator_b_loss": float(loss_d_b_fake + loss_d_b_real), + # "cycle_loss": float(loss_cycle), + # "identity_loss": float(loss_identity), + # "generator_loss": float(loss_g), + # "discriminator_loss": float(loss_d), + "train_loss": float(loss_g + loss_d), + } def set_requires_grad(nets: List[torch.nn.Module], requires_grad=False): diff --git a/external/fv3fit/tests/training/test_cyclegan.py b/external/fv3fit/tests/training/test_cyclegan.py index 5d1e6331f2..9fd1cb67a1 100644 --- a/external/fv3fit/tests/training/test_cyclegan.py +++ b/external/fv3fit/tests/training/test_cyclegan.py @@ -6,12 +6,14 @@ CycleGANNetworkConfig, CycleGANTrainingConfig, train_cyclegan, + CycleGAN, ) -from fv3fit.data import CycleGANLoader, SyntheticWaves +from fv3fit.data import CycleGANLoader, SyntheticWaves, SyntheticNoise import collections import os import fv3fit.pytorch import fv3fit +import matplotlib.pyplot as plt def get_tfdataset(nsamples, nbatch, ntime, nx, nz): @@ -24,7 +26,7 @@ def get_tfdataset(nsamples, nbatch, ntime, nx, nz): nx=nx, nz=nz, scalar_names=["b"], - scale_min=0.1, + scale_min=0.5, scale_max=1.0, period_min=4, period_max=7, @@ -37,7 +39,7 @@ def get_tfdataset(nsamples, nbatch, ntime, nx, nz): nx=nx, nz=nz, scalar_names=["b"], - scale_min=0.5, + scale_min=1.0, scale_max=1.5, period_min=8, period_max=16, @@ -49,6 +51,31 @@ def get_tfdataset(nsamples, nbatch, ntime, nx, nz): return dataset +def get_noise_tfdataset(nsamples, nbatch, ntime, nx, nz): + config = CycleGANLoader( + domain_configs=[ + SyntheticNoise( + nsamples=nsamples, + nbatch=nbatch, + ntime=ntime, + nx=nx, + nz=nz, + noise_amplitude=1.0, + ), + SyntheticNoise( + nsamples=nsamples, + nbatch=nbatch, + ntime=ntime, + nx=nx, + nz=nz, + noise_amplitude=1.0, + ), + ] + ) + dataset = config.open_tfdataset(local_download_path=None, variable_names=["a", "b"]) + return dataset + + def tfdataset_to_xr_dataset(tfdataset, dims: Sequence[str]): """ Returns a [time, tile, x, y, z] dataset needed for evaluation. @@ -122,47 +149,56 @@ def test_cyclegan_overfit(tmpdir): os.chdir(tmpdir) # need a larger nx for the sample data here since we're training # on whether we can autoencode sin waves, and need to resolve full cycles - nx = 32 + nx = 16 sizes = {"nbatch": 1, "ntime": 1, "nx": nx, "nz": 2} state_variables = ["a", "b"] - train_tfdataset = get_tfdataset(nsamples=1, **sizes) + train_tfdataset = get_noise_tfdataset(nsamples=1, **sizes) train_tfdataset = train_tfdataset.cache() # needed to keep sample identical hyperparameters = CycleGANHyperparameters( state_variables=state_variables, network=CycleGANNetworkConfig( generator=fv3fit.pytorch.GeneratorConfig( - n_convolutions=2, n_resnet=1, max_filters=32 + n_convolutions=2, n_resnet=1, max_filters=128 + ), + generator_optimizer=fv3fit.pytorch.OptimizerConfig( + name="Adam", kwargs={"lr": 0.01} + ), + discriminator_optimizer=fv3fit.pytorch.OptimizerConfig( + name="Adam", kwargs={"lr": 0.01} ), ), - training_loop=CycleGANTrainingConfig(n_epoch=1000, samples_per_batch=6), + training_loop=CycleGANTrainingConfig(n_epoch=100, samples_per_batch=6), ) predictor = train_cyclegan( hyperparameters, train_tfdataset, validation_batches=train_tfdataset ) # for test, need one continuous series so we consistently flip sign - test_xrdataset_in = tfdataset_to_xr_dataset( + real_a = tfdataset_to_xr_dataset( train_tfdataset.map(lambda a, b: a), dims=["time", "tile", "x", "y", "z"] ) - test_xrdataset_out = tfdataset_to_xr_dataset( + real_b = tfdataset_to_xr_dataset( train_tfdataset.map(lambda a, b: b), dims=["time", "tile", "x", "y", "z"] ) - predicted = predictor.predict(test_xrdataset_in) - reference = test_xrdataset_out + output_b = predictor.predict(real_a) + output_a = predictor.predict(real_b, reverse=True) # plotting code to uncomment if you'd like to manually check the results: - # for i in range(6): - # fig, ax = plt.subplots(1, 2) - # vmin = reference["a"][0, i, :, :, 0].values.min() - # vmax = reference["a"][0, i, :, :, 0].values.max() - # ax[0].imshow(reference["a"][0, i, :, :, 0].values) # , vmin=vmin, vmax=vmax) - # ax[1].imshow(predicted["a"][0, i, :, :, 0].values) # , vmin=vmin, vmax=vmax) - # plt.tight_layout() - # plt.show() - bias = predicted - reference - mean_bias: xr.Dataset = bias.mean() - rmse: xr.Dataset = (bias ** 2).mean() - for varname in state_variables: - assert np.abs(mean_bias[varname]) < 0.1 - assert rmse[varname] < 0.1 - import pdb - - pdb.set_trace() + for i in range(3): + fig, ax = plt.subplots(2, 2) + vmin = -1.5 + vmax = 1.5 + ax[0, 0].imshow(real_a["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) + ax[0, 1].imshow(real_b["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) + ax[1, 0].imshow(output_a["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) + ax[1, 1].imshow(output_b["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) + ax[0, 0].set_title("real a") + ax[0, 1].set_title("real b") + ax[1, 0].set_title("output a") + ax[1, 1].set_title("output b") + plt.tight_layout() + plt.show() + # bias = predicted - reference + # mean_bias: xr.Dataset = bias.mean() + # rmse: xr.Dataset = (bias ** 2).mean() + # for varname in state_variables: + # assert np.abs(mean_bias[varname]) < 0.1 + # assert rmse[varname] < 0.1 From 1f3efc1c3cdcd445cbb0df67b98fce4087648268 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Fri, 26 Aug 2022 09:15:59 -0700 Subject: [PATCH 18/55] ignore internal deprecation warnings during tests --- pytest.ini | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytest.ini b/pytest.ini index c346d23da8..5969cb54ee 100644 --- a/pytest.ini +++ b/pytest.ini @@ -22,3 +22,5 @@ markers = filterwarnings = ignore:distutils Version classes are deprecated:DeprecationWarning + ignore:Call to deprecated create function:DeprecationWarning + ignore:.*is deprecated and will be removed in Pillow 10:DeprecationWarning From b7992bf3f0b53148b41c2b2941c0552a04b6e2de Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 29 Aug 2022 10:01:11 -0700 Subject: [PATCH 19/55] cyclegan training code might be working, hard to test it --- external/fv3fit/fv3fit/data/synthetic.py | 26 ++- external/fv3fit/fv3fit/pytorch/__init__.py | 7 +- .../fv3fit/pytorch/cyclegan/__init__.py | 3 +- .../fv3fit/fv3fit/pytorch/cyclegan/network.py | 61 +++-- .../fv3fit/fv3fit/pytorch/cyclegan/train.py | 3 +- .../fv3fit/pytorch/cyclegan/train_cyclegan.py | 112 +++++++-- external/fv3fit/fv3fit/pytorch/graph/train.py | 3 +- external/fv3fit/fv3fit/pytorch/predict.py | 33 ++- external/fv3fit/fv3fit/pytorch/system.py | 9 +- .../fv3fit/tests/training/test_autoencoder.py | 17 +- .../fv3fit/tests/training/test_cyclegan.py | 220 +++++++++++++----- 11 files changed, 349 insertions(+), 145 deletions(-) diff --git a/external/fv3fit/fv3fit/data/synthetic.py b/external/fv3fit/fv3fit/data/synthetic.py index 6b7dd8a666..955c52ab4a 100644 --- a/external/fv3fit/fv3fit/data/synthetic.py +++ b/external/fv3fit/fv3fit/data/synthetic.py @@ -3,7 +3,7 @@ from typing import Optional, Sequence, List import tensorflow as tf import numpy as np -from ..tfdataset import iterable_to_tfdataset +from ..tfdataset import generator_to_tfdataset import dacite @@ -69,7 +69,7 @@ class SyntheticWaves(TFDatasetLoader): period_max: maximum period of waves phase_range: fraction of 2*pi to use for possible range of random phase, should be a value between 0 and 1. - + type: one of "sinusoidal" or "square" """ nsamples: int @@ -83,6 +83,7 @@ class SyntheticWaves(TFDatasetLoader): period_min: float = 8.0 period_max: float = 16.0 phase_range: float = 1.0 + type: str = "sinusoidal" def open_tfdataset( self, local_download_path: Optional[str], variable_names: Sequence[str], @@ -96,6 +97,13 @@ def open_tfdataset( variable name to variable value, and each value is a tensor whose first dimension is the batch dimension """ + if self.type == "sinusoidal": + func = np.sin + elif self.type == "square": + + def func(x): + return np.sign(np.sin(x)) + dataset = get_waves_tfdataset( variable_names, scalar_names=self.scalar_names, @@ -110,6 +118,7 @@ def open_tfdataset( period_min=self.period_min, period_max=self.period_max, phase_range=self.phase_range, + func=func, ) if local_download_path is not None: dataset = dataset.cache(local_download_path) @@ -137,6 +146,7 @@ def get_waves_tfdataset( period_min: float, period_max: float, phase_range: float, + func=np.sin, ): ntile = 6 @@ -146,7 +156,7 @@ def get_waves_tfdataset( grid_x = grid_x[None, None, None, :, :, None] grid_y = grid_y[None, None, None, :, :, None] - def sample_iterator(): + def sample_generator(): # creates a timeseries where each time is the negation of time before it for _ in range(nsamples): ax = np.random.uniform(scale_min, scale_max, size=(nbatch, 1, ntile, nz))[ @@ -169,9 +179,9 @@ def sample_iterator(): )[:, :, :, None, None, :] data = ( ax - * np.sin(2 * np.pi * grid_x / bx + cx) + * func(2 * np.pi * grid_x / bx + cx) * ay - * np.sin(2 * np.pi * grid_y / by + cy) + * func(2 * np.pi * grid_y / by + cy) ) start = {} for varname in variable_names: @@ -187,7 +197,7 @@ def sample_iterator(): out[varname] = np.concatenate(out[varname], axis=1) yield out - return iterable_to_tfdataset(list(sample_iterator())) + return generator_to_tfdataset(sample_generator) def get_noise_tfdataset( @@ -204,7 +214,7 @@ def get_noise_tfdataset( ): ntile = 6 - def sample_iterator(): + def sample_generator(): # creates a timeseries where each time is the negation of time before it for _ in range(nsamples): data = noise_amplitude * np.random.randn(nbatch, 1, ntile, nx, ny, nz) @@ -222,4 +232,4 @@ def sample_iterator(): out[varname] = np.concatenate(out[varname], axis=1) yield out - return iterable_to_tfdataset(list(sample_iterator())) + return generator_to_tfdataset(sample_generator) diff --git a/external/fv3fit/fv3fit/pytorch/__init__.py b/external/fv3fit/fv3fit/pytorch/__init__.py index 0cf16f3b1b..15e16ff36f 100644 --- a/external/fv3fit/fv3fit/pytorch/__init__.py +++ b/external/fv3fit/fv3fit/pytorch/__init__.py @@ -1,5 +1,10 @@ from .graph import GraphHyperparameters, train_graph_model, GraphNetworkConfig from .system import DEVICE from .predict import PytorchAutoregressor, PytorchPredictor -from .cyclegan import train_autoencoder, AutoencoderHyperparameters, GeneratorConfig +from .cyclegan import ( + train_autoencoder, + AutoencoderHyperparameters, + GeneratorConfig, + DiscriminatorConfig, +) from .optimizer import OptimizerConfig diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py b/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py index fa8964d198..250e6cbd1d 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py @@ -1,4 +1,5 @@ -from .train import train_autoencoder, AutoencoderHyperparameters, GeneratorConfig +from .train import train_autoencoder, AutoencoderHyperparameters +from .network import GeneratorConfig, DiscriminatorConfig from .train_cyclegan import ( train_cyclegan, CycleGANHyperparameters, diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/network.py b/external/fv3fit/fv3fit/pytorch/cyclegan/network.py index d40927a1bf..43b927719e 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/network.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/network.py @@ -1,9 +1,10 @@ from typing import Callable, Literal, Protocol -import torch import torch.nn as nn from toolz import curry import dataclasses -from fv3fit.pytorch.optimizer import OptimizerConfig +import logging + +logger = logging.getLogger(__name__) def relu_activation(**kwargs): @@ -147,6 +148,7 @@ def flat_convolution(in_channels: int, out_channels: int, kernel_size: int, bias class GeneratorConfig: n_convolutions: int = 3 n_resnet: int = 3 + kernel_size: int = 3 max_filters: int = 256 def build( @@ -158,6 +160,7 @@ def build( channels=channels, n_convolutions=self.n_convolutions, n_resnet=self.n_resnet, + kernel_size=self.kernel_size, max_filters=self.max_filters, convolution=convolution, ) @@ -167,6 +170,7 @@ def build( class DiscriminatorConfig: n_convolutions: int = 3 + kernel_size: int = 3 max_filters: int = 256 def build( @@ -177,6 +181,7 @@ def build( return Discriminator( in_channels=channels, n_convolutions=self.n_convolutions, + kernel_size=self.kernel_size, max_filters=self.max_filters, convolution=convolution, ) @@ -235,6 +240,7 @@ def __init__( self, in_channels: int, n_convolutions: int, + kernel_size: int, max_filters: int, convolution: ConvolutionFactoryFactory = regular_convolution, ): @@ -245,8 +251,12 @@ def __init__( ConvBlock( in_channels=in_channels, out_channels=min_filters, - convolution_factory=convolution(kernel_size=3, stride=2, padding=1), - activation_factory=leakyrelu_activation(negative_slope=0.2), + convolution_factory=convolution( + kernel_size=kernel_size, stride=2, padding=1 + ), + activation_factory=leakyrelu_activation( + negative_slope=0.2, inplace=True + ), ) ] for i in range(1, n_convolutions): @@ -254,21 +264,22 @@ def __init__( ConvBlock( in_channels=min_filters * 2 ** (i - 1), out_channels=min_filters * 2 ** i, - convolution_factory=convolution(kernel_size=3, stride=2, padding=1), - activation_factory=leakyrelu_activation(negative_slope=0.2), + convolution_factory=convolution( + kernel_size=kernel_size, stride=2, padding=1 + ), + activation_factory=leakyrelu_activation( + negative_slope=0.2, inplace=True + ), ) ) final_conv = ConvBlock( in_channels=max_filters, out_channels=max_filters, - convolution_factory=convolution(kernel_size=3), - activation_factory=leakyrelu_activation(negative_slope=0.2), + convolution_factory=convolution(kernel_size=kernel_size), + activation_factory=leakyrelu_activation(negative_slope=0.2, inplace=True), ) - patch_output = ConvBlock( - in_channels=max_filters, - out_channels=1, - convolution_factory=convolution(kernel_size=3), - activation_factory=leakyrelu_activation(negative_slope=0.2), + patch_output = convolution(kernel_size=3)( + in_channels=max_filters, out_channels=1 ) self._sequential = nn.Sequential(*convs, final_conv, patch_output) @@ -284,6 +295,7 @@ def __init__( channels: int, n_convolutions: int, n_resnet: int, + kernel_size: int, max_filters: int, convolution: ConvolutionFactoryFactory = regular_convolution, ): @@ -293,7 +305,7 @@ def resnet(in_channels: int): resnet_blocks = [ ResnetBlock( n_filters=in_channels, - convolution_factory=convolution(kernel_size=3), + convolution_factory=convolution(kernel_size=kernel_size), activation_factory=relu_activation(), ) for _ in range(n_resnet) @@ -304,7 +316,9 @@ def down(in_channels: int, out_channels: int): return ConvBlock( in_channels=in_channels, out_channels=out_channels, - convolution_factory=convolution(kernel_size=3, stride=2, padding=1), + convolution_factory=convolution( + kernel_size=kernel_size, stride=2, padding=1 + ), activation_factory=relu_activation(), ) @@ -313,7 +327,7 @@ def up(in_channels: int, out_channels: int): in_channels=in_channels, out_channels=out_channels, convolution_factory=convolution( - kernel_size=3, + kernel_size=kernel_size, stride=2, padding=1, output_padding=1, @@ -325,7 +339,7 @@ def up(in_channels: int, out_channels: int): min_filters = int(max_filters / 2 ** (n_convolutions - 1)) self._first_conv = nn.Sequential( - flat_convolution(kernel_size=3)( + flat_convolution(kernel_size=7)( in_channels=channels, out_channels=min_filters ), relu_activation()(), @@ -339,16 +353,21 @@ def up(in_channels: int, out_channels: int): in_channels=min_filters, ) - self._out_conv = flat_convolution(kernel_size=3)( - in_channels=2 * min_filters, out_channels=channels + self._out_conv = flat_convolution(kernel_size=7)( + in_channels=min_filters, out_channels=channels ) + self._identity = nn.Linear(in_features=channels, out_features=channels) def forward(self, inputs): # data will have channels last, model requires channels first inputs = inputs.permute(0, 3, 1, 2) + logger.info(inputs.shape) + # outputs = self._identity(inputs) x = self._first_conv(inputs) - x = self._unet(x) + logger.info(x.shape) + # x = self._unet(x) outputs = self._out_conv(x) + logger.info(outputs.shape) return outputs.permute(0, 2, 3, 1) @@ -378,5 +397,5 @@ def forward(self, inputs): x = self._lower(x) x = self._up(x) # skip connection - x = torch.concat([x, inputs], dim=1) + # x = torch.concat([x, inputs], dim=1) return x diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train.py index f660ab0d03..4f5ad4ee1a 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train.py @@ -1,5 +1,6 @@ from fv3fit._shared.hyperparameters import Hyperparameters import dataclasses +from fv3fit.pytorch.system import DEVICE import tensorflow as tf from fv3fit.pytorch.predict import PytorchPredictor from fv3fit.pytorch.loss import LossConfig @@ -130,4 +131,4 @@ def train_autoencoder( def build_model(config: GeneratorConfig, n_state: int) -> Generator: - return config.build(channels=n_state) + return config.build(channels=n_state).to(DEVICE) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py index d93b2cbb4a..f4dcb5896b 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py @@ -2,8 +2,8 @@ from fv3fit._shared.hyperparameters import Hyperparameters import random import dataclasses +from fv3fit._shared.predictor import Dumpable import tensorflow as tf -from fv3fit.pytorch.predict import PytorchPredictor from fv3fit.pytorch.loss import LossConfig from fv3fit.pytorch.optimizer import OptimizerConfig import xarray as xr @@ -22,7 +22,6 @@ from typing import ( Callable, Dict, - Hashable, Iterable, List, Mapping, @@ -30,7 +29,7 @@ Sequence, Tuple, ) -from fv3fit.tfdataset import ensure_nd, apply_to_mapping, apply_to_tuple +from fv3fit.tfdataset import ensure_nd from .network import Discriminator, Generator, GeneratorConfig, DiscriminatorConfig from fv3fit.pytorch.graph.train import ( get_scalers, @@ -107,6 +106,47 @@ def fit_loop( for name in train_losses[0] } logger.info("train_loss: %s", train_loss) + + # real_a = torch.as_tensor(batch_state[0]).float().to(DEVICE) + # real_b = torch.as_tensor(batch_state[1]).float().to(DEVICE) + # fake_b = train_model.generator_a_to_b(real_a) + # fake_a = train_model.generator_b_to_a(real_b) + # reconstructed_a = train_model.generator_b_to_a(fake_b) + # reconstructed_b = train_model.generator_a_to_b(fake_a) + + # import matplotlib.pyplot as plt + + # fig, ax = plt.subplots(3, 2, figsize=(8, 8)) + # i = 0 + # iz = 0 + # vmin = -1.5 + # vmax = 1.5 + # ax[0, 0].imshow( + # real_a[0, :, :, iz].detach().numpy(), vmin=vmin, vmax=vmax + # ) + # ax[0, 1].imshow( + # real_b[0, :, :, iz].detach().numpy(), vmin=vmin, vmax=vmax + # ) + # ax[1, 0].imshow( + # fake_b[0, :, :, iz].detach().numpy(), vmin=vmin, vmax=vmax + # ) + # ax[1, 1].imshow( + # fake_a[0, :, :, iz].detach().numpy(), vmin=vmin, vmax=vmax + # ) + # ax[2, 0].imshow( + # reconstructed_a[0, :, :, iz].detach().numpy(), vmin=vmin, vmax=vmax + # ) + # ax[2, 1].imshow( + # reconstructed_b[0, :, :, iz].detach().numpy(), vmin=vmin, vmax=vmax + # ) + # ax[0, 0].set_title("real a") + # ax[0, 1].set_title("real b") + # ax[1, 0].set_title("output b") + # ax[1, 1].set_title("output a") + # ax[2, 0].set_title("reconstructed a") + # ax[2, 1].set_title("reconstructed b") + # plt.tight_layout() + # plt.show() if validation_data is not None: val_loss = train_model.evaluate_on_dataset(validation_data) logger.info("val_loss %s", val_loss) @@ -128,7 +168,8 @@ def get_Xy_map_fn( state_variables: Sequence[str], n_dims: int, # [batch, time, tile, x, y, z] mapping_scale_funcs: Tuple[ - Callable[[Mapping[str, np.ndarray]], Mapping[str, np.ndarray]], ... + Callable[[Mapping[str, np.ndarray]], Mapping[str, np.ndarray]], + ..., # noqa: W504 ], ): funcs = tuple( @@ -278,6 +319,7 @@ class CycleGANNetworkConfig: identity_weight: float = 0.5 cycle_weight: float = 1.0 gan_weight: float = 1.0 + discriminator_weight: float = 1.0 def build( self, n_state: int, n_batch: int, state_variables, scalers @@ -301,7 +343,7 @@ def build( generator_b_to_a=generator_b_to_a, discriminator_a=discriminator_a, discriminator_b=discriminator_b, - ), + ).to(DEVICE), state_variables=state_variables, scalers=_merge_scaler_mappings(scalers), ), @@ -314,6 +356,7 @@ def build( identity_weight=self.identity_weight, cycle_weight=self.cycle_weight, gan_weight=self.gan_weight, + discriminator_weight=self.discriminator_weight, ) @@ -343,7 +386,7 @@ def __init__( @io.register("cycle_gan") -class CycleGAN: +class CycleGAN(Dumpable): _MODEL_FILENAME = "weight.pt" _CONFIG_FILENAME = "config.yaml" @@ -352,8 +395,8 @@ class CycleGAN: def __init__( self, model: CycleGANModule, - scalers: Mapping[Hashable, StandardScaler], - state_variables: Iterable[Hashable], + scalers: Mapping[str, StandardScaler], + state_variables: Iterable[str], ): """ Args: @@ -493,13 +536,14 @@ class CycleGANTrainer: identity_weight: float = 0.5 cycle_weight: float = 1.0 gan_weight: float = 1.0 + discriminator_weight: float = 1.0 def __post_init__(self): self.target_real = torch.autograd.Variable( - torch.Tensor(self.batch_size).fill_(1.0), requires_grad=False + torch.Tensor(self.batch_size).fill_(1.0).to(DEVICE), requires_grad=False ) self.target_fake = torch.autograd.Variable( - torch.Tensor(self.batch_size).fill_(0.0), requires_grad=False + torch.Tensor(self.batch_size).fill_(0.0).to(DEVICE), requires_grad=False ) self.fake_a_buffer = ReplayBuffer() self.fake_b_buffer = ReplayBuffer() @@ -509,7 +553,7 @@ def __post_init__(self): self.discriminator_b = self.cycle_gan.discriminator_b def evaluate_on_dataset( - self, dataset: tf.data.Dataset, n_dims_keep: int = 4 + self, dataset: tf.data.Dataset, n_dims_keep: int = 3 ) -> Dict[str, float]: stats_real_a = StatsCollector(n_dims_keep) stats_real_b = StatsCollector(n_dims_keep) @@ -529,14 +573,15 @@ def evaluate_on_dataset( stats_gen_a.observe(gen_a.detach().cpu().numpy()) stats_gen_b.observe(gen_b.detach().cpu().numpy()) metrics = { + # "r2_mean_b_against_real_a": get_r2(stats_real_a.mean, stats_gen_b.mean), "r2_mean_a": get_r2(stats_real_a.mean, stats_gen_a.mean), - "bias_mean_a": np.mean(stats_real_a.mean - stats_gen_a.mean), + # "bias_mean_a": np.mean(stats_real_a.mean - stats_gen_a.mean), "r2_mean_b": get_r2(stats_real_b.mean, stats_gen_b.mean), - "bias_mean_b": np.mean(stats_real_b.mean - stats_gen_b.mean), + # "bias_mean_b": np.mean(stats_real_b.mean - stats_gen_b.mean), "r2_std_a": get_r2(stats_real_a.std, stats_gen_a.std), - "bias_std_a": np.mean(stats_real_a.std - stats_gen_a.std), + # "bias_std_a": np.mean(stats_real_a.std - stats_gen_a.std), "r2_std_b": get_r2(stats_real_b.std, stats_gen_b.std), - "bias_std_b": np.mean(stats_real_b.std - stats_gen_b.std), + # "bias_std_b": np.mean(stats_real_b.std - stats_gen_b.std), } return metrics @@ -560,16 +605,16 @@ def train_on_batch( same_b = self.generator_a_to_b(real_b) loss_identity_b = self.identity_loss(same_b, real_b) * self.identity_weight # G_B2A(A) should equal A if real A is fed - same_a = self.generator_b_to_a(real_b) + same_a = self.generator_b_to_a(real_a) loss_identity_a = self.identity_loss(same_a, real_a) * self.identity_weight loss_identity = loss_identity_a + loss_identity_b # GAN loss - pred_fake = self.discriminator_b(fake_b) - loss_gan_a_to_b = self.gan_loss(pred_fake, self.target_real) * self.gan_weight + pred_fake_b = self.discriminator_b(fake_b) + loss_gan_a_to_b = self.gan_loss(pred_fake_b, self.target_real) * self.gan_weight - pred_fake = self.discriminator_a(fake_a) - loss_gan_b_to_a = self.gan_loss(pred_fake, self.target_real) * self.gan_weight + pred_fake_a = self.discriminator_a(fake_a) + loss_gan_b_to_a = self.gan_loss(pred_fake_a, self.target_real) * self.gan_weight loss_gan = loss_gan_a_to_b + loss_gan_b_to_a # Cycle loss @@ -592,26 +637,43 @@ def train_on_batch( # Real loss pred_real = self.discriminator_a(real_a) - loss_d_a_real = self.gan_loss(pred_real, self.target_real) + loss_d_a_real = ( + self.gan_loss(pred_real, self.target_real) + * self.gan_weight + * self.discriminator_weight + ) # Fake loss fake_a = self.fake_a_buffer.push_and_pop(fake_a) pred_a_fake = self.discriminator_a(fake_a.detach()) - loss_d_a_fake = self.gan_loss(pred_a_fake, self.target_fake) + loss_d_a_fake = ( + self.gan_loss(pred_a_fake, self.target_fake) + * self.gan_weight + * self.discriminator_weight + ) # Real loss pred_real = self.discriminator_b(real_b) - loss_d_b_real = self.gan_loss(pred_real, self.target_real) + loss_d_b_real = ( + self.gan_loss(pred_real, self.target_real) + * self.gan_weight + * self.discriminator_weight + ) # Fake loss fake_b = self.fake_b_buffer.push_and_pop(fake_b) pred_b_fake = self.discriminator_b(fake_b.detach()) - loss_d_b_fake = self.gan_loss(pred_b_fake, self.target_fake) + loss_d_b_fake = ( + self.gan_loss(pred_b_fake, self.target_fake) + * self.gan_weight + * self.discriminator_weight + ) # Total loss loss_d: torch.Tensor = ( loss_d_b_real + loss_d_b_fake + loss_d_a_real + loss_d_a_fake - ) + ) * self.discriminator_weight + self.optimizer_discriminator.zero_grad() loss_d.backward() self.optimizer_discriminator.step() diff --git a/external/fv3fit/fv3fit/pytorch/graph/train.py b/external/fv3fit/fv3fit/pytorch/graph/train.py index 6fcb276f3b..bd9b49dd17 100644 --- a/external/fv3fit/fv3fit/pytorch/graph/train.py +++ b/external/fv3fit/fv3fit/pytorch/graph/train.py @@ -2,7 +2,6 @@ import numpy as np import dataclasses from fv3fit._shared.training_config import Hyperparameters -from toolz.functoolz import curry from fv3fit.pytorch.predict import PytorchAutoregressor from fv3fit.pytorch.graph.network import GraphNetwork, GraphNetworkConfig from fv3fit.pytorch.loss import LossConfig @@ -207,7 +206,7 @@ def get_Xy_map_fn( ensure_dims = apply_to_mapping(ensure_nd(n_dims)) def map_fn(data): - data = mapping_scale_func(data) + # data = mapping_scale_func(data) data = ensure_dims(data) data = select_keys(state_variables, data) data = tf.concat(data, axis=-1) diff --git a/external/fv3fit/fv3fit/pytorch/predict.py b/external/fv3fit/fv3fit/pytorch/predict.py index 65e76e7b82..1220748c98 100644 --- a/external/fv3fit/fv3fit/pytorch/predict.py +++ b/external/fv3fit/fv3fit/pytorch/predict.py @@ -38,7 +38,7 @@ def load(cls: Type[L], f: IO[bytes]) -> L: ... -def dump_mapping(mapping: Mapping[Hashable, StandardScaler], f: IO[bytes]) -> None: +def dump_mapping(mapping: Mapping[str, StandardScaler], f: IO[bytes]) -> None: """ Serialize a mapping to a zip file. """ @@ -48,7 +48,7 @@ def dump_mapping(mapping: Mapping[Hashable, StandardScaler], f: IO[bytes]) -> No value.dump(f_dump) -def load_mapping(cls: Type[L], f: IO[bytes]) -> Mapping[Hashable, L]: +def load_mapping(cls: Type[L], f: IO[bytes]) -> Mapping[str, L]: """ Load a mapping from a zip file. """ @@ -68,8 +68,8 @@ def __init__( input_variables: Iterable[Hashable], output_variables: Iterable[Hashable], model: nn.Module, - scalers: Mapping[Hashable, StandardScaler], - output_scalers: Optional[Mapping[Hashable, StandardScaler]] = None, + scalers: Mapping[str, StandardScaler], + output_scalers: Optional[Mapping[str, StandardScaler]] = None, ): """Initialize the predictor Args: @@ -111,7 +111,7 @@ def pack_to_tensor(self, X: xr.Dataset) -> torch.Tensor: packed = _pack_to_tensor( ds=X, timesteps=0, - state_variables=self.input_variables, + state_variables=tuple(str(item) for item in self.input_variables), scalers=self.scalers, ) # dimensions are [time, tile, x, y, z], @@ -124,7 +124,7 @@ def unpack_tensor(self, data: torch.Tensor) -> xr.Dataset: data = torch.reshape(data, (-1, 6) + tuple(data.shape[1:])) return _unpack_tensor( data, - varnames=self.output_variables, + varnames=tuple(str(item) for item in self.output_variables), scalers=self.scalers, dims=["time", "tile", "x", "y", "z"], ) @@ -153,9 +153,9 @@ class PytorchAutoregressor(Dumpable, Loadable): def __init__( self, - state_variables: Iterable[Hashable], + state_variables: Iterable[str], model: nn.Module, - scalers: Mapping[Hashable, StandardScaler], + scalers: Mapping[str, StandardScaler], ): """Initialize the predictor Args: @@ -264,14 +264,11 @@ class PytorchDumpable(Protocol): _MODEL_FILENAME: str _SCALERS_FILENAME: str _CONFIG_FILENAME: str - scalers: Mapping[Hashable, StandardScaler] + scalers: Mapping[str, StandardScaler] model: torch.nn.Module def __init__( - self, - model: torch.nn.Module, - scalers: Mapping[Hashable, StandardScaler], - **kwargs, + self, model: torch.nn.Module, scalers: Mapping[str, StandardScaler], **kwargs, ): ... @@ -313,8 +310,8 @@ def _dump_pytorch(obj: PytorchDumpable, path: str) -> None: def _pack_to_tensor( ds: xr.Dataset, timesteps: int, - state_variables: Iterable[Hashable], - scalers: Mapping[Hashable, StandardScaler], + state_variables: Iterable[str], + scalers: Mapping[str, StandardScaler], ) -> torch.Tensor: """ Packs the dataset into a tensor to be used by the pytorch model. @@ -371,8 +368,8 @@ def _pack_to_tensor( def _unpack_tensor( data: torch.Tensor, - varnames: Iterable[Hashable], - scalers: Mapping[Hashable, StandardScaler], + varnames: Iterable[str], + scalers: Mapping[str, StandardScaler], dims: Sequence[Hashable], ) -> xr.Dataset: i_feature = 0 @@ -388,7 +385,7 @@ def _unpack_tensor( else: n_features = 1 var_data = data[..., i_feature] - var_data = scalers[varname].denormalize(var_data) + var_data = scalers[varname].denormalize(var_data.to("cpu").numpy()) data_vars[varname] = xr.DataArray( data=var_data, dims=dims[: len(var_data.shape)] ) diff --git a/external/fv3fit/fv3fit/pytorch/system.py b/external/fv3fit/fv3fit/pytorch/system.py index fdc2c15b6e..7e03607b3e 100644 --- a/external/fv3fit/fv3fit/pytorch/system.py +++ b/external/fv3fit/fv3fit/pytorch/system.py @@ -1,3 +1,10 @@ +import torch.backends import torch -DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +DEVICE = torch.device( + "cuda:0" + if torch.cuda.is_available() + else "mps" + if torch.backends.mps.is_available() + else "cpu" +) diff --git a/external/fv3fit/tests/training/test_autoencoder.py b/external/fv3fit/tests/training/test_autoencoder.py index f1c74306d7..a961bab2b2 100644 --- a/external/fv3fit/tests/training/test_autoencoder.py +++ b/external/fv3fit/tests/training/test_autoencoder.py @@ -65,7 +65,7 @@ def test_autoencoder(tmpdir): generator=fv3fit.pytorch.GeneratorConfig( n_convolutions=2, n_resnet=3, max_filters=32 ), - training_loop=TrainingConfig(n_epoch=5, samples_per_batch=2), + training_loop=TrainingConfig(n_epoch=10, samples_per_batch=2), optimizer_config=fv3fit.pytorch.OptimizerConfig(name="Adam",), noise_amount=0.5, ) @@ -78,6 +78,7 @@ def test_autoencoder(tmpdir): predicted = predictor.predict(test_xrdataset) reference = test_xrdataset # plotting code to uncomment if you'd like to manually check the results: + # import matplotlib.pyplot as plt # for i in range(6): # fig, ax = plt.subplots(1, 2) # vmin = reference["a"][0, i, :, :, 0].values.min() @@ -114,9 +115,12 @@ def test_autoencoder_overfit(tmpdir): optimizer_config=fv3fit.pytorch.OptimizerConfig(name="Adam",), noise_amount=0.0, ) - predictor = train_autoencoder( - hyperparameters, train_tfdataset, validation_batches=None - ) + import torch + + with torch.amp.autocast("cpu", enabled=False): + predictor = train_autoencoder( + hyperparameters, train_tfdataset, validation_batches=None + ) # for test, need one continuous series so we consistently flip sign test_xrdataset = tfdataset_to_xr_dataset( train_tfdataset, dims=["time", "tile", "x", "y", "z"] @@ -124,12 +128,13 @@ def test_autoencoder_overfit(tmpdir): predicted = predictor.predict(test_xrdataset) reference = test_xrdataset # plotting code to uncomment if you'd like to manually check the results: + # import matplotlib.pyplot as plt # for i in range(6): # fig, ax = plt.subplots(1, 2) # vmin = reference["a"][0, i, :, :, 0].values.min() # vmax = reference["a"][0, i, :, :, 0].values.max() - # ax[0].imshow(reference["a"][0, i, :, :, 0].values) # , vmin=vmin, vmax=vmax) - # ax[1].imshow(predicted["a"][0, i, :, :, 0].values) # , vmin=vmin, vmax=vmax) + # ax[0].imshow(reference["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) + # ax[1].imshow(predicted["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) # plt.tight_layout() # plt.show() bias = predicted - reference diff --git a/external/fv3fit/tests/training/test_cyclegan.py b/external/fv3fit/tests/training/test_cyclegan.py index 9fd1cb67a1..fd1b0b5dbc 100644 --- a/external/fv3fit/tests/training/test_cyclegan.py +++ b/external/fv3fit/tests/training/test_cyclegan.py @@ -6,9 +6,10 @@ CycleGANNetworkConfig, CycleGANTrainingConfig, train_cyclegan, - CycleGAN, ) from fv3fit.data import CycleGANLoader, SyntheticWaves, SyntheticNoise +import fv3fit.tfdataset +import tensorflow as tf import collections import os import fv3fit.pytorch @@ -28,9 +29,9 @@ def get_tfdataset(nsamples, nbatch, ntime, nx, nz): scalar_names=["b"], scale_min=0.5, scale_max=1.0, - period_min=4, - period_max=7, - phase_range=0.1, + period_min=8, + period_max=16, + type="sinusoidal", ), SyntheticWaves( nsamples=nsamples, @@ -39,11 +40,11 @@ def get_tfdataset(nsamples, nbatch, ntime, nx, nz): nx=nx, nz=nz, scalar_names=["b"], - scale_min=1.0, - scale_max=1.5, + scale_min=0.5, + scale_max=1.0, period_min=8, period_max=16, - phase_range=0.1, + type="square", ), ] ) @@ -98,49 +99,83 @@ def tfdataset_to_xr_dataset(tfdataset, dims: Sequence[str]): return xr.Dataset(data_vars) -# def test_cyclegan(tmpdir): -# fv3fit.set_random_seed(0) -# # run the test in a temporary directory to delete artifacts when done -# os.chdir(tmpdir) -# # need a larger nx, ny for the sample data here since we're training -# # on whether we can autoencode sin waves, and need to resolve full cycles -# nx, ny = 32, 32 -# sizes = {"nbatch": 2, "ntime": 2, "nx": nx, "ny": ny, "nz": 2} -# state_variables = ["a", "b"] -# train_tfdataset = get_tfdataset(nsamples=20, **sizes) -# val_tfdataset = get_tfdataset(nsamples=3, **sizes) -# hyperparameters = CycleGANHyperparameters( -# state_variables=state_variables, -# generator=fv3fit.pytorch.GeneratorConfig( -# n_convolutions=2, n_resnet=3, max_filters=32 -# ), -# training_loop=TrainingConfig(n_epoch=5, samples_per_batch=2), -# optimizer_config=fv3fit.pytorch.OptimizerConfig(name="Adam",), -# noise_amount=0.5, -# ) -# predictor = train_cyclegan(hyperparameters, train_tfdataset, val_tfdataset) -# # for test, need one continuous series so we consistently flip sign -# test_sizes = {"nbatch": 1, "ntime": 100, "nx": nx, "ny": ny, "nz": 2} -# test_xrdataset = tfdataset_to_xr_dataset( -# get_tfdataset(nsamples=1, **test_sizes), dims=["time", "tile", "x", "y", "z"] -# ) -# predicted = predictor.predict(test_xrdataset) -# reference = test_xrdataset -# # plotting code to uncomment if you'd like to manually check the results: -# # for i in range(6): -# # fig, ax = plt.subplots(1, 2) -# # vmin = reference["a"][0, i, :, :, 0].values.min() -# # vmax = reference["a"][0, i, :, :, 0].values.max() -# # ax[0].imshow(reference["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) -# # ax[1].imshow(predicted["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) -# # plt.tight_layout() -# # plt.show() -# bias = predicted.isel(time=1) - reference.isel(time=1) -# mean_bias: xr.Dataset = bias.mean() -# mse: xr.Dataset = (bias ** 2).mean() ** 0.5 -# for varname in state_variables: -# assert np.abs(mean_bias[varname]) < 0.1 -# assert mse[varname] < 0.1 +def test_cyclegan(tmpdir): + fv3fit.set_random_seed(0) + # run the test in a temporary directory to delete artifacts when done + os.chdir(tmpdir) + # need a larger nx, ny for the sample data here since we're training + # on whether we can autoencode sin waves, and need to resolve full cycles + nx = 32 + sizes = {"nbatch": 1, "ntime": 1, "nx": nx, "nz": 2} + state_variables = ["a", "b"] + train_tfdataset = get_tfdataset(nsamples=100, **sizes) + val_tfdataset = get_tfdataset(nsamples=3, **sizes) + hyperparameters = CycleGANHyperparameters( + state_variables=state_variables, + network=CycleGANNetworkConfig( + generator=fv3fit.pytorch.GeneratorConfig( + n_convolutions=2, n_resnet=5, max_filters=128, kernel_size=3 + ), + generator_optimizer=fv3fit.pytorch.OptimizerConfig( + name="Adam", kwargs={"lr": 0.001} + ), + discriminator=fv3fit.pytorch.DiscriminatorConfig(kernel_size=3), + discriminator_optimizer=fv3fit.pytorch.OptimizerConfig( + name="Adam", kwargs={"lr": 0.001} + ), + # identity_weight=0.01, + # cycle_weight=0.3, + # gan_weight=1.0, + discriminator_weight=0.5, + ), + training_loop=CycleGANTrainingConfig(n_epoch=10, samples_per_batch=1), + ) + predictor = train_cyclegan(hyperparameters, train_tfdataset, val_tfdataset) + # for test, need one continuous series so we consistently flip sign + real_a = tfdataset_to_xr_dataset( + train_tfdataset.map(lambda a, b: a), dims=["time", "tile", "x", "y", "z"] + ) + real_b = tfdataset_to_xr_dataset( + train_tfdataset.map(lambda a, b: b), dims=["time", "tile", "x", "y", "z"] + ) + output_a = predictor.predict(real_b, reverse=True) + reconstructed_b = predictor.predict(output_a) + # print("output a") + # print_compare(output_a, real_a) + # print("reconstructed b") + # print_compare(reconstructed_b, real_b) + output_b = predictor.predict(real_a) + reconstructed_a = predictor.predict(output_b, reverse=True) + # plotting code to uncomment if you'd like to manually check the results: + iz = 0 + for i in range(1): + fig, ax = plt.subplots(3, 2, figsize=(8, 8)) + vmin = -1.5 + vmax = 1.5 + ax[0, 0].imshow(real_a["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) + ax[0, 1].imshow(real_b["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) + ax[1, 0].imshow(output_b["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) + ax[1, 1].imshow(output_a["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) + ax[2, 0].imshow( + reconstructed_a["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax + ) + ax[2, 1].imshow( + reconstructed_b["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax + ) + ax[0, 0].set_title("real a") + ax[0, 1].set_title("real b") + ax[1, 0].set_title("output b") + ax[1, 1].set_title("output a") + ax[2, 0].set_title("reconstructed a") + ax[2, 1].set_title("reconstructed b") + plt.tight_layout() + plt.show() + # bias = predicted.isel(time=1) - reference.isel(time=1) + # mean_bias: xr.Dataset = bias.mean() + # mse: xr.Dataset = (bias ** 2).mean() ** 0.5 + # for varname in state_variables: + # assert np.abs(mean_bias[varname]) < 0.1 + # assert mse[varname] < 0.1 def test_cyclegan_overfit(tmpdir): @@ -158,16 +193,19 @@ def test_cyclegan_overfit(tmpdir): state_variables=state_variables, network=CycleGANNetworkConfig( generator=fv3fit.pytorch.GeneratorConfig( - n_convolutions=2, n_resnet=1, max_filters=128 + n_convolutions=2, n_resnet=1, max_filters=64 ), generator_optimizer=fv3fit.pytorch.OptimizerConfig( - name="Adam", kwargs={"lr": 0.01} + name="Adam", kwargs={"lr": 0.001} ), discriminator_optimizer=fv3fit.pytorch.OptimizerConfig( - name="Adam", kwargs={"lr": 0.01} + name="Adam", kwargs={"lr": 0.0001} ), + identity_weight=0.001, + cycle_weight=0.3, + gan_weight=1.0, ), - training_loop=CycleGANTrainingConfig(n_epoch=100, samples_per_batch=6), + training_loop=CycleGANTrainingConfig(n_epoch=200, samples_per_batch=6), ) predictor = train_cyclegan( hyperparameters, train_tfdataset, validation_batches=train_tfdataset @@ -179,21 +217,41 @@ def test_cyclegan_overfit(tmpdir): real_b = tfdataset_to_xr_dataset( train_tfdataset.map(lambda a, b: b), dims=["time", "tile", "x", "y", "z"] ) - output_b = predictor.predict(real_a) output_a = predictor.predict(real_b, reverse=True) + reconstructed_b = predictor.predict(output_a) + print("output a") + print_compare(output_a, real_a) + print("reconstructed b") + print_compare(reconstructed_b, real_b) + output_b = predictor.predict(real_a) + reconstructed_a = predictor.predict(output_b, reverse=True) + print("reconstructed a") + print_compare(reconstructed_a, real_a) + print("output b") + print_compare(output_b, real_b) # plotting code to uncomment if you'd like to manually check the results: - for i in range(3): - fig, ax = plt.subplots(2, 2) + # import pdb; pdb.set_trace() + iz = 0 + for i in range(1): + fig, ax = plt.subplots(3, 2, figsize=(12, 7)) vmin = -1.5 vmax = 1.5 - ax[0, 0].imshow(real_a["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) - ax[0, 1].imshow(real_b["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) - ax[1, 0].imshow(output_a["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) - ax[1, 1].imshow(output_b["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) + ax[0, 0].imshow(real_a["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) + ax[0, 1].imshow(real_b["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) + ax[1, 0].imshow(output_a["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) + ax[1, 1].imshow(output_b["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) + ax[2, 0].imshow( + reconstructed_a["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax + ) + ax[2, 1].imshow( + reconstructed_b["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax + ) ax[0, 0].set_title("real a") ax[0, 1].set_title("real b") ax[1, 0].set_title("output a") ax[1, 1].set_title("output b") + ax[2, 0].set_title("reconstructed a") + ax[2, 1].set_title("reconstructed b") plt.tight_layout() plt.show() # bias = predicted - reference @@ -202,3 +260,43 @@ def test_cyclegan_overfit(tmpdir): # for varname in state_variables: # assert np.abs(mean_bias[varname]) < 0.1 # assert rmse[varname] < 0.1 + + +def assert_close(a: xr.Dataset, b: xr.Dataset): + rmse = ((a - b) ** 2).mean() ** 0.5 + bias = (a - b).mean() + for varname in rmse.data_vars: + assert rmse[varname] < 0.1 + for varname in bias.data_vars: + assert bias[varname] < 0.1 + + +def print_compare(a: xr.Dataset, b: xr.Dataset): + rmse = ((a - b) ** 2).mean() ** 0.5 + bias = (a - b).mean() + print("compare") + for varname in rmse.data_vars: + print(varname, rmse[varname], bias[varname]) + + +def test_tuple_map(): + """ + External package test demonstrating that for map operations on tuples + of functions, tuple entries are passed as independent arguments + and must be collected with *args. + """ + + def generator(): + for entry in [(1, 1), (2, 2), (3, 3)]: + yield entry + + dataset = tf.data.Dataset.from_generator( + generator, output_types=(tf.int32, tf.int32) + ) + + def map_fn(x, y): + return x * 2, y * 3 + + mapped = dataset.map(map_fn) + out = list(mapped) + assert out == [(2, 3), (4, 6), (6, 9)] From c4df7c4f0d80400cfd2ce54cb1d2e00b8a2fafaf Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 29 Aug 2022 12:43:00 -0700 Subject: [PATCH 20/55] fix test broken by merge --- external/fv3fit/tests/training/test_main.py | 1 - 1 file changed, 1 deletion(-) diff --git a/external/fv3fit/tests/training/test_main.py b/external/fv3fit/tests/training/test_main.py index 27fdfe3830..9b8c9a6465 100644 --- a/external/fv3fit/tests/training/test_main.py +++ b/external/fv3fit/tests/training/test_main.py @@ -100,7 +100,6 @@ def mock_train_dense_model(): "dense", fv3fit.DenseHyperparameters )(original_func) register._model_types.pop("mock") - register._dump_types.pop("mock") @pytest.fixture From 4f590b9eec68919acc73833a0d8d2c8ca6160643 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Tue, 30 Aug 2022 11:44:20 -0700 Subject: [PATCH 21/55] fix test_io.py by reverting to master --- external/fv3fit/tests/test_io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/external/fv3fit/tests/test_io.py b/external/fv3fit/tests/test_io.py index e1c004232e..0ef536f9d0 100644 --- a/external/fv3fit/tests/test_io.py +++ b/external/fv3fit/tests/test_io.py @@ -13,7 +13,7 @@ class Mock: pass mock = Mock() - assert register.get_dumpable_name(mock) == "mock" + assert register.get_name(mock) == "mock" def test_registering_twice_fails(): @@ -42,7 +42,7 @@ class MockSubclass(Mock): pass mock = MockSubclass() - assert register.get_dumpable_name(mock) == "mock-subclass" + assert register.get_name(mock) == "mock-subclass" def test_register_dump_load(tmpdir): From ed89ba7659104ef9d95257d7204c7426632fd470 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Fri, 2 Sep 2022 14:04:33 -0700 Subject: [PATCH 22/55] working version of cyclegan manual test --- .../fv3fit/fv3fit/pytorch/cyclegan/train.py | 8 +-- external/fv3fit/fv3fit/pytorch/graph/train.py | 37 +--------- .../fv3fit/tests/training/test_cyclegan.py | 70 ++++++++----------- external/fv3fit/tests/training/test_graph.py | 4 +- 4 files changed, 38 insertions(+), 81 deletions(-) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train.py index 09bd54e2d7..5242caa26b 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train.py @@ -14,7 +14,7 @@ Tuple, ) from .network import Generator, GeneratorConfig -from fv3fit.pytorch.graph.train import get_Xy_dataset +from fv3fit.pytorch.graph.train import get_Xy_map_fn from fv3fit._shared.scaler import ( get_standard_scaler_mapping, get_mapping_standard_scale_func, @@ -84,18 +84,18 @@ def train_autoencoder( scalers = get_standard_scaler_mapping(sample_batch) mapping_scale_func = get_mapping_standard_scale_func(scalers) - get_state = curry(get_Xy_dataset)( + get_state = get_Xy_map_fn( state_variables=hyperparameters.state_variables, n_dims=6, # [batch, time, tile, x, y, z] mapping_scale_func=mapping_scale_func, ) if validation_batches is not None: - val_state = get_state(data=validation_batches) + val_state = validation_batches.map(get_state) else: val_state = None - train_state = get_state(data=train_batches) + train_state = train_batches.map(get_state) train_model = build_model( hyperparameters.generator, n_state=next(iter(train_state)).shape[-1] diff --git a/external/fv3fit/fv3fit/pytorch/graph/train.py b/external/fv3fit/fv3fit/pytorch/graph/train.py index 51a0a5d33d..2bf663e925 100644 --- a/external/fv3fit/fv3fit/pytorch/graph/train.py +++ b/external/fv3fit/fv3fit/pytorch/graph/train.py @@ -130,41 +130,6 @@ def build_model(graph_network, n_state: int, nx: int): ) -def get_Xy_dataset( - state_variables: Sequence[str], - n_dims: int, - mapping_scale_func: Callable[[Mapping[str, np.ndarray]], Mapping[str, np.ndarray]], - data: tf.data.Dataset, -): - """ - Given a tf.data.Dataset with mappings from variable name to samples - return a tf.data.Dataset whose entries are tensors of the requested - state variables concatenated along the feature dimension. - - Args: - state_variables: names of variables to include in returned tensor - n_dims: number of dimensions of each sample, including feature dimension - mapping_scale_func: function which scales data stored as a mapping - from variable name to array - data: tf.data.Dataset with mappings from variable name - to sample tensors - - Returns: - tf.data.Dataset where each sample is a single tensor - containing normalized and concatenated state variables - """ - ensure_dims = apply_to_mapping(ensure_nd(n_dims)) - - def map_fn(data): - data = mapping_scale_func(data) - data = ensure_dims(data) - data = select_keys(state_variables, data) - data = tf.concat(data, axis=-1) - return data - - return data.map(map_fn) - - def get_Xy_map_fn( state_variables: Sequence[str], n_dims: int, @@ -190,7 +155,7 @@ def get_Xy_map_fn( ensure_dims = apply_to_mapping(ensure_nd(n_dims)) def map_fn(data): - # data = mapping_scale_func(data) + data = mapping_scale_func(data) data = ensure_dims(data) data = select_keys(state_variables, data) data = tf.concat(data, axis=-1) diff --git a/external/fv3fit/tests/training/test_cyclegan.py b/external/fv3fit/tests/training/test_cyclegan.py index fd1b0b5dbc..68b469ef62 100644 --- a/external/fv3fit/tests/training/test_cyclegan.py +++ b/external/fv3fit/tests/training/test_cyclegan.py @@ -15,6 +15,7 @@ import fv3fit.pytorch import fv3fit import matplotlib.pyplot as plt +import pytest def get_tfdataset(nsamples, nbatch, ntime, nx, nz): @@ -31,7 +32,7 @@ def get_tfdataset(nsamples, nbatch, ntime, nx, nz): scale_max=1.0, period_min=8, period_max=16, - type="sinusoidal", + wave_type="sinusoidal", ), SyntheticWaves( nsamples=nsamples, @@ -44,7 +45,7 @@ def get_tfdataset(nsamples, nbatch, ntime, nx, nz): scale_max=1.0, period_min=8, period_max=16, - type="square", + wave_type="square", ), ] ) @@ -99,6 +100,7 @@ def tfdataset_to_xr_dataset(tfdataset, dims: Sequence[str]): return xr.Dataset(data_vars) +@pytest.mark.skip("test is designed to run manually to visualize results") def test_cyclegan(tmpdir): fv3fit.set_random_seed(0) # run the test in a temporary directory to delete artifacts when done @@ -109,12 +111,12 @@ def test_cyclegan(tmpdir): sizes = {"nbatch": 1, "ntime": 1, "nx": nx, "nz": 2} state_variables = ["a", "b"] train_tfdataset = get_tfdataset(nsamples=100, **sizes) - val_tfdataset = get_tfdataset(nsamples=3, **sizes) + val_tfdataset = get_tfdataset(nsamples=20, **sizes) hyperparameters = CycleGANHyperparameters( state_variables=state_variables, network=CycleGANNetworkConfig( generator=fv3fit.pytorch.GeneratorConfig( - n_convolutions=2, n_resnet=5, max_filters=128, kernel_size=3 + n_convolutions=3, n_resnet=5, max_filters=128, kernel_size=3 ), generator_optimizer=fv3fit.pytorch.OptimizerConfig( name="Adam", kwargs={"lr": 0.001} @@ -128,7 +130,9 @@ def test_cyclegan(tmpdir): # gan_weight=1.0, discriminator_weight=0.5, ), - training_loop=CycleGANTrainingConfig(n_epoch=10, samples_per_batch=1), + training_loop=CycleGANTrainingConfig( + n_epoch=20, samples_per_batch=1, validation_batch_size=10 + ), ) predictor = train_cyclegan(hyperparameters, train_tfdataset, val_tfdataset) # for test, need one continuous series so we consistently flip sign @@ -140,42 +144,27 @@ def test_cyclegan(tmpdir): ) output_a = predictor.predict(real_b, reverse=True) reconstructed_b = predictor.predict(output_a) - # print("output a") - # print_compare(output_a, real_a) - # print("reconstructed b") - # print_compare(reconstructed_b, real_b) output_b = predictor.predict(real_a) reconstructed_a = predictor.predict(output_b, reverse=True) - # plotting code to uncomment if you'd like to manually check the results: iz = 0 - for i in range(1): - fig, ax = plt.subplots(3, 2, figsize=(8, 8)) - vmin = -1.5 - vmax = 1.5 - ax[0, 0].imshow(real_a["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) - ax[0, 1].imshow(real_b["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) - ax[1, 0].imshow(output_b["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) - ax[1, 1].imshow(output_a["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) - ax[2, 0].imshow( - reconstructed_a["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax - ) - ax[2, 1].imshow( - reconstructed_b["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax - ) - ax[0, 0].set_title("real a") - ax[0, 1].set_title("real b") - ax[1, 0].set_title("output b") - ax[1, 1].set_title("output a") - ax[2, 0].set_title("reconstructed a") - ax[2, 1].set_title("reconstructed b") - plt.tight_layout() - plt.show() - # bias = predicted.isel(time=1) - reference.isel(time=1) - # mean_bias: xr.Dataset = bias.mean() - # mse: xr.Dataset = (bias ** 2).mean() ** 0.5 - # for varname in state_variables: - # assert np.abs(mean_bias[varname]) < 0.1 - # assert mse[varname] < 0.1 + i = 0 + fig, ax = plt.subplots(3, 2, figsize=(8, 8)) + vmin = -1.5 + vmax = 1.5 + ax[0, 0].imshow(real_a["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) + ax[0, 1].imshow(real_b["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) + ax[1, 0].imshow(output_b["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) + ax[1, 1].imshow(output_a["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) + ax[2, 0].imshow(reconstructed_a["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) + ax[2, 1].imshow(reconstructed_b["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) + ax[0, 0].set_title("real a") + ax[0, 1].set_title("real b") + ax[1, 0].set_title("output b") + ax[1, 1].set_title("output a") + ax[2, 0].set_title("reconstructed a") + ax[2, 1].set_title("reconstructed b") + plt.tight_layout() + plt.show() def test_cyclegan_overfit(tmpdir): @@ -193,11 +182,14 @@ def test_cyclegan_overfit(tmpdir): state_variables=state_variables, network=CycleGANNetworkConfig( generator=fv3fit.pytorch.GeneratorConfig( - n_convolutions=2, n_resnet=1, max_filters=64 + n_convolutions=2, n_resnet=1, max_filters=256, kernel_size=3 ), generator_optimizer=fv3fit.pytorch.OptimizerConfig( name="Adam", kwargs={"lr": 0.001} ), + discriminator=fv3fit.pytorch.DiscriminatorConfig( + kernel_size=3, n_convolutions=2, max_filters=256 + ), discriminator_optimizer=fv3fit.pytorch.OptimizerConfig( name="Adam", kwargs={"lr": 0.0001} ), diff --git a/external/fv3fit/tests/training/test_graph.py b/external/fv3fit/tests/training/test_graph.py index 7c84885cea..d455972bf2 100644 --- a/external/fv3fit/tests/training/test_graph.py +++ b/external/fv3fit/tests/training/test_graph.py @@ -83,8 +83,8 @@ def test_train_graph_network(tmpdir, network_type): training_config = AutoregressiveTrainingConfig(n_epoch=100) optimizer = OptimizerConfig(kwargs={"lr": 0.001}) elif network_type == "UNet": - graph_network = GraphUNetConfig(depth=1, min_filters=8) - training_config = AutoregressiveTrainingConfig(n_epoch=50) + graph_network = GraphUNetConfig() + training_config = AutoregressiveTrainingConfig(n_epoch=30) optimizer = OptimizerConfig(kwargs={"lr": 0.005}) hyperparameters = GraphHyperparameters( From e92aabdbfd4e3735bece318930b07b4f7c83c21d Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Fri, 2 Sep 2022 14:05:01 -0700 Subject: [PATCH 23/55] remove non-functional overfitting test --- .../fv3fit/tests/training/test_cyclegan.py | 104 ------------------ 1 file changed, 104 deletions(-) diff --git a/external/fv3fit/tests/training/test_cyclegan.py b/external/fv3fit/tests/training/test_cyclegan.py index 68b469ef62..6daefffe79 100644 --- a/external/fv3fit/tests/training/test_cyclegan.py +++ b/external/fv3fit/tests/training/test_cyclegan.py @@ -167,110 +167,6 @@ def test_cyclegan(tmpdir): plt.show() -def test_cyclegan_overfit(tmpdir): - fv3fit.set_random_seed(0) - # run the test in a temporary directory to delete artifacts when done - os.chdir(tmpdir) - # need a larger nx for the sample data here since we're training - # on whether we can autoencode sin waves, and need to resolve full cycles - nx = 16 - sizes = {"nbatch": 1, "ntime": 1, "nx": nx, "nz": 2} - state_variables = ["a", "b"] - train_tfdataset = get_noise_tfdataset(nsamples=1, **sizes) - train_tfdataset = train_tfdataset.cache() # needed to keep sample identical - hyperparameters = CycleGANHyperparameters( - state_variables=state_variables, - network=CycleGANNetworkConfig( - generator=fv3fit.pytorch.GeneratorConfig( - n_convolutions=2, n_resnet=1, max_filters=256, kernel_size=3 - ), - generator_optimizer=fv3fit.pytorch.OptimizerConfig( - name="Adam", kwargs={"lr": 0.001} - ), - discriminator=fv3fit.pytorch.DiscriminatorConfig( - kernel_size=3, n_convolutions=2, max_filters=256 - ), - discriminator_optimizer=fv3fit.pytorch.OptimizerConfig( - name="Adam", kwargs={"lr": 0.0001} - ), - identity_weight=0.001, - cycle_weight=0.3, - gan_weight=1.0, - ), - training_loop=CycleGANTrainingConfig(n_epoch=200, samples_per_batch=6), - ) - predictor = train_cyclegan( - hyperparameters, train_tfdataset, validation_batches=train_tfdataset - ) - # for test, need one continuous series so we consistently flip sign - real_a = tfdataset_to_xr_dataset( - train_tfdataset.map(lambda a, b: a), dims=["time", "tile", "x", "y", "z"] - ) - real_b = tfdataset_to_xr_dataset( - train_tfdataset.map(lambda a, b: b), dims=["time", "tile", "x", "y", "z"] - ) - output_a = predictor.predict(real_b, reverse=True) - reconstructed_b = predictor.predict(output_a) - print("output a") - print_compare(output_a, real_a) - print("reconstructed b") - print_compare(reconstructed_b, real_b) - output_b = predictor.predict(real_a) - reconstructed_a = predictor.predict(output_b, reverse=True) - print("reconstructed a") - print_compare(reconstructed_a, real_a) - print("output b") - print_compare(output_b, real_b) - # plotting code to uncomment if you'd like to manually check the results: - # import pdb; pdb.set_trace() - iz = 0 - for i in range(1): - fig, ax = plt.subplots(3, 2, figsize=(12, 7)) - vmin = -1.5 - vmax = 1.5 - ax[0, 0].imshow(real_a["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) - ax[0, 1].imshow(real_b["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) - ax[1, 0].imshow(output_a["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) - ax[1, 1].imshow(output_b["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) - ax[2, 0].imshow( - reconstructed_a["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax - ) - ax[2, 1].imshow( - reconstructed_b["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax - ) - ax[0, 0].set_title("real a") - ax[0, 1].set_title("real b") - ax[1, 0].set_title("output a") - ax[1, 1].set_title("output b") - ax[2, 0].set_title("reconstructed a") - ax[2, 1].set_title("reconstructed b") - plt.tight_layout() - plt.show() - # bias = predicted - reference - # mean_bias: xr.Dataset = bias.mean() - # rmse: xr.Dataset = (bias ** 2).mean() - # for varname in state_variables: - # assert np.abs(mean_bias[varname]) < 0.1 - # assert rmse[varname] < 0.1 - - -def assert_close(a: xr.Dataset, b: xr.Dataset): - rmse = ((a - b) ** 2).mean() ** 0.5 - bias = (a - b).mean() - for varname in rmse.data_vars: - assert rmse[varname] < 0.1 - for varname in bias.data_vars: - assert bias[varname] < 0.1 - - -def print_compare(a: xr.Dataset, b: xr.Dataset): - rmse = ((a - b) ** 2).mean() ** 0.5 - bias = (a - b).mean() - print("compare") - for varname in rmse.data_vars: - print(varname, rmse[varname], bias[varname]) - - def test_tuple_map(): """ External package test demonstrating that for map operations on tuples From d988f23ffb4f7f0a8330af5478a1d3faa4581809 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Fri, 2 Sep 2022 15:31:19 -0700 Subject: [PATCH 24/55] update cyclegan to work with mps acceleration --- .../fv3fit/fv3fit/pytorch/cyclegan/network.py | 16 ++++----- .../fv3fit/fv3fit/pytorch/cyclegan/train.py | 4 +++ .../fv3fit/pytorch/cyclegan/train_cyclegan.py | 33 ++++++++++++------- external/fv3fit/fv3fit/pytorch/system.py | 4 +-- external/fv3fit/fv3fit/tfdataset.py | 10 +++--- .../fv3fit/tests/training/test_cyclegan.py | 4 +-- 6 files changed, 42 insertions(+), 29 deletions(-) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/network.py b/external/fv3fit/fv3fit/pytorch/cyclegan/network.py index a9265b2f2d..062b7a9afb 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/network.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/network.py @@ -4,6 +4,7 @@ from typing import Callable, Literal, Protocol, Union import torch.nn as nn from toolz import curry +import torch logger = logging.getLogger(__name__) @@ -256,10 +257,8 @@ def __init__( ) self._sequential = nn.Sequential(*convs, final_conv, patch_output) - def forward(self, inputs): - inputs = inputs.permute(0, 3, 1, 2) - outputs = self._sequential(inputs) - return outputs.permute(0, 2, 3, 1) + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return self._sequential(inputs) class Generator(nn.Module): @@ -338,14 +337,11 @@ def up(in_channels: int, out_channels: int): padding="same", ) - def forward(self, inputs): - # permute [batch, x, y, channels] to [batch, channels, x, y] - inputs = inputs.permute(0, 3, 1, 2) + def forward(self, inputs: torch.Tensor) -> torch.Tensor: x = self._first_conv(inputs) x = self._unet(x) - outputs = self._out_conv(x) - # permute [batch, channels, x, y] to [batch, x, y, channels] - return outputs.permute(0, 2, 3, 1) + outputs: torch.Tensor = self._out_conv(x) + return outputs class UNet(nn.Module): diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train.py index 5242caa26b..59f1f58dcd 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train.py @@ -130,5 +130,9 @@ def train_autoencoder( return predictor +def channels_first(data: tf.Tensor) -> tf.Tensor: + return tf.transpose(data, perm=[0, 3, 1, 2]) + + def build_model(config: GeneratorConfig, n_state: int) -> Generator: return config.build(channels=n_state).to(DEVICE) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py index 80012b4f29..7f647e580b 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py @@ -10,7 +10,7 @@ import torch from fv3fit.pytorch.system import DEVICE import tensorflow_datasets as tfds -from fv3fit.tfdataset import sequence_size +from fv3fit.tfdataset import sequence_size, apply_to_tuple from fv3fit.pytorch.predict import ( _load_pytorch, _dump_pytorch, @@ -186,6 +186,10 @@ def Xy_map_fn(*data: Mapping[str, np.ndarray]): return Xy_map_fn +def channels_first(data: tf.Tensor) -> tf.Tensor: + return tf.transpose(data, perm=[0, 3, 1, 2]) + + @register_training_function("cyclegan", CycleGANHyperparameters) def train_cyclegan( hyperparameters: CycleGANHyperparameters, @@ -233,9 +237,9 @@ def train_cyclegan( ) # remove time and tile dimensions, while we're using regular convolution - train_state = train_state.unbatch().unbatch() + train_state = train_state.unbatch().map(apply_to_tuple(channels_first)).unbatch() if validation_batches is not None: - val_state = val_state.unbatch().unbatch() + val_state = val_state.unbatch().map(apply_to_tuple(channels_first)).unbatch() hyperparameters.training_loop.fit_loop( train_model=train_model, train_data=train_state, validation_data=val_state, @@ -463,9 +467,10 @@ def pack_to_tensor(self, ds: xr.Dataset, domain: str = "a") -> torch.Tensor: for name, scaler in self.scalers.items() if name.startswith(f"{domain}_") } - return _pack_to_tensor( + tensor = _pack_to_tensor( ds=ds, timesteps=0, state_variables=self.state_variables, scalers=scalers, ) + return tensor.permute([0, 1, 4, 2, 3]) def unpack_tensor(self, data: torch.Tensor, domain: str = "b") -> xr.Dataset: """ @@ -484,7 +489,7 @@ def unpack_tensor(self, data: torch.Tensor, domain: str = "b") -> xr.Dataset: if name.startswith(f"{domain}_") } return _unpack_tensor( - data, + data.permute([0, 1, 3, 4, 2]), varnames=self.state_variables, scalers=scalers, dims=["time", "tile", "x", "y", "z"], @@ -542,12 +547,8 @@ class CycleGANTrainer: discriminator_weight: float = 1.0 def __post_init__(self): - self.target_real = torch.autograd.Variable( - torch.Tensor(self.batch_size).fill_(1.0).to(DEVICE), requires_grad=False - ) - self.target_fake = torch.autograd.Variable( - torch.Tensor(self.batch_size).fill_(0.0).to(DEVICE), requires_grad=False - ) + self.target_real: Optional[torch.autograd.Variable] = None + self.target_fake: Optional[torch.autograd.Variable] = None self.fake_a_buffer = ReplayBuffer() self.fake_b_buffer = ReplayBuffer() self.generator_a_to_b = self.cycle_gan.generator_a_to_b @@ -555,6 +556,14 @@ def __post_init__(self): self.discriminator_a = self.cycle_gan.discriminator_a self.discriminator_b = self.cycle_gan.discriminator_b + def _init_targets(self, shape: Tuple[int, ...]): + self.target_real = torch.autograd.Variable( + torch.Tensor(shape).fill_(1.0).to(DEVICE), requires_grad=False + ) + self.target_fake = torch.autograd.Variable( + torch.Tensor(shape).fill_(0.0).to(DEVICE), requires_grad=False + ) + def evaluate_on_dataset( self, dataset: tf.data.Dataset, n_dims_keep: int = 3 ) -> Dict[str, float]: @@ -614,6 +623,8 @@ def train_on_batch( # GAN loss pred_fake_b = self.discriminator_b(fake_b) + if self.target_real is None: + self._init_targets(pred_fake_b.shape) loss_gan_a_to_b = self.gan_loss(pred_fake_b, self.target_real) * self.gan_weight pred_fake_a = self.discriminator_a(fake_a) diff --git a/external/fv3fit/fv3fit/pytorch/system.py b/external/fv3fit/fv3fit/pytorch/system.py index 36ff17e1a5..7e03607b3e 100644 --- a/external/fv3fit/fv3fit/pytorch/system.py +++ b/external/fv3fit/fv3fit/pytorch/system.py @@ -4,7 +4,7 @@ DEVICE = torch.device( "cuda:0" if torch.cuda.is_available() - # else "mps" - # if torch.backends.mps.is_available() + else "mps" + if torch.backends.mps.is_available() else "cpu" ) diff --git a/external/fv3fit/fv3fit/tfdataset.py b/external/fv3fit/fv3fit/tfdataset.py index 9cf1b22d9c..35dd93e54c 100644 --- a/external/fv3fit/fv3fit/tfdataset.py +++ b/external/fv3fit/fv3fit/tfdataset.py @@ -27,11 +27,13 @@ def apply_to_mapping( return {name: tensor_func(tensor) for name, tensor in data.items()} -@curry def apply_to_tuple( - tensor_func: Callable[[T_in], T_out], data: Tuple[T_in, ...] -) -> Tuple[T_out, ...]: - return tuple(tensor_func(tensor) for tensor in data) + tensor_func: Callable[[T_in], T_out], +) -> Callable[[Tuple[T_in, ...]], Tuple[T_out, ...]]: + def wrapped(*data): + return tuple(tensor_func(tensor) for tensor in data) + + return wrapped def sequence_size(seq): diff --git a/external/fv3fit/tests/training/test_cyclegan.py b/external/fv3fit/tests/training/test_cyclegan.py index 6daefffe79..e2ac0e2bab 100644 --- a/external/fv3fit/tests/training/test_cyclegan.py +++ b/external/fv3fit/tests/training/test_cyclegan.py @@ -110,7 +110,7 @@ def test_cyclegan(tmpdir): nx = 32 sizes = {"nbatch": 1, "ntime": 1, "nx": nx, "nz": 2} state_variables = ["a", "b"] - train_tfdataset = get_tfdataset(nsamples=100, **sizes) + train_tfdataset = get_tfdataset(nsamples=200, **sizes) val_tfdataset = get_tfdataset(nsamples=20, **sizes) hyperparameters = CycleGANHyperparameters( state_variables=state_variables, @@ -131,7 +131,7 @@ def test_cyclegan(tmpdir): discriminator_weight=0.5, ), training_loop=CycleGANTrainingConfig( - n_epoch=20, samples_per_batch=1, validation_batch_size=10 + n_epoch=30, samples_per_batch=20, validation_batch_size=10 ), ) predictor = train_cyclegan(hyperparameters, train_tfdataset, val_tfdataset) From e6b2045ade3a3635c34381ee5ea8b05f3b586a34 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Fri, 2 Sep 2022 17:20:50 -0700 Subject: [PATCH 25/55] re-organize cyclegan code into more modules --- external/fv3fit/fv3fit/pytorch/__init__.py | 1 + .../fv3fit/pytorch/cyclegan/__init__.py | 9 +- .../pytorch/cyclegan/cyclegan_trainer.py | 339 +++++++++++ .../fv3fit/pytorch/cyclegan/discriminator.py | 82 +++ .../fv3fit/pytorch/cyclegan/generator.py | 148 +++++ .../fv3fit/fv3fit/pytorch/cyclegan/modules.py | 167 ++++++ .../fv3fit/fv3fit/pytorch/cyclegan/network.py | 374 ------------ .../fv3fit/pytorch/cyclegan/reloadable.py | 162 ++++++ .../{train.py => train_autoencoder.py} | 12 +- .../fv3fit/pytorch/cyclegan/train_cyclegan.py | 534 +----------------- external/fv3fit/fv3fit/pytorch/predict.py | 7 +- .../fv3fit/tests/training/test_autoencoder.py | 6 +- 12 files changed, 924 insertions(+), 917 deletions(-) create mode 100644 external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py create mode 100644 external/fv3fit/fv3fit/pytorch/cyclegan/discriminator.py create mode 100644 external/fv3fit/fv3fit/pytorch/cyclegan/generator.py create mode 100644 external/fv3fit/fv3fit/pytorch/cyclegan/modules.py delete mode 100644 external/fv3fit/fv3fit/pytorch/cyclegan/network.py create mode 100644 external/fv3fit/fv3fit/pytorch/cyclegan/reloadable.py rename external/fv3fit/fv3fit/pytorch/cyclegan/{train.py => train_autoencoder.py} (92%) diff --git a/external/fv3fit/fv3fit/pytorch/__init__.py b/external/fv3fit/fv3fit/pytorch/__init__.py index e1277300b7..883c1e5df0 100644 --- a/external/fv3fit/fv3fit/pytorch/__init__.py +++ b/external/fv3fit/fv3fit/pytorch/__init__.py @@ -14,3 +14,4 @@ ) from .optimizer import OptimizerConfig from .activation import ActivationConfig +from .loss import LossConfig diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py b/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py index 250e6cbd1d..bd8ca6d5f6 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py @@ -1,9 +1,10 @@ -from .train import train_autoencoder, AutoencoderHyperparameters -from .network import GeneratorConfig, DiscriminatorConfig +from .train_autoencoder import train_autoencoder, AutoencoderHyperparameters from .train_cyclegan import ( train_cyclegan, CycleGANHyperparameters, - CycleGANNetworkConfig, CycleGANTrainingConfig, - CycleGAN, ) +from .discriminator import DiscriminatorConfig +from .generator import GeneratorConfig +from .cyclegan_trainer import CycleGANNetworkConfig +from .reloadable import CycleGAN diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py b/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py new file mode 100644 index 0000000000..78a802d3e8 --- /dev/null +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py @@ -0,0 +1,339 @@ +import random +from typing import Dict, List, Mapping, Tuple, Optional +import tensorflow as tf +from fv3fit._shared.scaler import StandardScaler +from .reloadable import CycleGAN, CycleGANModule +import torch +from .generator import GeneratorConfig +from .discriminator import DiscriminatorConfig +import dataclasses +from fv3fit.pytorch.loss import LossConfig +from fv3fit.pytorch.optimizer import OptimizerConfig +from fv3fit.pytorch.system import DEVICE +import itertools +import numpy as np + + +@dataclasses.dataclass +class CycleGANNetworkConfig: + generator_optimizer: OptimizerConfig = dataclasses.field( + default_factory=lambda: OptimizerConfig("Adam") + ) + discriminator_optimizer: OptimizerConfig = dataclasses.field( + default_factory=lambda: OptimizerConfig("Adam") + ) + generator: "GeneratorConfig" = dataclasses.field( + default_factory=lambda: GeneratorConfig() + ) + discriminator: "DiscriminatorConfig" = dataclasses.field( + default_factory=lambda: DiscriminatorConfig() + ) + identity_loss: LossConfig = dataclasses.field(default_factory=LossConfig) + cycle_loss: LossConfig = dataclasses.field(default_factory=LossConfig) + gan_loss: LossConfig = dataclasses.field(default_factory=LossConfig) + identity_weight: float = 0.5 + cycle_weight: float = 1.0 + gan_weight: float = 1.0 + discriminator_weight: float = 1.0 + + def build( + self, n_state: int, n_batch: int, state_variables, scalers + ) -> "CycleGANTrainer": + generator_a_to_b = self.generator.build(n_state) + generator_b_to_a = self.generator.build(n_state) + discriminator_a = self.discriminator.build(n_state) + discriminator_b = self.discriminator.build(n_state) + optimizer_generator = self.generator_optimizer.instance( + itertools.chain( + generator_a_to_b.parameters(), generator_b_to_a.parameters() + ) + ) + optimizer_discriminator = self.discriminator_optimizer.instance( + itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()) + ) + return CycleGANTrainer( + cycle_gan=CycleGAN( + model=CycleGANModule( + generator_a_to_b=generator_a_to_b, + generator_b_to_a=generator_b_to_a, + discriminator_a=discriminator_a, + discriminator_b=discriminator_b, + ).to(DEVICE), + state_variables=state_variables, + scalers=_merge_scaler_mappings(scalers), + ), + optimizer_generator=optimizer_generator, + optimizer_discriminator=optimizer_discriminator, + identity_loss=self.identity_loss.instance, + cycle_loss=self.cycle_loss.instance, + gan_loss=self.gan_loss.instance, + batch_size=n_batch, + identity_weight=self.identity_weight, + cycle_weight=self.cycle_weight, + gan_weight=self.gan_weight, + discriminator_weight=self.discriminator_weight, + ) + + +def _merge_scaler_mappings( + scaler_tuple: Tuple[Mapping[str, StandardScaler], Mapping[str, StandardScaler]] +) -> Mapping[str, StandardScaler]: + scalers = {} + for prefix, scaler_map in zip(("a_", "b_"), scaler_tuple): + for key, scaler in scaler_map.items(): + scalers[prefix + key] = scaler + return scalers + + +class ReplayBuffer: + + # To reduce model oscillation during training, we update the discriminator + # using a history of generated data instead of the most recently generated data + # according to Shrivastava et al. (2017). + + def __init__(self, max_size=50): + if max_size <= 0: + raise ValueError("max_size must be positive") + self.max_size = max_size + self.data = [] + + def push_and_pop(self, data: torch.Tensor) -> torch.autograd.Variable: + to_return = [] + for element in data.data: + element = torch.unsqueeze(element, 0) + if len(self.data) < self.max_size: + self.data.append(element) + to_return.append(element) + else: + if random.uniform(0, 1) > 0.5: + i = random.randint(0, self.max_size - 1) + to_return.append(self.data[i].clone()) + self.data[i] = element + else: + to_return.append(element) + return torch.autograd.Variable(torch.cat(to_return)) + + +class StatsCollector: + def __init__(self, n_dims_keep: int): + self.n_dims_keep = n_dims_keep + self._sum = 0.0 + self._sum_squared = 0.0 + self._count = 0 + + def observe(self, data: np.ndarray): + mean_dims = tuple(range(0, len(data.shape) - self.n_dims_keep)) + data = data.astype(np.float64) + self._sum += data.mean(axis=mean_dims) + self._sum_squared += (data ** 2).mean(axis=mean_dims) + self._count += 1 + + @property + def mean(self) -> np.ndarray: + return self._sum / self._count + + @property + def std(self) -> np.ndarray: + return np.sqrt(self._sum_squared / self._count - self.mean ** 2) + + +def get_r2(predicted, target) -> float: + """ + Compute the R^2 statistic for the predicted and target data. + """ + return 1.0 - np.var(predicted - target) / np.var(target) + + +@dataclasses.dataclass +class CycleGANTrainer: + + # This class based loosely on + # https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/cycle_gan_model.py + + # Copyright Facebook, BSD license + # https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/c99ce7c4e781712e0252c6127ad1a4e8021cc489/LICENSE + + cycle_gan: CycleGAN + optimizer_generator: torch.optim.Optimizer + optimizer_discriminator: torch.optim.Optimizer + identity_loss: torch.nn.Module + cycle_loss: torch.nn.Module + gan_loss: torch.nn.Module + batch_size: int + identity_weight: float = 0.5 + cycle_weight: float = 1.0 + gan_weight: float = 1.0 + discriminator_weight: float = 1.0 + + def __post_init__(self): + self.target_real: Optional[torch.autograd.Variable] = None + self.target_fake: Optional[torch.autograd.Variable] = None + self.fake_a_buffer = ReplayBuffer() + self.fake_b_buffer = ReplayBuffer() + self.generator_a_to_b = self.cycle_gan.generator_a_to_b + self.generator_b_to_a = self.cycle_gan.generator_b_to_a + self.discriminator_a = self.cycle_gan.discriminator_a + self.discriminator_b = self.cycle_gan.discriminator_b + + def _init_targets(self, shape: Tuple[int, ...]): + self.target_real = torch.autograd.Variable( + torch.Tensor(shape).fill_(1.0).to(DEVICE), requires_grad=False + ) + self.target_fake = torch.autograd.Variable( + torch.Tensor(shape).fill_(0.0).to(DEVICE), requires_grad=False + ) + + def evaluate_on_dataset( + self, dataset: tf.data.Dataset, n_dims_keep: int = 3 + ) -> Dict[str, float]: + stats_real_a = StatsCollector(n_dims_keep) + stats_real_b = StatsCollector(n_dims_keep) + stats_gen_a = StatsCollector(n_dims_keep) + stats_gen_b = StatsCollector(n_dims_keep) + real_a: np.ndarray + real_b: np.ndarray + for real_a, real_b in dataset: + stats_real_a.observe(real_a) + stats_real_b.observe(real_b) + gen_b: torch.Tensor = self.generator_a_to_b( + torch.as_tensor(real_a).float().to(DEVICE) + ) + gen_a: torch.Tensor = self.generator_b_to_a( + torch.as_tensor(real_b).float().to(DEVICE) + ) + stats_gen_a.observe(gen_a.detach().cpu().numpy()) + stats_gen_b.observe(gen_b.detach().cpu().numpy()) + metrics = { + # "r2_mean_b_against_real_a": get_r2(stats_real_a.mean, stats_gen_b.mean), + "r2_mean_a": get_r2(stats_real_a.mean, stats_gen_a.mean), + # "bias_mean_a": np.mean(stats_real_a.mean - stats_gen_a.mean), + "r2_mean_b": get_r2(stats_real_b.mean, stats_gen_b.mean), + # "bias_mean_b": np.mean(stats_real_b.mean - stats_gen_b.mean), + "r2_std_a": get_r2(stats_real_a.std, stats_gen_a.std), + # "bias_std_a": np.mean(stats_real_a.std - stats_gen_a.std), + "r2_std_b": get_r2(stats_real_b.std, stats_gen_b.std), + # "bias_std_b": np.mean(stats_real_b.std - stats_gen_b.std), + } + return metrics + + def train_on_batch( + self, real_a: torch.Tensor, real_b: torch.Tensor + ) -> Mapping[str, float]: + fake_b = self.generator_a_to_b(real_a) + fake_a = self.generator_b_to_a(real_b) + reconstructed_a = self.generator_b_to_a(fake_b) + reconstructed_b = self.generator_a_to_b(fake_a) + + # Generators A2B and B2A ###### + + # don't update discriminators when training generators to fool them + set_requires_grad( + [self.discriminator_a, self.discriminator_b], requires_grad=False + ) + + # Identity loss + # G_A2B(B) should equal B if real B is fed + same_b = self.generator_a_to_b(real_b) + loss_identity_b = self.identity_loss(same_b, real_b) * self.identity_weight + # G_B2A(A) should equal A if real A is fed + same_a = self.generator_b_to_a(real_a) + loss_identity_a = self.identity_loss(same_a, real_a) * self.identity_weight + loss_identity = loss_identity_a + loss_identity_b + + # GAN loss + pred_fake_b = self.discriminator_b(fake_b) + if self.target_real is None: + self._init_targets(pred_fake_b.shape) + loss_gan_a_to_b = self.gan_loss(pred_fake_b, self.target_real) * self.gan_weight + + pred_fake_a = self.discriminator_a(fake_a) + loss_gan_b_to_a = self.gan_loss(pred_fake_a, self.target_real) * self.gan_weight + loss_gan = loss_gan_a_to_b + loss_gan_b_to_a + + # Cycle loss + loss_cycle_a_b_a = self.cycle_loss(reconstructed_a, real_a) * self.cycle_weight + loss_cycle_b_a_b = self.cycle_loss(reconstructed_b, real_b) * self.cycle_weight + loss_cycle = loss_cycle_a_b_a + loss_cycle_b_a_b + + # Total loss + loss_g: torch.Tensor = (loss_identity + loss_gan + loss_cycle) + self.optimizer_generator.zero_grad() + loss_g.backward() + self.optimizer_generator.step() + + # Discriminators A and B ###### + + # do update discriminators when training them to identify samples + set_requires_grad( + [self.discriminator_a, self.discriminator_b], requires_grad=True + ) + + # Real loss + pred_real = self.discriminator_a(real_a) + loss_d_a_real = ( + self.gan_loss(pred_real, self.target_real) + * self.gan_weight + * self.discriminator_weight + ) + + # Fake loss + fake_a = self.fake_a_buffer.push_and_pop(fake_a) + pred_a_fake = self.discriminator_a(fake_a.detach()) + loss_d_a_fake = ( + self.gan_loss(pred_a_fake, self.target_fake) + * self.gan_weight + * self.discriminator_weight + ) + + # Real loss + pred_real = self.discriminator_b(real_b) + loss_d_b_real = ( + self.gan_loss(pred_real, self.target_real) + * self.gan_weight + * self.discriminator_weight + ) + + # Fake loss + fake_b = self.fake_b_buffer.push_and_pop(fake_b) + pred_b_fake = self.discriminator_b(fake_b.detach()) + loss_d_b_fake = ( + self.gan_loss(pred_b_fake, self.target_fake) + * self.gan_weight + * self.discriminator_weight + ) + + # Total loss + loss_d: torch.Tensor = ( + loss_d_b_real + loss_d_b_fake + loss_d_a_real + loss_d_a_fake + ) * self.discriminator_weight + + self.optimizer_discriminator.zero_grad() + loss_d.backward() + self.optimizer_discriminator.step() + + return { + # "gan_loss": float(loss_gan), + "b_to_a_gan_loss": float(loss_gan_b_to_a), + "a_to_b_gan_loss": float(loss_gan_a_to_b), + "discriminator_a_loss": float(loss_d_a_fake + loss_d_a_real), + "discriminator_b_loss": float(loss_d_b_fake + loss_d_b_real), + # "cycle_loss": float(loss_cycle), + # "identity_loss": float(loss_identity), + # "generator_loss": float(loss_g), + # "discriminator_loss": float(loss_d), + "train_loss": float(loss_g + loss_d), + } + + +def set_requires_grad(nets: List[torch.nn.Module], requires_grad=False): + """Set requies_grad=Fasle for all the networks to avoid unnecessary computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/discriminator.py b/external/fv3fit/fv3fit/pytorch/cyclegan/discriminator.py new file mode 100644 index 0000000000..adf764b2e8 --- /dev/null +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/discriminator.py @@ -0,0 +1,82 @@ +import dataclasses + +import torch.nn as nn +from toolz import curry +import torch +from .modules import ( + ConvolutionFactory, + single_tile_convolution, + leakyrelu_activation, + ConvBlock, +) + + +@dataclasses.dataclass +class DiscriminatorConfig: + + n_convolutions: int = 3 + kernel_size: int = 3 + max_filters: int = 256 + + def build( + self, channels: int, convolution: ConvolutionFactory = single_tile_convolution, + ): + return Discriminator( + in_channels=channels, + n_convolutions=self.n_convolutions, + kernel_size=self.kernel_size, + max_filters=self.max_filters, + convolution=convolution, + ) + + +class Discriminator(nn.Module): + def __init__( + self, + in_channels: int, + n_convolutions: int, + kernel_size: int, + max_filters: int, + convolution: ConvolutionFactory = single_tile_convolution, + ): + super(Discriminator, self).__init__() + # max_filters = min_filters * 2 ** (n_convolutions - 1), therefore + min_filters = int(max_filters / 2 ** (n_convolutions - 1)) + convs = [ + ConvBlock( + in_channels=in_channels, + out_channels=min_filters, + convolution_factory=curry(convolution)( + kernel_size=kernel_size, stride=2, padding=1 + ), + activation_factory=leakyrelu_activation( + negative_slope=0.2, inplace=True + ), + ) + ] + for i in range(1, n_convolutions): + convs.append( + ConvBlock( + in_channels=min_filters * 2 ** (i - 1), + out_channels=min_filters * 2 ** i, + convolution_factory=curry(convolution)( + kernel_size=kernel_size, stride=2, padding=1 + ), + activation_factory=leakyrelu_activation( + negative_slope=0.2, inplace=True + ), + ) + ) + final_conv = ConvBlock( + in_channels=max_filters, + out_channels=max_filters, + convolution_factory=curry(convolution)(kernel_size=kernel_size), + activation_factory=leakyrelu_activation(negative_slope=0.2, inplace=True), + ) + patch_output = convolution( + kernel_size=3, in_channels=max_filters, out_channels=1, padding="same" + ) + self._sequential = nn.Sequential(*convs, final_conv, patch_output) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return self._sequential(inputs) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/generator.py b/external/fv3fit/fv3fit/pytorch/cyclegan/generator.py new file mode 100644 index 0000000000..a768d0d588 --- /dev/null +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/generator.py @@ -0,0 +1,148 @@ +import dataclasses +import torch.nn as nn +from toolz import curry +import torch +from .modules import ( + ConvBlock, + ConvolutionFactory, + single_tile_convolution, + relu_activation, + ResnetBlock, +) + + +@dataclasses.dataclass +class GeneratorConfig: + n_convolutions: int = 3 + n_resnet: int = 3 + kernel_size: int = 3 + max_filters: int = 256 + + def build( + self, channels: int, convolution: ConvolutionFactory = single_tile_convolution, + ): + return Generator( + channels=channels, + n_convolutions=self.n_convolutions, + n_resnet=self.n_resnet, + kernel_size=self.kernel_size, + max_filters=self.max_filters, + convolution=convolution, + ) + + +class Generator(nn.Module): + def __init__( + self, + channels: int, + n_convolutions: int, + n_resnet: int, + kernel_size: int, + max_filters: int, + convolution: ConvolutionFactory = single_tile_convolution, + ): + super(Generator, self).__init__() + + def resnet(in_channels: int): + resnet_blocks = [ + ResnetBlock( + n_filters=in_channels, + convolution_factory=curry(convolution)( + kernel_size=3, padding="same" + ), + activation_factory=relu_activation(), + ) + for _ in range(n_resnet) + ] + return nn.Sequential(*resnet_blocks) + + def down(in_channels: int, out_channels: int): + return ConvBlock( + in_channels=in_channels, + out_channels=out_channels, + convolution_factory=curry(convolution)( + kernel_size=3, stride=2, padding=1 + ), + activation_factory=relu_activation(), + ) + + def up(in_channels: int, out_channels: int): + return ConvBlock( + in_channels=in_channels, + out_channels=out_channels, + convolution_factory=curry(convolution)( + kernel_size=kernel_size, + stride=2, + padding=1, + output_padding=1, + stride_type="transpose", + ), + activation_factory=relu_activation(), + ) + + min_filters = int(max_filters / 2 ** (n_convolutions - 1)) + + self._first_conv = nn.Sequential( + convolution( + kernel_size=7, + in_channels=channels, + out_channels=min_filters, + padding="same", + ), + relu_activation()(), + ) + + self._encoder_decoder = SymmetricEncoderDecoder( + down_factory=down, + up_factory=up, + bottom_factory=resnet, + depth=n_convolutions - 1, + in_channels=min_filters, + ) + + self._out_conv = convolution( + kernel_size=7, + in_channels=min_filters, + out_channels=channels, + padding="same", + ) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + x = self._first_conv(inputs) + x = self._encoder_decoder(x) + outputs: torch.Tensor = self._out_conv(x) + return outputs + + +class SymmetricEncoderDecoder(nn.Module): + """ + Encoder-decoder network with a symmetric structure. + + Not a u-net because it does not have skip connections. + """ + + def __init__( + self, down_factory, up_factory, bottom_factory, depth: int, in_channels: int, + ): + super(SymmetricEncoderDecoder, self).__init__() + lower_channels = 2 * in_channels + self._down = down_factory(in_channels=in_channels, out_channels=lower_channels) + self._up = up_factory(in_channels=lower_channels, out_channels=in_channels) + if depth == 1: + self._lower = bottom_factory(in_channels=lower_channels) + elif depth <= 0: + raise ValueError(f"depth must be at least 1, got {depth}") + else: + self._lower = SymmetricEncoderDecoder( + down_factory, + up_factory, + bottom_factory, + depth=depth - 1, + in_channels=lower_channels, + ) + + def forward(self, inputs): + x = self._down(inputs) + x = self._lower(x) + x = self._up(x) + return x diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/modules.py b/external/fv3fit/fv3fit/pytorch/cyclegan/modules.py new file mode 100644 index 0000000000..49fab74d69 --- /dev/null +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/modules.py @@ -0,0 +1,167 @@ +import logging + +from typing import Callable, Literal, Protocol, Union +import torch.nn as nn + +logger = logging.getLogger(__name__) + + +def relu_activation(**kwargs): + def relu_factory(): + return nn.ReLU(**kwargs) + + return relu_factory + + +def tanh_activation(): + return nn.Tanh() + + +def leakyrelu_activation(**kwargs): + def leakyrelu_factory(): + return nn.LeakyReLU(**kwargs) + + return leakyrelu_factory + + +def no_activation(): + return nn.Identity() + + +class ConvolutionFactory(Protocol): + def __call__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + padding: Union[str, int] = 0, + output_padding: int = 0, + stride: int = 1, + stride_type: Literal["regular", "transpose"] = "regular", + bias: bool = True, + ) -> nn.Module: + """ + Create a convolutional layer. + + Args: + in_channels: number of input channels + out_channels: number of output channels + kernel_size: size of the convolution kernel + padding: padding to apply to the input, should be an integer or "same" + output_padding: argument used for transpose convolution + stride: stride of the convolution + stride_type: type of stride, one of "regular" or "transpose" + bias: whether to include a bias vector in the produced layers + """ + ... + + +class CurriedConvolutionFactory(Protocol): + def __call__(self, in_channels: int, out_channels: int,) -> nn.Module: + """ + Create a convolutional layer. + + Args: + in_channels: number of input channels + out_channels: number of output channels + """ + ... + + +def single_tile_convolution( + in_channels: int, + out_channels: int, + kernel_size: int, + padding: Union[str, int] = 0, + output_padding: int = 0, + stride: int = 1, + stride_type: Literal["regular", "transpose"] = "regular", + bias: bool = True, +) -> ConvolutionFactory: + """ + Construct a convolutional layer for single tile data (like images). + + Args: + kernel_size: size of the convolution kernel + padding: padding to apply to the input, should be an integer or "same" + output_padding: argument used for transpose convolution + stride: stride of the convolution + stride_type: type of stride, one of "regular" or "transpose" + bias: whether to include a bias vector in the produced layers + """ + if stride == 1: + return nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + bias=bias, + ) + + elif stride_type == "regular": + return nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=bias, + ) + elif stride_type == "transpose": + return nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=(padding, padding), + output_padding=output_padding, + bias=bias, + ) + + +class ResnetBlock(nn.Module): + def __init__( + self, + n_filters: int, + convolution_factory: CurriedConvolutionFactory, + activation_factory: Callable[[], nn.Module] = relu_activation(), + ): + super(ResnetBlock, self).__init__() + self.conv_block = nn.Sequential( + ConvBlock( + in_channels=n_filters, + out_channels=n_filters, + convolution_factory=convolution_factory, + activation_factory=activation_factory, + ), + ConvBlock( + in_channels=n_filters, + out_channels=n_filters, + convolution_factory=convolution_factory, + activation_factory=no_activation, + ), + ) + self.identity = nn.Identity() + + def forward(self, inputs): + g = self.conv_block(inputs) + return g + self.identity(inputs) + + +class ConvBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + convolution_factory: CurriedConvolutionFactory, + activation_factory: Callable[[], nn.Module] = relu_activation(), + ): + super(ConvBlock, self).__init__() + self.conv_block = nn.Sequential( + convolution_factory(in_channels=in_channels, out_channels=out_channels), + nn.InstanceNorm2d(out_channels), + activation_factory(), + ) + + def forward(self, inputs): + return self.conv_block(inputs) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/network.py b/external/fv3fit/fv3fit/pytorch/cyclegan/network.py deleted file mode 100644 index 062b7a9afb..0000000000 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/network.py +++ /dev/null @@ -1,374 +0,0 @@ -import dataclasses -import logging - -from typing import Callable, Literal, Protocol, Union -import torch.nn as nn -from toolz import curry -import torch - -logger = logging.getLogger(__name__) - - -def relu_activation(**kwargs): - def relu_factory(): - return nn.ReLU(**kwargs) - - return relu_factory - - -def tanh_activation(): - return nn.Tanh() - - -def leakyrelu_activation(**kwargs): - def leakyrelu_factory(): - return nn.LeakyReLU(**kwargs) - - return leakyrelu_factory - - -def no_activation(): - return nn.Identity() - - -class ConvolutionFactory(Protocol): - def __call__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - padding: Union[str, int] = 0, - output_padding: int = 0, - stride: int = 1, - stride_type: Literal["regular", "transpose"] = "regular", - bias: bool = True, - ) -> nn.Module: - """ - Create a convolutional layer. - - Args: - in_channels: number of input channels - out_channels: number of output channels - kernel_size: size of the convolution kernel - padding: padding to apply to the input, should be an integer or "same" - output_padding: argument used for transpose convolution - stride: stride of the convolution - stride_type: type of stride, one of "regular" or "transpose" - bias: whether to include a bias vector in the produced layers - """ - ... - - -class CurriedConvolutionFactory(Protocol): - def __call__(self, in_channels: int, out_channels: int,) -> nn.Module: - """ - Create a convolutional layer. - - Args: - in_channels: number of input channels - out_channels: number of output channels - """ - ... - - -def single_tile_convolution( - in_channels: int, - out_channels: int, - kernel_size: int, - padding: Union[str, int] = 0, - output_padding: int = 0, - stride: int = 1, - stride_type: Literal["regular", "transpose"] = "regular", - bias: bool = True, -) -> ConvolutionFactory: - """ - Construct a convolutional layer for single tile data (like images). - - Args: - kernel_size: size of the convolution kernel - padding: padding to apply to the input, should be an integer or "same" - output_padding: argument used for transpose convolution - stride: stride of the convolution - stride_type: type of stride, one of "regular" or "transpose" - bias: whether to include a bias vector in the produced layers - """ - if stride == 1: - return nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - padding=padding, - bias=bias, - ) - - elif stride_type == "regular": - return nn.Conv2d( - in_channels, - out_channels, - kernel_size, - stride=stride, - padding=padding, - bias=bias, - ) - elif stride_type == "transpose": - return nn.ConvTranspose2d( - in_channels, - out_channels, - kernel_size, - stride=stride, - padding=(padding, padding), - output_padding=output_padding, - bias=bias, - ) - - -@dataclasses.dataclass -class GeneratorConfig: - n_convolutions: int = 3 - n_resnet: int = 3 - kernel_size: int = 3 - max_filters: int = 256 - - def build( - self, channels: int, convolution: ConvolutionFactory = single_tile_convolution, - ): - return Generator( - channels=channels, - n_convolutions=self.n_convolutions, - n_resnet=self.n_resnet, - kernel_size=self.kernel_size, - max_filters=self.max_filters, - convolution=convolution, - ) - - -@dataclasses.dataclass -class DiscriminatorConfig: - - n_convolutions: int = 3 - kernel_size: int = 3 - max_filters: int = 256 - - def build( - self, channels: int, convolution: ConvolutionFactory = single_tile_convolution, - ): - return Discriminator( - in_channels=channels, - n_convolutions=self.n_convolutions, - kernel_size=self.kernel_size, - max_filters=self.max_filters, - convolution=convolution, - ) - - -class ResnetBlock(nn.Module): - def __init__( - self, - n_filters: int, - convolution_factory: CurriedConvolutionFactory, - activation_factory: Callable[[], nn.Module] = relu_activation(), - ): - super(ResnetBlock, self).__init__() - self.conv_block = nn.Sequential( - ConvBlock( - in_channels=n_filters, - out_channels=n_filters, - convolution_factory=convolution_factory, - activation_factory=activation_factory, - ), - ConvBlock( - in_channels=n_filters, - out_channels=n_filters, - convolution_factory=convolution_factory, - activation_factory=no_activation, - ), - ) - self.identity = nn.Identity() - - def forward(self, inputs): - g = self.conv_block(inputs) - return g + self.identity(inputs) - - -class ConvBlock(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - convolution_factory: CurriedConvolutionFactory, - activation_factory: Callable[[], nn.Module] = relu_activation(), - ): - super(ConvBlock, self).__init__() - self.conv_block = nn.Sequential( - convolution_factory(in_channels=in_channels, out_channels=out_channels), - nn.InstanceNorm2d(out_channels), - activation_factory(), - ) - - def forward(self, inputs): - return self.conv_block(inputs) - - -class Discriminator(nn.Module): - def __init__( - self, - in_channels: int, - n_convolutions: int, - kernel_size: int, - max_filters: int, - convolution: ConvolutionFactory = single_tile_convolution, - ): - super(Discriminator, self).__init__() - # max_filters = min_filters * 2 ** (n_convolutions - 1), therefore - min_filters = int(max_filters / 2 ** (n_convolutions - 1)) - convs = [ - ConvBlock( - in_channels=in_channels, - out_channels=min_filters, - convolution_factory=curry(convolution)( - kernel_size=kernel_size, stride=2, padding=1 - ), - activation_factory=leakyrelu_activation( - negative_slope=0.2, inplace=True - ), - ) - ] - for i in range(1, n_convolutions): - convs.append( - ConvBlock( - in_channels=min_filters * 2 ** (i - 1), - out_channels=min_filters * 2 ** i, - convolution_factory=curry(convolution)( - kernel_size=kernel_size, stride=2, padding=1 - ), - activation_factory=leakyrelu_activation( - negative_slope=0.2, inplace=True - ), - ) - ) - final_conv = ConvBlock( - in_channels=max_filters, - out_channels=max_filters, - convolution_factory=curry(convolution)(kernel_size=kernel_size), - activation_factory=leakyrelu_activation(negative_slope=0.2, inplace=True), - ) - patch_output = convolution( - kernel_size=3, in_channels=max_filters, out_channels=1, padding="same" - ) - self._sequential = nn.Sequential(*convs, final_conv, patch_output) - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - return self._sequential(inputs) - - -class Generator(nn.Module): - def __init__( - self, - channels: int, - n_convolutions: int, - n_resnet: int, - kernel_size: int, - max_filters: int, - convolution: ConvolutionFactory = single_tile_convolution, - ): - super(Generator, self).__init__() - - def resnet(in_channels: int): - resnet_blocks = [ - ResnetBlock( - n_filters=in_channels, - convolution_factory=curry(convolution)( - kernel_size=3, padding="same" - ), - activation_factory=relu_activation(), - ) - for _ in range(n_resnet) - ] - return nn.Sequential(*resnet_blocks) - - def down(in_channels: int, out_channels: int): - return ConvBlock( - in_channels=in_channels, - out_channels=out_channels, - convolution_factory=curry(convolution)( - kernel_size=3, stride=2, padding=1 - ), - activation_factory=relu_activation(), - ) - - def up(in_channels: int, out_channels: int): - return ConvBlock( - in_channels=in_channels, - out_channels=out_channels, - convolution_factory=curry(convolution)( - kernel_size=kernel_size, - stride=2, - padding=1, - output_padding=1, - stride_type="transpose", - ), - activation_factory=relu_activation(), - ) - - min_filters = int(max_filters / 2 ** (n_convolutions - 1)) - - self._first_conv = nn.Sequential( - convolution( - kernel_size=7, - in_channels=channels, - out_channels=min_filters, - padding="same", - ), - relu_activation()(), - ) - - self._unet = UNet( - down_factory=down, - up_factory=up, - bottom_factory=resnet, - depth=n_convolutions - 1, - in_channels=min_filters, - ) - - self._out_conv = convolution( - kernel_size=7, - in_channels=min_filters, - out_channels=channels, - padding="same", - ) - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - x = self._first_conv(inputs) - x = self._unet(x) - outputs: torch.Tensor = self._out_conv(x) - return outputs - - -class UNet(nn.Module): - def __init__( - self, down_factory, up_factory, bottom_factory, depth: int, in_channels: int, - ): - super(UNet, self).__init__() - lower_channels = 2 * in_channels - self._down = down_factory(in_channels=in_channels, out_channels=lower_channels) - self._up = up_factory(in_channels=lower_channels, out_channels=in_channels) - if depth == 1: - self._lower = bottom_factory(in_channels=lower_channels) - elif depth <= 0: - raise ValueError(f"depth must be at least 1, got {depth}") - else: - self._lower = UNet( - down_factory, - up_factory, - bottom_factory, - depth=depth - 1, - in_channels=lower_channels, - ) - - def forward(self, inputs): - x = self._down(inputs) - x = self._lower(x) - x = self._up(x) - # skip connection - # x = torch.concat([x, inputs], dim=1) - return x diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/reloadable.py b/external/fv3fit/fv3fit/pytorch/cyclegan/reloadable.py new file mode 100644 index 0000000000..369d1ef612 --- /dev/null +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/reloadable.py @@ -0,0 +1,162 @@ +from fv3fit._shared import io +from fv3fit._shared.predictor import Reloadable +from fv3fit._shared.scaler import StandardScaler +from fv3fit.pytorch.predict import ( + _dump_pytorch, + _load_pytorch, + _pack_to_tensor, + _unpack_tensor, +) +from .generator import Generator +from .discriminator import Discriminator +from typing import Mapping, Iterable +import torch +import xarray as xr + + +class CycleGANModule(torch.nn.Module): + def __init__( + self, + generator_a_to_b: Generator, + generator_b_to_a: Generator, + discriminator_a: Discriminator, + discriminator_b: Discriminator, + ): + super(CycleGANModule, self).__init__() + self.generator_a_to_b = generator_a_to_b + self.generator_b_to_a = generator_b_to_a + self.discriminator_a = discriminator_a + self.discriminator_b = discriminator_b + + +@io.register("cycle_gan") +class CycleGAN(Reloadable): + + _MODEL_FILENAME = "weight.pt" + _CONFIG_FILENAME = "config.yaml" + _SCALERS_FILENAME = "scalers.zip" + + def __init__( + self, + model: CycleGANModule, + scalers: Mapping[str, StandardScaler], + state_variables: Iterable[str], + ): + """ + Args: + model: pytorch model + scalers: scalers for the state variables, keys are prepended with "a_" + or "b_" to denote the domain of the scaler, followed by the name of + the state variable it scales + state_variables: name of variables to be used as state variables in + the order expected by the model + """ + self.model = model + self.scalers = scalers + self.state_variables = state_variables + + @property + def generator_a_to_b(self) -> torch.nn.Module: + return self.model.generator_a_to_b + + @property + def generator_b_to_a(self) -> torch.nn.Module: + return self.model.generator_b_to_a + + @property + def discriminator_a(self) -> torch.nn.Module: + return self.model.discriminator_a + + @property + def discriminator_b(self) -> torch.nn.Module: + return self.model.discriminator_b + + @classmethod + def load(cls, path: str) -> "CycleGAN": + """Load a serialized model from a directory.""" + return _load_pytorch(cls, path) + + def dump(self, path: str) -> None: + _dump_pytorch(self, path) + + def get_config(self): + return {} + + def pack_to_tensor(self, ds: xr.Dataset, domain: str = "a") -> torch.Tensor: + """ + Packs the dataset into a tensor to be used by the pytorch model. + + Subdivides the dataset evenly into windows + of size (timesteps + 1) with overlapping start and end points. + Overlapping the window start and ends is necessary so that every + timestep (evolution from one time to the next) is included within + one of the windows. + + Args: + ds: dataset containing values to pack + domain: one of "a" or "b" + + Returns: + tensor of shape [window, time, tile, x, y, feature] + """ + scalers = { + name[2:]: scaler + for name, scaler in self.scalers.items() + if name.startswith(f"{domain}_") + } + tensor = _pack_to_tensor( + ds=ds, timesteps=0, state_variables=self.state_variables, scalers=scalers, + ) + return tensor.permute([0, 1, 4, 2, 3]) + + def unpack_tensor(self, data: torch.Tensor, domain: str = "b") -> xr.Dataset: + """ + Unpacks the tensor into a dataset. + + Args: + data: tensor of shape [window, time, tile, x, y, feature] + domain: one of "a" or "b" + + Returns: + xarray dataset with values of shape [window, time, tile, x, y, feature] + """ + scalers = { + name[2:]: scaler + for name, scaler in self.scalers.items() + if name.startswith(f"{domain}_") + } + return _unpack_tensor( + data.permute([0, 1, 3, 4, 2]), + varnames=self.state_variables, + scalers=scalers, + dims=["time", "tile", "x", "y", "z"], + ) + + def predict(self, X: xr.Dataset, reverse: bool = False) -> xr.Dataset: + """ + Predict a state in the output domain from a state in the input domain. + + Args: + X: input dataset + reverse: if True, transform from the output domain to the input domain + + Returns: + predicted: predicted dataset + """ + if reverse: + input_domain, output_domain = "b", "a" + else: + input_domain, output_domain = "a", "b" + + tensor = self.pack_to_tensor(X, domain=input_domain) + reshaped_tensor = tensor.reshape( + [tensor.shape[0] * tensor.shape[1]] + list(tensor.shape[2:]) + ) + with torch.no_grad(): + if reverse: + outputs = self.generator_b_to_a(reshaped_tensor) + else: + outputs = self.generator_a_to_b(reshaped_tensor) + outputs = outputs.reshape(tensor.shape) + predicted = self.unpack_tensor(outputs, domain=output_domain) + return predicted diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train_autoencoder.py similarity index 92% rename from external/fv3fit/fv3fit/pytorch/cyclegan/train.py rename to external/fv3fit/fv3fit/pytorch/cyclegan/train_autoencoder.py index 59f1f58dcd..b6c6678cf5 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train_autoencoder.py @@ -13,7 +13,7 @@ Optional, Tuple, ) -from .network import Generator, GeneratorConfig +from .generator import Generator, GeneratorConfig from fv3fit.pytorch.graph.train import get_Xy_map_fn from fv3fit._shared.scaler import ( get_standard_scaler_mapping, @@ -106,11 +106,15 @@ def train_autoencoder( optimizer = hyperparameters.optimizer_config train_state = flatten_dims( - train_state.map(define_noisy_input(stdev=hyperparameters.noise_amount)) + train_state.map(channels_first).map( + define_noisy_input(stdev=hyperparameters.noise_amount) + ) ) if validation_batches is not None: val_state = flatten_dims( - val_state.map(define_noisy_input(stdev=hyperparameters.noise_amount)) + val_state.map(channels_first).map( + define_noisy_input(stdev=hyperparameters.noise_amount) + ) ) hyperparameters.training_loop.fit_loop( @@ -131,7 +135,7 @@ def train_autoencoder( def channels_first(data: tf.Tensor) -> tf.Tensor: - return tf.transpose(data, perm=[0, 3, 1, 2]) + return tf.transpose(data, perm=[0, 1, 2, 5, 3, 4]) def build_model(config: GeneratorConfig, n_state: int) -> Generator: diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py index 7f647e580b..8c0f3b5e1f 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py @@ -1,28 +1,15 @@ -import itertools from fv3fit._shared.hyperparameters import Hyperparameters -import random import dataclasses -from fv3fit._shared.predictor import Reloadable import tensorflow as tf from fv3fit.pytorch.loss import LossConfig -from fv3fit.pytorch.optimizer import OptimizerConfig -import xarray as xr import torch from fv3fit.pytorch.system import DEVICE import tensorflow_datasets as tfds from fv3fit.tfdataset import sequence_size, apply_to_tuple -from fv3fit.pytorch.predict import ( - _load_pytorch, - _dump_pytorch, - _pack_to_tensor, - _unpack_tensor, -) -from fv3fit._shared import register_training_function, io +from fv3fit._shared import register_training_function from typing import ( Callable, - Dict, - Iterable, List, Mapping, Optional, @@ -30,15 +17,15 @@ Tuple, ) from fv3fit.tfdataset import ensure_nd -from .network import Discriminator, Generator, GeneratorConfig, DiscriminatorConfig from fv3fit.pytorch.graph.train import get_Xy_map_fn as get_Xy_map_fn_single_domain from fv3fit._shared.scaler import ( get_standard_scaler_mapping, get_mapping_standard_scale_func, - StandardScaler, ) import logging import numpy as np +from .reloadable import CycleGAN +from .cyclegan_trainer import CycleGANNetworkConfig, CycleGANTrainer logger = logging.getLogger(__name__) @@ -71,7 +58,7 @@ class CycleGANTrainingConfig: def fit_loop( self, - train_model: "CycleGANTrainer", + train_model: CycleGANTrainer, train_data: tf.data.Dataset, validation_data: Optional[tf.data.Dataset], ) -> None: @@ -108,46 +95,6 @@ def fit_loop( } logger.info("train_loss: %s", train_loss) - # real_a = torch.as_tensor(batch_state[0]).float().to(DEVICE) - # real_b = torch.as_tensor(batch_state[1]).float().to(DEVICE) - # fake_b = train_model.generator_a_to_b(real_a) - # fake_a = train_model.generator_b_to_a(real_b) - # reconstructed_a = train_model.generator_b_to_a(fake_b) - # reconstructed_b = train_model.generator_a_to_b(fake_a) - - # import matplotlib.pyplot as plt - - # fig, ax = plt.subplots(3, 2, figsize=(8, 8)) - # i = 0 - # iz = 0 - # vmin = -1.5 - # vmax = 1.5 - # ax[0, 0].imshow( - # real_a[0, :, :, iz].detach().numpy(), vmin=vmin, vmax=vmax - # ) - # ax[0, 1].imshow( - # real_b[0, :, :, iz].detach().numpy(), vmin=vmin, vmax=vmax - # ) - # ax[1, 0].imshow( - # fake_b[0, :, :, iz].detach().numpy(), vmin=vmin, vmax=vmax - # ) - # ax[1, 1].imshow( - # fake_a[0, :, :, iz].detach().numpy(), vmin=vmin, vmax=vmax - # ) - # ax[2, 0].imshow( - # reconstructed_a[0, :, :, iz].detach().numpy(), vmin=vmin, vmax=vmax - # ) - # ax[2, 1].imshow( - # reconstructed_b[0, :, :, iz].detach().numpy(), vmin=vmin, vmax=vmax - # ) - # ax[0, 0].set_title("real a") - # ax[0, 1].set_title("real b") - # ax[1, 0].set_title("output b") - # ax[1, 1].set_title("output a") - # ax[2, 0].set_title("reconstructed a") - # ax[2, 1].set_title("reconstructed b") - # plt.tight_layout() - # plt.show() if validation_data is not None: val_loss = train_model.evaluate_on_dataset(validation_data) logger.info("val_loss %s", val_loss) @@ -245,476 +192,3 @@ def train_cyclegan( train_model=train_model, train_data=train_state, validation_data=val_state, ) return train_model.cycle_gan - - -class ReplayBuffer: - - # To reduce model oscillation during training, we update the discriminator - # using a history of generated data instead of the most recently generated data - # according to Shrivastava et al. (2017). - - def __init__(self, max_size=50): - if max_size <= 0: - raise ValueError("max_size must be positive") - self.max_size = max_size - self.data = [] - - def push_and_pop(self, data: torch.Tensor) -> torch.autograd.Variable: - to_return = [] - for element in data.data: - element = torch.unsqueeze(element, 0) - if len(self.data) < self.max_size: - self.data.append(element) - to_return.append(element) - else: - if random.uniform(0, 1) > 0.5: - i = random.randint(0, self.max_size - 1) - to_return.append(self.data[i].clone()) - self.data[i] = element - else: - to_return.append(element) - return torch.autograd.Variable(torch.cat(to_return)) - - -class StatsCollector: - def __init__(self, n_dims_keep: int): - self.n_dims_keep = n_dims_keep - self._sum = 0.0 - self._sum_squared = 0.0 - self._count = 0 - - def observe(self, data: np.ndarray): - mean_dims = tuple(range(0, len(data.shape) - self.n_dims_keep)) - data = data.astype(np.float64) - self._sum += data.mean(axis=mean_dims) - self._sum_squared += (data ** 2).mean(axis=mean_dims) - self._count += 1 - - @property - def mean(self) -> np.ndarray: - return self._sum / self._count - - @property - def std(self) -> np.ndarray: - return np.sqrt(self._sum_squared / self._count - self.mean ** 2) - - -def get_r2(predicted, target) -> float: - """ - Compute the R^2 statistic for the predicted and target data. - """ - return 1.0 - np.var(predicted - target) / np.var(target) - - -@dataclasses.dataclass -class CycleGANNetworkConfig: - generator_optimizer: OptimizerConfig = dataclasses.field( - default_factory=lambda: OptimizerConfig("Adam") - ) - discriminator_optimizer: OptimizerConfig = dataclasses.field( - default_factory=lambda: OptimizerConfig("Adam") - ) - generator: "GeneratorConfig" = dataclasses.field( - default_factory=lambda: GeneratorConfig() - ) - discriminator: "DiscriminatorConfig" = dataclasses.field( - default_factory=lambda: DiscriminatorConfig() - ) - identity_loss: LossConfig = dataclasses.field(default_factory=LossConfig) - cycle_loss: LossConfig = dataclasses.field(default_factory=LossConfig) - gan_loss: LossConfig = dataclasses.field(default_factory=LossConfig) - identity_weight: float = 0.5 - cycle_weight: float = 1.0 - gan_weight: float = 1.0 - discriminator_weight: float = 1.0 - - def build( - self, n_state: int, n_batch: int, state_variables, scalers - ) -> "CycleGANTrainer": - generator_a_to_b = self.generator.build(n_state) - generator_b_to_a = self.generator.build(n_state) - discriminator_a = self.discriminator.build(n_state) - discriminator_b = self.discriminator.build(n_state) - optimizer_generator = self.generator_optimizer.instance( - itertools.chain( - generator_a_to_b.parameters(), generator_b_to_a.parameters() - ) - ) - optimizer_discriminator = self.discriminator_optimizer.instance( - itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()) - ) - return CycleGANTrainer( - cycle_gan=CycleGAN( - model=CycleGANModule( - generator_a_to_b=generator_a_to_b, - generator_b_to_a=generator_b_to_a, - discriminator_a=discriminator_a, - discriminator_b=discriminator_b, - ).to(DEVICE), - state_variables=state_variables, - scalers=_merge_scaler_mappings(scalers), - ), - optimizer_generator=optimizer_generator, - optimizer_discriminator=optimizer_discriminator, - identity_loss=self.identity_loss.instance, - cycle_loss=self.cycle_loss.instance, - gan_loss=self.gan_loss.instance, - batch_size=n_batch, - identity_weight=self.identity_weight, - cycle_weight=self.cycle_weight, - gan_weight=self.gan_weight, - discriminator_weight=self.discriminator_weight, - ) - - -def _merge_scaler_mappings( - scaler_tuple: Tuple[Mapping[str, StandardScaler], Mapping[str, StandardScaler]] -) -> Mapping[str, StandardScaler]: - scalers = {} - for prefix, scaler_map in zip(("a_", "b_"), scaler_tuple): - for key, scaler in scaler_map.items(): - scalers[prefix + key] = scaler - return scalers - - -class CycleGANModule(torch.nn.Module): - def __init__( - self, - generator_a_to_b: Generator, - generator_b_to_a: Generator, - discriminator_a: Discriminator, - discriminator_b: Discriminator, - ): - super(CycleGANModule, self).__init__() - self.generator_a_to_b = generator_a_to_b - self.generator_b_to_a = generator_b_to_a - self.discriminator_a = discriminator_a - self.discriminator_b = discriminator_b - - -@io.register("cycle_gan") -class CycleGAN(Reloadable): - - _MODEL_FILENAME = "weight.pt" - _CONFIG_FILENAME = "config.yaml" - _SCALERS_FILENAME = "scalers.zip" - - def __init__( - self, - model: CycleGANModule, - scalers: Mapping[str, StandardScaler], - state_variables: Iterable[str], - ): - """ - Args: - model: pytorch model - scalers: scalers for the state variables, keys are prepended with "a_" - or "b_" to denote the domain of the scaler, followed by the name of - the state variable it scales - state_variables: name of variables to be used as state variables in - the order expected by the model - """ - self.model = model - self.scalers = scalers - self.state_variables = state_variables - - @property - def generator_a_to_b(self) -> torch.nn.Module: - return self.model.generator_a_to_b - - @property - def generator_b_to_a(self) -> torch.nn.Module: - return self.model.generator_b_to_a - - @property - def discriminator_a(self) -> torch.nn.Module: - return self.model.discriminator_a - - @property - def discriminator_b(self) -> torch.nn.Module: - return self.model.discriminator_b - - @classmethod - def load(cls, path: str) -> "CycleGAN": - """Load a serialized model from a directory.""" - return _load_pytorch(cls, path) - - def dump(self, path: str) -> None: - _dump_pytorch(self, path) - - def get_config(self): - return {} - - def pack_to_tensor(self, ds: xr.Dataset, domain: str = "a") -> torch.Tensor: - """ - Packs the dataset into a tensor to be used by the pytorch model. - - Subdivides the dataset evenly into windows - of size (timesteps + 1) with overlapping start and end points. - Overlapping the window start and ends is necessary so that every - timestep (evolution from one time to the next) is included within - one of the windows. - - Args: - ds: dataset containing values to pack - domain: one of "a" or "b" - - Returns: - tensor of shape [window, time, tile, x, y, feature] - """ - scalers = { - name[2:]: scaler - for name, scaler in self.scalers.items() - if name.startswith(f"{domain}_") - } - tensor = _pack_to_tensor( - ds=ds, timesteps=0, state_variables=self.state_variables, scalers=scalers, - ) - return tensor.permute([0, 1, 4, 2, 3]) - - def unpack_tensor(self, data: torch.Tensor, domain: str = "b") -> xr.Dataset: - """ - Unpacks the tensor into a dataset. - - Args: - data: tensor of shape [window, time, tile, x, y, feature] - domain: one of "a" or "b" - - Returns: - xarray dataset with values of shape [window, time, tile, x, y, feature] - """ - scalers = { - name[2:]: scaler - for name, scaler in self.scalers.items() - if name.startswith(f"{domain}_") - } - return _unpack_tensor( - data.permute([0, 1, 3, 4, 2]), - varnames=self.state_variables, - scalers=scalers, - dims=["time", "tile", "x", "y", "z"], - ) - - def predict(self, X: xr.Dataset, reverse: bool = False) -> xr.Dataset: - """ - Predict a state in the output domain from a state in the input domain. - - Args: - X: input dataset - reverse: if True, transform from the output domain to the input domain - - Returns: - predicted: predicted dataset - """ - if reverse: - input_domain, output_domain = "b", "a" - else: - input_domain, output_domain = "a", "b" - - tensor = self.pack_to_tensor(X, domain=input_domain) - reshaped_tensor = tensor.reshape( - [tensor.shape[0] * tensor.shape[1]] + list(tensor.shape[2:]) - ) - with torch.no_grad(): - if reverse: - outputs = self.generator_b_to_a(reshaped_tensor) - else: - outputs = self.generator_a_to_b(reshaped_tensor) - outputs = outputs.reshape(tensor.shape) - predicted = self.unpack_tensor(outputs, domain=output_domain) - return predicted - - -@dataclasses.dataclass -class CycleGANTrainer: - - # This class based loosely on - # https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/cycle_gan_model.py - - # Copyright Facebook, BSD license - # https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/c99ce7c4e781712e0252c6127ad1a4e8021cc489/LICENSE - - cycle_gan: CycleGAN - optimizer_generator: torch.optim.Optimizer - optimizer_discriminator: torch.optim.Optimizer - identity_loss: torch.nn.Module - cycle_loss: torch.nn.Module - gan_loss: torch.nn.Module - batch_size: int - identity_weight: float = 0.5 - cycle_weight: float = 1.0 - gan_weight: float = 1.0 - discriminator_weight: float = 1.0 - - def __post_init__(self): - self.target_real: Optional[torch.autograd.Variable] = None - self.target_fake: Optional[torch.autograd.Variable] = None - self.fake_a_buffer = ReplayBuffer() - self.fake_b_buffer = ReplayBuffer() - self.generator_a_to_b = self.cycle_gan.generator_a_to_b - self.generator_b_to_a = self.cycle_gan.generator_b_to_a - self.discriminator_a = self.cycle_gan.discriminator_a - self.discriminator_b = self.cycle_gan.discriminator_b - - def _init_targets(self, shape: Tuple[int, ...]): - self.target_real = torch.autograd.Variable( - torch.Tensor(shape).fill_(1.0).to(DEVICE), requires_grad=False - ) - self.target_fake = torch.autograd.Variable( - torch.Tensor(shape).fill_(0.0).to(DEVICE), requires_grad=False - ) - - def evaluate_on_dataset( - self, dataset: tf.data.Dataset, n_dims_keep: int = 3 - ) -> Dict[str, float]: - stats_real_a = StatsCollector(n_dims_keep) - stats_real_b = StatsCollector(n_dims_keep) - stats_gen_a = StatsCollector(n_dims_keep) - stats_gen_b = StatsCollector(n_dims_keep) - real_a: np.ndarray - real_b: np.ndarray - for real_a, real_b in dataset: - stats_real_a.observe(real_a) - stats_real_b.observe(real_b) - gen_b: torch.Tensor = self.generator_a_to_b( - torch.as_tensor(real_a).float().to(DEVICE) - ) - gen_a: torch.Tensor = self.generator_b_to_a( - torch.as_tensor(real_b).float().to(DEVICE) - ) - stats_gen_a.observe(gen_a.detach().cpu().numpy()) - stats_gen_b.observe(gen_b.detach().cpu().numpy()) - metrics = { - # "r2_mean_b_against_real_a": get_r2(stats_real_a.mean, stats_gen_b.mean), - "r2_mean_a": get_r2(stats_real_a.mean, stats_gen_a.mean), - # "bias_mean_a": np.mean(stats_real_a.mean - stats_gen_a.mean), - "r2_mean_b": get_r2(stats_real_b.mean, stats_gen_b.mean), - # "bias_mean_b": np.mean(stats_real_b.mean - stats_gen_b.mean), - "r2_std_a": get_r2(stats_real_a.std, stats_gen_a.std), - # "bias_std_a": np.mean(stats_real_a.std - stats_gen_a.std), - "r2_std_b": get_r2(stats_real_b.std, stats_gen_b.std), - # "bias_std_b": np.mean(stats_real_b.std - stats_gen_b.std), - } - return metrics - - def train_on_batch( - self, real_a: torch.Tensor, real_b: torch.Tensor - ) -> Mapping[str, float]: - fake_b = self.generator_a_to_b(real_a) - fake_a = self.generator_b_to_a(real_b) - reconstructed_a = self.generator_b_to_a(fake_b) - reconstructed_b = self.generator_a_to_b(fake_a) - - # Generators A2B and B2A ###### - - # don't update discriminators when training generators to fool them - set_requires_grad( - [self.discriminator_a, self.discriminator_b], requires_grad=False - ) - - # Identity loss - # G_A2B(B) should equal B if real B is fed - same_b = self.generator_a_to_b(real_b) - loss_identity_b = self.identity_loss(same_b, real_b) * self.identity_weight - # G_B2A(A) should equal A if real A is fed - same_a = self.generator_b_to_a(real_a) - loss_identity_a = self.identity_loss(same_a, real_a) * self.identity_weight - loss_identity = loss_identity_a + loss_identity_b - - # GAN loss - pred_fake_b = self.discriminator_b(fake_b) - if self.target_real is None: - self._init_targets(pred_fake_b.shape) - loss_gan_a_to_b = self.gan_loss(pred_fake_b, self.target_real) * self.gan_weight - - pred_fake_a = self.discriminator_a(fake_a) - loss_gan_b_to_a = self.gan_loss(pred_fake_a, self.target_real) * self.gan_weight - loss_gan = loss_gan_a_to_b + loss_gan_b_to_a - - # Cycle loss - loss_cycle_a_b_a = self.cycle_loss(reconstructed_a, real_a) * self.cycle_weight - loss_cycle_b_a_b = self.cycle_loss(reconstructed_b, real_b) * self.cycle_weight - loss_cycle = loss_cycle_a_b_a + loss_cycle_b_a_b - - # Total loss - loss_g: torch.Tensor = (loss_identity + loss_gan + loss_cycle) - self.optimizer_generator.zero_grad() - loss_g.backward() - self.optimizer_generator.step() - - # Discriminators A and B ###### - - # do update discriminators when training them to identify samples - set_requires_grad( - [self.discriminator_a, self.discriminator_b], requires_grad=True - ) - - # Real loss - pred_real = self.discriminator_a(real_a) - loss_d_a_real = ( - self.gan_loss(pred_real, self.target_real) - * self.gan_weight - * self.discriminator_weight - ) - - # Fake loss - fake_a = self.fake_a_buffer.push_and_pop(fake_a) - pred_a_fake = self.discriminator_a(fake_a.detach()) - loss_d_a_fake = ( - self.gan_loss(pred_a_fake, self.target_fake) - * self.gan_weight - * self.discriminator_weight - ) - - # Real loss - pred_real = self.discriminator_b(real_b) - loss_d_b_real = ( - self.gan_loss(pred_real, self.target_real) - * self.gan_weight - * self.discriminator_weight - ) - - # Fake loss - fake_b = self.fake_b_buffer.push_and_pop(fake_b) - pred_b_fake = self.discriminator_b(fake_b.detach()) - loss_d_b_fake = ( - self.gan_loss(pred_b_fake, self.target_fake) - * self.gan_weight - * self.discriminator_weight - ) - - # Total loss - loss_d: torch.Tensor = ( - loss_d_b_real + loss_d_b_fake + loss_d_a_real + loss_d_a_fake - ) * self.discriminator_weight - - self.optimizer_discriminator.zero_grad() - loss_d.backward() - self.optimizer_discriminator.step() - - return { - # "gan_loss": float(loss_gan), - "b_to_a_gan_loss": float(loss_gan_b_to_a), - "a_to_b_gan_loss": float(loss_gan_a_to_b), - "discriminator_a_loss": float(loss_d_a_fake + loss_d_a_real), - "discriminator_b_loss": float(loss_d_b_fake + loss_d_b_real), - # "cycle_loss": float(loss_cycle), - # "identity_loss": float(loss_identity), - # "generator_loss": float(loss_g), - # "discriminator_loss": float(loss_d), - "train_loss": float(loss_g + loss_d), - } - - -def set_requires_grad(nets: List[torch.nn.Module], requires_grad=False): - """Set requies_grad=Fasle for all the networks to avoid unnecessary computations - Parameters: - nets (network list) -- a list of networks - requires_grad (bool) -- whether the networks require gradients or not - """ - if not isinstance(nets, list): - nets = [nets] - for net in nets: - if net is not None: - for param in net.parameters(): - param.requires_grad = requires_grad diff --git a/external/fv3fit/fv3fit/pytorch/predict.py b/external/fv3fit/fv3fit/pytorch/predict.py index f01cf94855..1a5b0f3c94 100644 --- a/external/fv3fit/fv3fit/pytorch/predict.py +++ b/external/fv3fit/fv3fit/pytorch/predict.py @@ -113,14 +113,17 @@ def pack_to_tensor(self, X: xr.Dataset) -> torch.Tensor: ) # dimensions are [time, tile, x, y, z], # we must combine [time, tile] into one sample dimension - return torch.reshape( + reshaped = torch.reshape( packed, (packed.shape[0] * packed.shape[1],) + tuple(packed.shape[2:]), ) + # torch expects channels before x, y so we have to transpose + transposed = reshaped.permute([0, 3, 1, 2]) + return transposed def unpack_tensor(self, data: torch.Tensor) -> xr.Dataset: data = torch.reshape(data, (-1, 6) + tuple(data.shape[1:])) return _unpack_tensor( - data, + data.permute([0, 1, 3, 4, 2]), # convert from channels (z) first to last varnames=tuple(str(item) for item in self.output_variables), scalers=self.scalers, dims=["time", "tile", "x", "y", "z"], diff --git a/external/fv3fit/tests/training/test_autoencoder.py b/external/fv3fit/tests/training/test_autoencoder.py index 3158d3b8aa..7f80e1819b 100644 --- a/external/fv3fit/tests/training/test_autoencoder.py +++ b/external/fv3fit/tests/training/test_autoencoder.py @@ -2,7 +2,7 @@ import xarray as xr from typing import Sequence from fv3fit.pytorch.cyclegan import AutoencoderHyperparameters, train_autoencoder -from fv3fit.pytorch.cyclegan.train import TrainingConfig +from fv3fit.pytorch.cyclegan.train_autoencoder import TrainingConfig import pytest from fv3fit.data.synthetic import SyntheticWaves import collections @@ -78,14 +78,14 @@ def test_autoencoder(tmpdir): # doesn't particularly matter what the input data is, as long as the denoising # autoencoder can learn to remove noise from its samples. A dataset of # pure synthetic noise would not work, it must have some structure. - train_tfdataset = get_synthetic_waves_tfdataset(nsamples=20, **sizes) + train_tfdataset = get_synthetic_waves_tfdataset(nsamples=100, **sizes) val_tfdataset = get_synthetic_waves_tfdataset(nsamples=3, **sizes) hyperparameters = AutoencoderHyperparameters( state_variables=state_variables, generator=fv3fit.pytorch.GeneratorConfig( n_convolutions=2, n_resnet=3, max_filters=32 ), - training_loop=TrainingConfig(n_epoch=10, samples_per_batch=2), + training_loop=TrainingConfig(n_epoch=5, samples_per_batch=10), optimizer_config=fv3fit.pytorch.OptimizerConfig(name="Adam",), noise_amount=0.5, ) From 2b17e8869b4ff30e4b277d2632f27e880344d1cc Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Fri, 2 Sep 2022 17:23:50 -0700 Subject: [PATCH 26/55] cleanup merge leftover --- external/fv3fit/fv3fit/data/synthetic.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/external/fv3fit/fv3fit/data/synthetic.py b/external/fv3fit/fv3fit/data/synthetic.py index 46c6ede8b1..edcc7d61e4 100644 --- a/external/fv3fit/fv3fit/data/synthetic.py +++ b/external/fv3fit/fv3fit/data/synthetic.py @@ -69,11 +69,7 @@ class SyntheticWaves(TFDatasetLoader): period_max: maximum period of waves phase_range: fraction of 2*pi to use for possible range of random phase, should be a value between 0 and 1. -<<<<<<< HEAD - type: one of "sinusoidal" or "square" -======= wave_type: one of "sinusoidal" or "square" ->>>>>>> master """ nsamples: int @@ -88,7 +84,6 @@ class SyntheticWaves(TFDatasetLoader): period_min: float = 8.0 period_max: float = 16.0 phase_range: float = 1.0 - type: str = "sinusoidal" def open_tfdataset( self, local_download_path: Optional[str], variable_names: Sequence[str], From 5596903dde8b8eb25b5cf9629a15fc5f0773cf33 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Tue, 6 Sep 2022 17:11:56 -0700 Subject: [PATCH 27/55] add some docstrings to cyclegan_trainer.py --- .../pytorch/cyclegan/cyclegan_trainer.py | 96 +++++++++++++++---- 1 file changed, 78 insertions(+), 18 deletions(-) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py b/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py index 78a802d3e8..37f9e0663a 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py @@ -16,6 +16,30 @@ @dataclasses.dataclass class CycleGANNetworkConfig: + """ + Configuration for building and training a CycleGAN network. + + Attributes: + generator_optimizer: configuration for the optimizer used to train the + generator + discriminator_optimizer: configuration for the optimizer used to train the + discriminator + generator: configuration for building the generator network + discriminator: configuration for building the discriminator network + identity_loss: loss function used to make the generator which outputs + a given domain behave as an identity function when given data from + that domain as input + cycle_loss: loss function used on the difference between a round-trip + of the CycleGAN network and the original input + gan_loss: loss function used on output of the discriminator when + training the discriminator identify samples correctly or when training + the generator to fool the discriminator + identity_weight: weight of the identity loss + cycle_weight: weight of the cycle loss + generator_weight: weight of the generator's gan loss + discriminator_weight: weight of the discriminator gan loss + """ + generator_optimizer: OptimizerConfig = dataclasses.field( default_factory=lambda: OptimizerConfig("Adam") ) @@ -33,7 +57,7 @@ class CycleGANNetworkConfig: gan_loss: LossConfig = dataclasses.field(default_factory=LossConfig) identity_weight: float = 0.5 cycle_weight: float = 1.0 - gan_weight: float = 1.0 + generator_weight: float = 1.0 discriminator_weight: float = 1.0 def build( @@ -70,7 +94,7 @@ def build( batch_size=n_batch, identity_weight=self.identity_weight, cycle_weight=self.cycle_weight, - gan_weight=self.gan_weight, + generator_weight=self.generator_weight, discriminator_weight=self.discriminator_weight, ) @@ -98,6 +122,12 @@ def __init__(self, max_size=50): self.data = [] def push_and_pop(self, data: torch.Tensor) -> torch.autograd.Variable: + """ + Push data into the buffer and return a random sample of the buffer. + + If there are at least max_size elements in the buffer, the returned sample + is removed from the buffer. + """ to_return = [] for element in data.data: element = torch.unsqueeze(element, 0) @@ -115,6 +145,10 @@ def push_and_pop(self, data: torch.Tensor) -> torch.autograd.Variable: class StatsCollector: + """ + Object to track the mean and standard deviation of sampled arrays. + """ + def __init__(self, n_dims_keep: int): self.n_dims_keep = n_dims_keep self._sum = 0.0 @@ -122,6 +156,9 @@ def __init__(self, n_dims_keep: int): self._count = 0 def observe(self, data: np.ndarray): + """ + Add a new sample to the statistics. + """ mean_dims = tuple(range(0, len(data.shape) - self.n_dims_keep)) data = data.astype(np.float64) self._sum += data.mean(axis=mean_dims) @@ -130,10 +167,16 @@ def observe(self, data: np.ndarray): @property def mean(self) -> np.ndarray: + """ + Mean of the observed samples. + """ return self._sum / self._count @property def std(self) -> np.ndarray: + """ + Standard deviation of the observed samples. + """ return np.sqrt(self._sum_squared / self._count - self.mean ** 2) @@ -146,6 +189,27 @@ def get_r2(predicted, target) -> float: @dataclasses.dataclass class CycleGANTrainer: + """ + A trainer for a CycleGAN model. + + Attributes: + cycle_gan: the CycleGAN model to train + optimizer_generator: the optimizer for the generator + optimizer_discriminator: the optimizer for the discriminator + identity_loss: loss function used to make the generator which outputs + a given domain behave as an identity function when given data from + that domain as input + cycle_loss: loss function used on the difference between a round-trip + of the CycleGAN network and the original input + gan_loss: loss function used on output of the discriminator when + training the discriminator identify samples correctly or when training + the generator to fool the discriminator + batch_size: the number of samples to use in each batch when training + identity_weight: weight of the identity loss + cycle_weight: weight of the cycle loss + generator_weight: weight of the generator's gan loss + discriminator_weight: weight of the discriminator gan loss + """ # This class based loosely on # https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/cycle_gan_model.py @@ -162,7 +226,7 @@ class CycleGANTrainer: batch_size: int identity_weight: float = 0.5 cycle_weight: float = 1.0 - gan_weight: float = 1.0 + generator_weight: float = 1.0 discriminator_weight: float = 1.0 def __post_init__(self): @@ -244,10 +308,14 @@ def train_on_batch( pred_fake_b = self.discriminator_b(fake_b) if self.target_real is None: self._init_targets(pred_fake_b.shape) - loss_gan_a_to_b = self.gan_loss(pred_fake_b, self.target_real) * self.gan_weight + loss_gan_a_to_b = ( + self.gan_loss(pred_fake_b, self.target_real) * self.generator_weight + ) pred_fake_a = self.discriminator_a(fake_a) - loss_gan_b_to_a = self.gan_loss(pred_fake_a, self.target_real) * self.gan_weight + loss_gan_b_to_a = ( + self.gan_loss(pred_fake_a, self.target_real) * self.generator_weight + ) loss_gan = loss_gan_a_to_b + loss_gan_b_to_a # Cycle loss @@ -271,41 +339,33 @@ def train_on_batch( # Real loss pred_real = self.discriminator_a(real_a) loss_d_a_real = ( - self.gan_loss(pred_real, self.target_real) - * self.gan_weight - * self.discriminator_weight + self.gan_loss(pred_real, self.target_real) * self.discriminator_weight ) # Fake loss fake_a = self.fake_a_buffer.push_and_pop(fake_a) pred_a_fake = self.discriminator_a(fake_a.detach()) loss_d_a_fake = ( - self.gan_loss(pred_a_fake, self.target_fake) - * self.gan_weight - * self.discriminator_weight + self.gan_loss(pred_a_fake, self.target_fake) * self.discriminator_weight ) # Real loss pred_real = self.discriminator_b(real_b) loss_d_b_real = ( - self.gan_loss(pred_real, self.target_real) - * self.gan_weight - * self.discriminator_weight + self.gan_loss(pred_real, self.target_real) * self.discriminator_weight ) # Fake loss fake_b = self.fake_b_buffer.push_and_pop(fake_b) pred_b_fake = self.discriminator_b(fake_b.detach()) loss_d_b_fake = ( - self.gan_loss(pred_b_fake, self.target_fake) - * self.gan_weight - * self.discriminator_weight + self.gan_loss(pred_b_fake, self.target_fake) * self.discriminator_weight ) # Total loss loss_d: torch.Tensor = ( loss_d_b_real + loss_d_b_fake + loss_d_a_real + loss_d_a_fake - ) * self.discriminator_weight + ) self.optimizer_discriminator.zero_grad() loss_d.backward() From d8a3796470d43395a498f4e7b51358258f93e4c5 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Wed, 7 Sep 2022 11:45:56 -0700 Subject: [PATCH 28/55] update n_convolutions in test to reflect new api behavior --- external/fv3fit/tests/training/test_autoencoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/external/fv3fit/tests/training/test_autoencoder.py b/external/fv3fit/tests/training/test_autoencoder.py index 7f80e1819b..12d218308c 100644 --- a/external/fv3fit/tests/training/test_autoencoder.py +++ b/external/fv3fit/tests/training/test_autoencoder.py @@ -83,7 +83,7 @@ def test_autoencoder(tmpdir): hyperparameters = AutoencoderHyperparameters( state_variables=state_variables, generator=fv3fit.pytorch.GeneratorConfig( - n_convolutions=2, n_resnet=3, max_filters=32 + n_convolutions=1, n_resnet=3, max_filters=32 ), training_loop=TrainingConfig(n_epoch=5, samples_per_batch=10), optimizer_config=fv3fit.pytorch.OptimizerConfig(name="Adam",), @@ -138,7 +138,7 @@ def test_autoencoder_overfit(tmpdir): hyperparameters = AutoencoderHyperparameters( state_variables=state_variables, generator=fv3fit.pytorch.GeneratorConfig( - n_convolutions=2, n_resnet=1, max_filters=32 + n_convolutions=1, n_resnet=1, max_filters=32 ), training_loop=TrainingConfig(n_epoch=100, samples_per_batch=6), optimizer_config=fv3fit.pytorch.OptimizerConfig(name="Adam",), From 8f18ccd2a651e6e261ea7b2de6c8db410cd9919f Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Wed, 7 Sep 2022 12:10:55 -0700 Subject: [PATCH 29/55] improve documentation for cyclegan model and training --- .../fv3fit/pytorch/cyclegan/discriminator.py | 40 +++++++++ .../fv3fit/pytorch/cyclegan/generator.py | 90 +++++++++++++++++-- .../fv3fit/fv3fit/pytorch/cyclegan/modules.py | 49 +++++++--- .../fv3fit/pytorch/cyclegan/reloadable.py | 8 +- .../fv3fit/pytorch/cyclegan/train_cyclegan.py | 30 +++++-- .../fv3fit/tests/training/test_cyclegan.py | 36 +++++--- 6 files changed, 217 insertions(+), 36 deletions(-) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/discriminator.py b/external/fv3fit/fv3fit/pytorch/cyclegan/discriminator.py index adf764b2e8..a43614ce53 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/discriminator.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/discriminator.py @@ -13,6 +13,20 @@ @dataclasses.dataclass class DiscriminatorConfig: + """ + Configuration for a discriminator network. + + Follows the architecture of Zhu et al. 2017, https://arxiv.org/abs/1703.10593. + Uses a series of strided convolutions with leaky ReLU activations, followed + by two convolutional layers. + + Args: + n_convolutions: number of strided convolutional layers before the + final convolutional output layer, must be at least 1 + kernel_size: size of convolutional kernels + max_filters: maximum number of filters in any convolutional layer, + equal to the number of filters in the final strided convolutional layer + """ n_convolutions: int = 3 kernel_size: int = 3 @@ -31,6 +45,10 @@ def build( class Discriminator(nn.Module): + + # analogous to NLayerDiscriminator at + # https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + def __init__( self, in_channels: int, @@ -39,7 +57,20 @@ def __init__( max_filters: int, convolution: ConvolutionFactory = single_tile_convolution, ): + """ + Args: + in_channels: number of input channels + n_convolutions: number of strided convolutional layers before the + final convolutional output layers, must be at least 1 + kernel_size: size of convolutional kernels + max_filters: maximum number of filters in any convolutional layer, + equal to the number of filters in the final strided convolutional layer + and in the convolutional layer just before the output layer + convolution: factory for creating all convolutional layers + """ super(Discriminator, self).__init__() + if n_convolutions < 1: + raise ValueError("n_convolutions must be at least 1") # max_filters = min_filters * 2 ** (n_convolutions - 1), therefore min_filters = int(max_filters / 2 ** (n_convolutions - 1)) convs = [ @@ -54,6 +85,7 @@ def __init__( ), ) ] + # we've already defined the first strided convolutional layer, so start at 1 for i in range(1, n_convolutions): convs.append( ConvBlock( @@ -67,6 +99,7 @@ def __init__( ), ) ) + # final_conv isn't strided so it's not included in the n_convolutions count final_conv = ConvBlock( in_channels=max_filters, out_channels=max_filters, @@ -79,4 +112,11 @@ def __init__( self._sequential = nn.Sequential(*convs, final_conv, patch_output) def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """ + Args: + inputs: tensor of shape (batch, in_channels, height, width) + + Returns: + tensor of shape (batch, 1, height, width) + """ return self._sequential(inputs) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/generator.py b/external/fv3fit/fv3fit/pytorch/cyclegan/generator.py index a768d0d588..22671bed29 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/generator.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/generator.py @@ -8,11 +8,32 @@ single_tile_convolution, relu_activation, ResnetBlock, + CurriedModuleFactory, ) @dataclasses.dataclass class GeneratorConfig: + """ + Configuration for a generator network. + + Follows the architecture of Zhu et al. 2017, https://arxiv.org/abs/1703.10593. + This network contains an initial convolutional layer with kernel size 7, + strided convolutions with stride of 2, multiple residual blocks, + fractionally strided convolutions with stride 1/2, followed by an output + convolutional layer with kernel size 7 to map to the output channels. + + Attributes: + n_convolutions: number of strided convolutional layers after the initial + convolutional layer and before the residual blocks + n_resnet: number of residual blocks + kernel_size: size of convolutional kernels in the strided convolutions + and resnet blocks + max_filters: maximum number of filters in any convolutional layer, + equal to the number of filters in the final strided convolutional layer + and in the resnet blocks + """ + n_convolutions: int = 3 n_resnet: int = 3 kernel_size: int = 3 @@ -21,6 +42,12 @@ class GeneratorConfig: def build( self, channels: int, convolution: ConvolutionFactory = single_tile_convolution, ): + """ + Args: + channels: number of input channels + convolution: factory for creating all convolutional layers + used by the network + """ return Generator( channels=channels, n_convolutions=self.n_convolutions, @@ -41,12 +68,31 @@ def __init__( max_filters: int, convolution: ConvolutionFactory = single_tile_convolution, ): + """ + Args: + channels: number of input and output channels + n_convolutions: number of strided convolutional layers after the initial + convolutional layer and before the residual blocks + n_resnet: number of residual blocks + kernel_size: size of convolutional kernels in the strided convolutions + and resnet blocks + max_filters: maximum number of filters in any convolutional layer, + equal to the number of filters in the final strided convolutional layer + and in the resnet blocks + convolution: factory for creating all convolutional layers + used by the network + """ super(Generator, self).__init__() - def resnet(in_channels: int): + def resnet(in_channels: int, out_channels: int): + if in_channels != out_channels: + raise ValueError( + "resnet must have same number of output channels as " + "input channels, since the inputs are added to the outputs" + ) resnet_blocks = [ ResnetBlock( - n_filters=in_channels, + channels=in_channels, convolution_factory=curry(convolution)( kernel_size=3, padding="same" ), @@ -80,7 +126,7 @@ def up(in_channels: int, out_channels: int): activation_factory=relu_activation(), ) - min_filters = int(max_filters / 2 ** (n_convolutions - 1)) + min_filters = int(max_filters / 2 ** n_convolutions) self._first_conv = nn.Sequential( convolution( @@ -96,7 +142,7 @@ def up(in_channels: int, out_channels: int): down_factory=down, up_factory=up, bottom_factory=resnet, - depth=n_convolutions - 1, + depth=n_convolutions, in_channels=min_filters, ) @@ -108,6 +154,13 @@ def up(in_channels: int, out_channels: int): ) def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """ + Args: + inputs: tensor of shape (batch, channels, height, width) + + Returns: + tensor of shape (batch, channels, height, width) + """ x = self._first_conv(inputs) x = self._encoder_decoder(x) outputs: torch.Tensor = self._out_conv(x) @@ -122,14 +175,30 @@ class SymmetricEncoderDecoder(nn.Module): """ def __init__( - self, down_factory, up_factory, bottom_factory, depth: int, in_channels: int, + self, + down_factory: CurriedModuleFactory, + up_factory: CurriedModuleFactory, + bottom_factory: CurriedModuleFactory, + depth: int, + in_channels: int, ): + """ + Args: + down_factory: factory for creating a downsample module which reduces + height and width by a factor of 2, such as strided convolution + up_factory: factory for creating an upsample module which doubles + height and width, such as fractionally strided convolution + bottom_factory: factory for creating the bottom module which keeps + height and width constant + """ super(SymmetricEncoderDecoder, self).__init__() lower_channels = 2 * in_channels self._down = down_factory(in_channels=in_channels, out_channels=lower_channels) self._up = up_factory(in_channels=lower_channels, out_channels=in_channels) if depth == 1: - self._lower = bottom_factory(in_channels=lower_channels) + self._lower = bottom_factory( + in_channels=lower_channels, out_channels=lower_channels + ) elif depth <= 0: raise ValueError(f"depth must be at least 1, got {depth}") else: @@ -141,7 +210,14 @@ def __init__( in_channels=lower_channels, ) - def forward(self, inputs): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """ + Args: + inputs: tensor of shape (batch, channels, height, width) + + Returns: + tensor of shape (batch, channels, height, width) + """ x = self._down(inputs) x = self._lower(x) x = self._up(x) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/modules.py b/external/fv3fit/fv3fit/pytorch/cyclegan/modules.py index 49fab74d69..87ad751b2a 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/modules.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/modules.py @@ -56,10 +56,10 @@ def __call__( ... -class CurriedConvolutionFactory(Protocol): - def __call__(self, in_channels: int, out_channels: int,) -> nn.Module: +class CurriedModuleFactory(Protocol): + def __call__(self, in_channels: int, out_channels: int) -> nn.Module: """ - Create a convolutional layer. + Create a torch module. Args: in_channels: number of input channels @@ -120,23 +120,38 @@ def single_tile_convolution( class ResnetBlock(nn.Module): + """ + Residual network block as defined in He et al. 2016, + https://arxiv.org/abs/1512.03385. + + Contains two convolutional layers with instance normalization, and an + activation function applied to the first layer's instance-normalized output. + The input to the block is added to the output of the final convolutional layer. + """ + def __init__( self, - n_filters: int, - convolution_factory: CurriedConvolutionFactory, + channels: int, + convolution_factory: CurriedModuleFactory, activation_factory: Callable[[], nn.Module] = relu_activation(), ): + """ + Args: + channels: number of input channels and filters in the convolutional layers + convolution_factory: factory for creating convolutional layers + activation_factory: factory for creating activation layers + """ super(ResnetBlock, self).__init__() self.conv_block = nn.Sequential( ConvBlock( - in_channels=n_filters, - out_channels=n_filters, + in_channels=channels, + out_channels=channels, convolution_factory=convolution_factory, activation_factory=activation_factory, ), ConvBlock( - in_channels=n_filters, - out_channels=n_filters, + in_channels=channels, + out_channels=channels, convolution_factory=convolution_factory, activation_factory=no_activation, ), @@ -149,14 +164,28 @@ def forward(self, inputs): class ConvBlock(nn.Module): + """ + Module packaging a convolutional layer with instance normalization and activation. + """ + def __init__( self, in_channels: int, out_channels: int, - convolution_factory: CurriedConvolutionFactory, + convolution_factory: CurriedModuleFactory, activation_factory: Callable[[], nn.Module] = relu_activation(), ): + """ + Args: + in_channels: number of input channels + out_channels: number of output channels + convolution_factory: factory for creating convolutional layers + activation_factory: factory for creating activation layers + """ super(ConvBlock, self).__init__() + # it's helpful to package this code into a class so that we can e.g. see what + # happens when globally disabling InstanceNorm2d or switching to another type + # of normalization, while debugging. self.conv_block = nn.Sequential( convolution_factory(in_channels=in_channels, out_channels=out_channels), nn.InstanceNorm2d(out_channels), diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/reloadable.py b/external/fv3fit/fv3fit/pytorch/cyclegan/reloadable.py index 369d1ef612..478d6f9d4a 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/reloadable.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/reloadable.py @@ -15,6 +15,12 @@ class CycleGANModule(torch.nn.Module): + """ + Torch module containing the components of a CycleGAN. + """ + + # we package this in this way so we can easily transform the model + # to different devices, and save/load the model as one module def __init__( self, generator_a_to_b: Generator, @@ -154,7 +160,7 @@ def predict(self, X: xr.Dataset, reverse: bool = False) -> xr.Dataset: ) with torch.no_grad(): if reverse: - outputs = self.generator_b_to_a(reshaped_tensor) + outputs: torch.Tensor = self.generator_b_to_a(reshaped_tensor) else: outputs = self.generator_a_to_b(reshaped_tensor) outputs = outputs.reshape(tensor.shape) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py index 8c0f3b5e1f..ea10f097fe 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py @@ -1,7 +1,6 @@ from fv3fit._shared.hyperparameters import Hyperparameters import dataclasses import tensorflow as tf -from fv3fit.pytorch.loss import LossConfig import torch from fv3fit.pytorch.system import DEVICE import tensorflow_datasets as tfds @@ -32,16 +31,25 @@ @dataclasses.dataclass class CycleGANHyperparameters(Hyperparameters): + """ + Hyperparameters for CycleGAN training. + + Attributes: + state_variables: list of variables to be transformed by the model + normalization_fit_samples: number of samples to use when fitting the + normalization + network: configuration for the CycleGAN network + training: configuration for the CycleGAN training + """ state_variables: List[str] normalization_fit_samples: int = 50_000 network: "CycleGANNetworkConfig" = dataclasses.field( default_factory=lambda: CycleGANNetworkConfig() ) - training_loop: "CycleGANTrainingConfig" = dataclasses.field( + training: "CycleGANTrainingConfig" = dataclasses.field( default_factory=lambda: CycleGANTrainingConfig() ) - loss: LossConfig = LossConfig(loss_type="mse") @property def variables(self): @@ -50,6 +58,15 @@ def variables(self): @dataclasses.dataclass class CycleGANTrainingConfig: + """ + Attributes: + n_epoch: number of epochs to train for + shuffle_buffer_size: number of samples to use for shuffling the training data + samples_per_batch: number of samples to use per batch + validation_batch_size: number of samples to use per batch for validation, + does not affect training result but allows the use of out-of-sample + validation data + """ n_epoch: int = 20 shuffle_buffer_size: int = 10 @@ -178,17 +195,20 @@ def train_cyclegan( train_model = hyperparameters.network.build( n_state=next(iter(train_state))[0].shape[-1], - n_batch=hyperparameters.training_loop.samples_per_batch, + n_batch=hyperparameters.training.samples_per_batch, state_variables=hyperparameters.state_variables, scalers=scalers, ) # remove time and tile dimensions, while we're using regular convolution + # MPS backend has a bug where it doesn't properly read striding information when + # doing 2d convolutions, so we need to use a channels-first data layout + # from the get-go and do transformations before and after while in numpy/tf space. train_state = train_state.unbatch().map(apply_to_tuple(channels_first)).unbatch() if validation_batches is not None: val_state = val_state.unbatch().map(apply_to_tuple(channels_first)).unbatch() - hyperparameters.training_loop.fit_loop( + hyperparameters.training.fit_loop( train_model=train_model, train_data=train_state, validation_data=val_state, ) return train_model.cycle_gan diff --git a/external/fv3fit/tests/training/test_cyclegan.py b/external/fv3fit/tests/training/test_cyclegan.py index e2ac0e2bab..9cc8a5af90 100644 --- a/external/fv3fit/tests/training/test_cyclegan.py +++ b/external/fv3fit/tests/training/test_cyclegan.py @@ -27,7 +27,7 @@ def get_tfdataset(nsamples, nbatch, ntime, nx, nz): ntime=ntime, nx=nx, nz=nz, - scalar_names=["b"], + scalar_names=["var_2d"], scale_min=0.5, scale_max=1.0, period_min=8, @@ -40,7 +40,7 @@ def get_tfdataset(nsamples, nbatch, ntime, nx, nz): ntime=ntime, nx=nx, nz=nz, - scalar_names=["b"], + scalar_names=["var_2d"], scale_min=0.5, scale_max=1.0, period_min=8, @@ -49,7 +49,9 @@ def get_tfdataset(nsamples, nbatch, ntime, nx, nz): ), ] ) - dataset = config.open_tfdataset(local_download_path=None, variable_names=["a", "b"]) + dataset = config.open_tfdataset( + local_download_path=None, variable_names=["var_3d", "var_2d"] + ) return dataset @@ -62,6 +64,7 @@ def get_noise_tfdataset(nsamples, nbatch, ntime, nx, nz): ntime=ntime, nx=nx, nz=nz, + scalar_names=["var_2d"], noise_amplitude=1.0, ), SyntheticNoise( @@ -70,11 +73,14 @@ def get_noise_tfdataset(nsamples, nbatch, ntime, nx, nz): ntime=ntime, nx=nx, nz=nz, + scalar_names=["var_2d"], noise_amplitude=1.0, ), ] ) - dataset = config.open_tfdataset(local_download_path=None, variable_names=["a", "b"]) + dataset = config.open_tfdataset( + local_download_path=None, variable_names=["var_3d", "var_2d"] + ) return dataset @@ -109,14 +115,14 @@ def test_cyclegan(tmpdir): # on whether we can autoencode sin waves, and need to resolve full cycles nx = 32 sizes = {"nbatch": 1, "ntime": 1, "nx": nx, "nz": 2} - state_variables = ["a", "b"] + state_variables = ["var_3d", "var_2d"] train_tfdataset = get_tfdataset(nsamples=200, **sizes) val_tfdataset = get_tfdataset(nsamples=20, **sizes) hyperparameters = CycleGANHyperparameters( state_variables=state_variables, network=CycleGANNetworkConfig( generator=fv3fit.pytorch.GeneratorConfig( - n_convolutions=3, n_resnet=5, max_filters=128, kernel_size=3 + n_convolutions=2, n_resnet=5, max_filters=128, kernel_size=3 ), generator_optimizer=fv3fit.pytorch.OptimizerConfig( name="Adam", kwargs={"lr": 0.001} @@ -130,7 +136,7 @@ def test_cyclegan(tmpdir): # gan_weight=1.0, discriminator_weight=0.5, ), - training_loop=CycleGANTrainingConfig( + training=CycleGANTrainingConfig( n_epoch=30, samples_per_batch=20, validation_batch_size=10 ), ) @@ -151,12 +157,16 @@ def test_cyclegan(tmpdir): fig, ax = plt.subplots(3, 2, figsize=(8, 8)) vmin = -1.5 vmax = 1.5 - ax[0, 0].imshow(real_a["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) - ax[0, 1].imshow(real_b["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) - ax[1, 0].imshow(output_b["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) - ax[1, 1].imshow(output_a["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) - ax[2, 0].imshow(reconstructed_a["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) - ax[2, 1].imshow(reconstructed_b["a"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) + ax[0, 0].imshow(real_a["var_3d"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) + ax[0, 1].imshow(real_b["var_3d"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) + ax[1, 0].imshow(output_b["var_3d"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) + ax[1, 1].imshow(output_a["var_3d"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax) + ax[2, 0].imshow( + reconstructed_a["var_3d"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax + ) + ax[2, 1].imshow( + reconstructed_b["var_3d"][0, i, :, :, iz].values, vmin=vmin, vmax=vmax + ) ax[0, 0].set_title("real a") ax[0, 1].set_title("real b") ax[1, 0].set_title("output b") From c0c5d9065e725432d876a99c87adb701b7b8d909 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Wed, 7 Sep 2022 13:54:32 -0700 Subject: [PATCH 30/55] update shapes to their expected final form at each scope, document them --- external/fv3fit/fv3fit/data/synthetic.py | 4 +- external/fv3fit/fv3fit/data/tfdataset.py | 6 ++- .../pytorch/cyclegan/cyclegan_trainer.py | 35 +++++++++++-- .../fv3fit/pytorch/cyclegan/discriminator.py | 4 +- .../fv3fit/pytorch/cyclegan/generator.py | 8 +-- .../fv3fit/fv3fit/pytorch/cyclegan/modules.py | 51 ++++++++++++++++--- .../fv3fit/pytorch/cyclegan/reloadable.py | 10 ++-- .../fv3fit/pytorch/cyclegan/train_cyclegan.py | 31 +++++++---- external/fv3fit/fv3fit/pytorch/system.py | 18 ++++--- 9 files changed, 123 insertions(+), 44 deletions(-) diff --git a/external/fv3fit/fv3fit/data/synthetic.py b/external/fv3fit/fv3fit/data/synthetic.py index edcc7d61e4..e13ed509af 100644 --- a/external/fv3fit/fv3fit/data/synthetic.py +++ b/external/fv3fit/fv3fit/data/synthetic.py @@ -94,8 +94,8 @@ def open_tfdataset( variable_names: names of variables to include when loading data Returns: dataset containing requested variables, each record is a mapping from - variable name to variable value, and each value is a tensor whose - first dimension is the batch dimension + variable name to variable value, and each value is a tensor + of shape [nbatch, nsamples, ntime, nx, ny(, nz)] """ if self.wave_type == "sinusoidal": func = np.sin diff --git a/external/fv3fit/fv3fit/data/tfdataset.py b/external/fv3fit/fv3fit/data/tfdataset.py index 69712c0d66..f64bbee946 100644 --- a/external/fv3fit/fv3fit/data/tfdataset.py +++ b/external/fv3fit/fv3fit/data/tfdataset.py @@ -73,8 +73,10 @@ def open_tfdataset( ) -> tf.data.Dataset: datasets = [] for config in self.domain_configs: - datasets.append(config.open_tfdataset(local_download_path, variable_names)) - return tf.data.Dataset.zip(tuple(datasets)) + datasets.append( + config.open_tfdataset(local_download_path, variable_names).unbatch() + ) + return tf.data.Dataset.zip(tuple(datasets)).batch(batch_size=self.batch_size) @classmethod def from_dict(cls, d: dict) -> "CycleGANLoader": diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py b/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py index 37f9e0663a..e07bb85948 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py @@ -257,6 +257,14 @@ def evaluate_on_dataset( real_a: np.ndarray real_b: np.ndarray for real_a, real_b in dataset: + # for now there is no time-evolution-based loss, so we fold the time + # dimension into the sample dimension + real_a = real_a.reshape( + [real_a.shape[0] * real_a.shape[1]] + list(real_a.shape[2:]) + ) + real_b = real_b.reshape( + [real_b.shape[0] * real_b.shape[1]] + list(real_b.shape[2:]) + ) stats_real_a.observe(real_a) stats_real_b.observe(real_b) gen_b: torch.Tensor = self.generator_a_to_b( @@ -283,6 +291,24 @@ def evaluate_on_dataset( def train_on_batch( self, real_a: torch.Tensor, real_b: torch.Tensor ) -> Mapping[str, float]: + """ + Train the CycleGAN on a batch of data. + + Args: + real_a: a batch of data from domain A, should have shape + [sample, time, tile, channel, y, x] + real_b: a batch of data from domain B, should have shape + [sample, time, tile, channel, y, x] + """ + # for now there is no time-evolution-based loss, so we fold the time + # dimension into the sample dimension + real_a = real_a.reshape( + [real_a.shape[0] * real_a.shape[1]] + list(real_a.shape[2:]) + ) + real_b = real_b.reshape( + [real_b.shape[0] * real_b.shape[1]] + list(real_b.shape[2:]) + ) + fake_b = self.generator_a_to_b(real_a) fake_a = self.generator_b_to_a(real_b) reconstructed_a = self.generator_b_to_a(fake_b) @@ -372,15 +398,14 @@ def train_on_batch( self.optimizer_discriminator.step() return { - # "gan_loss": float(loss_gan), "b_to_a_gan_loss": float(loss_gan_b_to_a), "a_to_b_gan_loss": float(loss_gan_a_to_b), "discriminator_a_loss": float(loss_d_a_fake + loss_d_a_real), "discriminator_b_loss": float(loss_d_b_fake + loss_d_b_real), - # "cycle_loss": float(loss_cycle), - # "identity_loss": float(loss_identity), - # "generator_loss": float(loss_g), - # "discriminator_loss": float(loss_d), + "cycle_loss": float(loss_cycle), + "identity_loss": float(loss_identity), + "generator_loss": float(loss_g), + "discriminator_loss": float(loss_d), "train_loss": float(loss_g + loss_d), } diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/discriminator.py b/external/fv3fit/fv3fit/pytorch/cyclegan/discriminator.py index a43614ce53..3361382eaf 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/discriminator.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/discriminator.py @@ -114,9 +114,9 @@ def __init__( def forward(self, inputs: torch.Tensor) -> torch.Tensor: """ Args: - inputs: tensor of shape (batch, in_channels, height, width) + inputs: tensor of shape (batch, tile, in_channels, height, width) Returns: - tensor of shape (batch, 1, height, width) + tensor of shape (batch, tile, 1, height, width) """ return self._sequential(inputs) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/generator.py b/external/fv3fit/fv3fit/pytorch/cyclegan/generator.py index 22671bed29..30947e9ab5 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/generator.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/generator.py @@ -156,10 +156,10 @@ def up(in_channels: int, out_channels: int): def forward(self, inputs: torch.Tensor) -> torch.Tensor: """ Args: - inputs: tensor of shape (batch, channels, height, width) + inputs: tensor of shape [batch, tile, channels, x, y] Returns: - tensor of shape (batch, channels, height, width) + tensor of shape [batch, tile, channels, x, y] """ x = self._first_conv(inputs) x = self._encoder_decoder(x) @@ -213,10 +213,10 @@ def __init__( def forward(self, inputs: torch.Tensor) -> torch.Tensor: """ Args: - inputs: tensor of shape (batch, channels, height, width) + inputs: tensor of shape [batch, tile, channels, x, y] Returns: - tensor of shape (batch, channels, height, width) + tensor of shape [batch, tile, channels, x, y] """ x = self._down(inputs) x = self._lower(x) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/modules.py b/external/fv3fit/fv3fit/pytorch/cyclegan/modules.py index 87ad751b2a..a0367f5317 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/modules.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/modules.py @@ -2,6 +2,7 @@ from typing import Callable, Literal, Protocol, Union import torch.nn as nn +import torch logger = logging.getLogger(__name__) @@ -43,6 +44,8 @@ def __call__( """ Create a convolutional layer. + Layer takes in and returns tensors of shape [batch, tile, channels, x, y]. + Args: in_channels: number of input channels out_channels: number of output channels @@ -68,6 +71,23 @@ def __call__(self, in_channels: int, out_channels: int) -> nn.Module: ... +class FoldTileDimension(nn.Module): + """ + Module wrapping a module which takes [batch, channel, x, y] data into one + which takes [batch, tile, channel, x, y] data by folding the tile dimension + into the batch dimension. + """ + + def __init__(self, wrapped): + super(FoldTileDimension, self).__init__() + self._wrapped = wrapped + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + x = inputs.reshape(-1, *inputs.shape[2:]) + x = self._wrapped(x) + return x.reshape(inputs.shape[0], inputs.shape[1], *x.shape[1:]) + + def single_tile_convolution( in_channels: int, out_channels: int, @@ -81,6 +101,8 @@ def single_tile_convolution( """ Construct a convolutional layer for single tile data (like images). + Layer takes in and returns tensors of shape [batch, tile, channels, x, y]. + Args: kernel_size: size of the convolution kernel padding: padding to apply to the input, should be an integer or "same" @@ -90,7 +112,7 @@ def single_tile_convolution( bias: whether to include a bias vector in the produced layers """ if stride == 1: - return nn.Conv2d( + conv = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, @@ -99,7 +121,7 @@ def single_tile_convolution( ) elif stride_type == "regular": - return nn.Conv2d( + conv = nn.Conv2d( in_channels, out_channels, kernel_size, @@ -108,7 +130,7 @@ def single_tile_convolution( bias=bias, ) elif stride_type == "transpose": - return nn.ConvTranspose2d( + conv = nn.ConvTranspose2d( in_channels, out_channels, kernel_size, @@ -117,6 +139,9 @@ def single_tile_convolution( output_padding=output_padding, bias=bias, ) + else: + raise ValueError(f"Invalid stride_type: {stride_type}") + return FoldTileDimension(conv) class ResnetBlock(nn.Module): @@ -158,7 +183,14 @@ def __init__( ) self.identity = nn.Identity() - def forward(self, inputs): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """ + Args: + inputs: tensor of shape [batch, tile, channels, x, y] + + Returns: + tensor of shape [batch, tile, channels, x, y] + """ g = self.conv_block(inputs) return g + self.identity(inputs) @@ -188,9 +220,16 @@ def __init__( # of normalization, while debugging. self.conv_block = nn.Sequential( convolution_factory(in_channels=in_channels, out_channels=out_channels), - nn.InstanceNorm2d(out_channels), + FoldTileDimension(nn.InstanceNorm2d(out_channels)), activation_factory(), ) - def forward(self, inputs): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """ + Args: + inputs: tensor of shape [batch, tile, channels, x, y] + + Returns: + tensor of shape [batch, tile, channels, x, y] + """ return self.conv_block(inputs) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/reloadable.py b/external/fv3fit/fv3fit/pytorch/cyclegan/reloadable.py index 478d6f9d4a..aee5d459ff 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/reloadable.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/reloadable.py @@ -17,6 +17,9 @@ class CycleGANModule(torch.nn.Module): """ Torch module containing the components of a CycleGAN. + + All modules expect inputs and produce outputs of shape + (batch, tile, channels, x, y). """ # we package this in this way so we can easily transform the model @@ -155,14 +158,11 @@ def predict(self, X: xr.Dataset, reverse: bool = False) -> xr.Dataset: input_domain, output_domain = "a", "b" tensor = self.pack_to_tensor(X, domain=input_domain) - reshaped_tensor = tensor.reshape( - [tensor.shape[0] * tensor.shape[1]] + list(tensor.shape[2:]) - ) with torch.no_grad(): if reverse: - outputs: torch.Tensor = self.generator_b_to_a(reshaped_tensor) + outputs: torch.Tensor = self.generator_b_to_a(tensor) else: - outputs = self.generator_a_to_b(reshaped_tensor) + outputs = self.generator_a_to_b(tensor) outputs = outputs.reshape(tensor.shape) predicted = self.unpack_tensor(outputs, domain=output_domain) return predicted diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py index ea10f097fe..6268378dad 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py @@ -81,13 +81,13 @@ def fit_loop( ) -> None: """ Args: - train_model: cycle-GAN to train + train_model: Cycle-GAN to train train_data: training dataset containing samples to be passed to the model, - should have dimensions [sample, time, tile, x, y, z] + should be unbatched and have dimensions [time, tile, z, x, y] validation_data: validation dataset containing samples to be passed - to the model, should have dimensions [sample, time, tile, x, y, z] + to the model, should be unbatched and have dimensions + [time, tile, z, x, y] """ - train_data = train_data.shuffle(buffer_size=self.shuffle_buffer_size).batch( self.samples_per_batch ) @@ -151,7 +151,8 @@ def Xy_map_fn(*data: Mapping[str, np.ndarray]): def channels_first(data: tf.Tensor) -> tf.Tensor: - return tf.transpose(data, perm=[0, 3, 1, 2]) + # [batch, time, tile, x, y, z] -> [batch, time, tile, z, x, y] + return tf.transpose(data, perm=[0, 1, 2, 5, 3, 4]) @register_training_function("cyclegan", CycleGANHyperparameters) @@ -182,16 +183,16 @@ def train_cyclegan( get_Xy = get_Xy_map_fn( state_variables=hyperparameters.state_variables, - n_dims=6, # [batch, time, tile, x, y, z] + n_dims=6, # [batch, sample, tile, x, y, z] mapping_scale_funcs=mapping_scale_funcs, ) if validation_batches is not None: - val_state = validation_batches.map(get_Xy).unbatch() + val_state = validation_batches.map(get_Xy) else: val_state = None - train_state = train_batches.map(get_Xy).unbatch() + train_state = train_batches.map(get_Xy) train_model = hyperparameters.network.build( n_state=next(iter(train_state))[0].shape[-1], @@ -200,13 +201,21 @@ def train_cyclegan( scalers=scalers, ) - # remove time and tile dimensions, while we're using regular convolution + # time and tile dimensions aren't being used yet while we're using single-tile + # convolution without a motion constraint, but they will be used in the future + # MPS backend has a bug where it doesn't properly read striding information when # doing 2d convolutions, so we need to use a channels-first data layout # from the get-go and do transformations before and after while in numpy/tf space. - train_state = train_state.unbatch().map(apply_to_tuple(channels_first)).unbatch() + train_state = train_state.map(apply_to_tuple(channels_first)) + if validation_batches is not None: + val_state = val_state.map(apply_to_tuple(channels_first)) + + # batching from the loader is undone here, so we can do our own batching + # in fit_loop + train_state = train_state.unbatch() if validation_batches is not None: - val_state = val_state.unbatch().map(apply_to_tuple(channels_first)).unbatch() + val_state = val_state.unbatch() hyperparameters.training.fit_loop( train_model=train_model, train_data=train_state, validation_data=val_state, diff --git a/external/fv3fit/fv3fit/pytorch/system.py b/external/fv3fit/fv3fit/pytorch/system.py index 7e03607b3e..54fc7b1635 100644 --- a/external/fv3fit/fv3fit/pytorch/system.py +++ b/external/fv3fit/fv3fit/pytorch/system.py @@ -1,10 +1,14 @@ import torch.backends import torch +import os -DEVICE = torch.device( - "cuda:0" - if torch.cuda.is_available() - else "mps" - if torch.backends.mps.is_available() - else "cpu" -) +if os.environ.get("TORCH_CPU_ONLY", False): + DEVICE = torch.device("cpu") +else: + DEVICE = torch.device( + "cuda:0" + if torch.cuda.is_available() + else "mps" + if torch.backends.mps.is_available() + else "cpu" + ) From eb22eda3486e159c900bcc3f4241f5db5add4ab2 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Wed, 7 Sep 2022 13:59:07 -0700 Subject: [PATCH 31/55] add test of StatsCollector --- external/fv3fit/tests/test_stats_collector.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 external/fv3fit/tests/test_stats_collector.py diff --git a/external/fv3fit/tests/test_stats_collector.py b/external/fv3fit/tests/test_stats_collector.py new file mode 100644 index 0000000000..8e5986c71e --- /dev/null +++ b/external/fv3fit/tests/test_stats_collector.py @@ -0,0 +1,12 @@ +from fv3fit.pytorch.cyclegan.cyclegan_trainer import StatsCollector +import numpy as np + + +def test_stats_collector(): + np.random.seed(0) + values = np.random.uniform(low=5.0, high=1.0, size=(100, 10)) + stats_collector = StatsCollector(n_dims_keep=1) + for i in range(values.shape[0]): + stats_collector.observe(values[i, :]) + assert np.allclose(stats_collector.mean, np.mean(values, axis=0)) + assert np.allclose(stats_collector.std, np.std(values, axis=0)) From e88b482711f8524239c381def35602c9b6ac08c7 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Wed, 7 Sep 2022 14:15:17 -0700 Subject: [PATCH 32/55] remove unused output_scalers argument --- external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py | 2 +- external/fv3fit/fv3fit/pytorch/predict.py | 6 ------ 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py index 6268378dad..c46c4dcdef 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py @@ -119,7 +119,7 @@ def fit_loop( def apply_to_tuple_mapping(func): # not sure why, but tensorflow doesn't like parsing - # apply_to_tuple(apply_to_maping(func)), so we do it manually + # apply_to_tuple(apply_to_mapping(func)), so we do it manually def wrapped(*tuple_of_mapping): return tuple( {name: func(value) for name, value in mapping.items()} diff --git a/external/fv3fit/fv3fit/pytorch/predict.py b/external/fv3fit/fv3fit/pytorch/predict.py index 1a5b0f3c94..439adbc3c8 100644 --- a/external/fv3fit/fv3fit/pytorch/predict.py +++ b/external/fv3fit/fv3fit/pytorch/predict.py @@ -9,7 +9,6 @@ Hashable, Iterable, Mapping, - Optional, Sequence, Tuple, TypeVar, @@ -69,7 +68,6 @@ def __init__( output_variables: Iterable[Hashable], model: nn.Module, scalers: Mapping[str, StandardScaler], - output_scalers: Optional[Mapping[str, StandardScaler]] = None, ): """Initialize the predictor Args: @@ -81,10 +79,6 @@ def __init__( self.output_variables = output_variables self.model = model self.scalers = scalers - if output_scalers is None: - self.output_scalers = output_scalers - else: - self.output_scalers = scalers def predict(self, X: xr.Dataset) -> xr.Dataset: """ From c8f8d1c45ba70754d60f65bde648ff6ca4a61631 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Wed, 7 Sep 2022 14:18:09 -0700 Subject: [PATCH 33/55] add cyclegan symbols to fv3fit.pytorch namespace --- external/fv3fit/fv3fit/pytorch/__init__.py | 5 +++++ external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/external/fv3fit/fv3fit/pytorch/__init__.py b/external/fv3fit/fv3fit/pytorch/__init__.py index 883c1e5df0..2e3a34ca29 100644 --- a/external/fv3fit/fv3fit/pytorch/__init__.py +++ b/external/fv3fit/fv3fit/pytorch/__init__.py @@ -11,6 +11,11 @@ AutoencoderHyperparameters, GeneratorConfig, DiscriminatorConfig, + CycleGANHyperparameters, + CycleGANTrainingConfig, + CycleGANNetworkConfig, + CycleGAN, + CycleGANModule, ) from .optimizer import OptimizerConfig from .activation import ActivationConfig diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py b/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py index bd8ca6d5f6..87862a8186 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/__init__.py @@ -7,4 +7,4 @@ from .discriminator import DiscriminatorConfig from .generator import GeneratorConfig from .cyclegan_trainer import CycleGANNetworkConfig -from .reloadable import CycleGAN +from .reloadable import CycleGAN, CycleGANModule From 233224c95b5d35b1182321f1bbeb8d2d89c53936 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Wed, 7 Sep 2022 15:19:47 -0700 Subject: [PATCH 34/55] add missing public symbols to fv3fit.data --- external/fv3fit/fv3fit/data/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/external/fv3fit/fv3fit/data/__init__.py b/external/fv3fit/fv3fit/data/__init__.py index c9d742838e..91196b618c 100644 --- a/external/fv3fit/fv3fit/data/__init__.py +++ b/external/fv3fit/fv3fit/data/__init__.py @@ -1,3 +1,4 @@ from .base import TFDatasetLoader, tfdataset_loader_from_dict, register_tfdataset_loader from .batches import FromBatches -from .tfdataset import WindowedZarrLoader, VariableConfig +from .tfdataset import WindowedZarrLoader, VariableConfig, CycleGANLoader +from .synthetic import SyntheticNoise, SyntheticWaves From 188978acfb536fbbe16d5da4d9122471560537bd Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Wed, 7 Sep 2022 15:23:03 -0700 Subject: [PATCH 35/55] further elaborate on the shape of data returned by synthetic loader --- external/fv3fit/fv3fit/data/synthetic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/external/fv3fit/fv3fit/data/synthetic.py b/external/fv3fit/fv3fit/data/synthetic.py index edcc7d61e4..e13ed509af 100644 --- a/external/fv3fit/fv3fit/data/synthetic.py +++ b/external/fv3fit/fv3fit/data/synthetic.py @@ -94,8 +94,8 @@ def open_tfdataset( variable_names: names of variables to include when loading data Returns: dataset containing requested variables, each record is a mapping from - variable name to variable value, and each value is a tensor whose - first dimension is the batch dimension + variable name to variable value, and each value is a tensor + of shape [nbatch, nsamples, ntime, nx, ny(, nz)] """ if self.wave_type == "sinusoidal": func = np.sin From 4221f55ed28fc905682e02946ff67509e266b74b Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Wed, 7 Sep 2022 15:23:59 -0700 Subject: [PATCH 36/55] fix bug in CycleGANLoader where batch_size option is ignored --- external/fv3fit/fv3fit/data/tfdataset.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/external/fv3fit/fv3fit/data/tfdataset.py b/external/fv3fit/fv3fit/data/tfdataset.py index 69712c0d66..f64bbee946 100644 --- a/external/fv3fit/fv3fit/data/tfdataset.py +++ b/external/fv3fit/fv3fit/data/tfdataset.py @@ -73,8 +73,10 @@ def open_tfdataset( ) -> tf.data.Dataset: datasets = [] for config in self.domain_configs: - datasets.append(config.open_tfdataset(local_download_path, variable_names)) - return tf.data.Dataset.zip(tuple(datasets)) + datasets.append( + config.open_tfdataset(local_download_path, variable_names).unbatch() + ) + return tf.data.Dataset.zip(tuple(datasets)).batch(batch_size=self.batch_size) @classmethod def from_dict(cls, d: dict) -> "CycleGANLoader": From 01bc917a6747220be23788d5a8c4788930066568 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Wed, 7 Sep 2022 15:36:18 -0700 Subject: [PATCH 37/55] refactor get_Xy_dataset into more composeable get_Xy_map_fn --- external/fv3fit/fv3fit/pytorch/cyclegan/train.py | 8 ++++---- external/fv3fit/fv3fit/pytorch/graph/train.py | 13 ++++++------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train.py index c63a403acd..065cf87308 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train.py @@ -16,7 +16,7 @@ cast, ) from .network import Generator -from fv3fit.pytorch.graph.train import get_Xy_dataset +from fv3fit.pytorch.graph.train import get_Xy_map_fn from fv3fit._shared.scaler import ( get_standard_scaler_mapping, get_mapping_standard_scale_func, @@ -94,18 +94,18 @@ def train_autoencoder( scalers = get_standard_scaler_mapping(sample_batch) mapping_scale_func = get_mapping_standard_scale_func(scalers) - get_state = curry(get_Xy_dataset)( + get_state = get_Xy_map_fn( state_variables=hyperparameters.state_variables, n_dims=6, # [batch, time, tile, x, y, z] mapping_scale_func=mapping_scale_func, ) if validation_batches is not None: - val_state = get_state(data=validation_batches) + val_state = validation_batches.map(get_state) else: val_state = None - train_state = get_state(data=train_batches) + train_state = train_batches.map(get_state) train_model = build_model( hyperparameters.generator, n_state=next(iter(train_state)).shape[-1] diff --git a/external/fv3fit/fv3fit/pytorch/graph/train.py b/external/fv3fit/fv3fit/pytorch/graph/train.py index 2b84db8adf..1b6acb6189 100644 --- a/external/fv3fit/fv3fit/pytorch/graph/train.py +++ b/external/fv3fit/fv3fit/pytorch/graph/train.py @@ -2,7 +2,6 @@ import numpy as np import dataclasses from fv3fit._shared.training_config import Hyperparameters -from toolz.functoolz import curry from fv3fit.pytorch.predict import PytorchAutoregressor from fv3fit.pytorch.graph.mpg_unet import MPGraphUNetConfig from fv3fit.pytorch.graph.unet import GraphUNetConfig @@ -89,18 +88,19 @@ def train_graph_model( scalers = get_standard_scaler_mapping(sample) mapping_scale_func = get_mapping_standard_scale_func(scalers) - get_state = curry(get_Xy_dataset)( + get_Xy = get_Xy_map_fn( state_variables=hyperparameters.state_variables, n_dims=6, # [batch, time, tile, x, y, z] mapping_scale_func=mapping_scale_func, ) if validation_batches is not None: - val_state = get_state(data=validation_batches).unbatch() + val_state = validation_batches.map(get_Xy).unbatch() else: val_state = None - train_state = get_state(data=train_batches).unbatch() + train_state = train_batches.map(get_Xy).unbatch() + sample = next(iter(train_state)) train_model = build_model( hyperparameters.graph_network, n_state=sample.shape[-1], nx=sample.shape[3], @@ -133,11 +133,10 @@ def build_model(graph_network, n_state: int, nx: int): ) -def get_Xy_dataset( +def get_Xy_map_fn( state_variables: Sequence[str], n_dims: int, mapping_scale_func: Callable[[Mapping[str, np.ndarray]], Mapping[str, np.ndarray]], - data: tf.data.Dataset, ): """ Given a tf.data.Dataset with mappings from variable name to samples @@ -165,4 +164,4 @@ def map_fn(data): data = tf.concat(data, axis=-1) return data - return data.map(map_fn) + return map_fn From 99e52905e2830fab6d6430ce95f6d800b6839ef4 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Wed, 7 Sep 2022 15:42:22 -0700 Subject: [PATCH 38/55] update typing on PytorchAutoregressor to use str keys for scalers --- .../fv3fit/fv3fit/pytorch/cyclegan/train.py | 6 +-- external/fv3fit/fv3fit/pytorch/graph/train.py | 5 +-- external/fv3fit/fv3fit/pytorch/predict.py | 40 +++++++++---------- 3 files changed, 20 insertions(+), 31 deletions(-) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train.py index 065cf87308..6096d1b4ec 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/train.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train.py @@ -8,19 +8,15 @@ from fv3fit._shared import register_training_function from typing import ( - Hashable, List, - Mapping, Optional, Tuple, - cast, ) from .network import Generator from fv3fit.pytorch.graph.train import get_Xy_map_fn from fv3fit._shared.scaler import ( get_standard_scaler_mapping, get_mapping_standard_scale_func, - StandardScaler, ) from toolz import curry import logging @@ -135,7 +131,7 @@ def train_autoencoder( input_variables=hyperparameters.state_variables, output_variables=hyperparameters.state_variables, model=train_model, - scalers=cast(Mapping[Hashable, StandardScaler], scalers), + scalers=scalers, ) return predictor diff --git a/external/fv3fit/fv3fit/pytorch/graph/train.py b/external/fv3fit/fv3fit/pytorch/graph/train.py index 1b6acb6189..2bf663e925 100644 --- a/external/fv3fit/fv3fit/pytorch/graph/train.py +++ b/external/fv3fit/fv3fit/pytorch/graph/train.py @@ -11,20 +11,17 @@ from fv3fit._shared.scaler import ( get_standard_scaler_mapping, get_mapping_standard_scale_func, - StandardScaler, ) from ..system import DEVICE from fv3fit._shared import register_training_function from typing import ( Callable, - Hashable, List, Optional, Sequence, Set, Mapping, - cast, Union, ) from fv3fit.tfdataset import select_keys, ensure_nd, apply_to_mapping @@ -117,7 +114,7 @@ def train_graph_model( predictor = PytorchAutoregressor( state_variables=hyperparameters.state_variables, model=train_model, - scalers=cast(Mapping[Hashable, StandardScaler], scalers), + scalers=scalers, ) return predictor diff --git a/external/fv3fit/fv3fit/pytorch/predict.py b/external/fv3fit/fv3fit/pytorch/predict.py index b315d6a1c2..b7e046899b 100644 --- a/external/fv3fit/fv3fit/pytorch/predict.py +++ b/external/fv3fit/fv3fit/pytorch/predict.py @@ -1,4 +1,4 @@ -from fv3fit._shared.predictor import Predictor, Reloadable +from fv3fit._shared.predictor import Reloadable, Predictor from .._shared.scaler import StandardScaler import numpy as np import torch @@ -37,7 +37,7 @@ def load(cls: Type[L], f: IO[bytes]) -> L: ... -def dump_mapping(mapping: Mapping[Hashable, StandardScaler], f: IO[bytes]) -> None: +def dump_mapping(mapping: Mapping[str, StandardScaler], f: IO[bytes]) -> None: """ Serialize a mapping to a zip file. """ @@ -47,7 +47,7 @@ def dump_mapping(mapping: Mapping[Hashable, StandardScaler], f: IO[bytes]) -> No value.dump(f_dump) -def load_mapping(cls: Type[L], f: IO[bytes]) -> Mapping[Hashable, L]: +def load_mapping(cls: Type[L], f: IO[bytes]) -> Mapping[str, L]: """ Load a mapping from a zip file. """ @@ -67,7 +67,7 @@ def __init__( input_variables: Iterable[Hashable], output_variables: Iterable[Hashable], model: nn.Module, - scalers: Mapping[Hashable, StandardScaler], + scalers: Mapping[str, StandardScaler], ): """Initialize the predictor Args: @@ -83,7 +83,6 @@ def __init__( def predict(self, X: xr.Dataset) -> xr.Dataset: """ Predict an output xarray dataset from an input xarray dataset. - Note that returned datasets include the initial state of the prediction, where by definition the model will have perfect skill. @@ -105,7 +104,7 @@ def pack_to_tensor(self, X: xr.Dataset) -> torch.Tensor: packed = _pack_to_tensor( ds=X, timesteps=0, - state_variables=self.input_variables, + state_variables=tuple(str(item) for item in self.input_variables), scalers=self.scalers, ) # dimensions are [time, tile, x, y, z], @@ -118,7 +117,7 @@ def unpack_tensor(self, data: torch.Tensor) -> xr.Dataset: data = torch.reshape(data, (-1, 6) + tuple(data.shape[1:])) return _unpack_tensor( data, - varnames=self.output_variables, + varnames=tuple(str(item) for item in self.output_variables), scalers=self.scalers, dims=["time", "tile", "x", "y", "z"], ) @@ -147,9 +146,9 @@ class PytorchAutoregressor(Reloadable): def __init__( self, - state_variables: Iterable[Hashable], + state_variables: Iterable[str], model: nn.Module, - scalers: Mapping[Hashable, StandardScaler], + scalers: Mapping[str, StandardScaler], ): """Initialize the predictor Args: @@ -254,18 +253,15 @@ def get_config(self) -> Mapping[str, Any]: return {"state_variables": self.state_variables} -class PytorchDumpable(Protocol): +class _PytorchDumpable(Protocol): _MODEL_FILENAME: str _SCALERS_FILENAME: str _CONFIG_FILENAME: str - scalers: Mapping[Hashable, StandardScaler] + scalers: Mapping[str, StandardScaler] model: torch.nn.Module def __init__( - self, - model: torch.nn.Module, - scalers: Mapping[Hashable, StandardScaler], - **kwargs, + self, model: torch.nn.Module, scalers: Mapping[str, StandardScaler], **kwargs, ): ... @@ -279,7 +275,7 @@ def get_config(self) -> Mapping[str, Any]: ... -def _load_pytorch(cls: Type[PytorchDumpable], path: str): +def _load_pytorch(cls: Type[_PytorchDumpable], path: str): """Load a serialized model from a directory.""" fs = vcm.get_fs(path) model_filename = os.path.join(path, cls._MODEL_FILENAME) @@ -293,7 +289,7 @@ def _load_pytorch(cls: Type[PytorchDumpable], path: str): return obj -def _dump_pytorch(obj: PytorchDumpable, path: str) -> None: +def _dump_pytorch(obj: _PytorchDumpable, path: str) -> None: fs = vcm.get_fs(path) model_filename = os.path.join(path, obj._MODEL_FILENAME) with fs.open(model_filename, "wb") as f: @@ -307,8 +303,8 @@ def _dump_pytorch(obj: PytorchDumpable, path: str) -> None: def _pack_to_tensor( ds: xr.Dataset, timesteps: int, - state_variables: Iterable[Hashable], - scalers: Mapping[Hashable, StandardScaler], + state_variables: Iterable[str], + scalers: Mapping[str, StandardScaler], ) -> torch.Tensor: """ Packs the dataset into a tensor to be used by the pytorch model. @@ -365,8 +361,8 @@ def _pack_to_tensor( def _unpack_tensor( data: torch.Tensor, - varnames: Iterable[Hashable], - scalers: Mapping[Hashable, StandardScaler], + varnames: Iterable[str], + scalers: Mapping[str, StandardScaler], dims: Sequence[Hashable], ) -> xr.Dataset: i_feature = 0 @@ -382,7 +378,7 @@ def _unpack_tensor( else: n_features = 1 var_data = data[..., i_feature] - var_data = scalers[varname].denormalize(var_data) + var_data = scalers[varname].denormalize(var_data.to("cpu").numpy()) data_vars[varname] = xr.DataArray( data=var_data, dims=dims[: len(var_data.shape)] ) From 2b121eb92e99dc25eb09808cc8c470b921ac0bc4 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Wed, 7 Sep 2022 15:43:18 -0700 Subject: [PATCH 39/55] add MPS support to fv3fit.pytorch.system for mac M1s --- external/fv3fit/fv3fit/pytorch/system.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/external/fv3fit/fv3fit/pytorch/system.py b/external/fv3fit/fv3fit/pytorch/system.py index fdc2c15b6e..54fc7b1635 100644 --- a/external/fv3fit/fv3fit/pytorch/system.py +++ b/external/fv3fit/fv3fit/pytorch/system.py @@ -1,3 +1,14 @@ +import torch.backends import torch +import os -DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +if os.environ.get("TORCH_CPU_ONLY", False): + DEVICE = torch.device("cpu") +else: + DEVICE = torch.device( + "cuda:0" + if torch.cuda.is_available() + else "mps" + if torch.backends.mps.is_available() + else "cpu" + ) From 7c662a2eb8757afc1ab1e1ff035cfda871ba51ef Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Wed, 7 Sep 2022 15:50:40 -0700 Subject: [PATCH 40/55] fix apparent bug in apply_to_tuple --- .../fv3fit/keras/_models/convolutional.py | 6 ++--- external/fv3fit/fv3fit/tfdataset.py | 10 ++++---- external/fv3fit/tests/test_tfdataset_ops.py | 23 +++++++++++++++++++ 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/external/fv3fit/fv3fit/keras/_models/convolutional.py b/external/fv3fit/fv3fit/keras/_models/convolutional.py index 1dd18398ef..d87658b45e 100644 --- a/external/fv3fit/fv3fit/keras/_models/convolutional.py +++ b/external/fv3fit/fv3fit/keras/_models/convolutional.py @@ -17,7 +17,7 @@ full_standard_normalized_input, ) from fv3fit import tfdataset -from fv3fit.tfdataset import select_keys, ensure_nd, apply_to_mapping, apply_to_tuple +from fv3fit.tfdataset import select_keys, ensure_nd, apply_to_mapping logger = logging.getLogger(__file__) @@ -88,8 +88,8 @@ def map_fn(data): # clipping of inputs happens within the keras model, we don't clip at the # data layer so that the model still takes full-sized inputs when used # in production - x = apply_to_tuple(append_halos_tensor(n_halo))( - select_keys(input_variables, data) + x = select_keys( + input_variables, apply_to_mapping(append_halos_tensor(n_halo))(data) ) y = select_keys(output_variables, clip_function(data)) return x, y diff --git a/external/fv3fit/fv3fit/tfdataset.py b/external/fv3fit/fv3fit/tfdataset.py index 9cf1b22d9c..35dd93e54c 100644 --- a/external/fv3fit/fv3fit/tfdataset.py +++ b/external/fv3fit/fv3fit/tfdataset.py @@ -27,11 +27,13 @@ def apply_to_mapping( return {name: tensor_func(tensor) for name, tensor in data.items()} -@curry def apply_to_tuple( - tensor_func: Callable[[T_in], T_out], data: Tuple[T_in, ...] -) -> Tuple[T_out, ...]: - return tuple(tensor_func(tensor) for tensor in data) + tensor_func: Callable[[T_in], T_out], +) -> Callable[[Tuple[T_in, ...]], Tuple[T_out, ...]]: + def wrapped(*data): + return tuple(tensor_func(tensor) for tensor in data) + + return wrapped def sequence_size(seq): diff --git a/external/fv3fit/tests/test_tfdataset_ops.py b/external/fv3fit/tests/test_tfdataset_ops.py index d1efe5087b..5e29baae0e 100644 --- a/external/fv3fit/tests/test_tfdataset_ops.py +++ b/external/fv3fit/tests/test_tfdataset_ops.py @@ -224,3 +224,26 @@ def transform(batch): result = next(tf_ds.as_numpy_iterator()) assert isinstance(result, dict) np.testing.assert_equal(result["a"], batches[0]["a"] * 2) + + +def test_tuple_map(): + """ + External package test demonstrating that for map operations on tuples + of functions, tuple entries are passed as independent arguments + and must be collected with *args. + """ + + def generator(): + for entry in [(1, 1), (2, 2), (3, 3)]: + yield entry + + dataset = tf.data.Dataset.from_generator( + generator, output_types=(tf.int32, tf.int32) + ) + + def map_fn(x, y): + return x * 2, y * 3 + + mapped = dataset.map(map_fn) + out = list(mapped) + assert out == [(2, 3), (4, 6), (6, 9)] From c049b85d35cac084f0cd57bf9b8d48ee5af3c85c Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Wed, 7 Sep 2022 16:45:33 -0700 Subject: [PATCH 41/55] add cyclegan to SPECIAL_TRAINING_TYPES --- external/fv3fit/tests/training/test_train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/external/fv3fit/tests/training/test_train.py b/external/fv3fit/tests/training/test_train.py index 65c84653c0..47dd72897a 100644 --- a/external/fv3fit/tests/training/test_train.py +++ b/external/fv3fit/tests/training/test_train.py @@ -43,6 +43,7 @@ "min_max_novelty_detector", "ocsvm_novelty_detector", "autoencoder", + "cyclegan", ] From 8a006e99c9dd99b13435da863b773fca13abfb64 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Wed, 7 Sep 2022 16:57:12 -0700 Subject: [PATCH 42/55] update autoencoder test and training function to channels first, fix test --- .../fv3fit/fv3fit/pytorch/cyclegan/train_autoencoder.py | 4 ++-- external/fv3fit/fv3fit/pytorch/predict.py | 7 +------ external/fv3fit/tests/training/test_autoencoder.py | 2 +- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train_autoencoder.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train_autoencoder.py index b6c6678cf5..5f932bd2ba 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/train_autoencoder.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train_autoencoder.py @@ -57,8 +57,8 @@ def define_noisy_input(data: tf.Tensor, stdev=0.1) -> Tuple[tf.Tensor, tf.Tensor def flatten_dims(dataset: tf.data.Dataset) -> tf.data.Dataset: - """Transform [batch, time, tile, x, y, z] to [sample, x, y, z]""" - return dataset.unbatch().unbatch().unbatch() + """Transform [batch, time, tile, x, y, z] to [sample, tile, x, y, z]""" + return dataset.unbatch().unbatch() @register_training_function("autoencoder", AutoencoderHyperparameters) diff --git a/external/fv3fit/fv3fit/pytorch/predict.py b/external/fv3fit/fv3fit/pytorch/predict.py index 439adbc3c8..7d864d2526 100644 --- a/external/fv3fit/fv3fit/pytorch/predict.py +++ b/external/fv3fit/fv3fit/pytorch/predict.py @@ -106,16 +106,11 @@ def pack_to_tensor(self, X: xr.Dataset) -> torch.Tensor: scalers=self.scalers, ) # dimensions are [time, tile, x, y, z], - # we must combine [time, tile] into one sample dimension - reshaped = torch.reshape( - packed, (packed.shape[0] * packed.shape[1],) + tuple(packed.shape[2:]), - ) # torch expects channels before x, y so we have to transpose - transposed = reshaped.permute([0, 3, 1, 2]) + transposed = packed.permute([0, 1, 4, 2, 3]) return transposed def unpack_tensor(self, data: torch.Tensor) -> xr.Dataset: - data = torch.reshape(data, (-1, 6) + tuple(data.shape[1:])) return _unpack_tensor( data.permute([0, 1, 3, 4, 2]), # convert from channels (z) first to last varnames=tuple(str(item) for item in self.output_variables), diff --git a/external/fv3fit/tests/training/test_autoencoder.py b/external/fv3fit/tests/training/test_autoencoder.py index 12d218308c..73216dcce4 100644 --- a/external/fv3fit/tests/training/test_autoencoder.py +++ b/external/fv3fit/tests/training/test_autoencoder.py @@ -85,7 +85,7 @@ def test_autoencoder(tmpdir): generator=fv3fit.pytorch.GeneratorConfig( n_convolutions=1, n_resnet=3, max_filters=32 ), - training_loop=TrainingConfig(n_epoch=5, samples_per_batch=10), + training_loop=TrainingConfig(n_epoch=10, samples_per_batch=2), optimizer_config=fv3fit.pytorch.OptimizerConfig(name="Adam",), noise_amount=0.5, ) From fd02e1031237523131acb498efa401af7ebdd727 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 12 Sep 2022 09:57:25 -0700 Subject: [PATCH 43/55] updated training to closer match paper, model trained but predictor still misbehaving --- external/fv3fit/fv3fit/_shared/io.py | 10 +- .../pytorch/cyclegan/cyclegan_trainer.py | 113 +++++++++++++++--- .../fv3fit/pytorch/cyclegan/generator.py | 47 ++++++-- .../fv3fit/pytorch/cyclegan/reloadable.py | 12 +- .../fv3fit/pytorch/cyclegan/train_cyclegan.py | 30 +++-- external/fv3fit/fv3fit/pytorch/predict.py | 8 +- external/fv3fit/fv3fit/train.py | 5 +- projects/cyclegan/Makefile | 7 ++ projects/cyclegan/download_data.py | 13 ++ projects/cyclegan/evaluate.py | 35 ++++++ projects/cyclegan/train-data.yaml | 24 ++++ projects/cyclegan/training.yaml | 40 +++++++ 12 files changed, 299 insertions(+), 45 deletions(-) create mode 100644 projects/cyclegan/Makefile create mode 100644 projects/cyclegan/download_data.py create mode 100644 projects/cyclegan/evaluate.py create mode 100644 projects/cyclegan/train-data.yaml create mode 100644 projects/cyclegan/training.yaml diff --git a/external/fv3fit/fv3fit/_shared/io.py b/external/fv3fit/fv3fit/_shared/io.py index 487416afeb..742d026967 100644 --- a/external/fv3fit/fv3fit/_shared/io.py +++ b/external/fv3fit/fv3fit/_shared/io.py @@ -1,4 +1,4 @@ -from typing import MutableMapping, Callable, Type +from typing import MutableMapping, Callable, Type, TypeVar, cast import os import fsspec import warnings @@ -11,6 +11,8 @@ DEPCRECATED_NAMES = {"packed-keras": "007bc80046c29ae3e2a535689b5c68e46cf2c613"} +R = TypeVar("R", bound=Type[Reloadable]) + class _Register: """Class to register new I/O names @@ -19,15 +21,15 @@ class _Register: def __init__(self) -> None: self._model_types: MutableMapping[str, Type[Reloadable]] = {} - def __call__(self, name: str) -> Callable[[Type[Reloadable]], Type[Reloadable]]: + def __call__(self, name: str) -> Callable[[R], R]: if name in self._model_types: raise ValueError( f"{name} is already registered by {self._model_types[name]}." ) else: - return partial(self._register_class, name=name) + return cast(Callable[[R], R], partial(self._register_class, name=name)) - def _register_class(self, cls: Type[Reloadable], name: str) -> Type[Reloadable]: + def _register_class(self, cls: R, name: str) -> R: self._model_types[name] = cls return cls diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py b/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py index e07bb85948..fb90d5ccf7 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py @@ -1,5 +1,5 @@ import random -from typing import Dict, List, Mapping, Tuple, Optional +from typing import Dict, List, Literal, Mapping, Tuple, Optional import tensorflow as tf from fv3fit._shared.scaler import StandardScaler from .reloadable import CycleGAN, CycleGANModule @@ -12,6 +12,14 @@ from fv3fit.pytorch.system import DEVICE import itertools import numpy as np +import wandb +import io +import PIL + +try: + import matplotlib.pyplot as plt +except ModuleNotFoundError: + plt = None @dataclasses.dataclass @@ -61,10 +69,10 @@ class CycleGANNetworkConfig: discriminator_weight: float = 1.0 def build( - self, n_state: int, n_batch: int, state_variables, scalers + self, n_state: int, nx: int, ny: int, n_batch: int, state_variables, scalers ) -> "CycleGANTrainer": - generator_a_to_b = self.generator.build(n_state) - generator_b_to_a = self.generator.build(n_state) + generator_a_to_b = self.generator.build(n_state, nx=nx, ny=ny) + generator_b_to_a = self.generator.build(n_state, nx=nx, ny=ny) discriminator_a = self.discriminator.build(n_state) discriminator_b = self.discriminator.build(n_state) optimizer_generator = self.generator_optimizer.instance( @@ -75,14 +83,16 @@ def build( optimizer_discriminator = self.discriminator_optimizer.instance( itertools.chain(discriminator_a.parameters(), discriminator_b.parameters()) ) + model = CycleGANModule( + generator_a_to_b=generator_a_to_b, + generator_b_to_a=generator_b_to_a, + discriminator_a=discriminator_a, + discriminator_b=discriminator_b, + ).to(DEVICE) + # init_weights(model) return CycleGANTrainer( cycle_gan=CycleGAN( - model=CycleGANModule( - generator_a_to_b=generator_a_to_b, - generator_b_to_a=generator_b_to_a, - discriminator_a=discriminator_a, - discriminator_b=discriminator_b, - ).to(DEVICE), + model=model, state_variables=state_variables, scalers=_merge_scaler_mappings(scalers), ), @@ -99,6 +109,50 @@ def build( ) +def init_weights( + net: torch.nn.Module, + init_type: Literal["normal", "xavier", "kaiming", "orthogonal"] = "normal", + init_gain: float = 0.02, +): + """Initialize network weights. + + Args: + net: network to be initialized + init_type: the name of an initialization method + init_gain: scaling factor for normal, xavier and orthogonal. + + Note: We use 'normal' in the original pix2pix and CycleGAN paper. + But xavier and kaiming might work better for some applications. + Feel free to try yourself. + """ + + def init_func(module): # define the initialization function + classname = module.__class__.__name__ + if hasattr(module, "weight") and classname == "Conv2d": + if init_type == "normal": + torch.nn.init.normal_(module.weight.data, 0.0, init_gain) + elif init_type == "xavier": + torch.nn.init.xavier_normal_(module.weight.data, gain=init_gain) + elif init_type == "kaiming": + torch.nn.init.kaiming_normal_(module.weight.data, a=0, mode="fan_in") + elif init_type == "orthogonal": + torch.nn.init.orthogonal_(module.weight.data, gain=init_gain) + else: + raise NotImplementedError( + "initialization method [%s] is not implemented" % init_type + ) + if hasattr(module, "bias") and module.bias is not None: + torch.nn.init.constant_(module.bias.data, 0.0) + elif classname.find("BatchNorm2d") != -1: + # BatchNorm Layer's weight is not a matrix; + # only normal distribution applies. + torch.nn.init.normal_(module.weight.data, 1.0, init_gain) + torch.nn.init.constant_(module.bias.data, 0.0) + + print("initialize network with %s" % init_type) + net.apply(init_func) # apply the initialization function + + def _merge_scaler_mappings( scaler_tuple: Tuple[Mapping[str, StandardScaler], Mapping[str, StandardScaler]] ) -> Mapping[str, StandardScaler]: @@ -256,6 +310,7 @@ def evaluate_on_dataset( stats_gen_b = StatsCollector(n_dims_keep) real_a: np.ndarray real_b: np.ndarray + reported_plot = False for real_a, real_b in dataset: # for now there is no time-evolution-based loss, so we fold the time # dimension into the sample dimension @@ -267,14 +322,40 @@ def evaluate_on_dataset( ) stats_real_a.observe(real_a) stats_real_b.observe(real_b) - gen_b: torch.Tensor = self.generator_a_to_b( + gen_b: np.ndarray = self.generator_a_to_b( torch.as_tensor(real_a).float().to(DEVICE) - ) - gen_a: torch.Tensor = self.generator_b_to_a( + ).detach().cpu().numpy() + gen_a: np.ndarray = self.generator_b_to_a( torch.as_tensor(real_b).float().to(DEVICE) - ) - stats_gen_a.observe(gen_a.detach().cpu().numpy()) - stats_gen_b.observe(gen_b.detach().cpu().numpy()) + ).detach().cpu().numpy() + stats_gen_a.observe(gen_a) + stats_gen_b.observe(gen_b) + if not reported_plot and plt is not None: + report = {} + for i_tile in range(6): + fig, ax = plt.subplots(2, 2, figsize=(8, 7)) + im = ax[0, 0].pcolormesh(real_a[0, i_tile, 0, :, :]) + plt.colorbar(im, ax=ax[0, 0]) + ax[0, 0].set_title("a_real") + im = ax[1, 0].pcolormesh(real_b[0, i_tile, 0, :, :]) + plt.colorbar(im, ax=ax[1, 0]) + ax[1, 0].set_title("b_real") + im = ax[0, 1].pcolormesh(gen_b[0, i_tile, 0, :, :]) + plt.colorbar(im, ax=ax[0, 1]) + ax[0, 1].set_title("b_gen") + im = ax[1, 1].pcolormesh(gen_a[0, i_tile, 0, :, :]) + plt.colorbar(im, ax=ax[1, 1]) + ax[1, 1].set_title("a_gen") + plt.tight_layout() + buf = io.BytesIO() + plt.savefig(buf, format="png") + plt.close() + buf.seek(0) + report[f"tile_{i_tile}_example"] = wandb.Image( + PIL.Image.open(buf), caption=f"Tile {i_tile} Example", + ) + wandb.log(report) + reported_plot = True metrics = { # "r2_mean_b_against_real_a": get_r2(stats_real_a.mean, stats_gen_b.mean), "r2_mean_a": get_r2(stats_real_a.mean, stats_gen_a.mean), diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/generator.py b/external/fv3fit/fv3fit/pytorch/cyclegan/generator.py index 30947e9ab5..ec13b65838 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/generator.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/generator.py @@ -5,6 +5,7 @@ from .modules import ( ConvBlock, ConvolutionFactory, + FoldTileDimension, single_tile_convolution, relu_activation, ResnetBlock, @@ -40,7 +41,11 @@ class GeneratorConfig: max_filters: int = 256 def build( - self, channels: int, convolution: ConvolutionFactory = single_tile_convolution, + self, + channels: int, + nx: int, + ny: int, + convolution: ConvolutionFactory = single_tile_convolution, ): """ Args: @@ -55,13 +60,26 @@ def build( kernel_size=self.kernel_size, max_filters=self.max_filters, convolution=convolution, + nx=nx, + ny=ny, ) +class GeographicBias(nn.Module): + def __init__(self, channels: int, nx: int, ny: int): + super().__init__() + self.bias = nn.Parameter(torch.zeros(1, channels, nx, ny)) + + def forward(self, x): + return x + self.bias + + class Generator(nn.Module): def __init__( self, channels: int, + nx: int, + ny: int, n_convolutions: int, n_resnet: int, kernel_size: int, @@ -129,12 +147,14 @@ def up(in_channels: int, out_channels: int): min_filters = int(max_filters / 2 ** n_convolutions) self._first_conv = nn.Sequential( + FoldTileDimension(nn.ReflectionPad2d(3)), convolution( kernel_size=7, in_channels=channels, out_channels=min_filters, - padding="same", + padding=0, ), + FoldTileDimension(nn.InstanceNorm2d(min_filters)), relu_activation()(), ) @@ -146,12 +166,17 @@ def up(in_channels: int, out_channels: int): in_channels=min_filters, ) - self._out_conv = convolution( - kernel_size=7, - in_channels=min_filters, - out_channels=channels, - padding="same", + self._out_conv = nn.Sequential( + FoldTileDimension(nn.ReflectionPad2d(3)), + convolution( + kernel_size=7, + in_channels=min_filters, + out_channels=channels, + padding=0, + ), ) + self._input_bias = GeographicBias(channels=channels, nx=nx, ny=ny) + self._output_bias = GeographicBias(channels=channels, nx=nx, ny=ny) def forward(self, inputs: torch.Tensor) -> torch.Tensor: """ @@ -161,10 +186,12 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: Returns: tensor of shape [batch, tile, channels, x, y] """ - x = self._first_conv(inputs) + x = self._input_bias(inputs) + x = self._first_conv(x) x = self._encoder_decoder(x) - outputs: torch.Tensor = self._out_conv(x) - return outputs + x = self._out_conv(x) + x = self._output_bias(x) + return x class SymmetricEncoderDecoder(nn.Module): diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/reloadable.py b/external/fv3fit/fv3fit/pytorch/cyclegan/reloadable.py index aee5d459ff..bc9a18b94f 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/reloadable.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/reloadable.py @@ -85,11 +85,15 @@ def load(cls, path: str) -> "CycleGAN": """Load a serialized model from a directory.""" return _load_pytorch(cls, path) + def to(self, device) -> "CycleGAN": + model = self.model.to(device) + return CycleGAN(model, self.scalers, **self.get_config()) + def dump(self, path: str) -> None: _dump_pytorch(self, path) def get_config(self): - return {} + return {"state_variables": self.state_variables} def pack_to_tensor(self, ds: xr.Dataset, domain: str = "a") -> torch.Tensor: """ @@ -116,7 +120,11 @@ def pack_to_tensor(self, ds: xr.Dataset, domain: str = "a") -> torch.Tensor: tensor = _pack_to_tensor( ds=ds, timesteps=0, state_variables=self.state_variables, scalers=scalers, ) - return tensor.permute([0, 1, 4, 2, 3]) + # TODO: this permute order is needed, but it does not seem like it should be. + # when we replace the model with a linear one, the output only matches the + # input if we flip the x and y dimension. + # investigate why this is necessary + return tensor.permute([0, 1, 4, 3, 2]) def unpack_tensor(self, data: torch.Tensor, domain: str = "b") -> xr.Dataset: """ diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py index c46c4dcdef..09e1c10d4b 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py @@ -1,3 +1,4 @@ +import random from fv3fit._shared.hyperparameters import Hyperparameters import dataclasses import tensorflow as tf @@ -88,10 +89,16 @@ def fit_loop( to the model, should be unbatched and have dimensions [time, tile, z, x, y] """ - train_data = train_data.shuffle(buffer_size=self.shuffle_buffer_size).batch( - self.samples_per_batch - ) - train_data = tfds.as_numpy(train_data) + if self.shuffle_buffer_size > 1: + train_data = train_data.shuffle(buffer_size=self.shuffle_buffer_size) + train_data = train_data.batch(self.samples_per_batch) + train_data_numpy = tfds.as_numpy(train_data) + train_states = [] + for batch_state in train_data_numpy: + state_a = torch.as_tensor(batch_state[0]).float().to(DEVICE) + state_b = torch.as_tensor(batch_state[1]).float().to(DEVICE) + train_states.append((state_a, state_b)) + train_example_as_dataset = tfds.as_numpy(train_data.take(1).cache()) if validation_data is not None: if self.validation_batch_size is None: validation_batch_size = sequence_size(validation_data) @@ -102,10 +109,12 @@ def fit_loop( for i in range(1, self.n_epoch + 1): logger.info("starting epoch %d", i) train_losses = [] - for batch_state in train_data: - state_a = torch.as_tensor(batch_state[0]).float().to(DEVICE) - state_b = torch.as_tensor(batch_state[1]).float().to(DEVICE) + # for batch_state in train_data_numpy: + # state_a = torch.as_tensor(batch_state[0]).float().to(DEVICE) + # state_b = torch.as_tensor(batch_state[1]).float().to(DEVICE) + for state_a, state_b in train_states: train_losses.append(train_model.train_on_batch(state_a, state_b)) + random.shuffle(train_states) train_loss = { name: np.mean([data[name] for data in train_losses]) for name in train_losses[0] @@ -115,6 +124,8 @@ def fit_loop( if validation_data is not None: val_loss = train_model.evaluate_on_dataset(validation_data) logger.info("val_loss %s", val_loss) + else: + train_model.evaluate_on_dataset(train_example_as_dataset) def apply_to_tuple_mapping(func): @@ -194,8 +205,11 @@ def train_cyclegan( train_state = train_batches.map(get_Xy) + sample: tf.Tensor = next(iter(train_state))[0] train_model = hyperparameters.network.build( - n_state=next(iter(train_state))[0].shape[-1], + nx=sample.shape[-3], + ny=sample.shape[-2], + n_state=sample.shape[-1], n_batch=hyperparameters.training.samples_per_batch, state_variables=hyperparameters.state_variables, scalers=scalers, diff --git a/external/fv3fit/fv3fit/pytorch/predict.py b/external/fv3fit/fv3fit/pytorch/predict.py index 7d864d2526..f363b3ff34 100644 --- a/external/fv3fit/fv3fit/pytorch/predict.py +++ b/external/fv3fit/fv3fit/pytorch/predict.py @@ -276,7 +276,7 @@ def _load_pytorch(cls: Type[_PytorchDumpable], path: str): fs = vcm.get_fs(path) model_filename = os.path.join(path, cls._MODEL_FILENAME) with fs.open(model_filename, "rb") as f: - model = torch.load(f) + model = torch.load(f).to(DEVICE) with fs.open(os.path.join(path, cls._SCALERS_FILENAME), "rb") as f: scalers = load_mapping(StandardScaler, f) with open(os.path.join(path, cls._CONFIG_FILENAME), "r") as f: @@ -319,8 +319,10 @@ def _pack_to_tensor( tensor of shape [window, time, tile, x, y, feature] """ - expected_dims = ("time", "tile", "x", "y", "z") - ds = ds.transpose(*expected_dims) + expected_dims: Tuple[str, ...] = ("time", "tile", "x", "y") + if "z" in ds.dims: + expected_dims += ("z",) + ds = ds.transpose(..., *expected_dims) if timesteps > 0: n_times = ds.time.size n_windows = int((n_times - 1) // timesteps) diff --git a/external/fv3fit/fv3fit/train.py b/external/fv3fit/fv3fit/train.py index 11cc193b67..cb26b5d044 100644 --- a/external/fv3fit/fv3fit/train.py +++ b/external/fv3fit/fv3fit/train.py @@ -175,12 +175,13 @@ def main(args, unknown_args=None): if __name__ == "__main__": - logger.setLevel(logging.INFO) + LOGLEVEL = os.environ.get("LOGLEVEL", "INFO").upper() + logger.setLevel(LOGLEVEL) parser = get_parser() args, unknown_args = parser.parse_known_args() os.makedirs("artifacts", exist_ok=True) logging.basicConfig( - level=logging.INFO, + level=LOGLEVEL, format="%(asctime)s [%(levelname)s] %(filename)s::L%(lineno)d : %(message)s", handlers=[ logging.FileHandler("artifacts/training.log"), diff --git a/projects/cyclegan/Makefile b/projects/cyclegan/Makefile new file mode 100644 index 0000000000..ba275a5de4 --- /dev/null +++ b/projects/cyclegan/Makefile @@ -0,0 +1,7 @@ + + +train: + python3 -m fv3fit.train training.yaml train-data.yaml output + +data: + python3 download_data.py diff --git a/projects/cyclegan/download_data.py b/projects/cyclegan/download_data.py new file mode 100644 index 0000000000..e691788f77 --- /dev/null +++ b/projects/cyclegan/download_data.py @@ -0,0 +1,13 @@ +import xarray as xr + +if __name__ == "__main__": + c48 = xr.open_zarr( + "gs://vcm-ml-experiments/spencerc/2021-05-24/n2f-25km-baseline-unperturbed/fv3gfs_run/atmos_dt_atmos.zarr" # noqa: E501 + ) + c48 = c48.drop_vars([name for name in c48.data_vars if name != "h500"]) + c384 = xr.open_zarr( + "gs://vcm-ml-raw-flexible-retention/2021-01-04-1-year-C384-FV3GFS-simulations/unperturbed/C384-to-C48-diagnostics/atmos_8xdaily_coarse.zarr" # noqa: E501 + ) + c384 = c384.drop_vars([name for name in c384.data_vars if name != "h500"]) + c48.to_zarr("c48_baseline.zarr") + c384.to_zarr("c384_baseline.zarr") diff --git a/projects/cyclegan/evaluate.py b/projects/cyclegan/evaluate.py new file mode 100644 index 0000000000..4439642d0f --- /dev/null +++ b/projects/cyclegan/evaluate.py @@ -0,0 +1,35 @@ +import fv3fit +from matplotlib import pyplot as plt +import xarray as xr + +if __name__ == "__main__": + cyclegan: fv3fit.pytorch.CycleGAN = fv3fit.load("output").to("cpu") + c48_real = ( + xr.open_zarr("c48_baseline.zarr") + .rename({"grid_xt": "x", "grid_yt": "y"}) + .isel(time=range(0, 100, 10)) + ) + c384_real = ( + xr.open_zarr("c384_baseline.zarr") + .rename({"grid_xt": "x", "grid_yt": "y"}) + .isel(time=range(0, 100, 10)) + ) + c48_gen = cyclegan.predict(c384_real) + c384_gen = cyclegan.predict(c48_real, reverse=True) + i_tile = 3 + for i_tile in range(1): + for i in range(1): + import pdb + + pdb.set_trace() + fig, ax = plt.subplots(2, 2, figsize=(10, 8)) + c48_real.h500.isel(time=i, tile=i_tile).plot(ax=ax[0, 0]) + ax[0, 0].set_title("c48_real") + c384_real.h500.isel(time=i, tile=i_tile).plot(ax=ax[1, 0]) + ax[1, 0].set_title("c384_real") + c384_gen.h500.isel(time=i, tile=i_tile).plot(ax=ax[0, 1]) + ax[0, 1].set_title("c384_gen") + c48_gen.h500.isel(time=i, tile=i_tile).plot(ax=ax[1, 1]) + ax[1, 1].set_title("c48_gen") + plt.tight_layout() + plt.show() diff --git a/projects/cyclegan/train-data.yaml b/projects/cyclegan/train-data.yaml new file mode 100644 index 0000000000..871d557d33 --- /dev/null +++ b/projects/cyclegan/train-data.yaml @@ -0,0 +1,24 @@ +batch_size: 500 +domain_configs: + - data_path: c48_baseline.zarr + unstacked_dims: + - time + - tile + - grid_xt + - grid_yt + - z + window_size: 1 + default_variable_config: + times: window + n_windows: 500 # 5808 + - data_path: c384_baseline.zarr + unstacked_dims: + - time + - tile + - grid_xt + - grid_yt + - z + window_size: 1 + default_variable_config: + times: window + n_windows: 500 # 5808 diff --git a/projects/cyclegan/training.yaml b/projects/cyclegan/training.yaml new file mode 100644 index 0000000000..1cc577a4e1 --- /dev/null +++ b/projects/cyclegan/training.yaml @@ -0,0 +1,40 @@ +model_type: cyclegan +cache: + in_memory: true +hyperparameters: + state_variables: + - h500 + normalization_fit_samples: 50_000 + network: + generator_optimizer: + name: Adam + kwargs: + lr: 0.0002 + discriminator_optimizer: + name: Adam + kwargs: + lr: 0.0002 + generator: + n_convolutions: 2 + n_resnet: 6 + kernel_size: 3 + max_filters: 256 + discriminator: + n_convolutions: 3 + kernel_size: 3 + max_filters: 256 + identity_loss: + loss_type: mae + cycle_loss: + loss_type: mae + gan_loss: + loss_type: mse + identity_weight: 0.5 + cycle_weight: 10.0 + generator_weight: 1.0 + discriminator_weight: 1.0 + training: + n_epoch: 100 + shuffle_buffer_size: 1000 + samples_per_batch: 1 + validation_batch_size: 100 From 37d7230e66e12ebf1721d2c34f0e9e91e94b4ca5 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 12 Sep 2022 10:30:02 -0700 Subject: [PATCH 44/55] switch ReplayBuffer for ImagePool implementation by authors --- .../pytorch/cyclegan/cyclegan_trainer.py | 42 ++----------- .../fv3fit/pytorch/cyclegan/image_pool.py | 59 +++++++++++++++++++ 2 files changed, 63 insertions(+), 38 deletions(-) create mode 100644 external/fv3fit/fv3fit/pytorch/cyclegan/image_pool.py diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py b/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py index e07bb85948..36a9a27d40 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py @@ -1,4 +1,3 @@ -import random from typing import Dict, List, Mapping, Tuple, Optional import tensorflow as tf from fv3fit._shared.scaler import StandardScaler @@ -11,6 +10,7 @@ from fv3fit.pytorch.optimizer import OptimizerConfig from fv3fit.pytorch.system import DEVICE import itertools +from .image_pool import ImagePool import numpy as np @@ -109,41 +109,6 @@ def _merge_scaler_mappings( return scalers -class ReplayBuffer: - - # To reduce model oscillation during training, we update the discriminator - # using a history of generated data instead of the most recently generated data - # according to Shrivastava et al. (2017). - - def __init__(self, max_size=50): - if max_size <= 0: - raise ValueError("max_size must be positive") - self.max_size = max_size - self.data = [] - - def push_and_pop(self, data: torch.Tensor) -> torch.autograd.Variable: - """ - Push data into the buffer and return a random sample of the buffer. - - If there are at least max_size elements in the buffer, the returned sample - is removed from the buffer. - """ - to_return = [] - for element in data.data: - element = torch.unsqueeze(element, 0) - if len(self.data) < self.max_size: - self.data.append(element) - to_return.append(element) - else: - if random.uniform(0, 1) > 0.5: - i = random.randint(0, self.max_size - 1) - to_return.append(self.data[i].clone()) - self.data[i] = element - else: - to_return.append(element) - return torch.autograd.Variable(torch.cat(to_return)) - - class StatsCollector: """ Object to track the mean and standard deviation of sampled arrays. @@ -232,8 +197,9 @@ class CycleGANTrainer: def __post_init__(self): self.target_real: Optional[torch.autograd.Variable] = None self.target_fake: Optional[torch.autograd.Variable] = None - self.fake_a_buffer = ReplayBuffer() - self.fake_b_buffer = ReplayBuffer() + # image pool size of 50 used by Zhu et al. (2017) + self.fake_a_buffer = ImagePool(50) + self.fake_b_buffer = ImagePool(50) self.generator_a_to_b = self.cycle_gan.generator_a_to_b self.generator_b_to_a = self.cycle_gan.generator_b_to_a self.discriminator_a = self.cycle_gan.discriminator_a diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/image_pool.py b/external/fv3fit/fv3fit/pytorch/cyclegan/image_pool.py new file mode 100644 index 0000000000..246d9db0c5 --- /dev/null +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/image_pool.py @@ -0,0 +1,59 @@ +# flake8: noqa +# Taken from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/eb6ae80412e23c09b4317b04d889f1af27526d2d/util/image_pool.py +# Copyright (c) 2017, Jun-Yan Zhu and Taesung Park under a BSD license + +import random +import torch + + +class ImagePool: + """This class implements an image buffer that stores previously generated images. + This buffer enables us to update discriminators using a history of generated images + rather than the ones produced by the latest generators. + """ + + def __init__(self, pool_size): + """Initialize the ImagePool class + Parameters: + pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created + """ + self.pool_size = pool_size + if self.pool_size > 0: # create an empty pool + self.num_imgs = 0 + self.images = [] + + def query(self, images): + """Return an image from the pool. + Parameters: + images: the latest generated images from the generator + Returns images from the buffer. + By 50/100, the buffer will return input images. + By 50/100, the buffer will return images previously stored in the buffer, + and insert the current images to the buffer. + """ + if self.pool_size == 0: # if the buffer size is 0, do nothing + return images + return_images = [] + for image in images: + image = torch.unsqueeze(image.data, 0) + if ( + self.num_imgs < self.pool_size + ): # if the buffer is not full; keep inserting current images to the buffer + self.num_imgs = self.num_imgs + 1 + self.images.append(image) + return_images.append(image) + else: + p = random.uniform(0, 1) + if ( + p > 0.5 + ): # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer + random_id = random.randint( + 0, self.pool_size - 1 + ) # randint is inclusive + tmp = self.images[random_id].clone() + self.images[random_id] = image + return_images.append(tmp) + else: # by another 50% chance, the buffer will return the current image + return_images.append(image) + return_images = torch.cat(return_images, 0) # collect all the images and return + return return_images From 97b4ab65f76219dff7f78ebfe41356b93544cc21 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 12 Sep 2022 11:31:08 -0700 Subject: [PATCH 45/55] uncomment validation statistic outputs --- .../fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py b/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py index 36a9a27d40..b272e80646 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py @@ -242,15 +242,15 @@ def evaluate_on_dataset( stats_gen_a.observe(gen_a.detach().cpu().numpy()) stats_gen_b.observe(gen_b.detach().cpu().numpy()) metrics = { - # "r2_mean_b_against_real_a": get_r2(stats_real_a.mean, stats_gen_b.mean), + "r2_mean_b_against_real_a": get_r2(stats_real_a.mean, stats_gen_b.mean), "r2_mean_a": get_r2(stats_real_a.mean, stats_gen_a.mean), - # "bias_mean_a": np.mean(stats_real_a.mean - stats_gen_a.mean), + "bias_mean_a": np.mean(stats_real_a.mean - stats_gen_a.mean), "r2_mean_b": get_r2(stats_real_b.mean, stats_gen_b.mean), - # "bias_mean_b": np.mean(stats_real_b.mean - stats_gen_b.mean), + "bias_mean_b": np.mean(stats_real_b.mean - stats_gen_b.mean), "r2_std_a": get_r2(stats_real_a.std, stats_gen_a.std), - # "bias_std_a": np.mean(stats_real_a.std - stats_gen_a.std), + "bias_std_a": np.mean(stats_real_a.std - stats_gen_a.std), "r2_std_b": get_r2(stats_real_b.std, stats_gen_b.std), - # "bias_std_b": np.mean(stats_real_b.std - stats_gen_b.std), + "bias_std_b": np.mean(stats_real_b.std - stats_gen_b.std), } return metrics From bd9ae1eb1fbdd05cecaea58b5b2488dd4e118e96 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 12 Sep 2022 11:58:37 -0700 Subject: [PATCH 46/55] add regtest coverage for cyclegan training --- .../pytorch/cyclegan/cyclegan_trainer.py | 4 +- ...test_cyclegan.test_cyclegan_regression.out | 1 + .../fv3fit/tests/training/test_cyclegan.py | 79 +++++++++++++------ 3 files changed, 59 insertions(+), 25 deletions(-) create mode 100644 external/fv3fit/tests/training/_regtest_outputs/test_cyclegan.test_cyclegan_regression.out diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py b/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py index b272e80646..7855c06e57 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py @@ -335,7 +335,7 @@ def train_on_batch( ) # Fake loss - fake_a = self.fake_a_buffer.push_and_pop(fake_a) + fake_a = self.fake_a_buffer.query(fake_a) pred_a_fake = self.discriminator_a(fake_a.detach()) loss_d_a_fake = ( self.gan_loss(pred_a_fake, self.target_fake) * self.discriminator_weight @@ -348,7 +348,7 @@ def train_on_batch( ) # Fake loss - fake_b = self.fake_b_buffer.push_and_pop(fake_b) + fake_b = self.fake_b_buffer.query(fake_b) pred_b_fake = self.discriminator_b(fake_b.detach()) loss_d_b_fake = ( self.gan_loss(pred_b_fake, self.target_fake) * self.discriminator_weight diff --git a/external/fv3fit/tests/training/_regtest_outputs/test_cyclegan.test_cyclegan_regression.out b/external/fv3fit/tests/training/_regtest_outputs/test_cyclegan.test_cyclegan_regression.out new file mode 100644 index 0000000000..6ed9f3f763 --- /dev/null +++ b/external/fv3fit/tests/training/_regtest_outputs/test_cyclegan.test_cyclegan_regression.out @@ -0,0 +1 @@ +[["var_2d", "ff4864715b31d10785262655a308aad8"], ["var_3d", "db01012751eb605db61f0bad64e829d3"]][["var_2d", "8638eb0f0cd834ee1d667c4c74e84ff1"], ["var_3d", "19751999b1ba7f69e4720aae501d14ef"]][["var_2d", "979249b94b538e0bcead00910783544b"], ["var_3d", "0369bcfa8a93c026e650e5591f4cf86b"]][["var_2d", "5a0960aad22de24d278e2ffb1f1d9c44"], ["var_3d", "6d9d7bd4068c4a1f88c6291dcc71d4fa"]] \ No newline at end of file diff --git a/external/fv3fit/tests/training/test_cyclegan.py b/external/fv3fit/tests/training/test_cyclegan.py index 9cc8a5af90..dcd19de2d2 100644 --- a/external/fv3fit/tests/training/test_cyclegan.py +++ b/external/fv3fit/tests/training/test_cyclegan.py @@ -1,3 +1,4 @@ +import json import numpy as np import xarray as xr from typing import Sequence @@ -9,13 +10,13 @@ ) from fv3fit.data import CycleGANLoader, SyntheticWaves, SyntheticNoise import fv3fit.tfdataset -import tensorflow as tf import collections import os import fv3fit.pytorch import fv3fit import matplotlib.pyplot as plt import pytest +import vcm.testing def get_tfdataset(nsamples, nbatch, ntime, nx, nz): @@ -107,7 +108,7 @@ def tfdataset_to_xr_dataset(tfdataset, dims: Sequence[str]): @pytest.mark.skip("test is designed to run manually to visualize results") -def test_cyclegan(tmpdir): +def test_cyclegan_visual(tmpdir): fv3fit.set_random_seed(0) # run the test in a temporary directory to delete artifacts when done os.chdir(tmpdir) @@ -131,9 +132,9 @@ def test_cyclegan(tmpdir): discriminator_optimizer=fv3fit.pytorch.OptimizerConfig( name="Adam", kwargs={"lr": 0.001} ), - # identity_weight=0.01, - # cycle_weight=0.3, - # gan_weight=1.0, + identity_weight=0.01, + cycle_weight=10.0, + generator_weight=1.0, discriminator_weight=0.5, ), training=CycleGANTrainingConfig( @@ -177,24 +178,56 @@ def test_cyclegan(tmpdir): plt.show() -def test_tuple_map(): +def test_cyclegan_regression(tmpdir, regtest): """ - External package test demonstrating that for map operations on tuples - of functions, tuple entries are passed as independent arguments - and must be collected with *args. + If this test fails, uncomment and re-run the manual test above to confirm the + model training is still valid. """ - - def generator(): - for entry in [(1, 1), (2, 2), (3, 3)]: - yield entry - - dataset = tf.data.Dataset.from_generator( - generator, output_types=(tf.int32, tf.int32) + fv3fit.set_random_seed(0) + # run the test in a temporary directory to delete artifacts when done + os.chdir(tmpdir) + # need a larger nx, ny for the sample data here since we're training + # on whether we can autoencode sin waves, and need to resolve full cycles + nx = 32 + sizes = {"nbatch": 1, "ntime": 1, "nx": nx, "nz": 2} + state_variables = ["var_3d", "var_2d"] + train_tfdataset = get_tfdataset(nsamples=5, **sizes) + val_tfdataset = get_tfdataset(nsamples=2, **sizes) + hyperparameters = CycleGANHyperparameters( + state_variables=state_variables, + network=CycleGANNetworkConfig( + generator=fv3fit.pytorch.GeneratorConfig( + n_convolutions=2, n_resnet=5, max_filters=128, kernel_size=3 + ), + generator_optimizer=fv3fit.pytorch.OptimizerConfig( + name="Adam", kwargs={"lr": 0.001} + ), + discriminator=fv3fit.pytorch.DiscriminatorConfig(kernel_size=3), + discriminator_optimizer=fv3fit.pytorch.OptimizerConfig( + name="Adam", kwargs={"lr": 0.001} + ), + identity_weight=0.01, + cycle_weight=10.0, + generator_weight=1.0, + discriminator_weight=0.5, + ), + training=CycleGANTrainingConfig( + n_epoch=2, samples_per_batch=2, validation_batch_size=2 + ), ) - - def map_fn(x, y): - return x * 2, y * 3 - - mapped = dataset.map(map_fn) - out = list(mapped) - assert out == [(2, 3), (4, 6), (6, 9)] + predictor = train_cyclegan(hyperparameters, train_tfdataset, val_tfdataset) + # for test, need one continuous series so we consistently flip sign + real_a = tfdataset_to_xr_dataset( + train_tfdataset.map(lambda a, b: a), dims=["time", "tile", "x", "y", "z"] + ) + real_b = tfdataset_to_xr_dataset( + train_tfdataset.map(lambda a, b: b), dims=["time", "tile", "x", "y", "z"] + ) + output_a = predictor.predict(real_b, reverse=True) + reconstructed_b = predictor.predict(output_a) + output_b = predictor.predict(real_a) + reconstructed_a = predictor.predict(output_b, reverse=True) + regtest.write(json.dumps(vcm.testing.checksum_dataarray_mapping(output_a))) + regtest.write(json.dumps(vcm.testing.checksum_dataarray_mapping(reconstructed_b))) + regtest.write(json.dumps(vcm.testing.checksum_dataarray_mapping(output_b))) + regtest.write(json.dumps(vcm.testing.checksum_dataarray_mapping(reconstructed_a))) From dc718fa6379a817d5e43046444e686d7ca709483 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 12 Sep 2022 12:25:42 -0700 Subject: [PATCH 47/55] remove regtest from cyclegan test, non-deterministic --- ...test_cyclegan.test_cyclegan_regression.out | 1 - .../fv3fit/tests/training/test_cyclegan.py | 25 +++++++++---------- 2 files changed, 12 insertions(+), 14 deletions(-) delete mode 100644 external/fv3fit/tests/training/_regtest_outputs/test_cyclegan.test_cyclegan_regression.out diff --git a/external/fv3fit/tests/training/_regtest_outputs/test_cyclegan.test_cyclegan_regression.out b/external/fv3fit/tests/training/_regtest_outputs/test_cyclegan.test_cyclegan_regression.out deleted file mode 100644 index 6ed9f3f763..0000000000 --- a/external/fv3fit/tests/training/_regtest_outputs/test_cyclegan.test_cyclegan_regression.out +++ /dev/null @@ -1 +0,0 @@ -[["var_2d", "ff4864715b31d10785262655a308aad8"], ["var_3d", "db01012751eb605db61f0bad64e829d3"]][["var_2d", "8638eb0f0cd834ee1d667c4c74e84ff1"], ["var_3d", "19751999b1ba7f69e4720aae501d14ef"]][["var_2d", "979249b94b538e0bcead00910783544b"], ["var_3d", "0369bcfa8a93c026e650e5591f4cf86b"]][["var_2d", "5a0960aad22de24d278e2ffb1f1d9c44"], ["var_3d", "6d9d7bd4068c4a1f88c6291dcc71d4fa"]] \ No newline at end of file diff --git a/external/fv3fit/tests/training/test_cyclegan.py b/external/fv3fit/tests/training/test_cyclegan.py index dcd19de2d2..f25d76e683 100644 --- a/external/fv3fit/tests/training/test_cyclegan.py +++ b/external/fv3fit/tests/training/test_cyclegan.py @@ -1,4 +1,3 @@ -import json import numpy as np import xarray as xr from typing import Sequence @@ -16,7 +15,6 @@ import fv3fit import matplotlib.pyplot as plt import pytest -import vcm.testing def get_tfdataset(nsamples, nbatch, ntime, nx, nz): @@ -178,11 +176,7 @@ def test_cyclegan_visual(tmpdir): plt.show() -def test_cyclegan_regression(tmpdir, regtest): - """ - If this test fails, uncomment and re-run the manual test above to confirm the - model training is still valid. - """ +def test_cyclegan_runs_without_errors(tmpdir, regtest): fv3fit.set_random_seed(0) # run the test in a temporary directory to delete artifacts when done os.chdir(tmpdir) @@ -224,10 +218,15 @@ def test_cyclegan_regression(tmpdir, regtest): train_tfdataset.map(lambda a, b: b), dims=["time", "tile", "x", "y", "z"] ) output_a = predictor.predict(real_b, reverse=True) - reconstructed_b = predictor.predict(output_a) + reconstructed_b = predictor.predict(output_a) # noqa: F841 output_b = predictor.predict(real_a) - reconstructed_a = predictor.predict(output_b, reverse=True) - regtest.write(json.dumps(vcm.testing.checksum_dataarray_mapping(output_a))) - regtest.write(json.dumps(vcm.testing.checksum_dataarray_mapping(reconstructed_b))) - regtest.write(json.dumps(vcm.testing.checksum_dataarray_mapping(output_b))) - regtest.write(json.dumps(vcm.testing.checksum_dataarray_mapping(reconstructed_a))) + reconstructed_a = predictor.predict(output_b, reverse=True) # noqa: F841 + # We can't use regtest because the output is not deterministic between platforms, + # but you can un-comment this and use local-only (do not commit to git) regtest + # outputs when refactoring the code to ensure you don't change results. + # import json + # import vcm.testing + # regtest.write(json.dumps(vcm.testing.checksum_dataarray_mapping(output_a))) + # regtest.write(json.dumps(vcm.testing.checksum_dataarray_mapping(reconstructed_b))) + # regtest.write(json.dumps(vcm.testing.checksum_dataarray_mapping(output_b))) + # regtest.write(json.dumps(vcm.testing.checksum_dataarray_mapping(reconstructed_a))) From 8be0697471945077ee2b7de286de76dce393bf8d Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Fri, 16 Sep 2022 09:50:57 -0700 Subject: [PATCH 48/55] state left wip at last session --- .../pytorch/cyclegan/cyclegan_trainer.py | 2 +- .../fv3fit/pytorch/cyclegan/generator.py | 95 +++++++---- .../fv3fit/pytorch/cyclegan/reloadable.py | 21 +-- .../pytorch/cyclegan/train_autoencoder.py | 10 +- external/fv3fit/fv3fit/pytorch/system.py | 5 + external/fv3fit/fv3fit/wandb.py | 25 ++- .../fv3fit/tests/training/test_autoencoder.py | 46 ++++-- .../fv3fit/tests/training/test_cyclegan.py | 70 ++++++++ projects/cyclegan/evaluate.py | 154 +++++++++++++++--- projects/cyclegan/training.yaml | 2 +- 10 files changed, 344 insertions(+), 86 deletions(-) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py b/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py index 735c609700..0eda5603ed 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py @@ -12,7 +12,7 @@ import itertools from .image_pool import ImagePool import numpy as np -import wandb +from fv3fit import wandb import io import PIL diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/generator.py b/external/fv3fit/fv3fit/pytorch/cyclegan/generator.py index ec13b65838..12d2ef284d 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/generator.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/generator.py @@ -33,12 +33,19 @@ class GeneratorConfig: max_filters: maximum number of filters in any convolutional layer, equal to the number of filters in the final strided convolutional layer and in the resnet blocks + use_geographic_bias: if True, include a layer that adds a trainable bias + vector that is a function of x and y to the input and output of the network + disable_convolutions: if True, ignore all layers other than bias (if enabled). + Useful for debugging and for testing the effect of the + geographic bias layer. """ n_convolutions: int = 3 n_resnet: int = 3 kernel_size: int = 3 max_filters: int = 256 + use_geographic_bias: bool = True + disable_convolutions: bool = False def build( self, @@ -62,6 +69,8 @@ def build( convolution=convolution, nx=nx, ny=ny, + use_geographic_bias=self.use_geographic_bias, + disable_convolutions=self.disable_convolutions, ) @@ -84,11 +93,15 @@ def __init__( n_resnet: int, kernel_size: int, max_filters: int, + use_geographic_bias: bool, + disable_convolutions: bool, convolution: ConvolutionFactory = single_tile_convolution, ): """ Args: channels: number of input and output channels + nx: number of grid points in x direction + ny: number of grid points in y direction n_convolutions: number of strided convolutional layers after the initial convolutional layer and before the residual blocks n_resnet: number of residual blocks @@ -97,6 +110,12 @@ def __init__( max_filters: maximum number of filters in any convolutional layer, equal to the number of filters in the final strided convolutional layer and in the resnet blocks + use_geographic_bias: if True, include a layer that adds a trainable bias + vector that is a function of x and y to the input and output + of the network + disable_convolutions: if True, ignore all layers other than bias + (if enabled). Useful for debugging and for testing the effect + of the geographic bias layer. convolution: factory for creating all convolutional layers used by the network """ @@ -146,37 +165,46 @@ def up(in_channels: int, out_channels: int): min_filters = int(max_filters / 2 ** n_convolutions) - self._first_conv = nn.Sequential( - FoldTileDimension(nn.ReflectionPad2d(3)), - convolution( - kernel_size=7, - in_channels=channels, - out_channels=min_filters, - padding=0, - ), - FoldTileDimension(nn.InstanceNorm2d(min_filters)), - relu_activation()(), - ) - - self._encoder_decoder = SymmetricEncoderDecoder( - down_factory=down, - up_factory=up, - bottom_factory=resnet, - depth=n_convolutions, - in_channels=min_filters, - ) + if disable_convolutions: + main = nn.Identity() + else: + first_conv = nn.Sequential( + FoldTileDimension(nn.ReflectionPad2d(3)), + convolution( + kernel_size=7, + in_channels=channels, + out_channels=min_filters, + padding=0, + ), + FoldTileDimension(nn.InstanceNorm2d(min_filters)), + relu_activation()(), + ) - self._out_conv = nn.Sequential( - FoldTileDimension(nn.ReflectionPad2d(3)), - convolution( - kernel_size=7, + encoder_decoder = SymmetricEncoderDecoder( + down_factory=down, + up_factory=up, + bottom_factory=resnet, + depth=n_convolutions, in_channels=min_filters, - out_channels=channels, - padding=0, - ), - ) - self._input_bias = GeographicBias(channels=channels, nx=nx, ny=ny) - self._output_bias = GeographicBias(channels=channels, nx=nx, ny=ny) + ) + + out_conv = nn.Sequential( + FoldTileDimension(nn.ReflectionPad2d(3)), + convolution( + kernel_size=7, + in_channels=min_filters, + out_channels=channels, + padding=0, + ), + ) + main = nn.Sequential(first_conv, encoder_decoder, out_conv) + self._main = main + if use_geographic_bias: + self._input_bias = GeographicBias(channels=channels, nx=nx, ny=ny) + self._output_bias = GeographicBias(channels=channels, nx=nx, ny=ny) + else: + self._input_bias = nn.Identity() + self._output_bias = nn.Identity() def forward(self, inputs: torch.Tensor) -> torch.Tensor: """ @@ -187,9 +215,12 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: tensor of shape [batch, tile, channels, x, y] """ x = self._input_bias(inputs) - x = self._first_conv(x) - x = self._encoder_decoder(x) - x = self._out_conv(x) + if hasattr(self, "_main"): + x = self._main(x) + else: + x = self._first_conv(x) + x = self._encoder_decoder(x) + x = self._out_conv(x) x = self._output_bias(x) return x diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/reloadable.py b/external/fv3fit/fv3fit/pytorch/cyclegan/reloadable.py index bc9a18b94f..34b06bd04e 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/reloadable.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/reloadable.py @@ -120,11 +120,7 @@ def pack_to_tensor(self, ds: xr.Dataset, domain: str = "a") -> torch.Tensor: tensor = _pack_to_tensor( ds=ds, timesteps=0, state_variables=self.state_variables, scalers=scalers, ) - # TODO: this permute order is needed, but it does not seem like it should be. - # when we replace the model with a linear one, the output only matches the - # input if we flip the x and y dimension. - # investigate why this is necessary - return tensor.permute([0, 1, 4, 3, 2]) + return tensor.permute([0, 1, 4, 2, 3]) def unpack_tensor(self, data: torch.Tensor, domain: str = "b") -> xr.Dataset: """ @@ -166,11 +162,16 @@ def predict(self, X: xr.Dataset, reverse: bool = False) -> xr.Dataset: input_domain, output_domain = "a", "b" tensor = self.pack_to_tensor(X, domain=input_domain) + if reverse: + generator = self.generator_b_to_a + else: + generator = self.generator_a_to_b + n_batch = 100 + outputs = torch.zeros_like(tensor) with torch.no_grad(): - if reverse: - outputs: torch.Tensor = self.generator_b_to_a(tensor) - else: - outputs = self.generator_a_to_b(tensor) - outputs = outputs.reshape(tensor.shape) + for i in range(0, tensor.shape[0], n_batch): + new: torch.Tensor = generator(tensor[i : i + n_batch]) + outputs[i : i + new.shape[0]] = new + # outputs = outputs.reshape(tensor.shape) predicted = self.unpack_tensor(outputs, domain=output_domain) return predicted diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train_autoencoder.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train_autoencoder.py index 5f932bd2ba..c2fa1591ce 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/train_autoencoder.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train_autoencoder.py @@ -97,8 +97,12 @@ def train_autoencoder( train_state = train_batches.map(get_state) + sample: tf.Tensor = next(iter(train_state))[0] train_model = build_model( - hyperparameters.generator, n_state=next(iter(train_state)).shape[-1] + hyperparameters.generator, + nx=sample.shape[-3], + ny=sample.shape[-2], + n_state=sample.shape[-1], ) logging.debug("training with model structure: %s", train_model) @@ -138,5 +142,5 @@ def channels_first(data: tf.Tensor) -> tf.Tensor: return tf.transpose(data, perm=[0, 1, 2, 5, 3, 4]) -def build_model(config: GeneratorConfig, n_state: int) -> Generator: - return config.build(channels=n_state).to(DEVICE) +def build_model(config: GeneratorConfig, n_state: int, nx: int, ny: int) -> Generator: + return config.build(channels=n_state, nx=nx, ny=ny).to(DEVICE) diff --git a/external/fv3fit/fv3fit/pytorch/system.py b/external/fv3fit/fv3fit/pytorch/system.py index 54fc7b1635..daaa785401 100644 --- a/external/fv3fit/fv3fit/pytorch/system.py +++ b/external/fv3fit/fv3fit/pytorch/system.py @@ -1,6 +1,9 @@ import torch.backends import torch import os +import logging + +logger = logging.getLogger(__name__) if os.environ.get("TORCH_CPU_ONLY", False): DEVICE = torch.device("cpu") @@ -12,3 +15,5 @@ if torch.backends.mps.is_available() else "cpu" ) + +logger.info("using device %s for pytorch", DEVICE.type) diff --git a/external/fv3fit/fv3fit/wandb.py b/external/fv3fit/fv3fit/wandb.py index 27e590d153..43c05d8e02 100644 --- a/external/fv3fit/fv3fit/wandb.py +++ b/external/fv3fit/fv3fit/wandb.py @@ -1,4 +1,5 @@ from collections import defaultdict +import contextlib import dataclasses import logging import os @@ -9,10 +10,28 @@ import plotly.graph_objects as go from plotly.subplots import make_subplots from typing import Any, Dict, List, Mapping, Optional +from wandb import Image # noqa: F401 from .tensorboard import plot_to_image from .keras.jacobian import OutputSensitivity +WANDB_ENABLED = True + + +@contextlib.contextmanager +def disable_wandb(): + global WANDB_ENABLED + WANDB_ENABLED = False + try: + yield + finally: + WANDB_ENABLED = True + + +def log(*args, **kwargs): + if WANDB_ENABLED: + wandb.log(*args, **kwargs) + @dataclasses.dataclass class WandBConfig: @@ -64,7 +83,7 @@ def log_to_table(log_key: str, data: Dict[str, Any], index: Optional[List[Any]] df = pd.DataFrame(data, index=index) table = wandb.Table(dataframe=df) - wandb.log({log_key: table}) + log({log_key: table}) def _plot_profiles(target, prediction, name): @@ -95,7 +114,7 @@ def _plot_profiles(target, prediction, name): plt.title(f"Sample {i+1}: {name}") plt.xlabel(f"{units[name]}") plt.ylabel("Level") - wandb.log({f"{name}_sample_{i}": wandb.Image(plot_to_image(fig))}) + log({f"{name}_sample_{i}": wandb.Image(plot_to_image(fig))}) plt.close() @@ -178,4 +197,4 @@ def plot_all_output_sensitivities(jacobians: Mapping[str, OutputSensitivity]): } for out_name, fig in all_plots.items(): - wandb.log({f"jacobian/{out_name}": wandb.Plotly(fig)}) + log({f"jacobian/{out_name}": wandb.Plotly(fig)}) diff --git a/external/fv3fit/tests/training/test_autoencoder.py b/external/fv3fit/tests/training/test_autoencoder.py index 73216dcce4..cd125dabde 100644 --- a/external/fv3fit/tests/training/test_autoencoder.py +++ b/external/fv3fit/tests/training/test_autoencoder.py @@ -4,7 +4,7 @@ from fv3fit.pytorch.cyclegan import AutoencoderHyperparameters, train_autoencoder from fv3fit.pytorch.cyclegan.train_autoencoder import TrainingConfig import pytest -from fv3fit.data.synthetic import SyntheticWaves +from fv3fit.data.synthetic import SyntheticWaves, SyntheticNoise import collections import os import fv3fit.pytorch @@ -38,6 +38,26 @@ def get_synthetic_waves_tfdataset(nsamples, nbatch, ntime, nx, nz): return dataset +def get_noise_tfdataset(nsamples, nbatch, ntime, nx, nz): + """ + Returns a tfdataset of random noise. + + Dataset contains a variable "a" which is vertically-resolved + and "b" which is a scalar. + """ + config = SyntheticNoise( + nsamples=nsamples, + nbatch=nbatch, + ntime=ntime, + nx=nx, + nz=nz, + noise_amplitude=1.0, + scalar_names=["b"], + ) + dataset = config.open_tfdataset(local_download_path=None, variable_names=["a", "b"]) + return dataset + + def tfdataset_to_xr_dataset(tfdataset, dims: Sequence[str]): """ Takes a tfdataset whose samples all have the same shape, and converts @@ -133,7 +153,7 @@ def test_autoencoder_overfit(tmpdir): sizes = {"nbatch": 1, "ntime": 1, "nx": nx, "nz": 2} state_variables = ["a", "b"] # for single-sample overfitting we can use any data, even pure noise - train_tfdataset = get_synthetic_waves_tfdataset(nsamples=1, **sizes) + train_tfdataset = get_noise_tfdataset(nsamples=1, **sizes) train_tfdataset = train_tfdataset.cache() # needed to keep sample identical hyperparameters = AutoencoderHyperparameters( state_variables=state_variables, @@ -147,22 +167,26 @@ def test_autoencoder_overfit(tmpdir): predictor = train_autoencoder( hyperparameters, train_tfdataset, validation_batches=None ) + fv3fit.dump(predictor, str(tmpdir)) + predictor = fv3fit.load(str(tmpdir)) # predict takes xarray datasets, so we have to convert test_xrdataset = tfdataset_to_xr_dataset( train_tfdataset, dims=["time", "tile", "x", "y", "z"] ) + predicted = predictor.predict(test_xrdataset) reference = test_xrdataset # plotting code to uncomment if you'd like to manually check the results: - # import matplotlib.pyplot as plt - # for i in range(6): - # fig, ax = plt.subplots(1, 2) - # vmin = reference["a"][0, i, :, :, 0].values.min() - # vmax = reference["a"][0, i, :, :, 0].values.max() - # ax[0].imshow(reference["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) - # ax[1].imshow(predicted["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) - # plt.tight_layout() - # plt.show() + import matplotlib.pyplot as plt + + for i in range(6): + fig, ax = plt.subplots(1, 2) + vmin = reference["a"][0, i, :, :, 0].values.min() + vmax = reference["a"][0, i, :, :, 0].values.max() + ax[0].imshow(reference["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) + ax[1].imshow(predicted["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) + plt.tight_layout() + plt.show() bias = predicted - reference mean_bias: xr.Dataset = bias.mean() rmse: xr.Dataset = (bias ** 2).mean() diff --git a/external/fv3fit/tests/training/test_cyclegan.py b/external/fv3fit/tests/training/test_cyclegan.py index f25d76e683..1c2052b4aa 100644 --- a/external/fv3fit/tests/training/test_cyclegan.py +++ b/external/fv3fit/tests/training/test_cyclegan.py @@ -15,6 +15,7 @@ import fv3fit import matplotlib.pyplot as plt import pytest +import fv3fit.wandb def get_tfdataset(nsamples, nbatch, ntime, nx, nz): @@ -230,3 +231,72 @@ def test_cyclegan_runs_without_errors(tmpdir, regtest): # regtest.write(json.dumps(vcm.testing.checksum_dataarray_mapping(reconstructed_b))) # regtest.write(json.dumps(vcm.testing.checksum_dataarray_mapping(output_b))) # regtest.write(json.dumps(vcm.testing.checksum_dataarray_mapping(reconstructed_a))) + + +def test_cyclegan_bias_overfit(tmpdir, regtest): + """ + Test that a "cyclegan" with only a bias layer can overfit to one sample + """ + fv3fit.set_random_seed(0) + # run the test in a temporary directory to delete artifacts when done + os.chdir(tmpdir) + # need a larger nx, ny for the sample data here since we're training + # on whether we can autoencode sin waves, and need to resolve full cycles + nx = 32 + sizes = {"nbatch": 1, "ntime": 1, "nx": nx, "nz": 2} + state_variables = ["var_3d", "var_2d"] + train_tfdataset = get_noise_tfdataset(nsamples=1, **sizes) + hyperparameters = CycleGANHyperparameters( + state_variables=state_variables, + network=CycleGANNetworkConfig( + generator=fv3fit.pytorch.GeneratorConfig( + n_convolutions=2, + n_resnet=5, + max_filters=128, + kernel_size=3, + disable_convolutions=True, + use_geographic_bias=True, + ), + generator_optimizer=fv3fit.pytorch.OptimizerConfig( + name="Adam", kwargs={"lr": 0.001} + ), + discriminator=fv3fit.pytorch.DiscriminatorConfig(kernel_size=3), + discriminator_optimizer=fv3fit.pytorch.OptimizerConfig( + name="Adam", kwargs={"lr": 0.001} + ), + identity_weight=0.0, + cycle_weight=10.0, + generator_weight=1.0, + discriminator_weight=0.1, + ), + training=CycleGANTrainingConfig( + n_epoch=100, samples_per_batch=1, validation_batch_size=2 + ), + ) + with fv3fit.wandb.disable_wandb(): + predictor = train_cyclegan( + hyperparameters, train_tfdataset, validation_batches=train_tfdataset + ) + # for test, need one continuous series so we consistently flip sign + real_a = tfdataset_to_xr_dataset( + train_tfdataset.map(lambda a, b: a), dims=["time", "tile", "x", "y", "z"] + ) + real_b = tfdataset_to_xr_dataset( + train_tfdataset.map(lambda a, b: b), dims=["time", "tile", "x", "y", "z"] + ) + output_a = predictor.predict(real_b, reverse=True) + reconstructed_b = predictor.predict(output_a) # noqa: F841 + output_b = predictor.predict(real_a) + reconstructed_a = predictor.predict(output_b, reverse=True) # noqa: F841 + for (real, output), label in ( + ((real_a, reconstructed_a), "reconstructed_a"), + ((real_b, reconstructed_b), "reconstructed_b"), + ((real_a, output_a), "output_a"), + ((real_b, output_b), "output_b"), + ): + bias = output.isel(time=0) - real.isel(time=0) + mean_bias: xr.Dataset = bias.mean() + rmse: xr.Dataset = (bias ** 2).mean() ** 0.5 + for varname in state_variables: + assert np.abs(mean_bias[varname]) < 0.1, label + assert rmse[varname] < 0.1, label diff --git a/projects/cyclegan/evaluate.py b/projects/cyclegan/evaluate.py index 4439642d0f..7257de6107 100644 --- a/projects/cyclegan/evaluate.py +++ b/projects/cyclegan/evaluate.py @@ -1,35 +1,139 @@ +# flake8: noqa + +import random import fv3fit +from fv3fit.pytorch import DEVICE from matplotlib import pyplot as plt import xarray as xr +from vcm.catalog import catalog +import fv3viz +import cartopy.crs as ccrs + if __name__ == "__main__": - cyclegan: fv3fit.pytorch.CycleGAN = fv3fit.load("output").to("cpu") - c48_real = ( + random.seed(0) + grid = catalog["grid/c48"].read() + cyclegan: fv3fit.pytorch.CycleGAN = fv3fit.load("output_good").to(DEVICE) + c48_real: xr.Dataset = ( xr.open_zarr("c48_baseline.zarr") .rename({"grid_xt": "x", "grid_yt": "y"}) - .isel(time=range(0, 100, 10)) - ) - c384_real = ( + .isel(time=slice(-2905 * 2, None, 2)) + ).load() + c384_real: xr.Dataset = ( xr.open_zarr("c384_baseline.zarr") .rename({"grid_xt": "x", "grid_yt": "y"}) - .isel(time=range(0, 100, 10)) + .isel(time=slice(-2905 * 2, None, 2)) + ).load() + c384_gen: xr.Dataset = cyclegan.predict(c48_real) + c48_gen: xr.Dataset = cyclegan.predict(c384_real, reverse=True) + + # for _ in range(3): + # i_time = random.randint(0, c48_real.time.size - 1) + # fig, ax = plt.subplots(2, 2, figsize=(10, 6), subplot_kw={"projection": ccrs.Robinson()}) + # fv3viz.plot_cube(ds=c48_real.isel(time=i_time).merge(grid), var_name="h500", ax=ax[0, 0]) + # ax[0, 0].set_title("c48_real") + # fv3viz.plot_cube(ds=c384_real.isel(time=i_time).merge(grid), var_name="h500", ax=ax[1, 0]) + # ax[1, 0].set_title("c384_real") + # fv3viz.plot_cube(ds=c384_gen.isel(time=i_time).merge(grid), var_name="h500", ax=ax[0, 1]) + # ax[0, 1].set_title("c384_gen") + # fv3viz.plot_cube(ds=c48_gen.isel(time=i_time).merge(grid), var_name="h500", ax=ax[1, 1]) + # ax[1, 1].set_title("c48_gen") + # plt.tight_layout() + # fig, ax = plt.subplots(2, 2, figsize=(10, 6), subplot_kw={"projection": ccrs.Robinson()}) + # fv3viz.plot_cube(ds=c48_real.isel(time=i_time).merge(grid), var_name="h500", ax=ax[0, 0]) + # ax[0, 0].set_title("c48_real") + # fv3viz.plot_cube(ds=c384_real.isel(time=i_time).merge(grid), var_name="h500", ax=ax[1, 0]) + # ax[1, 0].set_title("c384_real") + # fv3viz.plot_cube(ds=(c384_gen - c48_real).isel(time=i_time).merge(grid), var_name="h500", ax=ax[0, 1]) + # ax[0, 1].set_title("c384_gen") + # fv3viz.plot_cube(ds=(c48_gen - c384_real).isel(time=i_time).merge(grid), var_name="h500", ax=ax[1, 1]) + # ax[1, 1].set_title("c48_gen") + # plt.tight_layout() + + c48_real_mean = c48_real.mean("time") + c48_gen_mean = c48_gen.mean("time") + c384_real_mean = c384_real.mean("time") + c384_gen_mean = c384_gen.mean("time") + fig, ax = plt.subplots( + 2, 2, figsize=(10, 6), subplot_kw={"projection": ccrs.Robinson()} + ) + fv3viz.plot_cube(ds=c48_real_mean.merge(grid), var_name="h500", ax=ax[0, 0]) + ax[0, 0].set_title("c48_real") + fv3viz.plot_cube(ds=c384_real_mean.merge(grid), var_name="h500", ax=ax[1, 0]) + ax[1, 0].set_title("c384_real") + fv3viz.plot_cube(ds=c384_gen_mean.merge(grid), var_name="h500", ax=ax[0, 1]) + ax[0, 1].set_title("c384_gen") + fv3viz.plot_cube(ds=c48_gen_mean.merge(grid), var_name="h500", ax=ax[1, 1]) + ax[1, 1].set_title("c48_gen") + plt.tight_layout() + + fig, ax = plt.subplots( + 2, 2, figsize=(10, 6), subplot_kw={"projection": ccrs.Robinson()} + ) + fv3viz.plot_cube(ds=c48_real_mean.merge(grid), var_name="h500", ax=ax[0, 0]) + ax[0, 0].set_title("c48_real") + fv3viz.plot_cube( + ds=(c384_real_mean - c48_real_mean).merge(grid), + var_name="h500", + ax=ax[1, 0], + vmin=-70, + vmax=70, ) - c48_gen = cyclegan.predict(c384_real) - c384_gen = cyclegan.predict(c48_real, reverse=True) - i_tile = 3 - for i_tile in range(1): - for i in range(1): - import pdb - - pdb.set_trace() - fig, ax = plt.subplots(2, 2, figsize=(10, 8)) - c48_real.h500.isel(time=i, tile=i_tile).plot(ax=ax[0, 0]) - ax[0, 0].set_title("c48_real") - c384_real.h500.isel(time=i, tile=i_tile).plot(ax=ax[1, 0]) - ax[1, 0].set_title("c384_real") - c384_gen.h500.isel(time=i, tile=i_tile).plot(ax=ax[0, 1]) - ax[0, 1].set_title("c384_gen") - c48_gen.h500.isel(time=i, tile=i_tile).plot(ax=ax[1, 1]) - ax[1, 1].set_title("c48_gen") - plt.tight_layout() - plt.show() + ax[1, 0].set_title("c384_real") + fv3viz.plot_cube( + ds=(c384_gen_mean - c48_real_mean).merge(grid), + var_name="h500", + ax=ax[0, 1], + vmin=-70, + vmax=70, + ) + ax[0, 1].set_title("c384_gen") + fv3viz.plot_cube( + ds=(c48_gen_mean - c48_real_mean).merge(grid), + var_name="h500", + ax=ax[1, 1], + vmin=-70, + vmax=70, + ) + ax[1, 1].set_title("c48_gen") + plt.tight_layout() + + mse = (c384_real_mean - c384_gen_mean).var() + var = (c384_real_mean).var() + print(mse) + print(var) + print(1.0 - (mse / var)) + + mse = (c384_real_mean - c48_real_mean).var() + var = (c384_real_mean).var() + print(mse) + print(var) + print(1.0 - (mse / var)) + + # c48_real_std = c48_real.std("time") + # c48_gen_std = c48_gen.std("time") + # c384_real_std = c384_real.std("time") + # c384_gen_std = c384_gen.std("time") + # fig, ax = plt.subplots(2, 2, figsize=(10, 6), subplot_kw={"projection": ccrs.Robinson()}) + # fv3viz.plot_cube(ds=c48_real_std.merge(grid), var_name="h500", ax=ax[0, 0]) + # ax[0, 0].set_title("c48_real") + # fv3viz.plot_cube(ds=c384_real_std.merge(grid), var_name="h500", ax=ax[1, 0]) + # ax[1, 0].set_title("c384_real") + # fv3viz.plot_cube(ds=c384_gen_std.merge(grid), var_name="h500", ax=ax[0, 1]) + # ax[0, 1].set_title("c384_gen") + # fv3viz.plot_cube(ds=c48_gen_std.merge(grid), var_name="h500", ax=ax[1, 1]) + # ax[1, 1].set_title("c48_gen") + # plt.tight_layout() + + # fig, ax = plt.subplots(2, 2, figsize=(10, 6), subplot_kw={"projection": ccrs.Robinson()}) + # fv3viz.plot_cube(ds=c48_real_std.merge(grid), var_name="h500", ax=ax[0, 0]) + # ax[0, 0].set_title("c48_real") + # fv3viz.plot_cube(ds=(c384_real_std - c48_real_std).merge(grid), var_name="h500", ax=ax[1, 0], vmin=-50, vmax=50) + # ax[1, 0].set_title("c384_real") + # fv3viz.plot_cube(ds=(c384_gen_std - c48_real_std).merge(grid), var_name="h500", ax=ax[0, 1], vmin=-50, vmax=50) + # ax[0, 1].set_title("c384_gen") + # fv3viz.plot_cube(ds=(c48_gen_std - c48_real_std).merge(grid), var_name="h500", ax=ax[1, 1], vmin=-50, vmax=50) + # ax[1, 1].set_title("c48_gen") + # plt.tight_layout() + + plt.show() diff --git a/projects/cyclegan/training.yaml b/projects/cyclegan/training.yaml index 1cc577a4e1..41ed43b1fa 100644 --- a/projects/cyclegan/training.yaml +++ b/projects/cyclegan/training.yaml @@ -34,7 +34,7 @@ hyperparameters: generator_weight: 1.0 discriminator_weight: 1.0 training: - n_epoch: 100 + n_epoch: 10 shuffle_buffer_size: 1000 samples_per_batch: 1 validation_batch_size: 100 From d10bed9ec5b2963d8049018c31bd6ec4400071c9 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Fri, 16 Sep 2022 13:16:27 -0700 Subject: [PATCH 49/55] delete non-working test --- .../fv3fit/tests/training/test_cyclegan.py | 69 ------------------- 1 file changed, 69 deletions(-) diff --git a/external/fv3fit/tests/training/test_cyclegan.py b/external/fv3fit/tests/training/test_cyclegan.py index 1c2052b4aa..a05c9c89b5 100644 --- a/external/fv3fit/tests/training/test_cyclegan.py +++ b/external/fv3fit/tests/training/test_cyclegan.py @@ -231,72 +231,3 @@ def test_cyclegan_runs_without_errors(tmpdir, regtest): # regtest.write(json.dumps(vcm.testing.checksum_dataarray_mapping(reconstructed_b))) # regtest.write(json.dumps(vcm.testing.checksum_dataarray_mapping(output_b))) # regtest.write(json.dumps(vcm.testing.checksum_dataarray_mapping(reconstructed_a))) - - -def test_cyclegan_bias_overfit(tmpdir, regtest): - """ - Test that a "cyclegan" with only a bias layer can overfit to one sample - """ - fv3fit.set_random_seed(0) - # run the test in a temporary directory to delete artifacts when done - os.chdir(tmpdir) - # need a larger nx, ny for the sample data here since we're training - # on whether we can autoencode sin waves, and need to resolve full cycles - nx = 32 - sizes = {"nbatch": 1, "ntime": 1, "nx": nx, "nz": 2} - state_variables = ["var_3d", "var_2d"] - train_tfdataset = get_noise_tfdataset(nsamples=1, **sizes) - hyperparameters = CycleGANHyperparameters( - state_variables=state_variables, - network=CycleGANNetworkConfig( - generator=fv3fit.pytorch.GeneratorConfig( - n_convolutions=2, - n_resnet=5, - max_filters=128, - kernel_size=3, - disable_convolutions=True, - use_geographic_bias=True, - ), - generator_optimizer=fv3fit.pytorch.OptimizerConfig( - name="Adam", kwargs={"lr": 0.001} - ), - discriminator=fv3fit.pytorch.DiscriminatorConfig(kernel_size=3), - discriminator_optimizer=fv3fit.pytorch.OptimizerConfig( - name="Adam", kwargs={"lr": 0.001} - ), - identity_weight=0.0, - cycle_weight=10.0, - generator_weight=1.0, - discriminator_weight=0.1, - ), - training=CycleGANTrainingConfig( - n_epoch=100, samples_per_batch=1, validation_batch_size=2 - ), - ) - with fv3fit.wandb.disable_wandb(): - predictor = train_cyclegan( - hyperparameters, train_tfdataset, validation_batches=train_tfdataset - ) - # for test, need one continuous series so we consistently flip sign - real_a = tfdataset_to_xr_dataset( - train_tfdataset.map(lambda a, b: a), dims=["time", "tile", "x", "y", "z"] - ) - real_b = tfdataset_to_xr_dataset( - train_tfdataset.map(lambda a, b: b), dims=["time", "tile", "x", "y", "z"] - ) - output_a = predictor.predict(real_b, reverse=True) - reconstructed_b = predictor.predict(output_a) # noqa: F841 - output_b = predictor.predict(real_a) - reconstructed_a = predictor.predict(output_b, reverse=True) # noqa: F841 - for (real, output), label in ( - ((real_a, reconstructed_a), "reconstructed_a"), - ((real_b, reconstructed_b), "reconstructed_b"), - ((real_a, output_a), "output_a"), - ((real_b, output_b), "output_b"), - ): - bias = output.isel(time=0) - real.isel(time=0) - mean_bias: xr.Dataset = bias.mean() - rmse: xr.Dataset = (bias ** 2).mean() ** 0.5 - for varname in state_variables: - assert np.abs(mean_bias[varname]) < 0.1, label - assert rmse[varname] < 0.1, label From 954db8ecd66956610d821096bada5e1830a2811d Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 26 Sep 2022 14:36:00 -0700 Subject: [PATCH 50/55] add halo updates to cyclegan --- .../pytorch/cyclegan/cyclegan_trainer.py | 20 +- .../fv3fit/pytorch/cyclegan/discriminator.py | 6 +- .../fv3fit/pytorch/cyclegan/generator.py | 26 ++- .../fv3fit/fv3fit/pytorch/cyclegan/modules.py | 210 ++++++++++++++++++ 4 files changed, 245 insertions(+), 17 deletions(-) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py b/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py index 1efab33f21..1c4f2c4eb8 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/cyclegan_trainer.py @@ -5,6 +5,7 @@ import torch from .generator import GeneratorConfig from .discriminator import DiscriminatorConfig +from .modules import single_tile_convolution, halo_convolution import dataclasses from fv3fit.pytorch.loss import LossConfig from fv3fit.pytorch.optimizer import OptimizerConfig @@ -64,6 +65,7 @@ class CycleGANNetworkConfig: discriminator: "DiscriminatorConfig" = dataclasses.field( default_factory=lambda: DiscriminatorConfig() ) + convolution_type: str = "conv2d" identity_loss: LossConfig = dataclasses.field(default_factory=LossConfig) cycle_loss: LossConfig = dataclasses.field(default_factory=LossConfig) gan_loss: LossConfig = dataclasses.field(default_factory=LossConfig) @@ -75,10 +77,20 @@ class CycleGANNetworkConfig: def build( self, n_state: int, nx: int, ny: int, n_batch: int, state_variables, scalers ) -> "CycleGANTrainer": - generator_a_to_b = self.generator.build(n_state, nx=nx, ny=ny) - generator_b_to_a = self.generator.build(n_state, nx=nx, ny=ny) - discriminator_a = self.discriminator.build(n_state) - discriminator_b = self.discriminator.build(n_state) + if self.convolution_type == "conv2d": + convolution = single_tile_convolution + elif self.convolution_type == "halo_conv2d": + convolution = halo_convolution + else: + raise ValueError(f"convolution_type {self.convolution_type} not supported") + generator_a_to_b = self.generator.build( + n_state, nx=nx, ny=ny, convolution=convolution + ) + generator_b_to_a = self.generator.build( + n_state, nx=nx, ny=ny, convolution=convolution + ) + discriminator_a = self.discriminator.build(n_state, convolution=convolution) + discriminator_b = self.discriminator.build(n_state, convolution=convolution) optimizer_generator = self.generator_optimizer.instance( itertools.chain( generator_a_to_b.parameters(), generator_b_to_a.parameters() diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/discriminator.py b/external/fv3fit/fv3fit/pytorch/cyclegan/discriminator.py index 3361382eaf..bbba9bf2f4 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/discriminator.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/discriminator.py @@ -92,7 +92,7 @@ def __init__( in_channels=min_filters * 2 ** (i - 1), out_channels=min_filters * 2 ** i, convolution_factory=curry(convolution)( - kernel_size=kernel_size, stride=2, padding=1 + kernel_size=kernel_size, stride=2, padding="same" ), activation_factory=leakyrelu_activation( negative_slope=0.2, inplace=True @@ -103,7 +103,9 @@ def __init__( final_conv = ConvBlock( in_channels=max_filters, out_channels=max_filters, - convolution_factory=curry(convolution)(kernel_size=kernel_size), + convolution_factory=curry(convolution)( + kernel_size=kernel_size, padding="same" + ), activation_factory=leakyrelu_activation(negative_slope=0.2, inplace=True), ) patch_output = convolution( diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/generator.py b/external/fv3fit/fv3fit/pytorch/cyclegan/generator.py index eafa7a7ade..211fd2007c 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/generator.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/generator.py @@ -43,6 +43,7 @@ class GeneratorConfig: n_convolutions: int = 3 n_resnet: int = 3 kernel_size: int = 3 + strided_kernel_size: int = 4 max_filters: int = 256 use_geographic_bias: bool = True disable_convolutions: bool = False @@ -57,6 +58,8 @@ def build( """ Args: channels: number of input channels + nx: number of x grid points + ny: number of y grid points convolution: factory for creating all convolutional layers used by the network """ @@ -65,6 +68,7 @@ def build( n_convolutions=self.n_convolutions, n_resnet=self.n_resnet, kernel_size=self.kernel_size, + strided_kernel_size=self.strided_kernel_size, max_filters=self.max_filters, convolution=convolution, nx=nx, @@ -96,6 +100,7 @@ def __init__( n_convolutions: int, n_resnet: int, kernel_size: int, + strided_kernel_size: int, max_filters: int, use_geographic_bias: bool, disable_convolutions: bool, @@ -109,8 +114,9 @@ def __init__( n_convolutions: number of strided convolutional layers after the initial convolutional layer and before the residual blocks n_resnet: number of residual blocks - kernel_size: size of convolutional kernels in the strided convolutions - and resnet blocks + kernel_size: size of convolutional kernels in the resnet blocks + strided_kernel_size: size of convolutional kernels in the + strided convolutions max_filters: maximum number of filters in any convolutional layer, equal to the number of filters in the final strided convolutional layer and in the resnet blocks @@ -135,7 +141,7 @@ def resnet(in_channels: int, out_channels: int): ResnetBlock( channels=in_channels, convolution_factory=curry(convolution)( - kernel_size=3, padding="same" + kernel_size=kernel_size, padding="same" ), activation_factory=relu_activation(), ) @@ -148,7 +154,7 @@ def down(in_channels: int, out_channels: int): in_channels=in_channels, out_channels=out_channels, convolution_factory=curry(convolution)( - kernel_size=3, stride=2, padding=1 + kernel_size=strided_kernel_size, stride=2, padding="same" ), activation_factory=relu_activation(), ) @@ -158,10 +164,10 @@ def up(in_channels: int, out_channels: int): in_channels=in_channels, out_channels=out_channels, convolution_factory=curry(convolution)( - kernel_size=kernel_size, + kernel_size=strided_kernel_size, stride=2, - padding=1, - output_padding=1, + padding="same", + output_padding=0, stride_type="transpose", ), activation_factory=relu_activation(), @@ -173,12 +179,11 @@ def up(in_channels: int, out_channels: int): main = nn.Identity() else: first_conv = nn.Sequential( - FoldTileDimension(nn.ReflectionPad2d(3)), convolution( kernel_size=7, in_channels=channels, out_channels=min_filters, - padding=0, + padding="same", ), FoldTileDimension(nn.InstanceNorm2d(min_filters)), relu_activation()(), @@ -193,12 +198,11 @@ def up(in_channels: int, out_channels: int): ) out_conv = nn.Sequential( - FoldTileDimension(nn.ReflectionPad2d(3)), convolution( kernel_size=7, in_channels=min_filters, out_channels=channels, - padding=0, + padding="same", ), ) main = nn.Sequential(first_conv, encoder_decoder, out_conv) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/modules.py b/external/fv3fit/fv3fit/pytorch/cyclegan/modules.py index a0367f5317..0f21a182f7 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/modules.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/modules.py @@ -1,4 +1,5 @@ import logging +import functools from typing import Callable, Literal, Protocol, Union import torch.nn as nn @@ -144,6 +145,215 @@ def single_tile_convolution( return FoldTileDimension(conv) +def halo_convolution( + in_channels: int, + out_channels: int, + kernel_size: int, + padding: Union[str, int] = 0, + output_padding: int = 0, + stride: int = 1, + stride_type: Literal["regular", "transpose"] = "regular", + bias: bool = True, +) -> ConvolutionFactory: + """ + Construct a convolutional layer that appends halo data before applying conv2d. + + Layer takes in and returns tensors of shape [batch, tile, channels, x, y]. + + Args: + kernel_size: size of the convolution kernel + padding: padding to apply to the input, should be an integer or "same" + output_padding: argument used for transpose convolution + stride: stride of the convolution + stride_type: type of stride, one of "regular" or "transpose" + bias: whether to include a bias vector in the produced layers + """ + if padding == "same": + if stride_type == "transpose": + padding = int((kernel_size - 1) // 2 * stride) + else: + padding = int((kernel_size - 1) // 2) + elif isinstance(padding, str): + raise ValueError(f'padding must be integer or "same", got: {padding}') + append = AppendHalos(n_halo=padding) + conv = single_tile_convolution( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=0, + output_padding=output_padding, + stride=stride, + stride_type=stride_type, + bias=bias, + ) + if stride_type == "transpose": + # have to crop halo points from the output, as pytorch has no option to + # only output a subset of the domain for ConvTranspose2d + conv = nn.Sequential( + conv, Crop(n_halo=padding * stride + int(kernel_size - 1) // 2) + ) + return BreakOnOp(nn.Sequential(append, conv)) + + +class Crop(nn.Module): + def __init__(self, n_halo): + super(Crop, self).__init__() + self.n_halo = n_halo + + def forward(self, x): + return x[..., self.n_halo : -self.n_halo, self.n_halo : -self.n_halo] + + +class BreakOnOp(nn.Module): + """ + Module which asserts that the shape of its input does not change. + """ + + def __init__(self, op: nn.Module): + super(BreakOnOp, self).__init__() + self._op = op + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + output: torch.Tensor = self._op(inputs) + # print(inputs.shape, output.shape) + # if output.shape[-1] != inputs.shape[-1]: + # print(self._op) + # import pdb;pdb.set_trace() + return output + + +def cpu_only(method): + """ + Decorator to mark a method as only being supported on the CPU. + """ + + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + original_device = args[0].device + args = [arg.cpu() if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = { + k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in kwargs.items() + } + return method(self, *args, **kwargs).to(original_device) + + return wrapper + + +class AppendHalos(nn.Module): + + """ + Module which appends horizontal halos to the input tensor. + + Args: + n_halo: size of the halo to append + """ + + def __init__(self, n_halo: int): + super(AppendHalos, self).__init__() + self.n_halo = n_halo + + def extra_repr(self) -> str: + return super().extra_repr() + f"n_halo={self.n_halo}" + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """ + Args: + inputs: tensor of shape [batch, tile, channel, x, y] + """ + corner = torch.zeros_like(inputs[:, 0, :, : self.n_halo, : self.n_halo]) + if self.n_halo > 0: + with_halo = [] + for _ in range(6): + tile = [] + for _ in range(3): + column = [] + for _ in range(3): + column.append([None, None, None]) # row + tile.append(column) + with_halo.append(tile) + + for i_tile in range(6): + with_halo[i_tile][1][1] = inputs[:, i_tile, :, :, :] + with_halo[i_tile][0][0] = corner + with_halo[i_tile][0][2] = corner + with_halo[i_tile][2][0] = corner + with_halo[i_tile][2][2] = corner + # we must make data contiguous after rotating 90 degrees because + # the MPS backend doesn't properly manage strides when concatenating + # arrays + if i_tile % 2 == 0: # even tile + # south edge + with_halo[i_tile][0][1] = torch.rot90( + inputs[ + :, (i_tile - 2) % 6, :, :, -self.n_halo : + ], # write tile 4 to tile 0 + k=-1, # relative rotation of tile 0 with respect to tile 5 + dims=(2, 3), + ).contiguous() + # west edge + with_halo[i_tile][1][0] = inputs[ + :, (i_tile - 1) % 6, :, :, -self.n_halo : + ] # write tile 5 to tile 0 + # east edge + with_halo[i_tile][2][1] = inputs[ + :, + (i_tile + 1) % 6, + :, + : self.n_halo, + :, # write tile 1 to tile 0 + ] + # north edge + with_halo[i_tile][1][2] = torch.rot90( + inputs[ + :, (i_tile + 2) % 6, :, : self.n_halo, : + ], # write tile 2 to tile 0 + k=1, # relative rotation of tile 0 with respect to tile 2 + dims=(2, 3), + ).contiguous() + else: # odd tile + # south edge + with_halo[i_tile][0][1] = inputs[ + :, + (i_tile - 1) % 6, + :, + -self.n_halo :, + :, # write tile 0 to tile 1 + ] + # west edge + with_halo[i_tile][1][0] = torch.rot90( + inputs[ + :, (i_tile - 2) % 6, :, -self.n_halo :, : + ], # write tile 5 to tile 1 + k=1, # relative rotation of tile 1 with respect to tile 5 + dims=(2, 3), + ).contiguous() + # east edge + with_halo[i_tile][2][1] = torch.rot90( + inputs[ + :, (i_tile + 2) % 6, :, :, : self.n_halo + ], # write tile 3 to tile 1 + k=-1, # relative rotation of tile 1 with respect to tile 3 + dims=(2, 3), + ).contiguous() + # north edge + with_halo[i_tile][1][2] = inputs[ + :, (i_tile + 1) % 6, :, :, : self.n_halo + ] # write tile 2 to tile 1 + + for i_tile in range(6): + for i_col in range(3): + with_halo[i_tile][i_col] = torch.cat( + tensors=with_halo[i_tile][i_col], dim=-1 + ) + with_halo[i_tile] = torch.cat(tensors=with_halo[i_tile], dim=-2) + with_halo = torch.stack(tensors=with_halo, dim=1) + + else: + with_halo = inputs + + return with_halo + + class ResnetBlock(nn.Module): """ Residual network block as defined in He et al. 2016, From bf9330c7293c796e24783c4cb08fdf307230fe3c Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 26 Sep 2022 14:38:22 -0700 Subject: [PATCH 51/55] revert unneeded changes from main branch --- .../pytorch/cyclegan/test_autoencoder.py | 22 ++++++++----------- external/fv3fit/fv3fit/train.py | 5 ++--- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/test_autoencoder.py b/external/fv3fit/fv3fit/pytorch/cyclegan/test_autoencoder.py index 9a068f50cb..3f1b1254ff 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/test_autoencoder.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/test_autoencoder.py @@ -167,26 +167,22 @@ def test_autoencoder_overfit(tmpdir): predictor = train_autoencoder( hyperparameters, train_tfdataset, validation_batches=None ) - fv3fit.dump(predictor, str(tmpdir)) - predictor = fv3fit.load(str(tmpdir)) # predict takes xarray datasets, so we have to convert test_xrdataset = tfdataset_to_xr_dataset( train_tfdataset, dims=["time", "tile", "x", "y", "z"] ) - predicted = predictor.predict(test_xrdataset) reference = test_xrdataset # plotting code to uncomment if you'd like to manually check the results: - import matplotlib.pyplot as plt - - for i in range(6): - fig, ax = plt.subplots(1, 2) - vmin = reference["a"][0, i, :, :, 0].values.min() - vmax = reference["a"][0, i, :, :, 0].values.max() - ax[0].imshow(reference["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) - ax[1].imshow(predicted["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) - plt.tight_layout() - plt.show() + # import matplotlib.pyplot as plt + # for i in range(6): + # fig, ax = plt.subplots(1, 2) + # vmin = reference["a"][0, i, :, :, 0].values.min() + # vmax = reference["a"][0, i, :, :, 0].values.max() + # ax[0].imshow(reference["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) + # ax[1].imshow(predicted["a"][0, i, :, :, 0].values, vmin=vmin, vmax=vmax) + # plt.tight_layout() + # plt.show() bias = predicted - reference mean_bias: xr.Dataset = bias.mean() mse: xr.Dataset = (bias ** 2).mean() diff --git a/external/fv3fit/fv3fit/train.py b/external/fv3fit/fv3fit/train.py index cb26b5d044..11cc193b67 100644 --- a/external/fv3fit/fv3fit/train.py +++ b/external/fv3fit/fv3fit/train.py @@ -175,13 +175,12 @@ def main(args, unknown_args=None): if __name__ == "__main__": - LOGLEVEL = os.environ.get("LOGLEVEL", "INFO").upper() - logger.setLevel(LOGLEVEL) + logger.setLevel(logging.INFO) parser = get_parser() args, unknown_args = parser.parse_known_args() os.makedirs("artifacts", exist_ok=True) logging.basicConfig( - level=LOGLEVEL, + level=logging.INFO, format="%(asctime)s [%(levelname)s] %(filename)s::L%(lineno)d : %(message)s", handlers=[ logging.FileHandler("artifacts/training.log"), From f7cf3e985765983fdf9a3e19c62b78d0ff0a659b Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 26 Sep 2022 15:05:02 -0700 Subject: [PATCH 52/55] add in_memory option for training CycleGAN --- .../fv3fit/pytorch/cyclegan/train_cyclegan.py | 55 +++++++++++++++---- 1 file changed, 44 insertions(+), 11 deletions(-) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py b/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py index 09e1c10d4b..9e427ac565 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/train_cyclegan.py @@ -67,12 +67,16 @@ class CycleGANTrainingConfig: validation_batch_size: number of samples to use per batch for validation, does not affect training result but allows the use of out-of-sample validation data + in_memory: if True, load the entire dataset into memory as pytorch tensors + before training. Batches will be statically defined but will be shuffled + between epochs. """ n_epoch: int = 20 shuffle_buffer_size: int = 10 samples_per_batch: int = 1 validation_batch_size: Optional[int] = None + in_memory: bool = False def fit_loop( self, @@ -93,12 +97,6 @@ def fit_loop( train_data = train_data.shuffle(buffer_size=self.shuffle_buffer_size) train_data = train_data.batch(self.samples_per_batch) train_data_numpy = tfds.as_numpy(train_data) - train_states = [] - for batch_state in train_data_numpy: - state_a = torch.as_tensor(batch_state[0]).float().to(DEVICE) - state_b = torch.as_tensor(batch_state[1]).float().to(DEVICE) - train_states.append((state_a, state_b)) - train_example_as_dataset = tfds.as_numpy(train_data.take(1).cache()) if validation_data is not None: if self.validation_batch_size is None: validation_batch_size = sequence_size(validation_data) @@ -106,12 +104,49 @@ def fit_loop( validation_batch_size = self.validation_batch_size validation_data = validation_data.batch(validation_batch_size) validation_data = tfds.as_numpy(validation_data) + if self.in_memory: + self._fit_loop_tensor(train_model, train_data_numpy, validation_data) + else: + self._fit_loop_dataset(train_model, train_data_numpy, validation_data) + + def _fit_loop_dataset( + self, + train_model: CycleGANTrainer, + train_data_numpy, + validation_data: Optional[tf.data.Dataset], + ): + for i in range(1, self.n_epoch + 1): + logger.info("starting epoch %d", i) + train_losses = [] + for batch_state in train_data_numpy: + state_a = torch.as_tensor(batch_state[0]).float().to(DEVICE) + state_b = torch.as_tensor(batch_state[1]).float().to(DEVICE) + train_losses.append(train_model.train_on_batch(state_a, state_b)) + train_loss = { + name: np.mean([data[name] for data in train_losses]) + for name in train_losses[0] + } + logger.info("train_loss: %s", train_loss) + + if validation_data is not None: + val_loss = train_model.evaluate_on_dataset(validation_data) + logger.info("val_loss %s", val_loss) + + def _fit_loop_tensor( + self, + train_model: CycleGANTrainer, + train_data_numpy: tf.data.Dataset, + validation_data: Optional[tf.data.Dataset], + ): + train_states = [] + batch_state: Tuple[np.ndarray, np.ndarray] + for batch_state in train_data_numpy: + state_a = torch.as_tensor(batch_state[0]).float().to(DEVICE) + state_b = torch.as_tensor(batch_state[1]).float().to(DEVICE) + train_states.append((state_a, state_b)) for i in range(1, self.n_epoch + 1): logger.info("starting epoch %d", i) train_losses = [] - # for batch_state in train_data_numpy: - # state_a = torch.as_tensor(batch_state[0]).float().to(DEVICE) - # state_b = torch.as_tensor(batch_state[1]).float().to(DEVICE) for state_a, state_b in train_states: train_losses.append(train_model.train_on_batch(state_a, state_b)) random.shuffle(train_states) @@ -124,8 +159,6 @@ def fit_loop( if validation_data is not None: val_loss = train_model.evaluate_on_dataset(validation_data) logger.info("val_loss %s", val_loss) - else: - train_model.evaluate_on_dataset(train_example_as_dataset) def apply_to_tuple_mapping(func): From f86e2dcc8a95afb0cda6691c2a91768c9af1f570 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 26 Sep 2022 15:05:29 -0700 Subject: [PATCH 53/55] continued attempts to get training working again on mps (not working) --- .../fv3fit/fv3fit/pytorch/cyclegan/modules.py | 40 ++++++++++--------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/external/fv3fit/fv3fit/pytorch/cyclegan/modules.py b/external/fv3fit/fv3fit/pytorch/cyclegan/modules.py index 0f21a182f7..1656019917 100644 --- a/external/fv3fit/fv3fit/pytorch/cyclegan/modules.py +++ b/external/fv3fit/fv3fit/pytorch/cyclegan/modules.py @@ -195,13 +195,33 @@ def halo_convolution( return BreakOnOp(nn.Sequential(append, conv)) +def cpu_only(method): + """ + Decorator to mark a method as only being supported on the CPU. + """ + + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + original_device = args[0].device + args = [arg.cpu() if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = { + k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in kwargs.items() + } + return method(self, *args, **kwargs).to(original_device) + + return wrapper + + class Crop(nn.Module): def __init__(self, n_halo): super(Crop, self).__init__() self.n_halo = n_halo + @cpu_only def forward(self, x): - return x[..., self.n_halo : -self.n_halo, self.n_halo : -self.n_halo] + return x[ + ..., self.n_halo : -self.n_halo, self.n_halo : -self.n_halo + ].contiguous() class BreakOnOp(nn.Module): @@ -222,23 +242,6 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: return output -def cpu_only(method): - """ - Decorator to mark a method as only being supported on the CPU. - """ - - @functools.wraps(method) - def wrapper(self, *args, **kwargs): - original_device = args[0].device - args = [arg.cpu() if isinstance(arg, torch.Tensor) else arg for arg in args] - kwargs = { - k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in kwargs.items() - } - return method(self, *args, **kwargs).to(original_device) - - return wrapper - - class AppendHalos(nn.Module): """ @@ -255,6 +258,7 @@ def __init__(self, n_halo: int): def extra_repr(self) -> str: return super().extra_repr() + f"n_halo={self.n_halo}" + @cpu_only def forward(self, inputs: torch.Tensor) -> torch.Tensor: """ Args: From 5118281e68a7c82b794278d18ab7c0ddeec46cca Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Mon, 26 Sep 2022 15:06:00 -0700 Subject: [PATCH 54/55] updates to project directory, evaluation --- projects/cyclegan/Makefile | 2 +- projects/cyclegan/evaluate.py | 379 ++++++++++++++++++++----- projects/cyclegan/train-data.yaml | 4 +- projects/cyclegan/training.yaml | 5 +- projects/cyclegan/validation-data.yaml | 24 ++ 5 files changed, 343 insertions(+), 71 deletions(-) create mode 100644 projects/cyclegan/validation-data.yaml diff --git a/projects/cyclegan/Makefile b/projects/cyclegan/Makefile index ba275a5de4..a23e17522f 100644 --- a/projects/cyclegan/Makefile +++ b/projects/cyclegan/Makefile @@ -1,7 +1,7 @@ train: - python3 -m fv3fit.train training.yaml train-data.yaml output + python3 -m fv3fit.train training.yaml train-data.yaml output --validation-data validation-data.yaml data: python3 download_data.py diff --git a/projects/cyclegan/evaluate.py b/projects/cyclegan/evaluate.py index 7257de6107..2dc12fbf47 100644 --- a/projects/cyclegan/evaluate.py +++ b/projects/cyclegan/evaluate.py @@ -8,47 +8,218 @@ from vcm.catalog import catalog import fv3viz import cartopy.crs as ccrs +from fv3net.diagnostics.prognostic_run.views import movies +from toolz import curry + +GRID = catalog["grid/c48"].read() + + +def plot_video(): + pass + + +def plot_once(ax): + i_time = random.randint(0, c48_real.time.size - 1) + fig, ax = plt.subplots( + 2, 2, figsize=(10, 6), subplot_kw={"projection": ccrs.Robinson()} + ) + fv3viz.plot_cube(ds=c48_real, var_name="h500", ax=ax[0, 0]) + ax[0, 0].set_title("c48_real") + fv3viz.plot_cube(ds=c384_real, var_name="h500", ax=ax[1, 0]) + ax[1, 0].set_title("c384_real") + fv3viz.plot_cube(ds=c384_gen, var_name="h500", ax=ax[0, 1]) + ax[0, 1].set_title("c384_gen") + fv3viz.plot_cube(ds=c48_gen, var_name="h500", ax=ax[1, 1]) + ax[1, 1].set_title("c48_gen") + plt.tight_layout() + + +def plot(arg, vmin=4800, vmax=6000): + ds, filename = arg + fig, ax = plt.subplots( + 2, 2, figsize=(10, 5), subplot_kw={"projection": ccrs.Robinson()} + ) + c48_real = ds.isel(resolution=0, type=0) + c384_real = ds.isel(resolution=1, type=0) + c384_gen = ds.isel(resolution=1, type=1) + c48_gen = ds.isel(resolution=0, type=1) + fv3viz.plot_cube(ds=c48_real, var_name="h500", ax=ax[0, 0], vmin=vmin, vmax=vmax) + ax[0, 0].set_title("c48_real") + fv3viz.plot_cube(ds=c384_real, var_name="h500", ax=ax[1, 0], vmin=vmin, vmax=vmax) + ax[1, 0].set_title("c384_real") + fv3viz.plot_cube(ds=c384_gen, var_name="h500", ax=ax[0, 1], vmin=vmin, vmax=vmax) + ax[0, 1].set_title("c384_gen") + fv3viz.plot_cube(ds=c48_gen, var_name="h500", ax=ax[1, 1], vmin=vmin, vmax=vmax) + ax[1, 1].set_title("c48_gen") + plt.tight_layout() + fig.savefig(filename, dpi=100) + plt.close(fig) + + +def plot_weather(arg, vmin=4800, vmax=6000, vmin_diff=-100, vmax_diff=100): + ds, filename = arg + fig, ax = plt.subplots( + 2, 3, figsize=(12, 6), subplot_kw={"projection": ccrs.Robinson()} + ) + c48_real = ds.isel(resolution=0, type=0) + c384_real = ds.isel(resolution=1, type=0) + c384_gen = ds.isel(resolution=1, type=1) + fv3viz.plot_cube(ds=c384_real, var_name="h500", ax=ax[0, 0], vmin=vmin, vmax=vmax) + ax[0, 0].set_title("c384_real") + fv3viz.plot_cube(ds=c48_real, var_name="h500", ax=ax[0, 1], vmin=vmin, vmax=vmax) + ax[0, 1].set_title("c48_real") + fv3viz.plot_cube(ds=c384_gen, var_name="h500", ax=ax[0, 2], vmin=vmin, vmax=vmax) + ax[0, 2].set_title("c384_gen") + fv3viz.plot_cube( + ds=GRID.merge(c48_real - c384_real, compat="override"), + var_name="h500", + ax=ax[1, 1], + vmin=vmin_diff, + vmax=vmax_diff, + ) + ax[1, 1].set_title("c48_real - c384_real") + fv3viz.plot_cube( + ds=GRID.merge(c384_gen - c384_real, compat="override"), + var_name="h500", + ax=ax[1, 2], + vmin=vmin_diff, + vmax=vmax_diff, + ) + ax[1, 2].set_title("c384_gen - c384_real") + plt.tight_layout() + fig.savefig(filename, dpi=100) + plt.close(fig) if __name__ == "__main__": random.seed(0) - grid = catalog["grid/c48"].read() - cyclegan: fv3fit.pytorch.CycleGAN = fv3fit.load("output_good").to(DEVICE) - c48_real: xr.Dataset = ( - xr.open_zarr("c48_baseline.zarr") - .rename({"grid_xt": "x", "grid_yt": "y"}) - .isel(time=slice(-2905 * 2, None, 2)) - ).load() + cyclegan: fv3fit.pytorch.CycleGAN = fv3fit.load("output").to(DEVICE) c384_real: xr.Dataset = ( - xr.open_zarr("c384_baseline.zarr") - .rename({"grid_xt": "x", "grid_yt": "y"}) - .isel(time=slice(-2905 * 2, None, 2)) + xr.open_zarr("c384_baseline.zarr").rename({"grid_xt": "x", "grid_yt": "y"}) + # .isel(time=slice(2904, None)) ).load() + c48_real: xr.Dataset = ( + xr.open_zarr("c48_baseline.zarr").rename({"grid_xt": "x", "grid_yt": "y"}) + # .isel(time=slice(0, len(c384_real.time))) + ).load() + i_start = 2904 + c384_gen: xr.Dataset = cyclegan.predict(c48_real) c48_gen: xr.Dataset = cyclegan.predict(c384_real, reverse=True) - # for _ in range(3): - # i_time = random.randint(0, c48_real.time.size - 1) - # fig, ax = plt.subplots(2, 2, figsize=(10, 6), subplot_kw={"projection": ccrs.Robinson()}) - # fv3viz.plot_cube(ds=c48_real.isel(time=i_time).merge(grid), var_name="h500", ax=ax[0, 0]) - # ax[0, 0].set_title("c48_real") - # fv3viz.plot_cube(ds=c384_real.isel(time=i_time).merge(grid), var_name="h500", ax=ax[1, 0]) - # ax[1, 0].set_title("c384_real") - # fv3viz.plot_cube(ds=c384_gen.isel(time=i_time).merge(grid), var_name="h500", ax=ax[0, 1]) - # ax[0, 1].set_title("c384_gen") - # fv3viz.plot_cube(ds=c48_gen.isel(time=i_time).merge(grid), var_name="h500", ax=ax[1, 1]) - # ax[1, 1].set_title("c48_gen") - # plt.tight_layout() - # fig, ax = plt.subplots(2, 2, figsize=(10, 6), subplot_kw={"projection": ccrs.Robinson()}) - # fv3viz.plot_cube(ds=c48_real.isel(time=i_time).merge(grid), var_name="h500", ax=ax[0, 0]) - # ax[0, 0].set_title("c48_real") - # fv3viz.plot_cube(ds=c384_real.isel(time=i_time).merge(grid), var_name="h500", ax=ax[1, 0]) - # ax[1, 0].set_title("c384_real") - # fv3viz.plot_cube(ds=(c384_gen - c48_real).isel(time=i_time).merge(grid), var_name="h500", ax=ax[0, 1]) - # ax[0, 1].set_title("c384_gen") - # fv3viz.plot_cube(ds=(c48_gen - c384_real).isel(time=i_time).merge(grid), var_name="h500", ax=ax[1, 1]) - # ax[1, 1].set_title("c48_gen") - # plt.tight_layout() + # ds = xr.concat( + # [ + # xr.concat([c48_real.drop("time"), c384_real.drop("time")], dim="resolution"), + # xr.concat([c48_gen, c384_gen], dim="resolution") + # ], dim="type" + # ).merge(GRID) + + # spec = movies.MovieSpec(name="h500", plotting_function=plot, required_variables=["h500"]) + # movies._create_movie(spec, ds.isel(time=range(0, 4*30)), output=".", n_jobs=8) + + # spec = movies.MovieSpec( + # name="h500_weather", + # plotting_function=plot_weather, + # required_variables=["h500"] + # ) + # movies._create_movie(spec, ds.isel(time=range(0, 4*7)), output=".", n_jobs=8) + + # fig, ax = plt.subplots( + # 1, 1, figsize=(5, 3) + # ) + # stderr_baseline = ( + # c48_real["h500"].isel(time=range(0, 4*7)) - c384_real["h500"].isel(time=range(0, 4*7)) + # ).std(dim=["x", "y", "tile"]) + # stderr_gen = ( + # c384_gen["h500"].isel(time=range(0, 4*7)) - c384_real["h500"].isel(time=range(0, 4*7)) + # ).std(dim=["x", "y", "tile"]) + # stderr_baseline.plot(ax=ax, label="baseline") + # stderr_gen.plot(ax=ax, label="generated") + # ax.legend(loc="upper left") + # ax.set_ylabel("h500 standard error vs c384_real") + # plt.tight_layout() + # fig.savefig("h500_weather_stderr.png", dpi=100) + + # fig, ax = plt.subplots( + # 1, 1, figsize=(5, 3) + # ) + # bias_baseline = ( + # c48_real["h500"].isel(time=range(0, 4*7)) - c384_real["h500"].isel(time=range(0, 4*7)) + # ).mean(dim=["x", "y", "tile"]) + # bias_gen = ( + # c384_gen["h500"].isel(time=range(0, 4*7)) - c384_real["h500"].isel(time=range(0, 4*7)) + # ).mean(dim=["x", "y", "tile"]) + # bias_baseline.plot(ax=ax, label="baseline") + # bias_gen.plot(ax=ax, label="generated") + # ax.legend(loc="upper left") + # ax.set_ylabel("h500 bias vs c384_real") + # plt.tight_layout() + # fig.savefig("h500_weather_bias.png", dpi=100) + + fig, ax = plt.subplots(1, 1, figsize=(5, 3)) + plt.hist( + c48_real.h500.values.flatten(), + bins=100, + alpha=0.5, + label="c48_real", + histtype="step", + density=True, + ) + plt.hist( + c384_real.h500.values.flatten(), + bins=100, + alpha=0.5, + label="c384_real", + histtype="step", + density=True, + ) + plt.hist( + c384_gen.h500.values.flatten(), + bins=100, + alpha=0.5, + label="c384_gen", + histtype="step", + density=True, + ) + plt.yscale("log") + # plt.hist(c48_gen.h500.values.flatten(), bins=100, alpha=0.5, label="c48_gen") + plt.legend(loc="upper left") + plt.xlabel("h500 (Pa)") + plt.ylabel("probability density") + plt.tight_layout() + fig.savefig("h500_histogram_log.png", dpi=100) + + fig, ax = plt.subplots(1, 1, figsize=(5, 3)) + plt.hist( + c48_real.h500.values.flatten(), + bins=100, + alpha=0.5, + label="c48_real", + histtype="step", + density=True, + ) + plt.hist( + c384_real.h500.values.flatten(), + bins=100, + alpha=0.5, + label="c384_real", + histtype="step", + density=True, + ) + plt.hist( + c384_gen.h500.values.flatten(), + bins=100, + alpha=0.5, + label="c384_gen", + histtype="step", + density=True, + ) + # plt.hist(c48_gen.h500.values.flatten(), bins=100, alpha=0.5, label="c48_gen") + plt.legend(loc="upper left") + plt.xlabel("h500 (Pa)") + plt.ylabel("probability density") + plt.tight_layout() + fig.savefig("h500_histogram.png", dpi=100) c48_real_mean = c48_real.mean("time") c48_gen_mean = c48_gen.mean("time") @@ -57,46 +228,68 @@ fig, ax = plt.subplots( 2, 2, figsize=(10, 6), subplot_kw={"projection": ccrs.Robinson()} ) - fv3viz.plot_cube(ds=c48_real_mean.merge(grid), var_name="h500", ax=ax[0, 0]) + fv3viz.plot_cube( + ds=c48_real_mean.merge(GRID), var_name="h500", ax=ax[0, 0], vmin=5000, vmax=5900 + ) ax[0, 0].set_title("c48_real") - fv3viz.plot_cube(ds=c384_real_mean.merge(grid), var_name="h500", ax=ax[1, 0]) + fv3viz.plot_cube( + ds=c384_real_mean.merge(GRID), + var_name="h500", + ax=ax[1, 0], + vmin=5000, + vmax=5900, + ) ax[1, 0].set_title("c384_real") - fv3viz.plot_cube(ds=c384_gen_mean.merge(grid), var_name="h500", ax=ax[0, 1]) + fv3viz.plot_cube( + ds=c384_gen_mean.merge(GRID), var_name="h500", ax=ax[0, 1], vmin=5000, vmax=5900 + ) ax[0, 1].set_title("c384_gen") - fv3viz.plot_cube(ds=c48_gen_mean.merge(grid), var_name="h500", ax=ax[1, 1]) + fv3viz.plot_cube( + ds=c48_gen_mean.merge(GRID), var_name="h500", ax=ax[1, 1], vmin=5000, vmax=5900 + ) ax[1, 1].set_title("c48_gen") plt.tight_layout() + fig.savefig("h500_mean.png", dpi=100) fig, ax = plt.subplots( - 2, 2, figsize=(10, 6), subplot_kw={"projection": ccrs.Robinson()} + 1, 3, figsize=(12, 3), subplot_kw={"projection": ccrs.Robinson()} ) - fv3viz.plot_cube(ds=c48_real_mean.merge(grid), var_name="h500", ax=ax[0, 0]) - ax[0, 0].set_title("c48_real") fv3viz.plot_cube( - ds=(c384_real_mean - c48_real_mean).merge(grid), - var_name="h500", - ax=ax[1, 0], - vmin=-70, - vmax=70, + ds=c384_real_mean.merge(GRID), var_name="h500", ax=ax[0], vmin=4800, vmax=6000, ) - ax[1, 0].set_title("c384_real") + ax[0].set_title("c384_real") fv3viz.plot_cube( - ds=(c384_gen_mean - c48_real_mean).merge(grid), + ds=(c48_real_mean - c384_real_mean).merge(GRID), var_name="h500", - ax=ax[0, 1], - vmin=-70, - vmax=70, + ax=ax[1], + vmin=-100, + vmax=100, ) - ax[0, 1].set_title("c384_gen") + var = ( + (c48_real_mean - c384_real_mean).var(dim=["x", "y", "tile"]).h500.values.item() + ) + mean = ( + (c48_real_mean - c384_real_mean).mean(dim=["x", "y", "tile"]).h500.values.item() + ) + ax[1].set_title("c48_real - c384_real\nvar: {:.2f}\nmean: {:.2f}".format(var, mean)) fv3viz.plot_cube( - ds=(c48_gen_mean - c48_real_mean).merge(grid), + ds=(c384_gen_mean - c384_real_mean).merge(GRID), var_name="h500", - ax=ax[1, 1], - vmin=-70, - vmax=70, + ax=ax[2], + vmin=-100, + vmax=100, + ) + var = ( + (c384_gen_mean - c384_real_mean).var(dim=["x", "y", "tile"]).h500.values.item() + ) + mean = ( + (c384_gen_mean - c384_real_mean).mean(dim=["x", "y", "tile"]).h500.values.item() + ) + ax[2].set_title( + "c384_gen - c384_real\nvar = {:.2f}\nmean = {:.2f}".format(var, mean) ) - ax[1, 1].set_title("c48_gen") plt.tight_layout() + fig.savefig("h500_mean_diff.png", dpi=100) mse = (c384_real_mean - c384_gen_mean).var() var = (c384_real_mean).var() @@ -110,30 +303,82 @@ print(var) print(1.0 - (mse / var)) - # c48_real_std = c48_real.std("time") - # c48_gen_std = c48_gen.std("time") - # c384_real_std = c384_real.std("time") - # c384_gen_std = c384_gen.std("time") + c48_real_std = c48_real.std("time").rename({"h500": "h500_std"}) + c48_gen_std = c48_gen.std("time").rename({"h500": "h500_std"}) + c384_real_std = c384_real.std("time").rename({"h500": "h500_std"}) + c384_gen_std = c384_gen.std("time").rename({"h500": "h500_std"}) + + fig, ax = plt.subplots( + 1, 3, figsize=(12, 3), subplot_kw={"projection": ccrs.Robinson()} + ) + fv3viz.plot_cube( + ds=c384_real_std.merge(GRID), var_name="h500_std", ax=ax[0], + ) + ax[0].set_title("c384_real") + fv3viz.plot_cube( + ds=(c48_real_std - c384_real_std).merge(GRID), + var_name="h500_std", + ax=ax[1], + vmin=-60, + vmax=60, + ) + var = ( + (c48_real_std - c384_real_std) + .var(dim=["x", "y", "tile"]) + .h500_std.values.item() + ) + mean = ( + (c48_real_std - c384_real_std) + .mean(dim=["x", "y", "tile"]) + .h500_std.values.item() + ) + ax[1].set_title( + "c48_real - c384_real\nvar = {:.2f}\nmean = {:.2f}".format(var, mean) + ) + fv3viz.plot_cube( + ds=(c384_gen_std - c384_real_std).merge(GRID), + var_name="h500_std", + ax=ax[2], + vmin=-60, + vmax=60, + ) + var = ( + (c384_gen_std - c384_real_std) + .var(dim=["x", "y", "tile"]) + .h500_std.values.item() + ) + mean = ( + (c384_gen_std - c384_real_std) + .mean(dim=["x", "y", "tile"]) + .h500_std.values.item() + ) + ax[2].set_title( + "c384_gen - c384_real\nvar = {:.2f}\nmean = {:.2f}".format(var, mean) + ) + plt.tight_layout() + fig.savefig("h500_std_diff.png", dpi=100) + # fig, ax = plt.subplots(2, 2, figsize=(10, 6), subplot_kw={"projection": ccrs.Robinson()}) - # fv3viz.plot_cube(ds=c48_real_std.merge(grid), var_name="h500", ax=ax[0, 0]) + # fv3viz.plot_cube(ds=c48_real_std.merge(GRID), var_name="h500", ax=ax[0, 0]) # ax[0, 0].set_title("c48_real") - # fv3viz.plot_cube(ds=c384_real_std.merge(grid), var_name="h500", ax=ax[1, 0]) + # fv3viz.plot_cube(ds=c384_real_std.merge(GRID), var_name="h500", ax=ax[1, 0]) # ax[1, 0].set_title("c384_real") - # fv3viz.plot_cube(ds=c384_gen_std.merge(grid), var_name="h500", ax=ax[0, 1]) + # fv3viz.plot_cube(ds=c384_gen_std.merge(GRID), var_name="h500", ax=ax[0, 1]) # ax[0, 1].set_title("c384_gen") - # fv3viz.plot_cube(ds=c48_gen_std.merge(grid), var_name="h500", ax=ax[1, 1]) + # fv3viz.plot_cube(ds=c48_gen_std.merge(GRID), var_name="h500", ax=ax[1, 1]) # ax[1, 1].set_title("c48_gen") # plt.tight_layout() # fig, ax = plt.subplots(2, 2, figsize=(10, 6), subplot_kw={"projection": ccrs.Robinson()}) - # fv3viz.plot_cube(ds=c48_real_std.merge(grid), var_name="h500", ax=ax[0, 0]) + # fv3viz.plot_cube(ds=c48_real_std.merge(GRID), var_name="h500", ax=ax[0, 0]) # ax[0, 0].set_title("c48_real") - # fv3viz.plot_cube(ds=(c384_real_std - c48_real_std).merge(grid), var_name="h500", ax=ax[1, 0], vmin=-50, vmax=50) + # fv3viz.plot_cube(ds=(c384_real_std - c48_real_std).merge(GRID), var_name="h500", ax=ax[1, 0], vmin=-50, vmax=50) # ax[1, 0].set_title("c384_real") - # fv3viz.plot_cube(ds=(c384_gen_std - c48_real_std).merge(grid), var_name="h500", ax=ax[0, 1], vmin=-50, vmax=50) + # fv3viz.plot_cube(ds=(c384_gen_std - c48_real_std).merge(GRID), var_name="h500", ax=ax[0, 1], vmin=-50, vmax=50) # ax[0, 1].set_title("c384_gen") - # fv3viz.plot_cube(ds=(c48_gen_std - c48_real_std).merge(grid), var_name="h500", ax=ax[1, 1], vmin=-50, vmax=50) + # fv3viz.plot_cube(ds=(c48_gen_std - c48_real_std).merge(GRID), var_name="h500", ax=ax[1, 1], vmin=-50, vmax=50) # ax[1, 1].set_title("c48_gen") # plt.tight_layout() + # fig.savefig("h500_std.png", dpi=100) plt.show() diff --git a/projects/cyclegan/train-data.yaml b/projects/cyclegan/train-data.yaml index 871d557d33..9e2f481d9f 100644 --- a/projects/cyclegan/train-data.yaml +++ b/projects/cyclegan/train-data.yaml @@ -10,7 +10,7 @@ domain_configs: window_size: 1 default_variable_config: times: window - n_windows: 500 # 5808 + n_windows: 20 # 5808 - data_path: c384_baseline.zarr unstacked_dims: - time @@ -21,4 +21,4 @@ domain_configs: window_size: 1 default_variable_config: times: window - n_windows: 500 # 5808 + n_windows: 20 # 5808 diff --git a/projects/cyclegan/training.yaml b/projects/cyclegan/training.yaml index 41ed43b1fa..99a9dc17e2 100644 --- a/projects/cyclegan/training.yaml +++ b/projects/cyclegan/training.yaml @@ -6,6 +6,7 @@ hyperparameters: - h500 normalization_fit_samples: 50_000 network: + convolution_type: halo_conv2d generator_optimizer: name: Adam kwargs: @@ -18,6 +19,7 @@ hyperparameters: n_convolutions: 2 n_resnet: 6 kernel_size: 3 + strided_kernel_size: 4 max_filters: 256 discriminator: n_convolutions: 3 @@ -34,7 +36,8 @@ hyperparameters: generator_weight: 1.0 discriminator_weight: 1.0 training: - n_epoch: 10 + n_epoch: 20 + in_memory: true shuffle_buffer_size: 1000 samples_per_batch: 1 validation_batch_size: 100 diff --git a/projects/cyclegan/validation-data.yaml b/projects/cyclegan/validation-data.yaml new file mode 100644 index 0000000000..d9190396d0 --- /dev/null +++ b/projects/cyclegan/validation-data.yaml @@ -0,0 +1,24 @@ +batch_size: 500 +domain_configs: + - data_path: c48_baseline.zarr + unstacked_dims: + - time + - tile + - grid_xt + - grid_yt + - z + window_size: 1 + default_variable_config: + times: window + n_windows: 50 # 5808 + - data_path: c384_baseline.zarr + unstacked_dims: + - time + - tile + - grid_xt + - grid_yt + - z + window_size: 1 + default_variable_config: + times: window + n_windows: 50 # 5808 From 682c353668f5460b9f9e7121af30f9f3662c4983 Mon Sep 17 00:00:00 2001 From: Jeremy McGibbon Date: Tue, 27 Sep 2022 12:14:38 -0700 Subject: [PATCH 55/55] use time-interpolated C384 data to match C48 time grid --- projects/cyclegan/download_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/cyclegan/download_data.py b/projects/cyclegan/download_data.py index e691788f77..65ed6f48ed 100644 --- a/projects/cyclegan/download_data.py +++ b/projects/cyclegan/download_data.py @@ -6,7 +6,7 @@ ) c48 = c48.drop_vars([name for name in c48.data_vars if name != "h500"]) c384 = xr.open_zarr( - "gs://vcm-ml-raw-flexible-retention/2021-01-04-1-year-C384-FV3GFS-simulations/unperturbed/C384-to-C48-diagnostics/atmos_8xdaily_coarse.zarr" # noqa: E501 + "gs://vcm-ml-raw-flexible-retention/2021-01-04-1-year-C384-FV3GFS-simulations/unperturbed/C384-to-C48-diagnostics/atmos_8xdaily_coarse_interpolated.zarr" # noqa: E501 ) c384 = c384.drop_vars([name for name in c384.data_vars if name != "h500"]) c48.to_zarr("c48_baseline.zarr")