add filters

This commit is contained in:
Dmitry Belyaev 2024-02-20 10:15:48 +03:00
parent 0dc440e8cf
commit 773a4e0107
Signed by: b4tman
GPG Key ID: 41A00BF15EA7E5F3
5 changed files with 123 additions and 7 deletions

View File

@ -96,3 +96,46 @@ class BaseOutputProvider(ABC):
@abstractmethod @abstractmethod
async def set_addrs_imp(self, source_provider, addr_v4, addr_v6): async def set_addrs_imp(self, source_provider, addr_v4, addr_v6):
... ...
class BaseFilterProvider(ABC):
_childs = {}
registred = {}
def __init__(self, name, config, ipv4t, ipv6t):
self.name, self.config = name, config
self.ipv4t, self.ipv6t = ipv4t, ipv6t
def __init_subclass__(cls) -> None:
BaseFilterProvider._childs[cls.__name__] = cls
return super().__init_subclass__()
def __str__(self):
return f"{self.__class__.__name__}: {self.name}"
def best_client(self, addr_v4, addr_v6):
if addr_v6 is None and addr_v4 is not None:
return self.ipv4t
return self.ipv6t
@classmethod
def validate_source_config(cls, name, config):
if "provider" not in config:
return False
prov_name = config["provider"]
if prov_name not in cls._childs:
return False
return True
@classmethod
def register_provider(cls, name, config, ipv4t, ipv6t):
if not cls.validate_source_config(name, config):
return
provider = cls._childs[config["provider"]]
cls.registred[name] = provider(name, config, ipv4t, ipv6t)
async def check(self, source_provider, addr_v4, addr_v6):
return await self.check_imp(source_provider, addr_v4, addr_v6)
@abstractmethod
async def check_imp(self, source_provider, addr_v4, addr_v6):
...

View File

@ -2,14 +2,16 @@ import httpx
import asyncio import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import toml import toml
from .base import BaseSourceProvider, BaseOutputProvider from .base import BaseFilterProvider, BaseSourceProvider, BaseOutputProvider
from . import sources from . import sources
from . import outputs from . import outputs
from . import filters
async def source_task(providers): async def source_task():
providers = BaseSourceProvider.registred.values()
result = None result = None
is_done = False is_done = False
pending = [asyncio.create_task(p.fetch_all()) for p in providers] pending = [asyncio.create_task(p.fetch_all(), name=p.name) for p in providers]
while not is_done and pending: while not is_done and pending:
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
for x in done: for x in done:
@ -29,10 +31,36 @@ async def source_task(providers):
async def output_task(providers, result): async def output_task(providers, result):
providers = BaseOutputProvider.registred.values()
await asyncio.gather( await asyncio.gather(
*(asyncio.create_task(p.set_addrs(*result)) for p in providers) *(asyncio.create_task(p.set_addrs(*result), name=p.name) for p in providers)
) )
async def filter_task(ip_result):
providers = BaseFilterProvider.registred.values()
result = True
failed = ""
pending = [asyncio.create_task(p.check(*ip_result), name=p.name) for p in providers]
while result and pending:
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
for x in done:
result = x.result()
if not result:
failed = x.get_name()
if pending:
gather = asyncio.gather(*pending)
gather.cancel()
try:
await gather
except asyncio.CancelledError:
pass
if not result:
print("failed filter:", failed)
return result
async def app(ipv4t, ipv6t): async def app(ipv4t, ipv6t):
config = toml.load("settings/config.toml") config = toml.load("settings/config.toml")
@ -41,6 +69,11 @@ async def app(ipv4t, ipv6t):
source_name, config["sources"][source_name], ipv4t, ipv6t source_name, config["sources"][source_name], ipv4t, ipv6t
) )
for filter_name in config["filters"]:
BaseFilterProvider.register_provider(
filter_name, config["filters"][filter_name], ipv4t, ipv6t
)
for output_name in config["outputs"]: for output_name in config["outputs"]:
BaseOutputProvider.register_provider( BaseOutputProvider.register_provider(
output_name, config["outputs"][output_name], ipv4t, ipv6t output_name, config["outputs"][output_name], ipv4t, ipv6t
@ -52,19 +85,28 @@ async def app(ipv4t, ipv6t):
print( print(
f"source classes: {[*BaseSourceProvider._childs]}, {[*map(str, BaseSourceProvider._childs.values())]}" f"source classes: {[*BaseSourceProvider._childs]}, {[*map(str, BaseSourceProvider._childs.values())]}"
) )
print(
f"filter classes: {[*BaseFilterProvider._childs]}, {[*map(str, BaseFilterProvider._childs.values())]}"
)
print( print(
f"output classes: {[*BaseOutputProvider._childs]}, {[*map(str, BaseOutputProvider._childs.values())]}" f"output classes: {[*BaseOutputProvider._childs]}, {[*map(str, BaseOutputProvider._childs.values())]}"
) )
print( print(
f"source providers: {[*BaseSourceProvider.registred]}, {[*map(str, BaseSourceProvider.registred.values())]}" f"source providers: {[*BaseSourceProvider.registred]}, {[*map(str, BaseSourceProvider.registred.values())]}"
) )
print(
f"filter providers: {[*BaseFilterProvider.registred]}, {[*map(str, BaseFilterProvider.registred.values())]}"
)
print( print(
f"output providers: {[*BaseOutputProvider.registred]}, {[*map(str, BaseOutputProvider.registred.values())]}" f"output providers: {[*BaseOutputProvider.registred]}, {[*map(str, BaseOutputProvider.registred.values())]}"
) )
print(config) #print(config)
result = await source_task(BaseSourceProvider.registred.values()) result = await source_task()
await output_task(BaseOutputProvider.registred.values(), result) if not await filter_task(result):
print("stop by filters")
return
await output_task(result)
async def main(): async def main():

View File

@ -0,0 +1,3 @@
from pddnsc.plugins import load_plugins
load_plugins(__file__)

23
pddnsc/filters/files.py Normal file
View File

@ -0,0 +1,23 @@
import asyncio
import hashlib
import json
import aiofiles
from os.path import isfile
from pddnsc.base import BaseFilterProvider
class StateHashFilter(BaseFilterProvider):
async def check_imp(self, source_provider, addr_v4, addr_v6):
if not isfile(self.config['filepath']):
return True
new_state = {
"ipv4": addr_v4 or "",
"ipv6": addr_v6 or "",
}
new_state_str = json.dumps(new_state)
new_sha = hashlib.sha256(new_state_str.encode(encoding='utf-8'))
async with aiofiles.open(self.config['filepath'], mode='r', encoding='utf-8') as f:
old_state_hash = await f.read()
return old_state_hash != new_sha.hexdigest()

View File

@ -7,6 +7,11 @@ debug = true
provider = "FakeSource" provider = "FakeSource"
ipv6 = "fe80::1" ipv6 = "fe80::1"
[filters]
[filters.state-hash]
provider = "StateHashFilter"
filepath = "state/hash.txt"
[outputs] [outputs]
[outputs.print] [outputs.print]
provider = "JustPrint" provider = "JustPrint"