Support (1 <= dp < tp) in the dp attention in DeepEP (#4770)

Co-authored-by: Cheng Wan <cwan39@gatech.edu>
This commit is contained in:
tarinkk
2025-03-27 20:09:35 -04:00
committed by GitHub
parent 98a2cfa9b2
commit 7f19e083c1
10 changed files with 238 additions and 47 deletions

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import functools
import logging
from contextlib import contextmanager
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING, List
import torch
import triton
@@ -249,3 +249,14 @@ def dp_scatter(
memcpy_triton(
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
)
def tp_reduce_scatter(
output: torch.Tensor,
input_list: List[torch.Tensor],
):
return get_attention_tp_group().reduce_scatter(output, input_list)
def tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
return get_attention_tp_group().all_gather(input_, tensor_list=output_list)