Fix vocab mask update bug (#1376)
This commit is contained in:
@@ -652,8 +652,6 @@ class ScheduleBatch:
|
|||||||
self.req_pool_indices, self.seq_lens - 1
|
self.req_pool_indices, self.seq_lens - 1
|
||||||
] = self.out_cache_loc
|
] = self.out_cache_loc
|
||||||
|
|
||||||
self.sampling_info.update_regex_vocab_mask(self)
|
|
||||||
|
|
||||||
def filter_batch(self, unfinished_indices: List[int]):
|
def filter_batch(self, unfinished_indices: List[int]):
|
||||||
if unfinished_indices is None or len(unfinished_indices) == 0:
|
if unfinished_indices is None or len(unfinished_indices) == 0:
|
||||||
# Filter out all requests
|
# Filter out all requests
|
||||||
|
|||||||
@@ -195,7 +195,8 @@ class InputMetadata:
|
|||||||
top_logprobs_nums=batch.top_logprobs_nums,
|
top_logprobs_nums=batch.top_logprobs_nums,
|
||||||
)
|
)
|
||||||
|
|
||||||
ret.sampling_info.prepare_penalties()
|
ret.sampling_info.update_penalties()
|
||||||
|
ret.sampling_info.update_regex_vocab_mask(batch)
|
||||||
|
|
||||||
ret.compute_positions(batch)
|
ret.compute_positions(batch)
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,9 @@ class SamplingBatchInfo:
|
|||||||
linear_penalties: torch.Tensor = None
|
linear_penalties: torch.Tensor = None
|
||||||
scaling_penalties: torch.Tensor = None
|
scaling_penalties: torch.Tensor = None
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.temperatures)
|
||||||
|
|
||||||
def can_run_in_cuda_graph(self):
|
def can_run_in_cuda_graph(self):
|
||||||
# Vocab bias and min_ps are not supported in CUDA graph
|
# Vocab bias and min_ps are not supported in CUDA graph
|
||||||
return (
|
return (
|
||||||
@@ -118,11 +121,9 @@ class SamplingBatchInfo:
|
|||||||
# Handle logit bias but only allocate when needed
|
# Handle logit bias but only allocate when needed
|
||||||
ret.logit_bias = None
|
ret.logit_bias = None
|
||||||
|
|
||||||
ret.update_regex_vocab_mask(batch)
|
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def prepare_penalties(self):
|
def update_penalties(self):
|
||||||
self.scaling_penalties = None
|
self.scaling_penalties = None
|
||||||
self.linear_penalties = None
|
self.linear_penalties = None
|
||||||
|
|
||||||
@@ -174,6 +175,26 @@ class SamplingBatchInfo:
|
|||||||
if self_val is not None: # logit_bias can be None
|
if self_val is not None: # logit_bias can be None
|
||||||
setattr(self, item, self_val[new_indices])
|
setattr(self, item, self_val[new_indices])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def merge_bias_tensor(
|
||||||
|
lhs: torch.Tensor, rhs: torch.Tensor, bs1: int, bs2: int, default: int = 0
|
||||||
|
):
|
||||||
|
# bias tensor can be None
|
||||||
|
if lhs is not None or rhs is not None:
|
||||||
|
shape, dtype = None, None
|
||||||
|
if lhs is not None:
|
||||||
|
shape, dtype = lhs.shape[1:], lhs.dtype
|
||||||
|
else:
|
||||||
|
shape, dtype = rhs.shape[1:], rhs.dtype
|
||||||
|
with torch.dtype(dtype):
|
||||||
|
if lhs is None:
|
||||||
|
lhs = torch.empty((bs1, *shape), device="cuda").fill_(default)
|
||||||
|
if rhs is None:
|
||||||
|
rhs = torch.empty((bs2, *shape), device="cuda").fill_(default)
|
||||||
|
return torch.cat([lhs, rhs])
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def merge(self, other: "SamplingBatchInfo"):
|
def merge(self, other: "SamplingBatchInfo"):
|
||||||
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
||||||
|
|
||||||
@@ -187,19 +208,6 @@ class SamplingBatchInfo:
|
|||||||
other_val = getattr(other, item, None)
|
other_val = getattr(other, item, None)
|
||||||
setattr(self, item, torch.concat([self_val, other_val]))
|
setattr(self, item, torch.concat([self_val, other_val]))
|
||||||
|
|
||||||
# logit_bias can be None
|
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
||||||
if self.logit_bias is not None or other.logit_bias is not None:
|
self.logit_bias, other.logit_bias, len(self), len(other)
|
||||||
vocab_size = (
|
)
|
||||||
self.logit_bias.shape[1]
|
|
||||||
if self.logit_bias is not None
|
|
||||||
else other.logit_bias.shape[1]
|
|
||||||
)
|
|
||||||
if self.logit_bias is None:
|
|
||||||
self.logit_bias = torch.zeros(
|
|
||||||
(len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
|
||||||
)
|
|
||||||
if other.logit_bias is None:
|
|
||||||
other.logit_bias = torch.zeros(
|
|
||||||
(len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
|
||||||
)
|
|
||||||
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
|
||||||
|
|||||||
Reference in New Issue
Block a user