Support global scale in addition to per expert scale for cutedsl moe (#10270)
This commit is contained in:
@@ -39,7 +39,7 @@ from sglang.srt.layers.quantization.utils import (
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.utils import is_cuda, next_power_of_2
|
||||
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
@@ -74,6 +74,10 @@ except ImportError:
|
||||
# Initialize logger for the module
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CUTEDSL_MOE_SCALAR_INPUT_SCALE = get_bool_env_var(
|
||||
"SGLANG_CUTEDSL_MOE_SCALAR_INPUT_SCALE", "true"
|
||||
)
|
||||
|
||||
# Supported activation schemes for the current configuration
|
||||
ACTIVATION_SCHEMES = ["static"]
|
||||
|
||||
@@ -1190,7 +1194,19 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
w13_input_scale = layer.w13_input_scale.max().to(torch.float32)
|
||||
w2_input_scale = layer.w2_input_scale.max().to(torch.float32)
|
||||
elif self.enable_flashinfer_cutedsl_moe:
|
||||
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
||||
# All-expert-one-input-scale is mathematically different from default per-expert-input-scale
|
||||
# Thus we allow users to switch the flag to do thorough testing
|
||||
if CUTEDSL_MOE_SCALAR_INPUT_SCALE:
|
||||
w13_input_scale = (
|
||||
layer.w13_input_scale.max()
|
||||
.to(torch.float32)
|
||||
.repeat(layer.w13_input_scale.shape[0])
|
||||
)
|
||||
else:
|
||||
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(
|
||||
torch.float32
|
||||
)
|
||||
|
||||
w2_input_scale = layer.w2_input_scale
|
||||
|
||||
def _slice_scale(w):
|
||||
|
||||
Reference in New Issue
Block a user