# mypy: allow-untyped-defs
import importlib

import torch


lib = torch.library.Library("export", "FRAGMENT")  # noqa: TOR901

lib.define(
    "access_subclass_inner_tensor(Tensor src_subclass_tensor, str attr) -> Tensor"
)


@torch.library.impl(lib, "access_subclass_inner_tensor", "Autograd")
# When running under torch.inference_mode(), we seem to skip AUtograd key
# so we should desugar this op as soon as we start tracing to post-dispatch.
@torch.library.impl(lib, "access_subclass_inner_tensor", "Python")
def _access_subclass_inner_tensor(
    src_subclass_tensor: torch.Tensor, attr: str
) -> torch.Tensor:
    from torch.utils._python_dispatch import is_traceable_wrapper_subclass

    assert is_traceable_wrapper_subclass(src_subclass_tensor)
    val = getattr(src_subclass_tensor, attr, None)
    if val is None or not isinstance(val, torch.Tensor):
        raise RuntimeError(
            f"Attribute {attr} is not a tensor or doesn't exist in {src_subclass_tensor}"
        )
    return val


def _call_custom_autograd_function_in_pre_dispatch(function_cls_name, *args, **kwargs):
    """
    Import a custom autograd function by string name and call it. This is pretty bad
    because:
    1) There is no schema

    Ideally we should automatically wrap custom autograd functions with a custom op, but
    that is too much work because we need to schematize custom autograd functions. For now,
    we just hackily put it in the IR.
    """
    # Parse module and class name
    module_name, class_name = function_cls_name.rsplit(".", 1)

    # Import the module and get the class
    module = importlib.import_module(module_name)
    function_cls = getattr(module, class_name)
    assert hasattr(function_cls, "apply")
    return function_cls.apply(*args, **kwargs)
