feat: support compile torchair graph while warming up (#839)

### What this PR does / why we need it?
feat: support compile torchair graph while warming up

Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
NeverRaR
2025-05-31 06:03:03 +08:00
committed by GitHub
parent d9fb027068
commit 507ae627ca
7 changed files with 242 additions and 234 deletions

View File

@@ -241,7 +241,44 @@ class AscendMLAMetadataBuilder:
max_blocks] = block_tables[:num_seqs, :
max_blocks]
return graph_block_tables
return graph_block_tables[:num_seqs, :max_blocks]
def build_dummy(self, num_reqs: int,
num_actual_tokens: int) -> AscendMLAMetadata:
device = self.runner.device
_, max_blocks = self.runner.graph_block_tables.shape
block_table = torch.zeros((num_reqs, max_blocks),
dtype=torch.int32,
device=device)
block_table = self._get_graph_runner_block_tables(
num_reqs, block_table)
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
input_positions = torch.zeros(num_reqs,
dtype=torch.int32,
device=device).long()
slot_mapping = torch.full((num_reqs, ),
PAD_SLOT_ID,
dtype=torch.int32,
device=device)
decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens.tolist(),
max_seq_lens=1)
return self.metadata_cls( # type: ignore
num_input_tokens=num_actual_tokens,
num_actual_tokens=num_actual_tokens,
slot_mapping=slot_mapping,
head_dim=self.runner.model_config.get_head_size(),
num_decodes=1,
num_decode_tokens=1,
num_prefills=0,
attn_mask=self.runner.attn_mask,
attn_state=AscendAttentionState.DecodeOnly,
prefill=None,
decode=decode_metadata,
)
def build(self,
num_reqs: int,
@@ -324,7 +361,7 @@ class AscendMLAMetadataBuilder:
block_table = torch.cat([block_table, block_table_padding],
dim=0)
block_table = self._get_graph_runner_block_tables(
num_seqs, block_table)
num_seqs + graph_pad_size, block_table)
padding_0 = torch.zeros(graph_pad_size,
dtype=input_positions.dtype,
device=input_positions.device)