From ab5d110fcc35ca11330977450141b1d7176f21e7 Mon Sep 17 00:00:00 2001 From: fems14 <74094523+fems14@users.noreply.github.com> Date: Sat, 14 Jun 2025 22:31:16 +0800 Subject: [PATCH] vllm-ascend support chunked prefill (#1172) ### What this PR does / why we need it? vllm-ascend support chunked prefill for MLA --------- Signed-off-by: fems14 <1804143737@qq.com> --- docs/source/user_guide/additional_config.md | 1 + tests/singlecard/test_chunked.py | 74 ++++++ vllm_ascend/ascend_config.py | 2 + vllm_ascend/attention/mla_v1.py | 235 ++++++++++++++++++-- vllm_ascend/worker/model_runner_v1.py | 11 +- 5 files changed, 303 insertions(+), 20 deletions(-) create mode 100644 tests/singlecard/test_chunked.py diff --git a/docs/source/user_guide/additional_config.md b/docs/source/user_guide/additional_config.md index 778938a..d4756ef 100644 --- a/docs/source/user_guide/additional_config.md +++ b/docs/source/user_guide/additional_config.md @@ -31,6 +31,7 @@ The following table lists the additional configuration options available in vLLM | `expert_tensor_parallel_size` | str | `0` | Expert tensor parallel size the model to use. | | `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf case. | | `expert_map_path` | str | None | When using expert load balancing for the MOE model, an expert map path needs to be passed in. | +| `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. | The details of each config option are as follows: diff --git a/tests/singlecard/test_chunked.py b/tests/singlecard/test_chunked.py new file mode 100644 index 0000000..2240b88 --- /dev/null +++ b/tests/singlecard/test_chunked.py @@ -0,0 +1,74 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Compare the outputs of vLLM with and without aclgraph. + +Run `pytest tests/compile/test_aclgraph.py`. +""" + +import os + +import pytest +import torch +from vllm import LLM, SamplingParams + +MODELS = ["deepseek-ai/DeepSeek-V2-Lite"] + + +@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0", + reason="new chunked only support on v1") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("max_tokens", [1]) +def test_models( + model: str, + max_tokens: int, + monkeypatch: pytest.MonkeyPatch, +) -> None: + return + with monkeypatch.context() as m: + prompts = "The president of the United States is" + + m.setenv("VLLM_USE_V1", "1") + + sampling_params = SamplingParams( + max_tokens=max_tokens, + temperature=0.0, + ) + + vllm_model = LLM(model, + long_prefill_token_threshold=4, + enforce_eager=True) + output_chunked = vllm_model.generate(prompts, sampling_params) + logprobs_chunked = output_chunked.outputs[0].logprobs + del vllm_model + torch.npu.empty_cache() + + vllm_model = LLM(model, + enforce_eager=True, + additional_config={ + 'ascend_scheduler_config': { + 'enabled': True + }, + }) + output = vllm_model.generate(prompts, sampling_params) + logprobs = output.outputs[0].logprobs + del vllm_model + torch.npu.empty_cache() + + logprobs_similarity = torch.cosine_similarity( + logprobs_chunked.flatten(), logprobs.flatten(), dim=0) + assert logprobs_similarity > 0.95 diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 2d34283..defa7fd 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -39,6 +39,8 @@ class AscendConfig: self.expert_tensor_parallel_size = int( additional_config.get("expert_tensor_parallel_size", 0)) self.expert_map_path = additional_config.get("expert_map_path", None) + self.chunked_prefill_for_mla = additional_config.get( + "chunked_prefill_for_mla", False) class TorchairGraphConfig: diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index e07d59a..43cb71c 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -11,6 +11,7 @@ from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import get_current_vllm_config from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) +from vllm.utils import cdiv, round_down from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV @@ -69,6 +70,18 @@ class AscendMLABackend(AttentionBackend): @dataclass class AscendMLAPrefillMetadata: """ Prefill Specific Metadata for Ascend""" + + @dataclass + class ChunkedContextMetadata: + # New for MLA (compared to FlashAttention) + # For handling chunked prefill + cu_seq_lens: torch.Tensor + starts: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + workspace: torch.Tensor + chunk_seq_lens: torch.Tensor + attn_mask: torch.Tensor query_lens: list[int] seq_lens: list[int] @@ -78,6 +91,7 @@ class AscendMLAPrefillMetadata: block_table: torch.Tensor max_query_len: int max_seq_lens: int + chunked_context: Optional[ChunkedContextMetadata] = None @dataclass @@ -172,7 +186,32 @@ class AscendMLAMetadataBuilder: if metadata_cls is not None else AscendMLAMetadata # type: ignore self.runner = runner scheduler_config = runner.scheduler_config - self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled + model_config = runner.model_config + self.block_size = runner.block_size + self.chunked_prefill_enabled = runner.chunked_prefill_enabled + if self.chunked_prefill_enabled: + self.chunked_prefill_workspace_size = min( + # Max sure there is enough for 8 full length request or at least + # 4 pages of cache per request + max(8 * model_config.max_model_len, + 4 * scheduler_config.max_num_seqs * self.block_size), + # For long-context models try not to over-allocate limiting + # kv-cache space, limiting it to 64k tokens, + # which would result in the workspace being: + # 2*(576)*(64*1024) = 144mb + # (assuming 576 MLA head dim, and fp16) + # which would result in up-projected context being + # 2*(192*128)*(64*1024) = 3gb + # (assuming 192 QK head dim, 128 heads, and fp16) + 128 * 1024) + assert self.chunked_prefill_workspace_size >= \ + scheduler_config.max_num_seqs * self.block_size + self.chunked_prefill_workspace = torch.empty( + (self.chunked_prefill_workspace_size, + model_config.get_head_size()), + dtype=model_config.dtype, + device=runner.device, + ) ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled @@ -350,6 +389,7 @@ class AscendMLAMetadataBuilder: query_start_loc = common_attn_metadata.query_start_loc prefill_metadata = None + chunked_context_metadata = None if self._num_prefills > 0: reqs_start = self._num_decodes # prefill_start tokens_start = self._num_decode_tokens @@ -359,6 +399,41 @@ class AscendMLAMetadataBuilder: prefill_query_start_loc = query_start_loc[ reqs_start:] - query_start_loc[reqs_start] + context_lens_cpu = self.runner.input_batch.num_computed_tokens_cpu_tensor[ + reqs_start:num_reqs] + max_context_len_cpu = context_lens_cpu.max().item() + num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() + if self.chunked_prefill_enabled and max_context_len_cpu > 0: + max_context_chunk = (self.chunked_prefill_workspace_size // + num_prefills_with_context_cpu) + max_context_chunk = round_down(max_context_chunk, + self.block_size) + + assert max_context_chunk > 0 + num_chunks = cdiv(max_context_len_cpu, max_context_chunk) + chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \ + .unsqueeze(1).expand(-1, self._num_prefills) * max_context_chunk + chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), + chunk_starts + max_context_chunk) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) + cu_seq_lens_cpu = torch.zeros(num_chunks, + self._num_prefills + 1, + dtype=torch.int32, + pin_memory=True) + torch.cumsum(chunk_seq_lens, + dim=1, + out=cu_seq_lens_cpu[:, 1:], + dtype=torch.int32) + chunked_context_metadata = \ + AscendMLAPrefillMetadata.ChunkedContextMetadata( + cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), + starts=chunk_starts.to(device, non_blocking=True), + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + chunk_seq_lens=chunk_seq_lens, + workspace=self.chunked_prefill_workspace, + ) + prefill_metadata = AscendMLAPrefillMetadata( attn_mask=self.runner.attn_mask, query_lens=query_lens[tokens_start:], @@ -369,6 +444,7 @@ class AscendMLAMetadataBuilder: max_query_len=max_query_len, max_seq_lens=max_seq_lens, query_start_loc=prefill_query_start_loc, + chunked_context=chunked_context_metadata, ) decode_metadata = None @@ -575,6 +651,83 @@ class AscendMLAImpl(MLAAttentionImpl): self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) + def _compute_prefill_context( + self, + query: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + rope_dim: int, + attn_metadata: AscendMLAMetadata, + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + ): + prefill_metadata = attn_metadata.prefill + if prefill_metadata is None or prefill_metadata.chunked_context is None: + return prefix_output, prefix_lse + + iters = len(prefill_metadata.chunked_context.seq_tot) + q_pe = query[..., self.qk_nope_head_dim:] + q_nope = query[..., :self.qk_nope_head_dim] + + seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) + latent_kv_dim = kv_c_and_k_pe_cache.size(3) - rope_dim + cache_kv_c = kv_c_and_k_pe_cache[:, :, :, :latent_kv_dim] + cache_k_pe = kv_c_and_k_pe_cache[:, :, :, latent_kv_dim:] + for i in range(iters): + toks = prefill_metadata.chunked_context.seq_tot[i] + + seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i] + seq_len = torch.stack([seq_len1, seq_len2]) + kv_c_normed = torch.empty(toks, + kv_c_and_k_pe_cache.size(2), + latent_kv_dim, + dtype=query.dtype, + device=query.device) + k_pe = torch.empty(toks, + kv_c_and_k_pe_cache.size(2), + rope_dim, + dtype=query.dtype, + device=query.device) + + torch_npu.atb.npu_paged_cache_load( + cache_kv_c, + cache_k_pe, + prefill_metadata.block_table, + seq_len2.to(query.device), + seq_starts=prefill_metadata.chunked_context.starts[i], + key=kv_c_normed, + value=k_pe, + ) + + kv_c_normed = kv_c_normed.squeeze() + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \ + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv_nope\ + .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) + mask = torch.triu( + torch.ones(512, 512, device=query.device, dtype=query.dtype), + 1) + torch_npu.atb.npu_ring_mla( + q_nope=q_nope, + q_rope=q_pe, + k_nope=k_nope, + k_rope=k_pe, + value=v, + mask=mask, + seqlen=seq_len, + head_num=self.num_heads, + kv_head_num=self.num_heads, + pre_out=prefix_output, + prev_lse=prefix_lse, + qk_scale=self.scale, + kernel_type="kernel_type_high_precision", + mask_type="no_mask", + input_layout="type_bsnd", + calc_type="calc_type_default", + output=prefix_output, + softmax_lse=prefix_lse) + return prefix_output, prefix_lse + def _forward_prefill( self, query: torch.Tensor, @@ -586,19 +739,29 @@ class AscendMLAImpl(MLAAttentionImpl): assert attn_metadata.prefill is not None num_tokens = query.size(0) - attn_output = None + attn_output = torch.empty(num_tokens, + self.num_heads, + self.v_head_dim, + dtype=query.dtype, + device=query.device) + k_nope, value = self.kv_b_proj(kv_c_normed)[0].view( + -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) # Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache + ascend_config = get_ascend_config() + if attn_metadata.attn_state in [ AscendAttentionState.ChunkedPrefill, AscendAttentionState.SpecDecoding - ]: - attn_output = torch.empty(num_tokens, - self.num_heads * self.v_head_dim, - dtype=query.dtype, - device=query.device) + ] and not ascend_config.chunked_prefill_for_mla: + attn_output_torch = torch.empty(num_tokens, + self.num_heads * self.v_head_dim, + dtype=query.dtype, + device=query.device) # current requests is chunked in prefill, disable flash attention with chunked prefill vanilla_chunked_prefill_mla( - output=attn_output, + output=attn_output_torch, query=query, kv_cache=kv_c_and_k_pe_cache, block_tables=attn_metadata.prefill.block_table, @@ -613,18 +776,47 @@ class AscendMLAImpl(MLAAttentionImpl): scale=self.scale, alibi_slopes=None, causal=True) + elif attn_metadata.attn_state in [ + AscendAttentionState.ChunkedPrefill, + AscendAttentionState.SpecDecoding + ]: + attn_lse = torch.empty(self.num_heads, + num_tokens, + dtype=torch.float32, + device=query.device) + q_pe = query[..., self.qk_nope_head_dim:] + q_nope = query[..., :self.qk_nope_head_dim] + mask = torch.triu( + torch.ones(512, 512, device=query.device, dtype=query.dtype), + 1) # 512: mask only support 512 + if attn_metadata.num_prefills > 1: + mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1, + 1) + torch_npu.atb.npu_ring_mla( + q_nope=q_nope, + q_rope=q_pe, + k_nope=k_nope, + k_rope=k_pe, + value=value, + mask=mask, + seqlen=torch.tensor(attn_metadata.prefill.query_lens, + dtype=torch.int32), + head_num=self.num_heads, + kv_head_num=self.num_heads, + pre_out=None, + prev_lse=None, + qk_scale=self.scale, + kernel_type="kernel_type_high_precision", + mask_type="mask_type_triu", + input_layout="type_bsnd", + calc_type="calc_type_first_ring", + output=attn_output, + softmax_lse=attn_lse) + attn_output, attn_lse = self._compute_prefill_context( \ + query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse) + elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: - attn_output = torch.empty(num_tokens, - self.num_heads, - self.v_head_dim, - dtype=query.dtype, - device=query.device) - k_nope, value = self.kv_b_proj(kv_c_normed)[0].view( - -1, self.num_heads, - self.qk_nope_head_dim + self.v_head_dim).split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), - dim=-1) + key = torch.cat((k_nope, k_pe), dim=-1) torch_npu._npu_flash_attention( query=query, key=key, @@ -642,6 +834,11 @@ class AscendMLAImpl(MLAAttentionImpl): ) attn_output = attn_output.reshape( [num_tokens, self.num_heads * self.v_head_dim]) + if attn_metadata.attn_state in [ + AscendAttentionState.ChunkedPrefill, + AscendAttentionState.SpecDecoding + ] and not ascend_config.chunked_prefill_for_mla: + attn_output = attn_output_torch current_ms_metadata = get_multistream_comm_context() if current_ms_metadata is None: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index c358793..6d226da 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -134,7 +134,11 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.lora_config = vllm_config.lora_config self.scheduler_config = vllm_config.scheduler_config self.speculative_config = vllm_config.speculative_config - self.chunked_prefill_enabled = vllm_config.scheduler_config.chunked_prefill_enabled + ascend_config = get_ascend_config() + if ascend_config.ascend_scheduler_config.enabled: + self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled + else: + self.chunked_prefill_enabled = True self.device = device self.is_multimodal_model = self.model_config.is_multimodal_model @@ -1260,6 +1264,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. + discard_sampled_tokens_req_indices = [] for i, req_id in enumerate(self.input_batch.req_ids): req_state = self.requests[req_id] seq_len = (req_state.num_computed_tokens + @@ -1270,6 +1275,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): generator = self.input_batch.generators.get(i) if generator is not None: generator.set_offset(generator.get_offset() - 4) + discard_sampled_tokens_req_indices.append(i) # NOTE: NPU -> CPU Sync happens here. # Move as many CPU operations as possible before this sync point. @@ -1290,6 +1296,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.input_batch.vocab_size, ) + for i in discard_sampled_tokens_req_indices: + valid_sampled_token_ids[i].clear() + spec_token_ids = self._get_spec_token_ids( valid_sampled_token_ids, sampling_metadata,