Support non-contiguous query input for extend/decode attention (#7462)
This commit is contained in:
@@ -102,9 +102,10 @@ class TestDecodeAttention(CustomTestCase):
|
||||
device=device,
|
||||
)
|
||||
|
||||
# k_buffer, v_buffer, key and value supports non-contiguous tensors
|
||||
# k_buffer, v_buffer, query, key and value supports non-contiguous tensors
|
||||
k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
q = q.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
key = key.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
value = value.transpose(0, 1).contiguous().transpose(0, 1)
|
||||
torch.ops.sgl_kernel.decode_attention_cpu(
|
||||
|
||||
Reference in New Issue
Block a user