From 645d76637549028a8b9a7fa7678e68a7f8f90798 Mon Sep 17 00:00:00 2001 From: Sudo Dios Date: Mon, 23 Sep 2024 13:20:48 +0330 Subject: [PATCH] Added command arg parameters --- src/cmd.rs | 178 ++++++++++++++++++++++++++++++++++++++++++++++ src/config/mod.rs | 50 +++++++++++-- src/main.rs | 26 ++----- 3 files changed, 226 insertions(+), 28 deletions(-) create mode 100644 src/cmd.rs diff --git a/src/cmd.rs b/src/cmd.rs new file mode 100644 index 0000000..d5e512c --- /dev/null +++ b/src/cmd.rs @@ -0,0 +1,178 @@ +use clap::{value_parser, Arg, ArgAction, Command}; + +#[derive(Debug)] +pub struct Cmd { + pub download_ipdb : bool, + pub server_config_path : Option, + pub bind_address : Option, + pub listen_port : Option, + pub base_url : Option, + pub ipinfo_api_key : Option, + pub speed_test_dir : Option, + pub stats_password : Option, + pub database_type : Option, + pub database_hostname : Option, + pub database_name : Option, + pub database_username : Option, + pub database_password : Option, + pub database_file : Option, + pub enable_tls : Option, + pub tls_cert_file : Option, + pub tls_key_file : Option, +} + +const PKG_VERSION: &str = env!("CARGO_PKG_VERSION"); +const PKG_NAME: &str = env!("CARGO_PKG_NAME"); +const PKG_AUTHORS: &str = env!("CARGO_PKG_AUTHORS"); +const PKG_DESCRIPTION: &str = env!("CARGO_PKG_DESCRIPTION"); + +impl Cmd { + + pub fn parse_args() -> Self { + let args = Command::new(PKG_NAME) + .version(PKG_VERSION) + .author(PKG_AUTHORS) + .about(PKG_DESCRIPTION) + .arg( + Arg::new("server-config-path") + .short('c') + .long("config") + ) + .arg( + Arg::new("update-ipdb") + .long("update-ipdb") + .help("Download or update IPInfo country asn database") + .action(ArgAction::SetTrue) + ) + .arg( + Arg::new("bind-address") + .short('b') + .long("bind-address") + .help("Bind IP address") + .value_parser(value_parser!(String)) + ) + .arg( + Arg::new("listen-port") + .short('p') + .long("listen-port") + .help("Listening port") + .value_parser(value_parser!(u16)) + ) + .arg( + Arg::new("base-url") + .long("base-url") + .help("Specify base api url /{base_url}/routes") + .value_parser(value_parser!(String)) + ) + .arg( + Arg::new("ipinfo-api-key") + .long("ipinfo-api-key") + .help("Specify the ipinfo API key") + .value_parser(value_parser!(String)) + ) + .arg( + Arg::new("speed-test-dir") + .long("speed-test-dir") + .help("Specify the directory of speedtest web frontend") + .value_parser(value_parser!(String)) + ) + .arg( + Arg::new("stats-password") + .long("stats-password") + .help("Specify the password for logging into statistics page") + .value_parser(value_parser!(String)) + ) + .arg( + Arg::new("database-type") + .long("database-type") + .help("Specify the database type : mysql, postgres, sqlite, memory") + .value_parser(value_parser!(String)) + ) + .arg( + Arg::new("database-hostname") + .long("database-hostname") + .help("Specify the database connection hostname") + .value_parser(value_parser!(String)) + ) + .arg( + Arg::new("database-name") + .long("database-name") + .help("Specify the database name") + .value_parser(value_parser!(String)) + ) + .arg( + Arg::new("database-username") + .long("database-username") + .help("Specify the database authentication username") + .value_parser(value_parser!(String)) + ) + .arg( + Arg::new("database-password") + .long("database-password") + .help("Specify the database authentication password") + .value_parser(value_parser!(String)) + ) + .arg( + Arg::new("database-file") + .long("database-file") + .help("Specify the database file path (for sqlite database type)") + .value_parser(value_parser!(String)) + ) + .arg( + Arg::new("enable-tls") + .long("enable-tls") + .help("Enable and use TLS server") + .value_parser(value_parser!(bool)) + ) + .arg( + Arg::new("tls-cert-file") + .long("tls-cert-file") + .help("Specify the certificate file path") + .value_parser(value_parser!(String)) + ) + .arg( + Arg::new("tls-key-file") + .long("tls-key-file") + .help("Specify the key file path") + .value_parser(value_parser!(String)) + ) + .get_matches(); + let download_ipdb = args.get_flag("update-ipdb"); + let server_config_path : Option = args.get_one::("server-config-path").map(|s| s.to_owned()); + let bind_address : Option = args.get_one::("bind-address").map(|s| s.to_owned()); + let listen_port : Option = args.get_one::("listen-port").map(|s| s.to_owned()); + let base_url : Option = args.get_one::("base-url").map(|s| s.to_owned()); + let ipinfo_api_key : Option = args.get_one::("ipinfo-api-key").map(|s| s.to_owned()); + let speed_test_dir : Option = args.get_one::("speed-test-dir").map(|s| s.to_owned()); + let stats_password : Option = args.get_one::("stats-password").map(|s| s.to_owned()); + let database_type : Option = args.get_one::("database-type").map(|s| s.to_owned()); + let database_hostname : Option = args.get_one::("database-hostname").map(|s| s.to_owned()); + let database_name : Option = args.get_one::("database-name").map(|s| s.to_owned()); + let database_username : Option = args.get_one::("database-username").map(|s| s.to_owned()); + let database_password : Option = args.get_one::("database-password").map(|s| s.to_owned()); + let database_file : Option = args.get_one::("database-file").map(|s| s.to_owned()); + let enable_tls : Option = args.get_one::("enable-tls").map(|s| s.to_owned()); + let tls_cert_file : Option = args.get_one::("tls-cert-file").map(|s| s.to_owned()); + let tls_key_file : Option = args.get_one::("tls-key-file").map(|s| s.to_owned()); + Cmd { + download_ipdb, + server_config_path, + bind_address, + listen_port, + base_url, + ipinfo_api_key, + speed_test_dir, + stats_password, + database_type, + database_hostname, + database_name, + database_username, + database_password, + database_file, + enable_tls, + tls_cert_file, + tls_key_file, + } + } + +} \ No newline at end of file diff --git a/src/config/mod.rs b/src/config/mod.rs index d41a27d..de91e2c 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -10,10 +10,31 @@ use serde::Deserialize; use serde_json::Value; use tokio::runtime::{Builder, Runtime}; use std::io::Write; +use crate::cmd::Cmd; use crate::config::time::current_formatted_time; pub mod time; +trait SetIfSome { + fn set_if_some(&mut self, option: Option); +} + +impl SetIfSome for T { + fn set_if_some(&mut self, option: Option) { + if let Some(value) = option { + *self = value; + } + } +} + +impl SetIfSome for Option { + fn set_if_some(&mut self, option: Option) { + if let Some(value) = option { + *self = Some(value); + } + } +} + #[derive(Deserialize, Debug)] pub struct ServerConfig { pub bind_address : String, @@ -75,7 +96,7 @@ pub fn init_runtime () -> std::io::Result { } } -pub fn init_configs (config_path : Option<&String>) -> std::io::Result<()> { +pub fn init_configs (cmd : Cmd) -> std::io::Result<()> { //init logger env_logger::builder() .format(|buf,rec| { @@ -85,12 +106,12 @@ pub fn init_configs (config_path : Option<&String>) -> std::io::Result<()> { .filter_level(LevelFilter::Info).init(); println!("{HEAD_ART}"); //find server configs - match config_path { + match cmd.server_config_path.clone() { Some(config_path) => { - let config = open_config_file(config_path); + let config = open_config_file(&config_path); match config { Ok(config) => { - initialize(config)?; + initialize(config,cmd)?; info!("Configs initialized file : {}",config_path); Ok(()) } @@ -104,14 +125,14 @@ pub fn init_configs (config_path : Option<&String>) -> std::io::Result<()> { match config { // open config from current dir Ok(config) => { - initialize(config)?; + initialize(config,cmd)?; info!("Configs initialized file : configs.toml"); Ok(()) } // set default config Err(e) => { let config = ServerConfig::default(); - initialize(config)?; + initialize(config,cmd)?; info!("Configs initialized with defaults"); trace!("Load config default path error : {}",e); Ok(()) @@ -161,9 +182,24 @@ fn generate_routes(base_url : &str) { ROUTES.get_or_init(|| routes); } -fn initialize (mut config: ServerConfig) -> std::io::Result<()> { +fn initialize (mut config: ServerConfig,cmd : Cmd) -> std::io::Result<()> { //server config config.base_url = validate_base_url_path(&config.base_url); + config.bind_address.set_if_some(cmd.bind_address); + config.listen_port.set_if_some(cmd.listen_port); + config.base_url.set_if_some(cmd.base_url); + config.ipinfo_api_key.set_if_some(cmd.ipinfo_api_key); + config.speed_test_dir.set_if_some(cmd.speed_test_dir); + config.stats_password.set_if_some(cmd.stats_password); + config.database_type.set_if_some(cmd.database_type); + config.database_hostname.set_if_some(cmd.database_hostname); + config.database_name.set_if_some(cmd.database_name); + config.database_username.set_if_some(cmd.database_username); + config.database_password .set_if_some(cmd.database_password); + config.database_file.set_if_some(cmd.database_file); + config.enable_tls.set_if_some(cmd.enable_tls); + config.tls_cert_file.set_if_some(cmd.tls_cert_file); + config.tls_key_file.set_if_some(cmd.tls_key_file); generate_routes(&config.base_url); if !config.speed_test_dir.is_empty() { if check_speed_test_dir(&config.speed_test_dir) { diff --git a/src/main.rs b/src/main.rs index 60dfc22..e94cce5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ #![forbid(unsafe_code)] -use clap::{Arg, ArgAction, Command}; use log::error; +use crate::cmd::Cmd; use crate::http::http_server::HttpServer; mod http; @@ -9,35 +9,19 @@ mod results; mod database; mod ip; mod config; - -const PKG_VERSION: &str = env!("CARGO_PKG_VERSION"); -const PKG_NAME: &str = env!("CARGO_PKG_NAME"); -const PKG_AUTHORS: &str = env!("CARGO_PKG_AUTHORS"); -const PKG_DESCRIPTION: &str = env!("CARGO_PKG_DESCRIPTION"); +mod cmd; fn main() -> std::io::Result<()> { //parse args - let args = Command::new(PKG_NAME) - .version(PKG_VERSION) - .author(PKG_AUTHORS) - .about(PKG_DESCRIPTION) - .arg(Arg::new("server_config_path").short('c').long("config")) - .arg(Arg::new("update-ipdb") - .long("update-ipdb") - .help("Download or update IPInfo country asn database") - .action(ArgAction::SetTrue)) - .get_matches(); + let cmd = Cmd::parse_args(); - if args.get_flag("update-ipdb") { + if cmd.download_ipdb { ip::update_ipdb("https://raw.githubusercontent.com/librespeed/speedtest-rust/master/country_asn.mmdb", "country_asn.mmdb"); return Ok(()) } - //get config path - let config_path = args.get_one::("server_config_path"); - //init configs & statics - if let Err(e) = config::init_configs(config_path) { + if let Err(e) = config::init_configs(cmd) { error!("{}",e.to_string()); std::process::exit(1) }