forked from BeastByteAI/scikit-llm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
172 lines (141 loc) · 3.9 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import os
from typing import Optional
_OPENAI_KEY_VAR = "SKLLM_CONFIG_OPENAI_KEY"
_OPENAI_ORG_VAR = "SKLLM_CONFIG_OPENAI_ORG"
_AZURE_API_BASE_VAR = "SKLLM_CONFIG_AZURE_API_BASE"
_AZURE_API_VERSION_VAR = "SKLLM_CONFIG_AZURE_API_VERSION"
_GOOGLE_PROJECT = "GOOGLE_CLOUD_PROJECT"
_GPT_URL_VAR = "SKLLM_CONFIG_GPT_URL"
class SKLLMConfig:
@staticmethod
def set_gpt_key(key: str) -> None:
"""Sets the GPT key.
Parameters
----------
key : str
GPT key.
"""
os.environ[_OPENAI_KEY_VAR] = key
def set_gpt_org(key: str) -> None:
"""Sets the GPT organization ID.
Parameters
----------
key : str
GPT organization ID.
"""
os.environ[_OPENAI_ORG_VAR] = key
@staticmethod
def set_openai_key(key: str) -> None:
"""Sets the OpenAI key.
Parameters
----------
key : str
OpenAI key.
"""
os.environ[_OPENAI_KEY_VAR] = key
@staticmethod
def get_openai_key() -> Optional[str]:
"""Gets the OpenAI key.
Returns
-------
Optional[str]
OpenAI key.
"""
return os.environ.get(_OPENAI_KEY_VAR, None)
@staticmethod
def set_openai_org(key: str) -> None:
"""Sets OpenAI organization ID.
Parameters
----------
key : str
OpenAI organization ID.
"""
os.environ[_OPENAI_ORG_VAR] = key
@staticmethod
def get_openai_org() -> str:
"""Gets the OpenAI organization ID.
Returns
-------
str
OpenAI organization ID.
"""
return os.environ.get(_OPENAI_ORG_VAR, "")
@staticmethod
def get_azure_api_base() -> str:
"""Gets the API base for Azure.
Returns
-------
str
URL to be used as the base for the Azure API.
"""
base = os.environ.get(_AZURE_API_BASE_VAR, None)
if base is None:
raise RuntimeError("Azure API base is not set")
return base
@staticmethod
def set_azure_api_base(base: str) -> None:
"""Set the API base for Azure.
Parameters
----------
base : str
URL to be used as the base for the Azure API.
"""
os.environ[_AZURE_API_BASE_VAR] = base
@staticmethod
def set_azure_api_version(ver: str) -> None:
"""Set the API version for Azure.
Parameters
----------
ver : str
Azure API version.
"""
os.environ[_AZURE_API_VERSION_VAR] = ver
@staticmethod
def get_azure_api_version() -> str:
"""Gets the API version for Azure.
Returns
-------
str
Azure API version.
"""
return os.environ.get(_AZURE_API_VERSION_VAR, "2023-05-15")
@staticmethod
def get_google_project() -> Optional[str]:
"""Gets the Google Cloud project ID.
Returns
-------
Optional[str]
Google Cloud project ID.
"""
return os.environ.get(_GOOGLE_PROJECT, None)
@staticmethod
def set_google_project(project: str) -> None:
"""Sets the Google Cloud project ID.
Parameters
----------
project : str
Google Cloud project ID.
"""
os.environ[_GOOGLE_PROJECT] = project
@staticmethod
def set_gpt_url(url: str):
"""Sets the GPT URL.
Parameters
----------
url : str
GPT URL.
"""
os.environ[_GPT_URL_VAR] = url
@staticmethod
def get_gpt_url() -> Optional[str]:
"""Gets the GPT URL.
Returns
-------
Optional[str]
GPT URL.
"""
return os.environ.get(_GPT_URL_VAR, None)
@staticmethod
def reset_gpt_url():
"""Resets the GPT URL."""
os.environ.pop(_GPT_URL_VAR, None)