[Minor] clean up multimodal processor and tokenizer manager (#7624)
This commit is contained in:
@@ -10,7 +10,6 @@ import torch
|
||||
import sglang.srt.sampling.penaltylib as penaltylib
|
||||
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
||||
from sglang.srt.sampling.sampling_params import TOP_K_ALL
|
||||
from sglang.srt.utils import merge_bias_tensor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
@@ -345,3 +344,42 @@ class SamplingBatchInfo:
|
||||
self.logit_bias = merge_bias_tensor(
|
||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0
|
||||
)
|
||||
|
||||
|
||||
def merge_bias_tensor(
|
||||
lhs: Optional[torch.Tensor],
|
||||
rhs: Optional[torch.Tensor],
|
||||
bs1: int,
|
||||
bs2: int,
|
||||
device: str,
|
||||
default: float,
|
||||
):
|
||||
"""Merge two bias tensors for batch merging.
|
||||
|
||||
Args:
|
||||
lhs: Left-hand side tensor
|
||||
rhs: Right-hand side tensor
|
||||
bs1: Batch size of left-hand side tensor
|
||||
bs2: Batch size of right-hand side tensor
|
||||
device: Device to place the merged tensor on
|
||||
default: Default value for missing tensor elements
|
||||
|
||||
Returns:
|
||||
Merged tensor or None if both inputs are None
|
||||
"""
|
||||
if lhs is None and rhs is None:
|
||||
return None
|
||||
|
||||
if lhs is not None and rhs is not None:
|
||||
return torch.cat([lhs, rhs])
|
||||
else:
|
||||
if lhs is not None:
|
||||
shape, dtype = lhs.shape[1:], lhs.dtype
|
||||
else:
|
||||
shape, dtype = rhs.shape[1:], rhs.dtype
|
||||
|
||||
if lhs is None:
|
||||
lhs = torch.empty((bs1, *shape), device=device, dtype=dtype).fill_(default)
|
||||
if rhs is None:
|
||||
rhs = torch.empty((bs2, *shape), device=device, dtype=dtype).fill_(default)
|
||||
return torch.cat([lhs, rhs])
|
||||
|
||||
Reference in New Issue
Block a user