forked from jnsahaj/lumen
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclaude.rs
91 lines (83 loc) · 2.72 KB
/
claude.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
use super::{AIProvider, ProviderError};
use crate::ai_prompt::AIPrompt;
use async_trait::async_trait;
use reqwest::StatusCode;
use serde_json::{json, Value};
#[derive(Clone)]
pub struct ClaudeConfig {
api_key: String,
model: String,
api_base_url: String,
}
impl ClaudeConfig {
pub fn new(api_key: String, model: Option<String>) -> Self {
Self {
api_key,
model: model.unwrap_or_else(|| "claude-3-5-sonnet-20241022".to_string()),
api_base_url: "https://api.anthropic.com/v1/messages".to_string(),
}
}
}
pub struct ClaudeProvider {
client: reqwest::Client,
config: ClaudeConfig,
}
impl ClaudeProvider {
pub fn new(client: reqwest::Client, config: ClaudeConfig) -> Self {
Self { client, config }
}
async fn complete(&self, prompt: AIPrompt) -> Result<String, ProviderError> {
let payload = json!({
"model": self.config.model,
"max_tokens": 4096,
"messages": [
{
"role": "system",
"content": prompt.system_prompt
},
{
"role": "user",
"content": prompt.user_prompt
}
]
});
let response = self
.client
.post(&self.config.api_base_url)
.header("x-api-key", &self.config.api_key)
.header("anthropic-version", "2023-06-01")
.header("Content-Type", "application/json")
.json(&payload)
.send()
.await?;
let status = response.status();
match status {
StatusCode::OK => {
let response_json: Value = response.json().await?;
let content = response_json
.get("content")
.and_then(|content| content.get(0))
.and_then(|message| message.get("text"))
.and_then(|text| text.as_str())
.ok_or(ProviderError::NoCompletionChoice)?;
Ok(content.to_string())
}
_ => {
let error_json: Value = response.json().await?;
let error_message = error_json
.get("error")
.and_then(|error| error.get("message"))
.and_then(|msg| msg.as_str())
.ok_or(ProviderError::UnexpectedResponse)?
.into();
Err(ProviderError::APIError(status, error_message))
}
}
}
}
#[async_trait]
impl AIProvider for ClaudeProvider {
async fn complete(&self, prompt: AIPrompt) -> Result<String, ProviderError> {
self.complete(prompt).await
}
}