#!/usr/bin/python3.9
# ruff: noqa: T201

"""Show UCS profiling statistics"""

import argparse
import base64
import hashlib
import hmac
import json
import sys
import time
from functools import partial
from pathlib import Path
from typing import Any, Optional
from urllib.error import HTTPError
from urllib.request import Request, urlopen

parser = argparse.ArgumentParser(
    description="Export part of UCS configuration",
)
parser.add_argument("-a", dest="url", action="append", help="UCS API URL", default="http://127.0.0.1:1234/JSON")
parser.add_argument("-k", dest="key", action="append", help="uAuth PSK", default="file:///etc/uauth/psk")
parser.add_argument("-u", dest="user", action="append", help="username to execute request as", default="admin")
args = parser.parse_args()

def jwt(key: str = "file:///etc/uauth/psk", user: str = "admin") -> str:
    """Make JWT for uAuth"""
    if key.startswith("file://"):
        key = Path(key.removeprefix("file://")).read_text().rstrip()

    segments = []
    header = {"typ": "JWT", "alg": "HS256"}
    json_header = json.dumps(header, separators=(",", ":"), sort_keys=True).encode("utf8")
    segments.append(base64.urlsafe_b64encode(json_header).replace(b"=", b""))
    payload = {"sub": user, "iat": int(time.time())}
    json_payload = json.dumps(payload, separators=(",", ":")).encode("utf8")
    segments.append(base64.urlsafe_b64encode(json_payload).replace(b"=", b""))
    signing_input = b".".join(segments)
    signature = hmac.new(key.encode("utf-8"), signing_input, hashlib.sha256).digest()
    segments.append(base64.urlsafe_b64encode(signature).replace(b"=", b""))
    return b".".join(segments).decode("utf-8")


class ApiError(Exception):
    """API reported error"""

    def __init__(self, message: str, code: int, data: Optional[dict[str, str]] = None):
        super().__init__(message)
        self.message = message
        self.code = code
        self.data = data


def ucs_rpc(url: str, user: str, key: str, method: str, *args: list[Any]) -> Any:  # noqa: ANN401 - we really don"t know the type
    """Do UCS-JSON-RPC request"""
    url = f'{url.removesuffix("/")}/{method}'
    data = json.dumps([args, {}, {}]).encode("utf-8")
    headers = {
        "Authorization": f"uAuth {jwt(key, user)}",
        "Content-Type": "application/json",
    }
    if not url.startswith(("http://", "https://")):
        message = "API URL must be HTTP or HTTPS"
        raise OSError(message)

    req = Request(url, data, headers)  # noqa: S310 - wtf, we check it above
    try:
        with urlopen(req) as response:  # noqa: S310 - wtf, we check it above
            return json.load(response.fp)

    except HTTPError as error:
        try:
            ucs_error = json.load(error.fp)
            message = [f'UCS FAIL: {ucs_error["message"]}']
            data = ucs_error.get("data")
            if ucs_error["code"] == 0:  # Input validator error
                for k, v in data.items():
                    message.append(f" * {k}: {v}")
            code = min(abs(ucs_error["code"]), 90)
            raise ApiError("\n".join(message), code=code, data=data)

        except (ValueError, KeyError):
            message = f"API FAIL: {error}"
            raise ApiError(message, code=98) from error

    except OSError as error:
        message = f"HTTP FAIL: {error}"
        raise ApiError(message, code=99) from error


ucs = partial(ucs_rpc, args.url, args.user, args.key)

try:
    stats = ucs("system.multimeter.stats")
except ApiError as error:
    sys.exit(str(error))


def key_value(data: dict[str, Any], mapping: dict[str, str]) -> list[str]:
    """Align key value pairs into columns"""
    longest_title = max(len(k) for k in mapping)
    longest_value = max(len(str(int(d))) for d in data.values() if isinstance(d, (int, float)))
    out = []
    for title, key in mapping.items():
        if isinstance(key, tuple):
            key, retype, unit = key  # noqa: PLW2901
            value = data[key]
            value = retype(value)
        else:
            value = data[key]
            unit = ""
        value_format = f"{longest_value}d" if isinstance(value, int) else f"{longest_value + 4}.3f"
        out.append(f"{title:>{longest_title}}: {value:{value_format}}{unit}")
    return out


