[Enhancement] Add padding for ACL Graph (#803)
### What this PR does / why we need it? Add padding for ACL Graph and refactor graph batch size adjustments to utils.py --------- Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -104,6 +104,7 @@ class AscendAttentionState(Enum):
|
||||
|
||||
@dataclass
|
||||
class AscendMetadata:
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
block_tables: torch.Tensor
|
||||
@@ -125,7 +126,6 @@ class AscendMetadata:
|
||||
is_only_prefill: bool = False
|
||||
# Current state of this attention run.
|
||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@@ -149,7 +149,8 @@ class AscendAttentionMetadataBuilder:
|
||||
attn_mask = self.runner.attn_mask
|
||||
attn_state = self.runner.attn_state
|
||||
|
||||
attn_metadata = AscendMetadata(block_tables=block_table,
|
||||
attn_metadata = AscendMetadata(num_actual_tokens=num_actual_tokens,
|
||||
block_tables=block_table,
|
||||
query_lens=query_lens,
|
||||
seq_lens=seq_lens,
|
||||
max_query_len=max_query_len,
|
||||
@@ -234,9 +235,9 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
output=output,
|
||||
layer_name=layer.layer_name)
|
||||
else:
|
||||
num_tokens = query.shape[0]
|
||||
if attn_metadata is None:
|
||||
return output.view(num_tokens, self.hidden_size)
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
attn_type = self.attn_type
|
||||
if attn_type != AttentionType.DECODER:
|
||||
@@ -255,11 +256,12 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
if self.key_cache is None:
|
||||
self.key_cache, self.value_cache = kv_cache[0], kv_cache[1]
|
||||
slots = attn_metadata.slot_mapping
|
||||
torch_npu._npu_reshape_and_cache(key=key,
|
||||
value=value,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_indices=slots)
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key=key[:num_actual_tokens],
|
||||
value=value[:num_actual_tokens],
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_indices=slots)
|
||||
|
||||
if hasattr(layer, 'quant_method'):
|
||||
# TODO: Add attr (num_prefills, prefill_metadata, decode_metadata) to AscendMetadata
|
||||
|
||||
Reference in New Issue
Block a user