[BugFix] Fix incorrect hidden_states_tensor in pd disaggregation + eagle (#9976)

This commit is contained in:
Ziming Huang
2025-09-18 01:37:14 +08:00
committed by GitHub
parent 77098aea7b
commit b73ac629cd

View File

@@ -421,9 +421,14 @@ class SchedulerDisaggregationPrefillMixin:
last_hidden_index = (
hidden_state_offset + extend_input_len_per_req[i] - 1
)
req.hidden_states_tensor = (
logits_output.hidden_states[last_hidden_index].cpu().clone()
)
if self.spec_algorithm.is_eagle3():
req.hidden_states_tensor = (
batch.spec_info.hidden_states[i].cpu().clone()
)
else:
req.hidden_states_tensor = (
logits_output.hidden_states[last_hidden_index].cpu().clone()
)
hidden_state_offset += extend_input_len_per_req[i]
else:
req.hidden_states_tensor = None