Fix batch invariant ops (#11368)

This commit is contained in:
Stefan He
2025-10-10 20:49:08 -07:00
committed by GitHub
parent 2674c1d280
commit eae9a9fb9d
3 changed files with 168 additions and 6 deletions

View File

@@ -77,8 +77,6 @@ def matmul_kernel_persistent(
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
num_tiles = num_pid_m * num_pid_n
tile_id_c = start_pid - NUM_SMS
offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
@@ -120,10 +118,6 @@ def matmul_kernel_persistent(
)
accumulator = tl.dot(a, b, accumulator)
tile_id_c += NUM_SMS
pid_m, pid_n = _compute_pid(
tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
if C_LARGE:
@@ -137,6 +131,10 @@ def matmul_kernel_persistent(
accumulator += bias
if c_ptr.dtype.element_ty == tl.float8e4nv:
c = accumulator.to(tl.float8e4nv)
elif c_ptr.dtype.element_ty == tl.bfloat16:
c = accumulator.to(tl.bfloat16)
elif c_ptr.dtype.element_ty == tl.float32:
c = accumulator.to(tl.float32)
else:
c = accumulator.to(tl.float16)
tl.store(c_ptrs, c, mask=c_mask)