Skip to content

Commit

Permalink
feat: server multi models support (langgenius#799)
Browse files Browse the repository at this point in the history
  • Loading branch information
takatost authored Aug 11, 2023
1 parent d8b712b commit 5fa2161
Show file tree
Hide file tree
Showing 213 changed files with 10,535 additions and 2,558 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/check_no_chinese_comments.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def check_file_for_chinese_comments(file_path):

def main():
has_chinese = False
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py', 'indexing_runner.py', 'web_reader_tool.py']
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py',
'indexing_runner.py', 'web_reader_tool.py', 'spark_provider.py']

for root, _, files in os.walk("."):
for file in files:
Expand Down
26 changes: 26 additions & 0 deletions api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,29 @@ NOTION_INTEGRATION_TYPE=public
NOTION_CLIENT_SECRET=you-client-secret
NOTION_CLIENT_ID=you-client-id
NOTION_INTERNAL_SECRET=you-internal-secret

# Hosted Model Credentials
HOSTED_OPENAI_ENABLED=false
HOSTED_OPENAI_API_KEY=
HOSTED_OPENAI_API_BASE=
HOSTED_OPENAI_API_ORGANIZATION=
HOSTED_OPENAI_QUOTA_LIMIT=200
HOSTED_OPENAI_PAID_ENABLED=false
HOSTED_OPENAI_PAID_STRIPE_PRICE_ID=
HOSTED_OPENAI_PAID_INCREASE_QUOTA=1

HOSTED_AZURE_OPENAI_ENABLED=false
HOSTED_AZURE_OPENAI_API_KEY=
HOSTED_AZURE_OPENAI_API_BASE=
HOSTED_AZURE_OPENAI_QUOTA_LIMIT=200

HOSTED_ANTHROPIC_ENABLED=false
HOSTED_ANTHROPIC_API_BASE=
HOSTED_ANTHROPIC_API_KEY=
HOSTED_ANTHROPIC_QUOTA_LIMIT=1000000
HOSTED_ANTHROPIC_PAID_ENABLED=false
HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID=
HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1

STRIPE_API_KEY=
STRIPE_WEBHOOK_SECRET=
19 changes: 17 additions & 2 deletions api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
import flask_login
from flask_cors import CORS

from core.model_providers.providers import hosted
from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
ext_database, ext_storage, ext_mail
ext_database, ext_storage, ext_mail, ext_stripe
from extensions.ext_database import db
from extensions.ext_login import login_manager

Expand Down Expand Up @@ -71,7 +72,7 @@ def create_app(test_config=None) -> Flask:
register_blueprints(app)
register_commands(app)

core.init_app(app)
hosted.init_app(app)

return app

Expand All @@ -88,6 +89,7 @@ def initialize_extensions(app):
ext_login.init_app(app)
ext_mail.init_app(app)
ext_sentry.init_app(app)
ext_stripe.init_app(app)


def _create_tenant_for_account(account):
Expand Down Expand Up @@ -246,5 +248,18 @@ def threads():
}


@app.route('/db-pool-stat')
def pool_stat():
engine = db.engine
return {
'pool_size': engine.pool.size(),
'checked_in_connections': engine.pool.checkedin(),
'checked_out_connections': engine.pool.checkedout(),
'overflow_connections': engine.pool.overflow(),
'connection_timeout': engine.pool.timeout(),
'recycle_time': db.engine.pool._recycle
}


if __name__ == '__main__':
app.run(host='0.0.0.0', port=5001)
37 changes: 24 additions & 13 deletions api/commands.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime
import logging
import math
import random
import string
import time
Expand All @@ -9,18 +9,18 @@
from werkzeug.exceptions import NotFound

from core.index.index import IndexBuilder
from core.model_providers.providers.hosted import hosted_model_providers
from libs.password import password_pattern, valid_password, hash_password
from libs.helper import email as email_validate
from extensions.ext_database import db
from libs.rsa import generate_key_pair
from models.account import InvitationCode, Tenant
from models.dataset import Dataset, DatasetQuery, Document, DocumentSegment
from models.dataset import Dataset, DatasetQuery, Document
from models.model import Account
import secrets
import base64

