Files
xc-llm-kunlun/vllm_kunlun/vllm_utils_wrapper.py
Xinyu Dong b3c30a3cb9 [Feature] Support XiaoMi MIMO Flash V2 (#62)
* [Feature] Support MIMO Flash V2
2025-12-31 10:16:33 +08:00

1517 lines
40 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""vllm_utils_wrapper.py"""
import vllm.distributed.parallel_state as parallel_state
import vllm.utils as _orig
from typing import (Any, Callable, Optional, Union, get_origin, get_args, List)
from types import SimpleNamespace
import torch
from torch.library import Library
import inspect
import typing
from torch.library import register_fake
import vllm_kunlun._kunlun
def patch_annotations_for_schema(func):
"""patch_annotations_for_schema"""
sig = inspect.signature(func)
new_params = []
for name, param in sig.parameters.items():
ann = param.annotation
if get_origin(ann) is typing.Union and type(None) in get_args(ann):
inner_type = [a for a in get_args(ann) if a is not type(None)][0]
if get_origin(inner_type) is list: # Optional[list[int]]
inner_args = get_args(inner_type)
new_ann = Optional[List[inner_args[0] if inner_args else typing.Any]]
param = param.replace(annotation=new_ann)
elif get_origin(ann) is list:
args = get_args(ann)
new_ann = List[args[0] if args else typing.Any]
param = param.replace(annotation=new_ann)
new_params.append(param)
func.__signature__ = sig.replace(parameters=new_params)
return func
def supports_custom_op() -> bool:
"""supports_custom_op"""
return hasattr(torch.library, "custom_op")
vllm_lib = Library("vllm", "FRAGMENT") # noqa
def direct_register_custom_op(
op_name: str,
op_func: Callable,
mutates_args: Optional[list[str]] = None,
fake_impl: Optional[Callable] = None,
target_lib: Optional[Library] = None,
dispatch_key: str = "CUDA",
tags: tuple[torch.Tag, ...] = (),
):
"""
`torch.library.custom_op` can have significant overhead because it
needs to consider complicated dispatching logic. This function
directly registers a custom op and dispatches it to the CUDA backend.
See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
for more details.
By default, the custom op is registered to the vLLM library. If you
want to register it to a different library, you can pass the library
object to the `target_lib` argument.
IMPORTANT: the lifetime of the operator is tied to the lifetime of the
library object. If you want to bind the operator to a different library,
make sure the library object is alive when the operator is used.
"""
if not supports_custom_op():
from vllm.platforms import current_platform
assert not current_platform.is_cuda_alike(), (
"cuda platform needs torch>=2.4 to support custom op, "
"chances are you are using an old version of pytorch "
"or a custom build of pytorch. It is recommended to "
"use vLLM in a fresh new environment and let it install "
"the required dependencies.")
return
if mutates_args is None:
mutates_args = []
import torch.library
if hasattr(torch.library, "infer_schema"):
patched_func = patch_annotations_for_schema(op_func)
schema_str = torch.library.infer_schema(op_func,
mutates_args=mutates_args)
else:
# for pytorch 2.4
import torch._custom_op.impl
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
my_lib = target_lib or vllm_lib
my_lib.define(op_name + schema_str, tags=tags)
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl)
def vllm_kunlun_weak_ref_tensor(tensor: Any) -> Any:
"""
Create a weak reference to a tensor.
The new tensor will share the same data as the original tensor,
but will not keep the original tensor alive.
"""
# return tensor
if isinstance(tensor, torch.Tensor):
return torch.ops._kunlun.weak_ref_tensor(tensor)
else:
return tensor
def vllm_kunlun_weak_ref_tensors(
tensors: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor]]
) -> Union[torch.Tensor, list[Any], tuple[Any], Any]:
"""
Convenience function to create weak references to tensors,
for single tensor, list of tensors or tuple of tensors.
"""
if isinstance(tensors, torch.Tensor):
return vllm_kunlun_weak_ref_tensor(tensors)
if isinstance(tensors, list):
return [vllm_kunlun_weak_ref_tensor(t) for t in tensors]
if isinstance(tensors, tuple):
return tuple(vllm_kunlun_weak_ref_tensor(t) for t in tensors)
raise ValueError("Invalid type for tensors")
# import vllm.utils as vu
# vu.direct_register_custom_op = direct_register_custom_op
# import vllm.utils as vu
# vu.direct_register_custom_op = direct_register_custom_op
_wrapped = SimpleNamespace(**_orig.__dict__)
_wrapped.direct_register_custom_op = direct_register_custom_op
_wrapped.weak_ref_tensor = vllm_kunlun_weak_ref_tensor
_wrapped.weak_ref_tensors = vllm_kunlun_weak_ref_tensors
import sys
sys.modules["vllm.utils"] = _wrapped
_original_all_reduce = parallel_state.GroupCoordinator.all_reduce
_original_all_gather = parallel_state.GroupCoordinator.all_gather
def vllm_kunlun_all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
"""vllm_kunlun_all_reduce"""
if self.world_size == 1:
return input_
torch.distributed.all_reduce(input_, group=self.device_group)
return input_
def vllm_kunlun_all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
"""vllm_kunlun_all_reduce"""
world_size = self.world_size
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# Allocate output tensor.
output_tensor = torch.empty((world_size, ) + input_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
torch.distributed.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(world_size *
input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
parallel_state.GroupCoordinator.all_reduce = vllm_kunlun_all_reduce
parallel_state.GroupCoordinator.all_gather = vllm_kunlun_all_gather
from torch.library import custom_op, impl
import torch
from vllm import _custom_ops as ops
from typing import Optional, List
import os
@custom_op("_C::rms_norm", mutates_args=())
def rms_norm(
result : torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
epsilon: float
)->None:
pass
@custom_op("_C::fused_add_rms_norm", mutates_args=())
def fused_add_rms_norm(
result : torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
epsilon: float
)->None:
pass
@custom_op("_C::static_scaled_fp8_quant", mutates_args=())
def static_scaled_fp8_quant(
result : torch.Tensor,
input: torch.Tensor,
scale: torch.Tensor
)->None:
pass
@impl("_C::static_scaled_fp8_quant", "CUDA")
def static_scaled_fp8_quant_xpu(
result : torch.Tensor,
input: torch.Tensor,
scale: torch.Tensor
)->None:
pass
@custom_op("_C::dynamic_scaled_fp8_quant", mutates_args=())
def dynamic_scaled_fp8_quant(
result : torch.Tensor,
input: torch.Tensor,
scale: torch.Tensor
)->None:
pass
@impl("_C::dynamic_scaled_fp8_quant", "CUDA")
def dynamic_scaled_fp8_quant_xpu(
result : torch.Tensor,
input: torch.Tensor,
scale: torch.Tensor
)->None:
pass
@custom_op("_C::dynamic_per_token_scaled_fp8_quant", mutates_args=())
def dynamic_per_token_scaled_fp8_quant(
result : torch.Tensor,
input: torch.Tensor,
scale: torch.Tensor,
scale_ub: Optional[torch.Tensor]
)->None:
pass
@impl("_C::dynamic_per_token_scaled_fp8_quant", "CUDA")
def dynamic_per_token_scaled_fp8_quant_xpu(
result : torch.Tensor,
input: torch.Tensor,
scale: torch.Tensor,
scale_ub: Optional[torch.Tensor]
)->None:
pass
@custom_op("_C::rms_norm_static_fp8_quant", mutates_args=())
def rms_norm_static_fp8_quant(
result : torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
epsilon: float
)->None:
pass
@impl("_C::rms_norm_static_fp8_quant", "CUDA")
def rms_norm_static_fp8_quant_xpu(
result : torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
epsilon: float
)->None:
pass
@custom_op("_C::fused_add_rms_norm_static_fp8_quant", mutates_args=())
def fused_add_rms_norm_static_fp8_quant(
result : torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
epsilon: float
)->None:
pass
@impl("_C::fused_add_rms_norm_static_fp8_quant", "CUDA")
def fused_add_rms_norm_static_fp8_quant_xpu(
result : torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
epsilon: float
)->None:
pass
@custom_op("_C::rms_norm_dynamic_per_token_quant", mutates_args=())
def rms_norm_dynamic_per_token_quant(
result : torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
epsilon: float,
scale_ub: Optional[torch.Tensor],
residual: Optional[torch.Tensor]
)->None:
pass
@impl("_C::rms_norm_dynamic_per_token_quant", "CUDA")
def rms_norm_dynamic_per_token_quant_xpu(
result : torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
epsilon: float,
scale_ub: Optional[torch.Tensor],
residual: Optional[torch.Tensor]
)->None:
pass
@custom_op("_C::silu_and_mul", mutates_args=())
def silu_and_mul(
result : torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
epsilon: float
)->None:
pass
@impl("_C::silu_and_mul", "CUDA")
def silu_and_mul_xpu(
result : torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
epsilon: float
)->None:
pass
@custom_op("_C::silu_and_mul_quant", mutates_args=())
def silu_and_mul_quant(
result : torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
epsilon: float
)->None:
pass
@impl("_C::silu_and_mul_quant", "CUDA")
def silu_and_mul_quant_xpu(
result : torch.Tensor,
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
epsilon: float
)->None:
pass
import torch
import xtorch_ops
from torch.library import custom_op, impl
@custom_op("_C::add_rmsnorm", mutates_args=())
def add_rmsnorm(
x: torch.Tensor,
y: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
interweaved: bool = False,
store_output_before_norm: bool = True,
bias: torch.Tensor = None,
smooth: torch.Tensor = None,
residual_output: torch.Tensor = None,
output_max: torch.Tensor = None,
) -> None:
xtorch_ops.add_rmsnorm(
x,
y, # 原来写 residual这里其实是 y
residual_output=residual_output,
weight=weight,
eps=eps,
output=output,
)
@impl("_C::add_rmsnorm", "CUDA")
def add_rmsnorm_cuda(
x: torch.Tensor,
y: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
interweaved: bool = False,
store_output_before_norm: bool = True,
bias: torch.Tensor = None,
smooth: torch.Tensor = None,
residual_output: torch.Tensor = None,
output_max: torch.Tensor = None,
) -> None:
xtorch_ops.add_rmsnorm(
x,
y,
residual_output=residual_output,
weight=weight,
eps=eps,
output=output,
)
@custom_op("_C::rmsnorm", mutates_args=())
def rmsnorm(
x: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
interweave: bool = False,
store_output_before_norm: bool = True,
bias: torch.Tensor = None,
residual_output: torch.Tensor = None,
output_max: torch.Tensor = None,
) -> None:
xtorch_ops.rmsnorm(
x,
weight,
output,
eps,
)
@impl("_C::rmsnorm", "CUDA")
def rmsnorm_cuda(
x: torch.Tensor,
weight: torch.Tensor,
output: torch.Tensor,
eps: float = 1e-5,
interweave: bool = False,
store_output_before_norm: bool = True,
bias: torch.Tensor = None,
residual_output: torch.Tensor = None,
output_max: torch.Tensor = None,
) -> None:
xtorch_ops.rmsnorm(
x,
weight,
output,
eps,
)
import torch
def _fake_rmsnorm(x, weight, output, eps=1e-5, interweave=False,
store_output_before_norm=True, bias=None,
residual_output=None, output_max=None):
# 设置 shape/dtype但不返回值
output.fake_shape = x.shape
output.fake_dtype = x.dtype
return None
rmsnorm.register_fake(_fake_rmsnorm)
def _fake_add_rmsnorm(x, y, weight, output, eps=1e-5,
interweaved=False, store_output_before_norm=True,
bias=None, smooth=None, residual_output=None,
output_max=None):
output.fake_shape = x.shape
output.fake_dtype = x.dtype
return None
add_rmsnorm.register_fake(_fake_add_rmsnorm)
@custom_op("_C::split_norm_rope_neox", mutates_args=())
def split_norm_rope_neox(
q_emb: torch.Tensor,
k_emb: torch.Tensor,
v_out: torch.Tensor,
qkv: torch.Tensor,
rotary_pos_embedding: torch.Tensor,
q_norm_weight: torch.Tensor,
k_norm_weight: torch.Tensor,
positions: torch.Tensor,
num_tokens: int,
max_seqlen: int,
head_num: int,
kv_head_num: int,
head_dim: int,
rotary_dim: int,
emb_batch_size: int=1
) -> None:
xtorch_ops.split_norm_rope_neox(
q_emb,
k_emb,
v_out,
qkv,
rotary_pos_embedding,
q_norm_weight,
k_norm_weight,
positions,
num_tokens,
max_seqlen,
head_num,
kv_head_num,
head_dim,
rotary_dim,
)
@impl("_C::split_norm_rope_neox", "CUDA")
def split_norm_rope_neox_cuda(
q_emb: torch.Tensor,
k_emb: torch.Tensor,
v_out: torch.Tensor,
qkv: torch.Tensor,
rotary_pos_embedding: torch.Tensor,
q_norm_weight: torch.Tensor,
k_norm_weight: torch.Tensor,
positions: torch.Tensor,
num_tokens: int,
max_seqlen: int,
head_num: int,
kv_head_num: int,
head_dim: int,
rotary_dim: int,
emb_batch_size: int=1
) -> None:
xtorch_ops.split_norm_rope_neox(
q_emb,
k_emb,
v_out,
qkv,
rotary_pos_embedding,
q_norm_weight,
k_norm_weight,
positions,
num_tokens,
max_seqlen,
head_num,
kv_head_num,
head_dim,
rotary_dim,
)
def _fake_split_norm_rope_neox(
q_emb: torch.Tensor,
k_emb: torch.Tensor,
v_out: torch.Tensor,
qkv: torch.Tensor,
rotary_pos_embedding: torch.Tensor,
q_norm_weight: torch.Tensor,
k_norm_weight: torch.Tensor,
positions: torch.Tensor,
num_tokens: int,
max_seqlen: int,
head_num: int,
kv_head_num: int,
head_dim: int,
rotary_dim: int,
emb_batch_size: int=1):
q_emb.fake_shape = q_emb.shape
q_emb.fake_dtype = q_emb.dtype
k_emb.fake_shape = k_emb.shape
k_emb.fake_dtype = k_emb.dtype
v_out.fake_shape = v_out.shape
v_out.fake_dtype = v_out.dtype
return None
split_norm_rope_neox.register_fake(_fake_split_norm_rope_neox)
# register fake op impl here
# for torch.dynamo
from torch.library import register_fake
if hasattr(torch.ops.custom_ops, "fc_fusion"):
@register_fake("custom_ops::fc_fusion")
def fc_fusion_fake(self: torch.Tensor,
other: torch.Tensor,
bias: Optional[torch.Tensor],
self_trans: bool,
other_trans: bool,
*,
alpha: float=1.0,
beta: float=0.0,
act: int=1,
multi_stream: bool=False,
out: torch.Tensor
) -> None:
pass
@custom_op("_C::swiglu", mutates_args=())
def swiglu(
x: torch.Tensor,
y: torch.Tensor,
axis: int=-1,
turn: bool=True
) -> None:
xtorch_ops.swiglu(
x,
y,
)
@impl("_C::swiglu", "CUDA")
def swiglu_cuda(
x: torch.Tensor,
y: torch.Tensor,
axis: int=-1,
turn: bool=True
) -> None:
xtorch_ops.swiglu(
x,
y,
)
def _fake_swiglu(
x: torch.Tensor,
y: torch.Tensor,
axis: int=-1,
turn: bool=True):
return None
swiglu.register_fake(_fake_swiglu)
@custom_op("_C::swigluoai_and_mul", mutates_args=())
def swigluoai_and_mul(
x: torch.Tensor,
alpha: float = 1.702,
limit: float = 7.0,
axis: int = -1,
turn: bool = True
) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
gate, up = x[..., ::2], x[..., 1::2]
gate = gate.clamp(min=None, max=limit)
up = up.clamp(min=-limit, max=limit)
glu = gate * torch.sigmoid(gate * alpha)
gated_output = (up + 1) * glu
return gated_output
@impl("_C::swigluoai_and_mul", "CUDA")
def swigluoai_and_mul_cuda(
x: torch.Tensor,
alpha: float = 1.702,
limit: float = 7.0,
axis: int = -1,
turn: bool = True
) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
gate, up = x[..., ::2], x[..., 1::2]
gate = gate.clamp(min=None, max=limit)
up = up.clamp(min=-limit, max=limit)
glu = gate * torch.sigmoid(gate * alpha)
gated_output = (up + 1) * glu
return gated_output
def _fake_swigluoai_and_mul(
x: torch.Tensor,
alpha: float = 1.702,
limit: float = 7.0,
axis: int = -1,
turn: bool = True
) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
gate, up = x[..., ::2], x[..., 1::2]
gate = gate.clamp(min=None, max=limit)
up = up.clamp(min=-limit, max=limit)
glu = gate * torch.sigmoid(gate * alpha)
gated_output = (up + 1) * glu
return gated_output
swigluoai_and_mul.register_fake(_fake_swigluoai_and_mul)
@custom_op("_C::moe_softmax_topk", mutates_args=())
def moe_softmax_topk(
x: torch.Tensor,
normed_score: torch.Tensor,
topk_index: torch.Tensor,
block_statistic: torch.Tensor,
axis: int = -1,
turn: bool = True
) -> None:
xtorch_ops.moe_softmax_topk(
x,
normed_score,
topk_index,
block_statistic
)
@impl("_C::moe_softmax_topk", "CUDA")
def moe_softmax_topk_cuda(
x: torch.Tensor,
normed_score: torch.Tensor,
topk_index: torch.Tensor,
block_statistic: torch.Tensor,
axis: int = -1,
turn: bool = True
) -> None:
xtorch_ops.moe_softmax_topk(
x,
normed_score,
topk_index,
block_statistic
)
def _fake_moe_softmax_topk(
x: torch.Tensor,
normed_score: torch.Tensor,
topk_index: torch.Tensor,
block_statistic: torch.Tensor,
axis: int = -1,
turn: bool = True
) -> None:
return None
moe_softmax_topk.register_fake(_fake_moe_softmax_topk)
@custom_op("_C::moe_ffn_block", mutates_args=())
def moe_ffn_block(
out: torch.Tensor,
x: torch.Tensor,
expert_num: int,
moe_top_k: int,
gate_w: torch.Tensor,
inter_w: torch.Tensor,
output_w: torch.Tensor,
renormalize: bool = True,
use_grouped_topk: bool = False,
expert_group_num: Optional[int] = 0,
topk_group: Optional[int] = 0,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> None:
xtorch_ops.moe_ffn_block(
x=x,
gate_w=gate_w,
inter_w=inter_w,
output_w=output_w,
expert_num=expert_num,
moe_top_k=moe_top_k,
topk_group=topk_group,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
expert_group_num=expert_group_num,
out=out,
)
@impl("_C::moe_ffn_block", "CUDA")
def moe_ffn_block_cuda(
out: torch.Tensor,
x: torch.Tensor,
expert_num: int,
moe_top_k: int,
gate_w: torch.Tensor,
inter_w: torch.Tensor,
output_w: torch.Tensor,
renormalize: bool = True,
use_grouped_topk: bool = False,
expert_group_num: Optional[int] = 0,
topk_group: Optional[int] = 0,
w1_bias: Optional[torch.Tensor] = None,
w2_bias: Optional[torch.Tensor] = None,
) -> None:
xtorch_ops.moe_ffn_block(
x=x,
gate_w=gate_w,
inter_w=inter_w,
output_w=output_w,
expert_num=expert_num,
moe_top_k=moe_top_k,
topk_group=topk_group,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
expert_group_num=expert_group_num,
out=out,
)
def _fake_moe_ffn_block(
out: torch.Tensor,
x: torch.Tensor,
expert_num: int,
moe_top_k: int,
gate_w: torch.Tensor,
inter_w: torch.Tensor,
output_w: torch.Tensor,
renormalize: bool = True,
use_grouped_topk: bool = False,
expert_group_num: Optional[int] = 0,
topk_group: Optional[int] = 0,):
return None
moe_ffn_block.register_fake(_fake_moe_ffn_block)
@custom_op("_C::moe_ffn_per_token_block", mutates_args=())
def moe_ffn_per_token_block(
x: torch.Tensor,
inter_weight: torch.Tensor,
inter_scale: torch.Tensor,
outer_weight: torch.Tensor,
outer_scale: torch.Tensor,
top_k: int,
global_num_experts: int,
linear_weights: Optional[torch.Tensor] = None,
expert_map: Optional[torch.Tensor] = None,
activation: str = "silu",
output: Optional[torch.Tensor] = None,
use_expert_parallel: bool = False,
ep_size: int = 1,
ep_rank: int = 0
) -> None:
xtorch_ops.moe_ffn_per_token_block(
x=x,
inter_weight=inter_weight,
inter_scale=inter_scale,
outer_weight=outer_weight,
outer_scale=outer_scale,
gate_weight=linear_weights,
expert_num=global_num_experts,
moe_top_k=top_k,
act_type=activation,
use_expert_parallel=use_expert_parallel,
ep_size=ep_size,
ep_rank=ep_rank,
out=output,
)
@impl("_C::moe_ffn_per_token_block", "CUDA")
def moe_ffn_per_token_block_cuda(
x: torch.Tensor,
inter_weight: torch.Tensor,
inter_scale: torch.Tensor,
outer_weight: torch.Tensor,
outer_scale: torch.Tensor,
top_k: int,
global_num_experts: int,
linear_weights: Optional[torch.Tensor] = None,
expert_map: Optional[torch.Tensor] = None,
activation: str = "silu",
output: Optional[torch.Tensor] = None,
use_expert_parallel: bool = False,
ep_size: int = 1,
ep_rank: int = 0
) -> None:
xtorch_ops.moe_ffn_per_token_block(
x=x,
inter_weight=inter_weight,
inter_scale=inter_scale,
outer_weight=outer_weight,
outer_scale=outer_scale,
gate_weight=linear_weights,
expert_num=global_num_experts,
moe_top_k=top_k,
act_type=activation,
use_expert_parallel=use_expert_parallel,
ep_size=ep_size,
ep_rank=ep_rank,
out=output,
)
def _fake_moe_ffn_per_token_block(
x: torch.Tensor,
inter_weight: torch.Tensor,
inter_scale: torch.Tensor,
outer_weight: torch.Tensor,
outer_scale: torch.Tensor,
top_k: int,
global_num_experts: int,
linear_weights: Optional[torch.Tensor] = None,
expert_map: Optional[torch.Tensor] = None,
activation: str = "silu",
output: Optional[torch.Tensor] = None,
use_expert_parallel: bool = False,
ep_size: int = 1,
ep_rank: int = 0
) -> None:
# Fake implementation can be a no-op or a simple operation
if output is not None:
output.copy_(x) # Example: simply copy input to output
# Register the fake implementation
moe_ffn_per_token_block.register_fake(_fake_moe_ffn_per_token_block)
@custom_op("_C::rotary_embedding", mutates_args=())
def rotary_embedding(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None :
xtorch_ops.rotary_embedding(
positions=positions,
query=query,
key=key,
head_size=head_size,
cos_sin_cache=cos_sin_cache,
is_neox=is_neox)
@impl("_C::rotary_embedding", "CUDA")
def rotary_embedding_cuda(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None:
xtorch_ops.rotary_embedding(
positions=positions,
query=query,
key=key,
head_size=head_size,
cos_sin_cache=cos_sin_cache,
is_neox=is_neox)
def _fake_rotary_embedding(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
)-> None:
return None
rotary_embedding.register_fake(_fake_rotary_embedding)
@custom_op("_C::quant2d", mutates_args=())
def quant2d(
x: torch.Tensor,
y: torch.Tensor,
max: torch.Tensor,
force_sdnn: bool,
) -> None:
xtorch_ops.quant2d(
x=x,
y=y,
max=max,
force_sdnn=force_sdnn
)
@impl("_C::quant2d", "CUDA")
def quant2d_cuda(
x: torch.Tensor,
y: torch.Tensor,
max: torch.Tensor,
force_sdnn: bool,
) -> None:
xtorch_ops.quant2d(
x=x,
y=y,
max=max,
force_sdnn=force_sdnn
)
def _fake_quant2d(
x: torch.Tensor,
y: torch.Tensor,
max: torch.Tensor,
force_sdnn: bool,
) -> None:
return None
quant2d.register_fake(_fake_quant2d)
@custom_op("_C::gemm_I8_I8_bf16_nt", mutates_args=())
def gemm_I8_I8_bf16_nt(
x_q: torch.Tensor,
x_scale: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out: torch.Tensor,
) -> None:
xtorch_ops.gemm_I8_I8_bf16_nt(
lhs=(x_q, x_scale),
rhs=(weight, weight_scale),
out=out
)
@impl("_C::gemm_I8_I8_bf16_nt", "CUDA")
def gemm_I8_I8_bf16_nt_cuda(
x_q: torch.Tensor,
x_scale: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out: torch.Tensor,
) -> None:
xtorch_ops.gemm_I8_I8_bf16_nt(
lhs=(x_q, x_scale),
rhs=(weight, weight_scale),
out=out
)
def _fake_gemm_I8_I8_bf16_nt(
x_q: torch.Tensor,
x_scale: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out: torch.Tensor,
) -> None:
return None
gemm_I8_I8_bf16_nt.register_fake(_fake_gemm_I8_I8_bf16_nt)
@custom_op("_C::moe_softmax_topk_norm", mutates_args=())
def moe_softmax_topk_norm(
x: torch.Tensor,
normed_score: torch.Tensor,
topk_index: torch.Tensor,
block_statistic: torch.Tensor,
stable: bool = True
) -> None:
xtorch_ops.moe_softmax_topk_norm(
x,
normed_score,
topk_index,
block_statistic,
stable
)
@impl("_C::moe_softmax_topk_norm", "CUDA")
def moe_softmax_topk_norm_cuda(
x: torch.Tensor,
normed_score: torch.Tensor,
topk_index: torch.Tensor,
block_statistic: torch.Tensor,
stable: bool = True
) -> None:
xtorch_ops.moe_softmax_topk_norm(
x,
normed_score,
topk_index,
block_statistic,
stable
)
def _fake_moe_softmax_topk_norm(
x: torch.Tensor,
normed_score: torch.Tensor,
topk_index: torch.Tensor,
block_statistic: torch.Tensor,
stable: bool = True
) -> None:
return None
moe_softmax_topk_norm.register_fake(_fake_moe_softmax_topk_norm)
@custom_op("_C::gen_block_statistic", mutates_args=())
def gen_block_statistic(
topk_ids: torch.Tensor,
block_statistic: torch.Tensor
)-> None:
xtorch_ops.gen_block_statistic(
topk_ids,block_statistic
)
@impl("_C::gen_block_statistic", "CUDA")
def gen_block_statistic_cuda(
topk_ids: torch.Tensor,
block_statistic: torch.Tensor
)-> None:
xtorch_ops.gen_block_statistic(
topk_ids,block_statistic
)
def fake_gen_block_statistic(
topk_ids: torch.Tensor,
block_statistic: torch.Tensor
)-> None:
return None
gen_block_statistic.register_fake(fake_gen_block_statistic)
@custom_op("_C::moe_pre_sorted", mutates_args=())
def moe_pre_sorted(
x: torch.Tensor,
topk_index: torch.Tensor,
block_statistic: torch.Tensor,
moe_expand: torch.Tensor,
moe_index: torch.Tensor,
expert_m: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
index_have_neg: bool = False
)-> None:
xtorch_ops.moe_pre_sorted(
x,
topk_index,
block_statistic,
moe_expand,
moe_index,
expert_m,
sorted_tokens_num_lod)
@impl("_C::moe_pre_sorted", "CUDA")
def moe_pre_sorted_cuda(
x: torch.Tensor,
topk_index: torch.Tensor,
block_statistic: torch.Tensor,
moe_expand: torch.Tensor,
moe_index: torch.Tensor,
expert_m: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
index_have_neg: bool = False
)-> None:
xtorch_ops.moe_pre_sorted(
x,
topk_index,
block_statistic,
moe_expand,
moe_index,
expert_m,
sorted_tokens_num_lod)
def fake_moe_pre_sorted(
x: torch.Tensor,
topk_index: torch.Tensor,
block_statistic: torch.Tensor,
moe_expand: torch.Tensor,
moe_index: torch.Tensor,
expert_m: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
index_have_neg: bool = False
)-> None:
return None
moe_pre_sorted.register_fake(fake_moe_pre_sorted)
@custom_op("_C::moe_fc", mutates_args=())
def moe_fc(
x: torch.Tensor,
weight: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
sorted_tokens_idx: torch.Tensor,
moe_topk: int,
y: torch.Tensor,
act: Optional[str] = None,
x_perchannel_max: Optional[torch.Tensor] = None,
w_perchannel_max: Optional[torch.Tensor] = None ,
topk_ids: Optional[torch.Tensor] = None,
topk_w: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
tgemm_type: Optional[str] = None,
tweight_type: Optional[str] = None,
scale_n: Optional[int] = 0,
scale_k: Optional[int] = 0,
use_pack_int4: Optional[bool] = False,
sort_mode: Optional[bool] = True
)-> None:
xtorch_ops.moe_fc(
x=x,
weight=weight,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_topk,
y=y,
act=act,
x_perchannel_max=x_perchannel_max,
w_perchannel_max=w_perchannel_max,
topk_ids=topk_ids,
topk_w=topk_w,
bias=bias,
tgemm_type=tgemm_type,
tweight_type=tweight_type,
scale_n=scale_n,
scale_k=scale_k,
use_pack_int4=use_pack_int4,
sort_mode=sort_mode)
@impl("_C::moe_fc", "CUDA")
def moe_fc_cuda(
x: torch.Tensor,
weight: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
sorted_tokens_idx: torch.Tensor,
moe_topk: int,
y: torch.Tensor,
act: Optional[str] = None,
x_perchannel_max: Optional[torch.Tensor] = None,
w_perchannel_max: Optional[torch.Tensor] = None ,
topk_ids: Optional[torch.Tensor] = None,
topk_w: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
tgemm_type: Optional[str] = None,
tweight_type: Optional[str] = None,
scale_n: Optional[int] = 0,
scale_k: Optional[int] = 0,
use_pack_int4: Optional[bool] = False,
sort_mode: Optional[bool] = True
)-> None:
xtorch_ops.moe_fc(
x=x,
weight=weight,
sorted_tokens_num_lod=sorted_tokens_num_lod,
sorted_tokens_idx=sorted_tokens_idx,
moe_topk=moe_topk,
y=y,
act=act,
x_perchannel_max=x_perchannel_max,
w_perchannel_max=w_perchannel_max,
topk_ids=topk_ids,
topk_w=topk_w,
bias=bias,
tgemm_type=tgemm_type,
tweight_type=tweight_type,
scale_n=scale_n,
scale_k=scale_k,
use_pack_int4=use_pack_int4,
sort_mode=sort_mode)
def fake_moe_fc(
x: torch.Tensor,
weight: torch.Tensor,
sorted_tokens_num_lod: torch.Tensor,
sorted_tokens_idx: torch.Tensor,
moe_topk: int,
y: torch.Tensor,
act: Optional[str] = None,
x_perchannel_max: Optional[torch.Tensor] = None,
w_perchannel_max: Optional[torch.Tensor] = None ,
topk_ids: Optional[torch.Tensor] = None,
topk_w: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
tgemm_type: Optional[str] = None,
tweight_type: Optional[str] = None,
scale_n: Optional[int] = 0,
scale_k: Optional[int] = 0,
use_pack_int4: Optional[bool] = False,
sort_mode: Optional[bool] = True
)-> None:
return None
moe_fc.register_fake(fake_moe_fc)
@custom_op("_C::moe_post", mutates_args=())
def moe_post(
x: torch.Tensor,
moe_index: torch.Tensor,
normed_scale: torch.Tensor,
dequant_scale: torch.Tensor,
y: torch.Tensor
)-> None:
xtorch_ops.moe_post(
x,
moe_index,
normed_scale,
dequant_scale,
y
)
@impl("_C::moe_post", "CUDA")
def moe_post_cuda(
x: torch.Tensor,
moe_index: torch.Tensor,
normed_scale: torch.Tensor,
dequant_scale: torch.Tensor,
y: torch.Tensor
)-> None:
xtorch_ops.moe_post(
x,
moe_index,
normed_scale,
dequant_scale,
y)
def fake_moe_post(
x: torch.Tensor,
moe_index: torch.Tensor,
normed_scale: torch.Tensor,
dequant_scale: torch.Tensor,
y: torch.Tensor
)-> None:
return None
moe_post.register_fake(fake_moe_post)
@custom_op("_C::moe_sigmoid_group_topk_norm", mutates_args=())
def moe_sigmoid_group_topk_norm(
x: torch.Tensor,
topk_index: torch.Tensor,
norm_score: torch.Tensor,
block_static: torch.Tensor,
bias: torch.Tensor,
scale: float,
n_group: int,
topk_group: int
) -> None:
xtorch_ops.moe_sigmoid_group_topk_norm(
x=x,
norm_score=norm_score,
topk_index=topk_index,
block_static=block_static,
bias=bias,
n_group=n_group,
topk_group=topk_group,
scale=scale,
)
@impl("_C::moe_sigmoid_group_topk_norm", "CUDA")
def moe_sigmoid_group_topk_norm_cuda(
x: torch.Tensor,
topk_index: torch.Tensor,
norm_score: torch.Tensor,
block_static: torch.Tensor,
bias: torch.Tensor,
scale: float,
n_group: int,
topk_group: int
) -> None:
xtorch_ops.moe_sigmoid_group_topk_norm(
x=x,
norm_score=norm_score,
topk_index=topk_index,
block_static=block_static,
bias=bias,
n_group=n_group,
topk_group=topk_group,
scale=scale,
)
def _fake_moe_sigmoid_group_topk_norm(
x: torch.Tensor,
topk_index: torch.Tensor,
norm_score: torch.Tensor,
block_static: torch.Tensor,
bias: torch.Tensor,
scale: float,
n_group: int,
topk_group: int
) -> None:
return None
moe_sigmoid_group_topk_norm.register_fake(_fake_moe_sigmoid_group_topk_norm)
##################################################
# --------------- awq_dequantize -----------------
##################################################
@custom_op("_C::awq_dequantize", mutates_args=())
def awq_dequantize(
qweight: torch.Tensor,
scales: torch.Tensor,
zeros: torch.Tensor,
quant_type: int = 0,
align_type: int = 1,
) -> torch.Tensor:
weight = torch.empty(
(qweight.shape[0], qweight.shape[1] * 8),
dtype=torch.float16,
device=qweight.device,
)
group_m = int(qweight.shape[0] / scales.shape[0])
xtorch_ops.awq_dequantize(
qweight=qweight,
scales=scales,
zeros=zeros,
weight=weight,
group_m=group_m,
quant_type=quant_type,
align_type=align_type,
)
return weight
@impl("_C::awq_dequantize", "CUDA")
def awq_dequantize_cuda(
qweight: torch.Tensor,
scales: torch.Tensor,
zeros: torch.Tensor,
quant_type: int = 0,
align_type: int = 1,
) -> torch.Tensor:
weight = torch.empty(
(qweight.shape[0], qweight.shape[1] * 8),
dtype=torch.float16,
device=qweight.device,
)
group_m = int(qweight.shape[0] / scales.shape[0])
out = xtorch_ops.awq_dequantize(
qweight=qweight,
scales=scales,
zeros=zeros,
weight=weight,
group_m=group_m,
quant_type=quant_type,
align_type=align_type,
)
return weight
def _fake_awq_dequantize(
qweight: torch.Tensor,
scales: torch.Tensor,
zeros: torch.Tensor,
quant_type: int = 0,
align_type: int = 1,
) -> torch.Tensor:
weight = torch.empty(
(qweight.shape[0], qweight.shape[1] * 8),
dtype=torch.float16,
device=qweight.device,
)
return weight
awq_dequantize.register_fake(_fake_awq_dequantize)
##################################################
# ------------------ awq_gemm -------------------
##################################################
@custom_op("_C::awq_gemm", mutates_args=())
def awq_gemm(
x: torch.Tensor,
qweight: torch.Tensor,
scale: torch.Tensor,
zeros: torch.Tensor,
align_type: int = 1,
) -> torch.Tensor:
out = torch.empty(
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
)
group_size = int(qweight.shape[0] / scale.shape[0])
xtorch_ops.awq_gemm(
x=x,
w=qweight,
scale=scale,
zeros=zeros,
out=out,
align_type=align_type,
group_size=group_size,
)
return out
@impl("_C::awq_gemm", "CUDA")
def awq_gemm_cuda(
x: torch.Tensor,
qweight: torch.Tensor,
scale: torch.Tensor,
zeros: torch.Tensor,
align_type: int = 1,
) -> torch.Tensor:
out = torch.empty(
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
)
group_size = int(qweight.shape[0] / scale.shape[0])
xtorch_ops.awq_gemm(
x=x,
w=qweight,
scale=scale,
zeros=zeros,
out=out,
align_type=align_type,
group_size=group_size,
)
return out
def _fake_awq_gemm(
x: torch.Tensor,
qweight: torch.Tensor,
scale: torch.Tensor,
zeros: torch.Tensor,
align_type: int = 1,
) -> torch.Tensor:
out = torch.empty(
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
)
return out
awq_gemm.register_fake(_fake_awq_gemm)
##################################################
# ---------------- gptq_shuffle ------------------
##################################################
@custom_op("_C::gptq_shuffle", mutates_args=())
def gptq_shuffle(
q_weight: torch.Tensor,
q_perm: torch.Tensor,
bit: int,
) -> None:
xtorch_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
@impl("_C::gptq_shuffle", "CUDA")
def gptq_shuffle_cuda(
q_weight: torch.Tensor,
q_perm: torch.Tensor,
bit: int,
) -> None:
xtorch_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
def _fake_gptq_shuffle(
q_weight: torch.Tensor,
q_perm: torch.Tensor,
bit: int,
) -> None:
return None
gptq_shuffle.register_fake(_fake_gptq_shuffle)