Fix spec filter batch when target extend (#10991)
This commit is contained in:
@@ -1736,7 +1736,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
|
|
||||||
self.sampling_info.filter_batch(keep_indices, keep_indices_device)
|
self.sampling_info.filter_batch(keep_indices, keep_indices_device)
|
||||||
if self.spec_info:
|
if self.spec_info:
|
||||||
self.spec_info.filter_batch(keep_indices_device)
|
if chunked_req_to_exclude is not None and len(chunked_req_to_exclude) > 0:
|
||||||
|
has_been_filtered = False
|
||||||
|
else:
|
||||||
|
has_been_filtered = True
|
||||||
|
self.spec_info.filter_batch(
|
||||||
|
new_indices=keep_indices_device,
|
||||||
|
has_been_filtered=has_been_filtered,
|
||||||
|
)
|
||||||
|
|
||||||
def merge_batch(self, other: "ScheduleBatch"):
|
def merge_batch(self, other: "ScheduleBatch"):
|
||||||
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
||||||
|
|||||||
@@ -405,7 +405,7 @@ class NgramVerifyInput:
|
|||||||
|
|
||||||
return logits_output, self.verified_id, self.accept_length.sum().item()
|
return logits_output, self.verified_id, self.accept_length.sum().item()
|
||||||
|
|
||||||
def filter_batch(self, new_indices: torch.Tensor):
|
def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def merge_batch(self, spec_info: NgramVerifyInput):
|
def merge_batch(self, spec_info: NgramVerifyInput):
|
||||||
|
|||||||
Reference in New Issue
Block a user