feat: support compile torchair graph while warming up (#839)
### What this PR does / why we need it? feat: support compile torchair graph while warming up Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
@@ -241,7 +241,44 @@ class AscendMLAMetadataBuilder:
|
||||
max_blocks] = block_tables[:num_seqs, :
|
||||
max_blocks]
|
||||
|
||||
return graph_block_tables
|
||||
return graph_block_tables[:num_seqs, :max_blocks]
|
||||
|
||||
def build_dummy(self, num_reqs: int,
|
||||
num_actual_tokens: int) -> AscendMLAMetadata:
|
||||
device = self.runner.device
|
||||
_, max_blocks = self.runner.graph_block_tables.shape
|
||||
block_table = torch.zeros((num_reqs, max_blocks),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
block_table = self._get_graph_runner_block_tables(
|
||||
num_reqs, block_table)
|
||||
seq_lens = torch.ones(num_reqs, dtype=torch.int32, device=device)
|
||||
input_positions = torch.zeros(num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=device).long()
|
||||
slot_mapping = torch.full((num_reqs, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
decode_metadata = AscendMLADecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens.tolist(),
|
||||
max_seq_lens=1)
|
||||
return self.metadata_cls( # type: ignore
|
||||
num_input_tokens=num_actual_tokens,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
head_dim=self.runner.model_config.get_head_size(),
|
||||
num_decodes=1,
|
||||
num_decode_tokens=1,
|
||||
num_prefills=0,
|
||||
attn_mask=self.runner.attn_mask,
|
||||
attn_state=AscendAttentionState.DecodeOnly,
|
||||
prefill=None,
|
||||
decode=decode_metadata,
|
||||
)
|
||||
|
||||
def build(self,
|
||||
num_reqs: int,
|
||||
@@ -324,7 +361,7 @@ class AscendMLAMetadataBuilder:
|
||||
block_table = torch.cat([block_table, block_table_padding],
|
||||
dim=0)
|
||||
block_table = self._get_graph_runner_block_tables(
|
||||
num_seqs, block_table)
|
||||
num_seqs + graph_pad_size, block_table)
|
||||
padding_0 = torch.zeros(graph_pad_size,
|
||||
dtype=input_positions.dtype,
|
||||
device=input_positions.device)
|
||||
|
||||
Reference in New Issue
Block a user