Qwen2.5-VL eagle3 infer (#8801)
This commit is contained in:
@@ -629,6 +629,7 @@ def general_mm_embed_routine(
|
|||||||
embed_tokens = language_model.get_input_embeddings()
|
embed_tokens = language_model.get_input_embeddings()
|
||||||
if (
|
if (
|
||||||
not forward_batch.forward_mode.is_decode()
|
not forward_batch.forward_mode.is_decode()
|
||||||
|
and not forward_batch.forward_mode.is_target_verify()
|
||||||
and forward_batch.contains_mm_inputs()
|
and forward_batch.contains_mm_inputs()
|
||||||
):
|
):
|
||||||
mm_inputs_list = [
|
mm_inputs_list = [
|
||||||
|
|||||||
@@ -317,7 +317,9 @@ class CudaGraphRunner:
|
|||||||
(self.max_num_token,), dtype=self._cache_loc_dtype()
|
(self.max_num_token,), dtype=self._cache_loc_dtype()
|
||||||
)
|
)
|
||||||
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||||
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
|
self.mrope_positions = torch.zeros(
|
||||||
|
(3, self.max_num_token), dtype=torch.int64
|
||||||
|
)
|
||||||
self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
|
self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
|
||||||
self.tbo_plugin = TboCudaGraphRunnerPlugin()
|
self.tbo_plugin = TboCudaGraphRunnerPlugin()
|
||||||
|
|
||||||
@@ -532,7 +534,7 @@ class CudaGraphRunner:
|
|||||||
encoder_lens = self.encoder_lens[:bs]
|
encoder_lens = self.encoder_lens[:bs]
|
||||||
else:
|
else:
|
||||||
encoder_lens = None
|
encoder_lens = None
|
||||||
mrope_positions = self.mrope_positions[:, :bs]
|
mrope_positions = self.mrope_positions[:, :num_tokens]
|
||||||
next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens]
|
next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens]
|
||||||
self.num_token_non_padded[...] = num_tokens
|
self.num_token_non_padded[...] = num_tokens
|
||||||
|
|
||||||
@@ -751,7 +753,7 @@ class CudaGraphRunner:
|
|||||||
if self.is_encoder_decoder:
|
if self.is_encoder_decoder:
|
||||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||||
if forward_batch.mrope_positions is not None:
|
if forward_batch.mrope_positions is not None:
|
||||||
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
self.mrope_positions[:, :raw_num_token].copy_(forward_batch.mrope_positions)
|
||||||
if self.require_gathered_buffer:
|
if self.require_gathered_buffer:
|
||||||
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
|
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
|
||||||
self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
|
self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
|
||||||
|
|||||||
@@ -441,7 +441,13 @@ class ForwardBatch:
|
|||||||
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
||||||
|
|
||||||
if model_runner.model_is_mrope:
|
if model_runner.model_is_mrope:
|
||||||
ret._compute_mrope_positions(model_runner, batch)
|
if (
|
||||||
|
ret.spec_info is not None
|
||||||
|
and getattr(ret.spec_info, "positions", None) is not None
|
||||||
|
):
|
||||||
|
ret._compute_spec_mrope_positions(model_runner, batch)
|
||||||
|
else:
|
||||||
|
ret._compute_mrope_positions(model_runner, batch)
|
||||||
|
|
||||||
# Init lora information
|
# Init lora information
|
||||||
if model_runner.server_args.enable_lora:
|
if model_runner.server_args.enable_lora:
|
||||||
@@ -507,6 +513,52 @@ class ForwardBatch:
|
|||||||
or self.contains_image_inputs()
|
or self.contains_image_inputs()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _compute_spec_mrope_positions(
|
||||||
|
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
||||||
|
):
|
||||||
|
# TODO support batched deltas
|
||||||
|
batch_size = self.seq_lens.shape[0]
|
||||||
|
device = model_runner.device
|
||||||
|
mm_inputs = batch.multimodal_inputs
|
||||||
|
|
||||||
|
if batch.forward_mode.is_draft_extend(): # draft_extend_after_decode
|
||||||
|
mrope_deltas = []
|
||||||
|
extend_lens = []
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
extend_seq_len = batch.extend_seq_lens[batch_idx]
|
||||||
|
extend_lens.append(extend_seq_len)
|
||||||
|
mrope_delta = (
|
||||||
|
torch.zeros(1, dtype=torch.int64)
|
||||||
|
if mm_inputs[batch_idx] is None
|
||||||
|
else mm_inputs[batch_idx].mrope_position_delta.squeeze(0)
|
||||||
|
)
|
||||||
|
mrope_deltas.append(mrope_delta.to(device=device))
|
||||||
|
position_chunks = torch.split(batch.spec_info.positions, extend_lens)
|
||||||
|
mrope_positions_list = [
|
||||||
|
pos_chunk + delta
|
||||||
|
for pos_chunk, delta in zip(position_chunks, mrope_deltas)
|
||||||
|
]
|
||||||
|
next_input_positions = (
|
||||||
|
torch.cat(mrope_positions_list, dim=0).unsqueeze(0).repeat(3, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
else: # target_verify or draft_decode
|
||||||
|
seq_positions = batch.spec_info.positions.view(batch_size, -1)
|
||||||
|
mrope_deltas = [
|
||||||
|
(
|
||||||
|
torch.tensor([0], dtype=torch.int64)
|
||||||
|
if mm_inputs[i] is None
|
||||||
|
else mm_inputs[i].mrope_position_delta.squeeze(0)
|
||||||
|
)
|
||||||
|
for i in range(batch_size)
|
||||||
|
]
|
||||||
|
mrope_delta_tensor = torch.stack(mrope_deltas, dim=0).to(device=device)
|
||||||
|
next_input_positions = (
|
||||||
|
(seq_positions + mrope_delta_tensor).flatten().unsqueeze(0).repeat(3, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mrope_positions = next_input_positions
|
||||||
|
|
||||||
def _compute_mrope_positions(
|
def _compute_mrope_positions(
|
||||||
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -109,6 +109,16 @@ class LlamaModel(nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
self.is_mrope_enabled = (
|
||||||
|
hasattr(config, "rope_scaling")
|
||||||
|
and config.rope_scaling is not None
|
||||||
|
and "mrope_section" in config.rope_scaling
|
||||||
|
)
|
||||||
|
# fix rope_scaling for qwen2.5-vl
|
||||||
|
if self.is_mrope_enabled:
|
||||||
|
config.rope_scaling["rope_type"] = "default"
|
||||||
|
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
@@ -144,6 +154,9 @@ class LlamaModel(nn.Module):
|
|||||||
else:
|
else:
|
||||||
embeds = input_embeds
|
embeds = input_embeds
|
||||||
|
|
||||||
|
if self.is_mrope_enabled:
|
||||||
|
positions = forward_batch.mrope_positions
|
||||||
|
|
||||||
hidden_states = forward_batch.spec_info.hidden_states
|
hidden_states = forward_batch.spec_info.hidden_states
|
||||||
if hidden_states.shape[-1] != embeds.shape[-1]:
|
if hidden_states.shape[-1] != embeds.shape[-1]:
|
||||||
hidden_states = self.fc(hidden_states)
|
hidden_states = self.fc(hidden_states)
|
||||||
|
|||||||
@@ -454,6 +454,9 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
# For EAGLE3 support
|
# For EAGLE3 support
|
||||||
self.capture_aux_hidden_states = False
|
self.capture_aux_hidden_states = False
|
||||||
|
|
||||||
|
# For EAGLE3 support
|
||||||
|
self.capture_aux_hidden_states = False
|
||||||
|
|
||||||
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.model.get_input_embedding(input_ids)
|
return self.model.get_input_embedding(input_ids)
|
||||||
|
|
||||||
@@ -481,6 +484,10 @@ class Qwen2ForCausalLM(nn.Module):
|
|||||||
if self.capture_aux_hidden_states:
|
if self.capture_aux_hidden_states:
|
||||||
hidden_states, aux_hidden_states = hidden_states
|
hidden_states, aux_hidden_states = hidden_states
|
||||||
|
|
||||||
|
aux_hidden_states = None
|
||||||
|
if self.capture_aux_hidden_states:
|
||||||
|
hidden_states, aux_hidden_states = hidden_states
|
||||||
|
|
||||||
if self.pp_group.is_last_rank:
|
if self.pp_group.is_last_rank:
|
||||||
if not get_embedding:
|
if not get_embedding:
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
|
|||||||
@@ -518,6 +518,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|||||||
self.logits_processor = LogitsProcessor(config)
|
self.logits_processor = LogitsProcessor(config)
|
||||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||||
|
|
||||||
|
# For EAGLE3 support
|
||||||
|
self.capture_aux_hidden_states = False
|
||||||
|
|
||||||
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
||||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||||
@@ -588,9 +591,13 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|||||||
positions=positions,
|
positions=positions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
aux_hidden_states = None
|
||||||
|
if self.capture_aux_hidden_states:
|
||||||
|
hidden_states, aux_hidden_states = hidden_states
|
||||||
|
|
||||||
if not get_embedding:
|
if not get_embedding:
|
||||||
return self.logits_processor(
|
return self.logits_processor(
|
||||||
input_ids, hidden_states, self.lm_head, forward_batch
|
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.pooler(hidden_states, forward_batch)
|
return self.pooler(hidden_states, forward_batch)
|
||||||
@@ -644,5 +651,21 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|||||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
def get_embed_and_head(self):
|
||||||
|
return self.model.embed_tokens.weight, self.lm_head.weight
|
||||||
|
|
||||||
|
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
||||||
|
self.capture_aux_hidden_states = True
|
||||||
|
self.model.capture_aux_hidden_states = True
|
||||||
|
if layer_ids is None:
|
||||||
|
num_layers = self.config.num_hidden_layers
|
||||||
|
self.model.layers_to_capture = [
|
||||||
|
2,
|
||||||
|
num_layers // 2,
|
||||||
|
num_layers - 3,
|
||||||
|
] # Specific layers for EAGLE3 support
|
||||||
|
else:
|
||||||
|
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
||||||
|
|
||||||
|
|
||||||
EntryClass = [Qwen2_5_VLForConditionalGeneration]
|
EntryClass = [Qwen2_5_VLForConditionalGeneration]
|
||||||
|
|||||||
@@ -91,6 +91,9 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
(self.max_num_token * self.speculative_num_steps,), dtype=torch.int64
|
(self.max_num_token * self.speculative_num_steps,), dtype=torch.int64
|
||||||
)
|
)
|
||||||
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||||
|
self.mrope_positions = torch.zeros(
|
||||||
|
(3, self.max_num_token), dtype=torch.int64
|
||||||
|
)
|
||||||
self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
|
self.topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32)
|
||||||
self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
|
self.topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64)
|
||||||
self.hidden_states = torch.zeros(
|
self.hidden_states = torch.zeros(
|
||||||
@@ -159,6 +162,7 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
seq_lens = self.seq_lens[:num_seqs]
|
seq_lens = self.seq_lens[:num_seqs]
|
||||||
out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps]
|
out_cache_loc = self.out_cache_loc[: num_tokens * self.speculative_num_steps]
|
||||||
positions = self.positions[:num_tokens]
|
positions = self.positions[:num_tokens]
|
||||||
|
mrope_positions = self.mrope_positions[:, :num_tokens]
|
||||||
topk_p = self.topk_p[:num_seqs]
|
topk_p = self.topk_p[:num_seqs]
|
||||||
topk_index = self.topk_index[:num_seqs]
|
topk_index = self.topk_index[:num_seqs]
|
||||||
hidden_states = self.hidden_states[:num_seqs]
|
hidden_states = self.hidden_states[:num_seqs]
|
||||||
@@ -224,6 +228,7 @@ class EAGLEDraftCudaGraphRunner:
|
|||||||
seq_lens_sum=seq_lens.sum().item(),
|
seq_lens_sum=seq_lens.sum().item(),
|
||||||
return_logprob=False,
|
return_logprob=False,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
|
mrope_positions=mrope_positions,
|
||||||
global_num_tokens_gpu=global_num_tokens,
|
global_num_tokens_gpu=global_num_tokens,
|
||||||
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
||||||
global_dp_buffer_len=global_dp_buffer_len,
|
global_dp_buffer_len=global_dp_buffer_len,
|
||||||
|
|||||||
@@ -80,6 +80,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
||||||
self.out_cache_loc = torch.ones((self.max_num_token,), dtype=torch.int64)
|
self.out_cache_loc = torch.ones((self.max_num_token,), dtype=torch.int64)
|
||||||
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||||
|
self.mrope_positions = torch.zeros(
|
||||||
|
(3, self.max_num_token), dtype=torch.int64
|
||||||
|
)
|
||||||
|
|
||||||
if self.eagle_worker.speculative_algorithm.is_eagle3():
|
if self.eagle_worker.speculative_algorithm.is_eagle3():
|
||||||
self.hidden_states = torch.zeros(
|
self.hidden_states = torch.zeros(
|
||||||
@@ -189,6 +192,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
accept_length = self.accept_length[:bs]
|
accept_length = self.accept_length[:bs]
|
||||||
out_cache_loc = self.out_cache_loc[:num_tokens]
|
out_cache_loc = self.out_cache_loc[:num_tokens]
|
||||||
positions = self.positions[:num_tokens]
|
positions = self.positions[:num_tokens]
|
||||||
|
mrope_positions = self.mrope_positions[:, :num_tokens]
|
||||||
hidden_states = self.hidden_states[:num_tokens]
|
hidden_states = self.hidden_states[:num_tokens]
|
||||||
next_token_logits_buffer = self.next_token_logits_buffer[:bs]
|
next_token_logits_buffer = self.next_token_logits_buffer[:bs]
|
||||||
|
|
||||||
@@ -247,6 +251,7 @@ class EAGLEDraftExtendCudaGraphRunner:
|
|||||||
seq_lens_sum=seq_lens.sum().item(),
|
seq_lens_sum=seq_lens.sum().item(),
|
||||||
return_logprob=False,
|
return_logprob=False,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
|
mrope_positions=mrope_positions,
|
||||||
global_num_tokens_gpu=self.global_num_tokens_gpu,
|
global_num_tokens_gpu=self.global_num_tokens_gpu,
|
||||||
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
|
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
|
||||||
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from sglang.srt.distributed import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
||||||
|
from sglang.srt.managers.mm_utils import embed_mm_inputs
|
||||||
from sglang.srt.managers.schedule_batch import (
|
from sglang.srt.managers.schedule_batch import (
|
||||||
ScheduleBatch,
|
ScheduleBatch,
|
||||||
get_last_loc,
|
get_last_loc,
|
||||||
|
|||||||
Reference in New Issue
Block a user