forked from BeastByteAI/scikit-llm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
123 lines (101 loc) · 2.89 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
from typing import Optional
_OPENAI_KEY_VAR = "SKLLM_CONFIG_OPENAI_KEY"
_OPENAI_ORG_VAR = "SKLLM_CONFIG_OPENAI_ORG"
_AZURE_API_BASE_VAR = "SKLLM_CONFIG_AZURE_API_BASE"
_AZURE_API_VERSION_VAR = "SKLLM_CONFIG_AZURE_API_VERSION"
_GOOGLE_PROJECT = "GOOGLE_CLOUD_PROJECT"
class SKLLMConfig:
@staticmethod
def set_openai_key(key: str) -> None:
"""Sets the OpenAI key.
Parameters
----------
key : str
OpenAI key.
"""
os.environ[_OPENAI_KEY_VAR] = key
@staticmethod
def get_openai_key() -> Optional[str]:
"""Gets the OpenAI key.
Returns
-------
Optional[str]
OpenAI key.
"""
return os.environ.get(_OPENAI_KEY_VAR, None)
@staticmethod
def set_openai_org(key: str) -> None:
"""Sets OpenAI organization ID.
Parameters
----------
key : str
OpenAI organization ID.
"""
os.environ[_OPENAI_ORG_VAR] = key
@staticmethod
def get_openai_org() -> str:
"""Gets the OpenAI organization ID.
Returns
-------
str
OpenAI organization ID.
"""
return os.environ.get(_OPENAI_ORG_VAR, "")
@staticmethod
def get_azure_api_base() -> str:
"""Gets the API base for Azure.
Returns
-------
str
URL to be used as the base for the Azure API.
"""
base = os.environ.get(_AZURE_API_BASE_VAR, None)
if base is None:
raise RuntimeError("Azure API base is not set")
return base
@staticmethod
def set_azure_api_base(base: str) -> None:
"""Set the API base for Azure.
Parameters
----------
base : str
URL to be used as the base for the Azure API.
"""
os.environ[_AZURE_API_BASE_VAR] = base
@staticmethod
def set_azure_api_version(ver: str) -> None:
"""Set the API version for Azure.
Parameters
----------
ver : str
Azure API version.
"""
os.environ[_AZURE_API_VERSION_VAR] = ver
@staticmethod
def get_azure_api_version() -> str:
"""Gets the API version for Azure.
Returns
-------
str
Azure API version.
"""
return os.environ.get(_AZURE_API_VERSION_VAR, "2023-05-15")
@staticmethod
def get_google_project() -> Optional[str]:
"""Gets the Google Cloud project ID.
Returns
-------
Optional[str]
Google Cloud project ID.
"""
return os.environ.get(_GOOGLE_PROJECT, None)
@staticmethod
def set_google_project(project: str) -> None:
"""Sets the Google Cloud project ID.
Parameters
----------
project : str
Google Cloud project ID.
"""
os.environ[_GOOGLE_PROJECT] = project