Skip to content

Commit

Permalink
improve federated api
Browse files Browse the repository at this point in the history
  • Loading branch information
jarviszeng-zjc committed Jan 14, 2021
1 parent 244857d commit 1f50d7b
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 61 deletions.
7 changes: 1 addition & 6 deletions c/proxy/lua/router.lua
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,8 @@ local _M = {
local ngx = ngx
local route_table = require "route_table"

local function get_request_dest()
local headers = ngx.req.get_headers()
return headers
end

local function routing()
local request_headers = get_request_dest()
local request_headers = ngx.req.get_headers()
local dest_env = request_headers["dest-party-id"]
if dest_env == nil then
dest_env = request_headers["dest-env"]
Expand Down
9 changes: 6 additions & 3 deletions conf/service_conf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ fateflow:
host: 127.0.0.1
http_port: 9380
grpc_port: 9360
# support rollsite or nginx as a coordinate proxy, rollsite recommended in the fate on eggroll, nginx is recommended in the fate on spark
# support rollsite or nginx as a coordination proxy, rollsite recommended in the fate on eggroll, nginx is recommended in the fate on spark
# format(proxy: rollsite) means rollsite use the rollsite configuration of fate_one_eggroll and nginx use the nginx configuration of fate_one_spark
# you can customize the config by format(proxy:\n name: rollsite \n host: xx \n port: xx)
# you also can customize the config by format(proxy:\n name: nginx \n host: xx \n port: xx)
proxy: rollsite
# can set default/http/grpc, rollsite default is grpc and not support http, nginx default is http also support grpc
protocol: default
fateboard:
host: 127.0.0.1
port: 8080
Expand Down Expand Up @@ -48,7 +50,8 @@ fate_on_spark:
route_table:
nginx:
host: 127.0.0.1
port: 9310
http_port: 9300
grpc_port: 9310
model_store_address:
storage: mysql
name: model
Expand Down
2 changes: 1 addition & 1 deletion python/fate_arch/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from fate_arch.common._types import WorkMode, Backend, Party, FederatedMode, FederatedCommunicationType, EngineType, CoordinateProxyService
from fate_arch.common._types import WorkMode, Backend, Party, FederatedMode, FederatedCommunicationType, EngineType, CoordinationProxyService
2 changes: 1 addition & 1 deletion python/fate_arch/common/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class EngineType(object):
FEDERATION = "federation"


class CoordinateProxyService(object):
class CoordinationProxyService(object):
rollsite = "rollsite"
nginx = 'nginx'

Expand Down
5 changes: 3 additions & 2 deletions python/fate_flow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
TEMP_DIRECTORY = os.path.join(file_utils.get_project_base_directory(), "temp", "fate_flow")
HEADERS = {
"Content-Type": "application/json",
"Connection": "close"
"Connection": "close",
"service": FATEFLOW_SERVICE_NAME
}
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
GRPC_SERVER_MAX_WORKERS = None
Expand Down Expand Up @@ -79,7 +80,7 @@
}

# Scheduling
DEFAULT_GRPC_OVERALL_TIMEOUT = 30 * 1000 # ms
DEFAULT_REMOTE_REQUEST_TIMEOUT = 30 * 1000 # ms
DEFAULT_FEDERATED_COMMAND_TRYS = 3
JOB_DEFAULT_TIMEOUT = 3 * 24 * 60 * 60
JOB_START_TIMEOUT = 60 * 1000 # ms
Expand Down
97 changes: 68 additions & 29 deletions python/fate_flow/utils/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from fate_arch.common.log import audit_logger, schedule_logger
from fate_arch.common import FederatedMode
from fate_arch.common import conf_utils
from fate_flow.settings import DEFAULT_GRPC_OVERALL_TIMEOUT, CHECK_NODES_IDENTITY,\
from fate_arch.common import CoordinationProxyService
from fate_flow.settings import DEFAULT_REMOTE_REQUEST_TIMEOUT, CHECK_NODES_IDENTITY,\
FATE_MANAGER_GET_NODE_INFO_ENDPOINT, HEADERS, API_VERSION, stat_logger
from fate_flow.utils.grpc_utils import wrap_grpc_packet, get_command_federation_channel, gen_routing_metadata, \
forward_grpc_packet
Expand Down Expand Up @@ -57,20 +58,79 @@ def error_response(response_code, retmsg):


