[Fix] Move ScheduleBatch out of SamplingInfo (#1556)
This commit is contained in:
@@ -423,10 +423,14 @@ class ScheduleBatch:
|
|||||||
# Stream
|
# Stream
|
||||||
has_stream: bool = False
|
has_stream: bool = False
|
||||||
|
|
||||||
|
# Has regex
|
||||||
|
has_regex: bool = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
||||||
return_logprob = any(req.return_logprob for req in reqs)
|
return_logprob = any(req.return_logprob for req in reqs)
|
||||||
has_stream = any(req.stream for req in reqs)
|
has_stream = any(req.stream for req in reqs)
|
||||||
|
has_regex = any(req.regex_fsm for req in reqs)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
reqs=reqs,
|
reqs=reqs,
|
||||||
@@ -435,6 +439,7 @@ class ScheduleBatch:
|
|||||||
tree_cache=tree_cache,
|
tree_cache=tree_cache,
|
||||||
return_logprob=return_logprob,
|
return_logprob=return_logprob,
|
||||||
has_stream=has_stream,
|
has_stream=has_stream,
|
||||||
|
has_regex=has_regex,
|
||||||
)
|
)
|
||||||
|
|
||||||
def batch_size(self):
|
def batch_size(self):
|
||||||
@@ -750,7 +755,9 @@ class ScheduleBatch:
|
|||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
self.top_logprobs_nums = None
|
self.top_logprobs_nums = None
|
||||||
|
|
||||||
self.has_stream = any(req.stream for req in self.reqs)
|
self.has_stream = any(req.stream for req in self.reqs)
|
||||||
|
self.has_regex = any(req.regex_fsm for req in self.reqs)
|
||||||
|
|
||||||
self.sampling_info.filter_batch(unfinished_indices, new_indices)
|
self.sampling_info.filter_batch(unfinished_indices, new_indices)
|
||||||
|
|
||||||
@@ -771,9 +778,11 @@ class ScheduleBatch:
|
|||||||
self.top_logprobs_nums.extend([0] * len(other.reqs))
|
self.top_logprobs_nums.extend([0] * len(other.reqs))
|
||||||
elif other.return_logprob:
|
elif other.return_logprob:
|
||||||
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
|
||||||
self.has_stream = any(req.stream for req in self.reqs)
|
|
||||||
self.reqs.extend(other.reqs)
|
self.reqs.extend(other.reqs)
|
||||||
|
|
||||||
self.return_logprob = self.return_logprob or other.return_logprob
|
self.return_logprob = self.return_logprob or other.return_logprob
|
||||||
|
self.has_stream = self.has_stream or other.has_stream
|
||||||
|
self.has_regex = self.has_regex or other.has_regex
|
||||||
|
|
||||||
def get_model_worker_batch(self):
|
def get_model_worker_batch(self):
|
||||||
if self.forward_mode.is_decode():
|
if self.forward_mode.is_decode():
|
||||||
@@ -787,7 +796,11 @@ class ScheduleBatch:
|
|||||||
image_inputs = [r.image_inputs for r in self.reqs]
|
image_inputs = [r.image_inputs for r in self.reqs]
|
||||||
|
|
||||||
lora_paths = [req.lora_path for req in self.reqs]
|
lora_paths = [req.lora_path for req in self.reqs]
|
||||||
self.sampling_info.regex_fsm_states = [req.regex_fsm_state for req in self.reqs]
|
if self.has_regex:
|
||||||
|
self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
|
||||||
|
self.sampling_info.regex_fsm_states = [
|
||||||
|
req.regex_fsm_state for req in self.reqs
|
||||||
|
]
|
||||||
|
|
||||||
return ModelWorkerBatch(
|
return ModelWorkerBatch(
|
||||||
forward_mode=self.forward_mode,
|
forward_mode=self.forward_mode,
|
||||||
|
|||||||
@@ -84,10 +84,6 @@ class SamplingBatchInfo:
|
|||||||
# Handle logit bias but only allocate when needed
|
# Handle logit bias but only allocate when needed
|
||||||
ret.logit_bias = None
|
ret.logit_bias = None
|
||||||
|
|
||||||
# This is only for regex_fsm. We notice a regression if we maintain the list of regex_fsm
|
|
||||||
# in SamplingBatchInfo, so we keep it here.
|
|
||||||
ret.schedule_batch = batch
|
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@@ -113,7 +109,7 @@ class SamplingBatchInfo:
|
|||||||
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
||||||
|
|
||||||
def update_regex_vocab_mask(self):
|
def update_regex_vocab_mask(self):
|
||||||
has_regex = any(req.regex_fsm is not None for req in self.schedule_batch.reqs)
|
has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)
|
||||||
|
|
||||||
# Reset the vocab mask
|
# Reset the vocab mask
|
||||||
self.vocab_mask = None
|
self.vocab_mask = None
|
||||||
@@ -122,11 +118,11 @@ class SamplingBatchInfo:
|
|||||||
self.vocab_mask = torch.zeros(
|
self.vocab_mask = torch.zeros(
|
||||||
len(self.temperatures), self.vocab_size, dtype=torch.bool, device="cuda"
|
len(self.temperatures), self.vocab_size, dtype=torch.bool, device="cuda"
|
||||||
)
|
)
|
||||||
for i, req in enumerate(self.schedule_batch.reqs):
|
for i, regex_fsm in enumerate(self.regex_fsms):
|
||||||
if req.regex_fsm is not None:
|
if regex_fsm is not None:
|
||||||
self.vocab_mask[i].fill_(1)
|
self.vocab_mask[i].fill_(1)
|
||||||
self.vocab_mask[i][
|
self.vocab_mask[i][
|
||||||
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
|
regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
|
||||||
] = 0
|
] = 0
|
||||||
|
|
||||||
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
||||||
|
|||||||
Reference in New Issue
Block a user