[Minor] clean up multimodal processor and tokenizer manager (#7624)

This commit is contained in:
Lianmin Zheng
2025-06-29 02:50:14 -07:00
committed by GitHub
parent 7c0db3a6c5
commit 071a1f51ae
9 changed files with 147 additions and 165 deletions

View File

@@ -10,7 +10,6 @@ import torch
import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.sampling.sampling_params import TOP_K_ALL
from sglang.srt.utils import merge_bias_tensor
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch
@@ -345,3 +344,42 @@ class SamplingBatchInfo:
self.logit_bias = merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0
)
def merge_bias_tensor(
lhs: Optional[torch.Tensor],
rhs: Optional[torch.Tensor],
bs1: int,
bs2: int,
device: str,
default: float,
):
"""Merge two bias tensors for batch merging.
Args:
lhs: Left-hand side tensor
rhs: Right-hand side tensor
bs1: Batch size of left-hand side tensor
bs2: Batch size of right-hand side tensor
device: Device to place the merged tensor on
default: Default value for missing tensor elements
Returns:
Merged tensor or None if both inputs are None
"""
if lhs is None and rhs is None:
return None
if lhs is not None and rhs is not None:
return torch.cat([lhs, rhs])
else:
if lhs is not None:
shape, dtype = lhs.shape[1:], lhs.dtype
else:
shape, dtype = rhs.shape[1:], rhs.dtype
if lhs is None:
lhs = torch.empty((bs1, *shape), device=device, dtype=dtype).fill_(default)
if rhs is None:
rhs = torch.empty((bs2, *shape), device=device, dtype=dtype).fill_(default)
return torch.cat([lhs, rhs])