diff --git a/src/server.rs b/src/server.rs index ef57612..cd1377b 100644 --- a/src/server.rs +++ b/src/server.rs @@ -6,23 +6,32 @@ use std::future::Future; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::select; use tokio::task; -use tokio_stream::StreamExt; +use tokio_stream::{Stream, StreamExt}; use tokio_util::sync::CancellationToken; +use std::sync::mpsc::Sender; + +use async_stream::stream; + use crate::config::Config; use crate::config::PasswordAuth; -pub fn server_executor(cfg: Config, token: CancellationToken) { +pub fn server_executor( + cfg: Config, + token: CancellationToken, + shutdown_tx: Sender<()>, +) -> std::io::Result<()> { tokio::runtime::Builder::new_multi_thread() .enable_all() - .build() - .unwrap() + .build()? .block_on(async { - spawn_socks5_server(cfg, token).await.unwrap(); + let result = spawn_socks5_server(cfg, token).await; + shutdown_tx.send(()).unwrap(); + result }) } -pub async fn spawn_socks5_server(cfg: Config, token: CancellationToken) -> Result<()> { +pub async fn spawn_socks5_server(cfg: Config, token: CancellationToken) -> std::io::Result<()> { let mut server_config = fast_socks5::server::Config::default(); server_config.set_request_timeout(cfg.request_timeout); server_config.set_skip_auth(cfg.skip_auth); @@ -40,11 +49,12 @@ pub async fn spawn_socks5_server(cfg: Config, token: CancellationToken) -> Resul let mut listener = Socks5Server::bind(&cfg.listen_addr).await?; listener.set_config(server_config); - let mut incoming = listener.incoming(); + let incoming = stream_with_cancellation(listener.incoming(), &token); + tokio::pin!(incoming); log::info!("Listen for socks connections @ {}", &cfg.listen_addr); - while let Some(socket_res) = check_cancelled(incoming.next(), token.child_token()).await { + while let Some(socket_res) = incoming.next().await { match socket_res { Ok(socket) => { let child_token = token.child_token(); @@ -59,14 +69,31 @@ pub async fn spawn_socks5_server(cfg: Config, token: CancellationToken) -> Resul Ok(()) } -async fn check_cancelled(future: F, token: CancellationToken) -> Option +fn stream_with_cancellation<'a, S>( + mut inner: S, + token: &'a CancellationToken, +) -> impl Stream::Item> + 'a where - F: Future>, + S: StreamExt + Unpin + 'a, +{ + stream! { + while let Some(res) = check_cancelled(inner.next(), token, None).await { + yield res; + } + } +} + +async fn check_cancelled(future: F, token: &CancellationToken, default: R) -> R +where + F: Future, { select! { + biased; + _ = token.cancelled() => { log::error!("accept canceled"); - None + + default } res = future => { res @@ -81,11 +108,13 @@ where { tokio::spawn(async move { let result = select! { + biased; + _ = token.cancelled() => { Err("Client connection canceled".to_string()) } res = future => { - res.map_err(|e| format!("{:#}", &e)) + res.map_err(|e| e.to_string()) } }; if let Err(e) = result { diff --git a/src/service.rs b/src/service.rs index 571e07f..41c9529 100644 --- a/src/service.rs +++ b/src/service.rs @@ -21,6 +21,56 @@ const SERVICE_TYPE: ServiceType = ServiceType::OWN_PROCESS; const SERVICE_DISPLAY: &str = "socks5ws proxy"; const SERVICE_DESCRIPTION: &str = "SOCKS5 proxy windows service"; +trait ErrorToString { + type Output; + fn str_err(self) -> std::result::Result; +} + +impl ErrorToString for std::result::Result +where + E: std::error::Error, +{ + type Output = T; + fn str_err(self) -> std::result::Result { + self.map_err(|e| e.to_string()) + } +} + +trait ServiceStatusEx { + fn running() -> ServiceStatus; + fn stopped() -> ServiceStatus; + fn stopped_with_error(code: u32) -> ServiceStatus; +} + +impl ServiceStatusEx for ServiceStatus { + fn running() -> ServiceStatus { + ServiceStatus { + service_type: SERVICE_TYPE, + current_state: ServiceState::Running, + controls_accepted: ServiceControlAccept::STOP, + exit_code: ServiceExitCode::Win32(0), + checkpoint: 0, + wait_hint: Duration::default(), + process_id: None, + } + } + + fn stopped() -> ServiceStatus { + ServiceStatus { + current_state: ServiceState::Stopped, + controls_accepted: ServiceControlAccept::empty(), + ..Self::running() + } + } + + fn stopped_with_error(code: u32) -> ServiceStatus { + ServiceStatus { + exit_code: ServiceExitCode::ServiceSpecific(code), + ..Self::stopped() + } + } +} + pub fn install() -> windows_service::Result<()> { let manager_access = ServiceManagerAccess::CONNECT | ServiceManagerAccess::CREATE_SERVICE; let service_manager = ServiceManager::local_computer(None::<&str>, manager_access)?; @@ -113,14 +163,16 @@ define_windows_service!(ffi_service_main, my_service_main); // output to file if needed. pub fn my_service_main(_arguments: Vec) { if let Err(e) = run_service() { - log::error!("{:#?}", e); + log::error!("error: {:#?}", e); } } -pub fn run_service() -> Result<()> { +pub fn run_service() -> std::result::Result<(), String> { // Create a channel to be able to poll a stop event from the service worker loop. let (shutdown_tx, shutdown_rx) = mpsc::channel(); + let shutdown_tx1 = shutdown_tx.clone(); + // Define system service event handler that will be receiving service events. let event_handler = move |control_event| -> ServiceControlHandlerResult { match control_event { @@ -130,8 +182,8 @@ pub fn run_service() -> Result<()> { // Handle stop ServiceControl::Stop => { - log::debug!("Stop signal from system"); - shutdown_tx.send(()).unwrap(); + log::debug!("stop signal from system"); + shutdown_tx1.send(()).unwrap(); ServiceControlHandlerResult::NoError } @@ -141,43 +193,47 @@ pub fn run_service() -> Result<()> { // Register system service event handler. // The returned status handle should be used to report service status changes to the system. - let status_handle = service_control_handler::register(SERVICE_NAME, event_handler)?; + let status_handle = service_control_handler::register(SERVICE_NAME, event_handler).str_err()?; // Tell the system that service is running - status_handle.set_service_status(ServiceStatus { - service_type: SERVICE_TYPE, - current_state: ServiceState::Running, - controls_accepted: ServiceControlAccept::STOP, - exit_code: ServiceExitCode::Win32(0), - checkpoint: 0, - wait_hint: Duration::default(), - process_id: None, - })?; + status_handle + .set_service_status(ServiceStatus::running()) + .str_err()?; let cfg = Config::get(); log::info!("start with config: {:#?}", cfg); let token = CancellationToken::new(); let child_token = token.child_token(); - let server_handle = std::thread::spawn(move || server_executor(cfg, child_token)); + let server_handle = std::thread::spawn(move || server_executor(cfg, child_token, shutdown_tx)); shutdown_rx.recv().unwrap(); // wait for shutdown signal log::info!("service stop"); // stop server token.cancel(); - server_handle.join().unwrap(); + + let result = server_handle.join(); + if let Err(e) = result { + log::error!("server panic: {:#?}", e); + status_handle + .set_service_status(ServiceStatus::stopped_with_error(1)) + .str_err()?; + return Err("server panic".into()); + } + + let result = result.unwrap(); + if let Err(e) = result { + log::error!("server error: {:#?}", e); + status_handle + .set_service_status(ServiceStatus::stopped_with_error(2)) + .str_err()?; + return Err("server error".into()); + } // Tell the system that service has stopped. - status_handle.set_service_status(ServiceStatus { - service_type: SERVICE_TYPE, - current_state: ServiceState::Stopped, - controls_accepted: ServiceControlAccept::empty(), - exit_code: ServiceExitCode::Win32(0), - checkpoint: 0, - wait_hint: Duration::default(), - process_id: None, - })?; - + status_handle + .set_service_status(ServiceStatus::stopped()) + .str_err()?; Ok(()) }