Clean up imports (#5467)
This commit is contained in:
@@ -3,9 +3,9 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
|
||||
|
||||
B, D = a.shape
|
||||
# Perform per-token quantization
|
||||
a_q, a_s = sgl_scaled_fp8_quant(a, use_per_token_if_dynamic=True)
|
||||
a_q, a_s = scaled_fp8_quant(a, use_per_token_if_dynamic=True)
|
||||
# Repeat tokens to match topk
|
||||
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
# Also repeat the scale
|
||||
@@ -69,7 +69,7 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
|
||||
# Activation function
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
# Quantize activation output with per-token
|
||||
act_out_q, act_out_s = sgl_scaled_fp8_quant(
|
||||
act_out_q, act_out_s = scaled_fp8_quant(
|
||||
act_out, use_per_token_if_dynamic=True
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user