diff --git a/vllm_kunlun/models/mimo_v2_flash.py b/vllm_kunlun/models/mimo_v2_flash.py index 033e1be..10d2f0b 100644 --- a/vllm_kunlun/models/mimo_v2_flash.py +++ b/vllm_kunlun/models/mimo_v2_flash.py @@ -21,7 +21,7 @@ from vllm.distributed import ( tensor_model_parallel_all_gather, ) from vllm.logger import init_logger -from vllm_kunlun.ops.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, diff --git a/vllm_kunlun/models/qwen3_moe.py b/vllm_kunlun/models/qwen3_moe.py index d7f8031..68cf9b6 100644 --- a/vllm_kunlun/models/qwen3_moe.py +++ b/vllm_kunlun/models/qwen3_moe.py @@ -38,7 +38,7 @@ from vllm.distributed import (get_ep_group, get_pp_group, tensor_model_parallel_all_gather) from vllm.logger import init_logger from vllm_kunlun.ops.activation import SiluAndMul -from vllm_kunlun.ops.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, diff --git a/vllm_kunlun/models/qwen3_next.py b/vllm_kunlun/models/qwen3_next.py index fdcc2d8..a665b18 100644 --- a/vllm_kunlun/models/qwen3_next.py +++ b/vllm_kunlun/models/qwen3_next.py @@ -27,7 +27,7 @@ from vllm.logger import init_logger from vllm_kunlun.ops.fla import (fused_recurrent_gated_delta_rule, torch_chunk_gated_delta_rule, chunk_gated_delta_rule) from vllm.model_executor.layers.fla.ops import ( RMSNormGated) -from vllm_kunlun.ops.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.fused_moe.layer import FusedMoE # yapf conflicts with isort for this block # yapf: disable from vllm.model_executor.layers.layernorm import ( diff --git a/vllm_kunlun/ops/fused_moe/layer.py b/vllm_kunlun/ops/fused_moe/layer.py index b6fb2be..51d0c84 100644 --- a/vllm_kunlun/ops/fused_moe/layer.py +++ b/vllm_kunlun/ops/fused_moe/layer.py @@ -1,17 +1,35 @@ -"""layer.py""" +# +# Copyright (c) 2026 Baidu, Inc. All Rights Reserved. +# +# This file is a part of the vllm-kunlun project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -from contextlib import nullcontext -from typing import Callable, Optional, Union, get_args +from typing import Callable, Optional import torch from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( should_ignore_layer, ) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig -from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod +from vllm.model_executor.layers.fused_moe.layer import ( + UnquantizedFusedMoEMethod, + FusedMoE, +) -def apply( + +class KunlunUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): + def apply( self, layer: torch.nn.Module, x: torch.Tensor, @@ -37,43 +55,47 @@ def apply( """apply""" if enable_eplb: raise NotImplementedError( - "EPLB not supported for `UnquantizedFusedMoEMethod` yet.") - + "EPLB not supported for `UnquantizedFusedMoEMethod` yet." + ) + """forward_kunlun""" from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops + if self.moe.use_ep: - return ops.fused_moe_ep(x, - layer.w13_weight, - layer.w2_weight, - router_logits, - self.moe.ep_rank, - top_k, - renormalize=renormalize, - inplace=True, - use_grouped_topk=use_grouped_topk, - num_expert_group=num_expert_group, - topk_group=topk_group) + return ops.fused_moe_ep( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, + self.moe.ep_rank, + top_k, + renormalize=renormalize, + inplace=True, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group, + ) else: - return ops.fused_moe(x, - layer.w13_weight, - layer.w2_weight, - router_logits, - self.moe.ep_rank, - top_k, - renormalize=renormalize, - inplace=True, - use_grouped_topk=use_grouped_topk, - num_expert_group=num_expert_group, - topk_group=topk_group, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - w1_bias=getattr(layer, 'w13_bias', None), - w2_bias=getattr(layer, 'w2_bias', None), - ) + return ops.fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, + self.moe.ep_rank, + top_k, + renormalize=renormalize, + inplace=True, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + w1_bias=getattr(layer, "w13_bias", None), + w2_bias=getattr(layer, "w2_bias", None), + ) -UnquantizedFusedMoEMethod.apply = apply -class VllmFusedMoE(FusedMoE): +class KunlunFusedMoE(FusedMoE): def __init__( self, num_experts: int, # Global number of experts @@ -131,7 +153,8 @@ class VllmFusedMoE(FusedMoE): has_bias=has_bias, is_sequence_parallel=is_sequence_parallel, zero_expert_num=zero_expert_num, - zero_expert_type=zero_expert_type) + zero_expert_type=zero_expert_type, + ) self.has_bias = has_bias self.register_parameter("w13_bias", None) self.register_parameter("w2_bias", None) @@ -143,7 +166,7 @@ class VllmFusedMoE(FusedMoE): fused_mapping=self.quant_config.packed_modules_mapping, ) ): - self.quant_method = UnquantizedFusedMoEMethod(self.moe_config) + self.quant_method = KunlunUnquantizedFusedMoEMethod(self.moe_config) moe_quant_params = { "num_experts": self.local_num_experts, "hidden_size": hidden_size, @@ -154,4 +177,17 @@ class VllmFusedMoE(FusedMoE): self.quant_method.create_weights(layer=self, **moe_quant_params) -FusedMoE = VllmFusedMoE +# monkey patch +from vllm.model_executor.layers.fused_moe import layer + +layer.UnquantizedFusedMoEMethod = KunlunUnquantizedFusedMoEMethod +layer.FusedMoE = KunlunFusedMoE + +print( + "[Monkey Patch Applied] >>> from vllm.model_executor.layers.fused_moe.layer.UnquantizedFusedMoEMethod \ + --> vllm_kunlun.ops.fused_moe.layer.KunlunUnquantizedFusedMoEMethod" +) +print( + "[Monkey Patch Applied] >>> from vllm.model_executor.layers.fused_moe.layer.FusedMoE \ + --> vllm_kunlun.ops.fused_moe.layer.KunlunFusedMoE" +) diff --git a/vllm_kunlun/ops/quantization/awq.py b/vllm_kunlun/ops/quantization/awq.py index e7b65bc..242fbb4 100644 --- a/vllm_kunlun/ops/quantization/awq.py +++ b/vllm_kunlun/ops/quantization/awq.py @@ -17,112 +17,119 @@ # limitations under the License. import torch - +from vllm.logger import init_logger from typing import Optional from vllm.model_executor.layers.quantization.awq import AWQLinearMethod +logger = init_logger(__name__) + +class KunlunAWQLinearMethod(AWQLinearMethod): + def repack_int4_for_kunlun(self, packed: torch.Tensor, num_bits: int = 4): + """Convert AWQ-packed int4 weights to Kunlun XPU format. + Input: packed[N, K], dtype=int32, saved as AWQ order + Output: packed_reordered[N, K], dtype=int32, saved as Kunlun order + """ + N, K = packed.shape + self.align_type = 1 if K % 8 == 0 else 0 + assert num_bits == 4, "Only int4 supported now" + shifts = torch.arange(0, 32, num_bits, device=packed.device, dtype=torch.int32) + + if self.align_type == 0: # NORMAL MODE + # Unpack AWQ order:[0, 2, 4, 6, 1, 3, 5, 7] + unpacked_awq = (packed.unsqueeze(-1) >> shifts) & 0xF # [N, K, 8] + + # Reverse AWQ order and convert to KUNLUN order + AWQ_TO_KUNLUN_ORDER_NORMAL = [4, 0, 5, 1, 6, 2, 7, 3] + # [0,2,4,6,1,3,5,7] --> [1, 0, 3, 2, 5, 4, 7, 6] + unpacked_kunlun = unpacked_awq[..., AWQ_TO_KUNLUN_ORDER_NORMAL] # [N, K, 8] + + # Pack to int32, order[6, 7, 4, 5, 2, 3, 0, 1] + packed_kunlun = (unpacked_kunlun << shifts).sum( + dim=-1, dtype=torch.int32 + ) # [N, K] + elif self.align_type == 1: # FAST MODEL + # Unpack AWQ order + unpacked_awq = ( + packed.view(N, K // 8, 8).unsqueeze(-1) >> shifts + ) & 0xF # [N, K//8, 8, 8] + + # Reverse AWQ order and convert to KUNLUN order + AWQ_TO_KUNLUN_ORDER_FAST = [ + 32, 0, 36, 4, 33, 1, 37, 5, + 34, 2, 38, 6, 35, 3, 39, 7, + 40, 8, 44, 12, 41, 9, 45, 13, + 42, 10, 46, 14, 43, 11, 47, 15, + 48, 16, 52, 20, 49, 17, 53, 21, + 50, 18, 54, 22, 51, 19, 55, 23, + 56, 24, 60, 28, 57, 25, 61, 29, + 58, 26, 62, 30, 59, 27, 63, 31 + ] + unpacked_awq = unpacked_awq.reshape(N, K // 8, 64) + unpacked_kunlun = unpacked_awq[..., AWQ_TO_KUNLUN_ORDER_FAST] # [N, K//8, 64] + + # Pack to int32 + unpacked_kunlun = unpacked_kunlun.reshape(N, K // 8, 8, 8) + packed_kunlun = ( + (unpacked_kunlun << shifts).sum(dim=-1, dtype=torch.int32).reshape(N, K) + ) # [N, K] + else: + raise NotImplementedError + + return packed_kunlun -def repack_int4_for_kunlun(self, packed: torch.Tensor, num_bits: int = 4): - """Convert AWQ-packed int4 weights to Kunlun XPU format. - Input: packed[N, K], dtype=int32, saved as AWQ order - Output: packed_reordered[N, K], dtype=int32, saved as Kunlun order - """ - N, K = packed.shape - self.align_type = 1 if K % 8 == 0 else 0 - assert num_bits == 4, "Only int4 supported now" - shifts = torch.arange(0, 32, num_bits, device=packed.device, dtype=torch.int32) - - if self.align_type == 0: # NORMAL MODE - # Unpack AWQ order:[0, 2, 4, 6, 1, 3, 5, 7] - unpacked_awq = (packed.unsqueeze(-1) >> shifts) & 0xF # [N, K, 8] - - # Reverse AWQ order and convert to KUNLUN order - AWQ_TO_KUNLUN_ORDER_NORMAL = [4, 0, 5, 1, 6, 2, 7, 3] - # [0,2,4,6,1,3,5,7] --> [1, 0, 3, 2, 5, 4, 7, 6] - unpacked_kunlun = unpacked_awq[..., AWQ_TO_KUNLUN_ORDER_NORMAL] # [N, K, 8] - - # Pack to int32, order[6, 7, 4, 5, 2, 3, 0, 1] - packed_kunlun = (unpacked_kunlun << shifts).sum( - dim=-1, dtype=torch.int32 - ) # [N, K] - elif self.align_type == 1: # FAST MODEL - # Unpack AWQ order - unpacked_awq = ( - packed.view(N, K // 8, 8).unsqueeze(-1) >> shifts - ) & 0xF # [N, K//8, 8, 8] - - # Reverse AWQ order and convert to KUNLUN order - AWQ_TO_KUNLUN_ORDER_FAST = [ - 32, 0, 36, 4, 33, 1, 37, 5, - 34, 2, 38, 6, 35, 3, 39, 7, - 40, 8, 44, 12, 41, 9, 45, 13, - 42, 10, 46, 14, 43, 11, 47, 15, - 48, 16, 52, 20, 49, 17, 53, 21, - 50, 18, 54, 22, 51, 19, 55, 23, - 56, 24, 60, 28, 57, 25, 61, 29, - 58, 26, 62, 30, 59, 27, 63, 31 - ] - unpacked_awq = unpacked_awq.reshape(N, K // 8, 64) - unpacked_kunlun = unpacked_awq[..., AWQ_TO_KUNLUN_ORDER_FAST] # [N, K//8, 64] - - # Pack to int32 - unpacked_kunlun = unpacked_kunlun.reshape(N, K // 8, 8, 8) - packed_kunlun = ( - (unpacked_kunlun << shifts).sum(dim=-1, dtype=torch.int32).reshape(N, K) - ) # [N, K] - else: - raise NotImplementedError - - return packed_kunlun - - -def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - layer.qweight = torch.nn.Parameter( - ( - self.repack_int4_for_kunlun(layer.qweight.data) - if layer.qweight.data.dtype == torch.int32 - else layer.qweight.data - ), - requires_grad=False, - ) - layer.qzeros = torch.nn.Parameter( - ( - self.repack_int4_for_kunlun(layer.qzeros.data) - if layer.qzeros.data.dtype == torch.int32 - else layer.qzeros.data - ), - requires_grad=False, - ) - layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False) - - -def apply( - self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None -) -> torch.Tensor: - qweight = layer.qweight - scales = layer.scales - qzeros = layer.qzeros - pack_factor = self.quant_config.pack_factor - out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,) - reshaped_x = x.reshape(-1, x.shape[-1]) - - # num_tokens >= threshold - FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256 - - if FP16_MATMUL_HEURISTIC_CONDITION: - out = torch.ops._C.awq_dequantize( - qweight, scales, qzeros, quant_type=0, align_type=self.align_type + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + logger.warning_once(f"Repacking INT4 for XPU ...") + layer.qweight = torch.nn.Parameter( + ( + self.repack_int4_for_kunlun(layer.qweight.data) + if layer.qweight.data.dtype == torch.int32 + else layer.qweight.data + ), + requires_grad=False, ) - out = torch.matmul(reshaped_x, out) - else: - out = torch.ops._C.awq_gemm( - reshaped_x, qweight, scales, qzeros, align_type=self.align_type + layer.qzeros = torch.nn.Parameter( + ( + self.repack_int4_for_kunlun(layer.qzeros.data) + if layer.qzeros.data.dtype == torch.int32 + else layer.qzeros.data + ), + requires_grad=False, ) - if bias is not None: - out.add_(bias) - return out.reshape(out_shape) + layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False) -AWQLinearMethod.repack_int4_for_kunlun = repack_int4_for_kunlun -AWQLinearMethod.process_weights_after_loading = process_weights_after_loading -AWQLinearMethod.apply = apply + def apply( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None + ) -> torch.Tensor: + qweight = layer.qweight + scales = layer.scales + qzeros = layer.qzeros + pack_factor = self.quant_config.pack_factor + out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,) + reshaped_x = x.reshape(-1, x.shape[-1]) + + # num_tokens >= threshold + FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256 + + if FP16_MATMUL_HEURISTIC_CONDITION: + out = torch.ops._C.awq_dequantize( + qweight, scales, qzeros, quant_type=0, align_type=self.align_type + ) + out = torch.matmul(reshaped_x, out) + else: + out = torch.ops._C.awq_gemm( + reshaped_x, qweight, scales, qzeros, align_type=self.align_type + ) + if bias is not None: + out.add_(bias) + return out.reshape(out_shape) + + +# monkey patch +from vllm.model_executor.layers.quantization import awq + +awq.AWQLinearMethod = KunlunAWQLinearMethod +print( + "[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.awq.AWQLinearMethod \ + --> vllm_kunlun.ops.quantization.awq.KunlunAWQLinearMethod" +) diff --git a/vllm_kunlun/ops/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm_kunlun/ops/quantization/compressed_tensors/compressed_tensors_moe.py index b0b979e..76452a1 100644 --- a/vllm_kunlun/ops/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm_kunlun/ops/quantization/compressed_tensors/compressed_tensors_moe.py @@ -24,176 +24,190 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso ) -def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # NOTE: xtorch_ops use max as scale - with torch.no_grad(): - layer.w13_weight_scale.mul_(127.0) - layer.w2_weight_scale.mul_(127.0) +class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMethod): + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # NOTE: xtorch_ops use max as scale + with torch.no_grad(): + layer.w13_weight_scale.mul_(127.0) + layer.w2_weight_scale.mul_(127.0) -def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None, -) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - hidden_states = x - global_num_experts, up_gate_size, _ = layer.w13_weight.shape - M, N = hidden_states.shape - hidden_dim = layer.w2_weight.shape[1] - normed_score = torch.empty( - M, top_k, dtype=torch.float32, device=hidden_states.device - ) - topk_ids = torch.empty(M, top_k, dtype=torch.int32, device=hidden_states.device) - num_blocks = 12 - block_statistic = torch.zeros( - num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device - ) - - router_logits = router_logits.float() - if scoring_func == "softmax": - torch.ops._C.moe_softmax_topk_norm( - x=router_logits, - normed_score=normed_score, - topk_index=topk_ids, - block_statistic=None, - stable=True, + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + hidden_states = x + global_num_experts, up_gate_size, _ = layer.w13_weight.shape + M, N = hidden_states.shape + hidden_dim = layer.w2_weight.shape[1] + normed_score = torch.empty( + M, top_k, dtype=torch.float32, device=hidden_states.device ) - elif scoring_func == "sigmoid": - torch.ops._C.moe_sigmoid_group_topk_norm( - x=router_logits, - norm_score=normed_score, - topk_index=topk_ids, - block_static=block_statistic, - bias=e_score_correction_bias, - n_group=num_expert_group, - topk_group=topk_group, - scale=routed_scaling_factor, + topk_ids = torch.empty(M, top_k, dtype=torch.int32, device=hidden_states.device) + num_blocks = 12 + block_statistic = torch.zeros( + num_blocks, + global_num_experts, + dtype=torch.int32, + device=hidden_states.device, ) - moe_expand = torch.empty( - (M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device - ) # [M, top_k, N], float - expert_m = torch.zeros( - global_num_experts, dtype=torch.int32, device=hidden_states.device - ) # [E] - sorted_tokens_num_lod = torch.zeros( - global_num_experts + 1, dtype=torch.int32, device=hidden_states.device - ) # [E+1] - sorted_tokens_idx = torch.zeros( - M * top_k, dtype=torch.int32, device=hidden_states.device - ) + router_logits = router_logits.float() + if scoring_func == "softmax": + torch.ops._C.moe_softmax_topk_norm( + x=router_logits, + normed_score=normed_score, + topk_index=topk_ids, + block_statistic=None, + stable=True, + ) + elif scoring_func == "sigmoid": + torch.ops._C.moe_sigmoid_group_topk_norm( + x=router_logits, + norm_score=normed_score, + topk_index=topk_ids, + block_static=block_statistic, + bias=e_score_correction_bias, + n_group=num_expert_group, + topk_group=topk_group, + scale=routed_scaling_factor, + ) - torch.ops._C.gen_block_statistic(topk_ids, block_statistic) + moe_expand = torch.empty( + (M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device + ) # [M, top_k, N], float + expert_m = torch.zeros( + global_num_experts, dtype=torch.int32, device=hidden_states.device + ) # [E] + sorted_tokens_num_lod = torch.zeros( + global_num_experts + 1, dtype=torch.int32, device=hidden_states.device + ) # [E+1] + sorted_tokens_idx = torch.zeros( + M * top_k, dtype=torch.int32, device=hidden_states.device + ) - torch.ops._C.moe_pre_sorted( - x=hidden_states, - topk_index=topk_ids, - block_statistic=block_statistic, - moe_expand=moe_expand, - moe_index=sorted_tokens_idx, - expert_m=expert_m, - sorted_tokens_num_lod=sorted_tokens_num_lod, - ) + torch.ops._C.gen_block_statistic(topk_ids, block_statistic) - y = torch.empty( - M, - top_k, - layer.w13_weight.shape[1], - dtype=hidden_states.dtype, - device=hidden_states.device, - ) + torch.ops._C.moe_pre_sorted( + x=hidden_states, + topk_index=topk_ids, + block_statistic=block_statistic, + moe_expand=moe_expand, + moe_index=sorted_tokens_idx, + expert_m=expert_m, + sorted_tokens_num_lod=sorted_tokens_num_lod, + ) - moe_expand = moe_expand.view(M * top_k, hidden_dim) + y = torch.empty( + M, + top_k, + layer.w13_weight.shape[1], + dtype=hidden_states.dtype, + device=hidden_states.device, + ) - x_shape = moe_expand.shape - x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device) - x_scale = torch.empty( - (x_shape[0], 1), dtype=torch.float32, device=moe_expand.device - ) - torch.ops._C.quant2d(moe_expand, x_q, x_scale, force_sdnn=True) + moe_expand = moe_expand.view(M * top_k, hidden_dim) - torch.ops._C.moe_fc( - x=x_q, - x_perchannel_max=x_scale, - weight=layer.w13_weight, - w_perchannel_max=layer.w13_weight_scale, - sorted_tokens_num_lod=sorted_tokens_num_lod, - sorted_tokens_idx=sorted_tokens_idx, - moe_topk=top_k, - y=y, - topk_ids=topk_ids, - # sort_mode=False, - act=None, - ) + x_shape = moe_expand.shape + x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device) + x_scale = torch.empty( + (x_shape[0], 1), dtype=torch.float32, device=moe_expand.device + ) + torch.ops._C.quant2d(moe_expand, x_q, x_scale, force_sdnn=True) - d = y.shape[-1] // 2 - output_shape = y.shape[:-1] + (d,) - out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device) - torch.ops._C.silu_and_mul(out1, y) + torch.ops._C.moe_fc( + x=x_q, + x_perchannel_max=x_scale, + weight=layer.w13_weight, + w_perchannel_max=layer.w13_weight_scale, + sorted_tokens_num_lod=sorted_tokens_num_lod, + sorted_tokens_idx=sorted_tokens_idx, + moe_topk=top_k, + y=y, + topk_ids=topk_ids, + # sort_mode=False, + act=None, + ) - out = torch.empty( - M, - top_k, - layer.w2_weight.shape[1], - dtype=hidden_states.dtype, - device=hidden_states.device, - ) + d = y.shape[-1] // 2 + output_shape = y.shape[:-1] + (d,) + out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device) + torch.ops._C.silu_and_mul(out1, y) - out1 = out1.reshape(-1, out1.shape[-1]) - x_shape = out1.shape - x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device) - x_scale = torch.empty( - (x_shape[0], 1), dtype=torch.float32, device=moe_expand.device - ) - torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True) + out = torch.empty( + M, + top_k, + layer.w2_weight.shape[1], + dtype=hidden_states.dtype, + device=hidden_states.device, + ) - torch.ops._C.moe_fc( - x=x_q, - x_perchannel_max=x_scale, - weight=layer.w2_weight, - w_perchannel_max=layer.w2_weight_scale, - sorted_tokens_num_lod=sorted_tokens_num_lod, - sorted_tokens_idx=sorted_tokens_idx, - moe_topk=top_k, - y=out, - topk_ids=topk_ids, - # sort_mode=False, - act=None, - ) + out1 = out1.reshape(-1, out1.shape[-1]) + x_shape = out1.shape + x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device) + x_scale = torch.empty( + (x_shape[0], 1), dtype=torch.float32, device=moe_expand.device + ) + torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True) - dequant_scale = torch.ones([M, top_k], dtype=torch.float32, device=out.device) - output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device) - sorted_tokens_idx = sorted_tokens_idx.view(M, top_k) + torch.ops._C.moe_fc( + x=x_q, + x_perchannel_max=x_scale, + weight=layer.w2_weight, + w_perchannel_max=layer.w2_weight_scale, + sorted_tokens_num_lod=sorted_tokens_num_lod, + sorted_tokens_idx=sorted_tokens_idx, + moe_topk=top_k, + y=out, + topk_ids=topk_ids, + # sort_mode=False, + act=None, + ) - torch.ops._C.moe_post( - x=out, - moe_index=sorted_tokens_idx, - normed_scale=normed_score, - dequant_scale=dequant_scale, - y=output, - ) - return output + dequant_scale = torch.ones([M, top_k], dtype=torch.float32, device=out.device) + output = torch.empty( + [M, N], dtype=hidden_states.dtype, device=hidden_states.device + ) + sorted_tokens_idx = sorted_tokens_idx.view(M, top_k) + + torch.ops._C.moe_post( + x=out, + moe_index=sorted_tokens_idx, + normed_scale=normed_score, + dequant_scale=dequant_scale, + y=output, + ) + return output -CompressedTensorsW8A8Int8MoEMethod.process_weights_after_loading = ( - process_weights_after_loading +# monkey patch +from vllm.model_executor.layers.quantization.compressed_tensors import ( + compressed_tensors_moe, +) + +compressed_tensors_moe.CompressedTensorsW8A8Int8MoEMethod = ( + KunlunCompressedTensorsW8A8Int8MoEMethod +) +print( + "[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe.CompressedTensorsW8A8Int8MoEMethod \ + --> vllm_kunlun.ops.quantization.compressed_tensors_moe.py:KunlunCompressedTensorsW8A8Int8MoEMethod" ) -CompressedTensorsW8A8Int8MoEMethod.apply = apply diff --git a/vllm_kunlun/ops/quantization/gptq.py b/vllm_kunlun/ops/quantization/gptq.py index e7fdba7..62fa084 100644 --- a/vllm_kunlun/ops/quantization/gptq.py +++ b/vllm_kunlun/ops/quantization/gptq.py @@ -17,92 +17,99 @@ # limitations under the License. import torch - -from torch.nn.parameter import Parameter from typing import Optional +from torch.nn.parameter import Parameter +from vllm.logger import init_logger from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod, ExllamaState +logger = init_logger(__name__) + +class KunlunGPTQLinearMethod(GPTQLinearMethod): + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # for torch.compile + logger.warning_once(f"Repacking INT4 for XPU ...") + layer.qzeros = Parameter( + self.repack_int4_for_kunlun(layer.qzeros.data, self.quant_config.weight_bits) + if self.quant_config.weight_bits == 4 else layer.qzeros.data, + requires_grad=False + ) + layer.qweight = Parameter(layer.qweight.data, requires_grad=False) + layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False) + layer.scales = Parameter(layer.scales.data, requires_grad=False) + + # exllama needs to shuffle the weight after the weight is loaded + # here we do the shuffle on first forward pass + if layer.exllama_state == ExllamaState.UNINITIALIZED: + if self.quant_config.desc_act: + layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) + else: + layer.g_idx.data = torch.empty((0, ), + dtype=torch.int, + device=layer.g_idx.device) + layer.exllama_state = ExllamaState.READY + + # No need shuffle on xpu + # ops.gptq_shuffle(layer.qweight, layer.g_idx, + # self.quant_config.weight_bits) -def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # for torch.compile - layer.qzeros = Parameter( - self.repack_int4_for_kunlun(layer.qzeros.data, self.quant_config.weight_bits) - if self.quant_config.weight_bits == 4 else layer.qzeros.data, - requires_grad=False - ) - layer.qweight = Parameter(layer.qweight.data, requires_grad=False) - layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False) - layer.scales = Parameter(layer.scales.data, requires_grad=False) + def repack_int4_for_kunlun(self, packed: torch.Tensor, num_bits: int = 4): + N, K = packed.shape + assert num_bits == 4, "Only int4 supported now" + shifts = torch.arange(0, 32, num_bits, device=packed.device, dtype=torch.int32) - # exllama needs to shuffle the weight after the weight is loaded - # here we do the shuffle on first forward pass - if layer.exllama_state == ExllamaState.UNINITIALIZED: - if self.quant_config.desc_act: - layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) - else: - layer.g_idx.data = torch.empty((0, ), - dtype=torch.int, - device=layer.g_idx.device) - layer.exllama_state = ExllamaState.READY + # Unpack int32 to int4 values + unpacked_gptq = ( + packed.view(N, K // 8, 8).unsqueeze(-1) >> shifts + ) & 0xF # [N, K//8, 8, 8] - # No need shuffle on xpu - # ops.gptq_shuffle(layer.qweight, layer.g_idx, - # self.quant_config.weight_bits) + # Convert to KUNLUN order + GPTQ_TO_KUNLUN_ORDER_FAST = [ + 32, 0, 33, 1, 34, 2, 35, 3, + 36, 4, 37, 5, 38, 6, 39, 7, + 40, 8, 41, 9, 42, 10, 43, 11, + 44, 12, 45, 13, 46, 14, 47, 15, + 48, 16, 49, 17, 50, 18, 51, 19, + 52, 20, 53, 21, 54, 22, 55, 23, + 56, 24, 57, 25, 58, 26, 59, 27, + 60, 28, 61, 29, 62, 30, 63, 31, + ] + unpacked_gptq = unpacked_gptq.reshape(N, K // 8, 64) + unpacked_kunlun = unpacked_gptq[..., GPTQ_TO_KUNLUN_ORDER_FAST] # [N, K//8, 64] + + # Pack to int32 + unpacked_kunlun = unpacked_kunlun.reshape(N, K // 8, 8, 8) + packed_kunlun = ( + (unpacked_kunlun << shifts).sum(dim=-1, dtype=torch.int32).reshape(N, K) + ) # [N, K] + + return packed_kunlun -def repack_int4_for_kunlun(self, packed: torch.Tensor, num_bits: int = 4): - N, K = packed.shape - assert num_bits == 4, "Only int4 supported now" - shifts = torch.arange(0, 32, num_bits, device=packed.device, dtype=torch.int32) + def apply( + self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None + ) -> torch.Tensor: + out_shape = x.shape[:-1] + (layer.qweight.shape[-1], ) + reshaped_x = x.reshape(-1, x.shape[-1]) - # Unpack int32 to int4 values - unpacked_gptq = ( - packed.view(N, K // 8, 8).unsqueeze(-1) >> shifts - ) & 0xF # [N, K//8, 8, 8] + output = torch.ops.xspeedgate_ops.gptq_gemm( + reshaped_x, + layer.qweight, + layer.qzeros, + layer.scales, + layer.g_idx, + layer.exllama_state == ExllamaState.READY, + self.quant_config.weight_bits, + ) + if bias is not None: + output.add_(bias) + return output.reshape(out_shape) + - # Convert to KUNLUN order - GPTQ_TO_KUNLUN_ORDER_FAST = [ - 32, 0, 33, 1, 34, 2, 35, 3, - 36, 4, 37, 5, 38, 6, 39, 7, - 40, 8, 41, 9, 42, 10, 43, 11, - 44, 12, 45, 13, 46, 14, 47, 15, - 48, 16, 49, 17, 50, 18, 51, 19, - 52, 20, 53, 21, 54, 22, 55, 23, - 56, 24, 57, 25, 58, 26, 59, 27, - 60, 28, 61, 29, 62, 30, 63, 31, - ] - unpacked_gptq = unpacked_gptq.reshape(N, K // 8, 64) - unpacked_kunlun = unpacked_gptq[..., GPTQ_TO_KUNLUN_ORDER_FAST] # [N, K//8, 64] +# monkey patch +from vllm.model_executor.layers.quantization import gptq - # Pack to int32 - unpacked_kunlun = unpacked_kunlun.reshape(N, K // 8, 8, 8) - packed_kunlun = ( - (unpacked_kunlun << shifts).sum(dim=-1, dtype=torch.int32).reshape(N, K) - ) # [N, K] - - return packed_kunlun - - -def apply( - self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None -) -> torch.Tensor: - out_shape = x.shape[:-1] + (layer.qweight.shape[-1], ) - reshaped_x = x.reshape(-1, x.shape[-1]) - - output = torch.ops.xspeedgate_ops.gptq_gemm( - reshaped_x, - layer.qweight, - layer.qzeros, - layer.scales, - layer.g_idx, - layer.exllama_state == ExllamaState.READY, - self.quant_config.weight_bits, - ) - if bias is not None: - output.add_(bias) - return output.reshape(out_shape) - - -GPTQLinearMethod.repack_int4_for_kunlun = repack_int4_for_kunlun -GPTQLinearMethod.process_weights_after_loading = process_weights_after_loading -GPTQLinearMethod.apply = apply +gptq.GPTQLinearMethod = KunlunGPTQLinearMethod +print( + "[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.gptq.GPTQLinearMethod \ + --> vllm_kunlun.ops.quantization.gptq.KunlunGPTQLinearMethod" +) \ No newline at end of file diff --git a/vllm_kunlun/ops/quantization/kernels/kunlun_scale_mm.py b/vllm_kunlun/ops/quantization/kernels/kunlun_scale_mm.py index 34d8c94..8876973 100644 --- a/vllm_kunlun/ops/quantization/kernels/kunlun_scale_mm.py +++ b/vllm_kunlun/ops/quantization/kernels/kunlun_scale_mm.py @@ -21,7 +21,6 @@ from typing import Optional import torch import xspeedgate_ops from vllm.platforms import current_platform, PlatformEnum -from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( convert_to_channelwise, ) @@ -100,9 +99,12 @@ class KunlunScaledMMLinearKernel(CutlassScaledMMLinearKernel): # ) +# monkey patch _POSSIBLE_KERNELS[PlatformEnum.CUDA] = [KunlunScaledMMLinearKernel] +from vllm.model_executor.layers.quantization.kernels.scaled_mm import cutlass - +cutlass.CutlassScaledMMLinearKernel = KunlunScaledMMLinearKernel print( - f"[vllm_kunlun] ScaledMM kernels: {[k.__name__ for k in _POSSIBLE_KERNELS[PlatformEnum.CUDA]]}" + "[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass.CutlassScaledMMLinearKernel \ + --> vllm_kunlun.ops.quantization.kernels.kunlun_scale_mm.KunlunScaledMMLinearKernel" )