Migrate XTorch operations to Kunlun operations (accelerating iteration) (#177)
Signed-off-by: dongxinyu03 <dongxinyu03@baidu.com>
This commit is contained in:
@@ -70,7 +70,7 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer,
|
|||||||
from vllm_kunlun.ops.activation import SiluAndMul
|
from vllm_kunlun.ops.activation import SiluAndMul
|
||||||
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
from vllm_kunlun.ops._kunlun_ops import KunlunOps as ops
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import get_masked_input_and_mask
|
from vllm.model_executor.layers.vocab_parallel_embedding import get_masked_input_and_mask
|
||||||
import xtorch_ops
|
import kunlun_ops
|
||||||
|
|
||||||
|
|
||||||
@torch.compile(dynamic=True, backend="aot_eager")
|
@torch.compile(dynamic=True, backend="aot_eager")
|
||||||
@@ -640,7 +640,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
|||||||
last_recurrent_state = last_recurrent_state.transpose(-1, -2).contiguous().to(ssm_state.dtype).view(
|
last_recurrent_state = last_recurrent_state.transpose(-1, -2).contiguous().to(ssm_state.dtype).view(
|
||||||
last_recurrent_state.shape[0], -1, last_recurrent_state.shape[-1])
|
last_recurrent_state.shape[0], -1, last_recurrent_state.shape[-1])
|
||||||
cast_ssm_state = ssm_state.view(ssm_state.shape[0], 1, -1, ssm_state.shape[-1])
|
cast_ssm_state = ssm_state.view(ssm_state.shape[0], 1, -1, ssm_state.shape[-1])
|
||||||
xtorch_ops.reshape_and_cache_flash(
|
kunlun_ops.reshape_and_cache_flash(
|
||||||
last_recurrent_state,
|
last_recurrent_state,
|
||||||
last_recurrent_state,
|
last_recurrent_state,
|
||||||
cast_ssm_state,
|
cast_ssm_state,
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM, Qwen3Model
|
|||||||
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
from vllm.model_executor.models.utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
||||||
maybe_prefix, merge_multimodal_embeddings)
|
maybe_prefix, merge_multimodal_embeddings)
|
||||||
from vllm.model_executor.models.vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
|
from vllm.model_executor.models.vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
|
||||||
import xtorch_ops
|
import kunlun_ops
|
||||||
from einops import repeat
|
from einops import repeat
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ from vllm.logger import init_logger
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import xtorch_ops
|
import kunlun_ops
|
||||||
logger.info(f"Load custom ops library success!")
|
logger.info(f"Load custom ops library success!")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning("Import error msg: %s", e.msg)
|
logger.warning("Import error msg: %s", e.msg)
|
||||||
@@ -71,7 +71,7 @@ class KunlunOps:
|
|||||||
):
|
):
|
||||||
""" PagedAttentionV1 """
|
""" PagedAttentionV1 """
|
||||||
# block_size = value_cache.shape[2]
|
# block_size = value_cache.shape[2]
|
||||||
xtorch_ops.paged_attention(
|
kunlun_ops.paged_attention(
|
||||||
x=query,
|
x=query,
|
||||||
k_cache=key_cache,
|
k_cache=key_cache,
|
||||||
v_cache=value_cache,
|
v_cache=value_cache,
|
||||||
@@ -114,7 +114,7 @@ class KunlunOps:
|
|||||||
):
|
):
|
||||||
""" PagedAttentionV2 """
|
""" PagedAttentionV2 """
|
||||||
# block_size = value_cache.shape[2]
|
# block_size = value_cache.shape[2]
|
||||||
xtorch_ops.paged_attention(
|
kunlun_ops.paged_attention(
|
||||||
x=query,
|
x=query,
|
||||||
k_cache=key_cache,
|
k_cache=key_cache,
|
||||||
v_cache=value_cache,
|
v_cache=value_cache,
|
||||||
@@ -133,7 +133,7 @@ class KunlunOps:
|
|||||||
def silu_and_mul(out: torch.Tensor,
|
def silu_and_mul(out: torch.Tensor,
|
||||||
x: torch.Tensor):
|
x: torch.Tensor):
|
||||||
""" silu and mul """
|
""" silu and mul """
|
||||||
xtorch_ops.silu_and_mul(
|
kunlun_ops.silu_and_mul(
|
||||||
x,
|
x,
|
||||||
axis=-1,
|
axis=-1,
|
||||||
turn=True,
|
turn=True,
|
||||||
@@ -145,7 +145,7 @@ class KunlunOps:
|
|||||||
def quick_gelu(out: torch.Tensor,
|
def quick_gelu(out: torch.Tensor,
|
||||||
x: torch.Tensor):
|
x: torch.Tensor):
|
||||||
""" quick gelu """
|
""" quick gelu """
|
||||||
xtorch_ops.quick_gelu(
|
kunlun_ops.quick_gelu(
|
||||||
x,
|
x,
|
||||||
out=out,
|
out=out,
|
||||||
)
|
)
|
||||||
@@ -159,7 +159,7 @@ class KunlunOps:
|
|||||||
epsilon,
|
epsilon,
|
||||||
):
|
):
|
||||||
"""rms_norm"""
|
"""rms_norm"""
|
||||||
xtorch_ops.rmsnorm(
|
kunlun_ops.rmsnorm(
|
||||||
x, weight.to(torch.float32), epsilon, out=out
|
x, weight.to(torch.float32), epsilon, out=out
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -172,7 +172,7 @@ class KunlunOps:
|
|||||||
):
|
):
|
||||||
"""fused_add_rms_norm"""
|
"""fused_add_rms_norm"""
|
||||||
output = torch.empty_like(x)
|
output = torch.empty_like(x)
|
||||||
xtorch_ops.add_rmsnorm(
|
kunlun_ops.add_rmsnorm(
|
||||||
x, residual, weight.to(torch.float32), epsilon, out=output
|
x, residual, weight.to(torch.float32), epsilon, out=output
|
||||||
)
|
)
|
||||||
fused_input = x + residual
|
fused_input = x + residual
|
||||||
@@ -222,7 +222,7 @@ class KunlunOps:
|
|||||||
key_x = key.contiguous()
|
key_x = key.contiguous()
|
||||||
query_x_dim = query_x.dim()
|
query_x_dim = query_x.dim()
|
||||||
assert is_neox_style
|
assert is_neox_style
|
||||||
xtorch_ops.mrotary_embedding_neox(
|
kunlun_ops.mrotary_embedding_neox(
|
||||||
positions,
|
positions,
|
||||||
query_x,
|
query_x,
|
||||||
key_x,
|
key_x,
|
||||||
@@ -240,7 +240,7 @@ class KunlunOps:
|
|||||||
dst,
|
dst,
|
||||||
block_mapping):
|
block_mapping):
|
||||||
""" swap_blocks """
|
""" swap_blocks """
|
||||||
xtorch_ops.swap_blocks(
|
kunlun_ops.swap_blocks(
|
||||||
src,
|
src,
|
||||||
dst,
|
dst,
|
||||||
block_mapping
|
block_mapping
|
||||||
@@ -255,7 +255,7 @@ class KunlunOps:
|
|||||||
for i in range(len(key_caches)):
|
for i in range(len(key_caches)):
|
||||||
key_caches[i] = key_caches[i].contiguous()
|
key_caches[i] = key_caches[i].contiguous()
|
||||||
value_caches[i] = value_caches[i].contiguous()
|
value_caches[i] = value_caches[i].contiguous()
|
||||||
xtorch_ops.copy_blocks(
|
kunlun_ops.copy_blocks(
|
||||||
key_caches,
|
key_caches,
|
||||||
value_caches,
|
value_caches,
|
||||||
block_mapping,
|
block_mapping,
|
||||||
@@ -272,7 +272,7 @@ class KunlunOps:
|
|||||||
):
|
):
|
||||||
""" reshape_and_cache """
|
""" reshape_and_cache """
|
||||||
# slot_mapping_cast = slot_mapping.to(torch.int32)
|
# slot_mapping_cast = slot_mapping.to(torch.int32)
|
||||||
xtorch_ops.reshape_and_cache(
|
kunlun_ops.reshape_and_cache(
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
key_cache,
|
key_cache,
|
||||||
@@ -308,7 +308,7 @@ class KunlunOps:
|
|||||||
repeat = Qh // KVh
|
repeat = Qh // KVh
|
||||||
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
|
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
|
||||||
value = value.repeat_interleave(repeat, dim=2)
|
value = value.repeat_interleave(repeat, dim=2)
|
||||||
xtorch_ops.attention(
|
kunlun_ops.attention(
|
||||||
q=query,
|
q=query,
|
||||||
k_cache=key,
|
k_cache=key,
|
||||||
v_cache=value,
|
v_cache=value,
|
||||||
@@ -337,7 +337,7 @@ class KunlunOps:
|
|||||||
else:
|
else:
|
||||||
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
|
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
|
||||||
|
|
||||||
xtorch_ops.quant_fusedresidual_rmsnorm(x, residual, weight, bias, eps,
|
kunlun_ops.quant_fusedresidual_rmsnorm(x, residual, weight, bias, eps,
|
||||||
out=out, out_scale=out_scale , residual_tensor=residual)
|
out=out, out_scale=out_scale , residual_tensor=residual)
|
||||||
|
|
||||||
if residual is None:
|
if residual is None:
|
||||||
@@ -360,7 +360,7 @@ class KunlunOps:
|
|||||||
else:
|
else:
|
||||||
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
|
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
|
||||||
|
|
||||||
xtorch_ops.quant_rmsnorm(x, weight, bias, eps,
|
kunlun_ops.quant_rmsnorm(x, weight, bias, eps,
|
||||||
out=out, out_scale=out_scale)
|
out=out, out_scale=out_scale)
|
||||||
return out, out_scale
|
return out, out_scale
|
||||||
|
|
||||||
@@ -388,7 +388,7 @@ class KunlunOps:
|
|||||||
dtype=torch.float16,
|
dtype=torch.float16,
|
||||||
device=weight.device)
|
device=weight.device)
|
||||||
output_bs_shape = [-1]
|
output_bs_shape = [-1]
|
||||||
xtorch_ops.smooth_quant_matmul_column_row_kernels(input_tensor,
|
kunlun_ops.smooth_quant_matmul_column_row_kernels(input_tensor,
|
||||||
weight, smoother,
|
weight, smoother,
|
||||||
input_scale,
|
input_scale,
|
||||||
weight_scale,
|
weight_scale,
|
||||||
@@ -642,7 +642,7 @@ class KunlunOps:
|
|||||||
"""mla pa block"""
|
"""mla pa block"""
|
||||||
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype,
|
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype,
|
||||||
device=hidden_states.device)
|
device=hidden_states.device)
|
||||||
xtorch_ops.xft_multi_head_latent_page_attention_block(
|
kunlun_ops.xft_multi_head_latent_page_attention_block(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
q_lora_rank,
|
q_lora_rank,
|
||||||
kv_lora_rank,
|
kv_lora_rank,
|
||||||
@@ -688,7 +688,7 @@ class KunlunOps:
|
|||||||
threshold: float = 20.0,
|
threshold: float = 20.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""fused_gdn_gating"""
|
"""fused_gdn_gating"""
|
||||||
output = xtorch_ops.fused_gdn_gating(
|
output = kunlun_ops.fused_gdn_gating(
|
||||||
A_log,
|
A_log,
|
||||||
a,
|
a,
|
||||||
dt_bias,
|
dt_bias,
|
||||||
@@ -713,7 +713,7 @@ class KunlunOps:
|
|||||||
2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制。
|
2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制。
|
||||||
'''
|
'''
|
||||||
|
|
||||||
o, final_state = xtorch_ops.fused_recurrent_gated_delta_rule_fwd(
|
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwd(
|
||||||
q, k, v, g, beta, scale, h0_source, output_final_state, use_qk_l2norm_in_kernel,
|
q, k, v, g, beta, scale, h0_source, output_final_state, use_qk_l2norm_in_kernel,
|
||||||
cu_seqlens)
|
cu_seqlens)
|
||||||
return (o, final_state)
|
return (o, final_state)
|
||||||
@@ -93,7 +93,7 @@ class SiluAndMul(CustomOp):
|
|||||||
|
|
||||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""forward_cuda"""
|
"""forward_cuda"""
|
||||||
import xtorch_ops
|
import kunlun_ops
|
||||||
|
|
||||||
d = x.shape[-1] // 2
|
d = x.shape[-1] // 2
|
||||||
output_shape = (x.shape[:-1] + (d, ))
|
output_shape = (x.shape[:-1] + (d, ))
|
||||||
@@ -103,7 +103,7 @@ class SiluAndMul(CustomOp):
|
|||||||
|
|
||||||
def forward_kunlun(self, x: torch.Tensor) -> torch.Tensor:
|
def forward_kunlun(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""forward_kunlun"""
|
"""forward_kunlun"""
|
||||||
import xtorch_ops
|
import kunlun_ops
|
||||||
|
|
||||||
d = x.shape[-1] // 2
|
d = x.shape[-1] // 2
|
||||||
output_shape = (x.shape[:-1] + (d, ))
|
output_shape = (x.shape[:-1] + (d, ))
|
||||||
@@ -251,14 +251,14 @@ class GeluAndMul(CustomOp):
|
|||||||
无。
|
无。
|
||||||
"""
|
"""
|
||||||
# from vllm import _custom_ops as ops
|
# from vllm import _custom_ops as ops
|
||||||
import xtorch_ops
|
import kunlun_ops
|
||||||
# d = x.shape[-1] // 2
|
# d = x.shape[-1] // 2
|
||||||
# output_shape = (x.shape[:-1] + (d, ))
|
# output_shape = (x.shape[:-1] + (d, ))
|
||||||
out = torch.empty(x, dtype=x.dtype, device=x.device)
|
out = torch.empty(x, dtype=x.dtype, device=x.device)
|
||||||
if self.approximate == "none":
|
if self.approximate == "none":
|
||||||
# ops.gelu_and_mul(out, x)
|
# ops.gelu_and_mul(out, x)
|
||||||
print(x,x.shape)
|
print(x,x.shape)
|
||||||
xtorch_ops.gelu(x, out)
|
kunlun_ops.gelu(x, out)
|
||||||
elif self.approximate == "tanh":
|
elif self.approximate == "tanh":
|
||||||
ops.gelu_tanh_and_mul(out, x)
|
ops.gelu_tanh_and_mul(out, x)
|
||||||
return out
|
return out
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import torch
|
|||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
import xtorch_ops
|
import kunlun_ops
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -104,7 +104,7 @@ def flash_mla_with_kvcache(
|
|||||||
is_context = False
|
is_context = False
|
||||||
vo_head_dim = -1
|
vo_head_dim = -1
|
||||||
|
|
||||||
xtorch_ops.paged_attention(out,
|
kunlun_ops.paged_attention(out,
|
||||||
q,
|
q,
|
||||||
k_cache, None,
|
k_cache, None,
|
||||||
block_table,
|
block_table,
|
||||||
@@ -149,7 +149,7 @@ def kunlun_flash_mla_with_kvcache(
|
|||||||
p_sums: (batch_size, seq_len_q, num_heads_q), torch.float32.
|
p_sums: (batch_size, seq_len_q, num_heads_q), torch.float32.
|
||||||
"""
|
"""
|
||||||
assert not is_fp8_kvcache, "By now, the kernel does not support uint8 kv cache."
|
assert not is_fp8_kvcache, "By now, the kernel does not support uint8 kv cache."
|
||||||
assert q.shape[1] <= 2, "xtorch_ops.fwd_kvcache_mla only support seq_len_q <= 2 for now."
|
assert q.shape[1] <= 2, "kunlun_ops.fwd_kvcache_mla only support seq_len_q <= 2 for now."
|
||||||
if softmax_scale is None:
|
if softmax_scale is None:
|
||||||
softmax_scale = q.shape[-1] ** (-0.5)
|
softmax_scale = q.shape[-1] ** (-0.5)
|
||||||
if indices is not None:
|
if indices is not None:
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import xtorch_ops
|
import kunlun_ops
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
@@ -16,7 +16,7 @@ def merge_attn_states(
|
|||||||
output_lse: Optional[torch.Tensor] = None,
|
output_lse: Optional[torch.Tensor] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
return xtorch_ops.attention_merge_stage(
|
return kunlun_ops.attention_merge_stage(
|
||||||
prefix_output,
|
prefix_output,
|
||||||
prefix_lse,
|
prefix_lse,
|
||||||
suffix_output,
|
suffix_output,
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import xtorch_ops
|
import kunlun_ops
|
||||||
|
|
||||||
|
|
||||||
class FusedRecurrentFunction(torch.autograd.Function):
|
class FusedRecurrentFunction(torch.autograd.Function):
|
||||||
@@ -31,7 +31,7 @@ class FusedRecurrentFunction(torch.autograd.Function):
|
|||||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||||
use_qk_l2norm_in_kernel: bool = False):
|
use_qk_l2norm_in_kernel: bool = False):
|
||||||
|
|
||||||
o, final_state = xtorch_ops.fused_recurrent_gated_delta_rule_fwdv2(
|
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwdv2(
|
||||||
q.contiguous(),
|
q.contiguous(),
|
||||||
k.contiguous(),
|
k.contiguous(),
|
||||||
v.contiguous(),
|
v.contiguous(),
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
import xtorch_ops
|
import kunlun_ops
|
||||||
|
|
||||||
|
|
||||||
BT_LIST = [8, 16, 32, 64, 128]
|
BT_LIST = [8, 16, 32, 64, 128]
|
||||||
@@ -149,5 +149,5 @@ def l2norm_fwd(x: torch.Tensor,
|
|||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
output_dtype: Optional[torch.dtype] = None):
|
output_dtype: Optional[torch.dtype] = None):
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
xtorch_ops.l2norm(x, out, eps)
|
kunlun_ops.l2norm(x, out, eps)
|
||||||
return out
|
return out
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
|||||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm as OriGemmaRMSNorm
|
from vllm.model_executor.layers.layernorm import GemmaRMSNorm as OriGemmaRMSNorm
|
||||||
from vllm.model_executor.layers import layernorm
|
from vllm.model_executor.layers import layernorm
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
import xtorch_ops
|
import kunlun_ops
|
||||||
|
|
||||||
def vllm_kunlun_forward_cuda(
|
def vllm_kunlun_forward_cuda(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import torch.nn.functional as F
|
|||||||
|
|
||||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
import xtorch_ops
|
import kunlun_ops
|
||||||
|
|
||||||
|
|
||||||
@triton.jit()
|
@triton.jit()
|
||||||
@@ -1212,7 +1212,7 @@ def torch_causal_conv1d_update(
|
|||||||
tmp_hidden_states = hidden_states_new[:, :, -state_len:]
|
tmp_hidden_states = hidden_states_new[:, :, -state_len:]
|
||||||
ori_shape = tmp_hidden_states.shape
|
ori_shape = tmp_hidden_states.shape
|
||||||
tmp_hidden_states = tmp_hidden_states.transpose(1, 2).reshape(ori_shape)
|
tmp_hidden_states = tmp_hidden_states.transpose(1, 2).reshape(ori_shape)
|
||||||
xtorch_ops.reshape_and_cache_flash(
|
kunlun_ops.reshape_and_cache_flash(
|
||||||
tmp_hidden_states,
|
tmp_hidden_states,
|
||||||
tmp_hidden_states,
|
tmp_hidden_states,
|
||||||
cast_conv_state,
|
cast_conv_state,
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ class KunlunCompressedTensorsMoEMethod(FusedMoEMethodBase):
|
|||||||
class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMethod):
|
class KunlunCompressedTensorsW8A8Int8MoEMethod(CompressedTensorsW8A8Int8MoEMethod):
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
# NOTE: xtorch_ops use max as scale
|
# NOTE: kunlun_ops use max as scale
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
layer.w13_weight_scale.mul_(127.0)
|
layer.w13_weight_scale.mul_(127.0)
|
||||||
layer.w2_weight_scale.mul_(127.0)
|
layer.w2_weight_scale.mul_(127.0)
|
||||||
|
|||||||
@@ -28,9 +28,9 @@ from typing import (
|
|||||||
TypeVar,
|
TypeVar,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import kunlun_ops
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import xtorch_ops
|
|
||||||
from vllm.attention.backends.abstract import (
|
from vllm.attention.backends.abstract import (
|
||||||
AttentionBackend,
|
AttentionBackend,
|
||||||
AttentionImpl,
|
AttentionImpl,
|
||||||
@@ -39,6 +39,7 @@ from vllm.attention.backends.abstract import (
|
|||||||
AttentionType,
|
AttentionType,
|
||||||
)
|
)
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.utils import cdiv
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
AttentionCGSupport,
|
AttentionCGSupport,
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
@@ -227,9 +228,9 @@ class KunlunMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""__post_init__"""
|
"""__post_init__"""
|
||||||
self.attn_bias: Optional[List[AttentionBias]] = None
|
self.attn_bias: Optional[List[AttentionBias]] = None # noqa: F821
|
||||||
self.encoder_attn_bias: Optional[List[AttentionBias]] = None
|
self.encoder_attn_bias: Optional[List[AttentionBias]] = None # noqa: F821
|
||||||
self.cross_attn_bias: Optional[List[AttentionBias]] = None
|
self.cross_attn_bias: Optional[List[AttentionBias]] = None # noqa: F821
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_all_encoder_attn_metadata_set(self):
|
def is_all_encoder_attn_metadata_set(self):
|
||||||
@@ -572,12 +573,11 @@ class KunlunAttentionMetadataBuilder:
|
|||||||
"""build"""
|
"""build"""
|
||||||
num_reqs = common_attn_metadata.num_reqs
|
num_reqs = common_attn_metadata.num_reqs
|
||||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||||
max_query_len = common_attn_metadata.max_query_len
|
|
||||||
common_prefix_len = common_prefix_len
|
common_prefix_len = common_prefix_len
|
||||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||||
slot_mapping = common_attn_metadata.slot_mapping
|
slot_mapping = common_attn_metadata.slot_mapping
|
||||||
|
|
||||||
max_seq_len = int(common_attn_metadata.seq_lens_cpu.max())
|
|
||||||
query_start_loc_host = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1]
|
query_start_loc_host = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1]
|
||||||
query_start_loc = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1].to(
|
query_start_loc = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1].to(
|
||||||
self.device, non_blocking=True
|
self.device, non_blocking=True
|
||||||
@@ -771,7 +771,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
|||||||
# not cached. This happens during the initial memory
|
# not cached. This happens during the initial memory
|
||||||
value = value.contiguous()
|
value = value.contiguous()
|
||||||
if key_cache.is_contiguous():
|
if key_cache.is_contiguous():
|
||||||
xtorch_ops.reshape_and_cache(
|
kunlun_ops.reshape_and_cache(
|
||||||
key[: attn_metadata.num_actual_tokens],
|
key[: attn_metadata.num_actual_tokens],
|
||||||
value[: attn_metadata.num_actual_tokens],
|
value[: attn_metadata.num_actual_tokens],
|
||||||
key_cache,
|
key_cache,
|
||||||
@@ -781,7 +781,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
|||||||
else:
|
else:
|
||||||
cast_key_cache = key_cache.squeeze(1).unsqueeze(-2)
|
cast_key_cache = key_cache.squeeze(1).unsqueeze(-2)
|
||||||
cast_value_cache = value_cache.squeeze(1).unsqueeze(-2)
|
cast_value_cache = value_cache.squeeze(1).unsqueeze(-2)
|
||||||
xtorch_ops.reshape_and_cache_flash(
|
kunlun_ops.reshape_and_cache_flash(
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
cast_key_cache,
|
cast_key_cache,
|
||||||
@@ -791,7 +791,6 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
|||||||
|
|
||||||
assert attn_type == AttentionType.DECODER
|
assert attn_type == AttentionType.DECODER
|
||||||
# Decoder self-attention supports chunked prefill.
|
# Decoder self-attention supports chunked prefill.
|
||||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
|
||||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
# Only enforce this shape-constraint for decoder
|
# Only enforce this shape-constraint for decoder
|
||||||
# self-attention
|
# self-attention
|
||||||
@@ -811,7 +810,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
|||||||
|
|
||||||
# Prefix cache
|
# Prefix cache
|
||||||
if prefill_meta.query_start_loc_host[-1] != prefill_meta.kv_lod_cpu[-1]:
|
if prefill_meta.query_start_loc_host[-1] != prefill_meta.kv_lod_cpu[-1]:
|
||||||
xtorch_ops.prefill_attention(
|
kunlun_ops.prefill_attention(
|
||||||
q=prefill_query,
|
q=prefill_query,
|
||||||
k=key_cache, # Key Cache [block_num, head, block_size, dim]
|
k=key_cache, # Key Cache [block_num, head, block_size, dim]
|
||||||
v=value_cache,
|
v=value_cache,
|
||||||
@@ -827,7 +826,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
|||||||
softmax_lse=None,
|
softmax_lse=None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
xtorch_ops.prefill_attention(
|
kunlun_ops.prefill_attention(
|
||||||
q=prefill_query,
|
q=prefill_query,
|
||||||
k=prefill_key,
|
k=prefill_key,
|
||||||
v=prefill_value,
|
v=prefill_value,
|
||||||
@@ -860,9 +859,9 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
|||||||
decode_meta.block_tables * 2
|
decode_meta.block_tables * 2
|
||||||
) # only test in Qwen3-Next
|
) # only test in Qwen3-Next
|
||||||
|
|
||||||
sig = inspect.signature(xtorch_ops.speculative_attention)
|
sig = inspect.signature(kunlun_ops.speculative_attention)
|
||||||
if "max_window_size" in sig.parameters:
|
if "max_window_size" in sig.parameters:
|
||||||
xtorch_ops.speculative_attention(
|
kunlun_ops.speculative_attention(
|
||||||
out=output[:num_decode_tokens],
|
out=output[:num_decode_tokens],
|
||||||
# Only MLA support q len > 1 right now
|
# Only MLA support q len > 1 right now
|
||||||
q=decode_query.unsqueeze(0),
|
q=decode_query.unsqueeze(0),
|
||||||
@@ -890,7 +889,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif not attn_metadata.is_speculative:
|
elif not attn_metadata.is_speculative:
|
||||||
xtorch_ops.paged_attention(
|
kunlun_ops.paged_attention(
|
||||||
x=decode_query,
|
x=decode_query,
|
||||||
k_cache=key_cache,
|
k_cache=key_cache,
|
||||||
v_cache=value_cache,
|
v_cache=value_cache,
|
||||||
@@ -910,7 +909,7 @@ class KunlunAttentionImpl(AttentionImpl[KunlunMetadata]):
|
|||||||
out = output[:num_decode_tokens]
|
out = output[:num_decode_tokens]
|
||||||
assert out.is_contiguous()
|
assert out.is_contiguous()
|
||||||
|
|
||||||
xtorch_ops.speculative_attention(
|
kunlun_ops.speculative_attention(
|
||||||
out=out.view(batch_size, qlen, head_num, self.head_size),
|
out=out.view(batch_size, qlen, head_num, self.head_size),
|
||||||
q=decode_query.view(batch_size, qlen, head_num, head_dim),
|
q=decode_query.view(batch_size, qlen, head_num, head_dim),
|
||||||
k_cache=key_cache,
|
k_cache=key_cache,
|
||||||
|
|||||||
@@ -220,7 +220,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
|||||||
infer_global_hyperparameters,
|
infer_global_hyperparameters,
|
||||||
split_decodes_and_prefills)
|
split_decodes_and_prefills)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
import xtorch_ops
|
import kunlun_ops
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||||
@@ -1106,7 +1106,7 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
|
|||||||
) * q_len
|
) * q_len
|
||||||
sorted_tokens_idx = torch.arange(
|
sorted_tokens_idx = torch.arange(
|
||||||
self.num_heads * q_len, dtype=torch.int, device="cuda")
|
self.num_heads * q_len, dtype=torch.int, device="cuda")
|
||||||
xtorch_ops.mla_bmm_I8(
|
kunlun_ops.mla_bmm_I8(
|
||||||
x.contiguous(), # [1, 16, 512] torch.float16
|
x.contiguous(), # [1, 16, 512] torch.float16
|
||||||
self.W_UV, # [16, 128, 512] torch.int8
|
self.W_UV, # [16, 128, 512] torch.int8
|
||||||
self.W_UV_SCALE, # [2048, 1] torch.float32
|
self.W_UV_SCALE, # [2048, 1] torch.float32
|
||||||
@@ -1220,7 +1220,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
tp_q_head_num=q.size(1)
|
tp_q_head_num=q.size(1)
|
||||||
softmax_lse = torch.zeros(tp_q_head_num, q.size(0), dtype=torch.float32, device=q.device)
|
softmax_lse = torch.zeros(tp_q_head_num, q.size(0), dtype=torch.float32, device=q.device)
|
||||||
softmax_lse.fill_(float('-inf'))
|
softmax_lse.fill_(float('-inf'))
|
||||||
xtorch_ops.attention(
|
kunlun_ops.attention(
|
||||||
q=q,
|
q=q,
|
||||||
k_cache=k,
|
k_cache=k,
|
||||||
v_cache=maybe_padded_v,
|
v_cache=maybe_padded_v,
|
||||||
@@ -1406,7 +1406,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
self.W_UK_T = W_UK.transpose(1, 2).contiguous()
|
self.W_UK_T = W_UK.transpose(1, 2).contiguous()
|
||||||
self.W_UK_SCALE = torch.empty([W_UK.shape[0] * W_UK.shape[2], 1],
|
self.W_UK_SCALE = torch.empty([W_UK.shape[0] * W_UK.shape[2], 1],
|
||||||
dtype=torch.float, device=kv_b_proj_weight.device)
|
dtype=torch.float, device=kv_b_proj_weight.device)
|
||||||
xtorch_ops.quant2d(w_uk_dq_trans, self.W_UK_T, self.W_UK_SCALE)
|
kunlun_ops.quant2d(w_uk_dq_trans, self.W_UK_T, self.W_UK_SCALE)
|
||||||
self.W_UV = W_UV.contiguous()
|
self.W_UV = W_UV.contiguous()
|
||||||
self.W_UV_SCALE = W_UV_SCALE.contiguous().reshape(-1, 1)
|
self.W_UV_SCALE = W_UV_SCALE.contiguous().reshape(-1, 1)
|
||||||
else:
|
else:
|
||||||
@@ -1836,7 +1836,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
|
|
||||||
# write the latent and rope to kv cache
|
# write the latent and rope to kv cache
|
||||||
if kv_cache.numel() > 0:
|
if kv_cache.numel() > 0:
|
||||||
xtorch_ops.concat_and_cache_mla(
|
kunlun_ops.concat_and_cache_mla(
|
||||||
k_c_normed,
|
k_c_normed,
|
||||||
k_pe.squeeze(1),
|
k_pe.squeeze(1),
|
||||||
attn_metadata.slot_mapping.flatten(),
|
attn_metadata.slot_mapping.flatten(),
|
||||||
@@ -1885,7 +1885,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
sorted_tokens_idx = torch.arange(
|
sorted_tokens_idx = torch.arange(
|
||||||
self.num_heads * q_len, dtype=torch.int, device="cuda")
|
self.num_heads * q_len, dtype=torch.int, device="cuda")
|
||||||
extra_params = {"trans": False}
|
extra_params = {"trans": False}
|
||||||
xtorch_ops.mla_bmm_I8(
|
kunlun_ops.mla_bmm_I8(
|
||||||
decode_q_nope.contiguous(),
|
decode_q_nope.contiguous(),
|
||||||
self.W_UK_T,
|
self.W_UK_T,
|
||||||
self.W_UK_SCALE,
|
self.W_UK_SCALE,
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from packaging import version
|
|||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
import xtorch_ops
|
import kunlun_ops
|
||||||
import os
|
import os
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@@ -200,16 +200,16 @@ def flashinfer_sample(
|
|||||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||||
if k is None:
|
if k is None:
|
||||||
# Top-p only.
|
# Top-p only.
|
||||||
next_token_ids = xtorch_ops.top_p_sampling_from_probs(
|
next_token_ids = kunlun_ops.top_p_sampling_from_probs(
|
||||||
probs,top_p=p, deterministic=True)
|
probs,top_p=p, deterministic=True)
|
||||||
elif p is None:
|
elif p is None:
|
||||||
# Top-k only.
|
# Top-k only.
|
||||||
next_token_ids = xtorch_ops.top_k_sampling_from_probs(
|
next_token_ids = kunlun_ops.top_k_sampling_from_probs(
|
||||||
probs, top_k=k, deterministic=True)
|
probs, top_k=k, deterministic=True)
|
||||||
else:
|
else:
|
||||||
# Both top-k and top-p.
|
# Both top-k and top-p.
|
||||||
k = k.to(torch.int32)
|
k = k.to(torch.int32)
|
||||||
next_token_ids = xtorch_ops.top_k_top_p_sampling_from_probs(
|
next_token_ids = kunlun_ops.top_k_top_p_sampling_from_probs(
|
||||||
probs, top_k=k, top_p=p, deterministic=True)
|
probs, top_k=k, top_p=p, deterministic=True)
|
||||||
|
|
||||||
return next_token_ids.view(-1)
|
return next_token_ids.view(-1)
|
||||||
|
|||||||
@@ -405,7 +405,7 @@ def add_rmsnorm(
|
|||||||
residual_output: torch.Tensor = None,
|
residual_output: torch.Tensor = None,
|
||||||
output_max: torch.Tensor = None,
|
output_max: torch.Tensor = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.add_rmsnorm(
|
kunlun_ops.add_rmsnorm(
|
||||||
x,
|
x,
|
||||||
y, # 原来写 residual,这里其实是 y
|
y, # 原来写 residual,这里其实是 y
|
||||||
residual_output=residual_output,
|
residual_output=residual_output,
|
||||||
@@ -429,7 +429,7 @@ def add_rmsnorm_cuda(
|
|||||||
residual_output: torch.Tensor = None,
|
residual_output: torch.Tensor = None,
|
||||||
output_max: torch.Tensor = None,
|
output_max: torch.Tensor = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.add_rmsnorm(
|
kunlun_ops.add_rmsnorm(
|
||||||
x,
|
x,
|
||||||
y,
|
y,
|
||||||
residual_output=residual_output,
|
residual_output=residual_output,
|
||||||
@@ -451,7 +451,7 @@ def rmsnorm(
|
|||||||
residual_output: torch.Tensor = None,
|
residual_output: torch.Tensor = None,
|
||||||
output_max: torch.Tensor = None,
|
output_max: torch.Tensor = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.rmsnorm(
|
kunlun_ops.rmsnorm(
|
||||||
x,
|
x,
|
||||||
weight,
|
weight,
|
||||||
output,
|
output,
|
||||||
@@ -471,7 +471,7 @@ def rmsnorm_cuda(
|
|||||||
residual_output: torch.Tensor = None,
|
residual_output: torch.Tensor = None,
|
||||||
output_max: torch.Tensor = None,
|
output_max: torch.Tensor = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.rmsnorm(
|
kunlun_ops.rmsnorm(
|
||||||
x,
|
x,
|
||||||
weight,
|
weight,
|
||||||
output,
|
output,
|
||||||
@@ -541,7 +541,7 @@ def split_norm_rope_neox(
|
|||||||
rotary_dim: int,
|
rotary_dim: int,
|
||||||
emb_batch_size: int = 1,
|
emb_batch_size: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.split_norm_rope_neox(
|
kunlun_ops.split_norm_rope_neox(
|
||||||
q_emb,
|
q_emb,
|
||||||
k_emb,
|
k_emb,
|
||||||
v_out,
|
v_out,
|
||||||
@@ -577,7 +577,7 @@ def split_norm_rope_neox_cuda(
|
|||||||
rotary_dim: int,
|
rotary_dim: int,
|
||||||
emb_batch_size: int = 1,
|
emb_batch_size: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.split_norm_rope_neox(
|
kunlun_ops.split_norm_rope_neox(
|
||||||
q_emb,
|
q_emb,
|
||||||
k_emb,
|
k_emb,
|
||||||
v_out,
|
v_out,
|
||||||
@@ -649,7 +649,7 @@ if hasattr(torch.ops.custom_ops, "fc_fusion"):
|
|||||||
def silu_and_mul(
|
def silu_and_mul(
|
||||||
out: torch.Tensor, x: torch.Tensor, axis: int = -1, turn: bool = True
|
out: torch.Tensor, x: torch.Tensor, axis: int = -1, turn: bool = True
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.swiglu(
|
kunlun_ops.swiglu(
|
||||||
x=x,
|
x=x,
|
||||||
y=out,
|
y=out,
|
||||||
)
|
)
|
||||||
@@ -659,7 +659,7 @@ def silu_and_mul(
|
|||||||
def silu_and_mul_cuda(
|
def silu_and_mul_cuda(
|
||||||
out: torch.Tensor, x: torch.Tensor, axis: int = -1, turn: bool = True
|
out: torch.Tensor, x: torch.Tensor, axis: int = -1, turn: bool = True
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.swiglu(
|
kunlun_ops.swiglu(
|
||||||
x=x,
|
x=x,
|
||||||
y=out,
|
y=out,
|
||||||
)
|
)
|
||||||
@@ -736,7 +736,7 @@ def moe_softmax_topk(
|
|||||||
axis: int = -1,
|
axis: int = -1,
|
||||||
turn: bool = True,
|
turn: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic)
|
kunlun_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic)
|
||||||
|
|
||||||
|
|
||||||
@impl("_C::moe_softmax_topk", "CUDA")
|
@impl("_C::moe_softmax_topk", "CUDA")
|
||||||
@@ -748,7 +748,7 @@ def moe_softmax_topk_cuda(
|
|||||||
axis: int = -1,
|
axis: int = -1,
|
||||||
turn: bool = True,
|
turn: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic)
|
kunlun_ops.moe_softmax_topk(x, normed_score, topk_index, block_statistic)
|
||||||
|
|
||||||
|
|
||||||
def _fake_moe_softmax_topk(
|
def _fake_moe_softmax_topk(
|
||||||
@@ -781,7 +781,7 @@ def moe_ffn_block(
|
|||||||
w1_bias: Optional[torch.Tensor] = None,
|
w1_bias: Optional[torch.Tensor] = None,
|
||||||
w2_bias: Optional[torch.Tensor] = None,
|
w2_bias: Optional[torch.Tensor] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.moe_ffn_block(
|
kunlun_ops.moe_ffn_block(
|
||||||
x=x,
|
x=x,
|
||||||
gate_w=gate_w,
|
gate_w=gate_w,
|
||||||
inter_w=inter_w,
|
inter_w=inter_w,
|
||||||
@@ -812,7 +812,7 @@ def moe_ffn_block_cuda(
|
|||||||
w1_bias: Optional[torch.Tensor] = None,
|
w1_bias: Optional[torch.Tensor] = None,
|
||||||
w2_bias: Optional[torch.Tensor] = None,
|
w2_bias: Optional[torch.Tensor] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.moe_ffn_block(
|
kunlun_ops.moe_ffn_block(
|
||||||
x=x,
|
x=x,
|
||||||
gate_w=gate_w,
|
gate_w=gate_w,
|
||||||
inter_w=inter_w,
|
inter_w=inter_w,
|
||||||
@@ -863,7 +863,7 @@ def moe_ffn_per_token_block(
|
|||||||
ep_size: int = 1,
|
ep_size: int = 1,
|
||||||
ep_rank: int = 0,
|
ep_rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.moe_ffn_per_token_block(
|
kunlun_ops.moe_ffn_per_token_block(
|
||||||
x=x,
|
x=x,
|
||||||
inter_weight=inter_weight,
|
inter_weight=inter_weight,
|
||||||
inter_scale=inter_scale,
|
inter_scale=inter_scale,
|
||||||
@@ -897,7 +897,7 @@ def moe_ffn_per_token_block_cuda(
|
|||||||
ep_size: int = 1,
|
ep_size: int = 1,
|
||||||
ep_rank: int = 0,
|
ep_rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.moe_ffn_per_token_block(
|
kunlun_ops.moe_ffn_per_token_block(
|
||||||
x=x,
|
x=x,
|
||||||
inter_weight=inter_weight,
|
inter_weight=inter_weight,
|
||||||
inter_scale=inter_scale,
|
inter_scale=inter_scale,
|
||||||
@@ -948,7 +948,7 @@ def rotary_embedding(
|
|||||||
cos_sin_cache: torch.Tensor,
|
cos_sin_cache: torch.Tensor,
|
||||||
is_neox: bool,
|
is_neox: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.rotary_embedding(
|
kunlun_ops.rotary_embedding(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
query=query,
|
query=query,
|
||||||
key=key,
|
key=key,
|
||||||
@@ -967,7 +967,7 @@ def rotary_embedding_cuda(
|
|||||||
cos_sin_cache: torch.Tensor,
|
cos_sin_cache: torch.Tensor,
|
||||||
is_neox: bool,
|
is_neox: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.rotary_embedding(
|
kunlun_ops.rotary_embedding(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
query=query,
|
query=query,
|
||||||
key=key,
|
key=key,
|
||||||
@@ -999,7 +999,7 @@ def gemm_I8_I8_bf16_nt(
|
|||||||
weight_scale: torch.Tensor,
|
weight_scale: torch.Tensor,
|
||||||
out: torch.Tensor,
|
out: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.gemm_I8_I8_bf16_nt(
|
kunlun_ops.gemm_I8_I8_bf16_nt(
|
||||||
lhs=(x_q, x_scale), rhs=(weight, weight_scale), out=out
|
lhs=(x_q, x_scale), rhs=(weight, weight_scale), out=out
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1012,7 +1012,7 @@ def gemm_I8_I8_bf16_nt_cuda(
|
|||||||
weight_scale: torch.Tensor,
|
weight_scale: torch.Tensor,
|
||||||
out: torch.Tensor,
|
out: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.gemm_I8_I8_bf16_nt(
|
kunlun_ops.gemm_I8_I8_bf16_nt(
|
||||||
lhs=(x_q, x_scale), rhs=(weight, weight_scale), out=out
|
lhs=(x_q, x_scale), rhs=(weight, weight_scale), out=out
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1038,7 +1038,7 @@ def moe_softmax_topk_norm(
|
|||||||
block_statistic: torch.Tensor,
|
block_statistic: torch.Tensor,
|
||||||
stable: bool = True,
|
stable: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.moe_softmax_topk_norm(
|
kunlun_ops.moe_softmax_topk_norm(
|
||||||
x, normed_score, topk_index, block_statistic, stable
|
x, normed_score, topk_index, block_statistic, stable
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1051,7 +1051,7 @@ def moe_softmax_topk_norm_cuda(
|
|||||||
block_statistic: torch.Tensor,
|
block_statistic: torch.Tensor,
|
||||||
stable: bool = True,
|
stable: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.moe_softmax_topk_norm(
|
kunlun_ops.moe_softmax_topk_norm(
|
||||||
x, normed_score, topk_index, block_statistic, stable
|
x, normed_score, topk_index, block_statistic, stable
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1071,14 +1071,14 @@ moe_softmax_topk_norm.register_fake(_fake_moe_softmax_topk_norm)
|
|||||||
|
|
||||||
@custom_op("_C::gen_block_statistic", mutates_args=())
|
@custom_op("_C::gen_block_statistic", mutates_args=())
|
||||||
def gen_block_statistic(topk_ids: torch.Tensor, block_statistic: torch.Tensor) -> None:
|
def gen_block_statistic(topk_ids: torch.Tensor, block_statistic: torch.Tensor) -> None:
|
||||||
xtorch_ops.gen_block_statistic(topk_ids, block_statistic)
|
kunlun_ops.gen_block_statistic(topk_ids, block_statistic)
|
||||||
|
|
||||||
|
|
||||||
@impl("_C::gen_block_statistic", "CUDA")
|
@impl("_C::gen_block_statistic", "CUDA")
|
||||||
def gen_block_statistic_cuda(
|
def gen_block_statistic_cuda(
|
||||||
topk_ids: torch.Tensor, block_statistic: torch.Tensor
|
topk_ids: torch.Tensor, block_statistic: torch.Tensor
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.gen_block_statistic(topk_ids, block_statistic)
|
kunlun_ops.gen_block_statistic(topk_ids, block_statistic)
|
||||||
|
|
||||||
|
|
||||||
def fake_gen_block_statistic(
|
def fake_gen_block_statistic(
|
||||||
@@ -1101,7 +1101,7 @@ def moe_pre_sorted(
|
|||||||
sorted_tokens_num_lod: torch.Tensor,
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
index_have_neg: bool = False,
|
index_have_neg: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.moe_pre_sorted(
|
kunlun_ops.moe_pre_sorted(
|
||||||
x,
|
x,
|
||||||
topk_index,
|
topk_index,
|
||||||
block_statistic,
|
block_statistic,
|
||||||
@@ -1123,7 +1123,7 @@ def moe_pre_sorted_cuda(
|
|||||||
sorted_tokens_num_lod: torch.Tensor,
|
sorted_tokens_num_lod: torch.Tensor,
|
||||||
index_have_neg: bool = False,
|
index_have_neg: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.moe_pre_sorted(
|
kunlun_ops.moe_pre_sorted(
|
||||||
x,
|
x,
|
||||||
topk_index,
|
topk_index,
|
||||||
block_statistic,
|
block_statistic,
|
||||||
@@ -1171,7 +1171,7 @@ def moe_fc(
|
|||||||
use_pack_int4: Optional[bool] = False,
|
use_pack_int4: Optional[bool] = False,
|
||||||
sort_mode: Optional[bool] = True,
|
sort_mode: Optional[bool] = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.moe_fc(
|
kunlun_ops.moe_fc(
|
||||||
x=x,
|
x=x,
|
||||||
weight=weight,
|
weight=weight,
|
||||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||||
@@ -1214,7 +1214,7 @@ def moe_fc_cuda(
|
|||||||
use_pack_int4: Optional[bool] = False,
|
use_pack_int4: Optional[bool] = False,
|
||||||
sort_mode: Optional[bool] = True,
|
sort_mode: Optional[bool] = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.moe_fc(
|
kunlun_ops.moe_fc(
|
||||||
x=x,
|
x=x,
|
||||||
weight=weight,
|
weight=weight,
|
||||||
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
sorted_tokens_num_lod=sorted_tokens_num_lod,
|
||||||
@@ -1270,7 +1270,7 @@ def moe_post(
|
|||||||
dequant_scale: torch.Tensor,
|
dequant_scale: torch.Tensor,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y)
|
kunlun_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y)
|
||||||
|
|
||||||
|
|
||||||
@impl("_C::moe_post", "CUDA")
|
@impl("_C::moe_post", "CUDA")
|
||||||
@@ -1281,7 +1281,7 @@ def moe_post_cuda(
|
|||||||
dequant_scale: torch.Tensor,
|
dequant_scale: torch.Tensor,
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y)
|
kunlun_ops.moe_post(x, moe_index, normed_scale, dequant_scale, y)
|
||||||
|
|
||||||
|
|
||||||
def fake_moe_post(
|
def fake_moe_post(
|
||||||
@@ -1308,7 +1308,7 @@ def moe_sigmoid_group_topk_norm(
|
|||||||
n_group: int,
|
n_group: int,
|
||||||
topk_group: int,
|
topk_group: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.moe_sigmoid_group_topk_norm(
|
kunlun_ops.moe_sigmoid_group_topk_norm(
|
||||||
x=x,
|
x=x,
|
||||||
norm_score=norm_score,
|
norm_score=norm_score,
|
||||||
topk_index=topk_index,
|
topk_index=topk_index,
|
||||||
@@ -1331,7 +1331,7 @@ def moe_sigmoid_group_topk_norm_cuda(
|
|||||||
n_group: int,
|
n_group: int,
|
||||||
topk_group: int,
|
topk_group: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.moe_sigmoid_group_topk_norm(
|
kunlun_ops.moe_sigmoid_group_topk_norm(
|
||||||
x=x,
|
x=x,
|
||||||
norm_score=norm_score,
|
norm_score=norm_score,
|
||||||
topk_index=topk_index,
|
topk_index=topk_index,
|
||||||
@@ -1376,7 +1376,7 @@ def awq_dequantize(
|
|||||||
device=qweight.device,
|
device=qweight.device,
|
||||||
)
|
)
|
||||||
group_m = int(qweight.shape[0] / scales.shape[0])
|
group_m = int(qweight.shape[0] / scales.shape[0])
|
||||||
xtorch_ops.awq_dequantize(
|
kunlun_ops.awq_dequantize(
|
||||||
qweight=qweight,
|
qweight=qweight,
|
||||||
scales=scales,
|
scales=scales,
|
||||||
zeros=zeros,
|
zeros=zeros,
|
||||||
@@ -1402,7 +1402,7 @@ def awq_dequantize_cuda(
|
|||||||
device=qweight.device,
|
device=qweight.device,
|
||||||
)
|
)
|
||||||
group_m = int(qweight.shape[0] / scales.shape[0])
|
group_m = int(qweight.shape[0] / scales.shape[0])
|
||||||
xtorch_ops.awq_dequantize(
|
out = kunlun_ops.awq_dequantize(
|
||||||
qweight=qweight,
|
qweight=qweight,
|
||||||
scales=scales,
|
scales=scales,
|
||||||
zeros=zeros,
|
zeros=zeros,
|
||||||
@@ -1447,7 +1447,7 @@ def awq_gemm(
|
|||||||
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
|
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
|
||||||
)
|
)
|
||||||
group_size = int(qweight.shape[0] / scale.shape[0])
|
group_size = int(qweight.shape[0] / scale.shape[0])
|
||||||
xtorch_ops.awq_gemm(
|
kunlun_ops.awq_gemm(
|
||||||
x=x,
|
x=x,
|
||||||
w=qweight,
|
w=qweight,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
@@ -1471,7 +1471,7 @@ def awq_gemm_cuda(
|
|||||||
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
|
(x.shape[0], qweight.shape[1] * 8), dtype=torch.float16, device=x.device
|
||||||
)
|
)
|
||||||
group_size = int(qweight.shape[0] / scale.shape[0])
|
group_size = int(qweight.shape[0] / scale.shape[0])
|
||||||
xtorch_ops.awq_gemm(
|
kunlun_ops.awq_gemm(
|
||||||
x=x,
|
x=x,
|
||||||
w=qweight,
|
w=qweight,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
@@ -1508,7 +1508,7 @@ def gptq_shuffle(
|
|||||||
q_perm: torch.Tensor,
|
q_perm: torch.Tensor,
|
||||||
bit: int,
|
bit: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
|
kunlun_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
|
||||||
|
|
||||||
|
|
||||||
@impl("_C::gptq_shuffle", "CUDA")
|
@impl("_C::gptq_shuffle", "CUDA")
|
||||||
@@ -1517,7 +1517,7 @@ def gptq_shuffle_cuda(
|
|||||||
q_perm: torch.Tensor,
|
q_perm: torch.Tensor,
|
||||||
bit: int,
|
bit: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
|
kunlun_ops.gptq_shuffle(weight=q_weight, perm=q_perm, bit=bit)
|
||||||
|
|
||||||
|
|
||||||
def _fake_gptq_shuffle(
|
def _fake_gptq_shuffle(
|
||||||
@@ -1541,7 +1541,7 @@ def concat_and_cache_mla(
|
|||||||
kv_cache: torch.Tensor, # [num_blocks, block_size, (kv_lora_rank + pe_dim)]
|
kv_cache: torch.Tensor, # [num_blocks, block_size, (kv_lora_rank + pe_dim)]
|
||||||
slot_mapping: torch.Tensor, # [num_tokens] or [num_actual_tokens]
|
slot_mapping: torch.Tensor, # [num_tokens] or [num_actual_tokens]
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.concat_and_cache_mla(
|
kunlun_ops.concat_and_cache_mla(
|
||||||
kv_c=kv_c,
|
kv_c=kv_c,
|
||||||
k_pe=k_pe,
|
k_pe=k_pe,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
@@ -1556,7 +1556,7 @@ def concat_and_cache_mla_cuda(
|
|||||||
kv_cache: torch.Tensor, # [num_blocks, block_size, (kv_lora_rank + pe_dim)]
|
kv_cache: torch.Tensor, # [num_blocks, block_size, (kv_lora_rank + pe_dim)]
|
||||||
slot_mapping: torch.Tensor, # [num_tokens] or [num_actual_tokens]
|
slot_mapping: torch.Tensor, # [num_tokens] or [num_actual_tokens]
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.concat_and_cache_mla(
|
kunlun_ops.concat_and_cache_mla(
|
||||||
kv_c=kv_c,
|
kv_c=kv_c,
|
||||||
k_pe=k_pe,
|
k_pe=k_pe,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
@@ -1598,7 +1598,7 @@ def scaled_int8_quant(
|
|||||||
azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32)
|
azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32)
|
||||||
if symmetric:
|
if symmetric:
|
||||||
# NOTE: For quant2d ops, scale represents max.
|
# NOTE: For quant2d ops, scale represents max.
|
||||||
xtorch_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True)
|
kunlun_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True)
|
||||||
else:
|
else:
|
||||||
torch.ops.xspeedgate_ops.dynamic_scaled_int8_quant(
|
torch.ops.xspeedgate_ops.dynamic_scaled_int8_quant(
|
||||||
x_q, x.contiguous(), scale, azp
|
x_q, x.contiguous(), scale, azp
|
||||||
@@ -1625,7 +1625,7 @@ def scaled_int8_quant_cuda(
|
|||||||
azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32)
|
azp = None if symmetric else torch.empty_like(scale, dtype=torch.int32)
|
||||||
if symmetric:
|
if symmetric:
|
||||||
# NOTE: For quant2d ops, scale represents max.
|
# NOTE: For quant2d ops, scale represents max.
|
||||||
xtorch_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True)
|
kunlun_ops.quant2d(x=x.contiguous(), y=x_q, max=scale, force_sdnn=True)
|
||||||
else:
|
else:
|
||||||
torch.ops.xspeedgate_ops.dynamic_scaled_int8_quant(
|
torch.ops.xspeedgate_ops.dynamic_scaled_int8_quant(
|
||||||
x_q, x.contiguous(), scale, azp
|
x_q, x.contiguous(), scale, azp
|
||||||
@@ -1777,7 +1777,7 @@ def matmul(
|
|||||||
dtype=out_dtype,
|
dtype=out_dtype,
|
||||||
device=x.device,
|
device=x.device,
|
||||||
)
|
)
|
||||||
xtorch_ops.matmul(
|
kunlun_ops.matmul(
|
||||||
x=x.contiguous(),
|
x=x.contiguous(),
|
||||||
w=w.contiguous(),
|
w=w.contiguous(),
|
||||||
out=out,
|
out=out,
|
||||||
@@ -1814,7 +1814,7 @@ def matmul_cuda(
|
|||||||
dtype=out_dtype,
|
dtype=out_dtype,
|
||||||
device=x.device,
|
device=x.device,
|
||||||
)
|
)
|
||||||
xtorch_ops.matmul(
|
kunlun_ops.matmul(
|
||||||
x=x.contiguous(),
|
x=x.contiguous(),
|
||||||
w=w.contiguous(),
|
w=w.contiguous(),
|
||||||
out=out,
|
out=out,
|
||||||
@@ -1865,7 +1865,7 @@ def quant2d(
|
|||||||
max: torch.Tensor,
|
max: torch.Tensor,
|
||||||
force_sdnn: bool = False,
|
force_sdnn: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.quant2d(
|
kunlun_ops.quant2d(
|
||||||
x=x,
|
x=x,
|
||||||
y=x_q,
|
y=x_q,
|
||||||
max=max,
|
max=max,
|
||||||
@@ -1880,7 +1880,7 @@ def quant2d_cuda(
|
|||||||
max: torch.Tensor,
|
max: torch.Tensor,
|
||||||
force_sdnn: bool = False,
|
force_sdnn: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.quant2d(
|
kunlun_ops.quant2d(
|
||||||
x=x,
|
x=x,
|
||||||
y=x_q,
|
y=x_q,
|
||||||
max=max,
|
max=max,
|
||||||
@@ -1954,7 +1954,7 @@ def I8_mqa_logits(
|
|||||||
is_causal: Optional[bool] = False,
|
is_causal: Optional[bool] = False,
|
||||||
use_xfa_boost: Optional[bool] = False,
|
use_xfa_boost: Optional[bool] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.I8_mqa_logits(
|
kunlun_ops.I8_mqa_logits(
|
||||||
q=q,
|
q=q,
|
||||||
fused_kv_cache=fused_kv_cache,
|
fused_kv_cache=fused_kv_cache,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@@ -1984,7 +1984,7 @@ def I8_mqa_logits_cuda(
|
|||||||
is_causal: Optional[bool] = False,
|
is_causal: Optional[bool] = False,
|
||||||
use_xfa_boost: Optional[bool] = False,
|
use_xfa_boost: Optional[bool] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.I8_mqa_logits(
|
kunlun_ops.I8_mqa_logits(
|
||||||
q=q,
|
q=q,
|
||||||
fused_kv_cache=fused_kv_cache,
|
fused_kv_cache=fused_kv_cache,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@@ -2034,7 +2034,8 @@ def I8_paged_mqa_logits(
|
|||||||
out: torch.Tensor,
|
out: torch.Tensor,
|
||||||
use_xfa_boost: Optional[bool] = False,
|
use_xfa_boost: Optional[bool] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.I8_paged_mqa_logits(
|
kunlun_ops.sparse_prefill_fwd_opt(
|
||||||
|
.I8_paged_mqa_logits(
|
||||||
q=q,
|
q=q,
|
||||||
fused_kv_cache=fused_kv_cache,
|
fused_kv_cache=fused_kv_cache,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@@ -2060,7 +2061,7 @@ def I8_paged_mqa_logits_cuda(
|
|||||||
out: torch.Tensor,
|
out: torch.Tensor,
|
||||||
use_xfa_boost: Optional[bool] = False,
|
use_xfa_boost: Optional[bool] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.I8_paged_mqa_logits(
|
kunlun_ops.I8_paged_mqa_logits(
|
||||||
q=q,
|
q=q,
|
||||||
fused_kv_cache=fused_kv_cache,
|
fused_kv_cache=fused_kv_cache,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
@@ -2111,7 +2112,7 @@ def sparse_prefill_fwd_opt(
|
|||||||
is_causal: Optional[bool] = True,
|
is_causal: Optional[bool] = True,
|
||||||
use_xfa_boost: Optional[bool] = False,
|
use_xfa_boost: Optional[bool] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.sparse_prefill_fwd_opt(
|
kunlun_ops.sparse_prefill_fwd_opt(
|
||||||
q=q,
|
q=q,
|
||||||
kv=kv,
|
kv=kv,
|
||||||
indices=indices,
|
indices=indices,
|
||||||
@@ -2147,7 +2148,7 @@ def sparse_prefill_fwd_opt_cuda(
|
|||||||
is_causal: Optional[bool] = True,
|
is_causal: Optional[bool] = True,
|
||||||
use_xfa_boost: Optional[bool] = False,
|
use_xfa_boost: Optional[bool] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.sparse_prefill_fwd_opt(
|
kunlun_ops.sparse_prefill_fwd_opt(
|
||||||
q=q,
|
q=q,
|
||||||
kv=kv,
|
kv=kv,
|
||||||
indices=indices,
|
indices=indices,
|
||||||
@@ -2207,7 +2208,7 @@ def fwd_kvcache_mla(
|
|||||||
use_xfa_boost: Optional[bool] = False,
|
use_xfa_boost: Optional[bool] = False,
|
||||||
kv_lod_xpu: Optional[torch.Tensor] = None,
|
kv_lod_xpu: Optional[torch.Tensor] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.fwd_kvcache_mla(
|
kunlun_ops.fwd_kvcache_mla(
|
||||||
q_c=q_c,
|
q_c=q_c,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
indices=indices,
|
indices=indices,
|
||||||
@@ -2241,7 +2242,7 @@ def fwd_kvcache_mla_cuda(
|
|||||||
use_xfa_boost: Optional[bool] = False,
|
use_xfa_boost: Optional[bool] = False,
|
||||||
kv_lod_xpu: Optional[torch.Tensor] = None,
|
kv_lod_xpu: Optional[torch.Tensor] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.fwd_kvcache_mla(
|
kunlun_ops.fwd_kvcache_mla(
|
||||||
q_c=q_c,
|
q_c=q_c,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
indices=indices,
|
indices=indices,
|
||||||
@@ -2293,7 +2294,7 @@ def dequant_int4(
|
|||||||
int4_signed: bool = True,
|
int4_signed: bool = True,
|
||||||
use_mode_fast: bool = False,
|
use_mode_fast: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.dequant_int4(
|
kunlun_ops.dequant_int4(
|
||||||
x=x,
|
x=x,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
zero=zero,
|
zero=zero,
|
||||||
@@ -2315,7 +2316,7 @@ def dequant_int4_cuda(
|
|||||||
int4_signed: bool = True,
|
int4_signed: bool = True,
|
||||||
use_mode_fast: bool = False,
|
use_mode_fast: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
xtorch_ops.dequant_int4(
|
kunlun_ops.dequant_int4(
|
||||||
x=x,
|
x=x,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
zero=zero,
|
zero=zero,
|
||||||
@@ -2350,7 +2351,10 @@ def fast_topkv2(
|
|||||||
score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
|
score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
||||||
topk_indices = xtorch_ops.fast_topkv2(score=score, lengths=lengths, topk=topk)
|
topk_indices = kunlun_ops.fast_topkv2(
|
||||||
|
score=score,
|
||||||
|
lengths=lengths,
|
||||||
|
topk=topk)
|
||||||
return topk_indices
|
return topk_indices
|
||||||
|
|
||||||
|
|
||||||
@@ -2359,7 +2363,10 @@ def fast_topkv2_cuda(
|
|||||||
score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
|
score: torch.Tensor, lengths: torch.Tensor, topk: Optional[int] = 2048
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
||||||
topk_indices = xtorch_ops.fast_topkv2(score=score, lengths=lengths, topk=topk)
|
topk_indices = kunlun_ops.fast_topkv2(
|
||||||
|
score=score,
|
||||||
|
lengths=lengths,
|
||||||
|
topk=topk)
|
||||||
return topk_indices
|
return topk_indices
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user