From 773a4e01079b3cdd3cd0062d0b3a84332d9871b2 Mon Sep 17 00:00:00 2001
From: Dmitry <b4tm4n@mail.ru>
Date: Tue, 20 Feb 2024 10:15:48 +0300
Subject: [PATCH] add filters

---
 pddnsc/base.py             | 43 +++++++++++++++++++++++++++++
 pddnsc/cli.py              | 56 +++++++++++++++++++++++++++++++++-----
 pddnsc/filters/__init__.py |  3 ++
 pddnsc/filters/files.py    | 23 ++++++++++++++++
 settings/config.toml       |  5 ++++
 5 files changed, 123 insertions(+), 7 deletions(-)
 create mode 100644 pddnsc/filters/__init__.py
 create mode 100644 pddnsc/filters/files.py

diff --git a/pddnsc/base.py b/pddnsc/base.py
index 03e93e5..f84c448 100644
--- a/pddnsc/base.py
+++ b/pddnsc/base.py
@@ -96,3 +96,46 @@ class BaseOutputProvider(ABC):
     @abstractmethod
     async def set_addrs_imp(self, source_provider, addr_v4, addr_v6):
         ...
+
+class BaseFilterProvider(ABC):
+    _childs = {}
+    registred = {}
+
+    def __init__(self, name, config, ipv4t, ipv6t):
+        self.name, self.config = name, config
+        self.ipv4t, self.ipv6t = ipv4t, ipv6t
+
+    def __init_subclass__(cls) -> None:
+        BaseFilterProvider._childs[cls.__name__] = cls
+        return super().__init_subclass__()
+
+    def __str__(self):
+        return f"{self.__class__.__name__}: {self.name}"
+
+    def best_client(self, addr_v4, addr_v6):
+        if addr_v6 is None and addr_v4 is not None:
+            return self.ipv4t
+        return self.ipv6t
+
+    @classmethod
+    def validate_source_config(cls, name, config):
+        if "provider" not in config:
+            return False
+        prov_name = config["provider"]
+        if prov_name not in cls._childs:
+            return False
+        return True
+
+    @classmethod
+    def register_provider(cls, name, config, ipv4t, ipv6t):
+        if not cls.validate_source_config(name, config):
+            return
+        provider = cls._childs[config["provider"]]
+        cls.registred[name] = provider(name, config, ipv4t, ipv6t)
+
+    async def check(self, source_provider, addr_v4, addr_v6):
+        return await self.check_imp(source_provider, addr_v4, addr_v6)
+
+    @abstractmethod
+    async def check_imp(self, source_provider, addr_v4, addr_v6):
+        ...
\ No newline at end of file
diff --git a/pddnsc/cli.py b/pddnsc/cli.py
index 47fa6b3..4e9ba29 100644
--- a/pddnsc/cli.py
+++ b/pddnsc/cli.py
@@ -2,14 +2,16 @@ import httpx
 import asyncio
 from abc import ABC, abstractmethod
 import toml
-from .base import BaseSourceProvider, BaseOutputProvider
+from .base import BaseFilterProvider, BaseSourceProvider, BaseOutputProvider
 from . import sources
 from . import outputs
+from . import filters
 
-async def source_task(providers):
+async def source_task():
+    providers = BaseSourceProvider.registred.values()
     result = None
     is_done = False
-    pending = [asyncio.create_task(p.fetch_all()) for p in providers]
+    pending = [asyncio.create_task(p.fetch_all(), name=p.name) for p in providers]
     while not is_done and pending:
         done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
         for x in done:
@@ -29,10 +31,36 @@ async def source_task(providers):
 
 
 async def output_task(providers, result):
+    providers = BaseOutputProvider.registred.values()
     await asyncio.gather(
-        *(asyncio.create_task(p.set_addrs(*result)) for p in providers)
+        *(asyncio.create_task(p.set_addrs(*result), name=p.name) for p in providers)
     )
 
