Support (1 <= dp < tp) in the dp attention in DeepEP (#4770)
Co-authored-by: Cheng Wan <cwan39@gatech.edu>
This commit is contained in:
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import functools
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -249,3 +249,14 @@ def dp_scatter(
|
||||
memcpy_triton(
|
||||
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
||||
)
|
||||
|
||||
|
||||
def tp_reduce_scatter(
|
||||
output: torch.Tensor,
|
||||
input_list: List[torch.Tensor],
|
||||
):
|
||||
return get_attention_tp_group().reduce_scatter(output, input_list)
|
||||
|
||||
|
||||
def tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
|
||||
return get_attention_tp_group().all_gather(input_, tensor_list=output_list)
|
||||
|
||||
Reference in New Issue
Block a user