#!/usr/bin/python3.9
# ruff: noqa: T201

"""Set UCS configuration parameter"""

import argparse
import base64
import hashlib
import hmac
import json
import sys
import time
from functools import partial
from pathlib import Path
from typing import Any, Optional
from urllib.error import HTTPError
from urllib.request import Request, urlopen

parser = argparse.ArgumentParser(
    description="Set UCS configuration parameter",
    epilog="Either positional argument val or -f must be specified.",
)
group = parser.add_mutually_exclusive_group(required=False)
parser.add_argument("key", help='parameter to change, use "?" to list parameters')
group.add_argument("val", help="value to set", nargs="?")
group.add_argument("-r", dest="from_file", help="read value from file")
parser.add_argument("-a", dest="url", action="append", help="UCS API URL", default="http://127.0.0.1:1234/JSON")
parser.add_argument("-f", dest="force", action="store_true", help="don't ask before change")
parser.add_argument("-k", dest="psk", action="append", help="uAuth PSK", default="file:///etc/uauth/psk")
parser.add_argument("-q", dest="quiet", action="store_true", help="don't print success responses")
parser.add_argument("-u", dest="user", action="append", help="username to execute request as", default="admin")
args = parser.parse_args()

def jwt(key: str = "file:///etc/uauth/psk", user: str = "admin") -> str:
    """Make JWT for uAuth"""
    if key.startswith("file://"):
        key = Path(key.removeprefix("file://")).read_text().rstrip()

    segments = []
    header = {"typ": "JWT", "alg": "HS256"}
    json_header = json.dumps(header, separators=(",", ":"), sort_keys=True).encode("utf8")
    segments.append(base64.urlsafe_b64encode(json_header).replace(b"=", b""))
    payload = {"sub": user, "iat": int(time.time())}
    json_payload = json.dumps(payload, separators=(",", ":")).encode("utf8")
    segments.append(base64.urlsafe_b64encode(json_payload).replace(b"=", b""))
    signing_input = b".".join(segments)
    signature = hmac.new(key.encode("utf-8"), signing_input, hashlib.sha256).digest()
    segments.append(base64.urlsafe_b64encode(signature).replace(b"=", b""))
    return b".".join(segments).decode("utf-8")


class ApiError(Exception):
    """API reported error"""

    def __init__(self, message: str, code: int, data: Optional[dict[str, str]] = None):
        super().__init__(message)
        self.message = message
        self.code = code
        self.data = data


def ucs_rpc(url: str, user: str, key: str, method: str, *args: list[Any]) -> Any:  # noqa: ANN401 - we really don"t know the type
    """Do UCS-JSON-RPC request"""
    url = f'{url.removesuffix("/")}/{method}'
    data = json.dumps([args, {}, {}]).encode("utf-8")
    headers = {
        "Authorization": f"uAuth {jwt(key, user)}",
        "Content-Type": "application/json",
    }
    if not url.startswith(("http://", "https://")):
        message = "API URL must be HTTP or HTTPS"
        raise OSError(message)

    req = Request(url, data, headers)  # noqa: S310 - wtf, we check it above
    try:
        with urlopen(req) as response:  # noqa: S310 - wtf, we check it above
            return json.load(response.fp)

    except HTTPError as error:
        try:
            ucs_error = json.load(error.fp)
            message = [f'UCS FAIL: {ucs_error["message"]}']
            data = ucs_error.get("data")
            if ucs_error["code"] == 0:  # Input validator error
                for k, v in data.items():
                    message.append(f" * {k}: {v}")
            code = min(abs(ucs_error["code"]), 90)
            raise ApiError("\n".join(message), code=code, data=data) from error

        except (ValueError, KeyError):
            message = f"API FAIL: {error}"
            raise ApiError(message, code=98) from error

    except OSError as error:
        message = f"HTTP FAIL: {error}"
        raise ApiError(message, code=99) from error

ucs = partial(ucs_rpc, args.url, args.user, args.psk)

if args.key == "?":
    parameters = ucs("system.describeStruct", "config.parameters.parameters")
    for parameter in sorted(parameters, key=lambda p: p["key"]):
        key = parameter["key"]
        desc = parameter["desc"]
        print(f"{key:<30} {desc:.80}")
    sys.exit(0)


if not args.force:
    confirm = input(f"Are you sure to set {args.key} (y/N): ")
    if confirm not in ("y", "Y"):
        sys.exit(1)

if args.from_file:
    try:
        val = Path(args.from_file).read_text()
    except OSError as error:
        print(error, file=sys.stderr)
        sys.exit(error.errno)
else:
    val = args.val

try:
    message = ucs("config.parameters.set", args.key, val)
except ApiError as error:
    print(error, file=sys.stderr)
    sys.exit(2)

if not args.quiet:
    print(message)