def federated_api(job_id, method, endpoint, src_party_id, dest_party_id, src_role, json_body, federated_mode, api_version=API_VERSION,
overall_timeout=DEFAULT_GRPC_OVERALL_TIMEOUT):
overall_timeout=DEFAULT_REMOTE_REQUEST_TIMEOUT):
if int(dest_party_id) == 0:
federated_mode = FederatedMode.SINGLE
if federated_mode == FederatedMode.SINGLE:
return local_api(job_id=job_id, method=method, endpoint=endpoint, json_body=json_body, api_version=api_version)
elif federated_mode == FederatedMode.MULTIPLE:
return remote_api(job_id=job_id, method=method, endpoint=endpoint, src_party_id=src_party_id, src_role=src_role,
dest_party_id=dest_party_id, json_body=json_body, api_version=api_version, overall_timeout=overall_timeout)
host, port, protocol = get_federated_proxy_address()
if protocol == "http":
return federated_coordination_on_http(job_id=job_id, method=method, host=host,
port=port, endpoint=endpoint, src_party_id=src_party_id, src_role=src_role,
dest_party_id=dest_party_id, json_body=json_body, api_version=api_version, overall_timeout=overall_timeout)
else:
return federated_coordination_on_grpc(job_id=job_id, method=method, host=host,
port=port, endpoint=endpoint, src_party_id=src_party_id, src_role=src_role,
dest_party_id=dest_party_id, json_body=json_body, api_version=api_version, overall_timeout=overall_timeout)
else:
raise Exception('{} work mode is not supported'.format(federated_mode))


def remote_api(job_id, method, endpoint, src_party_id, dest_party_id, src_role, json_body, api_version=API_VERSION,
overall_timeout=DEFAULT_GRPC_OVERALL_TIMEOUT, try_times=3):
def local_api(job_id, method, endpoint, json_body, api_version=API_VERSION, try_times=3):
return federated_coordination_on_http(job_id=job_id, method=method, host=RuntimeConfig.JOB_SERVER_HOST,
port=RuntimeConfig.HTTP_PORT, endpoint=endpoint, src_party_id="", src_role="",
dest_party_id="", json_body=json_body, api_version=api_version, try_times=try_times)


def get_federated_proxy_address():
proxy_config = get_base_config("fateflow", {}).get("proxy", None)
protocol_config = get_base_config("fateflow", {}).get("protocol", "default")
if isinstance(proxy_config, str):
if proxy_config == CoordinationProxyService.rollsite:
proxy_address = get_base_config("fate_on_eggroll", {}).get(proxy_config)
return proxy_address["host"], proxy_address["port"], "grpc"
elif proxy_config == CoordinationProxyService.nginx:
proxy_address = get_base_config("fate_on_spark", {}).get(proxy_config)
protocol = "http" if protocol_config == "default" else protocol_config
return proxy_address["host"], proxy_address[f"{protocol}_port"], protocol
else:
raise RuntimeError(f"can not support coordinate proxy {proxy_config}")
elif isinstance(proxy_config, dict):
proxy_address = proxy_config
protocol = "http" if protocol_config == "default" else protocol_config
return proxy_address["host"], proxy_address[f"{protocol}_port"], protocol
else:
raise RuntimeError(f"can not support coordinate proxy config {proxy_config}")


def federated_coordination_on_http(job_id, method, host, port, endpoint, src_party_id, src_role, dest_party_id, json_body, api_version=API_VERSION, overall_timeout=DEFAULT_REMOTE_REQUEST_TIMEOUT, try_times=3):
endpoint = f"/{api_version}{endpoint}"
exception = None
for t in range(try_times):
try:
url = "http://{}:{}{}".format(host, port, endpoint)
audit_logger(job_id).info('remote http api request: {}'.format(url))
action = getattr(requests, method.lower(), None)
headers = HEADERS.copy()
headers["dest-party-id"] = str(dest_party_id)
headers["src-party-id"] = str(src_party_id)
headers["src-role"] = str(src_role)
http_response = action(url=url, data=json_dumps(json_body), headers=headers)
audit_logger(job_id).info(http_response.text)
response = http_response.json()
audit_logger(job_id).info('remote http api response: {} {}'.format(endpoint, response))
return response
except Exception as e:
exception = e
schedule_logger(job_id).warning(f"remote http request {endpoint} error, sleep and try again")
time.sleep(2 * (t+1))
else:
raise Exception('remote http request error: {}'.format(exception))


def federated_coordination_on_grpc(job_id, method, host, port, endpoint, src_party_id, src_role, dest_party_id, json_body, api_version=API_VERSION,
overall_timeout=DEFAULT_REMOTE_REQUEST_TIMEOUT, try_times=3):
endpoint = f"/{api_version}{endpoint}"
json_body['src_role'] = src_role
json_body['src_party_id'] = src_party_id
Expand All @@ -82,7 +142,7 @@ def remote_api(job_id, method, endpoint, src_party_id, dest_party_id, src_role,
exception = None
for t in range(try_times):
try:
channel, stub = get_command_federation_channel()
channel, stub = get_command_federation_channel(host, port)
_return, _call = stub.unaryCall.with_call(_packet, metadata=_routing_metadata, timeout=(overall_timeout/1000))
audit_logger(job_id).info("grpc api response: {}".format(_return))
channel.close()
Expand Down Expand Up @@ -111,7 +171,7 @@ def proxy_api(role, _job_id, request_config):
dest_party_id = request_config.get('header').get('dest_party_id')
json_body = request_config.get('body')
_packet = forward_grpc_packet(json_body, method, endpoint, src_party_id, dest_party_id, job_id=job_id, role=role,
overall_timeout=DEFAULT_GRPC_OVERALL_TIMEOUT)
overall_timeout=DEFAULT_REMOTE_REQUEST_TIMEOUT)
_routing_metadata = gen_routing_metadata(src_party_id=src_party_id, dest_party_id=dest_party_id)

