diff --git a/data_copilot/main.py b/data_copilot/main.py index e26abff..4036f0e 100644 --- a/data_copilot/main.py +++ b/data_copilot/main.py @@ -20,6 +20,7 @@ class BACKENDS(enum.Enum): STANDARD_BACKEND = BACKENDS.SQL +BACKEND_HOST = "" def get_envs(): @@ -27,9 +28,9 @@ def get_envs(): "JWT_SECRET_KEY": subprocess.check_output("openssl rand -hex 32", shell=True) .decode("utf-8") .strip(), - "BACKEND_HOST": "localhost:8000/api", "DB_CONNECTION_STRING": "sqlite:///data_copilot.db", "CELERY_BROKER_URL": "redis://localhost:6378/0", + "BACKEND_HOST": BACKEND_HOST, "OPENAI_API_KEY": os.environ.get("OPENAI_API_KEY"), "STORAGE_BACKEND": "volume://shared-fs/data", "ENVIRONMENT": "DEVELOPMENT", @@ -43,7 +44,7 @@ def get_envs(): return env -def start_backend_server(log_level="INFO"): +def start_backend_server(log_level="INFO", port=8000): if log_level.lower() not in ( "critical", "error", @@ -60,7 +61,7 @@ def start_backend_server(log_level="INFO"): "--host", "0.0.0.0", "--port", - "8000", + str(port), "--log-level", log_level.lower(), ], @@ -72,7 +73,7 @@ def start_backend_server(log_level="INFO"): return backend_process -def start_redis(log_level="INFO"): +def start_redis(log_level="INFO", port=6378): if log_level.lower() not in ("debug", "verbose", "notice", "warning"): log_level = "warning" @@ -80,7 +81,7 @@ def start_redis(log_level="INFO"): [ "redis-server", "--port", - "6378", + str(port), "--appendonly", "yes", "--loglevel", @@ -93,7 +94,9 @@ def start_redis(log_level="INFO"): return redis_process -def start_frontend(log_level="INFO"): +def start_frontend( + log_level="INFO", port=8080, backend_host="http://localhost:8000/api" +): def _start_frontend(): # Set up a logger for this thread, with the specified log level logger = logging.getLogger("Frontend") @@ -122,7 +125,7 @@ def send_assets(path): def catch_all(path): return send_from_directory(app.static_folder, "index.html") - app.run(host="0.0.0.0", port=8080) + app.run(host="0.0.0.0", port=port) frontend_process = threading.Thread(target=_start_frontend, daemon=True) frontend_process.start() @@ -221,8 +224,34 @@ def main(): type=click.Choice([b.value for b in BACKENDS]), help=f"The backend to use for computation. Defaults to {STANDARD_BACKEND.value}", ) -def run(log_level, backend): - # check_free_ports() +@click.option( + "--backend-port", + default=8000, + type=int, + help="The port to use for the backend.", +) +@click.option( + "--redis-port", + default=6378, + type=int, + help="The port to use for redis.", +) +@click.option( + "--frontend-port", + default=8080, + type=int, + help="The port to use for the frontend.", +) +@click.option( + "--backend-host", + default="http://localhost:8000/api", + type=str, + help="The host for the backend used by the frontend.", +) +def run(log_level, backend, backend_port, redis_port, frontend_port, backend_host): + global BACKEND_HOST + BACKEND_HOST = backend_host + check_free_ports(ports=[backend_port, redis_port, frontend_port]) load_dotenv(".env") @@ -237,9 +266,9 @@ def run(log_level, backend): load_dotenv(".env") worker_process = start_worker(log_level) - redis_process = start_redis(log_level) - backend_process = start_backend_server(log_level) - start_frontend(log_level) + redis_process = start_redis(log_level, redis_port) + backend_process = start_backend_server(log_level, backend_port) + start_frontend(log_level, frontend_port, backend_host) create_subprocess_logger( worker=worker_process,