Use torch.compile for scaling penalty (#3133)
This commit is contained in:
@@ -1,6 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import itertools
|
import itertools
|
||||||
import time
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
|
|||||||
@@ -3,11 +3,16 @@ from typing import List
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
|
||||||
from sglang.srt.utils import is_cuda_available
|
from sglang.srt.utils import get_compiler_backend
|
||||||
|
|
||||||
is_cuda = is_cuda_available()
|
|
||||||
if is_cuda:
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||||
from sgl_kernel import sampling_scaling_penalties
|
def apply_scaling_penalties(logits, scaling_penalties):
|
||||||
|
logits[:] = torch.where(
|
||||||
|
logits > 0,
|
||||||
|
logits / scaling_penalties,
|
||||||
|
logits * scaling_penalties,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
||||||
@@ -61,16 +66,7 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
|
|||||||
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
|
||||||
|
|
||||||
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
|
||||||
if is_cuda:
|
apply_scaling_penalties(logits, self.cumulated_repetition_penalties)
|
||||||
return sampling_scaling_penalties(
|
|
||||||
logits, self.cumulated_repetition_penalties
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return torch.where(
|
|
||||||
logits > 0,
|
|
||||||
logits / self.cumulated_repetition_penalties,
|
|
||||||
logits * self.cumulated_repetition_penalties,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
|
||||||
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
|
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
|
||||||
|
|||||||
@@ -7,14 +7,11 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.utils import is_cuda_available
|
|
||||||
|
|
||||||
is_cuda = is_cuda_available()
|
|
||||||
if is_cuda:
|
|
||||||
from sgl_kernel import sampling_scaling_penalties
|
|
||||||
|
|
||||||
import sglang.srt.sampling.penaltylib as penaltylib
|
import sglang.srt.sampling.penaltylib as penaltylib
|
||||||
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||||
|
from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import (
|
||||||
|
apply_scaling_penalties,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -386,14 +383,7 @@ class SamplingBatchInfo:
|
|||||||
|
|
||||||
# repetition
|
# repetition
|
||||||
if self.scaling_penalties is not None:
|
if self.scaling_penalties is not None:
|
||||||
if is_cuda:
|
apply_scaling_penalties(logits, self.scaling_penalties)
|
||||||
logits[:] = sampling_scaling_penalties(logits, self.scaling_penalties)
|
|
||||||
else:
|
|
||||||
logits[:] = torch.where(
|
|
||||||
logits > 0,
|
|
||||||
logits / self.scaling_penalties,
|
|
||||||
logits * self.scaling_penalties,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply regex vocab_mask
|
# Apply regex vocab_mask
|
||||||
if self.vocab_mask is not None:
|
if self.vocab_mask is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user