[V1] MTP supports torchair (#2145)
### What this PR does / why we need it?
Support MTP with:
- [x] V0 Scheduler
- [x] TorchAir
- [x] Single DP
- [x] Multi DP
- [x] Disaggregate PD
Known issues:
- [ ] Not support V1 Scheduler (chunked prefill), will be supported in a
few weeks
- [ ] vllm v0.10.0 does not support metrics with `DP > 1` right now,
need to comment out the line 171-175 in file
`vllm/vllm/v1/metrics/loggers.py`
```
if (len(self.engine_indexes) > 1
and vllm_config.speculative_config is not None):
raise NotImplementedError("Prometheus metrics with Spec Decoding "
"with >1 EngineCore per AsyncLLM is not "
"supported yet.")
```
To start an online server with torchair enabled, here is an example:
```
python -m vllm.entrypoints.openai.api_server \
--model="/weights/DeepSeek-R1_w8a8/" \
--trust-remote-code \
--max-model-len 40000 \
--tensor-parallel-size 4 \
--data_parallel_size 4 \
--max-num-seqs 16 \
--no-enable-prefix-caching \
--enable_expert_parallel \
--served-model-name deepseekr1 \
--speculative-config '{"num_speculative_tokens": 1, "method":"deepseek_mtp"}' \
--quantization ascend \
--host 0.0.0.0 \
--port 1234 \
--additional-config '{"ascend_scheduler_config":{"enabled":true,"enable_chunked_prefill":false},"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]},"enable_weight_nz_layout":true}' \
--gpu_memory_utilization 0.9
```
offline example with torchair enabled
```
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=16, temperature=0)
# Create an LLM.
llm = LLM(
model="/home/data/DeepSeek-R1_w8a8/",
tensor_parallel_size=16,
max_num_seqs=16,
gpu_memory_utilization=0.9,
distributed_executor_backend="mp",
enable_expert_parallel=True,
speculative_config={
"method": "deepseek_mtp",
"num_speculative_tokens": 1,
},
trust_remote_code=True,
enforce_eager=False,
max_model_len=2000,
additional_config = {
'torchair_graph_config': {
'enabled': True,
"graph_batch_sizes": [16],
'enable_multistream_shared_expert': False,
},
"ascend_scheduler_config": {
"enabled": True
},
# 'expert_tensor_parallel_size': 16,
}
)
# Generate texts from the prompts.
# llm.start_profile()
outputs = llm.generate(prompts, sampling_params)
# llm.stop_profile()
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
- vLLM version: v0.10.0
- vLLM main:
302962e806
---------
Signed-off-by: xuyexiong <xuyexiong@huawei.com>
This commit is contained in:
@@ -156,7 +156,7 @@ class AscendAttentionTorchairMetadataBuilder:
|
||||
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
||||
assert max_batch_size >= num_seqs
|
||||
assert max_batch_size >= num_seqs, f"max_batch_size: {max_batch_size} should be bigger than cur_num_seqs: {num_seqs}"
|
||||
|
||||
if isinstance(self.runner.graph_block_tables, np.ndarray):
|
||||
graph_block_tables = torch.zeros((max_batch_size, max_blocks),
|
||||
@@ -259,26 +259,34 @@ class AscendAttentionTorchairMetadataBuilder:
|
||||
if use_torchair_graph and self.runner.attn_state in [
|
||||
AscendAttentionState.DecodeOnly,
|
||||
]:
|
||||
num_reqs_pad_size = 0
|
||||
num_token_pad_size = 0
|
||||
if graph_pad_size != 0:
|
||||
pad_value = 0
|
||||
num_token_pad_size = graph_pad_size - num_actual_tokens
|
||||
num_reqs_pad_size = (
|
||||
graph_pad_size // self.runner.decode_token_per_req -
|
||||
num_reqs)
|
||||
pad_value = 1
|
||||
padded_seq_lens = seq_lens.tolist() + [pad_value
|
||||
] * graph_pad_size
|
||||
] * num_reqs_pad_size
|
||||
|
||||
seq_lens = torch.from_numpy(
|
||||
np.array(padded_seq_lens).astype(np.int32))
|
||||
padding = torch.full((graph_pad_size, ),
|
||||
padding = torch.full((num_token_pad_size, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=slot_mapping.dtype,
|
||||
device=slot_mapping.device)
|
||||
slot_mapping = torch.cat([slot_mapping, padding])
|
||||
block_table_padding = torch.zeros(
|
||||
(graph_pad_size, ) + block_table.shape[1:],
|
||||
(num_reqs_pad_size, ) + block_table.shape[1:],
|
||||
dtype=block_table.dtype,
|
||||
device=block_table.device)
|
||||
block_table = torch.cat([block_table, block_table_padding],
|
||||
dim=0)
|
||||
block_table = self._get_graph_runner_block_tables(
|
||||
num_seqs + graph_pad_size, block_table)
|
||||
padding_0 = torch.zeros(graph_pad_size,
|
||||
num_seqs + num_reqs_pad_size, block_table)
|
||||
padding_0 = torch.zeros(num_token_pad_size,
|
||||
dtype=input_positions.dtype,
|
||||
device=input_positions.device)
|
||||
input_positions = torch.cat([input_positions, padding_0])
|
||||
|
||||
@@ -93,6 +93,7 @@ class AscendMLADecodeMetadata:
|
||||
seq_lens: torch.Tensor
|
||||
max_seq_lens: int
|
||||
seq_lens_list: list[int]
|
||||
actual_seq_lengths_q: Optional[list[int]] = None
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
sin: torch.Tensor = None
|
||||
cos: torch.Tensor = None
|
||||
@@ -283,7 +284,7 @@ class AscendMLAMetadataBuilder:
|
||||
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
||||
assert max_batch_size >= num_seqs
|
||||
assert max_batch_size >= num_seqs, f"max_batch_size: {max_batch_size} should be bigger than cur_num_seqs: {num_seqs}"
|
||||
|
||||
if isinstance(self.runner.graph_block_tables, np.ndarray):
|
||||
graph_block_tables = torch.zeros((max_batch_size, max_blocks),
|
||||
@@ -314,11 +315,13 @@ class AscendMLAMetadataBuilder:
|
||||
device=device)
|
||||
block_table = self._get_graph_runner_block_tables(
|
||||
num_reqs, block_table)
|
||||
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
|
||||
input_positions = torch.zeros(num_reqs,
|
||||
num_tokens = num_reqs * self.runner.decode_token_per_req
|
||||
seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device)
|
||||
seq_lens_list = [0] * num_reqs
|
||||
input_positions = torch.zeros(num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=device).long()
|
||||
slot_mapping = torch.full((num_reqs, ),
|
||||
slot_mapping = torch.full((num_tokens, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
@@ -326,37 +329,46 @@ class AscendMLAMetadataBuilder:
|
||||
-1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
sin = torch.ones(num_reqs,
|
||||
sin = torch.ones(num_tokens,
|
||||
1,
|
||||
1,
|
||||
self.rope_dim,
|
||||
dtype=self.runner.dtype,
|
||||
device=device)
|
||||
cos = torch.ones(num_reqs,
|
||||
cos = torch.ones(num_tokens,
|
||||
1,
|
||||
1,
|
||||
self.rope_dim,
|
||||
dtype=self.runner.dtype,
|
||||
device=device)
|
||||
if self.runner.speculative_config is not None and\
|
||||
self.runner.speculative_config.method == 'deepseek_mtp':
|
||||
attn_state = AscendAttentionState.SpecDecoding
|
||||
num_decode_tokens = 2
|
||||
else:
|
||||
attn_state = AscendAttentionState.DecodeOnly
|
||||
num_decode_tokens = 1
|
||||
decode_metadata = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens.tolist(),
|
||||
seq_lens_list=seq_lens_list,
|
||||
max_seq_lens=1,
|
||||
attn_mask=self.runner.spec_attn_mask,
|
||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q[:num_reqs],
|
||||
sin=sin,
|
||||
cos=cos)
|
||||
cos=cos,
|
||||
)
|
||||
return self.metadata_cls( # type: ignore
|
||||
num_input_tokens=num_actual_tokens,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
head_dim=self.runner.model_config.get_head_size(),
|
||||
num_decodes=1,
|
||||
num_decode_tokens=1,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=0,
|
||||
attn_mask=self.runner.attn_mask,
|
||||
attn_state=AscendAttentionState.DecodeOnly,
|
||||
attn_state=attn_state,
|
||||
prefill=None,
|
||||
decode=decode_metadata,
|
||||
query_start_loc=query_start_loc,
|
||||
@@ -473,6 +485,7 @@ class AscendMLAMetadataBuilder:
|
||||
decode_metadata = None
|
||||
use_torchair_graph = graph_pad_size != -1
|
||||
if self._num_decodes > 0:
|
||||
actual_seq_lengths_q = query_start_loc[1:].tolist()
|
||||
max_seq_lens = seq_lens[:self._num_decodes].max().item()
|
||||
seq_lens = seq_lens[:self._num_decode_tokens]
|
||||
input_positions = input_positions[:self._num_decode_tokens]
|
||||
@@ -481,33 +494,51 @@ class AscendMLAMetadataBuilder:
|
||||
AscendAttentionState.DecodeOnly,
|
||||
AscendAttentionState.SpecDecoding
|
||||
]:
|
||||
num_seqs = len(seq_lens)
|
||||
num_reqs_pad_size = 0
|
||||
num_token_pad_size = 0
|
||||
if graph_pad_size != 0:
|
||||
pad_value = 1
|
||||
padded_seq_lens = seq_lens.tolist() + [pad_value
|
||||
] * graph_pad_size
|
||||
pad_value = 0
|
||||
num_token_pad_size = graph_pad_size - self._num_decode_tokens
|
||||
num_reqs_pad_size = (
|
||||
graph_pad_size // self.runner.decode_token_per_req -
|
||||
num_reqs)
|
||||
padded_seq_lens = seq_lens.tolist(
|
||||
) + [pad_value] * num_reqs_pad_size
|
||||
else:
|
||||
padded_seq_lens = seq_lens.tolist()
|
||||
|
||||
seq_lens = torch.from_numpy(
|
||||
np.array(padded_seq_lens).astype(np.int32))
|
||||
padding = torch.full((graph_pad_size, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=slot_mapping.dtype,
|
||||
device=slot_mapping.device)
|
||||
slot_mapping = torch.cat([slot_mapping, padding])
|
||||
seq_lens_list = padded_seq_lens
|
||||
slot_padding = torch.full((num_token_pad_size, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=slot_mapping.dtype,
|
||||
device=slot_mapping.device)
|
||||
slot_mapping = torch.cat([slot_mapping, slot_padding])
|
||||
block_table_padding = torch.zeros(
|
||||
(graph_pad_size, ) + block_table.shape[1:],
|
||||
(num_reqs_pad_size, ) + block_table.shape[1:],
|
||||
dtype=block_table.dtype,
|
||||
device=block_table.device)
|
||||
block_table = torch.cat([block_table, block_table_padding],
|
||||
dim=0)
|
||||
block_table = self._get_graph_runner_block_tables(
|
||||
num_seqs + graph_pad_size, block_table)
|
||||
padding_0 = torch.zeros(graph_pad_size,
|
||||
dtype=input_positions.dtype,
|
||||
device=input_positions.device)
|
||||
input_positions = torch.cat([input_positions, padding_0])
|
||||
num_reqs + num_reqs_pad_size, block_table)
|
||||
position_padding = torch.zeros(num_token_pad_size,
|
||||
dtype=input_positions.dtype,
|
||||
device=input_positions.device)
|
||||
input_positions = torch.cat(
|
||||
[input_positions, position_padding])
|
||||
actual_seq_lengths_q = query_start_loc[1:].tolist(
|
||||
) + self.runner.actual_seq_lengths_q[num_reqs:num_reqs +
|
||||
num_reqs_pad_size]
|
||||
else:
|
||||
seq_lens_list = seq_lens.tolist()
|
||||
# mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens)
|
||||
batch_size = slot_mapping.size(0)
|
||||
if actual_seq_lengths_q[-1] != batch_size \
|
||||
and self.runner.attn_state == AscendAttentionState.SpecDecoding:
|
||||
actual_seq_lengths_q[-1] = batch_size
|
||||
|
||||
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
|
||||
@@ -517,9 +548,10 @@ class AscendMLAMetadataBuilder:
|
||||
input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens.tolist(),
|
||||
seq_lens_list=seq_lens_list,
|
||||
max_seq_lens=max_seq_lens,
|
||||
attn_mask=self.runner.spec_attn_mask,
|
||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||
sin=sin,
|
||||
cos=cos)
|
||||
|
||||
@@ -965,31 +997,10 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
assert decode_meta is not None
|
||||
num_tokens = q_nope.size(0)
|
||||
if self.running_in_graph:
|
||||
# TorchAir's shape is [bs, num_heads_per_rank, q_seq_len, dim]
|
||||
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
|
||||
assert num_tokens % self.spec_token_num == 0
|
||||
q_nope = q_nope.view(num_tokens // (self.spec_token_num + 1),
|
||||
self.spec_token_num + 1, self.num_heads,
|
||||
-1)
|
||||
q_pe = q_pe.view(num_tokens // (self.spec_token_num + 1),
|
||||
self.spec_token_num + 1, self.num_heads, -1)
|
||||
if not self.enable_kv_nz:
|
||||
q_nope = q_nope.transpose(1, 2).contiguous()
|
||||
q_pe = q_pe.transpose(1, 2).contiguous()
|
||||
sparse_mode = 3
|
||||
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
|
||||
else:
|
||||
if self.enable_kv_nz:
|
||||
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1)
|
||||
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
|
||||
else:
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
|
||||
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
||||
sparse_mode = 0
|
||||
spec_attn_mask = None
|
||||
# shape of knope/k_pe for npu graph mode should be:
|
||||
# [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim]
|
||||
block_size = kv_c_and_k_pe_cache[0].shape[1]
|
||||
actual_seq_lengths = None
|
||||
if self.enable_kv_nz:
|
||||
k_nope = k_nope.view(-1, self.num_kv_heads,
|
||||
self.kv_lora_rank // 16, block_size, 16)
|
||||
@@ -1003,6 +1014,25 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.qk_rope_head_dim)
|
||||
input_layout = "BNSD"
|
||||
|
||||
if attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
|
||||
assert num_tokens % self.spec_token_num == 0
|
||||
input_layout = "TND"
|
||||
# [bs * q_seq_len, num_heads_per_rank, dim]
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, -1)
|
||||
q_pe = q_pe.view(num_tokens, self.num_heads, -1)
|
||||
sparse_mode = 3
|
||||
spec_attn_mask = attn_metadata.decode.attn_mask # type:ignore
|
||||
actual_seq_lengths = decode_meta.actual_seq_lengths_q
|
||||
else:
|
||||
if self.enable_kv_nz:
|
||||
q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1)
|
||||
q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1)
|
||||
else:
|
||||
q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1)
|
||||
q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1)
|
||||
sparse_mode = 0
|
||||
spec_attn_mask = None
|
||||
|
||||
attn_output, _ = torch_npu.npu_fused_infer_attention_score(
|
||||
q_nope,
|
||||
k_nope,
|
||||
@@ -1020,7 +1050,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
block_table=decode_meta.block_table,
|
||||
block_size=block_size,
|
||||
actual_seq_lengths_kv=decode_meta.seq_lens_list,
|
||||
)
|
||||
actual_seq_lengths=actual_seq_lengths)
|
||||
else:
|
||||
# The MLA_PA path will be used as default path in the future, `_npu_paged_attention_mla` will
|
||||
# be removed after the torch_npu contains `torch_npu.atb.npu_multi_head_latent_attention` become
|
||||
|
||||
Reference in New Issue
Block a user