import asyncio
import logging
import uuid

import structlog
from asgiref.sync import iscoroutinefunction, markcoroutinefunction
from django.core.exceptions import PermissionDenied
from django.http import Http404, StreamingHttpResponse
from asgiref import sync

from .. import signals
from ..app_settings import app_settings

logger = structlog.getLogger(__name__)


def get_request_header(request, header_key, meta_key):
    if hasattr(request, "headers"):
        return request.headers.get(header_key)

    return request.META.get(meta_key)


def sync_streaming_content_wrapper(streaming_content, context):
    with structlog.contextvars.bound_contextvars(**context):
        logger.info("streaming_started")
        try:
            for chunk in streaming_content:
                yield chunk
        except Exception:
            logger.exception("streaming_failed")
        else:
            logger.info("streaming_finished")


async def async_streaming_content_wrapper(streaming_content, context):
    with structlog.contextvars.bound_contextvars(**context):
        logger.info("streaming_started")
        try:
            async for chunk in streaming_content:
                yield chunk
        except asyncio.CancelledError:
            logger.warning("streaming_cancelled")
            raise
        except Exception:
            logger.exception("streaming_failed")
        else:
            logger.info("streaming_finished")


class RequestMiddleware:
    """``RequestMiddleware`` adds request metadata to ``structlog``'s logger context automatically.

    >>> MIDDLEWARE = [
    ...     # ...
    ...     'django_structlog.middlewares.RequestMiddleware',
    ... ]

    """

    sync_capable = True
    async_capable = True

    def __init__(self, get_response):
        self.get_response = get_response
        if iscoroutinefunction(self.get_response):
            markcoroutinefunction(self)

    def __call__(self, request):
        if iscoroutinefunction(self):
            return self.__acall__(request)
        self.prepare(request)
        response = self.get_response(request)
        self.handle_response(request, response)
        return response

    async def __acall__(self, request):
        await sync.sync_to_async(self.prepare)(request)
        try:
            response = await self.get_response(request)
        except asyncio.CancelledError:
            logger.warning("request_cancelled")
            raise
        await sync.sync_to_async(self.handle_response)(request, response)
        return response

    def handle_response(self, request, response):
        if not hasattr(request, "_raised_exception"):
            self.bind_user_id(request)
            context = structlog.contextvars.get_merged_contextvars(logger)

            log_kwargs = dict(
                code=response.status_code,
                request=self.format_request(request),
            )
            signals.bind_extra_request_finished_metadata.send(
                sender=self.__class__,
                request=request,
                logger=logger,
                response=response,
                log_kwargs=log_kwargs,
            )
            if response.status_code >= 500:
                level = logging.ERROR
            elif response.status_code >= 400:
                level = app_settings.STATUS_4XX_LOG_LEVEL
            else:
                level = logging.INFO
            logger.log(
                level,
                "request_finished",
                **log_kwargs,
            )
            if isinstance(response, StreamingHttpResponse):
                streaming_content = response.streaming_content
                try:
                    iter(streaming_content)
                except TypeError:
                    response.streaming_content = async_streaming_content_wrapper(
                        streaming_content, context
                    )
                else:
                    response.streaming_content = sync_streaming_content_wrapper(
                        streaming_content, context
                    )

        else:
            exception = getattr(request, "_raised_exception")
            delattr(request, "_raised_exception")
            signals.update_failure_response.send(
                sender=self.__class__,
                request=request,
                response=response,
                logger=logger,
                exception=exception,
            )
        structlog.contextvars.clear_contextvars()

    def prepare(self, request):
        from ipware import get_client_ip

        request_id = get_request_header(
            request, "x-request-id", "HTTP_X_REQUEST_ID"
        ) or str(uuid.uuid4())
        correlation_id = get_request_header(
            request, "x-correlation-id", "HTTP_X_CORRELATION_ID"
        )
        structlog.contextvars.bind_contextvars(request_id=request_id)
        self.bind_user_id(request)
        if correlation_id:
            structlog.contextvars.bind_contextvars(correlation_id=correlation_id)
        ip, _ = get_client_ip(request)
        structlog.contextvars.bind_contextvars(ip=ip)
        log_kwargs = {
            "request": self.format_request(request),
            "user_agent": request.META.get("HTTP_USER_AGENT"),
        }
        signals.bind_extra_request_metadata.send(
            sender=self.__class__, request=request, logger=logger, log_kwargs=log_kwargs
        )
        logger.info("request_started", **log_kwargs)

    @staticmethod
    def format_request(request):
        return f"{request.method} {request.get_full_path()}"

    @staticmethod
    def bind_user_id(request):
        user_id_field = app_settings.USER_ID_FIELD
        if hasattr(request, "user") and request.user is not None and user_id_field:
            user_id = None
            if hasattr(request.user, user_id_field):
                user_id = getattr(request.user, user_id_field)
                if isinstance(user_id, uuid.UUID):
                    user_id = str(user_id)
            structlog.contextvars.bind_contextvars(user_id=user_id)

    def process_exception(self, request, exception):
        if isinstance(exception, (Http404, PermissionDenied)):
            # We don't log an exception here, and we don't set that we handled
            # an error as we want the standard `request_finished` log message
            # to be emitted.
            return

        setattr(request, "_raised_exception", exception)
        self.bind_user_id(request)
        log_kwargs = dict(
            code=500,
            request=self.format_request(request),
        )
        signals.bind_extra_request_failed_metadata.send(
            sender=self.__class__,
            request=request,
            logger=logger,
            exception=exception,
            log_kwargs=log_kwargs,
        )
        logger.exception(
            "request_failed",
            **log_kwargs,
        )
