Skip to content

Commit

Permalink
feat(webserver): implement is_admin_initialized graphql api (TabbyML#929
Browse files Browse the repository at this point in the history
)

* feat(webserver): implement is_admin_initialized graphql api

* refactor

* add unit test

* [autofix.ci] apply automated fixes

* renaming

* refactor: server -> locator

* fix unused

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
wsxiaoys and autofix-ci[bot] authored Dec 1, 2023
1 parent 5c52a71 commit 88e5187
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 31 deletions.
2 changes: 1 addition & 1 deletion ee/tabby-webserver/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use axum::{
use hyper::Body;
use juniper_axum::{graphiql, graphql, playground};
use schema::{
worker::{RegisterWorkerError, Worker, WorkerKind, WorkerService},
worker::{RegisterWorkerError, Worker, WorkerKind},
Schema, ServiceLocator,
};
use service::create_service_locator;
Expand Down
1 change: 1 addition & 0 deletions ee/tabby-webserver/src/schema/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ pub trait AuthenticationService: Send + Sync {
async fn token_auth(&self, email: String, password: String) -> FieldResult<TokenAuthResponse>;
async fn refresh_token(&self, refresh_token: String) -> FieldResult<RefreshTokenResponse>;
async fn verify_token(&self, access_token: String) -> FieldResult<VerifyTokenResponse>;
async fn is_admin_initialized(&self) -> FieldResult<bool>;
}

#[cfg(test)]
Expand Down
22 changes: 13 additions & 9 deletions ee/tabby-webserver/src/schema/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ pub trait ServiceLocator: Send + Sync {

pub struct Context {
claims: Option<auth::Claims>,
server: Arc<dyn ServiceLocator>,
locator: Arc<dyn ServiceLocator>,
}

impl FromAuth<Arc<dyn ServiceLocator>> for Context {
fn build(server: Arc<dyn ServiceLocator>, bearer: Option<String>) -> Self {
fn build(locator: Arc<dyn ServiceLocator>, bearer: Option<String>) -> Self {
let claims = bearer.and_then(|token| validate_jwt(&token).ok());
Self { claims, server }
Self { claims, locator }
}
}

Expand All @@ -46,13 +46,17 @@ pub struct Query;
#[graphql_object(context = Context)]
impl Query {
async fn workers(ctx: &Context) -> Vec<Worker> {
ctx.server.worker().list_workers().await
ctx.locator.worker().list_workers().await
}

async fn registration_token(ctx: &Context) -> FieldResult<String> {
let token = ctx.server.worker().read_registration_token().await?;
let token = ctx.locator.worker().read_registration_token().await?;
Ok(token)
}

async fn is_admin_initialized(ctx: &Context) -> FieldResult<bool> {
ctx.locator.auth().is_admin_initialized().await
}
}

#[derive(Default)]
Expand All @@ -63,7 +67,7 @@ impl Mutation {
async fn reset_registration_token(ctx: &Context) -> FieldResult<String> {
if let Some(claims) = &ctx.claims {
if claims.user_info().is_admin() {
let reg_token = ctx.server.worker().reset_registration_token().await?;
let reg_token = ctx.locator.worker().reset_registration_token().await?;
return Ok(reg_token);
}
}
Expand All @@ -79,7 +83,7 @@ impl Mutation {
password1: String,
password2: String,
) -> FieldResult<RegisterResponse> {
ctx.server
ctx.locator
.auth()
.register(email, password1, password2)
.await
Expand All @@ -90,11 +94,11 @@ impl Mutation {
email: String,
password: String,
) -> FieldResult<TokenAuthResponse> {
ctx.server.auth().token_auth(email, password).await
ctx.locator.auth().token_auth(email, password).await
}

async fn verify_token(ctx: &Context, token: String) -> FieldResult<VerifyTokenResponse> {
ctx.server.auth().verify_token(token).await
ctx.locator.auth().verify_token(token).await
}
}

Expand Down
5 changes: 5 additions & 0 deletions ee/tabby-webserver/src/service/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@ impl AuthenticationService for DbConn {
let resp = VerifyTokenResponse::new(claims);
Ok(resp)
}

async fn is_admin_initialized(&self) -> FieldResult<bool> {
let admin = self.list_admin_users().await?;
Ok(!admin.is_empty())
}
}

fn password_hash(raw: &str) -> password_hash::Result<String> {
Expand Down
80 changes: 59 additions & 21 deletions ee/tabby-webserver/src/service/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{path::PathBuf, sync::Arc};

use anyhow::Result;
use lazy_static::lazy_static;
use rusqlite::{params, OptionalExtension};
use rusqlite::{params, OptionalExtension, Row};
use rusqlite_migration::{AsyncMigrations, M};
use tabby_common::path::tabby_root;
use tokio_rusqlite::Connection;
Expand Down Expand Up @@ -47,6 +47,25 @@ pub struct User {
pub is_admin: bool,
}

impl User {
fn select(clause: &str) -> String {
r#"SELECT id, email, password_encrypted, is_admin, created_at, updated_at FROM users WHERE "#
.to_owned()
+ clause
}

fn from_row(row: &Row<'_>) -> std::result::Result<User, rusqlite::Error> {
Ok(User {
id: row.get(0)?,
email: row.get(1)?,
password_encrypted: row.get(2)?,
is_admin: row.get(3)?,
created_at: row.get(4)?,
updated_at: row.get(5)?,
})
}
}

async fn db_path() -> Result<PathBuf> {
let db_dir = tabby_root().join("ee");
tokio::fs::create_dir_all(db_dir.clone()).await?;
Expand Down Expand Up @@ -156,35 +175,51 @@ impl DbConn {
.conn
.call(move |c| {
c.query_row(
r#"SELECT id, email, password_encrypted, is_admin, created_at, updated_at FROM users WHERE email = ?"#,
User::select("email = ?").as_str(),
params![email],
|row| {
Ok(User {
id: row.get(0)?,
email: row.get(1)?,
password_encrypted: row.get(2)?,
is_admin: row.get(3)?,
created_at: row.get(4)?,
updated_at: row.get(5)?,
})
},
).optional()
User::from_row,
)
.optional()
})
.await?;

Ok(user)
}

pub async fn list_admin_users(&self) -> Result<Vec<User>> {
let users = self
.conn
.call(move |c| {
let mut stmt = c.prepare(&User::select("is_admin"))?;
let user_iter = stmt.query_map([], User::from_row)?;
Ok(user_iter.filter_map(|x| x.ok()).collect::<Vec<_>>())
})
.await?;

Ok(users)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::schema::auth::AuthenticationService;

async fn new_in_memory() -> Result<DbConn> {
let conn = Connection::open_in_memory().await?;
DbConn::init_db(conn).await
}

async fn create_admin_user(conn: &DbConn) -> String {
let email = "[email protected]";
let passwd = "123456";
let is_admin = true;
conn.create_user(email.to_string(), passwd.to_string(), is_admin)
.await
.unwrap();
email.to_owned()
}

#[tokio::test]
async fn migrations_test() {
assert!(MIGRATIONS.validate().await.is_ok());
Expand Down Expand Up @@ -212,14 +247,8 @@ mod tests {
async fn test_create_user() {
let conn = new_in_memory().await.unwrap();

let email = "[email protected]";
let passwd = "123456";
let is_admin = true;
conn.create_user(email.to_string(), passwd.to_string(), is_admin)
.await
.unwrap();

let user = conn.get_user_by_email(email).await.unwrap().unwrap();
let email = create_admin_user(&conn).await;
let user = conn.get_user_by_email(&email).await.unwrap().unwrap();
assert_eq!(user.id, 1);
}

Expand All @@ -232,4 +261,13 @@ mod tests {

assert!(user.is_none());
}

#[tokio::test]
async fn test_is_admin_initialized() {
let conn = new_in_memory().await.unwrap();

assert!(!conn.is_admin_initialized().await.unwrap());
create_admin_user(&conn).await;
assert!(conn.is_admin_initialized().await.unwrap());
}
}

0 comments on commit 88e5187

Please sign in to comment.