Fix vocab mask update bug (#1376)
This commit is contained in:
@@ -652,8 +652,6 @@ class ScheduleBatch:
|
||||
self.req_pool_indices, self.seq_lens - 1
|
||||
] = self.out_cache_loc
|
||||
|
||||
self.sampling_info.update_regex_vocab_mask(self)
|
||||
|
||||
def filter_batch(self, unfinished_indices: List[int]):
|
||||
if unfinished_indices is None or len(unfinished_indices) == 0:
|
||||
# Filter out all requests
|
||||
|
||||
@@ -195,7 +195,8 @@ class InputMetadata:
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
)
|
||||
|
||||
ret.sampling_info.prepare_penalties()
|
||||
ret.sampling_info.update_penalties()
|
||||
ret.sampling_info.update_regex_vocab_mask(batch)
|
||||
|
||||
ret.compute_positions(batch)
|
||||
|
||||
|
||||
@@ -34,6 +34,9 @@ class SamplingBatchInfo:
|
||||
linear_penalties: torch.Tensor = None
|
||||
scaling_penalties: torch.Tensor = None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.temperatures)
|
||||
|
||||
def can_run_in_cuda_graph(self):
|
||||
# Vocab bias and min_ps are not supported in CUDA graph
|
||||
return (
|
||||
@@ -118,11 +121,9 @@ class SamplingBatchInfo:
|
||||
# Handle logit bias but only allocate when needed
|
||||
ret.logit_bias = None
|
||||
|
||||
ret.update_regex_vocab_mask(batch)
|
||||
|
||||
return ret
|
||||
|
||||
def prepare_penalties(self):
|
||||
def update_penalties(self):
|
||||
self.scaling_penalties = None
|
||||
self.linear_penalties = None
|
||||
|
||||
@@ -174,6 +175,26 @@ class SamplingBatchInfo:
|
||||
if self_val is not None: # logit_bias can be None
|
||||
setattr(self, item, self_val[new_indices])
|
||||
|
||||
@staticmethod
|
||||
def merge_bias_tensor(
|
||||
lhs: torch.Tensor, rhs: torch.Tensor, bs1: int, bs2: int, default: int = 0
|
||||
):
|
||||
# bias tensor can be None
|
||||
if lhs is not None or rhs is not None:
|
||||
shape, dtype = None, None
|
||||
if lhs is not None:
|
||||
shape, dtype = lhs.shape[1:], lhs.dtype
|
||||
else:
|
||||
shape, dtype = rhs.shape[1:], rhs.dtype
|
||||
with torch.dtype(dtype):
|
||||
if lhs is None:
|
||||
lhs = torch.empty((bs1, *shape), device="cuda").fill_(default)
|
||||
if rhs is None:
|
||||
rhs = torch.empty((bs2, *shape), device="cuda").fill_(default)
|
||||
return torch.cat([lhs, rhs])
|
||||
|
||||
return None
|
||||
|
||||
def merge(self, other: "SamplingBatchInfo"):
|
||||
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
||||
|
||||
@@ -187,19 +208,6 @@ class SamplingBatchInfo:
|
||||
other_val = getattr(other, item, None)
|
||||
setattr(self, item, torch.concat([self_val, other_val]))
|
||||
|
||||
# logit_bias can be None
|
||||
if self.logit_bias is not None or other.logit_bias is not None:
|
||||
vocab_size = (
|
||||
self.logit_bias.shape[1]
|
||||
if self.logit_bias is not None
|
||||
else other.logit_bias.shape[1]
|
||||
)
|
||||
if self.logit_bias is None:
|
||||
self.logit_bias = torch.zeros(
|
||||
(len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
if other.logit_bias is None:
|
||||
other.logit_bias = torch.zeros(
|
||||
(len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
||||
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
||||
self.logit_bias, other.logit_bias, len(self), len(other)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user