[AMD] add aiter fused moe in DeepEP path (#7268)
This commit is contained in:
@@ -54,10 +54,16 @@ from sglang.srt.utils import (
|
|||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
_is_fp8_fnuz = is_fp8_fnuz()
|
_is_fp8_fnuz = is_fp8_fnuz()
|
||||||
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||||
|
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
from vllm._custom_ops import scaled_fp8_quant
|
from vllm._custom_ops import scaled_fp8_quant
|
||||||
|
|
||||||
|
if _use_aiter:
|
||||||
|
from aiter import ActivationType, QuantType
|
||||||
|
from aiter.fused_moe import fused_moe
|
||||||
|
from aiter.ops.shuffle import shuffle_weight
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -1046,6 +1052,15 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|||||||
w2_weight_scale, requires_grad=False
|
w2_weight_scale, requires_grad=False
|
||||||
)
|
)
|
||||||
layer.w2_input_scale = None
|
layer.w2_input_scale = None
|
||||||
|
if _use_aiter:
|
||||||
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
|
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.w2_weight = torch.nn.Parameter(
|
||||||
|
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
@@ -1117,18 +1132,36 @@ class DeepEPMoE(EPMoE):
|
|||||||
assert (
|
assert (
|
||||||
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||||
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
||||||
self.w13_weight_fp8 = (
|
if _use_aiter:
|
||||||
self.w13_weight,
|
# expert_mask is of size (self.num_experts_per_partition + 1),
|
||||||
(
|
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
|
||||||
self.w13_weight_scale_inv
|
# for instance, if we have 4 experts on this rank, we would have a expert_mask like:
|
||||||
if self.use_block_quant
|
# self.expert_mask = [1, 1, 1, 1, 0]
|
||||||
else self.w13_weight_scale
|
# idx from 0-3 is valid and will be processed, while idx == 4 will be masked out
|
||||||
),
|
self.expert_mask = torch.zeros(
|
||||||
)
|
(self.num_experts_per_partition + 1),
|
||||||
self.w2_weight_fp8 = (
|
device=torch.cuda.current_device(),
|
||||||
self.w2_weight,
|
dtype=torch.int,
|
||||||
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
|
)
|
||||||
)
|
# the last one is invalid rank_id
|
||||||
|
self.expert_mask[:-1] = 1
|
||||||
|
else:
|
||||||
|
self.w13_weight_fp8 = (
|
||||||
|
self.w13_weight,
|
||||||
|
(
|
||||||
|
self.w13_weight_scale_inv
|
||||||
|
if self.use_block_quant
|
||||||
|
else self.w13_weight_scale
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.w2_weight_fp8 = (
|
||||||
|
self.w2_weight,
|
||||||
|
(
|
||||||
|
self.w2_weight_scale_inv
|
||||||
|
if self.use_block_quant
|
||||||
|
else self.w2_weight_scale
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -1142,6 +1175,9 @@ class DeepEPMoE(EPMoE):
|
|||||||
num_recv_tokens_per_expert: List[int],
|
num_recv_tokens_per_expert: List[int],
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
):
|
):
|
||||||
|
if _use_aiter:
|
||||||
|
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
||||||
|
return self.forward_aiter(hidden_states, topk_idx, topk_weights)
|
||||||
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
resolved_deepep_mode = self.deepep_mode.resolve(forward_mode)
|
||||||
if resolved_deepep_mode == DeepEPMode.normal:
|
if resolved_deepep_mode == DeepEPMode.normal:
|
||||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
||||||
@@ -1274,6 +1310,37 @@ class DeepEPMoE(EPMoE):
|
|||||||
)
|
)
|
||||||
return down_output
|
return down_output
|
||||||
|
|
||||||
|
def forward_aiter(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
topk_idx: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
):
|
||||||
|
if hidden_states.shape[0] == 0:
|
||||||
|
return hidden_states
|
||||||
|
# in original deepep, idx == -1 meaning invalid and will not be processed.
|
||||||
|
# aiter does not accept -1, we use a expert mask to make these idx invalid
|
||||||
|
# (idx == num_experts_per_partition) meaning not used in aiter fused_moe
|
||||||
|
topk_idx_copy = topk_idx.to(torch.int32)
|
||||||
|
topk_idx_copy[topk_idx_copy == -1] = self.num_experts_per_partition
|
||||||
|
|
||||||
|
return fused_moe(
|
||||||
|
hidden_states,
|
||||||
|
self.w13_weight,
|
||||||
|
self.w2_weight,
|
||||||
|
topk_weights,
|
||||||
|
topk_idx_copy,
|
||||||
|
w1_scale=self.w13_weight_scale_inv,
|
||||||
|
w2_scale=self.w2_weight_scale_inv,
|
||||||
|
quant_type=QuantType.per_128x128,
|
||||||
|
activation=(
|
||||||
|
ActivationType.Silu
|
||||||
|
if self.activation == "silu"
|
||||||
|
else ActivationType.Gelu
|
||||||
|
),
|
||||||
|
expert_mask=self.expert_mask,
|
||||||
|
)
|
||||||
|
|
||||||
def forward_deepgemm_contiguous(
|
def forward_deepgemm_contiguous(
|
||||||
self,
|
self,
|
||||||
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
|
hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
|||||||
@@ -6,7 +6,13 @@ from sglang.srt.managers.expert_distribution import (
|
|||||||
get_global_expert_distribution_recorder,
|
get_global_expert_distribution_recorder,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.utils import DeepEPMode, get_int_env_var, load_json_config
|
from sglang.srt.utils import (
|
||||||
|
DeepEPMode,
|
||||||
|
get_bool_env_var,
|
||||||
|
get_int_env_var,
|
||||||
|
is_hip,
|
||||||
|
load_json_config,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from deep_ep import Buffer, Config
|
from deep_ep import Buffer, Config
|
||||||
@@ -32,6 +38,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
|
|
||||||
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -376,6 +384,15 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
|
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
|
||||||
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
|
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
|
||||||
"""
|
"""
|
||||||
|
if _use_aiter:
|
||||||
|
# skip permutation here as aiter fused_moe has fused inside
|
||||||
|
reorder_topk_ids = torch.empty(
|
||||||
|
(0,), device=hidden_states.device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
seg_indptr = torch.zeros(
|
||||||
|
(self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
||||||
|
)
|
||||||
|
return reorder_topk_ids, seg_indptr, hidden_states
|
||||||
|
|
||||||
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
||||||
topk_idx, self.num_experts
|
topk_idx, self.num_experts
|
||||||
@@ -409,7 +426,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
topk_idx: torch.Tensor,
|
topk_idx: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
):
|
):
|
||||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter:
|
||||||
output = hidden_states
|
output = hidden_states
|
||||||
else:
|
else:
|
||||||
if hidden_states.shape[0] > 0:
|
if hidden_states.shape[0] > 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user