diff --git a/python/sglang/srt/managers/mm_utils.py b/python/sglang/srt/managers/mm_utils.py index bedf50a66..f495904d5 100644 --- a/python/sglang/srt/managers/mm_utils.py +++ b/python/sglang/srt/managers/mm_utils.py @@ -629,6 +629,7 @@ def general_mm_embed_routine( embed_tokens = language_model.get_input_embeddings() if ( not forward_batch.forward_mode.is_decode() + and not forward_batch.forward_mode.is_target_verify() and forward_batch.contains_mm_inputs() ): mm_inputs_list = [ diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 2effec9c0..8413b164b 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -317,7 +317,9 @@ class CudaGraphRunner: (self.max_num_token,), dtype=self._cache_loc_dtype() ) self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) - self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64) + self.mrope_positions = torch.zeros( + (3, self.max_num_token), dtype=torch.int64 + ) self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32) self.tbo_plugin = TboCudaGraphRunnerPlugin() @@ -532,7 +534,7 @@ class CudaGraphRunner: encoder_lens = self.encoder_lens[:bs] else: encoder_lens = None - mrope_positions = self.mrope_positions[:, :bs] + mrope_positions = self.mrope_positions[:, :num_tokens] next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens] self.num_token_non_padded[...] = num_tokens @@ -751,7 +753,7 @@ class CudaGraphRunner: if self.is_encoder_decoder: self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) if forward_batch.mrope_positions is not None: - self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) + self.mrope_positions[:, :raw_num_token].copy_(forward_batch.mrope_positions) if self.require_gathered_buffer: self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs) self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 8904e89f1..dbe99fc3b 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -441,7 +441,13 @@ class ForwardBatch: ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens if model_runner.model_is_mrope: - ret._compute_mrope_positions(model_runner, batch) + if ( + ret.spec_info is not None + and getattr(ret.spec_info, "positions", None) is not None + ): + ret._compute_spec_mrope_positions(model_runner, batch) + else: + ret._compute_mrope_positions(model_runner, batch) # Init lora information if model_runner.server_args.enable_lora: @@ -507,6 +513,52 @@ class ForwardBatch: or self.contains_image_inputs() ) + def _compute_spec_mrope_positions( + self, model_runner: ModelRunner, batch: ModelWorkerBatch + ): + # TODO support batched deltas + batch_size = self.seq_lens.shape[0] + device = model_runner.device + mm_inputs = batch.multimodal_inputs + + if batch.forward_mode.is_draft_extend(): # draft_extend_after_decode + mrope_deltas = [] + extend_lens = [] + for batch_idx in range(batch_size): + extend_seq_len = batch.extend_seq_lens[batch_idx] + extend_lens.append(extend_seq_len) + mrope_delta = ( + torch.zeros(1, dtype=torch.int64) + if mm_inputs[batch_idx] is None + else mm_inputs[batch_idx].mrope_position_delta.squeeze(0) + ) + mrope_deltas.append(mrope_delta.to(device=device)) + position_chunks = torch.split(batch.spec_info.positions, extend_lens) + mrope_positions_list = [ + pos_chunk + delta + for pos_chunk, delta in zip(position_chunks, mrope_deltas) + ] + next_input_positions = ( + torch.cat(mrope_positions_list, dim=0).unsqueeze(0).repeat(3, 1) + ) + + else: # target_verify or draft_decode + seq_positions = batch.spec_info.positions.view(batch_size, -1) + mrope_deltas = [ + ( + torch.tensor([0], dtype=torch.int64) + if mm_inputs[i] is None + else mm_inputs[i].mrope_position_delta.squeeze(0) + ) + for i in range(batch_size) + ] + mrope_delta_tensor = torch.stack(mrope_deltas, dim=0).to(device=device) + next_input_positions = ( + (seq_positions + mrope_delta_tensor).flatten().unsqueeze(0).repeat(3, 1) + ) + + self.mrope_positions = next_input_positions + def _compute_mrope_positions( self, model_runner: ModelRunner, batch: ModelWorkerBatch ): diff --git a/python/sglang/srt/models/llama_eagle3.py b/python/sglang/srt/models/llama_eagle3.py index 5e632d5e4..87ae7ade5 100644 --- a/python/sglang/srt/models/llama_eagle3.py +++ b/python/sglang/srt/models/llama_eagle3.py @@ -109,6 +109,16 @@ class LlamaModel(nn.Module): ) -> None: super().__init__() self.config = config + + self.is_mrope_enabled = ( + hasattr(config, "rope_scaling") + and config.rope_scaling is not None + and "mrope_section" in config.rope_scaling + ) + # fix rope_scaling for qwen2.5-vl + if self.is_mrope_enabled: + config.rope_scaling["rope_type"] = "default" + self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( config.vocab_size, @@ -144,6 +154,9 @@ class LlamaModel(nn.Module): else: embeds = input_embeds + if self.is_mrope_enabled: + positions = forward_batch.mrope_positions + hidden_states = forward_batch.spec_info.hidden_states if hidden_states.shape[-1] != embeds.shape[-1]: hidden_states = self.fc(hidden_states) diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 531f5b6e9..b3d5fb9ad 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -454,6 +454,9 @@ class Qwen2ForCausalLM(nn.Module): # For EAGLE3 support self.capture_aux_hidden_states = False + # For EAGLE3 support + self.capture_aux_hidden_states = False + def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embedding(input_ids) @@ -481,6 +484,10 @@ class Qwen2ForCausalLM(nn.Module): if self.capture_aux_hidden_states: hidden_states, aux_hidden_states = hidden_states + aux_hidden_states = None + if self.capture_aux_hidden_states: + hidden_states, aux_hidden_states = hidden_states + if self.pp_group.is_last_rank: if not get_embedding: return self.logits_processor( diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 82370de54..9afb2b1ab 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -518,6 +518,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + # For EAGLE3 support + self.capture_aux_hidden_states = False + def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): pattern = MultiModalityDataPaddingPatternMultimodalTokens() return pattern.pad_input_tokens(input_ids, mm_inputs) @@ -588,9 +591,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): positions=positions, ) + aux_hidden_states = None + if self.capture_aux_hidden_states: + hidden_states, aux_hidden_states = hidden_states + if not get_embedding: return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states ) else: return self.pooler(hidden_states, forward_batch) @@ -644,5 +651,21 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): + self.capture_aux_hidden_states = True + self.model.capture_aux_hidden_states = True + if layer_ids is None: + num_layers = self.config.num_hidden_layers + self.model.layers_to_capture = [ + 2, + num_layers // 2, + num_layers - 3, + ] # Specific layers for EAGLE3 support + else: + self.model.layers_to_capture = [val + 1 for val in layer_ids] + EntryClass = [Qwen2_5_VLForConditionalGeneration] diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 3ee3b1c54..66d2d5a34 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -91,6 +91,9 @@ class EAGLEDraftCudaGraphRunner: (self.max_num_token * self.speculative_num_steps,), dtype=torch.int64 ) self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) + self.mrope_positions = torch.zeros( + (3, self.max_num_token), dtype=torch.int64 + ) self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32) self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64) self.hidden_states = torch.zeros( @@ -159,6 +162,7 @@ class EAGLEDraftCudaGraphRunner: seq_lens = self.seq_lens[:num_seqs] out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps] positions = self.positions[:num_tokens] + mrope_positions = self.mrope_positions[:, :num_tokens] topk_p = self.topk_p[:num_seqs] topk_index = self.topk_index[:num_seqs] hidden_states = self.hidden_states[:num_seqs] @@ -224,6 +228,7 @@ class EAGLEDraftCudaGraphRunner: seq_lens_sum=seq_lens.sum().item(), return_logprob=False, positions=positions, + mrope_positions=mrope_positions, global_num_tokens_gpu=global_num_tokens, dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(), global_dp_buffer_len=global_dp_buffer_len, diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index 4f4403fee..18ab617bd 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -80,6 +80,9 @@ class EAGLEDraftExtendCudaGraphRunner: self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) self.out_cache_loc = torch.ones((self.max_num_token,), dtype=torch.int64) self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) + self.mrope_positions = torch.zeros( + (3, self.max_num_token), dtype=torch.int64 + ) if self.eagle_worker.speculative_algorithm.is_eagle3(): self.hidden_states = torch.zeros( @@ -189,6 +192,7 @@ class EAGLEDraftExtendCudaGraphRunner: accept_length = self.accept_length[:bs] out_cache_loc = self.out_cache_loc[:num_tokens] positions = self.positions[:num_tokens] + mrope_positions = self.mrope_positions[:, :num_tokens] hidden_states = self.hidden_states[:num_tokens] next_token_logits_buffer = self.next_token_logits_buffer[:bs] @@ -247,6 +251,7 @@ class EAGLEDraftExtendCudaGraphRunner: seq_lens_sum=seq_lens.sum().item(), return_logprob=False, positions=positions, + mrope_positions=mrope_positions, global_num_tokens_gpu=self.global_num_tokens_gpu, global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu, dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(), diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 56c120a0f..2c3940943 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -14,6 +14,7 @@ from sglang.srt.distributed import ( ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs +from sglang.srt.managers.mm_utils import embed_mm_inputs from sglang.srt.managers.schedule_batch import ( ScheduleBatch, get_last_loc,