Restruct sgl-kernel benchmark (#10861)
This commit is contained in:
@@ -19,10 +19,6 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||
from sglang.srt.utils import is_npu, set_weight_attrs
|
||||
|
||||
_is_npu = is_npu()
|
||||
if not _is_npu:
|
||||
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe import MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
|
||||
@@ -621,11 +621,11 @@ class TestW8A8BlockFP8BatchedDeepGemm(CustomTestCase):
|
||||
w_s,
|
||||
)
|
||||
|
||||
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
|
||||
from deep_gemm import fp8_m_grouped_gemm_nt_masked
|
||||
|
||||
with torch.inference_mode():
|
||||
ref_out = torch_w8a8_block_fp8_bmm(a, a_s, w, w_s, block_size, dtype)
|
||||
m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs, rhs, oe, masked_m, expected_m)
|
||||
fp8_m_grouped_gemm_nt_masked(lhs, rhs, oe, masked_m, expected_m)
|
||||
out = oe[:, :M, :]
|
||||
|
||||
self.assertTrue(
|
||||
|
||||
Reference in New Issue
Block a user