#!/usr/bin/python3.9
# ruff: noqa: T201

"""Export specified group config to tar.gz file"""

import argparse
import base64
import hashlib
import hmac
import json
import tarfile
import time
from functools import partial
from io import BytesIO
from pathlib import Path
from typing import Any, Optional
from urllib.error import HTTPError
from urllib.request import Request, urlopen

parser = argparse.ArgumentParser(description="Export part of UCS configuration")
parser.add_argument("group_id", help="branch ID to export")
parser.add_argument("-a", dest="url", action="append", help="UCS API URL", default="http://127.0.0.1:1234/JSON")
parser.add_argument("-k", dest="key", action="append", help="uAuth PSK", default="file:///etc/uauth/psk")
parser.add_argument("-u", dest="user", action="append", help="username to execute request as", default="admin")
args = parser.parse_args()

def jwt(key: str = "file:///etc/uauth/psk", user: str = "admin") -> str:
    """Make JWT for uAuth"""
    if key.startswith("file://"):
        key = Path(key.removeprefix("file://")).read_text().rstrip()

    segments = []
    header = {"typ": "JWT", "alg": "HS256"}
    json_header = json.dumps(header, separators=(",", ":"), sort_keys=True).encode("utf8")
    segments.append(base64.urlsafe_b64encode(json_header).replace(b"=", b""))
    payload = {"sub": user, "iat": int(time.time())}
    json_payload = json.dumps(payload, separators=(",", ":")).encode("utf8")
    segments.append(base64.urlsafe_b64encode(json_payload).replace(b"=", b""))
    signing_input = b".".join(segments)
    signature = hmac.new(key.encode("utf-8"), signing_input, hashlib.sha256).digest()
    segments.append(base64.urlsafe_b64encode(signature).replace(b"=", b""))
    return b".".join(segments).decode("utf-8")


class ApiError(Exception):
    """API reported error"""

    def __init__(self, message: str, code: int, data: Optional[dict[str, str]] = None):
        super().__init__(message)
        self.message = message
        self.code = code
        self.data = data


def ucs_rpc(url: str, user: str, key: str, method: str, *args: list[Any]) -> Any:  # noqa: ANN401 - we really don't know the type
    """Do UCS-JSON-RPC request"""
    url = f'{url.removesuffix("/")}/{method}'
    data = json.dumps([args, {}, {}]).encode("utf-8")
    headers = {
        "Authorization": f"uAuth {jwt(key, user)}",
        "Content-Type": "application/json",
    }
    if not url.startswith(("http://", "https://")):
        message = "API URL must be HTTP or HTTPS"
        raise OSError(message)

    req = Request(url, data, headers)  # noqa: S310 - wtf, we check it above
    try:
        with urlopen(req) as response:  # noqa: S310 - wtf, we check it above
            return json.load(response.fp)

    except HTTPError as error:
        try:
            ucs_error = json.load(error.fp)
            message = [f'UCS FAIL: {ucs_error["message"]}']
            data = ucs_error.get("data")
            if ucs_error["code"] == 0:  # Input validator error
                for k, v in data.items():
                    message.append(f" * {k}: {v}")
            code = min(abs(ucs_error["code"]), 90)
            raise ApiError("\n".join(message), code=code, data=data)

        except (ValueError, KeyError):
            message = f"API FAIL: {error}"
            raise ApiError(message, code=98) from error

    except OSError as error:
        message = f"HTTP FAIL: {error}"
        raise ApiError(message, code=99) from error


ucs = partial(ucs_rpc, args.url, args.user, args.key)


def get_used_timeperiods(extension: dict[str, Any]) -> set[int]:
    """Extract timeperiod IDs from extensions"""
    timeperiods = set()

    settings = extension["settings"]
    if extension["type"] == "timeroute":
        for route in settings["routes"]:
            if route["period_id"]:
                timeperiods.add(route["period_id"])

    elif extension["type"] == "queue":
        working_hours_period_id = settings["working_hours_period_id"]
        if working_hours_period_id:
            timeperiods.add(working_hours_period_id)

    return timeperiods

