Skip to content

Commit

Permalink
add wandb logger
Browse files Browse the repository at this point in the history
  • Loading branch information
kkoutini committed Apr 12, 2023
1 parent 9956391 commit a3ea2a8
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 331 deletions.
27 changes: 27 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,30 @@ environment_builds.yml
# Output
lightning_logs/*
audioset_hdf5s/*
.vscode/settings.json
wandb/latest-run
wandb/run-20230412_111735-7leazcov/run-7leazcov.wandb
wandb/run-20230412_111735-7leazcov/files/conda-environment.yaml
wandb/run-20230412_111735-7leazcov/files/config.yaml
wandb/run-20230412_111735-7leazcov/files/diff.patch
wandb/run-20230412_111735-7leazcov/files/requirements.txt
wandb/run-20230412_111735-7leazcov/files/wandb-metadata.json
wandb/run-20230412_111735-7leazcov/files/wandb-summary.json
wandb/run-20230412_111735-7leazcov/files/code/ex_openmic.py
wandb/run-20230412_113812-d6v0j0ob/run-d6v0j0ob.wandb
wandb/run-20230412_113812-d6v0j0ob/files/conda-environment.yaml
wandb/run-20230412_113812-d6v0j0ob/files/config.yaml
wandb/run-20230412_113812-d6v0j0ob/files/diff.patch
wandb/run-20230412_113812-d6v0j0ob/files/requirements.txt
wandb/run-20230412_113812-d6v0j0ob/files/wandb-metadata.json
wandb/run-20230412_113812-d6v0j0ob/files/wandb-summary.json
wandb/run-20230412_113812-d6v0j0ob/files/code/ex_openmic.py
wandb/run-20230412_113812-d6v0j0ob/files/passt_openmic/d6v0j0ob/checkpoints/epoch=9-step=49.ckpt
wandb/run-20230412_114037-rus5aih1/run-rus5aih1.wandb
wandb/run-20230412_114037-rus5aih1/files/conda-environment.yaml
wandb/run-20230412_114037-rus5aih1/files/config.yaml
wandb/run-20230412_114037-rus5aih1/files/diff.patch
wandb/run-20230412_114037-rus5aih1/files/requirements.txt
wandb/run-20230412_114037-rus5aih1/files/wandb-metadata.json
wandb/run-20230412_114037-rus5aih1/files/wandb-summary.json
wandb/run-20230412_114037-rus5aih1/files/code/ex_openmic.py
22 changes: 8 additions & 14 deletions ba3l/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ba3l.ingredients.datasets import Datasets
from ba3l.ingredients.models import Models, Model
from ba3l.plutils.lr_monitor import LearningRateMonitor
from ba3l.trainer import Trainer
#from ba3l.trainer import Trainer
from ba3l.util.sacred_logger import SacredLogger
from ba3l.plutils.progress_bar import ProgressBar
from sacred import Experiment as Sacred_Experiment, Ingredient
Expand All @@ -31,11 +31,13 @@ def config_recursive_apply(conf, fn):
fn(k,v)


def get_loggers(expr, use_tensorboard_logger=False):
sacred_logger = SacredLogger(expr)
loggers = [sacred_logger]
def get_loggers(use_tensorboard_logger=False, use_sacred_logger=False):
loggers = []
if use_sacred_logger:
loggers.append( SacredLogger(expr))
if use_tensorboard_logger:
loggers.append(pl_loggers.TensorBoardLogger(sacred_logger.name))

return loggers


Expand All @@ -52,7 +54,6 @@ def __init__(
name: Optional[str] = None,
ingredients: Sequence[Ingredient] = (),
datasets: Optional[Ingredient] = None,
trainer: Optional[Ingredient] = None,
models: Optional[Ingredient] = None,
interactive: bool = False,
base_dir: Optional[PathType] = None,
Expand Down Expand Up @@ -99,12 +100,9 @@ def __init__(
if datasets is None:
datasets = Datasets.get_instance()
self.datasets = datasets
if trainer is None:
trainer = Trainer.get_instance(datasets=datasets, models=models)
self.trainer = trainer
if ingredients is None:
ingredients = []
ingredients = list(ingredients) + [models, datasets, trainer]
ingredients = list(ingredients) + [models, datasets]
caller_globals = inspect.stack()[1][0].f_globals
self.last_default_configuration_position = 0
super().__init__(
Expand All @@ -117,16 +115,12 @@ def __init__(
save_git_info=save_git_info,
caller_globals=caller_globals
)
self.trainer.command(get_loggers, static_args={"expr": self})
self.trainer.command(get_callbacks, static_args={"expr": self})
# filling out Default config


def get_run_identifier(self):
return str(self.current_run.db_identifier) \
+ "_" + str(self.current_run._id)

def get_trainer(self, *args, **kw):
return self.trainer.get_trainer(*args, **kw)

def get_dataloaders(self, filter={}):
results = {}
Expand Down
2 changes: 2 additions & 0 deletions ba3l/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,6 @@ def __init__(self, experiment):
self.config = DefaultMunch.fromDict(experiment.current_run.config)
for key,model in experiment.current_run.config['models'].items():
setattr(self, key, experiment.current_run.get_command_function("models."+key+"."+model['instance_cmd'])())
self.save_hyperparameters(self.config)


Loading

0 comments on commit a3ea2a8

Please sign in to comment.