import collections
import logging
from collections import defaultdict
from typing import Any, Callable, Optional

import torch
import torch.distributed as dist
import torch.utils._pytree as pytree
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import detect_fake_mode
from torch._inductor.runtime.runtime_utils import dynamo_timed
from torch._logging import trace_structured
from torch.fx.experimental.proxy_tensor import make_fx
from torch.utils._ordered_set import OrderedSet


logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def bucket_cap_mb_by_bucket_idx_default(bucket_id: int) -> float:
    """
    Determine the size of a bucket based on its ID.

    Args:
    bucket_id (int): The ID of the bucket.

    Returns:
    float: The size of the bucket.
    """
    return 2000.0


def bucket_all_gather(
    gm: torch.fx.GraphModule,
    bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
    mode: Optional[str] = None,
) -> None:
    if bucket_cap_mb_by_bucket_idx is None:
        from torch._inductor.fx_passes.bucketing import (
            bucket_cap_mb_by_bucket_idx_default,
        )

        bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
    ag_buckets = bucket_all_gather_by_mb(gm, bucket_cap_mb_by_bucket_idx)
    if len(ag_buckets) == 0:
        return
    merge_all_gather(gm, ag_buckets, mode)


def bucket_reduce_scatter(
    gm: torch.fx.GraphModule,
    bucket_cap_mb_by_bucket_idx: Optional[Callable[[int], float]] = None,
    mode: Optional[str] = None,
) -> None:
    if bucket_cap_mb_by_bucket_idx is None:
        from torch._inductor.fx_passes.bucketing import (
            bucket_cap_mb_by_bucket_idx_default,
        )

        bucket_cap_mb_by_bucket_idx = bucket_cap_mb_by_bucket_idx_default
    rs_buckets = bucket_reduce_scatter_by_mb(gm, bucket_cap_mb_by_bucket_idx)
    if len(rs_buckets) == 0:
        return
    merge_reduce_scatter(gm, rs_buckets, mode)


def is_all_gather_into_tensor(node: torch.fx.Node) -> bool:  # type: ignore[arg-type]
    return (
        node.op == "call_function"
        and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default
    )


def is_reduce_scatter_tensor(node: torch.fx.Node) -> bool:
    return (
        node.op == "call_function"
        and node.target == torch.ops._c10d_functional.reduce_scatter_tensor.default
    )


def is_wait_tensor(node: torch.fx.Node) -> bool:
    return (
        node.op == "call_function"
        and node.target == torch.ops._c10d_functional.wait_tensor.default
    )


def is_wait_tensor_from_all_gather_into_tensor(node: torch.fx.Node) -> bool:
    return is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0])  # type: ignore[arg-type]


def collect_node_descendants(
    graph: torch.fx.Graph,
) -> dict[torch.fx.Node, OrderedSet[torch.fx.Node]]:
    """
    Collects the descendants of each node in the graph.
    Args:
        graph (torch.fx.Graph): The graph to collect descendants from.
    Returns:
        dict[torch.fx.Node, OrderedSet[torch.fx.Node]]: A dictionary mapping each node to its descendants.
    """
    node_descendants: dict[torch.fx.Node, OrderedSet[torch.fx.Node]] = (
        collections.defaultdict(OrderedSet)
    )
    outdegree = collections.defaultdict(int)
    queue = []

    for node in graph.nodes:
        n_outdegree = len(node.users)
        if n_outdegree == 0:
            queue.append(node)
        else:
            outdegree[node] = len(node.users)

    while queue:
        node = queue.pop()
        for input_node in node.all_input_nodes:
            node_descendants[input_node] |= node_descendants[node]
            node_descendants[input_node].add(node)
            outdegree[input_node] -= 1

            if outdegree[input_node] == 0:
                queue.append(input_node)

    return node_descendants


