Use torch.compile for scaling penalty (#3133)

This commit is contained in:
Lianmin Zheng
2025-01-25 18:27:33 -08:00
committed by GitHub
parent da6f8081f6
commit 27acf63bbd
3 changed files with 14 additions and 29 deletions

View File

@@ -7,14 +7,11 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
import torch
from sglang.srt.utils import is_cuda_available
is_cuda = is_cuda_available()
if is_cuda:
from sgl_kernel import sampling_scaling_penalties
import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import (
apply_scaling_penalties,
)
logger = logging.getLogger(__name__)
@@ -386,14 +383,7 @@ class SamplingBatchInfo:
# repetition
if self.scaling_penalties is not None:
if is_cuda:
logits[:] = sampling_scaling_penalties(logits, self.scaling_penalties)
else:
logits[:] = torch.where(
logits > 0,
logits / self.scaling_penalties,
logits * self.scaling_penalties,
)
apply_scaling_penalties(logits, self.scaling_penalties)
# Apply regex vocab_mask
if self.vocab_mask is not None: