Skip to content

Commit

Permalink
callback 함수 추가,
Browse files Browse the repository at this point in the history
mlflow 라이브러리 설치
  • Loading branch information
cksruf91 committed Jan 11, 2022
1 parent c1662b3 commit 4f83e23
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 0 deletions.
89 changes: 89 additions & 0 deletions model/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import os

import mlflow
import numpy as np


class ModelCheckPoint:

def __init__(self, file, mf_logger=None, save_best=True, monitor='val_loss', mode='min'):
self.file = file
save_dir = os.path.dirname(self.file)
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
self.mf_logger = mf_logger
self.save_best = save_best
self.monitor = monitor
self.mode = mode
init_values = {'min': np.inf, 'max': -np.inf}
self.best_score = init_values[mode]

def __call__(self, model, history):
val_score = history[self.monitor]
check_point = self.file.format(**history)

if not self.save_best:
self.save_model(model, check_point)
elif self._best(val_score, self.best_score):
self.best_score = val_score
self.save_model(model, check_point)

def _best(self, val, best):
if self.mode == 'min':
return val <= best
else:
return val >= best

def save_model(self, model, file_name):
if self.mf_logger is not None:
self.mf_logger.log_model(model, "torch_model")
model.save(file_name)


class TrainHistory:

def __init__(self, file):
self.file = file
if os.path.isfile(self.file):
with open(self.file, 'a') as f:
f.write('\n')

def __call__(self, model, history):
with open(self.file, 'a+') as f:
f.write(str(history) + '\n')


class MlflowLogger:

def __init__(self, experiment_name: str, model_params: dict, run_name=None):
self.experiment_name = experiment_name
self.run_name = run_name
self.model_params = model_params
self._set_env()
self.run_id = self._get_run_id()

def __call__(self, model, history):
with mlflow.start_run(run_id=self.run_id):
mlflow.log_metrics(history, step=history['epoch'])

def __eq__(self, other):
return "MLFlow" == other

def _get_run_id(self):
with mlflow.start_run(run_name=self.run_name) as mlflow_run:
mlflow.log_params(self.model_params)
run_id = mlflow_run.info.run_id
return run_id

def _set_env(self):
if os.getenv('MLFLOW_TRACKING_URI') is None:
raise ValueError("Environment variable MLFLOW_TRACKING_URI is not exist")

mlflow.set_experiment(self.experiment_name)

def log_model(self, model, name):
with mlflow.start_run(run_id=self.run_id):
mlflow.pytorch.log_model(
pytorch_model=model,
artifact_path=name
)
33 changes: 33 additions & 0 deletions requirement.txt
Original file line number Diff line number Diff line change
@@ -1,19 +1,37 @@
alembic==1.4.1
appnope==0.1.2
backcall==0.2.0
certifi==2021.10.8
charset-normalizer==2.0.10
click==8.0.3
cloudpickle==2.0.0
cycler==0.11.0
databricks-cli==0.16.2
debugpy==1.5.1
decorator==5.1.0
docker==5.0.3
entrypoints==0.3
Flask==2.0.2
fonttools==4.28.5
gitdb==4.0.9
GitPython==3.1.26
greenlet==1.1.2
gunicorn==20.1.0
idna==3.3
importlib-metadata==4.10.0
ipykernel==6.6.0
ipython==7.30.1
itsdangerous==2.0.1
jedi==0.18.1
Jinja2==3.0.3
jupyter-client==7.1.0
jupyter-core==4.9.1
kiwisolver==1.3.2
Mako==1.1.6
MarkupSafe==2.0.1
matplotlib==3.5.1
matplotlib-inline==0.1.3
mlflow==1.22.0
nest-asyncio==1.5.4
numpy==1.21.4
packaging==21.3
Expand All @@ -22,16 +40,31 @@ parso==0.8.3
pexpect==4.8.0
pickleshare==0.7.5
Pillow==8.4.0
prometheus-client==0.12.0
prometheus-flask-exporter==0.18.7
prompt-toolkit==3.0.23
protobuf==3.19.2
ptyprocess==0.7.0
Pygments==2.10.0
pyparsing==3.0.6
python-dateutil==2.8.2
python-editor==1.0.4
pytz==2021.3
PyYAML==6.0
pyzmq==22.3.0
querystring-parser==1.2.4
requests==2.27.1
six==1.16.0
smmap==5.0.0
SQLAlchemy==1.4.29
sqlparse==0.4.2
tabulate==0.8.9
torch==1.10.0
tornado==6.1
traitlets==5.1.1
typing_extensions==4.0.1
urllib3==1.26.8
wcwidth==0.2.5
websocket-client==1.2.3
Werkzeug==2.0.2
zipp==3.7.0

0 comments on commit 4f83e23

Please sign in to comment.