fix unexcepted answer in EAGLE mode (#9252)
This commit is contained in:
@@ -177,11 +177,24 @@ class EagleDraftInput:
|
||||
)
|
||||
return kv_indices, cum_kv_seq_len, qo_indptr, None
|
||||
|
||||
def filter_batch(self, new_indices: torch.Tensor):
|
||||
self.topk_p = self.topk_p[: len(new_indices)]
|
||||
self.topk_index = self.topk_index[: len(new_indices)]
|
||||
self.hidden_states = self.hidden_states[: len(new_indices)]
|
||||
self.verified_id = self.verified_id[: len(new_indices)]
|
||||
def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
|
||||
if has_been_filtered:
|
||||
# in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index`
|
||||
# therefore, we don't need to filter the batch again in scheduler
|
||||
if len(new_indices) != len(self.topk_p):
|
||||
logger.warning(
|
||||
f"length of new_indices: {len(new_indices)} != length of topk_p: {len(self.topk_p)}, this should not happen"
|
||||
)
|
||||
self.topk_p = self.topk_p[: len(new_indices)]
|
||||
self.topk_index = self.topk_index[: len(new_indices)]
|
||||
self.hidden_states = self.hidden_states[: len(new_indices)]
|
||||
self.verified_id = self.verified_id[: len(new_indices)]
|
||||
else:
|
||||
# in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
|
||||
self.topk_p = self.topk_p[new_indices]
|
||||
self.topk_index = self.topk_index[new_indices]
|
||||
self.hidden_states = self.hidden_states[new_indices]
|
||||
self.verified_id = self.verified_id[new_indices]
|
||||
|
||||
def merge_batch(self, spec_info: EagleDraftInput):
|
||||
if self.hidden_states is None:
|
||||
|
||||
@@ -836,6 +836,21 @@ class EAGLEWorker(TpModelWorker):
|
||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||
assert forward_batch.spec_info is batch.spec_info
|
||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||
has_finished, unfinished_req_index = False, []
|
||||
for i, req in enumerate(batch.reqs):
|
||||
if req.finished():
|
||||
has_finished = True
|
||||
else:
|
||||
unfinished_req_index.append(i)
|
||||
if has_finished:
|
||||
unfinished_index_device = torch.tensor(
|
||||
unfinished_req_index,
|
||||
dtype=torch.int64,
|
||||
device=batch.spec_info.topk_p.device,
|
||||
)
|
||||
batch.spec_info.filter_batch(
|
||||
unfinished_index_device, has_been_filtered=False
|
||||
)
|
||||
|
||||
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
||||
assert isinstance(batch.spec_info, EagleDraftInput)
|
||||
|
||||
Reference in New Issue
Block a user