diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 01fdf686a..18ac91464 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -52,7 +52,6 @@ if not (_is_npu or _is_hip): if _use_aiter: from aiter import ActivationType, QuantType from aiter.fused_moe import fused_moe - from aiter.ops.shuffle import shuffle_weight logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/layers/moe/rocm_moe_utils.py b/python/sglang/srt/layers/moe/rocm_moe_utils.py new file mode 100644 index 000000000..5fe2de1e5 --- /dev/null +++ b/python/sglang/srt/layers/moe/rocm_moe_utils.py @@ -0,0 +1,141 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.9.1rc2/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from enum import IntEnum +from functools import cache +from typing import Optional + +import torch + +from sglang.srt.utils import direct_register_custom_op, get_bool_env_var, is_hip + +_is_hip = is_hip() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip + + +class ActivationMethod(IntEnum): + # This allows interfacing with AITER ActivationType enum + # without importing the ActivationType enum from AITER globally. + SILU = 0 + GELU = 1 + + +def rocm_aiter_asm_moe_tkw1_impl( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: Optional[torch.Tensor] = None, + fc2_scale: Optional[torch.Tensor] = None, + fc1_smooth_scale: Optional[torch.Tensor] = None, + fc2_smooth_scale: Optional[torch.Tensor] = None, + a16: bool = False, + per_tensor_quant_scale: Optional[torch.Tensor] = None, + expert_mask: Optional[torch.Tensor] = None, + activation_method: int = ActivationMethod.SILU.value, +) -> torch.Tensor: + + from aiter import ActivationType + from aiter.fused_moe_bf16_asm import asm_moe_tkw1 + + activation = ActivationType(activation_method) + + return asm_moe_tkw1( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + fc1_scale=fc1_scale, + fc2_scale=fc2_scale, + fc1_smooth_scale=fc1_smooth_scale, + fc2_smooth_scale=fc2_smooth_scale, + a16=a16, + per_tensor_quant_scale=per_tensor_quant_scale, + expert_mask=expert_mask, + activation=activation, + ) + + +def rocm_aiter_asm_moe_tkw1_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + fc1_scale: Optional[torch.Tensor] = None, + fc2_scale: Optional[torch.Tensor] = None, + fc1_smooth_scale: Optional[torch.Tensor] = None, + fc2_smooth_scale: Optional[torch.Tensor] = None, + a16: bool = False, + per_tensor_quant_scale: Optional[torch.Tensor] = None, + expert_mask: Optional[torch.Tensor] = None, + activation_method: int = ActivationMethod.SILU.value, +) -> torch.Tensor: + return torch.empty_like(hidden_states) + + +if _use_aiter: + + direct_register_custom_op( + op_name="rocm_aiter_asm_moe_tkw1", + op_func=rocm_aiter_asm_moe_tkw1_impl, + mutates_args=[], + fake_impl=rocm_aiter_asm_moe_tkw1_fake, + ) + + +def rocm_fused_experts_tkw1( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + per_channel_quant: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[list[int]] = None, +) -> torch.Tensor: + + activation_method = ( + ActivationMethod.SILU if activation == "silu" else ActivationMethod.GELU + ) + # All AITER Fused MoE kernels are expecting the following datatypes + topk_weights = topk_weights.to(torch.float32) + topk_ids = topk_ids.to(torch.int32) + + # w8a8 per-channel quantization + if per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8: + # AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input` + # This applies topk_weights on the GEMM output of the first FC layer + # rather than the second FC. + assert ( + topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + assert topk_weights.shape[-1] == 1, ( + "Only support topk=1 when" " `apply_router_weight_on_input` is True" + ) + + return torch.ops.sglang.rocm_aiter_asm_moe_tkw1( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + fc1_scale=w1_scale, + fc2_scale=w2_scale, + fc1_smooth_scale=None, + fc2_smooth_scale=None, + a16=False, + per_tensor_quant_scale=None, + expert_mask=None, + activation_method=activation_method, + ) + else: + assert False, "This should not be called." diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index c10515107..320a7ba87 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -19,7 +19,14 @@ from sglang.srt.layers.quantization.utils import ( per_tensor_dequantize, replace_parameter, ) -from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs +from sglang.srt.utils import ( + get_bool_env_var, + is_cpu, + is_cuda, + is_hip, + is_npu, + set_weight_attrs, +) if TYPE_CHECKING: from sglang.srt.layers.moe.fused_moe_triton import FusedMoE @@ -29,6 +36,13 @@ if TYPE_CHECKING: CompressedTensorsConfig, ) +_is_hip = is_hip() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip + +if _use_aiter: + from aiter.ops.shuffle import shuffle_weight + + from sglang.srt.layers.moe.rocm_moe_utils import rocm_fused_experts_tkw1 try: import vllm @@ -265,6 +279,20 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): max_w13_scales, requires_grad=False ) + if _use_aiter: + with torch.no_grad(): + # Pre-shuffle weights + layer.w13_weight = torch.nn.Parameter( + shuffle_weight(layer.w13_weight.data, (16, 16)), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + shuffle_weight(layer.w2_weight.data, (16, 16)), + requires_grad=False, + ) + torch.cuda.empty_cache() + def apply( self, layer: torch.nn.Module, @@ -274,20 +302,43 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton import fused_experts - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_output=topk_output, - moe_runner_config=moe_runner_config, - use_fp8_w8a8=True, - per_channel_quant=self.weight_quant.strategy - == QuantizationStrategy.CHANNEL, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - ) + if ( + _use_aiter + and self.weight_quant.strategy == QuantizationStrategy.CHANNEL + and moe_runner_config.apply_router_weight_on_input + ): + topk_weights, topk_ids, _ = topk_output + return rocm_fused_experts_tkw1( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=moe_runner_config.activation, + apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input, + use_fp8_w8a8=True, + per_channel_quant=self.weight_quant.strategy + == QuantizationStrategy.CHANNEL, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) + else: + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_output=topk_output, + moe_runner_config=moe_runner_config, + use_fp8_w8a8=True, + per_channel_quant=self.weight_quant.strategy + == QuantizationStrategy.CHANNEL, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 0192da7ef..6a199c8f1 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -966,6 +966,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): requires_grad=False, ) torch.cuda.empty_cache() + # ROCm (_use_aiter): using column-wise scaling layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1) layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 32f0caa38..fcdaa263e 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2228,7 +2228,10 @@ class ServerArgs: # use bf16 for mxfp4 triton kernels self.dtype = "bfloat16" elif "Llama4" in model_arch: - assert self.attention_backend == "fa3", "fa3 is required for Llama4 model" + assert self.attention_backend in { + "fa3", + "aiter", + }, "fa3 or aiter is required for Llama4 model" elif model_arch in [ "Gemma2ForCausalLM", "Gemma3ForCausalLM",