Support non-contiguous query input for extend/decode attention (#7462)

This commit is contained in:
YanbingJiang
2025-07-03 10:59:45 +08:00
committed by GitHub
parent 40e5cb7a9c
commit b044400dd3
4 changed files with 29 additions and 13 deletions

View File

@@ -102,9 +102,10 @@ class TestDecodeAttention(CustomTestCase):
device=device,
)
# k_buffer, v_buffer, key and value supports non-contiguous tensors
# k_buffer, v_buffer, query, key and value supports non-contiguous tensors
k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1)
v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1)
q = q.transpose(0, 1).contiguous().transpose(0, 1)
key = key.transpose(0, 1).contiguous().transpose(0, 1)
value = value.transpose(0, 1).contiguous().transpose(0, 1)
torch.ops.sgl_kernel.decode_attention_cpu(

View File

@@ -123,7 +123,8 @@ class TestExtendAttention(CustomTestCase):
(b_seq_len_extend[i], H_Q, D), dtype=dtype
)
# k_extend, v_extend, k_buffer and v_buffer supports non-contiguous tensors
# q_extend, k_extend, v_extend, k_buffer and v_buffer supports non-contiguous tensors
q_extend = q_extend.transpose(0, 1).contiguous().transpose(0, 1)
k_extend = k_extend.transpose(0, 1).contiguous().transpose(0, 1)
v_extend = v_extend.transpose(0, 1).contiguous().transpose(0, 1)
k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1)