Qwen2.5-VL eagle3 infer (#8801)
This commit is contained in:
@@ -91,6 +91,9 @@ class EAGLEDraftCudaGraphRunner:
|
||||
(self.max_num_token * self.speculative_num_steps,), dtype=torch.int64
|
||||
)
|
||||
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||
self.mrope_positions = torch.zeros(
|
||||
(3, self.max_num_token), dtype=torch.int64
|
||||
)
|
||||
self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
|
||||
self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
|
||||
self.hidden_states = torch.zeros(
|
||||
@@ -159,6 +162,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
seq_lens = self.seq_lens[:num_seqs]
|
||||
out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps]
|
||||
positions = self.positions[:num_tokens]
|
||||
mrope_positions = self.mrope_positions[:, :num_tokens]
|
||||
topk_p = self.topk_p[:num_seqs]
|
||||
topk_index = self.topk_index[:num_seqs]
|
||||
hidden_states = self.hidden_states[:num_seqs]
|
||||
@@ -224,6 +228,7 @@ class EAGLEDraftCudaGraphRunner:
|
||||
seq_lens_sum=seq_lens.sum().item(),
|
||||
return_logprob=False,
|
||||
positions=positions,
|
||||
mrope_positions=mrope_positions,
|
||||
global_num_tokens_gpu=global_num_tokens,
|
||||
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
||||
global_dp_buffer_len=global_dp_buffer_len,
|
||||
|
||||
@@ -80,6 +80,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
||||
self.out_cache_loc = torch.ones((self.max_num_token,), dtype=torch.int64)
|
||||
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||
self.mrope_positions = torch.zeros(
|
||||
(3, self.max_num_token), dtype=torch.int64
|
||||
)
|
||||
|
||||
if self.eagle_worker.speculative_algorithm.is_eagle3():
|
||||
self.hidden_states = torch.zeros(
|
||||
@@ -189,6 +192,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
accept_length = self.accept_length[:bs]
|
||||
out_cache_loc = self.out_cache_loc[:num_tokens]
|
||||
positions = self.positions[:num_tokens]
|
||||
mrope_positions = self.mrope_positions[:, :num_tokens]
|
||||
hidden_states = self.hidden_states[:num_tokens]
|
||||
next_token_logits_buffer = self.next_token_logits_buffer[:bs]
|
||||
|
||||
@@ -247,6 +251,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
seq_lens_sum=seq_lens.sum().item(),
|
||||
return_logprob=False,
|
||||
positions=positions,
|
||||
mrope_positions=mrope_positions,
|
||||
global_num_tokens_gpu=self.global_num_tokens_gpu,
|
||||
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
|
||||
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
||||
|
||||
@@ -14,6 +14,7 @@ from sglang.srt.distributed import (
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
||||
from sglang.srt.managers.mm_utils import embed_mm_inputs
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
ScheduleBatch,
|
||||
get_last_loc,
|
||||
|
||||
Reference in New Issue
Block a user