diff --git a/requirements.txt b/requirements.txt index b707752..4893576 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ blake3==1.0.5 cachetools==6.1.0 cbor2==5.7.0 cloudpickle==3.1.1 -compressed-tensors==0.11.0 +compressed-tensors==0.13.0 diskcache==5.6.3 gguf==0.17.1 mistral_common==1.8.3 diff --git a/vllm_kunlun/models/qwen3_moe.py b/vllm_kunlun/models/qwen3_moe.py index 7a1a36d..d7f8031 100644 --- a/vllm_kunlun/models/qwen3_moe.py +++ b/vllm_kunlun/models/qwen3_moe.py @@ -173,10 +173,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module): # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - kunlun_linear_weights = self.gate.get_weights() final_hidden_states = self.experts(hidden_states=hidden_states, - router_logits=router_logits, - linear_weights=kunlun_linear_weights) + router_logits=router_logits) if self.is_sequence_parallel: final_hidden_states = tensor_model_parallel_all_gather( diff --git a/vllm_kunlun/ops/__init__.py b/vllm_kunlun/ops/__init__.py index d6da6ae..0bae89f 100644 --- a/vllm_kunlun/ops/__init__.py +++ b/vllm_kunlun/ops/__init__.py @@ -21,7 +21,8 @@ import vllm_kunlun.ops.quantization.awq import vllm_kunlun.ops.quantization.gptq import vllm_kunlun.ops.vocab_parallel_embedding import vllm_kunlun.ops.linear -import vllm_kunlun.ops.quantization.kernels.scaled_mm.cutlass -import vllm_kunlun.ops.vocab_parallel_embedding -import vllm_kunlun.ops.quantization.compressed_tensors_moe -import vllm_kunlun.ops.fused_moe.layer \ No newline at end of file +# import vllm_kunlun.ops.quantization.kernels.scaled_mm.cutlass +import vllm_kunlun.ops.fused_moe.layer +import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors +import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors_moe +import vllm_kunlun.ops.quantization.kernels.scaled_mm.kunlun \ No newline at end of file diff --git a/vllm_kunlun/ops/quantization/compressed_tensors/compressed_tensors.py b/vllm_kunlun/ops/quantization/compressed_tensors/compressed_tensors.py new file mode 100644 index 0000000..23b9ba1 --- /dev/null +++ b/vllm_kunlun/ops/quantization/compressed_tensors/compressed_tensors.py @@ -0,0 +1,75 @@ +# +# Copyright (c) 2025 Baidu, Inc. All Rights Reserved. +# Author: Tang Shiwen, Li Wei +# Email: tangshiwen@baidu.com, liwei157@baidu.com +# This file is a part of the vllm-kunlun project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +from vllm.model_executor.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsConfig, + CompressedTensorsLinearMethod, + CompressedTensorsMoEMethod, + CompressedTensorsKVCacheMethod, + CompressedTensorsLinearTransformMethod, + get_linear_transform_schemes, +) +from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase +from vllm_kunlun.ops.fused_moe.layer import FusedMoE + + +def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str, +) -> Optional["QuantizeMethodBase"]: + from vllm_kunlun.ops.attention.layer import Attention # Avoid circular import + + if isinstance(layer, LinearBase): + # collect schemes + quant_scheme = self.get_scheme(layer=layer, layer_name=prefix) + input_tfms, output_tfms = get_linear_transform_schemes( + layer, prefix, self.transform_config, self.packed_modules_mapping + ) + + # choose quantization method + quant_method: LinearMethodBase = UnquantizedLinearMethod() + if quant_scheme is not None: + layer.scheme = quant_scheme + quant_method = CompressedTensorsLinearMethod(self) + + # choose transform method + if any((input_tfms, output_tfms)): + return CompressedTensorsLinearTransformMethod.from_schemes( + quant_method, quant_scheme, input_tfms, output_tfms + ) + + else: + return quant_method + + if isinstance(layer, Attention): + return CompressedTensorsKVCacheMethod(self) + if isinstance(layer, FusedMoE): + return CompressedTensorsMoEMethod.get_moe_method(self, layer) + return None + + +CompressedTensorsConfig.get_quant_method = get_quant_method diff --git a/vllm_kunlun/ops/quantization/compressed_tensors_moe.py b/vllm_kunlun/ops/quantization/compressed_tensors/compressed_tensors_moe.py similarity index 60% rename from vllm_kunlun/ops/quantization/compressed_tensors_moe.py rename to vllm_kunlun/ops/quantization/compressed_tensors/compressed_tensors_moe.py index 7a73e8b..b0b979e 100644 --- a/vllm_kunlun/ops/quantization/compressed_tensors_moe.py +++ b/vllm_kunlun/ops/quantization/compressed_tensors/compressed_tensors_moe.py @@ -1,26 +1,35 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Copyright (c) 2025 Baidu, Inc. All Rights Reserved. +# Author: Li Wei, Tang Shiwen +# Email: liwei157@baidu.com, tangshiwen@baidu.com +# This file is a part of the vllm-kunlun project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -import enum -from enum import Enum from typing import Callable, Optional, Union import torch -from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import CompressedTensorsW8A8Int8MoEMethod +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( + CompressedTensorsW8A8Int8MoEMethod, +) -def klx_process_weights_after_loading(layer: torch.nn.Module) -> None: - """modify scale -> abs max""" - layer.w13_weight = torch.nn.Parameter(layer.w13_weight, requires_grad=False) - layer.w2_weight = torch.nn.Parameter(layer.w2_weight, requires_grad=False) - layer.w13_weight_scale = torch.nn.Parameter( - layer.w13_weight_scale.data * 127, requires_grad=False - ) - layer.w2_weight_scale = torch.nn.Parameter( - layer.w2_weight_scale.data * 127, requires_grad=False - ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - klx_process_weights_after_loading(layer) + # NOTE: xtorch_ops use max as scale + with torch.no_grad(): + layer.w13_weight_scale.mul_(127.0) + layer.w2_weight_scale.mul_(127.0) + def apply( self, @@ -49,14 +58,10 @@ def apply( global_num_experts, up_gate_size, _ = layer.w13_weight.shape M, N = hidden_states.shape hidden_dim = layer.w2_weight.shape[1] - normed_score = torch.empty(M, - top_k, - dtype=torch.float32, - device=hidden_states.device) - topk_ids = torch.empty(M, - top_k, - dtype=torch.int32, - device=hidden_states.device) + normed_score = torch.empty( + M, top_k, dtype=torch.float32, device=hidden_states.device + ) + topk_ids = torch.empty(M, top_k, dtype=torch.int32, device=hidden_states.device) num_blocks = 12 block_statistic = torch.zeros( num_blocks, global_num_experts, dtype=torch.int32, device=hidden_states.device @@ -69,7 +74,8 @@ def apply( normed_score=normed_score, topk_index=topk_ids, block_statistic=None, - stable=True) + stable=True, + ) elif scoring_func == "sigmoid": torch.ops._C.moe_sigmoid_group_topk_norm( x=router_logits, @@ -82,12 +88,20 @@ def apply( scale=routed_scaling_factor, ) - moe_expand = torch.empty((M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device) # [M, top_k, N], float - expert_m = torch.zeros(global_num_experts, dtype=torch.int32, device=hidden_states.device) # [E] - sorted_tokens_num_lod = torch.zeros(global_num_experts + 1, dtype=torch.int32, device=hidden_states.device) # [E+1] - sorted_tokens_idx = torch.zeros(M * top_k, dtype=torch.int32, device=hidden_states.device) + moe_expand = torch.empty( + (M * top_k, N), dtype=hidden_states.dtype, device=hidden_states.device + ) # [M, top_k, N], float + expert_m = torch.zeros( + global_num_experts, dtype=torch.int32, device=hidden_states.device + ) # [E] + sorted_tokens_num_lod = torch.zeros( + global_num_experts + 1, dtype=torch.int32, device=hidden_states.device + ) # [E+1] + sorted_tokens_idx = torch.zeros( + M * top_k, dtype=torch.int32, device=hidden_states.device + ) - torch.ops._C.gen_block_statistic(topk_ids,block_statistic) + torch.ops._C.gen_block_statistic(topk_ids, block_statistic) torch.ops._C.moe_pre_sorted( x=hidden_states, @@ -96,18 +110,24 @@ def apply( moe_expand=moe_expand, moe_index=sorted_tokens_idx, expert_m=expert_m, - sorted_tokens_num_lod=sorted_tokens_num_lod) + sorted_tokens_num_lod=sorted_tokens_num_lod, + ) - y = torch.empty(M,top_k, - layer.w13_weight.shape[1], - dtype=hidden_states.dtype, - device=hidden_states.device) + y = torch.empty( + M, + top_k, + layer.w13_weight.shape[1], + dtype=hidden_states.dtype, + device=hidden_states.device, + ) moe_expand = moe_expand.view(M * top_k, hidden_dim) x_shape = moe_expand.shape x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device) - x_scale = torch.empty((x_shape[0], 1), dtype=torch.float32, device=moe_expand.device) + x_scale = torch.empty( + (x_shape[0], 1), dtype=torch.float32, device=moe_expand.device + ) torch.ops._C.quant2d(moe_expand, x_q, x_scale, force_sdnn=True) torch.ops._C.moe_fc( @@ -121,22 +141,28 @@ def apply( y=y, topk_ids=topk_ids, # sort_mode=False, - act=None) + act=None, + ) d = y.shape[-1] // 2 - output_shape = (y.shape[:-1] + (d, )) + output_shape = y.shape[:-1] + (d,) out1 = torch.empty(output_shape, dtype=y.dtype, device=y.device) torch.ops._C.silu_and_mul(out1, y) - out = torch.empty(M,top_k, - layer.w2_weight.shape[1], - dtype=hidden_states.dtype, - device=hidden_states.device) + out = torch.empty( + M, + top_k, + layer.w2_weight.shape[1], + dtype=hidden_states.dtype, + device=hidden_states.device, + ) out1 = out1.reshape(-1, out1.shape[-1]) x_shape = out1.shape x_q = torch.empty(x_shape, dtype=torch.int8, device=moe_expand.device) - x_scale = torch.empty((x_shape[0], 1), dtype=torch.float32, device=moe_expand.device) + x_scale = torch.empty( + (x_shape[0], 1), dtype=torch.float32, device=moe_expand.device + ) torch.ops._C.quant2d(out1, x_q, x_scale, force_sdnn=True) torch.ops._C.moe_fc( @@ -150,9 +176,10 @@ def apply( y=out, topk_ids=topk_ids, # sort_mode=False, - act=None) + act=None, + ) - dequant_scale = torch.ones([M, top_k], dtype = torch.float32, device=out.device) + dequant_scale = torch.ones([M, top_k], dtype=torch.float32, device=out.device) output = torch.empty([M, N], dtype=hidden_states.dtype, device=hidden_states.device) sorted_tokens_idx = sorted_tokens_idx.view(M, top_k) @@ -161,9 +188,12 @@ def apply( moe_index=sorted_tokens_idx, normed_scale=normed_score, dequant_scale=dequant_scale, - y=output + y=output, ) return output -CompressedTensorsW8A8Int8MoEMethod.process_weights_after_loading = process_weights_after_loading -CompressedTensorsW8A8Int8MoEMethod.apply = apply \ No newline at end of file + +CompressedTensorsW8A8Int8MoEMethod.process_weights_after_loading = ( + process_weights_after_loading +) +CompressedTensorsW8A8Int8MoEMethod.apply = apply diff --git a/vllm_kunlun/ops/quantization/kernels/scaled_mm/cutlass.py b/vllm_kunlun/ops/quantization/kernels/scaled_mm/cutlass.py deleted file mode 100644 index 25a3add..0000000 --- a/vllm_kunlun/ops/quantization/kernels/scaled_mm/cutlass.py +++ /dev/null @@ -1,122 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Optional - -import torch - -from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ScaledMMLinearLayerConfig -from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import CutlassScaledMMLinearKernel -from vllm.model_executor.layers.quantization.utils import replace_parameter -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - convert_to_channelwise) - -def can_implement_kunlun( - cls, c: ScaledMMLinearLayerConfig=None) -> tuple[bool, Optional[str]]: - return True, None - -def klx_process_weights_after_loading(layer: torch.nn.Module) -> None: - """modify scale -> abs max""" - layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) - layer.weight_scale = torch.nn.Parameter( - layer.weight_scale.data * 127, requires_grad=False) - -def process_weights_after_loading_kunlun(self, layer: torch.nn.Module) -> None: - # WEIGHT - # Cutlass kernels need transposed weight. - weight = getattr(layer, self.w_q_name) - replace_parameter( - layer, self.w_q_name, - torch.nn.Parameter(weight.t().data, requires_grad=False)) - - # WEIGHT SCALE - # Cutlass kernels support only per-tensor and per-channel. - # If we have a fused module (QKV, MLP) with per tensor scales (thus N - # scales being passed to the kernel), convert to the per-channel case. - is_fused_module = len(layer.logical_widths) > 1 - weight_scale = getattr(layer, self.w_s_name) - if is_fused_module and not self.config.is_channelwise: - weight_scale = convert_to_channelwise(weight_scale, - layer.logical_widths) - replace_parameter( - layer, self.w_s_name, - torch.nn.Parameter(weight_scale.data, requires_grad=False)) - - # INPUT SCALE - if self.config.is_static_input_scheme: - input_scale = getattr(layer, self.i_s_name) - - if self.config.input_symmetric: - replace_parameter( - layer, self.i_s_name, - torch.nn.Parameter(input_scale.max(), requires_grad=False)) - setattr(layer, self.i_zp_name, None) - else: - input_zero_point = getattr(layer, self.i_zp_name) - - # reconstruct the ranges - int8_traits = torch.iinfo(torch.int8) - azps = input_zero_point.to(dtype=torch.int32) - range_max = (input_scale * (int8_traits.max - azps)).max() - range_min = (input_scale * (int8_traits.min - azps)).min() - - scale = (range_max - range_min) / (int8_traits.max - - int8_traits.min) - replace_parameter( - layer, self.i_s_name, - torch.nn.Parameter(scale, requires_grad=False)) - - # AZP loaded as int8 but used as int32 - azp = (int8_traits.min - - range_min / scale).to(dtype=torch.int32) - replace_parameter(layer, self.i_zp_name, - torch.nn.Parameter(azp, requires_grad=False)) - - else: - setattr(layer, self.i_s_name, None) - setattr(layer, self.i_zp_name, None) - - # azp_adj is the AZP adjustment term, used to account for weights. - # It does not depend on scales or azp, so it is the same for - # static and dynamic quantization. - # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md - # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md - if not self.config.input_symmetric: - weight = getattr(layer, self.w_q_name) - azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32) - if self.config.is_static_input_scheme: - # cutlass_w8a8 requires azp to be folded into azp_adj - # in the per-tensor case - azp_adj = getattr(layer, self.i_zp_name) * azp_adj - setattr(layer, self.azp_adj_name, - torch.nn.Parameter(azp_adj, requires_grad=False)) - else: - setattr(layer, self.azp_adj_name, None) - - klx_process_weights_after_loading(layer) - -def apply_weights_kunlun(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - x_q, x_scale, out = None, None, None - w_t_shape = layer.weight.T.shape - if isinstance(x, tuple): - x_q, x_scale = x - out = torch.empty((x_q.shape[0], w_t_shape[0]), - dtype=torch.bfloat16, - device=x_q.device) - else: - x_shape = x.shape - x_q = torch.empty(x_shape, dtype=torch.int8, device=x.device) - x_scale = torch.empty((x_shape[0], 1), dtype=torch.float32, device=x.device) - out = torch.empty((x_shape[0], w_t_shape[0]), - dtype=x.dtype, - device=x.device) - torch.ops._C.quant2d(x, x_q, x_scale, force_sdnn=True) - torch.ops._C.gemm_I8_I8_bf16_nt(x_q, x_scale, layer.weight.T.data, layer.weight_scale.data, out) - return out - -CutlassScaledMMLinearKernel.apply_weights = apply_weights_kunlun -CutlassScaledMMLinearKernel.can_implement = can_implement_kunlun -CutlassScaledMMLinearKernel.process_weights_after_loading = process_weights_after_loading_kunlun \ No newline at end of file diff --git a/vllm_kunlun/ops/quantization/kernels/scaled_mm/kunlun.py b/vllm_kunlun/ops/quantization/kernels/scaled_mm/kunlun.py new file mode 100644 index 0000000..24f29d7 --- /dev/null +++ b/vllm_kunlun/ops/quantization/kernels/scaled_mm/kunlun.py @@ -0,0 +1,109 @@ +# +# Copyright (c) 2025 Baidu, Inc. All Rights Reserved. +# Author: Liwei, Tang Shiwen +# Email: liwei157@baidu.com, tangshiwen@baidu.com +# This file is a part of the vllm-kunlun project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import xspeedgate_ops +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + convert_to_channelwise, +) +from vllm.platforms import current_platform +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( # noqa: E501 + ScaledMMLinearLayerConfig, + CutlassScaledMMLinearKernel, +) +from vllm.platforms import PlatformEnum +from vllm.model_executor.layers.quantization.kernels.scaled_mm import _POSSIBLE_KERNELS + + +class KunlunScaledMMLinearKernel(CutlassScaledMMLinearKernel): + + @classmethod + def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, Optional[str]]: + + if not current_platform.is_kunlun(): + return False, "KunlunScaledMM requires running on XPU." + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + super().process_weights_after_loading(layer) + + # change scale to max for klx ops + with torch.no_grad(): + getattr(layer, self.w_s_name).mul_(127.0) + + def apply_weights( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + w_q, w_s, x_s, x_zp, azp_adj = self._get_weight_params(layer) + symmetric = azp_adj is None + + # scaled_int8_quant supports both dynamic and static quant + # Currently, static is per-tensor and dynamic is per-token + x_q, x_s, x_zp, static = torch.ops._C.scaled_int8_quant( + x=x.contiguous(), + scale=x_s, + azp=x_zp, + symmetric=symmetric, + ) + + if x_zp is not None: # asymmetric + azp = None if static else x_zp + return torch.ops._C.cutlass_scaled_mm_azp( + a=x_q, + b=w_q, + scale_a=x_s, + scale_b=(w_s / 127.0).transpose(0, 1), + out_dtype=x.dtype, + azp_adj=azp_adj, + azp=azp, + bias=bias.to(torch.float32).contiguous() if bias else None, + ) + else: # symmetric + return torch.ops._C.matmul( + x=x_q, + w=w_q.transpose(0, 1), + out_dtype=x.dtype, + x_pc_max=x_s * 127.0 if static else x_s, + w_pc_max=w_s, + bias=bias.to(torch.float32).contiguous() if bias else None, + ) + + # backup option: lower performance + # return torch.ops._C.cutlass_scaled_mm( + # a = x_q, + # b = w_q, + # scale_a=x_s / 127.0 if not static else x_s, + # scale_b=(w_s / 127.0).transpose(0, 1), + # out_dtype=x.dtype, + # bias=bias.to(torch.float32).contiguous() if bias else None, + # ) + + +_POSSIBLE_KERNELS[PlatformEnum.CUDA] = [KunlunScaledMMLinearKernel] + + +print( + f"[vllm_kunlun] ScaledMM kernels: {[k.__name__ for k in _POSSIBLE_KERNELS[PlatformEnum.CUDA]]}" +) diff --git a/vllm_kunlun/vllm_utils_wrapper.py b/vllm_kunlun/vllm_utils_wrapper.py index c323aad..faeb1e0 100644 --- a/vllm_kunlun/vllm_utils_wrapper.py +++ b/vllm_kunlun/vllm_utils_wrapper.py @@ -1,8 +1,8 @@ - """vllm_utils_wrapper.py""" + import vllm.distributed.parallel_state as parallel_state import vllm.utils as _orig -from typing import (Any, Callable, Optional, Union, get_origin, get_args, List) +from typing import Any, Callable, Optional, Union, get_origin, get_args, List, Tuple from types import SimpleNamespace import torch from torch.library import Library @@ -11,6 +11,7 @@ import typing from torch.library import register_fake import vllm_kunlun._kunlun + def patch_annotations_for_schema(func): """patch_annotations_for_schema""" sig = inspect.signature(func) @@ -36,20 +37,23 @@ def patch_annotations_for_schema(func): func.__signature__ = sig.replace(parameters=new_params) return func + def supports_custom_op() -> bool: """supports_custom_op""" return hasattr(torch.library, "custom_op") + vllm_lib = Library("vllm", "FRAGMENT") # noqa + def direct_register_custom_op( - op_name: str, - op_func: Callable, - mutates_args: Optional[list[str]] = None, - fake_impl: Optional[Callable] = None, - target_lib: Optional[Library] = None, - dispatch_key: str = "CUDA", - tags: tuple[torch.Tag, ...] = (), + op_name: str, + op_func: Callable, + mutates_args: Optional[list[str]] = None, + fake_impl: Optional[Callable] = None, + target_lib: Optional[Library] = None, + dispatch_key: str = "CUDA", + tags: tuple[torch.Tag, ...] = (), ): """ `torch.library.custom_op` can have significant overhead because it @@ -68,23 +72,26 @@ def direct_register_custom_op( """ if not supports_custom_op(): from vllm.platforms import current_platform + assert not current_platform.is_cuda_alike(), ( "cuda platform needs torch>=2.4 to support custom op, " "chances are you are using an old version of pytorch " "or a custom build of pytorch. It is recommended to " "use vLLM in a fresh new environment and let it install " - "the required dependencies.") + "the required dependencies." + ) return if mutates_args is None: mutates_args = [] import torch.library + if hasattr(torch.library, "infer_schema"): patched_func = patch_annotations_for_schema(op_func) - schema_str = torch.library.infer_schema(op_func, - mutates_args=mutates_args) + schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) else: # for pytorch 2.4 import torch._custom_op.impl + schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) my_lib = target_lib or vllm_lib my_lib.define(op_name + schema_str, tags=tags) @@ -92,6 +99,7 @@ def direct_register_custom_op( if fake_impl is not None: my_lib._register_fake(op_name, fake_impl) + def vllm_kunlun_weak_ref_tensor(tensor: Any) -> Any: """ Create a weak reference to a tensor. @@ -104,8 +112,9 @@ def vllm_kunlun_weak_ref_tensor(tensor: Any) -> Any: else: return tensor + def vllm_kunlun_weak_ref_tensors( - tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]] + tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]], ) -> Union[torch.Tensor, list[Any], tuple[Any], Any]: """ Convenience function to create weak references to tensors, @@ -119,6 +128,7 @@ def vllm_kunlun_weak_ref_tensors( return tuple(vllm_kunlun_weak_ref_tensor(t) for t in tensors) raise ValueError("Invalid type for tensors") + # import vllm.utils as vu # vu.direct_register_custom_op = direct_register_custom_op @@ -133,11 +143,13 @@ _wrapped.weak_ref_tensor = vllm_kunlun_weak_ref_tensor _wrapped.weak_ref_tensors = vllm_kunlun_weak_ref_tensors import sys + sys.modules["vllm.utils"] = _wrapped _original_all_reduce = parallel_state.GroupCoordinator.all_reduce _original_all_gather = parallel_state.GroupCoordinator.all_gather + def vllm_kunlun_all_reduce(self, input_: torch.Tensor) -> torch.Tensor: """vllm_kunlun_all_reduce""" if self.world_size == 1: @@ -146,35 +158,37 @@ def vllm_kunlun_all_reduce(self, input_: torch.Tensor) -> torch.Tensor: torch.distributed.all_reduce(input_, group=self.device_group) return input_ + def vllm_kunlun_all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: """vllm_kunlun_all_reduce""" world_size = self.world_size # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ - assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + assert ( + -input_.dim() <= dim < input_.dim() + ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" if dim < 0: # Convert negative dim to positive. dim += input_.dim() input_size = input_.size() # Allocate output tensor. - output_tensor = torch.empty((world_size, ) + input_size, - dtype=input_.dtype, - device=input_.device) + output_tensor = torch.empty( + (world_size,) + input_size, dtype=input_.dtype, device=input_.device + ) # All-gather. - torch.distributed.all_gather_into_tensor(output_tensor, - input_, - group=self.device_group) + torch.distributed.all_gather_into_tensor( + output_tensor, input_, group=self.device_group + ) # Reshape output_tensor = output_tensor.movedim(0, dim) - output_tensor = output_tensor.reshape(input_size[:dim] + - (world_size * - input_size[dim], ) + - input_size[dim + 1:]) + output_tensor = output_tensor.reshape( + input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :] + ) return output_tensor + parallel_state.GroupCoordinator.all_reduce = vllm_kunlun_all_reduce parallel_state.GroupCoordinator.all_gather = vllm_kunlun_all_gather @@ -185,169 +199,180 @@ from vllm import _custom_ops as ops from typing import Optional, List import os + @custom_op("_C::rms_norm", mutates_args=()) def rms_norm( - result : torch.Tensor, + result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, - epsilon: float -)->None: + epsilon: float, +) -> None: pass + @custom_op("_C::fused_add_rms_norm", mutates_args=()) def fused_add_rms_norm( - result : torch.Tensor, + result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, - epsilon: float -)->None: + epsilon: float, +) -> None: pass + @custom_op("_C::static_scaled_fp8_quant", mutates_args=()) def static_scaled_fp8_quant( - result : torch.Tensor, - input: torch.Tensor, - scale: torch.Tensor -)->None: + result: torch.Tensor, input: torch.Tensor, scale: torch.Tensor +) -> None: pass + @impl("_C::static_scaled_fp8_quant", "CUDA") def static_scaled_fp8_quant_xpu( - result : torch.Tensor, - input: torch.Tensor, - scale: torch.Tensor -)->None: + result: torch.Tensor, input: torch.Tensor, scale: torch.Tensor +) -> None: pass + @custom_op("_C::dynamic_scaled_fp8_quant", mutates_args=()) def dynamic_scaled_fp8_quant( - result : torch.Tensor, - input: torch.Tensor, - scale: torch.Tensor -)->None: + result: torch.Tensor, input: torch.Tensor, scale: torch.Tensor +) -> None: pass + @impl("_C::dynamic_scaled_fp8_quant", "CUDA") def dynamic_scaled_fp8_quant_xpu( - result : torch.Tensor, - input: torch.Tensor, - scale: torch.Tensor -)->None: + result: torch.Tensor, input: torch.Tensor, scale: torch.Tensor +) -> None: pass + @custom_op("_C::dynamic_per_token_scaled_fp8_quant", mutates_args=()) def dynamic_per_token_scaled_fp8_quant( - result : torch.Tensor, + result: torch.Tensor, input: torch.Tensor, scale: torch.Tensor, - scale_ub: Optional[torch.Tensor] -)->None: + scale_ub: Optional[torch.Tensor], +) -> None: pass + @impl("_C::dynamic_per_token_scaled_fp8_quant", "CUDA") def dynamic_per_token_scaled_fp8_quant_xpu( - result : torch.Tensor, + result: torch.Tensor, input: torch.Tensor, scale: torch.Tensor, - scale_ub: Optional[torch.Tensor] -)->None: + scale_ub: Optional[torch.Tensor], +) -> None: pass + @custom_op("_C::rms_norm_static_fp8_quant", mutates_args=()) def rms_norm_static_fp8_quant( - result : torch.Tensor, + result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, - epsilon: float -)->None: + epsilon: float, +) -> None: pass + @impl("_C::rms_norm_static_fp8_quant", "CUDA") def rms_norm_static_fp8_quant_xpu( - result : torch.Tensor, + result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, - epsilon: float -)->None: + epsilon: float, +) -> None: pass + @custom_op("_C::fused_add_rms_norm_static_fp8_quant", mutates_args=()) def fused_add_rms_norm_static_fp8_quant( - result : torch.Tensor, + result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, - epsilon: float -)->None: + epsilon: float, +) -> None: pass + @impl("_C::fused_add_rms_norm_static_fp8_quant", "CUDA") def fused_add_rms_norm_static_fp8_quant_xpu( - result : torch.Tensor, + result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, - epsilon: float -)->None: + epsilon: float, +) -> None: pass + @custom_op("_C::rms_norm_dynamic_per_token_quant", mutates_args=()) def rms_norm_dynamic_per_token_quant( - result : torch.Tensor, + result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, epsilon: float, scale_ub: Optional[torch.Tensor], - residual: Optional[torch.Tensor] -)->None: + residual: Optional[torch.Tensor], +) -> None: pass + @impl("_C::rms_norm_dynamic_per_token_quant", "CUDA") def rms_norm_dynamic_per_token_quant_xpu( - result : torch.Tensor, + result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, epsilon: float, scale_ub: Optional[torch.Tensor], - residual: Optional[torch.Tensor] -)->None: + residual: Optional[torch.Tensor], +) -> None: pass + @custom_op("_C::silu_and_mul_quant", mutates_args=()) def silu_and_mul_quant( - result : torch.Tensor, + result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, - epsilon: float -)->None: + epsilon: float, +) -> None: pass + + @impl("_C::silu_and_mul_quant", "CUDA") def silu_and_mul_quant_xpu( - result : torch.Tensor, + result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, - epsilon: float -)->None: + epsilon: float, +) -> None: pass + import torch import xtorch_ops from torch.library import custom_op, impl + @custom_op("_C::add_rmsnorm", mutates_args=()) def add_rmsnorm( x: torch.Tensor, @@ -364,7 +389,7 @@ def add_rmsnorm( ) -> None: xtorch_ops.add_rmsnorm( x, - y, # 原来写 residual,这里其实是 y + y, # 原来写 residual,这里其实是 y residual_output=residual_output, weight=weight, eps=eps, @@ -415,6 +440,7 @@ def rmsnorm( eps, ) + @impl("_C::rmsnorm", "CUDA") def rmsnorm_cuda( x: torch.Tensor, @@ -434,28 +460,51 @@ def rmsnorm_cuda( eps, ) + import torch -def _fake_rmsnorm(x, weight, output, eps=1e-5, interweave=False, - store_output_before_norm=True, bias=None, - residual_output=None, output_max=None): + +def _fake_rmsnorm( + x, + weight, + output, + eps=1e-5, + interweave=False, + store_output_before_norm=True, + bias=None, + residual_output=None, + output_max=None, +): # 设置 shape/dtype,但不返回值 output.fake_shape = x.shape output.fake_dtype = x.dtype return None + rmsnorm.register_fake(_fake_rmsnorm) -def _fake_add_rmsnorm(x, y, weight, output, eps=1e-5, - interweaved=False, store_output_before_norm=True, - bias=None, smooth=None, residual_output=None, - output_max=None): + +def _fake_add_rmsnorm( + x, + y, + weight, + output, + eps=1e-5, + interweaved=False, + store_output_before_norm=True, + bias=None, + smooth=None, + residual_output=None, + output_max=None, +): output.fake_shape = x.shape output.fake_dtype = x.dtype return None + add_rmsnorm.register_fake(_fake_add_rmsnorm) + @custom_op("_C::split_norm_rope_neox", mutates_args=()) def split_norm_rope_neox( q_emb: torch.Tensor, @@ -472,7 +521,7 @@ def split_norm_rope_neox( kv_head_num: int, head_dim: int, rotary_dim: int, - emb_batch_size: int=1 + emb_batch_size: int = 1, ) -> None: xtorch_ops.split_norm_rope_neox( q_emb, @@ -491,6 +540,7 @@ def split_norm_rope_neox( rotary_dim, ) + @impl("_C::split_norm_rope_neox", "CUDA") def split_norm_rope_neox_cuda( q_emb: torch.Tensor, @@ -507,7 +557,7 @@ def split_norm_rope_neox_cuda( kv_head_num: int, head_dim: int, rotary_dim: int, - emb_batch_size: int=1 + emb_batch_size: int = 1, ) -> None: xtorch_ops.split_norm_rope_neox( q_emb, @@ -526,6 +576,7 @@ def split_norm_rope_neox_cuda( rotary_dim, ) + def _fake_split_norm_rope_neox( q_emb: torch.Tensor, k_emb: torch.Tensor, @@ -541,7 +592,8 @@ def _fake_split_norm_rope_neox( kv_head_num: int, head_dim: int, rotary_dim: int, - emb_batch_size: int=1): + emb_batch_size: int = 1, +): q_emb.fake_shape = q_emb.shape q_emb.fake_dtype = q_emb.dtype k_emb.fake_shape = k_emb.shape @@ -550,33 +602,35 @@ def _fake_split_norm_rope_neox( v_out.fake_dtype = v_out.dtype return None + split_norm_rope_neox.register_fake(_fake_split_norm_rope_neox) # register fake op impl here # for torch.dynamo from torch.library import register_fake + if hasattr(torch.ops.custom_ops, "fc_fusion"): + @register_fake("custom_ops::fc_fusion") - def fc_fusion_fake(self: torch.Tensor, - other: torch.Tensor, - bias: Optional[torch.Tensor], - self_trans: bool, - other_trans: bool, - *, - alpha: float=1.0, - beta: float=0.0, - act: int=1, - multi_stream: bool=False, - out: torch.Tensor - ) -> None: + def fc_fusion_fake( + self: torch.Tensor, + other: torch.Tensor, + bias: Optional[torch.Tensor], + self_trans: bool, + other_trans: bool, + *, + alpha: float = 1.0, + beta: float = 0.0, + act: int = 1, + multi_stream: bool = False, + out: torch.Tensor, + ) -> None: pass + @custom_op("_C::silu_and_mul", mutates_args=()) def silu_and_mul( - out: torch.Tensor, - x: torch.Tensor, - axis: int=-1, - turn: bool=True + out: torch.Tensor, x: torch.Tensor, axis: int = -1, turn: bool = True ) -> None: xtorch_ops.swiglu( x=x, @@ -586,25 +640,21 @@ def silu_and_mul( @impl("_C::silu_and_mul", "CUDA") def silu_and_mul_cuda( - out: torch.Tensor, - x: torch.Tensor, - axis: int=-1, - turn: bool=True + out: torch.Tensor, x: torch.Tensor, axis: int = -1, turn: bool = True ) -> None: xtorch_ops.swiglu( x=x, y=out, ) + def _fake_silu_and_mul( - out: torch.Tensor, - x: torch.Tensor, - axis: int=-1, - turn: bool=True): + out: torch.Tensor, x: torch.Tensor, axis: int = -1, turn: bool = True +): return None -silu_and_mul.register_fake(_fake_silu_and_mul) +silu_and_mul.register_fake(_fake_silu_and_mul) @custom_op("_C::swigluoai_and_mul", mutates_args=()) @@ -613,7 +663,7 @@ def swigluoai_and_mul( alpha: float = 1.702, limit: float = 7.0, axis: int = -1, - turn: bool = True + turn: bool = True, ) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" gate, up = x[..., ::2], x[..., 1::2] @@ -623,13 +673,14 @@ def swigluoai_and_mul( gated_output = (up + 1) * glu return gated_output + @impl("_C::swigluoai_and_mul", "CUDA") def swigluoai_and_mul_cuda( x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0, axis: int = -1, - turn: bool = True + turn: bool = True, ) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" gate, up = x[..., ::2], x[..., 1::2] @@ -639,12 +690,13 @@ def swigluoai_and_mul_cuda( gated_output = (up + 1) * glu return gated_output + def _fake_swigluoai_and_mul( x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0, axis: int = -1, - turn: bool = True + turn: bool = True, ) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" gate, up = x[..., ::2], x[..., 1::2] @@ -654,8 +706,10 @@ def _fake_swigluoai_and_mul( gated_output = (up + 1) * glu return gated_output + swigluoai_and_mul.register_fake(_fake_swigluoai_and_mul) + @custom_op("_C::moe_softmax_topk", mutates_args=()) def moe_softmax_topk( x: torch.Tensor, @@ -663,14 +717,10 @@ def moe_softmax_topk( topk_index: torch.Tensor, block_statistic: torch.Tensor, axis: int = -1, - turn: bool = True + turn: bool = True, ) -> None: - xtorch_ops.moe_softmax_topk( - x, - normed_score, - topk_index, - block_statistic - ) + xtorch_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic) + @impl("_C::moe_softmax_topk", "CUDA") def moe_softmax_topk_cuda( @@ -679,14 +729,10 @@ def moe_softmax_topk_cuda( topk_index: torch.Tensor, block_statistic: torch.Tensor, axis: int = -1, - turn: bool = True + turn: bool = True, ) -> None: - xtorch_ops.moe_softmax_topk( - x, - normed_score, - topk_index, - block_statistic - ) + xtorch_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic) + def _fake_moe_softmax_topk( x: torch.Tensor, @@ -694,10 +740,11 @@ def _fake_moe_softmax_topk( topk_index: torch.Tensor, block_statistic: torch.Tensor, axis: int = -1, - turn: bool = True + turn: bool = True, ) -> None: return None + moe_softmax_topk.register_fake(_fake_moe_softmax_topk) @@ -762,6 +809,7 @@ def moe_ffn_block_cuda( out=out, ) + def _fake_moe_ffn_block( out: torch.Tensor, x: torch.Tensor, @@ -773,9 +821,11 @@ def _fake_moe_ffn_block( renormalize: bool = True, use_grouped_topk: bool = False, expert_group_num: Optional[int] = 0, - topk_group: Optional[int] = 0,): + topk_group: Optional[int] = 0, +): return None + moe_ffn_block.register_fake(_fake_moe_ffn_block) @@ -794,7 +844,7 @@ def moe_ffn_per_token_block( output: Optional[torch.Tensor] = None, use_expert_parallel: bool = False, ep_size: int = 1, - ep_rank: int = 0 + ep_rank: int = 0, ) -> None: xtorch_ops.moe_ffn_per_token_block( x=x, @@ -812,6 +862,7 @@ def moe_ffn_per_token_block( out=output, ) + @impl("_C::moe_ffn_per_token_block", "CUDA") def moe_ffn_per_token_block_cuda( x: torch.Tensor, @@ -827,7 +878,7 @@ def moe_ffn_per_token_block_cuda( output: Optional[torch.Tensor] = None, use_expert_parallel: bool = False, ep_size: int = 1, - ep_rank: int = 0 + ep_rank: int = 0, ) -> None: xtorch_ops.moe_ffn_per_token_block( x=x, @@ -845,6 +896,7 @@ def moe_ffn_per_token_block_cuda( out=output, ) + def _fake_moe_ffn_per_token_block( x: torch.Tensor, inter_weight: torch.Tensor, @@ -859,12 +911,13 @@ def _fake_moe_ffn_per_token_block( output: Optional[torch.Tensor] = None, use_expert_parallel: bool = False, ep_size: int = 1, - ep_rank: int = 0 + ep_rank: int = 0, ) -> None: # Fake implementation can be a no-op or a simple operation if output is not None: output.copy_(x) # Example: simply copy input to output + # Register the fake implementation moe_ffn_per_token_block.register_fake(_fake_moe_ffn_per_token_block) @@ -877,14 +930,16 @@ def rotary_embedding( head_size: int, cos_sin_cache: torch.Tensor, is_neox: bool, -) -> None : +) -> None: xtorch_ops.rotary_embedding( positions=positions, query=query, key=key, head_size=head_size, cos_sin_cache=cos_sin_cache, - is_neox=is_neox) + is_neox=is_neox, + ) + @impl("_C::rotary_embedding", "CUDA") def rotary_embedding_cuda( @@ -901,7 +956,8 @@ def rotary_embedding_cuda( key=key, head_size=head_size, cos_sin_cache=cos_sin_cache, - is_neox=is_neox) + is_neox=is_neox, + ) def _fake_rotary_embedding( @@ -911,49 +967,12 @@ def _fake_rotary_embedding( head_size: int, cos_sin_cache: torch.Tensor, is_neox: bool, -)-> None: +) -> None: return None rotary_embedding.register_fake(_fake_rotary_embedding) -@custom_op("_C::quant2d", mutates_args=()) -def quant2d( - x: torch.Tensor, - y: torch.Tensor, - max: torch.Tensor, - force_sdnn: bool, -) -> None: - xtorch_ops.quant2d( - x=x, - y=y, - max=max, - force_sdnn=force_sdnn - ) - -@impl("_C::quant2d", "CUDA") -def quant2d_cuda( - x: torch.Tensor, - y: torch.Tensor, - max: torch.Tensor, - force_sdnn: bool, -) -> None: - xtorch_ops.quant2d( - x=x, - y=y, - max=max, - force_sdnn=force_sdnn - ) - -def _fake_quant2d( - x: torch.Tensor, - y: torch.Tensor, - max: torch.Tensor, - force_sdnn: bool, -) -> None: - return None - -quant2d.register_fake(_fake_quant2d) @custom_op("_C::gemm_I8_I8_bf16_nt", mutates_args=()) def gemm_I8_I8_bf16_nt( @@ -964,11 +983,10 @@ def gemm_I8_I8_bf16_nt( out: torch.Tensor, ) -> None: xtorch_ops.gemm_I8_I8_bf16_nt( - lhs=(x_q, x_scale), - rhs=(weight, weight_scale), - out=out + lhs=(x_q, x_scale), rhs=(weight, weight_scale), out=out ) + @impl("_C::gemm_I8_I8_bf16_nt", "CUDA") def gemm_I8_I8_bf16_nt_cuda( x_q: torch.Tensor, @@ -978,11 +996,10 @@ def gemm_I8_I8_bf16_nt_cuda( out: torch.Tensor, ) -> None: xtorch_ops.gemm_I8_I8_bf16_nt( - lhs=(x_q, x_scale), - rhs=(weight, weight_scale), - out=out + lhs=(x_q, x_scale), rhs=(weight, weight_scale), out=out ) + def _fake_gemm_I8_I8_bf16_nt( x_q: torch.Tensor, x_scale: torch.Tensor, @@ -992,131 +1009,130 @@ def _fake_gemm_I8_I8_bf16_nt( ) -> None: return None + gemm_I8_I8_bf16_nt.register_fake(_fake_gemm_I8_I8_bf16_nt) + @custom_op("_C::moe_softmax_topk_norm", mutates_args=()) def moe_softmax_topk_norm( x: torch.Tensor, normed_score: torch.Tensor, topk_index: torch.Tensor, block_statistic: torch.Tensor, - stable: bool = True + stable: bool = True, ) -> None: xtorch_ops.moe_softmax_topk_norm( - x, - normed_score, - topk_index, - block_statistic, - stable + x, normed_score, topk_index, block_statistic, stable ) + @impl("_C::moe_softmax_topk_norm", "CUDA") def moe_softmax_topk_norm_cuda( x: torch.Tensor, normed_score: torch.Tensor, topk_index: torch.Tensor, block_statistic: torch.Tensor, - stable: bool = True + stable: bool = True, ) -> None: xtorch_ops.moe_softmax_topk_norm( - x, - normed_score, - topk_index, - block_statistic, - stable + x, normed_score, topk_index, block_statistic, stable ) + def _fake_moe_softmax_topk_norm( x: torch.Tensor, normed_score: torch.Tensor, topk_index: torch.Tensor, block_statistic: torch.Tensor, - stable: bool = True + stable: bool = True, ) -> None: return None + moe_softmax_topk_norm.register_fake(_fake_moe_softmax_topk_norm) + @custom_op("_C::gen_block_statistic", mutates_args=()) -def gen_block_statistic( - topk_ids: torch.Tensor, - block_statistic: torch.Tensor -)-> None: - xtorch_ops.gen_block_statistic( - topk_ids,block_statistic - ) +def gen_block_statistic(topk_ids: torch.Tensor, block_statistic: torch.Tensor) -> None: + xtorch_ops.gen_block_statistic(topk_ids, block_statistic) + @impl("_C::gen_block_statistic", "CUDA") def gen_block_statistic_cuda( - topk_ids: torch.Tensor, - block_statistic: torch.Tensor -)-> None: - xtorch_ops.gen_block_statistic( - topk_ids,block_statistic - ) + topk_ids: torch.Tensor, block_statistic: torch.Tensor +) -> None: + xtorch_ops.gen_block_statistic(topk_ids, block_statistic) + def fake_gen_block_statistic( - topk_ids: torch.Tensor, - block_statistic: torch.Tensor -)-> None: + topk_ids: torch.Tensor, block_statistic: torch.Tensor +) -> None: return None + gen_block_statistic.register_fake(fake_gen_block_statistic) + @custom_op("_C::moe_pre_sorted", mutates_args=()) def moe_pre_sorted( - x: torch.Tensor, - topk_index: torch.Tensor, - block_statistic: torch.Tensor, - moe_expand: torch.Tensor, - moe_index: torch.Tensor, - expert_m: torch.Tensor, - sorted_tokens_num_lod: torch.Tensor, - index_have_neg: bool = False -)-> None: + x: torch.Tensor, + topk_index: torch.Tensor, + block_statistic: torch.Tensor, + moe_expand: torch.Tensor, + moe_index: torch.Tensor, + expert_m: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + index_have_neg: bool = False, +) -> None: xtorch_ops.moe_pre_sorted( - x, - topk_index, - block_statistic, - moe_expand, - moe_index, - expert_m, - sorted_tokens_num_lod) + x, + topk_index, + block_statistic, + moe_expand, + moe_index, + expert_m, + sorted_tokens_num_lod, + ) + @impl("_C::moe_pre_sorted", "CUDA") def moe_pre_sorted_cuda( - x: torch.Tensor, - topk_index: torch.Tensor, - block_statistic: torch.Tensor, - moe_expand: torch.Tensor, - moe_index: torch.Tensor, - expert_m: torch.Tensor, - sorted_tokens_num_lod: torch.Tensor, - index_have_neg: bool = False -)-> None: + x: torch.Tensor, + topk_index: torch.Tensor, + block_statistic: torch.Tensor, + moe_expand: torch.Tensor, + moe_index: torch.Tensor, + expert_m: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + index_have_neg: bool = False, +) -> None: xtorch_ops.moe_pre_sorted( - x, - topk_index, - block_statistic, - moe_expand, - moe_index, - expert_m, - sorted_tokens_num_lod) + x, + topk_index, + block_statistic, + moe_expand, + moe_index, + expert_m, + sorted_tokens_num_lod, + ) + def fake_moe_pre_sorted( - x: torch.Tensor, - topk_index: torch.Tensor, - block_statistic: torch.Tensor, - moe_expand: torch.Tensor, - moe_index: torch.Tensor, - expert_m: torch.Tensor, - sorted_tokens_num_lod: torch.Tensor, - index_have_neg: bool = False -)-> None: + x: torch.Tensor, + topk_index: torch.Tensor, + block_statistic: torch.Tensor, + moe_expand: torch.Tensor, + moe_index: torch.Tensor, + expert_m: torch.Tensor, + sorted_tokens_num_lod: torch.Tensor, + index_have_neg: bool = False, +) -> None: return None + moe_pre_sorted.register_fake(fake_moe_pre_sorted) + @custom_op("_C::moe_fc", mutates_args=()) def moe_fc( x: torch.Tensor, @@ -1127,7 +1143,7 @@ def moe_fc( y: torch.Tensor, act: Optional[str] = None, x_perchannel_max: Optional[torch.Tensor] = None, - w_perchannel_max: Optional[torch.Tensor] = None , + w_perchannel_max: Optional[torch.Tensor] = None, topk_ids: Optional[torch.Tensor] = None, topk_w: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, @@ -1136,27 +1152,29 @@ def moe_fc( scale_n: Optional[int] = 0, scale_k: Optional[int] = 0, use_pack_int4: Optional[bool] = False, - sort_mode: Optional[bool] = True -)-> None: + sort_mode: Optional[bool] = True, +) -> None: xtorch_ops.moe_fc( - x=x, - weight=weight, - sorted_tokens_num_lod=sorted_tokens_num_lod, - sorted_tokens_idx=sorted_tokens_idx, - moe_topk=moe_topk, - y=y, - act=act, - x_perchannel_max=x_perchannel_max, - w_perchannel_max=w_perchannel_max, - topk_ids=topk_ids, - topk_w=topk_w, - bias=bias, - tgemm_type=tgemm_type, - tweight_type=tweight_type, - scale_n=scale_n, - scale_k=scale_k, - use_pack_int4=use_pack_int4, - sort_mode=sort_mode) + x=x, + weight=weight, + sorted_tokens_num_lod=sorted_tokens_num_lod, + sorted_tokens_idx=sorted_tokens_idx, + moe_topk=moe_topk, + y=y, + act=act, + x_perchannel_max=x_perchannel_max, + w_perchannel_max=w_perchannel_max, + topk_ids=topk_ids, + topk_w=topk_w, + bias=bias, + tgemm_type=tgemm_type, + tweight_type=tweight_type, + scale_n=scale_n, + scale_k=scale_k, + use_pack_int4=use_pack_int4, + sort_mode=sort_mode, + ) + @impl("_C::moe_fc", "CUDA") def moe_fc_cuda( @@ -1168,7 +1186,7 @@ def moe_fc_cuda( y: torch.Tensor, act: Optional[str] = None, x_perchannel_max: Optional[torch.Tensor] = None, - w_perchannel_max: Optional[torch.Tensor] = None , + w_perchannel_max: Optional[torch.Tensor] = None, topk_ids: Optional[torch.Tensor] = None, topk_w: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, @@ -1177,27 +1195,29 @@ def moe_fc_cuda( scale_n: Optional[int] = 0, scale_k: Optional[int] = 0, use_pack_int4: Optional[bool] = False, - sort_mode: Optional[bool] = True -)-> None: + sort_mode: Optional[bool] = True, +) -> None: xtorch_ops.moe_fc( - x=x, - weight=weight, - sorted_tokens_num_lod=sorted_tokens_num_lod, - sorted_tokens_idx=sorted_tokens_idx, - moe_topk=moe_topk, - y=y, - act=act, - x_perchannel_max=x_perchannel_max, - w_perchannel_max=w_perchannel_max, - topk_ids=topk_ids, - topk_w=topk_w, - bias=bias, - tgemm_type=tgemm_type, - tweight_type=tweight_type, - scale_n=scale_n, - scale_k=scale_k, - use_pack_int4=use_pack_int4, - sort_mode=sort_mode) + x=x, + weight=weight, + sorted_tokens_num_lod=sorted_tokens_num_lod, + sorted_tokens_idx=sorted_tokens_idx, + moe_topk=moe_topk, + y=y, + act=act, + x_perchannel_max=x_perchannel_max, + w_perchannel_max=w_perchannel_max, + topk_ids=topk_ids, + topk_w=topk_w, + bias=bias, + tgemm_type=tgemm_type, + tweight_type=tweight_type, + scale_n=scale_n, + scale_k=scale_k, + use_pack_int4=use_pack_int4, + sort_mode=sort_mode, + ) + def fake_moe_fc( x: torch.Tensor, @@ -1208,7 +1228,7 @@ def fake_moe_fc( y: torch.Tensor, act: Optional[str] = None, x_perchannel_max: Optional[torch.Tensor] = None, - w_perchannel_max: Optional[torch.Tensor] = None , + w_perchannel_max: Optional[torch.Tensor] = None, topk_ids: Optional[torch.Tensor] = None, topk_w: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, @@ -1217,53 +1237,46 @@ def fake_moe_fc( scale_n: Optional[int] = 0, scale_k: Optional[int] = 0, use_pack_int4: Optional[bool] = False, - sort_mode: Optional[bool] = True -)-> None: + sort_mode: Optional[bool] = True, +) -> None: return None + moe_fc.register_fake(fake_moe_fc) + @custom_op("_C::moe_post", mutates_args=()) def moe_post( - x: torch.Tensor, - moe_index: torch.Tensor, - normed_scale: torch.Tensor, - dequant_scale: torch.Tensor, - y: torch.Tensor -)-> None: - xtorch_ops.moe_post( - x, - moe_index, - normed_scale, - dequant_scale, - y - ) + x: torch.Tensor, + moe_index: torch.Tensor, + normed_scale: torch.Tensor, + dequant_scale: torch.Tensor, + y: torch.Tensor, +) -> None: + xtorch_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y) + @impl("_C::moe_post", "CUDA") def moe_post_cuda( - x: torch.Tensor, - moe_index: torch.Tensor, - normed_scale: torch.Tensor, - dequant_scale: torch.Tensor, - y: torch.Tensor -)-> None: - xtorch_ops.moe_post( - x, - moe_index, - normed_scale, - dequant_scale, - y) + x: torch.Tensor, + moe_index: torch.Tensor, + normed_scale: torch.Tensor, + dequant_scale: torch.Tensor, + y: torch.Tensor, +) -> None: + xtorch_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y) def fake_moe_post( - x: torch.Tensor, - moe_index: torch.Tensor, - normed_scale: torch.Tensor, - dequant_scale: torch.Tensor, - y: torch.Tensor -)-> None: + x: torch.Tensor, + moe_index: torch.Tensor, + normed_scale: torch.Tensor, + dequant_scale: torch.Tensor, + y: torch.Tensor, +) -> None: return None + moe_post.register_fake(fake_moe_post) @@ -1276,7 +1289,7 @@ def moe_sigmoid_group_topk_norm( bias: torch.Tensor, scale: float, n_group: int, - topk_group: int + topk_group: int, ) -> None: xtorch_ops.moe_sigmoid_group_topk_norm( x=x, @@ -1289,6 +1302,7 @@ def moe_sigmoid_group_topk_norm( scale=scale, ) + @impl("_C::moe_sigmoid_group_topk_norm", "CUDA") def moe_sigmoid_group_topk_norm_cuda( x: torch.Tensor, @@ -1298,7 +1312,7 @@ def moe_sigmoid_group_topk_norm_cuda( bias: torch.Tensor, scale: float, n_group: int, - topk_group: int + topk_group: int, ) -> None: xtorch_ops.moe_sigmoid_group_topk_norm( x=x, @@ -1311,6 +1325,7 @@ def moe_sigmoid_group_topk_norm_cuda( scale=scale, ) + def _fake_moe_sigmoid_group_topk_norm( x: torch.Tensor, topk_index: torch.Tensor, @@ -1319,11 +1334,14 @@ def _fake_moe_sigmoid_group_topk_norm( bias: torch.Tensor, scale: float, n_group: int, - topk_group: int + topk_group: int, ) -> None: return None + moe_sigmoid_group_topk_norm.register_fake(_fake_moe_sigmoid_group_topk_norm) + + ################################################## # --------------- awq_dequantize ----------------- ################################################## @@ -1495,15 +1513,16 @@ def _fake_gptq_shuffle( gptq_shuffle.register_fake(_fake_gptq_shuffle) + ################################################## -# ---------------- concat_and_cache_mla ------------------ +# ------------- concat_and_cache_mla ------------- ################################################## @custom_op("_C::concat_and_cache_mla", mutates_args=()) def concat_and_cache_mla( - kv_c: torch.Tensor, #[num_tokens, kv_lora_rank] - k_pe: torch.Tensor, #[num_tokens, pe_dim] - kv_cache: torch.Tensor, #[num_blocks, block_size, (kv_lora_rank + pe_dim)] - slot_mapping: torch.Tensor, #[num_tokens] or [num_actual_tokens] + kv_c: torch.Tensor, # [num_tokens, kv_lora_rank] + k_pe: torch.Tensor, # [num_tokens, pe_dim] + kv_cache: torch.Tensor, # [num_blocks, block_size, (kv_lora_rank + pe_dim)] + slot_mapping: torch.Tensor, # [num_tokens] or [num_actual_tokens] ) -> None: xtorch_ops.concat_and_cache_mla( kv_c=kv_c, @@ -1512,12 +1531,13 @@ def concat_and_cache_mla( kv_cache=kv_cache, ) + @impl("_C::concat_and_cache_mla", "CUDA") def concat_and_cache_mla_cuda( - kv_c: torch.Tensor, #[num_tokens, kv_lora_rank] - k_pe: torch.Tensor, #[num_tokens, pe_dim] - kv_cache: torch.Tensor, #[num_blocks, block_size, (kv_lora_rank + pe_dim)] - slot_mapping: torch.Tensor, #[num_tokens] or [num_actual_tokens] + kv_c: torch.Tensor, # [num_tokens, kv_lora_rank] + k_pe: torch.Tensor, # [num_tokens, pe_dim] + kv_cache: torch.Tensor, # [num_blocks, block_size, (kv_lora_rank + pe_dim)] + slot_mapping: torch.Tensor, # [num_tokens] or [num_actual_tokens] ) -> None: xtorch_ops.concat_and_cache_mla( kv_c=kv_c, @@ -1526,12 +1546,330 @@ def concat_and_cache_mla_cuda( kv_cache=kv_cache, ) + def _fake_concat_and_cache_mla( - kv_c: torch.Tensor, #[num_tokens, kv_lora_rank] - k_pe: torch.Tensor, #[num_tokens, pe_dim] - kv_cache: torch.Tensor, #[num_blocks, block_size, (kv_lora_rank + pe_dim)] - slot_mapping: torch.Tensor, #[num_tokens] or [num_actual_tokens] + kv_c: torch.Tensor, # [num_tokens, kv_lora_rank] + k_pe: torch.Tensor, # [num_tokens, pe_dim] + kv_cache: torch.Tensor, # [num_blocks, block_size, (kv_lora_rank + pe_dim)] + slot_mapping: torch.Tensor, # [num_tokens] or [num_actual_tokens] ) -> None: return None -concat_and_cache_mla.register_fake(_fake_concat_and_cache_mla) \ No newline at end of file + +concat_and_cache_mla.register_fake(_fake_concat_and_cache_mla) + + +###################################################### +# -------------- scaled_int8_quant ------------------- +###################################################### +@custom_op("_C::scaled_int8_quant", mutates_args=()) +def scaled_int8_quant( + x: torch.Tensor, + scale: torch.Tensor, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]: + static = False + x_q = torch.empty_like(x, dtype=torch.int8, device=x.device) + if scale is not None: # static + static = True + torch.ops.xspeedgate_ops.static_scaled_int8_quant(x_q, x, scale, azp) + else: # dynamic + scale = torch.empty( + (x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32 + ) + azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32) + if symmetric: + # NOTE: For quant2d ops, scale represents max. + xtorch_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True) + else: + torch.ops.xspeedgate_ops.dynamic_scaled_int8_quant( + x_q, x.contiguous(), scale, azp + ) + return x_q, scale, azp, static + + +@impl("_C::scaled_int8_quant", "CUDA") +def scaled_int8_quant_cuda( + x: torch.Tensor, + scale: torch.Tensor, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]: + static = False + x_q = torch.empty_like(x, dtype=torch.int8, device=x.device) + if scale is not None: # static + static = True + torch.ops.xspeedgate_ops.static_scaled_int8_quant(x_q, x, scale, azp) + else: # dynamic + scale = torch.empty( + (x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32 + ) + azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32) + if symmetric: + # NOTE: For quant2d ops, scale represents max. + xtorch_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True) + else: + torch.ops.xspeedgate_ops.dynamic_scaled_int8_quant( + x_q, x.contiguous(), scale, azp + ) + return x_q, scale, azp, static + + +def fake_scaled_int8_quant( + x: torch.Tensor, + scale: torch.Tensor, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]: + x_q = torch.ones(x.shape, dtype=torch.int8, device=x.device) + return x_q, scale, azp, False + + +scaled_int8_quant.register_fake(fake_scaled_int8_quant) + + +###################################################### +# ---------------- cutlass_scaled_mm ----------------- +###################################################### +@custom_op("_C::cutlass_scaled_mm", mutates_args=()) +def cutlass_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device) + torch.ops.xspeedgate_ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + return out + + +@impl("_C::cutlass_scaled_mm", "CUDA") +def cutlass_scaled_mm_cuda( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device) + torch.ops.xspeedgate_ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + return out + + +def fake_cutlass_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + return torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device) + + +cutlass_scaled_mm.register_fake(fake_cutlass_scaled_mm) + + +###################################################### +# ------------ cutlass_scaled_mm_azp ----------------- +###################################################### +@custom_op("_C::cutlass_scaled_mm_azp", mutates_args=()) +def cutlass_scaled_mm_azp( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + azp_adj: torch.Tensor, + azp: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device) + torch.ops.xspeedgate_ops.cutlass_scaled_mm_azp( + out, a, b, scale_a, scale_b, azp_adj, azp, bias + ) + return out + + +@impl("_C::cutlass_scaled_mm_azp", "CUDA") +def cutlass_scaled_mm_azp_cuda( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + azp_adj: torch.Tensor, + azp: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device) + torch.ops.xspeedgate_ops.cutlass_scaled_mm_azp( + out, a, b, scale_a, scale_b, azp_adj, azp, bias + ) + return out + + +def fake_cutlass_scaled_mm_azp( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + azp_adj: torch.Tensor, + azp: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + return torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device) + + +cutlass_scaled_mm_azp.register_fake(fake_cutlass_scaled_mm_azp) + + +################################################## +# ------------------ matmul --------------------- +################################################## +@custom_op("_C::matmul", mutates_args=()) +def matmul( + x: torch.Tensor, + w: torch.Tensor, + out_dtype: torch.dtype, + x_trans: bool = False, + w_trans: bool = True, + alpha: float = 1.0, + beta: float = 0.0, + bias: torch.Tensor = None, + x_max: torch.Tensor = None, + w_max: torch.Tensor = None, + x_pc_max: torch.Tensor = None, + w_pc_max: torch.Tensor = None, +) -> torch.Tensor: + out = torch.empty( + (x.shape[0], w.shape[0] if w_trans else w.shape[1]), + dtype=out_dtype, + device=x.device, + ) + xtorch_ops.matmul( + x=x.contiguous(), + w=w.contiguous(), + out=out, + x_trans=x_trans, + w_trans=w_trans, + alpha=alpha, + beta=beta, + bias=bias, + x_max=x_max, + w_max=w_max, + x_pc_max=x_pc_max, + w_pc_max=w_pc_max, + ) + return out + + +@impl("_C::matmul", "CUDA") +def matmul_cuda( + x: torch.Tensor, + w: torch.Tensor, + out_dtype: torch.dtype, + x_trans: bool = False, + w_trans: bool = True, + alpha: float = 1.0, + beta: float = 0.0, + bias: torch.Tensor = None, + x_max: torch.Tensor = None, + w_max: torch.Tensor = None, + x_pc_max: torch.Tensor = None, + w_pc_max: torch.Tensor = None, +) -> torch.Tensor: + out = torch.empty( + (x.shape[0], w.shape[0] if w_trans else w.shape[1]), + dtype=out_dtype, + device=x.device, + ) + xtorch_ops.matmul( + x=x.contiguous(), + w=w.contiguous(), + out=out, + x_trans=x_trans, + w_trans=w_trans, + alpha=alpha, + beta=beta, + bias=bias, + x_max=x_max, + w_max=w_max, + x_pc_max=x_pc_max, + w_pc_max=w_pc_max, + ) + return out + + +def _fake_matmul( + x: torch.Tensor, + w: torch.Tensor, + out_dtype: torch.dtype, + x_trans: bool = False, + w_trans: bool = True, + alpha: float = 1.0, + beta: float = 0.0, + bias: torch.Tensor = None, + x_max: torch.Tensor = None, + w_max: torch.Tensor = None, + x_pc_max: torch.Tensor = None, + w_pc_max: torch.Tensor = None, +) -> torch.Tensor: + return torch.empty( + (x.shape[0], w.shape[0]), + dtype=out_dtype, + device=x.device, + ) + + +matmul.register_fake(_fake_matmul) + + +################################################## +# ------------------- quant2d -------------------- +################################################## +@custom_op("_C::quant2d", mutates_args=()) +def quant2d( + x: torch.Tensor, + x_q: torch.Tensor, + max: torch.Tensor, + force_sdnn: bool = False, +) -> None: + xtorch_ops.quant2d( + x=x, + y=x_q, + max=max, + force_sdnn=force_sdnn, + ) + + +@impl("_C::quant2d", "CUDA") +def quant2d_cuda( + x: torch.Tensor, + x_q: torch.Tensor, + max: torch.Tensor, + force_sdnn: bool = False, +) -> None: + xtorch_ops.quant2d( + x=x, + y=x_q, + max=max, + force_sdnn=force_sdnn, + ) + + +def _fake_quant2d( + x: torch.Tensor, + x_q: torch.Tensor, + max: torch.Tensor, + force_sdnn: bool = False, +) -> None: + return None + + +quant2d.register_fake(_fake_quant2d)