Sampling penalties memory interface (#2870)

This commit is contained in:
Xiaoyu Zhang
2025-01-13 23:09:00 +08:00
committed by GitHub
parent c1e097ca66
commit d08c77c434
7 changed files with 251 additions and 41 deletions

View File

@@ -27,7 +27,7 @@ runtime_common = [
]
srt = [
"sglang[runtime_common]", "cuda-python",
"sgl-kernel>=0.0.2.post11", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1",
"sgl-kernel>=0.0.2.post12", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1",
"flashinfer==0.1.6"
]

View File

@@ -3,6 +3,11 @@ from typing import List
import torch
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
from sglang.srt.utils import is_cuda_available
is_cuda = is_cuda_available()
if is_cuda:
from sgl_kernel import sampling_scaling_penalties
class BatchedRepetitionPenalizer(_BatchedPenalizer):
@@ -56,11 +61,16 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
def _apply(self, logits: torch.Tensor) -> torch.Tensor:
return torch.where(
logits > 0,
logits / self.cumulated_repetition_penalties,
logits * self.cumulated_repetition_penalties,
)
if is_cuda:
return sampling_scaling_penalties(
logits, self.cumulated_repetition_penalties
)
else:
return torch.where(
logits > 0,
logits / self.cumulated_repetition_penalties,
logits * self.cumulated_repetition_penalties,
)
def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]

View File

@@ -7,6 +7,12 @@ from typing import TYPE_CHECKING, Callable, List, Optional
import torch
from sglang.srt.utils import is_cuda_available
is_cuda = is_cuda_available()
if is_cuda:
from sgl_kernel import sampling_scaling_penalties
import sglang.srt.sampling.penaltylib as penaltylib
logger = logging.getLogger(__name__)
@@ -245,11 +251,14 @@ class SamplingBatchInfo:
# repetition
if self.scaling_penalties is not None:
logits[:] = torch.where(
logits > 0,
logits / self.scaling_penalties,
logits * self.scaling_penalties,
)
if is_cuda:
logits[:] = sampling_scaling_penalties(logits, self.scaling_penalties)
else:
logits[:] = torch.where(
logits > 0,
logits / self.scaling_penalties,
logits * self.scaling_penalties,
)
# Apply regex vocab_mask
if self.vocab_mask is not None:

View File

@@ -97,6 +97,10 @@ def is_flashinfer_available():
return torch.cuda.is_available() and torch.version.cuda
def is_cuda_available():
return torch.cuda.is_available() and torch.version.cuda
def is_ipv6(address):
try:
ipaddress.IPv6Address(address)