[Fix] do not maintain regex_fsm in SamplingBatchInfo (#1555)
This commit is contained in:
@@ -59,7 +59,6 @@ class SamplingBatchInfo:
|
|||||||
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
||||||
)
|
)
|
||||||
|
|
||||||
ret.regex_fsms = [r.regex_fsm for r in reqs]
|
|
||||||
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
||||||
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
|
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
|
||||||
|
|
||||||
@@ -85,6 +84,10 @@ class SamplingBatchInfo:
|
|||||||
# Handle logit bias but only allocate when needed
|
# Handle logit bias but only allocate when needed
|
||||||
ret.logit_bias = None
|
ret.logit_bias = None
|
||||||
|
|
||||||
|
# This is only for regex_fsm. We notice a regression if we maintain the list of regex_fsm
|
||||||
|
# in SamplingBatchInfo, so we keep it here.
|
||||||
|
ret.schedule_batch = batch
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@@ -110,18 +113,20 @@ class SamplingBatchInfo:
|
|||||||
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
||||||
|
|
||||||
def update_regex_vocab_mask(self):
|
def update_regex_vocab_mask(self):
|
||||||
|
has_regex = any(req.regex_fsm is not None for req in self.schedule_batch.reqs)
|
||||||
|
|
||||||
# Reset the vocab mask
|
# Reset the vocab mask
|
||||||
self.vocab_mask = None
|
self.vocab_mask = None
|
||||||
|
|
||||||
if any(regex_fsm is not None for regex_fsm in self.regex_fsms):
|
if has_regex:
|
||||||
self.vocab_mask = torch.zeros(
|
self.vocab_mask = torch.zeros(
|
||||||
len(self.regex_fsms), self.vocab_size, dtype=torch.bool, device="cuda"
|
len(self.temperatures), self.vocab_size, dtype=torch.bool, device="cuda"
|
||||||
)
|
)
|
||||||
for i, regex_fsm in enumerate(self.regex_fsms):
|
for i, req in enumerate(self.schedule_batch.reqs):
|
||||||
if regex_fsm is not None:
|
if req.regex_fsm is not None:
|
||||||
self.vocab_mask[i].fill_(1)
|
self.vocab_mask[i].fill_(1)
|
||||||
self.vocab_mask[i][
|
self.vocab_mask[i][
|
||||||
regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
|
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
|
||||||
] = 0
|
] = 0
|
||||||
|
|
||||||
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
||||||
@@ -138,8 +143,6 @@ class SamplingBatchInfo:
|
|||||||
if value is not None: # logit_bias can be None
|
if value is not None: # logit_bias can be None
|
||||||
setattr(self, item, value[new_indices])
|
setattr(self, item, value[new_indices])
|
||||||
|
|
||||||
self.regex_fsms = [self.regex_fsms[i] for i in new_indices]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def merge_bias_tensor(
|
def merge_bias_tensor(
|
||||||
lhs: torch.Tensor, rhs: torch.Tensor, bs1: int, bs2: int, default: int = 0
|
lhs: torch.Tensor, rhs: torch.Tensor, bs1: int, bs2: int, default: int = 0
|
||||||
@@ -176,5 +179,3 @@ class SamplingBatchInfo:
|
|||||||
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
||||||
self.logit_bias, other.logit_bias, len(self), len(other)
|
self.logit_bias, other.logit_bias, len(self), len(other)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.regex_fsms.extend(other.regex_fsms)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user