[AMD] Fix Llama 4 FP8 accuracy issues on MI300X (#7699)
This commit is contained in:
@@ -52,7 +52,6 @@ if not (_is_npu or _is_hip):
|
||||
if _use_aiter:
|
||||
from aiter import ActivationType, QuantType
|
||||
from aiter.fused_moe import fused_moe
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
141
python/sglang/srt/layers/moe/rocm_moe_utils.py
Normal file
141
python/sglang/srt/layers/moe/rocm_moe_utils.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.9.1rc2/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from enum import IntEnum
|
||||
from functools import cache
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import direct_register_custom_op, get_bool_env_var, is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
|
||||
class ActivationMethod(IntEnum):
|
||||
# This allows interfacing with AITER ActivationType enum
|
||||
# without importing the ActivationType enum from AITER globally.
|
||||
SILU = 0
|
||||
GELU = 1
|
||||
|
||||
|
||||
def rocm_aiter_asm_moe_tkw1_impl(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
fc1_scale: Optional[torch.Tensor] = None,
|
||||
fc2_scale: Optional[torch.Tensor] = None,
|
||||
fc1_smooth_scale: Optional[torch.Tensor] = None,
|
||||
fc2_smooth_scale: Optional[torch.Tensor] = None,
|
||||
a16: bool = False,
|
||||
per_tensor_quant_scale: Optional[torch.Tensor] = None,
|
||||
expert_mask: Optional[torch.Tensor] = None,
|
||||
activation_method: int = ActivationMethod.SILU.value,
|
||||
) -> torch.Tensor:
|
||||
|
||||
from aiter import ActivationType
|
||||
from aiter.fused_moe_bf16_asm import asm_moe_tkw1
|
||||
|
||||
activation = ActivationType(activation_method)
|
||||
|
||||
return asm_moe_tkw1(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
fc1_scale=fc1_scale,
|
||||
fc2_scale=fc2_scale,
|
||||
fc1_smooth_scale=fc1_smooth_scale,
|
||||
fc2_smooth_scale=fc2_smooth_scale,
|
||||
a16=a16,
|
||||
per_tensor_quant_scale=per_tensor_quant_scale,
|
||||
expert_mask=expert_mask,
|
||||
activation=activation,
|
||||
)
|
||||
|
||||
|
||||
def rocm_aiter_asm_moe_tkw1_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
fc1_scale: Optional[torch.Tensor] = None,
|
||||
fc2_scale: Optional[torch.Tensor] = None,
|
||||
fc1_smooth_scale: Optional[torch.Tensor] = None,
|
||||
fc2_smooth_scale: Optional[torch.Tensor] = None,
|
||||
a16: bool = False,
|
||||
per_tensor_quant_scale: Optional[torch.Tensor] = None,
|
||||
expert_mask: Optional[torch.Tensor] = None,
|
||||
activation_method: int = ActivationMethod.SILU.value,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
|
||||
if _use_aiter:
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_asm_moe_tkw1",
|
||||
op_func=rocm_aiter_asm_moe_tkw1_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=rocm_aiter_asm_moe_tkw1_fake,
|
||||
)
|
||||
|
||||
|
||||
def rocm_fused_experts_tkw1(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
activation_method = (
|
||||
ActivationMethod.SILU if activation == "silu" else ActivationMethod.GELU
|
||||
)
|
||||
# All AITER Fused MoE kernels are expecting the following datatypes
|
||||
topk_weights = topk_weights.to(torch.float32)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
|
||||
# w8a8 per-channel quantization
|
||||
if per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
|
||||
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
|
||||
# This applies topk_weights on the GEMM output of the first FC layer
|
||||
# rather than the second FC.
|
||||
assert (
|
||||
topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
assert topk_weights.shape[-1] == 1, (
|
||||
"Only support topk=1 when" " `apply_router_weight_on_input` is True"
|
||||
)
|
||||
|
||||
return torch.ops.sglang.rocm_aiter_asm_moe_tkw1(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
fc1_scale=w1_scale,
|
||||
fc2_scale=w2_scale,
|
||||
fc1_smooth_scale=None,
|
||||
fc2_smooth_scale=None,
|
||||
a16=False,
|
||||
per_tensor_quant_scale=None,
|
||||
expert_mask=None,
|
||||
activation_method=activation_method,
|
||||
)
|
||||
else:
|
||||
assert False, "This should not be called."
|
||||
@@ -19,7 +19,14 @@ from sglang.srt.layers.quantization.utils import (
|
||||
per_tensor_dequantize,
|
||||
replace_parameter,
|
||||
)
|
||||
from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
is_cpu,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
is_npu,
|
||||
set_weight_attrs,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
@@ -29,6 +36,13 @@ if TYPE_CHECKING:
|
||||
CompressedTensorsConfig,
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
if _use_aiter:
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
|
||||
from sglang.srt.layers.moe.rocm_moe_utils import rocm_fused_experts_tkw1
|
||||
|
||||
try:
|
||||
import vllm
|
||||
@@ -265,6 +279,20 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
max_w13_scales, requires_grad=False
|
||||
)
|
||||
|
||||
if _use_aiter:
|
||||
with torch.no_grad():
|
||||
# Pre-shuffle weights
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
layer.w2_weight = torch.nn.Parameter(
|
||||
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -274,20 +302,43 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import fused_experts
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
moe_runner_config=moe_runner_config,
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=self.weight_quant.strategy
|
||||
== QuantizationStrategy.CHANNEL,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
if (
|
||||
_use_aiter
|
||||
and self.weight_quant.strategy == QuantizationStrategy.CHANNEL
|
||||
and moe_runner_config.apply_router_weight_on_input
|
||||
):
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
return rocm_fused_experts_tkw1(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=moe_runner_config.activation,
|
||||
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=self.weight_quant.strategy
|
||||
== QuantizationStrategy.CHANNEL,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
else:
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_output=topk_output,
|
||||
moe_runner_config=moe_runner_config,
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=self.weight_quant.strategy
|
||||
== QuantizationStrategy.CHANNEL,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
)
|
||||
|
||||
|
||||
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
@@ -966,6 +966,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
requires_grad=False,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# ROCm (_use_aiter): using column-wise scaling
|
||||
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
|
||||
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
|
||||
|
||||
@@ -2228,7 +2228,10 @@ class ServerArgs:
|
||||
# use bf16 for mxfp4 triton kernels
|
||||
self.dtype = "bfloat16"
|
||||
elif "Llama4" in model_arch:
|
||||
assert self.attention_backend == "fa3", "fa3 is required for Llama4 model"
|
||||
assert self.attention_backend in {
|
||||
"fa3",
|
||||
"aiter",
|
||||
}, "fa3 or aiter is required for Llama4 model"
|
||||
elif model_arch in [
|
||||
"Gemma2ForCausalLM",
|
||||
"Gemma3ForCausalLM",
|
||||
|
||||
Reference in New Issue
Block a user