Skip to content

Commit

Permalink
Merge pull request #121 from MDIL-SNU/dev
Browse files Browse the repository at this point in the history
refactor: please use get
  • Loading branch information
YutackPark authored Nov 7, 2024
2 parents 7f2a026 + f86c141 commit 32e1357
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ All notable changes to this project will be documented in this file.
- Save checkpoint_0.pth (model before any training)
- `SevenNetGraphDataset._file_to_graph_list` -> `SevenNetGraphDataset.file_to_graph_list`
- Refactoring `SevenNetGraphDataset`, skips computing statistics if it is loaded, more detailed logging
- Prefer use .get when accessing config dict
### Fixed
- Fix error when loading `SevenNetGraphDataset` with other types of data (ex: extxyz) in one dataset

Expand Down
15 changes: 8 additions & 7 deletions sevenn/error_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,16 +331,17 @@ def init_total_loss_metric(config, criteria):

@staticmethod
def from_config(config: dict):
loss_cls = loss_dict[config[KEY.LOSS].lower()]
try:
loss_param = config[KEY.LOSS_PARAM]
except KeyError:
loss_param = {}
loss_cls = loss_dict[config.get(KEY.LOSS, 'mse').lower()]
loss_param = config.get(KEY.LOSS_PARAM, {})
criteria = loss_cls(**loss_param)

err_config = config[KEY.ERROR_RECORD]
err_config = config.get(KEY.ERROR_RECORD, False)
if not err_config:
raise ValueError(
'No error_record config found. Consider util.get_error_recorder'
)
err_config_n = []
if not config[KEY.IS_TRAIN_STRESS]:
if not config.get(KEY.IS_TRAIN_STRESS, True):
for err_type, metric_name in err_config:
if 'Stress' in err_type:
continue
Expand Down
4 changes: 2 additions & 2 deletions sevenn/scripts/processing_epoch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def processing_epoch_v2(
prefix = f'{os.path.abspath(working_dir)}/'

total_epoch = total_epoch or config[KEY.EPOCH]
per_epoch = per_epoch or config[KEY.PER_EPOCH]
best_metric = best_metric or config[KEY.BEST_METRIC]
per_epoch = per_epoch or config.get(KEY.PER_EPOCH, 10)
best_metric = best_metric or config.get(KEY.BEST_METRIC, 'TotalLoss')
recorder = error_recorder or ErrorRecorder.from_config(config)
recorders = {k: deepcopy(recorder) for k in loaders}

Expand Down
18 changes: 17 additions & 1 deletion sevenn/train/graph_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,24 @@ def __init__(
processed_name += '.pt'
self._processed_names = [
processed_name, # {root}/sevenn_data/{name}.pt
processed_name.replace('pt', 'yaml'),
processed_name.replace('.pt', '.yaml'),
]

root = root or './'
_pdir = os.path.join(root, 'sevenn_data')
_pt = os.path.join(_pdir, self._processed_names[0])
if not os.path.exists(_pt) and len(self._files) == 0:
raise ValueError((
f'{_pt} not found and no files to process. '
+ 'If you copied only .pt file, please copy '
+ 'whole sevenn_data dir without changing its name.'
+ ' They all work together.'
))

_yam = os.path.join(_pdir, self._processed_names[1])
if not os.path.exists(_yam) and len(self._files) == 0:
raise ValueError(f'{_yam} not found and no files to process')

self.process_num_cores = process_num_cores
self.process_kwargs = process_kwargs

Expand Down
16 changes: 9 additions & 7 deletions sevenn/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,15 @@ def from_config(model: torch.nn.Module, config: Dict[str, Any]) -> 'Trainer':
trainer = Trainer(
model,
loss_functions=get_loss_functions_from_config(config),
optimizer_cls=optim_dict[config[KEY.OPTIMIZER].lower()],
optimizer_args=config[KEY.OPTIM_PARAM],
scheduler_cls=scheduler_dict[config[KEY.SCHEDULER].lower()],
scheduler_args=config[KEY.SCHEDULER_PARAM],
device=config[KEY.DEVICE],
distributed=config[KEY.IS_DDP],
distributed_backend=config[KEY.DDP_BACKEND]
optimizer_cls=optim_dict[config.get(KEY.OPTIMIZER, 'adam').lower()],
optimizer_args=config.get(KEY.OPTIM_PARAM, {}),
scheduler_cls=scheduler_dict[
config.get(KEY.SCHEDULER, 'exponentiallr').lower()
],
scheduler_args=config.get(KEY.SCHEDULER_PARAM, {}),
device=config.get(KEY.DEVICE, 'auto'),
distributed=config.get(KEY.IS_DDP, False),
distributed_backend=config.get(KEY.DDP_BACKEND, 'nccl'),
)
return trainer

Expand Down

0 comments on commit 32e1357

Please sign in to comment.