-
Notifications
You must be signed in to change notification settings - Fork 69
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
25 changed files
with
567 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,103 @@ | ||
from sqlmodel import Session | ||
|
||
from fastapi import Depends, Request | ||
|
||
from fastapi.security import APIKeyHeader | ||
from config.config import settings | ||
from app.exceptions.exception import AuthenticationError, AuthorizationError, ResourceNotFoundError | ||
from app.providers import database | ||
from app.models.token_relation import RelationType, TokenRelationQuery | ||
from app.models.token import Token | ||
from app.services.token.token_relation import TokenRelationService | ||
from app.services.token.token import TokenService | ||
|
||
|
||
def get_session(): | ||
with Session(database.engine) as session: | ||
yield session | ||
|
||
|
||
class OAuth2Bearer(APIKeyHeader): | ||
""" | ||
it use to fetch token from header | ||
""" | ||
|
||
def __init__( | ||
self, *, name: str, scheme_name: str | None = None, description: str | None = None, auto_error: bool = True | ||
): | ||
super().__init__(name=name, scheme_name=scheme_name, description=description, auto_error=auto_error) | ||
|
||
async def __call__(self, request: Request) -> str: | ||
authorization_header_value = request.headers.get(self.model.name) | ||
if authorization_header_value: | ||
scheme, _, param = authorization_header_value.partition(" ") | ||
if scheme.lower() == "bearer" and param.strip() != "": | ||
return param.strip() | ||
return None | ||
|
||
|
||
oauth_token = OAuth2Bearer(name="Authorization") | ||
|
||
|
||
async def verify_admin_token(token=Depends(oauth_token)) -> Token: | ||
""" | ||
admin token authentication | ||
""" | ||
if token is None: | ||
raise AuthenticationError() | ||
if settings.AUTH_ADMIN_TOKEN != token: | ||
raise AuthorizationError() | ||
|
||
|
||
async def get_token(session=Depends(get_session), token=Depends(oauth_token)) -> Token: | ||
""" | ||
get token info | ||
""" | ||
if token and token != "": | ||
try: | ||
return TokenService.get_token(session=session, token=token) | ||
except ResourceNotFoundError: | ||
pass | ||
return None | ||
|
||
|
||
async def verfiy_token(token: Token = Depends(get_token)): | ||
if token is None: | ||
raise AuthenticationError() | ||
|
||
|
||
async def get_token_id(token: Token = Depends(get_token)): | ||
""" | ||
Return token_id, which can be considered as user information. | ||
""" | ||
return token.id if token is not None else None | ||
|
||
|
||
def get_param(name: str): | ||
""" | ||
extract param from Request | ||
""" | ||
|
||
async def get_param_from_request(request: Request): | ||
if name in request.path_params: | ||
return request.path_params[name] | ||
if name in request.query_params: | ||
return request.query_params[name] | ||
body = await request.json() | ||
if name in body: | ||
return body[name] | ||
|
||
return get_param_from_request | ||
|
||
|
||
def verify_token_relation(relation_type: RelationType, name: str): | ||
async def verify_authorization( | ||
session=Depends(get_session), token_id=Depends(get_token_id), relation_id=Depends(get_param(name)) | ||
): | ||
if token_id and relation_id: | ||
verify = TokenRelationQuery(token_id=token_id, relation_type=relation_type, relation_id=relation_id) | ||
if TokenRelationService.verify_relation(session=session, verify=verify): | ||
return | ||
raise AuthorizationError() | ||
|
||
return verify_authorization |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,14 @@ | ||
from fastapi import APIRouter | ||
from app.api.v1 import assistant, assistant_file, thread, message, files, runs | ||
from app.api.v1 import assistant, assistant_file, thread, message, files, runs, token | ||
|
||
api_router = APIRouter(prefix="/v1") | ||
|
||
api_router.include_router(assistant.router, prefix="/assistants", tags=["assistants"]) | ||
api_router.include_router(assistant_file.router, prefix="/assistants", tags=["assistants"]) | ||
api_router.include_router(thread.router, prefix="/threads", tags=["threads"]) | ||
api_router.include_router(message.router, prefix="/threads", tags=["messages"]) | ||
api_router.include_router(runs.router, prefix="/threads", tags=["runs"]) | ||
api_router.include_router(files.router, prefix="/files", tags=["files"]) | ||
|
||
def router_init(): | ||
api_router.include_router(assistant.router, prefix="/assistants", tags=["assistants"]) | ||
api_router.include_router(assistant_file.router, prefix="/assistants", tags=["assistants"]) | ||
api_router.include_router(thread.router, prefix="/threads", tags=["threads"]) | ||
api_router.include_router(message.router, prefix="/threads", tags=["messages"]) | ||
api_router.include_router(runs.router, prefix="/threads", tags=["runs"]) | ||
api_router.include_router(files.router, prefix="/files", tags=["files"]) | ||
api_router.include_router(token.router, prefix="/tokens", tags=["tokens"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from sqlmodel import Session, select | ||
from fastapi import APIRouter, Depends | ||
|
||
from app.api.deps import get_session, verify_admin_token | ||
from app.libs.paginate import CommonPage, cursor_page | ||
from app.models.token import Token, TokenCreate, TokenUpdate | ||
from app.services.token.token import TokenService | ||
|
||
router = APIRouter() | ||
|
||
|
||
@router.get("/list", response_model=CommonPage[Token], dependencies=[Depends(verify_admin_token)]) | ||
def list_tokens(*, session: Session = Depends(get_session)): | ||
""" | ||
Returns a list of tokens. | ||
""" | ||
statement = select(Token) | ||
return cursor_page(statement, session) | ||
|
||
|
||
@router.post("", response_model=Token, dependencies=[Depends(verify_admin_token)]) | ||
def create_token(*, session: Session = Depends(get_session), body: TokenCreate) -> Token: | ||
""" | ||
Create a token with a llm url & token. | ||
""" | ||
return TokenService.create_token(session=session, body=body) | ||
|
||
|
||
@router.get("", response_model=Token, dependencies=[Depends(verify_admin_token)]) | ||
def get_token(*, session: Session = Depends(get_session), token: str) -> Token: | ||
""" | ||
Retrieves a token. | ||
""" | ||
return TokenService.get_token(session=session, token=token) | ||
|
||
|
||
@router.get("/refresh_token", response_model=Token, dependencies=[Depends(verify_admin_token)]) | ||
def refresh_token(*, session: Session = Depends(get_session), token: str) -> Token: | ||
return TokenService.refresh_token(session=session, token=token) | ||
|
||
|
||
@router.post("/modify_token", response_model=Token, dependencies=[Depends(verify_admin_token)]) | ||
def modify_token(*, session: Session = Depends(get_session), update: TokenUpdate) -> Token: | ||
return TokenService.modify_token(session=session, update=update) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.