Skip to content

Commit

Permalink
Added command arg parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
SudoDios committed Sep 23, 2024
1 parent d262c1a commit 645d766
Show file tree
Hide file tree
Showing 3 changed files with 226 additions and 28 deletions.
178 changes: 178 additions & 0 deletions src/cmd.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
use clap::{value_parser, Arg, ArgAction, Command};

#[derive(Debug)]
pub struct Cmd {
pub download_ipdb : bool,
pub server_config_path : Option<String>,
pub bind_address : Option<String>,
pub listen_port : Option<u16>,
pub base_url : Option<String>,
pub ipinfo_api_key : Option<String>,
pub speed_test_dir : Option<String>,
pub stats_password : Option<String>,
pub database_type : Option<String>,
pub database_hostname : Option<String>,
pub database_name : Option<String>,
pub database_username : Option<String>,
pub database_password : Option<String>,
pub database_file : Option<String>,
pub enable_tls : Option<bool>,
pub tls_cert_file : Option<String>,
pub tls_key_file : Option<String>,
}

const PKG_VERSION: &str = env!("CARGO_PKG_VERSION");
const PKG_NAME: &str = env!("CARGO_PKG_NAME");
const PKG_AUTHORS: &str = env!("CARGO_PKG_AUTHORS");
const PKG_DESCRIPTION: &str = env!("CARGO_PKG_DESCRIPTION");

