Fix some ci issue and refactor modelrunner (#2445)
### What this PR does / why we need it?
Fix some ci issue and refactor modelrunner
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
CI passed with existing test.
- vLLM version: v0.10.0
- vLLM main:
4d9c61993a
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
Co-authored-by: wangli <wangli858794774@gmail.com>
Co-authored-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -20,15 +20,20 @@ from typing import List, Optional, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
|
||||
AttentionType)
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import cdiv
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
|
||||
AscendAttentionMetadataBuilder,
|
||||
AscendAttentionState,
|
||||
AscendMetadata)
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
||||
nd_to_nz_2d)
|
||||
|
||||
@@ -91,22 +96,26 @@ class AscendTorchairMetadata(AscendMetadata):
|
||||
|
||||
class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
|
||||
def __init__(self, runner):
|
||||
super().__init__(runner)
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(vllm_config, device)
|
||||
self.max_num_blocks_per_req = cdiv(
|
||||
self.model_config.max_model_len,
|
||||
self.vllm_config.cache_config.block_size)
|
||||
self.max_blocks = (self.model_config.max_model_len +
|
||||
self.vllm_config.cache_config.block_size -
|
||||
1) // self.vllm_config.cache_config.block_size
|
||||
|
||||
def _get_graph_runner_block_tables(
|
||||
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
|
||||
max_blocks = self.max_blocks
|
||||
|
||||
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
||||
assert max_batch_size >= num_seqs, f"max_batch_size: {max_batch_size} should be bigger than cur_num_seqs: {num_seqs}"
|
||||
|
||||
if isinstance(self.runner.graph_block_tables, np.ndarray):
|
||||
graph_block_tables = torch.zeros((max_batch_size, max_blocks),
|
||||
dtype=block_tables.dtype,
|
||||
device=block_tables.device)
|
||||
else:
|
||||
graph_block_tables = self.runner.graph_block_tables.to(
|
||||
device=block_tables.device, dtype=block_tables.dtype)
|
||||
graph_block_tables = torch.zeros((num_seqs, max_blocks),
|
||||
dtype=block_tables.dtype,
|
||||
device=block_tables.device)
|
||||
|
||||
num_blocks = block_tables.size(1)
|
||||
if num_blocks <= max_blocks:
|
||||
@@ -118,14 +127,14 @@ class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
max_blocks] = block_tables[:num_seqs, :
|
||||
max_blocks]
|
||||
|
||||
return graph_block_tables[:num_seqs, :max_blocks]
|
||||
return graph_block_tables[:, :max_blocks]
|
||||
|
||||
def build_torchair_graph_dummy(
|
||||
self, num_reqs: int,
|
||||
num_actual_tokens: int) -> AscendTorchairMetadata:
|
||||
device = self.runner.device
|
||||
_, max_blocks = self.runner.graph_block_tables.shape
|
||||
block_table = torch.zeros((num_reqs, max_blocks),
|
||||
self, common_attn_metadata: TorchairCommonAttentionMetadata
|
||||
) -> AscendTorchairMetadata:
|
||||
device = self.device
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
block_table = torch.zeros((num_reqs, self.max_blocks),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
block_table = self._get_graph_runner_block_tables(
|
||||
@@ -150,7 +159,7 @@ class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
max_seq_lens=1)
|
||||
|
||||
attn_metadata = AscendTorchairMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
block_tables=block_table,
|
||||
query_lens=0,
|
||||
query_start_loc=query_start_loc,
|
||||
@@ -160,52 +169,50 @@ class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
decode=decode_metadata)
|
||||
return attn_metadata
|
||||
|
||||
def build(self,
|
||||
num_reqs,
|
||||
num_actual_tokens,
|
||||
max_query_len,
|
||||
enable_dbo_across_dp: bool = False,
|
||||
is_only_prefill: bool = False,
|
||||
*args,
|
||||
**kwargs):
|
||||
def build(
|
||||
self,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
if 'graph_pad_size' in kwargs:
|
||||
graph_pad_size = kwargs['graph_pad_size']
|
||||
else:
|
||||
graph_pad_size = -1 # default value
|
||||
|
||||
device = self.runner.device
|
||||
|
||||
block_table = self.runner.input_batch.block_table[0].get_device_tensor(
|
||||
)
|
||||
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
block_table[:num_reqs, :self.max_num_blocks_per_req] = (
|
||||
block_table[:num_reqs])
|
||||
|
||||
query_lens = self.runner.query_lens
|
||||
seq_lens = self.runner.seq_lens_cpu[:num_reqs]
|
||||
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
self.runner.device, non_blocking=True)
|
||||
attn_mask = self.runner.attn_mask
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
|
||||
num_actual_tokens].to(
|
||||
self.device,
|
||||
non_blocking=
|
||||
True)
|
||||
attn_mask = common_attn_metadata.attn_mask
|
||||
|
||||
attn_state = self.runner.attn_state
|
||||
attn_state = common_attn_metadata.attn_state
|
||||
if is_310p() and attn_state == AscendAttentionState.PrefillNoCache:
|
||||
mask_nz = nd_to_nz_2d(attn_mask)
|
||||
attn_mask = torch_npu.npu_format_cast(mask_nz.contiguous(), 29)
|
||||
|
||||
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
|
||||
query_start_loc = query_start_loc_cpu.to(self.runner.device,
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||
num_reqs
|
||||
+ 1]
|
||||
query_start_loc = query_start_loc_cpu.to(self.device,
|
||||
non_blocking=True)
|
||||
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
|
||||
device, non_blocking=True).long()
|
||||
query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
input_positions = common_attn_metadata.positions[:
|
||||
num_actual_tokens].long(
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
graph_pad_size = common_attn_metadata.graph_pad_size
|
||||
use_torchair_graph = graph_pad_size > -1
|
||||
if self.runner.attn_state in [
|
||||
if common_attn_metadata.attn_state in [
|
||||
AscendAttentionState.DecodeOnly,
|
||||
]:
|
||||
max_seq_lens = seq_lens.max().item()
|
||||
num_seqs = len(seq_lens)
|
||||
if use_torchair_graph and self.runner.attn_state in [
|
||||
if use_torchair_graph and common_attn_metadata.attn_state in [
|
||||
AscendAttentionState.DecodeOnly,
|
||||
]:
|
||||
num_reqs_pad_size = 0
|
||||
@@ -214,8 +221,8 @@ class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
pad_value = 0
|
||||
num_token_pad_size = graph_pad_size - num_actual_tokens
|
||||
num_reqs_pad_size = (
|
||||
graph_pad_size // self.runner.decode_token_per_req -
|
||||
num_reqs)
|
||||
graph_pad_size //
|
||||
common_attn_metadata.decode_token_per_req - num_reqs)
|
||||
pad_value = 1
|
||||
padded_seq_lens = seq_lens.tolist() + [pad_value
|
||||
] * num_reqs_pad_size
|
||||
@@ -255,11 +262,11 @@ class AscendAttentionTorchairMetadataBuilder(AscendAttentionMetadataBuilder):
|
||||
query_start_loc=query_start_loc,
|
||||
query_lens=query_lens,
|
||||
seq_lens=seq_lens,
|
||||
max_query_len=max_query_len,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_mask=attn_mask,
|
||||
attn_state=attn_state,
|
||||
enable_dbo_across_dp=enable_dbo_across_dp)
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
|
||||
@@ -26,7 +26,8 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
|
||||
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
|
||||
check_torchair_cache_exist,
|
||||
register_torchair_model,
|
||||
write_kv_cache_bytes_to_file)
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
@@ -71,8 +72,16 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
# NOTE: If torchair graph mode and not with_prefill,
|
||||
# we can't skip_attn, it will cause graph recompile.
|
||||
if not with_prefill:
|
||||
common_attn_metadata = TorchairCommonAttentionMetadata(
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=1,
|
||||
actual_seq_lengths_q=self.actual_seq_lengths_q,
|
||||
attn_mask=self.attn_mask,
|
||||
spec_attn_mask=self.spec_attn_mask,
|
||||
decode_token_per_req=self.decode_token_per_req,
|
||||
)
|
||||
attn_metadata = self.attn_metadata_builder.build_torchair_graph_dummy(
|
||||
num_reqs=num_reqs, num_actual_tokens=1)
|
||||
common_attn_metadata)
|
||||
else:
|
||||
attn_metadata = super()._build_attention_metadata(
|
||||
with_prefill, num_reqs, skip_attn)
|
||||
|
||||
@@ -2,6 +2,7 @@ import fcntl
|
||||
import os
|
||||
import shutil
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
@@ -20,6 +21,32 @@ TORCHAIR_CACHE_DIR = os.getenv(
|
||||
'TORCHAIR_CACHE_HOME', os.path.join(os.getcwd(), TORCHAIR_CACHE_PATH_NAME))
|
||||
|
||||
|
||||
@dataclass
|
||||
class TorchairCommonAttentionMetadata:
|
||||
"""
|
||||
Per-batch attention metadata, shared across layers and backends.
|
||||
AttentionMetadataBuilder instances use it to construct per-layer metadata.
|
||||
|
||||
For many of the tensors we keep both GPU and CPU versions.
|
||||
"""
|
||||
|
||||
num_reqs: int
|
||||
"""Number of requests"""
|
||||
|
||||
num_actual_tokens: int
|
||||
"""Total number of tokens in batch"""
|
||||
|
||||
decode_token_per_req: int
|
||||
|
||||
actual_seq_lengths_q: list[int]
|
||||
|
||||
attn_mask: torch.Tensor = None
|
||||
|
||||
spec_attn_mask: torch.Tensor = None
|
||||
|
||||
graph_pad_size: int = -1
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _file_lock(file_descriptor, lock_type):
|
||||
fcntl.flock(file_descriptor, lock_type)
|
||||
|
||||
Reference in New Issue
Block a user