support FULL graph mode for GQA (#3970)
### What this PR does / why we need it?
The current library only supports the FullDecodeOnly graph mode, which
enables full graph execution during the decode. This PR extends support
to allow full graph execution in both the prefill and decode, referred
to as FULL graph mode.
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
@@ -195,6 +195,7 @@ class AscendMetadataForDecode:
|
||||
class AscendMetadata:
|
||||
# **************************** Basic Properties ************************** #
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
fia_attn_mask: Optional[torch.Tensor] = None
|
||||
# Current state of this attention run.
|
||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||
|
||||
@@ -215,6 +216,7 @@ class AscendMetadata:
|
||||
seq_lens: torch.Tensor = None
|
||||
seq_lens_list: List[int] = None # type: ignore
|
||||
actual_seq_lengths_q: List[int] = None # type: ignore
|
||||
query_start_loc_list: List[int] = None # type: ignore
|
||||
|
||||
query_start_loc: torch.Tensor = None
|
||||
query_lens: torch.Tensor = None
|
||||
@@ -241,7 +243,8 @@ class AscendMetadata:
|
||||
class AscendAttentionMetadataBuilder:
|
||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
AttentionCGSupport.ALWAYS
|
||||
# AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
# Does this backend/builder reorder the batch?
|
||||
# If not, set this to None. Otherwise set it to the query
|
||||
# length that will be pulled into the front of the batch.
|
||||
@@ -321,6 +324,7 @@ class AscendAttentionMetadataBuilder:
|
||||
num_actual_tokens_pcp_padded]
|
||||
# slot_mapping = common_attn_metadata.slot_mapping[:num_actual_tokens]
|
||||
attn_mask = common_attn_metadata.attn_mask
|
||||
fia_attn_mask = common_attn_metadata.fia_attn_mask
|
||||
attn_state = common_attn_metadata.attn_state
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||
num_reqs
|
||||
@@ -471,6 +475,7 @@ class AscendAttentionMetadataBuilder:
|
||||
num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded,
|
||||
block_tables=block_table,
|
||||
query_start_loc=query_start_loc,
|
||||
query_start_loc_list=query_start_loc_cpu[1:].tolist(),
|
||||
query_lens=query_lens,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens.tolist(),
|
||||
@@ -478,6 +483,7 @@ class AscendAttentionMetadataBuilder:
|
||||
actual_seq_lengths_q=query_start_loc_cpu[1:].tolist(),
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
fia_attn_mask=fia_attn_mask,
|
||||
attn_state=attn_state,
|
||||
num_prefills=num_prefills,
|
||||
num_decodes=num_decodes,
|
||||
@@ -565,6 +571,113 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
self.dcp_group = get_dcp_group(
|
||||
).device_group if self.dcp_size > 1 else None
|
||||
|
||||
def full_graph_attention(self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor,
|
||||
num_tokens=0):
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
block_size = 128
|
||||
block_table = None
|
||||
actual_seq_lengths_kv = attn_metadata.query_start_loc_list
|
||||
elif attn_metadata.attn_state == \
|
||||
AscendAttentionState.PrefillCacheHit:
|
||||
batch_size = attn_metadata.query_lens.shape[0]
|
||||
block_table = attn_metadata.block_tables[:batch_size, :]
|
||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||
key = self.key_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
value = self.value_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||
key = self.key_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
value = self.value_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
block_table = attn_metadata.block_tables
|
||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||
# Normal V1 situation.
|
||||
else:
|
||||
num_block, block_size, _, _ = self.key_cache.shape # type: ignore
|
||||
key = self.key_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
value = self.value_cache.view( # type: ignore
|
||||
num_block, block_size, -1)
|
||||
block_table = attn_metadata.block_tables
|
||||
actual_seq_lengths_kv = attn_metadata.seq_lens_list
|
||||
|
||||
num_tokens = attn_metadata.query_start_loc_list[-1]
|
||||
query = query[:num_tokens]
|
||||
graph_params = get_graph_params()
|
||||
query_start_loc = attn_metadata.query_start_loc_list
|
||||
# Prepare tensors for attention output
|
||||
# TODO: Refactor this to step-level instead of layer-level
|
||||
|
||||
# Get workspace from cache or calculate it if not present.
|
||||
workspace = graph_params.workspaces.get(num_tokens)
|
||||
softmax_lse = torch.empty(1, dtype=query.dtype, device=query.device)
|
||||
if workspace is None:
|
||||
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
atten_mask=attn_metadata.fia_attn_mask,
|
||||
block_table=block_table,
|
||||
input_layout="TND",
|
||||
block_size=block_size,
|
||||
actual_seq_lengths=query_start_loc,
|
||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
sparse_mode=3,
|
||||
scale=self.scale,
|
||||
)
|
||||
update_graph_params_workspaces(num_tokens, workspace)
|
||||
|
||||
# Handle graph capturing mode
|
||||
stream = torch_npu.npu.current_stream()
|
||||
|
||||
event = torch.npu.ExternalEvent()
|
||||
event.wait(stream)
|
||||
event.reset(stream)
|
||||
graph_params.events[num_tokens].append(event)
|
||||
graph_params.attn_params[num_tokens].append(
|
||||
(weak_ref_tensors(query), weak_ref_tensors(key),
|
||||
weak_ref_tensors(value), weak_ref_tensors(block_table),
|
||||
weak_ref_tensors(attn_metadata.fia_attn_mask), block_size,
|
||||
actual_seq_lengths_kv, query_start_loc, self.num_kv_heads,
|
||||
self.num_heads, self.scale, weak_ref_tensors(output),
|
||||
weak_ref_tensors(softmax_lse)))
|
||||
|
||||
torch.npu.graph_task_group_begin(stream)
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
atten_mask=attn_metadata.fia_attn_mask,
|
||||
block_table=block_table,
|
||||
input_layout="TND",
|
||||
block_size=block_size,
|
||||
actual_seq_lengths=query_start_loc,
|
||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||
num_key_value_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale=self.scale,
|
||||
sparse_mode=3,
|
||||
workspace=workspace,
|
||||
out=[output, softmax_lse],
|
||||
)
|
||||
|
||||
output = output.view(num_tokens, self.num_heads, self.head_size)
|
||||
|
||||
handle = torch.npu.graph_task_group_end(stream)
|
||||
graph_params.handles[num_tokens].append(handle)
|
||||
return output, num_tokens
|
||||
|
||||
def _forward_prefill_no_cache(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
@@ -692,70 +805,16 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
output = output.view(batch_size, self.num_heads, self.head_size)
|
||||
else:
|
||||
graph_params = get_graph_params()
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
num_tokens = query.shape[0]
|
||||
if forward_context.capturing:
|
||||
# Get workspace from cache or calculate it if not present.
|
||||
workspace = graph_params.workspaces.get(num_tokens)
|
||||
if workspace is None:
|
||||
workspace = torch_npu._npu_paged_attention_get_workspace(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
update_graph_params_workspaces(num_tokens,
|
||||
weak_ref_tensors(workspace))
|
||||
|
||||
# Handle graph capturing mode
|
||||
stream = torch_npu.npu.current_stream()
|
||||
|
||||
event = torch.npu.ExternalEvent()
|
||||
event.wait(stream)
|
||||
event.reset(stream)
|
||||
graph_params.events[num_tokens].append(event)
|
||||
graph_params.attn_params[num_tokens].append((
|
||||
weak_ref_tensors(query),
|
||||
weak_ref_tensors(self.key_cache),
|
||||
weak_ref_tensors(self.value_cache),
|
||||
self.num_kv_heads,
|
||||
self.num_heads,
|
||||
self.scale,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.seq_lens,
|
||||
weak_ref_tensors(output),
|
||||
))
|
||||
|
||||
torch.npu.graph_task_group_begin(stream)
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output,
|
||||
workspace=workspace)
|
||||
handle = torch.npu.graph_task_group_end(stream)
|
||||
graph_params.handles[num_tokens].append(handle)
|
||||
else:
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
return output
|
||||
|
||||
def _forward_v1_style(
|
||||
@@ -819,7 +878,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
scale=self.scale,
|
||||
sparse_mode=3,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def _attention_with_nomask_and_mask(self, q: torch.Tensor,
|
||||
@@ -1481,47 +1539,51 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
num_decode_tokens:attn_metadata.
|
||||
num_actual_tokens_pcp_padded])
|
||||
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
intermediate_output = self._forward_pcp_dcp(
|
||||
query, key, value, kv_cache, attn_metadata, output)
|
||||
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||
# TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly.
|
||||
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
|
||||
intermediate_output = torch_npu.npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
head_num=self.num_heads,
|
||||
input_layout="TND",
|
||||
scale=self.scale,
|
||||
sparse_mode=4,
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
pre_tockens=attn_metadata.max_query_len,
|
||||
next_tockens=attn_metadata.max_query_len,
|
||||
actual_seq_qlen=cum_seq_len,
|
||||
actual_seq_kvlen=cum_seq_len,
|
||||
)[0]
|
||||
# V0-Style scheduler situation.
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
intermediate_output = self._forward_prefill_no_cache(
|
||||
query, key, value, attn_metadata, output, num_tokens)
|
||||
elif attn_metadata.attn_state == \
|
||||
AscendAttentionState.PrefillCacheHit:
|
||||
intermediate_output = self._forward_prefill_cache_hit(
|
||||
query, attn_metadata, output)
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
intermediate_output = self._forward_decode_only(
|
||||
query, attn_metadata, output)
|
||||
# Normal V1 situation.
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
if not forward_context.capturing:
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
intermediate_output = self._forward_pcp_dcp(
|
||||
query, key, value, kv_cache, attn_metadata, output)
|
||||
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||
# TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly.
|
||||
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
|
||||
intermediate_output = torch_npu.npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
head_num=self.num_heads,
|
||||
input_layout="TND",
|
||||
scale=self.scale,
|
||||
sparse_mode=4,
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
pre_tockens=attn_metadata.max_query_len,
|
||||
next_tockens=attn_metadata.max_query_len,
|
||||
actual_seq_qlen=cum_seq_len,
|
||||
actual_seq_kvlen=cum_seq_len,
|
||||
)[0]
|
||||
# V0-Style scheduler situation.
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
intermediate_output = self._forward_prefill_no_cache(
|
||||
query, key, value, attn_metadata, output, num_tokens)
|
||||
elif attn_metadata.attn_state == \
|
||||
AscendAttentionState.PrefillCacheHit:
|
||||
intermediate_output = self._forward_prefill_cache_hit(
|
||||
query, attn_metadata, output)
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
intermediate_output = self._forward_decode_only(
|
||||
query, attn_metadata, output)
|
||||
# Normal V1 situation.
|
||||
else:
|
||||
# npu_fused_infer_attention_score does not support cases
|
||||
# where query.shape[0] != attn_metadata.query_start_loc[-1].
|
||||
# Thus we need unpad it here.
|
||||
num_tokens = attn_metadata.query_start_loc[-1]
|
||||
query = query[:num_tokens]
|
||||
intermediate_output = self._forward_v1_style(
|
||||
query, attn_metadata, output)
|
||||
else:
|
||||
# npu_fused_infer_attention_score does not support cases
|
||||
# where query.shape[0] != attn_metadata.query_start_loc[-1].
|
||||
# Thus we need unpad it here.
|
||||
num_tokens = attn_metadata.query_start_loc[-1]
|
||||
query = query[:num_tokens]
|
||||
intermediate_output = self._forward_v1_style(
|
||||
query, attn_metadata, output)
|
||||
|
||||
intermediate_output, num_tokens = self.full_graph_attention(
|
||||
query, key, value, attn_metadata, output)
|
||||
output[:num_tokens] = intermediate_output[:num_tokens]
|
||||
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user