Fix some ci issue and refactor modelrunner (#2445)
### What this PR does / why we need it?
Fix some ci issue and refactor modelrunner
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
CI passed with existing test.
- vLLM version: v0.10.0
- vLLM main:
4d9c61993a
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
Co-authored-by: wangli <wangli858794774@gmail.com>
Co-authored-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -3,12 +3,13 @@ from typing import TYPE_CHECKING, Optional, Tuple, Type, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl)
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
@@ -17,11 +18,14 @@ from vllm.utils import cdiv, round_down
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
|
||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
|
||||
npu_stream_switch, npu_wait_tensor)
|
||||
from vllm_ascend.utils import npu_prefetch
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
|
||||
@@ -172,20 +176,24 @@ class AscendMLAMetadataBuilder:
|
||||
|
||||
# _attn_mask_builder = None
|
||||
def __init__(self,
|
||||
runner,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
metadata_cls: Optional[AscendMLAMetadata] = None):
|
||||
self.metadata_cls: Optional[AscendMLAMetadata] = metadata_cls \
|
||||
if metadata_cls is not None else AscendMLAMetadata # type: ignore
|
||||
self.runner = runner
|
||||
scheduler_config = runner.scheduler_config
|
||||
model_config = runner.model_config
|
||||
self.block_size = runner.block_size
|
||||
self.chunked_prefill_enabled = runner.chunked_prefill_enabled
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.device = device
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.max_blocks = (vllm_config.model_config.max_model_len +
|
||||
self.block_size - 1) // self.block_size
|
||||
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
||||
if self.chunked_prefill_enabled:
|
||||
self.chunked_prefill_workspace_size = min(
|
||||
# Max sure there is enough for 8 full length request or at least
|
||||
# 4 pages of cache per request
|
||||
max(8 * model_config.max_model_len,
|
||||
max(8 * self.model_config.max_model_len,
|
||||
4 * scheduler_config.max_num_seqs * self.block_size),
|
||||
# For long-context models try not to over-allocate limiting
|
||||
# kv-cache space, limiting it to 64k tokens,
|
||||
@@ -200,13 +208,13 @@ class AscendMLAMetadataBuilder:
|
||||
scheduler_config.max_num_seqs * self.block_size
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(self.chunked_prefill_workspace_size,
|
||||
model_config.get_head_size()),
|
||||
dtype=model_config.dtype,
|
||||
device=runner.device,
|
||||
self.model_config.get_head_size()),
|
||||
dtype=self.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
self.rope_dim = self.runner.model_config.hf_text_config.qk_rope_head_dim
|
||||
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||
self.cos_cache = None
|
||||
self.sin_cache = None
|
||||
|
||||
@@ -220,8 +228,6 @@ class AscendMLAMetadataBuilder:
|
||||
# better naming here)
|
||||
decodes = []
|
||||
prefills = []
|
||||
num_decode_tokens = 0
|
||||
num_prefill_tokens = 0
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
@@ -231,18 +237,14 @@ class AscendMLAMetadataBuilder:
|
||||
if self.torchair_graph_enabled:
|
||||
if num_tokens - num_spec_tokens == 1:
|
||||
decodes.append(i)
|
||||
num_decode_tokens += num_tokens
|
||||
else:
|
||||
prefills.append(i)
|
||||
num_prefill_tokens += num_tokens
|
||||
# For eager mode we treat spec decoding as chunked prefill.
|
||||
else:
|
||||
if num_tokens == 1:
|
||||
decodes.append(i)
|
||||
num_decode_tokens += num_tokens
|
||||
else:
|
||||
prefills.append(i)
|
||||
num_prefill_tokens += num_tokens
|
||||
|
||||
# We hope that this is fairly minimal since decodes
|
||||
# should be around for a number of iterations so hopefully they are
|
||||
@@ -273,26 +275,15 @@ class AscendMLAMetadataBuilder:
|
||||
# Save for next `build` call
|
||||
# TODO(lucas): this is a bit of a hack, we should probably have a
|
||||
# better way of doing this
|
||||
self._num_decodes = num_decodes
|
||||
self._num_prefills = num_prefills
|
||||
self._num_decode_tokens = num_decode_tokens
|
||||
self._num_prefill_tokens = num_prefill_tokens
|
||||
|
||||
return modified_batch
|
||||
|
||||
def _get_graph_runner_block_tables(
|
||||
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
|
||||
max_blocks = self.max_blocks
|
||||
|
||||
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
||||
assert max_batch_size >= num_seqs, f"max_batch_size: {max_batch_size} should be bigger than cur_num_seqs: {num_seqs}"
|
||||
|
||||
if isinstance(self.runner.graph_block_tables, np.ndarray):
|
||||
graph_block_tables = torch.zeros((max_batch_size, max_blocks),
|
||||
dtype=block_tables.dtype,
|
||||
device=block_tables.device)
|
||||
else:
|
||||
graph_block_tables = self.runner.graph_block_tables.to(
|
||||
device=block_tables.device, dtype=block_tables.dtype)
|
||||
graph_block_tables = torch.zeros((num_seqs, max_blocks),
|
||||
dtype=block_tables.dtype,
|
||||
device=block_tables.device)
|
||||
|
||||
num_blocks = block_tables.size(1)
|
||||
if num_blocks <= max_blocks:
|
||||
@@ -304,18 +295,20 @@ class AscendMLAMetadataBuilder:
|
||||
max_blocks] = block_tables[:num_seqs, :
|
||||
max_blocks]
|
||||
|
||||
return graph_block_tables[:num_seqs, :max_blocks]
|
||||
return graph_block_tables[:, :max_blocks]
|
||||
|
||||
def build_torchair_graph_dummy(
|
||||
self, num_reqs: int, num_actual_tokens: int) -> AscendMLAMetadata:
|
||||
device = self.runner.device
|
||||
_, max_blocks = self.runner.graph_block_tables.shape
|
||||
block_table = torch.zeros((num_reqs, max_blocks),
|
||||
self,
|
||||
common_attn_metadata: TorchairCommonAttentionMetadata,
|
||||
) -> AscendMLAMetadata:
|
||||
device = self.device
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
block_table = torch.zeros((num_reqs, self.max_blocks),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
block_table = self._get_graph_runner_block_tables(
|
||||
num_reqs, block_table)
|
||||
num_tokens = num_reqs * self.runner.decode_token_per_req
|
||||
num_tokens = num_reqs * common_attn_metadata.decode_token_per_req
|
||||
seq_lens = torch.zeros(num_reqs, dtype=torch.int32, device=device)
|
||||
seq_lens_list = [0] * num_reqs
|
||||
input_positions = torch.zeros(num_tokens,
|
||||
@@ -333,16 +326,16 @@ class AscendMLAMetadataBuilder:
|
||||
1,
|
||||
1,
|
||||
self.rope_dim,
|
||||
dtype=self.runner.dtype,
|
||||
dtype=self.model_config.dtype,
|
||||
device=device)
|
||||
cos = torch.ones(num_tokens,
|
||||
1,
|
||||
1,
|
||||
self.rope_dim,
|
||||
dtype=self.runner.dtype,
|
||||
dtype=self.model_config.dtype,
|
||||
device=device)
|
||||
if self.runner.speculative_config is not None and\
|
||||
self.runner.speculative_config.method == 'deepseek_mtp':
|
||||
if self.vllm_config.speculative_config is not None and\
|
||||
self.vllm_config.speculative_config.method == 'deepseek_mtp':
|
||||
attn_state = AscendAttentionState.SpecDecoding
|
||||
num_decode_tokens = 2
|
||||
else:
|
||||
@@ -354,20 +347,21 @@ class AscendMLAMetadataBuilder:
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens_list,
|
||||
max_seq_lens=1,
|
||||
attn_mask=self.runner.spec_attn_mask,
|
||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q[:num_reqs],
|
||||
attn_mask=common_attn_metadata.spec_attn_mask,
|
||||
actual_seq_lengths_q=common_attn_metadata.
|
||||
actual_seq_lengths_q[:num_reqs],
|
||||
sin=sin,
|
||||
cos=cos,
|
||||
)
|
||||
return self.metadata_cls( # type: ignore
|
||||
num_input_tokens=num_actual_tokens,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
num_input_tokens=common_attn_metadata.num_actual_tokens,
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
head_dim=self.runner.model_config.get_head_size(),
|
||||
head_dim=self.model_config.get_head_size(),
|
||||
num_decodes=1,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=0,
|
||||
attn_mask=self.runner.attn_mask,
|
||||
attn_mask=common_attn_metadata.attn_mask,
|
||||
attn_state=attn_state,
|
||||
prefill=None,
|
||||
decode=decode_metadata,
|
||||
@@ -378,58 +372,68 @@ class AscendMLAMetadataBuilder:
|
||||
|
||||
def build(
|
||||
self,
|
||||
num_reqs: int,
|
||||
num_actual_tokens: int,
|
||||
max_query_len: int,
|
||||
graph_pad_size: int = -1,
|
||||
query_start_loc: torch.Tensor = None,
|
||||
enable_dbo_across_dp: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
) -> AscendMLAMetadata:
|
||||
assert self._num_decodes + self._num_prefills == num_reqs
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
if self.torchair_graph_enabled and common_attn_metadata.attn_state in [
|
||||
AscendAttentionState.DecodeOnly,
|
||||
AscendAttentionState.SpecDecoding
|
||||
]:
|
||||
decode_threshold = common_attn_metadata.decode_token_per_req
|
||||
else:
|
||||
# TODO(xyx): remove the if condition after mla supports torch mode speculative decoding
|
||||
decode_threshold = 1
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
split_decodes_and_prefills(common_attn_metadata, decode_threshold=decode_threshold)
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
|
||||
|
||||
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
||||
# function. We should avoid GPU -> CPU sync as much as possible because
|
||||
# it blocks on all previous kernels.
|
||||
device = self.runner.device
|
||||
device = self.device
|
||||
|
||||
block_table = (self.runner.input_batch.block_table[0].
|
||||
get_device_tensor()[:num_reqs])
|
||||
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
|
||||
device, non_blocking=True)
|
||||
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
|
||||
device, non_blocking=True).long()
|
||||
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
|
||||
slot_mapping = common_attn_metadata.slot_mapping_cpu[:
|
||||
num_actual_tokens].to(
|
||||
device,
|
||||
non_blocking=
|
||||
True)
|
||||
input_positions = common_attn_metadata.positions[:
|
||||
num_actual_tokens].long(
|
||||
)
|
||||
|
||||
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
|
||||
query_lens = seq_lens_cpu - self.runner.input_batch.num_computed_tokens_cpu_tensor[:
|
||||
num_reqs]
|
||||
seq_lens = seq_lens_cpu
|
||||
max_query_len = query_lens.max().item()
|
||||
max_seq_lens = seq_lens.max().item()
|
||||
if self.cos_cache is None:
|
||||
self.cos_cache = self.runner.get_model(
|
||||
).model.layers[0].self_attn.rotary_emb.cos_cached
|
||||
self.sin_cache = self.runner.get_model(
|
||||
).model.layers[0].self_attn.rotary_emb.sin_cached
|
||||
if self.cos_cache.dtype != self.runner.dtype: # type: ignore
|
||||
self.cos_cache = model.model.layers[
|
||||
0].self_attn.rotary_emb.cos_cached
|
||||
self.sin_cache = model.model.layers[
|
||||
0].self_attn.rotary_emb.sin_cached
|
||||
if self.cos_cache.dtype != self.model_config.dtype: # type: ignore
|
||||
self.cos_cache = self.cos_cache.to( # type: ignore
|
||||
self.runner.dtype) # type: ignore
|
||||
self.model_config.dtype) # type: ignore
|
||||
self.sin_cache = self.sin_cache.to( # type: ignore
|
||||
self.runner.dtype) # type: ignore
|
||||
self.model_config.dtype) # type: ignore
|
||||
|
||||
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
query_lens = query_seq_lens_cpu[:num_reqs]
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
num_computed_tokens_cpu = (seq_lens - query_lens)
|
||||
|
||||
prefill_metadata = None
|
||||
chunked_context_metadata = None
|
||||
if self._num_prefills > 0:
|
||||
reqs_start = self._num_decodes # prefill_start
|
||||
tokens_start = self._num_decode_tokens
|
||||
if num_prefills > 0:
|
||||
reqs_start = num_decodes # prefill_start
|
||||
tokens_start = num_decode_tokens
|
||||
max_query_len = query_lens[tokens_start:].max().item()
|
||||
max_seq_lens = seq_lens[tokens_start:].max().item()
|
||||
prefill_query_start_loc = query_start_loc[
|
||||
reqs_start:] - query_start_loc[reqs_start]
|
||||
|
||||
context_lens_cpu = self.runner.input_batch.num_computed_tokens_cpu_tensor[
|
||||
reqs_start:num_reqs]
|
||||
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
|
||||
max_context_len_cpu = context_lens_cpu.max().item()
|
||||
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
||||
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
|
||||
@@ -441,12 +445,12 @@ class AscendMLAMetadataBuilder:
|
||||
assert max_context_chunk > 0
|
||||
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
|
||||
chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
|
||||
.unsqueeze(1).expand(-1, self._num_prefills) * max_context_chunk
|
||||
.unsqueeze(1).expand(-1, num_prefills) * max_context_chunk
|
||||
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
|
||||
chunk_starts + max_context_chunk)
|
||||
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
||||
cu_seq_lens_cpu = torch.zeros(num_chunks,
|
||||
self._num_prefills + 1,
|
||||
num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
pin_memory=True)
|
||||
torch.cumsum(chunk_seq_lens,
|
||||
@@ -470,7 +474,7 @@ class AscendMLAMetadataBuilder:
|
||||
prefill_input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
prefill_metadata = AscendMLAPrefillMetadata(
|
||||
attn_mask=self.runner.attn_mask,
|
||||
attn_mask=common_attn_metadata.attn_mask,
|
||||
query_lens=query_lens[tokens_start:],
|
||||
seq_lens=seq_lens,
|
||||
context_lens=seq_lens[tokens_start:],
|
||||
@@ -485,14 +489,15 @@ class AscendMLAMetadataBuilder:
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
graph_pad_size = common_attn_metadata.graph_pad_size
|
||||
use_torchair_graph = graph_pad_size != -1
|
||||
if self._num_decodes > 0:
|
||||
if num_decodes > 0:
|
||||
actual_seq_lengths_q = query_start_loc[1:].tolist()
|
||||
max_seq_lens = seq_lens[:self._num_decodes].max().item()
|
||||
seq_lens = seq_lens[:self._num_decode_tokens]
|
||||
input_positions = input_positions[:self._num_decode_tokens]
|
||||
block_table = block_table[:self._num_decode_tokens, ...]
|
||||
if use_torchair_graph and self.runner.attn_state in [
|
||||
max_seq_lens = seq_lens[:num_decodes].max().item()
|
||||
seq_lens = seq_lens[:num_decode_tokens]
|
||||
input_positions = input_positions[:num_decode_tokens]
|
||||
block_table = block_table[:num_decode_tokens, ...]
|
||||
if use_torchair_graph and common_attn_metadata.attn_state in [
|
||||
AscendAttentionState.DecodeOnly,
|
||||
AscendAttentionState.SpecDecoding
|
||||
]:
|
||||
@@ -500,10 +505,10 @@ class AscendMLAMetadataBuilder:
|
||||
num_token_pad_size = 0
|
||||
if graph_pad_size != 0:
|
||||
pad_value = 0
|
||||
num_token_pad_size = graph_pad_size - self._num_decode_tokens
|
||||
num_token_pad_size = graph_pad_size - num_decode_tokens
|
||||
num_reqs_pad_size = (
|
||||
graph_pad_size // self.runner.decode_token_per_req -
|
||||
num_reqs)
|
||||
graph_pad_size //
|
||||
common_attn_metadata.decode_token_per_req - num_reqs)
|
||||
padded_seq_lens = seq_lens.tolist(
|
||||
) + [pad_value] * num_reqs_pad_size
|
||||
else:
|
||||
@@ -531,14 +536,14 @@ class AscendMLAMetadataBuilder:
|
||||
input_positions = torch.cat(
|
||||
[input_positions, position_padding])
|
||||
actual_seq_lengths_q = query_start_loc[1:].tolist(
|
||||
) + self.runner.actual_seq_lengths_q[num_reqs:num_reqs +
|
||||
num_reqs_pad_size]
|
||||
) + common_attn_metadata.actual_seq_lengths_q[
|
||||
num_reqs:num_reqs + num_reqs_pad_size]
|
||||
else:
|
||||
seq_lens_list = seq_lens.tolist()
|
||||
# mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens)
|
||||
batch_size = slot_mapping.size(0)
|
||||
if actual_seq_lengths_q[-1] != batch_size \
|
||||
and self.runner.attn_state == AscendAttentionState.SpecDecoding:
|
||||
and common_attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
|
||||
actual_seq_lengths_q[-1] = batch_size
|
||||
|
||||
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
|
||||
@@ -552,7 +557,7 @@ class AscendMLAMetadataBuilder:
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens_list,
|
||||
max_seq_lens=max_seq_lens,
|
||||
attn_mask=self.runner.spec_attn_mask,
|
||||
attn_mask=common_attn_metadata.spec_attn_mask,
|
||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||
sin=sin,
|
||||
cos=cos)
|
||||
@@ -561,18 +566,18 @@ class AscendMLAMetadataBuilder:
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
query_lens=query_lens.tolist(),
|
||||
slot_mapping=slot_mapping,
|
||||
head_dim=self.runner.model_config.get_head_size(),
|
||||
num_decodes=self._num_decodes,
|
||||
num_decode_tokens=self._num_decode_tokens,
|
||||
num_prefills=self._num_prefills,
|
||||
attn_mask=self.runner.attn_mask,
|
||||
attn_state=self.runner.attn_state,
|
||||
head_dim=self.model_config.get_head_size(),
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
attn_mask=common_attn_metadata.attn_mask,
|
||||
attn_state=common_attn_metadata.attn_state,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
query_start_loc=query_start_loc,
|
||||
block_tables=block_table,
|
||||
seq_lens=seq_lens,
|
||||
enable_dbo_across_dp=enable_dbo_across_dp,
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user