-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
183 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
# %% [code] | ||
import torch, numpy as np | ||
|
||
from time import time | ||
from tqdm import tqdm | ||
from torch.utils.data import Dataset, DataLoader | ||
|
||
# %% [markdown] | ||
# # Slicing | ||
|
||
|
||
# %% [code] | ||
def gather_slice(inputs, dim, index, length): | ||
""" | ||
Gather a slice (index : index + length) along dimension dim. | ||
""" | ||
shape = torch.ones(len(inputs.shape)) | ||
shape[dim] = -1 | ||
|
||
slice = torch.arange(0, length, device=inputs.device) | ||
slice = slice.view(tuple(int(s) for s in shape)) | ||
|
||
index = index.unsqueeze(dim) | ||
index = index + slice | ||
index = index.to(torch.int64) | ||
|
||
return torch.gather(inputs, dim, index) | ||
|
||
|
||
# %% [code] | ||
def centered_slice(centers, length, lower_bound=0, upper_bound=None): | ||
""" | ||
Create slices of equal length centered at the centers, | ||
contained within the lower and upper bounds. | ||
""" | ||
start = centers - length / 2 | ||
|
||
if lower_bound is not None: | ||
start = np.maximum(start, 0) | ||
|
||
end = start + length | ||
|
||
if upper_bound is not None: | ||
end = np.minimum(end, upper_bound) | ||
start = end - length | ||
|
||
return start, end | ||
|
||
# %% [markdown] | ||
# # Training | ||
|
||
|
||
# %% [code] | ||
def set_parameter_requires_grad(model, value): | ||
for param in model.parameters(): | ||
param.requires_grad = value | ||
|
||
|
||
# %% [code] | ||
def evaluate(metrics, output, labels): | ||
""" | ||
Default model evaluation. | ||
""" | ||
scores = {} | ||
|
||
with torch.no_grad(): | ||
output = output.detach() | ||
|
||
for key in metrics: | ||
try: | ||
scores[key] = metrics[key](output, labels).item() | ||
except: | ||
scores[key] = metrics[key](output.cpu(), labels.cpu()) | ||
|
||
return scores | ||
|
||
|
||
# %% [code] | ||
class Checkpoint(): | ||
def __init__(self, phase, metric): | ||
super().__init__() | ||
self.phase = phase | ||
self.metric = metric | ||
self.reset() | ||
|
||
def step(self, epoch, model): | ||
loss = epoch['score'][self.phase][self.metric][-1] | ||
|
||
if self.best is None or loss < self.best: | ||
self.state_dict = model.state_dict() | ||
self.best = loss | ||
|
||
print('Checkpoint.', flush=True) | ||
|
||
def load_state_dict(self, model): | ||
model.load_state_dict(self.state_dict) | ||
|
||
def reset(self): | ||
self.best = None | ||
self.state_dict = None | ||
|
||
def __repr__(self): | ||
return f'Best {self.phase} {self.metric}: {self.best:.4f}' | ||
|
||
|
||
# %% [code] | ||
def epochs(start, stop): | ||
since = time() | ||
|
||
epoch = { | ||
'start': start, | ||
'stop': stop, | ||
'curr': None, | ||
'score': {}, | ||
} | ||
|
||
for epoch['curr'] in range(start, stop): | ||
|
||
yield epoch | ||
|
||
for phase in epoch['score']: | ||
s = epoch['score'][phase] | ||
score_str = [f'{m}: {s[m][-1]:.4f}' for m in s] | ||
score_str = ', '.join(score_str).strip(', ') | ||
|
||
print(f'{phase} - {score_str}', flush=True) | ||
|
||
elapsed = time() - since | ||
m, s = elapsed // 60, elapsed % 60 | ||
|
||
print(flush=True) | ||
print(f'Training complete in {m:.0f}m {s:.0f}s') | ||
|
||
|
||
# %% [code] | ||
def steps(epoch, dataloader, with_tqdm=True): | ||
desc = f'Epoch {epoch["curr"] + 1}/{epoch["stop"]}' | ||
|
||
count_agg, score_agg, score = 0, {}, {} | ||
loader = tqdm(dataloader, desc=desc) if with_tqdm else dataloader | ||
|
||
for data in loader: | ||
try: | ||
count = data.shape[0] | ||
except: | ||
count = data[0].shape[0] | ||
count_agg += count | ||
|
||
yield score, data | ||
|
||
for p in score: | ||
if p not in score_agg: | ||
score_agg[p] = {} | ||
for m in score[p]: | ||
try: | ||
s = score[p][m].item() | ||
except: | ||
s = score[p][m] | ||
if m not in score_agg[p]: | ||
score_agg[p][m] = s | ||
else: | ||
score_agg[p][m] += s * count | ||
|
||
for p in score_agg: | ||
if p not in epoch['score']: | ||
epoch['score'][p] = {} | ||
for m in score_agg[p]: | ||
s = score_agg[p][m] / count_agg | ||
if m not in epoch['score'][p]: | ||
epoch['score'][p][m] = [s] | ||
else: | ||
epoch['score'][p][m].append(s) | ||
|
||
|
||
# %% [code] | ||
def folds(splits): | ||
for k, split in enumerate(splits): | ||
print(12 * '-') | ||
print(f' Fold {k:02d}:') | ||
print(12 * '-') | ||
|
||
yield k, split | ||
|