Let EP prefill support new DeepGEMM (#7310)
This commit is contained in:
@@ -46,6 +46,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.utils import (
|
||||
DeepEPMode,
|
||||
ceil_div,
|
||||
dispose_tensor,
|
||||
get_bool_env_var,
|
||||
is_hip,
|
||||
@@ -1370,10 +1371,19 @@ class DeepEPMoE(EPMoE):
|
||||
device=hidden_states_fp8.device,
|
||||
dtype=hidden_states_fp8.dtype,
|
||||
),
|
||||
torch.empty(
|
||||
(all_tokens, K // 128),
|
||||
device=hidden_states_fp8.device,
|
||||
dtype=torch.float32,
|
||||
(
|
||||
# TODO check whether need `zeros`
|
||||
torch.zeros(
|
||||
(ceil_div(K // 128, 4), all_tokens),
|
||||
device=hidden_states_fp8.device,
|
||||
dtype=torch.int,
|
||||
).transpose(0, 1)
|
||||
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
||||
else torch.empty(
|
||||
(all_tokens, K // 128),
|
||||
device=hidden_states_fp8.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
),
|
||||
]
|
||||
m_indices = torch.empty(
|
||||
@@ -1399,6 +1409,7 @@ class DeepEPMoE(EPMoE):
|
||||
input_tensor[1],
|
||||
m_indices,
|
||||
output_index,
|
||||
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
||||
)
|
||||
dispose_tensor(hidden_states_fp8)
|
||||
|
||||
@@ -1407,7 +1418,8 @@ class DeepEPMoE(EPMoE):
|
||||
device=hidden_states_fp8_device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
input_tensor[1] = tma_align_input_scale(input_tensor[1])
|
||||
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
||||
input_tensor[1] = tma_align_input_scale(input_tensor[1])
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
||||
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
|
||||
)
|
||||
@@ -1428,10 +1440,15 @@ class DeepEPMoE(EPMoE):
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
|
||||
down_input, scale_block_size
|
||||
down_input,
|
||||
scale_block_size,
|
||||
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
||||
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
||||
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
||||
)
|
||||
del down_input
|
||||
down_input_scale = tma_align_input_scale(down_input_scale)
|
||||
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
|
||||
down_input_scale = tma_align_input_scale(down_input_scale)
|
||||
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
|
||||
(down_input_fp8, down_input_scale),
|
||||
self.w2_weight_fp8,
|
||||
|
||||
@@ -246,7 +246,13 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
||||
topk_idx = topk_idx.to(torch.int64)
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
||||
# TODO hard code 128 block quant,use fp8 communication
|
||||
hidden_states = sglang_per_token_group_quant_fp8(hidden_states, 128)
|
||||
hidden_states = sglang_per_token_group_quant_fp8(
|
||||
hidden_states,
|
||||
128,
|
||||
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
||||
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
||||
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
|
||||
)
|
||||
previous_event = Buffer.capture() if self.async_finish else None
|
||||
return hidden_states, topk_idx, topk_weights, previous_event
|
||||
|
||||
|
||||
Reference in New Issue
Block a user