def get_used_soundfiles(extension: dict[str, Any]) -> set[str]:
    """Extract soundfile names from extensions"""
    soundfiles = set()
    settings = extension["settings"]

    if extension["type"] == "playback":
        soundfile = settings["file"]
        if soundfile:
            soundfiles.update(soundfile.split("&"))

    elif extension["type"] in ["ivr", "disa"]:
        soundfile = settings["announce"]
        if soundfile:
            soundfiles.update(soundfile.split("&"))

    elif extension["type"] == "queue":
        soundfile = settings["announce"]
        if soundfile:
            soundfiles.update(soundfile.split("&"))
        announcment = settings["announcment"]
        for key in [
            "queue-youarenext",
            "queue-thereare",
            "queue-callswaiting",
            "queue-holdtime",
            "queue-minute",
            "queue-minutes",
            "queue-seconds",
            "queue-thankyou",
            "queue-reporthold",
            "periodic-announce",
        ]:
            soundfile = announcment[key]
            if soundfile:
                soundfiles.update(soundfile.split("&"))
        if announcment["longer-wait"]:  # it may be None
            for longer_wait in announcment["longer-wait"]:
                soundfile = longer_wait["announce"]
                if soundfile:
                    soundfiles.update(soundfile.split("&"))

    return soundfiles


def get_config_objects(object_type: str, needed_ids: set[int], id_name: str = "id") -> list[dict[str, Any]]:
    """Dynamically get configuration from UCS"""
    needed = []
    try:
        objects = ucs(f"config.{object_type}.list")
    except ApiError as error:
        print(f"Unable to list {object_type}: {error}")
        return needed

    for obj in objects:
        obj_id = obj[id_name]
        if obj_id not in needed_ids:
            continue

        try:
            needed.append(ucs(f"config.{object_type}.get", obj_id))
        except ApiError as error:
            print(f"Unable to get {object_type} ID {obj_id}: {error}")
            continue

    return needed


groups = []
extensions = []

needed_setting_sets = set()
needed_timeperiods = set()
needed_soundfiles = set()
branch = [args.group_id]
while branch:
    group_id = branch.pop(0)
    try:
        group = ucs("config.groups.get", group_id)
    except ApiError as error:
        print(f"Unable to get group ID {error}")
        continue

    group.pop("created")
    group.pop("updated")
    group.pop("deleted")
    groups.append(group)
    if group["settings_id"]:
        needed_setting_sets.add(group["settings_id"])

    try:
        children = ucs("structure.childs", group_id)
    except ApiError as error:
        print(f"Unable to get group ID {group_id} children: {error}")
        continue

    for child in children:
        if child["type"] == "group":
            branch.append(child["id"])
            continue

        try:
            extension = ucs("config.extensions.get", child["id"])
        except ApiError as error:
            print(f"Unable to get group ID {group_id} extension ID {child['id']}: {error}")
            continue

        extension.pop("created")
        extension.pop("updated")
        extension.pop("deleted")
        extensions.append(extension)
        if extension["settings_id"]:
            needed_setting_sets.add(extension["settings_id"])
        needed_timeperiods.update(get_used_timeperiods(extension))
        needed_soundfiles.update(get_used_soundfiles(extension))

paths = {"" if path.find("/") == -1 else path.rsplit("/", 1)[0] for path in needed_soundfiles}
ucs_soundfiles = {}
for path in paths:
    try:
        files = ucs("config.soundfiles.list", f"/{path}")
    except ApiError as error:
        print(f"Unable to get soundfiles at {path}: {error}")
        continue

    if path:
        path = f"/{path}"

    for file in files:
        if file["directory"]:
            continue

        if file.get("group_id") is None:
            continue

        filename = file["name"]
        pos = filename.rfind(".")
        key = filename[:pos] if pos != -1 else filename
        ucs_soundfiles[f"{path}/{key}"] = path, filename

with tarfile.open("config.tar.gz", "w:gz") as tar:
    soundfiles = []
    for path in needed_soundfiles:
        if not path.startswith("/"):
            path = f"/{path}"
        file_path, file_name = ucs_soundfiles.get(path, (None, None))
        if not file_name:
            print(f"Unable to determine soundfile for {path}")
            continue

        filename = f"{file_path}/{file_name}"
        try:
            with_content = True
            soundfile = ucs("config.soundfiles.get", filename, with_content)
        except ApiError as error:
            print(f"Unable to get soundfile {filename}: {error}")
            continue

        soundfile["path"] = file_path
        soundfiles.append(soundfile)

        content = base64.b64decode(soundfile.pop("content"))
        soundfile_wav = tarfile.TarInfo(f"soundfiles{filename}")
        soundfile_wav.size = len(content)
        tar.addfile(soundfile_wav, BytesIO(content))

    config = json.dumps({
        "groups": groups,
        "extensions": extensions,
        "settings": get_config_objects("settings", needed_setting_sets),
        "timeperiods": get_config_objects("periods", needed_timeperiods),
        "soundfiles": soundfiles,
    }, indent=2).encode("ascii")
    config_json = tarfile.TarInfo("config.json")
    config_json.size = len(config)
    tar.addfile(config_json, BytesIO(config))
