Skip to content

Commit

Permalink
feat(webserver): Populate the user field in completions requests if t…
Browse files Browse the repository at this point in the history
…he user authenticated (TabbyML#1341)

* Partially complete

* Fix error and populate user field

* Add back comment to refactored code

* Get user email from access token
  • Loading branch information
boxbeam authored Feb 1, 2024
1 parent 1f51a09 commit 2f14062
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 52 deletions.
34 changes: 32 additions & 2 deletions crates/tabby/src/routes/completions.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::sync::Arc;

use axum::{extract::State, Json};
use axum::{extract::State, headers::Header, Json, TypedHeader};
use hyper::StatusCode;
use tabby_webserver::public::USER_HEADER_FIELD_NAME;
use tracing::{instrument, warn};

use crate::services::completion::{CompletionRequest, CompletionResponse, CompletionService};
Expand All @@ -23,8 +24,12 @@ use crate::services::completion::{CompletionRequest, CompletionResponse, Complet
#[instrument(skip(state, request))]
pub async fn completions(
State(state): State<Arc<CompletionService>>,
Json(request): Json<CompletionRequest>,
TypedHeader(MaybeUser(user)): TypedHeader<MaybeUser>,
Json(mut request): Json<CompletionRequest>,
) -> Result<Json<CompletionResponse>, StatusCode> {
if let Some(user) = user {
request.user.replace(user);
}
match state.generate(&request).await {
Ok(resp) => Ok(Json(resp)),
Err(err) => {
Expand All @@ -33,3 +38,28 @@ pub async fn completions(
}
}
}

#[derive(Debug)]
pub struct MaybeUser(Option<String>);

impl Header for MaybeUser {
fn name() -> &'static axum::http::HeaderName {
&USER_HEADER_FIELD_NAME
}

fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
where
Self: Sized,
I: Iterator<Item = &'i axum::http::HeaderValue>,
{
let Some(value) = values.next() else {
return Ok(MaybeUser(None));
};
let str = value.to_str().expect("User email is always a valid string");
Ok(MaybeUser(Some(str.to_string())))
}

fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
todo!()
}
}
2 changes: 1 addition & 1 deletion crates/tabby/src/services/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub struct CompletionRequest {

/// A unique identifier representing your end-user, which can help Tabby to monitor & generating
/// reports.
user: Option<String>,
pub(crate) user: Option<String>,

debug_options: Option<DebugOptions>,

Expand Down
11 changes: 5 additions & 6 deletions ee/tabby-db/src/users.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,15 @@ impl DbConn {
);

let users = sqlx::query_as(&query).fetch_all(&self.pool).await?;

Ok(users)
}

pub async fn verify_auth_token(&self, token: &str) -> bool {
pub async fn verify_auth_token(&self, token: &str) -> Result<String> {
let token = token.to_owned();
let id = query_scalar!("SELECT id FROM users WHERE auth_token = ?", token)
let email = query_scalar!("SELECT email FROM users WHERE auth_token = ?", token)
.fetch_one(&self.pool)
.await;
id.is_ok()
email.map_err(Into::into)
}

pub async fn reset_user_auth_token_by_email(&self, email: &str) -> Result<()> {
Expand Down Expand Up @@ -222,9 +221,9 @@ mod tests {

let user = conn.get_user(id).await.unwrap().unwrap();

assert!(!conn.verify_auth_token("abcd").await);
assert!(conn.verify_auth_token("abcd").await.is_err());

assert!(conn.verify_auth_token(&user.auth_token).await);
assert!(conn.verify_auth_token(&user.auth_token).await.is_ok());

conn.reset_user_auth_token_by_email(&user.email)
.await
Expand Down
5 changes: 5 additions & 0 deletions ee/tabby-webserver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ mod service;
mod ui;

pub mod public {

pub static USER_HEADER_FIELD_NAME: HeaderName = HeaderName::from_static("x-tabby-user");

use axum::http::HeaderName;

pub use super::{
handler::attach_webserver,
/* used by tabby workers (consumer of /hub api) */
Expand Down
2 changes: 1 addition & 1 deletion ee/tabby-webserver/src/schema/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ impl RefreshTokenResponse {

#[derive(Debug, GraphQLObject)]
pub struct VerifyTokenResponse {
claims: JWTPayload,
pub claims: JWTPayload,
}

impl VerifyTokenResponse {
Expand Down
86 changes: 44 additions & 42 deletions ee/tabby-webserver/src/service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@ use tabby_db::DbConn;
use tracing::{info, warn};

use self::{cron::run_cron, email::new_email_service};
use crate::schema::{
auth::AuthenticationService,
email::EmailService,
job::JobService,
repository::RepositoryService,
worker::{RegisterWorkerError, Worker, WorkerKind, WorkerService},
ServiceLocator,
use crate::{
public::USER_HEADER_FIELD_NAME,
schema::{
auth::AuthenticationService,
email::EmailService,
job::JobService,
repository::RepositoryService,
worker::{RegisterWorkerError, Worker, WorkerKind, WorkerService},
ServiceLocator,
},
};

struct ServerContext {
Expand Down Expand Up @@ -60,41 +63,32 @@ impl ServerContext {
}
}

async fn authorize_request(&self, request: &Request<Body>) -> bool {
async fn authorize_request(&self, request: &Request<Body>) -> (bool, Option<String>) {
let path = request.uri().path();
if path.starts_with("/v1/") || path.starts_with("/v1beta/") {
let token = {
let authorization = request
.headers()
.get("authorization")
.map(HeaderValue::to_str)
.and_then(Result::ok);

if let Some(authorization) = authorization {
let split = authorization.split_once(' ');
match split {
// Found proper bearer
Some(("Bearer", contents)) => Some(contents),
_ => None,
}
} else {
None
}
};

if let Some(token) = token {
if self.db_conn.verify_access_token(token).await.is_err()
&& !self.db_conn.verify_auth_token(token).await
{
return false;
}
} else {
// Admin system is initialized, but there's no valid token.
return false;
}
if !(path.starts_with("/v1/") || path.starts_with("/v1beta/")) {
return (true, None);
}
let authorization = request
.headers()
.get("authorization")
.map(HeaderValue::to_str)
.and_then(Result::ok);

let token = authorization
.and_then(|s| s.split_once(' '))
.map(|(_bearer, token)| token);

let Some(token) = token else {
// Admin system is initialized, but there is no valid token.
return (false, None);
};
if let Ok(jwt) = self.db_conn.verify_access_token(token).await {
return (true, Some(jwt.claims.sub));
}
match self.db_conn.verify_auth_token(token).await {
Ok(email) => (true, Some(email)),
Err(_) => (false, None),
}

true
}
}

Expand Down Expand Up @@ -147,17 +141,25 @@ impl WorkerService for ServerContext {

async fn dispatch_request(
&self,
request: Request<Body>,
mut request: Request<Body>,
next: Next<Body>,
) -> axum::response::Response {
if !self.authorize_request(&request).await {
let (auth, user) = self.authorize_request(&request).await;
if !auth {
return axum::response::Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(Body::empty())
.unwrap()
.into_response();
}

if let Some(user) = user {
request.headers_mut().append(
&USER_HEADER_FIELD_NAME,
HeaderValue::from_str(&user).expect("User must be valid header"),
);
}

let remote_addr = request
.extensions()
.get::<axum::extract::ConnectInfo<SocketAddr>>()
Expand Down

0 comments on commit 2f14062

Please sign in to comment.