fix amd EP MoE FP8 issue (#7125)
This commit is contained in:
@@ -33,10 +33,12 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||||
|
is_fp8_fnuz,
|
||||||
scaled_fp8_quant,
|
scaled_fp8_quant,
|
||||||
sglang_per_token_group_quant_fp8,
|
sglang_per_token_group_quant_fp8,
|
||||||
sglang_per_token_quant_fp8,
|
sglang_per_token_quant_fp8,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||||
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
|
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
|
||||||
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
@@ -50,6 +52,7 @@ from sglang.srt.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
|
_is_fp8_fnuz = is_fp8_fnuz()
|
||||||
|
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
from vllm._custom_ops import scaled_fp8_quant
|
from vllm._custom_ops import scaled_fp8_quant
|
||||||
@@ -843,6 +846,33 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|||||||
torch.max(layer.w13_weight_scale, dim=1).values,
|
torch.max(layer.w13_weight_scale, dim=1).values,
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
|
if self.block_quant:
|
||||||
|
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||||
|
if _is_fp8_fnuz:
|
||||||
|
# activation_scheme: dynamic
|
||||||
|
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
weight=layer.w13_weight,
|
||||||
|
weight_scale=layer.w13_weight_scale_inv,
|
||||||
|
input_scale=None,
|
||||||
|
)
|
||||||
|
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
weight=layer.w2_weight,
|
||||||
|
weight_scale=layer.w2_weight_scale_inv,
|
||||||
|
input_scale=None,
|
||||||
|
)
|
||||||
|
# Reset the parameter
|
||||||
|
layer.w13_weight = torch.nn.Parameter(
|
||||||
|
w13_weight, requires_grad=False
|
||||||
|
)
|
||||||
|
layer.w13_weight_scale_inv = torch.nn.Parameter(
|
||||||
|
w13_weight_scale, requires_grad=False
|
||||||
|
)
|
||||||
|
layer.w13_input_scale = None
|
||||||
|
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||||
|
layer.w2_weight_scale_inv = torch.nn.Parameter(
|
||||||
|
w2_weight_scale, requires_grad=False
|
||||||
|
)
|
||||||
|
layer.w2_input_scale = None
|
||||||
return
|
return
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
|
|||||||
Reference in New Issue
Block a user