[main][refactor] Refactoring forward_context and model_runner_v1 (#1979)
### What this PR does / why we need it?
A refactoring of forward_context and model_runner_v1, add some context
which is necessary in model inference into forward_context, and refactor
dummy_run logic, make it more reasonable.
Some details for this PR:
Add `ascend_forward_context`;
Update mc2_v2 op, and support `active_mask` param;
Update scripts in examples dir;
refactor `dummy_run` logic;
Add soc_version for A2 and A3;
### Does this PR introduce _any_ user-facing change?
No change at user-facing.
### How was this patch tested?
- vLLM version: v0.10.0
- vLLM main:
57c22e57f9
Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
@@ -119,6 +119,7 @@ class AscendAttentionState(Enum):
|
||||
|
||||
@dataclass
|
||||
class AscendMetadata:
|
||||
|
||||
# **************************** Basic Properties ****************************
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
# Current state of this attention run.
|
||||
@@ -149,11 +150,6 @@ class AscendMetadata:
|
||||
# (num_tokens,)
|
||||
slot_mapping: torch.Tensor = None
|
||||
|
||||
# ************************* DP Related Properties **************************
|
||||
with_prefill_across_dp: bool = False
|
||||
# Maximum number of tokens across dp group
|
||||
max_num_tokens_across_dp: int = 0
|
||||
|
||||
|
||||
class AscendAttentionMetadataBuilder:
|
||||
|
||||
@@ -164,12 +160,7 @@ class AscendAttentionMetadataBuilder:
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return False
|
||||
|
||||
def build(self,
|
||||
num_reqs,
|
||||
num_actual_tokens,
|
||||
max_query_len,
|
||||
max_num_tokens_across_dp: int = 0,
|
||||
with_prefill_across_dp: bool = False):
|
||||
def build(self, num_reqs, num_actual_tokens, max_query_len):
|
||||
|
||||
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
|
||||
)
|
||||
@@ -196,18 +187,15 @@ class AscendAttentionMetadataBuilder:
|
||||
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(),
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
attn_metadata = AscendMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
block_tables=block_table,
|
||||
query_start_loc=query_start_loc,
|
||||
query_lens=query_lens,
|
||||
seq_lens=seq_lens,
|
||||
max_query_len=max_query_len,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state,
|
||||
max_num_tokens_across_dp=max_num_tokens_across_dp,
|
||||
with_prefill_across_dp=with_prefill_across_dp)
|
||||
attn_metadata = AscendMetadata(num_actual_tokens=num_actual_tokens,
|
||||
block_tables=block_table,
|
||||
query_start_loc=query_start_loc,
|
||||
query_lens=query_lens,
|
||||
seq_lens=seq_lens,
|
||||
max_query_len=max_query_len,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
|
||||
@@ -127,8 +127,6 @@ class AscendTorchairMetadata:
|
||||
query_start_loc: torch.Tensor
|
||||
query_lens: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
# max value of number of tokens across dp group
|
||||
max_num_tokens_across_dp: int = 0
|
||||
# Maximum query length in the batch. None for decoding.
|
||||
max_query_len: Optional[int] = None
|
||||
# (num_tokens,). The indices of the token slots that input tokens will be
|
||||
@@ -139,7 +137,7 @@ class AscendTorchairMetadata:
|
||||
# Current state of this attention run.
|
||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
with_prefill_across_dp: bool = False
|
||||
|
||||
decode: Optional[AscendDecodeMetadata] = None
|
||||
|
||||
|
||||
@@ -178,8 +176,9 @@ class AscendAttentionTorchairMetadataBuilder:
|
||||
|
||||
return graph_block_tables[:num_seqs, :max_blocks]
|
||||
|
||||
def build_dummy(self, num_reqs: int,
|
||||
num_actual_tokens: int) -> AscendTorchairMetadata:
|
||||
def build_torchair_graph_dummy(
|
||||
self, num_reqs: int,
|
||||
num_actual_tokens: int) -> AscendTorchairMetadata:
|
||||
device = self.runner.device
|
||||
_, max_blocks = self.runner.graph_block_tables.shape
|
||||
block_table = torch.zeros((num_reqs, max_blocks),
|
||||
@@ -214,7 +213,6 @@ class AscendAttentionTorchairMetadataBuilder:
|
||||
seq_lens=seq_lens,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_state=AscendAttentionState.DecodeOnly,
|
||||
max_num_tokens_across_dp=num_reqs,
|
||||
decode=decode_metadata)
|
||||
return attn_metadata
|
||||
|
||||
@@ -222,9 +220,7 @@ class AscendAttentionTorchairMetadataBuilder:
|
||||
num_reqs,
|
||||
num_actual_tokens,
|
||||
max_query_len,
|
||||
graph_pad_size: int = -1,
|
||||
max_num_tokens_across_dp: int = 0,
|
||||
with_prefill_across_dp: bool = False):
|
||||
graph_pad_size: int = -1):
|
||||
|
||||
device = self.runner.device
|
||||
|
||||
@@ -263,7 +259,6 @@ class AscendAttentionTorchairMetadataBuilder:
|
||||
pad_value = 1
|
||||
padded_seq_lens = seq_lens.tolist() + [pad_value
|
||||
] * graph_pad_size
|
||||
max_num_tokens_across_dp = len(padded_seq_lens)
|
||||
|
||||
seq_lens = torch.from_numpy(
|
||||
np.array(padded_seq_lens).astype(np.int32))
|
||||
@@ -303,9 +298,7 @@ class AscendAttentionTorchairMetadataBuilder:
|
||||
max_query_len=max_query_len,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state,
|
||||
max_num_tokens_across_dp=max_num_tokens_across_dp,
|
||||
with_prefill_across_dp=with_prefill_across_dp)
|
||||
attn_state=attn_state)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
|
||||
@@ -126,9 +126,6 @@ class AscendMLAMetadata:
|
||||
# For logging.
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
|
||||
max_num_tokens_across_dp: int = 0
|
||||
with_prefill_across_dp: bool = False
|
||||
|
||||
query_lens: Optional[list[int]] = None
|
||||
# The dimension of the attention heads
|
||||
head_dim: Optional[int] = None
|
||||
@@ -302,8 +299,8 @@ class AscendMLAMetadataBuilder:
|
||||
|
||||
return graph_block_tables[:num_seqs, :max_blocks]
|
||||
|
||||
def build_dummy(self, num_reqs: int,
|
||||
num_actual_tokens: int) -> AscendMLAMetadata:
|
||||
def build_torchair_graph_dummy(
|
||||
self, num_reqs: int, num_actual_tokens: int) -> AscendMLAMetadata:
|
||||
device = self.runner.device
|
||||
_, max_blocks = self.runner.graph_block_tables.shape
|
||||
block_table = torch.zeros((num_reqs, max_blocks),
|
||||
@@ -353,8 +350,6 @@ class AscendMLAMetadataBuilder:
|
||||
num_actual_tokens: int,
|
||||
max_query_len: int,
|
||||
graph_pad_size: int = -1,
|
||||
max_num_tokens_across_dp: int = 0,
|
||||
with_prefill_across_dp: bool = False,
|
||||
query_start_loc: torch.Tensor = None,
|
||||
) -> AscendMLAMetadata:
|
||||
assert self._num_decodes + self._num_prefills == num_reqs
|
||||
@@ -498,8 +493,6 @@ class AscendMLAMetadataBuilder:
|
||||
query_start_loc=query_start_loc,
|
||||
block_tables=block_table,
|
||||
seq_lens=seq_lens,
|
||||
max_num_tokens_across_dp=max_num_tokens_across_dp,
|
||||
with_prefill_across_dp=with_prefill_across_dp,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user