vllm-ascend support chunked prefill (#1172)
### What this PR does / why we need it? vllm-ascend support chunked prefill for MLA --------- Signed-off-by: fems14 <1804143737@qq.com>
This commit is contained in:
@@ -11,6 +11,7 @@ from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.utils import cdiv, round_down
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention import _ALLOWED_NUM_QUERIES_PER_KV
|
||||
@@ -69,6 +70,18 @@ class AscendMLABackend(AttentionBackend):
|
||||
@dataclass
|
||||
class AscendMLAPrefillMetadata:
|
||||
""" Prefill Specific Metadata for Ascend"""
|
||||
|
||||
@dataclass
|
||||
class ChunkedContextMetadata:
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# For handling chunked prefill
|
||||
cu_seq_lens: torch.Tensor
|
||||
starts: torch.Tensor
|
||||
seq_tot: list[int]
|
||||
max_seq_lens: list[int]
|
||||
workspace: torch.Tensor
|
||||
chunk_seq_lens: torch.Tensor
|
||||
|
||||
attn_mask: torch.Tensor
|
||||
query_lens: list[int]
|
||||
seq_lens: list[int]
|
||||
@@ -78,6 +91,7 @@ class AscendMLAPrefillMetadata:
|
||||
block_table: torch.Tensor
|
||||
max_query_len: int
|
||||
max_seq_lens: int
|
||||
chunked_context: Optional[ChunkedContextMetadata] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -172,7 +186,32 @@ class AscendMLAMetadataBuilder:
|
||||
if metadata_cls is not None else AscendMLAMetadata # type: ignore
|
||||
self.runner = runner
|
||||
scheduler_config = runner.scheduler_config
|
||||
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
||||
model_config = runner.model_config
|
||||
self.block_size = runner.block_size
|
||||
self.chunked_prefill_enabled = runner.chunked_prefill_enabled
|
||||
if self.chunked_prefill_enabled:
|
||||
self.chunked_prefill_workspace_size = min(
|
||||
# Max sure there is enough for 8 full length request or at least
|
||||
# 4 pages of cache per request
|
||||
max(8 * model_config.max_model_len,
|
||||
4 * scheduler_config.max_num_seqs * self.block_size),
|
||||
# For long-context models try not to over-allocate limiting
|
||||
# kv-cache space, limiting it to 64k tokens,
|
||||
# which would result in the workspace being:
|
||||
# 2*(576)*(64*1024) = 144mb
|
||||
# (assuming 576 MLA head dim, and fp16)
|
||||
# which would result in up-projected context being
|
||||
# 2*(192*128)*(64*1024) = 3gb
|
||||
# (assuming 192 QK head dim, 128 heads, and fp16)
|
||||
128 * 1024)
|
||||
assert self.chunked_prefill_workspace_size >= \
|
||||
scheduler_config.max_num_seqs * self.block_size
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(self.chunked_prefill_workspace_size,
|
||||
model_config.get_head_size()),
|
||||
dtype=model_config.dtype,
|
||||
device=runner.device,
|
||||
)
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
|
||||
@@ -350,6 +389,7 @@ class AscendMLAMetadataBuilder:
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
|
||||
prefill_metadata = None
|
||||
chunked_context_metadata = None
|
||||
if self._num_prefills > 0:
|
||||
reqs_start = self._num_decodes # prefill_start
|
||||
tokens_start = self._num_decode_tokens
|
||||
@@ -359,6 +399,41 @@ class AscendMLAMetadataBuilder:
|
||||
prefill_query_start_loc = query_start_loc[
|
||||
reqs_start:] - query_start_loc[reqs_start]
|
||||
|
||||
context_lens_cpu = self.runner.input_batch.num_computed_tokens_cpu_tensor[
|
||||
reqs_start:num_reqs]
|
||||
max_context_len_cpu = context_lens_cpu.max().item()
|
||||
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
||||
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
|
||||
max_context_chunk = (self.chunked_prefill_workspace_size //
|
||||
num_prefills_with_context_cpu)
|
||||
max_context_chunk = round_down(max_context_chunk,
|
||||
self.block_size)
|
||||
|
||||
assert max_context_chunk > 0
|
||||
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
|
||||
chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
|
||||
.unsqueeze(1).expand(-1, self._num_prefills) * max_context_chunk
|
||||
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
|
||||
chunk_starts + max_context_chunk)
|
||||
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
||||
cu_seq_lens_cpu = torch.zeros(num_chunks,
|
||||
self._num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
pin_memory=True)
|
||||
torch.cumsum(chunk_seq_lens,
|
||||
dim=1,
|
||||
out=cu_seq_lens_cpu[:, 1:],
|
||||
dtype=torch.int32)
|
||||
chunked_context_metadata = \
|
||||
AscendMLAPrefillMetadata.ChunkedContextMetadata(
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
||||
starts=chunk_starts.to(device, non_blocking=True),
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
chunk_seq_lens=chunk_seq_lens,
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
)
|
||||
|
||||
prefill_metadata = AscendMLAPrefillMetadata(
|
||||
attn_mask=self.runner.attn_mask,
|
||||
query_lens=query_lens[tokens_start:],
|
||||
@@ -369,6 +444,7 @@ class AscendMLAMetadataBuilder:
|
||||
max_query_len=max_query_len,
|
||||
max_seq_lens=max_seq_lens,
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
chunked_context=chunked_context_metadata,
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
@@ -575,6 +651,83 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
|
||||
self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
|
||||
|
||||
def _compute_prefill_context(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
rope_dim: int,
|
||||
attn_metadata: AscendMLAMetadata,
|
||||
prefix_output: torch.Tensor,
|
||||
prefix_lse: torch.Tensor,
|
||||
):
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
if prefill_metadata is None or prefill_metadata.chunked_context is None:
|
||||
return prefix_output, prefix_lse
|
||||
|
||||
iters = len(prefill_metadata.chunked_context.seq_tot)
|
||||
q_pe = query[..., self.qk_nope_head_dim:]
|
||||
q_nope = query[..., :self.qk_nope_head_dim]
|
||||
|
||||
seq_len1 = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32)
|
||||
latent_kv_dim = kv_c_and_k_pe_cache.size(3) - rope_dim
|
||||
cache_kv_c = kv_c_and_k_pe_cache[:, :, :, :latent_kv_dim]
|
||||
cache_k_pe = kv_c_and_k_pe_cache[:, :, :, latent_kv_dim:]
|
||||
for i in range(iters):
|
||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||
|
||||
seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
|
||||
seq_len = torch.stack([seq_len1, seq_len2])
|
||||
kv_c_normed = torch.empty(toks,
|
||||
kv_c_and_k_pe_cache.size(2),
|
||||
latent_kv_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
k_pe = torch.empty(toks,
|
||||
kv_c_and_k_pe_cache.size(2),
|
||||
rope_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
|
||||
torch_npu.atb.npu_paged_cache_load(
|
||||
cache_kv_c,
|
||||
cache_k_pe,
|
||||
prefill_metadata.block_table,
|
||||
seq_len2.to(query.device),
|
||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||
key=kv_c_normed,
|
||||
value=k_pe,
|
||||
)
|
||||
|
||||
kv_c_normed = kv_c_normed.squeeze()
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
|
||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv_nope\
|
||||
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
||||
mask = torch.triu(
|
||||
torch.ones(512, 512, device=query.device, dtype=query.dtype),
|
||||
1)
|
||||
torch_npu.atb.npu_ring_mla(
|
||||
q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope,
|
||||
k_rope=k_pe,
|
||||
value=v,
|
||||
mask=mask,
|
||||
seqlen=seq_len,
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=prefix_output,
|
||||
prev_lse=prefix_lse,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="no_mask",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_default",
|
||||
output=prefix_output,
|
||||
softmax_lse=prefix_lse)
|
||||
return prefix_output, prefix_lse
|
||||
|
||||
def _forward_prefill(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
@@ -586,19 +739,29 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
assert attn_metadata.prefill is not None
|
||||
|
||||
num_tokens = query.size(0)
|
||||
attn_output = None
|
||||
attn_output = torch.empty(num_tokens,
|
||||
self.num_heads,
|
||||
self.v_head_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
|
||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim).split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
k_pe = k_pe.expand((*k_nope.shape[:-1], -1))
|
||||
# Here is only 2 possibility of input, ChunkedPrefill or PrefillNoCache
|
||||
ascend_config = get_ascend_config()
|
||||
|
||||
if attn_metadata.attn_state in [
|
||||
AscendAttentionState.ChunkedPrefill,
|
||||
AscendAttentionState.SpecDecoding
|
||||
]:
|
||||
attn_output = torch.empty(num_tokens,
|
||||
self.num_heads * self.v_head_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
] and not ascend_config.chunked_prefill_for_mla:
|
||||
attn_output_torch = torch.empty(num_tokens,
|
||||
self.num_heads * self.v_head_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
# current requests is chunked in prefill, disable flash attention with chunked prefill
|
||||
vanilla_chunked_prefill_mla(
|
||||
output=attn_output,
|
||||
output=attn_output_torch,
|
||||
query=query,
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
block_tables=attn_metadata.prefill.block_table,
|
||||
@@ -613,18 +776,47 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
scale=self.scale,
|
||||
alibi_slopes=None,
|
||||
causal=True)
|
||||
elif attn_metadata.attn_state in [
|
||||
AscendAttentionState.ChunkedPrefill,
|
||||
AscendAttentionState.SpecDecoding
|
||||
]:
|
||||
attn_lse = torch.empty(self.num_heads,
|
||||
num_tokens,
|
||||
dtype=torch.float32,
|
||||
device=query.device)
|
||||
q_pe = query[..., self.qk_nope_head_dim:]
|
||||
q_nope = query[..., :self.qk_nope_head_dim]
|
||||
mask = torch.triu(
|
||||
torch.ones(512, 512, device=query.device, dtype=query.dtype),
|
||||
1) # 512: mask only support 512
|
||||
if attn_metadata.num_prefills > 1:
|
||||
mask = mask.unsqueeze(0).repeat(attn_metadata.num_prefills, 1,
|
||||
1)
|
||||
torch_npu.atb.npu_ring_mla(
|
||||
q_nope=q_nope,
|
||||
q_rope=q_pe,
|
||||
k_nope=k_nope,
|
||||
k_rope=k_pe,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seqlen=torch.tensor(attn_metadata.prefill.query_lens,
|
||||
dtype=torch.int32),
|
||||
head_num=self.num_heads,
|
||||
kv_head_num=self.num_heads,
|
||||
pre_out=None,
|
||||
prev_lse=None,
|
||||
qk_scale=self.scale,
|
||||
kernel_type="kernel_type_high_precision",
|
||||
mask_type="mask_type_triu",
|
||||
input_layout="type_bsnd",
|
||||
calc_type="calc_type_first_ring",
|
||||
output=attn_output,
|
||||
softmax_lse=attn_lse)
|
||||
attn_output, attn_lse = self._compute_prefill_context( \
|
||||
query, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse)
|
||||
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
attn_output = torch.empty(num_tokens,
|
||||
self.num_heads,
|
||||
self.v_head_dim,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
k_nope, value = self.kv_b_proj(kv_c_normed)[0].view(
|
||||
-1, self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim).split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
key = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
|
||||
dim=-1)
|
||||
key = torch.cat((k_nope, k_pe), dim=-1)
|
||||
torch_npu._npu_flash_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
@@ -642,6 +834,11 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
)
|
||||
attn_output = attn_output.reshape(
|
||||
[num_tokens, self.num_heads * self.v_head_dim])
|
||||
if attn_metadata.attn_state in [
|
||||
AscendAttentionState.ChunkedPrefill,
|
||||
AscendAttentionState.SpecDecoding
|
||||
] and not ascend_config.chunked_prefill_for_mla:
|
||||
attn_output = attn_output_torch
|
||||
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
if current_ms_metadata is None:
|
||||
|
||||
Reference in New Issue
Block a user