import abc
import builtins
import importlib
import inspect
import logging
import pickle
import types
from dataclasses import dataclass
from typing import Any, Callable, Optional

import torch
import torch.fx
from torch._dynamo.precompile_context import PrecompileContext

from . import convert_frame
from .hooks import Hooks


log = logging.getLogger(__name__)


class SerializableCallable(abc.ABC):
    @classmethod
    @abc.abstractmethod
    def serialize_compile_artifacts(cls, fn: Any) -> bytes:
        pass

    @classmethod
    @abc.abstractmethod
    def deserialize_compile_artifacts(cls, data: bytes) -> Any:
        pass


def bind_locals(
    signature: inspect.Signature, *args: Any, **kwargs: Any
) -> dict[str, Any]:
    bound_arguments = signature.bind(*args, **kwargs)
    bound_arguments.apply_defaults()
    return bound_arguments.arguments


@dataclass
class CompileArtifacts:
    signature: inspect.Signature
    bytecode: types.CodeType
    guard_manager: Optional[torch._dynamo.guards.GuardManagerWrapper]
    guards_state: bytes
    import_sources: dict[str, str]
    backend_id: str
    compiled_fn: SerializableCallable
    original_code: types.CodeType
    closure: Optional[tuple[Any, ...]]


@dataclass
class AOTCompiledFunction:
    _artifacts: CompileArtifacts

    def guard_check(self, *args: Any, **kwargs: Any) -> bool:
        f_locals = bind_locals(self._artifacts.signature, *args, **kwargs)
        assert self._artifacts.guard_manager is not None
        return self._artifacts.guard_manager.check(f_locals)

    def __post_init__(self) -> None:
        import_sources = {
            alias: importlib.import_module(module_name)
            for alias, module_name in self._artifacts.import_sources.items()
        }
        f_globals = {
            **import_sources,
            self._artifacts.backend_id: self._artifacts.compiled_fn,
        }
        self.fn = types.FunctionType(
            self._artifacts.bytecode, f_globals, closure=self._artifacts.closure
        )

        if self._artifacts.guard_manager is None:
            guards_state = pickle.loads(self._artifacts.guards_state)
            self._artifacts.guard_manager = torch._dynamo.guards.CheckFunctionManager(
                self._artifacts.original_code,
                guards_state.output_graph,
                shape_code_parts=guards_state.shape_code_parts,
                runtime_global_scope=f_globals,
            ).guard_manager

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        assert self._artifacts.guard_manager is not None
        if not self.guard_check(*args, **kwargs):
            f_locals = bind_locals(self._artifacts.signature, *args, **kwargs)
            reason = str(self._artifacts.guard_manager.check_verbose(f_locals))
            raise RuntimeError(f"GuardManager check failed, reason: {reason}")
        return self.fn(*args, **kwargs)

    def save_compiled_function(self, path: str) -> None:
        with open(path, "wb") as f:
            f.write(type(self).serialize(self))

    @classmethod
    def serialize(cls, fn: "AOTCompiledFunction") -> bytes:
        from torch._dynamo.package import SerializedCode

        state = fn._artifacts.__dict__.copy()
        state["guard_manager"] = None
        state["bytecode"] = SerializedCode.from_code_object(state["bytecode"])
        compiled_fn = state["compiled_fn"]
        state["compiled_fn"] = (
            type(compiled_fn).deserialize_compile_artifacts,
            type(compiled_fn).serialize_compile_artifacts(compiled_fn),
        )
        state["original_code"] = SerializedCode.from_code_object(state["original_code"])
        return pickle.dumps(state)

    @classmethod
    def deserialize(cls, data: bytes) -> "AOTCompiledFunction":
        from torch._dynamo.package import SerializedCode

        state = pickle.loads(data)
        state["bytecode"] = SerializedCode.to_code_object(state["bytecode"])
        deserializer, compiled_fn_state = state["compiled_fn"]
        state["compiled_fn"] = deserializer(compiled_fn_state)
        state["original_code"] = SerializedCode.to_code_object(state["original_code"])

        artifacts = CompileArtifacts(**state)
        return cls(artifacts)


