Skip to content

Commit

Permalink
update AAD login
Browse files Browse the repository at this point in the history
  • Loading branch information
Mac0q committed Mar 6, 2024
1 parent 6bcf1a6 commit 1d28b1f
Show file tree
Hide file tree
Showing 8 changed files with 307 additions and 33 deletions.
14 changes: 14 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Ignore login file
*.bin

# Ignore Jupyter Notebook checkpoints
.ipynb_checkpoints
/test/*
Expand All @@ -6,3 +9,14 @@
__pycache__/
**/__pycache__/
*.pyc

# Ignore the config file
ufo/config/config.yaml
*.yaml.test

# Ignore the helper files
ufo/rag/app_docs/*
learner/records.json
vectordb/*


10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,16 @@ OPENAI_API_KEY: "YOUR_API_KEY" # Set the value to the openai key for the llm mo
OPENAI_API_MODEL: "GPTV_MODEL_NAME" # The only OpenAI model by now that accepts visual input
```

#### AAD application auth
```
API_TYPE: "azure_ad"
OPENAI_API_BASE: "YOUR_ENDPOINT" # The AAD API address. Format: https://{your-resource-name}.azure-api.net/
API_VERSION: "API-VERSION" #For GPT4-visual, the value usually be the "2023-12-01-preview"
OPENAI_API_MODEL: "GPTV_MODEL_NAME" # The only OpenAI model by now that accepts visual input
AAD_TENANT_ID: "YOUR_TENANT_ID" #Set the value to your tenant id for the llm model
AAD_API_SCOPE: "YOUR_SCOPE" #Set the value to your scope for the llm model
AAD_API_SCOPE_BASE: "YOUR_SCOPE_BASE" #Set the value to your scope base for the llm model, whose format is API://YOUR_SCOPE_BASE
```

### 🎉 Step 3: Start UFO

Expand Down
1 change: 0 additions & 1 deletion ufo/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from . import ufo

if __name__ == "__main__":
Expand Down
15 changes: 11 additions & 4 deletions ufo/config/config.yaml → ufo/config/config.yaml.template
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
version: 0.1
version: 0.2

API_TYPE: "openai" # The API type, "openai" for the OpenAI API, "aoai" for the AOAI API.
OPENAI_API_BASE: "YOUR_ENDPOINT" # The the OpenAI API endpoint, "https://api.openai.com/v1/chat/completions" for the OpenAI API.
API_TYPE: "azure_ad" # The API type, "openai" for the OpenAI API, "aoai" for the AOAI API, 'azure_ad' for the ad authority of the AOAI API.
OPENAI_API_BASE: "YOUR_ENDPOINT" # The the OpenAI API endpoint, "https://api.openai.com/v1/chat/completions" for the OpenAI API. As for the AAD API address. Format: https://{your-resource-name}.azure-api.net/
OPENAI_API_KEY: "YOUR_API_KEY" # The OpenAI API key
OPENAI_API_MODEL: "gpt-4-vision-preview" # The only OpenAI model by now that accepts visual input
API_VERSION: "API_VERSION" # For GPT4-visual, the value usually be the "2023-12-01-preview"
OPENAI_API_MODEL: "gpt-4-visual-preview" # The only OpenAI model by now that accepts visual input
CONTROL_BACKEND: "uia" # The backend for control action
MAX_TOKENS: 2000 # The max token limit for the response completion
MAX_RETRY: 3 # The max retry limit for the response completion
Expand Down Expand Up @@ -35,6 +36,12 @@ REQUEST_TIMEOUT: 250 # The call timeout for the GPT-V model
APP_SELECTION_PROMPT: "ufo/prompts/base/app_selection.yaml" # The prompt for the app selection
ACTION_SELECTION_PROMPT: "ufo/prompts/base/action_selection.yaml" # The prompt for the action selection
INPUT_TEXT_API: "type_keys" # The input text API
###For AAD
AAD_TENANT_ID: "YOUR_TENANT_ID" # Set the value to your tenant id for the llm model
AAD_API_SCOPE: "YOUR_SCOPE" # Set the value to your scope for the llm model
AAD_API_SCOPE_BASE: "YOUR_SCOPE_BASE" # Set the value to your scope base for the llm model, whose format is API://YOUR_SCOPE_BASE





Expand Down
215 changes: 215 additions & 0 deletions ufo/llm/azure_ad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@

import datetime
from typing import Literal, Optional
from ..config.config import load_config

configs = load_config()
available_models = Literal[ #only GPT4V could be used
"gpt-4-visual-preview",
]

def get_openai_token(
token_cache_file: str = 'apim-token-cache.bin',
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
) -> str:
'''
acquire token from Azure AD for your organization
Parameters
----------
token_cache_file : str, optional
path to the token cache file, by default 'apim-token-cache.bin' in the current directory
client_id : Optional[str], optional
client id for AAD app, by default None
client_secret : Optional[str], optional
client secret for AAD app, by default None
Returns
-------
str
access token for your own organization
'''
import msal
import os

cache = msal.SerializableTokenCache()

def save_cache():
if token_cache_file is not None and cache.has_state_changed:
with open(token_cache_file, "w") as cache_file:
cache_file.write(cache.serialize())
if os.path.exists(token_cache_file):
cache.deserialize(open(token_cache_file, "r").read())

authority = "https://login.microsoftonline.com/" + configs["AAD_TENANT_ID"]
api_scope_base = "api://" + configs["AAD_API_SCOPE_BASE"]

if client_id is not None and client_secret is not None:
app = msal.ConfidentialClientApplication(
client_id=client_id,
client_credential=client_secret,
authority=authority,
token_cache=cache
)
result = app.acquire_token_for_client(
scopes=[
api_scope_base + "/.default",
])
if "access_token" in result:
return result['access_token']
else:
print(result.get("error"))
print(result.get("error_description"))
raise Exception(
"Authentication failed for acquiring AAD token for your organization")

scopes = [api_scope_base + "/" + configs["AAD_API_SCOPE"]]
app = msal.PublicClientApplication(
configs["AAD_API_SCOPE_BASE"],
authority=authority,
token_cache=cache
)
result = None
for account in app.get_accounts():
try:
result = app.acquire_token_silent(scopes, account=account)
if result is not None and "access_token" in result:
save_cache()
return result['access_token']
result = None
except Exception:
continue

accounts_in_cache = cache.find(msal.TokenCache.CredentialType.ACCOUNT)
for account in accounts_in_cache:
try:
refresh_token = cache.find(
msal.CredentialType.REFRESH_TOKEN,
query={
"home_account_id": account["home_account_id"]
})[0]
result = app.acquire_token_by_refresh_token(
refresh_token["secret"], scopes=scopes)
if result is not None and "access_token" in result:
save_cache()
return result['access_token']
result = None
except Exception:
pass

if result is None:
print("no token available from cache, acquiring token from AAD")
# The pattern to acquire a token looks like this.
flow = app.initiate_device_flow(scopes=scopes)
print(flow['message'])
result = app.acquire_token_by_device_flow(flow=flow)
if result is not None and "access_token" in result:
save_cache()
return result['access_token']
else:
print(result.get("error"))
print(result.get("error_description"))
raise Exception(
"Authentication failed for acquiring AAD token for your organization")



def auto_refresh_token(
token_cache_file: str = 'apim-token-cache.bin',
interval: datetime.timedelta = datetime.timedelta(minutes=15),
on_token_update: callable = None,
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
) -> callable:
"""
helper function for auto refreshing token from your organization
Parameters
----------
token_cache_file : str, optional
path to the token cache file, by default 'apim-token-cache.bin' in the current directory
interval : datetime.timedelta, optional
interval for refreshing token, by default 15 minutes
on_token_update : callable, optional
callback function to be called when token is updated, by default None. In the callback function, you can get token from openai.api_key
Returns
-------
callable
a callable function that can be used to stop the auto refresh thread
"""

import threading

def update_token():
import openai

openai.api_type = "azure" if configs["API_TYPE"] == "azure_ad" else configs["API_TYPE"]
openai.base_url = configs["OPENAI_API_BASE"]
openai.api_version = configs["API_VERSION"]
openai.api_key = get_openai_token(
token_cache_file=token_cache_file,
client_id=client_id,
client_secret=client_secret,
)

if on_token_update is not None:
on_token_update()

def refresh_token_thread():
import time
while True:
try:
update_token()
except Exception as e:
print("failed to acquire token from AAD for your organization", e)
time.sleep(interval.total_seconds())

try:
update_token()
except Exception as e:
raise Exception(
"failed to acquire token from AAD for your organization", e)

thread = threading.Thread(target=refresh_token_thread, daemon=True)
thread.start()

def stop():
thread.stop()

return stop

def get_chat_completion(
model: available_models = None,
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
*args,
**kwargs
):
"""
helper function for getting chat completion from your organization
"""
import openai

model_name: str = \
model if model is not None else \
kwargs.get("engine") if kwargs.get("engine") is not None else \
None

if "engine" in kwargs:
del kwargs["engine"]

client = openai.AzureOpenAI(
api_version=configs["API_VERSION"],
azure_endpoint=configs["OPENAI_API_BASE"],
azure_ad_token=get_openai_token(client_id=client_id, client_secret=client_secret),
)

response = client.chat.completions.create(
model=model_name,
*args,
**kwargs
)

return response
65 changes: 41 additions & 24 deletions ufo/llm/llm_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import time
from ..config.config import load_config
from ..utils import print_with_color
from .azure_ad import get_chat_completion


configs = load_config()

Expand All @@ -21,34 +23,50 @@ def get_gptv_completion(messages, headers):
max_retry: The maximum number of retries.
return: The response of the request.
"""

payload = {
"messages": messages,
"temperature": configs["TEMPERATURE"],
"max_tokens": configs["MAX_TOKENS"],
"top_p": configs["TOP_P"],
"model": configs["OPENAI_API_MODEL"]
}

aad = configs['API_TYPE'].lower() == 'azure_ad'
if not aad:
payload = {
"messages": messages,
"temperature": configs["TEMPERATURE"],
"max_tokens": configs["MAX_TOKENS"],
"top_p": configs["TOP_P"],
"model": configs["OPENAI_API_MODEL"]
}

for _ in range(configs["MAX_RETRY"]):
try:
response = requests.post(configs["OPENAI_API_BASE"], headers=headers, json=payload)
response_json = response.json()
response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code


if "choices" not in response_json:
print_with_color(f"GPT Error: No Reply", "red")
continue

if "error" not in response_json:
usage = response_json.get("usage", {})
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
if not aad :
response = requests.post(configs["OPENAI_API_BASE"], headers=headers, json=payload)

response_json = response.json()
response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code

cost = prompt_tokens / 1000 * 0.01 + completion_tokens / 1000 * 0.03

if "choices" not in response_json:
print_with_color(f"GPT Error: No Reply", "red")
continue

if "error" not in response_json:
usage = response_json.get("usage", {})
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
else:
response = get_chat_completion(
engine=configs["OPENAI_API_MODEL"],
messages = messages,
max_tokens = configs["MAX_TOKENS"],
temperature = configs["TEMPERATURE"],
top_p = configs["TOP_P"],
)

if "error" not in response:
usage = response.usage
prompt_tokens = usage.prompt_tokens
completion_tokens = usage.completion_tokens
response_json = response

cost = prompt_tokens / 1000 * 0.01 + completion_tokens / 1000 * 0.03

return response_json, cost
except requests.RequestException as e:
print_with_color(f"Error making API request: {e}", "red")
Expand All @@ -59,4 +77,3 @@ def get_gptv_completion(messages, headers):
_
time.sleep(3)
continue

Loading

0 comments on commit 1d28b1f

Please sign in to comment.