cli refactor

This commit is contained in:
Dmitry Belyaev 2024-02-20 12:05:11 +03:00
parent 3d2f046dcf
commit 058b9a9cf8
Signed by: b4tman
GPG Key ID: 41A00BF15EA7E5F3
2 changed files with 40 additions and 20 deletions

View File

@ -1,6 +1,13 @@
import httpx import httpx
import asyncio import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import NamedTuple
class IPAddreses(NamedTuple):
source_name: str
ipv4: str
ipv6: str
class BaseSourceProvider(ABC): class BaseSourceProvider(ABC):
@ -18,12 +25,13 @@ class BaseSourceProvider(ABC):
def __str__(self): def __str__(self):
return f"{self.__class__.__name__}: {self.name}" return f"{self.__class__.__name__}: {self.name}"
async def fetch_all(self) -> tuple[str, str, str]: async def fetch_all(self) -> IPAddreses:
results = await asyncio.gather( results = await asyncio.gather(
self.fetch_v4(), self.fetch_v6(), return_exceptions=True self.fetch_v4(), self.fetch_v6(), return_exceptions=True
) )
return (self.name,) + tuple(
None if isinstance(i, Exception) else i for i in results return IPAddreses(
self.name, *("" if isinstance(i, Exception) else i for i in results)
) )
def __init_subclass__(cls) -> None: def __init_subclass__(cls) -> None:

View File

@ -2,20 +2,21 @@ import httpx
import asyncio import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import toml import toml
from .base import BaseFilterProvider, BaseSourceProvider, BaseOutputProvider from .base import BaseFilterProvider, BaseSourceProvider, BaseOutputProvider, IPAddreses
from .plugins import use_plugins from .plugins import use_plugins
from typing import Optional
async def source_task(): async def get_ip_addresses() -> Optional[IPAddreses]:
providers = BaseSourceProvider.registred.values() providers = BaseSourceProvider.registred.values()
result = None ip_addresses = None
is_done = False is_done = False
pending = [asyncio.create_task(p.fetch_all(), name=p.name) for p in providers] pending = [asyncio.create_task(p.fetch_all(), name=p.name) for p in providers]
while not is_done and pending: while not is_done and pending:
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
for x in done: for x in done:
result = x.result() ip_addresses = x.result()
if len(result) == 3 and result[1] or result[2]: if ip_addresses.ipv4 or ip_addresses.ipv6:
is_done = True is_done = True
break break
@ -26,14 +27,16 @@ async def source_task():
await gather await gather
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
return result return ip_addresses
async def filter_task(ip_result): async def check_ip_addresses(ip_addresses):
providers = BaseFilterProvider.registred.values() providers = BaseFilterProvider.registred.values()
result = True result = True
failed = "" failed = ""
pending = [asyncio.create_task(p.check(*ip_result), name=p.name) for p in providers] pending = [
asyncio.create_task(p.check(*ip_addresses), name=p.name) for p in providers
]
while result and pending: while result and pending:
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
for x in done: for x in done:
@ -55,17 +58,17 @@ async def filter_task(ip_result):
return result return result
async def output_task(result): async def send_ip_addreses(ip_addresses):
providers = BaseOutputProvider.registred.values() providers = BaseOutputProvider.registred.values()
await asyncio.gather( await asyncio.gather(
*(asyncio.create_task(p.set_addrs(*result), name=p.name) for p in providers) *(
asyncio.create_task(p.set_addrs(*ip_addresses), name=p.name)
for p in providers
)
) )
async def app(config, ipv4t, ipv6t): def print_debug_info(config):
use_plugins(config, ipv4t, ipv6t)
debug = config.get("debug", False) debug = config.get("debug", False)
if debug: if debug:
print("DEBUG info:") print("DEBUG info:")
@ -88,11 +91,20 @@ async def app(config, ipv4t, ipv6t):
f"output providers: {[*BaseOutputProvider.registred]}, {[*map(str, BaseOutputProvider.registred.values())]}" f"output providers: {[*BaseOutputProvider.registred]}, {[*map(str, BaseOutputProvider.registred.values())]}"
) )
result = await source_task()
if not await filter_task(result): async def app(config, ipv4t, ipv6t):
use_plugins(config, ipv4t, ipv6t)
print_debug_info(config)
ip_addreses = await get_ip_addresses()
if ip_addreses is None:
print("no IP addresses")
return
if not await check_ip_addresses(ip_addreses):
print("stop by filters") print("stop by filters")
return return
await output_task(result) await send_ip_addreses(ip_addreses)
print("done") print("done")