pass a_scale from fp8 quant result instead of hard code to 1.0f (#10241)

Co-authored-by: Yichen Wang <yichen.wang@bytedance.com>
Co-authored-by: Jinwu Guo <641876696@qq.com>
This commit is contained in:
Rain Jiang
2025-09-10 12:56:05 -07:00
committed by GitHub
parent 91b3555d2d
commit 2286e85e77
3 changed files with 34 additions and 29 deletions

View File

@@ -147,8 +147,8 @@ def cutlass_w4a8_moe(
k,
)
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half)
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half)
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16)
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16)
cutlass_w4a8_moe_mm(
c1,
@@ -166,7 +166,7 @@ def cutlass_w4a8_moe(
topk,
)
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half)
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16)
silu_and_mul(c1, intermediate)
intermediate_q = torch.empty(