Clean up imports (#5467)
This commit is contained in:
@@ -93,9 +93,7 @@ class TestPerTokenGroupQuantFP8(TestFP8Base):
|
||||
A, A_quant_gt, scale_gt = self._make_A(
|
||||
M=self.M, K=self.K, group_size=self.group_size, out_dtype=self.quant_type
|
||||
)
|
||||
A_quant, scale = per_token_group_quant_fp8(
|
||||
x=A, group_size=self.group_size, dtype=self.quant_type
|
||||
)
|
||||
A_quant, scale = per_token_group_quant_fp8(x=A, group_size=self.group_size)
|
||||
torch.testing.assert_close(scale, scale_gt)
|
||||
diff = (A_quant.to(torch.float16) - A_quant_gt.to(torch.float16)).abs()
|
||||
diff_count = (diff > 1e-5).count_nonzero()
|
||||
|
||||
@@ -3,9 +3,9 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
|
||||
|
||||
B, D = a.shape
|
||||
# Perform per-token quantization
|
||||
a_q, a_s = sgl_scaled_fp8_quant(a, use_per_token_if_dynamic=True)
|
||||
a_q, a_s = scaled_fp8_quant(a, use_per_token_if_dynamic=True)
|
||||
# Repeat tokens to match topk
|
||||
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
# Also repeat the scale
|
||||
@@ -69,7 +69,7 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
|
||||
# Activation function
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
# Quantize activation output with per-token
|
||||
act_out_q, act_out_s = sgl_scaled_fp8_quant(
|
||||
act_out_q, act_out_s = scaled_fp8_quant(
|
||||
act_out, use_per_token_if_dynamic=True
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user