[sgl-kernel][3/N]Support Expert Specialization Grouped GEMM (#11674)
This commit is contained in:
@@ -133,6 +133,7 @@ def bench_es(
|
||||
d_strides = torch.full(
|
||||
(num_groups,), c_out.stride(0), device=device, dtype=torch.int64
|
||||
)
|
||||
workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8)
|
||||
|
||||
def run_cutlass():
|
||||
es_fp8_blockwise_scaled_grouped_mm(
|
||||
@@ -146,6 +147,7 @@ def bench_es(
|
||||
d_strides,
|
||||
problem_sizes,
|
||||
expert_offsets[:-1],
|
||||
workspace,
|
||||
)
|
||||
|
||||
run_cutlass()
|
||||
|
||||
Reference in New Issue
Block a user