+async def filter_task(ip_result):
+    providers = BaseFilterProvider.registred.values()
+    result = True
+    failed = ""
+    pending = [asyncio.create_task(p.check(*ip_result), name=p.name) for p in providers]
+    while result and pending:
+        done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
+        for x in done:
+            result = x.result()
+            if not result:
+                failed = x.get_name()
+
+    if pending:
+        gather = asyncio.gather(*pending)
+        gather.cancel()
+        try:
+            await gather
+        except asyncio.CancelledError:
+            pass
+    
+    if not result:
+        print("failed filter:", failed)
+
+    return result
+    
 
 async def app(ipv4t, ipv6t):
     config = toml.load("settings/config.toml")
@@ -41,6 +69,11 @@ async def app(ipv4t, ipv6t):
             source_name, config["sources"][source_name], ipv4t, ipv6t
         )
 
+    for filter_name in config["filters"]:
+        BaseFilterProvider.register_provider(
+            filter_name, config["filters"][filter_name], ipv4t, ipv6t
+        )
+
     for output_name in config["outputs"]:
         BaseOutputProvider.register_provider(
             output_name, config["outputs"][output_name], ipv4t, ipv6t
@@ -52,19 +85,28 @@ async def app(ipv4t, ipv6t):
         print(
             f"source classes: {[*BaseSourceProvider._childs]}, {[*map(str, BaseSourceProvider._childs.values())]}"
         )
+        print(
+            f"filter classes: {[*BaseFilterProvider._childs]}, {[*map(str, BaseFilterProvider._childs.values())]}"
+        )
         print(
             f"output classes: {[*BaseOutputProvider._childs]}, {[*map(str, BaseOutputProvider._childs.values())]}"
         )
         print(
             f"source providers: {[*BaseSourceProvider.registred]}, {[*map(str, BaseSourceProvider.registred.values())]}"
         )
+        print(
+            f"filter providers: {[*BaseFilterProvider.registred]}, {[*map(str, BaseFilterProvider.registred.values())]}"
+        )
         print(
             f"output providers: {[*BaseOutputProvider.registred]}, {[*map(str, BaseOutputProvider.registred.values())]}"
         )
-        print(config)
+        #print(config)
 
-    result = await source_task(BaseSourceProvider.registred.values())
-    await output_task(BaseOutputProvider.registred.values(), result)
+    result = await source_task()
+    if not await filter_task(result):
+        print("stop by filters")
+        return
+    await output_task(result)
 
 
 async def main():
diff --git a/pddnsc/filters/__init__.py b/pddnsc/filters/__init__.py
new file mode 100644
index 0000000..1587c10
--- /dev/null
+++ b/pddnsc/filters/__init__.py
@@ -0,0 +1,3 @@
+from pddnsc.plugins import load_plugins
+
+load_plugins(__file__)
diff --git a/pddnsc/filters/files.py b/pddnsc/filters/files.py
new file mode 100644
index 0000000..3680978
--- /dev/null
+++ b/pddnsc/filters/files.py
@@ -0,0 +1,23 @@
+import asyncio
+import hashlib
+import json
+import aiofiles
+from os.path import isfile
+
+from pddnsc.base import BaseFilterProvider
+
+class StateHashFilter(BaseFilterProvider):
+    async def check_imp(self, source_provider, addr_v4, addr_v6):
+        if not isfile(self.config['filepath']):
+            return True
+
+        new_state = {
+            "ipv4": addr_v4 or "",
+            "ipv6": addr_v6 or "",
+        }
+        new_state_str = json.dumps(new_state)
+        new_sha = hashlib.sha256(new_state_str.encode(encoding='utf-8'))
+        async with aiofiles.open(self.config['filepath'], mode='r', encoding='utf-8') as f:
+            old_state_hash = await f.read()
+        
+        return old_state_hash != new_sha.hexdigest()
diff --git a/settings/config.toml b/settings/config.toml
index 35b7969..4d3fd42 100644
--- a/settings/config.toml
+++ b/settings/config.toml
@@ -7,6 +7,11 @@ debug = true
     provider = "FakeSource"
     ipv6 = "fe80::1"
 
+[filters]
+  [filters.state-hash]
+    provider = "StateHashFilter"
+    filepath = "state/hash.txt"
+    
 [outputs]
   [outputs.print]
     provider = "JustPrint"