adapt w4a8 marlin deepep dp ep
This commit is contained in:
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, List, Optional, Union
|
from typing import TYPE_CHECKING, List, Optional, Union
|
||||||
|
|
||||||
|
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
@@ -124,7 +125,6 @@ class EPMoE(FusedMoE):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.intermediate_size = intermediate_size
|
self.intermediate_size = intermediate_size
|
||||||
|
|
||||||
if isinstance(quant_config, Fp8Config):
|
if isinstance(quant_config, Fp8Config):
|
||||||
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
||||||
self.block_shape = (
|
self.block_shape = (
|
||||||
@@ -135,11 +135,23 @@ class EPMoE(FusedMoE):
|
|||||||
self.use_fp8_w8a8 = True
|
self.use_fp8_w8a8 = True
|
||||||
self.fp8_dtype = torch.float8_e4m3fn
|
self.fp8_dtype = torch.float8_e4m3fn
|
||||||
self.activation_scheme = quant_config.activation_scheme
|
self.activation_scheme = quant_config.activation_scheme
|
||||||
|
self.use_w4a8_marlin = False
|
||||||
|
elif isinstance(quant_config, SlimQuantW4A8Int8MarlinConfig):
|
||||||
|
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
||||||
|
self.block_shape = (
|
||||||
|
self.quant_method.quant_config.weight_block_size
|
||||||
|
if self.use_block_quant
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
self.use_fp8_w8a8 = False
|
||||||
|
self.activation_scheme = None
|
||||||
|
self.use_w4a8_marlin = True
|
||||||
else:
|
else:
|
||||||
self.use_fp8_w8a8 = False
|
self.use_fp8_w8a8 = False
|
||||||
self.use_block_quant = False
|
self.use_block_quant = False
|
||||||
self.block_shape = None
|
self.block_shape = None
|
||||||
self.activation_scheme = None
|
self.activation_scheme = None
|
||||||
|
self.use_w4a8_marlin = False
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
||||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
||||||
@@ -386,11 +398,11 @@ class DeepEPMoE(EPMoE):
|
|||||||
return_recv_hook=True,
|
return_recv_hook=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.deepep_mode.enable_low_latency() and not _is_npu:
|
# if self.deepep_mode.enable_low_latency() and not _is_npu:
|
||||||
# NPU supports low_latency deepep without deepgemm
|
# # NPU supports low_latency deepep without deepgemm
|
||||||
assert (
|
# assert (
|
||||||
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
# deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||||
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
# ), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
|
||||||
if _use_aiter:
|
if _use_aiter:
|
||||||
# expert_mask is of size (self.num_local_experts + 1),
|
# expert_mask is of size (self.num_local_experts + 1),
|
||||||
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
|
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
|
||||||
@@ -404,23 +416,23 @@ class DeepEPMoE(EPMoE):
|
|||||||
)
|
)
|
||||||
# the last one is invalid rank_id
|
# the last one is invalid rank_id
|
||||||
self.expert_mask[:-1] = 1
|
self.expert_mask[:-1] = 1
|
||||||
elif not _is_npu:
|
# elif not _is_npu:
|
||||||
self.w13_weight_fp8 = (
|
# self.w13_weight_fp8 = (
|
||||||
self.w13_weight,
|
# self.w13_weight,
|
||||||
(
|
# (
|
||||||
self.w13_weight_scale_inv
|
# self.w13_weight_scale_inv
|
||||||
if self.use_block_quant
|
# if self.use_block_quant
|
||||||
else self.w13_weight_scale
|
# else self.w13_weight_scale
|
||||||
),
|
# ),
|
||||||
)
|
# )
|
||||||
self.w2_weight_fp8 = (
|
# self.w2_weight_fp8 = (
|
||||||
self.w2_weight,
|
# self.w2_weight,
|
||||||
(
|
# (
|
||||||
self.w2_weight_scale_inv
|
# self.w2_weight_scale_inv
|
||||||
if self.use_block_quant
|
# if self.use_block_quant
|
||||||
else self.w2_weight_scale
|
# else self.w2_weight_scale
|
||||||
),
|
# ),
|
||||||
)
|
# )
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -466,8 +478,15 @@ class DeepEPMoE(EPMoE):
|
|||||||
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
|
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
|
||||||
return self.forward_npu(dispatch_output)
|
return self.forward_npu(dispatch_output)
|
||||||
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
|
||||||
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
#assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
||||||
return self.forward_deepgemm_contiguous(dispatch_output)
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
||||||
|
return self.forward_deepgemm_contiguous(dispatch_output)
|
||||||
|
elif self.use_w4a8_marlin:
|
||||||
|
return self.forward_deepgemm_w4a8_marlin_contiguous(dispatch_output)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Dispatch output is not supported"
|
||||||
|
)
|
||||||
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
|
||||||
if get_moe_runner_backend().is_flashinfer_cutedsl():
|
if get_moe_runner_backend().is_flashinfer_cutedsl():
|
||||||
return self.forward_flashinfer_cutedsl(dispatch_output)
|
return self.forward_flashinfer_cutedsl(dispatch_output)
|
||||||
@@ -526,6 +545,34 @@ class DeepEPMoE(EPMoE):
|
|||||||
expert_mask=self.expert_mask,
|
expert_mask=self.expert_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def forward_deepgemm_w4a8_marlin_contiguous(
|
||||||
|
self,
|
||||||
|
dispatch_output: DeepEPNormalOutput,
|
||||||
|
):
|
||||||
|
hidden_states_int8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
|
||||||
|
dispatch_output
|
||||||
|
)
|
||||||
|
assert self.quant_method is not None
|
||||||
|
assert self.moe_runner_config.activation == "silu"
|
||||||
|
# if num_recv_tokens_per_expert is None:
|
||||||
|
return hidden_states_int8.bfloat16()
|
||||||
|
# expert_output = self.quant_method.apply_ep(
|
||||||
|
# layer=self,
|
||||||
|
# x=dispatch_output,
|
||||||
|
# topk_weights=topk_weights,
|
||||||
|
# topk_ids=topk_idx,
|
||||||
|
# global_num_experts=self.global_num_experts,
|
||||||
|
# expert_map=self.expert_map,
|
||||||
|
# activation=self.activation,
|
||||||
|
# apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||||
|
# use_nn_moe=self.use_nn_moe,
|
||||||
|
# num_local_tokens=dispatch_recv_num_token,
|
||||||
|
# config_select_bs=hidden_states.shape[0],
|
||||||
|
# scales=dispatch_scales if self.use_int8_dispatch else None
|
||||||
|
# # routed_scaling_factor=self.routed_scaling_factor,
|
||||||
|
# )
|
||||||
|
# return expert_output
|
||||||
|
|
||||||
def forward_deepgemm_contiguous(
|
def forward_deepgemm_contiguous(
|
||||||
self,
|
self,
|
||||||
dispatch_output: DeepEPNormalOutput,
|
dispatch_output: DeepEPNormalOutput,
|
||||||
|
|||||||
@@ -431,32 +431,32 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|||||||
deepep_post_reorder_triton_kernel,
|
deepep_post_reorder_triton_kernel,
|
||||||
)
|
)
|
||||||
|
|
||||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
|
#if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM or _use_aiter or _is_npu:
|
||||||
output = hidden_states
|
output = hidden_states
|
||||||
else:
|
# else:
|
||||||
if hidden_states.shape[0] > 0:
|
# if hidden_states.shape[0] > 0:
|
||||||
num_tokens = self.src2dst.shape[0] // self.router_topk
|
# num_tokens = self.src2dst.shape[0] // self.router_topk
|
||||||
output = torch.empty(
|
# output = torch.empty(
|
||||||
(num_tokens, hidden_states.shape[1]),
|
# (num_tokens, hidden_states.shape[1]),
|
||||||
device=hidden_states.device,
|
# device=hidden_states.device,
|
||||||
dtype=hidden_states.dtype,
|
# dtype=hidden_states.dtype,
|
||||||
)
|
# )
|
||||||
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
# deepep_post_reorder_triton_kernel[(num_tokens,)](
|
||||||
hidden_states,
|
# hidden_states,
|
||||||
output,
|
# output,
|
||||||
self.src2dst,
|
# self.src2dst,
|
||||||
topk_idx,
|
# topk_idx,
|
||||||
topk_weights,
|
# topk_weights,
|
||||||
self.router_topk,
|
# self.router_topk,
|
||||||
hidden_states.shape[1],
|
# hidden_states.shape[1],
|
||||||
BLOCK_SIZE=512,
|
# BLOCK_SIZE=512,
|
||||||
)
|
# )
|
||||||
else:
|
# else:
|
||||||
output = torch.zeros(
|
# output = torch.zeros(
|
||||||
(0, hidden_states.shape[1]),
|
# (0, hidden_states.shape[1]),
|
||||||
device=hidden_states.device,
|
# device=hidden_states.device,
|
||||||
dtype=hidden_states.dtype,
|
# dtype=hidden_states.dtype,
|
||||||
)
|
# )
|
||||||
previous_event = Buffer.capture() if self.async_finish else None
|
previous_event = Buffer.capture() if self.async_finish else None
|
||||||
return output, previous_event
|
return output, previous_event
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user