vllm-ascend support chunked prefill (#1172)
### What this PR does / why we need it? vllm-ascend support chunked prefill for MLA --------- Signed-off-by: fems14 <1804143737@qq.com>
This commit is contained in:
@@ -31,6 +31,7 @@ The following table lists the additional configuration options available in vLLM
|
||||
| `expert_tensor_parallel_size` | str | `0` | Expert tensor parallel size the model to use. |
|
||||
| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf case. |
|
||||
| `expert_map_path` | str | None | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
|
||||
| `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. |
|
||||
|
||||
The details of each config option are as follows:
|
||||
|
||||
|
||||
74
tests/singlecard/test_chunked.py
Normal file
74
tests/singlecard/test_chunked.py
Normal file
@@ -0,0 +1,74 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
Compare the outputs of vLLM with and without aclgraph.
|
||||
|
||||
Run `pytest tests/compile/test_aclgraph.py`.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
MODELS = ["deepseek-ai/DeepSeek-V2-Lite"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
|
||||
reason="new chunked only support on v1")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("max_tokens", [1])
|
||||
def test_models(
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
return
|
||||
with monkeypatch.context() as m:
|
||||
prompts = "The president of the United States is"
|
||||
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
vllm_model = LLM(model,
|
||||
long_prefill_token_threshold=4,
|
||||
enforce_eager=True)
|
||||
output_chunked = vllm_model.generate(prompts, sampling_params)
|
||||
logprobs_chunked = output_chunked.outputs[0].logprobs
|
||||
del vllm_model
|
||||
torch.npu.empty_cache()
|
||||
|
||||
vllm_model = LLM(model,
|
||||
enforce_eager=True,
|
||||
additional_config={
|
||||
'ascend_scheduler_config': {
|
||||
'enabled': True
|
||||
},
|
||||
})
|
||||
output = vllm_model.generate(prompts, sampling_params)
|
||||
logprobs = output.outputs[0].logprobs
|
||||
del vllm_model
|
||||
torch.npu.empty_cache()
|
||||
|
||||
logprobs_similarity = torch.cosine_similarity(
|
||||
logprobs_chunked.flatten(), logprobs.flatten(), dim=0)
|
||||
assert logprobs_similarity > 0.95
|
||||
@@ -39,6 +39,8 @@ class AscendConfig:
|
||||
self.expert_tensor_parallel_size = int(
|
||||
additional_config.get("expert_tensor_parallel_size", 0))
|
||||
self.expert_map_path = additional_config.get("expert_map_path", None)
|
||||
self.chunked_prefill_for_mla = additional_config.get(
|
||||
"chunked_prefill_for_mla", False)
|
||||
|
||||
|
||||
class TorchairGraphConfig:
|
||||
|
||||
@@ -11,6 +11,7 @@ from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.utils import cdiv, round_down
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV
|
||||
@@ -69,6 +70,18 @@ class AscendMLABackend(AttentionBackend):
|
||||
@dataclass
|
||||
class AscendMLAPrefillMetadata:
|
||||
""" Prefill Specific Metadata for Ascend"""
|
||||
|
||||
@dataclass
|
||||
class ChunkedContextMetadata:
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# For handling chunked prefill
|
||||
cu_seq_lens: torch.Tensor
|
||||
starts: torch.Tensor
|
||||
seq_tot: list[int]
|
||||
max_seq_lens: list[int]
|
||||
workspace: torch.Tensor
|
||||
chunk_seq_lens: torch.Tensor
|
||||
|
||||
attn_mask: torch.Tensor
|
||||
query_lens: list[int]
|
||||
seq_lens: list[int]
|
||||
@@ -78,6 +91,7 @@ class AscendMLAPrefillMetadata:
|
||||
block_table: torch.Tensor
|
||||
max_query_len: int
|
||||
max_seq_lens: int
|
||||
chunked_context: Optional[ChunkedContextMetadata] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -172,7 +186,32 @@ class AscendMLAMetadataBuilder:
|
||||
if metadata_cls is not None else AscendMLAMetadata # type: ignore
|
||||
self.runner = runner
|
||||
scheduler_config = runner.scheduler_config
|
||||
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
||||
model_config = runner.model_config
|
||||
self.block_size = runner.block_size
|
||||
self.chunked_prefill_enabled = runner.chunked_prefill_enabled
|
||||
if self.chunked_prefill_enabled:
|
||||
self.chunked_prefill_workspace_size = min(
|
||||
# Max sure there is enough for 8 full length request or at least
|
||||
# 4 pages of cache per request
|
||||
max(8 * model_config.max_model_len,
|
||||
4 * scheduler_config.max_num_seqs * self.block_size),
|
||||
# For long-context models try not to over-allocate limiting
|
||||
# kv-cache space, limiting it to 64k tokens,
|
||||
# which would result in the workspace being:
|
||||
# 2*(576)*(64*1024) = 144mb
|
||||
# (assuming 576 MLA head dim, and fp16)
|
||||
# which would result in up-projected context being
|
||||
# 2*(192*128)*(64*1024) = 3gb
|
||||
# (assuming 192 QK head dim, 128 heads, and fp16)
|
||||
128 * 1024)
|
||||
assert self.chunked_prefill_workspace_size >= \
|
||||
scheduler_config.max_num_seqs * self.block_size
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(self.chunked_prefill_workspace_size,
|
||||
model_config.get_head_size()),
|
||||
dtype=model_config.dtype,
|
||||
device=runner.device,
|
||||
)
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
|
||||
@@ -350,6 +389,7 @@ class AscendMLAMetadataBuilder:
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
|
||||
prefill_metadata = None
|
||||
chunked_context_metadata = None
|
||||
if self._num_prefills > 0:
|
||||
reqs_start = self._num_decodes # prefill_start
|
||||
tokens_start = self._num_decode_tokens
|
||||
@@ -359,6 +399,41 @@ class AscendMLAMetadataBuilder:
|
||||
prefill_query_start_loc = query_start_loc[
|
||||
reqs_start:] - query_start_loc[reqs_start]
|
||||
|
||||
context_lens_cpu = self.runner.input_batch.num_computed_tokens_cpu_tensor[
|
||||
reqs_start:num_reqs]
|
||||
max_context_len_cpu = context_lens_cpu.max().item()
|
||||
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
||||
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
|
||||
max_context_chunk = (self.chunked_prefill_workspace_size //
|
||||
num_prefills_with_context_cpu)
|
||||
max_context_chunk = round_down(max_context_chunk,
|
||||
self.block_size)
|
||||
|
||||
assert max_context_chunk > 0
|
||||
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
|
||||
chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
|
||||
.unsqueeze(1).expand(-1, self._num_prefills) * max_context_chunk
|
||||
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
|
||||
chunk_starts + max_context_chunk)
|
||||
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
||||
cu_seq_lens_cpu = torch.zeros(num_chunks,
|
||||
self._num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
pin_memory=True)
|
||||
torch.cumsum(chunk_seq_lens,
|
||||
dim=1,
|
||||
out=cu_seq_lens_cpu[:, 1:],
|
||||
dtype=torch.int32)
|
||||
chunked_context_metadata = \
|
||||
AscendMLAPrefillMetadata.ChunkedContextMetadata(
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
||||
starts=chunk_starts.to(device, non_blocking=True),
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
chunk_seq_lens=chunk_seq_lens,
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
)
|
||||
|
||||
prefill_metadata = AscendMLAPrefillMetadata(
|
||||
attn_mask=self.runner.attn_mask,
|
||||
query_lens=query_lens[tokens_start:],
|
||||
@@ -369,6 +444,7 @@ class AscendMLAMetadataBuilder:
|
||||
max_query_len=max_query_len,
|
||||
max_seq_lens=max_seq_lens,
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
chunked_context=chunked_context_metadata,
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
@@ -575,6 +651,83 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
|
||||
self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
|
||||
|
||||
def _compute_prefill_context(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
rope_dim: int,
|
||||
attn_metadata: AscendMLAMetadata,
|
||||
prefix_output: torch.Tensor,
|
||||
prefix_lse: torch.Tensor,
|
||||
):
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
if prefill_metadata is None or prefill_metadata.chunked_context is None:
|
||||
return prefix_output, prefix_lse
|
||||
|
||||
iters = len(prefill_metadata.chunked_context.seq_tot)
|
||||
q_pe = query[..., self.qk_nope_head_dim:]
|
||||
q_nope = query[..., :self.qk_nope_head_dim]
|
||||
|
||||
seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
|
||||
latent_kv_dim = kv_c_and_k_pe_cache.size(3) - rope_dim
|
||||
cache_kv_c = kv_c_and_k_pe_cache[:, :, :, :latent_kv_dim]
|
||||
cache_k_pe = kv_c_and_k_pe_cache[:, :, :, latent_kv_dim:]
|
||||
for i in range(iters):
|
||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||
|
||||
seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
|
||||
seq_len = torch.stack([seq_len1, seq_len2])
|
||||
kv_c_normed = torch.empty(toks,
|
||||
kv_c_and_k_pe_cache.size(2),
|
||||
latent_kv_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
k_pe = torch.empty(toks,
|
||||
kv_c_and_k_pe_cache.size(2),
|
||||
rope_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
|
||||
torch_npu.atb.npu_paged_cache_load(
|
||||
cache_kv_c,
|
||||
cache_k_pe,
|
||||
prefill_metadata.block_table,
|
||||
seq_len2.to(query.device),
|
||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||
key=kv_c_normed,
|
||||
value=k_pe,
|
||||
)
|
||||
|
||||
kv_c_normed = kv_c_normed.squeeze()
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
|
||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv_nope\
|
||||
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
||||
mask = torch.triu(
|
||||
torch.ones(512, 512, device=query.device, dtype=query.dtype),
|
||||
1)
|
||||
torch_npu.atb.npu_ring_mla(
|
||||
q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope,
|
||||
k_rope=k_pe,
|
||||
value=v,
|
||||
mask=mask,
|
||||
seqlen=seq_len,
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=prefix_output,
|
||||
prev_lse=prefix_lse,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="no_mask",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_default",
|
||||
output=prefix_output,
|
||||
softmax_lse=prefix_lse)
|
||||
return prefix_output, prefix_lse
|
||||
|
||||
def _forward_prefill(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
@@ -586,19 +739,29 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
assert attn_metadata.prefill is not None
|
||||
|
||||
num_tokens = query.size(0)
|
||||
attn_output = None
|
||||
attn_output = torch.empty(num_tokens,
|
||||
self.num_heads,
|
||||
self.v_head_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
|
||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
||||
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
|
||||
ascend_config = get_ascend_config()
|
||||
|
||||
if attn_metadata.attn_state in [
|
||||
AscendAttentionState.ChunkedPrefill,
|
||||
AscendAttentionState.SpecDecoding
|
||||
]:
|
||||
attn_output = torch.empty(num_tokens,
|
||||
self.num_heads * self.v_head_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
] and not ascend_config.chunked_prefill_for_mla:
|
||||
attn_output_torch = torch.empty(num_tokens,
|
||||
self.num_heads * self.v_head_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
# current requests is chunked in prefill, disable flash attention with chunked prefill
|
||||
vanilla_chunked_prefill_mla(
|
||||
output=attn_output,
|
||||
output=attn_output_torch,
|
||||
query=query,
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
block_tables=attn_metadata.prefill.block_table,
|
||||
@@ -613,18 +776,47 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
scale=self.scale,
|
||||
alibi_slopes=None,
|
||||
causal=True)
|
||||
elif attn_metadata.attn_state in [
|
||||
AscendAttentionState.ChunkedPrefill,
|
||||
AscendAttentionState.SpecDecoding
|
||||
]:
|
||||
attn_lse = torch.empty(self.num_heads,
|
||||
num_tokens,
|
||||
dtype=torch.float32,
|
||||
device=query.device)
|
||||
q_pe = query[..., self.qk_nope_head_dim:]
|
||||
q_nope = query[..., :self.qk_nope_head_dim]
|
||||
mask = torch.triu(
|
||||
torch.ones(512, 512, device=query.device, dtype=query.dtype),
|
||||
1) # 512: mask only support 512
|
||||
if attn_metadata.num_prefills > 1:
|
||||
mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1,
|
||||
1)
|
||||
torch_npu.atb.npu_ring_mla(
|
||||
q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope,
|
||||
k_rope=k_pe,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
|
||||
dtype=torch.int32),
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=None,
|
||||
prev_lse=None,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="mask_type_triu",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_first_ring",
|
||||
output=attn_output,
|
||||
softmax_lse=attn_lse)
|
||||
attn_output, attn_lse = self._compute_prefill_context( \
|
||||
query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
|
||||
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
attn_output = torch.empty(num_tokens,
|
||||
self.num_heads,
|
||||
self.v_head_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
|
||||
-1, self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim).split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
|
||||
dim=-1)
|
||||
key = torch.cat((k_nope, k_pe), dim=-1)
|
||||
torch_npu._npu_flash_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
@@ -642,6 +834,11 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
)
|
||||
attn_output = attn_output.reshape(
|
||||
[num_tokens, self.num_heads * self.v_head_dim])
|
||||
if attn_metadata.attn_state in [
|
||||
AscendAttentionState.ChunkedPrefill,
|
||||
AscendAttentionState.SpecDecoding
|
||||
] and not ascend_config.chunked_prefill_for_mla:
|
||||
attn_output = attn_output_torch
|
||||
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
if current_ms_metadata is None:
|
||||
|
||||
@@ -134,7 +134,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.scheduler_config = vllm_config.scheduler_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.chunked_prefill_enabled = vllm_config.scheduler_config.chunked_prefill_enabled
|
||||
ascend_config = get_ascend_config()
|
||||
if ascend_config.ascend_scheduler_config.enabled:
|
||||
self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled
|
||||
else:
|
||||
self.chunked_prefill_enabled = True
|
||||
self.device = device
|
||||
|
||||
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||||
@@ -1260,6 +1264,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
# the requests one by one. Optimize.
|
||||
discard_sampled_tokens_req_indices = []
|
||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
@@ -1270,6 +1275,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
generator = self.input_batch.generators.get(i)
|
||||
if generator is not None:
|
||||
generator.set_offset(generator.get_offset() - 4)
|
||||
discard_sampled_tokens_req_indices.append(i)
|
||||
|
||||
# NOTE: NPU -> CPU Sync happens here.
|
||||
# Move as many CPU operations as possible before this sync point.
|
||||
@@ -1290,6 +1296,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.input_batch.vocab_size,
|
||||
)
|
||||
|
||||
for i in discard_sampled_tokens_req_indices:
|
||||
valid_sampled_token_ids[i].clear()
|
||||
|
||||
spec_token_ids = self._get_spec_token_ids(
|
||||
valid_sampled_token_ids,
|
||||
sampling_metadata,
|
||||
|
||||
Reference in New Issue
Block a user