Compare commits

...

2 Commits

Author SHA1 Message Date
bbfb4f92af
fmt + StateFileFilter 2024-02-20 10:40:23 +03:00
773a4e0107
add filters 2024-02-20 10:15:48 +03:00
13 changed files with 229 additions and 64 deletions

View File

@ -2,6 +2,7 @@ import httpx
import asyncio import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
class BaseSourceProvider(ABC): class BaseSourceProvider(ABC):
_childs = {} _childs = {}
registred = {} registred = {}
@ -46,12 +47,10 @@ class BaseSourceProvider(ABC):
cls.registred[name] = provider(name, config, ipv4t, ipv6t) cls.registred[name] = provider(name, config, ipv4t, ipv6t)
@abstractmethod @abstractmethod
async def fetch_v4(self) -> str: async def fetch_v4(self) -> str: ...
...
@abstractmethod @abstractmethod
async def fetch_v6(self) -> str: async def fetch_v6(self) -> str: ...
...
class BaseOutputProvider(ABC): class BaseOutputProvider(ABC):
@ -94,5 +93,47 @@ class BaseOutputProvider(ABC):
return await self.set_addrs_imp(source_provider, addr_v4, addr_v6) return await self.set_addrs_imp(source_provider, addr_v4, addr_v6)
@abstractmethod @abstractmethod
async def set_addrs_imp(self, source_provider, addr_v4, addr_v6): async def set_addrs_imp(self, source_provider, addr_v4, addr_v6): ...
...
class BaseFilterProvider(ABC):
_childs = {}
registred = {}
def __init__(self, name, config, ipv4t, ipv6t):
self.name, self.config = name, config
self.ipv4t, self.ipv6t = ipv4t, ipv6t
def __init_subclass__(cls) -> None:
BaseFilterProvider._childs[cls.__name__] = cls
return super().__init_subclass__()
def __str__(self):
return f"{self.__class__.__name__}: {self.name}"
def best_client(self, addr_v4, addr_v6):
if addr_v6 is None and addr_v4 is not None:
return self.ipv4t
return self.ipv6t
@classmethod
def validate_source_config(cls, name, config):
if "provider" not in config:
return False
prov_name = config["provider"]
if prov_name not in cls._childs:
return False
return True
@classmethod
def register_provider(cls, name, config, ipv4t, ipv6t):
if not cls.validate_source_config(name, config):
return
provider = cls._childs[config["provider"]]
cls.registred[name] = provider(name, config, ipv4t, ipv6t)
async def check(self, source_provider, addr_v4, addr_v6):
return await self.check_imp(source_provider, addr_v4, addr_v6)
@abstractmethod
async def check_imp(self, source_provider, addr_v4, addr_v6): ...

View File

@ -2,14 +2,15 @@ import httpx
import asyncio import asyncio
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import toml import toml
from .base import BaseSourceProvider, BaseOutputProvider from .base import BaseFilterProvider, BaseSourceProvider, BaseOutputProvider
from . import sources from .plugins import use_plugins
from . import outputs
async def source_task(providers):
async def source_task():
providers = BaseSourceProvider.registred.values()
result = None result = None
is_done = False is_done = False
pending = [asyncio.create_task(p.fetch_all()) for p in providers] pending = [asyncio.create_task(p.fetch_all(), name=p.name) for p in providers]
while not is_done and pending: while not is_done and pending:
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
for x in done: for x in done:
@ -28,23 +29,42 @@ async def source_task(providers):
return result return result
async def output_task(providers, result): async def filter_task(ip_result):
providers = BaseFilterProvider.registred.values()
result = True
failed = ""
pending = [asyncio.create_task(p.check(*ip_result), name=p.name) for p in providers]
while result and pending:
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
for x in done:
result = x.result()
if not result:
failed = x.get_name()
if pending:
gather = asyncio.gather(*pending)
gather.cancel()
try:
await gather
except asyncio.CancelledError:
pass
if not result:
print("failed filter:", failed)
return result
async def output_task(result):
providers = BaseOutputProvider.registred.values()
await asyncio.gather( await asyncio.gather(
*(asyncio.create_task(p.set_addrs(*result)) for p in providers) *(asyncio.create_task(p.set_addrs(*result), name=p.name) for p in providers)
) )
async def app(ipv4t, ipv6t): async def app(ipv4t, ipv6t):
config = toml.load("settings/config.toml") config = toml.load("settings/config.toml")
for source_name in config["sources"]: use_plugins(config, ipv4t, ipv6t)
BaseSourceProvider.register_provider(
source_name, config["sources"][source_name], ipv4t, ipv6t
)
for output_name in config["outputs"]:
BaseOutputProvider.register_provider(
output_name, config["outputs"][output_name], ipv4t, ipv6t
)
debug = config.get("debug", False) debug = config.get("debug", False)
if debug: if debug:
@ -52,19 +72,28 @@ async def app(ipv4t, ipv6t):
print( print(
f"source classes: {[*BaseSourceProvider._childs]}, {[*map(str, BaseSourceProvider._childs.values())]}" f"source classes: {[*BaseSourceProvider._childs]}, {[*map(str, BaseSourceProvider._childs.values())]}"
) )
print(
f"filter classes: {[*BaseFilterProvider._childs]}, {[*map(str, BaseFilterProvider._childs.values())]}"
)
print( print(
f"output classes: {[*BaseOutputProvider._childs]}, {[*map(str, BaseOutputProvider._childs.values())]}" f"output classes: {[*BaseOutputProvider._childs]}, {[*map(str, BaseOutputProvider._childs.values())]}"
) )
print( print(
f"source providers: {[*BaseSourceProvider.registred]}, {[*map(str, BaseSourceProvider.registred.values())]}" f"source providers: {[*BaseSourceProvider.registred]}, {[*map(str, BaseSourceProvider.registred.values())]}"
) )
print(
f"filter providers: {[*BaseFilterProvider.registred]}, {[*map(str, BaseFilterProvider.registred.values())]}"
)
print( print(
f"output providers: {[*BaseOutputProvider.registred]}, {[*map(str, BaseOutputProvider.registred.values())]}" f"output providers: {[*BaseOutputProvider.registred]}, {[*map(str, BaseOutputProvider.registred.values())]}"
) )
print(config)
result = await source_task(BaseSourceProvider.registred.values()) result = await source_task()
await output_task(BaseOutputProvider.registred.values(), result) if not await filter_task(result):
print("stop by filters")
return
await output_task(result)
print("done")
async def main(): async def main():

