diff --git a/python/fate_flow/utils/service_utils.py b/python/fate_flow/utils/service_utils.py index c99547eee4..419c396e0f 100644 --- a/python/fate_flow/utils/service_utils.py +++ b/python/fate_flow/utils/service_utils.py @@ -72,14 +72,13 @@ def get_from_registry(cls, service_name): raise Exception('loading servings node failed from zookeeper: {}'.format(e)) @classmethod - def register(cls, party_model_id=None, model_version=None): + def register(cls, zk=None, party_model_id=None, model_version=None): if not get_base_config('use_registry', False): return - - zk = ServiceUtils.get_zk() - zk.start() - atexit.register(zk.stop) - + if not zk: + zk = ServiceUtils.get_zk() + zk.start() + atexit.register(zk.stop) model_transfer_url = 'http://{}:{}{}'.format(IP, HTTP_PORT, FATE_FLOW_MODEL_TRANSFER_ENDPOINT) if party_model_id is not None and model_version is not None: model_transfer_url += '/{}/{}'.format(party_model_id.replace('#', '~'), model_version) @@ -97,9 +96,11 @@ def register(cls, party_model_id=None, model_version=None): def register_models(cls, models): if not get_base_config('use_registry', False): return - + zk = ServiceUtils.get_zk() + zk.start() + atexit.register(zk.stop) for model in models: - cls.register(model.f_party_model_id, model.f_model_version) + cls.register(zk, model.f_party_model_id, model.f_model_version) @classmethod def register_service(cls, service_config):