2025-12-10 12:05:39 +08:00
|
|
|
|
"""vllm_utils_wrapper.py"""
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
import inspect
|
2026-02-11 12:04:14 +08:00
|
|
|
|
import socket
|
2025-12-10 12:05:39 +08:00
|
|
|
|
import typing
|
2026-02-11 12:04:14 +08:00
|
|
|
|
from types import SimpleNamespace
|
|
|
|
|
|
from typing import Any, Callable, List, Optional, Tuple, Union, get_args, get_origin
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
import vllm.distributed.parallel_state as parallel_state
|
2026-01-08 11:05:48 +08:00
|
|
|
|
import vllm.envs as envs
|
2026-02-11 12:04:14 +08:00
|
|
|
|
import vllm.utils as _orig
|
|
|
|
|
|
from torch.library import Library, register_fake
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2026-02-11 18:52:18 +08:00
|
|
|
|
try:
|
|
|
|
|
|
import vllm_kunlun._kunlun # noqa: F401
|
|
|
|
|
|
except ImportError as e:
|
|
|
|
|
|
try:
|
|
|
|
|
|
from . import _kunlun # noqa: F401, F403
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
print(f"Warning: Failed to load vllm_kunlun native extension: {e}")
|
|
|
|
|
|
|
2026-01-26 18:56:05 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
def patch_annotations_for_schema(func):
|
|
|
|
|
|
"""patch_annotations_for_schema"""
|
|
|
|
|
|
sig = inspect.signature(func)
|
|
|
|
|
|
new_params = []
|
|
|
|
|
|
|
|
|
|
|
|
for name, param in sig.parameters.items():
|
|
|
|
|
|
ann = param.annotation
|
|
|
|
|
|
|
|
|
|
|
|
if get_origin(ann) is typing.Union and type(None) in get_args(ann):
|
|
|
|
|
|
inner_type = [a for a in get_args(ann) if a is not type(None)][0]
|
|
|
|
|
|
if get_origin(inner_type) is list: # Optional[list[int]]
|
|
|
|
|
|
inner_args = get_args(inner_type)
|
|
|
|
|
|
new_ann = Optional[List[inner_args[0] if inner_args else typing.Any]]
|
|
|
|
|
|
param = param.replace(annotation=new_ann)
|
|
|
|
|
|
|
|
|
|
|
|
elif get_origin(ann) is list:
|
|
|
|
|
|
args = get_args(ann)
|
|
|
|
|
|
new_ann = List[args[0] if args else typing.Any]
|
|
|
|
|
|
param = param.replace(annotation=new_ann)
|
|
|
|
|
|
|
|
|
|
|
|
new_params.append(param)
|
|
|
|
|
|
|
|
|
|
|
|
func.__signature__ = sig.replace(parameters=new_params)
|
|
|
|
|
|
return func
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
def supports_custom_op() -> bool:
|
|
|
|
|
|
"""supports_custom_op"""
|
|
|
|
|
|
return hasattr(torch.library, "custom_op")
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
vllm_lib = Library("vllm", "FRAGMENT") # noqa
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
def direct_register_custom_op(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
op_name: str,
|
|
|
|
|
|
op_func: Callable,
|
|
|
|
|
|
mutates_args: Optional[list[str]] = None,
|
|
|
|
|
|
fake_impl: Optional[Callable] = None,
|
|
|
|
|
|
target_lib: Optional[Library] = None,
|
|
|
|
|
|
dispatch_key: str = "CUDA",
|
|
|
|
|
|
tags: tuple[torch.Tag, ...] = (),
|
2025-12-10 12:05:39 +08:00
|
|
|
|
):
|
|
|
|
|
|
"""
|
|
|
|
|
|
`torch.library.custom_op` can have significant overhead because it
|
|
|
|
|
|
needs to consider complicated dispatching logic. This function
|
|
|
|
|
|
directly registers a custom op and dispatches it to the CUDA backend.
|
|
|
|
|
|
See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
|
|
|
|
|
|
for more details.
|
|
|
|
|
|
|
|
|
|
|
|
By default, the custom op is registered to the vLLM library. If you
|
|
|
|
|
|
want to register it to a different library, you can pass the library
|
|
|
|
|
|
object to the `target_lib` argument.
|
|
|
|
|
|
|
|
|
|
|
|
IMPORTANT: the lifetime of the operator is tied to the lifetime of the
|
|
|
|
|
|
library object. If you want to bind the operator to a different library,
|
|
|
|
|
|
make sure the library object is alive when the operator is used.
|
|
|
|
|
|
"""
|
|
|
|
|
|
if not supports_custom_op():
|
|
|
|
|
|
from vllm.platforms import current_platform
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
assert not current_platform.is_cuda_alike(), (
|
|
|
|
|
|
"cuda platform needs torch>=2.4 to support custom op, "
|
|
|
|
|
|
"chances are you are using an old version of pytorch "
|
|
|
|
|
|
"or a custom build of pytorch. It is recommended to "
|
|
|
|
|
|
"use vLLM in a fresh new environment and let it install "
|
2026-01-06 13:51:53 +08:00
|
|
|
|
"the required dependencies."
|
|
|
|
|
|
)
|
2025-12-10 12:05:39 +08:00
|
|
|
|
return
|
2025-12-10 17:51:24 +08:00
|
|
|
|
if mutates_args is None:
|
|
|
|
|
|
mutates_args = []
|
2025-12-10 12:05:39 +08:00
|
|
|
|
import torch.library
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
if hasattr(torch.library, "infer_schema"):
|
2026-02-11 12:04:14 +08:00
|
|
|
|
patch_annotations_for_schema(op_func)
|
2026-01-06 13:51:53 +08:00
|
|
|
|
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
|
2025-12-10 12:05:39 +08:00
|
|
|
|
else:
|
|
|
|
|
|
# for pytorch 2.4
|
|
|
|
|
|
import torch._custom_op.impl
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
|
|
|
|
|
|
my_lib = target_lib or vllm_lib
|
|
|
|
|
|
my_lib.define(op_name + schema_str, tags=tags)
|
|
|
|
|
|
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
|
|
|
|
|
|
if fake_impl is not None:
|
|
|
|
|
|
my_lib._register_fake(op_name, fake_impl)
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
def vllm_kunlun_weak_ref_tensor(tensor: Any) -> Any:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Create a weak reference to a tensor.
|
|
|
|
|
|
The new tensor will share the same data as the original tensor,
|
|
|
|
|
|
but will not keep the original tensor alive.
|
|
|
|
|
|
"""
|
|
|
|
|
|
# return tensor
|
2025-12-10 17:51:24 +08:00
|
|
|
|
if isinstance(tensor, torch.Tensor):
|
|
|
|
|
|
return torch.ops._kunlun.weak_ref_tensor(tensor)
|
|
|
|
|
|
else:
|
|
|
|
|
|
return tensor
|
2025-12-10 12:05:39 +08:00
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
def vllm_kunlun_weak_ref_tensors(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]],
|
2025-12-10 12:05:39 +08:00
|
|
|
|
) -> Union[torch.Tensor, list[Any], tuple[Any], Any]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Convenience function to create weak references to tensors,
|
|
|
|
|
|
for single tensor, list of tensors or tuple of tensors.
|
|
|
|
|
|
"""
|
|
|
|
|
|
if isinstance(tensors, torch.Tensor):
|
|
|
|
|
|
return vllm_kunlun_weak_ref_tensor(tensors)
|
|
|
|
|
|
if isinstance(tensors, list):
|
|
|
|
|
|
return [vllm_kunlun_weak_ref_tensor(t) for t in tensors]
|
|
|
|
|
|
if isinstance(tensors, tuple):
|
|
|
|
|
|
return tuple(vllm_kunlun_weak_ref_tensor(t) for t in tensors)
|
|
|
|
|
|
raise ValueError("Invalid type for tensors")
|
|
|
|
|
|
|
2026-01-26 18:56:05 +08:00
|
|
|
|
|
|
|
|
|
|
vllm_port = envs.VLLM_PORT
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-01-08 11:05:48 +08:00
|
|
|
|
def _get_open_port() -> int:
|
|
|
|
|
|
global vllm_port
|
|
|
|
|
|
try:
|
|
|
|
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
|
|
|
|
s.bind(("", vllm_port))
|
|
|
|
|
|
vllm_port += 1
|
|
|
|
|
|
return vllm_port
|
|
|
|
|
|
except OSError:
|
|
|
|
|
|
# try ipv6
|
|
|
|
|
|
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
|
|
|
|
|
s.bind(("", 0))
|
|
|
|
|
|
return s.getsockname()[1]
|
2025-12-10 12:05:39 +08:00
|
|
|
|
|
2026-01-26 18:56:05 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
_wrapped = SimpleNamespace(**_orig.__dict__)
|
|
|
|
|
|
_wrapped.direct_register_custom_op = direct_register_custom_op
|
|
|
|
|
|
_wrapped.weak_ref_tensor = vllm_kunlun_weak_ref_tensor
|
|
|
|
|
|
_wrapped.weak_ref_tensors = vllm_kunlun_weak_ref_tensors
|
2026-01-08 11:05:48 +08:00
|
|
|
|
_wrapped._get_open_port = _get_open_port
|
2025-12-10 12:05:39 +08:00
|
|
|
|
|
2026-02-11 12:04:14 +08:00
|
|
|
|
import sys # noqa: E402
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
sys.modules["vllm.utils"] = _wrapped
|
|
|
|
|
|
|
|
|
|
|
|
_original_all_reduce = parallel_state.GroupCoordinator.all_reduce
|
|
|
|
|
|
_original_all_gather = parallel_state.GroupCoordinator.all_gather
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
def vllm_kunlun_all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
|
"""vllm_kunlun_all_reduce"""
|
|
|
|
|
|
if self.world_size == 1:
|
|
|
|
|
|
return input_
|
|
|
|
|
|
|
|
|
|
|
|
torch.distributed.all_reduce(input_, group=self.device_group)
|
|
|
|
|
|
return input_
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
def vllm_kunlun_all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
|
|
|
|
|
"""vllm_kunlun_all_reduce"""
|
|
|
|
|
|
world_size = self.world_size
|
|
|
|
|
|
# Bypass the function if we are using only 1 GPU.
|
|
|
|
|
|
if world_size == 1:
|
|
|
|
|
|
return input_
|
2026-01-06 13:51:53 +08:00
|
|
|
|
assert (
|
|
|
|
|
|
-input_.dim() <= dim < input_.dim()
|
|
|
|
|
|
), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
|
2025-12-10 12:05:39 +08:00
|
|
|
|
|
|
|
|
|
|
if dim < 0:
|
|
|
|
|
|
# Convert negative dim to positive.
|
|
|
|
|
|
dim += input_.dim()
|
|
|
|
|
|
input_size = input_.size()
|
|
|
|
|
|
# Allocate output tensor.
|
2026-01-06 13:51:53 +08:00
|
|
|
|
output_tensor = torch.empty(
|
|
|
|
|
|
(world_size,) + input_size, dtype=input_.dtype, device=input_.device
|
|
|
|
|
|
)
|
2025-12-10 12:05:39 +08:00
|
|
|
|
# All-gather.
|
2026-01-06 13:51:53 +08:00
|
|
|
|
torch.distributed.all_gather_into_tensor(
|
|
|
|
|
|
output_tensor, input_, group=self.device_group
|
|
|
|
|
|
)
|
2025-12-10 12:05:39 +08:00
|
|
|
|
# Reshape
|
|
|
|
|
|
output_tensor = output_tensor.movedim(0, dim)
|
2026-01-06 13:51:53 +08:00
|
|
|
|
output_tensor = output_tensor.reshape(
|
|
|
|
|
|
input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
|
|
|
|
|
|
)
|
2025-12-10 12:05:39 +08:00
|
|
|
|
return output_tensor
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
parallel_state.GroupCoordinator.all_reduce = vllm_kunlun_all_reduce
|
|
|
|
|
|
parallel_state.GroupCoordinator.all_gather = vllm_kunlun_all_gather
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-02-11 12:04:14 +08:00
|
|
|
|
from typing import Optional # noqa: E402
|
|
|
|
|
|
|
|
|
|
|
|
import torch # noqa: E402
|
|
|
|
|
|
from torch.library import custom_op, impl # noqa: E402
|
2025-12-10 12:05:39 +08:00
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
@custom_op("_C::rms_norm", mutates_args=())
|
|
|
|
|
|
def rms_norm(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
result: torch.Tensor,
|
2025-12-10 12:05:39 +08:00
|
|
|
|
input: torch.Tensor,
|
|
|
|
|
|
residual: torch.Tensor,
|
|
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
scale: torch.Tensor,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
epsilon: float,
|
|
|
|
|
|
) -> None:
|
2025-12-10 12:05:39 +08:00
|
|
|
|
pass
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
@custom_op("_C::fused_add_rms_norm", mutates_args=())
|
|
|
|
|
|
def fused_add_rms_norm(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
result: torch.Tensor,
|
2025-12-10 12:05:39 +08:00
|
|
|
|
input: torch.Tensor,
|
|
|
|
|
|
residual: torch.Tensor,
|
|
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
scale: torch.Tensor,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
epsilon: float,
|
|
|
|
|
|
) -> None:
|
2025-12-10 12:05:39 +08:00
|
|
|
|
pass
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
@custom_op("_C::static_scaled_fp8_quant", mutates_args=())
|
|
|
|
|
|
def static_scaled_fp8_quant(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
result: torch.Tensor, input: torch.Tensor, scale: torch.Tensor
|
|
|
|
|
|
) -> None:
|
2025-12-10 12:05:39 +08:00
|
|
|
|
pass
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
@impl("_C::static_scaled_fp8_quant", "CUDA")
|
|
|
|
|
|
def static_scaled_fp8_quant_xpu(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
result: torch.Tensor, input: torch.Tensor, scale: torch.Tensor
|
|
|
|
|
|
) -> None:
|
2025-12-10 12:05:39 +08:00
|
|
|
|
pass
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
@custom_op("_C::dynamic_scaled_fp8_quant", mutates_args=())
|
|
|
|
|
|
def dynamic_scaled_fp8_quant(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
result: torch.Tensor, input: torch.Tensor, scale: torch.Tensor
|
|
|
|
|
|
) -> None:
|
2025-12-10 12:05:39 +08:00
|
|
|
|
pass
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
@impl("_C::dynamic_scaled_fp8_quant", "CUDA")
|
|
|
|
|
|
def dynamic_scaled_fp8_quant_xpu(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
result: torch.Tensor, input: torch.Tensor, scale: torch.Tensor
|
|
|
|
|
|
) -> None:
|
2025-12-10 12:05:39 +08:00
|
|
|
|
pass
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
@custom_op("_C::dynamic_per_token_scaled_fp8_quant", mutates_args=())
|
|
|
|
|
|
def dynamic_per_token_scaled_fp8_quant(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
result: torch.Tensor,
|
2025-12-10 12:05:39 +08:00
|
|
|
|
input: torch.Tensor,
|
|
|
|
|
|
scale: torch.Tensor,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
scale_ub: Optional[torch.Tensor],
|
|
|
|
|
|
) -> None:
|
2025-12-10 12:05:39 +08:00
|
|
|
|
pass
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
@impl("_C::dynamic_per_token_scaled_fp8_quant", "CUDA")
|
|
|
|
|
|
def dynamic_per_token_scaled_fp8_quant_xpu(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
result: torch.Tensor,
|
2025-12-10 12:05:39 +08:00
|
|
|
|
input: torch.Tensor,
|
|
|
|
|
|
scale: torch.Tensor,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
scale_ub: Optional[torch.Tensor],
|
|
|
|
|
|
) -> None:
|
2025-12-10 12:05:39 +08:00
|
|
|
|
pass
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
@custom_op("_C::rms_norm_static_fp8_quant", mutates_args=())
|
|
|
|
|
|
def rms_norm_static_fp8_quant(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
result: torch.Tensor,
|
2025-12-10 12:05:39 +08:00
|
|
|
|
input: torch.Tensor,
|
|
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
scale: torch.Tensor,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
epsilon: float,
|
|
|
|
|
|
) -> None:
|
2025-12-10 12:05:39 +08:00
|
|
|
|
pass
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
@impl("_C::rms_norm_static_fp8_quant", "CUDA")
|
|
|
|
|
|
def rms_norm_static_fp8_quant_xpu(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
result: torch.Tensor,
|
2025-12-10 12:05:39 +08:00
|
|
|
|
input: torch.Tensor,
|
|
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
scale: torch.Tensor,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
epsilon: float,
|
|
|
|
|
|
) -> None:
|
2025-12-10 12:05:39 +08:00
|
|
|
|
pass
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
@custom_op("_C::fused_add_rms_norm_static_fp8_quant", mutates_args=())
|
|
|
|
|
|
def fused_add_rms_norm_static_fp8_quant(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
result: torch.Tensor,
|
2025-12-10 12:05:39 +08:00
|
|
|
|
input: torch.Tensor,
|
|
|
|
|
|
residual: torch.Tensor,
|
|
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
scale: torch.Tensor,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
epsilon: float,
|
|
|
|
|
|
) -> None:
|
2025-12-10 12:05:39 +08:00
|
|
|
|
pass
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
@impl("_C::fused_add_rms_norm_static_fp8_quant", "CUDA")
|
|
|
|
|
|
def fused_add_rms_norm_static_fp8_quant_xpu(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
result: torch.Tensor,
|
2025-12-10 12:05:39 +08:00
|
|
|
|
input: torch.Tensor,
|
|
|
|
|
|
residual: torch.Tensor,
|
|
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
scale: torch.Tensor,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
epsilon: float,
|
|
|
|
|
|
) -> None:
|
2025-12-10 12:05:39 +08:00
|
|
|
|
pass
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
@custom_op("_C::rms_norm_dynamic_per_token_quant", mutates_args=())
|
|
|
|
|
|
def rms_norm_dynamic_per_token_quant(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
result: torch.Tensor,
|
2025-12-10 12:05:39 +08:00
|
|
|
|
input: torch.Tensor,
|
|
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
scale: torch.Tensor,
|
|
|
|
|
|
epsilon: float,
|
|
|
|
|
|
scale_ub: Optional[torch.Tensor],
|
2026-01-06 13:51:53 +08:00
|
|
|
|
residual: Optional[torch.Tensor],
|
|
|
|
|
|
) -> None:
|
2025-12-10 12:05:39 +08:00
|
|
|
|
pass
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
@impl("_C::rms_norm_dynamic_per_token_quant", "CUDA")
|
|
|
|
|
|
def rms_norm_dynamic_per_token_quant_xpu(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
result: torch.Tensor,
|
2025-12-10 12:05:39 +08:00
|
|
|
|
input: torch.Tensor,
|
|
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
scale: torch.Tensor,
|
|
|
|
|
|
epsilon: float,
|
|
|
|
|
|
scale_ub: Optional[torch.Tensor],
|
2026-01-06 13:51:53 +08:00
|
|
|
|
residual: Optional[torch.Tensor],
|
|
|
|
|
|
) -> None:
|
2025-12-10 12:05:39 +08:00
|
|
|
|
pass
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
@custom_op("_C::silu_and_mul_quant", mutates_args=())
|
|
|
|
|
|
def silu_and_mul_quant(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
result: torch.Tensor,
|
2025-12-10 17:51:24 +08:00
|
|
|
|
input: torch.Tensor,
|
|
|
|
|
|
residual: torch.Tensor,
|
|
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
scale: torch.Tensor,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
epsilon: float,
|
|
|
|
|
|
) -> None:
|
2025-12-10 17:51:24 +08:00
|
|
|
|
pass
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
|
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
@impl("_C::silu_and_mul_quant", "CUDA")
|
|
|
|
|
|
def silu_and_mul_quant_xpu(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
result: torch.Tensor,
|
2025-12-10 17:51:24 +08:00
|
|
|
|
input: torch.Tensor,
|
|
|
|
|
|
residual: torch.Tensor,
|
|
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
scale: torch.Tensor,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
epsilon: float,
|
|
|
|
|
|
) -> None:
|
2025-12-10 17:51:24 +08:00
|
|
|
|
pass
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2026-02-13 14:07:10 +08:00
|
|
|
|
import kunlun_ops # noqa: E402
|
2026-02-11 12:04:14 +08:00
|
|
|
|
import torch # noqa: E402
|
|
|
|
|
|
from torch.library import custom_op, impl # noqa: E402
|
2025-12-10 12:05:39 +08:00
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
@custom_op("_C::add_rmsnorm", mutates_args=())
|
|
|
|
|
|
def add_rmsnorm(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
y: torch.Tensor,
|
|
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
output: torch.Tensor,
|
|
|
|
|
|
eps: float = 1e-5,
|
|
|
|
|
|
interweaved: bool = False,
|
|
|
|
|
|
store_output_before_norm: bool = True,
|
|
|
|
|
|
bias: torch.Tensor = None,
|
|
|
|
|
|
smooth: torch.Tensor = None,
|
|
|
|
|
|
residual_output: torch.Tensor = None,
|
|
|
|
|
|
output_max: torch.Tensor = None,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.add_rmsnorm(
|
2025-12-10 12:05:39 +08:00
|
|
|
|
x,
|
2026-02-28 11:15:50 +08:00
|
|
|
|
y,
|
2025-12-10 12:05:39 +08:00
|
|
|
|
residual_output=residual_output,
|
|
|
|
|
|
weight=weight,
|
|
|
|
|
|
eps=eps,
|
|
|
|
|
|
output=output,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
@impl("_C::add_rmsnorm", "CUDA")
|
|
|
|
|
|
def add_rmsnorm_cuda(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
y: torch.Tensor,
|
|
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
output: torch.Tensor,
|
|
|
|
|
|
eps: float = 1e-5,
|
|
|
|
|
|
interweaved: bool = False,
|
|
|
|
|
|
store_output_before_norm: bool = True,
|
|
|
|
|
|
bias: torch.Tensor = None,
|
|
|
|
|
|
smooth: torch.Tensor = None,
|
|
|
|
|
|
residual_output: torch.Tensor = None,
|
|
|
|
|
|
output_max: torch.Tensor = None,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.add_rmsnorm(
|
2025-12-10 12:05:39 +08:00
|
|
|
|
x,
|
|
|
|
|
|
y,
|
|
|
|
|
|
residual_output=residual_output,
|
|
|
|
|
|
weight=weight,
|
|
|
|
|
|
eps=eps,
|
|
|
|
|
|
output=output,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
@custom_op("_C::rmsnorm", mutates_args=())
|
|
|
|
|
|
def rmsnorm(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
output: torch.Tensor,
|
|
|
|
|
|
eps: float = 1e-5,
|
|
|
|
|
|
interweave: bool = False,
|
|
|
|
|
|
store_output_before_norm: bool = True,
|
|
|
|
|
|
bias: torch.Tensor = None,
|
|
|
|
|
|
residual_output: torch.Tensor = None,
|
|
|
|
|
|
output_max: torch.Tensor = None,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.rmsnorm(
|
2025-12-10 12:05:39 +08:00
|
|
|
|
x,
|
|
|
|
|
|
weight,
|
|
|
|
|
|
output,
|
|
|
|
|
|
eps,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
@impl("_C::rmsnorm", "CUDA")
|
|
|
|
|
|
def rmsnorm_cuda(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
output: torch.Tensor,
|
|
|
|
|
|
eps: float = 1e-5,
|
|
|
|
|
|
interweave: bool = False,
|
|
|
|
|
|
store_output_before_norm: bool = True,
|
|
|
|
|
|
bias: torch.Tensor = None,
|
|
|
|
|
|
residual_output: torch.Tensor = None,
|
|
|
|
|
|
output_max: torch.Tensor = None,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.rmsnorm(
|
2025-12-10 12:05:39 +08:00
|
|
|
|
x,
|
|
|
|
|
|
weight,
|
|
|
|
|
|
output,
|
|
|
|
|
|
eps,
|
|
|
|
|
|
)
|
2025-12-10 17:51:24 +08:00
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2026-02-11 12:04:14 +08:00
|
|
|
|
import torch # noqa: E402
|
2025-12-10 17:51:24 +08:00
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
|
|
|
|
|
def _fake_rmsnorm(
|
|
|
|
|
|
x,
|
|
|
|
|
|
weight,
|
|
|
|
|
|
output,
|
|
|
|
|
|
eps=1e-5,
|
|
|
|
|
|
interweave=False,
|
|
|
|
|
|
store_output_before_norm=True,
|
|
|
|
|
|
bias=None,
|
|
|
|
|
|
residual_output=None,
|
|
|
|
|
|
output_max=None,
|
|
|
|
|
|
):
|
2025-12-10 17:51:24 +08:00
|
|
|
|
# 设置 shape/dtype,但不返回值
|
2025-12-10 12:05:39 +08:00
|
|
|
|
output.fake_shape = x.shape
|
|
|
|
|
|
output.fake_dtype = x.dtype
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
rmsnorm.register_fake(_fake_rmsnorm)
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
|
|
|
|
|
def _fake_add_rmsnorm(
|
|
|
|
|
|
x,
|
|
|
|
|
|
y,
|
|
|
|
|
|
weight,
|
|
|
|
|
|
output,
|
|
|
|
|
|
eps=1e-5,
|
|
|
|
|
|
interweaved=False,
|
|
|
|
|
|
store_output_before_norm=True,
|
|
|
|
|
|
bias=None,
|
|
|
|
|
|
smooth=None,
|
|
|
|
|
|
residual_output=None,
|
|
|
|
|
|
output_max=None,
|
|
|
|
|
|
):
|
2025-12-10 12:05:39 +08:00
|
|
|
|
output.fake_shape = x.shape
|
|
|
|
|
|
output.fake_dtype = x.dtype
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
add_rmsnorm.register_fake(_fake_add_rmsnorm)
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2026-02-28 11:15:50 +08:00
|
|
|
|
@custom_op("_C::gemma_add_rmsnorm", mutates_args=())
|
|
|
|
|
|
def gemma_add_rmsnorm(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
y: torch.Tensor,
|
|
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
output: torch.Tensor,
|
|
|
|
|
|
eps: float = 1e-5,
|
|
|
|
|
|
enable_pdl: bool = False,
|
|
|
|
|
|
interweaved: bool = False,
|
|
|
|
|
|
store_output_before_norm: bool = True,
|
|
|
|
|
|
bias: torch.Tensor = None,
|
|
|
|
|
|
smooth: torch.Tensor = None,
|
|
|
|
|
|
residual_output: torch.Tensor = None,
|
|
|
|
|
|
force_sdnn: bool = False,
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
# print("gemma_add_rmsnorm wrapper")
|
|
|
|
|
|
kunlun_ops.gemma_add_rmsnorm(
|
|
|
|
|
|
x,
|
|
|
|
|
|
y,
|
|
|
|
|
|
weight=weight,
|
|
|
|
|
|
output=output,
|
|
|
|
|
|
eps=eps,
|
|
|
|
|
|
enable_pdl=enable_pdl,
|
|
|
|
|
|
interweaved=interweaved,
|
|
|
|
|
|
store_output_before_norm=store_output_before_norm,
|
|
|
|
|
|
bias=bias,
|
|
|
|
|
|
smooth=smooth,
|
|
|
|
|
|
residual_output=residual_output,
|
|
|
|
|
|
force_sdnn=force_sdnn,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@impl("_C::gemma_add_rmsnorm", "CUDA")
|
|
|
|
|
|
def gemma_add_rmsnorm_cuda(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
y: torch.Tensor,
|
|
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
output: torch.Tensor,
|
|
|
|
|
|
eps: float = 1e-5,
|
|
|
|
|
|
enable_pdl: bool = False,
|
|
|
|
|
|
interweaved: bool = False,
|
|
|
|
|
|
store_output_before_norm: bool = True,
|
|
|
|
|
|
bias: torch.Tensor = None,
|
|
|
|
|
|
smooth: torch.Tensor = None,
|
|
|
|
|
|
residual_output: torch.Tensor = None,
|
|
|
|
|
|
force_sdnn: bool = False,
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
# print("gemma_add_rmsnorm_cuda wrapper")
|
|
|
|
|
|
kunlun_ops.gemma_add_rmsnorm(
|
|
|
|
|
|
x,
|
|
|
|
|
|
y,
|
|
|
|
|
|
weight=weight,
|
|
|
|
|
|
output=output,
|
|
|
|
|
|
eps=eps,
|
|
|
|
|
|
enable_pdl=enable_pdl,
|
|
|
|
|
|
interweaved=interweaved,
|
|
|
|
|
|
store_output_before_norm=store_output_before_norm,
|
|
|
|
|
|
bias=bias,
|
|
|
|
|
|
smooth=smooth,
|
|
|
|
|
|
residual_output=residual_output,
|
|
|
|
|
|
force_sdnn=force_sdnn,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _fake_gemma_add_rmsnorm(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
y: torch.Tensor,
|
|
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
output: torch.Tensor,
|
|
|
|
|
|
eps: float = 1e-5,
|
|
|
|
|
|
enable_pdl: bool = False,
|
|
|
|
|
|
interweaved: bool = False,
|
|
|
|
|
|
store_output_before_norm: bool = True,
|
|
|
|
|
|
bias: torch.Tensor = None,
|
|
|
|
|
|
smooth: torch.Tensor = None,
|
|
|
|
|
|
residual_output: torch.Tensor = None,
|
|
|
|
|
|
force_sdnn: bool = False,
|
|
|
|
|
|
):
|
|
|
|
|
|
output.fake_shape = x.shape
|
|
|
|
|
|
output.fake_dtype = x.dtype
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gemma_add_rmsnorm.register_fake(_fake_gemma_add_rmsnorm)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@custom_op("_C::gemma_rmsnorm", mutates_args=())
|
|
|
|
|
|
def gemma_rmsnorm(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
output: torch.Tensor,
|
|
|
|
|
|
eps: float = 1e-5,
|
|
|
|
|
|
enable_pdl: bool = False,
|
|
|
|
|
|
interweave: bool = False,
|
|
|
|
|
|
bias: torch.Tensor = None,
|
|
|
|
|
|
force_sdnn: bool = False,
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
# print("gemma_rmsnorm wrapper")
|
|
|
|
|
|
kunlun_ops.gemma_rmsnorm(
|
|
|
|
|
|
x, weight, output, eps, enable_pdl, interweave, bias, force_sdnn
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@impl("_C::gemma_rmsnorm", "CUDA")
|
|
|
|
|
|
def gemma_rmsnorm_cuda(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
output: torch.Tensor,
|
|
|
|
|
|
eps: float = 1e-5,
|
|
|
|
|
|
enable_pdl: bool = False,
|
|
|
|
|
|
interweave: bool = False,
|
|
|
|
|
|
bias: torch.Tensor = None,
|
|
|
|
|
|
force_sdnn: bool = False,
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
# print("gemma_rmsnorm_cuda wrapper")
|
|
|
|
|
|
kunlun_ops.gemma_rmsnorm(
|
|
|
|
|
|
x, weight, output, eps, enable_pdl, interweave, bias, force_sdnn
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _fake_gemma_rmsnorm(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
output: torch.Tensor,
|
|
|
|
|
|
eps: float = 1e-5,
|
|
|
|
|
|
enable_pdl: bool = False,
|
|
|
|
|
|
interweave: bool = False,
|
|
|
|
|
|
bias: torch.Tensor = None,
|
|
|
|
|
|
force_sdnn: bool = False,
|
|
|
|
|
|
):
|
|
|
|
|
|
# 设置 shape/dtype,但不返回值
|
|
|
|
|
|
output.fake_shape = x.shape
|
|
|
|
|
|
output.fake_dtype = x.dtype
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gemma_rmsnorm.register_fake(_fake_gemma_rmsnorm)
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
@custom_op("_C::split_norm_rope_neox", mutates_args=())
|
|
|
|
|
|
def split_norm_rope_neox(
|
|
|
|
|
|
q_emb: torch.Tensor,
|
|
|
|
|
|
k_emb: torch.Tensor,
|
|
|
|
|
|
v_out: torch.Tensor,
|
|
|
|
|
|
qkv: torch.Tensor,
|
|
|
|
|
|
rotary_pos_embedding: torch.Tensor,
|
|
|
|
|
|
q_norm_weight: torch.Tensor,
|
|
|
|
|
|
k_norm_weight: torch.Tensor,
|
|
|
|
|
|
positions: torch.Tensor,
|
|
|
|
|
|
num_tokens: int,
|
|
|
|
|
|
max_seqlen: int,
|
|
|
|
|
|
head_num: int,
|
|
|
|
|
|
kv_head_num: int,
|
|
|
|
|
|
head_dim: int,
|
|
|
|
|
|
rotary_dim: int,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
emb_batch_size: int = 1,
|
2025-12-10 17:51:24 +08:00
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.split_norm_rope_neox(
|
2025-12-10 17:51:24 +08:00
|
|
|
|
q_emb,
|
|
|
|
|
|
k_emb,
|
|
|
|
|
|
v_out,
|
|
|
|
|
|
qkv,
|
|
|
|
|
|
rotary_pos_embedding,
|
|
|
|
|
|
q_norm_weight,
|
|
|
|
|
|
k_norm_weight,
|
|
|
|
|
|
positions,
|
|
|
|
|
|
num_tokens,
|
|
|
|
|
|
max_seqlen,
|
|
|
|
|
|
head_num,
|
|
|
|
|
|
kv_head_num,
|
|
|
|
|
|
head_dim,
|
|
|
|
|
|
rotary_dim,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
@impl("_C::split_norm_rope_neox", "CUDA")
|
|
|
|
|
|
def split_norm_rope_neox_cuda(
|
|
|
|
|
|
q_emb: torch.Tensor,
|
|
|
|
|
|
k_emb: torch.Tensor,
|
|
|
|
|
|
v_out: torch.Tensor,
|
|
|
|
|
|
qkv: torch.Tensor,
|
|
|
|
|
|
rotary_pos_embedding: torch.Tensor,
|
|
|
|
|
|
q_norm_weight: torch.Tensor,
|
|
|
|
|
|
k_norm_weight: torch.Tensor,
|
|
|
|
|
|
positions: torch.Tensor,
|
|
|
|
|
|
num_tokens: int,
|
|
|
|
|
|
max_seqlen: int,
|
|
|
|
|
|
head_num: int,
|
|
|
|
|
|
kv_head_num: int,
|
|
|
|
|
|
head_dim: int,
|
|
|
|
|
|
rotary_dim: int,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
emb_batch_size: int = 1,
|
2025-12-10 17:51:24 +08:00
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.split_norm_rope_neox(
|
2025-12-10 17:51:24 +08:00
|
|
|
|
q_emb,
|
|
|
|
|
|
k_emb,
|
|
|
|
|
|
v_out,
|
|
|
|
|
|
qkv,
|
|
|
|
|
|
rotary_pos_embedding,
|
|
|
|
|
|
q_norm_weight,
|
|
|
|
|
|
k_norm_weight,
|
|
|
|
|
|
positions,
|
|
|
|
|
|
num_tokens,
|
|
|
|
|
|
max_seqlen,
|
|
|
|
|
|
head_num,
|
|
|
|
|
|
kv_head_num,
|
|
|
|
|
|
head_dim,
|
|
|
|
|
|
rotary_dim,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
def _fake_split_norm_rope_neox(
|
|
|
|
|
|
q_emb: torch.Tensor,
|
|
|
|
|
|
k_emb: torch.Tensor,
|
|
|
|
|
|
v_out: torch.Tensor,
|
|
|
|
|
|
qkv: torch.Tensor,
|
|
|
|
|
|
rotary_pos_embedding: torch.Tensor,
|
|
|
|
|
|
q_norm_weight: torch.Tensor,
|
|
|
|
|
|
k_norm_weight: torch.Tensor,
|
|
|
|
|
|
positions: torch.Tensor,
|
|
|
|
|
|
num_tokens: int,
|
|
|
|
|
|
max_seqlen: int,
|
|
|
|
|
|
head_num: int,
|
|
|
|
|
|
kv_head_num: int,
|
|
|
|
|
|
head_dim: int,
|
|
|
|
|
|
rotary_dim: int,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
emb_batch_size: int = 1,
|
|
|
|
|
|
):
|
2025-12-10 17:51:24 +08:00
|
|
|
|
q_emb.fake_shape = q_emb.shape
|
|
|
|
|
|
q_emb.fake_dtype = q_emb.dtype
|
|
|
|
|
|
k_emb.fake_shape = k_emb.shape
|
|
|
|
|
|
k_emb.fake_dtype = k_emb.dtype
|
|
|
|
|
|
v_out.fake_shape = v_out.shape
|
|
|
|
|
|
v_out.fake_dtype = v_out.dtype
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
split_norm_rope_neox.register_fake(_fake_split_norm_rope_neox)
|
2025-12-10 12:05:39 +08:00
|
|
|
|
|
|
|
|
|
|
# register fake op impl here
|
|
|
|
|
|
# for torch.dynamo
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
if hasattr(torch.ops.custom_ops, "fc_fusion"):
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
@register_fake("custom_ops::fc_fusion")
|
2026-01-06 13:51:53 +08:00
|
|
|
|
def fc_fusion_fake(
|
|
|
|
|
|
self: torch.Tensor,
|
|
|
|
|
|
other: torch.Tensor,
|
|
|
|
|
|
bias: Optional[torch.Tensor],
|
|
|
|
|
|
self_trans: bool,
|
|
|
|
|
|
other_trans: bool,
|
|
|
|
|
|
*,
|
|
|
|
|
|
alpha: float = 1.0,
|
|
|
|
|
|
beta: float = 0.0,
|
|
|
|
|
|
act: int = 1,
|
|
|
|
|
|
multi_stream: bool = False,
|
|
|
|
|
|
out: torch.Tensor,
|
|
|
|
|
|
) -> None:
|
2025-12-10 12:05:39 +08:00
|
|
|
|
pass
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-31 11:31:26 +08:00
|
|
|
|
@custom_op("_C::silu_and_mul", mutates_args=())
|
|
|
|
|
|
def silu_and_mul(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
out: torch.Tensor, x: torch.Tensor, axis: int = -1, turn: bool = True
|
2025-12-10 12:05:39 +08:00
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.swiglu(
|
2025-12-31 11:31:26 +08:00
|
|
|
|
x=x,
|
|
|
|
|
|
y=out,
|
2025-12-10 12:05:39 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
|
2025-12-31 11:31:26 +08:00
|
|
|
|
@impl("_C::silu_and_mul", "CUDA")
|
|
|
|
|
|
def silu_and_mul_cuda(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
out: torch.Tensor, x: torch.Tensor, axis: int = -1, turn: bool = True
|
2025-12-10 12:05:39 +08:00
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.swiglu(
|
2025-12-31 11:31:26 +08:00
|
|
|
|
x=x,
|
|
|
|
|
|
y=out,
|
2025-12-10 12:05:39 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-31 11:31:26 +08:00
|
|
|
|
def _fake_silu_and_mul(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
out: torch.Tensor, x: torch.Tensor, axis: int = -1, turn: bool = True
|
|
|
|
|
|
):
|
2025-12-10 12:05:39 +08:00
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
silu_and_mul.register_fake(_fake_silu_and_mul)
|
2025-12-10 17:51:24 +08:00
|
|
|
|
|
|
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
@custom_op("_C::swigluoai_and_mul", mutates_args=())
|
|
|
|
|
|
def swigluoai_and_mul(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
alpha: float = 1.702,
|
|
|
|
|
|
limit: float = 7.0,
|
|
|
|
|
|
axis: int = -1,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
turn: bool = True,
|
2025-12-10 12:05:39 +08:00
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
"""PyTorch-native implementation equivalent to forward()."""
|
|
|
|
|
|
gate, up = x[..., ::2], x[..., 1::2]
|
|
|
|
|
|
gate = gate.clamp(min=None, max=limit)
|
|
|
|
|
|
up = up.clamp(min=-limit, max=limit)
|
|
|
|
|
|
glu = gate * torch.sigmoid(gate * alpha)
|
|
|
|
|
|
gated_output = (up + 1) * glu
|
|
|
|
|
|
return gated_output
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
@impl("_C::swigluoai_and_mul", "CUDA")
|
|
|
|
|
|
def swigluoai_and_mul_cuda(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
alpha: float = 1.702,
|
|
|
|
|
|
limit: float = 7.0,
|
|
|
|
|
|
axis: int = -1,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
turn: bool = True,
|
2025-12-10 12:05:39 +08:00
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
"""PyTorch-native implementation equivalent to forward()."""
|
|
|
|
|
|
gate, up = x[..., ::2], x[..., 1::2]
|
|
|
|
|
|
gate = gate.clamp(min=None, max=limit)
|
|
|
|
|
|
up = up.clamp(min=-limit, max=limit)
|
|
|
|
|
|
glu = gate * torch.sigmoid(gate * alpha)
|
|
|
|
|
|
gated_output = (up + 1) * glu
|
|
|
|
|
|
return gated_output
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
def _fake_swigluoai_and_mul(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
alpha: float = 1.702,
|
|
|
|
|
|
limit: float = 7.0,
|
|
|
|
|
|
axis: int = -1,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
turn: bool = True,
|
2025-12-10 12:05:39 +08:00
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
"""PyTorch-native implementation equivalent to forward()."""
|
|
|
|
|
|
gate, up = x[..., ::2], x[..., 1::2]
|
|
|
|
|
|
gate = gate.clamp(min=None, max=limit)
|
|
|
|
|
|
up = up.clamp(min=-limit, max=limit)
|
|
|
|
|
|
glu = gate * torch.sigmoid(gate * alpha)
|
|
|
|
|
|
gated_output = (up + 1) * glu
|
|
|
|
|
|
return gated_output
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
swigluoai_and_mul.register_fake(_fake_swigluoai_and_mul)
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
@custom_op("_C::moe_softmax_topk", mutates_args=())
|
|
|
|
|
|
def moe_softmax_topk(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
normed_score: torch.Tensor,
|
|
|
|
|
|
topk_index: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
|
|
|
|
|
axis: int = -1,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
turn: bool = True,
|
2025-12-10 12:05:39 +08:00
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic)
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
|
|
|
|
|
|
@impl("_C::moe_softmax_topk", "CUDA")
|
|
|
|
|
|
def moe_softmax_topk_cuda(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
normed_score: torch.Tensor,
|
|
|
|
|
|
topk_index: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
|
|
|
|
|
axis: int = -1,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
turn: bool = True,
|
2025-12-10 12:05:39 +08:00
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic)
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
|
|
|
|
|
|
def _fake_moe_softmax_topk(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
normed_score: torch.Tensor,
|
|
|
|
|
|
topk_index: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
|
|
|
|
|
axis: int = -1,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
turn: bool = True,
|
2025-12-10 12:05:39 +08:00
|
|
|
|
) -> None:
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
moe_softmax_topk.register_fake(_fake_moe_softmax_topk)
|
|
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
@custom_op("_C::moe_ffn_block", mutates_args=())
|
|
|
|
|
|
def moe_ffn_block(
|
|
|
|
|
|
out: torch.Tensor,
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
expert_num: int,
|
|
|
|
|
|
moe_top_k: int,
|
|
|
|
|
|
gate_w: torch.Tensor,
|
|
|
|
|
|
inter_w: torch.Tensor,
|
|
|
|
|
|
output_w: torch.Tensor,
|
|
|
|
|
|
renormalize: bool = True,
|
|
|
|
|
|
use_grouped_topk: bool = False,
|
|
|
|
|
|
expert_group_num: Optional[int] = 0,
|
|
|
|
|
|
topk_group: Optional[int] = 0,
|
|
|
|
|
|
w1_bias: Optional[torch.Tensor] = None,
|
|
|
|
|
|
w2_bias: Optional[torch.Tensor] = None,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.moe_ffn_block(
|
2025-12-10 12:05:39 +08:00
|
|
|
|
x=x,
|
|
|
|
|
|
gate_w=gate_w,
|
|
|
|
|
|
inter_w=inter_w,
|
|
|
|
|
|
output_w=output_w,
|
|
|
|
|
|
expert_num=expert_num,
|
|
|
|
|
|
moe_top_k=moe_top_k,
|
|
|
|
|
|
topk_group=topk_group,
|
|
|
|
|
|
renormalize=renormalize,
|
|
|
|
|
|
use_grouped_topk=use_grouped_topk,
|
|
|
|
|
|
expert_group_num=expert_group_num,
|
|
|
|
|
|
out=out,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@impl("_C::moe_ffn_block", "CUDA")
|
|
|
|
|
|
def moe_ffn_block_cuda(
|
|
|
|
|
|
out: torch.Tensor,
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
expert_num: int,
|
|
|
|
|
|
moe_top_k: int,
|
|
|
|
|
|
gate_w: torch.Tensor,
|
|
|
|
|
|
inter_w: torch.Tensor,
|
|
|
|
|
|
output_w: torch.Tensor,
|
|
|
|
|
|
renormalize: bool = True,
|
|
|
|
|
|
use_grouped_topk: bool = False,
|
|
|
|
|
|
expert_group_num: Optional[int] = 0,
|
|
|
|
|
|
topk_group: Optional[int] = 0,
|
|
|
|
|
|
w1_bias: Optional[torch.Tensor] = None,
|
|
|
|
|
|
w2_bias: Optional[torch.Tensor] = None,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.moe_ffn_block(
|
2025-12-10 12:05:39 +08:00
|
|
|
|
x=x,
|
|
|
|
|
|
gate_w=gate_w,
|
|
|
|
|
|
inter_w=inter_w,
|
|
|
|
|
|
output_w=output_w,
|
|
|
|
|
|
expert_num=expert_num,
|
|
|
|
|
|
moe_top_k=moe_top_k,
|
|
|
|
|
|
topk_group=topk_group,
|
|
|
|
|
|
renormalize=renormalize,
|
|
|
|
|
|
use_grouped_topk=use_grouped_topk,
|
|
|
|
|
|
expert_group_num=expert_group_num,
|
|
|
|
|
|
out=out,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
def _fake_moe_ffn_block(
|
|
|
|
|
|
out: torch.Tensor,
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
expert_num: int,
|
|
|
|
|
|
moe_top_k: int,
|
|
|
|
|
|
gate_w: torch.Tensor,
|
|
|
|
|
|
inter_w: torch.Tensor,
|
|
|
|
|
|
output_w: torch.Tensor,
|
|
|
|
|
|
renormalize: bool = True,
|
|
|
|
|
|
use_grouped_topk: bool = False,
|
|
|
|
|
|
expert_group_num: Optional[int] = 0,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
topk_group: Optional[int] = 0,
|
|
|
|
|
|
):
|
2025-12-10 12:05:39 +08:00
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
moe_ffn_block.register_fake(_fake_moe_ffn_block)
|
|
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
@custom_op("_C::moe_ffn_per_token_block", mutates_args=())
|
|
|
|
|
|
def moe_ffn_per_token_block(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
inter_weight: torch.Tensor,
|
|
|
|
|
|
inter_scale: torch.Tensor,
|
|
|
|
|
|
outer_weight: torch.Tensor,
|
|
|
|
|
|
outer_scale: torch.Tensor,
|
|
|
|
|
|
top_k: int,
|
|
|
|
|
|
global_num_experts: int,
|
|
|
|
|
|
linear_weights: Optional[torch.Tensor] = None,
|
|
|
|
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
|
|
|
|
activation: str = "silu",
|
|
|
|
|
|
output: Optional[torch.Tensor] = None,
|
|
|
|
|
|
use_expert_parallel: bool = False,
|
|
|
|
|
|
ep_size: int = 1,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
ep_rank: int = 0,
|
2025-12-10 12:05:39 +08:00
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.moe_ffn_per_token_block(
|
2025-12-10 12:05:39 +08:00
|
|
|
|
x=x,
|
|
|
|
|
|
inter_weight=inter_weight,
|
|
|
|
|
|
inter_scale=inter_scale,
|
|
|
|
|
|
outer_weight=outer_weight,
|
|
|
|
|
|
outer_scale=outer_scale,
|
|
|
|
|
|
gate_weight=linear_weights,
|
|
|
|
|
|
expert_num=global_num_experts,
|
|
|
|
|
|
moe_top_k=top_k,
|
|
|
|
|
|
act_type=activation,
|
|
|
|
|
|
use_expert_parallel=use_expert_parallel,
|
|
|
|
|
|
ep_size=ep_size,
|
|
|
|
|
|
ep_rank=ep_rank,
|
|
|
|
|
|
out=output,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
@impl("_C::moe_ffn_per_token_block", "CUDA")
|
|
|
|
|
|
def moe_ffn_per_token_block_cuda(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
inter_weight: torch.Tensor,
|
|
|
|
|
|
inter_scale: torch.Tensor,
|
|
|
|
|
|
outer_weight: torch.Tensor,
|
|
|
|
|
|
outer_scale: torch.Tensor,
|
|
|
|
|
|
top_k: int,
|
|
|
|
|
|
global_num_experts: int,
|
|
|
|
|
|
linear_weights: Optional[torch.Tensor] = None,
|
|
|
|
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
|
|
|
|
activation: str = "silu",
|
|
|
|
|
|
output: Optional[torch.Tensor] = None,
|
|
|
|
|
|
use_expert_parallel: bool = False,
|
|
|
|
|
|
ep_size: int = 1,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
ep_rank: int = 0,
|
2025-12-10 12:05:39 +08:00
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.moe_ffn_per_token_block(
|
2025-12-10 12:05:39 +08:00
|
|
|
|
x=x,
|
|
|
|
|
|
inter_weight=inter_weight,
|
|
|
|
|
|
inter_scale=inter_scale,
|
|
|
|
|
|
outer_weight=outer_weight,
|
|
|
|
|
|
outer_scale=outer_scale,
|
|
|
|
|
|
gate_weight=linear_weights,
|
|
|
|
|
|
expert_num=global_num_experts,
|
|
|
|
|
|
moe_top_k=top_k,
|
|
|
|
|
|
act_type=activation,
|
|
|
|
|
|
use_expert_parallel=use_expert_parallel,
|
|
|
|
|
|
ep_size=ep_size,
|
|
|
|
|
|
ep_rank=ep_rank,
|
|
|
|
|
|
out=output,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
def _fake_moe_ffn_per_token_block(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
inter_weight: torch.Tensor,
|
|
|
|
|
|
inter_scale: torch.Tensor,
|
|
|
|
|
|
outer_weight: torch.Tensor,
|
|
|
|
|
|
outer_scale: torch.Tensor,
|
|
|
|
|
|
top_k: int,
|
|
|
|
|
|
global_num_experts: int,
|
|
|
|
|
|
linear_weights: Optional[torch.Tensor] = None,
|
|
|
|
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
|
|
|
|
activation: str = "silu",
|
|
|
|
|
|
output: Optional[torch.Tensor] = None,
|
|
|
|
|
|
use_expert_parallel: bool = False,
|
|
|
|
|
|
ep_size: int = 1,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
ep_rank: int = 0,
|
2025-12-10 12:05:39 +08:00
|
|
|
|
) -> None:
|
|
|
|
|
|
# Fake implementation can be a no-op or a simple operation
|
|
|
|
|
|
if output is not None:
|
|
|
|
|
|
output.copy_(x) # Example: simply copy input to output
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
# Register the fake implementation
|
|
|
|
|
|
moe_ffn_per_token_block.register_fake(_fake_moe_ffn_per_token_block)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@custom_op("_C::rotary_embedding", mutates_args=())
|
|
|
|
|
|
def rotary_embedding(
|
|
|
|
|
|
positions: torch.Tensor,
|
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
|
head_size: int,
|
|
|
|
|
|
cos_sin_cache: torch.Tensor,
|
|
|
|
|
|
is_neox: bool,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.rotary_embedding(
|
2025-12-10 12:05:39 +08:00
|
|
|
|
positions=positions,
|
|
|
|
|
|
query=query,
|
|
|
|
|
|
key=key,
|
|
|
|
|
|
head_size=head_size,
|
|
|
|
|
|
cos_sin_cache=cos_sin_cache,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
is_neox=is_neox,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
|
|
|
|
|
|
@impl("_C::rotary_embedding", "CUDA")
|
|
|
|
|
|
def rotary_embedding_cuda(
|
|
|
|
|
|
positions: torch.Tensor,
|
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
|
head_size: int,
|
|
|
|
|
|
cos_sin_cache: torch.Tensor,
|
|
|
|
|
|
is_neox: bool,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.rotary_embedding(
|
2025-12-10 12:05:39 +08:00
|
|
|
|
positions=positions,
|
|
|
|
|
|
query=query,
|
|
|
|
|
|
key=key,
|
|
|
|
|
|
head_size=head_size,
|
|
|
|
|
|
cos_sin_cache=cos_sin_cache,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
is_neox=is_neox,
|
|
|
|
|
|
)
|
2025-12-10 12:05:39 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _fake_rotary_embedding(
|
|
|
|
|
|
positions: torch.Tensor,
|
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
|
head_size: int,
|
|
|
|
|
|
cos_sin_cache: torch.Tensor,
|
|
|
|
|
|
is_neox: bool,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
) -> None:
|
2025-12-10 12:05:39 +08:00
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rotary_embedding.register_fake(_fake_rotary_embedding)
|
|
|
|
|
|
|
2025-12-31 10:16:33 +08:00
|
|
|
|
|
|
|
|
|
|
@custom_op("_C::gemm_I8_I8_bf16_nt", mutates_args=())
|
|
|
|
|
|
def gemm_I8_I8_bf16_nt(
|
|
|
|
|
|
x_q: torch.Tensor,
|
|
|
|
|
|
x_scale: torch.Tensor,
|
|
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
weight_scale: torch.Tensor,
|
|
|
|
|
|
out: torch.Tensor,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.gemm_I8_I8_bf16_nt(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
lhs=(x_q, x_scale), rhs=(weight, weight_scale), out=out
|
2025-12-31 10:16:33 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-31 10:16:33 +08:00
|
|
|
|
@impl("_C::gemm_I8_I8_bf16_nt", "CUDA")
|
|
|
|
|
|
def gemm_I8_I8_bf16_nt_cuda(
|
|
|
|
|
|
x_q: torch.Tensor,
|
|
|
|
|
|
x_scale: torch.Tensor,
|
|
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
weight_scale: torch.Tensor,
|
|
|
|
|
|
out: torch.Tensor,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.gemm_I8_I8_bf16_nt(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
lhs=(x_q, x_scale), rhs=(weight, weight_scale), out=out
|
2025-12-31 10:16:33 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-31 10:16:33 +08:00
|
|
|
|
def _fake_gemm_I8_I8_bf16_nt(
|
|
|
|
|
|
x_q: torch.Tensor,
|
|
|
|
|
|
x_scale: torch.Tensor,
|
|
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
weight_scale: torch.Tensor,
|
|
|
|
|
|
out: torch.Tensor,
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-31 10:16:33 +08:00
|
|
|
|
gemm_I8_I8_bf16_nt.register_fake(_fake_gemm_I8_I8_bf16_nt)
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
@custom_op("_C::moe_softmax_topk_norm", mutates_args=())
|
|
|
|
|
|
def moe_softmax_topk_norm(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
normed_score: torch.Tensor,
|
|
|
|
|
|
topk_index: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
stable: bool = True,
|
2025-12-10 12:05:39 +08:00
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.moe_softmax_topk_norm(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
x, normed_score, topk_index, block_statistic, stable
|
2025-12-10 12:05:39 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
@impl("_C::moe_softmax_topk_norm", "CUDA")
|
|
|
|
|
|
def moe_softmax_topk_norm_cuda(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
normed_score: torch.Tensor,
|
|
|
|
|
|
topk_index: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
stable: bool = True,
|
2025-12-10 12:05:39 +08:00
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.moe_softmax_topk_norm(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
x, normed_score, topk_index, block_statistic, stable
|
2025-12-10 12:05:39 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
def _fake_moe_softmax_topk_norm(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
normed_score: torch.Tensor,
|
|
|
|
|
|
topk_index: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
stable: bool = True,
|
2025-12-10 17:51:24 +08:00
|
|
|
|
) -> None:
|
2025-12-10 12:05:39 +08:00
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
moe_softmax_topk_norm.register_fake(_fake_moe_softmax_topk_norm)
|
2025-12-10 12:05:39 +08:00
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
@custom_op("_C::gen_block_statistic", mutates_args=())
|
2026-01-06 13:51:53 +08:00
|
|
|
|
def gen_block_statistic(topk_ids: torch.Tensor, block_statistic: torch.Tensor) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.gen_block_statistic(topk_ids, block_statistic)
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
@impl("_C::gen_block_statistic", "CUDA")
|
|
|
|
|
|
def gen_block_statistic_cuda(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
topk_ids: torch.Tensor, block_statistic: torch.Tensor
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.gen_block_statistic(topk_ids, block_statistic)
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
def fake_gen_block_statistic(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
topk_ids: torch.Tensor, block_statistic: torch.Tensor
|
|
|
|
|
|
) -> None:
|
2025-12-10 17:51:24 +08:00
|
|
|
|
return None
|
2025-12-10 12:05:39 +08:00
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
gen_block_statistic.register_fake(fake_gen_block_statistic)
|
2025-12-10 12:05:39 +08:00
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
@custom_op("_C::moe_pre_sorted", mutates_args=())
|
|
|
|
|
|
def moe_pre_sorted(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
topk_index: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
|
|
|
|
|
moe_expand: torch.Tensor,
|
|
|
|
|
|
moe_index: torch.Tensor,
|
|
|
|
|
|
expert_m: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_num_lod: torch.Tensor,
|
|
|
|
|
|
index_have_neg: bool = False,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.moe_pre_sorted(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
x,
|
|
|
|
|
|
topk_index,
|
|
|
|
|
|
block_statistic,
|
|
|
|
|
|
moe_expand,
|
|
|
|
|
|
moe_index,
|
|
|
|
|
|
expert_m,
|
|
|
|
|
|
sorted_tokens_num_lod,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
|
|
|
|
|
|
@impl("_C::moe_pre_sorted", "CUDA")
|
|
|
|
|
|
def moe_pre_sorted_cuda(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
topk_index: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
|
|
|
|
|
moe_expand: torch.Tensor,
|
|
|
|
|
|
moe_index: torch.Tensor,
|
|
|
|
|
|
expert_m: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_num_lod: torch.Tensor,
|
|
|
|
|
|
index_have_neg: bool = False,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.moe_pre_sorted(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
x,
|
|
|
|
|
|
topk_index,
|
|
|
|
|
|
block_statistic,
|
|
|
|
|
|
moe_expand,
|
|
|
|
|
|
moe_index,
|
|
|
|
|
|
expert_m,
|
|
|
|
|
|
sorted_tokens_num_lod,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
|
|
|
|
|
|
def fake_moe_pre_sorted(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
topk_index: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
|
|
|
|
|
moe_expand: torch.Tensor,
|
|
|
|
|
|
moe_index: torch.Tensor,
|
|
|
|
|
|
expert_m: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_num_lod: torch.Tensor,
|
|
|
|
|
|
index_have_neg: bool = False,
|
|
|
|
|
|
) -> None:
|
2025-12-10 17:51:24 +08:00
|
|
|
|
return None
|
2025-12-10 12:05:39 +08:00
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
moe_pre_sorted.register_fake(fake_moe_pre_sorted)
|
2025-12-10 12:05:39 +08:00
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
@custom_op("_C::moe_fc", mutates_args=())
|
|
|
|
|
|
def moe_fc(
|
2025-12-10 12:05:39 +08:00
|
|
|
|
x: torch.Tensor,
|
2025-12-10 17:51:24 +08:00
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_num_lod: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_idx: torch.Tensor,
|
|
|
|
|
|
moe_topk: int,
|
2025-12-31 10:16:33 +08:00
|
|
|
|
y: torch.Tensor,
|
|
|
|
|
|
act: Optional[str] = None,
|
|
|
|
|
|
x_perchannel_max: Optional[torch.Tensor] = None,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
w_perchannel_max: Optional[torch.Tensor] = None,
|
2025-12-31 10:16:33 +08:00
|
|
|
|
topk_ids: Optional[torch.Tensor] = None,
|
|
|
|
|
|
topk_w: Optional[torch.Tensor] = None,
|
|
|
|
|
|
bias: Optional[torch.Tensor] = None,
|
|
|
|
|
|
tgemm_type: Optional[str] = None,
|
|
|
|
|
|
tweight_type: Optional[str] = None,
|
|
|
|
|
|
scale_n: Optional[int] = 0,
|
|
|
|
|
|
scale_k: Optional[int] = 0,
|
|
|
|
|
|
use_pack_int4: Optional[bool] = False,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
sort_mode: Optional[bool] = True,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.moe_fc(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
x=x,
|
|
|
|
|
|
weight=weight,
|
|
|
|
|
|
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
|
|
|
|
|
sorted_tokens_idx=sorted_tokens_idx,
|
|
|
|
|
|
moe_topk=moe_topk,
|
|
|
|
|
|
y=y,
|
|
|
|
|
|
act=act,
|
|
|
|
|
|
x_perchannel_max=x_perchannel_max,
|
|
|
|
|
|
w_perchannel_max=w_perchannel_max,
|
|
|
|
|
|
topk_ids=topk_ids,
|
|
|
|
|
|
topk_w=topk_w,
|
|
|
|
|
|
bias=bias,
|
|
|
|
|
|
tgemm_type=tgemm_type,
|
|
|
|
|
|
tweight_type=tweight_type,
|
|
|
|
|
|
scale_n=scale_n,
|
|
|
|
|
|
scale_k=scale_k,
|
|
|
|
|
|
use_pack_int4=use_pack_int4,
|
|
|
|
|
|
sort_mode=sort_mode,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
@impl("_C::moe_fc", "CUDA")
|
|
|
|
|
|
def moe_fc_cuda(
|
2025-12-10 12:05:39 +08:00
|
|
|
|
x: torch.Tensor,
|
2025-12-10 17:51:24 +08:00
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_num_lod: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_idx: torch.Tensor,
|
|
|
|
|
|
moe_topk: int,
|
2025-12-31 10:16:33 +08:00
|
|
|
|
y: torch.Tensor,
|
|
|
|
|
|
act: Optional[str] = None,
|
|
|
|
|
|
x_perchannel_max: Optional[torch.Tensor] = None,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
w_perchannel_max: Optional[torch.Tensor] = None,
|
2025-12-31 10:16:33 +08:00
|
|
|
|
topk_ids: Optional[torch.Tensor] = None,
|
|
|
|
|
|
topk_w: Optional[torch.Tensor] = None,
|
|
|
|
|
|
bias: Optional[torch.Tensor] = None,
|
|
|
|
|
|
tgemm_type: Optional[str] = None,
|
|
|
|
|
|
tweight_type: Optional[str] = None,
|
|
|
|
|
|
scale_n: Optional[int] = 0,
|
|
|
|
|
|
scale_k: Optional[int] = 0,
|
|
|
|
|
|
use_pack_int4: Optional[bool] = False,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
sort_mode: Optional[bool] = True,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.moe_fc(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
x=x,
|
|
|
|
|
|
weight=weight,
|
|
|
|
|
|
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
|
|
|
|
|
sorted_tokens_idx=sorted_tokens_idx,
|
|
|
|
|
|
moe_topk=moe_topk,
|
|
|
|
|
|
y=y,
|
|
|
|
|
|
act=act,
|
|
|
|
|
|
x_perchannel_max=x_perchannel_max,
|
|
|
|
|
|
w_perchannel_max=w_perchannel_max,
|
|
|
|
|
|
topk_ids=topk_ids,
|
|
|
|
|
|
topk_w=topk_w,
|
|
|
|
|
|
bias=bias,
|
|
|
|
|
|
tgemm_type=tgemm_type,
|
|
|
|
|
|
tweight_type=tweight_type,
|
|
|
|
|
|
scale_n=scale_n,
|
|
|
|
|
|
scale_k=scale_k,
|
|
|
|
|
|
use_pack_int4=use_pack_int4,
|
|
|
|
|
|
sort_mode=sort_mode,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
|
|
|
|
|
|
def fake_moe_fc(
|
2025-12-10 12:05:39 +08:00
|
|
|
|
x: torch.Tensor,
|
2025-12-10 17:51:24 +08:00
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_num_lod: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_idx: torch.Tensor,
|
|
|
|
|
|
moe_topk: int,
|
2025-12-31 10:16:33 +08:00
|
|
|
|
y: torch.Tensor,
|
|
|
|
|
|
act: Optional[str] = None,
|
|
|
|
|
|
x_perchannel_max: Optional[torch.Tensor] = None,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
w_perchannel_max: Optional[torch.Tensor] = None,
|
2025-12-31 10:16:33 +08:00
|
|
|
|
topk_ids: Optional[torch.Tensor] = None,
|
|
|
|
|
|
topk_w: Optional[torch.Tensor] = None,
|
|
|
|
|
|
bias: Optional[torch.Tensor] = None,
|
|
|
|
|
|
tgemm_type: Optional[str] = None,
|
|
|
|
|
|
tweight_type: Optional[str] = None,
|
|
|
|
|
|
scale_n: Optional[int] = 0,
|
|
|
|
|
|
scale_k: Optional[int] = 0,
|
|
|
|
|
|
use_pack_int4: Optional[bool] = False,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
sort_mode: Optional[bool] = True,
|
|
|
|
|
|
) -> None:
|
2025-12-10 17:51:24 +08:00
|
|
|
|
return None
|
2025-12-10 12:05:39 +08:00
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
moe_fc.register_fake(fake_moe_fc)
|
2025-12-10 12:05:39 +08:00
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
@custom_op("_C::moe_post", mutates_args=())
|
|
|
|
|
|
def moe_post(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
moe_index: torch.Tensor,
|
|
|
|
|
|
normed_scale: torch.Tensor,
|
|
|
|
|
|
dequant_scale: torch.Tensor,
|
|
|
|
|
|
y: torch.Tensor,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y)
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 12:05:39 +08:00
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
@impl("_C::moe_post", "CUDA")
|
|
|
|
|
|
def moe_post_cuda(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
moe_index: torch.Tensor,
|
|
|
|
|
|
normed_scale: torch.Tensor,
|
|
|
|
|
|
dequant_scale: torch.Tensor,
|
|
|
|
|
|
y: torch.Tensor,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y)
|
2025-12-10 12:05:39 +08:00
|
|
|
|
|
|
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
def fake_moe_post(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
moe_index: torch.Tensor,
|
|
|
|
|
|
normed_scale: torch.Tensor,
|
|
|
|
|
|
dequant_scale: torch.Tensor,
|
|
|
|
|
|
y: torch.Tensor,
|
|
|
|
|
|
) -> None:
|
2025-12-10 12:05:39 +08:00
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-10 17:51:24 +08:00
|
|
|
|
moe_post.register_fake(fake_moe_post)
|
2025-12-24 13:45:55 +08:00
|
|
|
|
|
|
|
|
|
|
|
2025-12-31 10:16:33 +08:00
|
|
|
|
@custom_op("_C::moe_sigmoid_group_topk_norm", mutates_args=())
|
|
|
|
|
|
def moe_sigmoid_group_topk_norm(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
topk_index: torch.Tensor,
|
|
|
|
|
|
norm_score: torch.Tensor,
|
|
|
|
|
|
block_static: torch.Tensor,
|
|
|
|
|
|
bias: torch.Tensor,
|
|
|
|
|
|
scale: float,
|
|
|
|
|
|
n_group: int,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
topk_group: int,
|
2025-12-31 10:16:33 +08:00
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.moe_sigmoid_group_topk_norm(
|
2025-12-31 10:16:33 +08:00
|
|
|
|
x=x,
|
|
|
|
|
|
norm_score=norm_score,
|
|
|
|
|
|
topk_index=topk_index,
|
|
|
|
|
|
block_static=block_static,
|
|
|
|
|
|
bias=bias,
|
|
|
|
|
|
n_group=n_group,
|
|
|
|
|
|
topk_group=topk_group,
|
|
|
|
|
|
scale=scale,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-31 10:16:33 +08:00
|
|
|
|
@impl("_C::moe_sigmoid_group_topk_norm", "CUDA")
|
|
|
|
|
|
def moe_sigmoid_group_topk_norm_cuda(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
topk_index: torch.Tensor,
|
|
|
|
|
|
norm_score: torch.Tensor,
|
|
|
|
|
|
block_static: torch.Tensor,
|
|
|
|
|
|
bias: torch.Tensor,
|
|
|
|
|
|
scale: float,
|
|
|
|
|
|
n_group: int,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
topk_group: int,
|
2025-12-31 10:16:33 +08:00
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.moe_sigmoid_group_topk_norm(
|
2025-12-31 10:16:33 +08:00
|
|
|
|
x=x,
|
|
|
|
|
|
norm_score=norm_score,
|
|
|
|
|
|
topk_index=topk_index,
|
|
|
|
|
|
block_static=block_static,
|
|
|
|
|
|
bias=bias,
|
|
|
|
|
|
n_group=n_group,
|
|
|
|
|
|
topk_group=topk_group,
|
|
|
|
|
|
scale=scale,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-31 10:16:33 +08:00
|
|
|
|
def _fake_moe_sigmoid_group_topk_norm(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
topk_index: torch.Tensor,
|
|
|
|
|
|
norm_score: torch.Tensor,
|
|
|
|
|
|
block_static: torch.Tensor,
|
|
|
|
|
|
bias: torch.Tensor,
|
|
|
|
|
|
scale: float,
|
|
|
|
|
|
n_group: int,
|
2026-01-06 13:51:53 +08:00
|
|
|
|
topk_group: int,
|
2025-12-31 10:16:33 +08:00
|
|
|
|
) -> None:
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2025-12-31 10:16:33 +08:00
|
|
|
|
moe_sigmoid_group_topk_norm.register_fake(_fake_moe_sigmoid_group_topk_norm)
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
|
|
|
|
|
|
2025-12-24 13:45:55 +08:00
|
|
|
|
##################################################
|
|
|
|
|
|
# --------------- awq_dequantize -----------------
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
@custom_op("_C::awq_dequantize", mutates_args=())
|
|
|
|
|
|
def awq_dequantize(
|
|
|
|
|
|
qweight: torch.Tensor,
|
|
|
|
|
|
scales: torch.Tensor,
|
|
|
|
|
|
zeros: torch.Tensor,
|
|
|
|
|
|
quant_type: int = 0,
|
|
|
|
|
|
align_type: int = 1,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
weight = torch.empty(
|
|
|
|
|
|
(qweight.shape[0], qweight.shape[1] * 8),
|
|
|
|
|
|
dtype=torch.float16,
|
|
|
|
|
|
device=qweight.device,
|
|
|
|
|
|
)
|
|
|
|
|
|
group_m = int(qweight.shape[0] / scales.shape[0])
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.awq_dequantize(
|
2025-12-24 13:45:55 +08:00
|
|
|
|
qweight=qweight,
|
|
|
|
|
|
scales=scales,
|
|
|
|
|
|
zeros=zeros,
|
|
|
|
|
|
weight=weight,
|
|
|
|
|
|
group_m=group_m,
|
|
|
|
|
|
quant_type=quant_type,
|
|
|
|
|
|
align_type=align_type,
|
|
|
|
|
|
)
|
|
|
|
|
|
return weight
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@impl("_C::awq_dequantize", "CUDA")
|
|
|
|
|
|
def awq_dequantize_cuda(
|
|
|
|
|
|
qweight: torch.Tensor,
|
|
|
|
|
|
scales: torch.Tensor,
|
|
|
|
|
|
zeros: torch.Tensor,
|
|
|
|
|
|
quant_type: int = 0,
|
|
|
|
|
|
align_type: int = 1,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
weight = torch.empty(
|
|
|
|
|
|
(qweight.shape[0], qweight.shape[1] * 8),
|
|
|
|
|
|
dtype=torch.float16,
|
|
|
|
|
|
device=qweight.device,
|
|
|
|
|
|
)
|
|
|
|
|
|
group_m = int(qweight.shape[0] / scales.shape[0])
|
2026-02-13 14:07:10 +08:00
|
|
|
|
kunlun_ops.awq_dequantize(
|
2025-12-24 13:45:55 +08:00
|
|
|
|
qweight=qweight,
|
|
|
|
|
|
scales=scales,
|
|
|
|
|
|
zeros=zeros,
|
|
|
|
|
|
weight=weight,
|
|
|
|
|
|
group_m=group_m,
|
|
|
|
|
|
quant_type=quant_type,
|
|
|
|
|
|
align_type=align_type,
|
|
|
|
|
|
)
|
|
|
|
|
|
return weight
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _fake_awq_dequantize(
|
|
|
|
|
|
qweight: torch.Tensor,
|
|
|
|
|
|
scales: torch.Tensor,
|
|
|
|
|
|
zeros: torch.Tensor,
|
|
|
|
|
|
quant_type: int = 0,
|
|
|
|
|
|
align_type: int = 1,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
weight = torch.empty(
|
|
|
|
|
|
(qweight.shape[0], qweight.shape[1] * 8),
|
|
|
|
|
|
dtype=torch.float16,
|
|
|
|
|
|
device=qweight.device,
|
|
|
|
|
|
)
|
|
|
|
|
|
return weight
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
awq_dequantize.register_fake(_fake_awq_dequantize)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
# ------------------ awq_gemm -------------------
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
@custom_op("_C::awq_gemm", mutates_args=())
|
|
|
|
|
|
def awq_gemm(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
qweight: torch.Tensor,
|
|
|
|
|
|
scale: torch.Tensor,
|
|
|
|
|
|
zeros: torch.Tensor,
|
|
|
|
|
|
align_type: int = 1,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
out = torch.empty(
|
|
|
|
|
|
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
|
|
|
|
|
|
)
|
|
|
|
|
|
group_size = int(qweight.shape[0] / scale.shape[0])
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.awq_gemm(
|
2025-12-24 13:45:55 +08:00
|
|
|
|
x=x,
|
|
|
|
|
|
w=qweight,
|
|
|
|
|
|
scale=scale,
|
|
|
|
|
|
zeros=zeros,
|
|
|
|
|
|
out=out,
|
|
|
|
|
|
align_type=align_type,
|
|
|
|
|
|
group_size=group_size,
|
|
|
|
|
|
)
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@impl("_C::awq_gemm", "CUDA")
|
|
|
|
|
|
def awq_gemm_cuda(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
qweight: torch.Tensor,
|
|
|
|
|
|
scale: torch.Tensor,
|
|
|
|
|
|
zeros: torch.Tensor,
|
|
|
|
|
|
align_type: int = 1,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
out = torch.empty(
|
|
|
|
|
|
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
|
|
|
|
|
|
)
|
|
|
|
|
|
group_size = int(qweight.shape[0] / scale.shape[0])
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.awq_gemm(
|
2025-12-24 13:45:55 +08:00
|
|
|
|
x=x,
|
|
|
|
|
|
w=qweight,
|
|
|
|
|
|
scale=scale,
|
|
|
|
|
|
zeros=zeros,
|
|
|
|
|
|
out=out,
|
|
|
|
|
|
align_type=align_type,
|
|
|
|
|
|
group_size=group_size,
|
|
|
|
|
|
)
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _fake_awq_gemm(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
qweight: torch.Tensor,
|
|
|
|
|
|
scale: torch.Tensor,
|
|
|
|
|
|
zeros: torch.Tensor,
|
|
|
|
|
|
align_type: int = 1,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
out = torch.empty(
|
|
|
|
|
|
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
|
|
|
|
|
|
)
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
awq_gemm.register_fake(_fake_awq_gemm)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
# ---------------- gptq_shuffle ------------------
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
@custom_op("_C::gptq_shuffle", mutates_args=())
|
|
|
|
|
|
def gptq_shuffle(
|
|
|
|
|
|
q_weight: torch.Tensor,
|
|
|
|
|
|
q_perm: torch.Tensor,
|
|
|
|
|
|
bit: int,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
|
2025-12-24 13:45:55 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@impl("_C::gptq_shuffle", "CUDA")
|
|
|
|
|
|
def gptq_shuffle_cuda(
|
|
|
|
|
|
q_weight: torch.Tensor,
|
|
|
|
|
|
q_perm: torch.Tensor,
|
|
|
|
|
|
bit: int,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
|
2025-12-24 13:45:55 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _fake_gptq_shuffle(
|
|
|
|
|
|
q_weight: torch.Tensor,
|
|
|
|
|
|
q_perm: torch.Tensor,
|
|
|
|
|
|
bit: int,
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-01-05 22:55:35 +08:00
|
|
|
|
gptq_shuffle.register_fake(_fake_gptq_shuffle)
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2026-01-05 22:55:35 +08:00
|
|
|
|
##################################################
|
2026-01-06 13:51:53 +08:00
|
|
|
|
# ------------- concat_and_cache_mla -------------
|
2026-01-05 22:55:35 +08:00
|
|
|
|
##################################################
|
|
|
|
|
|
@custom_op("_C::concat_and_cache_mla", mutates_args=())
|
|
|
|
|
|
def concat_and_cache_mla(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
kv_c: torch.Tensor, # [num_tokens, kv_lora_rank]
|
|
|
|
|
|
k_pe: torch.Tensor, # [num_tokens, pe_dim]
|
|
|
|
|
|
kv_cache: torch.Tensor, # [num_blocks, block_size, (kv_lora_rank + pe_dim)]
|
|
|
|
|
|
slot_mapping: torch.Tensor, # [num_tokens] or [num_actual_tokens]
|
2026-01-05 22:55:35 +08:00
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.concat_and_cache_mla(
|
2026-01-05 22:55:35 +08:00
|
|
|
|
kv_c=kv_c,
|
|
|
|
|
|
k_pe=k_pe,
|
|
|
|
|
|
slot_mapping=slot_mapping,
|
|
|
|
|
|
kv_cache=kv_cache,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2026-01-05 22:55:35 +08:00
|
|
|
|
@impl("_C::concat_and_cache_mla", "CUDA")
|
|
|
|
|
|
def concat_and_cache_mla_cuda(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
kv_c: torch.Tensor, # [num_tokens, kv_lora_rank]
|
|
|
|
|
|
k_pe: torch.Tensor, # [num_tokens, pe_dim]
|
|
|
|
|
|
kv_cache: torch.Tensor, # [num_blocks, block_size, (kv_lora_rank + pe_dim)]
|
|
|
|
|
|
slot_mapping: torch.Tensor, # [num_tokens] or [num_actual_tokens]
|
2026-01-05 22:55:35 +08:00
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.concat_and_cache_mla(
|
2026-01-05 22:55:35 +08:00
|
|
|
|
kv_c=kv_c,
|
|
|
|
|
|
k_pe=k_pe,
|
|
|
|
|
|
slot_mapping=slot_mapping,
|
|
|
|
|
|
kv_cache=kv_cache,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
2026-01-05 22:55:35 +08:00
|
|
|
|
def _fake_concat_and_cache_mla(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
kv_c: torch.Tensor, # [num_tokens, kv_lora_rank]
|
|
|
|
|
|
k_pe: torch.Tensor, # [num_tokens, pe_dim]
|
|
|
|
|
|
kv_cache: torch.Tensor, # [num_blocks, block_size, (kv_lora_rank + pe_dim)]
|
|
|
|
|
|
slot_mapping: torch.Tensor, # [num_tokens] or [num_actual_tokens]
|
2026-01-05 22:55:35 +08:00
|
|
|
|
) -> None:
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
|
|
|
|
|
concat_and_cache_mla.register_fake(_fake_concat_and_cache_mla)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
######################################################
|
|
|
|
|
|
# -------------- scaled_int8_quant -------------------
|
|
|
|
|
|
######################################################
|
|
|
|
|
|
@custom_op("_C::scaled_int8_quant", mutates_args=())
|
|
|
|
|
|
def scaled_int8_quant(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
scale: torch.Tensor,
|
|
|
|
|
|
azp: Optional[torch.Tensor] = None,
|
|
|
|
|
|
symmetric: bool = True,
|
|
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]:
|
|
|
|
|
|
static = False
|
|
|
|
|
|
x_q = torch.empty_like(x, dtype=torch.int8, device=x.device)
|
|
|
|
|
|
if scale is not None: # static
|
|
|
|
|
|
static = True
|
|
|
|
|
|
torch.ops.xspeedgate_ops.static_scaled_int8_quant(x_q, x, scale, azp)
|
|
|
|
|
|
else: # dynamic
|
|
|
|
|
|
scale = torch.empty(
|
|
|
|
|
|
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
|
|
|
|
|
|
)
|
|
|
|
|
|
azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32)
|
|
|
|
|
|
if symmetric:
|
|
|
|
|
|
# NOTE: For quant2d ops, scale represents max.
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True)
|
2026-01-06 13:51:53 +08:00
|
|
|
|
else:
|
|
|
|
|
|
torch.ops.xspeedgate_ops.dynamic_scaled_int8_quant(
|
|
|
|
|
|
x_q, x.contiguous(), scale, azp
|
|
|
|
|
|
)
|
|
|
|
|
|
return x_q, scale, azp, static
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@impl("_C::scaled_int8_quant", "CUDA")
|
|
|
|
|
|
def scaled_int8_quant_cuda(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
scale: torch.Tensor,
|
|
|
|
|
|
azp: Optional[torch.Tensor] = None,
|
|
|
|
|
|
symmetric: bool = True,
|
|
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]:
|
|
|
|
|
|
static = False
|
|
|
|
|
|
x_q = torch.empty_like(x, dtype=torch.int8, device=x.device)
|
|
|
|
|
|
if scale is not None: # static
|
|
|
|
|
|
static = True
|
|
|
|
|
|
torch.ops.xspeedgate_ops.static_scaled_int8_quant(x_q, x, scale, azp)
|
|
|
|
|
|
else: # dynamic
|
|
|
|
|
|
scale = torch.empty(
|
|
|
|
|
|
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
|
|
|
|
|
|
)
|
|
|
|
|
|
azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32)
|
|
|
|
|
|
if symmetric:
|
|
|
|
|
|
# NOTE: For quant2d ops, scale represents max.
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True)
|
2026-01-06 13:51:53 +08:00
|
|
|
|
else:
|
|
|
|
|
|
torch.ops.xspeedgate_ops.dynamic_scaled_int8_quant(
|
|
|
|
|
|
x_q, x.contiguous(), scale, azp
|
|
|
|
|
|
)
|
|
|
|
|
|
return x_q, scale, azp, static
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-01-06 16:07:29 +08:00
|
|
|
|
def _fake_scaled_int8_quant(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
scale: torch.Tensor,
|
|
|
|
|
|
azp: Optional[torch.Tensor] = None,
|
|
|
|
|
|
symmetric: bool = True,
|
|
|
|
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, bool]:
|
2026-01-06 16:07:29 +08:00
|
|
|
|
x_q = torch.empty_like(x, dtype=torch.int8, device=x.device)
|
|
|
|
|
|
scale = torch.empty(
|
|
|
|
|
|
(x.numel() // x.shape[-1], 1), device=x.device, dtype=torch.float32
|
|
|
|
|
|
)
|
|
|
|
|
|
azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32)
|
2026-01-06 13:51:53 +08:00
|
|
|
|
return x_q, scale, azp, False
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-01-06 16:07:29 +08:00
|
|
|
|
scaled_int8_quant.register_fake(_fake_scaled_int8_quant)
|
2026-01-06 13:51:53 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
######################################################
|
|
|
|
|
|
# ---------------- cutlass_scaled_mm -----------------
|
|
|
|
|
|
######################################################
|
|
|
|
|
|
@custom_op("_C::cutlass_scaled_mm", mutates_args=())
|
|
|
|
|
|
def cutlass_scaled_mm(
|
|
|
|
|
|
a: torch.Tensor,
|
|
|
|
|
|
b: torch.Tensor,
|
|
|
|
|
|
scale_a: torch.Tensor,
|
|
|
|
|
|
scale_b: torch.Tensor,
|
|
|
|
|
|
out_dtype: torch.dtype,
|
|
|
|
|
|
bias: Optional[torch.Tensor] = None,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
|
2026-01-06 20:52:12 +08:00
|
|
|
|
torch.ops.xspeedgate_ops.cutlass_scaled_mm(
|
|
|
|
|
|
out, a.contiguous(), b.contiguous(), scale_a, scale_b, bias
|
|
|
|
|
|
)
|
2026-01-06 13:51:53 +08:00
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@impl("_C::cutlass_scaled_mm", "CUDA")
|
|
|
|
|
|
def cutlass_scaled_mm_cuda(
|
|
|
|
|
|
a: torch.Tensor,
|
|
|
|
|
|
b: torch.Tensor,
|
|
|
|
|
|
scale_a: torch.Tensor,
|
|
|
|
|
|
scale_b: torch.Tensor,
|
|
|
|
|
|
out_dtype: torch.dtype,
|
|
|
|
|
|
bias: Optional[torch.Tensor] = None,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
|
2026-01-06 20:52:12 +08:00
|
|
|
|
torch.ops.xspeedgate_ops.cutlass_scaled_mm(
|
|
|
|
|
|
out, a.contiguous(), b.contiguous(), scale_a, scale_b, bias
|
|
|
|
|
|
)
|
2026-01-06 13:51:53 +08:00
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fake_cutlass_scaled_mm(
|
|
|
|
|
|
a: torch.Tensor,
|
|
|
|
|
|
b: torch.Tensor,
|
|
|
|
|
|
scale_a: torch.Tensor,
|
|
|
|
|
|
scale_b: torch.Tensor,
|
|
|
|
|
|
out_dtype: torch.dtype,
|
|
|
|
|
|
bias: Optional[torch.Tensor] = None,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
return torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cutlass_scaled_mm.register_fake(fake_cutlass_scaled_mm)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
######################################################
|
|
|
|
|
|
# ------------ cutlass_scaled_mm_azp -----------------
|
|
|
|
|
|
######################################################
|
|
|
|
|
|
@custom_op("_C::cutlass_scaled_mm_azp", mutates_args=())
|
|
|
|
|
|
def cutlass_scaled_mm_azp(
|
|
|
|
|
|
a: torch.Tensor,
|
|
|
|
|
|
b: torch.Tensor,
|
|
|
|
|
|
scale_a: torch.Tensor,
|
|
|
|
|
|
scale_b: torch.Tensor,
|
|
|
|
|
|
out_dtype: torch.dtype,
|
|
|
|
|
|
azp_adj: torch.Tensor,
|
|
|
|
|
|
azp: Optional[torch.Tensor] = None,
|
|
|
|
|
|
bias: Optional[torch.Tensor] = None,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
|
|
|
|
|
|
torch.ops.xspeedgate_ops.cutlass_scaled_mm_azp(
|
2026-01-06 20:52:12 +08:00
|
|
|
|
out, a.contiguous(), b.contiguous(), scale_a, scale_b, azp_adj, azp, bias
|
2026-01-06 13:51:53 +08:00
|
|
|
|
)
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@impl("_C::cutlass_scaled_mm_azp", "CUDA")
|
|
|
|
|
|
def cutlass_scaled_mm_azp_cuda(
|
|
|
|
|
|
a: torch.Tensor,
|
|
|
|
|
|
b: torch.Tensor,
|
|
|
|
|
|
scale_a: torch.Tensor,
|
|
|
|
|
|
scale_b: torch.Tensor,
|
|
|
|
|
|
out_dtype: torch.dtype,
|
|
|
|
|
|
azp_adj: torch.Tensor,
|
|
|
|
|
|
azp: Optional[torch.Tensor] = None,
|
|
|
|
|
|
bias: Optional[torch.Tensor] = None,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
|
|
|
|
|
|
torch.ops.xspeedgate_ops.cutlass_scaled_mm_azp(
|
2026-01-06 20:52:12 +08:00
|
|
|
|
out, a.contiguous(), b.contiguous(), scale_a, scale_b, azp_adj, azp, bias
|
2026-01-06 13:51:53 +08:00
|
|
|
|
)
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fake_cutlass_scaled_mm_azp(
|
|
|
|
|
|
a: torch.Tensor,
|
|
|
|
|
|
b: torch.Tensor,
|
|
|
|
|
|
scale_a: torch.Tensor,
|
|
|
|
|
|
scale_b: torch.Tensor,
|
|
|
|
|
|
out_dtype: torch.dtype,
|
|
|
|
|
|
azp_adj: torch.Tensor,
|
|
|
|
|
|
azp: Optional[torch.Tensor] = None,
|
|
|
|
|
|
bias: Optional[torch.Tensor] = None,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
return torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cutlass_scaled_mm_azp.register_fake(fake_cutlass_scaled_mm_azp)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
# ------------------ matmul ---------------------
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
@custom_op("_C::matmul", mutates_args=())
|
|
|
|
|
|
def matmul(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
w: torch.Tensor,
|
|
|
|
|
|
out_dtype: torch.dtype,
|
|
|
|
|
|
x_trans: bool = False,
|
|
|
|
|
|
w_trans: bool = True,
|
|
|
|
|
|
alpha: float = 1.0,
|
|
|
|
|
|
beta: float = 0.0,
|
|
|
|
|
|
bias: torch.Tensor = None,
|
|
|
|
|
|
x_max: torch.Tensor = None,
|
|
|
|
|
|
w_max: torch.Tensor = None,
|
|
|
|
|
|
x_pc_max: torch.Tensor = None,
|
|
|
|
|
|
w_pc_max: torch.Tensor = None,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
out = torch.empty(
|
|
|
|
|
|
(x.shape[0], w.shape[0] if w_trans else w.shape[1]),
|
|
|
|
|
|
dtype=out_dtype,
|
|
|
|
|
|
device=x.device,
|
|
|
|
|
|
)
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.matmul(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
x=x.contiguous(),
|
|
|
|
|
|
w=w.contiguous(),
|
|
|
|
|
|
out=out,
|
|
|
|
|
|
x_trans=x_trans,
|
|
|
|
|
|
w_trans=w_trans,
|
|
|
|
|
|
alpha=alpha,
|
|
|
|
|
|
beta=beta,
|
|
|
|
|
|
bias=bias,
|
|
|
|
|
|
x_max=x_max,
|
|
|
|
|
|
w_max=w_max,
|
|
|
|
|
|
x_pc_max=x_pc_max,
|
|
|
|
|
|
w_pc_max=w_pc_max,
|
|
|
|
|
|
)
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@impl("_C::matmul", "CUDA")
|
|
|
|
|
|
def matmul_cuda(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
w: torch.Tensor,
|
|
|
|
|
|
out_dtype: torch.dtype,
|
|
|
|
|
|
x_trans: bool = False,
|
|
|
|
|
|
w_trans: bool = True,
|
|
|
|
|
|
alpha: float = 1.0,
|
|
|
|
|
|
beta: float = 0.0,
|
|
|
|
|
|
bias: torch.Tensor = None,
|
|
|
|
|
|
x_max: torch.Tensor = None,
|
|
|
|
|
|
w_max: torch.Tensor = None,
|
|
|
|
|
|
x_pc_max: torch.Tensor = None,
|
|
|
|
|
|
w_pc_max: torch.Tensor = None,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
out = torch.empty(
|
|
|
|
|
|
(x.shape[0], w.shape[0] if w_trans else w.shape[1]),
|
|
|
|
|
|
dtype=out_dtype,
|
|
|
|
|
|
device=x.device,
|
|
|
|
|
|
)
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.matmul(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
x=x.contiguous(),
|
|
|
|
|
|
w=w.contiguous(),
|
|
|
|
|
|
out=out,
|
|
|
|
|
|
x_trans=x_trans,
|
|
|
|
|
|
w_trans=w_trans,
|
|
|
|
|
|
alpha=alpha,
|
|
|
|
|
|
beta=beta,
|
|
|
|
|
|
bias=bias,
|
|
|
|
|
|
x_max=x_max,
|
|
|
|
|
|
w_max=w_max,
|
|
|
|
|
|
x_pc_max=x_pc_max,
|
|
|
|
|
|
w_pc_max=w_pc_max,
|
|
|
|
|
|
)
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _fake_matmul(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
w: torch.Tensor,
|
|
|
|
|
|
out_dtype: torch.dtype,
|
|
|
|
|
|
x_trans: bool = False,
|
|
|
|
|
|
w_trans: bool = True,
|
|
|
|
|
|
alpha: float = 1.0,
|
|
|
|
|
|
beta: float = 0.0,
|
|
|
|
|
|
bias: torch.Tensor = None,
|
|
|
|
|
|
x_max: torch.Tensor = None,
|
|
|
|
|
|
w_max: torch.Tensor = None,
|
|
|
|
|
|
x_pc_max: torch.Tensor = None,
|
|
|
|
|
|
w_pc_max: torch.Tensor = None,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
return torch.empty(
|
2026-01-06 16:07:29 +08:00
|
|
|
|
(x.shape[0], w.shape[0] if w_trans else w.shape[1]),
|
2026-01-06 13:51:53 +08:00
|
|
|
|
dtype=out_dtype,
|
|
|
|
|
|
device=x.device,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
matmul.register_fake(_fake_matmul)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
# ------------------- quant2d --------------------
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
@custom_op("_C::quant2d", mutates_args=())
|
|
|
|
|
|
def quant2d(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
x_q: torch.Tensor,
|
|
|
|
|
|
max: torch.Tensor,
|
|
|
|
|
|
force_sdnn: bool = False,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.quant2d(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
x=x,
|
|
|
|
|
|
y=x_q,
|
|
|
|
|
|
max=max,
|
|
|
|
|
|
force_sdnn=force_sdnn,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@impl("_C::quant2d", "CUDA")
|
|
|
|
|
|
def quant2d_cuda(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
x_q: torch.Tensor,
|
|
|
|
|
|
max: torch.Tensor,
|
|
|
|
|
|
force_sdnn: bool = False,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.quant2d(
|
2026-01-06 13:51:53 +08:00
|
|
|
|
x=x,
|
|
|
|
|
|
y=x_q,
|
|
|
|
|
|
max=max,
|
|
|
|
|
|
force_sdnn=force_sdnn,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _fake_quant2d(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
x_q: torch.Tensor,
|
|
|
|
|
|
max: torch.Tensor,
|
|
|
|
|
|
force_sdnn: bool = False,
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
quant2d.register_fake(_fake_quant2d)
|
2026-01-13 20:22:14 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
# --------------- penalties -----------------
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
@custom_op("_C::apply_repetition_penalties_", mutates_args=())
|
|
|
|
|
|
def apply_repetition_penalties_(
|
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
|
prompt_mask: torch.Tensor,
|
|
|
|
|
|
output_mask: torch.Tensor,
|
2026-01-26 18:56:05 +08:00
|
|
|
|
repetition_penalties: torch.Tensor,
|
2026-01-13 20:22:14 +08:00
|
|
|
|
) -> None:
|
|
|
|
|
|
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
|
2026-01-26 18:56:05 +08:00
|
|
|
|
1, logits.size(1)
|
|
|
|
|
|
)
|
2026-01-13 20:22:14 +08:00
|
|
|
|
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
|
2026-01-26 18:56:05 +08:00
|
|
|
|
penalties = torch.where(prompt_mask | output_mask, repetition_penalties, 1.0)
|
2026-01-13 20:22:14 +08:00
|
|
|
|
# If logits are positive, divide by penalty, otherwise multiply by penalty.
|
|
|
|
|
|
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
|
|
|
|
|
|
logits *= scaling
|
|
|
|
|
|
|
2026-01-26 18:56:05 +08:00
|
|
|
|
|
2026-01-13 20:22:14 +08:00
|
|
|
|
@impl("_C::apply_repetition_penalties_", "CUDA")
|
2026-02-11 12:04:14 +08:00
|
|
|
|
def apply_repetition_penalties_cuda(
|
2026-01-13 20:22:14 +08:00
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
|
prompt_mask: torch.Tensor,
|
|
|
|
|
|
output_mask: torch.Tensor,
|
2026-01-26 18:56:05 +08:00
|
|
|
|
repetition_penalties: torch.Tensor,
|
2026-01-13 20:22:14 +08:00
|
|
|
|
) -> None:
|
|
|
|
|
|
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
|
2026-01-26 18:56:05 +08:00
|
|
|
|
1, logits.size(1)
|
|
|
|
|
|
)
|
2026-01-13 20:22:14 +08:00
|
|
|
|
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
|
2026-01-26 18:56:05 +08:00
|
|
|
|
penalties = torch.where(prompt_mask | output_mask, repetition_penalties, 1.0)
|
2026-01-13 20:22:14 +08:00
|
|
|
|
# If logits are positive, divide by penalty, otherwise multiply by penalty.
|
|
|
|
|
|
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
|
|
|
|
|
|
logits *= scaling
|
2026-01-26 18:56:05 +08:00
|
|
|
|
|
|
|
|
|
|
|
2026-01-17 16:52:02 +08:00
|
|
|
|
##################################################
|
|
|
|
|
|
# --------------- I8_mqa_logits -----------------
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
@custom_op("_C::I8_mqa_logits", mutates_args=())
|
|
|
|
|
|
def I8_mqa_logits(
|
|
|
|
|
|
q: torch.Tensor,
|
|
|
|
|
|
fused_kv_cache: List[torch.Tensor],
|
|
|
|
|
|
weights: torch.Tensor,
|
|
|
|
|
|
context_q_lens: List[torch.Tensor],
|
|
|
|
|
|
context_k_lens: List[torch.Tensor],
|
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
|
clean_logits: bool,
|
|
|
|
|
|
max_seq_q: Optional[int] = 0,
|
2026-01-26 18:56:05 +08:00
|
|
|
|
max_seq_k: Optional[int] = 0,
|
2026-01-17 16:52:02 +08:00
|
|
|
|
is_causal: Optional[bool] = False,
|
2026-01-26 18:56:05 +08:00
|
|
|
|
use_xfa_boost: Optional[bool] = False,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.I8_mqa_logits(
|
2026-01-17 16:52:02 +08:00
|
|
|
|
q=q,
|
|
|
|
|
|
fused_kv_cache=fused_kv_cache,
|
|
|
|
|
|
weights=weights,
|
|
|
|
|
|
context_q_lens=context_q_lens,
|
|
|
|
|
|
context_k_lens=context_k_lens,
|
|
|
|
|
|
logits=logits,
|
|
|
|
|
|
clean_logits=clean_logits,
|
|
|
|
|
|
max_seq_q=max_seq_q,
|
|
|
|
|
|
max_seq_k=max_seq_k,
|
|
|
|
|
|
is_causal=is_causal,
|
|
|
|
|
|
use_xfa_boost=use_xfa_boost,
|
|
|
|
|
|
)
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-26 18:56:05 +08:00
|
|
|
|
|
2026-01-17 16:52:02 +08:00
|
|
|
|
@impl("_C::I8_mqa_logits", "CUDA")
|
|
|
|
|
|
def I8_mqa_logits_cuda(
|
|
|
|
|
|
q: torch.Tensor,
|
|
|
|
|
|
fused_kv_cache: List[torch.Tensor],
|
|
|
|
|
|
weights: torch.Tensor,
|
|
|
|
|
|
context_q_lens: List[torch.Tensor],
|
|
|
|
|
|
context_k_lens: List[torch.Tensor],
|
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
|
clean_logits: bool,
|
|
|
|
|
|
max_seq_q: Optional[int] = 0,
|
2026-01-26 18:56:05 +08:00
|
|
|
|
max_seq_k: Optional[int] = 0,
|
2026-01-17 16:52:02 +08:00
|
|
|
|
is_causal: Optional[bool] = False,
|
2026-01-26 18:56:05 +08:00
|
|
|
|
use_xfa_boost: Optional[bool] = False,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.I8_mqa_logits(
|
2026-01-17 16:52:02 +08:00
|
|
|
|
q=q,
|
|
|
|
|
|
fused_kv_cache=fused_kv_cache,
|
|
|
|
|
|
weights=weights,
|
|
|
|
|
|
context_q_lens=context_q_lens,
|
|
|
|
|
|
context_k_lens=context_k_lens,
|
|
|
|
|
|
logits=logits,
|
|
|
|
|
|
clean_logits=clean_logits,
|
|
|
|
|
|
max_seq_q=max_seq_q,
|
|
|
|
|
|
max_seq_k=max_seq_k,
|
|
|
|
|
|
is_causal=is_causal,
|
|
|
|
|
|
use_xfa_boost=use_xfa_boost,
|
|
|
|
|
|
)
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-26 18:56:05 +08:00
|
|
|
|
|
2026-01-17 16:52:02 +08:00
|
|
|
|
def _fake_I8_mqa_logits(
|
|
|
|
|
|
q: torch.Tensor,
|
|
|
|
|
|
fused_kv_cache: List[torch.Tensor],
|
|
|
|
|
|
weights: torch.Tensor,
|
|
|
|
|
|
context_q_lens: List[torch.Tensor],
|
|
|
|
|
|
context_k_lens: List[torch.Tensor],
|
|
|
|
|
|
logits: torch.Tensor,
|
|
|
|
|
|
clean_logits: bool,
|
|
|
|
|
|
max_seq_q: Optional[int] = 0,
|
2026-01-26 18:56:05 +08:00
|
|
|
|
max_seq_k: Optional[int] = 0,
|
2026-01-17 16:52:02 +08:00
|
|
|
|
is_causal: Optional[bool] = False,
|
2026-01-26 18:56:05 +08:00
|
|
|
|
use_xfa_boost: Optional[bool] = False,
|
|
|
|
|
|
) -> None:
|
2026-01-17 16:52:02 +08:00
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-26 18:56:05 +08:00
|
|
|
|
|
2026-01-17 16:52:02 +08:00
|
|
|
|
I8_mqa_logits.register_fake(_fake_I8_mqa_logits)
|
|
|
|
|
|
|
2026-01-26 18:56:05 +08:00
|
|
|
|
|
2026-01-17 16:52:02 +08:00
|
|
|
|
##################################################
|
|
|
|
|
|
# ------------- I8_paged_mqa_logits --------------
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
@custom_op("_C::I8_paged_mqa_logits", mutates_args=())
|
|
|
|
|
|
def I8_paged_mqa_logits(
|
|
|
|
|
|
q: torch.Tensor,
|
|
|
|
|
|
fused_kv_cache: List[torch.Tensor],
|
|
|
|
|
|
weights: torch.Tensor,
|
|
|
|
|
|
context_lens: List[torch.Tensor],
|
|
|
|
|
|
block_table: torch.Tensor,
|
|
|
|
|
|
max_context_len: int,
|
|
|
|
|
|
clean_logits: bool,
|
|
|
|
|
|
out: torch.Tensor,
|
2026-01-26 18:56:05 +08:00
|
|
|
|
use_xfa_boost: Optional[bool] = False,
|
|
|
|
|
|
) -> None:
|
2026-02-13 14:07:10 +08:00
|
|
|
|
kunlun_ops.I8_paged_mqa_logits(
|
2026-01-17 16:52:02 +08:00
|
|
|
|
q=q,
|
|
|
|
|
|
fused_kv_cache=fused_kv_cache,
|
|
|
|
|
|
weights=weights,
|
|
|
|
|
|
context_lens=context_lens,
|
|
|
|
|
|
block_table=block_table,
|
|
|
|
|
|
max_context_len=max_context_len,
|
|
|
|
|
|
clean_logits=clean_logits,
|
|
|
|
|
|
out=out,
|
2026-01-26 18:56:05 +08:00
|
|
|
|
use_xfa_boost=use_xfa_boost,
|
|
|
|
|
|
)
|
2026-01-17 16:52:02 +08:00
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-26 18:56:05 +08:00
|
|
|
|
|
2026-01-17 16:52:02 +08:00
|
|
|
|
@impl("_C::I8_paged_mqa_logits", "CUDA")
|
|
|
|
|
|
def I8_paged_mqa_logits_cuda(
|
|
|
|
|
|
q: torch.Tensor,
|
|
|
|
|
|
fused_kv_cache: List[torch.Tensor],
|
|
|
|
|
|
weights: torch.Tensor,
|
|
|
|
|
|
context_lens: List[torch.Tensor],
|
|
|
|
|
|
block_table: torch.Tensor,
|
|
|
|
|
|
max_context_len: int,
|
|
|
|
|
|
clean_logits: bool,
|
|
|
|
|
|
out: torch.Tensor,
|
2026-01-26 18:56:05 +08:00
|
|
|
|
use_xfa_boost: Optional[bool] = False,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.I8_paged_mqa_logits(
|
2026-01-17 16:52:02 +08:00
|
|
|
|
q=q,
|
|
|
|
|
|
fused_kv_cache=fused_kv_cache,
|
|
|
|
|
|
weights=weights,
|
|
|
|
|
|
context_lens=context_lens,
|
|
|
|
|
|
block_table=block_table,
|
|
|
|
|
|
max_context_len=max_context_len,
|
|
|
|
|
|
clean_logits=clean_logits,
|
|
|
|
|
|
out=out,
|
2026-01-26 18:56:05 +08:00
|
|
|
|
use_xfa_boost=use_xfa_boost,
|
|
|
|
|
|
)
|
2026-01-17 16:52:02 +08:00
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-26 18:56:05 +08:00
|
|
|
|
|
2026-01-17 16:52:02 +08:00
|
|
|
|
def _fake_I8_paged_mqa_logits(
|
2026-01-26 18:56:05 +08:00
|
|
|
|
q: torch.Tensor,
|
|
|
|
|
|
fused_kv_cache: List[torch.Tensor],
|
|
|
|
|
|
weights: torch.Tensor,
|
|
|
|
|
|
context_lens: List[torch.Tensor],
|
|
|
|
|
|
block_table: torch.Tensor,
|
|
|
|
|
|
max_context_len: int,
|
|
|
|
|
|
clean_logits: bool,
|
|
|
|
|
|
out: torch.Tensor,
|
|
|
|
|
|
use_xfa_boost: Optional[bool] = False,
|
|
|
|
|
|
) -> None:
|
2026-01-17 16:52:02 +08:00
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-26 18:56:05 +08:00
|
|
|
|
|
2026-01-17 16:52:02 +08:00
|
|
|
|
I8_paged_mqa_logits.register_fake(_fake_I8_paged_mqa_logits)
|
|
|
|
|
|
|
2026-01-26 18:56:05 +08:00
|
|
|
|
|
2026-01-17 16:52:02 +08:00
|
|
|
|
##################################################
|
|
|
|
|
|
# ----------- sparse_prefill_fwd_opt -------------
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
@custom_op("_C::sparse_prefill_fwd_opt", mutates_args=())
|
|
|
|
|
|
def sparse_prefill_fwd_opt(
|
2026-01-26 18:56:05 +08:00
|
|
|
|
q: torch.Tensor,
|
|
|
|
|
|
kv: torch.Tensor,
|
|
|
|
|
|
indices: torch.Tensor,
|
|
|
|
|
|
out: torch.Tensor,
|
|
|
|
|
|
max_logits: torch.Tensor,
|
|
|
|
|
|
lse: torch.Tensor,
|
|
|
|
|
|
sm_scale: float,
|
|
|
|
|
|
qlod_cpu: Optional[torch.Tensor] = None,
|
|
|
|
|
|
qlod_xpu: Optional[torch.Tensor] = None,
|
|
|
|
|
|
kvlod_cpu: Optional[torch.Tensor] = None,
|
|
|
|
|
|
kvlod_xpu: Optional[torch.Tensor] = None,
|
|
|
|
|
|
d_v: Optional[int] = -1,
|
|
|
|
|
|
is_causal: Optional[bool] = True,
|
|
|
|
|
|
use_xfa_boost: Optional[bool] = False,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.sparse_prefill_fwd_opt(
|
2026-01-17 16:52:02 +08:00
|
|
|
|
q=q,
|
|
|
|
|
|
kv=kv,
|
|
|
|
|
|
indices=indices,
|
|
|
|
|
|
out=out,
|
|
|
|
|
|
max_logits=max_logits,
|
|
|
|
|
|
lse=lse,
|
|
|
|
|
|
sm_scale=sm_scale,
|
|
|
|
|
|
qlod_cpu=qlod_cpu,
|
|
|
|
|
|
qlod_xpu=qlod_xpu,
|
|
|
|
|
|
kvlod_cpu=kvlod_cpu,
|
|
|
|
|
|
kvlod_xpu=kvlod_xpu,
|
|
|
|
|
|
d_v=d_v,
|
|
|
|
|
|
is_causal=is_causal,
|
2026-01-26 18:56:05 +08:00
|
|
|
|
use_xfa_boost=use_xfa_boost,
|
|
|
|
|
|
)
|
2026-01-17 16:52:02 +08:00
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-26 18:56:05 +08:00
|
|
|
|
|
2026-01-17 16:52:02 +08:00
|
|
|
|
@impl("_C::sparse_prefill_fwd_opt", "CUDA")
|
|
|
|
|
|
def sparse_prefill_fwd_opt_cuda(
|
2026-01-26 18:56:05 +08:00
|
|
|
|
q: torch.Tensor,
|
|
|
|
|
|
kv: torch.Tensor,
|
|
|
|
|
|
indices: torch.Tensor,
|
|
|
|
|
|
out: torch.Tensor,
|
|
|
|
|
|
max_logits: torch.Tensor,
|
|
|
|
|
|
lse: torch.Tensor,
|
|
|
|
|
|
sm_scale: float,
|
|
|
|
|
|
qlod_cpu: Optional[torch.Tensor] = None,
|
|
|
|
|
|
qlod_xpu: Optional[torch.Tensor] = None,
|
|
|
|
|
|
kvlod_cpu: Optional[torch.Tensor] = None,
|
|
|
|
|
|
kvlod_xpu: Optional[torch.Tensor] = None,
|
|
|
|
|
|
d_v: Optional[int] = -1,
|
|
|
|
|
|
is_causal: Optional[bool] = True,
|
|
|
|
|
|
use_xfa_boost: Optional[bool] = False,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.sparse_prefill_fwd_opt(
|
2026-01-17 16:52:02 +08:00
|
|
|
|
q=q,
|
|
|
|
|
|
kv=kv,
|
|
|
|
|
|
indices=indices,
|
|
|
|
|
|
out=out,
|
|
|
|
|
|
max_logits=max_logits,
|
|
|
|
|
|
lse=lse,
|
|
|
|
|
|
sm_scale=sm_scale,
|
|
|
|
|
|
qlod_cpu=qlod_cpu,
|
|
|
|
|
|
qlod_xpu=qlod_xpu,
|
|
|
|
|
|
kvlod_cpu=kvlod_cpu,
|
|
|
|
|
|
kvlod_xpu=kvlod_xpu,
|
|
|
|
|
|
d_v=d_v,
|
|
|
|
|
|
is_causal=is_causal,
|
2026-01-26 18:56:05 +08:00
|
|
|
|
use_xfa_boost=use_xfa_boost,
|
|
|
|
|
|
)
|
2026-01-17 16:52:02 +08:00
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-26 18:56:05 +08:00
|
|
|
|
|
2026-01-17 16:52:02 +08:00
|
|
|
|
def _fake_sparse_prefill_fwd_opt(
|
2026-01-26 18:56:05 +08:00
|
|
|
|
q: torch.Tensor,
|
|
|
|
|
|
kv: torch.Tensor,
|
|
|
|
|
|
indices: torch.Tensor,
|
|
|
|
|
|
out: torch.Tensor,
|
|
|
|
|
|
max_logits: torch.Tensor,
|
|
|
|
|
|
lse: torch.Tensor,
|
|
|
|
|
|
sm_scale: float,
|
|
|
|
|
|
qlod_cpu: Optional[torch.Tensor] = None,
|
|
|
|
|
|
qlod_xpu: Optional[torch.Tensor] = None,
|
|
|
|
|
|
kvlod_cpu: Optional[torch.Tensor] = None,
|
|
|
|
|
|
kvlod_xpu: Optional[torch.Tensor] = None,
|
|
|
|
|
|
d_v: Optional[int] = -1,
|
|
|
|
|
|
is_causal: Optional[bool] = True,
|
|
|
|
|
|
use_xfa_boost: Optional[bool] = False,
|
|
|
|
|
|
) -> None:
|
2026-01-17 16:52:02 +08:00
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-26 18:56:05 +08:00
|
|
|
|
|
2026-01-17 16:52:02 +08:00
|
|
|
|
sparse_prefill_fwd_opt.register_fake(_fake_sparse_prefill_fwd_opt)
|
|
|
|
|
|
|
2026-01-26 18:56:05 +08:00
|
|
|
|
|
2026-01-17 16:52:02 +08:00
|
|
|
|
##################################################
|
|
|
|
|
|
# ------------------ fwd_kvcache_mla -------------
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
@custom_op("_C::fwd_kvcache_mla", mutates_args=())
|
|
|
|
|
|
def fwd_kvcache_mla(
|
2026-01-26 18:56:05 +08:00
|
|
|
|
q_c: torch.Tensor,
|
|
|
|
|
|
kv_cache: torch.Tensor,
|
|
|
|
|
|
indices: torch.Tensor,
|
|
|
|
|
|
kv_lod_cpu: torch.Tensor,
|
|
|
|
|
|
out: torch.Tensor,
|
|
|
|
|
|
max_logits: torch.Tensor,
|
|
|
|
|
|
p_sums: torch.Tensor,
|
|
|
|
|
|
softmax_scale: float,
|
|
|
|
|
|
max_seq_kv: int,
|
|
|
|
|
|
q_r: Optional[torch.Tensor] = None,
|
|
|
|
|
|
pe_cache: Optional[torch.Tensor] = None,
|
|
|
|
|
|
use_xfa_boost: Optional[bool] = False,
|
|
|
|
|
|
kv_lod_xpu: Optional[torch.Tensor] = None,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.fwd_kvcache_mla(
|
2026-01-17 16:52:02 +08:00
|
|
|
|
q_c=q_c,
|
|
|
|
|
|
kv_cache=kv_cache,
|
|
|
|
|
|
indices=indices,
|
|
|
|
|
|
kv_lod_cpu=kv_lod_cpu,
|
|
|
|
|
|
out=out,
|
|
|
|
|
|
max_logits=max_logits,
|
|
|
|
|
|
p_sums=p_sums,
|
|
|
|
|
|
softmax_scale=softmax_scale,
|
|
|
|
|
|
max_seq_kv=max_seq_kv,
|
|
|
|
|
|
q_r=q_r,
|
|
|
|
|
|
pe_cache=pe_cache,
|
|
|
|
|
|
use_xfa_boost=use_xfa_boost,
|
2026-01-26 18:56:05 +08:00
|
|
|
|
kv_lod_xpu=kv_lod_xpu,
|
|
|
|
|
|
)
|
2026-01-17 16:52:02 +08:00
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-26 18:56:05 +08:00
|
|
|
|
|
2026-01-17 16:52:02 +08:00
|
|
|
|
@impl("_C::fwd_kvcache_mla", "CUDA")
|
|
|
|
|
|
def fwd_kvcache_mla_cuda(
|
2026-01-26 18:56:05 +08:00
|
|
|
|
q_c: torch.Tensor,
|
|
|
|
|
|
kv_cache: torch.Tensor,
|
|
|
|
|
|
indices: torch.Tensor,
|
|
|
|
|
|
kv_lod_cpu: torch.Tensor,
|
|
|
|
|
|
out: torch.Tensor,
|
|
|
|
|
|
max_logits: torch.Tensor,
|
|
|
|
|
|
p_sums: torch.Tensor,
|
|
|
|
|
|
softmax_scale: float,
|
|
|
|
|
|
max_seq_kv: int,
|
|
|
|
|
|
q_r: Optional[torch.Tensor] = None,
|
|
|
|
|
|
pe_cache: Optional[torch.Tensor] = None,
|
|
|
|
|
|
use_xfa_boost: Optional[bool] = False,
|
|
|
|
|
|
kv_lod_xpu: Optional[torch.Tensor] = None,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.fwd_kvcache_mla(
|
2026-01-17 16:52:02 +08:00
|
|
|
|
q_c=q_c,
|
|
|
|
|
|
kv_cache=kv_cache,
|
|
|
|
|
|
indices=indices,
|
|
|
|
|
|
kv_lod_cpu=kv_lod_cpu,
|
|
|
|
|
|
out=out,
|
|
|
|
|
|
max_logits=max_logits,
|
|
|
|
|
|
p_sums=p_sums,
|
|
|
|
|
|
softmax_scale=softmax_scale,
|
|
|
|
|
|
max_seq_kv=max_seq_kv,
|
|
|
|
|
|
q_r=q_r,
|
|
|
|
|
|
pe_cache=pe_cache,
|
|
|
|
|
|
use_xfa_boost=use_xfa_boost,
|
2026-01-26 18:56:05 +08:00
|
|
|
|
kv_lod_xpu=kv_lod_xpu,
|
|
|
|
|
|
)
|
2026-01-17 16:52:02 +08:00
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-26 18:56:05 +08:00
|
|
|
|
|
2026-01-17 16:52:02 +08:00
|
|
|
|
def _fake_fwd_kvcache_mla(
|
2026-01-26 18:56:05 +08:00
|
|
|
|
q_c: torch.Tensor,
|
|
|
|
|
|
kv_cache: torch.Tensor,
|
|
|
|
|
|
indices: torch.Tensor,
|
|
|
|
|
|
kv_lod_cpu: torch.Tensor,
|
|
|
|
|
|
out: torch.Tensor,
|
|
|
|
|
|
max_logits: torch.Tensor,
|
|
|
|
|
|
p_sums: torch.Tensor,
|
|
|
|
|
|
softmax_scale: float,
|
|
|
|
|
|
max_seq_kv: int,
|
|
|
|
|
|
q_r: Optional[torch.Tensor] = None,
|
|
|
|
|
|
pe_cache: Optional[torch.Tensor] = None,
|
|
|
|
|
|
use_xfa_boost: Optional[bool] = False,
|
|
|
|
|
|
kv_lod_xpu: Optional[torch.Tensor] = None,
|
|
|
|
|
|
) -> None:
|
2026-01-17 16:52:02 +08:00
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-26 18:56:05 +08:00
|
|
|
|
|
2026-01-17 16:52:02 +08:00
|
|
|
|
fwd_kvcache_mla.register_fake(_fake_fwd_kvcache_mla)
|
2026-01-23 10:29:52 +08:00
|
|
|
|
|
2026-01-26 18:56:05 +08:00
|
|
|
|
|
|
|
|
|
|
##################################################
|
2026-01-27 19:56:22 +08:00
|
|
|
|
# --------------- dequant_int4 -------------------
|
2026-01-26 18:56:05 +08:00
|
|
|
|
##################################################
|
|
|
|
|
|
@custom_op("_C::dequant_int4", mutates_args=())
|
|
|
|
|
|
def dequant_int4(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
scale: torch.Tensor,
|
|
|
|
|
|
zero: torch.Tensor,
|
|
|
|
|
|
y: torch.Tensor,
|
|
|
|
|
|
group_m: int,
|
|
|
|
|
|
int4_signed: bool = True,
|
|
|
|
|
|
use_mode_fast: bool = False,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.dequant_int4(
|
2026-01-26 18:56:05 +08:00
|
|
|
|
x=x,
|
|
|
|
|
|
scale=scale,
|
|
|
|
|
|
zero=zero,
|
|
|
|
|
|
y=y,
|
|
|
|
|
|
group_m=group_m,
|
|
|
|
|
|
int4_signed=int4_signed,
|
|
|
|
|
|
use_mode_fast=use_mode_fast,
|
|
|
|
|
|
)
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@impl("_C::dequant_int4", "CUDA")
|
|
|
|
|
|
def dequant_int4_cuda(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
scale: torch.Tensor,
|
|
|
|
|
|
zero: torch.Tensor,
|
|
|
|
|
|
y: torch.Tensor,
|
|
|
|
|
|
group_m: int,
|
|
|
|
|
|
int4_signed: bool = True,
|
|
|
|
|
|
use_mode_fast: bool = False,
|
|
|
|
|
|
) -> None:
|
2026-02-12 18:13:00 +08:00
|
|
|
|
kunlun_ops.dequant_int4(
|
2026-01-26 18:56:05 +08:00
|
|
|
|
x=x,
|
|
|
|
|
|
scale=scale,
|
|
|
|
|
|
zero=zero,
|
|
|
|
|
|
y=y,
|
|
|
|
|
|
group_m=group_m,
|
|
|
|
|
|
int4_signed=int4_signed,
|
|
|
|
|
|
use_mode_fast=use_mode_fast,
|
|
|
|
|
|
)
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _fake_dequant_int4(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
scale: torch.Tensor,
|
|
|
|
|
|
zero: torch.Tensor,
|
|
|
|
|
|
y: torch.Tensor,
|
|
|
|
|
|
group_m: int,
|
|
|
|
|
|
int4_signed: bool = True,
|
|
|
|
|
|
use_mode_fast: bool = False,
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dequant_int4.register_fake(_fake_dequant_int4)
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-01-23 10:29:52 +08:00
|
|
|
|
##################################################
|
|
|
|
|
|
# ------------------ fast_topkv2 -------------
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
@custom_op("_C::fast_topkv2", mutates_args=())
|
|
|
|
|
|
def fast_topkv2(
|
2026-02-11 12:04:14 +08:00
|
|
|
|
score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
|
|
|
|
|
|
) -> torch.Tensor:
|
2026-01-23 10:29:52 +08:00
|
|
|
|
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
2026-02-13 14:07:10 +08:00
|
|
|
|
topk_indices = kunlun_ops.fast_topkv2(score=score, lengths=lengths, topk=topk)
|
2026-01-23 10:29:52 +08:00
|
|
|
|
return topk_indices
|
|
|
|
|
|
|
2026-02-11 12:04:14 +08:00
|
|
|
|
|
2026-01-23 10:29:52 +08:00
|
|
|
|
@impl("_C::fast_topkv2", "CUDA")
|
|
|
|
|
|
def fast_topkv2_cuda(
|
2026-02-11 12:04:14 +08:00
|
|
|
|
score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
|
|
|
|
|
|
) -> torch.Tensor:
|
2026-01-23 10:29:52 +08:00
|
|
|
|
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
2026-02-13 14:07:10 +08:00
|
|
|
|
topk_indices = kunlun_ops.fast_topkv2(score=score, lengths=lengths, topk=topk)
|
2026-01-23 10:29:52 +08:00
|
|
|
|
return topk_indices
|
|
|
|
|
|
|
2026-02-11 12:04:14 +08:00
|
|
|
|
|
2026-01-23 10:29:52 +08:00
|
|
|
|
def _fake_fast_topkv2(
|
2026-02-11 12:04:14 +08:00
|
|
|
|
score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
|
|
|
|
|
|
) -> torch.Tensor:
|
2026-01-23 10:29:52 +08:00
|
|
|
|
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
|
|
|
|
|
topk_indices = score.new_empty((score.size(0), topk), dtype=torch.int32)
|
|
|
|
|
|
return topk_indices
|
|
|
|
|
|
|
2026-02-11 12:04:14 +08:00
|
|
|
|
|
|
|
|
|
|
fast_topkv2.register_fake(_fake_fast_topkv2)
|
|
|
|
|
|
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
# ----------------- LoRA ops --------------------
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
# -------------- sgmv_shrink_lora ----------------
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
@custom_op("_C::sgmv_shrink_lora", mutates_args=())
|
|
|
|
|
|
def sgmv_shrink_lora(
|
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
|
lora_a_weights: torch.Tensor,
|
|
|
|
|
|
output_tensor: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_num_lod: torch.Tensor,
|
|
|
|
|
|
moe_index: torch.Tensor,
|
|
|
|
|
|
expert_m: torch.Tensor,
|
|
|
|
|
|
b_seq_start_loc: torch.Tensor,
|
|
|
|
|
|
seq_len_tensor: torch.Tensor,
|
|
|
|
|
|
lora_indices_tensor: torch.Tensor,
|
|
|
|
|
|
batches: int,
|
|
|
|
|
|
max_seq_length: int,
|
|
|
|
|
|
token_nums: int,
|
|
|
|
|
|
scaling: float,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
# return torch.ops.xspeedgate_ops.sgmv_shrink_cluster(
|
|
|
|
|
|
# inputs, lora_a_weights, seq_len_tensor, lora_indices_tensor, output_tensor, scaling
|
|
|
|
|
|
# )
|
|
|
|
|
|
return torch.ops.xspeedgate_ops.sgmv_shrink_sdnn(
|
|
|
|
|
|
inputs,
|
|
|
|
|
|
lora_a_weights,
|
|
|
|
|
|
seq_len_tensor,
|
|
|
|
|
|
lora_indices_tensor,
|
|
|
|
|
|
output_tensor,
|
|
|
|
|
|
scaling,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@impl("_C::sgmv_shrink_lora", "CUDA")
|
|
|
|
|
|
def sgmv_shrink_lora_cuda(
|
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
|
lora_a_weights: torch.Tensor,
|
|
|
|
|
|
output_tensor: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_num_lod: torch.Tensor,
|
|
|
|
|
|
moe_index: torch.Tensor,
|
|
|
|
|
|
expert_m: torch.Tensor,
|
|
|
|
|
|
b_seq_start_loc: torch.Tensor,
|
|
|
|
|
|
seq_len_tensor: torch.Tensor,
|
|
|
|
|
|
lora_indices_tensor: torch.Tensor,
|
|
|
|
|
|
batches: int,
|
|
|
|
|
|
max_seq_length: int,
|
|
|
|
|
|
token_nums: int,
|
|
|
|
|
|
scaling: float,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
# return torch.ops.xspeedgate_ops.sgmv_shrink_cluster(
|
|
|
|
|
|
# inputs, lora_a_weights, seq_len_tensor, lora_indices_tensor, output_tensor, scaling
|
|
|
|
|
|
# )
|
|
|
|
|
|
return torch.ops.xspeedgate_ops.sgmv_shrink_sdnn(
|
|
|
|
|
|
inputs,
|
|
|
|
|
|
lora_a_weights,
|
|
|
|
|
|
seq_len_tensor,
|
|
|
|
|
|
lora_indices_tensor,
|
|
|
|
|
|
output_tensor,
|
|
|
|
|
|
scaling,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _fake_sgmv_shrink_lora(
|
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
|
lora_a_weights: torch.Tensor,
|
|
|
|
|
|
output_tensor: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_num_lod: torch.Tensor,
|
|
|
|
|
|
moe_index: torch.Tensor,
|
|
|
|
|
|
expert_m: torch.Tensor,
|
|
|
|
|
|
b_seq_start_loc: torch.Tensor,
|
|
|
|
|
|
seq_len_tensor: torch.Tensor,
|
|
|
|
|
|
lora_indices_tensor: torch.Tensor,
|
|
|
|
|
|
batches: int,
|
|
|
|
|
|
max_seq_length: int,
|
|
|
|
|
|
token_nums: int,
|
|
|
|
|
|
scaling: float,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
return output_tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgmv_shrink_lora.register_fake(_fake_sgmv_shrink_lora)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
# -------------- sgmv_expand_lora ----------------
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
@custom_op("_C::sgmv_expand_lora", mutates_args=())
|
|
|
|
|
|
def sgmv_expand_lora(
|
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
|
lora_b_weights: torch.Tensor,
|
|
|
|
|
|
output_tensor: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_num_lod: torch.Tensor,
|
|
|
|
|
|
moe_index: torch.Tensor,
|
|
|
|
|
|
b_seq_start_loc: torch.Tensor,
|
|
|
|
|
|
seq_len_tensor: torch.Tensor,
|
|
|
|
|
|
lora_indices_tensor: torch.Tensor,
|
|
|
|
|
|
batches: int,
|
|
|
|
|
|
max_seq_length: int,
|
|
|
|
|
|
token_nums: int,
|
|
|
|
|
|
add_inputs: bool = False,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
# return torch.ops.xspeedgate_ops.sgmv_expand_cluster(
|
|
|
|
|
|
# inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0
|
|
|
|
|
|
# )
|
|
|
|
|
|
return torch.ops.xspeedgate_ops.sgmv_expand_sdnn(
|
|
|
|
|
|
inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@impl("_C::sgmv_expand_lora", "CUDA")
|
|
|
|
|
|
def sgmv_expand_lora_cuda(
|
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
|
lora_b_weights: torch.Tensor,
|
|
|
|
|
|
output_tensor: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_num_lod: torch.Tensor,
|
|
|
|
|
|
moe_index: torch.Tensor,
|
|
|
|
|
|
b_seq_start_loc: torch.Tensor,
|
|
|
|
|
|
seq_len_tensor: torch.Tensor,
|
|
|
|
|
|
lora_indices_tensor: torch.Tensor,
|
|
|
|
|
|
batches: int,
|
|
|
|
|
|
max_seq_length: int,
|
|
|
|
|
|
token_nums: int,
|
|
|
|
|
|
add_inputs: bool = False,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
# return torch.ops.xspeedgate_ops.sgmv_expand_cluster(
|
|
|
|
|
|
# inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0
|
|
|
|
|
|
# )
|
|
|
|
|
|
return torch.ops.xspeedgate_ops.sgmv_expand_sdnn(
|
|
|
|
|
|
inputs, lora_b_weights, seq_len_tensor, lora_indices_tensor, output_tensor, 0
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _fake_sgmv_expand_lora(
|
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
|
lora_b_weights: torch.Tensor,
|
|
|
|
|
|
output_tensor: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_num_lod: torch.Tensor,
|
|
|
|
|
|
moe_index: torch.Tensor,
|
|
|
|
|
|
b_seq_start_loc: torch.Tensor,
|
|
|
|
|
|
seq_len_tensor: torch.Tensor,
|
|
|
|
|
|
lora_indices_tensor: torch.Tensor,
|
|
|
|
|
|
batches: int,
|
|
|
|
|
|
max_seq_length: int,
|
|
|
|
|
|
token_nums: int,
|
|
|
|
|
|
add_inputs: bool = False,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
return output_tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgmv_expand_lora.register_fake(_fake_sgmv_expand_lora)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
# ----------- sgmv_expand_slice_lora -------------
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
@custom_op("_C::sgmv_expand_slice_lora", mutates_args=())
|
|
|
|
|
|
def sgmv_expand_slice_lora(
|
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
|
lora_b_weights: torch.Tensor,
|
|
|
|
|
|
output_tensor: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_num_lod: torch.Tensor,
|
|
|
|
|
|
moe_index: torch.Tensor,
|
|
|
|
|
|
normed_scale: torch.Tensor,
|
|
|
|
|
|
b_seq_start_loc: torch.Tensor,
|
|
|
|
|
|
seq_len_tensor: torch.Tensor,
|
|
|
|
|
|
lora_indices_tensor: torch.Tensor,
|
|
|
|
|
|
batches: int,
|
|
|
|
|
|
max_seq_length: int,
|
|
|
|
|
|
token_nums: int,
|
|
|
|
|
|
slice_offset: int,
|
|
|
|
|
|
slice_size: int,
|
|
|
|
|
|
add_inputs: bool = False,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
return torch.ops.xspeedgate_ops.sgmv_expand_cluster(
|
|
|
|
|
|
inputs,
|
|
|
|
|
|
lora_b_weights,
|
|
|
|
|
|
seq_len_tensor,
|
|
|
|
|
|
lora_indices_tensor,
|
|
|
|
|
|
output_tensor,
|
|
|
|
|
|
slice_offset,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@impl("_C::sgmv_expand_slice_lora", "CUDA")
|
|
|
|
|
|
def sgmv_expand_slice_lora_cuda(
|
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
|
lora_b_weights: torch.Tensor,
|
|
|
|
|
|
output_tensor: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_num_lod: torch.Tensor,
|
|
|
|
|
|
moe_index: torch.Tensor,
|
|
|
|
|
|
normed_scale: torch.Tensor,
|
|
|
|
|
|
b_seq_start_loc: torch.Tensor,
|
|
|
|
|
|
seq_len_tensor: torch.Tensor,
|
|
|
|
|
|
lora_indices_tensor: torch.Tensor,
|
|
|
|
|
|
batches: int,
|
|
|
|
|
|
max_seq_length: int,
|
|
|
|
|
|
token_nums: int,
|
|
|
|
|
|
slice_offset: int,
|
|
|
|
|
|
slice_size: int,
|
|
|
|
|
|
add_inputs: bool = False,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
return torch.ops.xspeedgate_ops.sgmv_expand_cluster(
|
|
|
|
|
|
inputs,
|
|
|
|
|
|
lora_b_weights,
|
|
|
|
|
|
seq_len_tensor,
|
|
|
|
|
|
lora_indices_tensor,
|
|
|
|
|
|
output_tensor,
|
|
|
|
|
|
slice_offset,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _fake_sgmv_expand_slice_lora(
|
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
|
lora_b_weights: torch.Tensor,
|
|
|
|
|
|
output_tensor: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_num_lod: torch.Tensor,
|
|
|
|
|
|
moe_index: torch.Tensor,
|
|
|
|
|
|
normed_scale: torch.Tensor,
|
|
|
|
|
|
b_seq_start_loc: torch.Tensor,
|
|
|
|
|
|
seq_len_tensor: torch.Tensor,
|
|
|
|
|
|
lora_indices_tensor: torch.Tensor,
|
|
|
|
|
|
batches: int,
|
|
|
|
|
|
max_seq_length: int,
|
|
|
|
|
|
token_nums: int,
|
|
|
|
|
|
slice_offset: int,
|
|
|
|
|
|
slice_size: int,
|
|
|
|
|
|
add_inputs: bool = False,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
return output_tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sgmv_expand_slice_lora.register_fake(_fake_sgmv_expand_slice_lora)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
# -------------- bgmv_shrink_lora ----------------
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
@custom_op("_C::bgmv_shrink_lora", mutates_args=())
|
|
|
|
|
|
def bgmv_shrink_lora(
|
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
|
lora_a_weights: torch.Tensor,
|
|
|
|
|
|
output_tensor: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_num_lod: torch.Tensor,
|
|
|
|
|
|
moe_index: torch.Tensor,
|
|
|
|
|
|
expert_m: torch.Tensor,
|
|
|
|
|
|
lora_indices_tensor: torch.Tensor,
|
|
|
|
|
|
scaling: float = 1.0,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
return torch.ops.xspeedgate_ops.bgmv_shrink_cluster(
|
|
|
|
|
|
inputs, lora_a_weights, lora_indices_tensor, output_tensor, scaling
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@impl("_C::bgmv_shrink_lora", "CUDA")
|
|
|
|
|
|
def bgmv_shrink_lora_cuda(
|
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
|
lora_a_weights: torch.Tensor,
|
|
|
|
|
|
output_tensor: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_num_lod: torch.Tensor,
|
|
|
|
|
|
moe_index: torch.Tensor,
|
|
|
|
|
|
expert_m: torch.Tensor,
|
|
|
|
|
|
lora_indices_tensor: torch.Tensor,
|
|
|
|
|
|
scaling: float = 1.0,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
return torch.ops.xspeedgate_ops.bgmv_shrink_cluster(
|
|
|
|
|
|
inputs, lora_a_weights, lora_indices_tensor, output_tensor, scaling
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _fake_bgmv_shrink_lora(
|
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
|
lora_a_weights: torch.Tensor,
|
|
|
|
|
|
output_tensor: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_num_lod: torch.Tensor,
|
|
|
|
|
|
moe_index: torch.Tensor,
|
|
|
|
|
|
expert_m: torch.Tensor,
|
|
|
|
|
|
lora_indices_tensor: torch.Tensor,
|
|
|
|
|
|
scaling: float = 1.0,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
return output_tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bgmv_shrink_lora.register_fake(_fake_bgmv_shrink_lora)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
# -------------- bgmv_expand_lora ----------------
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
@custom_op("_C::bgmv_expand_lora", mutates_args=())
|
|
|
|
|
|
def bgmv_expand_lora(
|
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
|
lora_b_weights: torch.Tensor,
|
|
|
|
|
|
output_tensor: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_num_lod: torch.Tensor,
|
|
|
|
|
|
moe_index: torch.Tensor,
|
|
|
|
|
|
lora_indices_tensor: torch.Tensor,
|
|
|
|
|
|
add_inputs: bool = True,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(
|
|
|
|
|
|
inputs, lora_b_weights, lora_indices_tensor, output_tensor, 0
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@impl("_C::bgmv_expand_lora", "CUDA")
|
|
|
|
|
|
def bgmv_expand_lora_cuda(
|
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
|
lora_b_weights: torch.Tensor,
|
|
|
|
|
|
output_tensor: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_num_lod: torch.Tensor,
|
|
|
|
|
|
moe_index: torch.Tensor,
|
|
|
|
|
|
lora_indices_tensor: torch.Tensor,
|
|
|
|
|
|
add_inputs: bool = True,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(
|
|
|
|
|
|
inputs, lora_b_weights, lora_indices_tensor, output_tensor, 0
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _fake_bgmv_expand_lora(
|
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
|
lora_b_weights: torch.Tensor,
|
|
|
|
|
|
output_tensor: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_num_lod: torch.Tensor,
|
|
|
|
|
|
moe_index: torch.Tensor,
|
|
|
|
|
|
lora_indices_tensor: torch.Tensor,
|
|
|
|
|
|
add_inputs: bool = True,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
return output_tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bgmv_expand_lora.register_fake(_fake_bgmv_expand_lora)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
# ----------- bgmv_expand_slice_lora -------------
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
@custom_op("_C::bgmv_expand_slice_lora", mutates_args=())
|
|
|
|
|
|
def bgmv_expand_slice_lora(
|
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
|
lora_b_weights: torch.Tensor,
|
|
|
|
|
|
output_tensor: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_num_lod: torch.Tensor,
|
|
|
|
|
|
moe_index: torch.Tensor,
|
|
|
|
|
|
normed_scale: torch.Tensor,
|
|
|
|
|
|
lora_indices_tensor: torch.Tensor,
|
|
|
|
|
|
slice_offset: int,
|
|
|
|
|
|
slice_size: int,
|
|
|
|
|
|
add_inputs: bool = True,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(
|
|
|
|
|
|
inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@impl("_C::bgmv_expand_slice_lora", "CUDA")
|
|
|
|
|
|
def bgmv_expand_slice_lora_cuda(
|
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
|
lora_b_weights: torch.Tensor,
|
|
|
|
|
|
output_tensor: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_num_lod: torch.Tensor,
|
|
|
|
|
|
moe_index: torch.Tensor,
|
|
|
|
|
|
normed_scale: torch.Tensor,
|
|
|
|
|
|
lora_indices_tensor: torch.Tensor,
|
|
|
|
|
|
slice_offset: int,
|
|
|
|
|
|
slice_size: int,
|
|
|
|
|
|
add_inputs: bool = True,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
return torch.ops.xspeedgate_ops.bgmv_expand_cluster(
|
|
|
|
|
|
inputs, lora_b_weights, lora_indices_tensor, output_tensor, slice_offset
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _fake_bgmv_expand_slice_lora(
|
|
|
|
|
|
inputs: torch.Tensor,
|
|
|
|
|
|
lora_b_weights: torch.Tensor,
|
|
|
|
|
|
output_tensor: torch.Tensor,
|
|
|
|
|
|
block_statistic: torch.Tensor,
|
|
|
|
|
|
sorted_tokens_num_lod: torch.Tensor,
|
|
|
|
|
|
moe_index: torch.Tensor,
|
|
|
|
|
|
normed_scale: torch.Tensor,
|
|
|
|
|
|
lora_indices_tensor: torch.Tensor,
|
|
|
|
|
|
slice_offset: int,
|
|
|
|
|
|
slice_size: int,
|
|
|
|
|
|
add_inputs: bool = True,
|
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
return output_tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bgmv_expand_slice_lora.register_fake(_fake_bgmv_expand_slice_lora)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
# ----------- lora_matmul_inplace ----------------
|
|
|
|
|
|
##################################################
|
|
|
|
|
|
@custom_op("_C::lora_matmul_inplace", mutates_args=())
|
|
|
|
|
|
def lora_matmul_inplace(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
w: torch.Tensor,
|
|
|
|
|
|
output_tensor: torch.Tensor,
|
|
|
|
|
|
x_trans: bool = False,
|
|
|
|
|
|
w_trans: bool = True,
|
|
|
|
|
|
alpha: float = 1.0,
|
|
|
|
|
|
beta: float = 1.0,
|
|
|
|
|
|
) -> None:
|
2026-02-13 14:07:10 +08:00
|
|
|
|
kunlun_ops.matmul(
|
2026-02-11 12:04:14 +08:00
|
|
|
|
x=x.contiguous(),
|
|
|
|
|
|
w=w.contiguous(),
|
|
|
|
|
|
out=output_tensor,
|
|
|
|
|
|
x_trans=x_trans,
|
|
|
|
|
|
w_trans=w_trans,
|
|
|
|
|
|
alpha=alpha,
|
|
|
|
|
|
beta=beta,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@impl("_C::lora_matmul_inplace", "CUDA")
|
|
|
|
|
|
def lora_matmul_inplace_cuda(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
w: torch.Tensor,
|
|
|
|
|
|
output_tensor: torch.Tensor,
|
|
|
|
|
|
x_trans: bool = False,
|
|
|
|
|
|
w_trans: bool = True,
|
|
|
|
|
|
alpha: float = 1.0,
|
|
|
|
|
|
beta: float = 1.0,
|
|
|
|
|
|
) -> None:
|
2026-02-13 14:07:10 +08:00
|
|
|
|
kunlun_ops.matmul(
|
2026-02-11 12:04:14 +08:00
|
|
|
|
x=x.contiguous(),
|
|
|
|
|
|
w=w.contiguous(),
|
|
|
|
|
|
out=output_tensor,
|
|
|
|
|
|
x_trans=x_trans,
|
|
|
|
|
|
w_trans=w_trans,
|
|
|
|
|
|
alpha=alpha,
|
|
|
|
|
|
beta=beta,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _fake_lora_matmul_inplace(
|
|
|
|
|
|
x: torch.Tensor,
|
|
|
|
|
|
w: torch.Tensor,
|
|
|
|
|
|
output_tensor: torch.Tensor,
|
|
|
|
|
|
x_trans: bool = False,
|
|
|
|
|
|
w_trans: bool = True,
|
|
|
|
|
|
alpha: float = 1.0,
|
|
|
|
|
|
beta: float = 1.0,
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lora_matmul_inplace.register_fake(_fake_lora_matmul_inplace)
|