BLackwell cutlass mla: Add check for bad page size/block num combinations (#5431)
This commit is contained in:
@@ -74,9 +74,11 @@ def cutlass_mla_decode(
|
||||
f"but got D_q = {D_q}, D_ckv = {D_ckv}, D_latent = {D_latent}, D_rope = {D_rope}"
|
||||
)
|
||||
assert H == 128, f"H must be 128, but got {H}"
|
||||
# TODO: There is currently an illegal memory access issue with page size !=
|
||||
# 128. Change this when it is fixed.
|
||||
assert PAGE_SIZE == 128, f"PAGE_SIZE must be 128, but got {PAGE_SIZE}"
|
||||
|
||||
assert len(page_table.shape) == 2
|
||||
B_block_table, block_num = page_table.shape
|
||||
assert B_block_table == B_q
|
||||
assert block_num % (128 / PAGE_SIZE) == 0
|
||||
|
||||
# TODO(kaixih@nvidia): support fp8
|
||||
assert q_nope_and_q_pe.dtype in (
|
||||
|
||||
Reference in New Issue
Block a user