Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
@@ -9,6 +9,7 @@ from torch.nn import Parameter
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
@@ -60,6 +61,7 @@ from vllm.transformers_utils.config import get_safetensors_params_metadata
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
import ixformer.inference.functions as ixfops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -197,7 +199,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
quant_method.input_dtype = get_marlin_input_dtype(prefix)
|
||||
return quant_method
|
||||
elif isinstance(layer, FusedMoE):
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
# from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
|
||||
# if is_layer_skipped(
|
||||
# prefix,
|
||||
@@ -213,9 +215,10 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
# return MoeWNA16Config.from_config(self.full_config).get_quant_method(
|
||||
# layer, prefix
|
||||
# )
|
||||
moe_quant_method = AWQMarlinMoEMethod(self, layer.moe_config)
|
||||
moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
|
||||
return moe_quant_method
|
||||
# moe_quant_method = AWQMarlinMoEMethod(self, layer.moe_config)
|
||||
# moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
|
||||
# return moe_quant_method
|
||||
return AWQMarlinMoEMethod(self, layer.moe_config)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@@ -389,13 +392,13 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
replace_parameter(layer, "qweight", pad_qweight)
|
||||
replace_parameter(layer, "qzeros", pad_qzeros)
|
||||
replace_parameter(layer, "scales", pad_scales)
|
||||
return
|
||||
|
||||
# TODO(gyf) Marlin format is not support for now..
|
||||
device = layer.qweight.device
|
||||
layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
|
||||
layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
|
||||
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
|
||||
|
||||
return
|
||||
# Allocate marlin workspace
|
||||
layer.workspace = marlin_make_workspace_new(device)
|
||||
|
||||
@@ -811,49 +814,33 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
# Assign the value of shared_experts_output to variable shared_experts_input for fusion
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
# return fused_marlin_moe(
|
||||
# x,
|
||||
# layer.w13_qweight,
|
||||
# layer.w2_qweight,
|
||||
# getattr(layer, "w13_bias", None),
|
||||
# getattr(layer, "w2_bias", None),
|
||||
# layer.w13_scales,
|
||||
# layer.w2_scales,
|
||||
# topk_weights,
|
||||
# topk_ids,
|
||||
# input_global_scale1=getattr(layer, "w13_input_global_scale", None),
|
||||
# input_global_scale2=getattr(layer, "w2_input_global_scale", None),
|
||||
# quant_type_id=self.quant_type.id,
|
||||
# apply_router_weight_on_input=layer.apply_router_weight_on_input,
|
||||
# global_num_experts=layer.global_num_experts,
|
||||
# expert_map=layer.expert_map,
|
||||
# w1_zeros=layer.w13_qzeros,
|
||||
# w2_zeros=layer.w2_qzeros,
|
||||
# workspace=layer.workspace,
|
||||
# input_dtype=self.input_dtype,
|
||||
# inplace=not self.moe.disable_inplace,
|
||||
# )
|
||||
|
||||
num_tokens, num_experts = router_logits.shape
|
||||
assert layer.activation.value == "silu", "Only SiLU activation is supported."
|
||||
use_ep = layer.expert_map is not None
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata:
|
||||
if isinstance(attn_metadata, dict):
|
||||
only_decode = (use_ep == False and all(t.num_decodes > 0 and t.num_prefills ==0 for t in list(attn_metadata.values())))
|
||||
else:
|
||||
only_decode = use_ep == False and attn_metadata.num_decodes > 0 and attn_metadata.num_prefills == 0
|
||||
else:
|
||||
only_decode = False
|
||||
|
||||
if use_ep:
|
||||
start_eid = layer.ep_rank * layer.local_num_experts
|
||||
end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, layer.global_num_experts)
|
||||
if layer.apply_router_weight_on_input:
|
||||
raise NotImplementedError(
|
||||
"Apply router weight on input is not supported for"
|
||||
"fused Marlin MoE method.")
|
||||
|
||||
num_tokens = topk_ids.shape[0]
|
||||
num_experts = layer.global_num_experts
|
||||
if use_ep:
|
||||
hidden_size = x.shape[1]
|
||||
(
|
||||
@@ -875,7 +862,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
dtype=x.dtype,
|
||||
)
|
||||
else:
|
||||
expand_tokens = num_tokens * top_k
|
||||
expand_tokens = num_tokens * layer.top_k
|
||||
(
|
||||
src_to_dst,
|
||||
sorted_token_ids,
|
||||
@@ -885,7 +872,6 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
topk_ids=topk_ids,
|
||||
num_experts=num_experts,
|
||||
)
|
||||
expert_sizes_cpu = expert_sizes_gpu.cpu()
|
||||
|
||||
# expand + reorder
|
||||
# TODO use kernel
|
||||
@@ -893,76 +879,130 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
hidden_states=x,
|
||||
dst_to_src=sorted_token_ids,
|
||||
dst_tokens=expand_tokens,
|
||||
topk=top_k,
|
||||
topk=layer.top_k,
|
||||
src_to_dst=src_to_dst,
|
||||
)
|
||||
|
||||
# w4a16 group gemm 1
|
||||
# pt_output_1: (expand_tokens, 2n) dtype
|
||||
pt_output_1 = ixfops.moe_w4a16_group_gemm(
|
||||
input=expand_hidden_states,
|
||||
weight=layer.w13_qweight,
|
||||
w_scales=layer.w13_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w13_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=None,
|
||||
format="NN",
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
# act
|
||||
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
|
||||
|
||||
# w4a16 group gemm 2 + reorder
|
||||
# pt_output_3: (expand_tokens, k) dtype
|
||||
if use_ep:
|
||||
pt_output_3 = torch.empty(
|
||||
(num_tokens * top_k, hidden_size),
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
ixfops.moe_w4a16_group_gemm(
|
||||
input=pt_output_2,
|
||||
weight=layer.w2_qweight,
|
||||
w_scales=layer.w2_scales,
|
||||
if only_decode:
|
||||
pt_output_1 = ixfops.moe_w4a16_group_gemv(
|
||||
input=expand_hidden_states,
|
||||
weight=layer.w13_qweight,
|
||||
w_scales=layer.w13_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w2_qzeros,
|
||||
w_zeros=layer.w13_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=sorted_token_ids,
|
||||
format="NN",
|
||||
output=pt_output_3,
|
||||
)
|
||||
|
||||
reduce_mask = src_to_dst == -1
|
||||
final_hidden_states = ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
scaling_factor=routed_scaling_factor,
|
||||
mask=reduce_mask,
|
||||
)
|
||||
else:
|
||||
pt_output_3 = ixfops.moe_w4a16_group_gemm(
|
||||
input=pt_output_2,
|
||||
weight=layer.w2_qweight,
|
||||
w_scales=layer.w2_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w2_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=sorted_token_ids,
|
||||
dst_to_src=None,
|
||||
format="NN",
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
# mul + reduce_sum
|
||||
# final_hidden_states: (num_tokens, k)
|
||||
# act
|
||||
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
|
||||
|
||||
pt_output_3 = ixfops.moe_w4a16_group_gemv(
|
||||
input=pt_output_2,
|
||||
weight=layer.w2_qweight,
|
||||
w_scales=layer.w2_scales,
|
||||
quant_type="awq",
|
||||
w_zeros=layer.w2_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=sorted_token_ids,
|
||||
format="NN",
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
# mul + reduce_sum
|
||||
# final_hidden_states: (num_tokens, k)
|
||||
final_hidden_states = ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, top_k, -1),
|
||||
input=pt_output_3.view(num_tokens, layer.top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
scaling_factor=routed_scaling_factor
|
||||
scaling_factor=layer.routed_scaling_factor,
|
||||
extra_residual=shared_experts_input,
|
||||
)
|
||||
|
||||
else:
|
||||
expert_sizes_cpu = expert_sizes_gpu.cpu()
|
||||
pt_output_1 = ixfops.moe_w4a16_group_gemm(
|
||||
input=expand_hidden_states,
|
||||
weight=layer.w13_qweight,
|
||||
w_scales=layer.w13_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w13_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=None,
|
||||
format="NN",
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
# act
|
||||
pt_output_2 = ixfops.silu_and_mul(pt_output_1)
|
||||
|
||||
# w4a16 group gemm 2 + reorder
|
||||
# pt_output_3: (expand_tokens, k) dtype
|
||||
if use_ep:
|
||||
pt_output_3 = torch.empty(
|
||||
(num_tokens * layer.top_k, hidden_size),
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
)
|
||||
|
||||
ixfops.moe_w4a16_group_gemm(
|
||||
input=pt_output_2,
|
||||
weight=layer.w2_qweight,
|
||||
w_scales=layer.w2_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w2_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=sorted_token_ids,
|
||||
format="NN",
|
||||
output=pt_output_3,
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
reduce_mask = src_to_dst == -1
|
||||
final_hidden_states = ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, layer.top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
scaling_factor=layer.routed_scaling_factor,
|
||||
mask=reduce_mask,
|
||||
)
|
||||
else:
|
||||
pt_output_3 = ixfops.moe_w4a16_group_gemm(
|
||||
input=pt_output_2,
|
||||
weight=layer.w2_qweight,
|
||||
w_scales=layer.w2_scales,
|
||||
quant_type="awq",
|
||||
tokens_per_experts=expert_sizes_cpu,
|
||||
w_zeros=layer.w2_qzeros,
|
||||
group_size=self.quant_config.group_size,
|
||||
dst_to_src=sorted_token_ids,
|
||||
format="NN",
|
||||
tokens_per_experts_gpu=expert_sizes_gpu,
|
||||
)
|
||||
|
||||
# mul + reduce_sum
|
||||
# final_hidden_states: (num_tokens, k)
|
||||
final_hidden_states = ixfops.moe_output_reduce_sum(
|
||||
input=pt_output_3.view(num_tokens, layer.top_k, -1),
|
||||
topk_weight=topk_weights,
|
||||
scaling_factor=layer.routed_scaling_factor,
|
||||
extra_residual=shared_experts_input,
|
||||
)
|
||||
return final_hidden_states
|
||||
# return torch.ops.vllm.fused_marlin_moe(
|
||||
# x,
|
||||
# layer.w13_qweight,
|
||||
# layer.w2_qweight,
|
||||
# layer.w13_scales,
|
||||
# layer.w2_scales,
|
||||
# router_logits,
|
||||
# topk_weights,
|
||||
# topk_ids,
|
||||
# w1_zeros=layer.w13_qzeros,
|
||||
# w2_zeros=layer.w2_qzeros,
|
||||
# num_bits=self.quant_config.weight_bits,
|
||||
# )
|
||||
|
||||
Reference in New Issue
Block a user