impl Cmd {

pub fn parse_args() -> Self {
let args = Command::new(PKG_NAME)
.version(PKG_VERSION)
.author(PKG_AUTHORS)
.about(PKG_DESCRIPTION)
.arg(
Arg::new("server-config-path")
.short('c')
.long("config")
)
.arg(
Arg::new("update-ipdb")
.long("update-ipdb")
.help("Download or update IPInfo country asn database")
.action(ArgAction::SetTrue)
)
.arg(
Arg::new("bind-address")
.short('b')
.long("bind-address")
.help("Bind IP address")
.value_parser(value_parser!(String))
)
.arg(
Arg::new("listen-port")
.short('p')
.long("listen-port")
.help("Listening port")
.value_parser(value_parser!(u16))
)
.arg(
Arg::new("base-url")
.long("base-url")
.help("Specify base api url /{base_url}/routes")
.value_parser(value_parser!(String))
)
.arg(
Arg::new("ipinfo-api-key")
.long("ipinfo-api-key")
.help("Specify the ipinfo API key")
.value_parser(value_parser!(String))
)
.arg(
Arg::new("speed-test-dir")
.long("speed-test-dir")
.help("Specify the directory of speedtest web frontend")
.value_parser(value_parser!(String))
)
.arg(
Arg::new("stats-password")
.long("stats-password")
.help("Specify the password for logging into statistics page")
.value_parser(value_parser!(String))
)
.arg(
Arg::new("database-type")
.long("database-type")
.help("Specify the database type : mysql, postgres, sqlite, memory")
.value_parser(value_parser!(String))
)
.arg(
Arg::new("database-hostname")
.long("database-hostname")
.help("Specify the database connection hostname")
.value_parser(value_parser!(String))
)
.arg(
Arg::new("database-name")
.long("database-name")
.help("Specify the database name")
.value_parser(value_parser!(String))
)
.arg(
Arg::new("database-username")
.long("database-username")
.help("Specify the database authentication username")
.value_parser(value_parser!(String))
)
.arg(
Arg::new("database-password")
.long("database-password")
.help("Specify the database authentication password")
.value_parser(value_parser!(String))
)
.arg(
Arg::new("database-file")
.long("database-file")
.help("Specify the database file path (for sqlite database type)")
.value_parser(value_parser!(String))
)
.arg(
Arg::new("enable-tls")
.long("enable-tls")
.help("Enable and use TLS server")
.value_parser(value_parser!(bool))
)
.arg(
Arg::new("tls-cert-file")
.long("tls-cert-file")
.help("Specify the certificate file path")
.value_parser(value_parser!(String))
)
.arg(
Arg::new("tls-key-file")
.long("tls-key-file")
.help("Specify the key file path")
.value_parser(value_parser!(String))
)
.get_matches();
let download_ipdb = args.get_flag("update-ipdb");
let server_config_path : Option<String> = args.get_one::<String>("server-config-path").map(|s| s.to_owned());
let bind_address : Option<String> = args.get_one::<String>("bind-address").map(|s| s.to_owned());
let listen_port : Option<u16> = args.get_one::<u16>("listen-port").map(|s| s.to_owned());
let base_url : Option<String> = args.get_one::<String>("base-url").map(|s| s.to_owned());
let ipinfo_api_key : Option<String> = args.get_one::<String>("ipinfo-api-key").map(|s| s.to_owned());
let speed_test_dir : Option<String> = args.get_one::<String>("speed-test-dir").map(|s| s.to_owned());
let stats_password : Option<String> = args.get_one::<String>("stats-password").map(|s| s.to_owned());
let database_type : Option<String> = args.get_one::<String>("database-type").map(|s| s.to_owned());
let database_hostname : Option<String> = args.get_one::<String>("database-hostname").map(|s| s.to_owned());
let database_name : Option<String> = args.get_one::<String>("database-name").map(|s| s.to_owned());
let database_username : Option<String> = args.get_one::<String>("database-username").map(|s| s.to_owned());
let database_password : Option<String> = args.get_one::<String>("database-password").map(|s| s.to_owned());
let database_file : Option<String> = args.get_one::<String>("database-file").map(|s| s.to_owned());
let enable_tls : Option<bool> = args.get_one::<bool>("enable-tls").map(|s| s.to_owned());
let tls_cert_file : Option<String> = args.get_one::<String>("tls-cert-file").map(|s| s.to_owned());
let tls_key_file : Option<String> = args.get_one::<String>("tls-key-file").map(|s| s.to_owned());
Cmd {
download_ipdb,
server_config_path,
bind_address,
listen_port,
base_url,
ipinfo_api_key,
speed_test_dir,
stats_password,
database_type,
database_hostname,
database_name,
database_username,
database_password,
database_file,
enable_tls,
tls_cert_file,
tls_key_file,
}
}

}
50 changes: 43 additions & 7 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,31 @@ use serde::Deserialize;
use serde_json::Value;
use tokio::runtime::{Builder, Runtime};
use std::io::Write;
use crate::cmd::Cmd;
use crate::config::time::current_formatted_time;

pub mod time;

trait SetIfSome<T> {
fn set_if_some(&mut self, option: Option<T>);
}

impl<T> SetIfSome<T> for T {
fn set_if_some(&mut self, option: Option<T>) {
if let Some(value) = option {
*self = value;
}
}
}

impl<T> SetIfSome<T> for Option<T> {
fn set_if_some(&mut self, option: Option<T>) {
if let Some(value) = option {
*self = Some(value);
}
}
}

#[derive(Deserialize, Debug)]
pub struct ServerConfig {
pub bind_address : String,
Expand Down Expand Up @@ -75,7 +96,7 @@ pub fn init_runtime () -> std::io::Result<Runtime> {
}
}

