[AMD] Fix Llama 4 FP8 accuracy issues on MI300X (#7699)
This commit is contained in:
@@ -52,7 +52,6 @@ if not (_is_npu or _is_hip):
|
|||||||
if _use_aiter:
|
if _use_aiter:
|
||||||
from aiter import ActivationType, QuantType
|
from aiter import ActivationType, QuantType
|
||||||
from aiter.fused_moe import fused_moe
|
from aiter.fused_moe import fused_moe
|
||||||
from aiter.ops.shuffle import shuffle_weight
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
141
python/sglang/srt/layers/moe/rocm_moe_utils.py
Normal file
141
python/sglang/srt/layers/moe/rocm_moe_utils.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.9.1rc2/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from enum import IntEnum
|
||||||
|
from functools import cache
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.utils import direct_register_custom_op, get_bool_env_var, is_hip
|
||||||
|
|
||||||
|
_is_hip = is_hip()
|
||||||
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||||
|
|
||||||
|
|
||||||
|
class ActivationMethod(IntEnum):
|
||||||
|
# This allows interfacing with AITER ActivationType enum
|
||||||
|
# without importing the ActivationType enum from AITER globally.
|
||||||
|
SILU = 0
|
||||||
|
GELU = 1
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_aiter_asm_moe_tkw1_impl(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
fc1_scale: Optional[torch.Tensor] = None,
|
||||||
|
fc2_scale: Optional[torch.Tensor] = None,
|
||||||
|
fc1_smooth_scale: Optional[torch.Tensor] = None,
|
||||||
|
fc2_smooth_scale: Optional[torch.Tensor] = None,
|
||||||
|
a16: bool = False,
|
||||||
|
per_tensor_quant_scale: Optional[torch.Tensor] = None,
|
||||||
|
expert_mask: Optional[torch.Tensor] = None,
|
||||||
|
activation_method: int = ActivationMethod.SILU.value,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
from aiter import ActivationType
|
||||||
|
from aiter.fused_moe_bf16_asm import asm_moe_tkw1
|
||||||
|
|
||||||
|
activation = ActivationType(activation_method)
|
||||||
|
|
||||||
|
return asm_moe_tkw1(
|
||||||
|
hidden_states,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
fc1_scale=fc1_scale,
|
||||||
|
fc2_scale=fc2_scale,
|
||||||
|
fc1_smooth_scale=fc1_smooth_scale,
|
||||||
|
fc2_smooth_scale=fc2_smooth_scale,
|
||||||
|
a16=a16,
|
||||||
|
per_tensor_quant_scale=per_tensor_quant_scale,
|
||||||
|
expert_mask=expert_mask,
|
||||||
|
activation=activation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_aiter_asm_moe_tkw1_fake(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
fc1_scale: Optional[torch.Tensor] = None,
|
||||||
|
fc2_scale: Optional[torch.Tensor] = None,
|
||||||
|
fc1_smooth_scale: Optional[torch.Tensor] = None,
|
||||||
|
fc2_smooth_scale: Optional[torch.Tensor] = None,
|
||||||
|
a16: bool = False,
|
||||||
|
per_tensor_quant_scale: Optional[torch.Tensor] = None,
|
||||||
|
expert_mask: Optional[torch.Tensor] = None,
|
||||||
|
activation_method: int = ActivationMethod.SILU.value,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.empty_like(hidden_states)
|
||||||
|
|
||||||
|
|
||||||
|
if _use_aiter:
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_asm_moe_tkw1",
|
||||||
|
op_func=rocm_aiter_asm_moe_tkw1_impl,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=rocm_aiter_asm_moe_tkw1_fake,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_fused_experts_tkw1(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
activation: str = "silu",
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
use_fp8_w8a8: bool = False,
|
||||||
|
per_channel_quant: bool = False,
|
||||||
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
activation_method = (
|
||||||
|
ActivationMethod.SILU if activation == "silu" else ActivationMethod.GELU
|
||||||
|
)
|
||||||
|
# All AITER Fused MoE kernels are expecting the following datatypes
|
||||||
|
topk_weights = topk_weights.to(torch.float32)
|
||||||
|
topk_ids = topk_ids.to(torch.int32)
|
||||||
|
|
||||||
|
# w8a8 per-channel quantization
|
||||||
|
if per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
|
||||||
|
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
|
||||||
|
# This applies topk_weights on the GEMM output of the first FC layer
|
||||||
|
# rather than the second FC.
|
||||||
|
assert (
|
||||||
|
topk_weights.dim() == 2
|
||||||
|
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||||
|
assert topk_weights.shape[-1] == 1, (
|
||||||
|
"Only support topk=1 when" " `apply_router_weight_on_input` is True"
|
||||||
|
)
|
||||||
|
|
||||||
|
return torch.ops.sglang.rocm_aiter_asm_moe_tkw1(
|
||||||
|
hidden_states,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
fc1_scale=w1_scale,
|
||||||
|
fc2_scale=w2_scale,
|
||||||
|
fc1_smooth_scale=None,
|
||||||
|
fc2_smooth_scale=None,
|
||||||
|
a16=False,
|
||||||
|
per_tensor_quant_scale=None,
|
||||||
|
expert_mask=None,
|
||||||
|
activation_method=activation_method,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert False, "This should not be called."
|
||||||
@@ -19,7 +19,14 @@ from sglang.srt.layers.quantization.utils import (
|
|||||||
per_tensor_dequantize,
|
per_tensor_dequantize,
|
||||||
replace_parameter,
|
replace_parameter,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
|
from sglang.srt.utils import (
|
||||||
|
get_bool_env_var,
|
||||||
|
is_cpu,
|
||||||
|
is_cuda,
|
||||||
|
is_hip,
|
||||||
|
is_npu,
|
||||||
|
set_weight_attrs,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
@@ -29,6 +36,13 @@ if TYPE_CHECKING:
|
|||||||
CompressedTensorsConfig,
|
CompressedTensorsConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_is_hip = is_hip()
|
||||||
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||||
|
|
||||||
|
if _use_aiter:
|
||||||
|
from aiter.ops.shuffle import shuffle_weight
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.rocm_moe_utils import rocm_fused_experts_tkw1
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import vllm
|
import vllm
|
||||||
@@ -265,6 +279,20 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
max_w13_scales, requires_grad=False
|
max_w13_scales, requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if _use_aiter:
|
||||||
|
with torch.no_grad():
|
||||||
|
# Pre-shuffle weights
|
||||||
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
|
shuffle_weight(layer.w13_weight.data, (16, 16)),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
layer.w2_weight = torch.nn.Parameter(
|
||||||
|
shuffle_weight(layer.w2_weight.data, (16, 16)),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@@ -274,20 +302,43 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from sglang.srt.layers.moe.fused_moe_triton import fused_experts
|
from sglang.srt.layers.moe.fused_moe_triton import fused_experts
|
||||||
|
|
||||||
return fused_experts(
|
if (
|
||||||
x,
|
_use_aiter
|
||||||
layer.w13_weight,
|
and self.weight_quant.strategy == QuantizationStrategy.CHANNEL
|
||||||
layer.w2_weight,
|
and moe_runner_config.apply_router_weight_on_input
|
||||||
topk_output=topk_output,
|
):
|
||||||
moe_runner_config=moe_runner_config,
|
topk_weights, topk_ids, _ = topk_output
|
||||||
use_fp8_w8a8=True,
|
return rocm_fused_experts_tkw1(
|
||||||
per_channel_quant=self.weight_quant.strategy
|
hidden_states=x,
|
||||||
== QuantizationStrategy.CHANNEL,
|
w1=layer.w13_weight,
|
||||||
w1_scale=layer.w13_weight_scale,
|
w2=layer.w2_weight,
|
||||||
w2_scale=layer.w2_weight_scale,
|
topk_weights=topk_weights,
|
||||||
a1_scale=layer.w13_input_scale,
|
topk_ids=topk_ids,
|
||||||
a2_scale=layer.w2_input_scale,
|
activation=moe_runner_config.activation,
|
||||||
)
|
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
per_channel_quant=self.weight_quant.strategy
|
||||||
|
== QuantizationStrategy.CHANNEL,
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a2_scale=layer.w2_input_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return fused_experts(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_output=topk_output,
|
||||||
|
moe_runner_config=moe_runner_config,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
per_channel_quant=self.weight_quant.strategy
|
||||||
|
== QuantizationStrategy.CHANNEL,
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a2_scale=layer.w2_input_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||||
|
|||||||
@@ -966,6 +966,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# ROCm (_use_aiter): using column-wise scaling
|
# ROCm (_use_aiter): using column-wise scaling
|
||||||
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
|
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
|
||||||
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
|
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
|
||||||
|
|||||||
@@ -2228,7 +2228,10 @@ class ServerArgs:
|
|||||||
# use bf16 for mxfp4 triton kernels
|
# use bf16 for mxfp4 triton kernels
|
||||||
self.dtype = "bfloat16"
|
self.dtype = "bfloat16"
|
||||||
elif "Llama4" in model_arch:
|
elif "Llama4" in model_arch:
|
||||||
assert self.attention_backend == "fa3", "fa3 is required for Llama4 model"
|
assert self.attention_backend in {
|
||||||
|
"fa3",
|
||||||
|
"aiter",
|
||||||
|
}, "fa3 or aiter is required for Llama4 model"
|
||||||
elif model_arch in [
|
elif model_arch in [
|
||||||
"Gemma2ForCausalLM",
|
"Gemma2ForCausalLM",
|
||||||
"Gemma3ForCausalLM",
|
"Gemma3ForCausalLM",
|
||||||
|
|||||||
Reference in New Issue
Block a user