Migrate XTorch operations to Kunlun operations (accelerating iteration) (#177)
Signed-off-by: dongxinyu03 <dongxinyu03@baidu.com>
This commit is contained in:
@@ -7,7 +7,7 @@ import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
import xtorch_ops
|
||||
import kunlun_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -104,7 +104,7 @@ def flash_mla_with_kvcache(
|
||||
is_context = False
|
||||
vo_head_dim = -1
|
||||
|
||||
xtorch_ops.paged_attention(out,
|
||||
kunlun_ops.paged_attention(out,
|
||||
q,
|
||||
k_cache, None,
|
||||
block_table,
|
||||
@@ -149,7 +149,7 @@ def kunlun_flash_mla_with_kvcache(
|
||||
p_sums: (batch_size, seq_len_q, num_heads_q), torch.float32.
|
||||
"""
|
||||
assert not is_fp8_kvcache, "By now, the kernel does not support uint8 kv cache."
|
||||
assert q.shape[1] <= 2, "xtorch_ops.fwd_kvcache_mla only support seq_len_q <= 2 for now."
|
||||
assert q.shape[1] <= 2, "kunlun_ops.fwd_kvcache_mla only support seq_len_q <= 2 for now."
|
||||
if softmax_scale is None:
|
||||
softmax_scale = q.shape[-1] ** (-0.5)
|
||||
if indices is not None:
|
||||
|
||||
Reference in New Issue
Block a user