[FEAT] Refactor spec decode to support efficient padded speculation (#3528)
### What this PR does / why we need it?
1. Refactor the file `mtp_proposer.py`, splits torchair related codes
into `mtp_torchair_proposer.py`
2. According to https://github.com/vllm-project/vllm/pull/24539,
implements padded speculative decoding as described in
https://github.com/vllm-project/vllm/issues/21984.
### Does this PR introduce _any_ user-facing change?
User can use `disable_padded_drafter_batch` to disable/enable padded
speculation, default is `False`.
offline example:
```
speculative_config={"method": "deepseek_mtp", "num_speculative_tokens": 1, "disable_padded_drafter_batch": False}
```
### How was this patch tested?
- [x] egaer with pad/unpad:
- [x] aclgraph with pad/unpad
- [x] torchair with pad/unpad
performance test of deepseek-r1 with tp16、dp1
aclgraph with pad ITL: 168ms
aclgraph with unpad ITL: 169ms
original: 178ms
- vLLM version: v0.11.0rc3
- vLLM main:
83f478bb19
---------
Signed-off-by: xuyexiong <xuyexiong@huawei.com>
This commit is contained in:
@@ -1,11 +1,15 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.config import CompilationConfig, CUDAGraphMode
|
from vllm.config import CompilationConfig, CUDAGraphMode
|
||||||
|
|
||||||
from tests.e2e.conftest import VllmRunner
|
from tests.e2e.conftest import VllmRunner
|
||||||
|
|
||||||
|
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sampling_config():
|
def sampling_config():
|
||||||
@@ -17,12 +21,12 @@ def model_name():
|
|||||||
return "wemaster/deepseek_mtp_main_random_bf16"
|
return "wemaster/deepseek_mtp_main_random_bf16"
|
||||||
|
|
||||||
|
|
||||||
def mtp_correctness(
|
def mtp_correctness(sampling_config: SamplingParams,
|
||||||
sampling_config: SamplingParams,
|
|
||||||
model_name: str,
|
model_name: str,
|
||||||
num_speculative_tokens: int,
|
num_speculative_tokens: int,
|
||||||
graph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE,
|
graph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE,
|
||||||
):
|
enforce_eager=False,
|
||||||
|
disable_padded_drafter_batch=True):
|
||||||
example_prompts = [
|
example_prompts = [
|
||||||
"Hello, my name is",
|
"Hello, my name is",
|
||||||
"The president of the United States is",
|
"The president of the United States is",
|
||||||
@@ -37,7 +41,7 @@ def mtp_correctness(
|
|||||||
tensor_parallel_size=1,
|
tensor_parallel_size=1,
|
||||||
gpu_memory_utilization=0.7,
|
gpu_memory_utilization=0.7,
|
||||||
max_model_len=256,
|
max_model_len=256,
|
||||||
enforce_eager=False) as ref_llm:
|
enforce_eager=enforce_eager) as ref_llm:
|
||||||
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
|
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
|
||||||
|
|
||||||
graph_mode_str = "PIECEWISE"
|
graph_mode_str = "PIECEWISE"
|
||||||
@@ -54,8 +58,9 @@ def mtp_correctness(
|
|||||||
speculative_config={
|
speculative_config={
|
||||||
"method": "deepseek_mtp",
|
"method": "deepseek_mtp",
|
||||||
"num_speculative_tokens": num_speculative_tokens,
|
"num_speculative_tokens": num_speculative_tokens,
|
||||||
|
"disable_padded_drafter_batch": disable_padded_drafter_batch,
|
||||||
},
|
},
|
||||||
enforce_eager=False,
|
enforce_eager=enforce_eager,
|
||||||
max_model_len=2000,
|
max_model_len=2000,
|
||||||
compilation_config=CompilationConfig(
|
compilation_config=CompilationConfig(
|
||||||
cudagraph_mode=graph_mode_str),
|
cudagraph_mode=graph_mode_str),
|
||||||
@@ -82,6 +87,20 @@ def mtp_correctness(
|
|||||||
del spec_llm
|
del spec_llm
|
||||||
|
|
||||||
|
|
||||||
|
def test_mtp1_correctness_eager(
|
||||||
|
sampling_config: SamplingParams,
|
||||||
|
model_name: str,
|
||||||
|
):
|
||||||
|
mtp_correctness(sampling_config, model_name, 1, enforce_eager=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mtp2_correctness_eager(
|
||||||
|
sampling_config: SamplingParams,
|
||||||
|
model_name: str,
|
||||||
|
):
|
||||||
|
mtp_correctness(sampling_config, model_name, 2, enforce_eager=True)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("TODO(cmq): Revert me when mtp aclgraph is fixed")
|
@pytest.mark.skip("TODO(cmq): Revert me when mtp aclgraph is fixed")
|
||||||
def test_mtp1_correctness_piecewise_graph(
|
def test_mtp1_correctness_piecewise_graph(
|
||||||
sampling_config: SamplingParams,
|
sampling_config: SamplingParams,
|
||||||
@@ -110,3 +129,47 @@ def test_mtp2_correctness_full_graph(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
):
|
):
|
||||||
mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL)
|
mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mtp1_correctness_eager_with_pad(
|
||||||
|
sampling_config: SamplingParams,
|
||||||
|
model_name: str,
|
||||||
|
):
|
||||||
|
mtp_correctness(sampling_config,
|
||||||
|
model_name,
|
||||||
|
1,
|
||||||
|
enforce_eager=True,
|
||||||
|
disable_padded_drafter_batch=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_mtp2_correctness_eager_with_pad(
|
||||||
|
sampling_config: SamplingParams,
|
||||||
|
model_name: str,
|
||||||
|
):
|
||||||
|
mtp_correctness(sampling_config,
|
||||||
|
model_name,
|
||||||
|
2,
|
||||||
|
enforce_eager=True,
|
||||||
|
disable_padded_drafter_batch=False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("TODO(xyx): Revert me when mtp aclgraph is fixed")
|
||||||
|
def test_mtp1_correctness_piecewise_graph_with_pad(
|
||||||
|
sampling_config: SamplingParams,
|
||||||
|
model_name: str,
|
||||||
|
):
|
||||||
|
mtp_correctness(sampling_config,
|
||||||
|
model_name,
|
||||||
|
1,
|
||||||
|
disable_padded_drafter_batch=False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip("TODO(xyx): Revert me when mtp aclgraph is fixed")
|
||||||
|
def test_mtp2_correctness_piecewise_graph_with_pad(
|
||||||
|
sampling_config: SamplingParams,
|
||||||
|
model_name: str,
|
||||||
|
):
|
||||||
|
mtp_correctness(sampling_config,
|
||||||
|
model_name,
|
||||||
|
2,
|
||||||
|
disable_padded_drafter_batch=False)
|
||||||
|
|||||||
@@ -19,14 +19,21 @@
|
|||||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||||||
from vllm_ascend.spec_decode.ngram_proposer import NgramProposer
|
from vllm_ascend.spec_decode.ngram_proposer import NgramProposer
|
||||||
|
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
|
||||||
|
|
||||||
|
|
||||||
def get_spec_decode_method(method, vllm_config, device, runner):
|
def get_spec_decode_method(method,
|
||||||
|
vllm_config,
|
||||||
|
device,
|
||||||
|
runner,
|
||||||
|
is_torchair_graph=False):
|
||||||
if method == "ngram":
|
if method == "ngram":
|
||||||
return NgramProposer(vllm_config, device, runner)
|
return NgramProposer(vllm_config, device, runner)
|
||||||
elif method in ["eagle", "eagle3"]:
|
elif method in ["eagle", "eagle3"]:
|
||||||
return EagleProposer(vllm_config, device, runner)
|
return EagleProposer(vllm_config, device, runner)
|
||||||
elif method == 'deepseek_mtp':
|
elif method == 'deepseek_mtp':
|
||||||
|
if is_torchair_graph:
|
||||||
|
return TorchairMtpProposer(vllm_config, device, runner)
|
||||||
return MtpProposer(vllm_config, device, runner)
|
return MtpProposer(vllm_config, device, runner)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown speculative decoding method: "
|
raise ValueError("Unknown speculative decoding method: "
|
||||||
|
|||||||
@@ -1,37 +1,41 @@
|
|||||||
import types
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchair
|
|
||||||
from torchair import patch_for_hcom
|
|
||||||
from vllm.config import (CUDAGraphMode, VllmConfig,
|
from vllm.config import (CUDAGraphMode, VllmConfig,
|
||||||
get_layers_from_vllm_config, set_current_vllm_config)
|
get_layers_from_vllm_config, set_current_vllm_config)
|
||||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
from vllm.forward_context import BatchDescriptor
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
from vllm.model_executor.model_loader import get_model_loader
|
from vllm.model_executor.model_loader import get_model_loader
|
||||||
from vllm.model_executor.model_loader.utils import \
|
from vllm.model_executor.model_loader.utils import \
|
||||||
process_weights_after_loading
|
process_weights_after_loading
|
||||||
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
|
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
|
||||||
|
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||||
|
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||||
|
CommonAttentionMetadata)
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
|
from vllm.v1.utils import CpuGpuBuffer
|
||||||
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
|
||||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
||||||
from vllm_ascend.torchair.models.torchair_deepseek_mtp import \
|
|
||||||
TorchairDeepSeekMTP
|
|
||||||
from vllm_ascend.torchair.utils import (TORCHAIR_CACHE_DIR,
|
|
||||||
TorchairCommonAttentionMetadata)
|
|
||||||
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
|
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
|
||||||
vllm_version_is)
|
vllm_version_is)
|
||||||
|
|
||||||
if vllm_version_is("0.11.0"):
|
if vllm_version_is("0.11.0"):
|
||||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||||
|
from vllm.utils import is_pin_memory_available
|
||||||
else:
|
else:
|
||||||
|
from vllm.utils.platform_utils import is_pin_memory_available
|
||||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
PADDING_SLOT_ID = -1
|
PADDING_SLOT_ID = -1
|
||||||
|
|
||||||
|
|
||||||
@@ -45,34 +49,77 @@ class MtpProposer(Proposer):
|
|||||||
):
|
):
|
||||||
self.name = SpecDcodeType.MTP
|
self.name = SpecDcodeType.MTP
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.device = device
|
self.speculative_config = vllm_config.speculative_config
|
||||||
self.runner = runner
|
assert self.speculative_config is not None
|
||||||
self.num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
|
self.draft_model_config = self.speculative_config.draft_model_config
|
||||||
|
self.method = self.speculative_config.method
|
||||||
|
|
||||||
# persistent buffers for graph
|
self.runner = runner
|
||||||
self.input_ids = torch.zeros(self.runner.max_num_tokens,
|
self.device = device
|
||||||
|
self.dtype = vllm_config.model_config.dtype
|
||||||
|
self.max_model_len = vllm_config.model_config.max_model_len
|
||||||
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
|
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
|
||||||
|
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||||
|
self.token_arange_np = np.arange(self.max_num_tokens)
|
||||||
|
# We need to get the hidden size from the draft model config because
|
||||||
|
# the draft model's hidden size can be different from the target model's
|
||||||
|
# hidden size (e.g., Llama 3.3 70B).
|
||||||
|
self.hidden_size = self.draft_model_config.get_hidden_size()
|
||||||
|
|
||||||
|
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
|
||||||
|
self.draft_indexer_metadata_builder: Optional[
|
||||||
|
AttentionMetadataBuilder] = None
|
||||||
|
self.attn_layer_names: list[str] = []
|
||||||
|
self.indexer_layer_names: list[str] = []
|
||||||
|
|
||||||
|
self.use_aclgraph = self.runner._use_aclgraph()
|
||||||
|
|
||||||
|
self.cudagraph_batch_sizes = (list(
|
||||||
|
reversed(
|
||||||
|
self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
||||||
|
if self.use_aclgraph else [])
|
||||||
|
|
||||||
|
# persistent buffers for aclgraph graph
|
||||||
|
self.input_ids = torch.zeros(self.max_num_tokens,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device)
|
device=device)
|
||||||
self.positions = torch.zeros(self.runner.max_num_tokens,
|
self.uses_mrope = self.vllm_config.model_config.uses_mrope
|
||||||
|
if self.uses_mrope:
|
||||||
|
# M-RoPE need (3, max_num_tokens)
|
||||||
|
self.mrope_positions = torch.zeros((3, self.max_num_tokens),
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device=self.device)
|
device=device)
|
||||||
|
else:
|
||||||
|
# RoPE need (max_num_tokens,)
|
||||||
|
self.positions = torch.zeros(self.max_num_tokens,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=device)
|
||||||
self.hidden_states = torch.zeros(
|
self.hidden_states = torch.zeros(
|
||||||
(self.runner.max_num_tokens,
|
(self.max_num_tokens, self.hidden_size),
|
||||||
vllm_config.model_config.get_hidden_size()),
|
dtype=self.dtype,
|
||||||
dtype=self.runner.dtype,
|
device=device)
|
||||||
device=self.device)
|
|
||||||
self.torchair_compiled_model = None # type: ignore
|
|
||||||
self.torchair_compiled_models = {} # type: ignore
|
|
||||||
self.torchair_graph_enabled = get_ascend_config(
|
|
||||||
).torchair_graph_config.enabled
|
|
||||||
self.enable_shared_expert_dp = get_ascend_config(
|
|
||||||
).enable_shared_expert_dp
|
|
||||||
# We need +1 here because the arange is used to set query_start_loc,
|
# We need +1 here because the arange is used to set query_start_loc,
|
||||||
# which has one more element than batch_size.
|
# which has one more element than batch_size.
|
||||||
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
|
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||||
1,
|
max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
|
||||||
device=self.runner.device,
|
self.arange = torch.arange(max_num_slots_for_arange,
|
||||||
|
device=device,
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
self.inputs_embeds = torch.zeros(
|
||||||
|
(self.max_num_tokens, self.hidden_size),
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=device)
|
||||||
|
|
||||||
|
self.backup_next_token_ids = CpuGpuBuffer(
|
||||||
|
max_batch_size,
|
||||||
|
dtype=torch.int32,
|
||||||
|
pin_memory=is_pin_memory_available(),
|
||||||
|
device=device,
|
||||||
|
with_numpy=True,
|
||||||
|
)
|
||||||
self.use_sparse = hasattr(vllm_config.model_config.hf_config,
|
self.use_sparse = hasattr(vllm_config.model_config.hf_config,
|
||||||
"index_topk")
|
"index_topk")
|
||||||
|
|
||||||
@@ -89,12 +136,6 @@ class MtpProposer(Proposer):
|
|||||||
with set_default_torch_dtype(
|
with set_default_torch_dtype(
|
||||||
draft_model_config.dtype), set_current_vllm_config(
|
draft_model_config.dtype), set_current_vllm_config(
|
||||||
self.vllm_config):
|
self.vllm_config):
|
||||||
if self.torchair_graph_enabled or (
|
|
||||||
self.enable_shared_expert_dp
|
|
||||||
and self.vllm_config.model_config.use_mla):
|
|
||||||
self.model = TorchairDeepSeekMTP(
|
|
||||||
vllm_config=self.vllm_config).to(target_device)
|
|
||||||
else:
|
|
||||||
self.model = DeepSeekMTP(
|
self.model = DeepSeekMTP(
|
||||||
vllm_config=self.vllm_config).to(target_device)
|
vllm_config=self.vllm_config).to(target_device)
|
||||||
|
|
||||||
@@ -121,7 +162,7 @@ class MtpProposer(Proposer):
|
|||||||
num_tokens_across_dp=None,
|
num_tokens_across_dp=None,
|
||||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
batch_descriptor=None) -> None:
|
batch_descriptor=None) -> None:
|
||||||
if not self.torchair_graph_enabled:
|
|
||||||
(
|
(
|
||||||
num_tokens,
|
num_tokens,
|
||||||
num_tokens_across_dp,
|
num_tokens_across_dp,
|
||||||
@@ -131,24 +172,7 @@ class MtpProposer(Proposer):
|
|||||||
moe_comm_type = self.runner._select_moe_comm_method(
|
moe_comm_type = self.runner._select_moe_comm_method(
|
||||||
num_tokens, with_prefill)
|
num_tokens, with_prefill)
|
||||||
|
|
||||||
is_running_torchair = self.torchair_graph_enabled and \
|
|
||||||
not with_prefill
|
|
||||||
|
|
||||||
if is_running_torchair:
|
|
||||||
skip_attn = False
|
|
||||||
if skip_attn:
|
|
||||||
attn_metadata = None
|
attn_metadata = None
|
||||||
else:
|
|
||||||
common_attn_metadata = TorchairCommonAttentionMetadata(
|
|
||||||
num_reqs=num_reqs,
|
|
||||||
num_actual_tokens=1,
|
|
||||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
|
||||||
attn_mask=self.runner.attn_mask,
|
|
||||||
spec_attn_mask=self.runner.spec_attn_mask,
|
|
||||||
decode_token_per_req=self.runner.decode_token_per_req,
|
|
||||||
)
|
|
||||||
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(
|
|
||||||
common_attn_metadata)
|
|
||||||
|
|
||||||
input_ids = self.input_ids[:num_tokens]
|
input_ids = self.input_ids[:num_tokens]
|
||||||
positions = self.positions[:num_tokens]
|
positions = self.positions[:num_tokens]
|
||||||
@@ -166,32 +190,6 @@ class MtpProposer(Proposer):
|
|||||||
num_actual_tokens=0,
|
num_actual_tokens=0,
|
||||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||||
batch_descriptor=batch_descriptor):
|
batch_descriptor=batch_descriptor):
|
||||||
if is_running_torchair:
|
|
||||||
assert attn_metadata is not None
|
|
||||||
torch._dynamo.mark_static(input_ids)
|
|
||||||
torch._dynamo.mark_static(positions)
|
|
||||||
torch._dynamo.mark_static(previous_hidden_states)
|
|
||||||
torch._dynamo.mark_static(attn_metadata.decode.block_table)
|
|
||||||
torch._dynamo.mark_static(
|
|
||||||
attn_metadata.decode.input_positions)
|
|
||||||
if hasattr(attn_metadata.decode, "sin"):
|
|
||||||
torch._dynamo.mark_static(attn_metadata.decode.sin)
|
|
||||||
torch._dynamo.mark_static(attn_metadata.decode.cos)
|
|
||||||
torch._dynamo.mark_static(get_forward_context().mc2_mask)
|
|
||||||
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
|
||||||
torch._dynamo.mark_static(attn_metadata.decode.attn_mask)
|
|
||||||
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
|
|
||||||
num_tokens)
|
|
||||||
torchair_compiled_model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
positions=positions,
|
|
||||||
hidden_states=previous_hidden_states,
|
|
||||||
inputs_embeds=None,
|
|
||||||
intermediate_tensors=None,
|
|
||||||
attn_metadata=attn_metadata,
|
|
||||||
kv_caches=self.runner.kv_caches[-1:],
|
|
||||||
spec_step_idx=0)
|
|
||||||
else:
|
|
||||||
self.model(input_ids=input_ids,
|
self.model(input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=previous_hidden_states)
|
hidden_states=previous_hidden_states)
|
||||||
@@ -199,7 +197,7 @@ class MtpProposer(Proposer):
|
|||||||
break
|
break
|
||||||
|
|
||||||
def generate_token_ids(self,
|
def generate_token_ids(self,
|
||||||
valid_sampled_token_ids: list[list[int]],
|
sampled_token_ids: list[list[int]],
|
||||||
sampling_metadata: SamplingMetadata = None,
|
sampling_metadata: SamplingMetadata = None,
|
||||||
scheduler_output: SchedulerOutput = None,
|
scheduler_output: SchedulerOutput = None,
|
||||||
spec_decode_metadata: SpecDecodeMetadata = None,
|
spec_decode_metadata: SpecDecodeMetadata = None,
|
||||||
@@ -208,235 +206,240 @@ class MtpProposer(Proposer):
|
|||||||
hidden_states: torch.Tensor = None,
|
hidden_states: torch.Tensor = None,
|
||||||
attn_metadata=None,
|
attn_metadata=None,
|
||||||
aux_hidden_states: torch.Tensor = None):
|
aux_hidden_states: torch.Tensor = None):
|
||||||
|
common_attn_metadata = self.runner.spec_decode_common_attn_metadata
|
||||||
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
||||||
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
|
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
|
||||||
next_token_ids: list[int] = []
|
|
||||||
for i, token_ids in enumerate(valid_sampled_token_ids):
|
if self.speculative_config.disable_padded_drafter_batch:
|
||||||
if token_ids:
|
# When padded-batch is disabled, the sampled_token_ids should be
|
||||||
# Common case.
|
# the cpu-side list[list[int]] of valid sampled tokens for each
|
||||||
next_token_id = token_ids[-1]
|
# request, with invalid requests having empty lists.
|
||||||
|
assert isinstance(sampled_token_ids, list), \
|
||||||
|
"sampled_token_ids should be a python list when" \
|
||||||
|
"padded-batch is disabled."
|
||||||
|
next_token_ids = self.prepare_next_token_ids_cpu(
|
||||||
|
sampled_token_ids, self.runner.requests,
|
||||||
|
self.runner.input_batch, scheduler_output.num_scheduled_tokens)
|
||||||
else:
|
else:
|
||||||
# Partial prefill (rare case).
|
# When using padded-batch, the sampled_token_ids should be
|
||||||
# Get the next token id from the request state.
|
# the gpu tensor of sampled tokens for each request, of shape
|
||||||
req_id = self.runner.input_batch.req_ids[i]
|
# (num_reqs, num_spec_tokens + 1) with rejected tokens having
|
||||||
req_state = self.runner.requests[req_id]
|
# value -1.
|
||||||
seq_len = (req_state.num_computed_tokens +
|
assert isinstance(sampled_token_ids, torch.Tensor), \
|
||||||
scheduler_output.num_scheduled_tokens[req_id])
|
"sampled_token_ids should be a torch.Tensor when" \
|
||||||
next_token_id = req_state.get_token_id(seq_len)
|
"padded-batch is enabled."
|
||||||
next_token_ids.append(next_token_id)
|
next_token_ids, valid_sampled_tokens_count = \
|
||||||
next_token_ids = torch.tensor(next_token_ids,
|
self.prepare_next_token_ids_padded(
|
||||||
dtype=torch.int32,
|
common_attn_metadata,
|
||||||
device=self.device)
|
sampled_token_ids,
|
||||||
accepted_token_indices = None
|
self.runner.requests,
|
||||||
|
self.runner.input_batch,
|
||||||
|
self.runner.discard_request_indices.gpu,
|
||||||
|
self.runner.num_discarded_requests
|
||||||
|
)
|
||||||
|
|
||||||
if spec_decode_metadata is None:
|
if spec_decode_metadata is None:
|
||||||
|
token_indices_to_sample = None
|
||||||
# input_ids can be None for multimodal models.
|
# input_ids can be None for multimodal models.
|
||||||
target_token_ids = self.runner.input_ids[:num_scheduled_tokens]
|
target_token_ids = self.runner.input_ids[:num_scheduled_tokens]
|
||||||
target_positions = positions[:num_scheduled_tokens]
|
target_positions = positions[:num_scheduled_tokens]
|
||||||
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||||
target_slot_mapping = attn_metadata.slot_mapping
|
|
||||||
cu_num_tokens = attn_metadata.query_start_loc
|
|
||||||
else:
|
else:
|
||||||
# TODO(woosuk): Refactor this.
|
if self.speculative_config.disable_padded_drafter_batch:
|
||||||
num_draft_tokens = spec_decode_metadata.num_draft_tokens
|
token_indices_to_sample = None
|
||||||
num_rejected_tokens = [
|
common_attn_metadata, token_indices =\
|
||||||
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
|
self._prepare_inputs(
|
||||||
for i, n in enumerate(num_draft_tokens)
|
common_attn_metadata,
|
||||||
]
|
sampled_token_ids,
|
||||||
num_rejected_tokens = torch.tensor(
|
spec_decode_metadata.num_draft_tokens)
|
||||||
num_rejected_tokens,
|
else:
|
||||||
dtype=torch.int32,
|
common_attn_metadata, token_indices, \
|
||||||
device=self.device,
|
token_indices_to_sample =\
|
||||||
)
|
self.prepare_inputs_padded(
|
||||||
cu_num_tokens, accepted_token_indices, target_token_ids, \
|
common_attn_metadata,
|
||||||
target_positions, target_hidden_states, target_slot_mapping = self._prepare_inputs(
|
spec_decode_metadata,
|
||||||
attn_metadata.query_start_loc,
|
valid_sampled_tokens_count)
|
||||||
num_rejected_tokens,
|
target_token_ids = self.runner.input_ids[token_indices]
|
||||||
self.runner.input_ids[:num_scheduled_tokens],
|
target_positions = positions[token_indices]
|
||||||
positions[:num_scheduled_tokens],
|
target_hidden_states = hidden_states[token_indices]
|
||||||
hidden_states[:num_scheduled_tokens],
|
|
||||||
attn_metadata.slot_mapping[:num_scheduled_tokens],
|
|
||||||
is_torchair_graph=self.runner._build_drafter_prepare_inputs_torchair_param(),
|
|
||||||
)
|
|
||||||
|
|
||||||
draft_token_ids = self._propose(
|
draft_token_ids = self._propose(
|
||||||
target_token_ids=target_token_ids,
|
target_token_ids=target_token_ids,
|
||||||
target_positions=target_positions,
|
target_positions=target_positions,
|
||||||
target_hidden_states=target_hidden_states,
|
target_hidden_states=target_hidden_states,
|
||||||
target_slot_mapping=target_slot_mapping,
|
|
||||||
next_token_ids=next_token_ids,
|
next_token_ids=next_token_ids,
|
||||||
cu_num_tokens=cu_num_tokens,
|
last_token_indices=token_indices_to_sample,
|
||||||
block_table=attn_metadata.block_tables,
|
common_attn_metadata=common_attn_metadata,
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
token_indices=accepted_token_indices)
|
)
|
||||||
spec_token_ids = draft_token_ids.tolist()
|
|
||||||
return spec_token_ids
|
return draft_token_ids
|
||||||
|
|
||||||
def _prepare_inputs(
|
def _prepare_inputs(
|
||||||
self,
|
self,
|
||||||
# [batch_size + 1]
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
cu_target_query_lens: torch.Tensor,
|
sampled_token_ids: list[list[int]],
|
||||||
# [batch_size]
|
num_draft_tokens: list[int],
|
||||||
num_rejected_tokens: torch.Tensor,
|
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
|
||||||
token_ids: torch.Tensor,
|
"""
|
||||||
positions: torch.Tensor,
|
This function is used to prepare the inputs for speculative decoding.
|
||||||
hidden_states: torch.Tensor,
|
It updates to the common_attn_metadata to account for the rejected
|
||||||
slot_mapping: torch.Tensor,
|
tokens (and newly sampled tokens). It also returns the token indices
|
||||||
is_torchair_graph: bool = False
|
of the tokens that should be fed to the speculator.
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
"""
|
||||||
torch.Tensor, torch.Tensor]:
|
# E.g.
|
||||||
# cu_target_query_lens: [0, a, a + b, a + b + c]
|
# common_attn_metadata.query_start_loc{_cpu}:
|
||||||
|
# [0, q1, q1 + q2, q1 + q2 + q3]
|
||||||
|
# common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
|
||||||
# num_rejected_tokens: [n1, n2, n3]
|
# num_rejected_tokens: [n1, n2, n3]
|
||||||
# num_tokens_per_req: [a - n1, b - n2, c - n3]
|
# This function computes the intermediate values:
|
||||||
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
|
# num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
|
||||||
# token_indices: [0, 1, ..., a - n1 - 1,
|
# And returns:
|
||||||
# a, a + 1, ..., a + b - n2 - 1,
|
# common_attn_metadata.query_start_loc{_cpu}:
|
||||||
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
|
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
|
||||||
# [0, a, a + b, a + b + c] -> [a, b, c]
|
# common_attn_metadata.seq_lens{_cpu}:
|
||||||
query_len_per_req = (cu_target_query_lens[1:] -
|
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
|
||||||
cu_target_query_lens[:-1])
|
# token_indices: [0, 1, ..., q1 - n1 - 1,
|
||||||
# [a, b, c] -> [a - n1, b - n2, c - n3]
|
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
|
||||||
num_tokens_per_req = query_len_per_req - num_rejected_tokens
|
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
|
||||||
if is_torchair_graph:
|
|
||||||
cu_num_tokens = cu_target_query_lens
|
|
||||||
relative_index = query_len_per_req - num_rejected_tokens - 1
|
|
||||||
token_indices = cu_num_tokens[:-1] + relative_index
|
|
||||||
# the seq len of each bath is padded to 1+num_speculative_tokens, thus input is same as the main model
|
|
||||||
target_token_ids = token_ids
|
|
||||||
target_positions = positions
|
|
||||||
target_hidden_states = hidden_states
|
|
||||||
target_slot_mapping = slot_mapping
|
|
||||||
else:
|
|
||||||
cu_num_tokens = torch.empty_like(cu_target_query_lens)
|
|
||||||
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
|
|
||||||
cu_num_tokens[0] = 0
|
|
||||||
|
|
||||||
# FIXME(woosuk): Avoid synchronization.
|
num_rejected_tokens = [
|
||||||
num_tokens = cu_num_tokens[-1].item()
|
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
|
||||||
token_indices = torch.zeros(
|
for i, n in enumerate(num_draft_tokens)
|
||||||
num_tokens,
|
]
|
||||||
|
num_rejected_tokens = torch.tensor(num_rejected_tokens,
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
device = common_attn_metadata.query_start_loc.device
|
||||||
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||||
|
new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu - num_rejected_tokens
|
||||||
|
|
||||||
|
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
|
||||||
|
new_query_len_per_req = query_start_loc_cpu[
|
||||||
|
1:] - query_start_loc_cpu[:-1]
|
||||||
|
# [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
|
||||||
|
new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
|
||||||
|
new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()
|
||||||
|
|
||||||
|
# [q1 - n1, q2 - n2, q3 - n3] ->
|
||||||
|
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
|
||||||
|
new_query_start_loc_cpu = torch.zeros(
|
||||||
|
query_start_loc_cpu.shape,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=cu_num_tokens.device,
|
pin_memory=is_pin_memory_available(),
|
||||||
)
|
)
|
||||||
|
new_query_start_loc_np = new_query_start_loc_cpu.numpy()
|
||||||
|
np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])
|
||||||
|
|
||||||
BLOCK_SIZE = 1024
|
total_num_tokens = new_query_start_loc_np[-1]
|
||||||
self._prepare_input_kernel(
|
# Example assuming num_tokens_per_req_np = [2, 4, 3]
|
||||||
token_indices,
|
# this implies that `new_query_start_locs` is:
|
||||||
cu_target_query_lens,
|
# [0, 2, 6, 9] ->
|
||||||
cu_num_tokens,
|
# [0, 0, 2, 2, 2, 2, 6, 6, 6]
|
||||||
block_size=BLOCK_SIZE,
|
# _r1_ ____r2____ ___r3__
|
||||||
|
new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1],
|
||||||
|
new_num_tokens_per_req_np)
|
||||||
|
# [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
|
||||||
|
# [0, 1, 0, 1, 2, 3, 0, 1, 2]
|
||||||
|
# _r1_ ____r2____ ___r3__
|
||||||
|
token_offests = (self.token_arange_np[:total_num_tokens] -
|
||||||
|
new_query_start_locs_expanded)
|
||||||
|
|
||||||
|
# Expand starting positions to match token pattern
|
||||||
|
# [0, q1, q1 + q2] ->
|
||||||
|
# [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
|
||||||
|
# _r1_ _____r2_______ ___________r3____________
|
||||||
|
old_query_start_locs_expanded = np.repeat(
|
||||||
|
query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
|
||||||
|
# Final token indices are:
|
||||||
|
# [0, 1, // req 1
|
||||||
|
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
|
||||||
|
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
|
||||||
|
token_indices_np = token_offests + old_query_start_locs_expanded
|
||||||
|
token_indices = torch.from_numpy(token_indices_np).to(
|
||||||
|
device, non_blocking=True)
|
||||||
|
|
||||||
|
spec_common_attn_metadata = AscendCommonAttentionMetadata(
|
||||||
|
query_start_loc=new_query_start_loc_cpu.to(device,
|
||||||
|
non_blocking=True),
|
||||||
|
query_start_loc_cpu=new_query_start_loc_cpu,
|
||||||
|
seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
|
||||||
|
seq_lens_cpu=new_seq_lens_cpu,
|
||||||
|
num_computed_tokens_cpu=common_attn_metadata.
|
||||||
|
num_computed_tokens_cpu,
|
||||||
|
num_reqs=common_attn_metadata.num_reqs,
|
||||||
|
num_actual_tokens=total_num_tokens,
|
||||||
|
max_query_len=new_query_len_per_req.max().item(),
|
||||||
|
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||||
|
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
|
||||||
|
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||||
|
positions=common_attn_metadata.positions[token_indices],
|
||||||
|
attn_mask=self.runner.attn_mask,
|
||||||
|
spec_attn_mask=self.runner.spec_attn_mask,
|
||||||
|
attn_state=self.runner.attn_state,
|
||||||
|
graph_pad_size=self.runner.graph_pad_size,
|
||||||
|
decode_token_per_req=self.runner.decode_token_per_req,
|
||||||
)
|
)
|
||||||
target_token_ids = token_ids[token_indices]
|
return spec_common_attn_metadata, token_indices
|
||||||
target_positions = positions[token_indices]
|
|
||||||
target_hidden_states = hidden_states[token_indices]
|
|
||||||
target_slot_mapping = slot_mapping[token_indices]
|
|
||||||
return cu_num_tokens, token_indices, target_token_ids, target_positions, target_hidden_states, target_slot_mapping
|
|
||||||
|
|
||||||
def _propose(
|
def _propose(
|
||||||
self,
|
self,
|
||||||
# [num_tokens]
|
# [num_tokens]
|
||||||
target_token_ids: torch.Tensor,
|
target_token_ids: torch.Tensor,
|
||||||
# [num_tokens]
|
# [num_tokens] or [3, num_tokens] when M-RoPE is enabled
|
||||||
target_positions: torch.Tensor,
|
target_positions: torch.Tensor,
|
||||||
# [num_tokens, hidden_size]
|
# [num_tokens, hidden_size]
|
||||||
target_hidden_states: torch.Tensor,
|
target_hidden_states: torch.Tensor,
|
||||||
# [num_tokens]
|
|
||||||
target_slot_mapping: torch.Tensor,
|
|
||||||
# [batch_size]
|
# [batch_size]
|
||||||
next_token_ids: torch.Tensor,
|
next_token_ids: torch.Tensor,
|
||||||
# [batch_size + 1] starting with 0
|
last_token_indices: Optional[torch.Tensor],
|
||||||
cu_num_tokens: torch.Tensor,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
# [batch_size, max_num_blocks_per_req]
|
|
||||||
block_table: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
token_indices=None) -> torch.Tensor:
|
mm_embed_inputs: Optional[tuple[list[torch.Tensor],
|
||||||
|
torch.Tensor]] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
num_tokens = target_token_ids.shape[0]
|
num_tokens = target_token_ids.shape[0]
|
||||||
batch_size = next_token_ids.shape[0]
|
batch_size = next_token_ids.shape[0]
|
||||||
last_token_indices = cu_num_tokens[1:] - 1
|
|
||||||
|
if last_token_indices is None:
|
||||||
|
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
|
||||||
|
|
||||||
|
if self.method == "eagle3":
|
||||||
|
assert isinstance(self.model, Eagle3LlamaForCausalLM)
|
||||||
|
target_hidden_states = self.model.combine_hidden_states(
|
||||||
|
target_hidden_states)
|
||||||
|
assert target_hidden_states.shape[-1] == self.hidden_size
|
||||||
|
|
||||||
# Shift the input ids by one token.
|
# Shift the input ids by one token.
|
||||||
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
||||||
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
|
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
|
||||||
# Replace the last token with the next token.
|
# Replace the last token with the next token.
|
||||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||||
if token_indices is not None and self.torchair_graph_enabled:
|
|
||||||
last_token_indices = token_indices
|
|
||||||
|
|
||||||
self.input_ids[last_token_indices] = next_token_ids
|
self.input_ids[last_token_indices] = next_token_ids
|
||||||
|
|
||||||
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
|
assert self.runner is not None
|
||||||
max_query_len = query_lens.max().item()
|
|
||||||
|
|
||||||
# FIXME: reorder_batch() needs to be called before build()
|
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
||||||
# because fields of attn_metadata_builder needs to be updated.
|
attn_metadata_mtp = builder.build(0, common_attn_metadata,
|
||||||
# However, currently reorder_batch() takes input_batch and
|
self.runner.get_model())
|
||||||
# scheduler_output as arguments, we should probably refactor
|
attn_metadata = {}
|
||||||
# the method to use new data structures which are independent
|
for layer_name in self.attn_layer_name:
|
||||||
# from input_batch and scheduler_output.
|
attn_metadata[layer_name] = attn_metadata_mtp
|
||||||
# self.runner.attn_metadata_builder.reorder_batch(
|
|
||||||
# input_batch=self.runner.input_batch,
|
|
||||||
# scheduler_output=self.runner.scheduler_output,
|
|
||||||
# )
|
|
||||||
is_running_torchair = self.torchair_graph_enabled and \
|
|
||||||
not self.runner.with_prefill
|
|
||||||
|
|
||||||
if is_running_torchair:
|
if self.use_aclgraph and num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||||
# Torchair graph mode, padding is same as the main model
|
|
||||||
num_input_tokens = self.runner.graph_pad_size
|
|
||||||
elif (self.runner.use_aclgraph
|
|
||||||
and num_tokens <= self.runner.aclgraph_batch_sizes[-1]):
|
|
||||||
# Acl graph mode, add padding to the batch size
|
# Acl graph mode, add padding to the batch size
|
||||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||||
else:
|
else:
|
||||||
# Eager mode, no padding needed
|
# Eager mode, no padding needed
|
||||||
num_input_tokens = num_tokens
|
num_input_tokens = num_tokens
|
||||||
|
|
||||||
seq_lens = target_positions[last_token_indices] + 1
|
# copy inputs to buffer for cudagraph
|
||||||
seq_lens = seq_lens.int()
|
|
||||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
|
||||||
query_start_loc=cu_num_tokens[:batch_size + 1],
|
|
||||||
query_start_loc_cpu=cu_num_tokens[:batch_size + 1].cpu(),
|
|
||||||
seq_lens_cpu=seq_lens.cpu(),
|
|
||||||
num_reqs=batch_size,
|
|
||||||
num_actual_tokens=num_tokens,
|
|
||||||
max_query_len=max_query_len,
|
|
||||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
|
||||||
block_table_tensor=self.runner.input_batch.block_table[0].
|
|
||||||
get_device_tensor(),
|
|
||||||
slot_mapping=target_slot_mapping,
|
|
||||||
positions=target_positions,
|
|
||||||
attn_mask=self.runner.attn_mask,
|
|
||||||
spec_attn_mask=self.runner.spec_attn_mask,
|
|
||||||
attn_state=self.runner.attn_state,
|
|
||||||
graph_pad_size=self.runner.graph_pad_size,
|
|
||||||
decode_token_per_req=self.runner.decode_token_per_req,
|
|
||||||
num_computed_tokens_cpu=None,
|
|
||||||
seq_lens=None)
|
|
||||||
|
|
||||||
if not self.torchair_graph_enabled:
|
|
||||||
builder = self.runner.attn_groups[0][0].get_metadata_builder()
|
|
||||||
attn_metadata_mtp = builder.build(0, common_attn_metadata,
|
|
||||||
self.runner.get_model())
|
|
||||||
|
|
||||||
attn_metadata = {}
|
|
||||||
for layer_name in self.attn_layer_name:
|
|
||||||
attn_metadata[layer_name] = attn_metadata_mtp
|
|
||||||
|
|
||||||
else:
|
|
||||||
attn_metadata = self.runner.attn_metadata_builder.build(
|
|
||||||
0, common_attn_metadata, self.runner.get_model())
|
|
||||||
|
|
||||||
self.positions[:num_tokens] = target_positions
|
self.positions[:num_tokens] = target_positions
|
||||||
self.hidden_states[:num_tokens] = target_hidden_states
|
self.hidden_states[:num_tokens] = target_hidden_states
|
||||||
|
# eager/acl piecewise mode need to update num_tokens_across_dp
|
||||||
if not self.torchair_graph_enabled:
|
|
||||||
# torch mode need to update num_tokens_across_dp
|
|
||||||
(num_input_tokens, num_tokens_across_dp,
|
(num_input_tokens, num_tokens_across_dp,
|
||||||
with_prefill) = self.runner._sync_metadata_across_dp(
|
with_prefill) = self.runner._sync_metadata_across_dp(
|
||||||
num_input_tokens, self.runner.with_prefill)
|
num_input_tokens, self.runner.with_prefill)
|
||||||
else:
|
|
||||||
# torchair mode can reuse self.runner.num_tokens_across_dp
|
|
||||||
num_tokens_across_dp = self.runner.num_tokens_across_dp
|
|
||||||
with_prefill = self.runner.with_prefill
|
|
||||||
|
|
||||||
moe_comm_type = self.runner._select_moe_comm_method(
|
moe_comm_type = self.runner._select_moe_comm_method(
|
||||||
num_input_tokens, with_prefill)
|
num_input_tokens, with_prefill)
|
||||||
@@ -444,6 +447,15 @@ class MtpProposer(Proposer):
|
|||||||
uniform_decode=False)
|
uniform_decode=False)
|
||||||
aclgraph_runtime_mode, batch_descriptor = \
|
aclgraph_runtime_mode, batch_descriptor = \
|
||||||
self.runner.aclgraph_dispatcher.dispatch(batch_descriptor)
|
self.runner.aclgraph_dispatcher.dispatch(batch_descriptor)
|
||||||
|
if aclgraph_runtime_mode not in [
|
||||||
|
CUDAGraphMode.PIECEWISE, CUDAGraphMode.NONE
|
||||||
|
]:
|
||||||
|
# Fallback to piecewise graph, when acl full graph is enabled
|
||||||
|
logger.debug(
|
||||||
|
"Currently the eagle proposer only supports cudagraph_mode "
|
||||||
|
f"PIECEWISE, and is forced to set graph mode from {aclgraph_runtime_mode} "
|
||||||
|
"to CUDAGraphMode.PIECEWISE")
|
||||||
|
aclgraph_runtime_mode = CUDAGraphMode.PIECEWISE
|
||||||
|
|
||||||
for step in range(self.num_speculative_tokens):
|
for step in range(self.num_speculative_tokens):
|
||||||
with set_ascend_forward_context(
|
with set_ascend_forward_context(
|
||||||
@@ -461,26 +473,11 @@ class MtpProposer(Proposer):
|
|||||||
with ProfileExecuteDuration().capture_async('mtp_forward'):
|
with ProfileExecuteDuration().capture_async('mtp_forward'):
|
||||||
model_kwargs = {}
|
model_kwargs = {}
|
||||||
model_kwargs["attn_metadata"] = attn_metadata
|
model_kwargs["attn_metadata"] = attn_metadata
|
||||||
if self.torchair_graph_enabled:
|
|
||||||
model_kwargs["kv_caches"] = self.runner.kv_caches[-1:]
|
|
||||||
if is_running_torchair:
|
|
||||||
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
|
|
||||||
num_input_tokens)
|
|
||||||
hidden_states = torchair_compiled_model(
|
|
||||||
input_ids=self.input_ids[:num_input_tokens],
|
|
||||||
positions=self.positions[:num_input_tokens],
|
|
||||||
hidden_states=self.
|
|
||||||
hidden_states[:num_input_tokens],
|
|
||||||
inputs_embeds=None,
|
|
||||||
intermediate_tensors=None,
|
|
||||||
spec_step_idx=0,
|
|
||||||
**model_kwargs)
|
|
||||||
else:
|
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids=self.input_ids[:num_input_tokens],
|
input_ids=self.input_ids[:num_input_tokens],
|
||||||
positions=self.positions[:num_input_tokens],
|
positions=self.positions[:num_input_tokens],
|
||||||
hidden_states=self.hidden_states[:num_input_tokens]
|
hidden_states=self.hidden_states[:num_input_tokens])
|
||||||
)
|
|
||||||
|
|
||||||
num_indices = last_token_indices.shape[0]
|
num_indices = last_token_indices.shape[0]
|
||||||
if lmhead_tp_enable():
|
if lmhead_tp_enable():
|
||||||
@@ -515,10 +512,7 @@ class MtpProposer(Proposer):
|
|||||||
if step == self.num_speculative_tokens - 1 or with_prefill:
|
if step == self.num_speculative_tokens - 1 or with_prefill:
|
||||||
break
|
break
|
||||||
|
|
||||||
if not self.torchair_graph_enabled:
|
|
||||||
attn_metadata_i = attn_metadata[self.attn_layer_name[0]]
|
attn_metadata_i = attn_metadata[self.attn_layer_name[0]]
|
||||||
else:
|
|
||||||
attn_metadata_i = attn_metadata
|
|
||||||
|
|
||||||
if step == 0:
|
if step == 0:
|
||||||
positions = target_positions[last_token_indices]
|
positions = target_positions[last_token_indices]
|
||||||
@@ -529,21 +523,16 @@ class MtpProposer(Proposer):
|
|||||||
last_token_indices = self.arange[:batch_size]
|
last_token_indices = self.arange[:batch_size]
|
||||||
if attn_metadata_i.num_decode_tokens != 0:
|
if attn_metadata_i.num_decode_tokens != 0:
|
||||||
attn_metadata_i.num_decode_tokens = batch_size
|
attn_metadata_i.num_decode_tokens = batch_size
|
||||||
if is_running_torchair:
|
|
||||||
attn_metadata_i.num_actual_tokens = batch_size
|
|
||||||
attn_metadata_i.query_lens = [1] * batch_size
|
|
||||||
|
|
||||||
input_ids = draft_token_ids_list[-1].int()
|
input_ids = draft_token_ids_list[-1].int()
|
||||||
positions += 1
|
positions += 1
|
||||||
|
|
||||||
if not self.torchair_graph_enabled:
|
|
||||||
attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
|
attn_metadata_i.decode.actual_seq_lengths_q = attn_metadata_i.query_start_loc[
|
||||||
1:batch_size + 1].tolist()
|
1:batch_size + 1].tolist()
|
||||||
attn_metadata_i.decode.cos = builder.cos_cache[
|
attn_metadata_i.decode.cos = builder.cos_cache[
|
||||||
positions].unsqueeze(1).unsqueeze(2)
|
positions].unsqueeze(1).unsqueeze(2)
|
||||||
attn_metadata_i.decode.sin = builder.sin_cache[
|
attn_metadata_i.decode.sin = builder.sin_cache[
|
||||||
positions].unsqueeze(1).unsqueeze(2)
|
positions].unsqueeze(1).unsqueeze(2)
|
||||||
|
|
||||||
# NOTE(woosuk): We should handle the case where the draft model
|
# NOTE(woosuk): We should handle the case where the draft model
|
||||||
# generates tokens beyond the max model length. Since it is complex
|
# generates tokens beyond the max model length. Since it is complex
|
||||||
# to remove such requests from the batch, we keep them in the batch
|
# to remove such requests from the batch, we keep them in the batch
|
||||||
@@ -601,61 +590,6 @@ class MtpProposer(Proposer):
|
|||||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||||
return draft_token_ids
|
return draft_token_ids
|
||||||
|
|
||||||
def _get_torchair_lazy_compiled_model(self, batch_size: int):
|
|
||||||
if batch_size < 0 or batch_size > self.runner.torchair_graph_batch_sizes[
|
|
||||||
-1]:
|
|
||||||
raise ValueError(
|
|
||||||
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.runner.torchair_graph_batch_sizes[-1]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
compiled_model = self.torchair_compiled_models.get(
|
|
||||||
batch_size
|
|
||||||
) if self.runner.use_cached_npu_graph else self.torchair_compiled_model
|
|
||||||
|
|
||||||
if compiled_model:
|
|
||||||
return compiled_model
|
|
||||||
|
|
||||||
patch_for_hcom()
|
|
||||||
config = torchair.CompilerConfig()
|
|
||||||
config.experimental_config.frozen_parameter = True
|
|
||||||
config.experimental_config.tiling_schedule_optimize = True
|
|
||||||
config.experimental_config.enable_view_optimize = \
|
|
||||||
get_ascend_config().torchair_graph_config.enable_view_optimize
|
|
||||||
torch.npu.set_compile_mode(jit_compile=False)
|
|
||||||
if not self.runner.use_cached_npu_graph:
|
|
||||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
|
||||||
self.torchair_compiled_model = torch.compile(
|
|
||||||
self.model,
|
|
||||||
dynamic=not self.use_sparse,
|
|
||||||
fullgraph=True,
|
|
||||||
backend=npu_backend)
|
|
||||||
return self.torchair_compiled_model
|
|
||||||
else:
|
|
||||||
# Generate a new forward proxy code object to prevent the invalidation of
|
|
||||||
# compilation cache caused by dynamo retracing
|
|
||||||
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
|
|
||||||
forward_fn = self.model.forward
|
|
||||||
code = forward_fn.__code__
|
|
||||||
# Mark code object with a new proxy name
|
|
||||||
modified_code = code.replace(co_name=forward_proxy_name, )
|
|
||||||
|
|
||||||
modified_func = types.FunctionType(modified_code,
|
|
||||||
forward_fn.__globals__,
|
|
||||||
name=forward_proxy_name,
|
|
||||||
argdefs=forward_fn.__defaults__)
|
|
||||||
|
|
||||||
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
|
|
||||||
self.model, nn.Module)
|
|
||||||
self.torchair_compiled_models[
|
|
||||||
batch_size] = torchair.inference.cache_compile(
|
|
||||||
self.model.__dict__[forward_proxy_name],
|
|
||||||
dynamic=not self.use_sparse,
|
|
||||||
fullgraph=True,
|
|
||||||
cache_dir=TORCHAIR_CACHE_DIR,
|
|
||||||
config=config,
|
|
||||||
ge_cache=False)
|
|
||||||
return self.torchair_compiled_models[batch_size]
|
|
||||||
|
|
||||||
# TODO Using torch instead of triton may result in poor performance
|
# TODO Using torch instead of triton may result in poor performance
|
||||||
def _prepare_input_kernel(self, out_ptr: torch.Tensor,
|
def _prepare_input_kernel(self, out_ptr: torch.Tensor,
|
||||||
cu_query_lens: torch.Tensor,
|
cu_query_lens: torch.Tensor,
|
||||||
@@ -676,3 +610,160 @@ class MtpProposer(Proposer):
|
|||||||
global_indices_flat = global_indices[mask]
|
global_indices_flat = global_indices[mask]
|
||||||
values_flat = values[mask]
|
values_flat = values[mask]
|
||||||
out_ptr[global_indices_flat] = values_flat
|
out_ptr[global_indices_flat] = values_flat
|
||||||
|
|
||||||
|
def prepare_next_token_ids_cpu(
|
||||||
|
self,
|
||||||
|
sampled_token_ids: list[list[int]],
|
||||||
|
requests: dict[str, CachedRequestState],
|
||||||
|
gpu_input_batch: InputBatch,
|
||||||
|
num_scheduled_tokens: dict[str, int],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This function is used to prepare the inputs for speculative decoding.
|
||||||
|
It calculates the next token ids for each request based on the sampled
|
||||||
|
token ids from the CPU. If a request has no sampled token ids (e.g.,
|
||||||
|
during the initial decoding steps), it falls back to using the request
|
||||||
|
state to get the next token id.
|
||||||
|
"""
|
||||||
|
req_ids = gpu_input_batch.req_ids
|
||||||
|
next_token_ids: list[int] = []
|
||||||
|
for i, token_ids in enumerate(sampled_token_ids):
|
||||||
|
if token_ids:
|
||||||
|
# Common case.
|
||||||
|
next_token_id = token_ids[-1]
|
||||||
|
else:
|
||||||
|
# Partial prefill (rare case).
|
||||||
|
# Get the next token id from the request state.
|
||||||
|
req_id = req_ids[i]
|
||||||
|
req_state = requests[req_id]
|
||||||
|
seq_len = req_state.num_computed_tokens + num_scheduled_tokens[
|
||||||
|
req_id]
|
||||||
|
next_token_id = req_state.get_token_id(seq_len)
|
||||||
|
next_token_ids.append(next_token_id)
|
||||||
|
next_token_ids = torch.tensor(next_token_ids,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.input_ids.device)
|
||||||
|
return next_token_ids
|
||||||
|
|
||||||
|
def prepare_next_token_ids_padded(
|
||||||
|
self,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
sampled_token_ids: torch.Tensor,
|
||||||
|
requests: dict[str, CachedRequestState],
|
||||||
|
gpu_input_batch: InputBatch,
|
||||||
|
discard_request_indices: torch.Tensor,
|
||||||
|
num_discarded_requests: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
This function is used to prepare the inputs for speculative decoding.
|
||||||
|
It calculates the next token ids and the number of valid sampled tokens
|
||||||
|
for each request, considering the "discarded" requests whose next token
|
||||||
|
is not sampled and comes from `request.get_token_id()` instead.
|
||||||
|
It also accounts for the rejected tokens in `sampled_token_ids`.
|
||||||
|
This function must use device functions to operate on the inputs, and
|
||||||
|
should not introduce any blocking CPU-GPU synchronization.
|
||||||
|
"""
|
||||||
|
# TODO(Ben): Combine this into a custom fused kernel
|
||||||
|
|
||||||
|
# Precompute get_token_id for when there is no valid next token
|
||||||
|
num_reqs = gpu_input_batch.num_reqs
|
||||||
|
self.backup_next_token_ids.np[:num_reqs] = np.array([
|
||||||
|
requests[gpu_input_batch.req_ids[i]].get_token_id(
|
||||||
|
common_attn_metadata.seq_lens_cpu[i].item())
|
||||||
|
for i in range(num_reqs)
|
||||||
|
])
|
||||||
|
self.backup_next_token_ids.copy_to_gpu(num_reqs)
|
||||||
|
|
||||||
|
# Mask out the sampled tokens indices that should not be sampled.
|
||||||
|
discard_sampled_tokens_req_indices = discard_request_indices[:
|
||||||
|
num_discarded_requests]
|
||||||
|
|
||||||
|
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
|
||||||
|
valid_sampled_token_ids_gpu.index_fill_(
|
||||||
|
0, discard_sampled_tokens_req_indices, -1)
|
||||||
|
|
||||||
|
# Generate a mask for all valid tokens within those requests
|
||||||
|
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
|
||||||
|
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size)
|
||||||
|
|
||||||
|
# Count the number of valid tokens in each request
|
||||||
|
valid_sampled_tokens_count = valid_mask.sum(dim=1)
|
||||||
|
|
||||||
|
# Get the rightmost valid index per row
|
||||||
|
last_valid_indices = valid_sampled_tokens_count - 1
|
||||||
|
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
|
||||||
|
|
||||||
|
# Get last valid token from each row
|
||||||
|
# (assume undefined state where there is no valid token)
|
||||||
|
selected_tokens = torch.gather(
|
||||||
|
valid_sampled_token_ids_gpu, 1,
|
||||||
|
last_valid_indices_safe.unsqueeze(1)).squeeze(1)
|
||||||
|
|
||||||
|
# Use last token if valid, pre-computed backup if not
|
||||||
|
batch_size = valid_sampled_token_ids_gpu.shape[0]
|
||||||
|
next_token_ids = torch.where(
|
||||||
|
last_valid_indices != -1,
|
||||||
|
selected_tokens,
|
||||||
|
self.backup_next_token_ids.gpu[:batch_size],
|
||||||
|
)
|
||||||
|
|
||||||
|
return next_token_ids, valid_sampled_tokens_count
|
||||||
|
|
||||||
|
def prepare_inputs_padded(
|
||||||
|
self,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
spec_decode_metadata: SpecDecodeMetadata,
|
||||||
|
valid_sampled_tokens_count: torch.Tensor,
|
||||||
|
) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
This function is used to prepare the inputs for speculative decoding
|
||||||
|
It updates the common_attn_metadata for speculative decoding,
|
||||||
|
but does not consider the rejected tokens. Instead, all tokens
|
||||||
|
are included as inputs to the speculator, with the rejected tokens
|
||||||
|
used as padding and filtered out later by `token_indices_to_sample`.
|
||||||
|
No blocking CPU operations should be introduced in this function.
|
||||||
|
"""
|
||||||
|
num_draft_tokens_gpu = torch.cat([
|
||||||
|
spec_decode_metadata.cu_num_draft_tokens[0:1],
|
||||||
|
spec_decode_metadata.cu_num_draft_tokens[1:] -
|
||||||
|
spec_decode_metadata.cu_num_draft_tokens[:-1],
|
||||||
|
])
|
||||||
|
|
||||||
|
num_rejected_tokens_gpu = torch.where(
|
||||||
|
num_draft_tokens_gpu > 0,
|
||||||
|
num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
|
||||||
|
torch.zeros_like(num_draft_tokens_gpu),
|
||||||
|
)
|
||||||
|
|
||||||
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||||
|
|
||||||
|
new_query_len_per_req = query_start_loc_cpu[
|
||||||
|
1:] - query_start_loc_cpu[:-1]
|
||||||
|
|
||||||
|
total_num_tokens = query_start_loc_cpu[-1].item()
|
||||||
|
token_indices = self.arange[:total_num_tokens]
|
||||||
|
|
||||||
|
spec_common_attn_metadata = AscendCommonAttentionMetadata(
|
||||||
|
query_start_loc=common_attn_metadata.query_start_loc,
|
||||||
|
query_start_loc_cpu=query_start_loc_cpu,
|
||||||
|
seq_lens_cpu=common_attn_metadata.seq_lens,
|
||||||
|
num_reqs=common_attn_metadata.num_reqs,
|
||||||
|
num_actual_tokens=total_num_tokens,
|
||||||
|
max_query_len=new_query_len_per_req.max().item(),
|
||||||
|
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||||
|
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||||
|
slot_mapping=common_attn_metadata.slot_mapping,
|
||||||
|
positions=common_attn_metadata.positions,
|
||||||
|
attn_mask=self.runner.attn_mask,
|
||||||
|
spec_attn_mask=self.runner.spec_attn_mask,
|
||||||
|
attn_state=self.runner.attn_state,
|
||||||
|
graph_pad_size=self.runner.graph_pad_size,
|
||||||
|
decode_token_per_req=self.runner.decode_token_per_req,
|
||||||
|
num_computed_tokens_cpu=common_attn_metadata.
|
||||||
|
num_computed_tokens_cpu,
|
||||||
|
seq_lens=common_attn_metadata.seq_lens)
|
||||||
|
|
||||||
|
token_indices_to_sample = (common_attn_metadata.query_start_loc[1:] -
|
||||||
|
1 - num_rejected_tokens_gpu)
|
||||||
|
|
||||||
|
return spec_common_attn_metadata, token_indices, token_indices_to_sample
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from vllm.logger import logger
|
|||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
|
from vllm_ascend.spec_decode import get_spec_decode_method
|
||||||
from vllm_ascend.torchair.utils import (
|
from vllm_ascend.torchair.utils import (
|
||||||
TORCHAIR_CACHE_DIR, TorchairCommonAttentionMetadata,
|
TORCHAIR_CACHE_DIR, TorchairCommonAttentionMetadata,
|
||||||
check_torchair_cache_exist, converting_weight_acl_format,
|
check_torchair_cache_exist, converting_weight_acl_format,
|
||||||
@@ -83,6 +84,20 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
|||||||
|
|
||||||
self._check_batch_sizes_consistency()
|
self._check_batch_sizes_consistency()
|
||||||
|
|
||||||
|
def _set_up_drafter(self):
|
||||||
|
super()._set_up_drafter()
|
||||||
|
if self.speculative_config:
|
||||||
|
# Torchair do not support disable_padded_drafter_batch
|
||||||
|
# Enforce to disable this feature
|
||||||
|
self.speculative_config.disable_padded_drafter_batch = True
|
||||||
|
|
||||||
|
def _get_drafter(self):
|
||||||
|
return get_spec_decode_method(self.speculative_config.method,
|
||||||
|
self.vllm_config,
|
||||||
|
self.device,
|
||||||
|
self,
|
||||||
|
is_torchair_graph=True)
|
||||||
|
|
||||||
def _may_pad_kv_consumer_num_seq(self):
|
def _may_pad_kv_consumer_num_seq(self):
|
||||||
# pd disaggregation scenario need redundant_batch_sizes to avoid each batch's seq_len exceed 16 tokens
|
# pd disaggregation scenario need redundant_batch_sizes to avoid each batch's seq_len exceed 16 tokens
|
||||||
# self.max_num_reqs here is greater than the actual maximum request number
|
# self.max_num_reqs here is greater than the actual maximum request number
|
||||||
|
|||||||
554
vllm_ascend/torchair/torchair_mtp_proposer.py
Normal file
554
vllm_ascend/torchair/torchair_mtp_proposer.py
Normal file
@@ -0,0 +1,554 @@
|
|||||||
|
import types
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torchair
|
||||||
|
from torchair import patch_for_hcom
|
||||||
|
from vllm.config import (CUDAGraphMode, VllmConfig,
|
||||||
|
get_layers_from_vllm_config, set_current_vllm_config)
|
||||||
|
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||||
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
|
from vllm.model_executor.model_loader import get_model_loader
|
||||||
|
from vllm.model_executor.model_loader.utils import \
|
||||||
|
process_weights_after_loading
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||||
|
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||||
|
from vllm_ascend.spec_decode import MtpProposer
|
||||||
|
from vllm_ascend.torchair.models.torchair_deepseek_mtp import \
|
||||||
|
TorchairDeepSeekMTP
|
||||||
|
from vllm_ascend.torchair.utils import (TORCHAIR_CACHE_DIR,
|
||||||
|
TorchairCommonAttentionMetadata)
|
||||||
|
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
|
||||||
|
vllm_version_is)
|
||||||
|
|
||||||
|
if vllm_version_is("0.11.0"):
|
||||||
|
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||||
|
else:
|
||||||
|
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||||
|
|
||||||
|
PADDING_SLOT_ID = -1
|
||||||
|
|
||||||
|
|
||||||
|
class TorchairMtpProposer(MtpProposer):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
device,
|
||||||
|
runner,
|
||||||
|
):
|
||||||
|
super().__init__(vllm_config, device, runner)
|
||||||
|
self.torchair_compiled_model = None # type: ignore
|
||||||
|
self.torchair_compiled_models = {} # type: ignore
|
||||||
|
|
||||||
|
def load_model(self, model) -> None:
|
||||||
|
loader = get_model_loader(self.vllm_config.load_config)
|
||||||
|
|
||||||
|
target_attn_layer_names = set(
|
||||||
|
get_layers_from_vllm_config(self.vllm_config,
|
||||||
|
AttentionLayerBase).keys())
|
||||||
|
draft_model_config = \
|
||||||
|
self.vllm_config.speculative_config.draft_model_config
|
||||||
|
target_device = self.vllm_config.device_config.device
|
||||||
|
|
||||||
|
with set_default_torch_dtype(
|
||||||
|
draft_model_config.dtype), set_current_vllm_config(
|
||||||
|
self.vllm_config):
|
||||||
|
|
||||||
|
self.model = TorchairDeepSeekMTP(
|
||||||
|
vllm_config=self.vllm_config).to(target_device)
|
||||||
|
|
||||||
|
draft_attn_layer_names = (get_layers_from_vllm_config(
|
||||||
|
self.vllm_config, AttentionLayerBase).keys() -
|
||||||
|
target_attn_layer_names)
|
||||||
|
|
||||||
|
assert len(draft_attn_layer_names) == 1
|
||||||
|
self.attn_layer_name = list(draft_attn_layer_names)
|
||||||
|
|
||||||
|
self.model.load_weights(
|
||||||
|
loader.get_all_weights(
|
||||||
|
self.vllm_config.speculative_config.draft_model_config,
|
||||||
|
self.model))
|
||||||
|
process_weights_after_loading(self.model, draft_model_config,
|
||||||
|
target_device)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def dummy_run(self,
|
||||||
|
num_tokens: int,
|
||||||
|
with_prefill: bool = False,
|
||||||
|
skip_attn: bool = False,
|
||||||
|
num_reqs: int = 0,
|
||||||
|
num_tokens_across_dp=None,
|
||||||
|
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
|
batch_descriptor=None) -> None:
|
||||||
|
moe_comm_type = self.runner._select_moe_comm_method(
|
||||||
|
num_tokens, with_prefill)
|
||||||
|
|
||||||
|
if not with_prefill:
|
||||||
|
skip_attn = False
|
||||||
|
if skip_attn:
|
||||||
|
attn_metadata = None
|
||||||
|
else:
|
||||||
|
common_attn_metadata = TorchairCommonAttentionMetadata(
|
||||||
|
num_reqs=num_reqs,
|
||||||
|
num_actual_tokens=1,
|
||||||
|
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||||
|
attn_mask=self.runner.attn_mask,
|
||||||
|
spec_attn_mask=self.runner.spec_attn_mask,
|
||||||
|
decode_token_per_req=self.runner.decode_token_per_req,
|
||||||
|
)
|
||||||
|
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(
|
||||||
|
common_attn_metadata)
|
||||||
|
|
||||||
|
input_ids = self.input_ids[:num_tokens]
|
||||||
|
positions = self.positions[:num_tokens]
|
||||||
|
previous_hidden_states = self.hidden_states[:num_tokens]
|
||||||
|
for _ in range(self.num_speculative_tokens):
|
||||||
|
with set_ascend_forward_context(
|
||||||
|
attn_metadata,
|
||||||
|
self.vllm_config,
|
||||||
|
num_tokens=num_tokens,
|
||||||
|
with_prefill=with_prefill,
|
||||||
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
|
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
||||||
|
moe_comm_type=moe_comm_type,
|
||||||
|
in_profile_run=self.runner.in_profile_run,
|
||||||
|
num_actual_tokens=0):
|
||||||
|
if not with_prefill:
|
||||||
|
assert attn_metadata is not None
|
||||||
|
torch._dynamo.mark_static(input_ids)
|
||||||
|
torch._dynamo.mark_static(positions)
|
||||||
|
torch._dynamo.mark_static(previous_hidden_states)
|
||||||
|
torch._dynamo.mark_static(attn_metadata.decode.block_table)
|
||||||
|
torch._dynamo.mark_static(
|
||||||
|
attn_metadata.decode.input_positions)
|
||||||
|
if hasattr(attn_metadata.decode, "sin"):
|
||||||
|
torch._dynamo.mark_static(attn_metadata.decode.sin)
|
||||||
|
torch._dynamo.mark_static(attn_metadata.decode.cos)
|
||||||
|
torch._dynamo.mark_static(get_forward_context().mc2_mask)
|
||||||
|
torch._dynamo.mark_static(attn_metadata.slot_mapping)
|
||||||
|
torch._dynamo.mark_static(attn_metadata.decode.attn_mask)
|
||||||
|
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
|
||||||
|
num_tokens)
|
||||||
|
torchair_compiled_model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=previous_hidden_states,
|
||||||
|
inputs_embeds=None,
|
||||||
|
intermediate_tensors=None,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
kv_caches=self.runner.kv_caches[-1:],
|
||||||
|
spec_step_idx=0)
|
||||||
|
else:
|
||||||
|
self.model(input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=previous_hidden_states)
|
||||||
|
if with_prefill:
|
||||||
|
break
|
||||||
|
|
||||||
|
def generate_token_ids(self,
|
||||||
|
valid_sampled_token_ids: list[list[int]],
|
||||||
|
sampling_metadata: SamplingMetadata = None,
|
||||||
|
scheduler_output: SchedulerOutput = None,
|
||||||
|
spec_decode_metadata: SpecDecodeMetadata = None,
|
||||||
|
positions: torch.Tensor = None,
|
||||||
|
num_scheduled_tokens: int = 0,
|
||||||
|
hidden_states: torch.Tensor = None,
|
||||||
|
attn_metadata=None,
|
||||||
|
aux_hidden_states: torch.Tensor = None):
|
||||||
|
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
||||||
|
attn_metadata = attn_metadata['model.layers.0.self_attn.attn']
|
||||||
|
next_token_ids: list[int] = []
|
||||||
|
for i, token_ids in enumerate(valid_sampled_token_ids):
|
||||||
|
if token_ids:
|
||||||
|
# Common case.
|
||||||
|
next_token_id = token_ids[-1]
|
||||||
|
else:
|
||||||
|
# Partial prefill (rare case).
|
||||||
|
# Get the next token id from the request state.
|
||||||
|
req_id = self.runner.input_batch.req_ids[i]
|
||||||
|
req_state = self.runner.requests[req_id]
|
||||||
|
seq_len = (req_state.num_computed_tokens +
|
||||||
|
scheduler_output.num_scheduled_tokens[req_id])
|
||||||
|
next_token_id = req_state.get_token_id(seq_len)
|
||||||
|
next_token_ids.append(next_token_id)
|
||||||
|
next_token_ids = torch.tensor(next_token_ids,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
accepted_token_indices = None
|
||||||
|
if spec_decode_metadata is None:
|
||||||
|
# input_ids can be None for multimodal models.
|
||||||
|
target_token_ids = self.runner.input_ids[:num_scheduled_tokens]
|
||||||
|
target_positions = positions[:num_scheduled_tokens]
|
||||||
|
target_hidden_states = hidden_states[:num_scheduled_tokens]
|
||||||
|
target_slot_mapping = attn_metadata.slot_mapping
|
||||||
|
cu_num_tokens = attn_metadata.query_start_loc
|
||||||
|
else:
|
||||||
|
# TODO(woosuk): Refactor this.
|
||||||
|
num_draft_tokens = spec_decode_metadata.num_draft_tokens
|
||||||
|
num_rejected_tokens = [
|
||||||
|
n + 1 - len(valid_sampled_token_ids[i]) if n > 0 else 0
|
||||||
|
for i, n in enumerate(num_draft_tokens)
|
||||||
|
]
|
||||||
|
num_rejected_tokens = torch.tensor(
|
||||||
|
num_rejected_tokens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
cu_num_tokens, accepted_token_indices, target_token_ids, \
|
||||||
|
target_positions, target_hidden_states, target_slot_mapping = self._torchair_prepare_inputs(
|
||||||
|
attn_metadata.query_start_loc,
|
||||||
|
num_rejected_tokens,
|
||||||
|
self.runner.input_ids[:num_scheduled_tokens],
|
||||||
|
positions[:num_scheduled_tokens],
|
||||||
|
hidden_states[:num_scheduled_tokens],
|
||||||
|
attn_metadata.slot_mapping[:num_scheduled_tokens],
|
||||||
|
)
|
||||||
|
|
||||||
|
draft_token_ids = self._propose_torchair(
|
||||||
|
target_token_ids=target_token_ids,
|
||||||
|
target_positions=target_positions,
|
||||||
|
target_hidden_states=target_hidden_states,
|
||||||
|
target_slot_mapping=target_slot_mapping,
|
||||||
|
next_token_ids=next_token_ids,
|
||||||
|
cu_num_tokens=cu_num_tokens,
|
||||||
|
block_table=attn_metadata.block_tables,
|
||||||
|
sampling_metadata=sampling_metadata,
|
||||||
|
token_indices=accepted_token_indices)
|
||||||
|
spec_token_ids = draft_token_ids.tolist()
|
||||||
|
return spec_token_ids
|
||||||
|
|
||||||
|
def _torchair_prepare_inputs(
|
||||||
|
self,
|
||||||
|
# [batch_size + 1]
|
||||||
|
cu_target_query_lens: torch.Tensor,
|
||||||
|
# [batch_size]
|
||||||
|
num_rejected_tokens: torch.Tensor,
|
||||||
|
token_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
||||||
|
torch.Tensor, torch.Tensor]:
|
||||||
|
# cu_target_query_lens: [0, a, a + b, a + b + c]
|
||||||
|
# num_rejected_tokens: [n1, n2, n3]
|
||||||
|
# num_tokens_per_req: [a - n1, b - n2, c - n3]
|
||||||
|
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
|
||||||
|
# token_indices: [0, 1, ..., a - n1 - 1,
|
||||||
|
# a, a + 1, ..., a + b - n2 - 1,
|
||||||
|
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
|
||||||
|
# [0, a, a + b, a + b + c] -> [a, b, c]
|
||||||
|
query_len_per_req = (cu_target_query_lens[1:] -
|
||||||
|
cu_target_query_lens[:-1])
|
||||||
|
# [a, b, c] -> [a - n1, b - n2, c - n3]
|
||||||
|
|
||||||
|
cu_num_tokens = cu_target_query_lens
|
||||||
|
relative_index = query_len_per_req - num_rejected_tokens - 1
|
||||||
|
token_indices = cu_num_tokens[:-1] + relative_index
|
||||||
|
# the seq len of each bath is padded to 1+num_speculative_tokens, thus input is same as the main model
|
||||||
|
target_token_ids = token_ids
|
||||||
|
target_positions = positions
|
||||||
|
target_hidden_states = hidden_states
|
||||||
|
target_slot_mapping = slot_mapping
|
||||||
|
|
||||||
|
return cu_num_tokens, token_indices, target_token_ids, target_positions, target_hidden_states, target_slot_mapping
|
||||||
|
|
||||||
|
def _propose_torchair(
|
||||||
|
self,
|
||||||
|
# [num_tokens]
|
||||||
|
target_token_ids: torch.Tensor,
|
||||||
|
# [num_tokens]
|
||||||
|
target_positions: torch.Tensor,
|
||||||
|
# [num_tokens, hidden_size]
|
||||||
|
target_hidden_states: torch.Tensor,
|
||||||
|
# [num_tokens]
|
||||||
|
target_slot_mapping: torch.Tensor,
|
||||||
|
# [batch_size]
|
||||||
|
next_token_ids: torch.Tensor,
|
||||||
|
# [batch_size + 1] starting with 0
|
||||||
|
cu_num_tokens: torch.Tensor,
|
||||||
|
# [batch_size, max_num_blocks_per_req]
|
||||||
|
block_table: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
token_indices=None) -> torch.Tensor:
|
||||||
|
num_tokens = target_token_ids.shape[0]
|
||||||
|
batch_size = next_token_ids.shape[0]
|
||||||
|
last_token_indices = cu_num_tokens[1:] - 1
|
||||||
|
|
||||||
|
# Shift the input ids by one token.
|
||||||
|
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
||||||
|
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
|
||||||
|
# Replace the last token with the next token.
|
||||||
|
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||||
|
if token_indices is not None:
|
||||||
|
last_token_indices = token_indices
|
||||||
|
|
||||||
|
self.input_ids[last_token_indices] = next_token_ids
|
||||||
|
|
||||||
|
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
|
||||||
|
max_query_len = query_lens.max().item()
|
||||||
|
|
||||||
|
# FIXME: reorder_batch() needs to be called before build()
|
||||||
|
# because fields of attn_metadata_builder needs to be updated.
|
||||||
|
# However, currently reorder_batch() takes input_batch and
|
||||||
|
# scheduler_output as arguments, we should probably refactor
|
||||||
|
# the method to use new data structures which are independent
|
||||||
|
# from input_batch and scheduler_output.
|
||||||
|
# self.runner.attn_metadata_builder.reorder_batch(
|
||||||
|
# input_batch=self.runner.input_batch,
|
||||||
|
# scheduler_output=self.runner.scheduler_output,
|
||||||
|
# )
|
||||||
|
|
||||||
|
if not self.runner.with_prefill:
|
||||||
|
# Torchair graph mode, padding is same as the main model
|
||||||
|
num_input_tokens = self.runner.graph_pad_size
|
||||||
|
elif (self.runner.use_aclgraph
|
||||||
|
and num_tokens <= self.runner.aclgraph_batch_sizes[-1]):
|
||||||
|
# Acl graph mode, add padding to the batch size
|
||||||
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||||
|
else:
|
||||||
|
# Eager mode, no padding needed
|
||||||
|
num_input_tokens = num_tokens
|
||||||
|
|
||||||
|
seq_lens = target_positions[last_token_indices] + 1
|
||||||
|
seq_lens = seq_lens.int()
|
||||||
|
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||||
|
query_start_loc=cu_num_tokens[:batch_size + 1],
|
||||||
|
query_start_loc_cpu=cu_num_tokens[:batch_size + 1].cpu(),
|
||||||
|
seq_lens_cpu=seq_lens.cpu(),
|
||||||
|
num_reqs=batch_size,
|
||||||
|
num_actual_tokens=num_tokens,
|
||||||
|
max_query_len=max_query_len,
|
||||||
|
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||||
|
block_table_tensor=self.runner.input_batch.block_table[0].
|
||||||
|
get_device_tensor(),
|
||||||
|
slot_mapping=target_slot_mapping,
|
||||||
|
positions=target_positions,
|
||||||
|
attn_mask=self.runner.attn_mask,
|
||||||
|
spec_attn_mask=self.runner.spec_attn_mask,
|
||||||
|
attn_state=self.runner.attn_state,
|
||||||
|
graph_pad_size=self.runner.graph_pad_size,
|
||||||
|
decode_token_per_req=self.runner.decode_token_per_req,
|
||||||
|
num_computed_tokens_cpu=None,
|
||||||
|
seq_lens=None)
|
||||||
|
|
||||||
|
attn_metadata = self.runner.attn_metadata_builder.build(
|
||||||
|
0, common_attn_metadata, self.runner.get_model())
|
||||||
|
|
||||||
|
self.positions[:num_tokens] = target_positions
|
||||||
|
self.hidden_states[:num_tokens] = target_hidden_states
|
||||||
|
|
||||||
|
# torchair mode can reuse self.runner.num_tokens_across_dp
|
||||||
|
num_tokens_across_dp = self.runner.num_tokens_across_dp
|
||||||
|
with_prefill = self.runner.with_prefill
|
||||||
|
|
||||||
|
moe_comm_type = self.runner._select_moe_comm_method(
|
||||||
|
num_input_tokens, with_prefill)
|
||||||
|
batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens,
|
||||||
|
uniform_decode=False)
|
||||||
|
aclgraph_runtime_mode, batch_descriptor = \
|
||||||
|
self.runner.aclgraph_dispatcher.dispatch(batch_descriptor)
|
||||||
|
|
||||||
|
for step in range(self.num_speculative_tokens):
|
||||||
|
with set_ascend_forward_context(
|
||||||
|
attn_metadata,
|
||||||
|
self.vllm_config,
|
||||||
|
num_tokens=num_input_tokens,
|
||||||
|
with_prefill=with_prefill,
|
||||||
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
|
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
||||||
|
moe_comm_type=moe_comm_type,
|
||||||
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||||
|
in_profile_run=self.runner.in_profile_run,
|
||||||
|
num_actual_tokens=num_tokens):
|
||||||
|
with ProfileExecuteDuration().capture_async('mtp_forward'):
|
||||||
|
model_kwargs = {}
|
||||||
|
model_kwargs["attn_metadata"] = attn_metadata
|
||||||
|
|
||||||
|
model_kwargs["kv_caches"] = self.runner.kv_caches[-1:]
|
||||||
|
if not self.runner.with_prefill:
|
||||||
|
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
|
||||||
|
num_input_tokens)
|
||||||
|
hidden_states = torchair_compiled_model(
|
||||||
|
input_ids=self.input_ids[:num_input_tokens],
|
||||||
|
positions=self.positions[:num_input_tokens],
|
||||||
|
hidden_states=self.
|
||||||
|
hidden_states[:num_input_tokens],
|
||||||
|
inputs_embeds=None,
|
||||||
|
intermediate_tensors=None,
|
||||||
|
spec_step_idx=0,
|
||||||
|
**model_kwargs)
|
||||||
|
else:
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids=self.input_ids[:num_input_tokens],
|
||||||
|
positions=self.positions[:num_input_tokens],
|
||||||
|
hidden_states=self.hidden_states[:num_input_tokens]
|
||||||
|
)
|
||||||
|
|
||||||
|
num_indices = last_token_indices.shape[0]
|
||||||
|
if lmhead_tp_enable():
|
||||||
|
if not self.runner.with_prefill:
|
||||||
|
max_num_reqs_across_dp = num_input_tokens
|
||||||
|
else:
|
||||||
|
max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs
|
||||||
|
last_token_indices = nn.functional.pad(
|
||||||
|
last_token_indices,
|
||||||
|
(0, max_num_reqs_across_dp - num_indices))
|
||||||
|
|
||||||
|
sample_hidden_states = hidden_states[last_token_indices]
|
||||||
|
logits = self.model.compute_logits(sample_hidden_states)
|
||||||
|
if lmhead_tp_enable() and num_indices < logits.shape[0]:
|
||||||
|
logits = logits[:num_indices]
|
||||||
|
draft_token_ids = logits.argmax(dim=-1)
|
||||||
|
|
||||||
|
if self.num_speculative_tokens == 1:
|
||||||
|
# [batch_size, 1]
|
||||||
|
return draft_token_ids.view(-1, 1)
|
||||||
|
|
||||||
|
if step == 0:
|
||||||
|
draft_token_ids_list = [draft_token_ids]
|
||||||
|
else:
|
||||||
|
draft_token_ids_list.append(draft_token_ids)
|
||||||
|
|
||||||
|
# prepare next mtp inputs
|
||||||
|
# mtp>1: prefill skip or decode skip last loop
|
||||||
|
if with_prefill:
|
||||||
|
for _ in range(self.num_speculative_tokens - 1):
|
||||||
|
draft_token_ids_list.append(draft_token_ids)
|
||||||
|
if step == self.num_speculative_tokens - 1 or with_prefill:
|
||||||
|
break
|
||||||
|
|
||||||
|
attn_metadata_i = attn_metadata
|
||||||
|
|
||||||
|
if step == 0:
|
||||||
|
positions = target_positions[last_token_indices]
|
||||||
|
hidden_states = hidden_states[last_token_indices]
|
||||||
|
slot_mapping = attn_metadata_i.slot_mapping[last_token_indices]
|
||||||
|
attn_metadata_i.slot_mapping.fill_(-1)
|
||||||
|
attn_metadata_i.query_start_loc = self.arange[:batch_size + 1]
|
||||||
|
last_token_indices = self.arange[:batch_size]
|
||||||
|
if attn_metadata_i.num_decode_tokens != 0:
|
||||||
|
attn_metadata_i.num_decode_tokens = batch_size
|
||||||
|
if not self.runner.with_prefill:
|
||||||
|
attn_metadata_i.num_actual_tokens = batch_size
|
||||||
|
attn_metadata_i.query_lens = [1] * batch_size
|
||||||
|
|
||||||
|
input_ids = draft_token_ids_list[-1].int()
|
||||||
|
positions += 1
|
||||||
|
|
||||||
|
# NOTE(woosuk): We should handle the case where the draft model
|
||||||
|
# generates tokens beyond the max model length. Since it is complex
|
||||||
|
# to remove such requests from the batch, we keep them in the batch
|
||||||
|
# but adjust the position ids and slot mappings to avoid the
|
||||||
|
# out-of-range access during the model execution. The draft tokens
|
||||||
|
# generated with this adjustment should be ignored.
|
||||||
|
exceeds_max_model_len = positions >= self.runner.model_config.max_model_len
|
||||||
|
# Mask out the position ids that exceed the max model length.
|
||||||
|
# Otherwise, we may get out-of-range error in RoPE.
|
||||||
|
clamped_positions = torch.where(exceeds_max_model_len, 0,
|
||||||
|
positions)
|
||||||
|
# Increment the sequence lengths.
|
||||||
|
attn_metadata_i.seq_lens[:batch_size] += 1
|
||||||
|
# For the requests that exceed the max model length, we set the
|
||||||
|
# sequence length to 1 to minimize their overheads in attention.
|
||||||
|
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
|
||||||
|
attn_metadata_i.seq_lens.device, non_blocking=True)
|
||||||
|
attn_metadata_i.seq_lens[:batch_size].masked_fill_(
|
||||||
|
exceeds_max_model_len_cpu, 1)
|
||||||
|
# Mask out the slot mappings that exceed the max model length.
|
||||||
|
# Otherwise, the KV cache will be inadvertently updated with the
|
||||||
|
# padding tokens.
|
||||||
|
slot_mapping += 1
|
||||||
|
slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID)
|
||||||
|
|
||||||
|
# copy inputs to buffer for cudagraph
|
||||||
|
self.input_ids[:batch_size] = input_ids
|
||||||
|
self.positions[:batch_size] = clamped_positions
|
||||||
|
self.hidden_states[:hidden_states.shape[0]] = hidden_states
|
||||||
|
attn_metadata_i.slot_mapping[:batch_size] = slot_mapping
|
||||||
|
|
||||||
|
if attn_metadata_i.prefill is not None:
|
||||||
|
attn_metadata_i.prefill.seq_lens = attn_metadata_i.seq_lens
|
||||||
|
attn_metadata_i.prefill.seq_lens_list = attn_metadata_i.prefill.seq_lens.tolist(
|
||||||
|
)
|
||||||
|
attn_metadata_i.prefill.context_lens = attn_metadata_i.seq_lens
|
||||||
|
attn_metadata_i.prefill.input_positions = self.positions[:
|
||||||
|
num_input_tokens]
|
||||||
|
attn_metadata_i.prefill.max_seq_lens += 1
|
||||||
|
attn_metadata_i.prefill.max_seq_lens = min(
|
||||||
|
attn_metadata_i.prefill.max_seq_lens,
|
||||||
|
self.runner.model_config.max_model_len)
|
||||||
|
if attn_metadata_i.decode is not None:
|
||||||
|
attn_metadata_i.decode.seq_lens = attn_metadata_i.seq_lens
|
||||||
|
attn_metadata_i.decode.seq_lens_list = attn_metadata_i.decode.seq_lens.tolist(
|
||||||
|
)
|
||||||
|
attn_metadata_i.decode.input_positions = self.positions[:
|
||||||
|
num_input_tokens]
|
||||||
|
attn_metadata_i.decode.max_seq_lens += 1
|
||||||
|
attn_metadata_i.decode.max_seq_lens = min(
|
||||||
|
attn_metadata_i.decode.max_seq_lens,
|
||||||
|
self.runner.model_config.max_model_len)
|
||||||
|
|
||||||
|
# mtp>1: [batch_size, k]
|
||||||
|
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||||
|
return draft_token_ids
|
||||||
|
|
||||||
|
def _get_torchair_lazy_compiled_model(self, batch_size: int):
|
||||||
|
if batch_size < 0 or batch_size > self.runner.torchair_graph_batch_sizes[
|
||||||
|
-1]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.runner.torchair_graph_batch_sizes[-1]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
compiled_model = self.torchair_compiled_models.get(
|
||||||
|
batch_size
|
||||||
|
) if self.runner.use_cached_npu_graph else self.torchair_compiled_model
|
||||||
|
|
||||||
|
if compiled_model:
|
||||||
|
return compiled_model
|
||||||
|
|
||||||
|
patch_for_hcom()
|
||||||
|
config = torchair.CompilerConfig()
|
||||||
|
config.experimental_config.frozen_parameter = True
|
||||||
|
config.experimental_config.tiling_schedule_optimize = True
|
||||||
|
config.experimental_config.enable_view_optimize = \
|
||||||
|
get_ascend_config().torchair_graph_config.enable_view_optimize
|
||||||
|
torch.npu.set_compile_mode(jit_compile=False)
|
||||||
|
if not self.runner.use_cached_npu_graph:
|
||||||
|
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||||
|
self.torchair_compiled_model = torch.compile(
|
||||||
|
self.model,
|
||||||
|
dynamic=not self.use_sparse,
|
||||||
|
fullgraph=True,
|
||||||
|
backend=npu_backend)
|
||||||
|
return self.torchair_compiled_model
|
||||||
|
else:
|
||||||
|
# Generate a new forward proxy code object to prevent the invalidation of
|
||||||
|
# compilation cache caused by dynamo retracing
|
||||||
|
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
|
||||||
|
forward_fn = self.model.forward
|
||||||
|
code = forward_fn.__code__
|
||||||
|
# Mark code object with a new proxy name
|
||||||
|
modified_code = code.replace(co_name=forward_proxy_name, )
|
||||||
|
|
||||||
|
modified_func = types.FunctionType(modified_code,
|
||||||
|
forward_fn.__globals__,
|
||||||
|
name=forward_proxy_name,
|
||||||
|
argdefs=forward_fn.__defaults__)
|
||||||
|
|
||||||
|
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
|
||||||
|
self.model, nn.Module)
|
||||||
|
self.torchair_compiled_models[
|
||||||
|
batch_size] = torchair.inference.cache_compile(
|
||||||
|
self.model.__dict__[forward_proxy_name],
|
||||||
|
dynamic=not self.use_sparse,
|
||||||
|
fullgraph=True,
|
||||||
|
cache_dir=TORCHAIR_CACHE_DIR,
|
||||||
|
config=config,
|
||||||
|
ge_cache=False)
|
||||||
|
return self.torchair_compiled_models[batch_size]
|
||||||
@@ -133,6 +133,7 @@ from vllm_ascend.spec_decode import get_spec_decode_method
|
|||||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||||
from vllm_ascend.spec_decode.interface import SpecDcodeType
|
from vllm_ascend.spec_decode.interface import SpecDcodeType
|
||||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||||||
|
from vllm_ascend.torchair.torchair_mtp_proposer import TorchairMtpProposer
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||||
AscendSocVersion, ProfileExecuteDuration,
|
AscendSocVersion, ProfileExecuteDuration,
|
||||||
enable_sp, get_ascend_soc_version, is_310p,
|
enable_sp, get_ascend_soc_version, is_310p,
|
||||||
@@ -369,32 +370,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.attn_mask_builder = AttentionMaskBuilder(
|
self.attn_mask_builder = AttentionMaskBuilder(
|
||||||
self.model_config.max_model_len, self.dtype)
|
self.model_config.max_model_len, self.dtype)
|
||||||
|
|
||||||
# Set up speculative decoding.
|
self._set_up_drafter()
|
||||||
self.spec_attn_mask = None
|
|
||||||
self.drafter: Optional[Union[NgramProposer, EagleProposer,
|
|
||||||
MtpProposer]] = None
|
|
||||||
self.actual_seq_lengths_q: list[int] = []
|
|
||||||
self.decode_token_per_req = 1
|
|
||||||
if self.speculative_config:
|
|
||||||
spec_token_num = self.speculative_config.num_speculative_tokens
|
|
||||||
assert spec_token_num > 0
|
|
||||||
self.decode_token_per_req = 1 + spec_token_num
|
|
||||||
self.spec_attn_mask = torch.triu(torch.ones(2048,
|
|
||||||
2048,
|
|
||||||
dtype=torch.bool),
|
|
||||||
diagonal=1).to(self.device)
|
|
||||||
if get_pp_group().is_last_rank:
|
|
||||||
self.drafter = get_spec_decode_method(
|
|
||||||
self.speculative_config.method, self.vllm_config,
|
|
||||||
self.device, self)
|
|
||||||
if vllm_version_is("0.11.0"):
|
|
||||||
self.rejection_sampler = AscendRejectionSampler()
|
|
||||||
else:
|
|
||||||
self.rejection_sampler = AscendRejectionSampler(
|
|
||||||
self.sampler)
|
|
||||||
self.actual_seq_lengths_q = list(
|
|
||||||
range(self.decode_token_per_req, self.max_num_tokens + 1,
|
|
||||||
self.decode_token_per_req))
|
|
||||||
|
|
||||||
# kv role
|
# kv role
|
||||||
self.is_kv_producer = False
|
self.is_kv_producer = False
|
||||||
@@ -590,6 +566,39 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# TODO: EVS Support (Video tokens pruning) (see vllm#22980)
|
# TODO: EVS Support (Video tokens pruning) (see vllm#22980)
|
||||||
self.is_multimodal_pruning_enabled = False
|
self.is_multimodal_pruning_enabled = False
|
||||||
|
|
||||||
|
def _set_up_drafter(self):
|
||||||
|
# Set up speculative decoding.
|
||||||
|
self.spec_attn_mask = None
|
||||||
|
self.drafter: Optional[Union[NgramProposer, EagleProposer, MtpProposer,
|
||||||
|
TorchairMtpProposer]] = None
|
||||||
|
self.actual_seq_lengths_q: list[int] = []
|
||||||
|
self.decode_token_per_req = 1
|
||||||
|
if self.speculative_config:
|
||||||
|
spec_token_num = self.speculative_config.num_speculative_tokens
|
||||||
|
assert spec_token_num > 0
|
||||||
|
self.decode_token_per_req = 1 + spec_token_num
|
||||||
|
self.spec_attn_mask = torch.triu(torch.ones(2048,
|
||||||
|
2048,
|
||||||
|
dtype=torch.bool),
|
||||||
|
diagonal=1).to(self.device)
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
self.drafter = self._get_drafter()
|
||||||
|
if vllm_version_is("0.11.0"):
|
||||||
|
self.rejection_sampler = AscendRejectionSampler()
|
||||||
|
else:
|
||||||
|
self.rejection_sampler = AscendRejectionSampler(
|
||||||
|
self.sampler)
|
||||||
|
self.actual_seq_lengths_q = list(
|
||||||
|
range(self.decode_token_per_req, self.max_num_tokens + 1,
|
||||||
|
self.decode_token_per_req))
|
||||||
|
self.discard_request_indices = self._make_buffer(self.max_num_reqs,
|
||||||
|
dtype=torch.int64)
|
||||||
|
self.num_discarded_requests = 0
|
||||||
|
|
||||||
|
def _get_drafter(self):
|
||||||
|
return get_spec_decode_method(self.speculative_config.method,
|
||||||
|
self.vllm_config, self.device, self)
|
||||||
|
|
||||||
def _may_pad_kv_consumer_num_seq(self):
|
def _may_pad_kv_consumer_num_seq(self):
|
||||||
# For Full Graph + MTP in a PD (Prefill/Decode) disaggregation scenario,
|
# For Full Graph + MTP in a PD (Prefill/Decode) disaggregation scenario,
|
||||||
# we may want to pad self.max_num_seqs in kv_consumer nodes to avoid
|
# we may want to pad self.max_num_seqs in kv_consumer nodes to avoid
|
||||||
@@ -609,7 +618,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
tp_size = self.parallel_config.tensor_parallel_size
|
tp_size = self.parallel_config.tensor_parallel_size
|
||||||
# Use integer arithmetic for ceiling division.
|
# Use integer arithmetic for ceiling division.
|
||||||
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
|
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
|
||||||
self.mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
|
self.mc2_tokens_capacity: int = num_tokens_per_tp_rank * tp_size
|
||||||
|
|
||||||
def _make_buffer(self,
|
def _make_buffer(self,
|
||||||
*size: Union[int, torch.SymInt],
|
*size: Union[int, torch.SymInt],
|
||||||
@@ -1522,6 +1531,20 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens)
|
self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens)
|
||||||
attn_metadata: dict[str, Any] = {}
|
attn_metadata: dict[str, Any] = {}
|
||||||
|
|
||||||
|
# Record the index of requests that should not be sampled,
|
||||||
|
# so that we could clear the sampled tokens before returning
|
||||||
|
num_tokens = [
|
||||||
|
self.requests[r].num_tokens for r in self.input_batch.req_ids
|
||||||
|
]
|
||||||
|
num_tokens_np = np.array(num_tokens, dtype=np.int32)
|
||||||
|
num_reqs = self.input_batch.num_reqs
|
||||||
|
discard_requests_mask = self.seq_lens_np[:num_reqs] < num_tokens_np
|
||||||
|
discard_request_indices = np.nonzero(discard_requests_mask)[0]
|
||||||
|
self.num_discarded_requests = len(discard_request_indices)
|
||||||
|
self.discard_request_indices.np[:self.num_discarded_requests] = (
|
||||||
|
discard_request_indices)
|
||||||
|
self.discard_request_indices.copy_to_gpu(self.num_discarded_requests)
|
||||||
|
|
||||||
# _prepare_inputs may reorder the batch, so we must gather
|
# _prepare_inputs may reorder the batch, so we must gather
|
||||||
# multi-modal outputs after that to ensure the correct order
|
# multi-modal outputs after that to ensure the correct order
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
@@ -1615,7 +1638,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
|
# query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
|
||||||
num_computed_tokens_cpu = (
|
num_computed_tokens_cpu = (
|
||||||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
||||||
spec_decode_common_attn_metadata = None
|
self.spec_decode_common_attn_metadata = None
|
||||||
if use_spec_decode and self.need_accepted_tokens:
|
if use_spec_decode and self.need_accepted_tokens:
|
||||||
self.num_accepted_tokens.np[:num_reqs] = (
|
self.num_accepted_tokens.np[:num_reqs] = (
|
||||||
self.input_batch.num_accepted_tokens_cpu[:num_reqs])
|
self.input_batch.num_accepted_tokens_cpu[:num_reqs])
|
||||||
@@ -1676,7 +1699,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||||
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
||||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
|
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
|
||||||
seq_lens_cpu=self.seq_lens_cpu,
|
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
|
||||||
seq_lens=self.seq_lens_cpu[:num_reqs],
|
seq_lens=self.seq_lens_cpu[:num_reqs],
|
||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
num_actual_tokens=slot_mapping_size,
|
num_actual_tokens=slot_mapping_size,
|
||||||
@@ -1700,8 +1723,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.speculative_config and \
|
if self.speculative_config and \
|
||||||
spec_decode_common_attn_metadata is None:
|
self.spec_decode_common_attn_metadata is None:
|
||||||
spec_decode_common_attn_metadata = common_attn_metadata
|
self.spec_decode_common_attn_metadata = common_attn_metadata
|
||||||
|
|
||||||
for attn_group in self.attn_groups[kv_cache_group_id]:
|
for attn_group in self.attn_groups[kv_cache_group_id]:
|
||||||
common_prefix_len = 0
|
common_prefix_len = 0
|
||||||
@@ -1998,7 +2021,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
def propose_draft_token_ids(
|
def propose_draft_token_ids(
|
||||||
self,
|
self,
|
||||||
valid_sampled_token_ids: list[list[int]],
|
valid_sampled_token_ids: Union[torch.Tensor, list[list[int]]],
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
spec_decode_metadata: SpecDecodeMetadata,
|
spec_decode_metadata: SpecDecodeMetadata,
|
||||||
@@ -2255,6 +2278,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
logits = self.apply_grammar_bitmask(
|
logits = self.apply_grammar_bitmask(
|
||||||
scheduler_output, logits)
|
scheduler_output, logits)
|
||||||
|
|
||||||
|
with ProfileExecuteDuration().capture_async("Sample"):
|
||||||
# Sample the next token and get logprobs if needed.
|
# Sample the next token and get logprobs if needed.
|
||||||
sampling_metadata = self.input_batch.sampling_metadata
|
sampling_metadata = self.input_batch.sampling_metadata
|
||||||
if spec_decode_metadata is None:
|
if spec_decode_metadata is None:
|
||||||
@@ -2296,21 +2320,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if self.need_accepted_tokens:
|
if self.need_accepted_tokens:
|
||||||
self._update_states_after_model_execute(output_token_ids)
|
self._update_states_after_model_execute(output_token_ids)
|
||||||
|
|
||||||
discard_sampled_tokens_req_indices: list[int] = []
|
discard_sampled_tokens_req_indices = \
|
||||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
self.discard_request_indices.np[:self.num_discarded_requests]
|
||||||
# the requests one by one. Optimize.
|
for i in discard_sampled_tokens_req_indices:
|
||||||
discard_sampled_tokens_req_indices = []
|
generator = self.input_batch.generators.get(int(i))
|
||||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
|
||||||
req_state = self.requests[req_id]
|
|
||||||
seq_len = (req_state.num_computed_tokens +
|
|
||||||
scheduler_output.num_scheduled_tokens[req_id])
|
|
||||||
if seq_len < req_state.num_tokens:
|
|
||||||
# Ignore the sampled token.
|
|
||||||
# Rewind the generator state as if the token was not sampled.
|
|
||||||
generator = self.input_batch.generators.get(i)
|
|
||||||
if generator is not None:
|
if generator is not None:
|
||||||
generator.set_offset(generator.get_offset() - 4)
|
generator.set_offset(generator.get_offset() - 4)
|
||||||
discard_sampled_tokens_req_indices.append(i)
|
|
||||||
|
|
||||||
# Copy some objects so they don't get modified after returning.
|
# Copy some objects so they don't get modified after returning.
|
||||||
# This is important when using async scheduling.
|
# This is important when using async scheduling.
|
||||||
@@ -2346,10 +2361,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
# Mask out the sampled tokens that should not be sampled.
|
# Mask out the sampled tokens that should not be sampled.
|
||||||
for i in discard_sampled_tokens_req_indices:
|
for i in discard_sampled_tokens_req_indices:
|
||||||
valid_sampled_token_ids[i].clear()
|
valid_sampled_token_ids[int(i)].clear()
|
||||||
else:
|
else:
|
||||||
valid_sampled_token_ids = []
|
valid_sampled_token_ids = []
|
||||||
invalid_req_indices = list(discard_sampled_tokens_req_indices)
|
invalid_req_indices = discard_sampled_tokens_req_indices.tolist(
|
||||||
|
)
|
||||||
invalid_req_indices_set = set(invalid_req_indices)
|
invalid_req_indices_set = set(invalid_req_indices)
|
||||||
assert sampled_token_ids.shape[-1] == 1
|
assert sampled_token_ids.shape[-1] == 1
|
||||||
|
|
||||||
@@ -2394,9 +2410,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
req_state.output_token_ids.extend(sampled_ids)
|
req_state.output_token_ids.extend(sampled_ids)
|
||||||
|
|
||||||
if self.speculative_config:
|
def propose_draft_token_ids(sampled_token_ids):
|
||||||
|
assert self.spec_decode_common_attn_metadata is not None
|
||||||
self._draft_token_ids = self.propose_draft_token_ids(
|
self._draft_token_ids = self.propose_draft_token_ids(
|
||||||
valid_sampled_token_ids,
|
sampled_token_ids,
|
||||||
sampling_metadata,
|
sampling_metadata,
|
||||||
scheduler_output,
|
scheduler_output,
|
||||||
spec_decode_metadata,
|
spec_decode_metadata,
|
||||||
@@ -2407,6 +2424,20 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
aux_hidden_states,
|
aux_hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with ProfileExecuteDuration().capture_async("Draft"):
|
||||||
|
if self.speculative_config:
|
||||||
|
use_padded_batch_for_eagle = self.speculative_config and \
|
||||||
|
self.speculative_config.method == "deepseek_mtp" and \
|
||||||
|
not self.speculative_config.disable_padded_drafter_batch
|
||||||
|
if use_padded_batch_for_eagle:
|
||||||
|
# EAGLE speculative decoding can use the GPU sampled tokens
|
||||||
|
# as inputs, and does not need to wait for bookkeeping to finish.
|
||||||
|
propose_draft_token_ids(sampler_output.sampled_token_ids)
|
||||||
|
if self.speculative_config and not use_padded_batch_for_eagle:
|
||||||
|
# ngram and other speculative decoding methods use the sampled
|
||||||
|
# tokens on the CPU, so they are run after bookkeeping.
|
||||||
|
propose_draft_token_ids(valid_sampled_token_ids)
|
||||||
|
|
||||||
if has_kv_transfer_group():
|
if has_kv_transfer_group():
|
||||||
get_kv_transfer_group().clear_connector_metadata()
|
get_kv_transfer_group().clear_connector_metadata()
|
||||||
|
|
||||||
|
|||||||
@@ -92,8 +92,10 @@ class CachedRequestState:
|
|||||||
def get_token_id(self, idx: int) -> int:
|
def get_token_id(self, idx: int) -> int:
|
||||||
if idx < self.num_prompt_tokens:
|
if idx < self.num_prompt_tokens:
|
||||||
return self.prompt_token_ids[idx]
|
return self.prompt_token_ids[idx]
|
||||||
else:
|
elif idx - self.num_prompt_tokens < len(self.output_token_ids):
|
||||||
return self.output_token_ids[idx - self.num_prompt_tokens]
|
return self.output_token_ids[idx - self.num_prompt_tokens]
|
||||||
|
else:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
|
||||||
class InputBatch:
|
class InputBatch:
|
||||||
|
|||||||
Reference in New Issue
Block a user