import re
import datetime
import sys
import pathlib
from typing import NamedTuple

pattern = re.compile(
    r"(?P<date>\d{4}(-\d{2}){2})\s(?P<time>\d{2}(:\d{2}){2})\s(?P<IP>\d{1,3}(\.\d{1,3}){3}):\d+\s\[(?P<name>\w+)\]\sPeer\sConnection\sInitiated"
)


class ConnectionInfo(NamedTuple):
    """
    A named tuple representing connection information.

    Attributes:
        ip (str): The IP address of the connection.
        datetime (datetime.datetime): The date and time of the connection.
    """

    ip: str
    datetime: datetime.datetime


ConnectionMap = dict[str, list[ConnectionInfo]]


class ConnectionEntry(NamedTuple):
    """
    A named tuple representing a connection entry.

    Attributes:
        name (str): The name of the connection.
        info (ConnectionInfo): The connection information.
    """

    name: str
    info: ConnectionInfo


def log_lines(filepath: pathlib.Path):
    """
    A generator function that reads lines from a log file.

    Args:
        filepath (pathlib.Path): The path to the log file.

    Yields:
        str: A line from the log file.
    """

    with filepath.open("r") as f:
        for line in f:
            yield line


def filter_log_lines_for_date(lines, date: str):
    """
    A function that filters log lines for a specific date.

    Args:
        lines (iterable): An iterable of log lines.
        date (str): The date to filter the log lines for.

    Returns:
        iterable: An iterable of log lines that start with the specified date.
    """

    return filter(lambda line: line.startswith(date + " "), lines)


def parse_date_time(date: str, time: str) -> datetime.datetime:
    """
    A function that parses a date and time string into a datetime object.

    Args:
        date (str): The date string to parse.
        time (str): The time string to parse.

    Returns:
        datetime.datetime: The parsed datetime object.
    """

    return datetime.datetime.strptime(date + " " + time, "%Y-%m-%d %H:%M:%S")


def parse_connections(lines):
    """
    A generator function that parses log lines into connection entries.

    Args:
        lines (Iterable[str]): An iterable of log lines.

    Yields:
        ConnectionEntry: A connection entry parsed from a log line.
    """

    for line in lines:
        if match := pattern.match(line):
            yield ConnectionEntry(
                match.group("name"),
                ConnectionInfo(
                    match.group("IP"),
                    parse_date_time(match.group("date"), match.group("time")),
                ),
            )


def get_conn_map(lines):
    """
    A function that gets a map of connection entries from log lines.

    Args:
        lines (Iterable[str]): An iterable of log lines.

    Returns:
        dict: A map of connection entries, where the keys are connection names and the values are lists of connection information.
    """

    result = {}
    for name, info in parse_connections(lines):
        result.setdefault(name, []).append(info)
    return result


def find_names_with_multiple_ips(connmap: ConnectionMap) -> ConnectionMap:
    """
    A function that finds connection entries with same cn and multiple IPs.

    Args:
        connmap (ConnectionMap): A map of connection entries.

    Returns:
        ConnectionMap: A map of connection entries with multiple IPs, where the keys are connection names (common name) and the values are lists of connection information.
    """

    result = {}
    for name, infos in connmap.items():
        ips = {info.ip for info in infos}
        if len(ips) < 2:
            continue
        for ip in ips:
            max_ip_info = max(
                filter(lambda info: info.ip == ip, infos),
                key=lambda info: info.datetime,
            )
            result.setdefault(name, []).append(max_ip_info)
    return result


def find_fast_repeats(
    connmap: ConnectionMap, threshold: datetime.timedelta, min_repeats: int
) -> ConnectionMap:
    """
    A function that finds connection entries with fast repeats.

    Args:
        connmap (ConnectionMap): A map of connection entries.
        threshold (datetime.timedelta): The maximum time difference between two connection entries to be considered a repeat.
        min_repeats (int): The minimum number of repeats for a connection entry to be included in the result.

    Returns:
        ConnectionMap: A map of connection entries with fast repeats, where the keys are connection names and the values are lists of connection information.
    """

    result = {}
    for name, infos in connmap.items():
        if len(infos) < 2:
            continue
        infos = sorted(infos, key=lambda info: info.datetime)
        for a, b in zip(infos, infos[1:]):
            if a.ip == b.ip and b.datetime - a.datetime <= threshold:
                for x in a, b:
                    lst = result.setdefault(name, [])
                    if x not in lst:
                        lst.append(x)
        if name in result and len(result[name]) < min_repeats:
            result.pop(name)
    return result


def print_multiple_ips(connmap: ConnectionMap):
    """
    A function that prints connection entries with multiple IPs.

    Args:
        connmap (ConnectionMap): A map of connection entries.
    """

    if len(connmap) == 0:
        return
    print("Multiple IPs:")
    for name, infos in connmap.items():
        print(f"- {name}:")
        for n, info in enumerate(sorted(infos, key=lambda info: info.datetime), 1):
            print(f"  {n:02}. {info.ip}: {info.datetime}")


def print_fast_repeats(connmap: ConnectionMap, limit_for_one=10):
    """
    A function that prints connection entries with fast repeats.

    Args:
        connmap (ConnectionMap): A map of connection entries.
        limit_for_one (int, optional): The maximum number of repeats to print for each connection entry. Defaults to 10.
    """
    
    if len(connmap) == 0:
        return
    print("Fast repeats:")
    for name, infos in connmap.items():
        print(f"- {name}:")
        for n, info in enumerate(
            sorted(infos, key=lambda info: info.datetime, reverse=True), 1
        ):
            print(f"  {n:2}. {info.ip}: {info.datetime}")
            if n >= limit_for_one:
                break


def main():
    if len(sys.argv) < 2:
        print("Error: please specify a log file")
        exit(1)

    date = datetime.date.today().strftime("%Y-%m-%d")
    log_file = pathlib.Path(sys.argv[1])
    lines = log_lines(log_file)
    lines = filter_log_lines_for_date(lines, date)
    connmap = get_conn_map(lines)
    multiple_ips = find_names_with_multiple_ips(connmap)
    fast_repeats = find_fast_repeats(connmap, datetime.timedelta(minutes=3), 10)
    print_multiple_ips(multiple_ips)
    print_fast_repeats(fast_repeats)


if __name__ == "__main__":
    main()