use anyhow::Context; use bollard::container::ListContainersOptions; use bollard::Docker; use std::cmp::min; use tokio::sync::mpsc; use std::collections::HashMap; use std::default::Default; use tokio::time::{timeout, Instant}; use flexi_logger::{AdaptiveFormat, Logger, LoggerHandle}; use std::sync::Arc; use std::time::Duration; use clap::Parser; use parse_duration::parse as parse_duration; #[derive(Parser, Debug)] #[clap(author, version, about, long_about = None)] #[clap(propagate_version = true)] struct Cli { /// target label #[arg(short, long, default_value = "auto-restart.unhealthy")] pub label: String, /// check interval #[arg(short, long, default_value = "10s", value_parser = parse_duration)] pub interval: Duration, /// unhealthy status timeout #[arg(short, long, default_value = "35s", value_parser = parse_duration)] pub unhealthy_timeout: Duration, } fn get_query_options(args: &Cli) -> ListContainersOptions { let mut filters = HashMap::new(); //filters.insert("status", vec!["running"]); filters.insert("label".into(), vec![args.label.clone()]); filters.insert("health".into(), vec!["unhealthy".into()]); ListContainersOptions { filters, ..Default::default() } } fn create_logger() -> anyhow::Result { let logger = Logger::try_with_str("info") .context("default logging level invalid")? .format(flexi_logger::detailed_format) .adaptive_format_for_stdout(AdaptiveFormat::Detailed) .log_to_stdout() .start() .context("can't start logger"); log_panics::init(); logger } async fn query_containers( connection: &Docker, query_options: ListContainersOptions, ) -> anyhow::Result> where T: std::cmp::Eq + std::hash::Hash + serde::ser::Serialize, std::string::String: From, { Ok(connection .list_containers(Some(query_options)) .await? .into_iter() .map(|container| container.id.unwrap()) .collect()) } async fn restart_container(connection: &Docker, container_name: &str) -> anyhow::Result<()> { connection.restart_container(container_name, None).await?; Ok(()) } type Containers = Vec; async fn query_task( connection: Arc, interval: Duration, tx: mpsc::Sender, mut shutdown_rx: mpsc::Receiver<()>, args: Cli, ) { let query_options = get_query_options(&args); let mut query_time = Duration::new(0, 0); log::debug!("query_task -> start recv"); while (timeout(interval - query_time, shutdown_rx.recv()).await).is_err() { let start = Instant::now(); let containers = query_containers(&connection, query_options.clone()) .await .unwrap_or_default(); let end = Instant::now(); query_time = min(end - start, interval - Duration::from_millis(1)); let res = tx.send(containers).await; if res.is_err() { break; } } } async fn filter_task( unhealthy_timeout: Duration, mut in_rx: mpsc::Receiver, out_tx: mpsc::Sender, ) { let mut unhealthy_time: Option> = Some(HashMap::new()); while let Some(containers) = in_rx.recv().await { let now = Instant::now(); log::info!("filter -> found unhealthy: {}", containers.len()); let prev_times = unhealthy_time.take().unwrap(); let mut new_times: HashMap = prev_times .into_iter() .filter(|(k, _)| containers.contains(k)) .collect(); for container_id in containers { new_times.entry(container_id).or_insert_with(|| now); } let containers: Vec = new_times .iter() .filter(|(_, &time)| (now - time) > unhealthy_timeout) .map(|(id, _)| id.clone()) .collect(); let _ = unhealthy_time.replace(new_times); log::info!("filter -> filtered unhealthy: {}", containers.len()); if containers.is_empty() { continue; } let res = out_tx.send(containers).await; if res.is_err() { break; } } } async fn restart_task(connection: Arc, mut rx: mpsc::Receiver) { log::debug!("restart task start"); while let Some(containers) = rx.recv().await { log::info!("restart -> found: {}", containers.len()); for container_id in containers { log::warn!("restart -> container: {}...", &container_id); let res = restart_container(&connection, container_id.as_str()).await; match res { Ok(_) => log::info!("ok"), Err(e) => log::error!("error: \n{e:?}"), } } } } fn shutdown_control(shutdown: Option>) { let mut shutdown = shutdown; let res = ctrlc::set_handler(move || { log::info!("recieved Ctrl-C"); shutdown.take(); }); if res.is_ok() { log::info!("Press Ctrl-C to stop"); } } #[tokio::main] async fn main() -> anyhow::Result<()> { let logger = create_logger()?; let cli = Cli::parse(); let connection = Arc::new(Docker::connect_with_defaults()?); let _ = connection .as_ref() .ping() .await .context("ping on docker connection")?; let query_connection = connection.clone(); let restart_connection = connection.clone(); let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(1); let (query_tx, filter_rx) = mpsc::channel::(1); let (filter_tx, restart_rx) = mpsc::channel::(1); let interval = cli.interval; let unhealthy_timeout = cli.unhealthy_timeout; shutdown_control(Some(shutdown_tx)); tokio::try_join!( tokio::spawn(query_task( query_connection, interval, query_tx, shutdown_rx, cli )), tokio::spawn(filter_task(unhealthy_timeout, filter_rx, filter_tx)), tokio::spawn(restart_task(restart_connection, restart_rx)) )?; drop(logger); Ok(()) }