refactor service + incoming stream

This commit is contained in:
Dmitry Belyaev 2022-09-28 15:01:48 +03:00
parent 60b6912b2c
commit b2f81ccc5b
Signed by: b4tman
GPG Key ID: 41A00BF15EA7E5F3
2 changed files with 123 additions and 38 deletions

View File

@ -6,23 +6,32 @@ use std::future::Future;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::select;
use tokio::task;
use tokio_stream::StreamExt;
use tokio_stream::{Stream, StreamExt};
use tokio_util::sync::CancellationToken;
use std::sync::mpsc::Sender;
use async_stream::stream;
use crate::config::Config;
use crate::config::PasswordAuth;
pub fn server_executor(cfg: Config, token: CancellationToken) {
pub fn server_executor(
cfg: Config,
token: CancellationToken,
shutdown_tx: Sender<()>,
) -> std::io::Result<()> {
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
.build()?
.block_on(async {
spawn_socks5_server(cfg, token).await.unwrap();
let result = spawn_socks5_server(cfg, token).await;
shutdown_tx.send(()).unwrap();
result
})
}
pub async fn spawn_socks5_server(cfg: Config, token: CancellationToken) -> Result<()> {
pub async fn spawn_socks5_server(cfg: Config, token: CancellationToken) -> std::io::Result<()> {
let mut server_config = fast_socks5::server::Config::default();
server_config.set_request_timeout(cfg.request_timeout);
server_config.set_skip_auth(cfg.skip_auth);
@ -40,11 +49,12 @@ pub async fn spawn_socks5_server(cfg: Config, token: CancellationToken) -> Resul
let mut listener = Socks5Server::bind(&cfg.listen_addr).await?;
listener.set_config(server_config);
let mut incoming = listener.incoming();
let incoming = stream_with_cancellation(listener.incoming(), &token);
tokio::pin!(incoming);
log::info!("Listen for socks connections @ {}", &cfg.listen_addr);
while let Some(socket_res) = check_cancelled(incoming.next(), token.child_token()).await {
while let Some(socket_res) = incoming.next().await {
match socket_res {
Ok(socket) => {
let child_token = token.child_token();
@ -59,14 +69,31 @@ pub async fn spawn_socks5_server(cfg: Config, token: CancellationToken) -> Resul
Ok(())
}
async fn check_cancelled<F, R>(future: F, token: CancellationToken) -> Option<R>
fn stream_with_cancellation<'a, S>(
mut inner: S,
token: &'a CancellationToken,
) -> impl Stream<Item = <S as Stream>::Item> + 'a
where
F: Future<Output = Option<R>>,
S: StreamExt + Unpin + 'a,
{
stream! {
while let Some(res) = check_cancelled(inner.next(), token, None).await {
yield res;
}
}
}
async fn check_cancelled<F, R>(future: F, token: &CancellationToken, default: R) -> R
where
F: Future<Output = R>,
{
select! {
biased;
_ = token.cancelled() => {
log::error!("accept canceled");
None
default
}
res = future => {
res
@ -81,11 +108,13 @@ where
{
tokio::spawn(async move {
let result = select! {
biased;
_ = token.cancelled() => {
Err("Client connection canceled".to_string())
}
res = future => {
res.map_err(|e| format!("{:#}", &e))
res.map_err(|e| e.to_string())
}
};
if let Err(e) = result {

View File

@ -21,6 +21,56 @@ const SERVICE_TYPE: ServiceType = ServiceType::OWN_PROCESS;
const SERVICE_DISPLAY: &str = "socks5ws proxy";
const SERVICE_DESCRIPTION: &str = "SOCKS5 proxy windows service";
trait ErrorToString {
type Output;
fn str_err(self) -> std::result::Result<Self::Output, String>;
}
impl<T, E> ErrorToString for std::result::Result<T, E>
where
E: std::error::Error,
{
type Output = T;
fn str_err(self) -> std::result::Result<Self::Output, String> {
self.map_err(|e| e.to_string())
}
}
trait ServiceStatusEx {
fn running() -> ServiceStatus;
fn stopped() -> ServiceStatus;
fn stopped_with_error(code: u32) -> ServiceStatus;
}
impl ServiceStatusEx for ServiceStatus {
fn running() -> ServiceStatus {
ServiceStatus {
service_type: SERVICE_TYPE,
current_state: ServiceState::Running,
controls_accepted: ServiceControlAccept::STOP,
exit_code: ServiceExitCode::Win32(0),
checkpoint: 0,
wait_hint: Duration::default(),
process_id: None,
}
}
fn stopped() -> ServiceStatus {
ServiceStatus {
current_state: ServiceState::Stopped,
controls_accepted: ServiceControlAccept::empty(),
..Self::running()
}
}
fn stopped_with_error(code: u32) -> ServiceStatus {
ServiceStatus {
exit_code: ServiceExitCode::ServiceSpecific(code),
..Self::stopped()
}
}
}
pub fn install() -> windows_service::Result<()> {
let manager_access = ServiceManagerAccess::CONNECT | ServiceManagerAccess::CREATE_SERVICE;
let service_manager = ServiceManager::local_computer(None::<&str>, manager_access)?;
@ -113,14 +163,16 @@ define_windows_service!(ffi_service_main, my_service_main);
// output to file if needed.
pub fn my_service_main(_arguments: Vec<OsString>) {
if let Err(e) = run_service() {
log::error!("{:#?}", e);
log::error!("error: {:#?}", e);
}
}
pub fn run_service() -> Result<()> {
pub fn run_service() -> std::result::Result<(), String> {
// Create a channel to be able to poll a stop event from the service worker loop.
let (shutdown_tx, shutdown_rx) = mpsc::channel();
let shutdown_tx1 = shutdown_tx.clone();
// Define system service event handler that will be receiving service events.
let event_handler = move |control_event| -> ServiceControlHandlerResult {
match control_event {
@ -130,8 +182,8 @@ pub fn run_service() -> Result<()> {
// Handle stop
ServiceControl::Stop => {
log::debug!("Stop signal from system");
shutdown_tx.send(()).unwrap();
log::debug!("stop signal from system");
shutdown_tx1.send(()).unwrap();
ServiceControlHandlerResult::NoError
}
@ -141,43 +193,47 @@ pub fn run_service() -> Result<()> {
// Register system service event handler.
// The returned status handle should be used to report service status changes to the system.
let status_handle = service_control_handler::register(SERVICE_NAME, event_handler)?;
let status_handle = service_control_handler::register(SERVICE_NAME, event_handler).str_err()?;
// Tell the system that service is running
status_handle.set_service_status(ServiceStatus {
service_type: SERVICE_TYPE,
current_state: ServiceState::Running,
controls_accepted: ServiceControlAccept::STOP,
exit_code: ServiceExitCode::Win32(0),
checkpoint: 0,
wait_hint: Duration::default(),
process_id: None,
})?;
status_handle
.set_service_status(ServiceStatus::running())
.str_err()?;
let cfg = Config::get();
log::info!("start with config: {:#?}", cfg);
let token = CancellationToken::new();
let child_token = token.child_token();
let server_handle = std::thread::spawn(move || server_executor(cfg, child_token));
let server_handle = std::thread::spawn(move || server_executor(cfg, child_token, shutdown_tx));
shutdown_rx.recv().unwrap(); // wait for shutdown signal
log::info!("service stop");
// stop server
token.cancel();
server_handle.join().unwrap();
let result = server_handle.join();
if let Err(e) = result {
log::error!("server panic: {:#?}", e);
status_handle
.set_service_status(ServiceStatus::stopped_with_error(1))
.str_err()?;
return Err("server panic".into());
}
let result = result.unwrap();
if let Err(e) = result {
log::error!("server error: {:#?}", e);
status_handle
.set_service_status(ServiceStatus::stopped_with_error(2))
.str_err()?;
return Err("server error".into());
}
// Tell the system that service has stopped.
status_handle.set_service_status(ServiceStatus {
service_type: SERVICE_TYPE,
current_state: ServiceState::Stopped,
controls_accepted: ServiceControlAccept::empty(),
exit_code: ServiceExitCode::Win32(0),
checkpoint: 0,
wait_hint: Duration::default(),
process_id: None,
})?;
status_handle
.set_service_status(ServiceStatus::stopped())
.str_err()?;
Ok(())
}