vllm-ascend support chunked prefill (#1172)
### What this PR does / why we need it? vllm-ascend support chunked prefill for MLA --------- Signed-off-by: fems14 <1804143737@qq.com>
This commit is contained in:
@@ -31,6 +31,7 @@ The following table lists the additional configuration options available in vLLM
|
|||||||
| `expert_tensor_parallel_size` | str | `0` | Expert tensor parallel size the model to use. |
|
| `expert_tensor_parallel_size` | str | `0` | Expert tensor parallel size the model to use. |
|
||||||
| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf case. |
|
| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf case. |
|
||||||
| `expert_map_path` | str | None | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
|
| `expert_map_path` | str | None | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
|
||||||
|
| `chunked_prefill_for_mla` | bool | `False` | Whether to enable the fused operator-like chunked_prefill. |
|
||||||
|
|
||||||
The details of each config option are as follows:
|
The details of each config option are as follows:
|
||||||
|
|
||||||
|
|||||||
74
tests/singlecard/test_chunked.py
Normal file
74
tests/singlecard/test_chunked.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
"""
|
||||||
|
Compare the outputs of vLLM with and without aclgraph.
|
||||||
|
|
||||||
|
Run `pytest tests/compile/test_aclgraph.py`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
MODELS = ["deepseek-ai/DeepSeek-V2-Lite"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
|
||||||
|
reason="new chunked only support on v1")
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("max_tokens", [1])
|
||||||
|
def test_models(
|
||||||
|
model: str,
|
||||||
|
max_tokens: int,
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
return
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
prompts = "The president of the United States is"
|
||||||
|
|
||||||
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
vllm_model = LLM(model,
|
||||||
|
long_prefill_token_threshold=4,
|
||||||
|
enforce_eager=True)
|
||||||
|
output_chunked = vllm_model.generate(prompts, sampling_params)
|
||||||
|
logprobs_chunked = output_chunked.outputs[0].logprobs
|
||||||
|
del vllm_model
|
||||||
|
torch.npu.empty_cache()
|
||||||
|
|
||||||
|
vllm_model = LLM(model,
|
||||||
|
enforce_eager=True,
|
||||||
|
additional_config={
|
||||||
|
'ascend_scheduler_config': {
|
||||||
|
'enabled': True
|
||||||
|
},
|
||||||
|
})
|
||||||
|
output = vllm_model.generate(prompts, sampling_params)
|
||||||
|
logprobs = output.outputs[0].logprobs
|
||||||
|
del vllm_model
|
||||||
|
torch.npu.empty_cache()
|
||||||
|
|
||||||
|
logprobs_similarity = torch.cosine_similarity(
|
||||||
|
logprobs_chunked.flatten(), logprobs.flatten(), dim=0)
|
||||||
|
assert logprobs_similarity > 0.95
|
||||||
@@ -39,6 +39,8 @@ class AscendConfig:
|
|||||||
self.expert_tensor_parallel_size = int(
|
self.expert_tensor_parallel_size = int(
|
||||||
additional_config.get("expert_tensor_parallel_size", 0))
|
additional_config.get("expert_tensor_parallel_size", 0))
|
||||||
self.expert_map_path = additional_config.get("expert_map_path", None)
|
self.expert_map_path = additional_config.get("expert_map_path", None)
|
||||||
|
self.chunked_prefill_for_mla = additional_config.get(
|
||||||
|
"chunked_prefill_for_mla", False)
|
||||||
|
|
||||||
|
|
||||||
class TorchairGraphConfig:
|
class TorchairGraphConfig:
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from vllm.attention.backends.utils import PAD_SLOT_ID
|
|||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.model_executor.layers.linear import (LinearBase,
|
from vllm.model_executor.layers.linear import (LinearBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
|
from vllm.utils import cdiv, round_down
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV
|
from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV
|
||||||
@@ -69,6 +70,18 @@ class AscendMLABackend(AttentionBackend):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class AscendMLAPrefillMetadata:
|
class AscendMLAPrefillMetadata:
|
||||||
""" Prefill Specific Metadata for Ascend"""
|
""" Prefill Specific Metadata for Ascend"""
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChunkedContextMetadata:
|
||||||
|
# New for MLA (compared to FlashAttention)
|
||||||
|
# For handling chunked prefill
|
||||||
|
cu_seq_lens: torch.Tensor
|
||||||
|
starts: torch.Tensor
|
||||||
|
seq_tot: list[int]
|
||||||
|
max_seq_lens: list[int]
|
||||||
|
workspace: torch.Tensor
|
||||||
|
chunk_seq_lens: torch.Tensor
|
||||||
|
|
||||||
attn_mask: torch.Tensor
|
attn_mask: torch.Tensor
|
||||||
query_lens: list[int]
|
query_lens: list[int]
|
||||||
seq_lens: list[int]
|
seq_lens: list[int]
|
||||||
@@ -78,6 +91,7 @@ class AscendMLAPrefillMetadata:
|
|||||||
block_table: torch.Tensor
|
block_table: torch.Tensor
|
||||||
max_query_len: int
|
max_query_len: int
|
||||||
max_seq_lens: int
|
max_seq_lens: int
|
||||||
|
chunked_context: Optional[ChunkedContextMetadata] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -172,7 +186,32 @@ class AscendMLAMetadataBuilder:
|
|||||||
if metadata_cls is not None else AscendMLAMetadata # type: ignore
|
if metadata_cls is not None else AscendMLAMetadata # type: ignore
|
||||||
self.runner = runner
|
self.runner = runner
|
||||||
scheduler_config = runner.scheduler_config
|
scheduler_config = runner.scheduler_config
|
||||||
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
model_config = runner.model_config
|
||||||
|
self.block_size = runner.block_size
|
||||||
|
self.chunked_prefill_enabled = runner.chunked_prefill_enabled
|
||||||
|
if self.chunked_prefill_enabled:
|
||||||
|
self.chunked_prefill_workspace_size = min(
|
||||||
|
# Max sure there is enough for 8 full length request or at least
|
||||||
|
# 4 pages of cache per request
|
||||||
|
max(8 * model_config.max_model_len,
|
||||||
|
4 * scheduler_config.max_num_seqs * self.block_size),
|
||||||
|
# For long-context models try not to over-allocate limiting
|
||||||
|
# kv-cache space, limiting it to 64k tokens,
|
||||||
|
# which would result in the workspace being:
|
||||||
|
# 2*(576)*(64*1024) = 144mb
|
||||||
|
# (assuming 576 MLA head dim, and fp16)
|
||||||
|
# which would result in up-projected context being
|
||||||
|
# 2*(192*128)*(64*1024) = 3gb
|
||||||
|
# (assuming 192 QK head dim, 128 heads, and fp16)
|
||||||
|
128 * 1024)
|
||||||
|
assert self.chunked_prefill_workspace_size >= \
|
||||||
|
scheduler_config.max_num_seqs * self.block_size
|
||||||
|
self.chunked_prefill_workspace = torch.empty(
|
||||||
|
(self.chunked_prefill_workspace_size,
|
||||||
|
model_config.get_head_size()),
|
||||||
|
dtype=model_config.dtype,
|
||||||
|
device=runner.device,
|
||||||
|
)
|
||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||||
|
|
||||||
@@ -350,6 +389,7 @@ class AscendMLAMetadataBuilder:
|
|||||||
query_start_loc = common_attn_metadata.query_start_loc
|
query_start_loc = common_attn_metadata.query_start_loc
|
||||||
|
|
||||||
prefill_metadata = None
|
prefill_metadata = None
|
||||||
|
chunked_context_metadata = None
|
||||||
if self._num_prefills > 0:
|
if self._num_prefills > 0:
|
||||||
reqs_start = self._num_decodes # prefill_start
|
reqs_start = self._num_decodes # prefill_start
|
||||||
tokens_start = self._num_decode_tokens
|
tokens_start = self._num_decode_tokens
|
||||||
@@ -359,6 +399,41 @@ class AscendMLAMetadataBuilder:
|
|||||||
prefill_query_start_loc = query_start_loc[
|
prefill_query_start_loc = query_start_loc[
|
||||||
reqs_start:] - query_start_loc[reqs_start]
|
reqs_start:] - query_start_loc[reqs_start]
|
||||||
|
|
||||||
|
context_lens_cpu = self.runner.input_batch.num_computed_tokens_cpu_tensor[
|
||||||
|
reqs_start:num_reqs]
|
||||||
|
max_context_len_cpu = context_lens_cpu.max().item()
|
||||||
|
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
||||||
|
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
|
||||||
|
max_context_chunk = (self.chunked_prefill_workspace_size //
|
||||||
|
num_prefills_with_context_cpu)
|
||||||
|
max_context_chunk = round_down(max_context_chunk,
|
||||||
|
self.block_size)
|
||||||
|
|
||||||
|
assert max_context_chunk > 0
|
||||||
|
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
|
||||||
|
chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
|
||||||
|
.unsqueeze(1).expand(-1, self._num_prefills) * max_context_chunk
|
||||||
|
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
|
||||||
|
chunk_starts + max_context_chunk)
|
||||||
|
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
||||||
|
cu_seq_lens_cpu = torch.zeros(num_chunks,
|
||||||
|
self._num_prefills + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
pin_memory=True)
|
||||||
|
torch.cumsum(chunk_seq_lens,
|
||||||
|
dim=1,
|
||||||
|
out=cu_seq_lens_cpu[:, 1:],
|
||||||
|
dtype=torch.int32)
|
||||||
|
chunked_context_metadata = \
|
||||||
|
AscendMLAPrefillMetadata.ChunkedContextMetadata(
|
||||||
|
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
||||||
|
starts=chunk_starts.to(device, non_blocking=True),
|
||||||
|
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||||
|
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||||
|
chunk_seq_lens=chunk_seq_lens,
|
||||||
|
workspace=self.chunked_prefill_workspace,
|
||||||
|
)
|
||||||
|
|
||||||
prefill_metadata = AscendMLAPrefillMetadata(
|
prefill_metadata = AscendMLAPrefillMetadata(
|
||||||
attn_mask=self.runner.attn_mask,
|
attn_mask=self.runner.attn_mask,
|
||||||
query_lens=query_lens[tokens_start:],
|
query_lens=query_lens[tokens_start:],
|
||||||
@@ -369,6 +444,7 @@ class AscendMLAMetadataBuilder:
|
|||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
max_seq_lens=max_seq_lens,
|
max_seq_lens=max_seq_lens,
|
||||||
query_start_loc=prefill_query_start_loc,
|
query_start_loc=prefill_query_start_loc,
|
||||||
|
chunked_context=chunked_context_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
decode_metadata = None
|
decode_metadata = None
|
||||||
@@ -575,6 +651,83 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
|
self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
|
||||||
self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
|
self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
|
||||||
|
|
||||||
|
def _compute_prefill_context(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
|
rope_dim: int,
|
||||||
|
attn_metadata: AscendMLAMetadata,
|
||||||
|
prefix_output: torch.Tensor,
|
||||||
|
prefix_lse: torch.Tensor,
|
||||||
|
):
|
||||||
|
prefill_metadata = attn_metadata.prefill
|
||||||
|
if prefill_metadata is None or prefill_metadata.chunked_context is None:
|
||||||
|
return prefix_output, prefix_lse
|
||||||
|
|
||||||
|
iters = len(prefill_metadata.chunked_context.seq_tot)
|
||||||
|
q_pe = query[..., self.qk_nope_head_dim:]
|
||||||
|
q_nope = query[..., :self.qk_nope_head_dim]
|
||||||
|
|
||||||
|
seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
|
||||||
|
latent_kv_dim = kv_c_and_k_pe_cache.size(3) - rope_dim
|
||||||
|
cache_kv_c = kv_c_and_k_pe_cache[:, :, :, :latent_kv_dim]
|
||||||
|
cache_k_pe = kv_c_and_k_pe_cache[:, :, :, latent_kv_dim:]
|
||||||
|
for i in range(iters):
|
||||||
|
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||||
|
|
||||||
|
seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
|
||||||
|
seq_len = torch.stack([seq_len1, seq_len2])
|
||||||
|
kv_c_normed = torch.empty(toks,
|
||||||
|
kv_c_and_k_pe_cache.size(2),
|
||||||
|
latent_kv_dim,
|
||||||
|
dtype=query.dtype,
|
||||||
|
device=query.device)
|
||||||
|
k_pe = torch.empty(toks,
|
||||||
|
kv_c_and_k_pe_cache.size(2),
|
||||||
|
rope_dim,
|
||||||
|
dtype=query.dtype,
|
||||||
|
device=query.device)
|
||||||
|
|
||||||
|
torch_npu.atb.npu_paged_cache_load(
|
||||||
|
cache_kv_c,
|
||||||
|
cache_k_pe,
|
||||||
|
prefill_metadata.block_table,
|
||||||
|
seq_len2.to(query.device),
|
||||||
|
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||||
|
key=kv_c_normed,
|
||||||
|
value=k_pe,
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_c_normed = kv_c_normed.squeeze()
|
||||||
|
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
|
||||||
|
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||||
|
k_nope, v = kv_nope\
|
||||||
|
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||||
|
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
||||||
|
mask = torch.triu(
|
||||||
|
torch.ones(512, 512, device=query.device, dtype=query.dtype),
|
||||||
|
1)
|
||||||
|
torch_npu.atb.npu_ring_mla(
|
||||||
|
q_nope=q_nope,
|
||||||
|
q_rope=q_pe,
|
||||||
|
k_nope=k_nope,
|
||||||
|
k_rope=k_pe,
|
||||||
|
value=v,
|
||||||
|
mask=mask,
|
||||||
|
seqlen=seq_len,
|
||||||
|
head_num=self.num_heads,
|
||||||
|
kv_head_num=self.num_heads,
|
||||||
|
pre_out=prefix_output,
|
||||||
|
prev_lse=prefix_lse,
|
||||||
|
qk_scale=self.scale,
|
||||||
|
kernel_type="kernel_type_high_precision",
|
||||||
|
mask_type="no_mask",
|
||||||
|
input_layout="type_bsnd",
|
||||||
|
calc_type="calc_type_default",
|
||||||
|
output=prefix_output,
|
||||||
|
softmax_lse=prefix_lse)
|
||||||
|
return prefix_output, prefix_lse
|
||||||
|
|
||||||
def _forward_prefill(
|
def _forward_prefill(
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@@ -586,19 +739,29 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
assert attn_metadata.prefill is not None
|
assert attn_metadata.prefill is not None
|
||||||
|
|
||||||
num_tokens = query.size(0)
|
num_tokens = query.size(0)
|
||||||
attn_output = None
|
attn_output = torch.empty(num_tokens,
|
||||||
|
self.num_heads,
|
||||||
|
self.v_head_dim,
|
||||||
|
dtype=query.dtype,
|
||||||
|
device=query.device)
|
||||||
|
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
|
||||||
|
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split(
|
||||||
|
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||||
|
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
||||||
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
|
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
|
||||||
|
ascend_config = get_ascend_config()
|
||||||
|
|
||||||
if attn_metadata.attn_state in [
|
if attn_metadata.attn_state in [
|
||||||
AscendAttentionState.ChunkedPrefill,
|
AscendAttentionState.ChunkedPrefill,
|
||||||
AscendAttentionState.SpecDecoding
|
AscendAttentionState.SpecDecoding
|
||||||
]:
|
] and not ascend_config.chunked_prefill_for_mla:
|
||||||
attn_output = torch.empty(num_tokens,
|
attn_output_torch = torch.empty(num_tokens,
|
||||||
self.num_heads * self.v_head_dim,
|
self.num_heads * self.v_head_dim,
|
||||||
dtype=query.dtype,
|
dtype=query.dtype,
|
||||||
device=query.device)
|
device=query.device)
|
||||||
# current requests is chunked in prefill, disable flash attention with chunked prefill
|
# current requests is chunked in prefill, disable flash attention with chunked prefill
|
||||||
vanilla_chunked_prefill_mla(
|
vanilla_chunked_prefill_mla(
|
||||||
output=attn_output,
|
output=attn_output_torch,
|
||||||
query=query,
|
query=query,
|
||||||
kv_cache=kv_c_and_k_pe_cache,
|
kv_cache=kv_c_and_k_pe_cache,
|
||||||
block_tables=attn_metadata.prefill.block_table,
|
block_tables=attn_metadata.prefill.block_table,
|
||||||
@@ -613,18 +776,47 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
scale=self.scale,
|
scale=self.scale,
|
||||||
alibi_slopes=None,
|
alibi_slopes=None,
|
||||||
causal=True)
|
causal=True)
|
||||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
elif attn_metadata.attn_state in [
|
||||||
attn_output = torch.empty(num_tokens,
|
AscendAttentionState.ChunkedPrefill,
|
||||||
self.num_heads,
|
AscendAttentionState.SpecDecoding
|
||||||
self.v_head_dim,
|
]:
|
||||||
dtype=query.dtype,
|
attn_lse = torch.empty(self.num_heads,
|
||||||
|
num_tokens,
|
||||||
|
dtype=torch.float32,
|
||||||
device=query.device)
|
device=query.device)
|
||||||
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
|
q_pe = query[..., self.qk_nope_head_dim:]
|
||||||
-1, self.num_heads,
|
q_nope = query[..., :self.qk_nope_head_dim]
|
||||||
self.qk_nope_head_dim + self.v_head_dim).split(
|
mask = torch.triu(
|
||||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
torch.ones(512, 512, device=query.device, dtype=query.dtype),
|
||||||
key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
|
1) # 512: mask only support 512
|
||||||
dim=-1)
|
if attn_metadata.num_prefills > 1:
|
||||||
|
mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1,
|
||||||
|
1)
|
||||||
|
torch_npu.atb.npu_ring_mla(
|
||||||
|
q_nope=q_nope,
|
||||||
|
q_rope=q_pe,
|
||||||
|
k_nope=k_nope,
|
||||||
|
k_rope=k_pe,
|
||||||
|
value=value,
|
||||||
|
mask=mask,
|
||||||
|
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
|
||||||
|
dtype=torch.int32),
|
||||||
|
head_num=self.num_heads,
|
||||||
|
kv_head_num=self.num_heads,
|
||||||
|
pre_out=None,
|
||||||
|
prev_lse=None,
|
||||||
|
qk_scale=self.scale,
|
||||||
|
kernel_type="kernel_type_high_precision",
|
||||||
|
mask_type="mask_type_triu",
|
||||||
|
input_layout="type_bsnd",
|
||||||
|
calc_type="calc_type_first_ring",
|
||||||
|
output=attn_output,
|
||||||
|
softmax_lse=attn_lse)
|
||||||
|
attn_output, attn_lse = self._compute_prefill_context( \
|
||||||
|
query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
|
||||||
|
|
||||||
|
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||||
|
key = torch.cat((k_nope, k_pe), dim=-1)
|
||||||
torch_npu._npu_flash_attention(
|
torch_npu._npu_flash_attention(
|
||||||
query=query,
|
query=query,
|
||||||
key=key,
|
key=key,
|
||||||
@@ -642,6 +834,11 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
)
|
)
|
||||||
attn_output = attn_output.reshape(
|
attn_output = attn_output.reshape(
|
||||||
[num_tokens, self.num_heads * self.v_head_dim])
|
[num_tokens, self.num_heads * self.v_head_dim])
|
||||||
|
if attn_metadata.attn_state in [
|
||||||
|
AscendAttentionState.ChunkedPrefill,
|
||||||
|
AscendAttentionState.SpecDecoding
|
||||||
|
] and not ascend_config.chunked_prefill_for_mla:
|
||||||
|
attn_output = attn_output_torch
|
||||||
|
|
||||||
current_ms_metadata = get_multistream_comm_context()
|
current_ms_metadata = get_multistream_comm_context()
|
||||||
if current_ms_metadata is None:
|
if current_ms_metadata is None:
|
||||||
|
|||||||
@@ -134,7 +134,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.lora_config = vllm_config.lora_config
|
self.lora_config = vllm_config.lora_config
|
||||||
self.scheduler_config = vllm_config.scheduler_config
|
self.scheduler_config = vllm_config.scheduler_config
|
||||||
self.speculative_config = vllm_config.speculative_config
|
self.speculative_config = vllm_config.speculative_config
|
||||||
self.chunked_prefill_enabled = vllm_config.scheduler_config.chunked_prefill_enabled
|
ascend_config = get_ascend_config()
|
||||||
|
if ascend_config.ascend_scheduler_config.enabled:
|
||||||
|
self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled
|
||||||
|
else:
|
||||||
|
self.chunked_prefill_enabled = True
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
self.is_multimodal_model = self.model_config.is_multimodal_model
|
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||||||
@@ -1260,6 +1264,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||||
# the requests one by one. Optimize.
|
# the requests one by one. Optimize.
|
||||||
|
discard_sampled_tokens_req_indices = []
|
||||||
for i, req_id in enumerate(self.input_batch.req_ids):
|
for i, req_id in enumerate(self.input_batch.req_ids):
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
seq_len = (req_state.num_computed_tokens +
|
seq_len = (req_state.num_computed_tokens +
|
||||||
@@ -1270,6 +1275,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
generator = self.input_batch.generators.get(i)
|
generator = self.input_batch.generators.get(i)
|
||||||
if generator is not None:
|
if generator is not None:
|
||||||
generator.set_offset(generator.get_offset() - 4)
|
generator.set_offset(generator.get_offset() - 4)
|
||||||
|
discard_sampled_tokens_req_indices.append(i)
|
||||||
|
|
||||||
# NOTE: NPU -> CPU Sync happens here.
|
# NOTE: NPU -> CPU Sync happens here.
|
||||||
# Move as many CPU operations as possible before this sync point.
|
# Move as many CPU operations as possible before this sync point.
|
||||||
@@ -1290,6 +1296,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.input_batch.vocab_size,
|
self.input_batch.vocab_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for i in discard_sampled_tokens_req_indices:
|
||||||
|
valid_sampled_token_ids[i].clear()
|
||||||
|
|
||||||
spec_token_ids = self._get_spec_token_ids(
|
spec_token_ids = self._get_spec_token_ids(
|
||||||
valid_sampled_token_ids,
|
valid_sampled_token_ids,
|
||||||
sampling_metadata,
|
sampling_metadata,
|
||||||
|
|||||||
Reference in New Issue
Block a user