channel, stub = get_command_federation_channel()
Expand All @@ -121,27 +181,6 @@ def proxy_api(role, _job_id, request_config):
return json_body


def local_api(job_id, method, endpoint, json_body, api_version=API_VERSION, try_times=3):
endpoint = f"/{api_version}{endpoint}"
exception = None
for t in range(try_times):
try:
url = "http://{}:{}{}".format(RuntimeConfig.JOB_SERVER_HOST, RuntimeConfig.HTTP_PORT, endpoint)
audit_logger(job_id).info('local api request: {}'.format(url))
action = getattr(requests, method.lower(), None)
http_response = action(url=url, data=json_dumps(json_body), headers=HEADERS)
audit_logger(job_id).info(http_response.text)
response = http_response.json()
audit_logger(job_id).info('local api response: {} {}'.format(endpoint, response))
return response
except Exception as e:
exception = e
schedule_logger(job_id).warning(f"local request {endpoint} error, sleep and try again")
time.sleep(2 * (t+1))
else:
raise Exception('local request error: {}'.format(exception))


def forward_api(role, request_config):
endpoint = request_config.get('header', {}).get('endpoint')
ip = get_base_config(role, {}).get("host", "127.0.0.1")
Expand Down
24 changes: 5 additions & 19 deletions python/fate_flow/utils/grpc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,18 @@
import requests

from fate_arch.common.log import audit_logger
from fate_arch.common import CoordinateProxyService
from fate_flow.utils.proto_compatibility import basic_meta_pb2
from fate_flow.utils.proto_compatibility import proxy_pb2, proxy_pb2_grpc
import grpc

from fate_flow.settings import FATEFLOW_SERVICE_NAME, IP, GRPC_PORT, HEADERS, DEFAULT_GRPC_OVERALL_TIMEOUT, stat_logger
from fate_flow.settings import FATEFLOW_SERVICE_NAME, IP, GRPC_PORT, HEADERS, DEFAULT_REMOTE_REQUEST_TIMEOUT
from fate_flow.entity.runtime_config import RuntimeConfig
from fate_flow.utils.node_check_utils import nodes_check
from fate_arch.common.conf_utils import get_base_config
from fate_arch.common.base_utils import json_dumps, json_loads


def get_command_federation_channel():
proxy_config = get_base_config("fateflow", {}).get("proxy", None)
if isinstance(proxy_config, str):
if proxy_config == CoordinateProxyService.rollsite:
address = get_base_config("fate_on_eggroll", {}).get(proxy_config)
elif proxy_config == CoordinateProxyService.nginx:
address = get_base_config("fate_on_spark", {}).get(proxy_config)
else:
raise RuntimeError(f"can not support coordinate proxy {proxy_config}")
elif isinstance(proxy_config, dict):
address = proxy_config
else:
raise RuntimeError(f"can not support coordinate proxy config {proxy_config}")
channel = grpc.insecure_channel('{}:{}'.format(address.get("host"), address.get("port")))
def get_command_federation_channel(host, port):
channel = grpc.insecure_channel(f"{host}:{port}")
stub = proxy_pb2_grpc.DataTransferServiceStub(channel)
return channel, stub

Expand All @@ -57,7 +43,7 @@ def gen_routing_metadata(src_party_id, dest_party_id):
return routing_head


def wrap_grpc_packet(json_body, http_method, url, src_party_id, dst_party_id, job_id=None, overall_timeout=DEFAULT_GRPC_OVERALL_TIMEOUT):
def wrap_grpc_packet(json_body, http_method, url, src_party_id, dst_party_id, job_id=None, overall_timeout=DEFAULT_REMOTE_REQUEST_TIMEOUT):
_src_end_point = basic_meta_pb2.Endpoint(ip=IP, port=GRPC_PORT)
_src = proxy_pb2.Topic(name=job_id, partyId="{}".format(src_party_id), role=FATEFLOW_SERVICE_NAME, callback=_src_end_point)
_dst = proxy_pb2.Topic(name=job_id, partyId="{}".format(dst_party_id), role=FATEFLOW_SERVICE_NAME, callback=None)
Expand Down Expand Up @@ -115,7 +101,7 @@ def unaryCall(self, _request, context):


def forward_grpc_packet(_json_body, _method, _url, _src_party_id, _dst_party_id, role, job_id=None,
overall_timeout=DEFAULT_GRPC_OVERALL_TIMEOUT):
overall_timeout=DEFAULT_REMOTE_REQUEST_TIMEOUT):
_src_end_point = basic_meta_pb2.Endpoint(ip=IP, port=GRPC_PORT)
_src = proxy_pb2.Topic(name=job_id, partyId="{}".format(_src_party_id), role=FATEFLOW_SERVICE_NAME, callback=_src_end_point)
_dst = proxy_pb2.Topic(name=job_id, partyId="{}".format(_dst_party_id), role=role, callback=None)
Expand Down

0 comments on commit 1f50d7b

Please sign in to comment.