[sgl-kernel][3/N]Support Expert Specialization Grouped GEMM (#11674)
This commit is contained in:
@@ -168,7 +168,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
||||
].t() # b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later
|
||||
b_stack = b_stack.transpose(1, 2) # Transpose Matrix B to Column-Major.
|
||||
b_scale_stack = b_scale_stack.transpose(1, 2)
|
||||
|
||||
workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8)
|
||||
c_out = torch.empty((expert_offsets[-1], n_g), device=device, dtype=out_dtype)
|
||||
a_strides = torch.full(
|
||||
(num_experts,), a_stack.stride(0), device=device, dtype=torch.int64
|
||||
@@ -188,6 +188,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
|
||||
d_strides,
|
||||
problem_sizes,
|
||||
expert_offsets[:-1],
|
||||
workspace,
|
||||
)
|
||||
|
||||
for g in range(num_experts):
|
||||
|
||||
Reference in New Issue
Block a user