[EAGLE] Fix some boundary situation when retract reqs and req's max token = 1 (#2939)
Co-authored-by: josephyou <josephyou@tencent.com>
This commit is contained in:
@@ -1112,6 +1112,8 @@ class ScheduleBatch:
|
|||||||
self.has_grammar = any(req.grammar for req in self.reqs)
|
self.has_grammar = any(req.grammar for req in self.reqs)
|
||||||
|
|
||||||
self.sampling_info.filter_batch(keep_indices, new_indices)
|
self.sampling_info.filter_batch(keep_indices, new_indices)
|
||||||
|
if self.spec_info:
|
||||||
|
self.spec_info.filter_batch(new_indices)
|
||||||
|
|
||||||
def merge_batch(self, other: "ScheduleBatch"):
|
def merge_batch(self, other: "ScheduleBatch"):
|
||||||
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
||||||
|
|||||||
@@ -228,6 +228,14 @@ class EAGLEDraftInput(SpecInfo):
|
|||||||
assert len(batch.extend_lens) == 1
|
assert len(batch.extend_lens) == 1
|
||||||
batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
|
batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
|
||||||
|
|
||||||
|
def filter_batch(
|
||||||
|
self,
|
||||||
|
new_indices: torch.Tensor,
|
||||||
|
):
|
||||||
|
self.sample_output = self.sample_output[: len(new_indices)]
|
||||||
|
self.hidden_states = self.hidden_states[: len(new_indices)]
|
||||||
|
self.verified_id = self.verified_id[: len(new_indices)]
|
||||||
|
|
||||||
def prepare_for_decode(self, batch: ScheduleBatch):
|
def prepare_for_decode(self, batch: ScheduleBatch):
|
||||||
prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab)
|
prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab)
|
||||||
top = torch.topk(prob, self.topk, dim=-1)
|
top = torch.topk(prob, self.topk, dim=-1)
|
||||||
|
|||||||
Reference in New Issue
Block a user