Support non-contiguous query input for extend/decode attention (#7462)
This commit is contained in:
@@ -102,9 +102,10 @@ class TestDecodeAttention(CustomTestCase):
|
||||
device=device,
|
||||
)
|
||||
|
||||
# k_buffer, v_buffer, key and value supports non-contiguous tensors
|
||||
# k_buffer, v_buffer, query, key and value supports non-contiguous tensors
|
||||
k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
q = q.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
key = key.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
value = value.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
torch.ops.sgl_kernel.decode_attention_cpu(
|
||||
|
||||
@@ -123,7 +123,8 @@ class TestExtendAttention(CustomTestCase):
|
||||
(b_seq_len_extend[i], H_Q, D), dtype=dtype
|
||||
)
|
||||
|
||||
# k_extend, v_extend, k_buffer and v_buffer supports non-contiguous tensors
|
||||
# q_extend, k_extend, v_extend, k_buffer and v_buffer supports non-contiguous tensors
|
||||
q_extend = q_extend.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
k_extend = k_extend.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
v_extend = v_extend.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
|
||||
Reference in New Issue
Block a user