From b8574f695359e443e40ebb6a0fb6165b9e722674 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 6 Jan 2025 14:54:18 -0800 Subject: [PATCH] Clean up eagle code (#2756) --- python/sglang/srt/layers/logits_processor.py | 7 +- python/sglang/srt/managers/schedule_batch.py | 8 +- python/sglang/srt/managers/scheduler.py | 11 +- .../srt/model_executor/cuda_graph_runner.py | 9 +- .../srt/model_executor/forward_batch_info.py | 32 ++--- python/sglang/srt/speculative/eagle_utils.py | 121 ++++++++---------- python/sglang/srt/speculative/eagle_worker.py | 78 ++++++----- 7 files changed, 138 insertions(+), 128 deletions(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index bd4b41983..51e73d072 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -74,11 +74,6 @@ class LogitsMetadata: @classmethod def from_forward_batch(cls, forward_batch: ForwardBatch): - if forward_batch.spec_info: - capture_hidden_mode = forward_batch.spec_info.capture_hidden_mode - else: - capture_hidden_mode = CaptureHiddenMode.NULL - if forward_batch.forward_mode.is_extend() and forward_batch.return_logprob: extend_return_logprob = True extend_return_top_logprob = any( @@ -98,7 +93,7 @@ class LogitsMetadata: return cls( forward_mode=forward_batch.forward_mode, - capture_hidden_mode=capture_hidden_mode, + capture_hidden_mode=forward_batch.capture_hidden_mode, extend_return_logprob=extend_return_logprob, extend_return_top_logprob=extend_return_top_logprob, extend_seq_lens=forward_batch.extend_seq_lens, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 2a5db9084..3b056cc5d 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -44,7 +44,7 @@ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool -from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs @@ -1163,6 +1163,11 @@ class ScheduleBatch: input_embeds=self.input_embeds, spec_algorithm=self.spec_algorithm, spec_info=self.spec_info, + capture_hidden_mode=( + getattr(self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL) + if self.spec_info + else CaptureHiddenMode.NULL + ), ) def copy(self): @@ -1237,6 +1242,7 @@ class ModelWorkerBatch: # Speculative decoding spec_algorithm: SpeculativeAlgorithm = None spec_info: Optional[SpecInfo] = None + capture_hidden_mode: CaptureHiddenMode = None @triton.jit diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 8956352ca..180b4d96f 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -962,10 +962,13 @@ class Scheduler: self.tp_worker.forward_batch_generation(model_worker_batch) ) else: - logits_output, next_token_ids, model_worker_batch, spec_info = ( - self.draft_worker.forward_batch_speculative_generation(batch) - ) - batch.spec_info = spec_info + ( + logits_output, + next_token_ids, + model_worker_batch, + num_accepted_tokens, + ) = self.draft_worker.forward_batch_speculative_generation(batch) + self.num_generated_tokens += num_accepted_tokens elif batch.forward_mode.is_idle(): model_worker_batch = batch.get_model_worker_batch() self.tp_worker.forward_batch_idle(model_worker_batch) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 1689f7d66..99e72a3d0 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -322,6 +322,8 @@ class CudaGraphRunner: global_num_tokens = None gathered_buffer = None + spec_info = self.get_spec_info(num_tokens, positions) + forward_batch = ForwardBatch( forward_mode=self.capture_forward_mode, batch_size=bs, @@ -341,7 +343,10 @@ class CudaGraphRunner: mrope_positions=mrope_positions, gathered_buffer=gathered_buffer, spec_algorithm=self.model_runner.spec_algorithm, - spec_info=self.get_spec_info(num_tokens, positions), + spec_info=spec_info, + capture_hidden_mode=( + spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL + ), ) # Attention backend @@ -446,10 +451,10 @@ class CudaGraphRunner: if self.model_runner.is_draft_worker: spec_info = EAGLEDraftInput() + spec_info.load_server_args(self.model_runner.server_args) spec_info.hidden_states = self.hidden_states[:num_tokens] spec_info.positions = positions spec_info.capture_hidden_mode = CaptureHiddenMode.FULL - spec_info.init(self.model_runner.server_args) else: spec_info = EagleVerifyInput( None, diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 926961149..fab8b15a3 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -107,6 +107,21 @@ class ForwardMode(IntEnum): return self == ForwardMode.DUMMY_FIRST +class CaptureHiddenMode(IntEnum): + NULL = auto() + FULL = auto() + LAST = auto() + + def need_capture(self): + return self != CaptureHiddenMode.NULL + + def is_full(self): + return self == CaptureHiddenMode.FULL + + def is_last(self): + return self == CaptureHiddenMode.LAST + + @dataclass class ForwardBatch: """Store all inputs of a forward pass.""" @@ -174,6 +189,7 @@ class ForwardBatch: # Speculative decoding spec_info: SpecInfo = None spec_algorithm: SpeculativeAlgorithm = None + capture_hidden_mode: CaptureHiddenMode = None # For Qwen2-VL mrope_positions: torch.Tensor = None @@ -265,6 +281,7 @@ class ForwardBatch: sampling_info=batch.sampling_info, spec_algorithm=batch.spec_algorithm, spec_info=batch.spec_info, + capture_hidden_mode=batch.capture_hidden_mode, input_embeds=batch.input_embeds, ) @@ -400,18 +417,3 @@ def compute_position_torch( @maybe_torch_compile(dynamic=True) def clamp_position(seq_lens): return torch.clamp((seq_lens - 1), min=0).to(torch.int64) - - -class CaptureHiddenMode(IntEnum): - NULL = auto() - FULL = auto() - LAST = auto() - - def need_capture(self): - return self != CaptureHiddenMode.NULL - - def is_full(self): - return self == CaptureHiddenMode.FULL - - def is_last(self): - return self == CaptureHiddenMode.LAST diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index f76cca2d5..a6fcf2e57 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -9,12 +9,11 @@ import triton.language as tl from sglang.srt.layers.attention.flashinfer_backend import ( create_flashinfer_kv_indices_triton, ) -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.speculative.build_eagle_tree import build_tree_kernel from sglang.srt.speculative.spec_info import SpecInfo if TYPE_CHECKING: - from python.sglang.srt.layers.sampler import SampleOutput from python.sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.server_args import ServerArgs @@ -179,19 +178,9 @@ def generate_draft_decode_kv_indices( class EAGLEDraftInput(SpecInfo): - hidden_states: torch.Tensor = None - verified_id: torch.Tensor = None - positions: torch.Tensor = None - accept_length: torch.Tensor = None - has_finished: bool = False - unfinished_index: List[int] = None - - def init(self, server_args: ServerArgs): + def __init__(self): self.prev_mode = ForwardMode.DECODE self.sample_output = None - self.topk: int = server_args.speculative_eagle_topk - self.num_verify_token: int = server_args.speculative_num_draft_tokens - self.spec_steps = server_args.speculative_num_steps self.scores: torch.Tensor = None self.score_list: List[torch.Tensor] = [] @@ -200,11 +189,20 @@ class EAGLEDraftInput(SpecInfo): self.parents_list: List[torch.Tensor] = [] self.cache_list: List[torch.Tenor] = [] self.iter = 0 - self.root_token: int = None - assert self.topk <= 10, "topk should <= 10" + self.hidden_states: torch.Tensor = None + self.verified_id: torch.Tensor = None + self.positions: torch.Tensor = None + self.accept_length: torch.Tensor = None + self.has_finished: bool = False + self.unfinished_index: List[int] = None - def prepare_for_extend(self, batch: ForwardBatch): + def load_server_args(self, server_args: ServerArgs): + self.topk: int = server_args.speculative_eagle_topk + self.num_verify_token: int = server_args.speculative_num_draft_tokens + self.spec_steps = server_args.speculative_num_steps + + def prepare_for_extend(self, batch: ScheduleBatch): req_pool_indices = batch.alloc_req_slots(len(batch.reqs)) out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) batch.out_cache_loc = out_cache_loc @@ -226,81 +224,72 @@ class EAGLEDraftInput(SpecInfo): pt += req.extend_input_len - seq_lens = [0] + batch.extend_lens - input_ids = batch.input_ids.tolist() - verified_id = batch.spec_info.verified_id.tolist() - model_input_ids = [] - for i in range(len(seq_lens) - 1): - model_input_ids.extend( - input_ids[seq_lens[i] + 1 : seq_lens[i + 1]] + [verified_id[i]] - ) - batch.input_ids = torch.tensor( - model_input_ids, dtype=torch.int32, device="cuda" - ) - - def capture_for_decode( - self, - sample_output: SampleOutput, - hidden_states: torch.Tensor, - prev_mode: ForwardMode, - ): - self.sample_output = sample_output - self.prev_mode = prev_mode - self.hidden_states = hidden_states + # TODO: support batching inputs + assert len(batch.extend_lens) == 1 + batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id)) def prepare_for_decode(self, batch: ScheduleBatch): - prob = self.sample_output # b * (1/topk), vocab + prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab) top = torch.topk(prob, self.topk, dim=-1) - topk_index, topk_p = top.indices, top.values # b * (1/topk), topk - if self.prev_mode == ForwardMode.DECODE: + topk_index, topk_p = ( + top.indices, + top.values, + ) # shape: (b * top_k, top_k) or (b, top_k) + + if self.prev_mode.is_decode(): scores = torch.mul( self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk) - ) # (b, topk) mul (b * topk ,topk) -> b, topk, topk + ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk) topk_cs = torch.topk( scores.flatten(start_dim=1), self.topk, dim=-1 ) # (b, topk) topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values - self.scores = topk_cs_p - - selected_input_index = topk_cs_index.flatten() // self.topk # b* topk + selected_input_index = ( + topk_cs_index.flatten() // self.topk + ) # shape: (b * topk) batch.spec_info.hidden_states = batch.spec_info.hidden_states[ selected_input_index, : ] + topk_index = topk_index.reshape(-1, self.topk**2) batch.input_ids = torch.gather( topk_index, index=topk_cs_index, dim=1 ).flatten() - batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) - self.score_list.append(scores) # b, topk, topk - self.token_list.append(topk_index) # b, topk*topk + batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids)) + + self.scores = topk_cs_p + self.score_list.append(scores) # (b, topk, topk) + self.token_list.append(topk_index) # (b, topk * topk) self.origin_score_list.append(topk_p.reshape(topk_index.shape)) self.parents_list.append( topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk) - ) # b, topk - - elif self.prev_mode in (ForwardMode.EXTEND, ForwardMode.DRAFT_EXTEND): - self.scores = topk_p # b, top_k - self.score_list.append(topk_p.unsqueeze(1)) - self.token_list.append(topk_index) - self.origin_score_list.append(topk_p) + ) # shape: (b, topk) + else: + # ForwardMode.EXTEND or ForwardMode.DRAFT_EXTEND batch.spec_info.hidden_states = ( - batch.spec_info.hidden_states.repeat_interleave(self.topk, 0) + batch.spec_info.hidden_states.repeat_interleave(self.topk, dim=0) ) + batch.input_ids = topk_index.flatten() batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel()) + + self.scores = topk_p # shape: (b, topk) + self.score_list.append(topk_p.unsqueeze(1)) # shape: (b, 1, topk) + self.token_list.append(topk_index) # shape: (b, topk) + self.origin_score_list.append(topk_p) self.parents_list.append( torch.arange(-1, self.topk, dtype=torch.long, device="cuda") .unsqueeze(0) .repeat(self.scores.shape[0], 1) - ) # b, topk+1 + ) # shape: (b, topk + 1) self.cache_list.append(batch.out_cache_loc) self.positions = ( batch.seq_lens[:, None] + torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter ).flatten() - bs = batch.seq_lens.numel() + bs = len(batch.seq_lens) assign_req_to_token_pool[(bs,)]( batch.req_pool_indices, batch.req_to_token_pool.req_to_token, @@ -419,11 +408,6 @@ class EAGLEDraftInput(SpecInfo): ) return bs, kv_indices, cum_kv_seq_len - def clear(self): - self.iter = 0 - self.score_list.clear() - self.positions = None - def clear_draft_cache(self, batch): draft_cache = torch.cat(self.cache_list, dim=0) batch.token_to_kv_pool.free(draft_cache) @@ -460,7 +444,6 @@ class EAGLEDraftInput(SpecInfo): [self.hidden_states, spec_info.hidden_states], axis=0 ) self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0) - # self.positions = torch.cat([self.positions, spec_info.positions], axis=0) self.sample_output = torch.cat([self.sample_output, spec_info.sample_output]) @@ -568,9 +551,6 @@ class EagleVerifyInput(SpecInfo): ) accept_index = accept_index[accept_index != -1] - # extract_index = extract_index[extract_index != 0] - - draft_input = EAGLEDraftInput() accept_length_cpu = accept_length.tolist() verified_id = predict[accept_index] @@ -596,6 +576,7 @@ class EagleVerifyInput(SpecInfo): # retracted_reqs, new_token_ratio = batch.retract_decode() low = 0 + draft_input = EAGLEDraftInput() for i, (req, verified_len) in enumerate(zip(batch.reqs, accept_length_cpu)): req.output_ids.extend(verified_id_cpu[low : low + verified_len + 1]) req.check_finished() @@ -615,4 +596,10 @@ class EagleVerifyInput(SpecInfo): draft_input.unfinished_index = unfinished_index logits_output.next_token_logits = logits_output.next_token_logits[accept_index] - return draft_input, logits_output, verified_id, finished_extend_len + return ( + draft_input, + logits_output, + verified_id, + finished_extend_len, + accept_length_cpu, + ) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 6701c66ac..16d54c43b 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -51,63 +51,72 @@ class EAGLEWorker(TpModelWorker): batch.spec_info.prepare_for_decode(batch) model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST + forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST logits_output = self.model_runner.forward(forward_batch) self.capture_for_decode(logits_output, forward_batch) def forward_draft_extend(self, batch: ScheduleBatch): - self._swap_mem_pool(batch, self.model_runner) + self._set_mem_pool(batch, self.model_runner) batch.spec_info.prepare_for_extend(batch) model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST + forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST logits_output = self.model_runner.forward(forward_batch) self.capture_for_decode(logits_output, forward_batch) - self._swap_mem_pool(batch, self.target_worker.model_runner) + self._set_mem_pool(batch, self.target_worker.model_runner) def forward_batch_speculative_generation(self, batch: ScheduleBatch): if batch.forward_mode.is_decode(): - prev_spec_info = batch.spec_info - self._swap_mem_pool(batch, self.model_runner) + # Draft + self._set_mem_pool(batch, self.model_runner) for i in range(self.server_args.speculative_num_steps): self.forward_draft_decode(batch) batch.spec_info.clear_draft_cache(batch) - self._swap_mem_pool(batch, self.target_worker.model_runner) + self._set_mem_pool(batch, self.target_worker.model_runner) + + # Verify ( next_draft_input, logits_output, verified_id, self.finish_extend_len, + accept_length_cpu, model_worker_batch, ) = self.verify(batch) - next_draft_input.init(self.server_args) + next_draft_input.load_server_args(self.server_args) batch.spec_info = next_draft_input # if it is None, means all requsets are finished if batch.spec_info.verified_id is not None: - self.forward_extend_after_decode(batch) - batch.spec_info = prev_spec_info - return logits_output, verified_id, model_worker_batch, next_draft_input + self.forward_draft_extend_after_decode(batch) + return ( + logits_output, + verified_id, + model_worker_batch, + sum(accept_length_cpu), + ) else: - spec_info = EAGLEDraftInput() - spec_info.init(self.server_args) + # Forward with the target model and get hidden states. + # We need the full hidden states to prefill the KV cache of the draft model. model_worker_batch = batch.get_model_worker_batch() - model_worker_batch.spec_info = spec_info - spec_info.capture_hidden_mode = CaptureHiddenMode.FULL + model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL logits_output, next_token_ids = self.target_worker.forward_batch_generation( model_worker_batch ) - model_worker_batch.spec_info.verified_id = next_token_ids - model_worker_batch.spec_info.hidden_states = logits_output.hidden_states + + # Forward with the draft model. + spec_info = EAGLEDraftInput() + spec_info.load_server_args(self.server_args) + spec_info.hidden_states = logits_output.hidden_states + spec_info.verified_id = next_token_ids batch.spec_info = spec_info self.forward_draft_extend(batch) - batch.spec_info = None - return logits_output, next_token_ids, model_worker_batch, spec_info + return logits_output, next_token_ids, model_worker_batch, 0 def verify(self, batch: ScheduleBatch): verify_input = batch.spec_info.prepare_for_verify(batch) - batch.forward_mode = ForwardMode.TARGET_VERIFY verify_input.prepare_for_verify(batch) + batch.forward_mode = ForwardMode.TARGET_VERIFY batch.spec_info = verify_input batch.spec_info.capture_hidden_mode = CaptureHiddenMode.FULL model_worker_batch = batch.get_model_worker_batch() @@ -119,38 +128,41 @@ class EAGLEWorker(TpModelWorker): batch.forward_mode = ForwardMode.DECODE return res + (model_worker_batch,) - def _swap_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner): + def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner): batch.token_to_kv_pool = runner.token_to_kv_pool batch.req_to_token_pool = runner.req_to_token_pool - def forward_extend_after_decode(self, batch: ScheduleBatch): - self._swap_mem_pool(batch, self.model_runner) + def forward_draft_extend_after_decode(self, batch: ScheduleBatch): + self._set_mem_pool(batch, self.model_runner) batch.forward_mode = ForwardMode.DRAFT_EXTEND if batch.spec_info.has_finished: index = batch.spec_info.unfinished_index seq_lens = batch.seq_lens batch.seq_lens = batch.seq_lens[index] + batch.spec_info.prepare_extend_after_decode(batch) model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - forward_batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST + forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST logits_output = self.model_runner.forward(forward_batch) + batch.spec_info.hidden_states = logits_output.hidden_states self.capture_for_decode(logits_output, forward_batch) batch.forward_mode = ForwardMode.DECODE if batch.spec_info.has_finished: batch.seq_lens = seq_lens - self._swap_mem_pool(batch, self.target_worker.model_runner) + self._set_mem_pool(batch, self.target_worker.model_runner) - def capture_for_decode(self, logits_output, forward_batch): - if isinstance(logits_output, LogitsProcessorOutput): - logits = logits_output.next_token_logits + def capture_for_decode( + self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch + ): sample_output = torch.softmax( - logits, dim=-1 - ) # TODO: Support more sampling method @kavioyu - forward_batch.spec_info.capture_for_decode( - sample_output, logits_output.hidden_states, forward_batch.forward_mode - ) + logits_output.next_token_logits, dim=-1 + ) # TODO(kavioyu): Support more sampling methods + spec_info = forward_batch.spec_info + spec_info.sample_output = sample_output + spec_info.hidden_states = logits_output.hidden_states + spec_info.prev_mode = forward_batch.forward_mode # Don't support prefix share now. def finish_request(self, reqs: Union[Req, List[Req]]):