Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
rlmwang authored Dec 21, 2020
1 parent 28214b8 commit 6716bc9
Showing 1 changed file with 183 additions and 0 deletions.
183 changes: 183 additions & 0 deletions torch_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# %% [code]
import torch, numpy as np

from time import time
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

# %% [markdown]
# # Slicing


# %% [code]
def gather_slice(inputs, dim, index, length):
"""
Gather a slice (index : index + length) along dimension dim.
"""
shape = torch.ones(len(inputs.shape))
shape[dim] = -1

slice = torch.arange(0, length, device=inputs.device)
slice = slice.view(tuple(int(s) for s in shape))

index = index.unsqueeze(dim)
index = index + slice
index = index.to(torch.int64)

return torch.gather(inputs, dim, index)


# %% [code]
def centered_slice(centers, length, lower_bound=0, upper_bound=None):
"""
Create slices of equal length centered at the centers,
contained within the lower and upper bounds.
"""
start = centers - length / 2

if lower_bound is not None:
start = np.maximum(start, 0)

end = start + length

if upper_bound is not None:
end = np.minimum(end, upper_bound)
start = end - length

return start, end

# %% [markdown]
# # Training


# %% [code]
def set_parameter_requires_grad(model, value):
for param in model.parameters():
param.requires_grad = value


# %% [code]
def evaluate(metrics, output, labels):
"""
Default model evaluation.
"""
scores = {}

with torch.no_grad():
output = output.detach()

for key in metrics:
try:
scores[key] = metrics[key](output, labels).item()
except:
scores[key] = metrics[key](output.cpu(), labels.cpu())

return scores


# %% [code]
class Checkpoint():
def __init__(self, phase, metric):
super().__init__()
self.phase = phase
self.metric = metric
self.reset()

def step(self, epoch, model):
loss = epoch['score'][self.phase][self.metric][-1]

if self.best is None or loss < self.best:
self.state_dict = model.state_dict()
self.best = loss

print('Checkpoint.', flush=True)

def load_state_dict(self, model):
model.load_state_dict(self.state_dict)

def reset(self):
self.best = None
self.state_dict = None

def __repr__(self):
return f'Best {self.phase} {self.metric}: {self.best:.4f}'


# %% [code]
def epochs(start, stop):
since = time()

epoch = {
'start': start,
'stop': stop,
'curr': None,
'score': {},
}

for epoch['curr'] in range(start, stop):

yield epoch

for phase in epoch['score']:
s = epoch['score'][phase]
score_str = [f'{m}: {s[m][-1]:.4f}' for m in s]
score_str = ', '.join(score_str).strip(', ')

print(f'{phase} - {score_str}', flush=True)

elapsed = time() - since
m, s = elapsed // 60, elapsed % 60

print(flush=True)
print(f'Training complete in {m:.0f}m {s:.0f}s')


# %% [code]
def steps(epoch, dataloader, with_tqdm=True):
desc = f'Epoch {epoch["curr"] + 1}/{epoch["stop"]}'

count_agg, score_agg, score = 0, {}, {}
loader = tqdm(dataloader, desc=desc) if with_tqdm else dataloader

for data in loader:
try:
count = data.shape[0]
except:
count = data[0].shape[0]
count_agg += count

yield score, data

for p in score:
if p not in score_agg:
score_agg[p] = {}
for m in score[p]:
try:
s = score[p][m].item()
except:
s = score[p][m]
if m not in score_agg[p]:
score_agg[p][m] = s
else:
score_agg[p][m] += s * count

for p in score_agg:
if p not in epoch['score']:
epoch['score'][p] = {}
for m in score_agg[p]:
s = score_agg[p][m] / count_agg
if m not in epoch['score'][p]:
epoch['score'][p][m] = [s]
else:
epoch['score'][p][m].append(s)


# %% [code]
def folds(splits):
for k, split in enumerate(splits):
print(12 * '-')
print(f' Fold {k:02d}:')
print(12 * '-')

yield k, split

0 comments on commit 6716bc9

Please sign in to comment.