[feat]dcp pcp support aclgraph (#3731)
### What this PR does / why we need it?
dcp pcp support full aclgraph, including mla attention_v1
- vLLM version: v0.11.0rc3
- vLLM main:
c9461e05a4
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -140,6 +140,9 @@ class AscendMLADecodeMetadata:
|
||||
cos: torch.Tensor = None
|
||||
num_computed_tokens_of_pcp_dcp: Optional[list[Optional[list[Optional[
|
||||
list[int]]]]]] = None
|
||||
seq_mask_pcp: torch.Tensor = None
|
||||
seq_mask_dcp: torch.Tensor = None
|
||||
cp_seq_len: torch.Tensor = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -259,6 +262,24 @@ class AscendMLAMetadataBuilder:
|
||||
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||
self.cos_cache = None
|
||||
self.sin_cache = None
|
||||
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
||||
) if prefill_context_parallel_enable() else 1
|
||||
self.cp_rank = get_prefill_context_model_parallel_rank(
|
||||
) if self.pcp_size > 1 else 0
|
||||
self.dcp_size = get_decode_context_model_parallel_world_size()
|
||||
self.dcp_rank = get_decode_context_model_parallel_rank(
|
||||
) if self.dcp_size > 1 else 0
|
||||
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs',
|
||||
0)
|
||||
max_num_seqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
|
||||
self.seq_mask_pcp_buf = torch.empty(max_num_seqs,
|
||||
self.pcp_size,
|
||||
dtype=torch.uint8,
|
||||
device=device)
|
||||
self.seq_mask_dcp_buf = torch.empty(max_num_seqs,
|
||||
self.dcp_size,
|
||||
dtype=torch.uint8,
|
||||
device=device)
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
@@ -463,6 +484,41 @@ class AscendMLAMetadataBuilder:
|
||||
block_table = block_table[:num_decodes, ...]
|
||||
seq_lens_list = seq_lens.tolist()
|
||||
|
||||
if num_computed_tokens_of_pcp_dcp is not None:
|
||||
num_computed_tokens_of_cp_dcp_array = np.array(
|
||||
num_computed_tokens_of_pcp_dcp
|
||||
)[:num_decodes] # [bs, pcp_size, dcp_size]
|
||||
seq_mask_pcp = torch.where(
|
||||
torch.tensor(
|
||||
num_computed_tokens_of_cp_dcp_array.sum(2)) == 0, 0,
|
||||
1).to(torch.uint8)
|
||||
self.seq_mask_pcp_buf[:seq_mask_pcp.shape[0], :seq_mask_pcp.
|
||||
shape[1]].copy_(seq_mask_pcp,
|
||||
non_blocking=True)
|
||||
seq_mask_pcp_shape = (seq_mask_pcp.shape[0],
|
||||
seq_mask_pcp.shape[1])
|
||||
|
||||
seq_mask_dcp = torch.where(
|
||||
torch.tensor(
|
||||
num_computed_tokens_of_cp_dcp_array[:,
|
||||
self.cp_rank, :])
|
||||
== 0, 0, 1).to(torch.uint8)
|
||||
self.seq_mask_dcp_buf[:seq_mask_dcp.shape[0], :seq_mask_dcp.
|
||||
shape[1]].copy_(seq_mask_dcp,
|
||||
non_blocking=True)
|
||||
seq_mask_dcp_shape = (seq_mask_dcp.shape[0],
|
||||
seq_mask_dcp.shape[1])
|
||||
|
||||
cp_seq_len = num_computed_tokens_of_cp_dcp_array[:,
|
||||
self.cp_rank,
|
||||
self.dcp_rank]
|
||||
cp_seq_len = torch.tensor(cp_seq_len, dtype=torch.int32)
|
||||
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
|
||||
else:
|
||||
seq_mask_pcp_shape = (0, 0)
|
||||
seq_mask_dcp_shape = (0, 0)
|
||||
cp_seq_len = None
|
||||
|
||||
# TODO: After the fullgraph supports MTP, the if branch needs to deleted
|
||||
assert self.cos_cache is not None
|
||||
assert self.sin_cache is not None
|
||||
@@ -485,7 +541,14 @@ class AscendMLAMetadataBuilder:
|
||||
sin=sin,
|
||||
cos=cos,
|
||||
num_computed_tokens_of_pcp_dcp=
|
||||
num_computed_tokens_of_pcp_dcp)
|
||||
num_computed_tokens_of_pcp_dcp,
|
||||
seq_mask_pcp=self.
|
||||
seq_mask_pcp_buf[:seq_mask_pcp_shape[0], :
|
||||
seq_mask_pcp_shape[1]],
|
||||
seq_mask_dcp=self.
|
||||
seq_mask_dcp_buf[:seq_mask_dcp_shape[0], :
|
||||
seq_mask_dcp_shape[1]],
|
||||
cp_seq_len=cp_seq_len)
|
||||
else:
|
||||
cos[:num_decode_tokens,
|
||||
...] = self.cos_cache[input_positions].unsqueeze(
|
||||
@@ -505,7 +568,14 @@ class AscendMLAMetadataBuilder:
|
||||
sin=sin[:num_decode_tokens, ...],
|
||||
cos=cos[:num_decode_tokens, ...],
|
||||
num_computed_tokens_of_pcp_dcp=
|
||||
num_computed_tokens_of_pcp_dcp)
|
||||
num_computed_tokens_of_pcp_dcp,
|
||||
seq_mask_pcp=self.
|
||||
seq_mask_pcp_buf[:seq_mask_pcp_shape[0], :
|
||||
seq_mask_pcp_shape[1]],
|
||||
seq_mask_dcp=self.
|
||||
seq_mask_dcp_buf[:seq_mask_dcp_shape[0], :
|
||||
seq_mask_dcp_shape[1]],
|
||||
cp_seq_len=cp_seq_len)
|
||||
|
||||
return self.metadata_cls( # type: ignore
|
||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
||||
@@ -1590,36 +1660,63 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
q_nope = q_nope.view(num_tokens, num_heads, -1)
|
||||
q_pe = q_pe.view(num_tokens, num_heads, -1)
|
||||
# use pcp & dcp split computed token nums from scheduler to compute actual seq_len and seq_mask
|
||||
num_computed_tokens_of_pcp_dcp = np.array(
|
||||
decode_meta.num_computed_tokens_of_pcp_dcp
|
||||
)[:attn_metadata.num_decodes] # [bs, pcp_size, dcp_size]
|
||||
seq_mask_pcp = torch.where(
|
||||
torch.tensor(num_computed_tokens_of_pcp_dcp.sum(2)) == 0, 0,
|
||||
1).to(torch.uint8).to(q_pe.device)
|
||||
seq_mask_dcp = torch.where(
|
||||
torch.tensor(
|
||||
num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, :]) == 0, 0,
|
||||
1).to(torch.uint8).to(q_pe.device)
|
||||
seq_len = num_computed_tokens_of_pcp_dcp[:, self.pcp_rank,
|
||||
self.dcp_rank]
|
||||
seq_len = torch.tensor(seq_len, dtype=torch.int32)
|
||||
# npu_multi_head_latent_attention does not support seq_len = 0,
|
||||
# update where seq_len == 0 to 1.
|
||||
# This will not influence result, since we will use seq_mask to update lse.
|
||||
seq_len = torch.where(seq_len == 0, 1, seq_len)
|
||||
seq_mask_pcp = decode_meta.seq_mask_pcp
|
||||
seq_mask_dcp = decode_meta.seq_mask_dcp
|
||||
seq_len = decode_meta.cp_seq_len
|
||||
|
||||
if torch.sum(seq_len).item() == 0:
|
||||
# Case that no kv_cache has been stored on this rank, no need to do following computation.
|
||||
attn_output = torch.zeros(
|
||||
[num_tokens, num_heads, self.kv_lora_rank],
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
softmax_lse = torch.full((num_tokens, num_heads, 1),
|
||||
float('-inf'),
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
common_kwargs = {
|
||||
"return_lse": True,
|
||||
"calc_type": "calc_type_ring",
|
||||
}
|
||||
graph_params = get_graph_params()
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
if forward_context.capturing:
|
||||
stream = torch_npu.npu.current_stream()
|
||||
event = torch.npu.ExternalEvent()
|
||||
event.wait(stream)
|
||||
event.reset(stream)
|
||||
graph_params.events[num_tokens].append(event)
|
||||
workspace = graph_params.workspaces.get(num_tokens)
|
||||
if workspace is None:
|
||||
workspace = torch_npu.atb._npu_multi_head_latent_attention_get_workspace(
|
||||
q_nope, q_pe, k_nope, k_pe, decode_meta.block_table,
|
||||
seq_len, num_heads, self.scale, self.num_kv_heads,
|
||||
**common_kwargs)
|
||||
update_graph_params_workspaces(num_tokens,
|
||||
weak_ref_tensors(workspace))
|
||||
attn_output = torch.empty_like(q_nope)
|
||||
softmax_lse = torch.empty((num_tokens, num_heads, 1),
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
graph_params.attn_params[num_tokens].append(
|
||||
(weak_ref_tensors(q_nope), weak_ref_tensors(q_pe),
|
||||
weak_ref_tensors(k_nope), weak_ref_tensors(k_pe),
|
||||
decode_meta.block_table, seq_len, num_heads, self.scale,
|
||||
self.num_kv_heads, weak_ref_tensors(attn_output),
|
||||
weak_ref_tensors(softmax_lse)))
|
||||
torch.npu.graph_task_group_begin(stream)
|
||||
torch_npu.atb.npu_multi_head_latent_attention(
|
||||
q_nope,
|
||||
q_pe,
|
||||
k_nope,
|
||||
k_pe,
|
||||
decode_meta.block_table,
|
||||
seq_len,
|
||||
num_heads,
|
||||
self.scale,
|
||||
self.num_kv_heads,
|
||||
**common_kwargs,
|
||||
workspace=workspace,
|
||||
output=attn_output,
|
||||
lse=softmax_lse)
|
||||
handle = torch.npu.graph_task_group_end(stream)
|
||||
graph_params.handles[num_tokens].append(handle)
|
||||
else:
|
||||
attn_output, softmax_lse = torch_npu.atb.npu_multi_head_latent_attention(
|
||||
attn_output = torch.empty_like(q_nope)
|
||||
softmax_lse = torch.empty((num_tokens, num_heads, 1),
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
torch_npu.atb.npu_multi_head_latent_attention(
|
||||
q_nope,
|
||||
q_pe,
|
||||
k_nope,
|
||||
@@ -1630,7 +1727,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
self.scale,
|
||||
self.num_kv_heads,
|
||||
return_lse=True,
|
||||
calc_type="calc_type_ring")
|
||||
calc_type="calc_type_ring",
|
||||
output=attn_output,
|
||||
lse=softmax_lse)
|
||||
|
||||
if self.dcp_size > 1:
|
||||
# Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1]
|
||||
|
||||
Reference in New Issue
Block a user