[Misc] Clean up uesless code for LLM initialize (#1373)
This PR aims to clean up the useless code for LLM setup. It helps to make the code more clear. 1. remove useless `self.xxx` property 2. change `set_random_seed` to `seed_everything` 3. remove `set_custom_all_reduce`, it's only used for cuda This is just a code clean. no change for any code logic. Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -49,8 +49,7 @@ from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
|||||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
from vllm.utils import DeviceMemoryProfiler, LazyLoader, cdiv
|
||||||
LayerBlockType, LazyLoader, cdiv)
|
|
||||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheSpec)
|
KVCacheSpec)
|
||||||
@@ -137,82 +136,69 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.lora_config = vllm_config.lora_config
|
self.lora_config = vllm_config.lora_config
|
||||||
self.scheduler_config = vllm_config.scheduler_config
|
self.scheduler_config = vllm_config.scheduler_config
|
||||||
self.speculative_config = vllm_config.speculative_config
|
self.speculative_config = vllm_config.speculative_config
|
||||||
ascend_config = get_ascend_config()
|
|
||||||
if ascend_config.ascend_scheduler_config.enabled:
|
|
||||||
self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled
|
|
||||||
else:
|
|
||||||
self.chunked_prefill_enabled = True
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
self.is_multimodal_model = self.model_config.is_multimodal_model
|
|
||||||
self.block_size = vllm_config.cache_config.block_size
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
|
|
||||||
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
|
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
|
||||||
self.block_size)
|
self.block_size)
|
||||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
||||||
|
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||||
self.graph_block_tables = np.zeros(
|
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||||
(self.vllm_config.scheduler_config.max_num_seqs,
|
self.device = device
|
||||||
(self.model_config.max_model_len + self.block_size - 1) //
|
|
||||||
self.block_size),
|
|
||||||
dtype=np.int32)
|
|
||||||
|
|
||||||
# Model-related.
|
|
||||||
self.num_attn_layers = self.model_config.get_num_layers_by_block_type(
|
|
||||||
vllm_config.parallel_config, LayerBlockType.attention)
|
|
||||||
self.hidden_size = self.model_config.get_hidden_size()
|
|
||||||
self.dtype = self.model_config.dtype
|
self.dtype = self.model_config.dtype
|
||||||
cache_config = vllm_config.cache_config
|
self.sampler = Sampler()
|
||||||
if cache_config.cache_dtype == "auto":
|
|
||||||
self.kv_cache_dtype = self.dtype
|
|
||||||
else:
|
|
||||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
|
||||||
cache_config.cache_dtype]
|
|
||||||
|
|
||||||
self.head_size = self.model_config.get_head_size()
|
|
||||||
self.attn_backend = get_attn_backend(
|
|
||||||
self.head_size,
|
|
||||||
self.dtype,
|
|
||||||
self.kv_cache_dtype,
|
|
||||||
self.block_size,
|
|
||||||
self.model_config.is_attention_free,
|
|
||||||
use_mla=self.model_config.use_mla,
|
|
||||||
)
|
|
||||||
if self.attn_backend is None:
|
|
||||||
error_msg = (
|
|
||||||
f"Error with get_att_backend: {self.head_size=}, "
|
|
||||||
f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
|
|
||||||
f"{self.model_config.is_attention_free=}, "
|
|
||||||
f"{self.model_config.use_mla=}")
|
|
||||||
logger.error(error_msg)
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Non-Attention backend is not supported by V1 NPUModelRunner.")
|
|
||||||
|
|
||||||
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
|
||||||
weakref.proxy(self))
|
|
||||||
|
|
||||||
# Multi-modal data support
|
# Multi-modal data support
|
||||||
self.input_registry = INPUT_REGISTRY
|
self.input_registry = INPUT_REGISTRY
|
||||||
self.mm_registry = MULTIMODAL_REGISTRY
|
self.mm_registry = MULTIMODAL_REGISTRY
|
||||||
self.uses_mrope = self.model_config.uses_mrope
|
|
||||||
|
|
||||||
self.max_num_encoder_input_tokens, self.encoder_cache_size = compute_encoder_budget(
|
self.max_num_encoder_input_tokens, self.encoder_cache_size = compute_encoder_budget(
|
||||||
model_config=self.model_config,
|
model_config=self.model_config,
|
||||||
scheduler_config=self.scheduler_config,
|
scheduler_config=self.scheduler_config,
|
||||||
mm_registry=self.mm_registry)
|
mm_registry=self.mm_registry)
|
||||||
|
|
||||||
# Lazy initialization
|
# Lazy initialization, these will be set after __init__
|
||||||
# self.model: nn.Module # Set after load_model
|
|
||||||
self.kv_caches: List[torch.Tensor] = []
|
self.kv_caches: List[torch.Tensor] = []
|
||||||
# req_id -> (input_id -> encoder_output)
|
|
||||||
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
|
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
|
||||||
|
self.attn_mask = None
|
||||||
|
self.attn_state = None
|
||||||
|
self.requests: Dict[str, CachedRequestState] = {}
|
||||||
|
self.intermediate_tensors: Optional[IntermediateTensors] = None
|
||||||
|
|
||||||
|
ascend_config = get_ascend_config()
|
||||||
|
if ascend_config.ascend_scheduler_config.enabled:
|
||||||
|
self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled
|
||||||
|
else:
|
||||||
|
self.chunked_prefill_enabled = True
|
||||||
|
|
||||||
|
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||||||
|
if self.is_multimodal_model:
|
||||||
|
self.inputs_embeds = torch.zeros(
|
||||||
|
(self.max_num_tokens, self.model_config.get_hidden_size()),
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
|
self.graph_block_tables = np.zeros(
|
||||||
|
(self.max_num_reqs,
|
||||||
|
(self.model_config.max_model_len + self.block_size - 1) //
|
||||||
|
self.block_size),
|
||||||
|
dtype=np.int32)
|
||||||
|
|
||||||
|
# Set up Attention
|
||||||
|
self.attn_backend = get_attn_backend(
|
||||||
|
0,
|
||||||
|
self.dtype,
|
||||||
|
None,
|
||||||
|
self.block_size,
|
||||||
|
self.model_config.is_attention_free,
|
||||||
|
use_mla=self.model_config.use_mla,
|
||||||
|
)
|
||||||
|
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
||||||
|
weakref.proxy(self))
|
||||||
|
|
||||||
# Set up speculative decoding.
|
# Set up speculative decoding.
|
||||||
self.use_aux_hidden_state_outputs = False
|
self.use_aux_hidden_state_outputs = False
|
||||||
self.use_spec_decode = False
|
self.use_spec_decode = False
|
||||||
self.spec_attn_mask = None
|
self.spec_attn_mask = None
|
||||||
self.use_eagle = False
|
self.use_eagle = False
|
||||||
|
self.drafter = None
|
||||||
if self.speculative_config:
|
if self.speculative_config:
|
||||||
self.use_spec_decode = True
|
self.use_spec_decode = True
|
||||||
self.spec_attn_mask = torch.triu(torch.ones(2048,
|
self.spec_attn_mask = torch.triu(torch.ones(2048,
|
||||||
@@ -235,10 +221,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
f"{self.speculative_config.method}")
|
f"{self.speculative_config.method}")
|
||||||
self.rejection_sampler = AscendRejectionSampler()
|
self.rejection_sampler = AscendRejectionSampler()
|
||||||
|
|
||||||
# Request states.
|
|
||||||
self.requests: Dict[str, CachedRequestState] = {}
|
|
||||||
# Persistent batch.
|
|
||||||
|
|
||||||
self.input_ids = torch.zeros(self.max_num_tokens,
|
self.input_ids = torch.zeros(self.max_num_tokens,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
@@ -251,9 +233,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.seq_lens = torch.zeros(self.max_num_reqs,
|
self.seq_lens = torch.zeros(self.max_num_reqs,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
# None in the first PP rank. The rest are set after load_model.
|
|
||||||
self.intermediate_tensors: Optional[IntermediateTensors] = None
|
|
||||||
|
|
||||||
|
self.uses_mrope = self.model_config.uses_mrope
|
||||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||||
if self.uses_mrope:
|
if self.uses_mrope:
|
||||||
# NOTE: `mrope_positions` is implemented with one additional dummy
|
# NOTE: `mrope_positions` is implemented with one additional dummy
|
||||||
@@ -276,12 +257,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
pin_memory=True)
|
pin_memory=True)
|
||||||
self.mrope_positions_np = self.mrope_positions_cpu.numpy()
|
self.mrope_positions_np = self.mrope_positions_cpu.numpy()
|
||||||
|
|
||||||
if self.is_multimodal_model:
|
|
||||||
self.inputs_embeds = torch.zeros(
|
|
||||||
(self.max_num_tokens, self.hidden_size),
|
|
||||||
dtype=self.dtype,
|
|
||||||
device=self.device)
|
|
||||||
|
|
||||||
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
# OPTIMIZATION: Cache the tensors rather than creating them every step.
|
||||||
self.arange_np: npt.NDArray[np.int32] = np.arange(max(
|
self.arange_np: npt.NDArray[np.int32] = np.arange(max(
|
||||||
self.max_num_reqs + 1, self.model_config.max_model_len,
|
self.max_num_reqs + 1, self.model_config.max_model_len,
|
||||||
@@ -305,24 +280,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
device="cpu",
|
device="cpu",
|
||||||
pin_memory=True)
|
pin_memory=True)
|
||||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||||
|
|
||||||
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
|
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
pin_memory=True)
|
pin_memory=True)
|
||||||
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
|
self.query_start_loc_np = self.query_start_loc_cpu.numpy()
|
||||||
|
|
||||||
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
|
self.seq_lens_cpu = torch.zeros(self.max_num_reqs,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
pin_memory=True)
|
pin_memory=True)
|
||||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||||
|
|
||||||
self.input_positions_cpu = torch.arange(0,
|
|
||||||
self.max_num_tokens,
|
|
||||||
device="cpu")
|
|
||||||
self.attn_mask = None
|
|
||||||
self.attn_state = None
|
|
||||||
self.use_aclgraph = (self.vllm_config.compilation_config.level
|
self.use_aclgraph = (self.vllm_config.compilation_config.level
|
||||||
== CompilationLevel.PIECEWISE
|
== CompilationLevel.PIECEWISE
|
||||||
and not self.model_config.enforce_eager)
|
and not self.model_config.enforce_eager)
|
||||||
@@ -339,28 +307,20 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# Therefore, an environment variable is added here to dynamically set
|
# Therefore, an environment variable is added here to dynamically set
|
||||||
# the size of the pre-constructed mask matrix based on requirements.
|
# the size of the pre-constructed mask matrix based on requirements.
|
||||||
mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000)
|
mask_len = os.getenv("PAGED_ATTENTION_MASK_LEN", 10000)
|
||||||
self.attn_mask_len = min(self.model_config.max_model_len,
|
attn_mask_len = min(self.model_config.max_model_len, int(mask_len))
|
||||||
int(mask_len))
|
|
||||||
self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
|
self.attn_mask_builder = AttentionMaskBuilder.initialize_from_len(
|
||||||
self.attn_mask_len, self.dtype)
|
attn_mask_len, self.dtype)
|
||||||
|
|
||||||
self.sampler = Sampler()
|
|
||||||
|
|
||||||
self.torchair_compiled_model = None # type: ignore
|
self.torchair_compiled_model = None # type: ignore
|
||||||
self.torchair_compiled_models = {} # type: ignore
|
self.torchair_compiled_models = {} # type: ignore
|
||||||
ascend_config = get_ascend_config()
|
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled and self.vllm_config.model_config.use_mla
|
|
||||||
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
|
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
|
||||||
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
|
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
|
||||||
|
|
||||||
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
|
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
|
||||||
self.init_torchair_graph_batch_sizes()
|
self.init_torchair_graph_batch_sizes()
|
||||||
|
|
||||||
if len(self.torchair_graph_batch_sizes) == 0:
|
if len(self.torchair_graph_batch_sizes) == 0:
|
||||||
# TODO(zzzzwwjj): check torchair_graph_batch_sizes init code
|
# TODO(zzzzwwjj): check torchair_graph_batch_sizes init code
|
||||||
self.torchair_graph_batch_sizes = [
|
self.torchair_graph_batch_sizes = [self.max_num_reqs]
|
||||||
self.scheduler_config.max_num_seqs
|
|
||||||
]
|
|
||||||
|
|
||||||
torch._dynamo.cache_size.config.cache_size_limit += len(
|
torch._dynamo.cache_size.config.cache_size_limit += len(
|
||||||
self.torchair_graph_batch_sizes)
|
self.torchair_graph_batch_sizes)
|
||||||
@@ -368,9 +328,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
torch._logging.set_logs(
|
torch._logging.set_logs(
|
||||||
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
|
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
|
||||||
|
|
||||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
|
||||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
|
||||||
|
|
||||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||||
"""Update the cached states and the persistent batch with the scheduler
|
"""Update the cached states and the persistent batch with the scheduler
|
||||||
output.
|
output.
|
||||||
@@ -1702,8 +1659,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# for dummy run with LoRA so that the num_reqs collectively
|
# for dummy run with LoRA so that the num_reqs collectively
|
||||||
# has num_tokens in total.
|
# has num_tokens in total.
|
||||||
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
|
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
|
||||||
max_num_reqs = self.scheduler_config.max_num_seqs
|
num_reqs = self.max_num_reqs if num_tokens >= self.max_num_reqs else num_tokens
|
||||||
num_reqs = max_num_reqs if num_tokens >= max_num_reqs else num_tokens
|
|
||||||
min_tokens_per_req = num_tokens // num_reqs
|
min_tokens_per_req = num_tokens // num_reqs
|
||||||
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
||||||
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
||||||
@@ -1805,14 +1761,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
# For profile, have maximum num_reqs and that collectively have
|
# For profile, have maximum num_reqs and that collectively have
|
||||||
# maximum num_tokens.
|
# maximum num_tokens.
|
||||||
num_reqs = self.scheduler_config.max_num_seqs
|
min_tokens_per_req = self.max_num_tokens // self.max_num_reqs
|
||||||
num_tokens = self.max_num_tokens
|
|
||||||
min_tokens_per_req = num_tokens // num_reqs
|
|
||||||
|
|
||||||
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
num_scheduled_tokens_list = [min_tokens_per_req] * self.max_num_reqs
|
||||||
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
num_scheduled_tokens_list[
|
||||||
assert sum(num_scheduled_tokens_list) == num_tokens
|
-1] += self.max_num_tokens % self.max_num_reqs
|
||||||
assert len(num_scheduled_tokens_list) == num_reqs
|
assert sum(num_scheduled_tokens_list) == self.max_num_tokens
|
||||||
|
assert len(num_scheduled_tokens_list) == self.max_num_reqs
|
||||||
|
|
||||||
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
num_scheduled_tokens = np.array(num_scheduled_tokens_list,
|
||||||
dtype=np.int32)
|
dtype=np.int32)
|
||||||
@@ -1840,15 +1795,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
||||||
self.model = get_model(vllm_config=self.vllm_config)
|
self.model = get_model(vllm_config=self.vllm_config)
|
||||||
if hasattr(self, "drafter"):
|
if self.drafter:
|
||||||
logger.info("Loading drafter model...")
|
logger.info("Loading drafter model...")
|
||||||
if self.use_aux_hidden_state_outputs:
|
if self.use_aux_hidden_state_outputs:
|
||||||
self.drafter.load_model(self.model)
|
self.drafter.load_model(self.model)
|
||||||
|
self.model.set_aux_hidden_state_layers(
|
||||||
|
self.model.get_eagle3_aux_hidden_state_layers())
|
||||||
else:
|
else:
|
||||||
self.drafter.load_model()
|
self.drafter.load_model()
|
||||||
if self.use_aux_hidden_state_outputs:
|
|
||||||
self.model.set_aux_hidden_state_layers(
|
|
||||||
self.model.get_eagle3_aux_hidden_state_layers())
|
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
self.model = self.load_lora_model(self.model,
|
self.model = self.load_lora_model(self.model,
|
||||||
self.model_config,
|
self.model_config,
|
||||||
@@ -1934,7 +1888,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
vocab_size=self.model_config.get_vocab_size(),
|
vocab_size=self.model_config.get_vocab_size(),
|
||||||
block_sizes=[self.cache_config.block_size],
|
block_sizes=[self.block_size],
|
||||||
)
|
)
|
||||||
|
|
||||||
kv_cache_sizes = {}
|
kv_cache_sizes = {}
|
||||||
@@ -2014,7 +1968,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
forward_ctx = self.vllm_config.compilation_config.static_forward_context
|
forward_ctx = self.vllm_config.compilation_config.static_forward_context
|
||||||
block_size = self.vllm_config.cache_config.block_size
|
|
||||||
use_mla = self.vllm_config.model_config.use_mla
|
use_mla = self.vllm_config.model_config.use_mla
|
||||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||||
for layer_name, attn_module in forward_ctx.items():
|
for layer_name, attn_module in forward_ctx.items():
|
||||||
@@ -2026,7 +1979,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
assert isinstance(attn_module, Attention)
|
assert isinstance(attn_module, Attention)
|
||||||
if attn_module.attn_type == AttentionType.DECODER:
|
if attn_module.attn_type == AttentionType.DECODER:
|
||||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||||
block_size=block_size,
|
block_size=self.block_size,
|
||||||
num_kv_heads=attn_module.num_kv_heads,
|
num_kv_heads=attn_module.num_kv_heads,
|
||||||
head_size=attn_module.head_size,
|
head_size=attn_module.head_size,
|
||||||
dtype=attn_module.dtype,
|
dtype=attn_module.dtype,
|
||||||
@@ -2115,6 +2068,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
start_idx = self.input_batch.num_tokens_no_spec[i]
|
start_idx = self.input_batch.num_tokens_no_spec[i]
|
||||||
end_idx = start_idx + num_sampled_ids
|
end_idx = start_idx + num_sampled_ids
|
||||||
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
|
self.input_batch.token_ids_cpu[i, start_idx:end_idx] = sampled_ids
|
||||||
|
assert self.drafter is not None
|
||||||
drafter_output = self.drafter.propose(
|
drafter_output = self.drafter.propose(
|
||||||
self.input_batch.token_ids_cpu[i, :end_idx])
|
self.input_batch.token_ids_cpu[i, :end_idx])
|
||||||
if drafter_output is None or len(drafter_output) == 0:
|
if drafter_output is None or len(drafter_output) == 0:
|
||||||
@@ -2171,6 +2125,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
assert self.drafter is not None
|
||||||
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
|
cu_num_tokens, token_indices = self.drafter.prepare_inputs(
|
||||||
attn_metadata.query_start_loc,
|
attn_metadata.query_start_loc,
|
||||||
num_rejected_tokens,
|
num_rejected_tokens,
|
||||||
@@ -2179,7 +2134,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
target_positions = positions[token_indices]
|
target_positions = positions[token_indices]
|
||||||
target_hidden_states = hidden_states[token_indices]
|
target_hidden_states = hidden_states[token_indices]
|
||||||
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
|
target_slot_mapping = attn_metadata.slot_mapping[token_indices]
|
||||||
|
assert self.drafter is not None
|
||||||
draft_token_ids = self.drafter.propose(
|
draft_token_ids = self.drafter.propose(
|
||||||
target_token_ids=target_token_ids,
|
target_token_ids=target_token_ids,
|
||||||
target_positions=target_positions,
|
target_positions=target_positions,
|
||||||
@@ -2200,7 +2155,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks
|
# NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks
|
||||||
start_graph_batch_size = max(start_graph_batch_size, tp_size)
|
start_graph_batch_size = max(start_graph_batch_size, tp_size)
|
||||||
|
|
||||||
while (start_graph_batch_size <= self.scheduler_config.max_num_seqs):
|
while (start_graph_batch_size <= self.max_num_reqs):
|
||||||
self.torchair_graph_batch_sizes.append(start_graph_batch_size)
|
self.torchair_graph_batch_sizes.append(start_graph_batch_size)
|
||||||
start_graph_batch_size *= 2
|
start_graph_batch_size *= 2
|
||||||
|
|
||||||
|
|||||||
@@ -26,12 +26,10 @@ from torch_npu.op_plugin.atb._atb_ops import _register_atb_extensions
|
|||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
init_distributed_environment,
|
init_distributed_environment)
|
||||||
set_custom_all_reduce)
|
|
||||||
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor import set_random_seed
|
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||||
@@ -93,7 +91,6 @@ class NPUWorker(WorkerBase):
|
|||||||
self.profiler = self._init_profiler()
|
self.profiler = self._init_profiler()
|
||||||
|
|
||||||
def sleep(self, level: int = 1) -> None:
|
def sleep(self, level: int = 1) -> None:
|
||||||
NPUPlatform.set_device(self.device)
|
|
||||||
free_bytes_before_sleep = NPUPlatform.mem_get_info()[0]
|
free_bytes_before_sleep = NPUPlatform.mem_get_info()[0]
|
||||||
allocator = CaMemAllocator.get_instance()
|
allocator = CaMemAllocator.get_instance()
|
||||||
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
|
allocator.sleep(offload_tags=("weights", ) if level == 1 else tuple())
|
||||||
@@ -116,22 +113,18 @@ class NPUWorker(WorkerBase):
|
|||||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||||
|
|
||||||
def init_device(self):
|
def init_device(self):
|
||||||
if self.device_config.device.type == "npu":
|
device = torch.device(f"npu:{self.local_rank}")
|
||||||
self.device = torch.device(f"npu:{self.local_rank}")
|
NPUPlatform.set_device(device)
|
||||||
NPUPlatform.set_device(self.device)
|
NPUPlatform.empty_cache()
|
||||||
NPUPlatform.empty_cache()
|
self.init_npu_memory = NPUPlatform.mem_get_info()[0]
|
||||||
self.init_npu_memory = NPUPlatform.mem_get_info()[0]
|
|
||||||
else:
|
|
||||||
info = f"Not support device type: {self.device_config.device}"
|
|
||||||
logger.error(info)
|
|
||||||
raise RuntimeError(info)
|
|
||||||
# Initialize the distributed environment.
|
# Initialize the distributed environment.
|
||||||
self._init_worker_distributed_environment()
|
self._init_worker_distributed_environment()
|
||||||
# Set random seed.
|
# Set random seed.
|
||||||
set_random_seed(self.model_config.seed)
|
NPUPlatform.seed_everything(self.model_config.seed)
|
||||||
|
|
||||||
# Init ModelRunner here, so that we have access to self.device.
|
# Init ModelRunner here, so that we have access to self.device.
|
||||||
self.model_runner = NPUModelRunner(self.vllm_config, self.device)
|
self.model_runner = NPUModelRunner(self.vllm_config, device)
|
||||||
|
|
||||||
def determine_available_memory(self) -> int:
|
def determine_available_memory(self) -> int:
|
||||||
# Profile the memory usage of the model and get the maximum number of
|
# Profile the memory usage of the model and get the maximum number of
|
||||||
@@ -205,7 +198,7 @@ class NPUWorker(WorkerBase):
|
|||||||
self.model_runner.capture_model()
|
self.model_runner.capture_model()
|
||||||
# Reset the seed to ensure that the random state is not affected by
|
# Reset the seed to ensure that the random state is not affected by
|
||||||
# the model initialization and profiling.
|
# the model initialization and profiling.
|
||||||
set_random_seed(self.model_config.seed)
|
NPUPlatform.seed_everything(self.model_config.seed)
|
||||||
|
|
||||||
def get_model(self) -> nn.Module:
|
def get_model(self) -> nn.Module:
|
||||||
return self.model_runner.get_model()
|
return self.model_runner.get_model()
|
||||||
@@ -261,8 +254,6 @@ class NPUWorker(WorkerBase):
|
|||||||
def _init_worker_distributed_environment(self) -> None:
|
def _init_worker_distributed_environment(self) -> None:
|
||||||
"""Initialize the distributed environment."""
|
"""Initialize the distributed environment."""
|
||||||
parallel_config = self.vllm_config.parallel_config
|
parallel_config = self.vllm_config.parallel_config
|
||||||
set_custom_all_reduce(
|
|
||||||
not self.parallel_config.disable_custom_all_reduce)
|
|
||||||
init_distributed_environment(self.parallel_config.world_size,
|
init_distributed_environment(self.parallel_config.world_size,
|
||||||
self.rank, self.distributed_init_method,
|
self.rank, self.distributed_init_method,
|
||||||
self.local_rank, "hccl")
|
self.local_rank, "hccl")
|
||||||
|
|||||||
Reference in New Issue
Block a user