diff --git a/docs/source/user_guide/feature_guide/quantization.md b/docs/source/user_guide/feature_guide/quantization.md index be5b793..9d90df7 100644 --- a/docs/source/user_guide/feature_guide/quantization.md +++ b/docs/source/user_guide/feature_guide/quantization.md @@ -31,7 +31,7 @@ Like vLLM, we now support quantization methods such as compressed-tensors, AWQ, ✅ ✅ ✅ - WIP + ✅ ✅ WIP diff --git a/vllm_kunlun/ops/__init__.py b/vllm_kunlun/ops/__init__.py index fa5f0cc..c5d2991 100644 --- a/vllm_kunlun/ops/__init__.py +++ b/vllm_kunlun/ops/__init__.py @@ -19,6 +19,7 @@ import vllm_kunlun.ops.rotary_embedding import vllm_kunlun.ops.layernorm import vllm_kunlun.ops.quantization.awq import vllm_kunlun.ops.quantization.gptq +import vllm_kunlun.ops.quantization.moe_wna16 import vllm_kunlun.ops.vocab_parallel_embedding import vllm_kunlun.ops.linear import vllm_kunlun.ops.fused_moe.layer diff --git a/vllm_kunlun/ops/fused_moe/layer.py b/vllm_kunlun/ops/fused_moe/layer.py index 51d0c84..df50cef 100644 --- a/vllm_kunlun/ops/fused_moe/layer.py +++ b/vllm_kunlun/ops/fused_moe/layer.py @@ -162,7 +162,7 @@ class KunlunFusedMoE(FusedMoE): if (self.quant_config is None) or ( should_ignore_layer( prefix, - ignore=self.quant_config.ignore, + ignore=getattr(self.quant_config, "ignore", tuple()), fused_mapping=self.quant_config.packed_modules_mapping, ) ): diff --git a/vllm_kunlun/ops/quantization/awq.py b/vllm_kunlun/ops/quantization/awq.py index 242fbb4..40455e6 100644 --- a/vllm_kunlun/ops/quantization/awq.py +++ b/vllm_kunlun/ops/quantization/awq.py @@ -1,6 +1,6 @@ # # Copyright (c) 2025 Baidu, Inc. All Rights Reserved. -# Author: Li Wei, Pan Xiakai, You Zeyu +# Author: Li Wei, Pan Xiakai, You Zeyu, Tang Shiwen # Email: liwei157@baidu.com # This file is a part of the vllm-kunlun project. # @@ -16,13 +16,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Union + import torch + from vllm.logger import init_logger -from typing import Optional -from vllm.model_executor.layers.quantization.awq import AWQLinearMethod +from vllm.model_executor.layers.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod +from vllm.model_executor.layers.quantization.awq import ( + AWQLinearMethod, + AWQConfig, + is_layer_skipped_awq, +) +from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config + logger = init_logger(__name__) + class KunlunAWQLinearMethod(AWQLinearMethod): + def repack_int4_for_kunlun(self, packed: torch.Tensor, num_bits: int = 4): """Convert AWQ-packed int4 weights to Kunlun XPU format. Input: packed[N, K], dtype=int32, saved as AWQ order @@ -64,7 +76,9 @@ class KunlunAWQLinearMethod(AWQLinearMethod): 58, 26, 62, 30, 59, 27, 63, 31 ] unpacked_awq = unpacked_awq.reshape(N, K // 8, 64) - unpacked_kunlun = unpacked_awq[..., AWQ_TO_KUNLUN_ORDER_FAST] # [N, K//8, 64] + unpacked_kunlun = unpacked_awq[ + ..., AWQ_TO_KUNLUN_ORDER_FAST + ] # [N, K//8, 64] # Pack to int32 unpacked_kunlun = unpacked_kunlun.reshape(N, K // 8, 8, 8) @@ -76,7 +90,6 @@ class KunlunAWQLinearMethod(AWQLinearMethod): return packed_kunlun - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: logger.warning_once(f"Repacking INT4 for XPU ...") layer.qweight = torch.nn.Parameter( @@ -97,9 +110,11 @@ class KunlunAWQLinearMethod(AWQLinearMethod): ) layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False) - def apply( - self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: qweight = layer.qweight scales = layer.scales @@ -125,11 +140,42 @@ class KunlunAWQLinearMethod(AWQLinearMethod): return out.reshape(out_shape) +class KunlunAWQConfig(AWQConfig): + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]: # type: ignore + if isinstance(layer, LinearBase): + if is_layer_skipped_awq(prefix, self.modules_to_not_convert): + return UnquantizedLinearMethod() + return KunlunAWQLinearMethod(self) + elif isinstance(layer, FusedMoE): + logger.warning_once( + f"Layer '{prefix}' is not supported by AWQMoeMarlin. " + "Falling back to Moe WNA16 kernels." + ) + config = { + "quant_method": "awq", + "bits": self.weight_bits, + "group_size": self.group_size, + "zero_point": self.zero_point, + "lm_head": False, + } + return MoeWNA16Config.from_config(config).get_quant_method(layer, prefix) + + return None + + # monkey patch from vllm.model_executor.layers.quantization import awq awq.AWQLinearMethod = KunlunAWQLinearMethod +awq.AWQConfig = KunlunAWQConfig print( "[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.awq.AWQLinearMethod \ --> vllm_kunlun.ops.quantization.awq.KunlunAWQLinearMethod" ) +print( + "[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.awq.AWQConfig \ + --> vllm_kunlun.ops.quantization.awq.KunlunAWQConfig" +) diff --git a/vllm_kunlun/ops/quantization/kernels/quant_ops.py b/vllm_kunlun/ops/quantization/kernels/quant_ops.py new file mode 100644 index 0000000..00e1312 --- /dev/null +++ b/vllm_kunlun/ops/quantization/kernels/quant_ops.py @@ -0,0 +1,68 @@ +# +# Copyright (c) 2025 Baidu, Inc. All Rights Reserved. +# Author: Tang Shiwen +# Email: tangshiwen@baidu.com +# This file is a part of the vllm-kunlun project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def dequant_int4( + qweight: torch.Tensor, + scale: torch.Tensor, + zp: torch.Tensor, + int4_signed: bool = False, + use_mode_fast: bool = False, +) -> torch.Tensor: + + fpweight = torch.empty( + ( + qweight.shape[0], + qweight.shape[2], + scale.shape[1], + ), + dtype=scale.dtype, + device=qweight.device, + ) + + qweight_t = qweight.transpose(1, 2).contiguous() + qscale_t = scale.transpose(1, 2).contiguous() * 15.0 + + zp_t = zp.transpose(1, 2).contiguous() + zp_unpack = torch.stack((zp_t & 0xF, (zp_t >> 4) & 0xF), dim=-1) + zp_fp = ( + zp_unpack.reshape( + zp_unpack.shape[0], + zp_unpack.shape[1], + zp_unpack.shape[2] * zp_unpack.shape[3], + ) + .contiguous() + .to(scale.dtype) + - 8.0 + ) + + group_m = qweight_t.shape[-2] // qscale_t.shape[-2] + + torch.ops._C.dequant_int4( + x=qweight_t, + scale=qscale_t, + zero=zp_fp, + y=fpweight, + group_m=group_m, + int4_signed=int4_signed, + use_mode_fast=use_mode_fast, + ) + + return fpweight.transpose(1, 2).contiguous() diff --git a/vllm_kunlun/ops/quantization/moe_wna16.py b/vllm_kunlun/ops/quantization/moe_wna16.py new file mode 100644 index 0000000..f0ce115 --- /dev/null +++ b/vllm_kunlun/ops/quantization/moe_wna16.py @@ -0,0 +1,298 @@ +# +# Copyright (c) 2026 Baidu, Inc. All Rights Reserved. +# Author: Tang Shiwen, Li Wei +# Email: tangshiwen@baidu.com, liwei157@baidu.com +# This file is a part of the vllm-kunlun project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from typing import Optional, Callable, Union + +from vllm.distributed import get_tp_group +from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Method +from vllm.model_executor.utils import set_weight_attrs + +from vllm_kunlun.ops.quantization.kernels.quant_ops import dequant_int4 +from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops + + +def convert_awq_tensor_for_kunlun( + packed: torch.Tensor, + tensor_type: str, + num_bits: int = 4, + align_type: int = 0, +): + """ + Convert AWQ-packed int4 weights to Kunlun XPU format. + Input: packed[N, K], dtype=int32, saved as AWQ order + Output: + weight: packed_reordered[N, K*4], dtype=int8, saved as Kunlun order + zeros: zeros_reordered[N, K*8], dtype=float16 + """ + N, K = packed.shape + assert num_bits == 4, "Only int4 supported now" + shifts_from_int32 = torch.arange( + 0, 32, num_bits, device=packed.device, dtype=torch.int32 + ) + shifts_back_int8 = torch.arange( + 0, 8, num_bits, device=packed.device, dtype=torch.int32 + ) + + if tensor_type == "qweight": # pack weight + + if align_type == 0: # normal mode + # Unpack AWQ order:[0, 2, 4, 6, 1, 3, 5, 7] + unpacked_awq = (packed.unsqueeze(-1) >> shifts_from_int32) & 0xF + AWQ_TO_KUNLUN_ORDER_NORMAL = [0, 4, 1, 5, 2, 6, 3, 7] + unpacked_kunlun = unpacked_awq[..., AWQ_TO_KUNLUN_ORDER_NORMAL] + shifts_back_int8 = shifts_back_int8.repeat(4) + + elif align_type == 1: # fast mode + # Unpack AWQ order: [0, 2, 4, ..., 123, 125, 127] + unpacked_awq = ( + packed.view(N, K // 16, 16).unsqueeze(-1) >> shifts_from_int32 + ) & 0xF + unpacked_awq = unpacked_awq.reshape(N, K // 16, 128) + # Reverse AWQ order and convert to KUNLUN order + AWQ_TO_KUNLUN_ORDER_FAST = [ + j + 8 * i + for i in range(8) + for j in [0, 64, 4, 68, 1, 65, 5, 69, 2, 66, 6, 70, 3, 67, 7, 71] + ] + unpacked_kunlun = unpacked_awq[..., AWQ_TO_KUNLUN_ORDER_FAST] + shifts_back_int8 = shifts_back_int8.repeat(64) + + else: + raise NotImplementedError + + # Pack to int8, order[1, 0] + packed_kunlun = ( + (unpacked_kunlun << shifts_back_int8) + .view(*unpacked_kunlun.shape[:-1], -1, 2) + .sum(dim=-1) + .to(torch.int8) + .reshape(N, -1) + ) + + elif tensor_type == "qzeros": # pack zero points + unpacked_awq = (packed.unsqueeze(-1) >> shifts_from_int32) & 0xF + AWQ_TO_NORMAL_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] + unpacked_kunlun = unpacked_awq[..., AWQ_TO_NORMAL_ORDER] + shifts_back_int8 = shifts_back_int8.repeat(4) + packed_kunlun = ( + (unpacked_kunlun << shifts_back_int8) + .view(*unpacked_kunlun.shape[:-1], -1, 2) + .sum(dim=-1) + .to(torch.uint8) + .reshape(N, -1) + ) + + else: + raise NotImplementedError() + + return packed_kunlun.T.contiguous() + + +class KunlunMoeWNA16Method(MoeWNA16Method): + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + super().create_weights( + layer, + num_experts, + hidden_size, + intermediate_size_per_partition, + params_dtype, + **extra_weight_attrs, + ) + + wrapped_weight_loader = type(self).get_weight_loader( + layer, extra_weight_attrs["weight_loader"] + ) + extra_weight_attrs["weight_loader"] = wrapped_weight_loader + + # Fused gate_up_proj (column parallel) + w13_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 + * intermediate_size_per_partition + // self.quant_config.bit8_pack_factor, + hidden_size, + dtype=torch.int8, + ), + requires_grad=False, + ) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + + # down_proj (row parallel) + w2_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size // self.quant_config.bit8_pack_factor, + intermediate_size_per_partition, + dtype=torch.int8, + ), + requires_grad=False, + ) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + + @staticmethod + def get_weight_loader(layer, weight_loader): + + def patched_moe_wna16_weight_loader( + param, loaded_weight, weight_name, shard_id, expert_id, return_success=False + ): + + if "g_idx" in weight_name: + return False if return_success else None + if not layer.quant_config.has_zp and "qzeros" in weight_name: + return False if return_success else None + + device = get_tp_group().device + loaded_weight = loaded_weight.to(device) + + orig_method = layer.quant_config.linear_quant_method + + if layer.quant_config.linear_quant_method == "awq": + assert layer.quant_config.weight_bits == 4 + + if "weight" in weight_name: + + # TODO(hack): Temporary workaround for a packing conflict between + # dequant_int4 and tensor-parallel (TP) sharding. When align_type=1, + # the weights cannot be packed correctly after TP slicing, leading + # to invalid packed values. This should be revisited once the + # sharding/packing logic is refactored. + layer.align_type = 0 + + loaded_weight = convert_awq_tensor_for_kunlun( + packed=loaded_weight, + tensor_type="qweight", + align_type=layer.align_type, + ) + elif "zeros" in weight_name: + loaded_weight = convert_awq_tensor_for_kunlun( + packed=loaded_weight, tensor_type="qzeros", align_type=0 + ) + else: + loaded_weight = loaded_weight.T + + layer.quant_config.linear_quant_method = "_patched_awq" + + try: + return MoeWNA16Method.get_weight_loader(layer, weight_loader)( + param, + loaded_weight, + weight_name, + shard_id, + expert_id, + return_success=return_success, + ) + finally: + layer.quant_config.linear_quant_method = orig_method + + return patched_moe_wna16_weight_loader + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: Optional[torch.Tensor] = None, + logical_to_physical_map: Optional[torch.Tensor] = None, + logical_replica_count: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + + w13_weight = dequant_int4( + qweight=layer.w13_qweight, + scale=self.moe_quant_config.w1_scale, + zp=self.moe_quant_config.w1_zp, + int4_signed=False, + use_mode_fast=layer.align_type, + ) + + w2_weight = dequant_int4( + qweight=layer.w2_qweight, + scale=self.moe_quant_config.w2_scale, + zp=self.moe_quant_config.w2_zp, + int4_signed=False, + use_mode_fast=layer.align_type, + ) + + if self.moe.use_ep: + return ops.fused_moe_ep( + x, + w13_weight, + w2_weight, + router_logits, + self.moe.ep_rank, + top_k, + renormalize=renormalize, + inplace=True, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group, + ) + else: + return ops.fused_moe( + x, + w13_weight, + w2_weight, + router_logits, + self.moe.ep_rank, + top_k, + renormalize=renormalize, + inplace=True, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + w1_bias=getattr(layer, "w13_bias", None), + w2_bias=getattr(layer, "w2_bias", None), + ) + + +from vllm.model_executor.layers.quantization import moe_wna16 + +moe_wna16.MoeWNA16Method = KunlunMoeWNA16Method +print( + "[Monkey Patch Applied] >>> vllm.model_executor.layers.quantization.moe_wna16.MoeWNA16Method \ + --> vllm_kunlun.ops.quantization.moe_wna16.KunlunMoeWNA16Method" +) diff --git a/vllm_kunlun/vllm_utils_wrapper.py b/vllm_kunlun/vllm_utils_wrapper.py index 0fb258f..80ff691 100644 --- a/vllm_kunlun/vllm_utils_wrapper.py +++ b/vllm_kunlun/vllm_utils_wrapper.py @@ -12,6 +12,7 @@ from torch.library import register_fake import vllm_kunlun._kunlun import vllm.envs as envs + def patch_annotations_for_schema(func): """patch_annotations_for_schema""" sig = inspect.signature(func) @@ -128,7 +129,10 @@ def vllm_kunlun_weak_ref_tensors( return tuple(vllm_kunlun_weak_ref_tensor(t) for t in tensors) raise ValueError("Invalid type for tensors") -vllm_port=envs.VLLM_PORT + +vllm_port = envs.VLLM_PORT + + def _get_open_port() -> int: global vllm_port try: @@ -142,6 +146,7 @@ def _get_open_port() -> int: s.bind(("", 0)) return s.getsockname()[1] + _wrapped = SimpleNamespace(**_orig.__dict__) _wrapped.direct_register_custom_op = direct_register_custom_op _wrapped.weak_ref_tensor = vllm_kunlun_weak_ref_tensor @@ -1897,33 +1902,35 @@ def apply_repetition_penalties_( logits: torch.Tensor, prompt_mask: torch.Tensor, output_mask: torch.Tensor, - repetition_penalties: torch.Tensor + repetition_penalties: torch.Tensor, ) -> None: repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat( - 1, logits.size(1)) + 1, logits.size(1) + ) # If token appears in prompt or output, apply, otherwise use 1.0 for no-op. - penalties = torch.where(prompt_mask | output_mask, repetition_penalties, - 1.0) + penalties = torch.where(prompt_mask | output_mask, repetition_penalties, 1.0) # If logits are positive, divide by penalty, otherwise multiply by penalty. scaling = torch.where(logits > 0, 1.0 / penalties, penalties) logits *= scaling + @impl("_C::apply_repetition_penalties_", "CUDA") def apply_repetition_penalties_( logits: torch.Tensor, prompt_mask: torch.Tensor, output_mask: torch.Tensor, - repetition_penalties: torch.Tensor + repetition_penalties: torch.Tensor, ) -> None: repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat( - 1, logits.size(1)) + 1, logits.size(1) + ) # If token appears in prompt or output, apply, otherwise use 1.0 for no-op. - penalties = torch.where(prompt_mask | output_mask, repetition_penalties, - 1.0) + penalties = torch.where(prompt_mask | output_mask, repetition_penalties, 1.0) # If logits are positive, divide by penalty, otherwise multiply by penalty. scaling = torch.where(logits > 0, 1.0 / penalties, penalties) logits *= scaling - + + ################################################## # --------------- I8_mqa_logits ----------------- ################################################## @@ -1937,10 +1944,10 @@ def I8_mqa_logits( logits: torch.Tensor, clean_logits: bool, max_seq_q: Optional[int] = 0, - max_seq_k: Optional[int] = 0, + max_seq_k: Optional[int] = 0, is_causal: Optional[bool] = False, - use_xfa_boost: Optional[bool] = False, - ) -> None: + use_xfa_boost: Optional[bool] = False, +) -> None: xtorch_ops.I8_mqa_logits( q=q, fused_kv_cache=fused_kv_cache, @@ -1956,6 +1963,7 @@ def I8_mqa_logits( ) return None + @impl("_C::I8_mqa_logits", "CUDA") def I8_mqa_logits_cuda( q: torch.Tensor, @@ -1966,10 +1974,10 @@ def I8_mqa_logits_cuda( logits: torch.Tensor, clean_logits: bool, max_seq_q: Optional[int] = 0, - max_seq_k: Optional[int] = 0, + max_seq_k: Optional[int] = 0, is_causal: Optional[bool] = False, - use_xfa_boost: Optional[bool] = False, - ) -> None: + use_xfa_boost: Optional[bool] = False, +) -> None: xtorch_ops.I8_mqa_logits( q=q, fused_kv_cache=fused_kv_cache, @@ -1985,6 +1993,7 @@ def I8_mqa_logits_cuda( ) return None + def _fake_I8_mqa_logits( q: torch.Tensor, fused_kv_cache: List[torch.Tensor], @@ -1994,14 +2003,16 @@ def _fake_I8_mqa_logits( logits: torch.Tensor, clean_logits: bool, max_seq_q: Optional[int] = 0, - max_seq_k: Optional[int] = 0, + max_seq_k: Optional[int] = 0, is_causal: Optional[bool] = False, - use_xfa_boost: Optional[bool] = False, - ) -> None: + use_xfa_boost: Optional[bool] = False, +) -> None: return None + I8_mqa_logits.register_fake(_fake_I8_mqa_logits) + ################################################## # ------------- I8_paged_mqa_logits -------------- ################################################## @@ -2015,7 +2026,8 @@ def I8_paged_mqa_logits( max_context_len: int, clean_logits: bool, out: torch.Tensor, - use_xfa_boost: Optional[bool] = False) -> None: + use_xfa_boost: Optional[bool] = False, +) -> None: xtorch_ops.I8_paged_mqa_logits( q=q, fused_kv_cache=fused_kv_cache, @@ -2025,9 +2037,11 @@ def I8_paged_mqa_logits( max_context_len=max_context_len, clean_logits=clean_logits, out=out, - use_xfa_boost=use_xfa_boost) + use_xfa_boost=use_xfa_boost, + ) return None + @impl("_C::I8_paged_mqa_logits", "CUDA") def I8_paged_mqa_logits_cuda( q: torch.Tensor, @@ -2038,7 +2052,8 @@ def I8_paged_mqa_logits_cuda( max_context_len: int, clean_logits: bool, out: torch.Tensor, - use_xfa_boost: Optional[bool] = False) -> None: + use_xfa_boost: Optional[bool] = False, +) -> None: xtorch_ops.I8_paged_mqa_logits( q=q, fused_kv_cache=fused_kv_cache, @@ -2048,42 +2063,48 @@ def I8_paged_mqa_logits_cuda( max_context_len=max_context_len, clean_logits=clean_logits, out=out, - use_xfa_boost=use_xfa_boost) + use_xfa_boost=use_xfa_boost, + ) return None + def _fake_I8_paged_mqa_logits( - q: torch.Tensor, - fused_kv_cache: List[torch.Tensor], - weights: torch.Tensor, - context_lens: List[torch.Tensor], - block_table: torch.Tensor, - max_context_len: int, - clean_logits: bool, - out: torch.Tensor, - use_xfa_boost: Optional[bool] = False) -> None: + q: torch.Tensor, + fused_kv_cache: List[torch.Tensor], + weights: torch.Tensor, + context_lens: List[torch.Tensor], + block_table: torch.Tensor, + max_context_len: int, + clean_logits: bool, + out: torch.Tensor, + use_xfa_boost: Optional[bool] = False, +) -> None: return None + I8_paged_mqa_logits.register_fake(_fake_I8_paged_mqa_logits) + ################################################## # ----------- sparse_prefill_fwd_opt ------------- ################################################## @custom_op("_C::sparse_prefill_fwd_opt", mutates_args=()) def sparse_prefill_fwd_opt( - q: torch.Tensor, - kv: torch.Tensor, - indices: torch.Tensor, - out: torch.Tensor, - max_logits: torch.Tensor, - lse: torch.Tensor, - sm_scale: float, - qlod_cpu: Optional[torch.Tensor] = None, - qlod_xpu: Optional[torch.Tensor] = None, - kvlod_cpu: Optional[torch.Tensor] = None, - kvlod_xpu: Optional[torch.Tensor] = None, - d_v: Optional[int] = -1, - is_causal: Optional[bool] = True, - use_xfa_boost: Optional[bool] = False) -> None: + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + out: torch.Tensor, + max_logits: torch.Tensor, + lse: torch.Tensor, + sm_scale: float, + qlod_cpu: Optional[torch.Tensor] = None, + qlod_xpu: Optional[torch.Tensor] = None, + kvlod_cpu: Optional[torch.Tensor] = None, + kvlod_xpu: Optional[torch.Tensor] = None, + d_v: Optional[int] = -1, + is_causal: Optional[bool] = True, + use_xfa_boost: Optional[bool] = False, +) -> None: xtorch_ops.sparse_prefill_fwd_opt( q=q, kv=kv, @@ -2098,25 +2119,28 @@ def sparse_prefill_fwd_opt( kvlod_xpu=kvlod_xpu, d_v=d_v, is_causal=is_causal, - use_xfa_boost=use_xfa_boost) + use_xfa_boost=use_xfa_boost, + ) return None + @impl("_C::sparse_prefill_fwd_opt", "CUDA") def sparse_prefill_fwd_opt_cuda( - q: torch.Tensor, - kv: torch.Tensor, - indices: torch.Tensor, - out: torch.Tensor, - max_logits: torch.Tensor, - lse: torch.Tensor, - sm_scale: float, - qlod_cpu: Optional[torch.Tensor] = None, - qlod_xpu: Optional[torch.Tensor] = None, - kvlod_cpu: Optional[torch.Tensor] = None, - kvlod_xpu: Optional[torch.Tensor] = None, - d_v: Optional[int] = -1, - is_causal: Optional[bool] = True, - use_xfa_boost: Optional[bool] = False) -> None: + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + out: torch.Tensor, + max_logits: torch.Tensor, + lse: torch.Tensor, + sm_scale: float, + qlod_cpu: Optional[torch.Tensor] = None, + qlod_xpu: Optional[torch.Tensor] = None, + kvlod_cpu: Optional[torch.Tensor] = None, + kvlod_xpu: Optional[torch.Tensor] = None, + d_v: Optional[int] = -1, + is_causal: Optional[bool] = True, + use_xfa_boost: Optional[bool] = False, +) -> None: xtorch_ops.sparse_prefill_fwd_opt( q=q, kv=kv, @@ -2131,46 +2155,52 @@ def sparse_prefill_fwd_opt_cuda( kvlod_xpu=kvlod_xpu, d_v=d_v, is_causal=is_causal, - use_xfa_boost=use_xfa_boost) + use_xfa_boost=use_xfa_boost, + ) return None + def _fake_sparse_prefill_fwd_opt( - q: torch.Tensor, - kv: torch.Tensor, - indices: torch.Tensor, - out: torch.Tensor, - max_logits: torch.Tensor, - lse: torch.Tensor, - sm_scale: float, - qlod_cpu: Optional[torch.Tensor] = None, - qlod_xpu: Optional[torch.Tensor] = None, - kvlod_cpu: Optional[torch.Tensor] = None, - kvlod_xpu: Optional[torch.Tensor] = None, - d_v: Optional[int] = -1, - is_causal: Optional[bool] = True, - use_xfa_boost: Optional[bool] = False) -> None: + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + out: torch.Tensor, + max_logits: torch.Tensor, + lse: torch.Tensor, + sm_scale: float, + qlod_cpu: Optional[torch.Tensor] = None, + qlod_xpu: Optional[torch.Tensor] = None, + kvlod_cpu: Optional[torch.Tensor] = None, + kvlod_xpu: Optional[torch.Tensor] = None, + d_v: Optional[int] = -1, + is_causal: Optional[bool] = True, + use_xfa_boost: Optional[bool] = False, +) -> None: return None + sparse_prefill_fwd_opt.register_fake(_fake_sparse_prefill_fwd_opt) + ################################################## # ------------------ fwd_kvcache_mla ------------- ################################################## @custom_op("_C::fwd_kvcache_mla", mutates_args=()) def fwd_kvcache_mla( - q_c: torch.Tensor, - kv_cache: torch.Tensor, - indices: torch.Tensor, - kv_lod_cpu: torch.Tensor, - out: torch.Tensor, - max_logits: torch.Tensor, - p_sums: torch.Tensor, - softmax_scale: float, - max_seq_kv: int, - q_r: Optional[torch.Tensor] = None, - pe_cache: Optional[torch.Tensor] = None, - use_xfa_boost: Optional[bool] = False, - kv_lod_xpu: Optional[torch.Tensor] = None) -> None: + q_c: torch.Tensor, + kv_cache: torch.Tensor, + indices: torch.Tensor, + kv_lod_cpu: torch.Tensor, + out: torch.Tensor, + max_logits: torch.Tensor, + p_sums: torch.Tensor, + softmax_scale: float, + max_seq_kv: int, + q_r: Optional[torch.Tensor] = None, + pe_cache: Optional[torch.Tensor] = None, + use_xfa_boost: Optional[bool] = False, + kv_lod_xpu: Optional[torch.Tensor] = None, +) -> None: xtorch_ops.fwd_kvcache_mla( q_c=q_c, kv_cache=kv_cache, @@ -2184,24 +2214,27 @@ def fwd_kvcache_mla( q_r=q_r, pe_cache=pe_cache, use_xfa_boost=use_xfa_boost, - kv_lod_xpu=kv_lod_xpu) + kv_lod_xpu=kv_lod_xpu, + ) return None + @impl("_C::fwd_kvcache_mla", "CUDA") def fwd_kvcache_mla_cuda( - q_c: torch.Tensor, - kv_cache: torch.Tensor, - indices: torch.Tensor, - kv_lod_cpu: torch.Tensor, - out: torch.Tensor, - max_logits: torch.Tensor, - p_sums: torch.Tensor, - softmax_scale: float, - max_seq_kv: int, - q_r: Optional[torch.Tensor] = None, - pe_cache: Optional[torch.Tensor] = None, - use_xfa_boost: Optional[bool] = False, - kv_lod_xpu: Optional[torch.Tensor] = None) -> None: + q_c: torch.Tensor, + kv_cache: torch.Tensor, + indices: torch.Tensor, + kv_lod_cpu: torch.Tensor, + out: torch.Tensor, + max_logits: torch.Tensor, + p_sums: torch.Tensor, + softmax_scale: float, + max_seq_kv: int, + q_r: Optional[torch.Tensor] = None, + pe_cache: Optional[torch.Tensor] = None, + use_xfa_boost: Optional[bool] = False, + kv_lod_xpu: Optional[torch.Tensor] = None, +) -> None: xtorch_ops.fwd_kvcache_mla( q_c=q_c, kv_cache=kv_cache, @@ -2215,27 +2248,94 @@ def fwd_kvcache_mla_cuda( q_r=q_r, pe_cache=pe_cache, use_xfa_boost=use_xfa_boost, - kv_lod_xpu=kv_lod_xpu) + kv_lod_xpu=kv_lod_xpu, + ) return None + def _fake_fwd_kvcache_mla( - q_c: torch.Tensor, - kv_cache: torch.Tensor, - indices: torch.Tensor, - kv_lod_cpu: torch.Tensor, - out: torch.Tensor, - max_logits: torch.Tensor, - p_sums: torch.Tensor, - softmax_scale: float, - max_seq_kv: int, - q_r: Optional[torch.Tensor] = None, - pe_cache: Optional[torch.Tensor] = None, - use_xfa_boost: Optional[bool] = False, - kv_lod_xpu: Optional[torch.Tensor] = None) -> None: + q_c: torch.Tensor, + kv_cache: torch.Tensor, + indices: torch.Tensor, + kv_lod_cpu: torch.Tensor, + out: torch.Tensor, + max_logits: torch.Tensor, + p_sums: torch.Tensor, + softmax_scale: float, + max_seq_kv: int, + q_r: Optional[torch.Tensor] = None, + pe_cache: Optional[torch.Tensor] = None, + use_xfa_boost: Optional[bool] = False, + kv_lod_xpu: Optional[torch.Tensor] = None, +) -> None: return None + fwd_kvcache_mla.register_fake(_fake_fwd_kvcache_mla) + +################################################## +# --------------- dequant_int4 ----------------- +################################################## +@custom_op("_C::dequant_int4", mutates_args=()) +def dequant_int4( + x: torch.Tensor, + scale: torch.Tensor, + zero: torch.Tensor, + y: torch.Tensor, + group_m: int, + int4_signed: bool = True, + use_mode_fast: bool = False, +) -> None: + xtorch_ops.dequant_int4( + x=x, + scale=scale, + zero=zero, + y=y, + group_m=group_m, + int4_signed=int4_signed, + use_mode_fast=use_mode_fast, + ) + return None + + +@impl("_C::dequant_int4", "CUDA") +def dequant_int4_cuda( + x: torch.Tensor, + scale: torch.Tensor, + zero: torch.Tensor, + y: torch.Tensor, + group_m: int, + int4_signed: bool = True, + use_mode_fast: bool = False, +) -> None: + xtorch_ops.dequant_int4( + x=x, + scale=scale, + zero=zero, + y=y, + group_m=group_m, + int4_signed=int4_signed, + use_mode_fast=use_mode_fast, + ) + return None + + +def _fake_dequant_int4( + x: torch.Tensor, + scale: torch.Tensor, + zero: torch.Tensor, + y: torch.Tensor, + group_m: int, + int4_signed: bool = True, + use_mode_fast: bool = False, +) -> None: + return None + + +dequant_int4.register_fake(_fake_dequant_int4) + + ################################################## # ------------------ fast_topkv2 ------------- ##################################################