Skip to content

Commit

Permalink
bugfix
Browse files Browse the repository at this point in the history
Signed-off-by: Jat <[email protected]>
  • Loading branch information
jat001 committed Jun 17, 2021
1 parent 0731ef2 commit 5092306
Show file tree
Hide file tree
Showing 8 changed files with 188 additions and 81 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ __pycache__
.project
*.prefs
_build
venv

# excluded paths
/arch/core/target/
Expand Down
21 changes: 12 additions & 9 deletions python/fate_flow/db/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import os
import sys

import __main__
from peewee import (CharField, IntegerField, BigIntegerField,
TextField, CompositeKey, BigAutoField, BooleanField)
from playhouse.apsw_ext import APSWDatabase
Expand Down Expand Up @@ -63,14 +62,18 @@ def __init__(self):
raise Exception('can not init database')


MAIN_FILE_PATH = os.path.realpath(__main__.__file__)
if MAIN_FILE_PATH.endswith('fate_flow_server.py') or \
MAIN_FILE_PATH.endswith('task_executor.py') or \
MAIN_FILE_PATH.find("/unittest/__main__.py"):
DB = BaseDataBase().database_connection
else:
# Initialize the database only when the server is started.
DB = None
# Initialize the database only when the server is started.
DB = None
for frame in inspect.stack():
filename = frame.filename
if filename.startswith('<'):
continue
filename = os.path.abspath(os.path.realpath(frame.filename))
if filename.endswith('fate_flow_server.py') or \
filename.endswith('task_executor.py') or \
filename.find('/unittest/') >= 0:
DB = BaseDataBase().database_connection
break


def close_connection():
Expand Down
12 changes: 6 additions & 6 deletions python/fate_flow/pipelined_model/migrate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,18 @@ def gen_model_file_path(model_id, model_version):

def compare_roles(request_conf_roles: dict, run_time_conf_roles: dict):
if request_conf_roles.keys() == run_time_conf_roles.keys():
varify_format = True
varify_equality = True
verify_format = True
verify_equality = True
for key in request_conf_roles.keys():
varify_format = varify_format and (len(request_conf_roles[key]) == len(run_time_conf_roles[key])) and (isinstance(request_conf_roles[key], list))
verify_format = verify_format and (len(request_conf_roles[key]) == len(run_time_conf_roles[key])) and (isinstance(request_conf_roles[key], list))
request_conf_roles_set = set(str(item) for item in request_conf_roles[key])
run_time_conf_roles_set = set(str(item) for item in run_time_conf_roles[key])
varify_equality = varify_equality and (request_conf_roles_set == run_time_conf_roles_set)
if not varify_format:
verify_equality = verify_equality and (request_conf_roles_set == run_time_conf_roles_set)
if not verify_format:
raise Exception("The structure of roles data of local configuration is different from "
"model runtime configuration's. Migration aborting.")
else:
return varify_equality
return verify_equality
raise Exception("The structure of roles data of local configuration is different from "
"model runtime configuration's. Migration aborting.")

Expand Down
6 changes: 3 additions & 3 deletions python/fate_flow/pipelined_model/mysql_model_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ def restore(self, model_id: str, model_version: str, store_address: dict):
if not model_archive_data:
raise Exception("Restore model {} {} from mysql failed: {}".format(
model_id, model_version, "can not get model archive data"))
with open(model.archive_model_file_path(), "wb") as fw:
with open(model.archive_model_file_path, "wb") as fw:
fw.write(model_archive_data)
model.unpack_model(model.archive_model_file_path())
LOGGER.info("Restore model to {} from mysql successfully".format(model.archive_model_file_path()))
model.unpack_model(model.archive_model_file_path)
LOGGER.info("Restore model to {} from mysql successfully".format(model.archive_model_file_path))
self.close_connection()
except Exception as e:
LOGGER.exception(e)
Expand Down
101 changes: 67 additions & 34 deletions python/fate_flow/pipelined_model/pipelined_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,22 @@
from ruamel import yaml
from copy import deepcopy
from filelock import FileLock
import hashlib

from os.path import join, getsize
from fate_arch.common import file_utils
from fate_arch.protobuf.python import default_empty_fill_pb2
from fate_flow.settings import stat_logger, TEMP_DIRECTORY


def local_cache_required(method):
def magic(self, *args, **kwargs):
if not self.exists():
raise FileNotFoundError(f'Can not found {self.model_id} {self.model_version} model local cache')
return method(self, *args, **kwargs)
return magic