pub fn init_configs (config_path : Option<&String>) -> std::io::Result<()> {
pub fn init_configs (cmd : Cmd) -> std::io::Result<()> {
//init logger
env_logger::builder()
.format(|buf,rec| {
Expand All @@ -85,12 +106,12 @@ pub fn init_configs (config_path : Option<&String>) -> std::io::Result<()> {
.filter_level(LevelFilter::Info).init();
println!("{HEAD_ART}");
//find server configs
match config_path {
match cmd.server_config_path.clone() {
Some(config_path) => {
let config = open_config_file(config_path);
let config = open_config_file(&config_path);
match config {
Ok(config) => {
initialize(config)?;
initialize(config,cmd)?;
info!("Configs initialized file : {}",config_path);
Ok(())
}
Expand All @@ -104,14 +125,14 @@ pub fn init_configs (config_path : Option<&String>) -> std::io::Result<()> {
match config {
// open config from current dir
Ok(config) => {
initialize(config)?;
initialize(config,cmd)?;
info!("Configs initialized file : configs.toml");
Ok(())
}
// set default config
Err(e) => {
let config = ServerConfig::default();
initialize(config)?;
initialize(config,cmd)?;
info!("Configs initialized with defaults");
trace!("Load config default path error : {}",e);
Ok(())
Expand Down Expand Up @@ -161,9 +182,24 @@ fn generate_routes(base_url : &str) {
ROUTES.get_or_init(|| routes);
}

fn initialize (mut config: ServerConfig) -> std::io::Result<()> {
fn initialize (mut config: ServerConfig,cmd : Cmd) -> std::io::Result<()> {
//server config
config.base_url = validate_base_url_path(&config.base_url);
config.bind_address.set_if_some(cmd.bind_address);
config.listen_port.set_if_some(cmd.listen_port);
config.base_url.set_if_some(cmd.base_url);
config.ipinfo_api_key.set_if_some(cmd.ipinfo_api_key);
config.speed_test_dir.set_if_some(cmd.speed_test_dir);
config.stats_password.set_if_some(cmd.stats_password);
config.database_type.set_if_some(cmd.database_type);
config.database_hostname.set_if_some(cmd.database_hostname);
config.database_name.set_if_some(cmd.database_name);
config.database_username.set_if_some(cmd.database_username);
config.database_password .set_if_some(cmd.database_password);
config.database_file.set_if_some(cmd.database_file);
config.enable_tls.set_if_some(cmd.enable_tls);
config.tls_cert_file.set_if_some(cmd.tls_cert_file);
config.tls_key_file.set_if_some(cmd.tls_key_file);
generate_routes(&config.base_url);
if !config.speed_test_dir.is_empty() {
if check_speed_test_dir(&config.speed_test_dir) {
Expand Down
26 changes: 5 additions & 21 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,43 +1,27 @@
#![forbid(unsafe_code)]

use clap::{Arg, ArgAction, Command};
use log::error;
use crate::cmd::Cmd;
use crate::http::http_server::HttpServer;

mod http;
mod results;
mod database;
mod ip;
mod config;

const PKG_VERSION: &str = env!("CARGO_PKG_VERSION");
const PKG_NAME: &str = env!("CARGO_PKG_NAME");
const PKG_AUTHORS: &str = env!("CARGO_PKG_AUTHORS");
const PKG_DESCRIPTION: &str = env!("CARGO_PKG_DESCRIPTION");
mod cmd;

fn main() -> std::io::Result<()> {
//parse args
let args = Command::new(PKG_NAME)
.version(PKG_VERSION)
.author(PKG_AUTHORS)
.about(PKG_DESCRIPTION)
.arg(Arg::new("server_config_path").short('c').long("config"))
.arg(Arg::new("update-ipdb")
.long("update-ipdb")
.help("Download or update IPInfo country asn database")
.action(ArgAction::SetTrue))
.get_matches();
let cmd = Cmd::parse_args();

if args.get_flag("update-ipdb") {
if cmd.download_ipdb {
ip::update_ipdb("https://raw.githubusercontent.com/librespeed/speedtest-rust/master/country_asn.mmdb", "country_asn.mmdb");
return Ok(())
}

//get config path
let config_path = args.get_one::<String>("server_config_path");

//init configs & statics
if let Err(e) = config::init_configs(config_path) {
if let Err(e) = config::init_configs(cmd) {
error!("{}",e.to_string());
std::process::exit(1)
}
Expand Down

0 comments on commit 645d766

Please sign in to comment.