fix amd EP MoE FP8 issue (#7125)
This commit is contained in:
@@ -33,10 +33,12 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
is_fp8_fnuz,
|
||||
scaled_fp8_quant,
|
||||
sglang_per_token_group_quant_fp8,
|
||||
sglang_per_token_quant_fp8,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
|
||||
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
@@ -50,6 +52,7 @@ from sglang.srt.utils import (
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
|
||||
if _is_hip:
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
@@ -843,6 +846,33 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
||||
torch.max(layer.w13_weight_scale, dim=1).values,
|
||||
requires_grad=False,
|
||||
)
|
||||
if self.block_quant:
|
||||
# If ROCm, normalize the weights and scales to e4m3fnuz
|
||||
if _is_fp8_fnuz:
|
||||
# activation_scheme: dynamic
|
||||
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.w13_weight,
|
||||
weight_scale=layer.w13_weight_scale_inv,
|
||||
input_scale=None,
|
||||
)
|
||||
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.w2_weight,
|
||||
weight_scale=layer.w2_weight_scale_inv,
|
||||
input_scale=None,
|
||||
)
|
||||
# Reset the parameter
|
||||
layer.w13_weight = torch.nn.Parameter(
|
||||
w13_weight, requires_grad=False
|
||||
)
|
||||
layer.w13_weight_scale_inv = torch.nn.Parameter(
|
||||
w13_weight_scale, requires_grad=False
|
||||
)
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
layer.w2_weight_scale_inv = torch.nn.Parameter(
|
||||
w2_weight_scale, requires_grad=False
|
||||
)
|
||||
layer.w2_input_scale = None
|
||||
return
|
||||
|
||||
def apply(
|
||||
|
||||
Reference in New Issue
Block a user