Compare commits

...

3 Commits

Author SHA1 Message Date
058b9a9cf8
cli refactor 2024-02-20 12:05:11 +03:00
3d2f046dcf
refactor state hash output+filter 2024-02-20 11:11:00 +03:00
b7097540a2
add proxy conf 2024-02-20 10:58:04 +03:00
4 changed files with 48 additions and 33 deletions

View File

@ -1,6 +1,13 @@
import httpx import httpx
import asyncio import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import NamedTuple
class IPAddreses(NamedTuple):
source_name: str
ipv4: str
ipv6: str
class BaseSourceProvider(ABC): class BaseSourceProvider(ABC):
@ -18,12 +25,13 @@ class BaseSourceProvider(ABC):
def __str__(self): def __str__(self):
return f"{self.__class__.__name__}: {self.name}" return f"{self.__class__.__name__}: {self.name}"
async def fetch_all(self) -> tuple[str, str, str]: async def fetch_all(self) -> IPAddreses:
results = await asyncio.gather( results = await asyncio.gather(
self.fetch_v4(), self.fetch_v6(), return_exceptions=True self.fetch_v4(), self.fetch_v6(), return_exceptions=True
) )
return (self.name,) + tuple(
None if isinstance(i, Exception) else i for i in results return IPAddreses(
self.name, *("" if isinstance(i, Exception) else i for i in results)
) )
def __init_subclass__(cls) -> None: def __init_subclass__(cls) -> None:

View File

@ -2,20 +2,21 @@ import httpx
import asyncio import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import toml import toml
from .base import BaseFilterProvider, BaseSourceProvider, BaseOutputProvider from .base import BaseFilterProvider, BaseSourceProvider, BaseOutputProvider, IPAddreses
from .plugins import use_plugins from .plugins import use_plugins
from typing import Optional
async def source_task(): async def get_ip_addresses() -> Optional[IPAddreses]:
providers = BaseSourceProvider.registred.values() providers = BaseSourceProvider.registred.values()
result = None ip_addresses = None
is_done = False is_done = False
pending = [asyncio.create_task(p.fetch_all(), name=p.name) for p in providers] pending = [asyncio.create_task(p.fetch_all(), name=p.name) for p in providers]
while not is_done and pending: while not is_done and pending:
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
for x in done: for x in done:
result = x.result() ip_addresses = x.result()
if len(result) == 3 and result[1] or result[2]: if ip_addresses.ipv4 or ip_addresses.ipv6:
is_done = True is_done = True
break break
@ -26,14 +27,16 @@ async def source_task():
await gather await gather
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
return result return ip_addresses
async def filter_task(ip_result): async def check_ip_addresses(ip_addresses):
providers = BaseFilterProvider.registred.values() providers = BaseFilterProvider.registred.values()
result = True result = True
failed = "" failed = ""
pending = [asyncio.create_task(p.check(*ip_result), name=p.name) for p in providers] pending = [
asyncio.create_task(p.check(*ip_addresses), name=p.name) for p in providers
]
while result and pending: while result and pending:
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
for x in done: for x in done:
@ -55,17 +58,17 @@ async def filter_task(ip_result):
return result return result
async def output_task(result): async def send_ip_addreses(ip_addresses):
providers = BaseOutputProvider.registred.values() providers = BaseOutputProvider.registred.values()
await asyncio.gather( await asyncio.gather(
*(asyncio.create_task(p.set_addrs(*result), name=p.name) for p in providers) *(
asyncio.create_task(p.set_addrs(*ip_addresses), name=p.name)
for p in providers
)
) )
async def app(ipv4t, ipv6t): def print_debug_info(config):
config = toml.load("settings/config.toml")
use_plugins(config, ipv4t, ipv6t)
debug = config.get("debug", False) debug = config.get("debug", False)
if debug: if debug:
print("DEBUG info:") print("DEBUG info:")
@ -88,19 +91,31 @@ async def app(ipv4t, ipv6t):
f"output providers: {[*BaseOutputProvider.registred]}, {[*map(str, BaseOutputProvider.registred.values())]}" f"output providers: {[*BaseOutputProvider.registred]}, {[*map(str, BaseOutputProvider.registred.values())]}"
) )
result = await source_task()
if not await filter_task(result): async def app(config, ipv4t, ipv6t):
use_plugins(config, ipv4t, ipv6t)
print_debug_info(config)
ip_addreses = await get_ip_addresses()
if ip_addreses is None:
print("no IP addresses")
return
if not await check_ip_addresses(ip_addreses):
print("stop by filters") print("stop by filters")
return return
await output_task(result) await send_ip_addreses(ip_addreses)
print("done") print("done")
async def main(): async def main():
config = toml.load("settings/config.toml")
async with httpx.AsyncHTTPTransport( async with httpx.AsyncHTTPTransport(
local_address="0.0.0.0", local_address="0.0.0.0", proxy=config.get("proxy_v4")
) as ipv4t, httpx.AsyncHTTPTransport(local_address="::") as ipv6t: ) as ipv4t, httpx.AsyncHTTPTransport(
await app(ipv4t, ipv6t) local_address="::", proxy=config.get("proxy_v6")
) as ipv6t:
await app(config, ipv4t, ipv6t)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -12,11 +12,7 @@ class StateHashFilter(BaseFilterProvider):
if not isfile(self.config["filepath"]): if not isfile(self.config["filepath"]):
return True return True
new_state = { new_state_str = (addr_v4 or "") + (addr_v6 or "")
"ipv4": addr_v4 or "",
"ipv6": addr_v6 or "",
}
new_state_str = json.dumps(new_state)
new_sha = hashlib.sha256(new_state_str.encode(encoding="utf-8")) new_sha = hashlib.sha256(new_state_str.encode(encoding="utf-8"))
async with aiofiles.open( async with aiofiles.open(
self.config["filepath"], mode="r", encoding="utf-8" self.config["filepath"], mode="r", encoding="utf-8"

View File

@ -21,11 +21,7 @@ class StateFile(BaseOutputProvider):
class StateHashFile(BaseOutputProvider): class StateHashFile(BaseOutputProvider):
async def set_addrs_imp(self, source_provider, addr_v4, addr_v6): async def set_addrs_imp(self, source_provider, addr_v4, addr_v6):
state = { state_str = (addr_v4 or "") + (addr_v6 or "")
"ipv4": addr_v4 or "",
"ipv6": addr_v6 or "",
}
state_str = json.dumps(state)
sha = hashlib.sha256(state_str.encode(encoding="utf-8")) sha = hashlib.sha256(state_str.encode(encoding="utf-8"))
async with aiofiles.open( async with aiofiles.open(
self.config["filepath"], mode="w", encoding="utf-8" self.config["filepath"], mode="w", encoding="utf-8"