Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -35,6 +35,7 @@ from vllm.model_executor.layers.quantization.quark.utils import (
|
||||
)
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import get_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
@@ -59,6 +60,22 @@ class QuarkConfig(QuantizationConfig):
|
||||
self.kv_cache_group = kv_cache_group
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.pack_method = pack_method
|
||||
self.dynamic_mxfp4_quant = False
|
||||
|
||||
def maybe_update_config(self, model_name: str, revision: str | None = None):
|
||||
self.hf_config = get_config(
|
||||
model=model_name,
|
||||
trust_remote_code=False, # or get from model_config if available
|
||||
revision=revision,
|
||||
config_format="auto",
|
||||
)
|
||||
|
||||
quant_config = getattr(self.hf_config, "quantization_config", None)
|
||||
if quant_config is not None:
|
||||
quant_dtype = quant_config["global_quant_config"]["weight"]["dtype"]
|
||||
model_type = self.hf_config.model_type
|
||||
if quant_dtype == "fp4" and model_type == "deepseek_v3":
|
||||
self.dynamic_mxfp4_quant = True
|
||||
|
||||
def get_linear_method(self) -> "QuarkLinearMethod":
|
||||
return QuarkLinearMethod(self)
|
||||
@@ -108,7 +125,20 @@ class QuarkConfig(QuantizationConfig):
|
||||
if should_ignore_layer(
|
||||
prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
if (
|
||||
"self_attn" not in prefix # only quantize attention projections
|
||||
or not getattr(self, "dynamic_mxfp4_quant", False)
|
||||
or not isinstance(layer, LinearBase) # Ignore other methods
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
|
||||
scheme = self.get_scheme(
|
||||
layer=layer,
|
||||
layer_name=prefix,
|
||||
dynamic_mxfp4_quant=True,
|
||||
)
|
||||
layer.scheme = scheme
|
||||
return QuarkLinearMethod(self)
|
||||
if isinstance(layer, LinearBase):
|
||||
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||
layer.scheme = scheme
|
||||
@@ -450,7 +480,9 @@ class QuarkConfig(QuantizationConfig):
|
||||
)
|
||||
return global_quant_config
|
||||
|
||||
def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
|
||||
def _get_scheme_from_config(
|
||||
self, config: dict[str, Any], dynamic_mxfp4_quant: bool = False
|
||||
) -> "QuarkScheme":
|
||||
if config.get("output_tensors") or config.get("bias"):
|
||||
raise NotImplementedError(
|
||||
"Currently, Quark models with output_tensors "
|
||||
@@ -473,7 +505,9 @@ class QuarkConfig(QuantizationConfig):
|
||||
input_symmetric=input_config.get("symmetric"),
|
||||
)
|
||||
elif self._is_w_ocp_mx_a_x(weight_config, input_config):
|
||||
return QuarkOCP_MX(weight_config, input_config)
|
||||
return QuarkOCP_MX(
|
||||
weight_config, input_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant
|
||||
)
|
||||
|
||||
raise NotImplementedError(
|
||||
"No quark compatible scheme was found. "
|
||||
@@ -481,11 +515,15 @@ class QuarkConfig(QuantizationConfig):
|
||||
f"Input config: {input_config}"
|
||||
)
|
||||
|
||||
def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme":
|
||||
def get_scheme(
|
||||
self, layer: torch.nn.Module, layer_name: str, dynamic_mxfp4_quant: bool = False
|
||||
) -> "QuarkScheme":
|
||||
layer_quant_config = self._find_matched_config(layer_name, layer)
|
||||
|
||||
# Find the quant_scheme
|
||||
scheme = self._get_scheme_from_config(layer_quant_config)
|
||||
scheme = self._get_scheme_from_config(
|
||||
layer_quant_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant
|
||||
)
|
||||
# Raise error if device does not support the scheme
|
||||
# (e.g. fp8 needs ada lovelace)
|
||||
self._check_scheme_supported(scheme.get_min_capability())
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.quantization.mxfp4 import (
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
prepare_fp8_moe_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4
|
||||
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
|
||||
OCP_MX_BLOCK_SIZE,
|
||||
OCP_MX_Scheme,
|
||||
@@ -49,7 +50,11 @@ from vllm.utils.math_utils import round_up
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkOCP_MX_MoEMethod"]
|
||||
__all__ = [
|
||||
"QuarkMoEMethod",
|
||||
"QuarkOCP_MX_MoEMethod",
|
||||
"QuarkOCP_MX_MoEMethod_OSS",
|
||||
]
|
||||
|
||||
|
||||
class QuarkMoEMethod(FusedMoEMethodBase):
|
||||
@@ -71,14 +76,30 @@ class QuarkMoEMethod(FusedMoEMethodBase):
|
||||
"output_tensors and bias "
|
||||
"quantized are not supported"
|
||||
)
|
||||
|
||||
weight_config = layer_quant_config.get("weight")
|
||||
input_config = layer_quant_config.get("input_tensors")
|
||||
|
||||
if quant_config._is_fp8_w4a8(weight_config, input_config):
|
||||
return QuarkW4A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
|
||||
elif quant_config._is_fp8_w8a8(weight_config, input_config):
|
||||
return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
|
||||
elif quant_config._is_w_ocp_mx_a_x(weight_config, input_config):
|
||||
return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config)
|
||||
emulate = not current_platform.supports_mx() or not (
|
||||
rocm_aiter_ops.is_fused_moe_enabled()
|
||||
)
|
||||
if (
|
||||
input_config.get("dtype") == "fp8_e4m3"
|
||||
and not input_config.get("is_dynamic")
|
||||
and not emulate
|
||||
):
|
||||
return QuarkOCP_MX_MoEMethod_OSS(
|
||||
weight_config, input_config, module.moe_config
|
||||
)
|
||||
else:
|
||||
return QuarkOCP_MX_MoEMethod(
|
||||
weight_config, input_config, module.moe_config
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Unsupported FusedMoe scheme")
|
||||
|
||||
@@ -706,13 +727,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
get_current_vllm_config().model_config.hf_config, "model_type", None
|
||||
)
|
||||
|
||||
self._emulate = (
|
||||
self.emulate = (
|
||||
not current_platform.supports_mx()
|
||||
or not self.ocp_mx_scheme.startswith("w_mxfp4")
|
||||
) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe)
|
||||
|
||||
self.emulate = True if self.model_type == "gpt_oss" else self._emulate
|
||||
|
||||
if self.emulate:
|
||||
logger.warning_once(
|
||||
f"The current mode (supports_mx={current_platform.supports_mx()}, "
|
||||
@@ -753,6 +772,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
)
|
||||
|
||||
params_dtype = torch.uint8
|
||||
self.intermediate_size_per_partition = intermediate_size_per_partition
|
||||
if self.model_type == "gpt_oss":
|
||||
if current_platform.is_rocm():
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
@@ -765,6 +785,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
else:
|
||||
intermediate_size_per_partition_after_pad = intermediate_size_per_partition
|
||||
|
||||
self.unpadded_hidden_size = extra_weight_attrs.get(
|
||||
"unpadded_hidden_size", hidden_size
|
||||
)
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
@@ -991,30 +1015,20 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if not self.emulate:
|
||||
if (
|
||||
self.model_type == "gpt_oss"
|
||||
and self.mxfp4_backend == Mxfp4Backend.TRITON
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Triton kernel implemented fused MoE for GPT_OSS model "
|
||||
"in Quark(MoE) format is not integrated or provided yet."
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_fused_experts,
|
||||
)
|
||||
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
rocm_aiter_fused_experts,
|
||||
)
|
||||
|
||||
return rocm_aiter_fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
expert_map=layer.expert_map,
|
||||
)
|
||||
return rocm_aiter_fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=layer.activation,
|
||||
quant_config=self.moe_quant_config,
|
||||
expert_map=layer.expert_map,
|
||||
)
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
@@ -1031,3 +1045,133 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
|
||||
expert_map=layer.expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
|
||||
class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod):
|
||||
def __init__(
|
||||
self,
|
||||
weight_config: dict[str, Any],
|
||||
input_config: dict[str, Any],
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(weight_config, input_config, moe)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
|
||||
|
||||
w13_bias = layer.w13_bias.to(torch.float32)
|
||||
w2_bias = layer.w2_bias.to(torch.float32)
|
||||
|
||||
layer.w13_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
|
||||
layer.w2_bias = torch.nn.Parameter(w2_bias, requires_grad=False)
|
||||
|
||||
# FIXME warp need to be adjusted based on batch size
|
||||
# only apply to batched mode
|
||||
if self.moe.use_ep:
|
||||
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
|
||||
else:
|
||||
num_warps = 8
|
||||
|
||||
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
|
||||
layer.w13_weight, layer.w13_weight_scale, num_warps
|
||||
)
|
||||
w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
|
||||
layer.w2_weight, layer.w2_weight_scale, num_warps
|
||||
)
|
||||
|
||||
self.w13_weight_triton_tensor = w13_weight
|
||||
self.w2_weight_triton_tensor = w2_weight
|
||||
|
||||
# need to delete the original weights to save memory on single GPU
|
||||
del layer.w13_weight
|
||||
del layer.w2_weight
|
||||
layer.w13_weight = None
|
||||
layer.w2_weight = None
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if self.static_input_scales:
|
||||
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
"activation scales are None."
|
||||
)
|
||||
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
|
||||
layer.w2_input_scale
|
||||
):
|
||||
logger.warning_once(
|
||||
"Found input_scales that are not equal for "
|
||||
"fp8 MoE layer. Using the maximum across experts "
|
||||
"for each layer."
|
||||
)
|
||||
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
layer.w13_input_scale.max().to(torch.float32), requires_grad=False
|
||||
)
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
layer.w2_input_scale.max().to(torch.float32), requires_grad=False
|
||||
)
|
||||
|
||||
from triton_kernels.numerics import InFlexData
|
||||
|
||||
lhs_data13 = InFlexData(scale=layer.w13_input_scale)
|
||||
lhs_data2 = InFlexData(scale=layer.w2_input_scale)
|
||||
|
||||
self.w13_precision_config = PrecisionConfig(
|
||||
weight_scale=w13_scale,
|
||||
flex_ctx=FlexCtx(rhs_data=w13_flex, lhs_data=lhs_data13),
|
||||
)
|
||||
|
||||
self.w2_precision_config = PrecisionConfig(
|
||||
weight_scale=w2_scale,
|
||||
flex_ctx=FlexCtx(rhs_data=w2_flex, lhs_data=lhs_data2),
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
return mxfp4_w4a8_moe_quant_config(
|
||||
w1_scale=self.w13_precision_config,
|
||||
w2_scale=self.w2_precision_config,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
w1_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
block_shape=None,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_monolithic(self) -> bool:
|
||||
return True
|
||||
|
||||
def apply_monolithic(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if layer.enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `QuarkW4MXFp4MoEMethod_OSS` yet."
|
||||
)
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
|
||||
triton_kernel_moe_forward,
|
||||
)
|
||||
|
||||
return triton_kernel_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=self.w13_weight_triton_tensor,
|
||||
w2=self.w2_weight_triton_tensor,
|
||||
gating_output=router_logits,
|
||||
topk=layer.top_k,
|
||||
renormalize=layer.renormalize,
|
||||
global_num_experts=layer.global_num_experts,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
unpadded_N_w1=self.intermediate_size_per_partition * 2,
|
||||
unpadded_K_w1=self.unpadded_hidden_size,
|
||||
unpadded_N_w2=self.unpadded_hidden_size,
|
||||
unpadded_K_w2=self.intermediate_size_per_partition,
|
||||
)
|
||||
|
||||
@@ -24,7 +24,12 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
|
||||
OCP_MX_BLOCK_SIZE,
|
||||
OCP_MX_Scheme,
|
||||
)
|
||||
from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
||||
from vllm.model_executor.parameter import (
|
||||
GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PackedvLLMParameter,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .quark_scheme import QuarkScheme
|
||||
@@ -169,13 +174,16 @@ except (ImportError, AttributeError, RuntimeError):
|
||||
|
||||
class QuarkOCP_MX(QuarkScheme):
|
||||
def __init__(
|
||||
self, weight_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any]
|
||||
self,
|
||||
weight_quant_spec: dict[str, Any],
|
||||
input_quant_spec: dict[str, Any],
|
||||
dynamic_mxfp4_quant: bool = False,
|
||||
):
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.qscheme = "per_group"
|
||||
self.weight_quant_spec = weight_quant_spec
|
||||
self.input_quant_spec = input_quant_spec
|
||||
|
||||
self.dynamic_mxfp4_quant = dynamic_mxfp4_quant
|
||||
self.weight_dtype = weight_quant_spec["dtype"].replace("fp", "mxfp")
|
||||
self.input_dtype = input_quant_spec["dtype"].replace("fp", "mxfp")
|
||||
|
||||
@@ -269,7 +277,13 @@ class QuarkOCP_MX(QuarkScheme):
|
||||
layer.weight_scale.data, requires_grad=False
|
||||
)
|
||||
else:
|
||||
if self.rocm_use_aiter_fp4_asm_gemm:
|
||||
if self.dynamic_mxfp4_quant:
|
||||
w_q, w_s = dynamic_mxfp4_quant(layer.weight)
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
w_s.T.contiguous(), requires_grad=False
|
||||
)
|
||||
layer.weight = torch.nn.Parameter(w_q, requires_grad=False)
|
||||
elif self.rocm_use_aiter_fp4_asm_gemm:
|
||||
# shuffle weight scale
|
||||
weight_scale_shuffle = layer.weight_scale.data
|
||||
sm, sn = weight_scale_shuffle.shape
|
||||
@@ -302,36 +316,51 @@ class QuarkOCP_MX(QuarkScheme):
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
if self.dynamic_mxfp4_quant:
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
# WEIGHT
|
||||
weight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
self.get_packed_dim(input_size_per_partition, self.weight_dtype),
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=1,
|
||||
packed_factor=self.packed_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, kwargs)
|
||||
else:
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // OCP_MX_BLOCK_SIZE,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
# WEIGHT
|
||||
weight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
self.get_packed_dim(input_size_per_partition, self.weight_dtype),
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=1,
|
||||
packed_factor=self.packed_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // OCP_MX_BLOCK_SIZE,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user