"""vllm_utils_wrapper.py""" import vllm.distributed.parallel_state as parallel_state import vllm.utils as _orig from typing import Any, Callable, Optional, Union, get_origin, get_args, List, Tuple from types import SimpleNamespace import torch from torch.library import Library import inspect import typing from torch.library import register_fake import vllm_kunlun._kunlun def patch_annotations_for_schema(func): """patch_annotations_for_schema""" sig = inspect.signature(func) new_params = [] for name, param in sig.parameters.items(): ann = param.annotation if get_origin(ann) is typing.Union and type(None) in get_args(ann): inner_type = [a for a in get_args(ann) if a is not type(None)][0] if get_origin(inner_type) is list: # Optional[list[int]] inner_args = get_args(inner_type) new_ann = Optional[List[inner_args[0] if inner_args else typing.Any]] param = param.replace(annotation=new_ann) elif get_origin(ann) is list: args = get_args(ann) new_ann = List[args[0] if args else typing.Any] param = param.replace(annotation=new_ann) new_params.append(param) func.__signature__ = sig.replace(parameters=new_params) return func def supports_custom_op() -> bool: """supports_custom_op""" return hasattr(torch.library, "custom_op") vllm_lib = Library("vllm", "FRAGMENT") # noqa def direct_register_custom_op( op_name: str, op_func: Callable, mutates_args: Optional[list[str]] = None, fake_impl: Optional[Callable] = None, target_lib: Optional[Library] = None, dispatch_key: str = "CUDA", tags: tuple[torch.Tag, ...] = (), ): """ `torch.library.custom_op` can have significant overhead because it needs to consider complicated dispatching logic. This function directly registers a custom op and dispatches it to the CUDA backend. See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 for more details. By default, the custom op is registered to the vLLM library. If you want to register it to a different library, you can pass the library object to the `target_lib` argument. IMPORTANT: the lifetime of the operator is tied to the lifetime of the library object. If you want to bind the operator to a different library, make sure the library object is alive when the operator is used. """ if not supports_custom_op(): from vllm.platforms import current_platform assert not current_platform.is_cuda_alike(), ( "cuda platform needs torch>=2.4 to support custom op, " "chances are you are using an old version of pytorch " "or a custom build of pytorch. It is recommended to " "use vLLM in a fresh new environment and let it install " "the required dependencies." ) return if mutates_args is None: mutates_args = [] import torch.library if hasattr(torch.library, "infer_schema"): patched_func = patch_annotations_for_schema(op_func) schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) else: # for pytorch 2.4 import torch._custom_op.impl schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) my_lib = target_lib or vllm_lib my_lib.define(op_name + schema_str, tags=tags) my_lib.impl(op_name, op_func, dispatch_key=dispatch_key) if fake_impl is not None: my_lib._register_fake(op_name, fake_impl) def vllm_kunlun_weak_ref_tensor(tensor: Any) -> Any: """ Create a weak reference to a tensor. The new tensor will share the same data as the original tensor, but will not keep the original tensor alive. """ # return tensor if isinstance(tensor, torch.Tensor): return torch.ops._kunlun.weak_ref_tensor(tensor) else: return tensor def vllm_kunlun_weak_ref_tensors( tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]], ) -> Union[torch.Tensor, list[Any], tuple[Any], Any]: """ Convenience function to create weak references to tensors, for single tensor, list of tensors or tuple of tensors. """ if isinstance(tensors, torch.Tensor): return vllm_kunlun_weak_ref_tensor(tensors) if isinstance(tensors, list): return [vllm_kunlun_weak_ref_tensor(t) for t in tensors] if isinstance(tensors, tuple): return tuple(vllm_kunlun_weak_ref_tensor(t) for t in tensors) raise ValueError("Invalid type for tensors") # import vllm.utils as vu # vu.direct_register_custom_op = direct_register_custom_op # import vllm.utils as vu # vu.direct_register_custom_op = direct_register_custom_op _wrapped = SimpleNamespace(**_orig.__dict__) _wrapped.direct_register_custom_op = direct_register_custom_op _wrapped.weak_ref_tensor = vllm_kunlun_weak_ref_tensor _wrapped.weak_ref_tensors = vllm_kunlun_weak_ref_tensors import sys sys.modules["vllm.utils"] = _wrapped _original_all_reduce = parallel_state.GroupCoordinator.all_reduce _original_all_gather = parallel_state.GroupCoordinator.all_gather def vllm_kunlun_all_reduce(self, input_: torch.Tensor) -> torch.Tensor: """vllm_kunlun_all_reduce""" if self.world_size == 1: return input_ torch.distributed.all_reduce(input_, group=self.device_group) return input_ def vllm_kunlun_all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: """vllm_kunlun_all_reduce""" world_size = self.world_size # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ assert ( -input_.dim() <= dim < input_.dim() ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" if dim < 0: # Convert negative dim to positive. dim += input_.dim() input_size = input_.size() # Allocate output tensor. output_tensor = torch.empty( (world_size,) + input_size, dtype=input_.dtype, device=input_.device ) # All-gather. torch.distributed.all_gather_into_tensor( output_tensor, input_, group=self.device_group ) # Reshape output_tensor = output_tensor.movedim(0, dim) output_tensor = output_tensor.reshape( input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :] ) return output_tensor parallel_state.GroupCoordinator.all_reduce = vllm_kunlun_all_reduce parallel_state.GroupCoordinator.all_gather = vllm_kunlun_all_gather from torch.library import custom_op, impl import torch from vllm import _custom_ops as ops from typing import Optional, List import os @custom_op("_C::rms_norm", mutates_args=()) def rms_norm( result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, epsilon: float, ) -> None: pass @custom_op("_C::fused_add_rms_norm", mutates_args=()) def fused_add_rms_norm( result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, epsilon: float, ) -> None: pass @custom_op("_C::static_scaled_fp8_quant", mutates_args=()) def static_scaled_fp8_quant( result: torch.Tensor, input: torch.Tensor, scale: torch.Tensor ) -> None: pass @impl("_C::static_scaled_fp8_quant", "CUDA") def static_scaled_fp8_quant_xpu( result: torch.Tensor, input: torch.Tensor, scale: torch.Tensor ) -> None: pass @custom_op("_C::dynamic_scaled_fp8_quant", mutates_args=()) def dynamic_scaled_fp8_quant( result: torch.Tensor, input: torch.Tensor, scale: torch.Tensor ) -> None: pass @impl("_C::dynamic_scaled_fp8_quant", "CUDA") def dynamic_scaled_fp8_quant_xpu( result: torch.Tensor, input: torch.Tensor, scale: torch.Tensor ) -> None: pass @custom_op("_C::dynamic_per_token_scaled_fp8_quant", mutates_args=()) def dynamic_per_token_scaled_fp8_quant( result: torch.Tensor, input: torch.Tensor, scale: torch.Tensor, scale_ub: Optional[torch.Tensor], ) -> None: pass @impl("_C::dynamic_per_token_scaled_fp8_quant", "CUDA") def dynamic_per_token_scaled_fp8_quant_xpu( result: torch.Tensor, input: torch.Tensor, scale: torch.Tensor, scale_ub: Optional[torch.Tensor], ) -> None: pass @custom_op("_C::rms_norm_static_fp8_quant", mutates_args=()) def rms_norm_static_fp8_quant( result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, epsilon: float, ) -> None: pass @impl("_C::rms_norm_static_fp8_quant", "CUDA") def rms_norm_static_fp8_quant_xpu( result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, epsilon: float, ) -> None: pass @custom_op("_C::fused_add_rms_norm_static_fp8_quant", mutates_args=()) def fused_add_rms_norm_static_fp8_quant( result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, epsilon: float, ) -> None: pass @impl("_C::fused_add_rms_norm_static_fp8_quant", "CUDA") def fused_add_rms_norm_static_fp8_quant_xpu( result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, epsilon: float, ) -> None: pass @custom_op("_C::rms_norm_dynamic_per_token_quant", mutates_args=()) def rms_norm_dynamic_per_token_quant( result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, epsilon: float, scale_ub: Optional[torch.Tensor], residual: Optional[torch.Tensor], ) -> None: pass @impl("_C::rms_norm_dynamic_per_token_quant", "CUDA") def rms_norm_dynamic_per_token_quant_xpu( result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, epsilon: float, scale_ub: Optional[torch.Tensor], residual: Optional[torch.Tensor], ) -> None: pass @custom_op("_C::silu_and_mul_quant", mutates_args=()) def silu_and_mul_quant( result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, epsilon: float, ) -> None: pass @impl("_C::silu_and_mul_quant", "CUDA") def silu_and_mul_quant_xpu( result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, epsilon: float, ) -> None: pass import torch import xtorch_ops from torch.library import custom_op, impl @custom_op("_C::add_rmsnorm", mutates_args=()) def add_rmsnorm( x: torch.Tensor, y: torch.Tensor, weight: torch.Tensor, output: torch.Tensor, eps: float = 1e-5, interweaved: bool = False, store_output_before_norm: bool = True, bias: torch.Tensor = None, smooth: torch.Tensor = None, residual_output: torch.Tensor = None, output_max: torch.Tensor = None, ) -> None: xtorch_ops.add_rmsnorm( x, y, # 原来写 residual,这里其实是 y residual_output=residual_output, weight=weight, eps=eps, output=output, ) @impl("_C::add_rmsnorm", "CUDA") def add_rmsnorm_cuda( x: torch.Tensor, y: torch.Tensor, weight: torch.Tensor, output: torch.Tensor, eps: float = 1e-5, interweaved: bool = False, store_output_before_norm: bool = True, bias: torch.Tensor = None, smooth: torch.Tensor = None, residual_output: torch.Tensor = None, output_max: torch.Tensor = None, ) -> None: xtorch_ops.add_rmsnorm( x, y, residual_output=residual_output, weight=weight, eps=eps, output=output, ) @custom_op("_C::rmsnorm", mutates_args=()) def rmsnorm( x: torch.Tensor, weight: torch.Tensor, output: torch.Tensor, eps: float = 1e-5, interweave: bool = False, store_output_before_norm: bool = True, bias: torch.Tensor = None, residual_output: torch.Tensor = None, output_max: torch.Tensor = None, ) -> None: xtorch_ops.rmsnorm( x, weight, output, eps, ) @impl("_C::rmsnorm", "CUDA") def rmsnorm_cuda( x: torch.Tensor, weight: torch.Tensor, output: torch.Tensor, eps: float = 1e-5, interweave: bool = False, store_output_before_norm: bool = True, bias: torch.Tensor = None, residual_output: torch.Tensor = None, output_max: torch.Tensor = None, ) -> None: xtorch_ops.rmsnorm( x, weight, output, eps, ) import torch def _fake_rmsnorm( x, weight, output, eps=1e-5, interweave=False, store_output_before_norm=True, bias=None, residual_output=None, output_max=None, ): # 设置 shape/dtype,但不返回值 output.fake_shape = x.shape output.fake_dtype = x.dtype return None rmsnorm.register_fake(_fake_rmsnorm) def _fake_add_rmsnorm( x, y, weight, output, eps=1e-5, interweaved=False, store_output_before_norm=True, bias=None, smooth=None, residual_output=None, output_max=None, ): output.fake_shape = x.shape output.fake_dtype = x.dtype return None add_rmsnorm.register_fake(_fake_add_rmsnorm) @custom_op("_C::split_norm_rope_neox", mutates_args=()) def split_norm_rope_neox( q_emb: torch.Tensor, k_emb: torch.Tensor, v_out: torch.Tensor, qkv: torch.Tensor, rotary_pos_embedding: torch.Tensor, q_norm_weight: torch.Tensor, k_norm_weight: torch.Tensor, positions: torch.Tensor, num_tokens: int, max_seqlen: int, head_num: int, kv_head_num: int, head_dim: int, rotary_dim: int, emb_batch_size: int = 1, ) -> None: xtorch_ops.split_norm_rope_neox( q_emb, k_emb, v_out, qkv, rotary_pos_embedding, q_norm_weight, k_norm_weight, positions, num_tokens, max_seqlen, head_num, kv_head_num, head_dim, rotary_dim, ) @impl("_C::split_norm_rope_neox", "CUDA") def split_norm_rope_neox_cuda( q_emb: torch.Tensor, k_emb: torch.Tensor, v_out: torch.Tensor, qkv: torch.Tensor, rotary_pos_embedding: torch.Tensor, q_norm_weight: torch.Tensor, k_norm_weight: torch.Tensor, positions: torch.Tensor, num_tokens: int, max_seqlen: int, head_num: int, kv_head_num: int, head_dim: int, rotary_dim: int, emb_batch_size: int = 1, ) -> None: xtorch_ops.split_norm_rope_neox( q_emb, k_emb, v_out, qkv, rotary_pos_embedding, q_norm_weight, k_norm_weight, positions, num_tokens, max_seqlen, head_num, kv_head_num, head_dim, rotary_dim, ) def _fake_split_norm_rope_neox( q_emb: torch.Tensor, k_emb: torch.Tensor, v_out: torch.Tensor, qkv: torch.Tensor, rotary_pos_embedding: torch.Tensor, q_norm_weight: torch.Tensor, k_norm_weight: torch.Tensor, positions: torch.Tensor, num_tokens: int, max_seqlen: int, head_num: int, kv_head_num: int, head_dim: int, rotary_dim: int, emb_batch_size: int = 1, ): q_emb.fake_shape = q_emb.shape q_emb.fake_dtype = q_emb.dtype k_emb.fake_shape = k_emb.shape k_emb.fake_dtype = k_emb.dtype v_out.fake_shape = v_out.shape v_out.fake_dtype = v_out.dtype return None split_norm_rope_neox.register_fake(_fake_split_norm_rope_neox) # register fake op impl here # for torch.dynamo from torch.library import register_fake if hasattr(torch.ops.custom_ops, "fc_fusion"): @register_fake("custom_ops::fc_fusion") def fc_fusion_fake( self: torch.Tensor, other: torch.Tensor, bias: Optional[torch.Tensor], self_trans: bool, other_trans: bool, *, alpha: float = 1.0, beta: float = 0.0, act: int = 1, multi_stream: bool = False, out: torch.Tensor, ) -> None: pass @custom_op("_C::silu_and_mul", mutates_args=()) def silu_and_mul( out: torch.Tensor, x: torch.Tensor, axis: int = -1, turn: bool = True ) -> None: xtorch_ops.swiglu( x=x, y=out, ) @impl("_C::silu_and_mul", "CUDA") def silu_and_mul_cuda( out: torch.Tensor, x: torch.Tensor, axis: int = -1, turn: bool = True ) -> None: xtorch_ops.swiglu( x=x, y=out, ) def _fake_silu_and_mul( out: torch.Tensor, x: torch.Tensor, axis: int = -1, turn: bool = True ): return None silu_and_mul.register_fake(_fake_silu_and_mul) @custom_op("_C::swigluoai_and_mul", mutates_args=()) def swigluoai_and_mul( x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0, axis: int = -1, turn: bool = True, ) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" gate, up = x[..., ::2], x[..., 1::2] gate = gate.clamp(min=None, max=limit) up = up.clamp(min=-limit, max=limit) glu = gate * torch.sigmoid(gate * alpha) gated_output = (up + 1) * glu return gated_output @impl("_C::swigluoai_and_mul", "CUDA") def swigluoai_and_mul_cuda( x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0, axis: int = -1, turn: bool = True, ) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" gate, up = x[..., ::2], x[..., 1::2] gate = gate.clamp(min=None, max=limit) up = up.clamp(min=-limit, max=limit) glu = gate * torch.sigmoid(gate * alpha) gated_output = (up + 1) * glu return gated_output def _fake_swigluoai_and_mul( x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0, axis: int = -1, turn: bool = True, ) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" gate, up = x[..., ::2], x[..., 1::2] gate = gate.clamp(min=None, max=limit) up = up.clamp(min=-limit, max=limit) glu = gate * torch.sigmoid(gate * alpha) gated_output = (up + 1) * glu return gated_output swigluoai_and_mul.register_fake(_fake_swigluoai_and_mul) @custom_op("_C::moe_softmax_topk", mutates_args=()) def moe_softmax_topk( x: torch.Tensor, normed_score: torch.Tensor, topk_index: torch.Tensor, block_statistic: torch.Tensor, axis: int = -1, turn: bool = True, ) -> None: xtorch_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic) @impl("_C::moe_softmax_topk", "CUDA") def moe_softmax_topk_cuda( x: torch.Tensor, normed_score: torch.Tensor, topk_index: torch.Tensor, block_statistic: torch.Tensor, axis: int = -1, turn: bool = True, ) -> None: xtorch_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic) def _fake_moe_softmax_topk( x: torch.Tensor, normed_score: torch.Tensor, topk_index: torch.Tensor, block_statistic: torch.Tensor, axis: int = -1, turn: bool = True, ) -> None: return None moe_softmax_topk.register_fake(_fake_moe_softmax_topk) @custom_op("_C::moe_ffn_block", mutates_args=()) def moe_ffn_block( out: torch.Tensor, x: torch.Tensor, expert_num: int, moe_top_k: int, gate_w: torch.Tensor, inter_w: torch.Tensor, output_w: torch.Tensor, renormalize: bool = True, use_grouped_topk: bool = False, expert_group_num: Optional[int] = 0, topk_group: Optional[int] = 0, w1_bias: Optional[torch.Tensor] = None, w2_bias: Optional[torch.Tensor] = None, ) -> None: xtorch_ops.moe_ffn_block( x=x, gate_w=gate_w, inter_w=inter_w, output_w=output_w, expert_num=expert_num, moe_top_k=moe_top_k, topk_group=topk_group, renormalize=renormalize, use_grouped_topk=use_grouped_topk, expert_group_num=expert_group_num, out=out, ) @impl("_C::moe_ffn_block", "CUDA") def moe_ffn_block_cuda( out: torch.Tensor, x: torch.Tensor, expert_num: int, moe_top_k: int, gate_w: torch.Tensor, inter_w: torch.Tensor, output_w: torch.Tensor, renormalize: bool = True, use_grouped_topk: bool = False, expert_group_num: Optional[int] = 0, topk_group: Optional[int] = 0, w1_bias: Optional[torch.Tensor] = None, w2_bias: Optional[torch.Tensor] = None, ) -> None: xtorch_ops.moe_ffn_block( x=x, gate_w=gate_w, inter_w=inter_w, output_w=output_w, expert_num=expert_num, moe_top_k=moe_top_k, topk_group=topk_group, renormalize=renormalize, use_grouped_topk=use_grouped_topk, expert_group_num=expert_group_num, out=out, ) def _fake_moe_ffn_block( out: torch.Tensor, x: torch.Tensor, expert_num: int, moe_top_k: int, gate_w: torch.Tensor, inter_w: torch.Tensor, output_w: torch.Tensor, renormalize: bool = True, use_grouped_topk: bool = False, expert_group_num: Optional[int] = 0, topk_group: Optional[int] = 0, ): return None moe_ffn_block.register_fake(_fake_moe_ffn_block) @custom_op("_C::moe_ffn_per_token_block", mutates_args=()) def moe_ffn_per_token_block( x: torch.Tensor, inter_weight: torch.Tensor, inter_scale: torch.Tensor, outer_weight: torch.Tensor, outer_scale: torch.Tensor, top_k: int, global_num_experts: int, linear_weights: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None, activation: str = "silu", output: Optional[torch.Tensor] = None, use_expert_parallel: bool = False, ep_size: int = 1, ep_rank: int = 0, ) -> None: xtorch_ops.moe_ffn_per_token_block( x=x, inter_weight=inter_weight, inter_scale=inter_scale, outer_weight=outer_weight, outer_scale=outer_scale, gate_weight=linear_weights, expert_num=global_num_experts, moe_top_k=top_k, act_type=activation, use_expert_parallel=use_expert_parallel, ep_size=ep_size, ep_rank=ep_rank, out=output, ) @impl("_C::moe_ffn_per_token_block", "CUDA") def moe_ffn_per_token_block_cuda( x: torch.Tensor, inter_weight: torch.Tensor, inter_scale: torch.Tensor, outer_weight: torch.Tensor, outer_scale: torch.Tensor, top_k: int, global_num_experts: int, linear_weights: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None, activation: str = "silu", output: Optional[torch.Tensor] = None, use_expert_parallel: bool = False, ep_size: int = 1, ep_rank: int = 0, ) -> None: xtorch_ops.moe_ffn_per_token_block( x=x, inter_weight=inter_weight, inter_scale=inter_scale, outer_weight=outer_weight, outer_scale=outer_scale, gate_weight=linear_weights, expert_num=global_num_experts, moe_top_k=top_k, act_type=activation, use_expert_parallel=use_expert_parallel, ep_size=ep_size, ep_rank=ep_rank, out=output, ) def _fake_moe_ffn_per_token_block( x: torch.Tensor, inter_weight: torch.Tensor, inter_scale: torch.Tensor, outer_weight: torch.Tensor, outer_scale: torch.Tensor, top_k: int, global_num_experts: int, linear_weights: Optional[torch.Tensor] = None, expert_map: Optional[torch.Tensor] = None, activation: str = "silu", output: Optional[torch.Tensor] = None, use_expert_parallel: bool = False, ep_size: int = 1, ep_rank: int = 0, ) -> None: # Fake implementation can be a no-op or a simple operation if output is not None: output.copy_(x) # Example: simply copy input to output # Register the fake implementation moe_ffn_per_token_block.register_fake(_fake_moe_ffn_per_token_block) @custom_op("_C::rotary_embedding", mutates_args=()) def rotary_embedding( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, head_size: int, cos_sin_cache: torch.Tensor, is_neox: bool, ) -> None: xtorch_ops.rotary_embedding( positions=positions, query=query, key=key, head_size=head_size, cos_sin_cache=cos_sin_cache, is_neox=is_neox, ) @impl("_C::rotary_embedding", "CUDA") def rotary_embedding_cuda( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, head_size: int, cos_sin_cache: torch.Tensor, is_neox: bool, ) -> None: xtorch_ops.rotary_embedding( positions=positions, query=query, key=key, head_size=head_size, cos_sin_cache=cos_sin_cache, is_neox=is_neox, ) def _fake_rotary_embedding( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, head_size: int, cos_sin_cache: torch.Tensor, is_neox: bool, ) -> None: return None rotary_embedding.register_fake(_fake_rotary_embedding) @custom_op("_C::gemm_I8_I8_bf16_nt", mutates_args=()) def gemm_I8_I8_bf16_nt( x_q: torch.Tensor, x_scale: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, out: torch.Tensor, ) -> None: xtorch_ops.gemm_I8_I8_bf16_nt( lhs=(x_q, x_scale), rhs=(weight, weight_scale), out=out ) @impl("_C::gemm_I8_I8_bf16_nt", "CUDA") def gemm_I8_I8_bf16_nt_cuda( x_q: torch.Tensor, x_scale: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, out: torch.Tensor, ) -> None: xtorch_ops.gemm_I8_I8_bf16_nt( lhs=(x_q, x_scale), rhs=(weight, weight_scale), out=out ) def _fake_gemm_I8_I8_bf16_nt( x_q: torch.Tensor, x_scale: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, out: torch.Tensor, ) -> None: return None gemm_I8_I8_bf16_nt.register_fake(_fake_gemm_I8_I8_bf16_nt) @custom_op("_C::moe_softmax_topk_norm", mutates_args=()) def moe_softmax_topk_norm( x: torch.Tensor, normed_score: torch.Tensor, topk_index: torch.Tensor, block_statistic: torch.Tensor, stable: bool = True, ) -> None: xtorch_ops.moe_softmax_topk_norm( x, normed_score, topk_index, block_statistic, stable ) @impl("_C::moe_softmax_topk_norm", "CUDA") def moe_softmax_topk_norm_cuda( x: torch.Tensor, normed_score: torch.Tensor, topk_index: torch.Tensor, block_statistic: torch.Tensor, stable: bool = True, ) -> None: xtorch_ops.moe_softmax_topk_norm( x, normed_score, topk_index, block_statistic, stable ) def _fake_moe_softmax_topk_norm( x: torch.Tensor, normed_score: torch.Tensor, topk_index: torch.Tensor, block_statistic: torch.Tensor, stable: bool = True, ) -> None: return None moe_softmax_topk_norm.register_fake(_fake_moe_softmax_topk_norm) @custom_op("_C::gen_block_statistic", mutates_args=()) def gen_block_statistic(topk_ids: torch.Tensor, block_statistic: torch.Tensor) -> None: xtorch_ops.gen_block_statistic(topk_ids, block_statistic) @impl("_C::gen_block_statistic", "CUDA") def gen_block_statistic_cuda( topk_ids: torch.Tensor, block_statistic: torch.Tensor ) -> None: xtorch_ops.gen_block_statistic(topk_ids, block_statistic) def fake_gen_block_statistic( topk_ids: torch.Tensor, block_statistic: torch.Tensor ) -> None: return None gen_block_statistic.register_fake(fake_gen_block_statistic) @custom_op("_C::moe_pre_sorted", mutates_args=()) def moe_pre_sorted( x: torch.Tensor, topk_index: torch.Tensor, block_statistic: torch.Tensor, moe_expand: torch.Tensor, moe_index: torch.Tensor, expert_m: torch.Tensor, sorted_tokens_num_lod: torch.Tensor, index_have_neg: bool = False, ) -> None: xtorch_ops.moe_pre_sorted( x, topk_index, block_statistic, moe_expand, moe_index, expert_m, sorted_tokens_num_lod, ) @impl("_C::moe_pre_sorted", "CUDA") def moe_pre_sorted_cuda( x: torch.Tensor, topk_index: torch.Tensor, block_statistic: torch.Tensor, moe_expand: torch.Tensor, moe_index: torch.Tensor, expert_m: torch.Tensor, sorted_tokens_num_lod: torch.Tensor, index_have_neg: bool = False, ) -> None: xtorch_ops.moe_pre_sorted( x, topk_index, block_statistic, moe_expand, moe_index, expert_m, sorted_tokens_num_lod, ) def fake_moe_pre_sorted( x: torch.Tensor, topk_index: torch.Tensor, block_statistic: torch.Tensor, moe_expand: torch.Tensor, moe_index: torch.Tensor, expert_m: torch.Tensor, sorted_tokens_num_lod: torch.Tensor, index_have_neg: bool = False, ) -> None: return None moe_pre_sorted.register_fake(fake_moe_pre_sorted) @custom_op("_C::moe_fc", mutates_args=()) def moe_fc( x: torch.Tensor, weight: torch.Tensor, sorted_tokens_num_lod: torch.Tensor, sorted_tokens_idx: torch.Tensor, moe_topk: int, y: torch.Tensor, act: Optional[str] = None, x_perchannel_max: Optional[torch.Tensor] = None, w_perchannel_max: Optional[torch.Tensor] = None, topk_ids: Optional[torch.Tensor] = None, topk_w: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, tgemm_type: Optional[str] = None, tweight_type: Optional[str] = None, scale_n: Optional[int] = 0, scale_k: Optional[int] = 0, use_pack_int4: Optional[bool] = False, sort_mode: Optional[bool] = True, ) -> None: xtorch_ops.moe_fc( x=x, weight=weight, sorted_tokens_num_lod=sorted_tokens_num_lod, sorted_tokens_idx=sorted_tokens_idx, moe_topk=moe_topk, y=y, act=act, x_perchannel_max=x_perchannel_max, w_perchannel_max=w_perchannel_max, topk_ids=topk_ids, topk_w=topk_w, bias=bias, tgemm_type=tgemm_type, tweight_type=tweight_type, scale_n=scale_n, scale_k=scale_k, use_pack_int4=use_pack_int4, sort_mode=sort_mode, ) @impl("_C::moe_fc", "CUDA") def moe_fc_cuda( x: torch.Tensor, weight: torch.Tensor, sorted_tokens_num_lod: torch.Tensor, sorted_tokens_idx: torch.Tensor, moe_topk: int, y: torch.Tensor, act: Optional[str] = None, x_perchannel_max: Optional[torch.Tensor] = None, w_perchannel_max: Optional[torch.Tensor] = None, topk_ids: Optional[torch.Tensor] = None, topk_w: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, tgemm_type: Optional[str] = None, tweight_type: Optional[str] = None, scale_n: Optional[int] = 0, scale_k: Optional[int] = 0, use_pack_int4: Optional[bool] = False, sort_mode: Optional[bool] = True, ) -> None: xtorch_ops.moe_fc( x=x, weight=weight, sorted_tokens_num_lod=sorted_tokens_num_lod, sorted_tokens_idx=sorted_tokens_idx, moe_topk=moe_topk, y=y, act=act, x_perchannel_max=x_perchannel_max, w_perchannel_max=w_perchannel_max, topk_ids=topk_ids, topk_w=topk_w, bias=bias, tgemm_type=tgemm_type, tweight_type=tweight_type, scale_n=scale_n, scale_k=scale_k, use_pack_int4=use_pack_int4, sort_mode=sort_mode, ) def fake_moe_fc( x: torch.Tensor, weight: torch.Tensor, sorted_tokens_num_lod: torch.Tensor, sorted_tokens_idx: torch.Tensor, moe_topk: int, y: torch.Tensor, act: Optional[str] = None, x_perchannel_max: Optional[torch.Tensor] = None, w_perchannel_max: Optional[torch.Tensor] = None, topk_ids: Optional[torch.Tensor] = None, topk_w: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, tgemm_type: Optional[str] = None, tweight_type: Optional[str] = None, scale_n: Optional[int] = 0, scale_k: Optional[int] = 0, use_pack_int4: Optional[bool] = False, sort_mode: Optional[bool] = True, ) -> None: return None moe_fc.register_fake(fake_moe_fc) @custom_op("_C::moe_post", mutates_args=()) def moe_post( x: torch.Tensor, moe_index: torch.Tensor, normed_scale: torch.Tensor, dequant_scale: torch.Tensor, y: torch.Tensor, ) -> None: xtorch_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y) @impl("_C::moe_post", "CUDA") def moe_post_cuda( x: torch.Tensor, moe_index: torch.Tensor, normed_scale: torch.Tensor, dequant_scale: torch.Tensor, y: torch.Tensor, ) -> None: xtorch_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y) def fake_moe_post( x: torch.Tensor, moe_index: torch.Tensor, normed_scale: torch.Tensor, dequant_scale: torch.Tensor, y: torch.Tensor, ) -> None: return None moe_post.register_fake(fake_moe_post) @custom_op("_C::moe_sigmoid_group_topk_norm", mutates_args=()) def moe_sigmoid_group_topk_norm( x: torch.Tensor, topk_index: torch.Tensor, norm_score: torch.Tensor, block_static: torch.Tensor, bias: torch.Tensor, scale: float, n_group: int, topk_group: int, ) -> None: xtorch_ops.moe_sigmoid_group_topk_norm( x=x, norm_score=norm_score, topk_index=topk_index, block_static=block_static, bias=bias, n_group=n_group, topk_group=topk_group, scale=scale, ) @impl("_C::moe_sigmoid_group_topk_norm", "CUDA") def moe_sigmoid_group_topk_norm_cuda( x: torch.Tensor, topk_index: torch.Tensor, norm_score: torch.Tensor, block_static: torch.Tensor, bias: torch.Tensor, scale: float, n_group: int, topk_group: int, ) -> None: xtorch_ops.moe_sigmoid_group_topk_norm( x=x, norm_score=norm_score, topk_index=topk_index, block_static=block_static, bias=bias, n_group=n_group, topk_group=topk_group, scale=scale, ) def _fake_moe_sigmoid_group_topk_norm( x: torch.Tensor, topk_index: torch.Tensor, norm_score: torch.Tensor, block_static: torch.Tensor, bias: torch.Tensor, scale: float, n_group: int, topk_group: int, ) -> None: return None moe_sigmoid_group_topk_norm.register_fake(_fake_moe_sigmoid_group_topk_norm) ################################################## # --------------- awq_dequantize ----------------- ################################################## @custom_op("_C::awq_dequantize", mutates_args=()) def awq_dequantize( qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, quant_type: int = 0, align_type: int = 1, ) -> torch.Tensor: weight = torch.empty( (qweight.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=qweight.device, ) group_m = int(qweight.shape[0] / scales.shape[0]) xtorch_ops.awq_dequantize( qweight=qweight, scales=scales, zeros=zeros, weight=weight, group_m=group_m, quant_type=quant_type, align_type=align_type, ) return weight @impl("_C::awq_dequantize", "CUDA") def awq_dequantize_cuda( qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, quant_type: int = 0, align_type: int = 1, ) -> torch.Tensor: weight = torch.empty( (qweight.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=qweight.device, ) group_m = int(qweight.shape[0] / scales.shape[0]) out = xtorch_ops.awq_dequantize( qweight=qweight, scales=scales, zeros=zeros, weight=weight, group_m=group_m, quant_type=quant_type, align_type=align_type, ) return weight def _fake_awq_dequantize( qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, quant_type: int = 0, align_type: int = 1, ) -> torch.Tensor: weight = torch.empty( (qweight.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=qweight.device, ) return weight awq_dequantize.register_fake(_fake_awq_dequantize) ################################################## # ------------------ awq_gemm ------------------- ################################################## @custom_op("_C::awq_gemm", mutates_args=()) def awq_gemm( x: torch.Tensor, qweight: torch.Tensor, scale: torch.Tensor, zeros: torch.Tensor, align_type: int = 1, ) -> torch.Tensor: out = torch.empty( (x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device ) group_size = int(qweight.shape[0] / scale.shape[0]) xtorch_ops.awq_gemm( x=x, w=qweight, scale=scale, zeros=zeros, out=out, align_type=align_type, group_size=group_size, ) return out @impl("_C::awq_gemm", "CUDA") def awq_gemm_cuda( x: torch.Tensor, qweight: torch.Tensor, scale: torch.Tensor, zeros: torch.Tensor, align_type: int = 1, ) -> torch.Tensor: out = torch.empty( (x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device ) group_size = int(qweight.shape[0] / scale.shape[0]) xtorch_ops.awq_gemm( x=x, w=qweight, scale=scale, zeros=zeros, out=out, align_type=align_type, group_size=group_size, ) return out def _fake_awq_gemm( x: torch.Tensor, qweight: torch.Tensor, scale: torch.Tensor, zeros: torch.Tensor, align_type: int = 1, ) -> torch.Tensor: out = torch.empty( (x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device ) return out awq_gemm.register_fake(_fake_awq_gemm) ################################################## # ---------------- gptq_shuffle ------------------ ################################################## @custom_op("_C::gptq_shuffle", mutates_args=()) def gptq_shuffle( q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int, ) -> None: xtorch_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit) @impl("_C::gptq_shuffle", "CUDA") def gptq_shuffle_cuda( q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int, ) -> None: xtorch_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit) def _fake_gptq_shuffle( q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int, ) -> None: return None gptq_shuffle.register_fake(_fake_gptq_shuffle) ################################################## # ------------- concat_and_cache_mla ------------- ################################################## @custom_op("_C::concat_and_cache_mla", mutates_args=()) def concat_and_cache_mla( kv_c: torch.Tensor, # [num_tokens, kv_lora_rank] k_pe: torch.Tensor, # [num_tokens, pe_dim] kv_cache: torch.Tensor, # [num_blocks, block_size, (kv_lora_rank + pe_dim)] slot_mapping: torch.Tensor, # [num_tokens] or [num_actual_tokens] ) -> None: xtorch_ops.concat_and_cache_mla( kv_c=kv_c, k_pe=k_pe, slot_mapping=slot_mapping, kv_cache=kv_cache, ) @impl("_C::concat_and_cache_mla", "CUDA") def concat_and_cache_mla_cuda( kv_c: torch.Tensor, # [num_tokens, kv_lora_rank] k_pe: torch.Tensor, # [num_tokens, pe_dim] kv_cache: torch.Tensor, # [num_blocks, block_size, (kv_lora_rank + pe_dim)] slot_mapping: torch.Tensor, # [num_tokens] or [num_actual_tokens] ) -> None: xtorch_ops.concat_and_cache_mla( kv_c=kv_c, k_pe=k_pe, slot_mapping=slot_mapping, kv_cache=kv_cache, ) def _fake_concat_and_cache_mla( kv_c: torch.Tensor, # [num_tokens, kv_lora_rank] k_pe: torch.Tensor, # [num_tokens, pe_dim] kv_cache: torch.Tensor, # [num_blocks, block_size, (kv_lora_rank + pe_dim)] slot_mapping: torch.Tensor, # [num_tokens] or [num_actual_tokens] ) -> None: return None concat_and_cache_mla.register_fake(_fake_concat_and_cache_mla) ###################################################### # -------------- scaled_int8_quant ------------------- ###################################################### @custom_op("_C::scaled_int8_quant", mutates_args=()) def scaled_int8_quant( x: torch.Tensor, scale: torch.Tensor, azp: Optional[torch.Tensor] = None, symmetric: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]: static = False x_q = torch.empty_like(x, dtype=torch.int8, device=x.device) if scale is not None: # static static = True torch.ops.xspeedgate_ops.static_scaled_int8_quant(x_q, x, scale, azp) else: # dynamic scale = torch.empty( (x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32 ) azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32) if symmetric: # NOTE: For quant2d ops, scale represents max. xtorch_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True) else: torch.ops.xspeedgate_ops.dynamic_scaled_int8_quant( x_q, x.contiguous(), scale, azp ) return x_q, scale, azp, static @impl("_C::scaled_int8_quant", "CUDA") def scaled_int8_quant_cuda( x: torch.Tensor, scale: torch.Tensor, azp: Optional[torch.Tensor] = None, symmetric: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]: static = False x_q = torch.empty_like(x, dtype=torch.int8, device=x.device) if scale is not None: # static static = True torch.ops.xspeedgate_ops.static_scaled_int8_quant(x_q, x, scale, azp) else: # dynamic scale = torch.empty( (x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32 ) azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32) if symmetric: # NOTE: For quant2d ops, scale represents max. xtorch_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True) else: torch.ops.xspeedgate_ops.dynamic_scaled_int8_quant( x_q, x.contiguous(), scale, azp ) return x_q, scale, azp, static def _fake_scaled_int8_quant( x: torch.Tensor, scale: torch.Tensor, azp: Optional[torch.Tensor] = None, symmetric: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]: x_q = torch.empty_like(x, dtype=torch.int8, device=x.device) scale = torch.empty( (x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32 ) azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32) return x_q, scale, azp, False scaled_int8_quant.register_fake(_fake_scaled_int8_quant) ###################################################### # ---------------- cutlass_scaled_mm ----------------- ###################################################### @custom_op("_C::cutlass_scaled_mm", mutates_args=()) def cutlass_scaled_mm( a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, out_dtype: torch.dtype, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device) torch.ops.xspeedgate_ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) return out @impl("_C::cutlass_scaled_mm", "CUDA") def cutlass_scaled_mm_cuda( a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, out_dtype: torch.dtype, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device) torch.ops.xspeedgate_ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) return out def fake_cutlass_scaled_mm( a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, out_dtype: torch.dtype, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: return torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device) cutlass_scaled_mm.register_fake(fake_cutlass_scaled_mm) ###################################################### # ------------ cutlass_scaled_mm_azp ----------------- ###################################################### @custom_op("_C::cutlass_scaled_mm_azp", mutates_args=()) def cutlass_scaled_mm_azp( a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, out_dtype: torch.dtype, azp_adj: torch.Tensor, azp: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device) torch.ops.xspeedgate_ops.cutlass_scaled_mm_azp( out, a, b, scale_a, scale_b, azp_adj, azp, bias ) return out @impl("_C::cutlass_scaled_mm_azp", "CUDA") def cutlass_scaled_mm_azp_cuda( a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, out_dtype: torch.dtype, azp_adj: torch.Tensor, azp: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device) torch.ops.xspeedgate_ops.cutlass_scaled_mm_azp( out, a, b, scale_a, scale_b, azp_adj, azp, bias ) return out def fake_cutlass_scaled_mm_azp( a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, out_dtype: torch.dtype, azp_adj: torch.Tensor, azp: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: return torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device) cutlass_scaled_mm_azp.register_fake(fake_cutlass_scaled_mm_azp) ################################################## # ------------------ matmul --------------------- ################################################## @custom_op("_C::matmul", mutates_args=()) def matmul( x: torch.Tensor, w: torch.Tensor, out_dtype: torch.dtype, x_trans: bool = False, w_trans: bool = True, alpha: float = 1.0, beta: float = 0.0, bias: torch.Tensor = None, x_max: torch.Tensor = None, w_max: torch.Tensor = None, x_pc_max: torch.Tensor = None, w_pc_max: torch.Tensor = None, ) -> torch.Tensor: out = torch.empty( (x.shape[0], w.shape[0] if w_trans else w.shape[1]), dtype=out_dtype, device=x.device, ) xtorch_ops.matmul( x=x.contiguous(), w=w.contiguous(), out=out, x_trans=x_trans, w_trans=w_trans, alpha=alpha, beta=beta, bias=bias, x_max=x_max, w_max=w_max, x_pc_max=x_pc_max, w_pc_max=w_pc_max, ) return out @impl("_C::matmul", "CUDA") def matmul_cuda( x: torch.Tensor, w: torch.Tensor, out_dtype: torch.dtype, x_trans: bool = False, w_trans: bool = True, alpha: float = 1.0, beta: float = 0.0, bias: torch.Tensor = None, x_max: torch.Tensor = None, w_max: torch.Tensor = None, x_pc_max: torch.Tensor = None, w_pc_max: torch.Tensor = None, ) -> torch.Tensor: out = torch.empty( (x.shape[0], w.shape[0] if w_trans else w.shape[1]), dtype=out_dtype, device=x.device, ) xtorch_ops.matmul( x=x.contiguous(), w=w.contiguous(), out=out, x_trans=x_trans, w_trans=w_trans, alpha=alpha, beta=beta, bias=bias, x_max=x_max, w_max=w_max, x_pc_max=x_pc_max, w_pc_max=w_pc_max, ) return out def _fake_matmul( x: torch.Tensor, w: torch.Tensor, out_dtype: torch.dtype, x_trans: bool = False, w_trans: bool = True, alpha: float = 1.0, beta: float = 0.0, bias: torch.Tensor = None, x_max: torch.Tensor = None, w_max: torch.Tensor = None, x_pc_max: torch.Tensor = None, w_pc_max: torch.Tensor = None, ) -> torch.Tensor: return torch.empty( (x.shape[0], w.shape[0] if w_trans else w.shape[1]), dtype=out_dtype, device=x.device, ) matmul.register_fake(_fake_matmul) ################################################## # ------------------- quant2d -------------------- ################################################## @custom_op("_C::quant2d", mutates_args=()) def quant2d( x: torch.Tensor, x_q: torch.Tensor, max: torch.Tensor, force_sdnn: bool = False, ) -> None: xtorch_ops.quant2d( x=x, y=x_q, max=max, force_sdnn=force_sdnn, ) @impl("_C::quant2d", "CUDA") def quant2d_cuda( x: torch.Tensor, x_q: torch.Tensor, max: torch.Tensor, force_sdnn: bool = False, ) -> None: xtorch_ops.quant2d( x=x, y=x_q, max=max, force_sdnn=force_sdnn, ) def _fake_quant2d( x: torch.Tensor, x_q: torch.Tensor, max: torch.Tensor, force_sdnn: bool = False, ) -> None: return None quant2d.register_fake(_fake_quant2d)