Qwen2.5-VL eagle3 infer (#8801)

This commit is contained in:
Lzhang-hub
2025-09-08 11:44:34 +08:00
committed by GitHub
parent 7802586cab
commit 37d83c6e6d
9 changed files with 114 additions and 5 deletions

View File

@@ -317,7 +317,9 @@ class CudaGraphRunner:
(self.max_num_token,), dtype=self._cache_loc_dtype()
)
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
self.mrope_positions = torch.zeros(
(3, self.max_num_token), dtype=torch.int64
)
self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
self.tbo_plugin = TboCudaGraphRunnerPlugin()
@@ -532,7 +534,7 @@ class CudaGraphRunner:
encoder_lens = self.encoder_lens[:bs]
else:
encoder_lens = None
mrope_positions = self.mrope_positions[:, :bs]
mrope_positions = self.mrope_positions[:, :num_tokens]
next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens]
self.num_token_non_padded[...] = num_tokens
@@ -751,7 +753,7 @@ class CudaGraphRunner:
if self.is_encoder_decoder:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
self.mrope_positions[:, :raw_num_token].copy_(forward_batch.mrope_positions)
if self.require_gathered_buffer:
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)

View File

@@ -441,7 +441,13 @@ class ForwardBatch:
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
if model_runner.model_is_mrope:
ret._compute_mrope_positions(model_runner, batch)
if (
ret.spec_info is not None
and getattr(ret.spec_info, "positions", None) is not None
):
ret._compute_spec_mrope_positions(model_runner, batch)
else:
ret._compute_mrope_positions(model_runner, batch)
# Init lora information
if model_runner.server_args.enable_lora:
@@ -507,6 +513,52 @@ class ForwardBatch:
or self.contains_image_inputs()
)
def _compute_spec_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
):
# TODO support batched deltas
batch_size = self.seq_lens.shape[0]
device = model_runner.device
mm_inputs = batch.multimodal_inputs
if batch.forward_mode.is_draft_extend(): # draft_extend_after_decode
mrope_deltas = []
extend_lens = []
for batch_idx in range(batch_size):
extend_seq_len = batch.extend_seq_lens[batch_idx]
extend_lens.append(extend_seq_len)
mrope_delta = (
torch.zeros(1, dtype=torch.int64)
if mm_inputs[batch_idx] is None
else mm_inputs[batch_idx].mrope_position_delta.squeeze(0)
)
mrope_deltas.append(mrope_delta.to(device=device))
position_chunks = torch.split(batch.spec_info.positions, extend_lens)
mrope_positions_list = [
pos_chunk + delta
for pos_chunk, delta in zip(position_chunks, mrope_deltas)
]
next_input_positions = (
torch.cat(mrope_positions_list, dim=0).unsqueeze(0).repeat(3, 1)
)
else: # target_verify or draft_decode
seq_positions = batch.spec_info.positions.view(batch_size, -1)
mrope_deltas = [
(
torch.tensor([0], dtype=torch.int64)
if mm_inputs[i] is None
else mm_inputs[i].mrope_position_delta.squeeze(0)
)
for i in range(batch_size)
]
mrope_delta_tensor = torch.stack(mrope_deltas, dim=0).to(device=device)
next_input_positions = (
(seq_positions + mrope_delta_tensor).flatten().unsqueeze(0).repeat(3, 1)
)
self.mrope_positions = next_input_positions
def _compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
):