Eagle3 mm support, enablement on qwen3vl (#4848)
### What this PR does / why we need it?
follow pr
[https://github.com/vllm-project/vllm/pull/20788](https://github.com/vllm-project/vllm/pull/20788)
, Eagle3 mm support, enablement on qwen3vl
target model
[Qwen/Qwen3-VL-8B-Instruct]([https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct])
eagle3
[MNN/Qwen3-VL-8B-Instruct-Eagle3](https://www.modelscope.cn/models/MNN/Qwen3-VL-8B-Instruct-Eagle3)
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
pytest ./tests/e2e/singlecard/test_completion_with_prompt_embeds.py -vv
vLLM with eagle3 :
```bash
vllm serve /model/Qwen3-VL-8B-Instruct --enforce-eager --port 9100 --max-model-len 32768 --max-num-seqs 32 --tensor-parallel-size 2 --allowed-local-media-path /model/gx/images --speculative-config '{
"method": "eagle3",
"model": "/model/hf/Qwen3-VL-8B-Instruct-Eagle3",
"num_speculative_tokens": 3
}'
```
vLLM without eagle3 :
```bash
vllm serve /model/Qwen3-VL-8B-Instruct --enforce-eager --port 9100 --max-model-len 32768 --max-num-seqs 32 --tensor-parallel-size 2 --allowed-local-media-path /model/gx/images
```
bench:
```
vllm bench serve --backend openai-chat --base-url http://127.0.0.1:9100 --tokenizer /model/Qwen3-VL-8B-Instruct --endpoint /v1/chat/completions --model /model/Qwen3-VL-8B-Instruct --dataset-name random --num-prompts 50 --max-concurrency 5 --temperature 0 --top-p 1.0 --seed 123
```
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: jesse <szxfml@gmail.com>
This commit is contained in:
2
.github/workflows/misc/model_list.json
vendored
2
.github/workflows/misc/model_list.json
vendored
@@ -148,6 +148,7 @@
|
|||||||
"moonshotai/Kimi-K2-Thinking",
|
"moonshotai/Kimi-K2-Thinking",
|
||||||
"moonshotai/Kimi-Linear-48B-A3B-Instruct",
|
"moonshotai/Kimi-Linear-48B-A3B-Instruct",
|
||||||
"neuralmagic/Qwen2.5-3B-quantized.w8a8",
|
"neuralmagic/Qwen2.5-3B-quantized.w8a8",
|
||||||
|
"MNN/Qwen3-VL-8B-Instruct-Eagle3",
|
||||||
"nv-community/audio-flamingo-3",
|
"nv-community/audio-flamingo-3",
|
||||||
"nv-community/audio-flamingo-3-hf",
|
"nv-community/audio-flamingo-3-hf",
|
||||||
"nvidia/audio-flamingo-3-hf",
|
"nvidia/audio-flamingo-3-hf",
|
||||||
@@ -234,4 +235,3 @@
|
|||||||
"xlangai/OpenCUA-7B"
|
"xlangai/OpenCUA-7B"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -85,6 +85,14 @@ def eagle3_model_name():
|
|||||||
return "vllm-ascend/EAGLE3-LLaMA3.1-Instruct-8B"
|
return "vllm-ascend/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def vl_model_name():
|
||||||
|
return "Qwen/Qwen3-VL-8B-Instruct"
|
||||||
|
|
||||||
|
def vl_eagle3_model_name():
|
||||||
|
return "MNN/Qwen3-VL-8B-Instruct-Eagle3"
|
||||||
|
|
||||||
|
|
||||||
def test_ngram_correctness(
|
def test_ngram_correctness(
|
||||||
test_prompts: list[list[dict[str, Any]]],
|
test_prompts: list[list[dict[str, Any]]],
|
||||||
sampling_config: SamplingParams,
|
sampling_config: SamplingParams,
|
||||||
@@ -129,6 +137,48 @@ def test_ngram_correctness(
|
|||||||
assert matches > int(0.66 * len(ref_outputs))
|
assert matches > int(0.66 * len(ref_outputs))
|
||||||
|
|
||||||
|
|
||||||
|
def test_qwen3_vl_eagle_correctness(
|
||||||
|
test_prompts: list[list[dict[str, Any]]],
|
||||||
|
sampling_config: SamplingParams,
|
||||||
|
vl_model_name: str,
|
||||||
|
):
|
||||||
|
'''
|
||||||
|
Compare the outputs of a original LLM and a speculative LLM
|
||||||
|
should be the same when using eagle speculative decoding.
|
||||||
|
'''
|
||||||
|
with VllmRunner(
|
||||||
|
vl_model_name,
|
||||||
|
max_model_len=1024,
|
||||||
|
cudagraph_capture_sizes=[1, 2, 4, 8],
|
||||||
|
) as ref_llm:
|
||||||
|
ref_outputs = ref_llm.model.chat(test_prompts, sampling_config)
|
||||||
|
|
||||||
|
spec_model_name = vl_eagle3_model_name()
|
||||||
|
with VllmRunner(
|
||||||
|
vl_model_name,
|
||||||
|
speculative_config={
|
||||||
|
"method": "eagle3",
|
||||||
|
"model": spec_model_name,
|
||||||
|
"num_speculative_tokens": 2,
|
||||||
|
},
|
||||||
|
max_model_len=1024,
|
||||||
|
cudagraph_capture_sizes=[1, 2, 4, 8],
|
||||||
|
) as runner:
|
||||||
|
spec_outputs = runner.model.chat(test_prompts, sampling_config)
|
||||||
|
matches = 0
|
||||||
|
misses = 0
|
||||||
|
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||||
|
if ref_output.outputs[0].text == spec_output.outputs[0].text:
|
||||||
|
matches += 1
|
||||||
|
else:
|
||||||
|
misses += 1
|
||||||
|
print(f"ref_output: {ref_output.outputs[0].text}")
|
||||||
|
print(f"spec_output: {spec_output.outputs[0].text}")
|
||||||
|
|
||||||
|
# Heuristic: expect at least 70% of the prompts to match exactly
|
||||||
|
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||||
|
assert matches > int(0.66 * len(ref_outputs))
|
||||||
|
|
||||||
def test_suffix_correctness(
|
def test_suffix_correctness(
|
||||||
test_prompts: list[list[dict[str, Any]]],
|
test_prompts: list[list[dict[str, Any]]],
|
||||||
sampling_config: SamplingParams,
|
sampling_config: SamplingParams,
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ class TestEagleProposerInitialization(TestBase):
|
|||||||
self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096
|
self.vllm_config.speculative_config.draft_model_config.get_hidden_size.return_value = 4096
|
||||||
self.vllm_config.compilation_config.mode = CompilationMode.VLLM_COMPILE
|
self.vllm_config.compilation_config.mode = CompilationMode.VLLM_COMPILE
|
||||||
self.vllm_config.model_config.enforce_eager = False
|
self.vllm_config.model_config.enforce_eager = False
|
||||||
|
self.vllm_config.model_config.uses_mrope = False
|
||||||
self.vllm_config.speculative_config.enforce_eager = False
|
self.vllm_config.speculative_config.enforce_eager = False
|
||||||
self.vllm_config.scheduler_config.async_scheduling = False
|
self.vllm_config.scheduler_config.async_scheduling = False
|
||||||
init_ascend_config(self.vllm_config)
|
init_ascend_config(self.vllm_config)
|
||||||
@@ -156,6 +157,7 @@ class TestEagleProposerLoadModel(TestBase):
|
|||||||
}]
|
}]
|
||||||
|
|
||||||
mock_model = MagicMock()
|
mock_model = MagicMock()
|
||||||
|
mock_model.supports_multimodal = False
|
||||||
mock_model.model.embed_tokens = MagicMock()
|
mock_model.model.embed_tokens = MagicMock()
|
||||||
mock_model.lm_head = MagicMock()
|
mock_model.lm_head = MagicMock()
|
||||||
mock_model.multimodal_cpu_fields = None
|
mock_model.multimodal_cpu_fields = None
|
||||||
@@ -226,7 +228,7 @@ class TestEagleProposerLoadModel(TestBase):
|
|||||||
self.proposer.name = SpecDcodeType.EAGLE
|
self.proposer.name = SpecDcodeType.EAGLE
|
||||||
|
|
||||||
self.proposer.load_model(mock_model)
|
self.proposer.load_model(mock_model)
|
||||||
mock_model.get_language_model.assert_called_once()
|
self.assertEqual(mock_model.get_language_model.call_count, 2)
|
||||||
self.assertIs(self.proposer.model.lm_head,
|
self.assertIs(self.proposer.model.lm_head,
|
||||||
mock_model.get_language_model.return_value.lm_head)
|
mock_model.get_language_model.return_value.lm_head)
|
||||||
|
|
||||||
|
|||||||
@@ -149,8 +149,34 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
assert len(draft_attn_layer_names) == 1
|
assert len(draft_attn_layer_names) == 1
|
||||||
self.attn_layer_names = list(draft_attn_layer_names)
|
self.attn_layer_names = list(draft_attn_layer_names)
|
||||||
|
|
||||||
|
if supports_multimodal(model):
|
||||||
|
# handle multimodality
|
||||||
|
if self.get_model_name(model) in [
|
||||||
|
"Qwen2_5_VLForConditionalGeneration",
|
||||||
|
"Qwen3VLForConditionalGeneration",
|
||||||
|
]:
|
||||||
|
self.model.config.image_token_index = model.config.image_token_id
|
||||||
|
elif self.get_model_name(
|
||||||
|
model) == "PixtralForConditionalGeneration":
|
||||||
|
self.model.config.image_token_index = (
|
||||||
|
model.config.vision_config.image_token_id)
|
||||||
|
else:
|
||||||
|
self.model.config.image_token_index = (
|
||||||
|
model.config.image_token_index)
|
||||||
|
target_language_model = model.get_language_model()
|
||||||
|
else:
|
||||||
|
target_language_model = model
|
||||||
|
|
||||||
# share embed_tokens with the target model if needed
|
# share embed_tokens with the target model if needed
|
||||||
if get_pp_group().world_size == 1:
|
if get_pp_group().world_size == 1:
|
||||||
|
if hasattr(target_language_model.model, "embed_tokens"):
|
||||||
|
target_embed_tokens = target_language_model.model.embed_tokens
|
||||||
|
elif hasattr(target_language_model.model, "embedding"):
|
||||||
|
target_embed_tokens = target_language_model.model.embedding
|
||||||
|
else:
|
||||||
|
raise AttributeError(
|
||||||
|
"Target model does not have 'embed_tokens' or 'embedding' attribute"
|
||||||
|
)
|
||||||
if self.method == "mtp":
|
if self.method == "mtp":
|
||||||
if self.vllm_config.model_config.is_deepseek_mla and \
|
if self.vllm_config.model_config.is_deepseek_mla and \
|
||||||
torch.equal(self.model.model.embed_tokens.weight,
|
torch.equal(self.model.model.embed_tokens.weight,
|
||||||
@@ -161,7 +187,7 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
"The MTP head shares the same vocab embedding" \
|
"The MTP head shares the same vocab embedding" \
|
||||||
" with the target model."
|
" with the target model."
|
||||||
)
|
)
|
||||||
self.model.model.embed_tokens = model.model.embed_tokens
|
self.model.model.embed_tokens = target_embed_tokens
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
" The MTP head loaded its own vocab embedding" \
|
" The MTP head loaded its own vocab embedding" \
|
||||||
@@ -172,13 +198,12 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
"The EAGLE head shares the same vocab embedding" \
|
"The EAGLE head shares the same vocab embedding" \
|
||||||
" with the target model."
|
" with the target model."
|
||||||
)
|
)
|
||||||
self.model.model.embed_tokens = model.model.embed_tokens
|
self.model.model.embed_tokens = target_embed_tokens
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Since PP > 1 or other reasons the model head loaded its own vocab embedding" \
|
"Since PP > 1 or other reasons the model head loaded its own vocab embedding" \
|
||||||
" weights instead of sharing them with the target model."
|
" weights instead of sharing them with the target model."
|
||||||
)
|
)
|
||||||
|
|
||||||
# share lm_head with the target model if needed
|
# share lm_head with the target model if needed
|
||||||
# some model definition do not define lm_head explicitly
|
# some model definition do not define lm_head explicitly
|
||||||
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
|
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
|
||||||
@@ -221,7 +246,7 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
dummy_compute_logits=lambda hidden_states: None,
|
dummy_compute_logits=lambda hidden_states: None,
|
||||||
is_profile=False):
|
is_profile=False):
|
||||||
# update global cos, sin
|
# update global cos, sin
|
||||||
update_cos_sin(self.positions[:num_tokens])
|
update_cos_sin(self._get_positions(num_tokens))
|
||||||
|
|
||||||
attn_metadata = None
|
attn_metadata = None
|
||||||
if not self.use_cuda_graph:
|
if not self.use_cuda_graph:
|
||||||
@@ -265,7 +290,7 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
attn_metadata[layer_name] = attn_metadata_eagle
|
attn_metadata[layer_name] = attn_metadata_eagle
|
||||||
|
|
||||||
model_input_ids = self.input_ids[:num_tokens]
|
model_input_ids = self.input_ids[:num_tokens]
|
||||||
model_positions = self.positions[:num_tokens]
|
model_positions = self._get_positions(num_tokens)
|
||||||
model_previous_hidden_states = self.hidden_states[:num_tokens]
|
model_previous_hidden_states = self.hidden_states[:num_tokens]
|
||||||
for i in range(self.num_speculative_tokens):
|
for i in range(self.num_speculative_tokens):
|
||||||
if i > 0 and in_graph_capturing and aclgraph_runtime_mode == CUDAGraphMode.FULL:
|
if i > 0 and in_graph_capturing and aclgraph_runtime_mode == CUDAGraphMode.FULL:
|
||||||
@@ -340,7 +365,6 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
# Replace the last token with the next token.
|
# Replace the last token with the next token.
|
||||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||||
self.input_ids[last_token_indices] = next_token_ids
|
self.input_ids[last_token_indices] = next_token_ids
|
||||||
|
|
||||||
if self.use_cuda_graph and \
|
if self.use_cuda_graph and \
|
||||||
num_tokens <= self.runner.cudagraph_batch_sizes[-1]:
|
num_tokens <= self.runner.cudagraph_batch_sizes[-1]:
|
||||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||||
@@ -356,15 +380,28 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
batch_descriptor = None
|
batch_descriptor = None
|
||||||
|
|
||||||
# copy inputs to buffer for cudagraph
|
# copy inputs to buffer for cudagraph
|
||||||
self.positions[:num_tokens] = target_positions
|
self._set_positions(num_tokens, target_positions)
|
||||||
self.hidden_states[:num_tokens] = target_hidden_states
|
self.hidden_states[:num_tokens] = target_hidden_states
|
||||||
|
|
||||||
|
if self.supports_mm_inputs:
|
||||||
|
mm_embeds, is_mm_embed = mm_embed_inputs or (None, None)
|
||||||
|
inputs_embeds = self.model.embed_input_ids(
|
||||||
|
self.input_ids[:num_tokens],
|
||||||
|
multimodal_embeddings=mm_embeds,
|
||||||
|
is_multimodal=is_mm_embed)
|
||||||
|
self.inputs_embeds[:num_tokens] = inputs_embeds
|
||||||
|
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||||||
|
input_ids = self.input_ids[:num_input_tokens]
|
||||||
|
else:
|
||||||
|
inputs_embeds = None
|
||||||
|
input_ids = self.input_ids[:num_input_tokens]
|
||||||
|
|
||||||
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
|
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
|
||||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||||
attn_metadata = builder.build(0, common_attn_metadata,
|
attn_metadata = builder.build(0, common_attn_metadata,
|
||||||
self.runner.get_model())
|
self.runner.get_model())
|
||||||
# update global cos, sin
|
# update global cos, sin
|
||||||
update_cos_sin(self.positions[:num_input_tokens])
|
update_cos_sin(self._get_positions(num_input_tokens))
|
||||||
per_layer_attn_metadata = {}
|
per_layer_attn_metadata = {}
|
||||||
for layer_name in self.attn_layer_names:
|
for layer_name in self.attn_layer_names:
|
||||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||||
@@ -380,7 +417,7 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings.
|
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings.
|
||||||
# `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model.
|
# `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model.
|
||||||
model_input_ids = self.input_ids[:num_input_tokens]
|
model_input_ids = self.input_ids[:num_input_tokens]
|
||||||
model_positions = self.positions[:num_input_tokens]
|
model_positions = self._get_positions(num_input_tokens)
|
||||||
model_hidden_states = self.hidden_states[:num_input_tokens]
|
model_hidden_states = self.hidden_states[:num_input_tokens]
|
||||||
|
|
||||||
model_hidden_states, model_positions = self.maybe_pad_and_reduce(
|
model_hidden_states, model_positions = self.maybe_pad_and_reduce(
|
||||||
@@ -390,6 +427,7 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
input_ids=model_input_ids,
|
input_ids=model_input_ids,
|
||||||
positions=model_positions,
|
positions=model_positions,
|
||||||
hidden_states=model_hidden_states,
|
hidden_states=model_hidden_states,
|
||||||
|
inputs_embeds = inputs_embeds
|
||||||
)
|
)
|
||||||
if self.method == "mtp":
|
if self.method == "mtp":
|
||||||
last_hidden_states = ret_hidden_states
|
last_hidden_states = ret_hidden_states
|
||||||
@@ -420,8 +458,10 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
dtype=draft_token_ids.dtype,
|
dtype=draft_token_ids.dtype,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
draft_token_ids_tensor[0] = draft_token_ids
|
draft_token_ids_tensor[0] = draft_token_ids
|
||||||
|
if self.uses_mrope:
|
||||||
positions = target_positions[last_token_indices]
|
positions = target_positions[:, last_token_indices]
|
||||||
|
else:
|
||||||
|
positions = target_positions[last_token_indices]
|
||||||
hidden_states = hidden_states[last_token_indices]
|
hidden_states = hidden_states[last_token_indices]
|
||||||
last_token_indices = self.arange[:batch_size]
|
last_token_indices = self.arange[:batch_size]
|
||||||
|
|
||||||
@@ -460,11 +500,18 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
# but adjust the position ids and slot mappings to avoid the
|
# but adjust the position ids and slot mappings to avoid the
|
||||||
# out-of-range access during the model execution. The draft tokens
|
# out-of-range access during the model execution. The draft tokens
|
||||||
# generated with this adjustment should be ignored.
|
# generated with this adjustment should be ignored.
|
||||||
exceeds_max_model_len = positions >= self.vllm_config.model_config.max_model_len
|
if self.uses_mrope:
|
||||||
# Mask out the position ids that exceed the max model length.
|
exceeds_max_model_len = positions[
|
||||||
# Otherwise, we may get out-of-range error in RoPE.
|
0] >= self.vllm_config.model_config.max_model_len
|
||||||
clamped_positions = torch.where(exceeds_max_model_len, 0,
|
# Mask out the position ids that exceed the max model length.
|
||||||
positions)
|
# Otherwise, we may get out-of-range error in RoPE.
|
||||||
|
clamped_positions = torch.where(
|
||||||
|
exceeds_max_model_len.unsqueeze(0),
|
||||||
|
torch.zeros_like(positions), positions)
|
||||||
|
else:
|
||||||
|
exceeds_max_model_len = positions >= self.vllm_config.model_config.max_model_len
|
||||||
|
clamped_positions = torch.where(exceeds_max_model_len, 0,
|
||||||
|
positions)
|
||||||
|
|
||||||
# TODO: Increment the sequence lengths.
|
# TODO: Increment the sequence lengths.
|
||||||
|
|
||||||
@@ -485,12 +532,19 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
block_size = attn_metadata_builder.kv_cache_spec.block_size
|
block_size = attn_metadata_builder.kv_cache_spec.block_size
|
||||||
|
|
||||||
# Compute the slot mapping.
|
# Compute the slot mapping.
|
||||||
block_numbers = (clamped_positions // block_size)
|
if self.uses_mrope:
|
||||||
|
block_numbers = clamped_positions[0] // block_size
|
||||||
|
else:
|
||||||
|
block_numbers = (clamped_positions // block_size)
|
||||||
block_ids = attn_metadata.block_tables.gather(
|
block_ids = attn_metadata.block_tables.gather(
|
||||||
dim=1, index=block_numbers.view(-1, 1))
|
dim=1, index=block_numbers.view(-1, 1))
|
||||||
block_ids = block_ids.view(-1)
|
block_ids = block_ids.view(-1)
|
||||||
slot_mapping_tmp = (block_ids * block_size +
|
if self.uses_mrope:
|
||||||
clamped_positions % block_size)
|
slot_mapping_tmp = (block_ids * block_size +
|
||||||
|
clamped_positions[0] % block_size)
|
||||||
|
else:
|
||||||
|
slot_mapping_tmp = (block_ids * block_size +
|
||||||
|
clamped_positions % block_size)
|
||||||
|
|
||||||
# Mask out the slot mappings that exceed the max model length.
|
# Mask out the slot mappings that exceed the max model length.
|
||||||
# Otherwise, the KV cache will be inadvertently updated with the
|
# Otherwise, the KV cache will be inadvertently updated with the
|
||||||
@@ -504,14 +558,23 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
PADDING_SLOT_ID)
|
PADDING_SLOT_ID)
|
||||||
# copy inputs to buffer for cudagraph
|
# copy inputs to buffer for cudagraph
|
||||||
self.input_ids[:batch_size] = input_ids
|
self.input_ids[:batch_size] = input_ids
|
||||||
self.positions[:batch_size] = clamped_positions
|
self._set_positions(batch_size, clamped_positions)
|
||||||
self.hidden_states[:batch_size] = hidden_states
|
self.hidden_states[:batch_size] = hidden_states
|
||||||
|
if self.supports_mm_inputs:
|
||||||
|
self.inputs_embeds[:batch_size] = self.model.embed_input_ids(
|
||||||
|
input_ids)
|
||||||
|
|
||||||
|
input_ids = self.input_ids[:input_batch_size]
|
||||||
|
inputs_embeds = self.inputs_embeds[:input_batch_size]
|
||||||
|
else:
|
||||||
|
input_ids = self.input_ids[:input_batch_size]
|
||||||
|
inputs_embeds = None
|
||||||
attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask()
|
attn_mask = self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||||
|
|
||||||
attn_metadata.attn_mask = attn_mask
|
attn_metadata.attn_mask = attn_mask
|
||||||
|
|
||||||
# update global cos, sin
|
# update global cos, sin
|
||||||
update_cos_sin(self.positions[:input_batch_size])
|
update_cos_sin(self._get_positions(input_batch_size))
|
||||||
|
|
||||||
# Run the model.
|
# Run the model.
|
||||||
with set_ascend_forward_context(
|
with set_ascend_forward_context(
|
||||||
@@ -526,7 +589,7 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings.
|
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings.
|
||||||
# `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model.
|
# `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model.
|
||||||
model_input_ids = self.input_ids[:input_batch_size]
|
model_input_ids = self.input_ids[:input_batch_size]
|
||||||
model_positions = self.positions[:input_batch_size]
|
model_positions = self._get_positions(input_batch_size)
|
||||||
model_hidden_states = self.hidden_states[:input_batch_size]
|
model_hidden_states = self.hidden_states[:input_batch_size]
|
||||||
|
|
||||||
model_hidden_states, model_positions = self.maybe_pad_and_reduce(
|
model_hidden_states, model_positions = self.maybe_pad_and_reduce(
|
||||||
@@ -536,6 +599,7 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
input_ids=model_input_ids,
|
input_ids=model_input_ids,
|
||||||
positions=model_positions,
|
positions=model_positions,
|
||||||
hidden_states=model_hidden_states,
|
hidden_states=model_hidden_states,
|
||||||
|
inputs_embeds = inputs_embeds
|
||||||
)
|
)
|
||||||
if self.method == "mtp":
|
if self.method == "mtp":
|
||||||
last_hidden_states = ret_hidden_states
|
last_hidden_states = ret_hidden_states
|
||||||
|
|||||||
@@ -1354,14 +1354,14 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
query_start_loc_pcp_full[1:num_reqs + 1] - 1
|
query_start_loc_pcp_full[1:num_reqs + 1] - 1
|
||||||
target_token_ids = input_ids_pcp_full[:
|
target_token_ids = input_ids_pcp_full[:
|
||||||
num_scheduled_tokens]
|
num_scheduled_tokens]
|
||||||
target_positions = positions[:num_scheduled_tokens]
|
target_positions = self._get_positions(num_scheduled_tokens)
|
||||||
target_hidden_states = hidden_states
|
target_hidden_states = hidden_states
|
||||||
else:
|
else:
|
||||||
token_indices_to_sample = None
|
token_indices_to_sample = None
|
||||||
# input_ids can be None for multimodal models.
|
# input_ids can be None for multimodal models.
|
||||||
target_token_ids = self.input_ids.gpu[:
|
target_token_ids = self.input_ids.gpu[:
|
||||||
num_scheduled_tokens]
|
num_scheduled_tokens]
|
||||||
target_positions = positions[:num_scheduled_tokens]
|
target_positions = self._get_positions(num_scheduled_tokens)
|
||||||
if self.use_aux_hidden_state_outputs:
|
if self.use_aux_hidden_state_outputs:
|
||||||
target_hidden_states = torch.cat([
|
target_hidden_states = torch.cat([
|
||||||
h[:num_scheduled_tokens]
|
h[:num_scheduled_tokens]
|
||||||
@@ -1402,7 +1402,7 @@ class NPUModelRunner(GPUModelRunner):
|
|||||||
target_hidden_states = hidden_states
|
target_hidden_states = hidden_states
|
||||||
else:
|
else:
|
||||||
target_token_ids = self.input_ids.gpu[token_indices]
|
target_token_ids = self.input_ids.gpu[token_indices]
|
||||||
target_positions = positions[token_indices]
|
target_positions = self._get_positions(token_indices)
|
||||||
if self.use_aux_hidden_state_outputs:
|
if self.use_aux_hidden_state_outputs:
|
||||||
target_hidden_states = torch.cat(
|
target_hidden_states = torch.cat(
|
||||||
[h[token_indices] for h in aux_hidden_states],
|
[h[token_indices] for h in aux_hidden_states],
|
||||||
|
|||||||
Reference in New Issue
Block a user