from models.provider import Provider, ProviderName
from services.provider_service import ProviderService
from models.provider import Provider, ProviderType, ProviderQuotaType


@click.command('reset-password', help='Reset the account password.')
Expand Down Expand Up @@ -251,26 +251,37 @@ def clean_unused_dataset_indexes():

@click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.')
def sync_anthropic_hosted_providers():
if not hosted_model_providers.anthropic:
click.echo(click.style('Anthropic hosted provider is not configured.', fg='red'))
return

click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
count = 0

page = 1
while True:
try:
tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc()).paginate(page=page, per_page=50)
providers = db.session.query(Provider).filter(
Provider.provider_name == 'anthropic',
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
).order_by(Provider.created_at.desc()).paginate(page=page, per_page=100)
except NotFound:
break

page += 1
for tenant in tenants:
for provider in providers:
try:
click.echo('Syncing tenant anthropic hosted provider: {}'.format(tenant.id))
ProviderService.create_system_provider(
tenant,
ProviderName.ANTHROPIC.value,
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
True
)
click.echo('Syncing tenant anthropic hosted provider: {}'.format(provider.tenant_id))
original_quota_limit = provider.quota_limit
new_quota_limit = hosted_model_providers.anthropic.quota_limit
division = math.ceil(new_quota_limit / 1000)

provider.quota_limit = new_quota_limit if original_quota_limit == 1000 \
else original_quota_limit * division
provider.quota_used = division * provider.quota_used
db.session.commit()

count += 1
except Exception as e:
click.echo(click.style(
Expand Down
52 changes: 39 additions & 13 deletions api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
'SESSION_USE_SIGNER': 'True',
'DEPLOY_ENV': 'PRODUCTION',
'SQLALCHEMY_POOL_SIZE': 30,
'SQLALCHEMY_POOL_RECYCLE': 3600,
'SQLALCHEMY_ECHO': 'False',
'SENTRY_TRACES_SAMPLE_RATE': 1.0,
'SENTRY_PROFILES_SAMPLE_RATE': 1.0,
Expand All @@ -50,9 +51,16 @@
'PDF_PREVIEW': 'True',
'LOG_LEVEL': 'INFO',
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
'DEFAULT_LLM_PROVIDER': 'openai',
'OPENAI_HOSTED_QUOTA_LIMIT': 200,
'ANTHROPIC_HOSTED_QUOTA_LIMIT': 1000,
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_OPENAI_ENABLED': 'False',
'HOSTED_OPENAI_PAID_ENABLED': 'False',
'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1,
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 1000000,
'HOSTED_ANTHROPIC_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_ENABLED': 'False',
'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1,
'TENANT_DOCUMENT_COUNT': 100,
'CLEAN_DAY_SETTING': 30
}
Expand Down Expand Up @@ -182,7 +190,10 @@ def __init__(self):
}

self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}"
self.SQLALCHEMY_ENGINE_OPTIONS = {'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE'))}
self.SQLALCHEMY_ENGINE_OPTIONS = {
'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE')),
'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE'))
}

self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO')

Expand All @@ -194,20 +205,35 @@ def __init__(self):
self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://')

# hosted provider credentials
self.OPENAI_API_KEY = get_env('OPENAI_API_KEY')
self.ANTHROPIC_API_KEY = get_env('ANTHROPIC_API_KEY')

self.OPENAI_HOSTED_QUOTA_LIMIT = get_env('OPENAI_HOSTED_QUOTA_LIMIT')
self.ANTHROPIC_HOSTED_QUOTA_LIMIT = get_env('ANTHROPIC_HOSTED_QUOTA_LIMIT')
self.HOSTED_OPENAI_ENABLED = get_bool_env('HOSTED_OPENAI_ENABLED')
self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
self.HOSTED_OPENAI_QUOTA_LIMIT = get_env('HOSTED_OPENAI_QUOTA_LIMIT')
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID')
self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA'))

self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
self.HOSTED_AZURE_OPENAI_API_BASE = get_env('HOSTED_AZURE_OPENAI_API_BASE')
self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT')

