add filters
This commit is contained in:
parent
0dc440e8cf
commit
773a4e0107
@ -96,3 +96,46 @@ class BaseOutputProvider(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def set_addrs_imp(self, source_provider, addr_v4, addr_v6):
|
async def set_addrs_imp(self, source_provider, addr_v4, addr_v6):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
class BaseFilterProvider(ABC):
|
||||||
|
_childs = {}
|
||||||
|
registred = {}
|
||||||
|
|
||||||
|
def __init__(self, name, config, ipv4t, ipv6t):
|
||||||
|
self.name, self.config = name, config
|
||||||
|
self.ipv4t, self.ipv6t = ipv4t, ipv6t
|
||||||
|
|
||||||
|
def __init_subclass__(cls) -> None:
|
||||||
|
BaseFilterProvider._childs[cls.__name__] = cls
|
||||||
|
return super().__init_subclass__()
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return f"{self.__class__.__name__}: {self.name}"
|
||||||
|
|
||||||
|
def best_client(self, addr_v4, addr_v6):
|
||||||
|
if addr_v6 is None and addr_v4 is not None:
|
||||||
|
return self.ipv4t
|
||||||
|
return self.ipv6t
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_source_config(cls, name, config):
|
||||||
|
if "provider" not in config:
|
||||||
|
return False
|
||||||
|
prov_name = config["provider"]
|
||||||
|
if prov_name not in cls._childs:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_provider(cls, name, config, ipv4t, ipv6t):
|
||||||
|
if not cls.validate_source_config(name, config):
|
||||||
|
return
|
||||||
|
provider = cls._childs[config["provider"]]
|
||||||
|
cls.registred[name] = provider(name, config, ipv4t, ipv6t)
|
||||||
|
|
||||||
|
async def check(self, source_provider, addr_v4, addr_v6):
|
||||||
|
return await self.check_imp(source_provider, addr_v4, addr_v6)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def check_imp(self, source_provider, addr_v4, addr_v6):
|
||||||
|
...
|
@ -2,14 +2,16 @@ import httpx
|
|||||||
import asyncio
|
import asyncio
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import toml
|
import toml
|
||||||
from .base import BaseSourceProvider, BaseOutputProvider
|
from .base import BaseFilterProvider, BaseSourceProvider, BaseOutputProvider
|
||||||
from . import sources
|
from . import sources
|
||||||
from . import outputs
|
from . import outputs
|
||||||
|
from . import filters
|
||||||
|
|
||||||
async def source_task(providers):
|
async def source_task():
|
||||||
|
providers = BaseSourceProvider.registred.values()
|
||||||
result = None
|
result = None
|
||||||
is_done = False
|
is_done = False
|
||||||
pending = [asyncio.create_task(p.fetch_all()) for p in providers]
|
pending = [asyncio.create_task(p.fetch_all(), name=p.name) for p in providers]
|
||||||
while not is_done and pending:
|
while not is_done and pending:
|
||||||
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
|
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
|
||||||
for x in done:
|
for x in done:
|
||||||
@ -29,10 +31,36 @@ async def source_task(providers):
|
|||||||
|
|
||||||
|
|
||||||
async def output_task(providers, result):
|
async def output_task(providers, result):
|
||||||
|
providers = BaseOutputProvider.registred.values()
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
*(asyncio.create_task(p.set_addrs(*result)) for p in providers)
|
*(asyncio.create_task(p.set_addrs(*result), name=p.name) for p in providers)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def filter_task(ip_result):
|
||||||
|
providers = BaseFilterProvider.registred.values()
|
||||||
|
result = True
|
||||||
|
failed = ""
|
||||||
|
pending = [asyncio.create_task(p.check(*ip_result), name=p.name) for p in providers]
|
||||||
|
while result and pending:
|
||||||
|
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
|
||||||
|
for x in done:
|
||||||
|
result = x.result()
|
||||||
|
if not result:
|
||||||
|
failed = x.get_name()
|
||||||
|
|
||||||
|
if pending:
|
||||||
|
gather = asyncio.gather(*pending)
|
||||||
|
gather.cancel()
|
||||||
|
try:
|
||||||
|
await gather
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
print("failed filter:", failed)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def app(ipv4t, ipv6t):
|
async def app(ipv4t, ipv6t):
|
||||||
config = toml.load("settings/config.toml")
|
config = toml.load("settings/config.toml")
|
||||||
@ -41,6 +69,11 @@ async def app(ipv4t, ipv6t):
|
|||||||
source_name, config["sources"][source_name], ipv4t, ipv6t
|
source_name, config["sources"][source_name], ipv4t, ipv6t
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for filter_name in config["filters"]:
|
||||||
|
BaseFilterProvider.register_provider(
|
||||||
|
filter_name, config["filters"][filter_name], ipv4t, ipv6t
|
||||||
|
)
|
||||||
|
|
||||||
for output_name in config["outputs"]:
|
for output_name in config["outputs"]:
|
||||||
BaseOutputProvider.register_provider(
|
BaseOutputProvider.register_provider(
|
||||||
output_name, config["outputs"][output_name], ipv4t, ipv6t
|
output_name, config["outputs"][output_name], ipv4t, ipv6t
|
||||||
@ -52,19 +85,28 @@ async def app(ipv4t, ipv6t):
|
|||||||
print(
|
print(
|
||||||
f"source classes: {[*BaseSourceProvider._childs]}, {[*map(str, BaseSourceProvider._childs.values())]}"
|
f"source classes: {[*BaseSourceProvider._childs]}, {[*map(str, BaseSourceProvider._childs.values())]}"
|
||||||
)
|
)
|
||||||
|
print(
|
||||||
|
f"filter classes: {[*BaseFilterProvider._childs]}, {[*map(str, BaseFilterProvider._childs.values())]}"
|
||||||
|
)
|
||||||
print(
|
print(
|
||||||
f"output classes: {[*BaseOutputProvider._childs]}, {[*map(str, BaseOutputProvider._childs.values())]}"
|
f"output classes: {[*BaseOutputProvider._childs]}, {[*map(str, BaseOutputProvider._childs.values())]}"
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f"source providers: {[*BaseSourceProvider.registred]}, {[*map(str, BaseSourceProvider.registred.values())]}"
|
f"source providers: {[*BaseSourceProvider.registred]}, {[*map(str, BaseSourceProvider.registred.values())]}"
|
||||||
)
|
)
|
||||||
|
print(
|
||||||
|
f"filter providers: {[*BaseFilterProvider.registred]}, {[*map(str, BaseFilterProvider.registred.values())]}"
|
||||||
|
)
|
||||||
print(
|
print(
|
||||||
f"output providers: {[*BaseOutputProvider.registred]}, {[*map(str, BaseOutputProvider.registred.values())]}"
|
f"output providers: {[*BaseOutputProvider.registred]}, {[*map(str, BaseOutputProvider.registred.values())]}"
|
||||||
)
|
)
|
||||||
print(config)
|
#print(config)
|
||||||
|
|
||||||
result = await source_task(BaseSourceProvider.registred.values())
|
result = await source_task()
|
||||||
await output_task(BaseOutputProvider.registred.values(), result)
|
if not await filter_task(result):
|
||||||
|
print("stop by filters")
|
||||||
|
return
|
||||||
|
await output_task(result)
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
3
pddnsc/filters/__init__.py
Normal file
3
pddnsc/filters/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from pddnsc.plugins import load_plugins
|
||||||
|
|
||||||
|
load_plugins(__file__)
|
23
pddnsc/filters/files.py
Normal file
23
pddnsc/filters/files.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import aiofiles
|
||||||
|
from os.path import isfile
|
||||||
|
|
||||||
|
from pddnsc.base import BaseFilterProvider
|
||||||
|
|
||||||
|
class StateHashFilter(BaseFilterProvider):
|
||||||
|
async def check_imp(self, source_provider, addr_v4, addr_v6):
|
||||||
|
if not isfile(self.config['filepath']):
|
||||||
|
return True
|
||||||
|
|
||||||
|
new_state = {
|
||||||
|
"ipv4": addr_v4 or "",
|
||||||
|
"ipv6": addr_v6 or "",
|
||||||
|
}
|
||||||
|
new_state_str = json.dumps(new_state)
|
||||||
|
new_sha = hashlib.sha256(new_state_str.encode(encoding='utf-8'))
|
||||||
|
async with aiofiles.open(self.config['filepath'], mode='r', encoding='utf-8') as f:
|
||||||
|
old_state_hash = await f.read()
|
||||||
|
|
||||||
|
return old_state_hash != new_sha.hexdigest()
|
@ -7,6 +7,11 @@ debug = true
|
|||||||
provider = "FakeSource"
|
provider = "FakeSource"
|
||||||
ipv6 = "fe80::1"
|
ipv6 = "fe80::1"
|
||||||
|
|
||||||
|
[filters]
|
||||||
|
[filters.state-hash]
|
||||||
|
provider = "StateHashFilter"
|
||||||
|
filepath = "state/hash.txt"
|
||||||
|
|
||||||
[outputs]
|
[outputs]
|
||||||
[outputs.print]
|
[outputs.print]
|
||||||
provider = "JustPrint"
|
provider = "JustPrint"
|
||||||
|
Loading…
Reference in New Issue
Block a user