Skip to content

Commit

Permalink
Remove hyperparameter logging from TensorBoardCallback (jankrepl#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
jankrepl authored Dec 6, 2020
1 parent eb3b02d commit b56e63d
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 6 deletions.
7 changes: 1 addition & 6 deletions deepdow/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,14 +521,13 @@ class TensorBoardCallback(Callback):
Currently supports:
- images (evolution of predicted weights over time)
- histograms (activations of input and outputs of all layers)
- hyperparamters
- scalars (logged metrics)
Parameters
----------
log_dir : None or str or pathlib.Path
Represent the folder where to checkpoints will be saved. If None then using
`cwd/runs/CURRENT_DATETIME_HOSTNAME`. Else the exact path.
the current working directory. Else the exact path.
ts : datetime.datetime or None
If ``datetime.datetime``, then only logging specific sample corresponding to provided timestamp.
Expand Down Expand Up @@ -605,7 +604,6 @@ def on_batch_end(self, metadata):
def on_epoch_end(self, metadata):
"""Log images, metrics and hyperparamters."""
epoch = metadata.get('epoch')
n_epochs = metadata.get('n_epochs')

# create weight image
master_df = pd.concat(self.weights).sort_index()
Expand All @@ -622,9 +620,6 @@ def on_epoch_end(self, metadata):
for metric_name, metric_value in metrics.items():
self.writer.add_scalar(metric_name, metric_value, global_step=epoch)

if epoch == n_epochs - 1:
self.writer.add_hparams(self.run.hparams, metrics)

except KeyError:
pass

Expand Down
1 change: 1 addition & 0 deletions tests/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def test_attributes_after_construction(self, dataloader_dummy, additional_kwargs
assert dataloader_dummy is run.train_dataloader
assert isinstance(run.metrics, dict)
assert isinstance(run.val_dataloaders, dict)
assert isinstance(run.hparams, dict)

def test_launch(self, dataloader_dummy):
network = DummyNet(n_channels=dataloader_dummy.dataset.X.shape[1])
Expand Down

0 comments on commit b56e63d

Please sign in to comment.