From a1fcc5125cee3b67fd97c3f8be595420fb9d68ad Mon Sep 17 00:00:00 2001 From: Dmitry Date: Wed, 23 Oct 2024 15:31:52 +0300 Subject: [PATCH] peazyrsa-srv version --- Cargo.lock | 6 +- Cargo.toml | 8 ++- src/certs.rs | 96 +++++----------------------- src/common.rs | 135 ++++++++++++++++++--------------------- src/main.rs | 14 +--- src/openssl/external.rs | 138 ---------------------------------------- src/openssl/internal.rs | 100 +++++++++++++---------------- src/openssl/mod.rs | 1 - src/vars.rs | 96 ---------------------------- 9 files changed, 136 insertions(+), 458 deletions(-) delete mode 100644 src/openssl/external.rs delete mode 100644 src/vars.rs diff --git a/Cargo.lock b/Cargo.lock index cd406b3..7e9a2e3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -92,9 +92,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.90" +version = "1.0.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37bf3594c4c988a53154954629820791dde498571819ae4ca50ca811e060cc95" +checksum = "c042108f3ed77fd83760a5fd79b53be043192bb3b9dba91d8c574c0ada7850c8" [[package]] name = "async-stream" @@ -570,7 +570,7 @@ dependencies = [ ] [[package]] -name = "peazyrsa" +name = "peazyrsa-srv" version = "0.1.0" dependencies = [ "anyhow", diff --git a/Cargo.toml b/Cargo.toml index 15ee7a9..e591757 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,10 +1,14 @@ [package] -name = "peazyrsa" +name = "peazyrsa-srv" version = "0.1.0" edition = "2021" +[[bin]] +name = "peazysrv" +path = "src/main.rs" + [dependencies] -anyhow = "1.0.90" +anyhow = "1.0.91" async-stream = "0.3.6" chrono = "0.4.38" clap = { version = "4.5.20", features = ["derive"] } diff --git a/src/certs.rs b/src/certs.rs index d70bb9f..081d258 100644 --- a/src/certs.rs +++ b/src/certs.rs @@ -1,21 +1,17 @@ -use anyhow::{anyhow, Context, Result}; +use anyhow::{anyhow, Result}; use std::{path::PathBuf, sync::Arc}; -use crate::common::{is_file_exist, read_file, write_file, AppConfig, OpenSSLProviderArg, VarsMap}; +use crate::common::AppConfig; use crate::crypto_provider::ICryptoProvider; -use crate::openssl::{external::OpenSSLExternalProvider, internal::OpenSSLInternalProvider}; +use crate::openssl::internal::OpenSSLInternalProvider; pub(crate) struct Certs where T: ICryptoProvider, { - pub(crate) encoding: String, - pub(crate) ca_file: PathBuf, pub(crate) key_file: PathBuf, pub(crate) cert_file: PathBuf, - pub(crate) config_file: PathBuf, - pub(crate) template_file: PathBuf, pub(crate) provider: Arc, } @@ -26,24 +22,15 @@ where pub(crate) fn new(cfg: &AppConfig, provider: T) -> Self { let base_dir = PathBuf::from(&cfg.base_directory); let keys_dir = base_dir.clone().join(cfg.keys_subdir.clone()); - let config_dir = base_dir.clone().join(cfg.config_subdir.clone()); let name = cfg.name.clone(); Certs { - encoding: cfg.encoding.clone(), - ca_file: keys_dir.join(cfg.ca_filename.clone()), key_file: keys_dir.join(format!("{}.key", &name)), cert_file: keys_dir.join(format!("{}.crt", &name)), - config_file: config_dir.join(format!("{}.ovpn", &name)), - template_file: base_dir.clone().join(cfg.template_file.clone()), provider: Arc::new(provider), } } - async fn is_config_exists(&self) -> bool { - is_file_exist(&self.config_file).await - } - pub(crate) async fn request(&self) -> Result<()> { self.provider.request().await } @@ -51,72 +38,23 @@ where pub(crate) async fn sign(&self) -> Result<()> { self.provider.sign().await } - - pub(crate) async fn build_client_config(&self) -> Result { - if self.is_config_exists().await { - return Ok(false); - } - - self.request().await.context("req error")?; - self.sign().await.context("sign error")?; - - let (template_file, ca_file, cert_file, key_file) = ( - self.template_file.clone(), - self.ca_file.clone(), - self.cert_file.clone(), - self.key_file.clone(), - ); - let enc = self.encoding.clone(); - let (enc1, enc2, enc3, enc4) = (enc.clone(), enc.clone(), enc.clone(), enc.clone()); - - if let (Ok(Ok(template)), Ok(Ok(ca)), Ok(Ok(cert)), Ok(Ok(key))) = tokio::join!( - tokio::spawn(read_file(template_file, enc1)), - tokio::spawn(read_file(ca_file, enc2)), - tokio::spawn(read_file(cert_file, enc3)), - tokio::spawn(read_file(key_file, enc4)) - ) { - let text = template - .replace("{{ca}}", ca.trim()) - .replace("{{cert}}", cert.trim()) - .replace("{{key}}", key.trim()); - - write_file(&self.config_file, text, &self.encoding).await?; - - Ok(true) - } else { - Err(anyhow!("files read error")) - } - } } -pub async fn build_client_config(config: &AppConfig, vars: VarsMap) -> Result<()> { - let result_file: PathBuf; - let created: bool; +pub async fn build_cert(config: &AppConfig) -> Result<()> { + let certs = Certs::new(config, OpenSSLInternalProvider::from_cfg(config)); + certs.request().await?; + certs.sign().await?; - if let OpenSSLProviderArg::ExternalBin(_) = config.openssl { - let certs = Certs::new(config, OpenSSLExternalProvider::from_cfg(config, vars)); - created = certs - .build_client_config() - .await - .context("external openssl error")?; - result_file = certs.config_file; - } else { - let certs = Certs::new(config, OpenSSLInternalProvider::from_cfg(config, vars)); - created = certs - .build_client_config() - .await - .context("internal openssl error")?; - result_file = certs.config_file; - } - - let result_file = result_file + let key_file = certs + .key_file .to_str() - .ok_or(anyhow!("result_file PathBuf to str convert error"))?; + .ok_or(anyhow!("key_file PathBuf to str convert error"))?; - if created { - println!("created: {result_file}"); - Ok(()) - } else { - Err(anyhow!("file exists: {result_file}")) - } + let cert_file = certs + .cert_file + .to_str() + .ok_or(anyhow!("req_file PathBuf to str convert error"))?; + + println!("created: \n key: {key_file},\n cert: {cert_file}"); + Ok(()) } diff --git a/src/common.rs b/src/common.rs index 0165c18..c70a473 100644 --- a/src/common.rs +++ b/src/common.rs @@ -1,26 +1,16 @@ -use anyhow::{anyhow, Context, Result}; -use async_stream::stream; +use anyhow::{anyhow, Result}; use clap::Parser; -use encoding::{label::encoding_from_whatwg_label, EncoderTrap}; +use encoding::label::encoding_from_whatwg_label; use std::{ - collections::BTreeMap, fmt::Display, path::{Path, PathBuf}, str::FromStr, }; -use tokio::{ - fs::{self, File}, - io::{AsyncBufReadExt, BufReader}, -}; - -use futures_core::stream::Stream; - -pub(crate) type VarsMap = BTreeMap; +use tokio::fs; #[derive(Debug, Clone, PartialEq)] pub enum OpenSSLProviderArg { Internal, - ExternalBin(String), } impl FromStr for OpenSSLProviderArg { @@ -28,7 +18,7 @@ impl FromStr for OpenSSLProviderArg { fn from_str(s: &str) -> Result { match s.to_ascii_lowercase().as_str() { "internal" => Ok(OpenSSLProviderArg::Internal), - x => Ok(OpenSSLProviderArg::ExternalBin(x.to_string())), + &_ => todo!(), } } } @@ -36,8 +26,7 @@ impl FromStr for OpenSSLProviderArg { impl Display for OpenSSLProviderArg { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - OpenSSLProviderArg::ExternalBin(x) => write!(f, "{}", x), - _ => write!(f, "internal"), + OpenSSLProviderArg::Internal => write!(f, "internal"), } } } @@ -61,13 +50,9 @@ pub(crate) struct Args { pub(crate) encoding: Option, /// keys subdir - #[arg(long, default_value = "keys")] + #[arg(long, default_value = ".")] pub(crate) keys_dir: String, - /// config subdir - #[arg(long, default_value = "config")] - pub(crate) config_dir: String, - /// valid days #[arg(long, default_value = "3650")] pub(crate) days: u32, @@ -76,25 +61,53 @@ pub(crate) struct Args { #[arg(long, short, default_value = "internal")] pub(crate) openssl: OpenSSLProviderArg, - /// template file - #[arg(long, default_value = "template.ovpn")] - pub(crate) template_file: String, + /// dns name + #[arg(short = 'n', long)] + pub(crate) dns: Vec, + + /// IP address + #[arg(short, long)] + pub(crate) ip: Vec, + + /// Country + #[arg(long)] + pub(crate) country: Option, + + /// Province + #[arg(long)] + pub(crate) province: Option, + + /// Organization + #[arg(long)] + pub(crate) org: Option, + + /// Organization Unit + #[arg(long)] + pub(crate) ou: Option, + + /// Key size + #[arg(long, default_value = "2048")] + pub(crate) key_size: u32, } pub(crate) struct AppConfig { pub(crate) encoding: String, pub(crate) req_days: u32, pub(crate) keys_subdir: String, - pub(crate) config_subdir: String, - pub(crate) template_file: String, - pub(crate) openssl_default_cnf: String, - pub(crate) openssl_cnf_env: String, pub(crate) ca_filename: String, pub(crate) default_email_domain: String, + #[allow(unused)] pub(crate) openssl: OpenSSLProviderArg, pub(crate) base_directory: String, pub(crate) email: String, pub(crate) name: String, + pub(crate) dns: Vec, + pub(crate) ip: Vec, + pub(crate) country: Option, + pub(crate) province: Option, + pub(crate) org: Option, + pub(crate) ou: Option, + pub(crate) key_size: u32, } impl Default for AppConfig { @@ -103,16 +116,19 @@ impl Default for AppConfig { encoding: "cp866".into(), req_days: 30650, keys_subdir: "keys".into(), - config_subdir: "config".into(), - template_file: "template.ovpn".into(), - openssl_default_cnf: "openssl-1.0.0.cnf".into(), - openssl_cnf_env: "KEY_CONFIG".into(), ca_filename: "ca.crt".into(), default_email_domain: "example.com".into(), openssl: OpenSSLProviderArg::Internal, base_directory: ".".into(), email: "name@example.com".into(), name: "user".into(), + dns: Vec::new(), + ip: Vec::new(), + country: None, + province: None, + org: None, + ou: None, + key_size: 2048, } } } @@ -138,10 +154,16 @@ impl From<&Args> for AppConfig { }; let name = args.name.clone(); let openssl = args.openssl.clone(); - let template_file = args.template_file.clone(); let req_days = args.days; let keys_subdir = args.keys_dir.clone(); - let config_subdir = args.config_dir.clone(); + let (dns, ip, country, province, org, ou) = ( + args.dns.clone(), + args.ip.clone(), + args.country.clone(), + args.province.clone(), + args.org.clone(), + args.ou.clone(), + ); Self { base_directory, @@ -149,10 +171,15 @@ impl From<&Args> for AppConfig { encoding, name, openssl, - template_file, req_days, keys_subdir, - config_subdir, + dns, + ip, + country, + province, + org, + ou, + key_size: args.key_size, ..defaults } } @@ -187,39 +214,3 @@ where enc.decode(&bytes, encoding::DecoderTrap::Ignore) .map_err(|_| anyhow!("could not read file")) } - -pub(crate) async fn write_file(filepath: &PathBuf, text: String, encoding: &str) -> Result<()> { - if encoding == "utf8" { - return Ok(fs::write(filepath, text).await?); - } - - let enc = encoding_from_whatwg_label(encoding).ok_or(anyhow!("encoding not found"))?; - let mut bytes = Vec::new(); - enc.encode_to(&text, EncoderTrap::Ignore, &mut bytes) - .map_err(|_| anyhow!("can't encode"))?; - - fs::write(filepath, bytes).await.context("can't write file") -} - -pub(crate) async fn read_file_by_lines( - filepath: &PathBuf, - encoding: &str, -) -> Result>> { - Ok(if encoding == "utf8" { - let f = File::open(filepath).await?; - let reader = BufReader::new(f); - let mut lines = reader.lines(); - Box::new(stream! { - while let Ok(Some(line)) = lines.next_line().await { - yield line - } - }) - } else { - let text = read_file(filepath, encoding).await?; - Box::new(stream! { - for line in text.lines() { - yield line.to_string() - } - }) - }) -} diff --git a/src/main.rs b/src/main.rs index 2c135bd..7920939 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,26 +1,18 @@ -use anyhow::{anyhow, Result}; +use anyhow::Result; use clap::Parser; mod certs; mod common; mod crypto_provider; mod openssl; -mod vars; -use crate::certs::build_client_config; +use crate::certs::build_cert; use crate::common::{AppConfig, Args}; -use crate::vars::VarsFile; #[tokio::main(flavor = "current_thread")] async fn main() -> Result<()> { let args = Args::parse(); let config = AppConfig::from(&args); - let mut vars = VarsFile::from_config(&config).await?; - vars.parse().await?; - println!("found vars: {}", vars.filepath.to_str().expect("fff")); - println!("loaded: {:#?}", &vars.vars); - - let vars = vars.vars.ok_or(anyhow!("no vars loaded"))?; - build_client_config(&config, vars).await + build_cert(&config).await } diff --git a/src/openssl/external.rs b/src/openssl/external.rs deleted file mode 100644 index b54da1b..0000000 --- a/src/openssl/external.rs +++ /dev/null @@ -1,138 +0,0 @@ -use anyhow::{anyhow, Result}; -use std::path::PathBuf; - -use tokio::process::Command; - -use crate::common::{is_file_exist, AppConfig, VarsMap}; -use crate::crypto_provider::ICryptoProvider; - -pub(crate) struct OpenSSLExternalProvider { - vars: VarsMap, - base_dir: PathBuf, - openssl_cnf: PathBuf, - openssl: String, - ca_file: PathBuf, - req_file: PathBuf, - key_file: PathBuf, - cert_file: PathBuf, - req_days: u32, -} - -impl OpenSSLExternalProvider { - async fn is_ca_exists(&self) -> bool { - is_file_exist(&self.ca_file).await - } - - async fn is_cert_exists(&self) -> bool { - is_file_exist(&self.cert_file).await - } - - async fn is_req_exists(&self) -> bool { - is_file_exist(&self.req_file).await - } - - pub(crate) fn from_cfg(cfg: &AppConfig, vars: VarsMap) -> Self { - let base_dir = PathBuf::from(&cfg.base_directory); - let keys_dir = base_dir.clone().join(cfg.keys_subdir.clone()); - let name = cfg.name.clone(); - let mut vars = vars; - - vars.insert("KEY_CN".into(), name.clone()); - vars.insert("KEY_NAME".into(), name.clone()); - vars.insert("KEY_EMAIL".into(), cfg.email.clone()); - - let ca_file = keys_dir.join(cfg.ca_filename.clone()); - let req_file = keys_dir.join(format!("{}.csr", &name)); - let key_file = keys_dir.join(format!("{}.key", &name)); - let cert_file = keys_dir.join(format!("{}.crt", &name)); - let openssl_cnf = base_dir.clone().join( - std::env::var(cfg.openssl_cnf_env.clone()).unwrap_or(cfg.openssl_default_cnf.clone()), - ); - - Self { - vars, - base_dir, - openssl_cnf, - openssl: cfg.openssl.to_string(), - ca_file, - req_file, - key_file, - cert_file, - req_days: cfg.req_days, - } - } -} - -impl ICryptoProvider for OpenSSLExternalProvider { - async fn request(&self) -> Result<()> { - if self.is_req_exists().await { - return Ok(()); - } - - if !self.is_ca_exists().await { - return Err(anyhow!( - "ca file not found: {}", - &self.ca_file.to_str().unwrap() - )); - } - - let status = Command::new(&self.openssl) - .args([ - "req", - "-nodes", - "-new", - "-keyout", - self.key_file.to_str().unwrap(), - "-out", - self.req_file.to_str().unwrap(), - "-config", - self.openssl_cnf.to_str().unwrap(), - "-batch", - ]) - .current_dir(&self.base_dir) - .envs(&self.vars) - .status() - .await?; - - match status.success() { - true => Ok(()), - false => Err(anyhow!("openssl req execution failed")), - } - } - - async fn sign(&self) -> Result<()> { - if self.is_cert_exists().await { - return Ok(()); - } - - if !self.is_ca_exists().await { - return Err(anyhow!( - "ca file not found: {}", - &self.ca_file.to_str().unwrap() - )); - } - - let status = Command::new(&self.openssl) - .args([ - "ca", - "-days", - format!("{}", self.req_days).as_str(), - "-out", - self.cert_file.to_str().unwrap(), - "-in", - self.req_file.to_str().unwrap(), - "-config", - self.openssl_cnf.to_str().unwrap(), - "-batch", - ]) - .current_dir(&self.base_dir) - .envs(&self.vars) - .status() - .await?; - - match status.success() { - true => Ok(()), - false => Err(anyhow!("ssl ca execution failed")), - } - } -} diff --git a/src/openssl/internal.rs b/src/openssl/internal.rs index 02f2418..6740338 100644 --- a/src/openssl/internal.rs +++ b/src/openssl/internal.rs @@ -16,28 +16,12 @@ use std::path::{Path, PathBuf}; use tokio::fs; use crate::{ - common::{is_file_exist, read_file, AppConfig, VarsMap}, + common::{is_file_exist, read_file, AppConfig}, crypto_provider::ICryptoProvider, }; -use lazy_static::lazy_static; -use std::collections::HashMap; - use chrono::{Datelike, Days, Timelike, Utc}; -lazy_static! { - static ref KEYMAP: HashMap<&'static str, &'static str> = { - let mut m = HashMap::new(); - m.insert("C", "KEY_COUNTRY"); - m.insert("ST", "KEY_PROVINCE"); - m.insert("O", "KEY_ORG"); - m.insert("OU", "KEY_OU"); - m.insert("CN", "KEY_CN"); - m.insert("name", "KEY_NAME"); - m - }; -} - trait ToPemX { fn to_pem_x(&self) -> Result>; } @@ -92,11 +76,6 @@ fn get_time_str_x509(days: u32) -> Result { } pub(crate) struct OpenSSLInternalProvider { - vars: VarsMap, - #[allow(unused)] - base_dir: PathBuf, - #[allow(unused)] - openssl_cnf: PathBuf, ca_file: PathBuf, ca_key_file: PathBuf, req_file: PathBuf, @@ -105,6 +84,14 @@ pub(crate) struct OpenSSLInternalProvider { req_days: u32, key_size: u32, encoding: String, + name: String, + email: String, + dns: Vec, + ip: Vec, + country: Option, + province: Option, + org: Option, + ou: Option, } impl OpenSSLInternalProvider { @@ -120,43 +107,36 @@ impl OpenSSLInternalProvider { is_file_exist(&self.req_file).await } - pub(crate) fn from_cfg(cfg: &AppConfig, vars: VarsMap) -> Self { + pub(crate) fn from_cfg(cfg: &AppConfig) -> Self { let base_dir = PathBuf::from(&cfg.base_directory); let keys_dir = base_dir.clone().join(cfg.keys_subdir.clone()); let name = cfg.name.clone(); - let mut vars = vars; - - vars.insert("KEY_CN".into(), name.clone()); - vars.insert("KEY_NAME".into(), name.clone()); - vars.insert("KEY_EMAIL".into(), cfg.email.clone()); let ca_file = keys_dir.join(cfg.ca_filename.clone()); let ca_key_file = ca_file.with_extension("key"); let req_file = keys_dir.join(format!("{}.csr", &name)); let key_file = keys_dir.join(format!("{}.key", &name)); let cert_file = keys_dir.join(format!("{}.crt", &name)); - let openssl_cnf = base_dir.clone().join( - std::env::var(cfg.openssl_cnf_env.clone()).unwrap_or(cfg.openssl_default_cnf.clone()), - ); - - let default_key_size = "2048".to_string(); - let key_size_s = vars.get("KEY_SIZE").unwrap_or(&default_key_size); - let key_size: u32 = key_size_s.parse().unwrap(); let encoding = cfg.encoding.clone(); Self { - vars, - base_dir, - openssl_cnf, ca_file, ca_key_file, req_file, key_file, cert_file, req_days: cfg.req_days, - key_size, + key_size: cfg.key_size, encoding, + name, + email: cfg.email.clone(), + dns: cfg.dns.clone(), + ip: cfg.ip.clone(), + country: cfg.country.clone(), + province: cfg.province.clone(), + org: cfg.org.clone(), + ou: cfg.ou.clone(), } } @@ -204,32 +184,40 @@ impl OpenSSLInternalProvider { fn build_x509_name(&self) -> Result { let mut name_builder = X509NameBuilder::new().context("Failed to create X509 name builder")?; - for (&key, &var) in KEYMAP.iter() { - let value = self - .vars - .get(var) - .ok_or(anyhow!("variable not set: {}", var))?; - name_builder.append_entry_by_text(key, value).unwrap(); + name_builder.append_entry_by_text("name", &self.name)?; + name_builder.append_entry_by_text("CN", &self.name)?; + if let Some(country) = self.country.clone() { + name_builder.append_entry_by_text("C", &country)?; + } + if let Some(province) = self.province.clone() { + name_builder.append_entry_by_text("ST", &province)?; + } + if let Some(org) = self.org.clone() { + name_builder.append_entry_by_text("O", &org)?; + } + if let Some(ou) = self.ou.clone() { + name_builder.append_entry_by_text("OU", &ou)?; } Ok(name_builder.build()) } fn gen_x509_extensions( + &self, context: &openssl::x509::X509v3Context, - vars: &VarsMap, ) -> Result> { let key_usage = KeyUsage::new() - .key_agreement() .digital_signature() + .data_encipherment() .build()?; - let key_extended_ext = ExtendedKeyUsage::new().client_auth().build()?; + let key_extended_ext = ExtendedKeyUsage::new().server_auth().build()?; let mut san_extension = SubjectAlternativeName::new(); - if let Some(name) = vars.get("KEY_NAME") { - san_extension.dns(name); + san_extension.email(self.email.as_str()); + for dns in self.dns.iter() { + san_extension.dns(dns); } - if let Some(email) = vars.get("KEY_EMAIL") { - san_extension.email(email); + for ip in self.ip.iter() { + san_extension.ip(ip); } let san_ext = san_extension.build(context).context("build san")?; @@ -237,11 +225,11 @@ impl OpenSSLInternalProvider { } fn gen_x509_extensions_stack( + &self, context: &openssl::x509::X509v3Context, - vars: &VarsMap, ) -> Result> { let mut stack = Stack::new()?; - for extension in Self::gen_x509_extensions(context, vars)?.into_iter() { + for extension in self.gen_x509_extensions(context)?.into_iter() { stack.push(extension).context("push ext")?; } Ok(stack) @@ -274,7 +262,7 @@ impl ICryptoProvider for OpenSSLInternalProvider { .set_subject_name(&name) .context("set subject name")?; let context = csr_builder.x509v3_context(Some(&conf)); - let extensions = Self::gen_x509_extensions_stack(&context, &self.vars)?; + let extensions = self.gen_x509_extensions_stack(&context)?; csr_builder.add_extensions(&extensions)?; csr_builder.sign(&pkey, MessageDigest::sha512())?; let csr = csr_builder.build(); @@ -331,7 +319,7 @@ impl ICryptoProvider for OpenSSLInternalProvider { .context("set_subject_name")?; let context = builder.x509v3_context(Some(&ca_cert), None); - for extension in Self::gen_x509_extensions(&context, &self.vars)? { + for extension in self.gen_x509_extensions(&context)? { builder.append_extension(extension).context("append ext")?; } diff --git a/src/openssl/mod.rs b/src/openssl/mod.rs index d29baed..02f0bee 100644 --- a/src/openssl/mod.rs +++ b/src/openssl/mod.rs @@ -1,2 +1 @@ -pub(crate) mod external; pub(crate) mod internal; diff --git a/src/vars.rs b/src/vars.rs deleted file mode 100644 index 3874267..0000000 --- a/src/vars.rs +++ /dev/null @@ -1,96 +0,0 @@ -use anyhow::{anyhow, Context, Result}; -use regex::Regex; -use std::{path::PathBuf, pin::Pin}; -use tokio::pin; - -use futures_util::stream::StreamExt; - -use crate::common::{read_file_by_lines, AppConfig, VarsMap}; - -pub(crate) struct VarsFile { - pub(crate) filepath: PathBuf, - pub(crate) vars: Option, - pub(crate) encoding: String, -} - -impl VarsFile { - async fn from_file(filepath: &PathBuf, encoding: String) -> Result { - let metadata = tokio::fs::metadata(&filepath).await.context(format!( - "file not found {}", - filepath.to_str().expect("str") - ))?; - if !metadata.is_file() { - Err(anyhow!("{} is not a file", filepath.to_str().expect("str")))? - } - Ok(VarsFile { - filepath: filepath.to_path_buf(), - vars: None, - encoding, - }) - } - - async fn from_dir(dir: PathBuf, encoding: String) -> Result { - let filepath = dir.join("vars"); - let err_context = format!( - "vars or vars.bat file not found in {}", - dir.to_str().expect("str") - ); - - match Self::from_file(&filepath, encoding.clone()).await { - Ok(res) => Ok(res), - Err(_) => Self::from_file(&filepath.with_extension("bat"), encoding.clone()) - .await - .map_err(|e| e.context(err_context)), - } - } - - pub(crate) async fn from_config(config: &AppConfig) -> Result { - Self::from_dir( - PathBuf::from(&config.base_directory), - config.encoding.clone(), - ) - .await - } - - pub(crate) async fn parse(&mut self) -> Result<()> { - let mut result = VarsMap::new(); - let lines = read_file_by_lines(&self.filepath, &self.encoding).await?; - let lines = Pin::from(lines); - pin!(lines); - - let re_v2 = - Regex::new(r#"^(export|set)\s\b(?P[\w\d_]+)\b=\s?"?(?P[^\#]+?)"?$"#) - .context("regex v2")?; - let re_v3 = Regex::new(r"^set_var\s(?P[\w\d_]+)\s+(?P[^\#]+?)$") - .context("regex v3")?; - - while let Some(line) = lines.next().await { - if let Some(caps) = re_v2.captures(line.as_str()) { - result.insert(caps["key"].to_string(), caps["value"].to_string()); - continue; - } - - if let Some(caps) = re_v3.captures(line.as_str()) { - result.insert(caps["key"].to_string(), caps["value"].to_string()); - }; - } - - self.vars = Some(result); - Ok(()) - } - - #[allow(dead_code)] - fn apply(&self) -> Result<()> { - if let Some(vars) = self.vars.clone() { - for (key, value) in vars.iter() { - unsafe { - std::env::set_var(key, value); - } - } - } else { - Err(anyhow!("vars not parsed"))? - } - - Ok(()) - } -}