[V1] clean up V1 code (#505)
Clean up V1 code: 1. remove useless code. 2. format code to be clear. Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -24,7 +24,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from vllm.attention import AttentionType
|
from vllm.attention import AttentionType
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
@@ -36,11 +35,9 @@ from vllm.logger import logger
|
|||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
from vllm.utils import DeviceMemoryProfiler, LayerBlockType, cdiv
|
||||||
LayerBlockType, cdiv, is_pin_memory_available)
|
|
||||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheSpec)
|
KVCacheSpec)
|
||||||
@@ -50,6 +47,7 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
|||||||
|
|
||||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
|
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
|
||||||
AscendMetadata)
|
AscendMetadata)
|
||||||
|
from vllm_ascend.platform import NPUPlatform
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
@@ -60,61 +58,32 @@ NPU_PAGED_ATTENTION_MASK_VALUE = -10000
|
|||||||
class NPUModelRunner:
|
class NPUModelRunner:
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||||
|
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.model_config = vllm_config.model_config
|
self.model_config = vllm_config.model_config
|
||||||
self.cache_config = vllm_config.cache_config
|
|
||||||
self.lora_config = vllm_config.lora_config
|
self.lora_config = vllm_config.lora_config
|
||||||
self.load_config = vllm_config.load_config
|
|
||||||
self.parallel_config = vllm_config.parallel_config
|
|
||||||
self.scheduler_config = vllm_config.scheduler_config
|
self.scheduler_config = vllm_config.scheduler_config
|
||||||
self.speculative_config = vllm_config.speculative_config
|
|
||||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
|
||||||
self.observability_config = vllm_config.observability_config
|
|
||||||
|
|
||||||
model_config = self.model_config
|
|
||||||
cache_config = self.cache_config
|
|
||||||
scheduler_config = self.scheduler_config
|
|
||||||
parallel_config = self.parallel_config
|
|
||||||
|
|
||||||
self.device = device
|
self.device = device
|
||||||
self.pin_memory = is_pin_memory_available()
|
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||||||
self.dtype = self.model_config.dtype
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
|
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
|
||||||
if cache_config.cache_dtype == "auto":
|
self.block_size)
|
||||||
self.kv_cache_dtype = self.dtype
|
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||||
else:
|
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
|
||||||
cache_config.cache_dtype]
|
|
||||||
|
|
||||||
self.is_multimodal_model = model_config.is_multimodal_model
|
|
||||||
self.sliding_window = model_config.get_sliding_window()
|
|
||||||
self.block_size = cache_config.block_size
|
|
||||||
self.max_model_len = model_config.max_model_len
|
|
||||||
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
|
|
||||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
|
||||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
|
||||||
|
|
||||||
# Model-related.
|
# Model-related.
|
||||||
self.num_attn_layers = model_config.get_num_layers_by_block_type(
|
self.num_attn_layers = self.model_config.get_num_layers_by_block_type(
|
||||||
parallel_config, LayerBlockType.attention)
|
vllm_config.parallel_config, LayerBlockType.attention)
|
||||||
self.num_query_heads = model_config.get_num_attention_heads(
|
self.hidden_size = self.model_config.get_hidden_size()
|
||||||
parallel_config)
|
|
||||||
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
|
|
||||||
self.head_size = model_config.get_head_size()
|
|
||||||
self.hidden_size = model_config.get_hidden_size()
|
|
||||||
|
|
||||||
# Multi-modal data support
|
# Multi-modal data support
|
||||||
self.input_registry = INPUT_REGISTRY
|
self.input_registry = INPUT_REGISTRY
|
||||||
self.mm_registry = MULTIMODAL_REGISTRY
|
self.mm_registry = MULTIMODAL_REGISTRY
|
||||||
self.uses_mrope = model_config.uses_mrope
|
self.uses_mrope = self.model_config.uses_mrope
|
||||||
|
|
||||||
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
|
self.max_num_encoder_input_tokens, self.encoder_cache_size = compute_encoder_budget(
|
||||||
model_config=model_config,
|
model_config=self.model_config,
|
||||||
scheduler_config=scheduler_config,
|
scheduler_config=self.scheduler_config,
|
||||||
mm_registry=self.mm_registry)
|
mm_registry=self.mm_registry)
|
||||||
self.max_num_encoder_input_tokens = encoder_compute_budget
|
|
||||||
self.encoder_cache_size = encoder_cache_size
|
|
||||||
|
|
||||||
# Lazy initialization
|
# Lazy initialization
|
||||||
# self.model: nn.Module # Set after load_model
|
# self.model: nn.Module # Set after load_model
|
||||||
@@ -122,19 +91,16 @@ class NPUModelRunner:
|
|||||||
# req_id -> (input_id -> encoder_output)
|
# req_id -> (input_id -> encoder_output)
|
||||||
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
|
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
|
||||||
|
|
||||||
# Set up speculative decoding.
|
|
||||||
self.use_spec_decode = False
|
|
||||||
|
|
||||||
# Request states.
|
# Request states.
|
||||||
self.requests: Dict[str, CachedRequestState] = {}
|
self.requests: Dict[str, CachedRequestState] = {}
|
||||||
# Persistent batch.
|
# Persistent batch.
|
||||||
self.input_batch = InputBatch(
|
self.input_batch = InputBatch(
|
||||||
max_num_reqs=self.max_num_reqs,
|
max_num_reqs=self.max_num_reqs,
|
||||||
max_model_len=self.max_model_len,
|
max_model_len=self.model_config.max_model_len,
|
||||||
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=True,
|
||||||
vocab_size=model_config.get_vocab_size(),
|
vocab_size=self.model_config.get_vocab_size(),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.input_ids = torch.zeros(self.max_num_tokens,
|
self.input_ids = torch.zeros(self.max_num_tokens,
|
||||||
@@ -165,16 +131,17 @@ class NPUModelRunner:
|
|||||||
(3, self.max_num_tokens + 1),
|
(3, self.max_num_tokens + 1),
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
pin_memory=self.pin_memory)
|
pin_memory=True)
|
||||||
|
|
||||||
self.inputs_embeds = torch.zeros(
|
self.inputs_embeds = torch.zeros(
|
||||||
(self.max_num_tokens, self.hidden_size),
|
(self.max_num_tokens, self.hidden_size),
|
||||||
dtype=self.dtype,
|
dtype=self.model_config.dtype,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
|
||||||
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
||||||
self.arange_np: npt.NDArray[np.int32] = np.arange(max(
|
self.arange_np: npt.NDArray[np.int32] = np.arange(max(
|
||||||
self.max_num_reqs + 1, self.max_model_len, self.max_num_tokens),
|
self.max_num_reqs + 1, self.model_config.max_model_len,
|
||||||
|
self.max_num_tokens),
|
||||||
dtype=np.int32)
|
dtype=np.int32)
|
||||||
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
|
# NOTE(woosuk): These tensors are "stateless", i.e., they are literally
|
||||||
# a faster version of creating a new tensor every time. Thus, we should
|
# a faster version of creating a new tensor every time. Thus, we should
|
||||||
@@ -182,29 +149,23 @@ class NPUModelRunner:
|
|||||||
self.input_ids_cpu = torch.zeros(self.max_num_tokens,
|
self.input_ids_cpu = torch.zeros(self.max_num_tokens,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
pin_memory=self.pin_memory)
|
pin_memory=True)
|
||||||
self.input_ids_np = self.input_ids_cpu.numpy()
|
|
||||||
self.positions_cpu = torch.zeros(self.max_num_tokens,
|
self.positions_cpu = torch.zeros(self.max_num_tokens,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
pin_memory=self.pin_memory)
|
pin_memory=True)
|
||||||
self.positions_np = self.positions_cpu.numpy()
|
self.positions_np = self.positions_cpu.numpy()
|
||||||
|
|
||||||
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
|
self.slot_mapping_cpu = torch.zeros(self.max_num_tokens,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
pin_memory=self.pin_memory)
|
pin_memory=True)
|
||||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||||
|
|
||||||
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device="cpu",
|
|
||||||
pin_memory=self.pin_memory)
|
|
||||||
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
|
|
||||||
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
|
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
pin_memory=self.pin_memory)
|
pin_memory=True)
|
||||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||||
|
|
||||||
self.input_positions_cpu = torch.arange(0,
|
self.input_positions_cpu = torch.arange(0,
|
||||||
@@ -220,7 +181,8 @@ class NPUModelRunner:
|
|||||||
# Therefore, an environment variable is added here to dynamically set
|
# Therefore, an environment variable is added here to dynamically set
|
||||||
# the size of the pre-constructed mask matrix based on requirements.
|
# the size of the pre-constructed mask matrix based on requirements.
|
||||||
mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000)
|
mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000)
|
||||||
self.attn_mask_len = min(self.max_model_len, int(mask_len))
|
self.attn_mask_len = min(self.model_config.max_model_len,
|
||||||
|
int(mask_len))
|
||||||
self.attn_mask_npu = torch.full(
|
self.attn_mask_npu = torch.full(
|
||||||
(self.attn_mask_len, self.attn_mask_len),
|
(self.attn_mask_len, self.attn_mask_len),
|
||||||
NPU_PAGED_ATTENTION_MASK_VALUE,
|
NPU_PAGED_ATTENTION_MASK_VALUE,
|
||||||
@@ -384,8 +346,8 @@ class NPUModelRunner:
|
|||||||
def get_model(self) -> nn.Module:
|
def get_model(self) -> nn.Module:
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def make_attention_mask(self, seq_lens, query_lens,
|
def _make_attention_mask(self, seq_lens, query_lens,
|
||||||
position) -> torch.Tensor:
|
position) -> torch.Tensor:
|
||||||
max_seq_len = max(seq_lens, default=0)
|
max_seq_len = max(seq_lens, default=0)
|
||||||
if max_seq_len <= self.attn_mask_len:
|
if max_seq_len <= self.attn_mask_len:
|
||||||
return torch.index_select(self.attn_mask_npu,
|
return torch.index_select(self.attn_mask_npu,
|
||||||
@@ -475,9 +437,9 @@ class NPUModelRunner:
|
|||||||
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
|
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
|
||||||
self.device, non_blocking=True)
|
self.device, non_blocking=True)
|
||||||
|
|
||||||
attn_mask = self.make_attention_mask(seq_lens=seq_lens,
|
attn_mask = self._make_attention_mask(seq_lens=seq_lens,
|
||||||
query_lens=num_scheduled_tokens,
|
query_lens=num_scheduled_tokens,
|
||||||
position=positions)
|
position=positions)
|
||||||
|
|
||||||
attn_metadata = AscendMetadata(
|
attn_metadata = AscendMetadata(
|
||||||
seq_lens=query_lens,
|
seq_lens=query_lens,
|
||||||
@@ -653,22 +615,19 @@ class NPUModelRunner:
|
|||||||
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _dummy_run(
|
def _dummy_run(self) -> torch.Tensor:
|
||||||
self,
|
|
||||||
num_tokens: int,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
model = self.model
|
model = self.model
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
input_ids = None
|
input_ids = None
|
||||||
inputs_embeds = self.inputs_embeds[:num_tokens]
|
inputs_embeds = self.inputs_embeds[:self.max_num_tokens]
|
||||||
else:
|
else:
|
||||||
input_ids = self.input_ids[:num_tokens]
|
input_ids = self.input_ids[:self.max_num_tokens]
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
|
|
||||||
if self.uses_mrope:
|
if self.uses_mrope:
|
||||||
positions = self.mrope_positions[:, :num_tokens]
|
positions = self.mrope_positions[:, :self.max_num_tokens]
|
||||||
else:
|
else:
|
||||||
positions = self.input_positions_cpu[:num_tokens]
|
positions = self.input_positions_cpu[:self.max_num_tokens]
|
||||||
|
|
||||||
if get_pp_group().is_first_rank:
|
if get_pp_group().is_first_rank:
|
||||||
intermediate_tensors = None
|
intermediate_tensors = None
|
||||||
@@ -680,7 +639,7 @@ class NPUModelRunner:
|
|||||||
dtype=self.model_config.dtype,
|
dtype=self.model_config.dtype,
|
||||||
device=self.device))
|
device=self.device))
|
||||||
intermediate_tensors = IntermediateTensors({
|
intermediate_tensors = IntermediateTensors({
|
||||||
k: v[:num_tokens]
|
k: v[:self.max_num_tokens]
|
||||||
for k, v in self.intermediate_tensors.items()
|
for k, v in self.intermediate_tensors.items()
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -719,7 +678,7 @@ class NPUModelRunner:
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Trigger compilation for general shape.
|
# Trigger compilation for general shape.
|
||||||
hidden_states = self._dummy_run(self.max_num_tokens)
|
hidden_states = self._dummy_run()
|
||||||
|
|
||||||
if get_pp_group().is_last_rank:
|
if get_pp_group().is_last_rank:
|
||||||
hidden_states = hidden_states[logit_indices]
|
hidden_states = hidden_states[logit_indices]
|
||||||
@@ -727,7 +686,7 @@ class NPUModelRunner:
|
|||||||
else:
|
else:
|
||||||
logits = None
|
logits = None
|
||||||
|
|
||||||
current_platform.synchronize()
|
NPUPlatform.synchronize()
|
||||||
del hidden_states, logits, dummy_kv_caches
|
del hidden_states, logits, dummy_kv_caches
|
||||||
self.encoder_cache.clear()
|
self.encoder_cache.clear()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@@ -739,10 +698,8 @@ class NPUModelRunner:
|
|||||||
self.model = get_model(vllm_config=self.vllm_config)
|
self.model = get_model(vllm_config=self.vllm_config)
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
raise ValueError("LoRA model is not supported on NPU now.")
|
raise ValueError("LoRA model is not supported on NPU now.")
|
||||||
|
|
||||||
self.model_memory_usage = m.consumed_memory
|
|
||||||
logger.info("Loading model weights took %.4f GB",
|
logger.info("Loading model weights took %.4f GB",
|
||||||
self.model_memory_usage / float(2**30))
|
m.consumed_memory / float(2**30))
|
||||||
|
|
||||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -21,17 +21,15 @@ import gc
|
|||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.config import ParallelConfig, VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
init_distributed_environment,
|
init_distributed_environment,
|
||||||
set_custom_all_reduce)
|
set_custom_all_reduce)
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
@@ -40,17 +38,22 @@ from vllm.v1.outputs import ModelRunnerOutput
|
|||||||
from vllm.v1.utils import bind_kv_cache
|
from vllm.v1.utils import bind_kv_cache
|
||||||
from vllm.v1.worker.worker_base import WorkerBase
|
from vllm.v1.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
|
from vllm_ascend.platform import NPUPlatform
|
||||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||||
|
|
||||||
|
|
||||||
class NPUWorker(WorkerBase):
|
class NPUWorker(WorkerBase):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
vllm_config: VllmConfig,
|
self,
|
||||||
local_rank: int,
|
vllm_config: VllmConfig,
|
||||||
rank: int,
|
local_rank: int,
|
||||||
distributed_init_method: str,
|
rank: int,
|
||||||
is_driver_worker: bool = False):
|
distributed_init_method: str,
|
||||||
|
is_driver_worker: bool = False,
|
||||||
|
# Additional parameters for compatibility with vllm
|
||||||
|
**kwargs):
|
||||||
|
"""Initialize the worker for Ascend."""
|
||||||
# Register ops when worker init.
|
# Register ops when worker init.
|
||||||
from vllm_ascend import ops # noqa: F401
|
from vllm_ascend import ops # noqa: F401
|
||||||
|
|
||||||
@@ -59,19 +62,6 @@ class NPUWorker(WorkerBase):
|
|||||||
rank=rank,
|
rank=rank,
|
||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
is_driver_worker=is_driver_worker)
|
is_driver_worker=is_driver_worker)
|
||||||
|
|
||||||
self.vllm_config = vllm_config
|
|
||||||
self.model_config = vllm_config.model_config
|
|
||||||
self.cache_config = vllm_config.cache_config
|
|
||||||
self.lora_config = vllm_config.lora_config
|
|
||||||
self.load_config = vllm_config.load_config
|
|
||||||
self.parallel_config = vllm_config.parallel_config
|
|
||||||
self.scheduler_config = vllm_config.scheduler_config
|
|
||||||
self.device_config = vllm_config.device_config
|
|
||||||
self.speculative_config = vllm_config.speculative_config
|
|
||||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
|
||||||
self.observability_config = vllm_config.observability_config
|
|
||||||
|
|
||||||
if self.cache_config.cache_dtype == "auto":
|
if self.cache_config.cache_dtype == "auto":
|
||||||
self.cache_dtype = self.model_config.dtype
|
self.cache_dtype = self.model_config.dtype
|
||||||
else:
|
else:
|
||||||
@@ -82,53 +72,21 @@ class NPUWorker(WorkerBase):
|
|||||||
# note: lazy import to avoid importing torch before initializing
|
# note: lazy import to avoid importing torch before initializing
|
||||||
from vllm.utils import init_cached_hf_modules
|
from vllm.utils import init_cached_hf_modules
|
||||||
init_cached_hf_modules()
|
init_cached_hf_modules()
|
||||||
# Torch profiler. Enabled and configured through env vars:
|
|
||||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
|
||||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
|
||||||
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
|
||||||
logger.info("Profiling enabled. Traces will be saved to: %s",
|
|
||||||
torch_profiler_trace_dir)
|
|
||||||
|
|
||||||
experimental_config = torch_npu.profiler._ExperimentalConfig(
|
self.profiler = self._init_profiler()
|
||||||
export_type=torch_npu.profiler.ExportType.Text,
|
|
||||||
profiler_level=torch_npu.profiler.ProfilerLevel.Level0,
|
|
||||||
msprof_tx=False,
|
|
||||||
aic_metrics=torch_npu.profiler.AiCMetrics.AiCoreNone,
|
|
||||||
l2_cache=False,
|
|
||||||
op_attr=False,
|
|
||||||
data_simplification=False,
|
|
||||||
record_op_args=False,
|
|
||||||
gc_detect_threshold=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.profiler = torch_npu.profiler.profile(
|
|
||||||
activities=[
|
|
||||||
torch_npu.profiler.ProfilerActivity.CPU,
|
|
||||||
torch_npu.profiler.ProfilerActivity.NPU,
|
|
||||||
],
|
|
||||||
with_stack=True,
|
|
||||||
profile_memory=True,
|
|
||||||
with_modules=True,
|
|
||||||
experimental_config=experimental_config,
|
|
||||||
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(
|
|
||||||
torch_profiler_trace_dir))
|
|
||||||
else:
|
|
||||||
self.profiler = None
|
|
||||||
|
|
||||||
def init_device(self):
|
def init_device(self):
|
||||||
if self.device_config.device.type == "npu":
|
if self.device_config.device.type == "npu":
|
||||||
self.device = torch.device(f"npu:{self.local_rank}")
|
self.device = torch.device(f"npu:{self.local_rank}")
|
||||||
current_platform.set_device(self.device)
|
NPUPlatform.set_device(self.device)
|
||||||
|
NPUPlatform.empty_cache()
|
||||||
current_platform.empty_cache()
|
self.init_npu_memory = NPUPlatform.mem_get_info()[0]
|
||||||
self.init_npu_memory = current_platform.mem_get_info()[0]
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
info = f"Not support device type: {self.device_config.device}"
|
||||||
f"Not support device type: {self.device_config.device}")
|
logger.error(info)
|
||||||
|
raise RuntimeError(info)
|
||||||
# Initialize the distributed environment.
|
# Initialize the distributed environment.
|
||||||
init_worker_distributed_environment(self.parallel_config, self.rank,
|
self._init_worker_distributed_environment()
|
||||||
self.distributed_init_method,
|
|
||||||
self.local_rank)
|
|
||||||
# Set random seed.
|
# Set random seed.
|
||||||
set_random_seed(self.model_config.seed)
|
set_random_seed(self.model_config.seed)
|
||||||
|
|
||||||
@@ -140,14 +98,15 @@ class NPUWorker(WorkerBase):
|
|||||||
kv_cache_spec = self.model_runner.get_kv_cache_spec()
|
kv_cache_spec = self.model_runner.get_kv_cache_spec()
|
||||||
for layer_name, layer_spec in kv_cache_spec.items():
|
for layer_name, layer_spec in kv_cache_spec.items():
|
||||||
if isinstance(layer_spec, FullAttentionSpec):
|
if isinstance(layer_spec, FullAttentionSpec):
|
||||||
dtype = layer_spec.dtype
|
|
||||||
|
|
||||||
# Use an empty tensor instead of `None`` to force Dynamo to pass
|
# Use an empty tensor instead of `None`` to force Dynamo to pass
|
||||||
# it by reference, rather by specializing on the value ``None``.
|
# it by reference, rather by specializing on the value ``None``.
|
||||||
tpu_k_cache = torch.tensor([], dtype=dtype, device=self.device)
|
npu_k_cache = torch.tensor([],
|
||||||
tpu_v_cache = torch.tensor([], dtype=dtype, device=self.device)
|
dtype=layer_spec.dtype,
|
||||||
|
device=self.device)
|
||||||
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
|
npu_v_cache = torch.tensor([],
|
||||||
|
dtype=layer_spec.dtype,
|
||||||
|
device=self.device)
|
||||||
|
kv_caches[layer_name] = (npu_k_cache, npu_v_cache)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -159,7 +118,7 @@ class NPUWorker(WorkerBase):
|
|||||||
|
|
||||||
# Profile the memory usage of the model and get the maximum number of
|
# Profile the memory usage of the model and get the maximum number of
|
||||||
# cache blocks that can be allocated with the remaining free memory.
|
# cache blocks that can be allocated with the remaining free memory.
|
||||||
current_platform.empty_cache()
|
NPUPlatform.empty_cache()
|
||||||
|
|
||||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||||
# of the model.
|
# of the model.
|
||||||
@@ -167,7 +126,7 @@ class NPUWorker(WorkerBase):
|
|||||||
|
|
||||||
# Calculate the number of blocks that can be allocated with the
|
# Calculate the number of blocks that can be allocated with the
|
||||||
# profiled peak memory.
|
# profiled peak memory.
|
||||||
free_npu_memory, total_npu_memory = current_platform.mem_get_info()
|
free_npu_memory, total_npu_memory = NPUPlatform.mem_get_info()
|
||||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||||
# GPU did not change their memory usage during the profiling.
|
# GPU did not change their memory usage during the profiling.
|
||||||
peak_memory = self.init_npu_memory - free_npu_memory
|
peak_memory = self.init_npu_memory - free_npu_memory
|
||||||
@@ -180,7 +139,7 @@ class NPUWorker(WorkerBase):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
# TODO: don`t need impl this func after empty_cache in
|
# TODO: don`t need impl this func after empty_cache in
|
||||||
# Worker.determine_num_available_blocks() unified`
|
# Worker.determine_num_available_blocks() unified`
|
||||||
current_platform.empty_cache()
|
NPUPlatform.empty_cache()
|
||||||
usable_memory_size = total_npu_memory * self.cache_config.gpu_memory_utilization - peak_memory
|
usable_memory_size = total_npu_memory * self.cache_config.gpu_memory_utilization - peak_memory
|
||||||
npu_kv_cache_bytes = max(usable_memory_size, 0)
|
npu_kv_cache_bytes = max(usable_memory_size, 0)
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -228,17 +187,47 @@ class NPUWorker(WorkerBase):
|
|||||||
else:
|
else:
|
||||||
self.profiler.stop()
|
self.profiler.stop()
|
||||||
|
|
||||||
|
def _init_worker_distributed_environment(self) -> None:
|
||||||
|
"""Initialize the distributed environment."""
|
||||||
|
set_custom_all_reduce(
|
||||||
|
not self.parallel_config.disable_custom_all_reduce)
|
||||||
|
init_distributed_environment(self.parallel_config.world_size,
|
||||||
|
self.rank, self.distributed_init_method,
|
||||||
|
self.local_rank, "hccl")
|
||||||
|
ensure_model_parallel_initialized(
|
||||||
|
self.parallel_config.tensor_parallel_size,
|
||||||
|
self.parallel_config.pipeline_parallel_size)
|
||||||
|
|
||||||
def init_worker_distributed_environment(
|
def _init_profiler(self):
|
||||||
parallel_config: ParallelConfig,
|
# Torch profiler. Enabled and configured through env vars:
|
||||||
rank: int,
|
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||||
distributed_init_method: Optional[str] = None,
|
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||||
local_rank: int = -1) -> None:
|
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||||
"""Initialize the distributed environment."""
|
logger.info("Profiling enabled. Traces will be saved to: %s",
|
||||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
torch_profiler_trace_dir)
|
||||||
|
|
||||||
init_distributed_environment(parallel_config.world_size, rank,
|
experimental_config = torch_npu.profiler._ExperimentalConfig(
|
||||||
distributed_init_method, local_rank, "hccl")
|
export_type=torch_npu.profiler.ExportType.Text,
|
||||||
|
profiler_level=torch_npu.profiler.ProfilerLevel.Level0,
|
||||||
|
msprof_tx=False,
|
||||||
|
aic_metrics=torch_npu.profiler.AiCMetrics.AiCoreNone,
|
||||||
|
l2_cache=False,
|
||||||
|
op_attr=False,
|
||||||
|
data_simplification=False,
|
||||||
|
record_op_args=False,
|
||||||
|
gc_detect_threshold=None,
|
||||||
|
)
|
||||||
|
|
||||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
return torch_npu.profiler.profile(
|
||||||
parallel_config.pipeline_parallel_size)
|
activities=[
|
||||||
|
torch_npu.profiler.ProfilerActivity.CPU,
|
||||||
|
torch_npu.profiler.ProfilerActivity.NPU,
|
||||||
|
],
|
||||||
|
with_stack=True,
|
||||||
|
profile_memory=True,
|
||||||
|
with_modules=True,
|
||||||
|
experimental_config=experimental_config,
|
||||||
|
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(
|
||||||
|
torch_profiler_trace_dir))
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|||||||
Reference in New Issue
Block a user