View File

@ -0,0 +1,3 @@
from pddnsc.loaders import load_plugins
load_plugins(__file__)

55
pddnsc/filters/files.py Normal file
View File

@ -0,0 +1,55 @@
import asyncio
import hashlib
import json
import aiofiles
from os.path import isfile
from pddnsc.base import BaseFilterProvider
class StateHashFilter(BaseFilterProvider):
async def check_imp(self, source_provider, addr_v4, addr_v6):
if not isfile(self.config["filepath"]):
return True
new_state = {
"ipv4": addr_v4 or "",
"ipv6": addr_v6 or "",
}
new_state_str = json.dumps(new_state)
new_sha = hashlib.sha256(new_state_str.encode(encoding="utf-8"))
async with aiofiles.open(
self.config["filepath"], mode="r", encoding="utf-8"
) as f:
old_state_hash = await f.read()
return old_state_hash != new_sha.hexdigest()
class StateFileFilter(BaseFilterProvider):
async def check_imp(self, source_provider, addr_v4, addr_v6):
if not isfile(self.config["filepath"]):
return True
new_state = {
"ipv4": addr_v4 or "",
"ipv6": addr_v6 or "",
}
async with aiofiles.open(
self.config["filepath"], mode="r", encoding="utf-8"
) as f:
old_state = json.loads(await f.read())
result = True
if "check_ipv4" not in self.config and "check_ipv4" not in self.config:
return new_state != old_state
if self.config.get("check_ipv4", False):
result = result and new_state["ipv4"] != old_state["ipv4"]
if self.config.get("check_ipv6", False):
result = result and new_state["ipv6"] != old_state["ipv6"]
return result

25
pddnsc/loaders.py Normal file
View File

@ -0,0 +1,25 @@
import os
import traceback
from importlib import util
def load_module(path):
name = os.path.split(path)[-1]
spec = util.spec_from_file_location(name, path)
module = util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def load_plugins(init_filepath):
dirpath = os.path.dirname(os.path.abspath(init_filepath))
for fname in os.listdir(dirpath):
if (
not fname.startswith(".")
and not fname.startswith("__")
and fname.endswith(".py")
):
try:
load_module(os.path.join(dirpath, fname))
except Exception:
traceback.print_exc()

View File

@ -1,3 +1,3 @@
from pddnsc.plugins import load_plugins from pddnsc.loaders import load_plugins
load_plugins(__file__) load_plugins(__file__)

View File

@ -2,6 +2,7 @@ import asyncio
from pddnsc.base import BaseOutputProvider from pddnsc.base import BaseOutputProvider
class JustPrint(BaseOutputProvider): class JustPrint(BaseOutputProvider):
async def set_addrs_imp(self, source_provider, addr_v4, addr_v6): async def set_addrs_imp(self, source_provider, addr_v4, addr_v6):
print(f">> {self.name}") print(f">> {self.name}")

View File

@ -13,7 +13,9 @@ class StateFile(BaseOutputProvider):
"ipv6": addr_v6 or "", "ipv6": addr_v6 or "",
} }
state_str = json.dumps(state) state_str = json.dumps(state)
async with aiofiles.open(self.config['filepath'], mode='w', encoding='utf-8') as f: async with aiofiles.open(
self.config["filepath"], mode="w", encoding="utf-8"
) as f:
await f.write(state_str) await f.write(state_str)
@ -24,6 +26,8 @@ class StateHashFile(BaseOutputProvider):
"ipv6": addr_v6 or "", "ipv6": addr_v6 or "",
} }
state_str = json.dumps(state) state_str = json.dumps(state)
sha = hashlib.sha256(state_str.encode(encoding='utf-8')) sha = hashlib.sha256(state_str.encode(encoding="utf-8"))
async with aiofiles.open(self.config['filepath'], mode='w', encoding='utf-8') as f: async with aiofiles.open(
self.config["filepath"], mode="w", encoding="utf-8"
) as f:
await f.write(sha.hexdigest()) await f.write(sha.hexdigest())

