#!/usr/bin/python3.9

"""Call UCS API method"""

import argparse
import base64
import hashlib
import hmac
import json
import sys
import tempfile
import time
from pathlib import Path
from typing import Any, Optional
from urllib.error import HTTPError
from urllib.request import Request, urlopen


def jwt(key: str = "file:///etc/uauth/psk", user: str = "admin") -> str:
    """Make JWT for uAuth"""
    if key.startswith("file://"):
        key = Path(key.removeprefix("file://")).read_text().rstrip()

    segments = []
    header = {"typ": "JWT", "alg": "HS256"}
    json_header = json.dumps(header, separators=(",", ":"), sort_keys=True).encode("utf8")
    segments.append(base64.urlsafe_b64encode(json_header).replace(b"=", b""))
    payload = {"sub": user, "iat": int(time.time())}
    json_payload = json.dumps(payload, separators=(",", ":")).encode("utf8")
    segments.append(base64.urlsafe_b64encode(json_payload).replace(b"=", b""))
    signing_input = b".".join(segments)
    signature = hmac.new(key.encode("utf-8"), signing_input, hashlib.sha256).digest()
    segments.append(base64.urlsafe_b64encode(signature).replace(b"=", b""))
    return b".".join(segments).decode("utf-8")


class ApiError(Exception):
    """API reported error"""

    def __init__(self, message: str, code: int, data: Optional[dict[str, str]] = None):
        super().__init__(message)
        self.message = message
        self.code = code
        self.data = data


def ucs_rpc(url: str, user: str, key: str, method: str, *args: list[Any]) -> Any:  # noqa: ANN401 - we really don't know the type
    """Do UCS-JSON-RPC request"""
    url = f'{url.removesuffix("/")}/{method}'
    data = json.dumps([args, {}, {}]).encode("utf-8")
    headers = {
        "Authorization": f"uAuth {jwt(key, user)}",
        "Content-Type": "application/json",
    }
    if not url.startswith(("http://", "https://")):
        message = "API URL must be HTTP or HTTPS"
        raise OSError(message)

    req = Request(url, data, headers)  # noqa: S310 - wtf, we check it above
    try:
        with urlopen(req) as response:  # noqa: S310 - wtf, we check it above
            return json.load(response.fp)

    except HTTPError as error:
        try:
            ucs_error = json.load(error.fp)
            message = [f'UCS FAIL: {ucs_error["message"]}']
            data = ucs_error.get("data")
            if ucs_error["code"] == 0:  # Input validator error
                for k, v in data.items():
                    message.append(f" * {k}: {v}")
            code = min(abs(ucs_error["code"]), 90)
            raise ApiError("\n".join(message), code=code, data=data)

        except (ValueError, KeyError):
            message = f"API FAIL: {error}"
            raise ApiError(message, code=98) from error

    except OSError as error:
        message = f"HTTP FAIL: {error}"
        raise ApiError(message, code=99) from error


parser = argparse.ArgumentParser(
    description="Call UCS API method",
    epilog="If -j argument is used all arguments are considered to be JSON encoded.",
)
parser.add_argument("method", help="method to execute")
parser.add_argument("arg", help="arguments", nargs="*")
parser.add_argument("-a", dest="url", help="UCS API URL", default="http://127.0.0.1:1234/JSON")
parser.add_argument(
    "-d",
    dest="display",
    action="store_true",
    help="display method output even for successful requests",
)
parser.add_argument("-f", dest="force", action="store_true", help="don't ask before executing")
parser.add_argument("-j", dest="json", action="store_true", help="arguments are JSON encoded")
parser.add_argument("-k", dest="key", help="uAuth PSK", default="file:///etc/uauth/psk")
parser.add_argument("-p", dest="pretty", action="store_true", help="pretty print JSON output")
parser.add_argument("-q", dest="quiet", action="store_true", help="don't print success responses")
parser.add_argument("-Q", dest="quiet_errors", action="store_true", help="don't print error responses")
parser.add_argument(
    "-s",
    dest="save",
    action="append",
    help="filename to save result in case output is file (can be specified multiple times)",
)
parser.add_argument("-u", dest="user", help="username to execute request as", default="admin")
args = parser.parse_args()

if args.json:
    for i, arg in enumerate(args.arg):
        try:
            args.arg[i] = json.loads(arg)
        except json.decoder.JSONDecodeError as error:  # noqa: PERF203 - we need to check arguments separately to report index of failed one
            sys.stderr.write(f"Invalid argument {i + 1}: {error}\n")
            sys.exit(1)

if not args.force:
    args_repr = ", ".join([repr(arg) for arg in args.arg])
    confirm = input(f"Are you sure to execute {args.method}({args.user}, {args_repr}) (y/N): ")
    if confirm not in ("y", "Y"):
        sys.exit(1)

try:
    data = ucs_rpc(args.url, args.user, args.key, args.method, *args.arg)
except ApiError as error:
    if not args.quiet_errors:
        sys.stderr.write(error.message)
        sys.stderr.write("\n")
    sys.exit(error.code)
except Exception as error:  # noqa: BLE001 - we would like to show nice error to user in case of any error
    if not args.quiet_errors:
        sys.stderr.write(f"FAIL: {error}\n")
    sys.exit(100)

if args.save:
    _arg_types, return_type = ucs_rpc(args.url, args.user, args.key, "system.methodTyping", args.method)
    if return_type["type"] == "base64":
        data = base64.b64decode(data)
        with open(args.save.pop(), "w+b") if args.save else tempfile.NamedTemporaryFile(delete=False) as file:
            file.write(data)
        sys.exit(0)

    if return_type["type"] == "struct":
        for key, value_type in return_type["items"].items():
            if value_type["type"] == "base64":
                data = base64.b64decode(data[key])
                with open(args.save.pop(), "w+b") if args.save else tempfile.NamedTemporaryFile(delete=False) as file:
                    file.write(data)
        sys.exit(0)

if not args.quiet:
    indent = 2 if args.pretty else None
    json.dump(data, sys.stdout, indent=indent)
    sys.stdout.write("\n")

sys.exit(0)