self.HOSTED_ANTHROPIC_ENABLED = get_bool_env('HOSTED_ANTHROPIC_ENABLED')
self.HOSTED_ANTHROPIC_API_BASE = get_env('HOSTED_ANTHROPIC_API_BASE')
self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY')
self.HOSTED_ANTHROPIC_QUOTA_LIMIT = get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT')
self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED')
self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID')
self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA')

self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')

# By default it is False
# You could disable it for compatibility with certain OpenAPI providers
self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION')

# For temp use only
# set default LLM provider, default is 'openai', support `azure_openai`
self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER')

# notion import setting
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
Expand Down
5 changes: 4 additions & 1 deletion api/controllers/console/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source

# Import workspace controllers
from .workspace import workspace, members, model_providers, account, tool_providers
from .workspace import workspace, members, providers, model_providers, account, tool_providers, models

# Import explore controllers
from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio

# Import universal chat controllers
from .universal_chat import chat, conversation, message, parameter, audio

# Import webhook controllers
from .webhook import stripe
26 changes: 21 additions & 5 deletions api/controllers/console/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
import json
from datetime import datetime

import flask
from flask_login import login_required, current_user
from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs
from werkzeug.exceptions import Unauthorized, Forbidden
from werkzeug.exceptions import Forbidden

from constants.model_template import model_templates, demo_model_templates
from controllers.console import api
from controllers.console.app.error import AppNotFoundError
from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.model_params import ModelType
from events.app_event import app_was_created, app_was_deleted
from libs.helper import TimestampField
from extensions.ext_database import db
Expand Down Expand Up @@ -126,9 +127,9 @@ def post(self):
if args['model_config'] is not None:
# validate config
model_configuration = AppModelConfigService.validate_configuration(
tenant_id=current_user.current_tenant_id,
account=current_user,
config=args['model_config'],
mode=args['mode']
config=args['model_config']
)

app = App(
Expand Down Expand Up @@ -164,6 +165,21 @@ def post(self):
app = App(**model_config_template['app'])
app_model_config = AppModelConfig(**model_config_template['model_config'])

default_model = ModelFactory.get_default_model(
tenant_id=current_user.current_tenant_id,
model_type=ModelType.TEXT_GENERATION
)

if default_model:
model_dict = app_model_config.model_dict
model_dict['provider'] = default_model.provider_name
model_dict['name'] = default_model.model_name
app_model_config.model = json.dumps(model_dict)
else:
raise ProviderNotInitializeError(
f"No Text Generation Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")

app.name = args['name']
app.mode = args['mode']
app.icon = args['icon']
Expand Down
2 changes: 1 addition & 1 deletion api/controllers/console/app/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from flask_restful import Resource
from services.audio_service import AudioService
Expand Down
12 changes: 9 additions & 3 deletions api/controllers/console/app/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.conversation_message_task import PubHandler
from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value
from flask_restful import Resource, reqparse
Expand All @@ -41,8 +41,11 @@ def post(self, app_id):
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json')
parser.add_argument('model_config', type=dict, required=True, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
args = parser.parse_args()

streaming = args['response_mode'] != 'blocking'

account = flask_login.current_user

try:
Expand All @@ -51,7 +54,7 @@ def post(self, app_id):
user=account,
args=args,
from_source='console',
streaming=True,
streaming=streaming,
is_model_config_override=True
)

Expand Down Expand Up @@ -111,8 +114,11 @@ def post(self, app_id):
parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('model_config', type=dict, required=True, location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
args = parser.parse_args()

streaming = args['response_mode'] != 'blocking'

account = flask_login.current_user

try:
Expand All @@ -121,7 +127,7 @@ def post(self, app_id):
user=account,
args=args,
from_source='console',
streaming=True,
streaming=streaming,
is_model_config_override=True
)

Expand Down
2 changes: 1 addition & 1 deletion api/controllers/console/app/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.generator.llm_generator import LLMGenerator
from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \
LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError


Expand Down
2 changes: 1 addition & 1 deletion api/controllers/console/app/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from libs.helper import uuid_value, TimestampField
from libs.infinite_scroll_pagination import InfiniteScrollPagination
Expand Down
Loading

0 comments on commit 5fa2161

Please sign in to comment.