vllm-ascend support chunked prefill (#1172)

### What this PR does / why we need it?
vllm-ascend support chunked prefill for MLA


---------

Signed-off-by: fems14 <1804143737@qq.com>
This commit is contained in:
fems14
2025-06-14 22:31:16 +08:00
committed by GitHub
parent a3b5af8307
commit ab5d110fcc
5 changed files with 303 additions and 20 deletions

View File

@@ -11,6 +11,7 @@ from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.utils import cdiv, round_down
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV
@@ -69,6 +70,18 @@ class AscendMLABackend(AttentionBackend):
@dataclass
class AscendMLAPrefillMetadata:
""" Prefill Specific Metadata for Ascend"""
@dataclass
class ChunkedContextMetadata:
# New for MLA (compared to FlashAttention)
# For handling chunked prefill
cu_seq_lens: torch.Tensor
starts: torch.Tensor
seq_tot: list[int]
max_seq_lens: list[int]
workspace: torch.Tensor
chunk_seq_lens: torch.Tensor
attn_mask: torch.Tensor
query_lens: list[int]
seq_lens: list[int]
@@ -78,6 +91,7 @@ class AscendMLAPrefillMetadata:
block_table: torch.Tensor
max_query_len: int
max_seq_lens: int
chunked_context: Optional[ChunkedContextMetadata] = None
@dataclass
@@ -172,7 +186,32 @@ class AscendMLAMetadataBuilder:
if metadata_cls is not None else AscendMLAMetadata # type: ignore
self.runner = runner
scheduler_config = runner.scheduler_config
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
model_config = runner.model_config
self.block_size = runner.block_size
self.chunked_prefill_enabled = runner.chunked_prefill_enabled
if self.chunked_prefill_enabled:
self.chunked_prefill_workspace_size = min(
# Max sure there is enough for 8 full length request or at least
# 4 pages of cache per request
max(8 * model_config.max_model_len,
4 * scheduler_config.max_num_seqs * self.block_size),
# For long-context models try not to over-allocate limiting
# kv-cache space, limiting it to 64k tokens,
# which would result in the workspace being:
# 2*(576)*(64*1024) = 144mb
# (assuming 576 MLA head dim, and fp16)
# which would result in up-projected context being
# 2*(192*128)*(64*1024) = 3gb
# (assuming 192 QK head dim, 128 heads, and fp16)
128 * 1024)
assert self.chunked_prefill_workspace_size >= \
scheduler_config.max_num_seqs * self.block_size
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
model_config.get_head_size()),
dtype=model_config.dtype,
device=runner.device,
)
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
@@ -350,6 +389,7 @@ class AscendMLAMetadataBuilder:
query_start_loc = common_attn_metadata.query_start_loc
prefill_metadata = None
chunked_context_metadata = None
if self._num_prefills > 0:
reqs_start = self._num_decodes # prefill_start
tokens_start = self._num_decode_tokens
@@ -359,6 +399,41 @@ class AscendMLAMetadataBuilder:
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]
context_lens_cpu = self.runner.input_batch.num_computed_tokens_cpu_tensor[
reqs_start:num_reqs]
max_context_len_cpu = context_lens_cpu.max().item()
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
max_context_chunk = (self.chunked_prefill_workspace_size //
num_prefills_with_context_cpu)
max_context_chunk = round_down(max_context_chunk,
self.block_size)
assert max_context_chunk > 0
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
.unsqueeze(1).expand(-1, self._num_prefills) * max_context_chunk
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
chunk_starts + max_context_chunk)
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
cu_seq_lens_cpu = torch.zeros(num_chunks,
self._num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
torch.cumsum(chunk_seq_lens,
dim=1,
out=cu_seq_lens_cpu[:, 1:],
dtype=torch.int32)
chunked_context_metadata = \
AscendMLAPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
starts=chunk_starts.to(device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
workspace=self.chunked_prefill_workspace,
)
prefill_metadata = AscendMLAPrefillMetadata(
attn_mask=self.runner.attn_mask,
query_lens=query_lens[tokens_start:],
@@ -369,6 +444,7 @@ class AscendMLAMetadataBuilder:
max_query_len=max_query_len,
max_seq_lens=max_seq_lens,
query_start_loc=prefill_query_start_loc,
chunked_context=chunked_context_metadata,
)
decode_metadata = None
@@ -575,6 +651,83 @@ class AscendMLAImpl(MLAAttentionImpl):
self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
def _compute_prefill_context(
self,
query: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
rope_dim: int,
attn_metadata: AscendMLAMetadata,
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
):
prefill_metadata = attn_metadata.prefill
if prefill_metadata is None or prefill_metadata.chunked_context is None:
return prefix_output, prefix_lse
iters = len(prefill_metadata.chunked_context.seq_tot)
q_pe = query[..., self.qk_nope_head_dim:]
q_nope = query[..., :self.qk_nope_head_dim]
seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
latent_kv_dim = kv_c_and_k_pe_cache.size(3) - rope_dim
cache_kv_c = kv_c_and_k_pe_cache[:, :, :, :latent_kv_dim]
cache_k_pe = kv_c_and_k_pe_cache[:, :, :, latent_kv_dim:]
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
seq_len = torch.stack([seq_len1, seq_len2])
kv_c_normed = torch.empty(toks,
kv_c_and_k_pe_cache.size(2),
latent_kv_dim,
dtype=query.dtype,
device=query.device)
k_pe = torch.empty(toks,
kv_c_and_k_pe_cache.size(2),
rope_dim,
dtype=query.dtype,
device=query.device)
torch_npu.atb.npu_paged_cache_load(
cache_kv_c,
cache_k_pe,
prefill_metadata.block_table,
seq_len2.to(query.device),
seq_starts=prefill_metadata.chunked_context.starts[i],
key=kv_c_normed,
value=k_pe,
)
kv_c_normed = kv_c_normed.squeeze()
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
mask = torch.triu(
torch.ones(512, 512, device=query.device, dtype=query.dtype),
1)
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope,
k_rope=k_pe,
value=v,
mask=mask,
seqlen=seq_len,
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=prefix_output,
prev_lse=prefix_lse,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="no_mask",
input_layout="type_bsnd",
calc_type="calc_type_default",
output=prefix_output,
softmax_lse=prefix_lse)
return prefix_output, prefix_lse
def _forward_prefill(
self,
query: torch.Tensor,
@@ -586,19 +739,29 @@ class AscendMLAImpl(MLAAttentionImpl):
assert attn_metadata.prefill is not None
num_tokens = query.size(0)
attn_output = None
attn_output = torch.empty(num_tokens,
self.num_heads,
self.v_head_dim,
dtype=query.dtype,
device=query.device)
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
ascend_config = get_ascend_config()
if attn_metadata.attn_state in [
AscendAttentionState.ChunkedPrefill,
AscendAttentionState.SpecDecoding
]:
attn_output = torch.empty(num_tokens,
self.num_heads * self.v_head_dim,
dtype=query.dtype,
device=query.device)
] and not ascend_config.chunked_prefill_for_mla:
attn_output_torch = torch.empty(num_tokens,
self.num_heads * self.v_head_dim,
dtype=query.dtype,
device=query.device)
# current requests is chunked in prefill, disable flash attention with chunked prefill
vanilla_chunked_prefill_mla(
output=attn_output,
output=attn_output_torch,
query=query,
kv_cache=kv_c_and_k_pe_cache,
block_tables=attn_metadata.prefill.block_table,
@@ -613,18 +776,47 @@ class AscendMLAImpl(MLAAttentionImpl):
scale=self.scale,
alibi_slopes=None,
causal=True)
elif attn_metadata.attn_state in [
AscendAttentionState.ChunkedPrefill,
AscendAttentionState.SpecDecoding
]:
attn_lse = torch.empty(self.num_heads,
num_tokens,
dtype=torch.float32,
device=query.device)
q_pe = query[..., self.qk_nope_head_dim:]
q_nope = query[..., :self.qk_nope_head_dim]
mask = torch.triu(
torch.ones(512, 512, device=query.device, dtype=query.dtype),
1) # 512: mask only support 512
if attn_metadata.num_prefills > 1:
mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1,
1)
torch_npu.atb.npu_ring_mla(
q_nope=q_nope,
q_rope=q_pe,
k_nope=k_nope,
k_rope=k_pe,
value=value,
mask=mask,
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
dtype=torch.int32),
head_num=self.num_heads,
kv_head_num=self.num_heads,
pre_out=None,
prev_lse=None,
qk_scale=self.scale,
kernel_type="kernel_type_high_precision",
mask_type="mask_type_triu",
input_layout="type_bsnd",
calc_type="calc_type_first_ring",
output=attn_output,
softmax_lse=attn_lse)
attn_output, attn_lse = self._compute_prefill_context( \
query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
attn_output = torch.empty(num_tokens,
self.num_heads,
self.v_head_dim,
dtype=query.dtype,
device=query.device)
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
-1, self.num_heads,
self.qk_nope_head_dim + self.v_head_dim).split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
key = torch.cat((k_nope, k_pe), dim=-1)
torch_npu._npu_flash_attention(
query=query,
key=key,
@@ -642,6 +834,11 @@ class AscendMLAImpl(MLAAttentionImpl):
)
attn_output = attn_output.reshape(
[num_tokens, self.num_heads * self.v_head_dim])
if attn_metadata.attn_state in [
AscendAttentionState.ChunkedPrefill,
AscendAttentionState.SpecDecoding
] and not ascend_config.chunked_prefill_for_mla:
attn_output = attn_output_torch
current_ms_metadata = get_multistream_comm_context()
if current_ms_metadata is None: