From 27acf63bbd37eeb82231eca611a9d2947dc74ac6 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 25 Jan 2025 18:27:33 -0800 Subject: [PATCH] Use torch.compile for scaling penalty (#3133) --- .../benchmark_deepseekv3_moe_align_blocks.py | 1 - .../penalizers/repetition_penalty.py | 24 ++++++++----------- .../srt/sampling/sampling_batch_info.py | 18 ++++---------- 3 files changed, 14 insertions(+), 29 deletions(-) diff --git a/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py b/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py index d00f4985a..e2c4d8d35 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py @@ -1,6 +1,5 @@ import argparse import itertools -import time import torch import triton diff --git a/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py b/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py index fcd5ff71c..0f714c548 100644 --- a/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +++ b/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py @@ -3,11 +3,16 @@ from typing import List import torch from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs -from sglang.srt.utils import is_cuda_available +from sglang.srt.utils import get_compiler_backend -is_cuda = is_cuda_available() -if is_cuda: - from sgl_kernel import sampling_scaling_penalties + +@torch.compile(dynamic=True, backend=get_compiler_backend()) +def apply_scaling_penalties(logits, scaling_penalties): + logits[:] = torch.where( + logits > 0, + logits / scaling_penalties, + logits * scaling_penalties, + ) class BatchedRepetitionPenalizer(_BatchedPenalizer): @@ -61,16 +66,7 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer): self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask] def _apply(self, logits: torch.Tensor) -> torch.Tensor: - if is_cuda: - return sampling_scaling_penalties( - logits, self.cumulated_repetition_penalties - ) - else: - return torch.where( - logits > 0, - logits / self.cumulated_repetition_penalties, - logits * self.cumulated_repetition_penalties, - ) + apply_scaling_penalties(logits, self.cumulated_repetition_penalties) def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor): self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep] diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index a27ff1ad2..9521a34f4 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -7,14 +7,11 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple import torch -from sglang.srt.utils import is_cuda_available - -is_cuda = is_cuda_available() -if is_cuda: - from sgl_kernel import sampling_scaling_penalties - import sglang.srt.sampling.penaltylib as penaltylib from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor +from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import ( + apply_scaling_penalties, +) logger = logging.getLogger(__name__) @@ -386,14 +383,7 @@ class SamplingBatchInfo: # repetition if self.scaling_penalties is not None: - if is_cuda: - logits[:] = sampling_scaling_penalties(logits, self.scaling_penalties) - else: - logits[:] = torch.where( - logits > 0, - logits / self.scaling_penalties, - logits * self.scaling_penalties, - ) + apply_scaling_penalties(logits, self.scaling_penalties) # Apply regex vocab_mask if self.vocab_mask is not None: