[V1] MTP supports torchair (#2145)

### What this PR does / why we need it?
Support MTP  with:

- [x]  V0 Scheduler
- [x]  TorchAir
- [x]  Single DP
- [x]  Multi DP
- [x]  Disaggregate PD

Known issues:
- [ ] Not support V1 Scheduler (chunked prefill), will be supported in a
few weeks
- [ ] vllm v0.10.0 does not support metrics with `DP > 1` right now,
need to comment out the line 171-175 in file
`vllm/vllm/v1/metrics/loggers.py`
```
            if (len(self.engine_indexes) > 1
                and vllm_config.speculative_config is not None):
            raise NotImplementedError("Prometheus metrics with Spec Decoding "
                                      "with >1 EngineCore per AsyncLLM is not "
                                      "supported yet.")
```

To start an online server with torchair enabled, here is an example:
```
python -m vllm.entrypoints.openai.api_server \
 --model="/weights/DeepSeek-R1_w8a8/" \
 --trust-remote-code \
 --max-model-len 40000 \
 --tensor-parallel-size 4 \
 --data_parallel_size 4 \
 --max-num-seqs 16 \
 --no-enable-prefix-caching \
 --enable_expert_parallel \
 --served-model-name deepseekr1 \
 --speculative-config '{"num_speculative_tokens": 1, "method":"deepseek_mtp"}' \
 --quantization ascend \
 --host 0.0.0.0 \
 --port 1234 \
 --additional-config '{"ascend_scheduler_config":{"enabled":true,"enable_chunked_prefill":false},"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]},"enable_weight_nz_layout":true}' \
 --gpu_memory_utilization 0.9 
``` 

offline example with torchair enabled
```
from vllm import LLM, SamplingParams

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]

# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=16, temperature=0)
# Create an LLM.
llm = LLM(
    model="/home/data/DeepSeek-R1_w8a8/",
    tensor_parallel_size=16,
    max_num_seqs=16,
    gpu_memory_utilization=0.9,
    distributed_executor_backend="mp",
    enable_expert_parallel=True,
    speculative_config={
        "method": "deepseek_mtp",
        "num_speculative_tokens": 1,
    },
    trust_remote_code=True,
    enforce_eager=False,
    max_model_len=2000,
    additional_config = {
       'torchair_graph_config': {
            'enabled': True,
            "graph_batch_sizes": [16],
            'enable_multistream_shared_expert': False,
        },
       "ascend_scheduler_config": {
            "enabled": True
        },
        # 'expert_tensor_parallel_size': 16,
    }
)

# Generate texts from the prompts.
# llm.start_profile()
outputs = llm.generate(prompts, sampling_params)
# llm.stop_profile()
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```

- vLLM version: v0.10.0
- vLLM main:
302962e806

---------

Signed-off-by: xuyexiong <xuyexiong@huawei.com>
This commit is contained in:
xuyexiong
2025-08-06 19:37:43 +08:00
committed by GitHub
parent bf84f2dbfa
commit 26fc36b0e0
12 changed files with 541 additions and 160 deletions

View File

@@ -206,9 +206,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
dtype=self.dtype,
device=self.device)
self.graph_block_tables = np.zeros(
(self.max_num_reqs, self.max_num_blocks_per_req), dtype=np.int32)
# Set up Attention
self.attn_backend = get_attn_backend(
0,
@@ -231,8 +228,19 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.use_eagle = False
self.drafter: Optional[Union[NgramProposer, EagleProposer,
MtpProposer]] = None
self.actual_seq_lengths_q = []
self.spec_token_num = 0
self.decode_token_per_req = 1
if self.speculative_config:
self.use_spec_decode = True
self.spec_token_num = self.speculative_config.num_speculative_tokens
assert self.spec_token_num > 0
self.decode_token_per_req = 1 + self.spec_token_num
self.actual_seq_lengths_q = [
len for len in
range(self.decode_token_per_req, self.max_num_tokens +
1, self.decode_token_per_req)
]
self.spec_attn_mask = torch.triu(torch.ones(2048,
2048,
dtype=torch.bool),
@@ -253,6 +261,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
f"{self.speculative_config.method}")
self.rejection_sampler = AscendRejectionSampler()
# Persistent batch.
self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
@@ -338,9 +347,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
self.init_torchair_graph_batch_sizes()
if len(self.torchair_graph_batch_sizes) == 0:
# TODO(zzzzwwjj): check torchair_graph_batch_sizes init code
self.torchair_graph_batch_sizes = [self.max_num_reqs]
self.check_torchair_graph_batch_sizes()
# graph_block_tables shape: [num_request, cell(max_model_len / block_size)]
self.graph_block_tables = np.zeros(
(self.torchair_graph_batch_sizes[-1] // self.decode_token_per_req,
(self.model_config.max_model_len + self.block_size - 1) //
self.block_size),
dtype=np.int32)
torch._dynamo.cache_size.config.cache_size_limit += len(
self.torchair_graph_batch_sizes)
@@ -558,17 +573,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
start_token_index:end_token_index] = new_token_ids
self.input_batch.num_tokens_no_spec[
req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu.
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
req_id, ())
if spec_token_ids:
start_index = end_token_index
end_token_index += len(spec_token_ids)
self.input_batch.token_ids_cpu[
req_index,
start_index:end_token_index] = spec_token_ids
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
self.input_batch.num_tokens[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu.
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
req_id, ())
if spec_token_ids:
num_spec_tokens = len(spec_token_ids)
start_index = self.input_batch.num_tokens_no_spec[req_index]
end_token_index = start_index + num_spec_tokens
self.input_batch.token_ids_cpu[
req_index, start_index:end_token_index] = spec_token_ids
# NOTE(woosuk): `num_tokens` here may include spec tokens.
self.input_batch.num_tokens[req_index] += num_spec_tokens
# Check if the batch has changed. If not, we can skip copying the
# sampling metadata from CPU to GPU.
@@ -586,6 +601,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Append to the end.
req_index = None
self.input_batch.add_request(req_state, req_index)
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
req_id, ())
if spec_token_ids:
req_index = self.input_batch.num_reqs - 1
start_index = len(req_state.prompt_token_ids) + len(
req_state.output_token_ids)
end_token_index = start_index + len(spec_token_ids)
self.input_batch.token_ids_cpu[
req_index, start_index:end_token_index] = spec_token_ids
self.input_batch.num_tokens[req_index] = end_token_index
# Condense the batched states if there are empty indices.
if removed_req_indices:
@@ -615,6 +640,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
if self.dp_size == 1:
if self.torchair_graph_enabled and not with_prefill:
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
num_tokens)
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
return num_tokens, None, with_prefill, enable_dbo
if self.is_kv_producer and not envs_ascend.VLLM_ASCEND_ENABLE_CHUNK_MC2:
@@ -1108,6 +1137,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
elif np.all(num_scheduled_tokens == 1):
attn_state = AscendAttentionState.DecodeOnly
if self.speculative_config and self.speculative_config.method == 'deepseek_mtp':
# SpecDecoding now supports seq_len=1 and seq_len=2
# In Prefilling Decoding Disaggregation scenario, SpecDecoding need to supports seq_len=1
attn_state = AscendAttentionState.SpecDecoding
# Speculative decoding.
elif np.all(num_valid_tokens == 1):
if self.use_eagle:
@@ -1154,10 +1187,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
enable_dbo) = self._get_forward_metadata_across_dp_and_pad(
total_num_scheduled_tokens, with_prefill, enable_dbo)
extra_builder_kwargs['enable_dbo_across_dp'] = enable_dbo
self.with_prefill = with_prefill
self.num_tokens_across_dp = num_tokens_across_dp
if self.torchair_graph_enabled and not with_prefill:
graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
self.graph_pad_size = padded_num_tokens_across_dp
extra_builder_kwargs[
'graph_pad_size'] = self.graph_pad_size # type: ignore
else:
self.graph_pad_size = -1
if self.vllm_config.model_config.use_mla:
extra_builder_kwargs[
@@ -1837,10 +1874,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# has num_tokens in total.
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
max_num_reqs = self.scheduler_config.max_num_seqs
num_reqs = min(num_tokens, max_num_reqs)
if with_prefill:
num_reqs = num_tokens
else:
num_reqs = (num_tokens + self.decode_token_per_req -
1) // self.decode_token_per_req
num_reqs = min(num_reqs, max_num_reqs)
min_tokens_per_req = num_tokens // num_reqs
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
assert sum(num_scheduled_tokens_list) == num_tokens
assert len(num_scheduled_tokens_list) == num_reqs
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
dtype=np.int32)
@@ -1852,7 +1896,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# we can't skip_attn, it will cause graph recompile.
if self.torchair_graph_enabled and not with_prefill:
attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(
num_reqs=num_tokens, num_actual_tokens=1)
num_reqs=num_reqs, num_actual_tokens=1)
elif skip_attn:
attn_metadata = None
else:
@@ -1913,6 +1957,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
torch._dynamo.mark_static(attn_metadata.decode.sin)
torch._dynamo.mark_static(attn_metadata.decode.cos)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
if self.speculative_config:
torch._dynamo.mark_static(
attn_metadata.decode.attn_mask)
for kv in self.kv_caches:
assert isinstance(
kv, tuple), "kv_cache must be a tuple"
@@ -1949,6 +1996,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.use_spec_decode and isinstance(
self.drafter, EagleProposer):
self.drafter.dummy_run(num_tokens)
if self.speculative_config and self.speculative_config.method == "deepseek_mtp":
assert isinstance(self.drafter, MtpProposer)
self.drafter.dummy_run(
num_tokens=num_tokens,
with_prefill=with_prefill,
skip_attn=skip_attn,
num_reqs=num_reqs,
num_tokens_across_dp=num_tokens_across_dp)
return hidden_states
@contextmanager
@@ -2071,9 +2127,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
m.consumed_memory / float(2**30))
def _get_torchair_lazy_compiled_model(self, batch_size: int):
if batch_size < 0 or batch_size > self.max_num_reqs:
if batch_size < 0 or batch_size > self.torchair_graph_batch_sizes[-1]:
raise ValueError(
f"Bad graph batch size:{batch_size}! max_num_reqs:{self.max_num_reqs}"
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.torchair_graph_batch_sizes[-1]}"
)
compiled_model = self.torchair_compiled_models.get(
@@ -2537,7 +2593,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
accepted_token_indices = None
if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
@@ -2557,14 +2613,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
dtype=torch.int32,
device=self.device,
)
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
cu_num_tokens, accepted_token_indices, target_token_ids, \
target_positions, target_hidden_states, target_slot_mapping = self.drafter.prepare_inputs(
attn_metadata.query_start_loc,
num_rejected_tokens,
self.input_ids[:num_scheduled_tokens],
positions[:num_scheduled_tokens],
hidden_states[:num_scheduled_tokens],
attn_metadata.slot_mapping[:num_scheduled_tokens],
is_torchair_graph=self.torchair_graph_enabled,
)
target_token_ids = self.input_ids[token_indices]
target_positions = positions[token_indices]
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
draft_token_ids = self.drafter.propose(
target_token_ids=target_token_ids,
@@ -2575,7 +2633,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
cu_num_tokens=cu_num_tokens,
block_table=attn_metadata.block_tables,
sampling_metadata=sampling_metadata,
)
token_indices=accepted_token_indices)
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids
@@ -2686,11 +2744,56 @@ class NPUModelRunner(LoRAModelRunnerMixin):
start_graph_batch_size *= 2
def select_torchair_padded_batch_size(self, batch_size: int):
selected_batch_size = self.max_num_reqs
for padded_batch_size in self.torchair_graph_batch_sizes:
if batch_size <= padded_batch_size < selected_batch_size:
selected_batch_size = padded_batch_size
return selected_batch_size
if batch_size <= padded_batch_size:
# we treat batch_size as num of requests
return padded_batch_size
raise ValueError(
f"cur batch_size is invalid, torchair_graph_batch_sizes is "
f"{self.torchair_graph_batch_sizes}, but cur batch_size is {batch_size}."
)
def check_torchair_graph_batch_sizes(self):
# return graph_batch_sizes according to the max number of tokens
# first pad according to the number of requests
if len(self.torchair_graph_batch_sizes) == 0:
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
else:
self.torchair_graph_batch_sizes = sorted(
self.torchair_graph_batch_sizes)
while self.torchair_graph_batch_sizes[-1] > self.max_num_reqs:
self.torchair_graph_batch_sizes.pop()
if len(self.torchair_graph_batch_sizes) == 0:
logger.warning(
"torch_graph_batch_sizes is invalid, reset it to [1, max_num_seqs]"
)
self.torchair_graph_batch_sizes = [1, self.max_num_reqs]
if self.torchair_graph_batch_sizes[-1] < self.max_num_reqs:
self.torchair_graph_batch_sizes.append(self.max_num_reqs)
# padded max number tokens = max_num_req * decode_token_per_req
self.torchair_graph_batch_sizes = [
graph_batch_size * self.decode_token_per_req
for graph_batch_size in self.torchair_graph_batch_sizes
]
# NOTE: when enable_expert_parallel, we need to check if `graph_batch_size` is divisible by `tp_size`
tp_size = self.parallel_config.tensor_parallel_size
if self.parallel_config.enable_expert_parallel:
new_graph_batch_sizes = []
for graph_batch_size in self.torchair_graph_batch_sizes:
cur_graph_batch_size = (graph_batch_size + tp_size -
1) // tp_size * tp_size
if cur_graph_batch_size not in new_graph_batch_sizes and \
cur_graph_batch_size <= self.scheduler_config.max_num_batched_tokens:
new_graph_batch_sizes.append(cur_graph_batch_size)
elif cur_graph_batch_size > self.scheduler_config.max_num_batched_tokens \
and self.decode_token_per_req > 1:
logger.warning(
f"torchair_graph_batch_sizes {cur_graph_batch_size} is bigger than max_num_batched_tokens",
f"{self.scheduler_config.max_num_batched_tokens} will skip this batch size."
)
self.torchair_graph_batch_sizes = new_graph_batch_sizes
def get_supported_pooling_tasks(self):
model = self.get_model()

