Migrate XTorch operations to Kunlun operations (accelerating iteration) (#177)
Signed-off-by: dongxinyu03 <dongxinyu03@baidu.com>
This commit is contained in:
@@ -7,7 +7,7 @@ import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
import xtorch_ops
|
||||
import kunlun_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -104,7 +104,7 @@ def flash_mla_with_kvcache(
|
||||
is_context = False
|
||||
vo_head_dim = -1
|
||||
|
||||
xtorch_ops.paged_attention(out,
|
||||
kunlun_ops.paged_attention(out,
|
||||
q,
|
||||
k_cache, None,
|
||||
block_table,
|
||||
@@ -149,7 +149,7 @@ def kunlun_flash_mla_with_kvcache(
|
||||
p_sums: (batch_size, seq_len_q, num_heads_q), torch.float32.
|
||||
"""
|
||||
assert not is_fp8_kvcache, "By now, the kernel does not support uint8 kv cache."
|
||||
assert q.shape[1] <= 2, "xtorch_ops.fwd_kvcache_mla only support seq_len_q <= 2 for now."
|
||||
assert q.shape[1] <= 2, "kunlun_ops.fwd_kvcache_mla only support seq_len_q <= 2 for now."
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
if indices is not None:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import xtorch_ops
|
||||
import kunlun_ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ def merge_attn_states(
|
||||
output_lse: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
|
||||
return xtorch_ops.attention_merge_stage(
|
||||
return kunlun_ops.attention_merge_stage(
|
||||
prefix_output,
|
||||
prefix_lse,
|
||||
suffix_output,
|
||||
|
||||
Reference in New Issue
Block a user