[BugFix] Fix incorrect hidden_states_tensor in pd disaggregation + eagle (#9976)
This commit is contained in:
@@ -421,9 +421,14 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
last_hidden_index = (
|
||||
hidden_state_offset + extend_input_len_per_req[i] - 1
|
||||
)
|
||||
req.hidden_states_tensor = (
|
||||
logits_output.hidden_states[last_hidden_index].cpu().clone()
|
||||
)
|
||||
if self.spec_algorithm.is_eagle3():
|
||||
req.hidden_states_tensor = (
|
||||
batch.spec_info.hidden_states[i].cpu().clone()
|
||||
)
|
||||
else:
|
||||
req.hidden_states_tensor = (
|
||||
logits_output.hidden_states[last_hidden_index].cpu().clone()
|
||||
)
|
||||
hidden_state_offset += extend_input_len_per_req[i]
|
||||
else:
|
||||
req.hidden_states_tensor = None
|
||||
|
||||
Reference in New Issue
Block a user