Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -12,8 +12,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.attention import Attention, MLAAttention
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
@@ -24,14 +23,12 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
|
||||
Fp8MoeBackend,
|
||||
convert_to_fp8_moe_kernel_format,
|
||||
make_fp8_moe_kernel,
|
||||
make_fp8_moe_quant_config,
|
||||
select_fp8_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
|
||||
NvFp4MoeBackend,
|
||||
convert_to_nvfp4_moe_kernel_format,
|
||||
is_global_sf_supported_for_nvfp4_backend,
|
||||
make_nvfp4_moe_kernel,
|
||||
@@ -49,13 +46,6 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||
flashinfer_trtllm_fp4_moe,
|
||||
flashinfer_trtllm_fp4_routed_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_fi_trtllm_fp8_per_tensor_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
process_fp8_input_tensor_strategy_moe,
|
||||
@@ -114,6 +104,8 @@ QUANT_ALGOS = [
|
||||
"NVFP4",
|
||||
# MXFP8
|
||||
"MXFP8",
|
||||
# MIXED_PRECISION,
|
||||
"MIXED_PRECISION",
|
||||
]
|
||||
KV_CACHE_QUANT_ALGOS = ["FP8"]
|
||||
|
||||
@@ -181,7 +173,7 @@ class ModelOptQuantConfigBase(QuantizationConfig):
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "QuantizeMethodBase | None":
|
||||
# handle kv-cache first so we can focus only on weight quantization thereafter
|
||||
if isinstance(layer, Attention):
|
||||
if isinstance(layer, (Attention, MLAAttention)):
|
||||
return self.KVCacheMethodCls(self)
|
||||
|
||||
# handle exclusion
|
||||
@@ -235,6 +227,26 @@ class ModelOptQuantConfigBase(QuantizationConfig):
|
||||
|
||||
self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules)
|
||||
|
||||
@staticmethod
|
||||
def _extract_modelopt_quant_algo(
|
||||
hf_quant_cfg: dict[str, Any] | None,
|
||||
) -> str | None:
|
||||
"""Extract upper-cased quant_algo from a modelopt config.
|
||||
|
||||
Returns the quant_algo string (upper-cased), or None if the config
|
||||
is not a modelopt config.
|
||||
"""
|
||||
if hf_quant_cfg is None:
|
||||
return None
|
||||
if hf_quant_cfg.get("quant_method", "").lower() != "modelopt":
|
||||
return None
|
||||
if "quantization" in hf_quant_cfg:
|
||||
quant_config = hf_quant_cfg["quantization"]
|
||||
if isinstance(quant_config, dict):
|
||||
return str(quant_config.get("quant_algo", "")).upper()
|
||||
return None
|
||||
return str(hf_quant_cfg.get("quant_algo", "")).upper()
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
return ["hf_quant_config.json"]
|
||||
@@ -272,10 +284,20 @@ class ModelOptQuantConfigBase(QuantizationConfig):
|
||||
# "exclude_modules" is the key in the legacy hf_quant_config.json
|
||||
exclude_modules = quant_config.get("exclude_modules", [])
|
||||
else:
|
||||
# Compressed-tensors style format:
|
||||
# Compressed-tensors style format (config.json quantization_config):
|
||||
# {"quant_algo": "...", "quant_method": "modelopt"}
|
||||
quant_method = config.get("quant_algo")
|
||||
kv_cache_quant_method = config.get("kv_cache_quant_algo")
|
||||
|
||||
# "kv_cache_scheme" (a dict) instead of "kv_cache_quant_algo" (a string).
|
||||
kv_cache_scheme = config.get("kv_cache_scheme")
|
||||
if isinstance(kv_cache_scheme, dict) and (
|
||||
kv_cache_scheme.get("type") == "float"
|
||||
and kv_cache_scheme.get("num_bits") == 8
|
||||
):
|
||||
kv_cache_quant_method = "FP8"
|
||||
else:
|
||||
kv_cache_quant_method = None
|
||||
|
||||
# "ignore" is the key in config.json
|
||||
exclude_modules = config.get("ignore", [])
|
||||
group_size_raw = config.get("group_size")
|
||||
@@ -379,32 +401,9 @@ class ModelOptFp8Config(ModelOptQuantConfigBase):
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> QuantizationMethods | None:
|
||||
"""Detect if this ModelOpt config should be used based on
|
||||
quantization config."""
|
||||
|
||||
if hf_quant_cfg is None:
|
||||
return None
|
||||
|
||||
# Use the community standard 'quant_method'
|
||||
quant_method = hf_quant_cfg.get("quant_method", "").lower()
|
||||
|
||||
# Only proceed if the method is explicitly "modelopt"
|
||||
if quant_method != "modelopt":
|
||||
return None
|
||||
|
||||
# Look for ModelOpt-specific config structure
|
||||
if "quantization" in hf_quant_cfg:
|
||||
quant_config = hf_quant_cfg["quantization"]
|
||||
if isinstance(quant_config, dict):
|
||||
quant_algo = str(quant_config.get("quant_algo", ""))
|
||||
if quant_algo.upper() == "FP8":
|
||||
return "modelopt"
|
||||
else:
|
||||
# Check for compressed-tensors style config with specific quant_algo
|
||||
quant_algo = str(hf_quant_cfg.get("quant_algo", ""))
|
||||
if quant_algo.upper() == "FP8":
|
||||
return "modelopt"
|
||||
|
||||
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
|
||||
if algo is not None and algo == "FP8":
|
||||
return "modelopt"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@@ -737,7 +736,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
@@ -745,9 +744,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
) -> mk.FusedMoEExpertsModular:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
@@ -862,16 +861,15 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
# Setup modular kernel.
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.moe_mk = make_fp8_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
experts_cls=self.experts_cls,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
shared_experts=layer.shared_experts,
|
||||
)
|
||||
assert self.experts_cls is not None
|
||||
self.moe_kernel = make_fp8_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
fp8_backend=self.fp8_backend,
|
||||
experts_cls=self.experts_cls,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
shared_experts=layer.shared_experts,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
w13 = layer.w13_weight
|
||||
@@ -904,9 +902,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||
w1_scale = layer.w13_weight_scale
|
||||
w2_scale = layer.w2_weight_scale
|
||||
a1_scale = layer.w13_input_scale
|
||||
@@ -920,10 +916,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
a2_scale=a2_scale,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
@@ -931,28 +923,20 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.is_monolithic
|
||||
assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend."
|
||||
)
|
||||
# TODO(rob): this validation should happen at kernel selection
|
||||
# time in the oracle rather than here.
|
||||
SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
|
||||
assert layer.activation in SUPPORTED_ACTIVATIONS, (
|
||||
f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
|
||||
f"TRTLLM FP4 MoE, {layer.activation} found instead."
|
||||
)
|
||||
return apply_fi_trtllm_fp8_per_tensor_moe(
|
||||
layer=layer,
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
routing_bias=layer.e_score_correction_bias,
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply_monolithic(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
router_logits,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
top_k=layer.top_k,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
e_score_correction_bias=layer.e_score_correction_bias,
|
||||
routed_scaling_factor=layer.routed_scaling_factor,
|
||||
)
|
||||
|
||||
def apply(
|
||||
@@ -964,25 +948,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert not self.is_monolithic
|
||||
|
||||
# TODO(rob): this validation should happen at kernel selection
|
||||
# time in the oracle rather than here.
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
assert layer.activation in (
|
||||
MoEActivation.SILU,
|
||||
MoEActivation.RELU2_NO_MUL,
|
||||
), (
|
||||
"Expected activation to be in ('silu', 'relu2_no_mul'),"
|
||||
f"but got {layer.activation}"
|
||||
)
|
||||
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
@@ -1031,32 +1003,9 @@ class ModelOptNvFp4Config(ModelOptQuantConfigBase):
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> QuantizationMethods | None:
|
||||
"""Detect if this ModelOpt FP4 config should be used based on
|
||||
quantization config."""
|
||||
if hf_quant_cfg is None:
|
||||
return None
|
||||
|
||||
# Use the community standard 'quant_method'
|
||||
quant_method = hf_quant_cfg.get("quant_method", "").lower()
|
||||
|
||||
# Only proceed if the method is explicitly "modelopt"
|
||||
if quant_method != "modelopt":
|
||||
return None
|
||||
|
||||
# Look for ModelOpt-specific config structure
|
||||
if "quantization" in hf_quant_cfg:
|
||||
quant_config = hf_quant_cfg["quantization"]
|
||||
if isinstance(quant_config, dict):
|
||||
quant_algo = quant_config.get("quant_algo", "")
|
||||
if "NVFP4" in quant_algo:
|
||||
return "modelopt_fp4"
|
||||
else:
|
||||
# Check for compressed-tensors style config with specific
|
||||
# quant_algo field
|
||||
quant_algo = hf_quant_cfg.get("quant_algo", "")
|
||||
if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
|
||||
return "modelopt_fp4"
|
||||
|
||||
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
|
||||
if algo is not None and ("NVFP4" in algo or "FP4" in algo):
|
||||
return "modelopt_fp4"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@@ -1249,17 +1198,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||
layer: torch.nn.Module,
|
||||
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
) -> mk.FusedMoEPrepareAndFinalizeModular | None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} uses the new modular kernel initialization "
|
||||
"logic. This function should not be called."
|
||||
@@ -1434,51 +1373,18 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
replace_parameter(layer, "w2_weight_scale_2", w2_scale_2)
|
||||
replace_parameter(layer, "w2_input_scale", a2_scale)
|
||||
|
||||
# Setup modular kernel for TP case and naive DP/EP case.
|
||||
# In non-naive DP/EP case, we will create a ModularKernelMethod.
|
||||
# TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
|
||||
# in both cases.
|
||||
# Setup modular kernel.
|
||||
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
|
||||
if self.moe_quant_config:
|
||||
assert self.experts_cls is not None
|
||||
self.moe_mk = make_nvfp4_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
experts_cls=self.experts_cls,
|
||||
shared_experts=layer.shared_experts,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
)
|
||||
|
||||
@property
|
||||
def do_post_quant_allgather(self):
|
||||
return self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
|
||||
|
||||
def prepare_dp_allgather_tensor(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
"""Optionally prepare extra tensors to carry through DP allgather/EP."""
|
||||
if self.nvfp4_backend != NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
raise RuntimeError(
|
||||
"prepare_dp_allgather_tensor is only supported for "
|
||||
"FlashInfer TRTLLM NVFP4 MoE backend."
|
||||
)
|
||||
|
||||
import flashinfer
|
||||
|
||||
hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
|
||||
hidden_states,
|
||||
layer.a1_gscale,
|
||||
is_sf_swizzled_layout=False,
|
||||
assert self.experts_cls is not None
|
||||
self.moe_kernel = make_nvfp4_moe_kernel(
|
||||
moe_quant_config=self.moe_quant_config,
|
||||
moe_config=self.moe,
|
||||
experts_cls=self.experts_cls,
|
||||
shared_experts=layer.shared_experts,
|
||||
routing_tables=layer._maybe_init_expert_routing_tables(),
|
||||
)
|
||||
extra_tensors: list[torch.Tensor] = [hidden_states_sf]
|
||||
return hidden_states_fp4, extra_tensors
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
|
||||
return make_nvfp4_moe_quant_config(
|
||||
backend=self.nvfp4_backend,
|
||||
w13_scale=layer.w13_weight_scale,
|
||||
@@ -1493,13 +1399,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
def supports_eplb(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return (
|
||||
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
|
||||
and not self.moe.moe_parallel_config.enable_eplb
|
||||
)
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
@@ -1507,22 +1406,20 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
router_logits: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert self.is_monolithic
|
||||
assert (
|
||||
self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
|
||||
and not layer.enable_eplb
|
||||
)
|
||||
|
||||
return flashinfer_trtllm_fp4_moe(
|
||||
layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=layer.top_k,
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply_monolithic(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
router_logits,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
num_expert_group=layer.num_expert_group,
|
||||
topk_group=layer.topk_group,
|
||||
custom_routing_function=layer.custom_routing_function,
|
||||
e_score_correction_bias=layer.e_score_correction_bias,
|
||||
routed_scaling_factor=layer.routed_scaling_factor,
|
||||
)
|
||||
|
||||
def apply(
|
||||
@@ -1534,33 +1431,19 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert not self.is_monolithic
|
||||
|
||||
# EPLB path
|
||||
if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
|
||||
assert layer.enable_eplb
|
||||
return flashinfer_trtllm_fp4_routed_moe(
|
||||
layer=layer,
|
||||
x=x,
|
||||
topk_ids=topk_ids,
|
||||
topk_weights=topk_weights,
|
||||
top_k=layer.top_k,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
)
|
||||
else:
|
||||
assert self.moe_mk is not None
|
||||
return self.moe_mk(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
assert self.moe_kernel is not None
|
||||
return self.moe_kernel.apply(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation=layer.activation,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=layer.expert_map,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
shared_experts_input=shared_experts_input,
|
||||
)
|
||||
|
||||
|
||||
ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
|
||||
@@ -1619,31 +1502,9 @@ class ModelOptMxFp8Config(ModelOptQuantConfigBase):
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> QuantizationMethods | None:
|
||||
"""Detect if this ModelOpt MXFP8 config should be used based on
|
||||
quantization config."""
|
||||
if hf_quant_cfg is None:
|
||||
return None
|
||||
|
||||
# Use the community standard 'quant_method'
|
||||
quant_method = hf_quant_cfg.get("quant_method", "").lower()
|
||||
|
||||
# Only proceed if the method is explicitly "modelopt"
|
||||
if quant_method != "modelopt":
|
||||
return None
|
||||
|
||||
# Look for ModelOpt-specific config structure
|
||||
if "quantization" in hf_quant_cfg:
|
||||
quant_config = hf_quant_cfg["quantization"]
|
||||
if isinstance(quant_config, dict):
|
||||
quant_algo = str(quant_config.get("quant_algo", "")).upper()
|
||||
if "MXFP8" in quant_algo:
|
||||
return "modelopt_mxfp8"
|
||||
else:
|
||||
# Check for compressed-tensors style config with specific quant_algo
|
||||
quant_algo = str(hf_quant_cfg.get("quant_algo", "")).upper()
|
||||
if "MXFP8" in quant_algo:
|
||||
return "modelopt_mxfp8"
|
||||
|
||||
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
|
||||
if algo is not None and "MXFP8" in algo:
|
||||
return "modelopt_mxfp8"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@@ -1841,3 +1702,188 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
|
||||
# Register the method classes for ModelOptMxFp8Config
|
||||
ModelOptMxFp8Config.LinearMethodCls = ModelOptMxFp8LinearMethod
|
||||
ModelOptMxFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
|
||||
|
||||
|
||||
class ModelOptMixedPrecisionConfig(ModelOptQuantConfigBase):
|
||||
"""Config class for ModelOpt MIXED_PRECISION.
|
||||
|
||||
Supports checkpoints where different layers use different quantization
|
||||
algorithms (e.g., FP8 for dense layers and NVFP4 for MoE experts).
|
||||
The per-layer algorithm is specified in the ``quantized_layers`` dict
|
||||
inside ``config.json``'s ``quantization_config`` (preferred) or the
|
||||
legacy ``hf_quant_config.json``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_quant_method: str | None,
|
||||
exclude_modules: list[str],
|
||||
quantized_layers: dict[str, dict[str, Any]],
|
||||
fp8_config: ModelOptFp8Config,
|
||||
nvfp4_config: ModelOptNvFp4Config,
|
||||
) -> None:
|
||||
super().__init__(exclude_modules)
|
||||
self.kv_cache_quant_method = kv_cache_quant_method
|
||||
self.quantized_layers = quantized_layers
|
||||
self.fp8_config = fp8_config
|
||||
self.nvfp4_config = nvfp4_config
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "modelopt_mixed"
|
||||
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 89
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> QuantizationMethods | None:
|
||||
algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
|
||||
if algo is not None and algo == "MIXED_PRECISION":
|
||||
return "modelopt_mixed"
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _from_config(
|
||||
cls,
|
||||
*,
|
||||
quant_method: str,
|
||||
kv_cache_quant_method: str | None,
|
||||
exclude_modules: list[str],
|
||||
original_config: dict[str, Any],
|
||||
group_size: int | None,
|
||||
**kwargs: Any,
|
||||
) -> "ModelOptMixedPrecisionConfig":
|
||||
if "quantization" in original_config:
|
||||
quantized_layers = original_config["quantization"].get(
|
||||
"quantized_layers", {}
|
||||
)
|
||||
else:
|
||||
quantized_layers = original_config.get("quantized_layers", {})
|
||||
|
||||
if not quantized_layers:
|
||||
raise ValueError(
|
||||
"MIXED_PRECISION quant_algo requires a non-empty "
|
||||
"'quantized_layers' mapping in the quantization config."
|
||||
)
|
||||
|
||||
# Determine group_size from the first NVFP4 entry if not provided.
|
||||
if group_size is None:
|
||||
for layer_info in quantized_layers.values():
|
||||
if layer_info.get("quant_algo", "").upper() == "NVFP4":
|
||||
group_size = layer_info.get("group_size", 16)
|
||||
break
|
||||
if group_size is None:
|
||||
group_size = 16
|
||||
|
||||
fp8_config = ModelOptFp8Config(
|
||||
quant_method="FP8",
|
||||
is_checkpoint_fp8_serialized=True,
|
||||
kv_cache_quant_method=kv_cache_quant_method,
|
||||
exclude_modules=[],
|
||||
)
|
||||
nvfp4_config = ModelOptNvFp4Config(
|
||||
is_checkpoint_nvfp4_serialized=True,
|
||||
kv_cache_quant_algo=kv_cache_quant_method,
|
||||
exclude_modules=[],
|
||||
group_size=group_size,
|
||||
)
|
||||
|
||||
return cls(
|
||||
kv_cache_quant_method=kv_cache_quant_method,
|
||||
exclude_modules=exclude_modules,
|
||||
quantized_layers=quantized_layers,
|
||||
fp8_config=fp8_config,
|
||||
nvfp4_config=nvfp4_config,
|
||||
)
|
||||
|
||||
def _resolve_quant_algo(self, prefix: str) -> str | None:
|
||||
"""Look up the quant_algo for a vLLM-side layer prefix.
|
||||
|
||||
Tries three strategies in order:
|
||||
1. Direct lookup in ``quantized_layers``.
|
||||
2. Packed/fused-layer lookup (unfuse via ``packed_modules_mapping``).
|
||||
3. Prefix-based lookup for FusedMoE (any child key starts with
|
||||
``prefix + "."``).
|
||||
|
||||
Returns the upper-cased quant_algo string, or *None* if the prefix
|
||||
is not found.
|
||||
"""
|
||||
# 1. Direct lookup
|
||||
if prefix in self.quantized_layers:
|
||||
return self.quantized_layers[prefix]["quant_algo"].upper()
|
||||
|
||||
# 2. Packed / fused layer lookup
|
||||
proj_name = prefix.rsplit(".", 1)[-1]
|
||||
if self.packed_modules_mapping and proj_name in self.packed_modules_mapping:
|
||||
algos: set[str] = set()
|
||||
base = prefix.rsplit(".", 1)[0]
|
||||
for shard_name in self.packed_modules_mapping[proj_name]:
|
||||
shard_prefix = f"{base}.{shard_name}"
|
||||
if shard_prefix in self.quantized_layers:
|
||||
algos.add(self.quantized_layers[shard_prefix]["quant_algo"].upper())
|
||||
if len(algos) == 1:
|
||||
return algos.pop()
|
||||
if len(algos) > 1:
|
||||
raise ValueError(
|
||||
f"Mixed quant_algo within fused layer {prefix}: "
|
||||
f"{algos}. All shards must use the same quantization."
|
||||
)
|
||||
|
||||
# 3. Prefix-based lookup (for FusedMoE / parent modules)
|
||||
prefix_dot = prefix + "."
|
||||
for key, info in self.quantized_layers.items():
|
||||
if key.startswith(prefix_dot):
|
||||
return info["quant_algo"].upper()
|
||||
|
||||
return None
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> "QuantizeMethodBase | None":
|
||||
"""Return quantize-method based on layer."""
|
||||
# KV-cache quantization
|
||||
if isinstance(layer, Attention):
|
||||
if self.kv_cache_quant_method:
|
||||
return ModelOptFp8KVCacheMethod(self)
|
||||
return None
|
||||
|
||||
# Excluded layers
|
||||
if self.is_layer_excluded(prefix):
|
||||
if isinstance(layer, LinearBase):
|
||||
return UnquantizedLinearMethod()
|
||||
return None
|
||||
|
||||
quant_algo = self._resolve_quant_algo(prefix)
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if quant_algo == "FP8":
|
||||
return ModelOptFp8LinearMethod(self.fp8_config)
|
||||
if quant_algo == "NVFP4":
|
||||
return ModelOptNvFp4LinearMethod(self.nvfp4_config)
|
||||
# Layer not in quantized_layers — leave unquantized
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
if quant_algo == "FP8":
|
||||
return ModelOptFp8MoEMethod(
|
||||
quant_config=self.fp8_config,
|
||||
moe_config=layer.moe_config,
|
||||
)
|
||||
if quant_algo == "NVFP4":
|
||||
return ModelOptNvFp4FusedMoE(
|
||||
quant_config=self.nvfp4_config,
|
||||
moe_config=layer.moe_config,
|
||||
)
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
super().apply_vllm_mapper(hf_to_vllm_mapper)
|
||||
if self.quantized_layers:
|
||||
self.quantized_layers = hf_to_vllm_mapper.apply_dict(self.quantized_layers)
|
||||
|
||||
Reference in New Issue
Block a user