[Enhancement] Add padding for ACL Graph (#803)

### What this PR does / why we need it?
Add padding for ACL Graph and refactor graph batch size adjustments to
utils.py

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
yiz-liu
2025-05-12 20:26:22 +08:00
committed by GitHub
parent efabd722eb
commit 701b0fd95e
4 changed files with 97 additions and 79 deletions

View File

@@ -104,6 +104,7 @@ class AscendAttentionState(Enum):
@dataclass
class AscendMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
block_tables: torch.Tensor
@@ -125,7 +126,6 @@ class AscendMetadata:
is_only_prefill: bool = False
# Current state of this attention run.
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
attn_mask: Optional[torch.Tensor] = None
@@ -149,7 +149,8 @@ class AscendAttentionMetadataBuilder:
attn_mask = self.runner.attn_mask
attn_state = self.runner.attn_state
attn_metadata = AscendMetadata(block_tables=block_table,
attn_metadata = AscendMetadata(num_actual_tokens=num_actual_tokens,
block_tables=block_table,
query_lens=query_lens,
seq_lens=seq_lens,
max_query_len=max_query_len,
@@ -234,9 +235,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
output=output,
layer_name=layer.layer_name)
else:
num_tokens = query.shape[0]
if attn_metadata is None:
return output.view(num_tokens, self.hidden_size)
num_actual_tokens = attn_metadata.num_actual_tokens
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
attn_type = self.attn_type
if attn_type != AttentionType.DECODER:
@@ -255,11 +256,12 @@ class AscendAttentionBackendImpl(AttentionImpl):
if self.key_cache is None:
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
slots = attn_metadata.slot_mapping
torch_npu._npu_reshape_and_cache(key=key,
value=value,
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slots)
torch_npu._npu_reshape_and_cache(
key=key[:num_actual_tokens],
value=value[:num_actual_tokens],
key_cache=self.key_cache,
value_cache=self.value_cache,
slot_indices=slots)
if hasattr(layer, 'quant_method'):
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata