Skip to content

Commit

Permalink
Cohere Client (microsoft#3004)
Browse files Browse the repository at this point in the history
* initial setup for cohere client

* client update

* changes: ClintType added to the utils

* Revert "changes: ClintType added to the utils"

This reverts commit 80d6155.

* Message conversion to Cohere, Parameter handling, cost calculation, streaming, tool calling

* Changed Groq references.

* minor fix

* tests added

* ref fix

* added in the workflows

* Fixed bug on non-streaming text generation

* fix: formatting

* Support Cohere rule for last message not USER when tool_results exist

* Added Cohere to documentation

* fixed client.py merge, removed unnecessary comments in groq.py, updated Cohere documentation, added Groq documentation

* log: ignored params

* update: custom exception added

---------

Co-authored-by: Mark Sze <[email protected]>
Co-authored-by: Mark Sze <[email protected]>
  • Loading branch information
3 people authored Jul 3, 2024
1 parent 8133b7d commit b4a3f26
Show file tree
Hide file tree
Showing 11 changed files with 1,661 additions and 10 deletions.
36 changes: 36 additions & 0 deletions .github/workflows/contrib-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -638,3 +638,39 @@ jobs:
with:
file: ./coverage.xml
flags: unittests

CohereTest:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
with:
lfs: true
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install packages and dependencies for all tests
run: |
python -m pip install --upgrade pip wheel
pip install pytest-cov>=5
- name: Install packages and dependencies for Cohere
run: |
pip install -e .[cohere,test]
- name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash
run: |
if [[ ${{ matrix.os }} != ubuntu-latest ]]; then
echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
fi
- name: Coverage
run: |
pytest test/oai/test_cohere.py --skip-openai
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
12 changes: 11 additions & 1 deletion autogen/logger/file_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
if TYPE_CHECKING:
from autogen import Agent, ConversableAgent, OpenAIWrapper
from autogen.oai.anthropic import AnthropicClient
from autogen.oai.cohere import CohereClient
from autogen.oai.gemini import GeminiClient
from autogen.oai.groq import GroqClient
from autogen.oai.mistral import MistralAIClient
Expand Down Expand Up @@ -205,7 +206,16 @@ def log_new_wrapper(

def log_new_client(
self,
client: AzureOpenAI | OpenAI | GeminiClient | AnthropicClient | MistralAIClient | TogetherClient | GroqClient,
client: (
AzureOpenAI
| OpenAI
| GeminiClient
| AnthropicClient
| MistralAIClient
| TogetherClient
| GroqClient
| CohereClient
),
wrapper: OpenAIWrapper,
init_args: Dict[str, Any],
) -> None:
Expand Down
12 changes: 11 additions & 1 deletion autogen/logger/sqlite_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
if TYPE_CHECKING:
from autogen import Agent, ConversableAgent, OpenAIWrapper
from autogen.oai.anthropic import AnthropicClient
from autogen.oai.cohere import CohereClient
from autogen.oai.gemini import GeminiClient
from autogen.oai.groq import GroqClient
from autogen.oai.mistral import MistralAIClient
Expand Down Expand Up @@ -392,7 +393,16 @@ def log_function_use(self, source: Union[str, Agent], function: F, args: Dict[st

def log_new_client(
self,
client: Union[AzureOpenAI, OpenAI, GeminiClient, AnthropicClient, MistralAIClient, TogetherClient, GroqClient],
client: Union[
AzureOpenAI,
OpenAI,
GeminiClient,
AnthropicClient,
MistralAIClient,
TogetherClient,
GroqClient,
CohereClient,
],
wrapper: OpenAIWrapper,
init_args: Dict[str, Any],
) -> None:
Expand Down
12 changes: 12 additions & 0 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@
except ImportError as e:
groq_import_exception = e

try:
from autogen.oai.cohere import CohereClient

cohere_import_exception: Optional[ImportError] = None
except ImportError as e:
cohere_import_exception = e

logger = logging.getLogger(__name__)
if not logger.handlers:
# Add the console handler.
Expand Down Expand Up @@ -497,6 +504,11 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s
raise ImportError("Please install `groq` to use the Groq API.")
client = GroqClient(**openai_config)
self._clients.append(client)
elif api_type is not None and api_type.startswith("cohere"):
if cohere_import_exception:
raise ImportError("Please install `cohere` to use the Groq API.")
client = CohereClient(**openai_config)
self._clients.append(client)
else:
client = OpenAI(**openai_config)
self._clients.append(OpenAIClient(client))
Expand Down
Loading

0 comments on commit b4a3f26

Please sign in to comment.