Clean up batch data structures: Introducing ModelWorkerBatch (#1544)

This commit is contained in:
Lianmin Zheng
2024-09-30 06:41:49 -07:00
committed by GitHub
parent 36d5acfca5
commit 63ba2f8d7b
9 changed files with 274 additions and 155 deletions

View File

@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, List
import torch
import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.constrained import RegexGuide
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch
@@ -22,13 +23,17 @@ class SamplingBatchInfo:
top_ks: torch.Tensor = None
min_ps: torch.Tensor = None
# Dispatch in CUDA graph
need_min_p_sampling: bool = False
# Bias Tensors
logit_bias: torch.Tensor = None
vocab_mask: torch.Tensor = None
# FSM states
regex_fsms: List[RegexGuide] = None
regex_fsm_states: List[int] = None
# Dispatch in CUDA graph
need_min_p_sampling: bool = False
# Penalizer
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
linear_penalties: torch.Tensor = None
@@ -54,6 +59,8 @@ class SamplingBatchInfo:
[r.sampling_params.min_p for r in reqs], dtype=torch.float
)
ret.regex_fsms = [r.regex_fsm for r in reqs]
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
@@ -102,24 +109,22 @@ class SamplingBatchInfo:
)
self.linear_penalties = penalizer.apply(self.linear_penalties)
def update_regex_vocab_mask(self, batch: ScheduleBatch):
has_regex = any(req.regex_fsm is not None for req in batch.reqs)
def update_regex_vocab_mask(self):
# Reset the vocab mask
self.vocab_mask = None
if has_regex:
if any(regex_fsm is not None for regex_fsm in self.regex_fsms):
self.vocab_mask = torch.zeros(
batch.batch_size(), self.vocab_size, dtype=torch.bool, device="cuda"
len(self.regex_fsms), self.vocab_size, dtype=torch.bool, device="cuda"
)
for i, req in enumerate(batch.reqs):
if req.regex_fsm is not None:
for i, regex_fsm in enumerate(self.regex_fsms):
if regex_fsm is not None:
self.vocab_mask[i].fill_(1)
self.vocab_mask[i][
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
] = 0
def filter(self, unfinished_indices: List[int], new_indices: torch.Tensor):
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
for item in [
@@ -129,9 +134,11 @@ class SamplingBatchInfo:
"min_ps",
"logit_bias",
]:
self_val = getattr(self, item, None)
if self_val is not None: # logit_bias can be None
setattr(self, item, self_val[new_indices])
value = getattr(self, item, None)
if value is not None: # logit_bias can be None
setattr(self, item, value[new_indices])
self.regex_fsms = [self.regex_fsms[i] for i in new_indices]
@staticmethod
def merge_bias_tensor(
@@ -153,7 +160,7 @@ class SamplingBatchInfo:
return None
def merge(self, other: "SamplingBatchInfo"):
def merge_batch(self, other: "SamplingBatchInfo"):
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
for item in [
@@ -169,3 +176,5 @@ class SamplingBatchInfo:
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other)
)
self.regex_fsms.extend(other.regex_fsms)