From e96973742c326a129da772a115bdeb925643d95a Mon Sep 17 00:00:00 2001 From: kk <43161300+kkHuang-amd@users.noreply.github.com> Date: Fri, 5 Sep 2025 06:11:22 +0800 Subject: [PATCH] Optimized deepseek-v3/r1 model performance on mxfp4 run (#10008) Co-authored-by: wunhuang Co-authored-by: HAI Co-authored-by: Hubert Lu <55214931+hubertlu-tw@users.noreply.github.com> --- python/sglang/srt/layers/communicator.py | 40 ++- .../quark/schemes/quark_w4a4_mxfp4.py | 73 +++-- .../srt/layers/quantization/quark/utils.py | 97 +++++++ .../layers/quantization/rocm_mxfp4_utils.py | 13 + python/sglang/srt/layers/rocm_linear_utils.py | 44 +++ python/sglang/srt/models/deepseek_v2.py | 260 +++++++++++++++--- python/sglang/srt/models/glm4_moe.py | 11 +- python/sglang/srt/utils.py | 12 + 8 files changed, 486 insertions(+), 64 deletions(-) create mode 100644 python/sglang/srt/layers/quantization/rocm_mxfp4_utils.py create mode 100644 python/sglang/srt/layers/rocm_linear_utils.py diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 320e87962..fba8d8f18 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -43,8 +43,11 @@ from sglang.srt.layers.moe import ( from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.utils import ( + get_bool_env_var, is_cuda, is_flashinfer_available, + is_gfx95_supported, + is_hip, is_sm90_supported, is_sm100_supported, ) @@ -52,6 +55,11 @@ from sglang.srt.utils import ( _is_flashinfer_available = is_flashinfer_available() _is_sm90_supported = is_cuda() and is_sm90_supported() _is_sm100_supported = is_cuda() and is_sm100_supported() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip() +_is_gfx95_supported = is_gfx95_supported() + +if _use_aiter and _is_gfx95_supported: + from sglang.srt.layers.quantization.rocm_mxfp4_utils import fused_rms_mxfp4_quant FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048 @@ -207,6 +215,7 @@ class LayerCommunicator: hidden_states: torch.Tensor, residual: torch.Tensor, forward_batch: ForwardBatch, + qaunt_format: str = "", ): if hidden_states.shape[0] == 0: residual = hidden_states @@ -224,11 +233,34 @@ class LayerCommunicator: else: if residual is None: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + + if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format): + hidden_states = fused_rms_mxfp4_quant( + hidden_states, + self.input_layernorm.weight, + self.input_layernorm.variance_epsilon, + None, + None, + None, + None, + ) + else: + hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual - ) + if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format): + hidden_states, residual = fused_rms_mxfp4_quant( + hidden_states, + self.input_layernorm.weight, + self.input_layernorm.variance_epsilon, + None, + None, + None, + residual, + ) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual + ) hidden_states = self._communicate_simple_fn( hidden_states=hidden_states, diff --git a/python/sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py b/python/sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py index e5fc22797..a0787baaf 100644 --- a/python/sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +++ b/python/sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from aiter.ops.gemm_op_a4w4 import gemm_a4w4 from aiter.ops.shuffle import shuffle_weight from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4 +from aiter.ops.triton.gemm_afp4wfp4_pre_quant_atomic import gemm_afp4wfp4_pre_quant from aiter.ops.triton.quant import dynamic_mxfp4_quant from aiter.utility import dtypes from aiter.utility.fp4_utils import e8m0_shuffle @@ -38,15 +39,6 @@ class QuarkW4A4MXFP4(QuarkScheme): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: return - # for aiter implement - # wshuffle = shuffle_weight(layer.weight.data, layout=(16, 16)) - # w_scales_shuffle = e8m0_shuffle(layer.weight_scale.data).view(dtypes.fp8_e8m0) - - # layer.weight = torch.nn.Parameter(wshuffle, - # requires_grad=False) - # layer.weight_scale = torch.nn.Parameter(w_scales_shuffle, - # requires_grad=False) - def create_weights( self, layer: torch.nn.Module, @@ -93,26 +85,53 @@ class QuarkW4A4MXFP4(QuarkScheme): x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + # This path does not have support for bias currently + assert bias is None, "bias is not supported" - out_dtype = x.dtype - # M = x.shape[0] - # N = layer.weight.shape[0] + three_d = False + x_s = None + y = None + if isinstance(x, tuple): + assert len(x) in [ + 2, + 3, + ], "For tuple input, only (x, x_s) or (x, x_s, y) formats are accepted" + if len(x) == 2: + x, x_s = x + elif len(x) == 3: + x, x_s, y = x - # quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32) - # x, x_scales_shuffle = quant_func(x, shuffle=True) - - # y = torch.zeros((M + 255) // 256 * 256, N, device=x.device, dtype=self.out_dtype) - - # out = gemm_a4w4(x, layer.weight.data, x_scales_shuffle, layer.weight_scale.data, y, bias=bias) - - # return out[:M] - - # triton implement - x_q, x_s = dynamic_mxfp4_quant(x) - y = torch.empty( - x_q.shape[0], layer.weight.shape[0], device=x_q.device, dtype=out_dtype + use_fused_quant_gemm = ( + x_s is None and y is not None and layer.weight.shape[0] == y.shape[1] ) - out = gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, out_dtype, y) + if x.dim() == 3: + three_d = True + x = x.view(-1, x.shape[-1]) + output_shape = [*x.shape[:-1], layer.weight.shape[0]] - return out + # use_fused_quant_gemm = true, x_q is a bf16/fp16 num + # x_s is not None = true, x_q is uint8 num + if use_fused_quant_gemm or x_s is not None: + x_q = x + else: + x_q, x_s = dynamic_mxfp4_quant(x) + + if y is None: + y = torch.empty( + x_q.shape[0], + layer.weight.shape[0], + device=x_q.device, + dtype=self.out_dtype, + ) + + if use_fused_quant_gemm: + gemm_afp4wfp4_pre_quant(x_q, layer.weight, layer.weight_scale, y.dtype, y) + y = y.to(x.dtype) + else: + gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, self.out_dtype, y) + + if three_d: + return y.view(*output_shape) + + return y diff --git a/python/sglang/srt/layers/quantization/quark/utils.py b/python/sglang/srt/layers/quantization/quark/utils.py index 5ea91b5d8..eacbf3ba9 100644 --- a/python/sglang/srt/layers/quantization/quark/utils.py +++ b/python/sglang/srt/layers/quantization/quark/utils.py @@ -5,6 +5,10 @@ from collections.abc import Iterable, Mapping from types import MappingProxyType from typing import Any, Optional +import torch +from aiter.ops.triton.quant import dynamic_mxfp4_quant +from torch import nn + def deep_compare(dict1: Any, dict2: Any) -> bool: if type(dict1) is not type(dict2): @@ -105,3 +109,96 @@ def _is_equal_or_regex_match( elif target == value: return True return False + + +# utility for tensor dims > 2 cases +def b_dynamic_mxfp4_quant(x): + h, b, d = x.shape + x, x_scales = dynamic_mxfp4_quant(x.reshape(-1, d)) + return x.view(h, b, d // 2), x_scales.view(h, b, d // 32) + + +def mxfp4_to_f32(x, is_threed): + # 2 because we pack fp4 in uint8. + x = x.repeat_interleave(2, dim=-1) + if is_threed: + x[..., ::2] = x[..., ::2] & 0xF + x[..., 1::2] = x[..., 1::2] >> 4 + else: + x[:, ::2] = x[:, ::2] & 0xF + x[:, 1::2] = x[:, 1::2] >> 4 + + mxfp4_list = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ] + mxfp4_in_f32 = torch.tensor(mxfp4_list, dtype=torch.float32, device="cuda") + return mxfp4_in_f32[x.long()] + + +def e8m0_to_f32(x): + # Convert the input tensor `x` (assumed to be in e8m0 format) to float32. + # e8m0 is a custom 8-bit floating point format with 8 bits for exponent, 0 for mantissa. + # This means the value is essentially 2^(exponent - 127), similar to how IEEE-754 stores floats. + + # Convert x to float32 for computation, and compute the power of 2 by subtracting the bias (127). + x_f32 = 2 ** ((x.to(torch.float32)) - 127) + + # If the exponent value was 255 (i.e., 2^(128)), this is a special case usually used to represent NaN or Inf. + # Since this custom format has no mantissa, treat 2^128 as NaN. + x_f32[x_f32 == 128] = float("nan") + return x_f32 + + +def quark_post_load_weights(self_attn: nn.Module, w: torch.Tensor, quant_format: str): + if "mxfp4" in quant_format: + # when dtype is bf16, the processing flow is to dynamic quantize bf16 tensor to uint8 tensor + # do w_kc (bf16) first to get the w_kc(uint8) w_s_kc(uint8) + # and w_vc repeating the same procedure of w_kc to get w_vc(uint8) w_s_vc(uint8) + if w.dtype == torch.bfloat16: + w_kc, w_vc = w.unflatten( + 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) + ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1)) + w_kc = w_kc.transpose(-2, -1) + w_s_kc = w_s_kc.transpose(-2, -1) + w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc) + w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2) + w_s_vc = w_s_vc.contiguous().transpose(1, 2) + elif w.dtype == torch.uint8: # static quant for mxfp4 + # when dtype is uint8, it means the w has been quantized to mxfp4 format + # but we must separate it to w_kc and w_vc. + # The quantized tensor size is only half of original tensor size + # and the scaling factor is 1/32, the transpose behavior will be not correct + # need to upcast it to fp32 to separate w to w_kc and w_vc + # to ensure the following transpose behavior is correct + # and then do mxfp4 quant again + w = mxfp4_to_f32(w, True).to(torch.bfloat16) + w_scales = self_attn.kv_b_proj.weight_scale.repeat_interleave(32, dim=-1) + w_scales = e8m0_to_f32(w_scales).to(torch.bfloat16) + w = w * w_scales + w_kc, w_vc = w.unflatten( + 0, (-1, (self_attn.qk_nope_head_dim + self_attn.v_head_dim)) + ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1)) + w_kc = w_kc.transpose(-2, -1) + w_s_kc = w_s_kc.transpose(-2, -1) + w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc) + w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2) + w_s_vc = w_s_vc.contiguous().transpose(1, 2) + + return w_kc, w_s_kc, w_vc, w_s_vc diff --git a/python/sglang/srt/layers/quantization/rocm_mxfp4_utils.py b/python/sglang/srt/layers/quantization/rocm_mxfp4_utils.py new file mode 100644 index 000000000..4659f76bd --- /dev/null +++ b/python/sglang/srt/layers/quantization/rocm_mxfp4_utils.py @@ -0,0 +1,13 @@ +from aiter.ops.triton.batched_gemm_afp4wfp4_pre_quant import ( + batched_gemm_afp4wfp4_pre_quant, +) +from aiter.ops.triton.fused_mxfp4_quant import ( + fused_flatten_mxfp4_quant, + fused_rms_mxfp4_quant, +) + +__all__ = [ + "fused_rms_mxfp4_quant", + "fused_flatten_mxfp4_quant", + "batched_gemm_afp4wfp4_pre_quant", +] diff --git a/python/sglang/srt/layers/rocm_linear_utils.py b/python/sglang/srt/layers/rocm_linear_utils.py new file mode 100644 index 000000000..ee7dd1f59 --- /dev/null +++ b/python/sglang/srt/layers/rocm_linear_utils.py @@ -0,0 +1,44 @@ +import torch +from aiter.ops.triton.fused_qk_concat import fused_qk_rope_cat +from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 +from aiter.ops.triton.gemm_a16w16_atomic import gemm_a16w16_atomic + +from sglang.srt.utils import BumpAllocator + +__all__ = ["fused_qk_rope_cat"] + + +def aiter_dsv3_router_gemm( + hidden_states: torch.Tensor, + weight: torch.Tensor, + gemm_output_zero_allocator: BumpAllocator = None, +): + M = hidden_states.shape[0] + N = weight.shape[0] + y = None + + if M <= 256: + # TODO (cagri): convert to bfloat16 as part of another kernel to save time + # for now it is also coupled with zero allocator. + if gemm_output_zero_allocator != None: + y = gemm_output_zero_allocator.allocate(M * N).view(M, N) + else: + y = torch.zeros((M, N), dtype=torch.float32, device=hidden_states.device) + + if y is not None: + logits = gemm_a16w16_atomic(hidden_states, weight, y=y).to(hidden_states.dtype) + else: + logits = gemm_a16w16(hidden_states, weight) + + return logits + + +def get_dsv3_gemm_output_zero_allocator_size( + n_routed_experts: int, num_moe_layers: int, allocate_size: int, embedding_dim: int +): + if embedding_dim != 7168 or n_routed_experts != 256: + return 0 + + per_layer_size = 256 * (allocate_size + n_routed_experts) + + return num_moe_layers * per_layer_size diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 147925f88..a2296b569 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -112,6 +112,7 @@ from sglang.srt.utils import ( is_cpu, is_cuda, is_flashinfer_available, + is_gfx95_supported, is_hip, is_non_idle_and_non_empty, is_npu, @@ -129,6 +130,22 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() _device_sm = get_device_sm() +_is_gfx95_supported = is_gfx95_supported() + +_use_aiter_gfx95 = _use_aiter and _is_gfx95_supported + +if _use_aiter_gfx95: + from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights + from sglang.srt.layers.quantization.rocm_mxfp4_utils import ( + batched_gemm_afp4wfp4_pre_quant, + fused_flatten_mxfp4_quant, + fused_rms_mxfp4_quant, + ) + from sglang.srt.layers.rocm_linear_utils import ( + aiter_dsv3_router_gemm, + fused_qk_rope_cat, + get_dsv3_gemm_output_zero_allocator_size, + ) if _is_cuda: from sgl_kernel import ( @@ -224,10 +241,17 @@ class DeepseekV2MLP(nn.Module): forward_batch=None, should_allreduce_fusion: bool = False, use_reduce_scatter: bool = False, + gemm_output_zero_allocator: BumpAllocator = None, ): if (self.tp_size == 1) and x.shape[0] == 0: return x + if gemm_output_zero_allocator != None and x.shape[0] <= 256: + y = gemm_output_zero_allocator.allocate( + x.shape[0] * self.gate_up_proj.output_size_per_partition + ).view(x.shape[0], self.gate_up_proj.output_size_per_partition) + x = (x, None, y) + gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj( @@ -257,7 +281,7 @@ class MoEGate(nn.Module): if _is_cpu and _is_cpu_amx_available: self.quant_method = PackWeightMethod(weight_names=["weight"]) - def forward(self, hidden_states): + def forward(self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None): if use_intel_amx_backend(self): return torch.ops.sgl_kernel.weight_packed_linear( hidden_states, @@ -276,6 +300,10 @@ class MoEGate(nn.Module): ): # router gemm output float32 logits = dsv3_router_gemm(hidden_states, self.weight) + elif _use_aiter_gfx95 and hidden_states.shape[0] <= 256: + logits = aiter_dsv3_router_gemm( + hidden_states, self.weight, gemm_output_zero_allocator + ) else: logits = F.linear(hidden_states, self.weight, None) @@ -439,6 +467,7 @@ class DeepseekV2MoE(nn.Module): forward_batch: Optional[ForwardBatch] = None, should_allreduce_fusion: bool = False, use_reduce_scatter: bool = False, + gemm_output_zero_allocator: BumpAllocator = None, ) -> torch.Tensor: if not self._enable_deepep_moe: DUAL_STREAM_TOKEN_THRESHOLD = 1024 @@ -452,12 +481,14 @@ class DeepseekV2MoE(nn.Module): hidden_states, should_allreduce_fusion, use_reduce_scatter, + gemm_output_zero_allocator, ) else: return self.forward_normal( hidden_states, should_allreduce_fusion, use_reduce_scatter, + gemm_output_zero_allocator, ) else: return self.forward_deepep(hidden_states, forward_batch) @@ -467,15 +498,18 @@ class DeepseekV2MoE(nn.Module): hidden_states: torch.Tensor, should_allreduce_fusion: bool = False, use_reduce_scatter: bool = False, + gemm_output_zero_allocator: BumpAllocator = None, ) -> torch.Tensor: current_stream = torch.cuda.current_stream() self.alt_stream.wait_stream(current_stream) - shared_output = self._forward_shared_experts(hidden_states) + shared_output = self._forward_shared_experts( + hidden_states, gemm_output_zero_allocator + ) with torch.cuda.stream(self.alt_stream): # router_logits: (num_tokens, n_experts) - router_logits = self.gate(hidden_states) + router_logits = self.gate(hidden_states, gemm_output_zero_allocator) topk_output = self.topk(hidden_states, router_logits) final_hidden_states = self.experts(hidden_states, topk_output) if not _is_cuda: @@ -502,6 +536,7 @@ class DeepseekV2MoE(nn.Module): hidden_states: torch.Tensor, should_allreduce_fusion: bool = False, use_reduce_scatter: bool = False, + gemm_output_zero_allocator: BumpAllocator = None, ) -> torch.Tensor: if hasattr(self, "shared_experts") and use_intel_amx_backend( self.shared_experts.gate_up_proj @@ -509,9 +544,11 @@ class DeepseekV2MoE(nn.Module): return self.forward_cpu(hidden_states, should_allreduce_fusion) if hidden_states.shape[0] > 0: - shared_output = self._forward_shared_experts(hidden_states) + shared_output = self._forward_shared_experts( + hidden_states, gemm_output_zero_allocator + ) # router_logits: (num_tokens, n_experts) - router_logits = self.gate(hidden_states) + router_logits = self.gate(hidden_states, gemm_output_zero_allocator) topk_output = self.topk(hidden_states, router_logits) else: shared_output = None @@ -631,9 +668,13 @@ class DeepseekV2MoE(nn.Module): return final_hidden_states - def _forward_shared_experts(self, hidden_states): + def _forward_shared_experts( + self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None + ): if self.num_fused_shared_experts == 0: - return self.shared_experts(hidden_states) + return self.shared_experts( + hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator + ) else: return None @@ -1097,11 +1138,19 @@ class DeepseekV2AttentionMLA(nn.Module): if self.attn_mha.kv_b_proj is None: self.attn_mha.kv_b_proj = self.kv_b_proj - if hidden_states.shape[0] == 0: - assert ( - not self.o_proj.reduce_results - ), "short-circuiting allreduce will lead to hangs" - return hidden_states, None, forward_batch, None + # when hidden_states is a tuple of tensors, the tuple will include quantized weight and scale tensor + if isinstance(hidden_states, tuple): + if hidden_states[0].shape[0] == 0: + assert ( + not self.o_proj.reduce_results + ), "short-circuiting allreduce will lead to hangs" + return hidden_states[0] + else: + if hidden_states.shape[0] == 0: + assert ( + not self.o_proj.reduce_results + ), "short-circuiting allreduce will lead to hangs" + return hidden_states, None, forward_batch, None attn_forward_method = self.dispatch_attn_forward_method(forward_batch) @@ -1225,7 +1274,11 @@ class DeepseekV2AttentionMLA(nn.Module): from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode if self.q_lora_rank is not None: - if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm: + if ( + (not isinstance(hidden_states, tuple)) + and hidden_states.shape[0] <= 16 + and self.use_min_latency_fused_a_gemm + ): fused_qkv_a_proj_out = dsv3_fused_a_gemm( hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T ) @@ -1245,8 +1298,18 @@ class DeepseekV2AttentionMLA(nn.Module): k_nope = self.kv_a_layernorm(k_nope) current_stream.wait_stream(self.alt_stream) else: - q = self.q_a_layernorm(q) - k_nope = self.kv_a_layernorm(k_nope) + if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8: + q, k_nope = fused_rms_mxfp4_quant( + q, + self.q_a_layernorm.weight, + self.q_a_layernorm.variance_epsilon, + k_nope, + self.kv_a_layernorm.weight, + self.kv_a_layernorm.variance_epsilon, + ) + else: + q = self.q_a_layernorm(q) + k_nope = self.kv_a_layernorm(k_nope) k_nope = k_nope.unsqueeze(1) q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) @@ -1278,10 +1341,27 @@ class DeepseekV2AttentionMLA(nn.Module): q_nope_out = q_nope_out[:, :expected_m, :] elif _is_hip: # TODO(haishaw): add bmm_fp8 to ROCm - q_nope_out = torch.bmm( - q_nope.to(torch.bfloat16).transpose(0, 1), - self.w_kc.to(torch.bfloat16) * self.w_scale, - ) + if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8: + x = q_nope.transpose(0, 1) + q_nope_out = torch.empty( + x.shape[0], + x.shape[1], + self.w_kc.shape[2], + device=x.device, + dtype=torch.bfloat16, + ) + batched_gemm_afp4wfp4_pre_quant( + x, + self.w_kc.transpose(-2, -1), + self.w_scale_k.transpose(-2, -1), + torch.bfloat16, + q_nope_out, + ) + else: + q_nope_out = torch.bmm( + q_nope.to(torch.bfloat16).transpose(0, 1), + self.w_kc.to(torch.bfloat16) * self.w_scale, + ) elif self.w_kc.dtype == torch.float8_e4m3fn: q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( q_nope.transpose(0, 1), @@ -1295,13 +1375,15 @@ class DeepseekV2AttentionMLA(nn.Module): q_nope_out = q_nope_out.transpose(0, 1) - if not self._fuse_rope_for_trtllm_mla(forward_batch): + if not self._fuse_rope_for_trtllm_mla(forward_batch) and ( + not _use_aiter or not _is_gfx95_supported + ): q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) - return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator + return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions def forward_absorb_core( - self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator + self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions ): if ( self.current_attention_backend == "fa3" @@ -1326,8 +1408,23 @@ class DeepseekV2AttentionMLA(nn.Module): **extra_args, ) else: - q = torch.cat([q_nope_out, q_pe], dim=-1) - k = torch.cat([k_nope, k_pe], dim=-1) + if _use_aiter_gfx95: + cos = self.rotary_emb.cos_cache + sin = self.rotary_emb.sin_cache + q, k = fused_qk_rope_cat( + q_nope_out, + q_pe, + k_nope, + k_pe, + positions, + cos, + sin, + self.rotary_emb.is_neox_style, + ) + else: + q = torch.cat([q_nope_out, q_pe], dim=-1) + k = torch.cat([k_nope, k_pe], dim=-1) + attn_output = self.attn_mqa(q, k, k_nope, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) @@ -1352,11 +1449,34 @@ class DeepseekV2AttentionMLA(nn.Module): ) elif _is_hip: # TODO(haishaw): add bmm_fp8 to ROCm - attn_bmm_output = torch.bmm( - attn_output.to(torch.bfloat16).transpose(0, 1), - self.w_vc.to(torch.bfloat16) * self.w_scale, - ) - attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2) + if _use_aiter_gfx95 and self.w_vc.dtype == torch.uint8: + x = attn_output.transpose(0, 1) + attn_bmm_output = torch.empty( + x.shape[0], + x.shape[1], + self.w_vc.shape[2], + device=x.device, + dtype=torch.bfloat16, + ) + batched_gemm_afp4wfp4_pre_quant( + x, + self.w_vc.transpose(-2, -1), + self.w_scale_v.transpose(-2, -1), + torch.bfloat16, + attn_bmm_output, + ) + else: + attn_bmm_output = torch.bmm( + attn_output.to(torch.bfloat16).transpose(0, 1), + self.w_vc.to(torch.bfloat16) * self.w_scale, + ) + + if self.o_proj.weight.dtype == torch.uint8: + attn_bmm_output = attn_bmm_output.transpose(0, 1) + attn_bmm_output = fused_flatten_mxfp4_quant(attn_bmm_output) + else: + attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2) + elif self.w_vc.dtype == torch.float8_e4m3fn: attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8( attn_output.transpose(0, 1), @@ -1866,10 +1986,21 @@ class DeepseekV2DecoderLayer(nn.Module): forward_batch: ForwardBatch, residual: Optional[torch.Tensor], zero_allocator: BumpAllocator, + gemm_output_zero_allocator: BumpAllocator = None, ) -> torch.Tensor: + quant_format = ( + "mxfp4" + if _is_gfx95_supported + and self.self_attn.fused_qkv_a_proj_with_mqa.weight == torch.uint8 + else "" + ) + hidden_states, residual = self.layer_communicator.prepare_attn( - hidden_states, residual, forward_batch + hidden_states, + residual, + forward_batch, + quant_format, ) hidden_states = self.self_attn( @@ -1893,8 +2024,16 @@ class DeepseekV2DecoderLayer(nn.Module): use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter( forward_batch ) + + if isinstance(self.mlp, DeepseekV2MLP): + gemm_output_zero_allocator = None + hidden_states = self.mlp( - hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter + hidden_states, + forward_batch, + should_allreduce_fusion, + use_reduce_scatter, + gemm_output_zero_allocator, ) if should_allreduce_fusion: @@ -2038,6 +2177,37 @@ class DeepseekV2Model(nn.Module): else: self.norm = PPMissingLayer(return_tuple=True) + self.gemm_output_zero_allocator_size = 0 + if ( + _use_aiter_gfx95 + and config.n_routed_experts == 256 + and self.embed_tokens.embedding_dim == 7168 + ): + num_moe_layers = sum( + [ + 1 + for i in range(len(self.layers)) + if isinstance(self.layers[i].mlp, DeepseekV2MoE) + ] + ) + + allocate_size = 0 + for i in range(len(self.layers)): + if isinstance(self.layers[i].mlp, DeepseekV2MoE): + allocate_size = self.layers[ + i + ].mlp.shared_experts.gate_up_proj.output_size_per_partition + break + + self.gemm_output_zero_allocator_size = ( + get_dsv3_gemm_output_zero_allocator_size( + config.n_routed_experts, + num_moe_layers, + allocate_size, + self.embed_tokens.embedding_dim, + ) + ) + def get_input_embeddings(self) -> torch.Tensor: return self.embed_tokens @@ -2057,6 +2227,21 @@ class DeepseekV2Model(nn.Module): device=device, ) + has_gemm_output_zero_allocator = hasattr( + self, "gemm_output_zero_allocator_size" + ) + + gemm_output_zero_allocator = ( + BumpAllocator( + buffer_size=self.gemm_output_zero_allocator_size, + dtype=torch.float32, + device=device, + ) + if has_gemm_output_zero_allocator + and self.gemm_output_zero_allocator_size > 0 + else None + ) + if self.pp_group.is_first_rank: if input_embeds is None: hidden_states = self.embed_tokens(input_ids) @@ -2083,7 +2268,12 @@ class DeepseekV2Model(nn.Module): with get_global_expert_distribution_recorder().with_current_layer(i): layer = self.layers[i] hidden_states, residual = layer( - positions, hidden_states, forward_batch, residual, zero_allocator + positions, + hidden_states, + forward_batch, + residual, + zero_allocator, + gemm_output_zero_allocator, ) if normal_end_layer != self.end_layer: @@ -2356,6 +2546,12 @@ class DeepseekV2ForCausalLM(nn.Module): w_kc, w_vc = w.unflatten( 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + + if _use_aiter_gfx95 and self.quant_config.get_name() == "quark": + w_kc, self_attn.w_scale_k, w_vc, self_attn.w_scale_v = ( + quark_post_load_weights(self_attn, w, "mxfp4") + ) + if not use_deep_gemm_bmm: self_attn.w_kc = bind_or_assign( self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2) diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index ab118ad9c..5ae5b0af6 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -153,7 +153,13 @@ class Glm4MoeMLP(nn.Module): ) self.act_fn = SiluAndMul() - def forward(self, x, forward_batch=None, should_allreduce_fusion=False): + def forward( + self, + x, + forward_batch=None, + should_allreduce_fusion=False, + gemm_output_zero_allocator: BumpAllocator = None, + ): if (self.tp_size == 1) and x.shape[0] == 0: return x @@ -501,6 +507,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): hidden_states: torch.Tensor, should_allreduce_fusion: bool = False, use_reduce_scatter: bool = False, + gemm_output_zero_allocator: BumpAllocator = None, ) -> torch.Tensor: current_stream = torch.cuda.current_stream() @@ -543,6 +550,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): hidden_states: torch.Tensor, should_allreduce_fusion: bool = False, use_reduce_scatter: bool = False, + gemm_output_zero_allocator: BumpAllocator = None, ) -> torch.Tensor: if hasattr(self, "shared_experts") and use_intel_amx_backend( self.shared_experts.gate_up_proj @@ -666,6 +674,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer): forward_batch: ForwardBatch, residual: Optional[torch.Tensor], zero_allocator: BumpAllocator, + gemm_output_zero_allocator: BumpAllocator = None, ) -> torch.Tensor: hidden_states, residual = self.layer_communicator.prepare_attn( hidden_states, residual, forward_batch diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 6d720df14..cb40266ec 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2900,6 +2900,18 @@ def mxfp_supported(): return False +@lru_cache(maxsize=1) +def is_gfx95_supported(): + """ + Returns whether the current platform supports MX types. + """ + if torch.version.hip: + gcn_arch = torch.cuda.get_device_properties(0).gcnArchName + return any(gfx in gcn_arch for gfx in ["gfx95"]) + else: + return False + + # LoRA-related constants and utilities SUPPORTED_LORA_TARGET_MODULES = [ "q_proj",