Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WiP] Sen4Map inheritance class change to geobench-compatible datamodule and dataset classes #227

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 104 additions & 28 deletions terratorch/datamodules/sen4map.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@
from torchvision.transforms.v2 import InterpolationMode
import pickle
import h5py
import torch
from torchgeo.datamodules import GeoDataModule, NonGeoDataModule, BaseDataModule
import kornia.augmentation as K # noqa: N812
from torch.utils.data import DataLoader

from terratorch.datamodules.geobench_data_module import GeobenchDataModule
from terratorch.datasets import HLSBands
from terratorch.datasets import Sen4MapDatasetMonthlyComposites
from kornia.augmentation.container import VideoSequential, ImageSequential, AugmentationSequential


class Sen4MapLucasDataModule(pl.LightningDataModule):
class Sen4MapLucasDataModule(NonGeoDataModule):
def __init__(
self,
batch_size,
Expand All @@ -21,9 +26,19 @@ def __init__(
test_hdf5_keys_path = None,
val_hdf5_path = None,
val_hdf5_keys_path = None,
bands: list[HLSBands | int] = None,
**kwargs
):


super().__init__(
Sen4MapDatasetMonthlyComposites,
batch_size=batch_size,
num_workers=num_workers,
**kwargs,
)
self.aug = AugmentationSequential(K.RandomEqualize3D(), data_keys=None)

#self.aug = AugmentationSequential(K.Normalize(MEANS, STDS), data_keys = None)
self.prepare_data_per_node = False
self._log_hyperparams = None
self.allow_zero_length_dataloader_with_multiple_devices = False
Expand All @@ -40,6 +55,11 @@ def __init__(
self.test_hdf5_keys_path = test_hdf5_keys_path
self.val_hdf5_keys_path = val_hdf5_keys_path

self.bands = bands
#bands = kwargs.get("bands", Sen4MapDatasetMonthlyComposites.all_band_names)
#self.means = torch.tensor([means[b] for b in bands])
#self.stds = torch.tensor([stds[b] for b in bands])

if train_hdf5_path and not train_hdf5_keys_path: print(f"Train dataset path provided but not the path to the dataset keys. Generating the keys might take a few minutes.")
if test_hdf5_path and not test_hdf5_keys_path: print(f"Test dataset path provided but not the path to the dataset keys. Generating the keys might take a few minutes.")
if val_hdf5_path and not val_hdf5_keys_path: print(f"Val dataset path provided but not the path to the dataset keys. Generating the keys might take a few minutes.")
Expand Down Expand Up @@ -103,53 +123,109 @@ def _load_hdf5_keys_from_path(self, path, fraction=1.0):
return keys[:int(fraction*len(keys))]

def setup(self, stage: str):
if stage == "fit":
if stage in ["fit"]:
train_keys = self._load_hdf5_keys_from_path(self.train_hdf5_keys_path, fraction=self.train_data_fraction)
val_keys = self._load_hdf5_keys_from_path(self.val_hdf5_keys_path, fraction=self.val_data_fraction)

if self.reduce_train_keys:
test_keys = self._load_hdf5_keys_from_path(self.test_hdf5_keys_path, fraction=self.test_data_fraction)
train_keys = list(set(train_keys) - set(val_keys) - set(test_keys))

train_file = h5py.File(self.train_hdf5_path, 'r')
self.lucasS2_train = Sen4MapDatasetMonthlyComposites(
self.train_dataset = Sen4MapDatasetMonthlyComposites(
train_file,
h5data_keys = train_keys,
resize = self.resize,
resize_to = self.resize_to,
resize_interpolation = self.resize_interpolation,
resize_antialiasing = self.resize_antialiasing,
save_keys_path = self.train_hdf5_keys_save_path,
h5data_keys=train_keys,
input_bands = self.bands,
dataset_bands = Sen4MapDatasetMonthlyComposites.all_band_names,
resize=self.resize,
resize_to=self.resize_to,
resize_interpolation=self.resize_interpolation,
resize_antialiasing=self.resize_antialiasing,
save_keys_path=self.train_hdf5_keys_save_path,
**self.kwargs
)

if stage in ["fit", "validate"]:
val_keys = self._load_hdf5_keys_from_path(self.val_hdf5_keys_path, fraction=self.val_data_fraction)
val_file = h5py.File(self.val_hdf5_path, 'r')
self.lucasS2_val = Sen4MapDatasetMonthlyComposites(
self.val_dataset = Sen4MapDatasetMonthlyComposites(
val_file,
h5data_keys=val_keys,
resize = self.resize,
resize_to = self.resize_to,
resize_interpolation = self.resize_interpolation,
resize_antialiasing = self.resize_antialiasing,
save_keys_path = self.val_hdf5_keys_save_path,
h5data_keys=val_keys,
input_bands = self.bands,
dataset_bands = Sen4MapDatasetMonthlyComposites.all_band_names,
resize=self.resize,
resize_to=self.resize_to,
resize_interpolation=self.resize_interpolation,
resize_antialiasing=self.resize_antialiasing,
save_keys_path=self.val_hdf5_keys_save_path,
**self.kwargs
)

if stage == "test":
test_file = h5py.File(self.test_hdf5_path, 'r')
test_keys = self._load_hdf5_keys_from_path(self.test_hdf5_keys_path, fraction=self.test_data_fraction)
self.lucasS2_test = Sen4MapDatasetMonthlyComposites(
test_file = h5py.File(self.test_hdf5_path, 'r')
self.test_dataset = Sen4MapDatasetMonthlyComposites(
test_file,
h5data_keys=test_keys,
resize = self.resize,
resize_to = self.resize_to,
resize_interpolation = self.resize_interpolation,
resize_antialiasing = self.resize_antialiasing,
save_keys_path = self.test_hdf5_keys_save_path,
input_bands = self.bands,
dataset_bands = Sen4MapDatasetMonthlyComposites.all_band_names,
resize=self.resize,
resize_to=self.resize_to,
resize_interpolation=self.resize_interpolation,
resize_antialiasing=self.resize_antialiasing,
save_keys_path=self.test_hdf5_keys_save_path,
**self.kwargs
)


# def setup(self, stage: str):
# if stage == "fit":
# train_keys = self._load_hdf5_keys_from_path(self.train_hdf5_keys_path, fraction=self.train_data_fraction)
# val_keys = self._load_hdf5_keys_from_path(self.val_hdf5_keys_path, fraction=self.val_data_fraction)
# if self.reduce_train_keys:
# test_keys = self._load_hdf5_keys_from_path(self.test_hdf5_keys_path, fraction=self.test_data_fraction)
# train_keys = list(set(train_keys) - set(val_keys) - set(test_keys))
# train_file = h5py.File(self.train_hdf5_path, 'r')
# self.lucasS2_train = Sen4MapDatasetMonthlyComposites(
# train_file,
# h5data_keys = train_keys,
# resize = self.resize,
# resize_to = self.resize_to,
# resize_interpolation = self.resize_interpolation,
# resize_antialiasing = self.resize_antialiasing,
# save_keys_path = self.train_hdf5_keys_save_path,
# **self.kwargs
# )
# val_file = h5py.File(self.val_hdf5_path, 'r')
# self.lucasS2_val = Sen4MapDatasetMonthlyComposites(
# val_file,
# h5data_keys=val_keys,
# resize = self.resize,
# resize_to = self.resize_to,
# resize_interpolation = self.resize_interpolation,
# resize_antialiasing = self.resize_antialiasing,
# save_keys_path = self.val_hdf5_keys_save_path,
# **self.kwargs
# )
# if stage == "test":
# test_file = h5py.File(self.test_hdf5_path, 'r')
# test_keys = self._load_hdf5_keys_from_path(self.test_hdf5_keys_path, fraction=self.test_data_fraction)
# self.lucasS2_test = Sen4MapDatasetMonthlyComposites(
# test_file,
# h5data_keys=test_keys,
# resize = self.resize,
# resize_to = self.resize_to,
# resize_interpolation = self.resize_interpolation,
# resize_antialiasing = self.resize_antialiasing,
# save_keys_path = self.test_hdf5_keys_save_path,
# **self.kwargs
# )

def train_dataloader(self):
return DataLoader(self.lucasS2_train, batch_size=self.batch_size, num_workers=self.num_workers, prefetch_factor=self.prefetch_factor, shuffle=self.train_shuffle)
return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, prefetch_factor=self.prefetch_factor, shuffle=self.train_shuffle)

def val_dataloader(self):
return DataLoader(self.lucasS2_val, batch_size=self.batch_size, num_workers=self.num_workers, prefetch_factor=self.prefetch_factor, shuffle=self.val_shuffle)
return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, prefetch_factor=self.prefetch_factor, shuffle=self.val_shuffle)

def test_dataloader(self):
return DataLoader(self.lucasS2_test, batch_size=self.batch_size, num_workers=self.num_workers, prefetch_factor=self.prefetch_factor, shuffle=self.test_shuffle)
return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers, prefetch_factor=self.prefetch_factor, shuffle=self.test_shuffle)
27 changes: 26 additions & 1 deletion terratorch/datasets/sen4map.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,40 @@
import numpy as np
import h5py
from torchgeo.datasets import NonGeoDataset

import torch
import torch.nn.functional as F
from kornia.augmentation.container import VideoSequential, ImageSequential, AugmentationSequential
from torch.utils.data import Dataset
from terratorch.datasets.utils import HLSBands

from torchvision.transforms.v2.functional import resize
from torchvision.transforms.v2 import InterpolationMode



import pickle


class Sen4MapDatasetMonthlyComposites(Dataset):


class Sen4MapDatasetMonthlyComposites(NonGeoDataset):
all_band_names = (
"BLUE",
"GREEN",
"RED",
"RED_EDGE_1",
"RED_EDGE_2",
"RED_EDGE_3",
"NIR_BROAD",
"NIR_NARROW",
"SWIR_1",
"SWIR_2",
)

rgb_bands = ("RED", "GREEN", "BLUE")

BAND_SETS = {"all": all_band_names, "rgb": rgb_bands}
# This dictionary maps the LUCAS classes to Land-cover classes.
land_cover_classification_map={'A10':0, 'A11':0, 'A12':0, 'A13':0,
'A20':0, 'A21':0, 'A30':0,
Expand Down Expand Up @@ -106,11 +128,14 @@ def __getitem__(self, index):
# we can call dataset with an index, eg. dataset[0]
im = self.h5data[self.h5data_keys[index]]
Image, Label = self.get_data(im)

Image = self.min_max_normalize(Image, [67.0, 122.0, 93.27, 158.5, 160.77, 174.27, 162.27, 149.0, 84.5, 66.27 ],
[2089.0, 2598.45, 3214.5, 3620.45, 4033.61, 4613.0, 4825.45, 4945.72, 5140.84, 4414.45])

Image = Image.clip(0,1)
Label = torch.LongTensor(Label)


if self.input_channels:
Image = Image[self.input_channels, ...]

Expand Down
Loading