diff --git a/CHANGELOG.md b/CHANGELOG.md index d485ae6..a642b59 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ All notable changes to this project will be documented in this file. - Save checkpoint_0.pth (model before any training) - `SevenNetGraphDataset._file_to_graph_list` -> `SevenNetGraphDataset.file_to_graph_list` - Refactoring `SevenNetGraphDataset`, skips computing statistics if it is loaded, more detailed logging +- Prefer use .get when accessing config dict ### Fixed - Fix error when loading `SevenNetGraphDataset` with other types of data (ex: extxyz) in one dataset diff --git a/sevenn/error_recorder.py b/sevenn/error_recorder.py index 96c342d..9c27bf6 100644 --- a/sevenn/error_recorder.py +++ b/sevenn/error_recorder.py @@ -331,16 +331,17 @@ def init_total_loss_metric(config, criteria): @staticmethod def from_config(config: dict): - loss_cls = loss_dict[config[KEY.LOSS].lower()] - try: - loss_param = config[KEY.LOSS_PARAM] - except KeyError: - loss_param = {} + loss_cls = loss_dict[config.get(KEY.LOSS, 'mse').lower()] + loss_param = config.get(KEY.LOSS_PARAM, {}) criteria = loss_cls(**loss_param) - err_config = config[KEY.ERROR_RECORD] + err_config = config.get(KEY.ERROR_RECORD, False) + if not err_config: + raise ValueError( + 'No error_record config found. Consider util.get_error_recorder' + ) err_config_n = [] - if not config[KEY.IS_TRAIN_STRESS]: + if not config.get(KEY.IS_TRAIN_STRESS, True): for err_type, metric_name in err_config: if 'Stress' in err_type: continue diff --git a/sevenn/scripts/processing_epoch.py b/sevenn/scripts/processing_epoch.py index aabcb27..d2987bf 100644 --- a/sevenn/scripts/processing_epoch.py +++ b/sevenn/scripts/processing_epoch.py @@ -32,8 +32,8 @@ def processing_epoch_v2( prefix = f'{os.path.abspath(working_dir)}/' total_epoch = total_epoch or config[KEY.EPOCH] - per_epoch = per_epoch or config[KEY.PER_EPOCH] - best_metric = best_metric or config[KEY.BEST_METRIC] + per_epoch = per_epoch or config.get(KEY.PER_EPOCH, 10) + best_metric = best_metric or config.get(KEY.BEST_METRIC, 'TotalLoss') recorder = error_recorder or ErrorRecorder.from_config(config) recorders = {k: deepcopy(recorder) for k in loaders} diff --git a/sevenn/train/graph_dataset.py b/sevenn/train/graph_dataset.py index 4f4d8ab..8ccd1b4 100644 --- a/sevenn/train/graph_dataset.py +++ b/sevenn/train/graph_dataset.py @@ -172,8 +172,24 @@ def __init__( processed_name += '.pt' self._processed_names = [ processed_name, # {root}/sevenn_data/{name}.pt - processed_name.replace('pt', 'yaml'), + processed_name.replace('.pt', '.yaml'), ] + + root = root or './' + _pdir = os.path.join(root, 'sevenn_data') + _pt = os.path.join(_pdir, self._processed_names[0]) + if not os.path.exists(_pt) and len(self._files) == 0: + raise ValueError(( + f'{_pt} not found and no files to process. ' + + 'If you copied only .pt file, please copy ' + + 'whole sevenn_data dir without changing its name.' + + ' They all work together.' + )) + + _yam = os.path.join(_pdir, self._processed_names[1]) + if not os.path.exists(_yam) and len(self._files) == 0: + raise ValueError(f'{_yam} not found and no files to process') + self.process_num_cores = process_num_cores self.process_kwargs = process_kwargs diff --git a/sevenn/train/trainer.py b/sevenn/train/trainer.py index 7caf1a0..eb26b60 100644 --- a/sevenn/train/trainer.py +++ b/sevenn/train/trainer.py @@ -82,13 +82,15 @@ def from_config(model: torch.nn.Module, config: Dict[str, Any]) -> 'Trainer': trainer = Trainer( model, loss_functions=get_loss_functions_from_config(config), - optimizer_cls=optim_dict[config[KEY.OPTIMIZER].lower()], - optimizer_args=config[KEY.OPTIM_PARAM], - scheduler_cls=scheduler_dict[config[KEY.SCHEDULER].lower()], - scheduler_args=config[KEY.SCHEDULER_PARAM], - device=config[KEY.DEVICE], - distributed=config[KEY.IS_DDP], - distributed_backend=config[KEY.DDP_BACKEND] + optimizer_cls=optim_dict[config.get(KEY.OPTIMIZER, 'adam').lower()], + optimizer_args=config.get(KEY.OPTIM_PARAM, {}), + scheduler_cls=scheduler_dict[ + config.get(KEY.SCHEDULER, 'exponentiallr').lower() + ], + scheduler_args=config.get(KEY.SCHEDULER_PARAM, {}), + device=config.get(KEY.DEVICE, 'auto'), + distributed=config.get(KEY.IS_DDP, False), + distributed_backend=config.get(KEY.DDP_BACKEND, 'nccl'), ) return trainer