class BundledAOTAutogradSerializableCallable(SerializableCallable):
    """
    Represents a serializable callable generated by compile_fx.
    This class wraps around the compiled function generated by AOTAutograd.

    TODO: Instead of using PrecompileContext to grab it from AOTAutograd,
    this object should be what's *returned* by aot_module_simplified.
    We'll do that refactor in a later PR.
    """

    def __init__(self, artifact: Any) -> None:
        """
        Takes in a BundledAOTAutogradCacheArtifact, which is the serialized form
        of a compiled function generated by AOTAutograd.
        """

        self.compiled_fn = artifact.after_deserialization()
        self.data = artifact.content

    def __getattr__(self, attr: Any) -> Any:
        if hasattr(self, attr):
            return getattr(super(), attr)
        else:
            return getattr(self.compiled_fn, attr)

    @classmethod
    def from_backend_id(
        cls, backend_id: str
    ) -> "BundledAOTAutogradSerializableCallable":
        """
        Takes in a backend_id, and returns a BundledAOTAutogradSerializableCallable
        that wraps around the compiled function generated by AOTAutograd.
        """
        artifact = PrecompileContext.serialize_artifact_by_key(backend_id)
        if artifact is None:
            raise RuntimeError("No artifact found for backend_id: " + backend_id)
        return cls(artifact)

    @classmethod
    def serialize_compile_artifacts(
        cls, fn: "BundledAOTAutogradSerializableCallable"
    ) -> bytes:
        return fn.data

    @classmethod
    def deserialize_compile_artifacts(cls, data: bytes) -> Any:
        from torch._functorch._aot_autograd.autograd_cache import (
            BundledAOTAutogradCacheArtifact,
        )

        # The key in the artifact is not important here since we're not populating a cache,
        # we just want to grab the callable back out of the serialized entry
        artifact = BundledAOTAutogradCacheArtifact("", data)
        return cls(artifact)

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        return self.compiled_fn(*args, **kwargs)


def aot_compile_fullgraph(
    model: Any,
    example_inputs: tuple[tuple[Any, ...], dict[str, Any]],
    hooks: Hooks,
    backend: Callable[[torch.fx.GraphModule, list[torch.Tensor]], SerializableCallable],
) -> AOTCompiledFunction:
    from torch._dynamo.guards import CheckFunctionManager
    from torch._dynamo.utils import dynamo_timed, get_metrics_context
    from torch._guards import compile_context, CompileContext, TracingContext

    args, kwargs = example_inputs
    if hasattr(model, "__self__"):
        fn = model.__func__
        args = (model.__self__,) + args
    elif inspect.isfunction(model):
        fn = model
    else:
        raise RuntimeError(f"Unsupported model code type {model}")

    signature = inspect.signature(fn)
    f_locals = bind_locals(signature, *args, **kwargs)
    if fn.__code__.co_freevars or fn.__closure__:
        assert len(fn.__closure__) == len(fn.__code__.co_freevars)
        f_locals.update(
            {
                name: cell.cell_contents
                for name, cell in zip(fn.__code__.co_freevars, fn.__closure__)
            }
        )

    with (
        compile_context(CompileContext(convert_frame.get_compile_id({}))),
        get_metrics_context(),
        dynamo_timed("fullgraph_capture"),
    ):
        capture_output = convert_frame.fullgraph_capture(
            convert_frame.FrameInfo(
                fn.__code__,
                fn.__globals__,
                f_locals,
                builtins.__dict__,
                closure=fn.__closure__ or (),  # type: ignore[arg-type]
            )
        )
        dynamo_output = capture_output.dynamo_output

        if not hooks.guard_filter_fn:
            from torch._dynamo.types import GuardFilterEntry

            def new_guard_filter_fn(
                guard_entries: list[GuardFilterEntry],
            ) -> list[bool]:
                return [
                    (
                        not (
                            g.is_global
                            or g.guard_type
                            in CheckFunctionManager.UNSUPPORTED_SERIALIZATION_GUARD_TYPES
                        )
                    )
                    for g in guard_entries
                ]

            hooks.guard_filter_fn = new_guard_filter_fn

        check_fn = dynamo_output.build_guards(
            fn.__code__, hooks=hooks, save=True, strict_error=True
        )

        assert check_fn.guards_state is not None

    backend_input = capture_output.backend_input
    backend_input.graph_module._backend_id = backend_input.backend_id  # type: ignore[assignment]
    output_graph = dynamo_output.tracer_output.output_graph
    assert output_graph is not None
    import_sources = output_graph.import_sources
    with (
        torch._guards.tracing(TracingContext(backend_input.fake_mode)),
        torch._functorch.config.patch("bundled_autograd_cache", True),
    ):
        compiled_fn = backend(backend_input.graph_module, backend_input.example_inputs)

    # If Inductor backend is used, grab the compiled_fn from PrecompileContext
    # TODO: this should be replaced once we make the backend return the SerializableCallable directly.
    if isinstance(backend, torch._TorchCompileInductorWrapper):
        compiled_fn = BundledAOTAutogradSerializableCallable.from_backend_id(
            backend_input.backend_id
        )

    if not isinstance(compiled_fn, SerializableCallable):
        if hasattr(backend, "compiler_fn"):
            compiler_fn = backend.compiler_fn
        else:
            compiler_fn = backend
        raise RuntimeError(
            f"Compiled function type {type(compiled_fn)} (produced "
            + f"from backend {compiler_fn}) does not implement SerializableCallable."
        )

    artifacts = CompileArtifacts(
        signature=signature,
        bytecode=dynamo_output.bytecode,
        guard_manager=check_fn.guard_manager,
        guards_state=check_fn.guards_state,
        import_sources=import_sources,
        backend_id=backend_input.backend_id,
        compiled_fn=compiled_fn,
        original_code=fn.__code__,
        closure=fn.__closure__,
    )
    aot_compiled_fn = AOTCompiledFunction(_artifacts=artifacts)
    return aot_compiled_fn