View File

@@ -1,14 +1,23 @@
import types
import torch
import torch.nn as nn
import torchair
import vllm.envs as envs_vllm
from torchair import patch_for_hcom
from vllm.attention.layer import Attention
from vllm.config import (VllmConfig, get_layers_from_vllm_config,
set_current_vllm_config)
from vllm.forward_context import get_forward_context
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import (
process_weights_after_loading, set_default_torch_dtype)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
from vllm_ascend.utils import ProfileExecuteDuration
class MtpProposer:
@@ -22,7 +31,21 @@ class MtpProposer:
self.num_speculative_tokens = (
vllm_config.speculative_config.num_speculative_tokens)
self.block_size = vllm_config.cache_config.block_size
self.hidden_size = vllm_config.model_config.get_hidden_size()
self.runner = runner
# persistent buffers for graph
self.input_ids = torch.zeros(self.runner.max_num_tokens,
dtype=torch.int32,
device=self.runner.device)
self.positions = torch.zeros(self.runner.max_num_tokens,
dtype=torch.int64,
device=self.runner.device)
self.hidden_states = torch.zeros(
(self.runner.max_num_tokens, self.hidden_size),
dtype=self.runner.dtype,
device=self.runner.device)
self.torchair_compiled_model = None # type: ignore
self.torchair_compiled_models = {} # type: ignore
@staticmethod
def prepare_inputs(
@@ -30,7 +53,13 @@ class MtpProposer:
cu_target_query_lens: torch.Tensor,
# [batch_size]
num_rejected_tokens: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
token_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
slot_mapping: torch.Tensor,
is_torchair_graph: bool = False
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
# cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3]
# num_tokens_per_req: [a - n1, b - n2, c - n3]
@@ -38,63 +67,80 @@ class MtpProposer:
# token_indices: [0, 1, ..., a - n1 - 1,
# a, a + 1, ..., a + b - n2 - 1,
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
# [0, a, a + b, a + b + c] -> [a, b, c]
query_len_per_req = (cu_target_query_lens[1:] -
cu_target_query_lens[:-1])
# [a, b, c] -> [a - n1, b - n2, c - n3]
num_tokens_per_req = query_len_per_req - num_rejected_tokens
if is_torchair_graph:
cu_num_tokens = cu_target_query_lens
relative_index = query_len_per_req - num_rejected_tokens - 1
token_indices = cu_num_tokens[:-1] + relative_index
# the seq len of each bath is padded to 1+num_speculative_tokens, thus input is same as the main model
target_token_ids = token_ids
target_positions = positions
target_hidden_states = hidden_states
target_slot_mapping = slot_mapping
else:
cu_num_tokens = torch.empty_like(cu_target_query_lens)
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
cu_num_tokens[0] = 0
cu_num_tokens = torch.empty_like(cu_target_query_lens)
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
cu_num_tokens[0] = 0
# FIXME(woosuk): Avoid synchronization.
num_tokens = cu_num_tokens[-1].item()
token_indices = torch.empty(
num_tokens,
dtype=torch.int32,
device=cu_num_tokens.device,
)
# FIXME(woosuk): Avoid synchronization.
num_tokens = cu_num_tokens[-1].item()
token_indices = torch.empty(
num_tokens,
dtype=torch.int32,
device=cu_num_tokens.device,
)
BLOCK_SIZE = 1024
prepare_input_kernel(
token_indices,
cu_target_query_lens,
cu_num_tokens,
block_size=BLOCK_SIZE,
)
return cu_num_tokens, token_indices
BLOCK_SIZE = 1024
prepare_input_kernel(
token_indices,
cu_target_query_lens,
cu_num_tokens,
block_size=BLOCK_SIZE,
)
target_token_ids = token_ids[token_indices]
target_positions = positions[token_indices]
target_hidden_states = hidden_states[token_indices]
target_slot_mapping = slot_mapping[token_indices]
return cu_num_tokens, token_indices, target_token_ids, target_positions, target_hidden_states, target_slot_mapping
def propose(
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens]
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
# [num_tokens]
target_slot_mapping: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
# [batch_size + 1] starting with 0
cu_num_tokens: torch.Tensor,
# [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens]
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
# [num_tokens]
target_slot_mapping: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
# [batch_size + 1] starting with 0
cu_num_tokens: torch.Tensor,
# [batch_size, max_num_blocks_per_req]
block_table: torch.Tensor,
sampling_metadata: SamplingMetadata,
token_indices=None) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1
input_ids = torch.empty_like(target_token_ids)
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
input_ids[:-1] = target_token_ids[1:]
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
input_ids[last_token_indices] = next_token_ids
if token_indices is not None and self.runner.torchair_graph_enabled:
last_token_indices = token_indices
else:
seq_lens = target_positions[last_token_indices] + 1
seq_lens = seq_lens.cpu()
self.input_ids[last_token_indices] = next_token_ids
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
max_query_len = query_lens.max().item()
@@ -109,20 +155,76 @@ class MtpProposer:
# input_batch=self.runner.input_batch,
# scheduler_output=self.runner.scheduler_output,
# )
extra_builder_kwargs = {}
is_running_torchair = self.runner.torchair_graph_enabled and \
not self.runner.with_prefill
if is_running_torchair:
extra_builder_kwargs['graph_pad_size'] = self.runner.graph_pad_size
num_input_tokens = self.runner.graph_pad_size
else:
num_input_tokens = num_tokens
attn_metadata = self.runner.attn_metadata_builder.build(
num_reqs=batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
query_start_loc=cu_num_tokens,
)
**extra_builder_kwargs)
with set_ascend_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
input_ids=input_ids,
positions=target_positions,
previous_hidden_states=target_hidden_states,
)
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
if attn_metadata.prefill is not None:
attn_metadata.prefill.query_lens = query_lens.cpu()
attn_metadata.prefill.input_positions = target_positions
attn_metadata.prefill.seq_lens = seq_lens
if not self.runner.torchair_graph_enabled:
# torch mode need to update num_tokens_across_dp
# TODO: adapt enable_dbo later
(num_input_tokens, num_tokens_across_dp, with_prefill,
_) = self.runner._get_forward_metadata_across_dp_and_pad(
num_tokens, self.runner.with_prefill, False)
attn_metadata.slot_mapping = target_slot_mapping
else:
# torchair mode can reuse self.runner.num_tokens_across_dp
num_tokens_across_dp = self.runner.num_tokens_across_dp
with_prefill = self.runner.with_prefill
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
with_prefill=with_prefill,
num_tokens_across_dp=num_tokens_across_dp,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=num_tokens):
with ProfileExecuteDuration().capture_async('mtp_forward'):
model_kwargs = {}
model_kwargs["attn_metadata"] = attn_metadata
if self.runner.torchair_graph_enabled:
model_kwargs["kv_caches"] = self.runner.kv_caches[-1:]
if is_running_torchair:
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
num_input_tokens)
hidden_states = torchair_compiled_model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
previous_hidden_states=self.
hidden_states[:num_input_tokens],
inputs_embeds=None,
intermediate_tensors=None,
spec_step_idx=0,
**model_kwargs)
else:
hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
previous_hidden_states=self.
hidden_states[:num_input_tokens],
kv_caches=self.runner.kv_caches[-1:])
sample_hidden_states = hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids = logits.argmax(dim=-1)
@@ -159,6 +261,123 @@ class MtpProposer:
process_weights_after_loading(self.model, draft_model_config,
target_device)
@torch.inference_mode()
def dummy_run(self,
num_tokens: int,
with_prefill: bool = False,
skip_attn: bool = False,
num_reqs: int = 0,
num_tokens_across_dp=None) -> None:
if not self.runner.torchair_graph_enabled:
# TODO: adapt enable_dbo later
(num_tokens, num_tokens_across_dp, with_prefill,
_) = self.runner._get_forward_metadata_across_dp_and_pad(
num_tokens, with_prefill, False)
is_running_torchair = self.runner.torchair_graph_enabled and \
not with_prefill
if is_running_torchair:
skip_attn = False
if skip_attn:
attn_metadata = None
else:
attn_metadata = self.runner.attn_metadata_builder.build_torchair_graph_dummy(
num_reqs=num_reqs, num_actual_tokens=1)
input_ids = self.input_ids[:num_tokens]
positions = self.positions[:num_tokens]
previous_hidden_states = self.hidden_states[:num_tokens]
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,
with_prefill=with_prefill,
num_tokens_across_dp=num_tokens_across_dp,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=0):
if is_running_torchair:
assert attn_metadata is not None
torch._dynamo.mark_static(input_ids)
torch._dynamo.mark_static(positions)
torch._dynamo.mark_static(previous_hidden_states)
torch._dynamo.mark_static(attn_metadata.decode.block_table)
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
if hasattr(attn_metadata.decode, "sin"):
torch._dynamo.mark_static(attn_metadata.decode.sin)
torch._dynamo.mark_static(attn_metadata.decode.cos)
torch._dynamo.mark_static(get_forward_context().mc2_mask)
torch._dynamo.mark_static(attn_metadata.slot_mapping)
torch._dynamo.mark_static(attn_metadata.decode.attn_mask)
torchair_compiled_model = self._get_torchair_lazy_compiled_model(
num_tokens)
torchair_compiled_model(
input_ids=input_ids,
positions=positions,
previous_hidden_states=previous_hidden_states,
inputs_embeds=None,
intermediate_tensors=None,
attn_metadata=attn_metadata,
kv_caches=self.runner.kv_caches[-1:],
spec_step_idx=0)
else:
self.model(input_ids=input_ids,
positions=positions,
previous_hidden_states=previous_hidden_states)
def _get_torchair_lazy_compiled_model(self, batch_size: int):
if batch_size < 0 or batch_size > self.runner.torchair_graph_batch_sizes[
-1]:
raise ValueError(
f"Bad graph batch size:{batch_size}! max_graph_batch_sizes:{self.runner.torchair_graph_batch_sizes[-1]}"
)
compiled_model = self.torchair_compiled_models.get(
batch_size
) if self.runner.use_cached_npu_graph else self.torchair_compiled_model
if compiled_model:
return compiled_model
patch_for_hcom()
config = torchair.CompilerConfig()
config.experimental_config.frozen_parameter = True
config.experimental_config.tiling_schedule_optimize = True
config.experimental_config.enable_view_optimize = \
get_ascend_config().torchair_graph_config.enable_view_optimize
torch.npu.set_compile_mode(jit_compile=False)
if not self.runner.use_cached_npu_graph:
npu_backend = torchair.get_npu_backend(compiler_config=config)
self.torchair_compiled_model = torch.compile(
self.model,
dynamic=True,
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=npu_backend)
return self.torchair_compiled_model
else:
# Generate a new forward proxy code object to prevent the invalidation of
# compilation cache caused by dynamo retracing
forward_proxy_name = f"{self.model.__class__.__name__}_forward_with_batch_size_{batch_size}"
forward_fn = self.model.forward
code = forward_fn.__code__
# Mark code object with a new proxy name
modified_code = code.replace(co_name=forward_proxy_name, )
modified_func = types.FunctionType(modified_code,
forward_fn.__globals__,
name=forward_proxy_name,
argdefs=forward_fn.__defaults__)
self.model.__dict__[forward_proxy_name] = modified_func.__get__(
self.model, nn.Module)
self.torchair_compiled_models[
batch_size] = torchair.inference.cache_compile(
self.model.__dict__[forward_proxy_name],
dynamic=True,
fullgraph=envs_vllm.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
config=config,
ge_cache=False)
return self.torchair_compiled_models[batch_size]
# TODO Using torch instead of triton may result in poor performance
def prepare_input_kernel(out_ptr: torch.Tensor, cu_query_lens: torch.Tensor,