diff --git a/Cargo.lock b/Cargo.lock index f64052f..303fc90 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -92,9 +92,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.74" +version = "1.0.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c6f84b74db2535ebae81eede2f39b947dcbf01d093ae5f791e5dd414a1bf289" +checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" [[package]] name = "async-stream" @@ -658,6 +658,7 @@ dependencies = [ name = "socks5ws" version = "0.1.0" dependencies = [ + "anyhow", "async-stream", "clap", "ctrlc", diff --git a/Cargo.toml b/Cargo.toml index 7e8d716..4904b5f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +anyhow = "1.0.75" async-stream = "0.3.3" clap = { version = "4.3.21", features = ["derive"] } ctrlc = "3.2.3" diff --git a/src/config.rs b/src/config.rs index f9b71ef..b2246ea 100644 --- a/src/config.rs +++ b/src/config.rs @@ -3,6 +3,7 @@ use std::{ path::{Path, PathBuf}, }; +use anyhow::{Context, Result}; use serde_derive::{Deserialize, Serialize}; #[derive(Clone, Serialize, Deserialize, Debug)] @@ -65,25 +66,24 @@ impl Default for Config { impl Config { const FILENAME: &'static str = "config.toml"; - fn read>(path: P) -> Result { - let data = fs::read_to_string(path).map_err(|e| format!("can't read config: {:?}", e))?; - toml::from_str(&data).map_err(|e| format!("can't parse config: {:?}", e)) + fn read>(path: P) -> Result { + let data = fs::read_to_string(path).context("can't read config")?; + toml::from_str(&data).context("can't parse config") } - fn write>(&self, path: P) -> Result<(), String> { - let data = toml::to_string_pretty(&self) - .map_err(|e| format!("can't serialize config: {:?}", e))?; - fs::write(path, data).map_err(|e| format!("can't write config: {:?}", e)) + fn write>(&self, path: P) -> Result<()> { + let data = toml::to_string_pretty(&self).context("can't serialize config")?; + fs::write(path, data).context("can't write config") } - fn file_location() -> Result { + fn file_location() -> Result { let res = env::current_exe() - .map_err(|e| format!("can't get current exe path: {:?}", e))? + .context("can't get current exe path")? .with_file_name(Config::FILENAME); Ok(res) } pub fn get() -> Self { let path = Config::file_location(); if let Err(e) = path { - log::error!("Error: {e}, using default config"); + log::error!(r#"Error: "{e}", using default config"#); return Config::default(); } @@ -91,7 +91,7 @@ impl Config { let cfg = Config::read(path); match cfg { Err(e) => { - log::error!("Error: {e}, using default config"); + log::error!(r#"Error: "{e}", using default config"#); Config::default() } Ok(cfg) => cfg, @@ -106,9 +106,9 @@ impl Config { let res = self.write(&path); if let Err(e) = res { - log::error!("save error: {e}"); + log::error!("save error: {}", &e); } else { - log::info!("config saved to: {}", path.to_str().unwrap()); + log::info!(r#"config saved to: "{}""#, path.to_str().unwrap()); } } } diff --git a/src/main.rs b/src/main.rs index ebf8d30..5802356 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,6 +9,7 @@ use flexi_logger::{ AdaptiveFormat, Age, Cleanup, Criterion, Duplicate, FileSpec, Logger, LoggerHandle, Naming, }; +use anyhow::{Context, Result}; use clap::{Parser, Subcommand}; use tokio_util::sync::CancellationToken; @@ -39,15 +40,15 @@ struct Cli { command: Command, } -fn create_logger() -> LoggerHandle { +fn create_logger() -> Result { Logger::try_with_str("info") - .expect("default logging level invalid") + .context("default logging level invalid")? .log_to_file( FileSpec::default().directory( std::env::current_exe() - .expect("can't get current exe path") + .context("can't get current exe path")? .parent() - .expect("can't get parent folder"), + .context("can't get parent folder")?, ), ) .rotate( @@ -62,13 +63,18 @@ fn create_logger() -> LoggerHandle { .write_mode(flexi_logger::WriteMode::Async) .start_with_specfile( std::env::current_exe() - .unwrap() + .context("can't get current exe path")? .with_file_name("logspec.toml"), ) - .expect("can't start logger") + .context("can't start logger") } -fn server_foreground() { +fn save_default_config() -> Result<()> { + Config::default().save(); + Ok(()) +} + +fn server_foreground() -> Result<()> { let control_token = CancellationToken::new(); let server_token = control_token.child_token(); @@ -81,12 +87,14 @@ fn server_foreground() { log::info!("Press Ctrl-C to stop server"); } - server::server_executor(Config::get(), server_token).unwrap(); + server::server_executor(Config::get(), server_token)?; + + Ok(()) } -fn main() { +fn main() -> Result<()> { let args = Cli::parse(); - let logger = create_logger(); + let logger = create_logger()?; let res = match args.command { Command::Install => service::install(), @@ -94,19 +102,15 @@ fn main() { Command::Run => service::run(), Command::Start => service::start(), Command::Stop => service::stop(), - Command::SaveConfig => { - Config::default().save(); - Ok(()) - } - Command::Serve => { - server_foreground(); - Ok(()) - } + Command::SaveConfig => save_default_config(), + Command::Serve => server_foreground(), }; - if let Err(e) = res { - log::error!("{:?} error: {:#?}", args.command, e); + if let Err(e) = &res { + log::error!("{:?} -> error: {:?}", args.command, e); } + res?; drop(logger); + Ok(()) } diff --git a/src/server.rs b/src/server.rs index 13b96b7..82aa9e6 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,9 +1,7 @@ -use fast_socks5::{ - server::{SimpleUserPassword, Socks5Server, Socks5Socket}, - Result, -}; +use anyhow::{anyhow, Result}; +use fast_socks5::server::{SimpleUserPassword, Socks5Server, Socks5Socket}; use std::future::Future; -use tokio::io::{self, AsyncRead, AsyncWrite}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio::select; use tokio::task; use tokio_stream::{Stream, StreamExt}; @@ -14,14 +12,14 @@ use async_stream::stream; use crate::config::Config; use crate::config::PasswordAuth; -pub fn server_executor(cfg: Config, token: CancellationToken) -> io::Result<()> { +pub fn server_executor(cfg: Config, token: CancellationToken) -> Result<()> { tokio::runtime::Builder::new_multi_thread() .enable_all() .build()? .block_on(async { spawn_socks5_server(cfg, token).await }) } -pub async fn spawn_socks5_server(cfg: Config, token: CancellationToken) -> io::Result<()> { +pub async fn spawn_socks5_server(cfg: Config, token: CancellationToken) -> Result<()> { let mut server_config = fast_socks5::server::Config::default(); server_config.set_request_timeout(cfg.request_timeout); server_config.set_skip_auth(cfg.skip_auth); @@ -51,7 +49,7 @@ pub async fn spawn_socks5_server(cfg: Config, token: CancellationToken) -> io::R spawn_and_log_error(socket.upgrade_to_socks5(), child_token); } Err(err) => { - log::error!("accept error = {:?}", err); + log::error!("accept error: {}", err); } } } @@ -93,7 +91,7 @@ where fn spawn_and_log_error(future: F, token: CancellationToken) -> task::JoinHandle<()> where - F: Future>> + Send + 'static, + F: Future>> + Send + 'static, T: AsyncRead + AsyncWrite + Unpin, { tokio::spawn(async move { @@ -101,10 +99,10 @@ where biased; _ = token.cancelled() => { - Err("Client connection canceled".to_string()) + Err(anyhow!("Client connection canceled")) } res = future => { - res.map_err(|e| e.to_string()) + res.map_err(anyhow::Error::new) } }; if let Err(e) = result { diff --git a/src/service.rs b/src/service.rs index e785e5c..f2c81e3 100644 --- a/src/service.rs +++ b/src/service.rs @@ -1,5 +1,6 @@ use tokio_util::sync::CancellationToken; +use anyhow::{anyhow, Result}; use std::{ffi::OsString, thread, time::Duration}; use windows_service::{ define_windows_service, @@ -10,7 +11,6 @@ use windows_service::{ service_control_handler::{self, ServiceControlHandlerResult}, service_dispatcher, service_manager::{ServiceManager, ServiceManagerAccess}, - Result, }; use crate::config::Config; @@ -21,21 +21,6 @@ const SERVICE_TYPE: ServiceType = ServiceType::OWN_PROCESS; const SERVICE_DISPLAY: &str = "socks5ws proxy"; const SERVICE_DESCRIPTION: &str = "SOCKS5 proxy windows service"; -trait ErrorToString { - type Output; - fn str_err(self) -> std::result::Result; -} - -impl ErrorToString for std::result::Result -where - E: std::error::Error, -{ - type Output = T; - fn str_err(self) -> std::result::Result { - self.map_err(|e| e.to_string()) - } -} - trait ServiceStatusEx { fn running() -> ServiceStatus; fn stopped() -> ServiceStatus; @@ -71,11 +56,11 @@ impl ServiceStatusEx for ServiceStatus { } } -pub fn install() -> windows_service::Result<()> { +pub fn install() -> Result<()> { let manager_access = ServiceManagerAccess::CONNECT | ServiceManagerAccess::CREATE_SERVICE; let service_manager = ServiceManager::local_computer(None::<&str>, manager_access)?; - let service_binary_path = ::std::env::current_exe().unwrap(); + let service_binary_path = std::env::current_exe()?; let service_info = ServiceInfo { name: SERVICE_NAME.into(), @@ -95,7 +80,7 @@ pub fn install() -> windows_service::Result<()> { Ok(()) } -pub fn uninstall() -> windows_service::Result<()> { +pub fn uninstall() -> Result<()> { let manager_access = ServiceManagerAccess::CONNECT; let service_manager = ServiceManager::local_computer(None::<&str>, manager_access)?; @@ -115,7 +100,7 @@ pub fn uninstall() -> windows_service::Result<()> { Ok(()) } -pub fn stop() -> windows_service::Result<()> { +pub fn stop() -> Result<()> { let manager_access = ServiceManagerAccess::CONNECT; let service_manager = ServiceManager::local_computer(None::<&str>, manager_access)?; @@ -130,7 +115,7 @@ pub fn stop() -> windows_service::Result<()> { Ok(()) } -pub fn start() -> windows_service::Result<()> { +pub fn start() -> Result<()> { let manager_access = ServiceManagerAccess::CONNECT; let service_manager = ServiceManager::local_computer(None::<&str>, manager_access)?; @@ -149,7 +134,9 @@ pub fn run() -> Result<()> { // Register generated `ffi_service_main` with the system and start the service, blocking // this thread until the service is stopped. log::info!("service run"); - service_dispatcher::start(SERVICE_NAME, ffi_service_main) + service_dispatcher::start(SERVICE_NAME, ffi_service_main)?; + + Ok(()) } // Generate the windows service boilerplate. @@ -163,11 +150,11 @@ define_windows_service!(ffi_service_main, my_service_main); // output to file if needed. pub fn my_service_main(_arguments: Vec) { if let Err(e) = run_service() { - log::error!("error: {:#?}", e); + log::error!("error: {}", e); } } -pub fn run_service() -> std::result::Result<(), String> { +pub fn run_service() -> Result<()> { // Create a cancellation token to be able to cancell server let control_token = CancellationToken::new(); let server_token = control_token.child_token(); @@ -192,12 +179,10 @@ pub fn run_service() -> std::result::Result<(), String> { // Register system service event handler. // The returned status handle should be used to report service status changes to the system. - let status_handle = service_control_handler::register(SERVICE_NAME, event_handler).str_err()?; + let status_handle = service_control_handler::register(SERVICE_NAME, event_handler)?; // Tell the system that service is running - status_handle - .set_service_status(ServiceStatus::running()) - .str_err()?; + status_handle.set_service_status(ServiceStatus::running())?; let cfg = Config::get(); log::info!("start with config: {:#?}", cfg); @@ -206,26 +191,22 @@ pub fn run_service() -> std::result::Result<(), String> { log::info!("server thread stoped"); + // join() => Err(), when thread panic if let Err(e) = result { log::error!("server panic: {:#?}", e); - status_handle - .set_service_status(ServiceStatus::stopped_with_error(1)) - .str_err()?; - return Err("server panic".into()); + status_handle.set_service_status(ServiceStatus::stopped_with_error(1))?; + return Err(anyhow!("server panic")); } + // join() => Ok(Err()), when server executor error if let Err(e) = result.unwrap() { log::error!("server error: {:#?}", e); - status_handle - .set_service_status(ServiceStatus::stopped_with_error(2)) - .str_err()?; - return Err("server error".into()); + status_handle.set_service_status(ServiceStatus::stopped_with_error(2))?; + return Err(anyhow!("server error")); } // Tell the system that service has stopped. - status_handle - .set_service_status(ServiceStatus::stopped()) - .str_err()?; + status_handle.set_service_status(ServiceStatus::stopped())?; log::info!("service stoped"); Ok(())