Fix batch invariant ops (#11368)
This commit is contained in:
@@ -77,8 +77,6 @@ def matmul_kernel_persistent(
|
||||
k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
|
||||
num_tiles = num_pid_m * num_pid_n
|
||||
|
||||
tile_id_c = start_pid - NUM_SMS
|
||||
|
||||
offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
|
||||
@@ -120,10 +118,6 @@ def matmul_kernel_persistent(
|
||||
)
|
||||
accumulator = tl.dot(a, b, accumulator)
|
||||
|
||||
tile_id_c += NUM_SMS
|
||||
pid_m, pid_n = _compute_pid(
|
||||
tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
|
||||
)
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
if C_LARGE:
|
||||
@@ -137,6 +131,10 @@ def matmul_kernel_persistent(
|
||||
accumulator += bias
|
||||
if c_ptr.dtype.element_ty == tl.float8e4nv:
|
||||
c = accumulator.to(tl.float8e4nv)
|
||||
elif c_ptr.dtype.element_ty == tl.bfloat16:
|
||||
c = accumulator.to(tl.bfloat16)
|
||||
elif c_ptr.dtype.element_ty == tl.float32:
|
||||
c = accumulator.to(tl.float32)
|
||||
else:
|
||||
c = accumulator.to(tl.float16)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
Reference in New Issue
Block a user