Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any
import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
@@ -9,6 +9,7 @@ from torch.nn import Parameter
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
@@ -60,6 +61,7 @@ from vllm.transformers_utils.config import get_safetensors_params_metadata
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.models.utils import WeightsMapper
import ixformer.inference.functions as ixfops
logger = init_logger(__name__)
@@ -197,7 +199,7 @@ class AWQMarlinConfig(QuantizationConfig):
quant_method.input_dtype = get_marlin_input_dtype(prefix)
return quant_method
elif isinstance(layer, FusedMoE):
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
# from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
# if is_layer_skipped(
# prefix,
@@ -213,9 +215,10 @@ class AWQMarlinConfig(QuantizationConfig):
# return MoeWNA16Config.from_config(self.full_config).get_quant_method(
# layer, prefix
# )
moe_quant_method = AWQMarlinMoEMethod(self, layer.moe_config)
moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
return moe_quant_method
# moe_quant_method = AWQMarlinMoEMethod(self, layer.moe_config)
# moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
# return moe_quant_method
return AWQMarlinMoEMethod(self, layer.moe_config)
return None
@classmethod
@@ -389,13 +392,13 @@ class AWQMarlinLinearMethod(LinearMethodBase):
replace_parameter(layer, "qweight", pad_qweight)
replace_parameter(layer, "qzeros", pad_qzeros)
replace_parameter(layer, "scales", pad_scales)
return
# TODO(gyf) Marlin format is not support for now..
device = layer.qweight.device
layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
return
# Allocate marlin workspace
layer.workspace = marlin_make_workspace_new(device)
@@ -811,49 +814,33 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
# Assign the value of shared_experts_output to variable shared_experts_input for fusion
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
# return fused_marlin_moe(
# x,
# layer.w13_qweight,
# layer.w2_qweight,
# getattr(layer, "w13_bias", None),
# getattr(layer, "w2_bias", None),
# layer.w13_scales,
# layer.w2_scales,
# topk_weights,
# topk_ids,
# input_global_scale1=getattr(layer, "w13_input_global_scale", None),
# input_global_scale2=getattr(layer, "w2_input_global_scale", None),
# quant_type_id=self.quant_type.id,
# apply_router_weight_on_input=layer.apply_router_weight_on_input,
# global_num_experts=layer.global_num_experts,
# expert_map=layer.expert_map,
# w1_zeros=layer.w13_qzeros,
# w2_zeros=layer.w2_qzeros,
# workspace=layer.workspace,
# input_dtype=self.input_dtype,
# inplace=not self.moe.disable_inplace,
# )
num_tokens, num_experts = router_logits.shape
assert layer.activation.value == "silu", "Only SiLU activation is supported."
use_ep = layer.expert_map is not None
attn_metadata = get_forward_context().attn_metadata
if attn_metadata:
if isinstance(attn_metadata, dict):
only_decode = (use_ep == False and all(t.num_decodes > 0 and t.num_prefills ==0 for t in list(attn_metadata.values())))
else:
only_decode = use_ep == False and attn_metadata.num_decodes > 0 and attn_metadata.num_prefills == 0
else:
only_decode = False
if use_ep:
start_eid = layer.ep_rank * layer.local_num_experts
end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, layer.global_num_experts)
if layer.apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for"
"fused Marlin MoE method.")
num_tokens = topk_ids.shape[0]
num_experts = layer.global_num_experts
if use_ep:
hidden_size = x.shape[1]
(
@@ -875,7 +862,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
dtype=x.dtype,
)
else:
expand_tokens = num_tokens * top_k
expand_tokens = num_tokens * layer.top_k
(
src_to_dst,
sorted_token_ids,
@@ -885,7 +872,6 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
topk_ids=topk_ids,
num_experts=num_experts,
)
expert_sizes_cpu = expert_sizes_gpu.cpu()
# expand + reorder
# TODO use kernel
@@ -893,76 +879,130 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
hidden_states=x,
dst_to_src=sorted_token_ids,
dst_tokens=expand_tokens,
topk=top_k,
topk=layer.top_k,
src_to_dst=src_to_dst,
)
# w4a16 group gemm 1
# pt_output_1: (expand_tokens, 2n) dtype
pt_output_1 = ixfops.moe_w4a16_group_gemm(
input=expand_hidden_states,
weight=layer.w13_qweight,
w_scales=layer.w13_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w13_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=None,
format="NN",
tokens_per_experts_gpu=expert_sizes_gpu,
)
# act
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
# w4a16 group gemm 2 + reorder
# pt_output_3: (expand_tokens, k) dtype
if use_ep:
pt_output_3 = torch.empty(
(num_tokens * top_k, hidden_size),
device=x.device,
dtype=x.dtype,
)
ixfops.moe_w4a16_group_gemm(
input=pt_output_2,
weight=layer.w2_qweight,
w_scales=layer.w2_scales,
if only_decode:
pt_output_1 = ixfops.moe_w4a16_group_gemv(
input=expand_hidden_states,
weight=layer.w13_qweight,
w_scales=layer.w13_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w2_qzeros,
w_zeros=layer.w13_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=sorted_token_ids,
format="NN",
output=pt_output_3,
)
reduce_mask = src_to_dst == -1
final_hidden_states = ixfops.moe_output_reduce_sum(
input=pt_output_3.view(num_tokens, top_k, -1),
topk_weight=topk_weights,
scaling_factor=routed_scaling_factor,
mask=reduce_mask,
)
else:
pt_output_3 = ixfops.moe_w4a16_group_gemm(
input=pt_output_2,
weight=layer.w2_qweight,
w_scales=layer.w2_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w2_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=sorted_token_ids,
dst_to_src=None,
format="NN",
tokens_per_experts_gpu=expert_sizes_gpu,
)
# mul + reduce_sum
# final_hidden_states: (num_tokens, k)
# act
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
pt_output_3 = ixfops.moe_w4a16_group_gemv(
input=pt_output_2,
weight=layer.w2_qweight,
w_scales=layer.w2_scales,
quant_type="awq",
w_zeros=layer.w2_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=sorted_token_ids,
format="NN",
tokens_per_experts_gpu=expert_sizes_gpu,
)
# mul + reduce_sum
# final_hidden_states: (num_tokens, k)
final_hidden_states = ixfops.moe_output_reduce_sum(
input=pt_output_3.view(num_tokens, top_k, -1),
input=pt_output_3.view(num_tokens, layer.top_k, -1),
topk_weight=topk_weights,
scaling_factor=routed_scaling_factor
scaling_factor=layer.routed_scaling_factor,
extra_residual=shared_experts_input,
)
else:
expert_sizes_cpu = expert_sizes_gpu.cpu()
pt_output_1 = ixfops.moe_w4a16_group_gemm(
input=expand_hidden_states,
weight=layer.w13_qweight,
w_scales=layer.w13_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w13_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=None,
format="NN",
tokens_per_experts_gpu=expert_sizes_gpu,
)
# act
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
# w4a16 group gemm 2 + reorder
# pt_output_3: (expand_tokens, k) dtype
if use_ep:
pt_output_3 = torch.empty(
(num_tokens * layer.top_k, hidden_size),
device=x.device,
dtype=x.dtype,
)
ixfops.moe_w4a16_group_gemm(
input=pt_output_2,
weight=layer.w2_qweight,
w_scales=layer.w2_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w2_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=sorted_token_ids,
format="NN",
output=pt_output_3,
tokens_per_experts_gpu=expert_sizes_gpu,
)
reduce_mask = src_to_dst == -1
final_hidden_states = ixfops.moe_output_reduce_sum(
input=pt_output_3.view(num_tokens, layer.top_k, -1),
topk_weight=topk_weights,
scaling_factor=layer.routed_scaling_factor,
mask=reduce_mask,
)
else:
pt_output_3 = ixfops.moe_w4a16_group_gemm(
input=pt_output_2,
weight=layer.w2_qweight,
w_scales=layer.w2_scales,
quant_type="awq",
tokens_per_experts=expert_sizes_cpu,
w_zeros=layer.w2_qzeros,
group_size=self.quant_config.group_size,
dst_to_src=sorted_token_ids,
format="NN",
tokens_per_experts_gpu=expert_sizes_gpu,
)
# mul + reduce_sum
# final_hidden_states: (num_tokens, k)
final_hidden_states = ixfops.moe_output_reduce_sum(
input=pt_output_3.view(num_tokens, layer.top_k, -1),
topk_weight=topk_weights,
scaling_factor=layer.routed_scaling_factor,
extra_residual=shared_experts_input,
)
return final_hidden_states
# return torch.ops.vllm.fused_marlin_moe(
# x,
# layer.w13_qweight,
# layer.w2_qweight,
# layer.w13_scales,
# layer.w2_scales,
# router_logits,
# topk_weights,
# topk_ids,
# w1_zeros=layer.w13_qzeros,
# w2_zeros=layer.w2_qzeros,
# num_bits=self.quant_config.weight_bits,
# )