Skip to content

Commit

Permalink
add support for azure oai models #1
Browse files Browse the repository at this point in the history
  • Loading branch information
victordibia committed Sep 9, 2023
1 parent e08de93 commit 4c607d0
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 42 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@ llmx/generators/cache
llmx.egg-info
notebooks/test.ipynb
notebooks/data
notebooks/.env
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

.DS_Store
n

# C extensions
*.so

Expand Down
23 changes: 21 additions & 2 deletions llmx/datamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ def __getitem__(self, key: Union[str, int]) -> Any:
def to_dict(self):
return self._fields_dict

def __iter__(self):
for key, value in self._fields_dict.items():
yield key, value


@dataclass
class TextGenerationConfig:
Expand All @@ -38,18 +42,33 @@ def __post_init__(self):
def __getitem__(self, key: Union[str, int]) -> Any:
return self._fields_dict.get(key)

def __iter__(self):
for key, value in self._fields_dict.items():
yield key, value


@dataclass
class TextGenerationResponse:
"""Response from a text generation"""

text: List[Message]
config: Any
logprobs: Optional[Any] = None
usage: Optional[Any] = None
logprobs: Optional[Any] = None # logprobs if available
usage: Optional[Any] = None # usage statistics from the provider
response: Optional[Any] = None # full response from the provider

def __post_init__(self):
self._fields_dict = asdict(self)

def __getitem__(self, key: Union[str, int]) -> Any:
return self._fields_dict.get(key)

def __iter__(self):
for key, value in self._fields_dict.items():
yield key, value

def to_dict(self):
return self._fields_dict

def __json__(self):
return self._fields_dict
1 change: 1 addition & 0 deletions llmx/generators/text/palm_textgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def generate(
response_text, model=self.model_name
)
},
response=palm_response,
)

