Skip to content

API lambda adjustments #1666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,20 @@
using System.Text.Json.Serialization;
using Elastic.Documentation.Api.Core.AskAi;
using Elastic.Documentation.Api.Infrastructure.Gcp;
using Microsoft.Extensions.Options;

namespace Elastic.Documentation.Api.Infrastructure.Adapters.AskAi;

public class LlmGatewayAskAiGateway(HttpClient httpClient, GcpIdTokenProvider tokenProvider, IOptionsSnapshot<LlmGatewayOptions> options) : IAskAiGateway<Stream>
public class LlmGatewayAskAiGateway(HttpClient httpClient, GcpIdTokenProvider tokenProvider, LlmGatewayOptions options) : IAskAiGateway<Stream>
{
public async Task<Stream> AskAi(AskAiRequest askAiRequest, Cancel ctx = default)
{
var llmGatewayRequest = LlmGatewayRequest.CreateFromRequest(askAiRequest);
var requestBody = JsonSerializer.Serialize(llmGatewayRequest, LlmGatewayContext.Default.LlmGatewayRequest);
var request = new HttpRequestMessage(HttpMethod.Post, options.Value.FunctionUrl)
var request = new HttpRequestMessage(HttpMethod.Post, options.FunctionUrl)
{
Content = new StringContent(requestBody, Encoding.UTF8, "application/json")
};
var authToken = await tokenProvider.GenerateIdTokenAsync(ctx);
var authToken = await tokenProvider.GenerateIdTokenAsync(options.ServiceAccount, options.TargetAudience, ctx);
request.Headers.Authorization = new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", authToken);
request.Headers.Add("User-Agent", "elastic-docs-proxy/1.0");
request.Headers.Add("Accept", "text/event-stream");
Expand All @@ -44,7 +43,7 @@ public static LlmGatewayRequest CreateFromRequest(AskAiRequest request) =>
PlatformContext: new PlatformContext("support_portal", "support_assistant", []),
Input:
[
new ChatInput("system", AskAiRequest.SystemPrompt),
new ChatInput("user", AskAiRequest.SystemPrompt),
new ChatInput("user", request.Message)
],
ThreadId: request.ThreadId ?? "elastic-docs-" + Guid.NewGuid()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Licensed to Elasticsearch B.V under one or more agreements.
// Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
// See the LICENSE file in the project root for more information

using Elastic.Documentation.Api.Infrastructure.Aws;

namespace Elastic.Documentation.Api.Infrastructure.Adapters.AskAi;

public class LlmGatewayOptions
{
public LlmGatewayOptions(IParameterProvider parameterProvider)
{
ServiceAccount = parameterProvider.GetParam("llm-gateway-service-account").GetAwaiter().GetResult();
FunctionUrl = parameterProvider.GetParam("llm-gateway-function-url").GetAwaiter().GetResult();
var uri = new Uri(FunctionUrl);
TargetAudience = $"{uri.Scheme}://{uri.Host}";
}

public string ServiceAccount { get; }
public string FunctionUrl { get; }
public string TargetAudience { get; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

namespace Elastic.Documentation.Api.Infrastructure.Aws;

public class LambdaExtensionParameterProvider(IHttpClientFactory httpClientFactory, ILogger<LambdaExtensionParameterProvider> logger) : IParameterProvider
public class LambdaExtensionParameterProvider(IHttpClientFactory httpClientFactory, AppEnvironment appEnvironment, ILogger<LambdaExtensionParameterProvider> logger) : IParameterProvider
{
public const string HttpClientName = "AwsParametersAndSecretsLambdaExtensionClient";
private readonly HttpClient _httpClient = httpClientFactory.CreateClient(HttpClientName);
Expand All @@ -18,8 +18,10 @@ public async Task<string> GetParam(string name, bool withDecryption = true, Canc
{
try
{
logger.LogInformation("Retrieving parameter '{Name}' from Lambda Extension (SSM Parameter Store).", name);
var response = await _httpClient.GetFromJsonAsync<ParameterResponse>($"/systemsmanager/parameters/get?name={Uri.EscapeDataString(name)}&withDecryption={withDecryption.ToString().ToLowerInvariant()}", AwsJsonContext.Default.ParameterResponse, ctx);
var prefix = $"/elastic-docs-v3/{appEnvironment.Current.ToStringFast(true)}/";
var prefixedName = prefix + name.TrimStart('/');
logger.LogInformation("Retrieving parameter '{Name}' from Lambda Extension (SSM Parameter Store).", prefixedName);
var response = await _httpClient.GetFromJsonAsync<ParameterResponse>($"/systemsmanager/parameters/get?name={Uri.EscapeDataString(prefixedName)}&withDecryption={withDecryption.ToString().ToLowerInvariant()}", AwsJsonContext.Default.ParameterResponse, ctx);
return response?.Parameter?.Value ?? throw new InvalidOperationException($"Parameter value for '{name}' is null.");
}
catch (HttpRequestException httpEx)
Expand All @@ -42,23 +44,23 @@ public async Task<string> GetParam(string name, bool withDecryption = true, Canc

internal sealed class ParameterResponse
{
public Parameter? Parameter { get; set; }
public required Parameter Parameter { get; set; }
}

internal sealed class Parameter
{
public string? Arn { get; set; }
public string? Name { get; set; }
public string? Type { get; set; }
public string? Value { get; set; }
public string? Version { get; set; }
[JsonPropertyName("ARN")]
public required string Arn { get; set; }
public required string Name { get; set; }
public required string Type { get; set; }
public required string Value { get; set; }
public required int Version { get; set; }
public string? Selector { get; set; }
public string? LastModifiedDate { get; set; }
public string? LastModifiedUser { get; set; }
public string? DataType { get; set; }
public DateTime LastModifiedDate { get; set; }
public required string DataType { get; set; }
}


[JsonSerializable(typeof(ParameterResponse))]
[JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase)]
[JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.Unspecified)]
internal sealed partial class AwsJsonContext : JsonSerializerContext;
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ public async Task<string> GetParam(string name, bool withDecryption = true, Canc
{
switch (name)
{
case "/elastic-docs-v3/dev/llm-gateway-service-account":
case "llm-gateway-service-account":
{
const string envName = "LLM_GATEWAY_SERVICE_ACCOUNT_KEY_PATH";
var serviceAccountKeyPath = Environment.GetEnvironmentVariable(envName);
Expand All @@ -21,7 +21,7 @@ public async Task<string> GetParam(string name, bool withDecryption = true, Canc
var serviceAccountKey = await File.ReadAllTextAsync(serviceAccountKeyPath, ctx);
return serviceAccountKey;
}
case "/elastic-docs-v3/dev/llm-gateway-function-url":
case "llm-gateway-function-url":
{
const string envName = "LLM_GATEWAY_FUNCTION_URL";
var value = Environment.GetEnvironmentVariable(envName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,48 @@
// Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
// See the LICENSE file in the project root for more information

using System.Collections.Concurrent;
using System.Security.Cryptography;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using Microsoft.Extensions.Options;

namespace Elastic.Documentation.Api.Infrastructure.Gcp;

// This is a custom implementation to create an ID token for GCP.
// Because Google.Api.Auth.OAuth2 is not compatible with AOT
public class GcpIdTokenProvider(HttpClient httpClient, IOptionsSnapshot<LlmGatewayOptions> options)
public class GcpIdTokenProvider(HttpClient httpClient)
{
public async Task<string> GenerateIdTokenAsync(Cancel cancellationToken = default)
// Cache tokens by target audience to avoid regenerating them on every request
private static readonly ConcurrentDictionary<string, CachedToken> TokenCache = new();

private sealed record CachedToken(string Token, DateTimeOffset ExpiresAt);

public async Task<string> GenerateIdTokenAsync(string serviceAccount, string targetAudience, Cancel cancellationToken = default)
{
// Check if we have a valid cached token
if (TokenCache.TryGetValue(targetAudience, out var cachedToken) &&
cachedToken.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(1)) // Refresh 1 minute before expiry
return cachedToken.Token;

// Read and parse service account key file using System.Text.Json source generation (AOT compatible)
var serviceAccount = JsonSerializer.Deserialize(options.Value.ServiceAccount, GcpJsonContext.Default.ServiceAccountKey);
var serviceAccountJson = JsonSerializer.Deserialize(serviceAccount, GcpJsonContext.Default.ServiceAccountKey);

// Create JWT header
var header = new JwtHeader("RS256", "JWT", serviceAccount.PrivateKeyId);
var header = new JwtHeader("RS256", "JWT", serviceAccountJson.PrivateKeyId);
var headerJson = JsonSerializer.Serialize(header, JwtHeaderJsonContext.Default.JwtHeader);
var headerBase64 = Base64UrlEncode(Encoding.UTF8.GetBytes(headerJson));

// Create JWT payload
var now = DateTimeOffset.UtcNow.ToUnixTimeSeconds();
var now = DateTimeOffset.UtcNow;
var expirationTime = now.AddHours(1);
var payload = new JwtPayload(
serviceAccount.ClientEmail,
serviceAccount.ClientEmail,
serviceAccountJson.ClientEmail,
serviceAccountJson.ClientEmail,
"https://oauth2.googleapis.com/token",
now,
now + 300, // 5 minutes
options.Value.TargetAudience
now.ToUnixTimeSeconds(),
expirationTime.ToUnixTimeSeconds(),
targetAudience
);

var payloadJson = JsonSerializer.Serialize(payload, GcpJsonContext.Default.JwtPayload);
Expand All @@ -43,7 +54,7 @@ public async Task<string> GenerateIdTokenAsync(Cancel cancellationToken = defaul
var messageBytes = Encoding.UTF8.GetBytes(message);

// Parse the private key (removing PEM headers/footers and decoding)
var privateKeyPem = serviceAccount.PrivateKey
var privateKeyPem = serviceAccountJson.PrivateKey
.Replace("-----BEGIN PRIVATE KEY-----", "")
.Replace("-----END PRIVATE KEY-----", "")
.Replace("\n", "")
Expand All @@ -59,7 +70,14 @@ public async Task<string> GenerateIdTokenAsync(Cancel cancellationToken = defaul
var jwt = $"{message}.{signatureBase64}";

// Exchange JWT for ID token
return await ExchangeJwtForIdToken(jwt, options.Value.TargetAudience, cancellationToken);
var idToken = await ExchangeJwtForIdToken(jwt, targetAudience, cancellationToken);

var expiresAt = expirationTime.Subtract(TimeSpan.FromMinutes(1));
_ = TokenCache.AddOrUpdate(targetAudience,
new CachedToken(idToken, expiresAt),
(_, _) => new CachedToken(idToken, expiresAt));

return idToken;
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,17 @@
namespace Elastic.Documentation.Api.Infrastructure;

[EnumExtensions]
public enum AppEnvironment
public enum AppEnv
{
[Display(Name = "dev")] Dev,
[Display(Name = "staging")] Staging,
[Display(Name = "edge")] Edge,
[Display(Name = "prod")] Prod
}

public class LlmGatewayOptions
public class AppEnvironment
{
public string ServiceAccount { get; set; } = string.Empty;
public string FunctionUrl { get; set; } = string.Empty;
public string TargetAudience { get; set; } = string.Empty;
public AppEnv Current { get; init; }
}

public static class ServicesExtension
Expand All @@ -41,42 +39,44 @@ public static class ServicesExtension

public static void AddElasticDocsApiUsecases(this IServiceCollection services, string? appEnvironment)
{
if (AppEnvironmentExtensions.TryParse(appEnvironment, out var parsedEnvironment, true))
if (AppEnvExtensions.TryParse(appEnvironment, out var parsedEnvironment, true))
{
AddElasticDocsApiUsecases(services, parsedEnvironment);
}
else
{
var logger = GetLogger(services);
logger?.LogWarning("Unable to parse environment {AppEnvironment} into AppEnvironment. Using default AppEnvironment.Dev", appEnvironment);
AddElasticDocsApiUsecases(services, AppEnvironment.Dev);
AddElasticDocsApiUsecases(services, AppEnv.Dev);
}
}


private static void AddElasticDocsApiUsecases(this IServiceCollection services, AppEnvironment appEnvironment)
private static void AddElasticDocsApiUsecases(this IServiceCollection services, AppEnv appEnv)
{
_ = services.ConfigureHttpJsonOptions(options =>
{
options.SerializerOptions.TypeInfoResolverChain.Insert(0, ApiJsonContext.Default);
});
_ = services.AddHttpClient();
AddParameterProvider(services, appEnvironment);
AddAskAiUsecase(services, appEnvironment);
// Register AppEnvironment as a singleton for dependency injection
_ = services.AddSingleton(new AppEnvironment { Current = appEnv });
AddParameterProvider(services, appEnv);
AddAskAiUsecase(services, appEnv);
}

// https://docs.aws.amazon.com/systems-manager/latest/userguide/ps-integration-lambda-extensions.html
private static void AddParameterProvider(IServiceCollection services, AppEnvironment appEnvironment)
// https://docs.aws.amazon.com/systems -manager/latest/userguide/ps-integration-lambda-extensions.html
private static void AddParameterProvider(IServiceCollection services, AppEnv appEnv)
{
var logger = GetLogger(services);

switch (appEnvironment)
switch (appEnv)
{
case AppEnvironment.Prod:
case AppEnvironment.Staging:
case AppEnvironment.Edge:
case AppEnv.Prod:
case AppEnv.Staging:
case AppEnv.Edge:
{
logger?.LogInformation("Configuring LambdaExtensionParameterProvider for environment {AppEnvironment}", appEnvironment);
logger?.LogInformation("Configuring LambdaExtensionParameterProvider for environment {AppEnvironment}", appEnv);
_ = services.AddHttpClient(LambdaExtensionParameterProvider.HttpClientName, client =>
{
client.BaseAddress = new Uri("http://localhost:2773");
Expand All @@ -85,39 +85,27 @@ private static void AddParameterProvider(IServiceCollection services, AppEnviron
_ = services.AddSingleton<IParameterProvider, LambdaExtensionParameterProvider>();
break;
}
case AppEnvironment.Dev:
case AppEnv.Dev:
{
logger?.LogInformation("Configuring LocalParameterProvider for environment {AppEnvironment}", appEnvironment);
logger?.LogInformation("Configuring LocalParameterProvider for environment {AppEnvironment}", appEnv);
_ = services.AddSingleton<IParameterProvider, LocalParameterProvider>();
break;
}
default:
{
throw new ArgumentOutOfRangeException(nameof(appEnvironment), appEnvironment,
throw new ArgumentOutOfRangeException(nameof(appEnv), appEnv,
"Unsupported environment for parameter provider.");
}
}
}

private static void AddAskAiUsecase(IServiceCollection services, AppEnvironment appEnvironment)
private static void AddAskAiUsecase(IServiceCollection services, AppEnv appEnv)
{
var logger = GetLogger(services);
logger?.LogInformation("Configuring AskAi use case for environment {AppEnvironment}", appEnvironment);

_ = services.Configure<LlmGatewayOptions>(options =>
{
var serviceProvider = services.BuildServiceProvider();
var parameterProvider = serviceProvider.GetRequiredService<IParameterProvider>();
var appEnvString = appEnvironment.ToStringFast(true);

options.ServiceAccount = parameterProvider.GetParam($"/elastic-docs-v3/{appEnvString}/llm-gateway-service-account").GetAwaiter().GetResult();
options.FunctionUrl = parameterProvider.GetParam($"/elastic-docs-v3/{appEnvString}/llm-gateway-function-url").GetAwaiter().GetResult();

var functionUri = new Uri(options.FunctionUrl);
options.TargetAudience = $"{functionUri.Scheme}://{functionUri.Host}";
});
_ = services.AddScoped<GcpIdTokenProvider>();
_ = services.AddScoped<IAskAiGateway<Stream>, LlmGatewayAskAiGateway>();
logger?.LogInformation("Configuring AskAi use case for environment {AppEnvironment}", appEnv);
_ = services.AddSingleton<GcpIdTokenProvider>();
_ = services.AddSingleton<IAskAiGateway<Stream>, LlmGatewayAskAiGateway>();
_ = services.AddScoped<LlmGatewayOptions>();
_ = services.AddScoped<AskAiUsecase>();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
<IsPublishable>true</IsPublishable>
<PublishAot>true</PublishAot>
<PublishTrimmed>true</PublishTrimmed>

<EnableSdkContainerSupport>true</EnableSdkContainerSupport>
<TrimmerSingleWarn>false</TrimmerSingleWarn>
<DockerDefaultTargetOS>Linux</DockerDefaultTargetOS>
<EnableRequestDelegateGenerator>true</EnableRequestDelegateGenerator>
<InterceptorsPreviewNamespaces>$(InterceptorsPreviewNamespaces);Microsoft.AspNetCore.Http.Generated</InterceptorsPreviewNamespaces>

<RootNamespace>Elastic.Documentation.Api.Lambda</RootNamespace>
</PropertyGroup>

<ItemGroup>
Expand Down
11 changes: 7 additions & 4 deletions src/api/Elastic.Documentation.Api.Lambda/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,26 @@
// Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
// See the LICENSE file in the project root for more information

using System.Text.Json;
using System.Text.Json.Serialization;
using Amazon.Lambda.APIGatewayEvents;
using Amazon.Lambda.Serialization.SystemTextJson;
using Elastic.Documentation.Api.Core.AskAi;
using Elastic.Documentation.Api.Infrastructure;

var builder = WebApplication.CreateSlimBuilder(args);

builder.Services.AddAWSLambdaHosting(LambdaEventSource.RestApi, new SourceGeneratorLambdaJsonSerializer<LambdaJsonSerializerContext>());
builder.Services.AddElasticDocsApiUsecases(Environment.GetEnvironmentVariable("APP_ENVIRONMENT"));
builder.Services.AddElasticDocsApiUsecases(Environment.GetEnvironmentVariable("ENVIRONMENT"));

var app = builder.Build();

var v1 = app.MapGroup("/v1");
var v1 = app.MapGroup("/docs/_api/v1");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a few reserved namespaces in the markdown build :). I have /docs/api for the OpenAPI too.

v1.MapElasticDocsApiEndpoints();

app.Run();

[JsonSerializable(typeof(APIGatewayHttpApiV2ProxyRequest), GenerationMode = JsonSourceGenerationMode.Metadata)]
[JsonSerializable(typeof(APIGatewayHttpApiV2ProxyResponse), GenerationMode = JsonSourceGenerationMode.Default)]
[JsonSerializable(typeof(APIGatewayProxyRequest))]
[JsonSerializable(typeof(APIGatewayProxyResponse))]
[JsonSerializable(typeof(AskAiRequest))]
internal sealed partial class LambdaJsonSerializerContext : JsonSerializerContext;
Loading