Sampling penalties memory interface (#2870)
This commit is contained in:
@@ -7,6 +7,12 @@ from typing import TYPE_CHECKING, Callable, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import is_cuda_available
|
||||
|
||||
is_cuda = is_cuda_available()
|
||||
if is_cuda:
|
||||
from sgl_kernel import sampling_scaling_penalties
|
||||
|
||||
import sglang.srt.sampling.penaltylib as penaltylib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -245,11 +251,14 @@ class SamplingBatchInfo:
|
||||
|
||||
# repetition
|
||||
if self.scaling_penalties is not None:
|
||||
logits[:] = torch.where(
|
||||
logits > 0,
|
||||
logits / self.scaling_penalties,
|
||||
logits * self.scaling_penalties,
|
||||
)
|
||||
if is_cuda:
|
||||
logits[:] = sampling_scaling_penalties(logits, self.scaling_penalties)
|
||||
else:
|
||||
logits[:] = torch.where(
|
||||
logits > 0,
|
||||
logits / self.scaling_penalties,
|
||||
logits * self.scaling_penalties,
|
||||
)
|
||||
|
||||
# Apply regex vocab_mask
|
||||
if self.vocab_mask is not None:
|
||||
|
||||
Reference in New Issue
Block a user