diff --git a/model/callbacks.py b/model/callbacks.py new file mode 100644 index 0000000..9ad1a29 --- /dev/null +++ b/model/callbacks.py @@ -0,0 +1,89 @@ +import os + +import mlflow +import numpy as np + + +class ModelCheckPoint: + + def __init__(self, file, mf_logger=None, save_best=True, monitor='val_loss', mode='min'): + self.file = file + save_dir = os.path.dirname(self.file) + if not os.path.isdir(save_dir): + os.makedirs(save_dir) + self.mf_logger = mf_logger + self.save_best = save_best + self.monitor = monitor + self.mode = mode + init_values = {'min': np.inf, 'max': -np.inf} + self.best_score = init_values[mode] + + def __call__(self, model, history): + val_score = history[self.monitor] + check_point = self.file.format(**history) + + if not self.save_best: + self.save_model(model, check_point) + elif self._best(val_score, self.best_score): + self.best_score = val_score + self.save_model(model, check_point) + + def _best(self, val, best): + if self.mode == 'min': + return val <= best + else: + return val >= best + + def save_model(self, model, file_name): + if self.mf_logger is not None: + self.mf_logger.log_model(model, "torch_model") + model.save(file_name) + + +class TrainHistory: + + def __init__(self, file): + self.file = file + if os.path.isfile(self.file): + with open(self.file, 'a') as f: + f.write('\n') + + def __call__(self, model, history): + with open(self.file, 'a+') as f: + f.write(str(history) + '\n') + + +class MlflowLogger: + + def __init__(self, experiment_name: str, model_params: dict, run_name=None): + self.experiment_name = experiment_name + self.run_name = run_name + self.model_params = model_params + self._set_env() + self.run_id = self._get_run_id() + + def __call__(self, model, history): + with mlflow.start_run(run_id=self.run_id): + mlflow.log_metrics(history, step=history['epoch']) + + def __eq__(self, other): + return "MLFlow" == other + + def _get_run_id(self): + with mlflow.start_run(run_name=self.run_name) as mlflow_run: + mlflow.log_params(self.model_params) + run_id = mlflow_run.info.run_id + return run_id + + def _set_env(self): + if os.getenv('MLFLOW_TRACKING_URI') is None: + raise ValueError("Environment variable MLFLOW_TRACKING_URI is not exist") + + mlflow.set_experiment(self.experiment_name) + + def log_model(self, model, name): + with mlflow.start_run(run_id=self.run_id): + mlflow.pytorch.log_model( + pytorch_model=model, + artifact_path=name + ) \ No newline at end of file diff --git a/requirement.txt b/requirement.txt index 3932004..0116206 100644 --- a/requirement.txt +++ b/requirement.txt @@ -1,19 +1,37 @@ +alembic==1.4.1 appnope==0.1.2 backcall==0.2.0 certifi==2021.10.8 +charset-normalizer==2.0.10 +click==8.0.3 +cloudpickle==2.0.0 cycler==0.11.0 +databricks-cli==0.16.2 debugpy==1.5.1 decorator==5.1.0 +docker==5.0.3 entrypoints==0.3 +Flask==2.0.2 fonttools==4.28.5 +gitdb==4.0.9 +GitPython==3.1.26 +greenlet==1.1.2 +gunicorn==20.1.0 +idna==3.3 +importlib-metadata==4.10.0 ipykernel==6.6.0 ipython==7.30.1 +itsdangerous==2.0.1 jedi==0.18.1 +Jinja2==3.0.3 jupyter-client==7.1.0 jupyter-core==4.9.1 kiwisolver==1.3.2 +Mako==1.1.6 +MarkupSafe==2.0.1 matplotlib==3.5.1 matplotlib-inline==0.1.3 +mlflow==1.22.0 nest-asyncio==1.5.4 numpy==1.21.4 packaging==21.3 @@ -22,16 +40,31 @@ parso==0.8.3 pexpect==4.8.0 pickleshare==0.7.5 Pillow==8.4.0 +prometheus-client==0.12.0 +prometheus-flask-exporter==0.18.7 prompt-toolkit==3.0.23 +protobuf==3.19.2 ptyprocess==0.7.0 Pygments==2.10.0 pyparsing==3.0.6 python-dateutil==2.8.2 +python-editor==1.0.4 pytz==2021.3 +PyYAML==6.0 pyzmq==22.3.0 +querystring-parser==1.2.4 +requests==2.27.1 six==1.16.0 +smmap==5.0.0 +SQLAlchemy==1.4.29 +sqlparse==0.4.2 +tabulate==0.8.9 torch==1.10.0 tornado==6.1 traitlets==5.1.1 typing_extensions==4.0.1 +urllib3==1.26.8 wcwidth==0.2.5 +websocket-client==1.2.3 +Werkzeug==2.0.2 +zipp==3.7.0