cache_request(
Expand Down
63 changes: 23 additions & 40 deletions notebooks/tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
"metadata": {},
"outputs": [],
"source": [
"from llmx import llm\n",
"from llmx.datamodel import TextGenerationConfig"
"from llmx import llm, TextGenerationConfig"
]
},
{
Expand Down Expand Up @@ -48,7 +47,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Gravity is like a big invisible force that pulls things towards each other. It's what keeps us on the ground and makes things fall down when we drop them. It's kind of like a super strong magnet that pulls everything together.\n"
"Gravity is like a big invisible force that pulls things towards each other. It's what keeps us on the ground and makes things fall down when we drop them. It's like a big hug from the Earth that keeps us close to it.\n"
]
}
],
Expand All @@ -65,10 +64,22 @@
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Gravity is a force that pulls things down to Earth. It's what makes things fall down when you drop them, and it's what keeps the moon in orbit around the Earth. Gravity is a very strong force, but it's also\n"
"ename": "ValueError",
"evalue": "Service account key file is not set. Please set the PALM_SERVICE_ACCOUNT_KEY_FILE environment variable.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mDefaultCredentialsError\u001b[0m Traceback (most recent call last)",
"File \u001b[0;32m~/projects/llmx/llmx/utils.py:79\u001b[0m, in \u001b[0;36mget_gcp_credentials\u001b[0;34m(service_account_key_file, scopes)\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 78\u001b[0m \u001b[39m# Attempt to use Application Default Credentials\u001b[39;00m\n\u001b[0;32m---> 79\u001b[0m credentials, project_id \u001b[39m=\u001b[39m google\u001b[39m.\u001b[39;49mauth\u001b[39m.\u001b[39;49mdefault(scopes\u001b[39m=\u001b[39;49mscopes)\n\u001b[1;32m 80\u001b[0m auth_req \u001b[39m=\u001b[39m google\u001b[39m.\u001b[39mauth\u001b[39m.\u001b[39mtransport\u001b[39m.\u001b[39mrequests\u001b[39m.\u001b[39mRequest()\n",
"File \u001b[0;32m/opt/homebrew/Caskroom/miniconda/base/lib/python3.9/site-packages/google/auth/_default.py:692\u001b[0m, in \u001b[0;36mdefault\u001b[0;34m(scopes, request, quota_project_id, default_scopes)\u001b[0m\n\u001b[1;32m 690\u001b[0m \u001b[39mreturn\u001b[39;00m credentials, effective_project_id\n\u001b[0;32m--> 692\u001b[0m \u001b[39mraise\u001b[39;00m exceptions\u001b[39m.\u001b[39mDefaultCredentialsError(_CLOUD_SDK_MISSING_CREDENTIALS)\n",
"\u001b[0;31mDefaultCredentialsError\u001b[0m: Your default credentials were not found. To set up Application Default Credentials, see https://cloud.google.com/docs/authentication/external/set-up-adc for more information.",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/Users/victordibia/projects/llmx/notebooks/tutorial.ipynb Cell 5\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> <a href='vscode-notebook-cell:/Users/victordibia/projects/llmx/notebooks/tutorial.ipynb#W4sZmlsZQ%3D%3D?line=0'>1</a>\u001b[0m palm_gen \u001b[39m=\u001b[39m llm(provider\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mpalm\u001b[39;49m\u001b[39m\"\u001b[39;49m)\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/victordibia/projects/llmx/notebooks/tutorial.ipynb#W4sZmlsZQ%3D%3D?line=1'>2</a>\u001b[0m palm_config \u001b[39m=\u001b[39m TextGenerationConfig(model\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mcodechat-bison\u001b[39m\u001b[39m\"\u001b[39m, temperature\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m, max_tokens\u001b[39m=\u001b[39m\u001b[39m50\u001b[39m, use_cache\u001b[39m=\u001b[39m\u001b[39mTrue\u001b[39;00m)\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/victordibia/projects/llmx/notebooks/tutorial.ipynb#W4sZmlsZQ%3D%3D?line=2'>3</a>\u001b[0m palm_response \u001b[39m=\u001b[39m palm_gen\u001b[39m.\u001b[39mgenerate(messages, config\u001b[39m=\u001b[39mpalm_config)\n",
"File \u001b[0;32m~/projects/llmx/llmx/generators/text/textgen.py:10\u001b[0m, in \u001b[0;36mllm\u001b[0;34m(provider, **kwargs)\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[39mreturn\u001b[39;00m OpenAITextGenerator(\u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n\u001b[1;32m 9\u001b[0m \u001b[39melif\u001b[39;00m provider\u001b[39m.\u001b[39mlower() \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mpalm\u001b[39m\u001b[39m\"\u001b[39m \u001b[39mor\u001b[39;00m provider\u001b[39m.\u001b[39mlower() \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mgoogle\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[0;32m---> 10\u001b[0m \u001b[39mreturn\u001b[39;00m PalmTextGenerator(provider\u001b[39m=\u001b[39;49mprovider, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 11\u001b[0m \u001b[39melif\u001b[39;00m provider\u001b[39m.\u001b[39mlower() \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mcohere\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[1;32m 12\u001b[0m \u001b[39mreturn\u001b[39;00m CohereTextGenerator(provider\u001b[39m=\u001b[39mprovider, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n",
"File \u001b[0;32m~/projects/llmx/llmx/generators/text/palm_textgen.py:22\u001b[0m, in \u001b[0;36mPalmTextGenerator.__init__\u001b[0;34m(self, palm_key_file, project_id, project_location, provider)\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mproject_id \u001b[39m=\u001b[39m project_id\n\u001b[1;32m 21\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mproject_location \u001b[39m=\u001b[39m project_location\n\u001b[0;32m---> 22\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcredentials \u001b[39m=\u001b[39m get_gcp_credentials(palm_key_file)\n\u001b[1;32m 23\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel_list \u001b[39m=\u001b[39m providers[provider][\u001b[39m\"\u001b[39m\u001b[39mmodels\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39mif\u001b[39;00m provider \u001b[39min\u001b[39;00m providers \u001b[39melse\u001b[39;00m {}\n",
"File \u001b[0;32m~/projects/llmx/llmx/utils.py:86\u001b[0m, in \u001b[0;36mget_gcp_credentials\u001b[0;34m(service_account_key_file, scopes)\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[39mexcept\u001b[39;00m google\u001b[39m.\u001b[39mauth\u001b[39m.\u001b[39mexceptions\u001b[39m.\u001b[39mDefaultCredentialsError:\n\u001b[1;32m 84\u001b[0m \u001b[39m# Fall back to using service account key\u001b[39;00m\n\u001b[1;32m 85\u001b[0m \u001b[39mif\u001b[39;00m service_account_key_file \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m---> 86\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 87\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mService account key file is not set. Please set the PALM_SERVICE_ACCOUNT_KEY_FILE environment variable.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 88\u001b[0m )\n\u001b[1;32m 89\u001b[0m credentials \u001b[39m=\u001b[39m service_account\u001b[39m.\u001b[39mCredentials\u001b[39m.\u001b[39mfrom_service_account_file(\n\u001b[1;32m 90\u001b[0m service_account_key_file, scopes\u001b[39m=\u001b[39mscopes)\n\u001b[1;32m 91\u001b[0m auth_req \u001b[39m=\u001b[39m google\u001b[39m.\u001b[39mauth\u001b[39m.\u001b[39mtransport\u001b[39m.\u001b[39mrequests\u001b[39m.\u001b[39mRequest()\n",
"\u001b[0;31mValueError\u001b[0m: Service account key file is not set. Please set the PALM_SERVICE_ACCOUNT_KEY_FILE environment variable."
]
}
],
Expand All @@ -81,18 +92,9 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"calling cohere *************** TextGenerationConfig(n=1, temperature=0.1, max_tokens=4050, top_p=1.0, top_k=50, frequency_penalty=0.0, presence_penalty=0.0, provider='openai', model='command', stop=None, use_cache=True)\n",
"Gravity is a force that pulls things together. It is what makes things fall to the ground and what holds us on the earth. Gravity is a fundamental force of nature that affects everything around us. It is a property of all matter, and it is what makes things heavy. Gravity is also what causes the moon to orbit the earth and the planets to orbit the sun. It is a very important force that plays a big role in our lives.\n"
]
}
],
"outputs": [],
"source": [
"cohere_gen = llm(provider=\"cohere\")\n",
"cohere_config = TextGenerationConfig(model=\"command\", max_tokens=4050, use_cache=True)\n",
Expand All @@ -102,28 +104,9 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/victordibia/.local/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"/home/victordibia/.local/lib/python3.9/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.25.2\n",
" warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n",
"Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00, 2.75s/it]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hey there, little buddy! *smiling* Gravity is a magic that pulls everything towards each other! *excitedly* Just like how you like to hug your favorite toy or play with your friends, everything in the world has gravity too! *nodding* It's like a big hug that keeps everything close together. *giggles* Even you and me right now, we're being pulled towards each other by gravity! *winks* Isn't that cool? *grinning* So, let's have some fun and see how gravity works, okay? *excitedly*\n"
]
}
],
"outputs": [],
"source": [
"hf_generator = llm(provider=\"hf\", model=\"TheBloke/Llama-2-7b-chat-fp16\", device_map=\"auto\")\n",
"hf_config = TextGenerationConfig(temperature=0, max_tokens=650, use_cache=False)\n",
Expand Down Expand Up @@ -155,7 +138,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.17"
"version": "3.9.12"
},
"orig_nbformat": 4
},
Expand Down

0 comments on commit 4c607d0

Please sign in to comment.