Migrate XTorch operations to Kunlun operations (accelerating iteration) (#177)
Signed-off-by: dongxinyu03 <dongxinyu03@baidu.com>
This commit is contained in:
@@ -28,7 +28,7 @@ from vllm.logger import init_logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
import xtorch_ops
|
||||
import kunlun_ops
|
||||
logger.info(f"Load custom ops library success!")
|
||||
except ImportError as e:
|
||||
logger.warning("Import error msg: %s", e.msg)
|
||||
@@ -71,7 +71,7 @@ class KunlunOps:
|
||||
):
|
||||
""" PagedAttentionV1 """
|
||||
# block_size = value_cache.shape[2]
|
||||
xtorch_ops.paged_attention(
|
||||
kunlun_ops.paged_attention(
|
||||
x=query,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
@@ -114,7 +114,7 @@ class KunlunOps:
|
||||
):
|
||||
""" PagedAttentionV2 """
|
||||
# block_size = value_cache.shape[2]
|
||||
xtorch_ops.paged_attention(
|
||||
kunlun_ops.paged_attention(
|
||||
x=query,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
@@ -133,7 +133,7 @@ class KunlunOps:
|
||||
def silu_and_mul(out: torch.Tensor,
|
||||
x: torch.Tensor):
|
||||
""" silu and mul """
|
||||
xtorch_ops.silu_and_mul(
|
||||
kunlun_ops.silu_and_mul(
|
||||
x,
|
||||
axis=-1,
|
||||
turn=True,
|
||||
@@ -145,7 +145,7 @@ class KunlunOps:
|
||||
def quick_gelu(out: torch.Tensor,
|
||||
x: torch.Tensor):
|
||||
""" quick gelu """
|
||||
xtorch_ops.quick_gelu(
|
||||
kunlun_ops.quick_gelu(
|
||||
x,
|
||||
out=out,
|
||||
)
|
||||
@@ -159,7 +159,7 @@ class KunlunOps:
|
||||
epsilon,
|
||||
):
|
||||
"""rms_norm"""
|
||||
xtorch_ops.rmsnorm(
|
||||
kunlun_ops.rmsnorm(
|
||||
x, weight.to(torch.float32), epsilon, out=out
|
||||
)
|
||||
|
||||
@@ -172,7 +172,7 @@ class KunlunOps:
|
||||
):
|
||||
"""fused_add_rms_norm"""
|
||||
output = torch.empty_like(x)
|
||||
xtorch_ops.add_rmsnorm(
|
||||
kunlun_ops.add_rmsnorm(
|
||||
x, residual, weight.to(torch.float32), epsilon, out=output
|
||||
)
|
||||
fused_input = x + residual
|
||||
@@ -222,7 +222,7 @@ class KunlunOps:
|
||||
key_x = key.contiguous()
|
||||
query_x_dim = query_x.dim()
|
||||
assert is_neox_style
|
||||
xtorch_ops.mrotary_embedding_neox(
|
||||
kunlun_ops.mrotary_embedding_neox(
|
||||
positions,
|
||||
query_x,
|
||||
key_x,
|
||||
@@ -240,7 +240,7 @@ class KunlunOps:
|
||||
dst,
|
||||
block_mapping):
|
||||
""" swap_blocks """
|
||||
xtorch_ops.swap_blocks(
|
||||
kunlun_ops.swap_blocks(
|
||||
src,
|
||||
dst,
|
||||
block_mapping
|
||||
@@ -255,7 +255,7 @@ class KunlunOps:
|
||||
for i in range(len(key_caches)):
|
||||
key_caches[i] = key_caches[i].contiguous()
|
||||
value_caches[i] = value_caches[i].contiguous()
|
||||
xtorch_ops.copy_blocks(
|
||||
kunlun_ops.copy_blocks(
|
||||
key_caches,
|
||||
value_caches,
|
||||
block_mapping,
|
||||
@@ -272,7 +272,7 @@ class KunlunOps:
|
||||
):
|
||||
""" reshape_and_cache """
|
||||
# slot_mapping_cast = slot_mapping.to(torch.int32)
|
||||
xtorch_ops.reshape_and_cache(
|
||||
kunlun_ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
@@ -308,7 +308,7 @@ class KunlunOps:
|
||||
repeat = Qh // KVh
|
||||
key = key.repeat_interleave(repeat, dim=2) # [B, T, Qh, Hd]
|
||||
value = value.repeat_interleave(repeat, dim=2)
|
||||
xtorch_ops.attention(
|
||||
kunlun_ops.attention(
|
||||
q=query,
|
||||
k_cache=key,
|
||||
v_cache=value,
|
||||
@@ -337,7 +337,7 @@ class KunlunOps:
|
||||
else:
|
||||
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
|
||||
|
||||
xtorch_ops.quant_fusedresidual_rmsnorm(x, residual, weight, bias, eps,
|
||||
kunlun_ops.quant_fusedresidual_rmsnorm(x, residual, weight, bias, eps,
|
||||
out=out, out_scale=out_scale , residual_tensor=residual)
|
||||
|
||||
if residual is None:
|
||||
@@ -360,7 +360,7 @@ class KunlunOps:
|
||||
else:
|
||||
out_scale = torch.empty(12, device=x.device, dtype=torch.float)
|
||||
|
||||
xtorch_ops.quant_rmsnorm(x, weight, bias, eps,
|
||||
kunlun_ops.quant_rmsnorm(x, weight, bias, eps,
|
||||
out=out, out_scale=out_scale)
|
||||
return out, out_scale
|
||||
|
||||
@@ -388,7 +388,7 @@ class KunlunOps:
|
||||
dtype=torch.float16,
|
||||
device=weight.device)
|
||||
output_bs_shape = [-1]
|
||||
xtorch_ops.smooth_quant_matmul_column_row_kernels(input_tensor,
|
||||
kunlun_ops.smooth_quant_matmul_column_row_kernels(input_tensor,
|
||||
weight, smoother,
|
||||
input_scale,
|
||||
weight_scale,
|
||||
@@ -642,7 +642,7 @@ class KunlunOps:
|
||||
"""mla pa block"""
|
||||
output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
xtorch_ops.xft_multi_head_latent_page_attention_block(
|
||||
kunlun_ops.xft_multi_head_latent_page_attention_block(
|
||||
hidden_states,
|
||||
q_lora_rank,
|
||||
kv_lora_rank,
|
||||
@@ -688,7 +688,7 @@ class KunlunOps:
|
||||
threshold: float = 20.0,
|
||||
) -> torch.Tensor:
|
||||
"""fused_gdn_gating"""
|
||||
output = xtorch_ops.fused_gdn_gating(
|
||||
output = kunlun_ops.fused_gdn_gating(
|
||||
A_log,
|
||||
a,
|
||||
dt_bias,
|
||||
@@ -713,7 +713,7 @@ class KunlunOps:
|
||||
2. Delta Rule Update: 执行一个并行的状态空间模型(SSM)的递归更新, 同时结合了一个局部的注意力机制。
|
||||
'''
|
||||
|
||||
o, final_state = xtorch_ops.fused_recurrent_gated_delta_rule_fwd(
|
||||
o, final_state = kunlun_ops.fused_recurrent_gated_delta_rule_fwd(
|
||||
q, k, v, g, beta, scale, h0_source, output_final_state, use_qk_l2norm_in_kernel,
|
||||
cu_seqlens)
|
||||
return (o, final_state)
|
||||
Reference in New Issue
Block a user