fix typo in deep gemm benchmarking(#3991)
This commit is contained in:
@@ -211,7 +211,6 @@ def get_benchmark(tp_size):
|
|||||||
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||||
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
# 预处理数据,在计时之前完成
|
|
||||||
x_fp8, x_scale = per_token_cast_to_fp8(x)
|
x_fp8, x_scale = per_token_cast_to_fp8(x)
|
||||||
y_fp8, y_scale = per_block_cast_to_fp8(y)
|
y_fp8, y_scale = per_block_cast_to_fp8(y)
|
||||||
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())
|
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())
|
||||||
|
|||||||
Reference in New Issue
Block a user