Qwen2.5-VL eagle3 infer (#8801)

This commit is contained in:
Lzhang-hub
2025-09-08 11:44:34 +08:00
committed by GitHub
parent 7802586cab
commit 37d83c6e6d
9 changed files with 114 additions and 5 deletions

View File

@@ -91,6 +91,9 @@ class EAGLEDraftCudaGraphRunner:
(self.max_num_token * self.speculative_num_steps,), dtype=torch.int64
)
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.mrope_positions = torch.zeros(
(3, self.max_num_token), dtype=torch.int64
)
self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
self.hidden_states = torch.zeros(
@@ -159,6 +162,7 @@ class EAGLEDraftCudaGraphRunner:
seq_lens = self.seq_lens[:num_seqs]
out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps]
positions = self.positions[:num_tokens]
mrope_positions = self.mrope_positions[:, :num_tokens]
topk_p = self.topk_p[:num_seqs]
topk_index = self.topk_index[:num_seqs]
hidden_states = self.hidden_states[:num_seqs]
@@ -224,6 +228,7 @@ class EAGLEDraftCudaGraphRunner:
seq_lens_sum=seq_lens.sum().item(),
return_logprob=False,
positions=positions,
mrope_positions=mrope_positions,
global_num_tokens_gpu=global_num_tokens,
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
global_dp_buffer_len=global_dp_buffer_len,

View File

@@ -80,6 +80,9 @@ class EAGLEDraftExtendCudaGraphRunner:
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
self.out_cache_loc = torch.ones((self.max_num_token,), dtype=torch.int64)
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.mrope_positions = torch.zeros(
(3, self.max_num_token), dtype=torch.int64
)
if self.eagle_worker.speculative_algorithm.is_eagle3():
self.hidden_states = torch.zeros(
@@ -189,6 +192,7 @@ class EAGLEDraftExtendCudaGraphRunner:
accept_length = self.accept_length[:bs]
out_cache_loc = self.out_cache_loc[:num_tokens]
positions = self.positions[:num_tokens]
mrope_positions = self.mrope_positions[:, :num_tokens]
hidden_states = self.hidden_states[:num_tokens]
next_token_logits_buffer = self.next_token_logits_buffer[:bs]
@@ -247,6 +251,7 @@ class EAGLEDraftExtendCudaGraphRunner:
seq_lens_sum=seq_lens.sum().item(),
return_logprob=False,
positions=positions,
mrope_positions=mrope_positions,
global_num_tokens_gpu=self.global_num_tokens_gpu,
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),

View File

@@ -14,6 +14,7 @@ from sglang.srt.distributed import (
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.mm_utils import embed_mm_inputs
from sglang.srt.managers.schedule_batch import (
ScheduleBatch,
get_last_loc,