-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathmisc.py
101 lines (79 loc) · 2.64 KB
/
misc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import time
import random
import logging
import torch
import numpy as np
import yaml
from easydict import EasyDict
from logging import Logger
from tqdm.auto import tqdm
class BlackHole(object):
def __setattr__(self, name, value):
pass
def __call__(self, *args, **kwargs):
return self
def __getattr__(self, name):
return self
def load_config(path):
with open(path, 'r') as f:
return EasyDict(yaml.safe_load(f))
def get_logger(name, log_dir=None):
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s')
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.DEBUG)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
if log_dir is not None:
file_handler = logging.FileHandler(os.path.join(log_dir, 'log.txt'))
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
return logger
def get_new_log_dir(root='./logs', prefix='', tag=''):
fn = time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime())
if prefix != '':
fn = prefix + '_' + fn
if tag != '':
fn = fn + '_' + tag
log_dir = os.path.join(root, fn)
os.makedirs(log_dir)
return log_dir
def seed_all(seed):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def log_hyperparams(writer, args):
from torch.utils.tensorboard.summary import hparams
vars_args = {k:v if isinstance(v, str) else repr(v) for k, v in vars(args).items()}
exp, ssi, sei = hparams(vars_args, {})
writer.file_writer.add_summary(exp)
writer.file_writer.add_summary(ssi)
writer.file_writer.add_summary(sei)
def int_tuple(argstr):
return tuple(map(int, argstr.split(',')))
def str_tuple(argstr):
return tuple(argstr.split(','))
def unique(x, dim=None):
"""Unique elements of x and indices of those unique elements
https://github.com/pytorch/pytorch/issues/36748#issuecomment-619514810
e.g.
unique(tensor([
[1, 2, 3],
[1, 2, 4],
[1, 2, 3],
[1, 2, 5]
]), dim=0)
=> (tensor([[1, 2, 3],
[1, 2, 4],
[1, 2, 5]]),
tensor([0, 1, 3]))
"""
unique, inverse = torch.unique(
x, sorted=True, return_inverse=True, dim=dim)
perm = torch.arange(inverse.size(0), dtype=inverse.dtype,
device=inverse.device)
inverse, perm = inverse.flip([0]), perm.flip([0])
return unique, inverse.new_empty(unique.size(dim)).scatter_(0, inverse, perm)