From f024795e57c3589a63df2457d3d64771989d4ed7 Mon Sep 17 00:00:00 2001 From: YyWangCS <104079332+YyWangCS@users.noreply.github.com> Date: Mon, 4 Aug 2025 10:02:51 +0800 Subject: [PATCH] Replace torch.jit.script with torch.compile in get_masked_input_and_mask to fix benchmark underreporting (#8733) --- python/sglang/srt/layers/vocab_parallel_embedding.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index ab1ced99a..66abb7541 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -26,7 +26,12 @@ from sglang.srt.layers.quantization.base_config import ( method_has_implemented_embedding, ) from sglang.srt.layers.quantization.unquant import UnquantizedEmbeddingMethod -from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs +from sglang.srt.utils import ( + cpu_has_amx_support, + get_compiler_backend, + is_cpu, + set_weight_attrs, +) DEFAULT_VOCAB_PADDING_SIZE = 64 @@ -117,7 +122,7 @@ class VocabParallelEmbeddingShardIndices: assert self.num_added_elements <= self.num_added_elements_padded -@torch.jit.script +@torch.compile(dynamic=True, backend=get_compiler_backend()) def get_masked_input_and_mask( input_: torch.Tensor, org_vocab_start_index: int, @@ -126,7 +131,7 @@ def get_masked_input_and_mask( added_vocab_start_index: int, added_vocab_end_index: int, ) -> Tuple[torch.Tensor, torch.Tensor]: - # torch.jit.script will fuse all of the pointwise ops below + # torch.compile will fuse all of the pointwise ops below # into a single kernel, making it very fast org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index) added_vocab_mask = (input_ >= added_vocab_start_index) & (