[refactor]update Kunlun classes with monkey patch (#122)

Signed-off-by: Li Wei <liwei.109@outlook.com>
This commit is contained in:
Li Wei
2026-01-19 20:24:19 +08:00
committed by GitHub
parent 2512259944
commit 8f56cbf3ed
8 changed files with 444 additions and 378 deletions

View File

@@ -21,7 +21,7 @@ from vllm.distributed import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm_kunlun.ops.fused_moe.layer import FusedMoE from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,

View File

@@ -38,7 +38,7 @@ from vllm.distributed import (get_ep_group, get_pp_group,
tensor_model_parallel_all_gather) tensor_model_parallel_all_gather)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm_kunlun.ops.activation import SiluAndMul from vllm_kunlun.ops.activation import SiluAndMul
from vllm_kunlun.ops.fused_moe.layer import FusedMoE from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,

View File

@@ -27,7 +27,7 @@ from vllm.logger import init_logger
from vllm_kunlun.ops.fla import (fused_recurrent_gated_delta_rule, torch_chunk_gated_delta_rule, chunk_gated_delta_rule) from vllm_kunlun.ops.fla import (fused_recurrent_gated_delta_rule, torch_chunk_gated_delta_rule, chunk_gated_delta_rule)
from vllm.model_executor.layers.fla.ops import ( from vllm.model_executor.layers.fla.ops import (
RMSNormGated) RMSNormGated)
from vllm_kunlun.ops.fused_moe.layer import FusedMoE from vllm.model_executor.layers.fused_moe.layer import FusedMoE
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.model_executor.layers.layernorm import ( from vllm.model_executor.layers.layernorm import (

View File

@@ -1,17 +1,35 @@
"""layer.py""" #
# Copyright (c) 2026 Baidu, Inc. All Rights Reserved.
#
# This file is a part of the vllm-kunlun project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext from typing import Callable, Optional
from typing import Callable, Optional, Union, get_args
import torch import torch
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
should_ignore_layer, should_ignore_layer,
) )
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe.layer import (
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod UnquantizedFusedMoEMethod,
FusedMoE,
)
def apply(
class KunlunUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
@@ -37,43 +55,47 @@ def apply(
"""apply""" """apply"""
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `UnquantizedFusedMoEMethod` yet.") "EPLB not supported for `UnquantizedFusedMoEMethod` yet."
)
"""forward_kunlun""" """forward_kunlun"""
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
if self.moe.use_ep: if self.moe.use_ep:
return ops.fused_moe_ep(x, return ops.fused_moe_ep(
layer.w13_weight, x,
layer.w2_weight, layer.w13_weight,
router_logits, layer.w2_weight,
self.moe.ep_rank, router_logits,
top_k, self.moe.ep_rank,
renormalize=renormalize, top_k,
inplace=True, renormalize=renormalize,
use_grouped_topk=use_grouped_topk, inplace=True,
num_expert_group=num_expert_group, use_grouped_topk=use_grouped_topk,
topk_group=topk_group) num_expert_group=num_expert_group,
topk_group=topk_group,
)
else: else:
return ops.fused_moe(x, return ops.fused_moe(
layer.w13_weight, x,
layer.w2_weight, layer.w13_weight,
router_logits, layer.w2_weight,
self.moe.ep_rank, router_logits,
top_k, self.moe.ep_rank,
renormalize=renormalize, top_k,
inplace=True, renormalize=renormalize,
use_grouped_topk=use_grouped_topk, inplace=True,
num_expert_group=num_expert_group, use_grouped_topk=use_grouped_topk,
topk_group=topk_group, num_expert_group=num_expert_group,
scoring_func=scoring_func, topk_group=topk_group,
e_score_correction_bias=e_score_correction_bias, scoring_func=scoring_func,
w1_bias=getattr(layer, 'w13_bias', None), e_score_correction_bias=e_score_correction_bias,
w2_bias=getattr(layer, 'w2_bias', None), w1_bias=getattr(layer, "w13_bias", None),
) w2_bias=getattr(layer, "w2_bias", None),
)
UnquantizedFusedMoEMethod.apply = apply
class VllmFusedMoE(FusedMoE): class KunlunFusedMoE(FusedMoE):
def __init__( def __init__(
self, self,
num_experts: int, # Global number of experts num_experts: int, # Global number of experts
@@ -131,7 +153,8 @@ class VllmFusedMoE(FusedMoE):
has_bias=has_bias, has_bias=has_bias,
is_sequence_parallel=is_sequence_parallel, is_sequence_parallel=is_sequence_parallel,
zero_expert_num=zero_expert_num, zero_expert_num=zero_expert_num,
zero_expert_type=zero_expert_type) zero_expert_type=zero_expert_type,
)
self.has_bias = has_bias self.has_bias = has_bias
self.register_parameter("w13_bias", None) self.register_parameter("w13_bias", None)
self.register_parameter("w2_bias", None) self.register_parameter("w2_bias", None)
@@ -143,7 +166,7 @@ class VllmFusedMoE(FusedMoE):
fused_mapping=self.quant_config.packed_modules_mapping, fused_mapping=self.quant_config.packed_modules_mapping,
) )
): ):
self.quant_method = UnquantizedFusedMoEMethod(self.moe_config) self.quant_method = KunlunUnquantizedFusedMoEMethod(self.moe_config)
moe_quant_params = { moe_quant_params = {
"num_experts": self.local_num_experts, "num_experts": self.local_num_experts,
"hidden_size": hidden_size, "hidden_size": hidden_size,
@@ -154,4 +177,17 @@ class VllmFusedMoE(FusedMoE):
self.quant_method.create_weights(layer=self, **moe_quant_params) self.quant_method.create_weights(layer=self, **moe_quant_params)
FusedMoE = VllmFusedMoE # monkey patch
from vllm.model_executor.layers.fused_moe import layer
layer.UnquantizedFusedMoEMethod = KunlunUnquantizedFusedMoEMethod
layer.FusedMoE = KunlunFusedMoE
print(
"[Monkey Patch Applied] >>> from vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod \
--> vllm_kunlun.ops.fused_moe.layer.KunlunUnquantizedFusedMoEMethod"
)
print(
"[Monkey Patch Applied] >>> from vllm.model_executor.layers.fused_moe.layer.FusedMoE \
--> vllm_kunlun.ops.fused_moe.layer.KunlunFusedMoE"
)

View File

@@ -17,112 +17,119 @@
# limitations under the License. # limitations under the License.
import torch import torch
from vllm.logger import init_logger
from typing import Optional from typing import Optional
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
logger = init_logger(__name__)
class KunlunAWQLinearMethod(AWQLinearMethod):
def repack_int4_for_kunlun(self, packed: torch.Tensor, num_bits: int = 4):
"""Convert AWQ-packed int4 weights to Kunlun XPU format.
Input: packed[N, K], dtype=int32, saved as AWQ order
Output: packed_reordered[N, K], dtype=int32, saved as Kunlun order
"""
N, K = packed.shape
self.align_type = 1 if K % 8 == 0 else 0
assert num_bits == 4, "Only int4 supported now"
shifts = torch.arange(0, 32, num_bits, device=packed.device, dtype=torch.int32)
if self.align_type == 0: # NORMAL MODE
# Unpack AWQ order:[0, 2, 4, 6, 1, 3, 5, 7]
unpacked_awq = (packed.unsqueeze(-1) >> shifts) & 0xF # [N, K, 8]
# Reverse AWQ order and convert to KUNLUN order
AWQ_TO_KUNLUN_ORDER_NORMAL = [4, 0, 5, 1, 6, 2, 7, 3]
# [0,2,4,6,1,3,5,7] --> [1, 0, 3, 2, 5, 4, 7, 6]
unpacked_kunlun = unpacked_awq[..., AWQ_TO_KUNLUN_ORDER_NORMAL] # [N, K, 8]
# Pack to int32, order[6, 7, 4, 5, 2, 3, 0, 1]
packed_kunlun = (unpacked_kunlun << shifts).sum(
dim=-1, dtype=torch.int32
) # [N, K]
elif self.align_type == 1: # FAST MODEL
# Unpack AWQ order
unpacked_awq = (
packed.view(N, K // 8, 8).unsqueeze(-1) >> shifts
) & 0xF # [N, K//8, 8, 8]
# Reverse AWQ order and convert to KUNLUN order
AWQ_TO_KUNLUN_ORDER_FAST = [
32, 0, 36, 4, 33, 1, 37, 5,
34, 2, 38, 6, 35, 3, 39, 7,
40, 8, 44, 12, 41, 9, 45, 13,
42, 10, 46, 14, 43, 11, 47, 15,
48, 16, 52, 20, 49, 17, 53, 21,
50, 18, 54, 22, 51, 19, 55, 23,
56, 24, 60, 28, 57, 25, 61, 29,
58, 26, 62, 30, 59, 27, 63, 31
]
unpacked_awq = unpacked_awq.reshape(N, K // 8, 64)
unpacked_kunlun = unpacked_awq[..., AWQ_TO_KUNLUN_ORDER_FAST] # [N, K//8, 64]
# Pack to int32
unpacked_kunlun = unpacked_kunlun.reshape(N, K // 8, 8, 8)
packed_kunlun = (
(unpacked_kunlun << shifts).sum(dim=-1, dtype=torch.int32).reshape(N, K)
) # [N, K]
else:
raise NotImplementedError
return packed_kunlun
def repack_int4_for_kunlun(self, packed: torch.Tensor, num_bits: int = 4): def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Convert AWQ-packed int4 weights to Kunlun XPU format. logger.warning_once(f"Repacking INT4 for XPU ...")
Input: packed[N, K], dtype=int32, saved as AWQ order layer.qweight = torch.nn.Parameter(
Output: packed_reordered[N, K], dtype=int32, saved as Kunlun order (
""" self.repack_int4_for_kunlun(layer.qweight.data)
N, K = packed.shape if layer.qweight.data.dtype == torch.int32
self.align_type = 1 if K % 8 == 0 else 0 else layer.qweight.data
assert num_bits == 4, "Only int4 supported now" ),
shifts = torch.arange(0, 32, num_bits, device=packed.device, dtype=torch.int32) requires_grad=False,
if self.align_type == 0: # NORMAL MODE
# Unpack AWQ order:[0, 2, 4, 6, 1, 3, 5, 7]
unpacked_awq = (packed.unsqueeze(-1) >> shifts) & 0xF # [N, K, 8]
# Reverse AWQ order and convert to KUNLUN order
AWQ_TO_KUNLUN_ORDER_NORMAL = [4, 0, 5, 1, 6, 2, 7, 3]
# [0,2,4,6,1,3,5,7] --> [1, 0, 3, 2, 5, 4, 7, 6]
unpacked_kunlun = unpacked_awq[..., AWQ_TO_KUNLUN_ORDER_NORMAL] # [N, K, 8]
# Pack to int32, order[6, 7, 4, 5, 2, 3, 0, 1]
packed_kunlun = (unpacked_kunlun << shifts).sum(
dim=-1, dtype=torch.int32
) # [N, K]
elif self.align_type == 1: # FAST MODEL
# Unpack AWQ order
unpacked_awq = (
packed.view(N, K // 8, 8).unsqueeze(-1) >> shifts
) & 0xF # [N, K//8, 8, 8]
# Reverse AWQ order and convert to KUNLUN order
AWQ_TO_KUNLUN_ORDER_FAST = [
32, 0, 36, 4, 33, 1, 37, 5,
34, 2, 38, 6, 35, 3, 39, 7,
40, 8, 44, 12, 41, 9, 45, 13,
42, 10, 46, 14, 43, 11, 47, 15,
48, 16, 52, 20, 49, 17, 53, 21,
50, 18, 54, 22, 51, 19, 55, 23,
56, 24, 60, 28, 57, 25, 61, 29,
58, 26, 62, 30, 59, 27, 63, 31
]
unpacked_awq = unpacked_awq.reshape(N, K // 8, 64)
unpacked_kunlun = unpacked_awq[..., AWQ_TO_KUNLUN_ORDER_FAST] # [N, K//8, 64]
# Pack to int32
unpacked_kunlun = unpacked_kunlun.reshape(N, K // 8, 8, 8)
packed_kunlun = (
(unpacked_kunlun << shifts).sum(dim=-1, dtype=torch.int32).reshape(N, K)
) # [N, K]
else:
raise NotImplementedError
return packed_kunlun
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.qweight = torch.nn.Parameter(
(
self.repack_int4_for_kunlun(layer.qweight.data)
if layer.qweight.data.dtype == torch.int32
else layer.qweight.data
),
requires_grad=False,
)
layer.qzeros = torch.nn.Parameter(
(
self.repack_int4_for_kunlun(layer.qzeros.data)
if layer.qzeros.data.dtype == torch.int32
else layer.qzeros.data
),
requires_grad=False,
)
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
def apply(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
qweight = layer.qweight
scales = layer.scales
qzeros = layer.qzeros
pack_factor = self.quant_config.pack_factor
out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
reshaped_x = x.reshape(-1, x.shape[-1])
# num_tokens >= threshold
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
if FP16_MATMUL_HEURISTIC_CONDITION:
out = torch.ops._C.awq_dequantize(
qweight, scales, qzeros, quant_type=0, align_type=self.align_type
) )
out = torch.matmul(reshaped_x, out) layer.qzeros = torch.nn.Parameter(
else: (
out = torch.ops._C.awq_gemm( self.repack_int4_for_kunlun(layer.qzeros.data)
reshaped_x, qweight, scales, qzeros, align_type=self.align_type if layer.qzeros.data.dtype == torch.int32
else layer.qzeros.data
),
requires_grad=False,
) )
if bias is not None: layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
out.add_(bias)
return out.reshape(out_shape)
AWQLinearMethod.repack_int4_for_kunlun = repack_int4_for_kunlun def apply(
AWQLinearMethod.process_weights_after_loading = process_weights_after_loading self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None
AWQLinearMethod.apply = apply ) -> torch.Tensor:
qweight = layer.qweight
scales = layer.scales
qzeros = layer.qzeros
pack_factor = self.quant_config.pack_factor
out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
reshaped_x = x.reshape(-1, x.shape[-1])
# num_tokens >= threshold
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
if FP16_MATMUL_HEURISTIC_CONDITION:
out = torch.ops._C.awq_dequantize(
qweight, scales, qzeros, quant_type=0, align_type=self.align_type
)
out = torch.matmul(reshaped_x, out)
else:
out = torch.ops._C.awq_gemm(
reshaped_x, qweight, scales, qzeros, align_type=self.align_type
)
if bias is not None:
out.add_(bias)
return out.reshape(out_shape)
# monkey patch
from vllm.model_executor.layers.quantization import awq
awq.AWQLinearMethod = KunlunAWQLinearMethod
print(
"[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.awq.AWQLinearMethod \
--> vllm_kunlun.ops.quantization.awq.KunlunAWQLinearMethod"
)

View File

@@ -24,176 +24,190 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
) )
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMethod):
# NOTE: xtorch_ops use max as scale
with torch.no_grad():
layer.w13_weight_scale.mul_(127.0)
layer.w2_weight_scale.mul_(127.0)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# NOTE: xtorch_ops use max as scale
with torch.no_grad():
layer.w13_weight_scale.mul_(127.0)
layer.w2_weight_scale.mul_(127.0)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool, renormalize: bool,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0, routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
enable_eplb: bool = False, enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
hidden_states = x hidden_states = x
global_num_experts, up_gate_size, _ = layer.w13_weight.shape global_num_experts, up_gate_size, _ = layer.w13_weight.shape
M, N = hidden_states.shape M, N = hidden_states.shape
hidden_dim = layer.w2_weight.shape[1] hidden_dim = layer.w2_weight.shape[1]
normed_score = torch.empty( normed_score = torch.empty(
M, top_k, dtype=torch.float32, device=hidden_states.device M, top_k, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(M, top_k, dtype=torch.int32, device=hidden_states.device)
num_blocks = 12
block_statistic = torch.zeros(
num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device
)
router_logits = router_logits.float()
if scoring_func == "softmax":
torch.ops._C.moe_softmax_topk_norm(
x=router_logits,
normed_score=normed_score,
topk_index=topk_ids,
block_statistic=None,
stable=True,
) )
elif scoring_func == "sigmoid": topk_ids = torch.empty(M, top_k, dtype=torch.int32, device=hidden_states.device)
torch.ops._C.moe_sigmoid_group_topk_norm( num_blocks = 12
x=router_logits, block_statistic = torch.zeros(
norm_score=normed_score, num_blocks,
topk_index=topk_ids, global_num_experts,
block_static=block_statistic, dtype=torch.int32,
bias=e_score_correction_bias, device=hidden_states.device,
n_group=num_expert_group,
topk_group=topk_group,
scale=routed_scaling_factor,
) )
moe_expand = torch.empty( router_logits = router_logits.float()
(M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device if scoring_func == "softmax":
) # [M, top_k, N], float torch.ops._C.moe_softmax_topk_norm(
expert_m = torch.zeros( x=router_logits,
global_num_experts, dtype=torch.int32, device=hidden_states.device normed_score=normed_score,
) # [E] topk_index=topk_ids,
sorted_tokens_num_lod = torch.zeros( block_statistic=None,
global_num_experts + 1, dtype=torch.int32, device=hidden_states.device stable=True,
) # [E+1] )
sorted_tokens_idx = torch.zeros( elif scoring_func == "sigmoid":
M * top_k, dtype=torch.int32, device=hidden_states.device torch.ops._C.moe_sigmoid_group_topk_norm(
) x=router_logits,
norm_score=normed_score,
topk_index=topk_ids,
block_static=block_statistic,
bias=e_score_correction_bias,
n_group=num_expert_group,
topk_group=topk_group,
scale=routed_scaling_factor,
)
torch.ops._C.gen_block_statistic(topk_ids, block_statistic) moe_expand = torch.empty(
(M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device
) # [M, top_k, N], float
expert_m = torch.zeros(
global_num_experts, dtype=torch.int32, device=hidden_states.device
) # [E]
sorted_tokens_num_lod = torch.zeros(
global_num_experts + 1, dtype=torch.int32, device=hidden_states.device
) # [E+1]
sorted_tokens_idx = torch.zeros(
M * top_k, dtype=torch.int32, device=hidden_states.device
)
torch.ops._C.moe_pre_sorted( torch.ops._C.gen_block_statistic(topk_ids, block_statistic)
x=hidden_states,
topk_index=topk_ids,
block_statistic=block_statistic,
moe_expand=moe_expand,
moe_index=sorted_tokens_idx,
expert_m=expert_m,
sorted_tokens_num_lod=sorted_tokens_num_lod,
)
y = torch.empty( torch.ops._C.moe_pre_sorted(
M, x=hidden_states,
top_k, topk_index=topk_ids,
layer.w13_weight.shape[1], block_statistic=block_statistic,
dtype=hidden_states.dtype, moe_expand=moe_expand,
device=hidden_states.device, moe_index=sorted_tokens_idx,
) expert_m=expert_m,
sorted_tokens_num_lod=sorted_tokens_num_lod,
)
moe_expand = moe_expand.view(M * top_k, hidden_dim) y = torch.empty(
M,
top_k,
layer.w13_weight.shape[1],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
x_shape = moe_expand.shape moe_expand = moe_expand.view(M * top_k, hidden_dim)
x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device)
x_scale = torch.empty(
(x_shape[0], 1), dtype=torch.float32, device=moe_expand.device
)
torch.ops._C.quant2d(moe_expand, x_q, x_scale, force_sdnn=True)
torch.ops._C.moe_fc( x_shape = moe_expand.shape
x=x_q, x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device)
x_perchannel_max=x_scale, x_scale = torch.empty(
weight=layer.w13_weight, (x_shape[0], 1), dtype=torch.float32, device=moe_expand.device
w_perchannel_max=layer.w13_weight_scale, )
sorted_tokens_num_lod=sorted_tokens_num_lod, torch.ops._C.quant2d(moe_expand, x_q, x_scale, force_sdnn=True)
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=top_k,
y=y,
topk_ids=topk_ids,
# sort_mode=False,
act=None,
)
d = y.shape[-1] // 2 torch.ops._C.moe_fc(
output_shape = y.shape[:-1] + (d,) x=x_q,
out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device) x_perchannel_max=x_scale,
torch.ops._C.silu_and_mul(out1, y) weight=layer.w13_weight,
w_perchannel_max=layer.w13_weight_scale,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=top_k,
y=y,
topk_ids=topk_ids,
# sort_mode=False,
act=None,
)
out = torch.empty( d = y.shape[-1] // 2
M, output_shape = y.shape[:-1] + (d,)
top_k, out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device)
layer.w2_weight.shape[1], torch.ops._C.silu_and_mul(out1, y)
dtype=hidden_states.dtype,
device=hidden_states.device,
)
out1 = out1.reshape(-1, out1.shape[-1]) out = torch.empty(
x_shape = out1.shape M,
x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device) top_k,
x_scale = torch.empty( layer.w2_weight.shape[1],
(x_shape[0], 1), dtype=torch.float32, device=moe_expand.device dtype=hidden_states.dtype,
) device=hidden_states.device,
torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True) )
torch.ops._C.moe_fc( out1 = out1.reshape(-1, out1.shape[-1])
x=x_q, x_shape = out1.shape
x_perchannel_max=x_scale, x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device)
weight=layer.w2_weight, x_scale = torch.empty(
w_perchannel_max=layer.w2_weight_scale, (x_shape[0], 1), dtype=torch.float32, device=moe_expand.device
sorted_tokens_num_lod=sorted_tokens_num_lod, )
sorted_tokens_idx=sorted_tokens_idx, torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True)
moe_topk=top_k,
y=out,
topk_ids=topk_ids,
# sort_mode=False,
act=None,
)
dequant_scale = torch.ones([M, top_k], dtype=torch.float32, device=out.device) torch.ops._C.moe_fc(
output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device) x=x_q,
sorted_tokens_idx = sorted_tokens_idx.view(M, top_k) x_perchannel_max=x_scale,
weight=layer.w2_weight,
w_perchannel_max=layer.w2_weight_scale,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=top_k,
y=out,
topk_ids=topk_ids,
# sort_mode=False,
act=None,
)
torch.ops._C.moe_post( dequant_scale = torch.ones([M, top_k], dtype=torch.float32, device=out.device)
x=out, output = torch.empty(
moe_index=sorted_tokens_idx, [M, N], dtype=hidden_states.dtype, device=hidden_states.device
normed_scale=normed_score, )
dequant_scale=dequant_scale, sorted_tokens_idx = sorted_tokens_idx.view(M, top_k)
y=output,
) torch.ops._C.moe_post(
return output x=out,
moe_index=sorted_tokens_idx,
normed_scale=normed_score,
dequant_scale=dequant_scale,
y=output,
)
return output
CompressedTensorsW8A8Int8MoEMethod.process_weights_after_loading = ( # monkey patch
process_weights_after_loading from vllm.model_executor.layers.quantization.compressed_tensors import (
compressed_tensors_moe,
)
compressed_tensors_moe.CompressedTensorsW8A8Int8MoEMethod = (
KunlunCompressedTensorsW8A8Int8MoEMethod
)
print(
"[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW8A8Int8MoEMethod \
--> vllm_kunlun.ops.quantization.compressed_tensors_moe.py:KunlunCompressedTensorsW8A8Int8MoEMethod"
) )
CompressedTensorsW8A8Int8MoEMethod.apply = apply

View File

@@ -17,92 +17,99 @@
# limitations under the License. # limitations under the License.
import torch import torch
from torch.nn.parameter import Parameter
from typing import Optional from typing import Optional
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod, ExllamaState from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod, ExllamaState
logger = init_logger(__name__)
class KunlunGPTQLinearMethod(GPTQLinearMethod):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# for torch.compile
logger.warning_once(f"Repacking INT4 for XPU ...")
layer.qzeros = Parameter(
self.repack_int4_for_kunlun(layer.qzeros.data, self.quant_config.weight_bits)
if self.quant_config.weight_bits == 4 else layer.qzeros.data,
requires_grad=False
)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if layer.exllama_state == ExllamaState.UNINITIALIZED:
if self.quant_config.desc_act:
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
else:
layer.g_idx.data = torch.empty((0, ),
dtype=torch.int,
device=layer.g_idx.device)
layer.exllama_state = ExllamaState.READY
# No need shuffle on xpu
# ops.gptq_shuffle(layer.qweight, layer.g_idx,
# self.quant_config.weight_bits)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def repack_int4_for_kunlun(self, packed: torch.Tensor, num_bits: int = 4):
# for torch.compile N, K = packed.shape
layer.qzeros = Parameter( assert num_bits == 4, "Only int4 supported now"
self.repack_int4_for_kunlun(layer.qzeros.data, self.quant_config.weight_bits) shifts = torch.arange(0, 32, num_bits, device=packed.device, dtype=torch.int32)
if self.quant_config.weight_bits == 4 else layer.qzeros.data,
requires_grad=False
)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)
# exllama needs to shuffle the weight after the weight is loaded # Unpack int32 to int4 values
# here we do the shuffle on first forward pass unpacked_gptq = (
if layer.exllama_state == ExllamaState.UNINITIALIZED: packed.view(N, K // 8, 8).unsqueeze(-1) >> shifts
if self.quant_config.desc_act: ) & 0xF # [N, K//8, 8, 8]
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
else:
layer.g_idx.data = torch.empty((0, ),
dtype=torch.int,
device=layer.g_idx.device)
layer.exllama_state = ExllamaState.READY
# No need shuffle on xpu # Convert to KUNLUN order
# ops.gptq_shuffle(layer.qweight, layer.g_idx, GPTQ_TO_KUNLUN_ORDER_FAST = [
# self.quant_config.weight_bits) 32, 0, 33, 1, 34, 2, 35, 3,
36, 4, 37, 5, 38, 6, 39, 7,
40, 8, 41, 9, 42, 10, 43, 11,
44, 12, 45, 13, 46, 14, 47, 15,
48, 16, 49, 17, 50, 18, 51, 19,
52, 20, 53, 21, 54, 22, 55, 23,
56, 24, 57, 25, 58, 26, 59, 27,
60, 28, 61, 29, 62, 30, 63, 31,
]
unpacked_gptq = unpacked_gptq.reshape(N, K // 8, 64)
unpacked_kunlun = unpacked_gptq[..., GPTQ_TO_KUNLUN_ORDER_FAST] # [N, K//8, 64]
# Pack to int32
unpacked_kunlun = unpacked_kunlun.reshape(N, K // 8, 8, 8)
packed_kunlun = (
(unpacked_kunlun << shifts).sum(dim=-1, dtype=torch.int32).reshape(N, K)
) # [N, K]
return packed_kunlun
def repack_int4_for_kunlun(self, packed: torch.Tensor, num_bits: int = 4): def apply(
N, K = packed.shape self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None
assert num_bits == 4, "Only int4 supported now" ) -> torch.Tensor:
shifts = torch.arange(0, 32, num_bits, device=packed.device, dtype=torch.int32) out_shape = x.shape[:-1] + (layer.qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
# Unpack int32 to int4 values output = torch.ops.xspeedgate_ops.gptq_gemm(
unpacked_gptq = ( reshaped_x,
packed.view(N, K // 8, 8).unsqueeze(-1) >> shifts layer.qweight,
) & 0xF # [N, K//8, 8, 8] layer.qzeros,
layer.scales,
layer.g_idx,
layer.exllama_state == ExllamaState.READY,
self.quant_config.weight_bits,
)
if bias is not None:
output.add_(bias)
return output.reshape(out_shape)
# Convert to KUNLUN order # monkey patch
GPTQ_TO_KUNLUN_ORDER_FAST = [ from vllm.model_executor.layers.quantization import gptq
32, 0, 33, 1, 34, 2, 35, 3,
36, 4, 37, 5, 38, 6, 39, 7,
40, 8, 41, 9, 42, 10, 43, 11,
44, 12, 45, 13, 46, 14, 47, 15,
48, 16, 49, 17, 50, 18, 51, 19,
52, 20, 53, 21, 54, 22, 55, 23,
56, 24, 57, 25, 58, 26, 59, 27,
60, 28, 61, 29, 62, 30, 63, 31,
]
unpacked_gptq = unpacked_gptq.reshape(N, K // 8, 64)
unpacked_kunlun = unpacked_gptq[..., GPTQ_TO_KUNLUN_ORDER_FAST] # [N, K//8, 64]
# Pack to int32 gptq.GPTQLinearMethod = KunlunGPTQLinearMethod
unpacked_kunlun = unpacked_kunlun.reshape(N, K // 8, 8, 8) print(
packed_kunlun = ( "[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.gptq.GPTQLinearMethod \
(unpacked_kunlun << shifts).sum(dim=-1, dtype=torch.int32).reshape(N, K) --> vllm_kunlun.ops.quantization.gptq.KunlunGPTQLinearMethod"
) # [N, K] )
return packed_kunlun
def apply(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
out_shape = x.shape[:-1] + (layer.qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
output = torch.ops.xspeedgate_ops.gptq_gemm(
reshaped_x,
layer.qweight,
layer.qzeros,
layer.scales,
layer.g_idx,
layer.exllama_state == ExllamaState.READY,
self.quant_config.weight_bits,
)
if bias is not None:
output.add_(bias)
return output.reshape(out_shape)
GPTQLinearMethod.repack_int4_for_kunlun = repack_int4_for_kunlun
GPTQLinearMethod.process_weights_after_loading = process_weights_after_loading
GPTQLinearMethod.apply = apply

View File

@@ -21,7 +21,6 @@ from typing import Optional
import torch import torch
import xspeedgate_ops import xspeedgate_ops
from vllm.platforms import current_platform, PlatformEnum from vllm.platforms import current_platform, PlatformEnum
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise, convert_to_channelwise,
) )
@@ -100,9 +99,12 @@ class KunlunScaledMMLinearKernel(CutlassScaledMMLinearKernel):
# ) # )
# monkey patch
_POSSIBLE_KERNELS[PlatformEnum.CUDA] = [KunlunScaledMMLinearKernel] _POSSIBLE_KERNELS[PlatformEnum.CUDA] = [KunlunScaledMMLinearKernel]
from vllm.model_executor.layers.quantization.kernels.scaled_mm import cutlass
cutlass.CutlassScaledMMLinearKernel = KunlunScaledMMLinearKernel
print( print(
f"[vllm_kunlun] ScaledMM kernels: {[k.__name__ for k in _POSSIBLE_KERNELS[PlatformEnum.CUDA]]}" "[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass.CutlassScaledMMLinearKernel \
--> vllm_kunlun.ops.quantization.kernels.kunlun_scale_mm.KunlunScaledMMLinearKernel"
) )