Skip to content

Commit

Permalink
Allow using 3rd party AI services that are compatible with OpenAI API…
Browse files Browse the repository at this point in the history
… format in the `openai-gpt` agent (#331)
  • Loading branch information
daxian-dbw authored Jan 21, 2025
1 parent 1c4a2e8 commit d996130
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 7 deletions.
27 changes: 21 additions & 6 deletions shell/agents/AIShell.OpenAI.Agent/GPT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ internal enum EndpointType
{
AzureOpenAI,
OpenAI,
CompatibleThirdParty,
}

public class GPT
Expand Down Expand Up @@ -56,9 +57,16 @@ public GPT(
bool noDeployment = string.IsNullOrEmpty(Deployment);
Type = noEndpoint && noDeployment
? EndpointType.OpenAI
: !noEndpoint && !noDeployment
? EndpointType.AzureOpenAI
: throw new InvalidOperationException($"Invalid setting: {(noEndpoint ? "Endpoint" : "Deployment")} key is missing. To use Azure OpenAI service, please specify both the 'Endpoint' and 'Deployment' keys. To use OpenAI service, please ignore both keys.");
: !noEndpoint && noDeployment
? EndpointType.CompatibleThirdParty
: !noEndpoint && !noDeployment
? EndpointType.AzureOpenAI
: throw new InvalidOperationException($"Invalid setting: 'Deployment' key present but 'Endpoint' key is missing. To use Azure OpenAI service, please specify both the 'Endpoint' and 'Deployment' keys. To use OpenAI service, please ignore both keys.");

if (ModelInfo is null && Type is EndpointType.CompatibleThirdParty)
{
ModelInfo = ModelInfo.ThirdPartyModel;
}
}

/// <summary>
Expand Down Expand Up @@ -142,11 +150,18 @@ private void ShowEndpointInfo(IHost host)
new(label: " Model", m => m.ModelName),
},

EndpointType.OpenAI => new CustomElement<GPT>[]
{
EndpointType.OpenAI =>
[
new(label: " Type", m => m.Type.ToString()),
new(label: " Model", m => m.ModelName),
},
],

EndpointType.CompatibleThirdParty =>
[
new(label: " Type", m => m.Type.ToString()),
new(label: " Endpoint", m => m.Endpoint),
new(label: " Model", m => m.ModelName),
],

_ => throw new UnreachableException(),
};
Expand Down
5 changes: 5 additions & 0 deletions shell/agents/AIShell.OpenAI.Agent/ModelInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ internal class ModelInfo
private static readonly Dictionary<string, ModelInfo> s_modelMap;
private static readonly Dictionary<string, Task<Tokenizer>> s_encodingMap;

// A rough estimate to cover all third-party models.
// - most popular models today support 32K+ context length;
// - use the gpt-4o encoding as an estimate for token count.
internal static readonly ModelInfo ThirdPartyModel = new(32_000, encoding: Gpt4oEncoding);

static ModelInfo()
{
// For reference, see https://platform.openai.com/docs/models and the "Counting tokens" section in
Expand Down
8 changes: 7 additions & 1 deletion shell/agents/AIShell.OpenAI.Agent/Service.cs
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,10 @@ private void RefreshOpenAIClient()
return;
}

EndpointType type = _gptToUse.Type;
string userKey = Utils.ConvertFromSecureString(_gptToUse.Key);

if (_gptToUse.Type is EndpointType.AzureOpenAI)
if (type is EndpointType.AzureOpenAI)
{
// Create a client that targets Azure OpenAI service or Azure API Management service.
var clientOptions = new AzureOpenAIClientOptions() { RetryPolicy = new ChatRetryPolicy() };
Expand Down Expand Up @@ -152,6 +153,11 @@ private void RefreshOpenAIClient()
{
// Create a client that targets the non-Azure OpenAI service.
var clientOptions = new OpenAIClientOptions() { RetryPolicy = new ChatRetryPolicy() };
if (type is EndpointType.CompatibleThirdParty)
{
clientOptions.Endpoint = new(_gptToUse.Endpoint);
}

var aiClient = new OpenAIClient(new ApiKeyCredential(userKey), clientOptions);
_client = aiClient.GetChatClient(_gptToUse.ModelName);
}
Expand Down

0 comments on commit d996130

Please sign in to comment.