View File

@ -1,22 +1,21 @@
import os from .base import BaseSourceProvider, BaseFilterProvider, BaseOutputProvider
import traceback from . import sources
from importlib import util from . import outputs
from . import filters
def load_module(path): def use_plugins(config, ipv4t, ipv6t):
name = os.path.split(path)[-1] for source_name in config["sources"]:
spec = util.spec_from_file_location(name, path) BaseSourceProvider.register_provider(
module = util.module_from_spec(spec) source_name, config["sources"][source_name], ipv4t, ipv6t
spec.loader.exec_module(module) )
return module
for filter_name in config["filters"]:
BaseFilterProvider.register_provider(
filter_name, config["filters"][filter_name], ipv4t, ipv6t
)
def load_plugins(init_filepath): for output_name in config["outputs"]:
dirpath = os.path.dirname(os.path.abspath(init_filepath)) BaseOutputProvider.register_provider(
for fname in os.listdir(dirpath): output_name, config["outputs"][output_name], ipv4t, ipv6t
if not fname.startswith('.') and \ )
not fname.startswith('__') and fname.endswith('.py'):
try:
load_module(os.path.join(dirpath, fname))
except Exception:
traceback.print_exc()

View File

@ -1,3 +1,3 @@
from pddnsc.plugins import load_plugins from pddnsc.loaders import load_plugins
load_plugins(__file__) load_plugins(__file__)

View File

@ -3,29 +3,26 @@ import asyncio
from pddnsc.base import BaseSourceProvider from pddnsc.base import BaseSourceProvider
class DummySource(BaseSourceProvider): class DummySource(BaseSourceProvider):
async def fetch_v4(self) -> str: async def fetch_v4(self) -> str:
async with httpx.AsyncClient(transport=self.ipv4t) as client: result = await asyncio.sleep(self.config.get("delay", 1), result=None)
result = await asyncio.sleep(10, result=None)
result = await asyncio.sleep(10, result=None)
return result return result
async def fetch_v6(self) -> str: async def fetch_v6(self) -> str:
async with httpx.AsyncClient(transport=self.ipv6t) as client: result = await asyncio.sleep(self.config.get("delay", 1), result=None)
result = await asyncio.sleep(10, result=None)
result = await asyncio.sleep(10, result=None)
return result return result
class FakeSource(BaseSourceProvider): class FakeSource(BaseSourceProvider):
async def fetch_v4(self) -> str: async def fetch_v4(self) -> str:
async with httpx.AsyncClient(transport=self.ipv4t) as client:
result = await asyncio.sleep( result = await asyncio.sleep(
3.3, result=self.config.get("ipv4", "127.0.0.1") self.config.get("delay", 1), result=self.config.get("ipv4", "127.0.0.1")
) )
return result return result
async def fetch_v6(self) -> str: async def fetch_v6(self) -> str:
async with httpx.AsyncClient(transport=self.ipv6t) as client: result = await asyncio.sleep(
result = await asyncio.sleep(4.4, result=self.config.get("ipv6", "::1")) self.config.get("delay", 1), result=self.config.get("ipv6", "::1")
)
return result return result

View File

@ -2,6 +2,7 @@ import httpx
from pddnsc.base import BaseSourceProvider from pddnsc.base import BaseSourceProvider
class IPIFYSource(BaseSourceProvider): class IPIFYSource(BaseSourceProvider):
async def fetch_v4(self) -> str: async def fetch_v4(self) -> str:
async with httpx.AsyncClient(transport=self.ipv4t) as client: async with httpx.AsyncClient(transport=self.ipv4t) as client:

View File

@ -1,16 +1,26 @@
debug = true debug = true
[sources] [sources]
[sources.test1-src] [sources.ipfy]
provider = "IPIFYSource" provider = "IPIFYSource"
[sources.test2-src] [sources.fake]
provider = "FakeSource" provider = "FakeSource"
delay = 10
ipv6 = "fe80::1" ipv6 = "fe80::1"
[filters]
[filters.state-file]
provider = "StateFileFilter"
filepath = "state/state.json"
check_ipv4 = true
[filters.state-hash]
provider = "StateHashFilter"
filepath = "state/hash.txt"
[outputs] [outputs]
[outputs.print] [outputs.print]
provider = "JustPrint" provider = "JustPrint"
[outputs.file] [outputs.state-file]
provider = "StateFile" provider = "StateFile"
filepath = "state/state.json" filepath = "state/state.json"
[outputs.hash-file] [outputs.hash-file]