Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
rubencart committed Oct 25, 2024
1 parent d5cfd03 commit 11ba42f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 54 deletions.
29 changes: 10 additions & 19 deletions sssl/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ class FinetuneConfig(Tap):
lr: float = 1e-4
backbone_lr: float = 1e-5
optimizer: str = "adam"
lr_schedule: Literal[
"linear_with_warmup", "reduce_on_plateau"
] = "reduce_on_plateau" # or None
lr_schedule: Literal["linear_with_warmup", "reduce_on_plateau"] = "reduce_on_plateau" # or None
lr_schedule_warmup_epochs: int = 0
lr_schedule_monitor: str = "val_maj_vote_f1_macro"
lr_schedule_mode: str = "max"
Expand All @@ -51,17 +49,17 @@ class FinetuneConfig(Tap):
percentage_of_training_data: Literal[100, 70, 50, 20, 5, 1] = 100
binarize_ipc: bool = False
n_steps_in_future: int = 0
temporally_separated: bool = True
temporally_separated: bool = False


class PretrainConfig(Tap):
wandb_project_name: str = "SSSL"

lr: float = 1e-4
optimizer: str = "adam"
lr_schedule: Literal[
"linear_with_warmup", "reduce_on_plateau"
] = None # 'reduce_on_plateau' # or None
lr_schedule: Literal["linear_with_warmup", "reduce_on_plateau"] = (
None # 'reduce_on_plateau' # or None
)
lr_schedule_warmup_epochs: int = 0
lr_schedule_monitor: str = "val_loss"
lr_schedule_mode: str = "min"
Expand Down Expand Up @@ -141,9 +139,7 @@ class Config(Tap):
downstr_splits_path: str = "downstr_splits_incl_small.json"
valtest_wo_neighbors: str = "to_exclude.json"
fixed_random_order_path: str = "fixed_random_order.json"
ipc_scores_csv_path: str = (
"data/predicting_food_crises_data_somalia_from2013-05-01.csv"
)
ipc_scores_csv_path: str = "data/predicting_food_crises_data_somalia_from2013-05-01.csv"
future_ipc_shp: List[str] = [
"data/SO_202006/SO_202006_CS.shp",
"data/SO_202010/SO_202010_CS.shp",
Expand Down Expand Up @@ -217,14 +213,9 @@ def process_args(self, process=True) -> None:
self.process()

def process(self):
self.feature_size = (
1000 if "resnet" in self.cnn_type else self.conv4_feature_size
)

if (
self.pretrain.loss_type == "sssl"
and self.pretrain.augmentations == "rel_reasoning"
):
self.feature_size = 1000 if "resnet" in self.cnn_type else self.conv4_feature_size

if self.pretrain.loss_type == "sssl" and self.pretrain.augmentations == "rel_reasoning":
logger.info("Disabling pin_memory for rel_reasoning augmentations...")
self.pin_memory = False

Expand All @@ -238,7 +229,7 @@ def to_dict(self):
dct.update({"tile2vec": self.tile2vec.as_dict()})
dct.update({"pretrain": self.pretrain.as_dict()})
return {
k: (v if not isinstance(v, List) else str(v))
k: v if not isinstance(v, List) else str(v)
for (k, v) in dct.items()
if not isinstance(v, MethodType)
}
Expand Down
57 changes: 22 additions & 35 deletions sssl/data/ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,15 @@ def __init__(self, samples: List[Sample], transform=None, to_pt=True):
_, h, w = shapes.max(axis=0)

# bs x ch x h x w
self.tiles = np.stack(
[
np.pad(
a.array,
((0, 0), (0, h - a.array.shape[1]), (0, w - a.array.shape[2])),
mode="constant",
constant_values=0.0,
)
for a in tiles
]
)
self.tiles = np.stack([
np.pad(
a.array,
((0, 0), (0, h - a.array.shape[1]), (0, w - a.array.shape[2])),
mode="constant",
constant_values=0.0,
)
for a in tiles
])
self.tiles = np.expand_dims(self.tiles, axis=1)
if to_pt:
self.tiles = torch.from_numpy(self.tiles)
Expand Down Expand Up @@ -109,9 +107,7 @@ def __init__(
with open(os.path.join(cfg.indices_dir, cfg.path_to_h5_virtual_idx)) as f:
self.h5_idx = json.load(f)

def zones_4_split(
self, split: str, split_dict: Dict, temporally_separated: bool
) -> List[str]:
def zones_4_split(self, split: str, split_dict: Dict, temporally_separated: bool) -> List[str]:
if temporally_separated:
return (
split_dict["ood_regions"]
Expand All @@ -134,9 +130,7 @@ def zones_4_split(
class IPCScoreDataset(Landsat8Dataset):
def __init__(self, cfg: Config, split: str = "train"):
super().__init__(cfg, split=split)
self.f: IPCLandsat8Files = IPCLandsat8Files(
cfg, split, cfg.finetune.temporally_separated
)
self.f: IPCLandsat8Files = IPCLandsat8Files(cfg, split, cfg.finetune.temporally_separated)
# so workers initialize correct h5 dataset
self.h5_split_name = "downstream"

Expand All @@ -152,30 +146,25 @@ def __init__(self, cfg: Config, split: str = "train"):
self.df.fews_ipc = utils.binarize_ipcs(self.df.fews_ipc)
all_ipcs = self.df.fews_ipc.unique()
all_ipcs.sort()
self.f.all_ipcs = all_ipcs.astype(np.int32)
self.f.all_ipcs = all_ipcs.astype(int)
self.f.ipc_dict = {ipc: i for (i, ipc) in enumerate(self.f.all_ipcs)}

self.regions = self.f.regions
self.zone_ids = [self.f.admin_dict[z] for z in self.regions]

logger.info("Filtering downstream boxes for %s" % split)
self.boxes = [
b for z in tqdm(self.regions) for b in self.f.zone2box2p[z].keys()
]
self.boxes = [b for z in tqdm(self.regions) for b in self.f.zone2box2p[z].keys()]
fut = cfg.finetune.n_steps_in_future
self.temp_sep = cfg.finetune.temporally_separated
if self.temp_sep:
date_idcs_dict = {
"train": list(
range(min(3 - fut, 2), len(self.f.all_end_dates) - max(fut + 1, 2))
),
"train": list(range(min(3 - fut, 2), len(self.f.all_end_dates) - max(fut + 1, 2))),
"val": [len(self.f.all_end_dates) - max(fut + 1, 2)],
"test": [len(self.f.all_end_dates) - max(fut, 1)],
"ood": [len(self.f.all_end_dates) - max(fut, 1)],
}
date_dict = {
sp: [self.f.all_end_dates[di] for di in idcs]
for sp, idcs in date_idcs_dict.items()
sp: [self.f.all_end_dates[di] for di in idcs] for sp, idcs in date_idcs_dict.items()
}

logger.info("Filtering downstream paths for %s" % split)
Expand All @@ -184,9 +173,7 @@ def __init__(self, cfg: Config, split: str = "train"):
for z in tqdm(self.regions)
for plist in self.f.zone2box2p[z].values()
for p in (
utils.filter_paths_by_date(plist, date_dict[split])
if self.temp_sep
else plist
utils.filter_paths_by_date(plist, date_dict[split]) if self.temp_sep else plist
)
]
logger.info(
Expand Down Expand Up @@ -226,7 +213,7 @@ def __init__(self, cfg: Config, split: str = "train"):
)
]
self.ipc_2_zd = defaultdict(list)
for (z, d, ipc) in self.zone_time_combos:
for z, d, ipc in self.zone_time_combos:
self.ipc_2_zd[ipc].append((z, d))

self.log_ipc_distributions()
Expand All @@ -235,7 +222,7 @@ def __init__(self, cfg: Config, split: str = "train"):
def log_ipc_distributions(self):
per_date = {}
per_zone = {}
for (z, d, ipc) in self.zone_time_combos:
for z, d, ipc in self.zone_time_combos:
per_date.setdefault(d, []).append(ipc)
per_zone.setdefault(z, []).append(ipc)
per_date = {
Expand All @@ -246,9 +233,9 @@ def log_ipc_distributions(self):
self.f.all_admins[z]: np.bincount(ipcs, minlength=4) / len(ipcs)
for z, ipcs in per_zone.items()
}
overall = np.bincount(
[ipc for (z, d, ipc) in self.zone_time_combos], minlength=4
) / len(self.zone_time_combos)
overall = np.bincount([ipc for (z, d, ipc) in self.zone_time_combos], minlength=4) / len(
self.zone_time_combos
)
logger.info("IPC distribution overall: %s" % pprint.pformat(overall))
logger.info("IPC distribution per date: \n%s" % pprint.pformat(per_date))
logger.info("IPC distribution per zone: \n%s" % pprint.pformat(per_zone))
Expand Down Expand Up @@ -283,7 +270,7 @@ def __getitem__(self, item: int) -> Sample:

def sample_from_ipc_class(self, ipc: int, shuffle=False) -> Sample:
paths = []
for (z, d) in (
for z, d in (
random.sample(self.ipc_2_zd[ipc], k=len(self.ipc_2_zd[ipc]))
if shuffle
else self.ipc_2_zd[ipc]
Expand Down

0 comments on commit 11ba42f

Please sign in to comment.