[Feat] Adapted mtp function to Qwen3-next (#3918)
### What this PR does / why we need it?
Adapts mtp function to Qwen3-next.
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
@@ -20,10 +20,17 @@
|
|||||||
|
|
||||||
Run `pytest tests/e2e/multicard/test_qwen3_next.py`.
|
Run `pytest tests/e2e/multicard/test_qwen3_next.py`.
|
||||||
"""
|
"""
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
from tests.e2e.conftest import VllmRunner
|
from tests.e2e.conftest import VllmRunner
|
||||||
|
|
||||||
|
# NZ will cause precision error in Qwen3-Next
|
||||||
|
# When it is fixed, this set-up can be removed
|
||||||
|
_IS_ENABLE_NZ = "VLLM_ASCEND_ENABLE_NZ"
|
||||||
|
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {_IS_ENABLE_NZ: "0"})
|
||||||
def test_models_distributed_Qwen3_NEXT_TP4():
|
def test_models_distributed_Qwen3_NEXT_TP4():
|
||||||
example_prompts = [
|
example_prompts = [
|
||||||
"Hello, my name is",
|
"Hello, my name is",
|
||||||
@@ -36,8 +43,10 @@ def test_models_distributed_Qwen3_NEXT_TP4():
|
|||||||
distributed_executor_backend="mp",
|
distributed_executor_backend="mp",
|
||||||
enforce_eager=True) as vllm_model:
|
enforce_eager=True) as vllm_model:
|
||||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
del vllm_model
|
||||||
|
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {_IS_ENABLE_NZ: "0"})
|
||||||
def test_models_distributed_Qwen3_NEXT_TP4_FULL_DECODE_ONLY():
|
def test_models_distributed_Qwen3_NEXT_TP4_FULL_DECODE_ONLY():
|
||||||
example_prompts = [
|
example_prompts = [
|
||||||
"Hello, my name is",
|
"Hello, my name is",
|
||||||
@@ -54,3 +63,50 @@ def test_models_distributed_Qwen3_NEXT_TP4_FULL_DECODE_ONLY():
|
|||||||
"cudagraph_capture_sizes": [1, 8, 24, 48, 60]
|
"cudagraph_capture_sizes": [1, 8, 24, 48, 60]
|
||||||
}) as vllm_model:
|
}) as vllm_model:
|
||||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
del vllm_model
|
||||||
|
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {_IS_ENABLE_NZ: "0"})
|
||||||
|
def test_models_distributed_Qwen3_NEXT_MTP_TP4_SIMILARITY():
|
||||||
|
example_prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
max_tokens = 20
|
||||||
|
|
||||||
|
with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||||
|
tensor_parallel_size=4,
|
||||||
|
max_model_len=4096,
|
||||||
|
gpu_memory_utilization=0.8,
|
||||||
|
distributed_executor_backend="mp") as vllm_model:
|
||||||
|
ref_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
del vllm_model
|
||||||
|
|
||||||
|
with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct",
|
||||||
|
tensor_parallel_size=4,
|
||||||
|
max_model_len=4096,
|
||||||
|
gpu_memory_utilization=0.8,
|
||||||
|
distributed_executor_backend="mp",
|
||||||
|
speculative_config={
|
||||||
|
"method": "qwen3_next_mtp",
|
||||||
|
"num_speculative_tokens": 1
|
||||||
|
}) as spec_vllm_model:
|
||||||
|
spec_outputs = spec_vllm_model.generate_greedy(example_prompts,
|
||||||
|
max_tokens)
|
||||||
|
del spec_vllm_model
|
||||||
|
|
||||||
|
matches = 0
|
||||||
|
misses = 0
|
||||||
|
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||||
|
ref_token_ids = ref_output[0]
|
||||||
|
spec_token_ids = spec_output[0]
|
||||||
|
if ref_token_ids == spec_token_ids[:len(ref_token_ids)]:
|
||||||
|
matches += 1
|
||||||
|
else:
|
||||||
|
misses += 1
|
||||||
|
print(f"ref_output: {ref_output[1]}")
|
||||||
|
print(f"spec_output: {spec_output[1]}")
|
||||||
|
|
||||||
|
assert matches > int(0.66 * len(ref_outputs))
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ class TestAscendAttentionMetadataBuilder(TestBase):
|
|||||||
mock_get_dcp_group.return_value = dcp_group
|
mock_get_dcp_group.return_value = dcp_group
|
||||||
|
|
||||||
self.mock_vllm_config = MagicMock()
|
self.mock_vllm_config = MagicMock()
|
||||||
|
self.mock_vllm_config.speculative_config = None
|
||||||
self.mock_vllm_config.model_config.max_model_len = 640
|
self.mock_vllm_config.model_config.max_model_len = 640
|
||||||
self.mock_vllm_config.cache_config.block_size = 64
|
self.mock_vllm_config.cache_config.block_size = 64
|
||||||
self.mock_vllm_config.compilation_config.cudagraph_mode = None
|
self.mock_vllm_config.compilation_config.cudagraph_mode = None
|
||||||
|
|||||||
@@ -252,6 +252,17 @@ class AscendAttentionMetadataBuilder:
|
|||||||
self.dcp_rank = get_decode_context_model_parallel_rank(
|
self.dcp_rank = get_decode_context_model_parallel_rank(
|
||||||
) if self.dcp_size > 1 else 0
|
) if self.dcp_size > 1 else 0
|
||||||
|
|
||||||
|
self.speculative_config = vllm_config.speculative_config
|
||||||
|
self.decode_threshold = 1
|
||||||
|
if self.speculative_config:
|
||||||
|
spec_token_num = self.speculative_config.num_speculative_tokens
|
||||||
|
self.decode_threshold += spec_token_num
|
||||||
|
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
|
||||||
|
npu_fused_infer_attention_score TND layout's limit of 16, \
|
||||||
|
got {self.decode_threshold}"
|
||||||
|
|
||||||
|
AscendAttentionMetadataBuilder.reorder_batch_threshold = self.decode_threshold
|
||||||
|
|
||||||
def reorder_batch(self, input_batch,
|
def reorder_batch(self, input_batch,
|
||||||
scheduler_output: "SchedulerOutput") -> bool:
|
scheduler_output: "SchedulerOutput") -> bool:
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -35,6 +35,10 @@ def register_model():
|
|||||||
"PanguProMoEForCausalLM",
|
"PanguProMoEForCausalLM",
|
||||||
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
|
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
|
||||||
)
|
)
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
"Qwen3NextForCausalLM",
|
"Qwen3NextForCausalLM",
|
||||||
"vllm_ascend.models.qwen3_next:CustomQwen3NextForCausalLM")
|
"vllm_ascend.models.qwen3_next:CustomQwen3NextForCausalLM")
|
||||||
|
|
||||||
|
ModelRegistry.register_model(
|
||||||
|
"Qwen3NextMTP", "vllm_ascend.models.qwen3_next_mtp:CustomQwen3NextMTP")
|
||||||
|
|||||||
@@ -260,6 +260,24 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
|
|||||||
mixed_qkv_spec = None
|
mixed_qkv_spec = None
|
||||||
mixed_qkv_non_spec = mixed_qkv
|
mixed_qkv_non_spec = mixed_qkv
|
||||||
|
|
||||||
|
# 2.1: process the mutli-query part
|
||||||
|
if spec_sequence_masks is not None:
|
||||||
|
mixed_qkv_spec = mixed_qkv_spec.view(
|
||||||
|
attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1))
|
||||||
|
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l')
|
||||||
|
mixed_qkv_spec = causal_conv1d_update(
|
||||||
|
mixed_qkv_spec,
|
||||||
|
conv_state,
|
||||||
|
conv_weights,
|
||||||
|
self.conv1d.bias,
|
||||||
|
self.activation,
|
||||||
|
conv_state_indices=spec_state_indices_tensor[:, 0]
|
||||||
|
[:attn_metadata.num_spec_decodes],
|
||||||
|
num_accepted_tokens=num_accepted_tokens,
|
||||||
|
validate_data=False,
|
||||||
|
)
|
||||||
|
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b d l -> (b l) d')
|
||||||
|
|
||||||
# 2.2: process the remaining part
|
# 2.2: process the remaining part
|
||||||
if attn_metadata.num_prefills > 0:
|
if attn_metadata.num_prefills > 0:
|
||||||
# - "cache_indices" updates the conv_state cache in positions
|
# - "cache_indices" updates the conv_state cache in positions
|
||||||
|
|||||||
109
vllm_ascend/models/qwen3_next_mtp.py
Normal file
109
vllm_ascend/models/qwen3_next_mtp.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Inference-only Qwen3Next MTP model."""
|
||||||
|
import torch
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.models.interfaces import SupportsPP
|
||||||
|
from vllm.model_executor.models.qwen3_next_mtp import (
|
||||||
|
Qwen3NextMTP, Qwen3NextMultiTokenPredictor)
|
||||||
|
from vllm.model_executor.models.utils import (
|
||||||
|
make_empty_intermediate_tensors_factory, maybe_prefix)
|
||||||
|
from vllm.transformers_utils.configs import Qwen3NextConfig
|
||||||
|
|
||||||
|
from vllm_ascend.models.qwen3_next import (CustomQwen3NextDecoderLayer,
|
||||||
|
Qwen3NextRMSNorm)
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
|
class CustomQwen3NextMultiTokenPredictor(Qwen3NextMultiTokenPredictor):
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super(Qwen3NextMultiTokenPredictor, self).__init__()
|
||||||
|
|
||||||
|
model_config = vllm_config.model_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
lora_config = vllm_config.lora_config
|
||||||
|
config: Qwen3NextConfig = model_config.hf_config
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
lora_vocab = ((lora_config.lora_extra_vocab_size *
|
||||||
|
(lora_config.max_loras or 1)) if lora_config else 0)
|
||||||
|
self.vocab_size = config.vocab_size + lora_vocab
|
||||||
|
self.org_vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
self.mtp_start_layer_idx = config.num_hidden_layers
|
||||||
|
self.num_mtp_layers = getattr(config, "num_nextn_predict_layers", 1)
|
||||||
|
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
self.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.fc = ColumnParallelLinear(self.config.hidden_size * 2,
|
||||||
|
self.config.hidden_size,
|
||||||
|
gather_output=True,
|
||||||
|
bias=False,
|
||||||
|
return_bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f'{prefix}.fc')
|
||||||
|
|
||||||
|
# use old version mtp layer name to avoid a exception in vllm
|
||||||
|
self.layers = torch.nn.ModuleList(
|
||||||
|
CustomQwen3NextDecoderLayer(
|
||||||
|
vllm_config,
|
||||||
|
layer_type="full_attention",
|
||||||
|
prefix=f'{prefix}.layers.{self.mtp_start_layer_idx + idx}',
|
||||||
|
) for idx in range(self.num_mtp_layers))
|
||||||
|
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
make_empty_intermediate_tensors_factory(
|
||||||
|
["hidden_states", "residual"], config.hidden_size))
|
||||||
|
|
||||||
|
self.norm = Qwen3NextRMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.pre_fc_norm_hidden = Qwen3NextRMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.pre_fc_norm_embedding = Qwen3NextRMSNorm(config.hidden_size,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
|
class CustomQwen3NextMTP(Qwen3NextMTP, SupportsPP):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": ["up_proj", "down_proj"]
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
assert not cache_config.enable_prefix_caching, \
|
||||||
|
"Qwen3NextMTP currently does not support prefix caching"
|
||||||
|
|
||||||
|
self.quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
|
super(Qwen3NextMTP, self).__init__()
|
||||||
|
self.config = config
|
||||||
|
self.model = CustomQwen3NextMultiTokenPredictor(
|
||||||
|
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"))
|
||||||
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
|
self.lm_head = ParallelLMHead(self.unpadded_vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||||
|
prefix=maybe_prefix(prefix, "lm_head"))
|
||||||
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
config.vocab_size)
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
self.model.make_empty_intermediate_tensors)
|
||||||
@@ -55,7 +55,7 @@ def causal_conv1d_ref(
|
|||||||
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
|
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
|
||||||
dtype_in) # (batch, dim, width - 1)
|
dtype_in) # (batch, dim, width - 1)
|
||||||
if final_states_out is not None:
|
if final_states_out is not None:
|
||||||
final_states_out.copy_(final_states)
|
final_states_out[..., :(width - 1)].copy_(final_states)
|
||||||
else:
|
else:
|
||||||
final_states_out = final_states
|
final_states_out = final_states
|
||||||
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
||||||
|
|||||||
@@ -29,9 +29,9 @@ def get_spec_decode_method(method,
|
|||||||
is_torchair_graph=False):
|
is_torchair_graph=False):
|
||||||
if method == "ngram":
|
if method == "ngram":
|
||||||
return NgramProposer(vllm_config, device, runner)
|
return NgramProposer(vllm_config, device, runner)
|
||||||
elif method in ["eagle", "eagle3"]:
|
elif method in ("eagle", "eagle3"):
|
||||||
return EagleProposer(vllm_config, device, runner)
|
return EagleProposer(vllm_config, device, runner)
|
||||||
elif method == 'deepseek_mtp':
|
elif method in ('deepseek_mtp', 'qwen3_next_mtp'):
|
||||||
if is_torchair_graph:
|
if is_torchair_graph:
|
||||||
return TorchairMtpProposer(vllm_config, device, runner)
|
return TorchairMtpProposer(vllm_config, device, runner)
|
||||||
return MtpProposer(vllm_config, device, runner)
|
return MtpProposer(vllm_config, device, runner)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import importlib
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -12,7 +13,6 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
|||||||
from vllm.model_executor.model_loader import get_model_loader
|
from vllm.model_executor.model_loader import get_model_loader
|
||||||
from vllm.model_executor.model_loader.utils import \
|
from vllm.model_executor.model_loader.utils import \
|
||||||
process_weights_after_loading
|
process_weights_after_loading
|
||||||
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
|
|
||||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
@@ -42,6 +42,26 @@ logger = init_logger(__name__)
|
|||||||
|
|
||||||
PADDING_SLOT_ID = -1
|
PADDING_SLOT_ID = -1
|
||||||
|
|
||||||
|
_MTP_MODELS = {
|
||||||
|
"DeepseekV3ForCausalLM":
|
||||||
|
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
|
||||||
|
"Qwen3NextForCausalLM":
|
||||||
|
("vllm_ascend.models.qwen3_next_mtp", "CustomQwen3NextMTP")
|
||||||
|
}
|
||||||
|
|
||||||
|
_DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn'
|
||||||
|
|
||||||
|
_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'}
|
||||||
|
|
||||||
|
|
||||||
|
def _load_model(architecture):
|
||||||
|
if architecture not in _MTP_MODELS:
|
||||||
|
raise ValueError("Invalid architecture for mtp.")
|
||||||
|
module_name, model_name = _MTP_MODELS[architecture]
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
model = getattr(module, model_name)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
class MtpProposer(Proposer):
|
class MtpProposer(Proposer):
|
||||||
|
|
||||||
@@ -150,9 +170,7 @@ class MtpProposer(Proposer):
|
|||||||
with set_default_torch_dtype(
|
with set_default_torch_dtype(
|
||||||
draft_model_config.dtype), set_current_vllm_config(
|
draft_model_config.dtype), set_current_vllm_config(
|
||||||
self.vllm_config):
|
self.vllm_config):
|
||||||
self.model = DeepSeekMTP(
|
self._init_mtp_model()
|
||||||
vllm_config=self.vllm_config).to(target_device)
|
|
||||||
|
|
||||||
draft_attn_layer_names = (get_layers_from_vllm_config(
|
draft_attn_layer_names = (get_layers_from_vllm_config(
|
||||||
self.vllm_config, AttentionLayerBase).keys() -
|
self.vllm_config, AttentionLayerBase).keys() -
|
||||||
target_attn_layer_names)
|
target_attn_layer_names)
|
||||||
@@ -228,8 +246,7 @@ class MtpProposer(Proposer):
|
|||||||
attn_metadata=None,
|
attn_metadata=None,
|
||||||
aux_hidden_states: torch.Tensor = None):
|
aux_hidden_states: torch.Tensor = None):
|
||||||
common_attn_metadata = self.runner.spec_decode_common_attn_metadata
|
common_attn_metadata = self.runner.spec_decode_common_attn_metadata
|
||||||
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
attn_metadata = self._get_attn_metadata(attn_metadata)
|
||||||
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
|
|
||||||
|
|
||||||
if self.speculative_config.disable_padded_drafter_batch:
|
if self.speculative_config.disable_padded_drafter_batch:
|
||||||
# When padded-batch is disabled, the sampled_token_ids should be
|
# When padded-batch is disabled, the sampled_token_ids should be
|
||||||
@@ -311,6 +328,20 @@ class MtpProposer(Proposer):
|
|||||||
|
|
||||||
return draft_token_ids
|
return draft_token_ids
|
||||||
|
|
||||||
|
def _init_mtp_model(self):
|
||||||
|
architecture = self.vllm_config.model_config.architecture
|
||||||
|
target_device = self.vllm_config.device_config.device
|
||||||
|
model = _load_model(architecture)
|
||||||
|
self.model = model(vllm_config=self.vllm_config).to(target_device)
|
||||||
|
|
||||||
|
def _get_attn_metadata(self, attn_metadata):
|
||||||
|
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
||||||
|
architecture = self.vllm_config.model_config.architecture
|
||||||
|
layer_name = _FIRST_LAYERS.get(architecture, _DEFAULT_FIRST_LAYER)
|
||||||
|
attn_metadata = attn_metadata[layer_name]
|
||||||
|
|
||||||
|
return attn_metadata
|
||||||
|
|
||||||
def _prepare_inputs(
|
def _prepare_inputs(
|
||||||
self,
|
self,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
|||||||
@@ -1852,7 +1852,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
extra_attn_metadata_args = dict(
|
extra_attn_metadata_args = dict(
|
||||||
num_accepted_tokens=self.num_accepted_tokens.
|
num_accepted_tokens=self.num_accepted_tokens.
|
||||||
gpu[:num_reqs],
|
gpu[:num_reqs],
|
||||||
num_draft_tokens=self.num_draft_tokens.
|
num_decode_draft_tokens_cpu=self.num_draft_tokens.
|
||||||
gpu[:num_reqs],
|
gpu[:num_reqs],
|
||||||
)
|
)
|
||||||
attn_metadata_i = builder.build(
|
attn_metadata_i = builder.build(
|
||||||
@@ -1948,11 +1948,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
attn_state = AscendAttentionState.SpecDecoding
|
attn_state = AscendAttentionState.SpecDecoding
|
||||||
# Speculative decoding.
|
# Speculative decoding.
|
||||||
elif np.all(num_valid_tokens == 1):
|
elif np.all(num_valid_tokens == 1):
|
||||||
if self.drafter and (self.drafter.name == SpecDcodeType.EAGLE
|
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
|
||||||
or self.drafter.name == SpecDcodeType.EAGLE3):
|
|
||||||
attn_state = AscendAttentionState.ChunkedPrefill
|
|
||||||
else:
|
|
||||||
attn_state = AscendAttentionState.SpecDecoding
|
attn_state = AscendAttentionState.SpecDecoding
|
||||||
|
else:
|
||||||
|
attn_state = AscendAttentionState.ChunkedPrefill
|
||||||
# splitfuse
|
# splitfuse
|
||||||
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
|
elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled:
|
||||||
attn_state = AscendAttentionState.ChunkedPrefill
|
attn_state = AscendAttentionState.ChunkedPrefill
|
||||||
@@ -2548,7 +2547,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
with ProfileExecuteDuration().capture_async("Draft"):
|
with ProfileExecuteDuration().capture_async("Draft"):
|
||||||
if self.speculative_config:
|
if self.speculative_config:
|
||||||
use_padded_batch_for_eagle = self.speculative_config and \
|
use_padded_batch_for_eagle = self.speculative_config and \
|
||||||
self.speculative_config.method == "deepseek_mtp" and \
|
self.speculative_config.method in ("deepseek_mtp", "qwen3_next_mtp") and \
|
||||||
not self.speculative_config.disable_padded_drafter_batch
|
not self.speculative_config.disable_padded_drafter_batch
|
||||||
if use_padded_batch_for_eagle:
|
if use_padded_batch_for_eagle:
|
||||||
# EAGLE speculative decoding can use the GPU sampled tokens
|
# EAGLE speculative decoding can use the GPU sampled tokens
|
||||||
|
|||||||
Reference in New Issue
Block a user