def greedy_bucket_collective_by_mb(
    gm: torch.fx.GraphModule,
    bucket_cap_mb_by_bucket_idx: Callable[[int], float],
    filter_node: Callable[[torch.fx.Node], bool],
    node_group_key: Callable[[torch.fx.Node], Any],
    filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> list[list[torch.fx.Node]]:
    """
    Bucketing adjacent collectives with equal node_group_key.
    We can not bucket non adjacent collectives,
    as this will effectively change the order of collectives.
    Reordering can lead to different order on different ranks.
    """
    g = gm.graph
    found_candidates = False
    for node in g.nodes:
        if filter_node(node):
            found_candidates = True
            break
    if not found_candidates:
        return []

    # TODO: pearce kelly algorithm for detecting cycles
    node_descendents = collect_node_descendants(gm.graph)

    nodes_groups: list[list[torch.fx.Node]] = []
    cur_group: list[torch.fx.Node] = []
    cur_group_key = None

    for node in g.nodes:
        if is_wait_tensor(node) and filter_node(node.args[0]):
            if (filter_wait_node is None) or filter_wait_node(node):
                coll_node = node.args[0]
                group_key = node_group_key(coll_node)
                if group_key == cur_group_key:
                    cur_group.append(coll_node)
                else:
                    if len(cur_group) > 1:
                        nodes_groups.append(cur_group)
                    cur_group = [coll_node]
                    cur_group_key = group_key

    if len(cur_group) > 1:
        nodes_groups.append(cur_group)

    buckets: list[list[torch.fx.Node]] = []
    for nodes in nodes_groups:
        cur_bucket: list[torch.fx.Node] = []
        cur_bucket_descendents: OrderedSet[torch.fx.Node] = OrderedSet()
        cur_bucket_size_bytes: int = 0
        cur_bucket_id: int = 0
        bucket_size_bytes = int(
            bucket_cap_mb_by_bucket_idx(cur_bucket_id) * 1024 * 1024
        )
        for node in nodes:
            if node in cur_bucket_descendents:
                # if there is a path from node to the current bucket, we cannot horizontally fuse (bucket)
                continue
            assert "val" in node.meta
            n_val = node.meta["val"]
            out_size_bytes = n_val.numel() * n_val.element_size()
            n_input_val = node.all_input_nodes[0].meta["val"]
            in_size_bytes = n_input_val.numel() * n_input_val.element_size()
            size_bytes = max(out_size_bytes, in_size_bytes)
            if cur_bucket_size_bytes + size_bytes > bucket_size_bytes and cur_bucket:
                # Current bucket is full, create new bucket
                if len(cur_bucket) > 1:
                    buckets.append(cur_bucket)
                cur_bucket = []
                cur_bucket_size_bytes = 0
                cur_bucket_id += 1
                cur_bucket_descendents = OrderedSet()
            cur_bucket_size_bytes += size_bytes
            cur_bucket.append(node)
            cur_bucket_descendents |= node_descendents[node]
        if len(cur_bucket) > 1:
            buckets.append(cur_bucket)
    return buckets


def bucket_all_gather_by_mb(
    gm: torch.fx.GraphModule,
    bucket_cap_mb_by_bucket_idx: Callable[[int], float],
    filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> list[list[torch.fx.Node]]:
    """
    Identifies all all_gather nodes and groups them into buckets,
    based on size limit `bucket_cap_mb_by_bucket_idx`.

    Args:
        gm (torch.fx.GraphModule): GraphModule where to bucket all_gathers.
        bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
            in megabytes by bucket idx.  The idea of `bucket_cap_mb_by_bucket_idx` is to allow
            to specify different sizes of the buckets at the start,
            as first all_gather is usually exposed.  Interface of bucket_cap_mb_by_bucket_idx
            is `bucket_cap_mb_by_bucket_idx_default` function that is default value for `bucket_cap_mb_by_bucket_idx`.
        filter_wait_node (Optional[Callable[[torch.fx.Node], bool]]): If specified,
            only all_gather nodes with wait_node that satisfy `filter_wait_node` will be bucketed.

    Returns:
        list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of all_gather nodes.
    """

    def _ag_group_key(node: torch.fx.Node) -> tuple[str, torch.dtype]:
        _, group_size, group_name = node.args
        dtype = node.meta["val"].dtype
        assert isinstance(group_name, str)
        return (group_name, dtype)

    return greedy_bucket_collective_by_mb(
        gm,
        bucket_cap_mb_by_bucket_idx,
        is_all_gather_into_tensor,
        _ag_group_key,
        filter_wait_node,
    )


def bucket_reduce_scatter_by_mb(
    gm: torch.fx.GraphModule,
    bucket_cap_mb_by_bucket_idx: Callable[[int], float],
    filter_wait_node: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> list[list[torch.fx.Node]]:
    """
    Identifies all reduce_scatter nodes and groups them into buckets,
        based on size limit `bucket_cap_mb_by_bucket_idx`.

    Args:
        gm (torch.fx.GraphModule): GraphModule where to bucket reduce_scatters.
        bucket_cap_mb_by_bucket_idx (Callable[[int], float]): Callable to specify cap of the bucket
            in megabytes by bucket idx.  The idea of `bucket_cap_mb_by_bucket_idx` is to allow
            to specify different sizes of the buckets.
        filter_wait_node (Optional[Callable[[torch.fx.Node], bool]]): If specified,
            only reduce_scatter nodes with wait_node that satisfy `filter_wait_node` will be bucketed.

    Returns:
        list[list[torch.fx.Node]]: List of buckets, where each bucket is a list of reduce_scatter nodes.
    """

    def _rs_group_key(node: torch.fx.Node) -> tuple[str, str, torch.dtype]:
        _, reduce_op, group_size, group_name = node.args
        dtype = node.meta["val"].dtype
        assert isinstance(group_name, str)
        assert isinstance(reduce_op, str)
        return (group_name, reduce_op, dtype)

    return greedy_bucket_collective_by_mb(
        gm,
        bucket_cap_mb_by_bucket_idx,
        is_reduce_scatter_tensor,
        _rs_group_key,
        filter_wait_node,
    )


@torch.library.custom_op("bucketing::_pre_bucket_reduce_scatter", mutates_args={})
def _pre_bucket_reduce_scatter(
    rs_ins: list[torch.Tensor],
    group_size: int,
) -> torch.Tensor:
    rs_ins_flattened = [x.view(group_size, -1) for x in rs_ins]
    new_rs_in = torch.cat(rs_ins_flattened, dim=1).flatten()
    return new_rs_in


def _pre_bucket_reduce_scatter_fake(
    rs_ins: list[torch.Tensor],
    group_size: int,
) -> torch.Tensor:
    out_numel = sum(rs_in.numel() for rs_in in rs_ins)
    return torch.empty((out_numel,), device=rs_ins[0].device, dtype=rs_ins[0].dtype)


_pre_bucket_reduce_scatter.register_fake(_pre_bucket_reduce_scatter_fake)


def reduce_scatter_merge_fn_to_trace_custom_ops(
    rs_ins: list[torch.Tensor],
    group_size: int,
    group_name: str,
    reduce_op: str,
    reduce_dtype: torch.dtype,  # type: ignore[name-defined]
    device: torch.device,  # type: ignore[name-defined]
) -> list[torch.Tensor]:  # type: ignore[no-untyped-def]
    new_out_sizes = [(x.shape[0] // group_size,) + x.shape[1:] for x in rs_ins]
    new_out_numels = [x.numel() // group_size for x in rs_ins]

    new_rs_in = torch.ops.bucketing._pre_bucket_reduce_scatter(rs_ins, group_size)

    # TODO - either use torch.cat or make sure inductor foreach codegen
    # fires more reliably
    new_rs_out = torch.ops.c10d_functional.wait_tensor(
        torch.ops._c10d_functional.reduce_scatter_tensor.default(
            new_rs_in, reduce_op, group_size, group_name
        )
    )
    new_out_flat = new_rs_out.split(new_out_numels, 0)
    new_outs = [x.view(s) for x, s in zip(new_out_flat, new_out_sizes)]
    return new_outs


def reduce_scatter_merge_fn_to_trace(
    rs_ins: list[torch.Tensor],
    group_size: int,
    group_name: str,
    reduce_op: str,
    reduce_dtype: torch.dtype,  # type: ignore[name-defined]
    device: torch.device,  # type: ignore[name-defined]
) -> list[torch.Tensor]:  # type: ignore[no-untyped-def]
    rs_ins_flattened = [x.view(group_size, -1) for x in rs_ins]

    new_out_sizes = [(x.shape[0] // group_size,) + x.shape[1:] for x in rs_ins]
    new_out_numels = [x.numel() // group_size for x in rs_ins]

    new_rs_in = torch.cat(rs_ins_flattened, dim=1).flatten()

    new_rs_out = torch.ops.c10d_functional.wait_tensor(
        torch.ops._c10d_functional.reduce_scatter_tensor.default(
            new_rs_in, reduce_op, group_size, group_name
        )
    )
    new_out_flat = new_rs_out.split(new_out_numels, 0)
    new_outs = [x.view(s) for x, s in zip(new_out_flat, new_out_sizes)]
    return new_outs


@torch.library.custom_op("bucketing::_pre_bucket_all_gather", mutates_args={})
def _pre_bucket_all_gather(
    ag_ins: list[torch.Tensor],
    group_size: int,
    group_name: str,
    dtype: torch.dtype,  # type: ignore[name-defined]
    rank: int,
) -> torch.Tensor:
    ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
    ag_input_numel = sum(ins_split_sizes)
    device = ag_ins[0].device
    new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
    new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
    foreach_copy_dsts = torch.split(new_ag_in, ins_split_sizes)
    ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins]
    torch._foreach_copy_(foreach_copy_dsts, ag_ins_flattened)
    return new_ag_out


def _pre_bucket_all_gather_fake(
    ag_ins: list[torch.Tensor],
    group_size: int,
    group_name: str,
    dtype: torch.dtype,  # type: ignore[name-defined]
    rank: int,
) -> torch.Tensor:
    ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
    ag_input_numel = sum(ins_split_sizes)
    device = ag_ins[0].device
    new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
    return new_ag_out


_pre_bucket_all_gather.register_fake(_pre_bucket_all_gather_fake)


def all_gather_merge_fn_to_trace_custom_ops(
    ag_ins: list[torch.Tensor],
    group_size: int,
    group_name: str,
    dtype: torch.dtype,  # type: ignore[name-defined]
    rank: int,
) -> list[torch.Tensor]:
    ins_sizes = [ag_in.shape for ag_in in ag_ins]
    ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
    ag_input_numel = sum(ins_split_sizes)
    new_ag_out = torch.ops.bucketing._pre_bucket_all_gather(
        ag_ins, group_size, group_name, dtype, rank
    )
    new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
    wait_tensor = torch.ops.c10d_functional.wait_tensor(
        torch.ops._c10d_functional.all_gather_into_tensor_out.default(
            new_ag_in, group_size, group_name, out=new_ag_out
        )
    )
    new_ag_out_reshaped = wait_tensor.reshape(group_size, -1)
    outs = torch.split_with_sizes(
        new_ag_out_reshaped,
        ins_split_sizes,
        dim=1,
    )
    outs_reshaped = [
        o.reshape((shape[0] * group_size,) + shape[1:])
        for o, shape in zip(outs, ins_sizes)
    ]
    return outs_reshaped


def all_gather_merge_fn_to_trace(
    ag_ins: list[torch.Tensor],
    group_size: int,
    group_name: str,
    dtype: torch.dtype,  # type: ignore[name-defined]
    rank: int,
) -> list[torch.Tensor]:
    ins_sizes = [ag_in.shape for ag_in in ag_ins]
    ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
    ag_input_numel = sum(ins_split_sizes)
    device = ag_ins[0].device
    new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
    new_ag_in = new_ag_out.narrow(0, ag_input_numel * rank, ag_input_numel)
    foreach_copy_dsts = torch.split(new_ag_in, ins_split_sizes)
    ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins]
    torch._foreach_copy_(foreach_copy_dsts, ag_ins_flattened)
    wait_tensor = torch.ops.c10d_functional.wait_tensor(
        torch.ops._c10d_functional.all_gather_into_tensor_out.default(
            new_ag_in, group_size, group_name, out=new_ag_out
        )
    )
    new_ag_out_reshaped = wait_tensor.reshape(group_size, -1)
    outs = torch.split_with_sizes(
        new_ag_out_reshaped,
        ins_split_sizes,
        dim=1,
    )
    outs_reshaped = [
        o.reshape((shape[0] * group_size,) + shape[1:])
        for o, shape in zip(outs, ins_sizes)
    ]
    return outs_reshaped


def all_gather_merge_fn_to_trace_functional(
    ag_ins: list[torch.Tensor],
    group_size: int,
    group_name: str,
    dtype: torch.dtype,  # type: ignore[name-defined]
    rank: int,
    use_fsdp_ag_copy_in: bool = False,
) -> list[torch.Tensor]:
    # Implementation that is functional in graph,
    # but uses custom op torch.ops.fsdp.all_gather_copy_in.
    ins_sizes = [ag_in.shape for ag_in in ag_ins]
    ins_split_sizes = [ag_in.numel() for ag_in in ag_ins]
    ag_input_numel = sum(ins_split_sizes)
    device = ag_ins[0].device
    new_ag_out = torch.empty(ag_input_numel * group_size, dtype=dtype, device=device)
    ag_ins_flattened = [ag_in.reshape(-1) for ag_in in ag_ins]
    if use_fsdp_ag_copy_in:
        new_ag_in, new_ag_out = torch.ops.fsdp.all_gather_copy_in(
            ag_ins_flattened, new_ag_out, ins_split_sizes, ag_input_numel, rank
        )
    else:
        new_ag_in = torch.cat(ag_ins_flattened, dim=0)
    wait_tensor = torch.ops.c10d_functional.wait_tensor(
        torch.ops._c10d_functional.all_gather_into_tensor_out.default(
            new_ag_in, group_size, group_name, out=new_ag_out
        )
    )
    new_ag_out_reshaped = wait_tensor.reshape(group_size, -1)
    outs = torch.split_with_sizes(
        new_ag_out_reshaped,
        ins_split_sizes,
        dim=1,
    )
    outs_reshaped = [
        o.reshape((shape[0] * group_size,) + shape[1:])
        for o, shape in zip(outs, ins_sizes)
    ]
    return outs_reshaped


def _trace(fn, inps) -> torch.fx.GraphModule:  # type: ignore[no-untyped-def]
    with dynamo_timed("fx.bucketing._trace", log_pt2_compile_event=True):
        fake_mode = detect_fake_mode(inps)
        assert fake_mode is not None
        with fake_mode, enable_python_dispatcher():
            out = make_fx(fn)(*inps)
            for node in out.graph.find_nodes(
                op="call_function", target=torch.ops.aten.detach.default
            ):
                node.replace_all_uses_with(node.args[0])
                out.graph.erase_node(node)
            return out


def _insert_fn_trace_before_node(  # type: ignore[no-untyped-def]
    g: torch.fx.Graph,
    fn_to_trace,
    inps,
    insert_before_node: torch.fx.Node,
    g_fn_inps: list[torch.fx.Node],
    g_fn_outs: list[torch.fx.Node],
) -> dict[torch.fx.Node, torch.fx.Node]:  # type: ignore[no-untyped-def]
    """
    Helper function that traces :attr:`fn_to_trace` with inputs
    :attr:`inps`.
    The result function graph will be inserted before :attr:`insert_before_node`,
    using :attr:`g_fn_inps` nodes of original graphas inputs of function graph,
    function graph outputs will replace :attr:`g_fn_outs` in original graph.
    """
    with dynamo_timed(
        "fx.bucketing._insert_fn_trace_before_node", log_pt2_compile_event=True
    ):
        fn_gm = _trace(
            fn_to_trace,
            inps,
        )
        fn_g = fn_gm.graph
        fn_g_ins = fn_g.find_nodes(op="placeholder")
        env = {fn_g_ins[idx]: g_fn_inps[idx] for idx in range(len(g_fn_inps))}
        g_fn_new_outs: list[torch.fx.Node] = []
        with g.inserting_before(insert_before_node):
            for _n in fn_g.nodes:
                if _n.op == "placeholder":
                    continue
                _new_n = g.node_copy(_n, lambda x: env[x])
                env[_n] = _new_n
                if _n.op == "output":
                    g_fn_new_outs = _new_n.args[0]  # type: ignore[assignment]
                    g.erase_node(_new_n)
        replacements = {  # noqa: C416
            orig_out: new_out for orig_out, new_out in zip(g_fn_outs, g_fn_new_outs)
        }
        for orig_out, new_out in zip(g_fn_outs, g_fn_new_outs):
            orig_out.replace_all_uses_with(new_out)
        return replacements


def merge_reduce_scatter(
    gm: torch.fx.GraphModule,
    rs_buckets: list[list[torch.fx.Node]],
    mode: Optional[str] = None,
) -> None:
    """
    Merges specified buckets of reduce_scatter to joint reduce_scatter.
    """
    with dynamo_timed("fx.bucketing.merge_reduce_scatter", log_pt2_compile_event=True):
        rs_merge_fn = reduce_scatter_merge_fn_to_trace
        if mode and "custom_ops" in mode:
            rs_merge_fn = reduce_scatter_merge_fn_to_trace_custom_ops
        trace_structured(
            "artifact",
            metadata_fn=lambda: {
                "name": "fx_bucketing_passes_reduce_scatter_buckets",
                "encoding": "string",
            },
            payload_fn=lambda: str(rs_buckets),
        )
        n_buckets = len(rs_buckets)
        g = gm.graph
        rs_ins: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
        rs_waits: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]

        for bucket_idx, rs_nodes in enumerate(rs_buckets):
            rs0 = rs_nodes[0]
            rs0_val = rs0.meta["val"]
            _, reduce_op, group_size, group_name = rs0.args
            reduce_dtype = rs0_val.dtype
            device = rs0_val.device
            for n in rs_nodes:
                rs_val = n.meta["val"]
                assert (
                    n.args[1] == reduce_op
                    and n.args[2] == group_size
                    and n.args[3] == group_name
                    and rs_val.device == device
                    and rs_val.dtype == reduce_dtype
                )
                assert len(n.users) == 1
                wait_n = next(iter(n.users))
                rs_ins[bucket_idx].append(n.args[0])  # type: ignore[arg-type]
                rs_waits[bucket_idx].append(wait_n)

        for bucket_idx in range(n_buckets):
            _rs_ins = rs_ins[bucket_idx]
            _rs_waits = rs_waits[bucket_idx]
            _rs_ns = rs_buckets[bucket_idx]

            rs0 = _rs_ns[0]
            rs0_val = rs0.meta["val"]
            _, reduce_op, group_size, group_name = rs0.args
            reduce_dtype = rs0_val.dtype
            device = rs0_val.device

            replacements = _insert_fn_trace_before_node(
                g,
                rs_merge_fn,
                (
                    pytree.tree_map(lambda node: node.meta["val"], _rs_ins),
                    group_size,
                    group_name,
                    reduce_op,
                    reduce_dtype,
                    device,
                ),
                _rs_ns[-1].next,
                _rs_ins,
                _rs_waits,
            )
            # [Note: Replacement in bucketing passes]
            # After bucketing _rs_waits will be replaced with output nodes of
            # fn_to_trace graph that will be inserted in the graph g.
            # By this time we already prepared rs_ins, rs_waits.
            # rs_ins for following buckets can be replaced _rs_waits with new nodes.
            # We apply replacements to rs_ins.

            def _replace(x: torch.fx.Node) -> torch.fx.Node:
                return replacements.get(x, x)

            for j in range(bucket_idx + 1, n_buckets):
                rs_ins[j] = pytree.tree_map(_replace, rs_ins[j])

            for rs_n, wait_n in zip(_rs_ns, _rs_waits):
                g.erase_node(wait_n)
                g.erase_node(rs_n)


def merge_all_gather(
    gm: torch.fx.GraphModule,
    ag_buckets: list[list[torch.fx.Node]],
    mode: Optional[str] = None,
) -> None:  # type: ignore[union-attr]
    """
    Merges specified buckets of all_gather to joint all_gather.
    """
    with dynamo_timed("fx.bucketing.merge_all_gather", log_pt2_compile_event=True):
        from torch.distributed.distributed_c10d import _resolve_process_group

        ag_merge_fn = all_gather_merge_fn_to_trace
        if mode and "custom_ops" in mode:
            ag_merge_fn = all_gather_merge_fn_to_trace_custom_ops

        trace_structured(
            "artifact",
            metadata_fn=lambda: {
                "name": "fx_bucketing_passes_all_gather_buckets",
                "encoding": "string",
            },
            payload_fn=lambda: str(ag_buckets),
        )
        n_buckets = len(ag_buckets)

        ag_node_to_pre_nodes = defaultdict(list)

        ag_ins: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
        ag_waits: list[list[torch.fx.Node]] = [[] for _ in range(n_buckets)]
        for bucket_idx, ag_bucket in enumerate(ag_buckets):
            _, group_size, group_name = ag_bucket[0].args
            assert isinstance(group_name, str)
            dtype = ag_bucket[0].meta["val"].dtype

            for ag_node in ag_bucket:
                assert len(ag_node.users) == 1, (
                    f"Expect only one user for {ag_node}, but got {ag_node.users}"
                )
                wait_node = next(iter(ag_node.users))
                assert (
                    ag_node.args[1] == group_size
                    and ag_node.args[2] == group_name
                    and ag_node.meta["val"].dtype == dtype
                )
                ag_node_in = ag_node.args[0]
                if (
                    ag_node_in.op == "call_function"  # type: ignore[union-attr]
                    and ag_node_in.target  # type: ignore[union-attr]
                    == torch.ops.prims.convert_element_type.default  # type: ignore[union-attr]
                    and len(ag_node_in.users) == 1  # type: ignore[union-attr]
                ):
                    ag_node_to_pre_nodes[ag_node].append(ag_node_in)
                    ag_node_in = ag_node_in.args[0]  # type: ignore[union-attr]

                ag_ins[bucket_idx].append(ag_node_in)  # type: ignore[union-attr, arg-type]
                ag_waits[bucket_idx].append(wait_node)

        g = gm.graph

        for bucket_idx in range(n_buckets):
            _ag_ins = ag_ins[bucket_idx]
            _ag_waits = ag_waits[bucket_idx]
            _ag_ns = ag_buckets[bucket_idx]

            ag0 = _ag_ns[0]
            ag0_val = ag0.meta["val"]
            _, group_size, group_name = ag0.args
            dtype = ag0_val.dtype
            assert isinstance(group_name, str)

            rank: int = dist.get_rank(_resolve_process_group(group_name))

            replacements = _insert_fn_trace_before_node(
                g,
                ag_merge_fn,
                (
                    pytree.tree_map(lambda node: node.meta["val"], _ag_ins),
                    group_size,
                    group_name,
                    dtype,
                    rank,
                ),
                ag0.next,
                _ag_ins,
                _ag_waits,
            )

            # See Note: [Replacement in bucketing passes]
            def _replace(x: torch.fx.Node) -> torch.fx.Node:
                return replacements.get(x, x)

            for j in range(bucket_idx + 1, n_buckets):
                ag_ins[j] = pytree.tree_map(_replace, ag_ins[j])

            # Erasing old nodes in reverse order
            for ag_n, wait_n in zip(ag_buckets[bucket_idx], _ag_waits):
                g.erase_node(wait_n)
                g.erase_node(ag_n)
                for n in reversed(ag_node_to_pre_nodes[ag_n]):
                    g.erase_node(n)  # type: ignore[arg-type]
