Skip to content

Commit

Permalink
Updating xnas implementation to incorporate per example gradients. In…
Browse files Browse the repository at this point in the history
… progress.
  • Loading branch information
Debadeepta Dey committed Jun 9, 2020
1 parent 110e545 commit 2727507
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 17 deletions.
32 changes: 23 additions & 9 deletions archai/algos/xnas/xnas_arch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def pre_fit(self, train_dl: DataLoader, val_dl: Optional[DataLoader]) -> None:
lossfn = ml_utils.get_lossfn(self._conf_w_lossfn).to(self.get_device())

conf = get_conf()
num_val_examples = len(val_dl) * conf['nas']['search']['loader']['train_batch']
self._train_batch = conf['nas']['search']['loader']['train_batch']
num_val_examples = len(val_dl) * self._train_batch
num_cells = conf['nas']['search']['model_desc']['n_cells']
num_reduction_cells = conf['nas']['search']['model_desc']['n_reductions']
num_normal_cells = num_cells - num_reduction_cells
Expand All @@ -58,13 +59,13 @@ def pre_fit(self, train_dl: DataLoader, val_dl: Optional[DataLoader]) -> None:
assert num_normal_cells > 0
assert num_primitives > 0

normal_cell_effective_t = num_val_examples * self._epochs * num_normal_cells
reduction_cell_effective_t = num_val_examples * self._epochs * num_reduction_cells
self._normal_cell_effective_t = num_val_examples * self._epochs * num_normal_cells
self._reduction_cell_effective_t = num_val_examples * self._epochs * num_reduction_cells

self._normal_cell_lr = ma.sqrt(2 * ma.log(num_primitives) / (normal_cell_effective_t * self._grad_clip * self._grad_clip))
self._reduction_cell_lr = ma.sqrt(2 * ma.log(num_primitives) / (reduction_cell_effective_t * self._grad_clip * self._grad_clip))
self._normal_cell_lr = ma.sqrt(2 * ma.log(num_primitives) / (self._normal_cell_effective_t * self._grad_clip * self._grad_clip))
self._reduction_cell_lr = ma.sqrt(2 * ma.log(num_primitives) / (self._reduction_cell_effective_t * self._grad_clip * self._grad_clip))

self._xnas_optim = _XnasOptimizer(self._normal_cell_lr, self._reduction_cell_lr, self.model, lossfn)
self._xnas_optim = _XnasOptimizer(self._normal_cell_lr, self._reduction_cell_lr, self._normal_cell_effective_t, self._reduction_cell_effective_t, self._train_batch, self.model, lossfn)

@overrides
def post_fit(self, train_dl: DataLoader, val_dl: Optional[DataLoader]) -> None:
Expand Down Expand Up @@ -100,8 +101,11 @@ def pre_step(self, x: Tensor, y: Tensor) -> None:
self.get_device(), non_blocking=True)

# update alphas

self._multi_optim.zero_grad()
self._xnas_optim.step(x, y, x_val, y_val,
self.epoch(), self._epochs, self._grad_clip)
self.epoch(), self._epochs, self._grad_clip, self._multi_optim)
self._multi_optim.zero_grad()

@overrides
def update_checkpoint(self, checkpoint: CheckPoint) -> None:
Expand All @@ -110,9 +114,13 @@ def update_checkpoint(self, checkpoint: CheckPoint) -> None:

class _XnasOptimizer:
def __init__(self, ncell_lr: float, rcell_lr: float,
ncell_effective_t: float, rcell_effective_t: float, train_batch: int,
model: Model, lossfn: _Loss) -> None:
self._ncell_lr = ncell_lr
self._rcell_lr = rcell_lr
self._ncell_effective_t = ncell_effective_t
self._rcell_effective_t = rcell_effective_t
self._train_batch = train_batch

self._lossfn = lossfn
self._model = model # main model with respect to w and alpha
Expand All @@ -122,7 +130,7 @@ def _get_loss(model, lossfn, x, y):
logits, *_ = model(x) # might also return aux tower logits
return lossfn(logits, y)

def step(self, x_train: Tensor, y_train: Tensor, x_valid: Tensor, y_valid: Tensor, epoch: int, epochs: int, grad_clip: float) -> None:
def step(self, x_train: Tensor, y_train: Tensor, x_valid: Tensor, y_valid: Tensor, epoch: int, epochs: int, grad_clip: float, optim) -> None:
# put model in train mode just to be safe
self._model.train()

Expand All @@ -132,14 +140,20 @@ def step(self, x_train: Tensor, y_train: Tensor, x_valid: Tensor, y_valid: Tenso
# compute gradients
loss.backward()

# do grad clip
self._apex.clip_grad(grad_clip, model, optim)

# for each op in the model update alphas
for cell in self._model.cells:
if cell.desc.cell_type == CellType.Reduction:
lr = self._rcell_lr
T = self._rcell_effective_t
elif cell.desc.cell_type == CellType.Regular:
lr = self._ncell_lr
T = self._ncell_effective_t
else:
raise NotImplementedError

t = epoch * self._train_batch
for op in cell.ops():
op.update_alphas(lr, epoch, epochs, grad_clip)
op.update_alphas(lr, t, T, grad_clip)
19 changes: 11 additions & 8 deletions archai/algos/xnas/xnas_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,24 @@ def __init__(self, op_desc:OpDesc, arch_params:Optional[ArchParams],
# any previous child modules
self._setup_arch_params(arch_params)

def update_alphas(self, eta:float, epoch:int, epochs:int, grad_clip:float):
def update_alphas(self, eta:float, current_t:int, total_t:int, grad_clip:float):
grad_flat = torch.flatten(self._grad)
rewards = torch.tensor([-torch.dot(grad_flat, torch.flatten(activ)) for activ in self._activs])
exprewards = torch.exp(eta * rewards).cuda()
# NOTE: Will this remain registered?
self._alphas[0] = torch.mul(self._alphas[0], exprewards)

# weak learner eviction
theta = max(self._alphas[0]) * ma.exp(-2 * eta * grad_clip * (epochs - epoch))
assert len(self._ops) == self._alphas[0].shape[0]
to_keep_mask = self._alphas[0] >= theta
num_ops_kept = torch.sum(to_keep_mask).item()
assert num_ops_kept > 0
# zero out the weights which are evicted
self._alphas[0] = torch.mul(self._alphas[0], to_keep_mask)
conf = get_conf()
to_evict = conf['nas']['search']['xnas']['to_evict']
if to_evict:
theta = max(self._alphas[0]) * ma.exp(-2 * eta * grad_clip * (total_t - current_t))
assert len(self._ops) == self._alphas[0].shape[0]
to_keep_mask = self._alphas[0] >= theta
num_ops_kept = torch.sum(to_keep_mask).item()
assert num_ops_kept > 0
# zero out the weights which are evicted
self._alphas[0] = torch.mul(self._alphas[0], to_keep_mask)

# save some debugging info
expdir = get_expdir()
Expand Down
2 changes: 2 additions & 0 deletions confs/algos/xnas.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ __include__: "darts.yaml" # defaults are loaded from this file
# XNAS's parameters
nas:
search:
xnas:
to_evict: True
loader:
train_batch: 64
trainer:
Expand Down

0 comments on commit 2727507

Please sign in to comment.