def table(data: list[list[Any]], header: list[tuple[str, type, str]]) -> list[str]:
    """Align values into table"""
    widths = []
    for i, (title, retype, unit) in enumerate(header):
        if retype is float:
            retype = int  # noqa: PLW2901
        max_value_len = max(len(str(retype(row[i])) + unit) for row in data) if data else 0
        title_len = len(title)
        widths.append(max(max_value_len, title_len))

    out = []

    line = []
    for i, (title, _retype, unit) in enumerate(header):
        line.append(f"{title:^{widths[i] + len(unit)}}")
    out.append(" ".join(line))

    for row in data:
        line = []
        for i, value in enumerate(row):
            _title, retype, unit = header[i]
            if retype is int:
                line.append(f"{value:{widths[i]}d}{unit}")
            elif retype is float:
                line.append(f"{value:{widths[i]}.3f}{unit}")
            else:
                line.append(f"{value:<{widths[i]}}{unit}")
        out.append(" ".join(line))

    return out


def profile_api(stats: dict[str, dict[str, Any]], top: int = 10) -> list[str]:
    """Process profiler data for API"""
    data = []
    most_used = sorted(stats["profile"].items(), key=lambda x: x[1][0], reverse=True)[:top]
    most_used = sorted(most_used, key=lambda x: x[1][1], reverse=True)
    for method, (count, total_time, _slow_count) in most_used:
        data.append((method, total_time, count, total_time / count))
    return data


def profile_callback(stats: dict[str, dict[str, Any]], top: int = 10, *, details: bool = True) -> list[str]:
    """Process profiler data for UCS internal callback queue"""
    events = []
    profile = stats["profile"]
    for event, callbacks in profile.items():
        total_calls = 0
        total_time = 0
        total_slow = 0
        for calls, callback_time, slow_count in callbacks.values():
            total_calls = calls
            total_time += callback_time
            total_slow += slow_count

        events.append((event, total_time, total_calls, total_time / total_calls))

    data = []
    for row in sorted(events, key=lambda x: x[1], reverse=True)[:top]:
        data.append(row)

        if details:
            event = row[0]
            callbacks = profile[event]
            for callback, (calls, callback_time, slow_count) in callbacks.items():
                data.append((f"  {callback}", callback_time, -slow_count, callback_time / calls))

    return data


print(f'{" SQL ":=^80}')
print("\n".join(key_value(stats["sql"], {
    "Total queries": "count",
    "SELECT count": "select",
    "INSERT count": "insert",
    "UPDATE count": "update",
    "DELETE count": "delete",
    "Slow query count": "slow_count",
    "Slowest query last minute": ("slowest_1", float, " s"),
    "Slowest query last 5 minutes": ("slowest_5", float, " s"),
    "Slowest query last 15 minutes": ("slowest_15", float, " s"),
})))
print()

print(f'{" API ":=^80}')
print("\n".join(key_value(stats["api"], {
    "Total calls": "count",
    "Slow call count": "slow_count",
    "Slowest call last minute": ("slowest_1", float, " s"),
    "Slowest call last 5 minutes": ("slowest_5", float, " s"),
    "Slowest call last 15 minutes": ("slowest_15", float, " s"),
})))
top = 10
data = profile_api(stats["api"], top)
title = f"Top {top} methods which took most time to execute"
print()
print(f"{title:-^80}")
print("\n".join(table(data, (
    ("Method", str, ""),
    ("Total time", float, " s"),
    ("Count", int, ""),
    ("Avg. time", float, " s"),
))))
print()

print(f'{" Callbacks ":=^80}')
print("\n".join(key_value(stats["callback"], {
    "Total events": "count",
    "Slow call count": "slow_count",
    "Slowest call last minute": ("slowest_1", float, " s"),
    "Slowest call last 5 minutes": ("slowest_5", float, " s"),
    "Slowest call last 15 minutes": ("slowest_15", float, " s"),
})))
top = 10
data = profile_callback(stats["callback"], top)
title = f"Top {top} callbacks which took most time to execute"
print()
print(f"{title:-^80}")
print("\n".join(table(data, (
    ("Event", str, ""),
    ("Total time", float, " s"),
    ("Count", int, ""),
    ("Avg. time", float, " s"),
))))
print()
print("Count column contains total number of event occurences and")
print("number of slow calls for callback (denoted by minus sign).")
