[V1] clean up V1 code (#505)

Clean up V1 code:
1. remove useless code.
2. format code to be clear.

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-04-15 10:24:02 +08:00
committed by GitHub
parent f6af1d2471
commit c7f6584d75
2 changed files with 113 additions and 167 deletions

View File

@@ -24,7 +24,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union
import numpy as np
import numpy.typing as npt
import torch
import torch.distributed
import torch.nn as nn
from vllm.attention import AttentionType
from vllm.attention.layer import Attention
@@ -36,11 +35,9 @@ from vllm.logger import logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv, is_pin_memory_available)
from vllm.utils import DeviceMemoryProfiler, LayerBlockType, cdiv
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
@@ -50,6 +47,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
AscendMetadata)
from vllm_ascend.platform import NPUPlatform
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
@@ -60,61 +58,32 @@ NPU_PAGED_ATTENTION_MASK_VALUE = -10000
class NPUModelRunner:
def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
model_config = self.model_config
cache_config = self.cache_config
scheduler_config = self.scheduler_config
parallel_config = self.parallel_config
self.device = device
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
if cache_config.cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
else:
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype]
self.is_multimodal_model = model_config.is_multimodal_model
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.max_model_len = model_config.max_model_len
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs
self.is_multimodal_model = self.model_config.is_multimodal_model
self.block_size = vllm_config.cache_config.block_size
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
self.block_size)
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.scheduler_config.max_num_seqs
# Model-related.
self.num_attn_layers = model_config.get_num_layers_by_block_type(
parallel_config, LayerBlockType.attention)
self.num_query_heads = model_config.get_num_attention_heads(
parallel_config)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.head_size = model_config.get_head_size()
self.hidden_size = model_config.get_hidden_size()
self.num_attn_layers = self.model_config.get_num_layers_by_block_type(
vllm_config.parallel_config, LayerBlockType.attention)
self.hidden_size = self.model_config.get_hidden_size()
# Multi-modal data support
self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY
self.uses_mrope = model_config.uses_mrope
self.uses_mrope = self.model_config.uses_mrope
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
model_config=model_config,
scheduler_config=scheduler_config,
self.max_num_encoder_input_tokens, self.encoder_cache_size = compute_encoder_budget(
model_config=self.model_config,
scheduler_config=self.scheduler_config,
mm_registry=self.mm_registry)
self.max_num_encoder_input_tokens = encoder_compute_budget
self.encoder_cache_size = encoder_cache_size
# Lazy initialization
# self.model: nn.Module # Set after load_model
@@ -122,19 +91,16 @@ class NPUModelRunner:
# req_id -> (input_id -> encoder_output)
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
# Set up speculative decoding.
self.use_spec_decode = False
# Request states.
self.requests: Dict[str, CachedRequestState] = {}
# Persistent batch.
self.input_batch = InputBatch(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_model_len=self.model_config.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
device=self.device,
pin_memory=self.pin_memory,
vocab_size=model_config.get_vocab_size(),
pin_memory=True,
vocab_size=self.model_config.get_vocab_size(),
)
self.input_ids = torch.zeros(self.max_num_tokens,
@@ -165,16 +131,17 @@ class NPUModelRunner:
(3, self.max_num_tokens + 1),
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
pin_memory=True)
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
dtype=self.model_config.dtype,
device=self.device)
# OPTIMIZATION: Cache the tensors rather than creating them every step.
self.arange_np: npt.NDArray[np.int32] = np.arange(max(
self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens),
self.max_num_reqs + 1, self.model_config.max_model_len,
self.max_num_tokens),
dtype=np.int32)
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
# a faster version of creating a new tensor every time. Thus, we should
@@ -182,29 +149,23 @@ class NPUModelRunner:
self.input_ids_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.input_ids_np = self.input_ids_cpu.numpy()
pin_memory=True)
self.positions_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
pin_memory=True)
self.positions_np = self.positions_cpu.numpy()
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
pin_memory=True)
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
pin_memory=True)
self.seq_lens_np = self.seq_lens_cpu.numpy()
self.input_positions_cpu = torch.arange(0,
@@ -220,7 +181,8 @@ class NPUModelRunner:
# Therefore, an environment variable is added here to dynamically set
# the size of the pre-constructed mask matrix based on requirements.
mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000)
self.attn_mask_len = min(self.max_model_len, int(mask_len))
self.attn_mask_len = min(self.model_config.max_model_len,
int(mask_len))
self.attn_mask_npu = torch.full(
(self.attn_mask_len, self.attn_mask_len),
NPU_PAGED_ATTENTION_MASK_VALUE,
@@ -384,8 +346,8 @@ class NPUModelRunner:
def get_model(self) -> nn.Module:
return self.model
def make_attention_mask(self, seq_lens, query_lens,
position) -> torch.Tensor:
def _make_attention_mask(self, seq_lens, query_lens,
position) -> torch.Tensor:
max_seq_len = max(seq_lens, default=0)
if max_seq_len <= self.attn_mask_len:
return torch.index_select(self.attn_mask_npu,
@@ -475,9 +437,9 @@ class NPUModelRunner:
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
self.device, non_blocking=True)
attn_mask = self.make_attention_mask(seq_lens=seq_lens,
query_lens=num_scheduled_tokens,
position=positions)
attn_mask = self._make_attention_mask(seq_lens=seq_lens,
query_lens=num_scheduled_tokens,
position=positions)
attn_metadata = AscendMetadata(
seq_lens=query_lens,
@@ -653,22 +615,19 @@ class NPUModelRunner:
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
@torch.inference_mode()
def _dummy_run(
self,
num_tokens: int,
) -> torch.Tensor:
def _dummy_run(self) -> torch.Tensor:
model = self.model
if self.is_multimodal_model:
input_ids = None
inputs_embeds = self.inputs_embeds[:num_tokens]
inputs_embeds = self.inputs_embeds[:self.max_num_tokens]
else:
input_ids = self.input_ids[:num_tokens]
input_ids = self.input_ids[:self.max_num_tokens]
inputs_embeds = None
if self.uses_mrope:
positions = self.mrope_positions[:, :num_tokens]
positions = self.mrope_positions[:, :self.max_num_tokens]
else:
positions = self.input_positions_cpu[:num_tokens]
positions = self.input_positions_cpu[:self.max_num_tokens]
if get_pp_group().is_first_rank:
intermediate_tensors = None
@@ -680,7 +639,7 @@ class NPUModelRunner:
dtype=self.model_config.dtype,
device=self.device))
intermediate_tensors = IntermediateTensors({
k: v[:num_tokens]
k: v[:self.max_num_tokens]
for k, v in self.intermediate_tensors.items()
})
@@ -719,7 +678,7 @@ class NPUModelRunner:
]
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.max_num_tokens)
hidden_states = self._dummy_run()
if get_pp_group().is_last_rank:
hidden_states = hidden_states[logit_indices]
@@ -727,7 +686,7 @@ class NPUModelRunner:
else:
logits = None
current_platform.synchronize()
NPUPlatform.synchronize()
del hidden_states, logits, dummy_kv_caches
self.encoder_cache.clear()
gc.collect()
@@ -739,10 +698,8 @@ class NPUModelRunner:
self.model = get_model(vllm_config=self.vllm_config)
if self.lora_config:
raise ValueError("LoRA model is not supported on NPU now.")
self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB",
self.model_memory_usage / float(2**30))
m.consumed_memory / float(2**30))
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""