From 773a4e01079b3cdd3cd0062d0b3a84332d9871b2 Mon Sep 17 00:00:00 2001 From: Dmitry Date: Tue, 20 Feb 2024 10:15:48 +0300 Subject: [PATCH] add filters --- pddnsc/base.py | 43 +++++++++++++++++++++++++++++ pddnsc/cli.py | 56 +++++++++++++++++++++++++++++++++----- pddnsc/filters/__init__.py | 3 ++ pddnsc/filters/files.py | 23 ++++++++++++++++ settings/config.toml | 5 ++++ 5 files changed, 123 insertions(+), 7 deletions(-) create mode 100644 pddnsc/filters/__init__.py create mode 100644 pddnsc/filters/files.py diff --git a/pddnsc/base.py b/pddnsc/base.py index 03e93e5..f84c448 100644 --- a/pddnsc/base.py +++ b/pddnsc/base.py @@ -96,3 +96,46 @@ class BaseOutputProvider(ABC): @abstractmethod async def set_addrs_imp(self, source_provider, addr_v4, addr_v6): ... + +class BaseFilterProvider(ABC): + _childs = {} + registred = {} + + def __init__(self, name, config, ipv4t, ipv6t): + self.name, self.config = name, config + self.ipv4t, self.ipv6t = ipv4t, ipv6t + + def __init_subclass__(cls) -> None: + BaseFilterProvider._childs[cls.__name__] = cls + return super().__init_subclass__() + + def __str__(self): + return f"{self.__class__.__name__}: {self.name}" + + def best_client(self, addr_v4, addr_v6): + if addr_v6 is None and addr_v4 is not None: + return self.ipv4t + return self.ipv6t + + @classmethod + def validate_source_config(cls, name, config): + if "provider" not in config: + return False + prov_name = config["provider"] + if prov_name not in cls._childs: + return False + return True + + @classmethod + def register_provider(cls, name, config, ipv4t, ipv6t): + if not cls.validate_source_config(name, config): + return + provider = cls._childs[config["provider"]] + cls.registred[name] = provider(name, config, ipv4t, ipv6t) + + async def check(self, source_provider, addr_v4, addr_v6): + return await self.check_imp(source_provider, addr_v4, addr_v6) + + @abstractmethod + async def check_imp(self, source_provider, addr_v4, addr_v6): + ... \ No newline at end of file diff --git a/pddnsc/cli.py b/pddnsc/cli.py index 47fa6b3..4e9ba29 100644 --- a/pddnsc/cli.py +++ b/pddnsc/cli.py @@ -2,14 +2,16 @@ import httpx import asyncio from abc import ABC, abstractmethod import toml -from .base import BaseSourceProvider, BaseOutputProvider +from .base import BaseFilterProvider, BaseSourceProvider, BaseOutputProvider from . import sources from . import outputs +from . import filters -async def source_task(providers): +async def source_task(): + providers = BaseSourceProvider.registred.values() result = None is_done = False - pending = [asyncio.create_task(p.fetch_all()) for p in providers] + pending = [asyncio.create_task(p.fetch_all(), name=p.name) for p in providers] while not is_done and pending: done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) for x in done: @@ -29,10 +31,36 @@ async def source_task(providers): async def output_task(providers, result): + providers = BaseOutputProvider.registred.values() await asyncio.gather( - *(asyncio.create_task(p.set_addrs(*result)) for p in providers) + *(asyncio.create_task(p.set_addrs(*result), name=p.name) for p in providers) ) +async def filter_task(ip_result): + providers = BaseFilterProvider.registred.values() + result = True + failed = "" + pending = [asyncio.create_task(p.check(*ip_result), name=p.name) for p in providers] + while result and pending: + done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + for x in done: + result = x.result() + if not result: + failed = x.get_name() + + if pending: + gather = asyncio.gather(*pending) + gather.cancel() + try: + await gather + except asyncio.CancelledError: + pass + + if not result: + print("failed filter:", failed) + + return result + async def app(ipv4t, ipv6t): config = toml.load("settings/config.toml") @@ -41,6 +69,11 @@ async def app(ipv4t, ipv6t): source_name, config["sources"][source_name], ipv4t, ipv6t ) + for filter_name in config["filters"]: + BaseFilterProvider.register_provider( + filter_name, config["filters"][filter_name], ipv4t, ipv6t + ) + for output_name in config["outputs"]: BaseOutputProvider.register_provider( output_name, config["outputs"][output_name], ipv4t, ipv6t @@ -52,19 +85,28 @@ async def app(ipv4t, ipv6t): print( f"source classes: {[*BaseSourceProvider._childs]}, {[*map(str, BaseSourceProvider._childs.values())]}" ) + print( + f"filter classes: {[*BaseFilterProvider._childs]}, {[*map(str, BaseFilterProvider._childs.values())]}" + ) print( f"output classes: {[*BaseOutputProvider._childs]}, {[*map(str, BaseOutputProvider._childs.values())]}" ) print( f"source providers: {[*BaseSourceProvider.registred]}, {[*map(str, BaseSourceProvider.registred.values())]}" ) + print( + f"filter providers: {[*BaseFilterProvider.registred]}, {[*map(str, BaseFilterProvider.registred.values())]}" + ) print( f"output providers: {[*BaseOutputProvider.registred]}, {[*map(str, BaseOutputProvider.registred.values())]}" ) - print(config) + #print(config) - result = await source_task(BaseSourceProvider.registred.values()) - await output_task(BaseOutputProvider.registred.values(), result) + result = await source_task() + if not await filter_task(result): + print("stop by filters") + return + await output_task(result) async def main(): diff --git a/pddnsc/filters/__init__.py b/pddnsc/filters/__init__.py new file mode 100644 index 0000000..1587c10 --- /dev/null +++ b/pddnsc/filters/__init__.py @@ -0,0 +1,3 @@ +from pddnsc.plugins import load_plugins + +load_plugins(__file__) diff --git a/pddnsc/filters/files.py b/pddnsc/filters/files.py new file mode 100644 index 0000000..3680978 --- /dev/null +++ b/pddnsc/filters/files.py @@ -0,0 +1,23 @@ +import asyncio +import hashlib +import json +import aiofiles +from os.path import isfile + +from pddnsc.base import BaseFilterProvider + +class StateHashFilter(BaseFilterProvider): + async def check_imp(self, source_provider, addr_v4, addr_v6): + if not isfile(self.config['filepath']): + return True + + new_state = { + "ipv4": addr_v4 or "", + "ipv6": addr_v6 or "", + } + new_state_str = json.dumps(new_state) + new_sha = hashlib.sha256(new_state_str.encode(encoding='utf-8')) + async with aiofiles.open(self.config['filepath'], mode='r', encoding='utf-8') as f: + old_state_hash = await f.read() + + return old_state_hash != new_sha.hexdigest() diff --git a/settings/config.toml b/settings/config.toml index 35b7969..4d3fd42 100644 --- a/settings/config.toml +++ b/settings/config.toml @@ -7,6 +7,11 @@ debug = true provider = "FakeSource" ipv6 = "fe80::1" +[filters] + [filters.state-hash] + provider = "StateHashFilter" + filepath = "state/hash.txt" + [outputs] [outputs.print] provider = "JustPrint"