Skip to content

Commit

Permalink
add parameter mapping with vertex ai
Browse files Browse the repository at this point in the history
  • Loading branch information
ishaan-jaff committed May 20, 2024
1 parent 2c25bfa commit 518db13
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 2 deletions.
13 changes: 13 additions & 0 deletions docs/my-website/docs/providers/vertex.md
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,19 @@ response = await litellm.aimage_generation(
)
```

**Generating multiple images**

Use the `n` parameter to pass how many images you want generated
```python
response = await litellm.aimage_generation(
prompt="An olympic size swimming pool",
model="vertex_ai/imagegeneration@006",
vertex_ai_project="adroit-crow-413218",
vertex_ai_location="us-central1",
n=1,
)
```

## Extra

### Using `GOOGLE_APPLICATION_CREDENTIALS`
Expand Down
10 changes: 8 additions & 2 deletions litellm/llms/vertex_httpx.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,21 @@ async def aimage_generation(
{
"prompt": "a cat"
}
]
],
"parameters": {
"sampleCount": 1
}
} \
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
"""
auth_header = self._ensure_access_token()
optional_params = optional_params or {
"sampleCount": 1
} # default optional params

request_data = {
"instances": [{"prompt": prompt}],
"parameters": {"sampleCount": 1},
"parameters": optional_params,
}

request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
Expand Down
1 change: 1 addition & 0 deletions litellm/tests/test_image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ async def test_aimage_generation_vertex_ai():
model="vertex_ai/imagegeneration@006",
vertex_ai_project="adroit-crow-413218",
vertex_ai_location="us-central1",
n=1,
)
assert response.data is not None
assert len(response.data) > 0
Expand Down
8 changes: 8 additions & 0 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4946,6 +4946,14 @@ def _check_valid_arg(supported_params):
width, height = size.split("x")
optional_params["width"] = int(width)
optional_params["height"] = int(height)
elif custom_llm_provider == "vertex_ai":
supported_params = ["n"]
"""
All params here: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
"""
_check_valid_arg(supported_params=supported_params)
if n is not None:
optional_params["sampleCount"] = int(n)

for k in passed_params.keys():
if k not in default_params.keys():
Expand Down

0 comments on commit 518db13

Please sign in to comment.