Optimized deepseek-v3/r1 model performance on mxfp4 run (#10008)
Co-authored-by: wunhuang <wunhuang@amd.com> Co-authored-by: HAI <hixiao@gmail.com> Co-authored-by: Hubert Lu <55214931+hubertlu-tw@users.noreply.github.com>
This commit is contained in:
@@ -112,6 +112,7 @@ from sglang.srt.utils import (
|
||||
is_cpu,
|
||||
is_cuda,
|
||||
is_flashinfer_available,
|
||||
is_gfx95_supported,
|
||||
is_hip,
|
||||
is_non_idle_and_non_empty,
|
||||
is_npu,
|
||||
@@ -129,6 +130,22 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
_device_sm = get_device_sm()
|
||||
_is_gfx95_supported = is_gfx95_supported()
|
||||
|
||||
_use_aiter_gfx95 = _use_aiter and _is_gfx95_supported
|
||||
|
||||
if _use_aiter_gfx95:
|
||||
from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights
|
||||
from sglang.srt.layers.quantization.rocm_mxfp4_utils import (
|
||||
batched_gemm_afp4wfp4_pre_quant,
|
||||
fused_flatten_mxfp4_quant,
|
||||
fused_rms_mxfp4_quant,
|
||||
)
|
||||
from sglang.srt.layers.rocm_linear_utils import (
|
||||
aiter_dsv3_router_gemm,
|
||||
fused_qk_rope_cat,
|
||||
get_dsv3_gemm_output_zero_allocator_size,
|
||||
)
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import (
|
||||
@@ -224,10 +241,17 @@ class DeepseekV2MLP(nn.Module):
|
||||
forward_batch=None,
|
||||
should_allreduce_fusion: bool = False,
|
||||
use_reduce_scatter: bool = False,
|
||||
gemm_output_zero_allocator: BumpAllocator = None,
|
||||
):
|
||||
if (self.tp_size == 1) and x.shape[0] == 0:
|
||||
return x
|
||||
|
||||
if gemm_output_zero_allocator != None and x.shape[0] <= 256:
|
||||
y = gemm_output_zero_allocator.allocate(
|
||||
x.shape[0] * self.gate_up_proj.output_size_per_partition
|
||||
).view(x.shape[0], self.gate_up_proj.output_size_per_partition)
|
||||
x = (x, None, y)
|
||||
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(
|
||||
@@ -257,7 +281,7 @@ class MoEGate(nn.Module):
|
||||
if _is_cpu and _is_cpu_amx_available:
|
||||
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
||||
|
||||
def forward(self, hidden_states):
|
||||
def forward(self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None):
|
||||
if use_intel_amx_backend(self):
|
||||
return torch.ops.sgl_kernel.weight_packed_linear(
|
||||
hidden_states,
|
||||
@@ -276,6 +300,10 @@ class MoEGate(nn.Module):
|
||||
):
|
||||
# router gemm output float32
|
||||
logits = dsv3_router_gemm(hidden_states, self.weight)
|
||||
elif _use_aiter_gfx95 and hidden_states.shape[0] <= 256:
|
||||
logits = aiter_dsv3_router_gemm(
|
||||
hidden_states, self.weight, gemm_output_zero_allocator
|
||||
)
|
||||
else:
|
||||
logits = F.linear(hidden_states, self.weight, None)
|
||||
|
||||
@@ -439,6 +467,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
forward_batch: Optional[ForwardBatch] = None,
|
||||
should_allreduce_fusion: bool = False,
|
||||
use_reduce_scatter: bool = False,
|
||||
gemm_output_zero_allocator: BumpAllocator = None,
|
||||
) -> torch.Tensor:
|
||||
if not self._enable_deepep_moe:
|
||||
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
||||
@@ -452,12 +481,14 @@ class DeepseekV2MoE(nn.Module):
|
||||
hidden_states,
|
||||
should_allreduce_fusion,
|
||||
use_reduce_scatter,
|
||||
gemm_output_zero_allocator,
|
||||
)
|
||||
else:
|
||||
return self.forward_normal(
|
||||
hidden_states,
|
||||
should_allreduce_fusion,
|
||||
use_reduce_scatter,
|
||||
gemm_output_zero_allocator,
|
||||
)
|
||||
else:
|
||||
return self.forward_deepep(hidden_states, forward_batch)
|
||||
@@ -467,15 +498,18 @@ class DeepseekV2MoE(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
should_allreduce_fusion: bool = False,
|
||||
use_reduce_scatter: bool = False,
|
||||
gemm_output_zero_allocator: BumpAllocator = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
current_stream = torch.cuda.current_stream()
|
||||
self.alt_stream.wait_stream(current_stream)
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
shared_output = self._forward_shared_experts(
|
||||
hidden_states, gemm_output_zero_allocator
|
||||
)
|
||||
|
||||
with torch.cuda.stream(self.alt_stream):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
final_hidden_states = self.experts(hidden_states, topk_output)
|
||||
if not _is_cuda:
|
||||
@@ -502,6 +536,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
should_allreduce_fusion: bool = False,
|
||||
use_reduce_scatter: bool = False,
|
||||
gemm_output_zero_allocator: BumpAllocator = None,
|
||||
) -> torch.Tensor:
|
||||
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
||||
self.shared_experts.gate_up_proj
|
||||
@@ -509,9 +544,11 @@ class DeepseekV2MoE(nn.Module):
|
||||
return self.forward_cpu(hidden_states, should_allreduce_fusion)
|
||||
|
||||
if hidden_states.shape[0] > 0:
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
shared_output = self._forward_shared_experts(
|
||||
hidden_states, gemm_output_zero_allocator
|
||||
)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
else:
|
||||
shared_output = None
|
||||
@@ -631,9 +668,13 @@ class DeepseekV2MoE(nn.Module):
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
def _forward_shared_experts(self, hidden_states):
|
||||
def _forward_shared_experts(
|
||||
self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None
|
||||
):
|
||||
if self.num_fused_shared_experts == 0:
|
||||
return self.shared_experts(hidden_states)
|
||||
return self.shared_experts(
|
||||
hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -1097,11 +1138,19 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
if self.attn_mha.kv_b_proj is None:
|
||||
self.attn_mha.kv_b_proj = self.kv_b_proj
|
||||
|
||||
if hidden_states.shape[0] == 0:
|
||||
assert (
|
||||
not self.o_proj.reduce_results
|
||||
), "short-circuiting allreduce will lead to hangs"
|
||||
return hidden_states, None, forward_batch, None
|
||||
# when hidden_states is a tuple of tensors, the tuple will include quantized weight and scale tensor
|
||||
if isinstance(hidden_states, tuple):
|
||||
if hidden_states[0].shape[0] == 0:
|
||||
assert (
|
||||
not self.o_proj.reduce_results
|
||||
), "short-circuiting allreduce will lead to hangs"
|
||||
return hidden_states[0]
|
||||
else:
|
||||
if hidden_states.shape[0] == 0:
|
||||
assert (
|
||||
not self.o_proj.reduce_results
|
||||
), "short-circuiting allreduce will lead to hangs"
|
||||
return hidden_states, None, forward_batch, None
|
||||
|
||||
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
|
||||
|
||||
@@ -1225,7 +1274,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm:
|
||||
if (
|
||||
(not isinstance(hidden_states, tuple))
|
||||
and hidden_states.shape[0] <= 16
|
||||
and self.use_min_latency_fused_a_gemm
|
||||
):
|
||||
fused_qkv_a_proj_out = dsv3_fused_a_gemm(
|
||||
hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
|
||||
)
|
||||
@@ -1245,8 +1298,18 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
k_nope = self.kv_a_layernorm(k_nope)
|
||||
current_stream.wait_stream(self.alt_stream)
|
||||
else:
|
||||
q = self.q_a_layernorm(q)
|
||||
k_nope = self.kv_a_layernorm(k_nope)
|
||||
if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
|
||||
q, k_nope = fused_rms_mxfp4_quant(
|
||||
q,
|
||||
self.q_a_layernorm.weight,
|
||||
self.q_a_layernorm.variance_epsilon,
|
||||
k_nope,
|
||||
self.kv_a_layernorm.weight,
|
||||
self.kv_a_layernorm.variance_epsilon,
|
||||
)
|
||||
else:
|
||||
q = self.q_a_layernorm(q)
|
||||
k_nope = self.kv_a_layernorm(k_nope)
|
||||
|
||||
k_nope = k_nope.unsqueeze(1)
|
||||
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
||||
@@ -1278,10 +1341,27 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
q_nope_out = q_nope_out[:, :expected_m, :]
|
||||
elif _is_hip:
|
||||
# TODO(haishaw): add bmm_fp8 to ROCm
|
||||
q_nope_out = torch.bmm(
|
||||
q_nope.to(torch.bfloat16).transpose(0, 1),
|
||||
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
||||
)
|
||||
if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
|
||||
x = q_nope.transpose(0, 1)
|
||||
q_nope_out = torch.empty(
|
||||
x.shape[0],
|
||||
x.shape[1],
|
||||
self.w_kc.shape[2],
|
||||
device=x.device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
batched_gemm_afp4wfp4_pre_quant(
|
||||
x,
|
||||
self.w_kc.transpose(-2, -1),
|
||||
self.w_scale_k.transpose(-2, -1),
|
||||
torch.bfloat16,
|
||||
q_nope_out,
|
||||
)
|
||||
else:
|
||||
q_nope_out = torch.bmm(
|
||||
q_nope.to(torch.bfloat16).transpose(0, 1),
|
||||
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
||||
)
|
||||
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
||||
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
||||
q_nope.transpose(0, 1),
|
||||
@@ -1295,13 +1375,15 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
|
||||
q_nope_out = q_nope_out.transpose(0, 1)
|
||||
|
||||
if not self._fuse_rope_for_trtllm_mla(forward_batch):
|
||||
if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
|
||||
not _use_aiter or not _is_gfx95_supported
|
||||
):
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
|
||||
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
||||
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
|
||||
|
||||
def forward_absorb_core(
|
||||
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
||||
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
|
||||
):
|
||||
if (
|
||||
self.current_attention_backend == "fa3"
|
||||
@@ -1326,8 +1408,23 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
**extra_args,
|
||||
)
|
||||
else:
|
||||
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
||||
k = torch.cat([k_nope, k_pe], dim=-1)
|
||||
if _use_aiter_gfx95:
|
||||
cos = self.rotary_emb.cos_cache
|
||||
sin = self.rotary_emb.sin_cache
|
||||
q, k = fused_qk_rope_cat(
|
||||
q_nope_out,
|
||||
q_pe,
|
||||
k_nope,
|
||||
k_pe,
|
||||
positions,
|
||||
cos,
|
||||
sin,
|
||||
self.rotary_emb.is_neox_style,
|
||||
)
|
||||
else:
|
||||
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
||||
k = torch.cat([k_nope, k_pe], dim=-1)
|
||||
|
||||
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
|
||||
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||
|
||||
@@ -1352,11 +1449,34 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
)
|
||||
elif _is_hip:
|
||||
# TODO(haishaw): add bmm_fp8 to ROCm
|
||||
attn_bmm_output = torch.bmm(
|
||||
attn_output.to(torch.bfloat16).transpose(0, 1),
|
||||
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
||||
)
|
||||
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
||||
if _use_aiter_gfx95 and self.w_vc.dtype == torch.uint8:
|
||||
x = attn_output.transpose(0, 1)
|
||||
attn_bmm_output = torch.empty(
|
||||
x.shape[0],
|
||||
x.shape[1],
|
||||
self.w_vc.shape[2],
|
||||
device=x.device,
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
batched_gemm_afp4wfp4_pre_quant(
|
||||
x,
|
||||
self.w_vc.transpose(-2, -1),
|
||||
self.w_scale_v.transpose(-2, -1),
|
||||
torch.bfloat16,
|
||||
attn_bmm_output,
|
||||
)
|
||||
else:
|
||||
attn_bmm_output = torch.bmm(
|
||||
attn_output.to(torch.bfloat16).transpose(0, 1),
|
||||
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
||||
)
|
||||
|
||||
if self.o_proj.weight.dtype == torch.uint8:
|
||||
attn_bmm_output = attn_bmm_output.transpose(0, 1)
|
||||
attn_bmm_output = fused_flatten_mxfp4_quant(attn_bmm_output)
|
||||
else:
|
||||
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
||||
|
||||
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
||||
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
||||
attn_output.transpose(0, 1),
|
||||
@@ -1866,10 +1986,21 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
zero_allocator: BumpAllocator,
|
||||
gemm_output_zero_allocator: BumpAllocator = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
quant_format = (
|
||||
"mxfp4"
|
||||
if _is_gfx95_supported
|
||||
and self.self_attn.fused_qkv_a_proj_with_mqa.weight == torch.uint8
|
||||
else ""
|
||||
)
|
||||
|
||||
hidden_states, residual = self.layer_communicator.prepare_attn(
|
||||
hidden_states, residual, forward_batch
|
||||
hidden_states,
|
||||
residual,
|
||||
forward_batch,
|
||||
quant_format,
|
||||
)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
@@ -1893,8 +2024,16 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
||||
forward_batch
|
||||
)
|
||||
|
||||
if isinstance(self.mlp, DeepseekV2MLP):
|
||||
gemm_output_zero_allocator = None
|
||||
|
||||
hidden_states = self.mlp(
|
||||
hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
|
||||
hidden_states,
|
||||
forward_batch,
|
||||
should_allreduce_fusion,
|
||||
use_reduce_scatter,
|
||||
gemm_output_zero_allocator,
|
||||
)
|
||||
|
||||
if should_allreduce_fusion:
|
||||
@@ -2038,6 +2177,37 @@ class DeepseekV2Model(nn.Module):
|
||||
else:
|
||||
self.norm = PPMissingLayer(return_tuple=True)
|
||||
|
||||
self.gemm_output_zero_allocator_size = 0
|
||||
if (
|
||||
_use_aiter_gfx95
|
||||
and config.n_routed_experts == 256
|
||||
and self.embed_tokens.embedding_dim == 7168
|
||||
):
|
||||
num_moe_layers = sum(
|
||||
[
|
||||
1
|
||||
for i in range(len(self.layers))
|
||||
if isinstance(self.layers[i].mlp, DeepseekV2MoE)
|
||||
]
|
||||
)
|
||||
|
||||
allocate_size = 0
|
||||
for i in range(len(self.layers)):
|
||||
if isinstance(self.layers[i].mlp, DeepseekV2MoE):
|
||||
allocate_size = self.layers[
|
||||
i
|
||||
].mlp.shared_experts.gate_up_proj.output_size_per_partition
|
||||
break
|
||||
|
||||
self.gemm_output_zero_allocator_size = (
|
||||
get_dsv3_gemm_output_zero_allocator_size(
|
||||
config.n_routed_experts,
|
||||
num_moe_layers,
|
||||
allocate_size,
|
||||
self.embed_tokens.embedding_dim,
|
||||
)
|
||||
)
|
||||
|
||||
def get_input_embeddings(self) -> torch.Tensor:
|
||||
return self.embed_tokens
|
||||
|
||||
@@ -2057,6 +2227,21 @@ class DeepseekV2Model(nn.Module):
|
||||
device=device,
|
||||
)
|
||||
|
||||
has_gemm_output_zero_allocator = hasattr(
|
||||
self, "gemm_output_zero_allocator_size"
|
||||
)
|
||||
|
||||
gemm_output_zero_allocator = (
|
||||
BumpAllocator(
|
||||
buffer_size=self.gemm_output_zero_allocator_size,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
if has_gemm_output_zero_allocator
|
||||
and self.gemm_output_zero_allocator_size > 0
|
||||
else None
|
||||
)
|
||||
|
||||
if self.pp_group.is_first_rank:
|
||||
if input_embeds is None:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
@@ -2083,7 +2268,12 @@ class DeepseekV2Model(nn.Module):
|
||||
with get_global_expert_distribution_recorder().with_current_layer(i):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, forward_batch, residual, zero_allocator
|
||||
positions,
|
||||
hidden_states,
|
||||
forward_batch,
|
||||
residual,
|
||||
zero_allocator,
|
||||
gemm_output_zero_allocator,
|
||||
)
|
||||
|
||||
if normal_end_layer != self.end_layer:
|
||||
@@ -2356,6 +2546,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
w_kc, w_vc = w.unflatten(
|
||||
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
||||
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
||||
|
||||
if _use_aiter_gfx95 and self.quant_config.get_name() == "quark":
|
||||
w_kc, self_attn.w_scale_k, w_vc, self_attn.w_scale_v = (
|
||||
quark_post_load_weights(self_attn, w, "mxfp4")
|
||||
)
|
||||
|
||||
if not use_deep_gemm_bmm:
|
||||
self_attn.w_kc = bind_or_assign(
|
||||
self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
||||
|
||||
Reference in New Issue
Block a user