Fix bench script making input data on L2 cache (#7739)
This commit is contained in:
@@ -205,9 +205,9 @@ def benchmark(batch_size, seq_len, group_size, dst_dtype, provider):
|
|||||||
quantiles = [0.5, 0.2, 0.8]
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
|
||||||
if provider == "triton":
|
if provider == "triton":
|
||||||
fn = lambda: triton_per_token_group_quant_8bit(x.clone(), group_size, dst_dtype)
|
fn = lambda: triton_per_token_group_quant_8bit(x, group_size, dst_dtype)
|
||||||
elif provider == "sglang":
|
elif provider == "sglang":
|
||||||
fn = lambda: sglang_per_token_group_quant_8bit(x.clone(), group_size, dst_dtype)
|
fn = lambda: sglang_per_token_group_quant_8bit(x, group_size, dst_dtype)
|
||||||
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
|
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user