[EAGLE] Fix some boundary situation when retract reqs and req's max token = 1 (#2939)

Co-authored-by: josephyou <josephyou@tencent.com>
This commit is contained in:
996_icu
2025-01-21 09:46:43 +08:00
committed by GitHub
parent 60b2a44a80
commit b730aa6b9e
2 changed files with 10 additions and 0 deletions

View File

@@ -228,6 +228,14 @@ class EAGLEDraftInput(SpecInfo):
assert len(batch.extend_lens) == 1
batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
def filter_batch(
self,
new_indices: torch.Tensor,
):
self.sample_output = self.sample_output[: len(new_indices)]
self.hidden_states = self.hidden_states[: len(new_indices)]
self.verified_id = self.verified_id[: len(new_indices)]
def prepare_for_decode(self, batch: ScheduleBatch):
prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab)
top = torch.topk(prob, self.topk, dim=-1)