diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 86ef0498f..91caf99db 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -430,24 +430,12 @@ class SchedulerDisaggregationPrefillMixin: self.tree_cache.cache_unfinished_req(req) # update the tree and lock req.add_latency(RequestStage.PREFILL_FORWARD) self.disagg_prefill_inflight_queue.append(req) - if ( - logits_output is not None - and logits_output.hidden_states is not None - ): - last_hidden_index = ( - hidden_state_offset + extend_input_len_per_req[i] - 1 - ) + if self.spec_algorithm.is_eagle() and batch.spec_info is not None: req.output_topk_p = batch.spec_info.topk_p[i] req.output_topk_index = batch.spec_info.topk_index[i] - if self.spec_algorithm.is_eagle3(): - req.hidden_states_tensor = ( - batch.spec_info.hidden_states[i].cpu().clone() - ) - else: - req.hidden_states_tensor = ( - logits_output.hidden_states[last_hidden_index].cpu().clone() - ) - hidden_state_offset += extend_input_len_per_req[i] + req.hidden_states_tensor = ( + batch.spec_info.hidden_states[i].cpu().clone() + ) else: req.hidden_states_tensor = None if req.return_logprob: