# ************************************************************
# Copyright © 2003-2025 Acronis International GmbH.
# This source code is distributed under MIT software license.
# ************************************************************

import logging
import sqlite3
import jwt
import json
from base64 import b64decode
from api_client import ApiClient
from uuid import uuid4
from aiohttp import web, ClientSession
from typing import Optional, TypedDict, Literal
from datetime import datetime
from datatypes import *
from constants import JWT_AUD, APPCODE
from server.callbacks import CALLBACKS_MAPPING
from utils import verify

class JWK(TypedDict):
    use: str
    kty: Literal['RSA']
    kid: str
    alg: Literal['RSA256', 'HSA256']
    n: str
    e: str

def _get_mirrored_organization(db: sqlite3.Connection, tenant_id: str) -> Optional[dict]:
    row = db.execute('SELECT id FROM organizations WHERE external_id = ?', (tenant_id,)).fetchone()
    if not row:
        return
    return dict(row)

def _get_authenticated_user(db: sqlite3.Connection, identity: str, password: str) -> Optional[dict]:
    row = db.execute('SELECT id, organization_id, password FROM users WHERE login = ? AND password IS NOT NULL', (identity.lower(),)).fetchone()
    verify(row['password'], password)
    return { 'id': row['id'], 'organization_id': row['organization_id'] }

async def _is_valid_jwt_token(app: web.Application, authorization: str, endpoint_id: str, retry: bool = True) -> bool:
    token = authorization[len('Bearer '):]
    header = jwt.get_unverified_header(token)

    jwk = None
    for item in app['jwks']['keys']:
        if item['kid'] == header['kid'] and item['alg'] == header['alg']:
            jwk = item
            break

    if not jwk:
        if retry:
            async with ClientSession() as client:
                res = await client.get(f'https://{JWT_AUD}/api/idp/v1/keys')
                res.raise_for_status()
                app['jwks'] = await res.json()
            return await _is_valid_jwt_token(app, authorization, endpoint_id, retry=False)

        logging.error(f"Failed to find corresponding JWK: {header['kid']} with algorithm {header['alg']}")
        return False

    jwk = jwt.algorithms.RSAAlgorithm.from_jwk(jwk)

    try:
        # Decode the token using the JWK
        decoded = jwt.decode(token, jwk, algorithms=["RS256"], audience=JWT_AUD)
    except jwt.PyJWTError as e:
        logging.error(f"Error decoding token: {e}")
        return False

    if datetime.utcnow().timestamp() > decoded['exp']:
        logging.error("Token is expired.")
        return False

    if endpoint_id != decoded['scope'][0]['role']:
        logging.error(f"Context endpoint ID does not match the token scope. Received: {decoded['scope'][0]['role']}, expected: {endpoint_id}")
        return False

    return True

async def index(_: web.Request) -> web.Response:
    return web.Response(text='Hello there! Send POST requests to the /callback endpoint.', content_type='text/plain')


async def callback_handler_mirroring(request: web.Request) -> web.Response:
    if not request.content_type.startswith('application/json'):
        logging.error(f'Received non-JSON request')
        return web.Response(status=400)

    response_id = str(uuid4())
    data: CallbackRequest = await request.json()
    logging.info(f'Received data {data}')
    try:
        request_id = data['request_id']
        callback_id = data['context']['callback_id']
        if callback_id not in CALLBACKS_MAPPING:
            logging.error(f'Callback {callback_id} not found.')
            return web.json_response(status=400, data={'request_id': request_id, 'response_id': response_id, 'message': 'Callback not found.'})
        endpoint_id = data['context']['endpoint_id']
    except Exception as e:
        logging.error(f'Received malformed callback request.')
        logging.exception(e)
        return web.json_response(status=400, data={'response_id': response_id, 'message': 'Received malformed callback request.'})

    if 'Authorization' in request.headers:
        if not await _is_valid_jwt_token(request.app, request.headers['Authorization'], endpoint_id):
            return web.json_response(status=401, data={'request_id': request_id, 'request_id': request_id, 'response_id': response_id, 'message': 'Failed to validate authorization token.'})
    else:
        return web.json_response(status=401, data={'request_id': request_id, 'response_id': response_id, 'message': 'Missing authorization token.'})

    with request.app['db'] as conn:
        row = _get_mirrored_organization(conn, data['context']['tenant_id'])
        organization_id = row['id'] if row else None
        payload = data.get('payload', {})
        try:
            res = CALLBACKS_MAPPING[callback_id](conn, organization_id, request_id, response_id, data['context'], payload)
            logging.info(f'Response data: {res.body}')
        except Exception as e:
            res = web.json_response(status=500, data={'request_id': request_id, 'response_id': response_id, 'message': f'Failed to make proper response.'})
            logging.error(f'Failed to make proper response.')
            logging.exception(e)
    return res


