Replace torch.jit.script with torch.compile in get_masked_input_and_mask to fix benchmark underreporting (#8733)
This commit is contained in:
@@ -26,7 +26,12 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
method_has_implemented_embedding,
|
method_has_implemented_embedding,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.unquant import UnquantizedEmbeddingMethod
|
from sglang.srt.layers.quantization.unquant import UnquantizedEmbeddingMethod
|
||||||
from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs
|
from sglang.srt.utils import (
|
||||||
|
cpu_has_amx_support,
|
||||||
|
get_compiler_backend,
|
||||||
|
is_cpu,
|
||||||
|
set_weight_attrs,
|
||||||
|
)
|
||||||
|
|
||||||
DEFAULT_VOCAB_PADDING_SIZE = 64
|
DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||||
|
|
||||||
@@ -117,7 +122,7 @@ class VocabParallelEmbeddingShardIndices:
|
|||||||
assert self.num_added_elements <= self.num_added_elements_padded
|
assert self.num_added_elements <= self.num_added_elements_padded
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||||
def get_masked_input_and_mask(
|
def get_masked_input_and_mask(
|
||||||
input_: torch.Tensor,
|
input_: torch.Tensor,
|
||||||
org_vocab_start_index: int,
|
org_vocab_start_index: int,
|
||||||
@@ -126,7 +131,7 @@ def get_masked_input_and_mask(
|
|||||||
added_vocab_start_index: int,
|
added_vocab_start_index: int,
|
||||||
added_vocab_end_index: int,
|
added_vocab_end_index: int,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
# torch.jit.script will fuse all of the pointwise ops below
|
# torch.compile will fuse all of the pointwise ops below
|
||||||
# into a single kernel, making it very fast
|
# into a single kernel, making it very fast
|
||||||
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index)
|
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index)
|
||||||
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
||||||
|
|||||||
Reference in New Issue
Block a user