class PipelinedModel(object):
def __init__(self, model_id, model_version):
"""
Expand All @@ -45,16 +54,18 @@ def __init__(self, model_id, model_version):
self.default_archive_format = "zip"
self.lock = FileLock(os.path.join(self.model_path, ".lock"))

def __deepcopy__(self, memo):
return self

def create_pipelined_model(self):
if os.path.exists(self.model_path):
raise Exception("Model creation failed because it has already been created, model cache path is {}".format(
self.model_path
))
os.makedirs(self.model_path, exist_ok=False)
if self.exists():
raise FileExistsError("Model creation failed because it has already been created, model cache path is {}".
format(self.model_path))
os.makedirs(self.model_path)

with self.lock:
for path in [self.variables_index_path, self.variables_data_path]:
os.makedirs(path, exist_ok=False)
os.makedirs(path)
shutil.copytree(os.path.join(file_utils.get_python_base_directory(), "federatedml", "protobuf", "proto"), self.define_proto_path)
with open(self.define_meta_path, "x", encoding="utf-8") as fw:
yaml.dump({"describe": "This is the model definition meta"}, fw, Dumper=yaml.RoundTripDumper)
Expand Down Expand Up @@ -92,23 +103,24 @@ def read_component_model(self, component_name, model_alias):
buffer_object_serialized_string=buffer_object_serialized_string)
return model_buffers

@local_cache_required
def collect_models(self, in_bytes=False, b64encode=True):
model_buffers = {}
with open(self.define_meta_path, "r", encoding="utf-8") as fr:
define_index = yaml.safe_load(fr)
for component_name in define_index.get("model_proto", {}).keys():
for model_alias, model_proto_index in define_index["model_proto"][component_name].items():
component_model_storage_path = os.path.join(self.variables_data_path, component_name, model_alias)
for model_name, buffer_name in model_proto_index.items():
with open(os.path.join(component_model_storage_path, model_name), "rb") as fr:
buffer_object_serialized_string = fr.read()
if not in_bytes:
model_buffers[model_name] = self.parse_proto_object(buffer_name=buffer_name,
buffer_object_serialized_string=buffer_object_serialized_string)
else:
if b64encode:
buffer_object_serialized_string = base64.b64encode(buffer_object_serialized_string).decode()
model_buffers["{}.{}:{}".format(component_name, model_alias, model_name)] = buffer_object_serialized_string
for component_name in define_index.get("model_proto", {}).keys():
for model_alias, model_proto_index in define_index["model_proto"][component_name].items():
component_model_storage_path = os.path.join(self.variables_data_path, component_name, model_alias)
for model_name, buffer_name in model_proto_index.items():
with open(os.path.join(component_model_storage_path, model_name), "rb") as fr:
buffer_object_serialized_string = fr.read()
if not in_bytes:
model_buffers[model_name] = self.parse_proto_object(buffer_name=buffer_name,
buffer_object_serialized_string=buffer_object_serialized_string)
else:
if b64encode:
buffer_object_serialized_string = base64.b64encode(buffer_object_serialized_string).decode()
model_buffers["{}.{}:{}".format(component_name, model_alias, model_name)] = buffer_object_serialized_string
return model_buffers

def set_model_path(self):
Expand All @@ -127,21 +139,38 @@ def save_pipeline(self, pipelined_buffer_object):
with self.lock, open(os.path.join(self.model_path, "pipeline.pb"), "wb") as fw:
fw.write(buffer_object_serialized_string)

@local_cache_required
def packaging_model(self):
if not self.exists():
raise Exception("Can not found {} {} model local cache".format(self.model_id, self.model_version))
archive_file_path = shutil.make_archive(base_name=self.archive_model_base_path(), format=self.default_archive_format, root_dir=self.model_path)
stat_logger.info("Make model {} {} archive on {} successfully".format(self.model_id,
self.model_version,
archive_file_path))
archive_file_path = shutil.make_archive(base_name=self.archive_model_base_path, format=self.default_archive_format, root_dir=self.model_path)

with open(archive_file_path, 'rb') as f:
sha1 = hashlib.sha1(f.read()).hexdigest()
with open(archive_file_path + '.sha1', 'w', encoding='utf8') as f:
f.write(sha1)

stat_logger.info("Make model {} {} archive on {} successfully. sha1: {}".format(
self.model_id, self.model_version, archive_file_path, sha1))
return archive_file_path

def unpack_model(self, archive_file_path: str):
if os.path.exists(self.model_path):
raise Exception("Model {} {} local cache already existed".format(self.model_id, self.model_version))
shutil.unpack_archive(archive_file_path, self.model_path)
if self.exists():
raise FileExistsError("Model {} {} local cache already existed".format(self.model_id, self.model_version))

if os.path.isfile(archive_file_path + '.sha1'):
with open(archive_file_path + '.sha1', encoding='utf8') as f:
sha1_orig = f.read().strip()
with open(archive_file_path, 'rb') as f:
sha1 = hashlib.sha1(f.read()).hexdigest()
if sha1 != sha1_orig:
raise ValueError('Hash not match. path: {} expected: {} actual: {}'.format(
archive_file_path, sha1_orig, sha1))

os.makedirs(self.model_path)
with self.lock:
shutil.unpack_archive(archive_file_path, self.model_path)
stat_logger.info("Unpack model archive to {}".format(self.model_path))

@local_cache_required
def update_component_meta(self, component_name, component_module_name, model_alias, model_proto_index):
"""
update meta info yaml
Expand Down Expand Up @@ -170,18 +199,20 @@ def update_component_meta(self, component_name, component_module_name, model_ali
yaml.dump(define_index, f, Dumper=yaml.RoundTripDumper)
f.truncate()

@local_cache_required
def get_model_proto_index(self, component_name, model_alias):
with open(self.define_meta_path, "r", encoding="utf-8") as fr:
define_index = yaml.safe_load(fr)
return define_index.get("model_proto", {}).get(component_name, {}).get(model_alias, {})
return define_index.get("model_proto", {}).get(component_name, {}).get(model_alias, {})

@local_cache_required
def get_component_define(self, component_name=None):
with open(self.define_meta_path, "r", encoding="utf-8") as fr:
define_index = yaml.safe_load(fr)
if component_name:
return define_index.get("component_define", {}).get(component_name, {})
else:
return define_index.get("component_define", {})

if component_name is not None:
return define_index.get("component_define", {}).get(component_name, {})
return define_index.get("component_define", {})

def parse_proto_object(self, buffer_name, buffer_object_serialized_string):
try:
Expand Down Expand Up @@ -221,11 +252,13 @@ def get_proto_buffer_class(cls, buffer_name):
else:
return None

@property
def archive_model_base_path(self):
return os.path.join(TEMP_DIRECTORY, "{}_{}".format(self.model_id, self.model_version))

@property
def archive_model_file_path(self):
return "{}.{}".format(self.archive_model_base_path(), self.default_archive_format)
return "{}.{}".format(self.archive_model_base_path, self.default_archive_format)

def calculate_model_file_size(self):
size = 0
Expand Down
8 changes: 4 additions & 4 deletions python/fate_flow/pipelined_model/redis_model_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ def restore(self, model_id: str, model_version: str, store_address: dict):
if not model_archive_data:
raise Exception("Restore model {} {} to redis failed: {}".format(
model_id, model_version, "can not found model archive data"))
with open(model.archive_model_file_path(), "wb") as fw:
with open(model.archive_model_file_path, "wb") as fw:
fw.write(model_archive_data)
model.unpack_model(model.archive_model_file_path())
LOGGER.info("Restore model to {} from redis successfully using key {}".format(model.archive_model_file_path(),
redis_store_key))
model.unpack_model(model.archive_model_file_path)
LOGGER.info("Restore model to {} from redis successfully using key {}".format(
model.archive_model_file_path, redis_store_key))
except Exception as e:
LOGGER.exception(e)
raise Exception("Restore model {} {} from redis failed".format(model_id, model_version))
Expand Down
2 changes: 0 additions & 2 deletions python/fate_flow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,11 @@
from fate_arch.common import file_utils, log, EngineType
from fate_flow.entity.runtime_config import RuntimeConfig
from fate_arch.common.conf_utils import get_base_config
import __main__


# Server
API_VERSION = "v1"
FATEFLOW_SERVICE_NAME = "fateflow"
MAIN_MODULE = os.path.relpath(__main__.__file__)
SERVER_MODULE = "fate_flow_server.py"
TEMP_DIRECTORY = os.path.join(file_utils.get_project_base_directory(), "temp", "fate_flow")
HEADERS = {
Expand Down
Loading

0 comments on commit 5092306

Please sign in to comment.