Files
xc-llm-kunlun/vllm_kunlun/vllm_utils_wrapper.py

2238 lines
59 KiB
Python
Raw Normal View History

2025-12-10 12:05:39 +08:00
"""vllm_utils_wrapper.py"""
2025-12-10 12:05:39 +08:00
import vllm.distributed.parallel_state as parallel_state
import vllm.utils as _orig
from typing import Any, Callable, Optional, Union, get_origin, get_args, List, Tuple
2025-12-10 12:05:39 +08:00
from types import SimpleNamespace
import torch
from torch.library import Library
import inspect
import typing
from torch.library import register_fake
import vllm_kunlun._kunlun
2026-01-08 11:05:48 +08:00
import vllm.envs as envs
2025-12-10 12:05:39 +08:00
def patch_annotations_for_schema(func):
"""patch_annotations_for_schema"""
sig = inspect.signature(func)
new_params = []
for name, param in sig.parameters.items():
ann = param.annotation
if get_origin(ann) is typing.Union and type(None) in get_args(ann):
inner_type = [a for a in get_args(ann) if a is not type(None)][0]
if get_origin(inner_type) is list: # Optional[list[int]]
inner_args = get_args(inner_type)
new_ann = Optional[List[inner_args[0] if inner_args else typing.Any]]
param = param.replace(annotation=new_ann)
elif get_origin(ann) is list:
args = get_args(ann)
new_ann = List[args[0] if args else typing.Any]
param = param.replace(annotation=new_ann)
new_params.append(param)
func.__signature__ = sig.replace(parameters=new_params)
return func
2025-12-10 12:05:39 +08:00
def supports_custom_op() -> bool:
"""supports_custom_op"""
return hasattr(torch.library, "custom_op")
2025-12-10 12:05:39 +08:00
vllm_lib = Library("vllm", "FRAGMENT") # noqa
2025-12-10 12:05:39 +08:00
def direct_register_custom_op(
op_name: str,
op_func: Callable,
mutates_args: Optional[list[str]] = None,
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
dispatch_key: str = "CUDA",
tags: tuple[torch.Tag, ...] = (),
2025-12-10 12:05:39 +08:00
):
"""
`torch.library.custom_op` can have significant overhead because it
needs to consider complicated dispatching logic. This function
directly registers a custom op and dispatches it to the CUDA backend.
See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
for more details.
By default, the custom op is registered to the vLLM library. If you
want to register it to a different library, you can pass the library
object to the `target_lib` argument.
IMPORTANT: the lifetime of the operator is tied to the lifetime of the
library object. If you want to bind the operator to a different library,
make sure the library object is alive when the operator is used.
"""
if not supports_custom_op():
from vllm.platforms import current_platform
2025-12-10 12:05:39 +08:00
assert not current_platform.is_cuda_alike(), (
"cuda platform needs torch>=2.4 to support custom op, "
"chances are you are using an old version of pytorch "
"or a custom build of pytorch. It is recommended to "
"use vLLM in a fresh new environment and let it install "
"the required dependencies."
)
2025-12-10 12:05:39 +08:00
return
2025-12-10 17:51:24 +08:00
if mutates_args is None:
mutates_args = []
2025-12-10 12:05:39 +08:00
import torch.library
2025-12-10 12:05:39 +08:00
if hasattr(torch.library, "infer_schema"):
patched_func = patch_annotations_for_schema(op_func)
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
2025-12-10 12:05:39 +08:00
else:
# for pytorch 2.4
import torch._custom_op.impl
2025-12-10 12:05:39 +08:00
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
my_lib = target_lib or vllm_lib
my_lib.define(op_name + schema_str, tags=tags)
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl)
2025-12-10 12:05:39 +08:00
def vllm_kunlun_weak_ref_tensor(tensor: Any) -> Any:
"""
Create a weak reference to a tensor.
The new tensor will share the same data as the original tensor,
but will not keep the original tensor alive.
"""
# return tensor
2025-12-10 17:51:24 +08:00
if isinstance(tensor, torch.Tensor):
return torch.ops._kunlun.weak_ref_tensor(tensor)
else:
return tensor
2025-12-10 12:05:39 +08:00
2025-12-10 12:05:39 +08:00
def vllm_kunlun_weak_ref_tensors(
tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]],
2025-12-10 12:05:39 +08:00
) -> Union[torch.Tensor, list[Any], tuple[Any], Any]:
"""
Convenience function to create weak references to tensors,
for single tensor, list of tensors or tuple of tensors.
"""
if isinstance(tensors, torch.Tensor):
return vllm_kunlun_weak_ref_tensor(tensors)
if isinstance(tensors, list):
return [vllm_kunlun_weak_ref_tensor(t) for t in tensors]
if isinstance(tensors, tuple):
return tuple(vllm_kunlun_weak_ref_tensor(t) for t in tensors)
raise ValueError("Invalid type for tensors")
2026-01-08 11:05:48 +08:00
vllm_port=envs.VLLM_PORT
def _get_open_port() -> int:
global vllm_port
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", vllm_port))
vllm_port += 1
return vllm_port
except OSError:
# try ipv6
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
2025-12-10 12:05:39 +08:00
_wrapped = SimpleNamespace(**_orig.__dict__)
_wrapped.direct_register_custom_op = direct_register_custom_op
_wrapped.weak_ref_tensor = vllm_kunlun_weak_ref_tensor
_wrapped.weak_ref_tensors = vllm_kunlun_weak_ref_tensors
2026-01-08 11:05:48 +08:00
_wrapped._get_open_port = _get_open_port
2025-12-10 12:05:39 +08:00
import sys
2025-12-10 12:05:39 +08:00
sys.modules["vllm.utils"] = _wrapped
_original_all_reduce = parallel_state.GroupCoordinator.all_reduce
_original_all_gather = parallel_state.GroupCoordinator.all_gather
2025-12-10 12:05:39 +08:00
def vllm_kunlun_all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
"""vllm_kunlun_all_reduce"""
if self.world_size == 1:
return input_
torch.distributed.all_reduce(input_, group=self.device_group)
return input_
2025-12-10 12:05:39 +08:00
def vllm_kunlun_all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
"""vllm_kunlun_all_reduce"""
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert (
-input_.dim() <= dim < input_.dim()
), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
2025-12-10 12:05:39 +08:00
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# Allocate output tensor.
output_tensor = torch.empty(
(world_size,) + input_size, dtype=input_.dtype, device=input_.device
)
2025-12-10 12:05:39 +08:00
# All-gather.
torch.distributed.all_gather_into_tensor(
output_tensor, input_, group=self.device_group
)
2025-12-10 12:05:39 +08:00
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(
input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
)
2025-12-10 12:05:39 +08:00
return output_tensor
2025-12-10 12:05:39 +08:00
parallel_state.GroupCoordinator.all_reduce = vllm_kunlun_all_reduce
parallel_state.GroupCoordinator.all_gather = vllm_kunlun_all_gather
from torch.library import custom_op, impl
2025-12-10 17:51:24 +08:00
import torch
2025-12-10 12:05:39 +08:00
from vllm import _custom_ops as ops
from typing import Optional, List
2025-12-10 17:51:24 +08:00
import os
2025-12-10 12:05:39 +08:00
2025-12-10 12:05:39 +08:00
@custom_op("_C::rms_norm", mutates_args=())
def rms_norm(
result: torch.Tensor,
2025-12-10 12:05:39 +08:00
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
epsilon: float,
) -> None:
2025-12-10 12:05:39 +08:00
pass
2025-12-10 12:05:39 +08:00
@custom_op("_C::fused_add_rms_norm", mutates_args=())
def fused_add_rms_norm(
result: torch.Tensor,
2025-12-10 12:05:39 +08:00
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
epsilon: float,
) -> None:
2025-12-10 12:05:39 +08:00
pass
2025-12-10 12:05:39 +08:00
@custom_op("_C::static_scaled_fp8_quant", mutates_args=())
def static_scaled_fp8_quant(
result: torch.Tensor, input: torch.Tensor, scale: torch.Tensor
) -> None:
2025-12-10 12:05:39 +08:00
pass
2025-12-10 12:05:39 +08:00
@impl("_C::static_scaled_fp8_quant", "CUDA")
def static_scaled_fp8_quant_xpu(
result: torch.Tensor, input: torch.Tensor, scale: torch.Tensor
) -> None:
2025-12-10 12:05:39 +08:00
pass
2025-12-10 12:05:39 +08:00
@custom_op("_C::dynamic_scaled_fp8_quant", mutates_args=())
def dynamic_scaled_fp8_quant(
result: torch.Tensor, input: torch.Tensor, scale: torch.Tensor
) -> None:
2025-12-10 12:05:39 +08:00
pass
2025-12-10 12:05:39 +08:00
@impl("_C::dynamic_scaled_fp8_quant", "CUDA")
def dynamic_scaled_fp8_quant_xpu(
result: torch.Tensor, input: torch.Tensor, scale: torch.Tensor
) -> None:
2025-12-10 12:05:39 +08:00
pass
2025-12-10 12:05:39 +08:00
@custom_op("_C::dynamic_per_token_scaled_fp8_quant", mutates_args=())
def dynamic_per_token_scaled_fp8_quant(
result: torch.Tensor,
2025-12-10 12:05:39 +08:00
input: torch.Tensor,
scale: torch.Tensor,
scale_ub: Optional[torch.Tensor],
) -> None:
2025-12-10 12:05:39 +08:00
pass
2025-12-10 12:05:39 +08:00
@impl("_C::dynamic_per_token_scaled_fp8_quant", "CUDA")
def dynamic_per_token_scaled_fp8_quant_xpu(
result: torch.Tensor,
2025-12-10 12:05:39 +08:00
input: torch.Tensor,
scale: torch.Tensor,
scale_ub: Optional[torch.Tensor],
) -> None:
2025-12-10 12:05:39 +08:00
pass
2025-12-10 12:05:39 +08:00
@custom_op("_C::rms_norm_static_fp8_quant", mutates_args=())
def rms_norm_static_fp8_quant(
result: torch.Tensor,
2025-12-10 12:05:39 +08:00
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
epsilon: float,
) -> None:
2025-12-10 12:05:39 +08:00
pass
2025-12-10 12:05:39 +08:00
@impl("_C::rms_norm_static_fp8_quant", "CUDA")
def rms_norm_static_fp8_quant_xpu(
result: torch.Tensor,
2025-12-10 12:05:39 +08:00
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
epsilon: float,
) -> None:
2025-12-10 12:05:39 +08:00
pass
2025-12-10 12:05:39 +08:00
@custom_op("_C::fused_add_rms_norm_static_fp8_quant", mutates_args=())
def fused_add_rms_norm_static_fp8_quant(
result: torch.Tensor,
2025-12-10 12:05:39 +08:00
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
epsilon: float,
) -> None:
2025-12-10 12:05:39 +08:00
pass
2025-12-10 12:05:39 +08:00
@impl("_C::fused_add_rms_norm_static_fp8_quant", "CUDA")
def fused_add_rms_norm_static_fp8_quant_xpu(
result: torch.Tensor,
2025-12-10 12:05:39 +08:00
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
epsilon: float,
) -> None:
2025-12-10 12:05:39 +08:00
pass
2025-12-10 12:05:39 +08:00
@custom_op("_C::rms_norm_dynamic_per_token_quant", mutates_args=())
def rms_norm_dynamic_per_token_quant(
result: torch.Tensor,
2025-12-10 12:05:39 +08:00
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
epsilon: float,
scale_ub: Optional[torch.Tensor],
residual: Optional[torch.Tensor],
) -> None:
2025-12-10 12:05:39 +08:00
pass
2025-12-10 12:05:39 +08:00
@impl("_C::rms_norm_dynamic_per_token_quant", "CUDA")
def rms_norm_dynamic_per_token_quant_xpu(
result: torch.Tensor,
2025-12-10 12:05:39 +08:00
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
epsilon: float,
scale_ub: Optional[torch.Tensor],
residual: Optional[torch.Tensor],
) -> None:
2025-12-10 12:05:39 +08:00
pass
2025-12-10 17:51:24 +08:00
@custom_op("_C::silu_and_mul_quant", mutates_args=())
def silu_and_mul_quant(
result: torch.Tensor,
2025-12-10 17:51:24 +08:00
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
epsilon: float,
) -> None:
2025-12-10 17:51:24 +08:00
pass
2025-12-10 17:51:24 +08:00
@impl("_C::silu_and_mul_quant", "CUDA")
def silu_and_mul_quant_xpu(
result: torch.Tensor,
2025-12-10 17:51:24 +08:00
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
epsilon: float,
) -> None:
2025-12-10 17:51:24 +08:00
pass
2025-12-10 17:51:24 +08:00
import torch
2025-12-10 12:05:39 +08:00
import xtorch_ops
from torch.library import custom_op, impl
2025-12-10 12:05:39 +08:00
@custom_op("_C::add_rmsnorm", mutates_args=())
def add_rmsnorm(
x: torch.Tensor,
y: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
interweaved: bool = False,
store_output_before_norm: bool = True,
bias: torch.Tensor = None,
smooth: torch.Tensor = None,
residual_output: torch.Tensor = None,
output_max: torch.Tensor = None,
) -> None:
xtorch_ops.add_rmsnorm(
x,
y, # 原来写 residual这里其实是 y
2025-12-10 12:05:39 +08:00
residual_output=residual_output,
weight=weight,
eps=eps,
output=output,
)
2025-12-10 17:51:24 +08:00
2025-12-10 12:05:39 +08:00
@impl("_C::add_rmsnorm", "CUDA")
def add_rmsnorm_cuda(
x: torch.Tensor,
y: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
interweaved: bool = False,
store_output_before_norm: bool = True,
bias: torch.Tensor = None,
smooth: torch.Tensor = None,
residual_output: torch.Tensor = None,
output_max: torch.Tensor = None,
) -> None:
xtorch_ops.add_rmsnorm(
x,
y,
residual_output=residual_output,
weight=weight,
eps=eps,
output=output,
)
2025-12-10 17:51:24 +08:00
2025-12-10 12:05:39 +08:00
@custom_op("_C::rmsnorm", mutates_args=())
def rmsnorm(
x: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
interweave: bool = False,
store_output_before_norm: bool = True,
bias: torch.Tensor = None,
residual_output: torch.Tensor = None,
output_max: torch.Tensor = None,
) -> None:
xtorch_ops.rmsnorm(
x,
weight,
output,
eps,
)
2025-12-10 12:05:39 +08:00
@impl("_C::rmsnorm", "CUDA")
def rmsnorm_cuda(
x: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
interweave: bool = False,
store_output_before_norm: bool = True,
bias: torch.Tensor = None,
residual_output: torch.Tensor = None,
output_max: torch.Tensor = None,
) -> None:
xtorch_ops.rmsnorm(
x,
weight,
output,
eps,
)
2025-12-10 17:51:24 +08:00
2025-12-10 17:51:24 +08:00
import torch
def _fake_rmsnorm(
x,
weight,
output,
eps=1e-5,
interweave=False,
store_output_before_norm=True,
bias=None,
residual_output=None,
output_max=None,
):
2025-12-10 17:51:24 +08:00
# 设置 shape/dtype但不返回值
2025-12-10 12:05:39 +08:00
output.fake_shape = x.shape
output.fake_dtype = x.dtype
return None
2025-12-10 12:05:39 +08:00
rmsnorm.register_fake(_fake_rmsnorm)
def _fake_add_rmsnorm(
x,
y,
weight,
output,
eps=1e-5,
interweaved=False,
store_output_before_norm=True,
bias=None,
smooth=None,
residual_output=None,
output_max=None,
):
2025-12-10 12:05:39 +08:00
output.fake_shape = x.shape
output.fake_dtype = x.dtype
return None
2025-12-10 12:05:39 +08:00
add_rmsnorm.register_fake(_fake_add_rmsnorm)
2025-12-10 17:51:24 +08:00
@custom_op("_C::split_norm_rope_neox", mutates_args=())
def split_norm_rope_neox(
q_emb: torch.Tensor,
k_emb: torch.Tensor,
v_out: torch.Tensor,
qkv: torch.Tensor,
rotary_pos_embedding: torch.Tensor,
q_norm_weight: torch.Tensor,
k_norm_weight: torch.Tensor,
positions: torch.Tensor,
num_tokens: int,
max_seqlen: int,
head_num: int,
kv_head_num: int,
head_dim: int,
rotary_dim: int,
emb_batch_size: int = 1,
2025-12-10 17:51:24 +08:00
) -> None:
xtorch_ops.split_norm_rope_neox(
q_emb,
k_emb,
v_out,
qkv,
rotary_pos_embedding,
q_norm_weight,
k_norm_weight,
positions,
num_tokens,
max_seqlen,
head_num,
kv_head_num,
head_dim,
rotary_dim,
)
2025-12-10 17:51:24 +08:00
@impl("_C::split_norm_rope_neox", "CUDA")
def split_norm_rope_neox_cuda(
q_emb: torch.Tensor,
k_emb: torch.Tensor,
v_out: torch.Tensor,
qkv: torch.Tensor,
rotary_pos_embedding: torch.Tensor,
q_norm_weight: torch.Tensor,
k_norm_weight: torch.Tensor,
positions: torch.Tensor,
num_tokens: int,
max_seqlen: int,
head_num: int,
kv_head_num: int,
head_dim: int,
rotary_dim: int,
emb_batch_size: int = 1,
2025-12-10 17:51:24 +08:00
) -> None:
xtorch_ops.split_norm_rope_neox(
q_emb,
k_emb,
v_out,
qkv,
rotary_pos_embedding,
q_norm_weight,
k_norm_weight,
positions,
num_tokens,
max_seqlen,
head_num,
kv_head_num,
head_dim,
rotary_dim,
)
2025-12-10 17:51:24 +08:00
def _fake_split_norm_rope_neox(
q_emb: torch.Tensor,
k_emb: torch.Tensor,
v_out: torch.Tensor,
qkv: torch.Tensor,
rotary_pos_embedding: torch.Tensor,
q_norm_weight: torch.Tensor,
k_norm_weight: torch.Tensor,
positions: torch.Tensor,
num_tokens: int,
max_seqlen: int,
head_num: int,
kv_head_num: int,
head_dim: int,
rotary_dim: int,
emb_batch_size: int = 1,
):
2025-12-10 17:51:24 +08:00
q_emb.fake_shape = q_emb.shape
q_emb.fake_dtype = q_emb.dtype
k_emb.fake_shape = k_emb.shape
k_emb.fake_dtype = k_emb.dtype
v_out.fake_shape = v_out.shape
v_out.fake_dtype = v_out.dtype
return None
2025-12-10 17:51:24 +08:00
split_norm_rope_neox.register_fake(_fake_split_norm_rope_neox)
2025-12-10 12:05:39 +08:00
# register fake op impl here
# for torch.dynamo
from torch.library import register_fake
2025-12-10 12:05:39 +08:00
if hasattr(torch.ops.custom_ops, "fc_fusion"):
2025-12-10 12:05:39 +08:00
@register_fake("custom_ops::fc_fusion")
def fc_fusion_fake(
self: torch.Tensor,
other: torch.Tensor,
bias: Optional[torch.Tensor],
self_trans: bool,
other_trans: bool,
*,
alpha: float = 1.0,
beta: float = 0.0,
act: int = 1,
multi_stream: bool = False,
out: torch.Tensor,
) -> None:
2025-12-10 12:05:39 +08:00
pass
@custom_op("_C::silu_and_mul", mutates_args=())
def silu_and_mul(
out: torch.Tensor, x: torch.Tensor, axis: int = -1, turn: bool = True
2025-12-10 12:05:39 +08:00
) -> None:
xtorch_ops.swiglu(
x=x,
y=out,
2025-12-10 12:05:39 +08:00
)
2025-12-10 17:51:24 +08:00
@impl("_C::silu_and_mul", "CUDA")
def silu_and_mul_cuda(
out: torch.Tensor, x: torch.Tensor, axis: int = -1, turn: bool = True
2025-12-10 12:05:39 +08:00
) -> None:
xtorch_ops.swiglu(
x=x,
y=out,
2025-12-10 12:05:39 +08:00
)
def _fake_silu_and_mul(
out: torch.Tensor, x: torch.Tensor, axis: int = -1, turn: bool = True
):
2025-12-10 12:05:39 +08:00
return None
silu_and_mul.register_fake(_fake_silu_and_mul)
2025-12-10 17:51:24 +08:00
2025-12-10 12:05:39 +08:00
@custom_op("_C::swigluoai_and_mul", mutates_args=())
def swigluoai_and_mul(
x: torch.Tensor,
alpha: float = 1.702,
limit: float = 7.0,
axis: int = -1,
turn: bool = True,
2025-12-10 12:05:39 +08:00
) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
gate, up = x[..., ::2], x[..., 1::2]
gate = gate.clamp(min=None, max=limit)
up = up.clamp(min=-limit, max=limit)
glu = gate * torch.sigmoid(gate * alpha)
gated_output = (up + 1) * glu
return gated_output
2025-12-10 12:05:39 +08:00
@impl("_C::swigluoai_and_mul", "CUDA")
def swigluoai_and_mul_cuda(
x: torch.Tensor,
alpha: float = 1.702,
limit: float = 7.0,
axis: int = -1,
turn: bool = True,
2025-12-10 12:05:39 +08:00
) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
gate, up = x[..., ::2], x[..., 1::2]
gate = gate.clamp(min=None, max=limit)
up = up.clamp(min=-limit, max=limit)
glu = gate * torch.sigmoid(gate * alpha)
gated_output = (up + 1) * glu
return gated_output
2025-12-10 12:05:39 +08:00
def _fake_swigluoai_and_mul(
x: torch.Tensor,
alpha: float = 1.702,
limit: float = 7.0,
axis: int = -1,
turn: bool = True,
2025-12-10 12:05:39 +08:00
) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
gate, up = x[..., ::2], x[..., 1::2]
gate = gate.clamp(min=None, max=limit)
up = up.clamp(min=-limit, max=limit)
glu = gate * torch.sigmoid(gate * alpha)
gated_output = (up + 1) * glu
return gated_output
2025-12-10 12:05:39 +08:00
swigluoai_and_mul.register_fake(_fake_swigluoai_and_mul)
2025-12-10 12:05:39 +08:00
@custom_op("_C::moe_softmax_topk", mutates_args=())
def moe_softmax_topk(
x: torch.Tensor,
normed_score: torch.Tensor,
topk_index: torch.Tensor,
block_statistic: torch.Tensor,
axis: int = -1,
turn: bool = True,
2025-12-10 12:05:39 +08:00
) -> None:
xtorch_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic)
2025-12-10 12:05:39 +08:00
@impl("_C::moe_softmax_topk", "CUDA")
def moe_softmax_topk_cuda(
x: torch.Tensor,
normed_score: torch.Tensor,
topk_index: torch.Tensor,
block_statistic: torch.Tensor,
axis: int = -1,
turn: bool = True,
2025-12-10 12:05:39 +08:00
) -> None:
xtorch_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic)
2025-12-10 12:05:39 +08:00
def _fake_moe_softmax_topk(
x: torch.Tensor,
normed_score: torch.Tensor,
topk_index: torch.Tensor,
block_statistic: torch.Tensor,
axis: int = -1,
turn: bool = True,
2025-12-10 12:05:39 +08:00
) -> None:
return None
2025-12-10 12:05:39 +08:00
moe_softmax_topk.register_fake(_fake_moe_softmax_topk)
2025-12-10 17:51:24 +08:00
2025-12-10 12:05:39 +08:00
@custom_op("_C::moe_ffn_block", mutates_args=())
def moe_ffn_block(
out: torch.Tensor,
x: torch.Tensor,
expert_num: int,
moe_top_k: int,
gate_w: torch.Tensor,
inter_w: torch.Tensor,
output_w: torch.Tensor,
renormalize: bool = True,
use_grouped_topk: bool = False,
expert_group_num: Optional[int] = 0,
topk_group: Optional[int] = 0,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> None:
xtorch_ops.moe_ffn_block(
x=x,
gate_w=gate_w,
inter_w=inter_w,
output_w=output_w,
expert_num=expert_num,
moe_top_k=moe_top_k,
topk_group=topk_group,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
expert_group_num=expert_group_num,
out=out,
)
@impl("_C::moe_ffn_block", "CUDA")
def moe_ffn_block_cuda(
out: torch.Tensor,
x: torch.Tensor,
expert_num: int,
moe_top_k: int,
gate_w: torch.Tensor,
inter_w: torch.Tensor,
output_w: torch.Tensor,
renormalize: bool = True,
use_grouped_topk: bool = False,
expert_group_num: Optional[int] = 0,
topk_group: Optional[int] = 0,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> None:
xtorch_ops.moe_ffn_block(
x=x,
gate_w=gate_w,
inter_w=inter_w,
output_w=output_w,
expert_num=expert_num,
moe_top_k=moe_top_k,
topk_group=topk_group,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
expert_group_num=expert_group_num,
out=out,
)
2025-12-10 12:05:39 +08:00
def _fake_moe_ffn_block(
out: torch.Tensor,
x: torch.Tensor,
expert_num: int,
moe_top_k: int,
gate_w: torch.Tensor,
inter_w: torch.Tensor,
output_w: torch.Tensor,
renormalize: bool = True,
use_grouped_topk: bool = False,
expert_group_num: Optional[int] = 0,
topk_group: Optional[int] = 0,
):
2025-12-10 12:05:39 +08:00
return None
2025-12-10 12:05:39 +08:00
moe_ffn_block.register_fake(_fake_moe_ffn_block)
2025-12-10 17:51:24 +08:00
2025-12-10 12:05:39 +08:00
@custom_op("_C::moe_ffn_per_token_block", mutates_args=())
def moe_ffn_per_token_block(
x: torch.Tensor,
inter_weight: torch.Tensor,
inter_scale: torch.Tensor,
outer_weight: torch.Tensor,
outer_scale: torch.Tensor,
top_k: int,
global_num_experts: int,
linear_weights: Optional[torch.Tensor] = None,
expert_map: Optional[torch.Tensor] = None,
activation: str = "silu",
output: Optional[torch.Tensor] = None,
use_expert_parallel: bool = False,
ep_size: int = 1,
ep_rank: int = 0,
2025-12-10 12:05:39 +08:00
) -> None:
xtorch_ops.moe_ffn_per_token_block(
x=x,
inter_weight=inter_weight,
inter_scale=inter_scale,
outer_weight=outer_weight,
outer_scale=outer_scale,
gate_weight=linear_weights,
expert_num=global_num_experts,
moe_top_k=top_k,
act_type=activation,
use_expert_parallel=use_expert_parallel,
ep_size=ep_size,
ep_rank=ep_rank,
out=output,
)
2025-12-10 12:05:39 +08:00
@impl("_C::moe_ffn_per_token_block", "CUDA")
def moe_ffn_per_token_block_cuda(
x: torch.Tensor,
inter_weight: torch.Tensor,
inter_scale: torch.Tensor,
outer_weight: torch.Tensor,
outer_scale: torch.Tensor,
top_k: int,
global_num_experts: int,
linear_weights: Optional[torch.Tensor] = None,
expert_map: Optional[torch.Tensor] = None,
activation: str = "silu",
output: Optional[torch.Tensor] = None,
use_expert_parallel: bool = False,
ep_size: int = 1,
ep_rank: int = 0,
2025-12-10 12:05:39 +08:00
) -> None:
xtorch_ops.moe_ffn_per_token_block(
x=x,
inter_weight=inter_weight,
inter_scale=inter_scale,
outer_weight=outer_weight,
outer_scale=outer_scale,
gate_weight=linear_weights,
expert_num=global_num_experts,
moe_top_k=top_k,
act_type=activation,
use_expert_parallel=use_expert_parallel,
ep_size=ep_size,
ep_rank=ep_rank,
out=output,
)
2025-12-10 12:05:39 +08:00
def _fake_moe_ffn_per_token_block(
x: torch.Tensor,
inter_weight: torch.Tensor,
inter_scale: torch.Tensor,
outer_weight: torch.Tensor,
outer_scale: torch.Tensor,
top_k: int,
global_num_experts: int,
linear_weights: Optional[torch.Tensor] = None,
expert_map: Optional[torch.Tensor] = None,
activation: str = "silu",
output: Optional[torch.Tensor] = None,
use_expert_parallel: bool = False,
ep_size: int = 1,
ep_rank: int = 0,
2025-12-10 12:05:39 +08:00
) -> None:
# Fake implementation can be a no-op or a simple operation
if output is not None:
output.copy_(x) # Example: simply copy input to output
2025-12-10 12:05:39 +08:00
# Register the fake implementation
moe_ffn_per_token_block.register_fake(_fake_moe_ffn_per_token_block)
@custom_op("_C::rotary_embedding", mutates_args=())
def rotary_embedding(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None:
2025-12-10 12:05:39 +08:00
xtorch_ops.rotary_embedding(
positions=positions,
query=query,
key=key,
head_size=head_size,
cos_sin_cache=cos_sin_cache,
is_neox=is_neox,
)
2025-12-10 12:05:39 +08:00
@impl("_C::rotary_embedding", "CUDA")
def rotary_embedding_cuda(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None:
xtorch_ops.rotary_embedding(
positions=positions,
query=query,
key=key,
head_size=head_size,
cos_sin_cache=cos_sin_cache,
is_neox=is_neox,
)
2025-12-10 12:05:39 +08:00
def _fake_rotary_embedding(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None:
2025-12-10 12:05:39 +08:00
return None
rotary_embedding.register_fake(_fake_rotary_embedding)
@custom_op("_C::gemm_I8_I8_bf16_nt", mutates_args=())
def gemm_I8_I8_bf16_nt(
x_q: torch.Tensor,
x_scale: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out: torch.Tensor,
) -> None:
xtorch_ops.gemm_I8_I8_bf16_nt(
lhs=(x_q, x_scale), rhs=(weight, weight_scale), out=out
)
@impl("_C::gemm_I8_I8_bf16_nt", "CUDA")
def gemm_I8_I8_bf16_nt_cuda(
x_q: torch.Tensor,
x_scale: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out: torch.Tensor,
) -> None:
xtorch_ops.gemm_I8_I8_bf16_nt(
lhs=(x_q, x_scale), rhs=(weight, weight_scale), out=out
)
def _fake_gemm_I8_I8_bf16_nt(
x_q: torch.Tensor,
x_scale: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out: torch.Tensor,
) -> None:
return None
gemm_I8_I8_bf16_nt.register_fake(_fake_gemm_I8_I8_bf16_nt)
2025-12-10 17:51:24 +08:00
@custom_op("_C::moe_softmax_topk_norm", mutates_args=())
def moe_softmax_topk_norm(
x: torch.Tensor,
normed_score: torch.Tensor,
topk_index: torch.Tensor,
block_statistic: torch.Tensor,
stable: bool = True,
2025-12-10 12:05:39 +08:00
) -> None:
2025-12-10 17:51:24 +08:00
xtorch_ops.moe_softmax_topk_norm(
x, normed_score, topk_index, block_statistic, stable
2025-12-10 12:05:39 +08:00
)
2025-12-10 17:51:24 +08:00
@impl("_C::moe_softmax_topk_norm", "CUDA")
def moe_softmax_topk_norm_cuda(
x: torch.Tensor,
normed_score: torch.Tensor,
topk_index: torch.Tensor,
block_statistic: torch.Tensor,
stable: bool = True,
2025-12-10 12:05:39 +08:00
) -> None:
2025-12-10 17:51:24 +08:00
xtorch_ops.moe_softmax_topk_norm(
x, normed_score, topk_index, block_statistic, stable
2025-12-10 12:05:39 +08:00
)
2025-12-10 17:51:24 +08:00
def _fake_moe_softmax_topk_norm(
x: torch.Tensor,
normed_score: torch.Tensor,
topk_index: torch.Tensor,
block_statistic: torch.Tensor,
stable: bool = True,
2025-12-10 17:51:24 +08:00
) -> None:
2025-12-10 12:05:39 +08:00
return None
2025-12-10 17:51:24 +08:00
moe_softmax_topk_norm.register_fake(_fake_moe_softmax_topk_norm)
2025-12-10 12:05:39 +08:00
2025-12-10 17:51:24 +08:00
@custom_op("_C::gen_block_statistic", mutates_args=())
def gen_block_statistic(topk_ids: torch.Tensor, block_statistic: torch.Tensor) -> None:
xtorch_ops.gen_block_statistic(topk_ids, block_statistic)
2025-12-10 12:05:39 +08:00
2025-12-10 17:51:24 +08:00
@impl("_C::gen_block_statistic", "CUDA")
def gen_block_statistic_cuda(
topk_ids: torch.Tensor, block_statistic: torch.Tensor
) -> None:
xtorch_ops.gen_block_statistic(topk_ids, block_statistic)
2025-12-10 12:05:39 +08:00
2025-12-10 17:51:24 +08:00
def fake_gen_block_statistic(
topk_ids: torch.Tensor, block_statistic: torch.Tensor
) -> None:
2025-12-10 17:51:24 +08:00
return None
2025-12-10 12:05:39 +08:00
2025-12-10 17:51:24 +08:00
gen_block_statistic.register_fake(fake_gen_block_statistic)
2025-12-10 12:05:39 +08:00
2025-12-10 17:51:24 +08:00
@custom_op("_C::moe_pre_sorted", mutates_args=())
def moe_pre_sorted(
x: torch.Tensor,
topk_index: torch.Tensor,
block_statistic: torch.Tensor,
moe_expand: torch.Tensor,
moe_index: torch.Tensor,
expert_m: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
index_have_neg: bool = False,
) -> None:
2025-12-10 17:51:24 +08:00
xtorch_ops.moe_pre_sorted(
x,
topk_index,
block_statistic,
moe_expand,
moe_index,
expert_m,
sorted_tokens_num_lod,
)
2025-12-10 17:51:24 +08:00
@impl("_C::moe_pre_sorted", "CUDA")
def moe_pre_sorted_cuda(
x: torch.Tensor,
topk_index: torch.Tensor,
block_statistic: torch.Tensor,
moe_expand: torch.Tensor,
moe_index: torch.Tensor,
expert_m: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
index_have_neg: bool = False,
) -> None:
2025-12-10 17:51:24 +08:00
xtorch_ops.moe_pre_sorted(
x,
topk_index,
block_statistic,
moe_expand,
moe_index,
expert_m,
sorted_tokens_num_lod,
)
2025-12-10 17:51:24 +08:00
def fake_moe_pre_sorted(
x: torch.Tensor,
topk_index: torch.Tensor,
block_statistic: torch.Tensor,
moe_expand: torch.Tensor,
moe_index: torch.Tensor,
expert_m: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
index_have_neg: bool = False,
) -> None:
2025-12-10 17:51:24 +08:00
return None
2025-12-10 12:05:39 +08:00
2025-12-10 17:51:24 +08:00
moe_pre_sorted.register_fake(fake_moe_pre_sorted)
2025-12-10 12:05:39 +08:00
2025-12-10 17:51:24 +08:00
@custom_op("_C::moe_fc", mutates_args=())
def moe_fc(
2025-12-10 12:05:39 +08:00
x: torch.Tensor,
2025-12-10 17:51:24 +08:00
weight: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
sorted_tokens_idx: torch.Tensor,
moe_topk: int,
y: torch.Tensor,
act: Optional[str] = None,
x_perchannel_max: Optional[torch.Tensor] = None,
w_perchannel_max: Optional[torch.Tensor] = None,
topk_ids: Optional[torch.Tensor] = None,
topk_w: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
tgemm_type: Optional[str] = None,
tweight_type: Optional[str] = None,
scale_n: Optional[int] = 0,
scale_k: Optional[int] = 0,
use_pack_int4: Optional[bool] = False,
sort_mode: Optional[bool] = True,
) -> None:
2025-12-10 17:51:24 +08:00
xtorch_ops.moe_fc(
x=x,
weight=weight,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_topk,
y=y,
act=act,
x_perchannel_max=x_perchannel_max,
w_perchannel_max=w_perchannel_max,
topk_ids=topk_ids,
topk_w=topk_w,
bias=bias,
tgemm_type=tgemm_type,
tweight_type=tweight_type,
scale_n=scale_n,
scale_k=scale_k,
use_pack_int4=use_pack_int4,
sort_mode=sort_mode,
)
2025-12-10 12:05:39 +08:00
2025-12-10 17:51:24 +08:00
@impl("_C::moe_fc", "CUDA")
def moe_fc_cuda(
2025-12-10 12:05:39 +08:00
x: torch.Tensor,
2025-12-10 17:51:24 +08:00
weight: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
sorted_tokens_idx: torch.Tensor,
moe_topk: int,
y: torch.Tensor,
act: Optional[str] = None,
x_perchannel_max: Optional[torch.Tensor] = None,
w_perchannel_max: Optional[torch.Tensor] = None,
topk_ids: Optional[torch.Tensor] = None,
topk_w: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
tgemm_type: Optional[str] = None,
tweight_type: Optional[str] = None,
scale_n: Optional[int] = 0,
scale_k: Optional[int] = 0,
use_pack_int4: Optional[bool] = False,
sort_mode: Optional[bool] = True,
) -> None:
2025-12-10 17:51:24 +08:00
xtorch_ops.moe_fc(
x=x,
weight=weight,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_topk,
y=y,
act=act,
x_perchannel_max=x_perchannel_max,
w_perchannel_max=w_perchannel_max,
topk_ids=topk_ids,
topk_w=topk_w,
bias=bias,
tgemm_type=tgemm_type,
tweight_type=tweight_type,
scale_n=scale_n,
scale_k=scale_k,
use_pack_int4=use_pack_int4,
sort_mode=sort_mode,
)
2025-12-10 17:51:24 +08:00
def fake_moe_fc(
2025-12-10 12:05:39 +08:00
x: torch.Tensor,
2025-12-10 17:51:24 +08:00
weight: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
sorted_tokens_idx: torch.Tensor,
moe_topk: int,
y: torch.Tensor,
act: Optional[str] = None,
x_perchannel_max: Optional[torch.Tensor] = None,
w_perchannel_max: Optional[torch.Tensor] = None,
topk_ids: Optional[torch.Tensor] = None,
topk_w: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
tgemm_type: Optional[str] = None,
tweight_type: Optional[str] = None,
scale_n: Optional[int] = 0,
scale_k: Optional[int] = 0,
use_pack_int4: Optional[bool] = False,
sort_mode: Optional[bool] = True,
) -> None:
2025-12-10 17:51:24 +08:00
return None
2025-12-10 12:05:39 +08:00
2025-12-10 17:51:24 +08:00
moe_fc.register_fake(fake_moe_fc)
2025-12-10 12:05:39 +08:00
2025-12-10 17:51:24 +08:00
@custom_op("_C::moe_post", mutates_args=())
def moe_post(
x: torch.Tensor,
moe_index: torch.Tensor,
normed_scale: torch.Tensor,
dequant_scale: torch.Tensor,
y: torch.Tensor,
) -> None:
xtorch_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y)
2025-12-10 12:05:39 +08:00
2025-12-10 17:51:24 +08:00
@impl("_C::moe_post", "CUDA")
def moe_post_cuda(
x: torch.Tensor,
moe_index: torch.Tensor,
normed_scale: torch.Tensor,
dequant_scale: torch.Tensor,
y: torch.Tensor,
) -> None:
xtorch_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y)
2025-12-10 12:05:39 +08:00
2025-12-10 17:51:24 +08:00
def fake_moe_post(
x: torch.Tensor,
moe_index: torch.Tensor,
normed_scale: torch.Tensor,
dequant_scale: torch.Tensor,
y: torch.Tensor,
) -> None:
2025-12-10 12:05:39 +08:00
return None
2025-12-10 17:51:24 +08:00
moe_post.register_fake(fake_moe_post)
@custom_op("_C::moe_sigmoid_group_topk_norm", mutates_args=())
def moe_sigmoid_group_topk_norm(
x: torch.Tensor,
topk_index: torch.Tensor,
norm_score: torch.Tensor,
block_static: torch.Tensor,
bias: torch.Tensor,
scale: float,
n_group: int,
topk_group: int,
) -> None:
xtorch_ops.moe_sigmoid_group_topk_norm(
x=x,
norm_score=norm_score,
topk_index=topk_index,
block_static=block_static,
bias=bias,
n_group=n_group,
topk_group=topk_group,
scale=scale,
)
@impl("_C::moe_sigmoid_group_topk_norm", "CUDA")
def moe_sigmoid_group_topk_norm_cuda(
x: torch.Tensor,
topk_index: torch.Tensor,
norm_score: torch.Tensor,
block_static: torch.Tensor,
bias: torch.Tensor,
scale: float,
n_group: int,
topk_group: int,
) -> None:
xtorch_ops.moe_sigmoid_group_topk_norm(
x=x,
norm_score=norm_score,
topk_index=topk_index,
block_static=block_static,
bias=bias,
n_group=n_group,
topk_group=topk_group,
scale=scale,
)
def _fake_moe_sigmoid_group_topk_norm(
x: torch.Tensor,
topk_index: torch.Tensor,
norm_score: torch.Tensor,
block_static: torch.Tensor,
bias: torch.Tensor,
scale: float,
n_group: int,
topk_group: int,
) -> None:
return None
moe_sigmoid_group_topk_norm.register_fake(_fake_moe_sigmoid_group_topk_norm)
##################################################
# --------------- awq_dequantize -----------------
##################################################
@custom_op("_C::awq_dequantize", mutates_args=())
def awq_dequantize(
qweight: torch.Tensor,
scales: torch.Tensor,
zeros: torch.Tensor,
quant_type: int = 0,
align_type: int = 1,
) -> torch.Tensor:
weight = torch.empty(
(qweight.shape[0], qweight.shape[1] * 8),
dtype=torch.float16,
device=qweight.device,
)
group_m = int(qweight.shape[0] / scales.shape[0])
xtorch_ops.awq_dequantize(
qweight=qweight,
scales=scales,
zeros=zeros,
weight=weight,
group_m=group_m,
quant_type=quant_type,
align_type=align_type,
)
return weight
@impl("_C::awq_dequantize", "CUDA")
def awq_dequantize_cuda(
qweight: torch.Tensor,
scales: torch.Tensor,
zeros: torch.Tensor,
quant_type: int = 0,
align_type: int = 1,
) -> torch.Tensor:
weight = torch.empty(
(qweight.shape[0], qweight.shape[1] * 8),
dtype=torch.float16,
device=qweight.device,
)
group_m = int(qweight.shape[0] / scales.shape[0])
out = xtorch_ops.awq_dequantize(
qweight=qweight,
scales=scales,
zeros=zeros,
weight=weight,
group_m=group_m,
quant_type=quant_type,
align_type=align_type,
)
return weight
def _fake_awq_dequantize(
qweight: torch.Tensor,
scales: torch.Tensor,
zeros: torch.Tensor,
quant_type: int = 0,
align_type: int = 1,
) -> torch.Tensor:
weight = torch.empty(
(qweight.shape[0], qweight.shape[1] * 8),
dtype=torch.float16,
device=qweight.device,
)
return weight
awq_dequantize.register_fake(_fake_awq_dequantize)
##################################################
# ------------------ awq_gemm -------------------
##################################################
@custom_op("_C::awq_gemm", mutates_args=())
def awq_gemm(
x: torch.Tensor,
qweight: torch.Tensor,
scale: torch.Tensor,
zeros: torch.Tensor,
align_type: int = 1,
) -> torch.Tensor:
out = torch.empty(
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
)
group_size = int(qweight.shape[0] / scale.shape[0])
xtorch_ops.awq_gemm(
x=x,
w=qweight,
scale=scale,
zeros=zeros,
out=out,
align_type=align_type,
group_size=group_size,
)
return out
@impl("_C::awq_gemm", "CUDA")
def awq_gemm_cuda(
x: torch.Tensor,
qweight: torch.Tensor,
scale: torch.Tensor,
zeros: torch.Tensor,
align_type: int = 1,
) -> torch.Tensor:
out = torch.empty(
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
)
group_size = int(qweight.shape[0] / scale.shape[0])
xtorch_ops.awq_gemm(
x=x,
w=qweight,
scale=scale,
zeros=zeros,
out=out,
align_type=align_type,
group_size=group_size,
)
return out
def _fake_awq_gemm(
x: torch.Tensor,
qweight: torch.Tensor,
scale: torch.Tensor,
zeros: torch.Tensor,
align_type: int = 1,
) -> torch.Tensor:
out = torch.empty(
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
)
return out
awq_gemm.register_fake(_fake_awq_gemm)
##################################################
# ---------------- gptq_shuffle ------------------
##################################################
@custom_op("_C::gptq_shuffle", mutates_args=())
def gptq_shuffle(
q_weight: torch.Tensor,
q_perm: torch.Tensor,
bit: int,
) -> None:
xtorch_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
@impl("_C::gptq_shuffle", "CUDA")
def gptq_shuffle_cuda(
q_weight: torch.Tensor,
q_perm: torch.Tensor,
bit: int,
) -> None:
xtorch_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
def _fake_gptq_shuffle(
q_weight: torch.Tensor,
q_perm: torch.Tensor,
bit: int,
) -> None:
return None
gptq_shuffle.register_fake(_fake_gptq_shuffle)
##################################################
# ------------- concat_and_cache_mla -------------
##################################################
@custom_op("_C::concat_and_cache_mla", mutates_args=())
def concat_and_cache_mla(
kv_c: torch.Tensor, # [num_tokens, kv_lora_rank]
k_pe: torch.Tensor, # [num_tokens, pe_dim]
kv_cache: torch.Tensor, # [num_blocks, block_size, (kv_lora_rank + pe_dim)]
slot_mapping: torch.Tensor, # [num_tokens] or [num_actual_tokens]
) -> None:
xtorch_ops.concat_and_cache_mla(
kv_c=kv_c,
k_pe=k_pe,
slot_mapping=slot_mapping,
kv_cache=kv_cache,
)
@impl("_C::concat_and_cache_mla", "CUDA")
def concat_and_cache_mla_cuda(
kv_c: torch.Tensor, # [num_tokens, kv_lora_rank]
k_pe: torch.Tensor, # [num_tokens, pe_dim]
kv_cache: torch.Tensor, # [num_blocks, block_size, (kv_lora_rank + pe_dim)]
slot_mapping: torch.Tensor, # [num_tokens] or [num_actual_tokens]
) -> None:
xtorch_ops.concat_and_cache_mla(
kv_c=kv_c,
k_pe=k_pe,
slot_mapping=slot_mapping,
kv_cache=kv_cache,
)
def _fake_concat_and_cache_mla(
kv_c: torch.Tensor, # [num_tokens, kv_lora_rank]
k_pe: torch.Tensor, # [num_tokens, pe_dim]
kv_cache: torch.Tensor, # [num_blocks, block_size, (kv_lora_rank + pe_dim)]
slot_mapping: torch.Tensor, # [num_tokens] or [num_actual_tokens]
) -> None:
return None
concat_and_cache_mla.register_fake(_fake_concat_and_cache_mla)
######################################################
# -------------- scaled_int8_quant -------------------
######################################################
@custom_op("_C::scaled_int8_quant", mutates_args=())
def scaled_int8_quant(
x: torch.Tensor,
scale: torch.Tensor,
azp: Optional[torch.Tensor] = None,
symmetric: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]:
static = False
x_q = torch.empty_like(x, dtype=torch.int8, device=x.device)
if scale is not None: # static
static = True
torch.ops.xspeedgate_ops.static_scaled_int8_quant(x_q, x, scale, azp)
else: # dynamic
scale = torch.empty(
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
)
azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32)
if symmetric:
# NOTE: For quant2d ops, scale represents max.
xtorch_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True)
else:
torch.ops.xspeedgate_ops.dynamic_scaled_int8_quant(
x_q, x.contiguous(), scale, azp
)
return x_q, scale, azp, static
@impl("_C::scaled_int8_quant", "CUDA")
def scaled_int8_quant_cuda(
x: torch.Tensor,
scale: torch.Tensor,
azp: Optional[torch.Tensor] = None,
symmetric: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]:
static = False
x_q = torch.empty_like(x, dtype=torch.int8, device=x.device)
if scale is not None: # static
static = True
torch.ops.xspeedgate_ops.static_scaled_int8_quant(x_q, x, scale, azp)
else: # dynamic
scale = torch.empty(
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
)
azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32)
if symmetric:
# NOTE: For quant2d ops, scale represents max.
xtorch_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True)
else:
torch.ops.xspeedgate_ops.dynamic_scaled_int8_quant(
x_q, x.contiguous(), scale, azp
)
return x_q, scale, azp, static
2026-01-06 16:07:29 +08:00
def _fake_scaled_int8_quant(
x: torch.Tensor,
scale: torch.Tensor,
azp: Optional[torch.Tensor] = None,
symmetric: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]:
2026-01-06 16:07:29 +08:00
x_q = torch.empty_like(x, dtype=torch.int8, device=x.device)
scale = torch.empty(
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
)
azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32)
return x_q, scale, azp, False
2026-01-06 16:07:29 +08:00
scaled_int8_quant.register_fake(_fake_scaled_int8_quant)
######################################################
# ---------------- cutlass_scaled_mm -----------------
######################################################
@custom_op("_C::cutlass_scaled_mm", mutates_args=())
def cutlass_scaled_mm(
a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
torch.ops.xspeedgate_ops.cutlass_scaled_mm(
out, a.contiguous(), b.contiguous(), scale_a, scale_b, bias
)
return out
@impl("_C::cutlass_scaled_mm", "CUDA")
def cutlass_scaled_mm_cuda(
a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
torch.ops.xspeedgate_ops.cutlass_scaled_mm(
out, a.contiguous(), b.contiguous(), scale_a, scale_b, bias
)
return out
def fake_cutlass_scaled_mm(
a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
cutlass_scaled_mm.register_fake(fake_cutlass_scaled_mm)
######################################################
# ------------ cutlass_scaled_mm_azp -----------------
######################################################
@custom_op("_C::cutlass_scaled_mm_azp", mutates_args=())
def cutlass_scaled_mm_azp(
a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
azp_adj: torch.Tensor,
azp: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
torch.ops.xspeedgate_ops.cutlass_scaled_mm_azp(
out, a.contiguous(), b.contiguous(), scale_a, scale_b, azp_adj, azp, bias
)
return out
@impl("_C::cutlass_scaled_mm_azp", "CUDA")
def cutlass_scaled_mm_azp_cuda(
a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
azp_adj: torch.Tensor,
azp: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
torch.ops.xspeedgate_ops.cutlass_scaled_mm_azp(
out, a.contiguous(), b.contiguous(), scale_a, scale_b, azp_adj, azp, bias
)
return out
def fake_cutlass_scaled_mm_azp(
a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
azp_adj: torch.Tensor,
azp: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
cutlass_scaled_mm_azp.register_fake(fake_cutlass_scaled_mm_azp)
##################################################
# ------------------ matmul ---------------------
##################################################
@custom_op("_C::matmul", mutates_args=())
def matmul(
x: torch.Tensor,
w: torch.Tensor,
out_dtype: torch.dtype,
x_trans: bool = False,
w_trans: bool = True,
alpha: float = 1.0,
beta: float = 0.0,
bias: torch.Tensor = None,
x_max: torch.Tensor = None,
w_max: torch.Tensor = None,
x_pc_max: torch.Tensor = None,
w_pc_max: torch.Tensor = None,
) -> torch.Tensor:
out = torch.empty(
(x.shape[0], w.shape[0] if w_trans else w.shape[1]),
dtype=out_dtype,
device=x.device,
)
xtorch_ops.matmul(
x=x.contiguous(),
w=w.contiguous(),
out=out,
x_trans=x_trans,
w_trans=w_trans,
alpha=alpha,
beta=beta,
bias=bias,
x_max=x_max,
w_max=w_max,
x_pc_max=x_pc_max,
w_pc_max=w_pc_max,
)
return out
@impl("_C::matmul", "CUDA")
def matmul_cuda(
x: torch.Tensor,
w: torch.Tensor,
out_dtype: torch.dtype,
x_trans: bool = False,
w_trans: bool = True,
alpha: float = 1.0,
beta: float = 0.0,
bias: torch.Tensor = None,
x_max: torch.Tensor = None,
w_max: torch.Tensor = None,
x_pc_max: torch.Tensor = None,
w_pc_max: torch.Tensor = None,
) -> torch.Tensor:
out = torch.empty(
(x.shape[0], w.shape[0] if w_trans else w.shape[1]),
dtype=out_dtype,
device=x.device,
)
xtorch_ops.matmul(
x=x.contiguous(),
w=w.contiguous(),
out=out,
x_trans=x_trans,
w_trans=w_trans,
alpha=alpha,
beta=beta,
bias=bias,
x_max=x_max,
w_max=w_max,
x_pc_max=x_pc_max,
w_pc_max=w_pc_max,
)
return out
def _fake_matmul(
x: torch.Tensor,
w: torch.Tensor,
out_dtype: torch.dtype,
x_trans: bool = False,
w_trans: bool = True,
alpha: float = 1.0,
beta: float = 0.0,
bias: torch.Tensor = None,
x_max: torch.Tensor = None,
w_max: torch.Tensor = None,
x_pc_max: torch.Tensor = None,
w_pc_max: torch.Tensor = None,
) -> torch.Tensor:
return torch.empty(
2026-01-06 16:07:29 +08:00
(x.shape[0], w.shape[0] if w_trans else w.shape[1]),
dtype=out_dtype,
device=x.device,
)
matmul.register_fake(_fake_matmul)
##################################################
# ------------------- quant2d --------------------
##################################################
@custom_op("_C::quant2d", mutates_args=())
def quant2d(
x: torch.Tensor,
x_q: torch.Tensor,
max: torch.Tensor,
force_sdnn: bool = False,
) -> None:
xtorch_ops.quant2d(
x=x,
y=x_q,
max=max,
force_sdnn=force_sdnn,
)
@impl("_C::quant2d", "CUDA")
def quant2d_cuda(
x: torch.Tensor,
x_q: torch.Tensor,
max: torch.Tensor,
force_sdnn: bool = False,
) -> None:
xtorch_ops.quant2d(
x=x,
y=x_q,
max=max,
force_sdnn=force_sdnn,
)
def _fake_quant2d(
x: torch.Tensor,
x_q: torch.Tensor,
max: torch.Tensor,
force_sdnn: bool = False,
) -> None:
return None
quant2d.register_fake(_fake_quant2d)
##################################################
# --------------- penalties -----------------
##################################################
@custom_op("_C::apply_repetition_penalties_", mutates_args=())
def apply_repetition_penalties_(
logits: torch.Tensor,
prompt_mask: torch.Tensor,
output_mask: torch.Tensor,
repetition_penalties: torch.Tensor
) -> None:
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
1, logits.size(1))
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
1.0)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
logits *= scaling
@impl("_C::apply_repetition_penalties_", "CUDA")
def apply_repetition_penalties_(
logits: torch.Tensor,
prompt_mask: torch.Tensor,
output_mask: torch.Tensor,
repetition_penalties: torch.Tensor
) -> None:
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
1, logits.size(1))
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
1.0)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
logits *= scaling
##################################################
# --------------- I8_mqa_logits -----------------
##################################################
@custom_op("_C::I8_mqa_logits", mutates_args=())
def I8_mqa_logits(
q: torch.Tensor,
fused_kv_cache: List[torch.Tensor],
weights: torch.Tensor,
context_q_lens: List[torch.Tensor],
context_k_lens: List[torch.Tensor],
logits: torch.Tensor,
clean_logits: bool,
max_seq_q: Optional[int] = 0,
max_seq_k: Optional[int] = 0,
is_causal: Optional[bool] = False,
use_xfa_boost: Optional[bool] = False,
) -> None:
xtorch_ops.I8_mqa_logits(
q=q,
fused_kv_cache=fused_kv_cache,
weights=weights,
context_q_lens=context_q_lens,
context_k_lens=context_k_lens,
logits=logits,
clean_logits=clean_logits,
max_seq_q=max_seq_q,
max_seq_k=max_seq_k,
is_causal=is_causal,
use_xfa_boost=use_xfa_boost,
)
return None
@impl("_C::I8_mqa_logits", "CUDA")
def I8_mqa_logits_cuda(
q: torch.Tensor,
fused_kv_cache: List[torch.Tensor],
weights: torch.Tensor,
context_q_lens: List[torch.Tensor],
context_k_lens: List[torch.Tensor],
logits: torch.Tensor,
clean_logits: bool,
max_seq_q: Optional[int] = 0,
max_seq_k: Optional[int] = 0,
is_causal: Optional[bool] = False,
use_xfa_boost: Optional[bool] = False,
) -> None:
xtorch_ops.I8_mqa_logits(
q=q,
fused_kv_cache=fused_kv_cache,
weights=weights,
context_q_lens=context_q_lens,
context_k_lens=context_k_lens,
logits=logits,
clean_logits=clean_logits,
max_seq_q=max_seq_q,
max_seq_k=max_seq_k,
is_causal=is_causal,
use_xfa_boost=use_xfa_boost,
)
return None
def _fake_I8_mqa_logits(
q: torch.Tensor,
fused_kv_cache: List[torch.Tensor],
weights: torch.Tensor,
context_q_lens: List[torch.Tensor],
context_k_lens: List[torch.Tensor],
logits: torch.Tensor,
clean_logits: bool,
max_seq_q: Optional[int] = 0,
max_seq_k: Optional[int] = 0,
is_causal: Optional[bool] = False,
use_xfa_boost: Optional[bool] = False,
) -> None:
return None
I8_mqa_logits.register_fake(_fake_I8_mqa_logits)
##################################################
# ------------- I8_paged_mqa_logits --------------
##################################################
@custom_op("_C::I8_paged_mqa_logits", mutates_args=())
def I8_paged_mqa_logits(
q: torch.Tensor,
fused_kv_cache: List[torch.Tensor],
weights: torch.Tensor,
context_lens: List[torch.Tensor],
block_table: torch.Tensor,
max_context_len: int,
clean_logits: bool,
out: torch.Tensor,
use_xfa_boost: Optional[bool] = False) -> None:
xtorch_ops.I8_paged_mqa_logits(
q=q,
fused_kv_cache=fused_kv_cache,
weights=weights,
context_lens=context_lens,
block_table=block_table,
max_context_len=max_context_len,
clean_logits=clean_logits,
out=out,
use_xfa_boost=use_xfa_boost)
return None
@impl("_C::I8_paged_mqa_logits", "CUDA")
def I8_paged_mqa_logits_cuda(
q: torch.Tensor,
fused_kv_cache: List[torch.Tensor],
weights: torch.Tensor,
context_lens: List[torch.Tensor],
block_table: torch.Tensor,
max_context_len: int,
clean_logits: bool,
out: torch.Tensor,
use_xfa_boost: Optional[bool] = False) -> None:
xtorch_ops.I8_paged_mqa_logits(
q=q,
fused_kv_cache=fused_kv_cache,
weights=weights,
context_lens=context_lens,
block_table=block_table,
max_context_len=max_context_len,
clean_logits=clean_logits,
out=out,
use_xfa_boost=use_xfa_boost)
return None
def _fake_I8_paged_mqa_logits(
q: torch.Tensor,
fused_kv_cache: List[torch.Tensor],
weights: torch.Tensor,
context_lens: List[torch.Tensor],
block_table: torch.Tensor,
max_context_len: int,
clean_logits: bool,
out: torch.Tensor,
use_xfa_boost: Optional[bool] = False) -> None:
return None
I8_paged_mqa_logits.register_fake(_fake_I8_paged_mqa_logits)
##################################################
# ----------- sparse_prefill_fwd_opt -------------
##################################################
@custom_op("_C::sparse_prefill_fwd_opt", mutates_args=())
def sparse_prefill_fwd_opt(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
out: torch.Tensor,
max_logits: torch.Tensor,
lse: torch.Tensor,
sm_scale: float,
qlod_cpu: Optional[torch.Tensor] = None,
qlod_xpu: Optional[torch.Tensor] = None,
kvlod_cpu: Optional[torch.Tensor] = None,
kvlod_xpu: Optional[torch.Tensor] = None,
d_v: Optional[int] = -1,
is_causal: Optional[bool] = True,
use_xfa_boost: Optional[bool] = False) -> None:
xtorch_ops.sparse_prefill_fwd_opt(
q=q,
kv=kv,
indices=indices,
out=out,
max_logits=max_logits,
lse=lse,
sm_scale=sm_scale,
qlod_cpu=qlod_cpu,
qlod_xpu=qlod_xpu,
kvlod_cpu=kvlod_cpu,
kvlod_xpu=kvlod_xpu,
d_v=d_v,
is_causal=is_causal,
use_xfa_boost=use_xfa_boost)
return None
@impl("_C::sparse_prefill_fwd_opt", "CUDA")
def sparse_prefill_fwd_opt_cuda(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
out: torch.Tensor,
max_logits: torch.Tensor,
lse: torch.Tensor,
sm_scale: float,
qlod_cpu: Optional[torch.Tensor] = None,
qlod_xpu: Optional[torch.Tensor] = None,
kvlod_cpu: Optional[torch.Tensor] = None,
kvlod_xpu: Optional[torch.Tensor] = None,
d_v: Optional[int] = -1,
is_causal: Optional[bool] = True,
use_xfa_boost: Optional[bool] = False) -> None:
xtorch_ops.sparse_prefill_fwd_opt(
q=q,
kv=kv,
indices=indices,
out=out,
max_logits=max_logits,
lse=lse,
sm_scale=sm_scale,
qlod_cpu=qlod_cpu,
qlod_xpu=qlod_xpu,
kvlod_cpu=kvlod_cpu,
kvlod_xpu=kvlod_xpu,
d_v=d_v,
is_causal=is_causal,
use_xfa_boost=use_xfa_boost)
return None
def _fake_sparse_prefill_fwd_opt(
q: torch.Tensor,
kv: torch.Tensor,
indices: torch.Tensor,
out: torch.Tensor,
max_logits: torch.Tensor,
lse: torch.Tensor,
sm_scale: float,
qlod_cpu: Optional[torch.Tensor] = None,
qlod_xpu: Optional[torch.Tensor] = None,
kvlod_cpu: Optional[torch.Tensor] = None,
kvlod_xpu: Optional[torch.Tensor] = None,
d_v: Optional[int] = -1,
is_causal: Optional[bool] = True,
use_xfa_boost: Optional[bool] = False) -> None:
return None
sparse_prefill_fwd_opt.register_fake(_fake_sparse_prefill_fwd_opt)
##################################################
# ------------------ fwd_kvcache_mla -------------
##################################################
@custom_op("_C::fwd_kvcache_mla", mutates_args=())
def fwd_kvcache_mla(
q_c: torch.Tensor,
kv_cache: torch.Tensor,
indices: torch.Tensor,
kv_lod_cpu: torch.Tensor,
out: torch.Tensor,
max_logits: torch.Tensor,
p_sums: torch.Tensor,
softmax_scale: float,
max_seq_kv: int,
q_r: Optional[torch.Tensor] = None,
pe_cache: Optional[torch.Tensor] = None,
use_xfa_boost: Optional[bool] = False,
kv_lod_xpu: Optional[torch.Tensor] = None) -> None:
xtorch_ops.fwd_kvcache_mla(
q_c=q_c,
kv_cache=kv_cache,
indices=indices,
kv_lod_cpu=kv_lod_cpu,
out=out,
max_logits=max_logits,
p_sums=p_sums,
softmax_scale=softmax_scale,
max_seq_kv=max_seq_kv,
q_r=q_r,
pe_cache=pe_cache,
use_xfa_boost=use_xfa_boost,
kv_lod_xpu=kv_lod_xpu)
return None
@impl("_C::fwd_kvcache_mla", "CUDA")
def fwd_kvcache_mla_cuda(
q_c: torch.Tensor,
kv_cache: torch.Tensor,
indices: torch.Tensor,
kv_lod_cpu: torch.Tensor,
out: torch.Tensor,
max_logits: torch.Tensor,
p_sums: torch.Tensor,
softmax_scale: float,
max_seq_kv: int,
q_r: Optional[torch.Tensor] = None,
pe_cache: Optional[torch.Tensor] = None,
use_xfa_boost: Optional[bool] = False,
kv_lod_xpu: Optional[torch.Tensor] = None) -> None:
xtorch_ops.fwd_kvcache_mla(
q_c=q_c,
kv_cache=kv_cache,
indices=indices,
kv_lod_cpu=kv_lod_cpu,
out=out,
max_logits=max_logits,
p_sums=p_sums,
softmax_scale=softmax_scale,
max_seq_kv=max_seq_kv,
q_r=q_r,
pe_cache=pe_cache,
use_xfa_boost=use_xfa_boost,
kv_lod_xpu=kv_lod_xpu)
return None
def _fake_fwd_kvcache_mla(
q_c: torch.Tensor,
kv_cache: torch.Tensor,
indices: torch.Tensor,
kv_lod_cpu: torch.Tensor,
out: torch.Tensor,
max_logits: torch.Tensor,
p_sums: torch.Tensor,
softmax_scale: float,
max_seq_kv: int,
q_r: Optional[torch.Tensor] = None,
pe_cache: Optional[torch.Tensor] = None,
use_xfa_boost: Optional[bool] = False,
kv_lod_xpu: Optional[torch.Tensor] = None) -> None:
return None
fwd_kvcache_mla.register_fake(_fake_fwd_kvcache_mla)