async def callback_handler_mapping(request: web.Request) -> web.Response:
    if not request.content_type.startswith('application/json'):
        logging.error(f'Received non-JSON request')
        return web.Response(status=400)

    response_id = str(uuid4())
    data: CallbackRequest = await request.json()
    logging.info(f'Received data {data}')
    try:
        request_id = data['request_id']
        callback_id = data['context']['callback_id']
        if callback_id not in CALLBACKS_MAPPING:
            logging.error(f'Callback {callback_id} not found.')
            return web.json_response(status=400, data={'request_id': request_id, 'response_id': response_id, 'message': 'Callback not found.'})
        endpoint_id = data['context']['endpoint_id']
    except Exception as e:
        logging.error(f'Received malformed callback request.')
        logging.exception(e)
        return web.json_response(status=400, data={'response_id': response_id, 'message': 'Received malformed callback request.'})

    if 'Authorization' in request.headers:
        if not await _is_valid_jwt_token(request.app, request.headers['Authorization'], endpoint_id):
            return web.json_response(status=401, data={'request_id': request_id, 'response_id': response_id, 'message': 'Failed to validate authorization token.'})
    else:
        return web.json_response(status=401, data={'request_id': request_id, 'response_id': response_id, 'message': 'Missing authorization token.'})

    try:
        raw_creds = b64decode(request.headers['X-CyberApp-Auth']).decode()
        sep_idx = raw_creds.index(':')
        identity, secrets = [raw_creds[:sep_idx], json.loads(raw_creds[sep_idx + 1:])]
        # extra = json.loads(b64decode(request.headers['X-CyberApp-Extra']).decode())

        with request.app['db'] as conn:
            row = _get_authenticated_user(conn, identity, secrets['password'])
    except Exception as e:
        logging.info(f'Failed to authenticate user. Reason: {e}')
        logging.exception(e)
        return web.json_response(status=401, data={'request_id': request_id, 'response_id': response_id, 'message': f'Failed to authenticate user.'})

    payload = data.get('payload', {})
    with request.app['db'] as conn:
        try:
            res = CALLBACKS_MAPPING[callback_id](conn, row['organization_id'], request_id, response_id, data['context'], payload)
            logging.info(f'Response data: {res.body}')
        except Exception as e:
            res = web.json_response(status=500, data={'request_id': request_id, 'response_id': response_id, 'message': f'Failed to make proper response. Reason: {e}'})
            logging.info(f'Failed to make proper response. Reason: {e}')
            logging.exception(e)
    return res


async def ott_login(request: web.Request) -> web.Response:
    client: ApiClient = request.app['client']

    if 'tenant_id' not in request.query or 'external_id' not in request.query:
        return web.Response(status=400, text='Missing tenant_id or external_id in the query string.')

    try:
        await client.authenticate()
    except Exception as e:
        logging.info('Failed to authenticate')
        logging.exception(e)
        return web.Response(status=401)

    location = await client.service_user_login(APPCODE, request.query['tenant_id'], request.query['external_id'])
    return web.Response(status=302, headers={'Location': location})
