295 lines
12 KiB
Python
295 lines
12 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
|
|
|
from typing import get_args
|
|
|
|
from vllm.platforms import current_platform
|
|
from vllm.config import (
|
|
ModelConfig,
|
|
VllmConfig,
|
|
SchedulerConfig,
|
|
)
|
|
from vllm.config.cache import CacheDType
|
|
from vllm.engine.arg_utils import (
|
|
EngineArgs,
|
|
_raise_unsupported_error,
|
|
)
|
|
from vllm.logger import init_logger
|
|
from vllm.usage.usage_lib import UsageContext
|
|
|
|
import vllm_mlu._mlu_utils as mlu_envs
|
|
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
@classmethod
|
|
def vllm__engine__arg_utils__EngineArgs__get_chunked_prefill_prefix_caching_defaults(
|
|
cls,
|
|
model_config: ModelConfig,
|
|
) -> tuple[bool, bool]:
|
|
if model_config.runner_type != "pooling":
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: mlu-v1 default use unchunked scheduler
|
|
'''
|
|
if mlu_envs.VLLM_V1_USE_UNCHUNK_SCHED:
|
|
default_chunked_prefill = False
|
|
else:
|
|
default_chunked_prefill = True
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
# Disable prefix caching default for hybrid models
|
|
# since the feature is still experimental.
|
|
default_prefix_caching = not model_config.is_hybrid
|
|
else:
|
|
assert model_config.pooler_config is not None
|
|
|
|
pooling_type = model_config.pooler_config.pooling_type
|
|
incremental_prefill_supported = (
|
|
pooling_type is not None
|
|
and pooling_type.lower() == "last"
|
|
and getattr(model_config.hf_config, "is_causal", True)
|
|
)
|
|
|
|
default_chunked_prefill = incremental_prefill_supported
|
|
default_prefix_caching = incremental_prefill_supported
|
|
|
|
return default_chunked_prefill, default_prefix_caching
|
|
|
|
def vllm__engine__arg_utils__EngineArgs___set_default_args(
|
|
self, usage_context: UsageContext, model_config: ModelConfig
|
|
) -> None:
|
|
"""Set Default Arguments for V1 Engine."""
|
|
(
|
|
default_chunked_prefill,
|
|
default_prefix_caching,
|
|
) = self.get_chunked_prefill_prefix_caching_defaults(model_config)
|
|
|
|
if self.enable_chunked_prefill is None:
|
|
self.enable_chunked_prefill = default_chunked_prefill
|
|
|
|
logger.debug(
|
|
"%s chunked prefill by default",
|
|
"Enabling" if default_chunked_prefill else "Disabling",
|
|
)
|
|
elif (
|
|
model_config.runner_type == "pooling"
|
|
and self.enable_chunked_prefill
|
|
and not default_chunked_prefill
|
|
):
|
|
logger.warning(
|
|
"This model does not officially support chunked prefill. "
|
|
"Enabling this manually may cause the engine to crash "
|
|
"or produce incorrect outputs.",
|
|
)
|
|
|
|
if self.enable_prefix_caching is None:
|
|
self.enable_prefix_caching = default_prefix_caching
|
|
|
|
logger.debug(
|
|
"%s prefix caching by default",
|
|
"Enabling" if default_prefix_caching else "Disabling",
|
|
)
|
|
elif (
|
|
model_config.runner_type == "pooling"
|
|
and self.enable_prefix_caching
|
|
and not default_prefix_caching
|
|
):
|
|
logger.warning(
|
|
"This model does not officially support prefix caching. "
|
|
"Enabling this manually may cause the engine to crash "
|
|
"or produce incorrect outputs.",
|
|
)
|
|
|
|
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
|
|
(
|
|
default_max_num_batched_tokens,
|
|
default_max_num_seqs,
|
|
) = self.get_batch_defaults(world_size)
|
|
|
|
orig_max_num_batched_tokens = self.max_num_batched_tokens
|
|
orig_max_num_seqs = self.max_num_seqs
|
|
|
|
if self.max_num_seqs is None:
|
|
self.max_num_seqs = default_max_num_seqs.get(
|
|
usage_context,
|
|
SchedulerConfig.DEFAULT_MAX_NUM_SEQS,
|
|
)
|
|
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: only set max_num_batched_tokens when enable chunked_prefill
|
|
'''
|
|
if self.max_num_batched_tokens is None:
|
|
self.max_num_batched_tokens = default_max_num_batched_tokens.get(
|
|
usage_context,
|
|
SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
|
)
|
|
|
|
if orig_max_num_batched_tokens is None:
|
|
if not self.enable_chunked_prefill:
|
|
# If max_model_len is too short, use the default for higher throughput.
|
|
self.max_num_batched_tokens = max(
|
|
model_config.max_model_len,
|
|
self.max_num_batched_tokens,
|
|
)
|
|
|
|
# When using default settings,
|
|
# Ensure max_num_batched_tokens does not exceed model limit.
|
|
# Some models (e.g., Whisper) have embeddings tied to max length.
|
|
self.max_num_batched_tokens = min(
|
|
self.max_num_seqs * model_config.max_model_len,
|
|
self.max_num_batched_tokens,
|
|
)
|
|
|
|
logger.debug(
|
|
"Defaulting max_num_batched_tokens to %d for %s usage context.",
|
|
self.max_num_batched_tokens,
|
|
usage_context.value if usage_context else None,
|
|
)
|
|
|
|
if orig_max_num_seqs is None:
|
|
if self.max_num_batched_tokens is not None: # For type checking
|
|
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
|
|
|
|
logger.debug(
|
|
"Defaulting max_num_seqs to %d for %s usage context.",
|
|
self.max_num_seqs,
|
|
usage_context.value if usage_context else None,
|
|
)
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
|
|
_VALID_QUANT_ATTN_QKV_DTYPE = ['int8', 'fp8', 'fp8_e4m3']
|
|
|
|
def vllm__engine__arg_utils__EngineArgs__create_engine_config(
|
|
self,
|
|
usage_context: UsageContext | None = None,
|
|
headless: bool = False,
|
|
) -> VllmConfig:
|
|
"""
|
|
Create the VllmConfig.
|
|
|
|
NOTE: If VllmConfig is incompatible, we raise an error.
|
|
"""
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: add data parallel params to parallel config.
|
|
'''
|
|
if self.mlu_config and "decoder_attn_dtype" in self.mlu_config:
|
|
if self.mlu_config.get("decoder_attn_dtype") in ["int8", "fp8", "fp8_e4m3"]:
|
|
self.kv_cache_dtype = self.mlu_config.get("decoder_attn_dtype")
|
|
|
|
engine_config = vllm__engine__arg_utils__EngineArgs__create_engine_config_org(
|
|
self, usage_context, headless)
|
|
|
|
world_size = engine_config.parallel_config.world_size_across_dp
|
|
tensor_parallel_size = engine_config.parallel_config.tensor_parallel_size
|
|
embedding_tp_size = engine_config.mlu_config.layer_embedding_logit_tp_size
|
|
if embedding_tp_size:
|
|
assert embedding_tp_size >= tensor_parallel_size and embedding_tp_size <= world_size, (
|
|
f"embedding_tp_size = {embedding_tp_size} out of bounds. "
|
|
f"Require {tensor_parallel_size} ≤ size ≤ {world_size}")
|
|
dense_mlp_tp_size = engine_config.mlu_config.layer_dense_mlp_tp_size
|
|
if dense_mlp_tp_size:
|
|
assert dense_mlp_tp_size >= 1 and dense_mlp_tp_size <= world_size, (
|
|
f"dense_mlp_tp_size = {dense_mlp_tp_size} out of bounds. Require 1 ≤ size ≤ {world_size}")
|
|
if dense_mlp_tp_size != world_size:
|
|
assert not engine_config.mlu_config.is_dpsk_mcc_enabled, (
|
|
"dense_mlp_tp_size is not supported when dpsk mcc is enabled.")
|
|
if engine_config.model_config.is_longcat_flash and tensor_parallel_size > 1:
|
|
raise ValueError("For now, for longcat model, custom dense mlp tp split in data parallel requires dpXtp1. "
|
|
"Necessity of this constraint requires further investigation.")
|
|
if engine_config.model_config.is_longcat_flash and dense_mlp_tp_size < tensor_parallel_size:
|
|
raise ValueError(f"For longcat model, custom dense mlp tp_size {dense_mlp_tp_size} "
|
|
f"must be greater than or equal to tensor_parallel_size {tensor_parallel_size}")
|
|
if engine_config.model_config.is_deepseek_mla and dense_mlp_tp_size % tensor_parallel_size != 0:
|
|
raise ValueError(f"For deepseek mla model, custom mlp tp size {dense_mlp_tp_size} must "
|
|
f"be divisible by {tensor_parallel_size}")
|
|
|
|
if ((engine_config.parallel_config.data_parallel_size > 1 or engine_config.speculative_config is not None
|
|
or engine_config.mlu_config.prefill_use_sequence_parallel) and engine_config.mlu_config.prefill_enable_mlugraph):
|
|
logger.info("Data parallel or sequence parallel or speculative is enabled, forcing context mlugraph to be disabled.")
|
|
engine_config.mlu_config.prefill_enable_mlugraph = False
|
|
if engine_config.mlu_config.decoder_attn_dtype:
|
|
if engine_config.mlu_config.decoder_attn_dtype not in get_args(CacheDType):
|
|
raise ValueError(f"MLU backend does not support {engine_config.mlu_config.decoder_attn_dtype} "
|
|
f"decoder_attn_dtype for now")
|
|
is_glm4_moe = (hasattr(engine_config.model_config.hf_text_config, "model_type") and
|
|
engine_config.model_config.hf_text_config.model_type == "glm4_moe")
|
|
if (not (engine_config.model_config.is_deepseek_mla or is_glm4_moe)
|
|
and engine_config.mlu_config.decoder_attn_dtype != "auto"):
|
|
raise ValueError(f"mlu_config.decoder_attn_dtype only support deepseek_mla and glm4_moe model")
|
|
|
|
# sequence parallel checks
|
|
if (engine_config.mlu_config.prefill_use_sequence_parallel
|
|
and engine_config.model_config.hf_text_config.model_type not in ["deepseek_v32", "deepseek_v3"]):
|
|
raise ValueError("Prefill sequence parallel can only use in deepseek model.")
|
|
if engine_config.mlu_config.prefill_use_sequence_parallel and engine_config.scheduler_config.enable_chunked_prefill:
|
|
raise ValueError("Prefill sequence parallel can not use with chunked prefill for now.")
|
|
if engine_config.mlu_config.prefill_use_sequence_parallel and engine_config.mlu_config.is_dpsk_mcc_enabled:
|
|
raise ValueError("Prefill sequence parallel can not use with mcc.")
|
|
if engine_config.mlu_config.prefill_use_sequence_parallel and engine_config.parallel_config.data_parallel_size > 1:
|
|
raise ValueError("Prefill sequence parallel can not use with data parallel.")
|
|
if (engine_config.mlu_config.prefill_use_sequence_parallel
|
|
and engine_config.model_config.hf_text_config.model_type == "deepseek_v3"
|
|
and engine_config.quant_config.get_name() != "SmoothQuant"):
|
|
raise ValueError("Prefill sequence parallel can only use SmoothQuant for deepseek_v3.")
|
|
|
|
# disagg constraint
|
|
# 1、only support deepseek-v3/r1
|
|
# 2、unsupport kv8
|
|
if self.kv_transfer_config is not None:
|
|
if engine_config.model_config.hf_config.model_type != "deepseek_v3":
|
|
raise ValueError("Disagg only support DeepDeek-V3/R1")
|
|
if engine_config.cache_config.cache_dtype == "int8":
|
|
raise ValueError("Disagg does not support KV cache dtype is int8")
|
|
if engine_config.cache_config.enable_prefix_caching:
|
|
raise ValueError("Disagg does not support prefix caching")
|
|
|
|
if isinstance(self.kv_transfer_config, dict):
|
|
kv_connector = self.kv_transfer_config.get("kv_connector")
|
|
kv_role = self.kv_transfer_config.get("kv_role")
|
|
else:
|
|
kv_connector = self.kv_transfer_config.kv_connector
|
|
kv_role = self.kv_transfer_config.kv_role
|
|
|
|
if kv_connector != "LMCacheConnectorV1":
|
|
raise ValueError("Disagg only support LMCacheConnectorV1 connector")
|
|
if kv_role == "kv_consumer":
|
|
if not self.enable_chunked_prefill:
|
|
raise ValueError("Disagg decoder only support chunk scheduler")
|
|
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
return engine_config
|
|
|
|
|
|
MluHijackObject.apply_hijack(EngineArgs,
|
|
EngineArgs._set_default_args,
|
|
vllm__engine__arg_utils__EngineArgs___set_default_args)
|
|
MluHijackObject.apply_hijack(EngineArgs,
|
|
EngineArgs.create_engine_config,
|
|
vllm__engine__arg_utils__EngineArgs__create_engine_config)
|
|
MluHijackObject.apply_hijack(EngineArgs,
|
|
EngineArgs.get_chunked_prefill_prefix_caching_defaults,
|
|
vllm__engine__arg_utils__EngineArgs__get_chunked_prefill_prefix_caching_defaults)
|