import httpx import asyncio from abc import ABC, abstractmethod import toml class BaseSourceProvider(ABC): _childs = {} registred = {} def __init__(self, name, config, ipv4t, ipv6t): self.name, self.config, self.ipv4t, self.ipv6t = ( name, config, ipv4t, ipv6t, ) def __str__(self): return f"{self.__class__.__name__}: {self.name}" async def fetch_all(self) -> tuple[str, str, str]: results = await asyncio.gather( self.fetch_v4(), self.fetch_v6(), return_exceptions=True ) return (self.name,) + tuple( None if isinstance(i, Exception) else i for i in results ) def __init_subclass__(cls) -> None: BaseSourceProvider._childs[cls.__name__] = cls return super().__init_subclass__() @classmethod def validate_source_config(cls, name, config): if "provider" not in config: return False prov_name = config["provider"] if prov_name not in cls._childs: return False return True @classmethod def register_provider(cls, name, config, ipv4t, ipv6t): if not cls.validate_source_config(name, config): return provider = cls._childs[config["provider"]] cls.registred[name] = provider(name, config, ipv4t, ipv6t) @abstractmethod async def fetch_v4(self) -> str: ... @abstractmethod async def fetch_v6(self) -> str: ... class BaseDNSProvider(ABC): _childs = {} registred = {} def __init__(self, name, config, ipv4t, ipv6t): self.name, self.config = name, config self.ipv4t, self.ipv6t = ipv4t, ipv6t def __init_subclass__(cls) -> None: BaseDNSProvider._childs[cls.__name__] = cls return super().__init_subclass__() def __str__(self): return f"{self.__class__.__name__}: {self.name}" def best_client(self, addr_v4, addr_v6): if addr_v6 is None and addr_v4 is not None: return self.ipv4t return self.ipv6t @classmethod def validate_source_config(cls, name, config): if "provider" not in config: return False prov_name = config["provider"] if prov_name not in cls._childs: return False return True @classmethod def register_provider(cls, name, config, ipv4t, ipv6t): if not cls.validate_source_config(name, config): return provider = cls._childs[config["provider"]] cls.registred[name] = provider(name, config, ipv4t, ipv6t) async def set_addrs(self, source_provider, addr_v4, addr_v6): return await self.set_addrs_imp(source_provider, addr_v4, addr_v6) @abstractmethod async def set_addrs_imp(self, source_provider, addr_v4, addr_v6): ... class DummySource(BaseSourceProvider): async def fetch_v4(self) -> str: async with httpx.AsyncClient(transport=self.ipv4t) as client: result = await asyncio.sleep(10, result=None) result = await asyncio.sleep(10, result=None) return result async def fetch_v6(self) -> str: async with httpx.AsyncClient(transport=self.ipv6t) as client: result = await asyncio.sleep(10, result=None) result = await asyncio.sleep(10, result=None) return result class FakeSource(BaseSourceProvider): async def fetch_v4(self) -> str: async with httpx.AsyncClient(transport=self.ipv4t) as client: result = await asyncio.sleep( 3.3, result=self.config.get("ipv4", "127.0.0.1") ) return result async def fetch_v6(self) -> str: async with httpx.AsyncClient(transport=self.ipv6t) as client: result = await asyncio.sleep(4.4, result=self.config.get("ipv6", "::1")) return result class IPIFYSource(BaseSourceProvider): async def fetch_v4(self) -> str: async with httpx.AsyncClient(transport=self.ipv4t) as client: response = await client.get("https://api4.ipify.org/?format=json") if response.status_code == httpx.codes.OK: data = response.json() result = None if not isinstance(data, dict) else data.get("ip") return result async def fetch_v6(self) -> str: async with httpx.AsyncClient(transport=self.ipv6t) as client: response = await client.get("https://api6.ipify.org/?format=json") if response.status_code == httpx.codes.OK: data = response.json() result = None if not isinstance(data, dict) else data.get("ip") return result class JustPrint(BaseDNSProvider): async def set_addrs_imp(self, source_provider, addr_v4, addr_v6): print(f">> {self.name}") print(f"addresses from: {source_provider}") print(f"IPv4: {addr_v4}") print(f"IPv6: {addr_v6}") await asyncio.sleep(2.2, result=None) async def source_task(providers): result = None is_done = False pending = [asyncio.create_task(p.fetch_all()) for p in providers] while not is_done and pending: done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) for x in done: result = x.result() if len(result) == 3 and result[1] or result[2]: is_done = True break if pending: gather = asyncio.gather(*pending) gather.cancel() try: await gather except asyncio.CancelledError: pass return result async def output_task(providers, result): await asyncio.gather( *(asyncio.create_task(p.set_addrs(*result)) for p in providers) ) async def app(ipv4t, ipv6t): config = toml.load("settings/config.toml") for source_name in config["sources"]: BaseSourceProvider.register_provider( source_name, config["sources"][source_name], ipv4t, ipv6t ) for output_name in config["outputs"]: BaseDNSProvider.register_provider( output_name, config["outputs"][output_name], ipv4t, ipv6t ) debug = config.get("debug", False) if debug: print("DEBUG info:") print( f"source classes: {[*BaseSourceProvider._childs]}, {[*map(str, BaseSourceProvider._childs.values())]}" ) print( f"output classes: {[*BaseDNSProvider._childs]}, {[*map(str, BaseDNSProvider._childs.values())]}" ) print( f"source providers: {[*BaseSourceProvider.registred]}, {[*map(str, BaseSourceProvider.registred.values())]}" ) print( f"output providers: {[*BaseDNSProvider.registred]}, {[*map(str, BaseDNSProvider.registred.values())]}" ) print(config) result = await source_task(BaseSourceProvider.registred.values()) await output_task(BaseDNSProvider.registred.values(), result) async def main(): async with httpx.AsyncHTTPTransport( local_address="0.0.0.0", ) as ipv4t, httpx.AsyncHTTPTransport(local_address="::") as ipv6t: await app(ipv4t, ipv6t) if __name__ == "__main__": asyncio.run(main())