Migrate XTorch operations to Kunlun operations (accelerating iteration) (#177)
Signed-off-by: dongxinyu03 <dongxinyu03@baidu.com>
This commit is contained in:
@@ -12,7 +12,7 @@ import torch.nn.functional as F
|
||||
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.triton_utils import tl, triton
|
||||
import xtorch_ops
|
||||
import kunlun_ops
|
||||
|
||||
|
||||
@triton.jit()
|
||||
@@ -1212,7 +1212,7 @@ def torch_causal_conv1d_update(
|
||||
tmp_hidden_states = hidden_states_new[:, :, -state_len:]
|
||||
ori_shape = tmp_hidden_states.shape
|
||||
tmp_hidden_states = tmp_hidden_states.transpose(1, 2).reshape(ori_shape)
|
||||
xtorch_ops.reshape_and_cache_flash(
|
||||
kunlun_ops.reshape_and_cache_flash(
|
||||
tmp_hidden_states,
|
||||
tmp_hidden_states,
|
||||
cast_conv_state,
|
||||
|
||||
Reference in New Issue
Block a user