pass a_scale from fp8 quant result instead of hard code to 1.0f (#10241)
Co-authored-by: Yichen Wang <yichen.wang@bytedance.com> Co-authored-by: Jinwu Guo <641876696@qq.com>
This commit is contained in:
@@ -147,8 +147,8 @@ def cutlass_w4a8_moe(
|
||||
k,
|
||||
)
|
||||
|
||||
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.half)
|
||||
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.half)
|
||||
c1 = torch.empty((m * topk, n * 2), device=device, dtype=torch.bfloat16)
|
||||
c2 = torch.zeros((m * topk, k), device=device, dtype=torch.bfloat16)
|
||||
|
||||
cutlass_w4a8_moe_mm(
|
||||
c1,
|
||||
@@ -166,7 +166,7 @@ def cutlass_w4a8_moe(
|
||||
topk,
|
||||
)
|
||||
|
||||
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.half)
|
||||
intermediate = torch.empty((m * topk, n), device=device, dtype=torch.bfloat16)
|
||||
silu_and_mul(c1, intermediate)
|
||||
|
||||
intermediate_q = torch.empty(
|
||||
|
||||
Reference in New Issue
Block a user