540 lines
23 KiB
Python
540 lines
23 KiB
Python
|
|
"""A VACC worker class."""
|
||
|
|
import gc
|
||
|
|
import os
|
||
|
|
from typing import Dict, List, Optional, Set, Tuple, Type, Union
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import torch.distributed
|
||
|
|
|
||
|
|
import vllm.envs as envs
|
||
|
|
from vllm.config import VllmConfig
|
||
|
|
from vllm.device_allocator.cumem import CuMemAllocator
|
||
|
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||
|
|
init_distributed_environment,
|
||
|
|
set_custom_all_reduce)
|
||
|
|
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||
|
|
from vllm.logger import init_logger
|
||
|
|
from vllm.lora.request import LoRARequest
|
||
|
|
from vllm.model_executor import set_random_seed
|
||
|
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||
|
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||
|
|
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
|
||
|
|
SequenceGroupMetadata, SequenceGroupMetadataDelta)
|
||
|
|
from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache,
|
||
|
|
memory_profiling)
|
||
|
|
from vllm.worker.cache_engine import CacheEngine
|
||
|
|
from vllm_vacc.vllm.worker.vacc_model_runner import VACCModelRunner,ModelRunnerBase
|
||
|
|
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
|
||
|
|
WorkerInput)
|
||
|
|
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
|
||
|
|
from vllm.worker.pooling_model_runner import PoolingModelRunner
|
||
|
|
|
||
|
|
logger = init_logger(__name__)
|
||
|
|
|
||
|
|
from vllm_vacc.vllm.model_executor.models.vars import BLOCK_GROUP_SIZE as env_blk_grp_size
|
||
|
|
|
||
|
|
TP_GROUP_ID = 1234
|
||
|
|
class VACCWorker( LocalOrDistributedWorkerBase):
|
||
|
|
"""A worker class that executes the model on a group of vacc cores.
|
||
|
|
"""
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
vllm_config: VllmConfig,
|
||
|
|
local_rank: int,
|
||
|
|
rank: int,
|
||
|
|
distributed_init_method: str,
|
||
|
|
is_driver_worker: bool = False,
|
||
|
|
model_runner_cls: Optional[Type[ModelRunnerBase]] = None,
|
||
|
|
) -> None:
|
||
|
|
WorkerBase.__init__(self, vllm_config)
|
||
|
|
# set_hf_num_attention_heads(hf_num_attention_heads)
|
||
|
|
# set_hf_num_key_value_heads(hf_num_key_value_heads)
|
||
|
|
self.parallel_config.rank = rank
|
||
|
|
self.local_rank = local_rank
|
||
|
|
self.rank = rank
|
||
|
|
self.distributed_init_method = distributed_init_method
|
||
|
|
self.is_driver_worker = is_driver_worker
|
||
|
|
if self.model_config.trust_remote_code:
|
||
|
|
# note: lazy import to avoid importing torch before initializing
|
||
|
|
from vllm.utils import init_cached_hf_modules
|
||
|
|
init_cached_hf_modules()
|
||
|
|
|
||
|
|
# Return hidden states from target model if the draft model is an
|
||
|
|
# mlp_speculator
|
||
|
|
speculative_config = self.speculative_config
|
||
|
|
model_config = self.model_config
|
||
|
|
speculative_args = {} if speculative_config is None \
|
||
|
|
or (speculative_config.draft_model_config.hf_config.model_type ==
|
||
|
|
model_config.hf_config.model_type) \
|
||
|
|
or (speculative_config.draft_model_config.hf_config.model_type
|
||
|
|
not in ["medusa", "mlp_speculator", "eagle", "deepseek_mtp"]) \
|
||
|
|
else {"return_hidden_states": True}
|
||
|
|
ModelRunnerClass: Type[ModelRunnerBase] = VACCModelRunner
|
||
|
|
if model_config.runner_type == "pooling":
|
||
|
|
ModelRunnerClass = PoolingModelRunner
|
||
|
|
elif self.model_config.is_encoder_decoder:
|
||
|
|
ModelRunnerClass = EncoderDecoderModelRunner
|
||
|
|
# if model_runner_cls is not None:
|
||
|
|
# ModelRunnerClass = model_runner_cls
|
||
|
|
# elif model_config.task == "embedding":
|
||
|
|
# RuntimeError(
|
||
|
|
# f"Not support model_config.task == embedding device type: {self.device_config.device}")
|
||
|
|
# elif self.model_config.is_encoder_decoder:
|
||
|
|
# RuntimeError(
|
||
|
|
# f"Not support model_config.is_encoder_decoder == True device type: {self.device_config.device}")
|
||
|
|
|
||
|
|
self.model_runner: ModelRunnerBase = ModelRunnerClass(
|
||
|
|
vllm_config=self.vllm_config,
|
||
|
|
kv_cache_dtype=self.cache_config.cache_dtype,
|
||
|
|
is_driver_worker=is_driver_worker,
|
||
|
|
**speculative_args,
|
||
|
|
)
|
||
|
|
# Uninitialized cache engine. Will be initialized by
|
||
|
|
# initialize_cache.
|
||
|
|
self.cache_engine: List[CacheEngine]
|
||
|
|
# Initialize gpu_cache as embedding models don't initialize kv_caches
|
||
|
|
self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
|
||
|
|
self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {}
|
||
|
|
|
||
|
|
# Torch profiler. Enabled and configured through env vars:
|
||
|
|
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||
|
|
if envs.VLLM_TORCH_PROFILER_DIR:
|
||
|
|
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||
|
|
logger.info("Profiling enabled. Traces will be saved to: %s",
|
||
|
|
torch_profiler_trace_dir)
|
||
|
|
self.profiler = torch.profiler.profile(
|
||
|
|
activities=[
|
||
|
|
torch.profiler.ProfilerActivity.CPU,
|
||
|
|
torch.profiler.ProfilerActivity.PrivateUse1,# vacc
|
||
|
|
],
|
||
|
|
with_stack=True,
|
||
|
|
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||
|
|
torch_profiler_trace_dir, use_gzip=True))
|
||
|
|
else:
|
||
|
|
self.profiler = None
|
||
|
|
|
||
|
|
def start_profile(self):
|
||
|
|
if self.profiler is None:
|
||
|
|
raise RuntimeError("Profiler is not enabled.")
|
||
|
|
self.profiler.start()
|
||
|
|
|
||
|
|
def stop_profile(self):
|
||
|
|
if self.profiler is None:
|
||
|
|
raise RuntimeError("Profiler is not enabled.")
|
||
|
|
self.profiler.stop()
|
||
|
|
def sleep(self, level: int = 1) -> None:
|
||
|
|
free_bytes_before_sleep = torch.vacc.mem_get_info()[0]
|
||
|
|
allocator = CuMemAllocator.get_instance()
|
||
|
|
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
|
||
|
|
free_bytes_after_sleep, total = torch.vacc.mem_get_info()
|
||
|
|
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
|
||
|
|
used_bytes = total - free_bytes_after_sleep
|
||
|
|
assert freed_bytes >= 0, "Memory usage increased after sleeping."
|
||
|
|
logger.info(
|
||
|
|
"Sleep mode freed %.2f GiB memory, "
|
||
|
|
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
|
||
|
|
used_bytes / GiB_bytes)
|
||
|
|
|
||
|
|
def wake_up(self) -> None:
|
||
|
|
allocator = CuMemAllocator.get_instance()
|
||
|
|
allocator.wake_up()
|
||
|
|
|
||
|
|
def init_device(self) -> None:
|
||
|
|
if self.device_config.device.type == "vacc":
|
||
|
|
try:
|
||
|
|
self.device = torch.device(f"vacc:{self.local_rank}")
|
||
|
|
torch.vacc.set_device(self.device)
|
||
|
|
gc.collect()
|
||
|
|
torch.vacc.empty_cache()
|
||
|
|
except Exception as e:
|
||
|
|
raise RuntimeError(
|
||
|
|
f"device init fail: {e} ",
|
||
|
|
f"self.device: {self.device}, check /dev/* or VACC_VISIBLE_DEVICES")
|
||
|
|
else:
|
||
|
|
raise RuntimeError(
|
||
|
|
f"Not support device type: {self.device_config.device}")
|
||
|
|
# Initialize the distributed environment.
|
||
|
|
init_worker_distributed_environment(self.vllm_config, self.rank,
|
||
|
|
self.distributed_init_method,
|
||
|
|
self.local_rank)
|
||
|
|
# Set random seed.
|
||
|
|
set_random_seed(self.model_config.seed)
|
||
|
|
|
||
|
|
def load_model(self):
|
||
|
|
self.model_runner.load_model()
|
||
|
|
|
||
|
|
def save_sharded_state(
|
||
|
|
self,
|
||
|
|
path: str,
|
||
|
|
pattern: Optional[str] = None,
|
||
|
|
max_size: Optional[int] = None,
|
||
|
|
) -> None:
|
||
|
|
self.model_runner.save_sharded_state(
|
||
|
|
path,
|
||
|
|
pattern=pattern,
|
||
|
|
max_size=max_size,
|
||
|
|
)
|
||
|
|
|
||
|
|
def save_tensorized_model(
|
||
|
|
self,
|
||
|
|
tensorizer_config: TensorizerConfig,
|
||
|
|
) -> None:
|
||
|
|
self.model_runner.save_tensorized_model(
|
||
|
|
tensorizer_config=tensorizer_config, )
|
||
|
|
|
||
|
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||
|
|
"""Determine the number of available KV blocks.
|
||
|
|
|
||
|
|
Swapping is not yet supported, so always return num_cpu_blocks=0.
|
||
|
|
|
||
|
|
We configure num_gpu_blocks to be equal to max_num_seqs.
|
||
|
|
"""
|
||
|
|
available_kv_cache_memory= int(os.getenv("VLLM_VACC_KVCACHE_SPACE", "16")) * GiB_bytes
|
||
|
|
|
||
|
|
max_seq_num = int(os.getenv("MAX_SEQ_NUM", 4))
|
||
|
|
max_num_gpu_blocks=0
|
||
|
|
if available_kv_cache_memory ==0:
|
||
|
|
torch.vacc.empty_cache()
|
||
|
|
torch.vacc.reset_peak_memory_stats()
|
||
|
|
total_memory = torch.vacc.mem_get_info()[1]
|
||
|
|
self.model_runner.profile_run()
|
||
|
|
torch.vacc.synchronize()
|
||
|
|
peak_memory = torch.vacc.max_memory_allocated()
|
||
|
|
torch.vacc.empty_cache()
|
||
|
|
torch_allocated_bytes = torch.vacc.memory_stats(
|
||
|
|
)["allocated_bytes.all.current"]
|
||
|
|
total_allocated_bytes = torch.vacc.mem_get_info(
|
||
|
|
)[1] - torch.vacc.mem_get_info()[0]
|
||
|
|
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
|
||
|
|
if non_torch_allocations > 0:
|
||
|
|
peak_memory += non_torch_allocations
|
||
|
|
available_kv_cache_memory=total_memory*self.cache_config.gpu_memory_utilization - peak_memory
|
||
|
|
|
||
|
|
if self.model_config.hf_config.model_type == "deepseek_v3":
|
||
|
|
assert self.model_config.max_model_len <= 65536, f"unsupported max model len, should less equal 65536 but got {self.model_config.max_model_len}"
|
||
|
|
# Rules:
|
||
|
|
# 1. always reserve N * BLOCK_GROUP_SIZE blocks
|
||
|
|
# 2. no less than (MAX_SEQ_NUM + 1) * BLOCK_GROUP_SIZE blocks
|
||
|
|
minimum_num_gpu_blocks_required = (max_seq_num + 1) * env_blk_grp_size // self.cache_config.block_size
|
||
|
|
max_model_len = (self.model_config.max_model_len + env_blk_grp_size - 1) // env_blk_grp_size * env_blk_grp_size
|
||
|
|
max_num_gpu_blocks = max_model_len // self.cache_config.block_size
|
||
|
|
|
||
|
|
# limited by available_kv_cache_memory
|
||
|
|
cache_block_size = self.get_cache_block_size_bytes()
|
||
|
|
if cache_block_size == 0:
|
||
|
|
num_gpu_blocks = 0
|
||
|
|
num_cpu_blocks = 0
|
||
|
|
else:
|
||
|
|
num_gpu_blocks = int(available_kv_cache_memory // cache_block_size)
|
||
|
|
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
|
||
|
|
cache_block_size)
|
||
|
|
assert num_gpu_blocks >= minimum_num_gpu_blocks_required, \
|
||
|
|
f"num_gpu_blocks should >= {minimum_num_gpu_blocks_required} please increase VLLM_VACC_KVCACHE_SPACE"
|
||
|
|
torch.vacc.empty_cache()
|
||
|
|
if self.model_runner.lora_manager:
|
||
|
|
self.model_runner.remove_all_loras()
|
||
|
|
gc.collect()
|
||
|
|
if max_num_gpu_blocks != 0:
|
||
|
|
num_gpu_blocks = min(max_num_gpu_blocks, num_gpu_blocks)
|
||
|
|
|
||
|
|
num_gpu_blocks = max(num_gpu_blocks, minimum_num_gpu_blocks_required)
|
||
|
|
return num_gpu_blocks * self.parallel_config.pipeline_parallel_size, num_cpu_blocks
|
||
|
|
|
||
|
|
def initialize_cache(self, num_gpu_blocks: int,
|
||
|
|
num_cpu_blocks: int) -> None:
|
||
|
|
"""Initialize the KV cache.
|
||
|
|
"""
|
||
|
|
raise_if_cache_size_invalid(num_gpu_blocks,
|
||
|
|
self.cache_config.block_size,
|
||
|
|
self.cache_config.is_attention_free,
|
||
|
|
self.model_config.max_model_len)
|
||
|
|
|
||
|
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||
|
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||
|
|
|
||
|
|
self._init_cache_engine()
|
||
|
|
self._warm_up_model()
|
||
|
|
|
||
|
|
def _init_cache_engine(self) -> None:
|
||
|
|
|
||
|
|
self.cache_engine = [
|
||
|
|
CacheEngine(self.cache_config, self.model_config,
|
||
|
|
self.parallel_config, self.device_config)
|
||
|
|
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||
|
|
]
|
||
|
|
self.gpu_cache = [
|
||
|
|
self.cache_engine[ve].gpu_cache
|
||
|
|
for ve in range(self.parallel_config.pipeline_parallel_size)
|
||
|
|
]
|
||
|
|
bind_kv_cache(self.compilation_config.static_forward_context,
|
||
|
|
self.gpu_cache)
|
||
|
|
|
||
|
|
self.model_runner.block_size = self.cache_engine[0].block_size
|
||
|
|
assert all(
|
||
|
|
self.gpu_cache[ve] is not None
|
||
|
|
for ve in range(self.parallel_config.pipeline_parallel_size))
|
||
|
|
|
||
|
|
# Populate the cache to warmup the memory
|
||
|
|
for ve in range(self.parallel_config.pipeline_parallel_size):
|
||
|
|
for layer_cache in self.gpu_cache[ve]:
|
||
|
|
layer_cache.fill_(0)
|
||
|
|
|
||
|
|
def _warm_up_model(self) -> None:
|
||
|
|
if not self.model_config.enforce_eager:
|
||
|
|
logger.info(f"VACC is not support model_config.enforce_eager = {self.model_config.enforce_eager}")
|
||
|
|
# self.model_runner.capture_model(self.gpu_cache)
|
||
|
|
# Reset the seed to ensure that the random state is not affected by
|
||
|
|
# the model initialization and profiling.
|
||
|
|
set_random_seed(self.model_config.seed)
|
||
|
|
|
||
|
|
@property
|
||
|
|
def do_metadata_broadcast(self) -> bool:
|
||
|
|
return self.parallel_config.tensor_parallel_size > 1
|
||
|
|
|
||
|
|
@property
|
||
|
|
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
|
||
|
|
return self.gpu_cache
|
||
|
|
|
||
|
|
@torch.inference_mode()
|
||
|
|
def prepare_worker_input(
|
||
|
|
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
||
|
|
virtual_engine = execute_model_req.virtual_engine
|
||
|
|
num_steps = execute_model_req.num_steps
|
||
|
|
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
|
||
|
|
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
|
||
|
|
# they contain parameters to launch vaccmemcpyasync.
|
||
|
|
blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in,
|
||
|
|
device="cpu",
|
||
|
|
dtype=torch.int64).view(-1, 2)
|
||
|
|
blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out,
|
||
|
|
device="cpu",
|
||
|
|
dtype=torch.int64).view(-1, 2)
|
||
|
|
# `blocks_to_copy` is a gpu tensor. The src and tgt of
|
||
|
|
# blocks to copy are in the same device, and `blocks_to_copy`
|
||
|
|
# can be used directly within vacc kernels.
|
||
|
|
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
|
||
|
|
device=self.device,
|
||
|
|
dtype=torch.int64).view(-1, 2)
|
||
|
|
|
||
|
|
return WorkerInput(
|
||
|
|
num_seq_groups=num_seq_groups,
|
||
|
|
blocks_to_swap_in=blocks_to_swap_in,
|
||
|
|
blocks_to_swap_out=blocks_to_swap_out,
|
||
|
|
blocks_to_copy=blocks_to_copy,
|
||
|
|
virtual_engine=virtual_engine,
|
||
|
|
num_steps=num_steps,
|
||
|
|
)
|
||
|
|
@torch.inference_mode()
|
||
|
|
def execute_worker(self, worker_input: WorkerInput) -> None:
|
||
|
|
virtual_engine = worker_input.virtual_engine
|
||
|
|
# Issue cache operations.
|
||
|
|
if (worker_input.blocks_to_swap_in is not None
|
||
|
|
and worker_input.blocks_to_swap_in.numel() > 0):
|
||
|
|
self.cache_engine[virtual_engine].swap_in(
|
||
|
|
worker_input.blocks_to_swap_in)
|
||
|
|
if (worker_input.blocks_to_swap_out is not None
|
||
|
|
and worker_input.blocks_to_swap_out.numel() > 0):
|
||
|
|
self.cache_engine[virtual_engine].swap_out(
|
||
|
|
worker_input.blocks_to_swap_out)
|
||
|
|
if (worker_input.blocks_to_copy is not None
|
||
|
|
and worker_input.blocks_to_copy.numel() > 0):
|
||
|
|
self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy)
|
||
|
|
|
||
|
|
def _get_cached_seq_group_metadata(
|
||
|
|
self,
|
||
|
|
seq_group_metadata_list: List[Union[SequenceGroupMetadata,
|
||
|
|
SequenceGroupMetadataDelta]],
|
||
|
|
finished_request_ids: List[str]) -> List[SequenceGroupMetadata]:
|
||
|
|
"""Return a list of cached Sequence Group Metadata after updating its
|
||
|
|
state.
|
||
|
|
|
||
|
|
It is used because scheduler only sends delta to workers to reduce
|
||
|
|
the data payload size. The function also cleans up cache based on
|
||
|
|
a given `finished_request_ids`.
|
||
|
|
"""
|
||
|
|
new_seq_group_metadata_list = []
|
||
|
|
for metadata_or_delta in seq_group_metadata_list:
|
||
|
|
request_id = metadata_or_delta.request_id
|
||
|
|
if request_id not in self._seq_group_metadata_cache:
|
||
|
|
# The first prefill.
|
||
|
|
assert isinstance(metadata_or_delta, SequenceGroupMetadata)
|
||
|
|
self._seq_group_metadata_cache[request_id] = metadata_or_delta
|
||
|
|
else:
|
||
|
|
# The first prefill is already cached.
|
||
|
|
if isinstance(metadata_or_delta, SequenceGroupMetadataDelta):
|
||
|
|
self._seq_group_metadata_cache[request_id].apply_delta(
|
||
|
|
metadata_or_delta)
|
||
|
|
else:
|
||
|
|
# If metadata snapshot is sent again, it is
|
||
|
|
# preempted. Reset the cache because we need to start
|
||
|
|
# from scratch.
|
||
|
|
assert isinstance(metadata_or_delta, SequenceGroupMetadata)
|
||
|
|
self._seq_group_metadata_cache[
|
||
|
|
request_id] = metadata_or_delta
|
||
|
|
|
||
|
|
new_seq_group_metadata_list.append(
|
||
|
|
self._seq_group_metadata_cache[request_id])
|
||
|
|
|
||
|
|
# Clean up finished ids
|
||
|
|
for finished_id in finished_request_ids:
|
||
|
|
del self._seq_group_metadata_cache[finished_id]
|
||
|
|
|
||
|
|
return new_seq_group_metadata_list
|
||
|
|
|
||
|
|
def _execute_model_spmd(
|
||
|
|
self,
|
||
|
|
execute_model_req: ExecuteModelRequest,
|
||
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||
|
|
) -> Optional[List[SamplerOutput]]:
|
||
|
|
if execute_model_req is not None:
|
||
|
|
new_seq_group_metadata_list = self._get_cached_seq_group_metadata(
|
||
|
|
execute_model_req.seq_group_metadata_list,
|
||
|
|
execute_model_req.finished_requests_ids)
|
||
|
|
|
||
|
|
execute_model_req.seq_group_metadata_list = (
|
||
|
|
new_seq_group_metadata_list)
|
||
|
|
output = super()._execute_model_spmd(execute_model_req,
|
||
|
|
intermediate_tensors)
|
||
|
|
return output
|
||
|
|
|
||
|
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||
|
|
return self.model_runner.add_lora(lora_request)
|
||
|
|
|
||
|
|
def remove_lora(self, lora_id: int) -> bool:
|
||
|
|
return self.model_runner.remove_lora(lora_id)
|
||
|
|
|
||
|
|
def pin_lora(self, lora_id: int) -> bool:
|
||
|
|
return self.model_runner.pin_lora(lora_id)
|
||
|
|
|
||
|
|
def list_loras(self) -> Set[int]:
|
||
|
|
return self.model_runner.list_loras()
|
||
|
|
|
||
|
|
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||
|
|
return self.model_runner.remove_lora(prompt_adapter_id)
|
||
|
|
|
||
|
|
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
|
||
|
|
return self.model_runner.pin_prompt_adapter(prompt_adapter_id)
|
||
|
|
|
||
|
|
def list_prompt_adapters(self) -> Set[int]:
|
||
|
|
return self.model_runner.list_prompt_adapters()
|
||
|
|
|
||
|
|
@property
|
||
|
|
def max_model_len(self) -> int:
|
||
|
|
return self.model_config.max_model_len
|
||
|
|
|
||
|
|
@property
|
||
|
|
def vocab_size(self) -> int:
|
||
|
|
return self.model_runner.vocab_size
|
||
|
|
|
||
|
|
def get_cache_block_size_bytes(self) -> int:
|
||
|
|
"""Get the size of the KV cache block size in bytes.
|
||
|
|
"""
|
||
|
|
return CacheEngine.get_cache_block_size(self.cache_config,
|
||
|
|
self.model_config,
|
||
|
|
self.parallel_config)
|
||
|
|
|
||
|
|
def collect_comm_op_infos(rank):
|
||
|
|
# collect tp comm infos
|
||
|
|
from vllm.distributed import get_tp_group
|
||
|
|
device_type = get_tp_group().device_group._device_types[0].type
|
||
|
|
device = f'{device_type}:{rank}'
|
||
|
|
all_ranks = get_tp_group().device_group.size()
|
||
|
|
|
||
|
|
send_tensor = [torch.rand((1,), dtype=torch.float32).to(device)]
|
||
|
|
recv_tensors = [[torch.rand(1, dtype=torch.float32).to(device) for _ in range(all_ranks)]]
|
||
|
|
|
||
|
|
opt = torch.distributed.distributed_c10d.AllgatherOptions()
|
||
|
|
import datetime
|
||
|
|
# seconds -1111 use for collect tags
|
||
|
|
opt.timeout = datetime.timedelta(seconds=-1111)
|
||
|
|
get_tp_group().device_group.allgather(recv_tensors, send_tensor, opts=opt)
|
||
|
|
|
||
|
|
def generate_rank_info_list():
|
||
|
|
global TP_GROUP_ID
|
||
|
|
from vllm.distributed import get_tp_group
|
||
|
|
# generate ran
|
||
|
|
get_tp_group().generate_rank_device_infos()
|
||
|
|
get_tp_group().generate_group_id(TP_GROUP_ID)
|
||
|
|
|
||
|
|
def generate_tp_group_id():
|
||
|
|
global TP_GROUP_ID
|
||
|
|
from pathlib import Path
|
||
|
|
import uuid
|
||
|
|
workspace_path = Path.cwd()
|
||
|
|
|
||
|
|
bootinfo_config = f'{workspace_path}/.bootinfos'
|
||
|
|
bootinfo_inited = os.path.exists(bootinfo_config)
|
||
|
|
|
||
|
|
current_bootinfos = "default"
|
||
|
|
if bootinfo_inited:
|
||
|
|
try:
|
||
|
|
with open(bootinfo_config) as w:
|
||
|
|
current_bootinfos = w.readline()
|
||
|
|
except Exception as e:
|
||
|
|
print("[WARN] bootinfo load fail ", e)
|
||
|
|
|
||
|
|
if current_bootinfos is not None:
|
||
|
|
unique_value = uuid.uuid5(uuid.NAMESPACE_URL, current_bootinfos).int
|
||
|
|
|
||
|
|
int32_value = unique_value & 0xFFFFFFFF
|
||
|
|
if int32_value >= 2**31:
|
||
|
|
int32_value -= 2**32
|
||
|
|
TP_GROUP_ID = int32_value
|
||
|
|
# print("current_bootinfos:", current_bootinfos, TP_GROUP_ID)
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
def init_worker_distributed_environment(
|
||
|
|
vllm_config: VllmConfig,
|
||
|
|
rank: int,
|
||
|
|
distributed_init_method: Optional[str] = None,
|
||
|
|
local_rank: int = -1,
|
||
|
|
backend: str = "vccl",
|
||
|
|
) -> None:
|
||
|
|
"""Initialize the distributed environment."""
|
||
|
|
parallel_config = vllm_config.parallel_config
|
||
|
|
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||
|
|
|
||
|
|
init_distributed_environment(parallel_config.world_size, rank,
|
||
|
|
distributed_init_method, local_rank,backend)
|
||
|
|
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||
|
|
parallel_config.pipeline_parallel_size)
|
||
|
|
|
||
|
|
ensure_kv_transfer_initialized(vllm_config)
|
||
|
|
|
||
|
|
# ensor TP progress is finished
|
||
|
|
from vllm.distributed import get_tp_group
|
||
|
|
get_tp_group().barrier()
|
||
|
|
|
||
|
|
# vccl collect comm infos
|
||
|
|
# use for vnnl&vccl fused ops
|
||
|
|
# collect_comm_op_infos(rank)
|
||
|
|
generate_tp_group_id()
|
||
|
|
generate_rank_info_list()
|
||
|
|
|
||
|
|
|
||
|
|
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free,
|
||
|
|
max_model_len) -> None:
|
||
|
|
if os.getenv("VLLM_VACC_KVCACHE_SPACE",None) is None:
|
||
|
|
logger.info("Skip executing the function about raise_if_cache_size_invalid "
|
||
|
|
"when user set environment `VLLM_OPENVINO_KVCACHE_SPACE` ")
|
||
|
|
return
|
||
|
|
|
||
|
|
if is_attention_free and num_gpu_blocks != 0:
|
||
|
|
raise ValueError("No memory should be allocated for the cache blocks "
|
||
|
|
f"for an attention-free model, but {num_gpu_blocks}"
|
||
|
|
"blocks are allocated.")
|
||
|
|
if not is_attention_free and num_gpu_blocks <= 0:
|
||
|
|
raise ValueError("No available memory for the cache blocks. "
|
||
|
|
"Try increasing `gpu_memory_utilization` when "
|
||
|
|
"initializing the engine.")
|
||
|
|
max_seq_len = block_size * num_gpu_blocks
|
||
|
|
if not is_attention_free and max_model_len > max_seq_len:
|
||
|
|
raise ValueError(
|
||
|
|
f"The model's max seq len ({max_model_len}) "
|
||
|
|
"is larger than the maximum number of tokens that can be "
|
||
|
|
f"stored in KV cache ({max_seq_len}). Try increasing "
|
||
|
|
"`gpu_memory_utilization` or decreasing `max_model_len` when "
|
||
|
|
"initializing the engine.")
|