[Model] Support DeepSeek-V4
This commit is contained in:
3
vllm_mlu/engine/__init__.py
Normal file
3
vllm_mlu/engine/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
294
vllm_mlu/engine/arg_utils.py
Normal file
294
vllm_mlu/engine/arg_utils.py
Normal file
@@ -0,0 +1,294 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import get_args
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.config import (
|
||||
ModelConfig,
|
||||
VllmConfig,
|
||||
SchedulerConfig,
|
||||
)
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.engine.arg_utils import (
|
||||
EngineArgs,
|
||||
_raise_unsupported_error,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
import vllm_mlu._mlu_utils as mlu_envs
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@classmethod
|
||||
def vllm__engine__arg_utils__EngineArgs__get_chunked_prefill_prefix_caching_defaults(
|
||||
cls,
|
||||
model_config: ModelConfig,
|
||||
) -> tuple[bool, bool]:
|
||||
if model_config.runner_type != "pooling":
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: mlu-v1 default use unchunked scheduler
|
||||
'''
|
||||
if mlu_envs.VLLM_V1_USE_UNCHUNK_SCHED:
|
||||
default_chunked_prefill = False
|
||||
else:
|
||||
default_chunked_prefill = True
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
# Disable prefix caching default for hybrid models
|
||||
# since the feature is still experimental.
|
||||
default_prefix_caching = not model_config.is_hybrid
|
||||
else:
|
||||
assert model_config.pooler_config is not None
|
||||
|
||||
pooling_type = model_config.pooler_config.pooling_type
|
||||
incremental_prefill_supported = (
|
||||
pooling_type is not None
|
||||
and pooling_type.lower() == "last"
|
||||
and getattr(model_config.hf_config, "is_causal", True)
|
||||
)
|
||||
|
||||
default_chunked_prefill = incremental_prefill_supported
|
||||
default_prefix_caching = incremental_prefill_supported
|
||||
|
||||
return default_chunked_prefill, default_prefix_caching
|
||||
|
||||
def vllm__engine__arg_utils__EngineArgs___set_default_args(
|
||||
self, usage_context: UsageContext, model_config: ModelConfig
|
||||
) -> None:
|
||||
"""Set Default Arguments for V1 Engine."""
|
||||
(
|
||||
default_chunked_prefill,
|
||||
default_prefix_caching,
|
||||
) = self.get_chunked_prefill_prefix_caching_defaults(model_config)
|
||||
|
||||
if self.enable_chunked_prefill is None:
|
||||
self.enable_chunked_prefill = default_chunked_prefill
|
||||
|
||||
logger.debug(
|
||||
"%s chunked prefill by default",
|
||||
"Enabling" if default_chunked_prefill else "Disabling",
|
||||
)
|
||||
elif (
|
||||
model_config.runner_type == "pooling"
|
||||
and self.enable_chunked_prefill
|
||||
and not default_chunked_prefill
|
||||
):
|
||||
logger.warning(
|
||||
"This model does not officially support chunked prefill. "
|
||||
"Enabling this manually may cause the engine to crash "
|
||||
"or produce incorrect outputs.",
|
||||
)
|
||||
|
||||
if self.enable_prefix_caching is None:
|
||||
self.enable_prefix_caching = default_prefix_caching
|
||||
|
||||
logger.debug(
|
||||
"%s prefix caching by default",
|
||||
"Enabling" if default_prefix_caching else "Disabling",
|
||||
)
|
||||
elif (
|
||||
model_config.runner_type == "pooling"
|
||||
and self.enable_prefix_caching
|
||||
and not default_prefix_caching
|
||||
):
|
||||
logger.warning(
|
||||
"This model does not officially support prefix caching. "
|
||||
"Enabling this manually may cause the engine to crash "
|
||||
"or produce incorrect outputs.",
|
||||
)
|
||||
|
||||
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
|
||||
(
|
||||
default_max_num_batched_tokens,
|
||||
default_max_num_seqs,
|
||||
) = self.get_batch_defaults(world_size)
|
||||
|
||||
orig_max_num_batched_tokens = self.max_num_batched_tokens
|
||||
orig_max_num_seqs = self.max_num_seqs
|
||||
|
||||
if self.max_num_seqs is None:
|
||||
self.max_num_seqs = default_max_num_seqs.get(
|
||||
usage_context,
|
||||
SchedulerConfig.DEFAULT_MAX_NUM_SEQS,
|
||||
)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: only set max_num_batched_tokens when enable chunked_prefill
|
||||
'''
|
||||
if self.max_num_batched_tokens is None:
|
||||
self.max_num_batched_tokens = default_max_num_batched_tokens.get(
|
||||
usage_context,
|
||||
SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
|
||||
if orig_max_num_batched_tokens is None:
|
||||
if not self.enable_chunked_prefill:
|
||||
# If max_model_len is too short, use the default for higher throughput.
|
||||
self.max_num_batched_tokens = max(
|
||||
model_config.max_model_len,
|
||||
self.max_num_batched_tokens,
|
||||
)
|
||||
|
||||
# When using default settings,
|
||||
# Ensure max_num_batched_tokens does not exceed model limit.
|
||||
# Some models (e.g., Whisper) have embeddings tied to max length.
|
||||
self.max_num_batched_tokens = min(
|
||||
self.max_num_seqs * model_config.max_model_len,
|
||||
self.max_num_batched_tokens,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Defaulting max_num_batched_tokens to %d for %s usage context.",
|
||||
self.max_num_batched_tokens,
|
||||
usage_context.value if usage_context else None,
|
||||
)
|
||||
|
||||
if orig_max_num_seqs is None:
|
||||
if self.max_num_batched_tokens is not None: # For type checking
|
||||
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
|
||||
|
||||
logger.debug(
|
||||
"Defaulting max_num_seqs to %d for %s usage context.",
|
||||
self.max_num_seqs,
|
||||
usage_context.value if usage_context else None,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
_VALID_QUANT_ATTN_QKV_DTYPE = ['int8', 'fp8', 'fp8_e4m3']
|
||||
|
||||
def vllm__engine__arg_utils__EngineArgs__create_engine_config(
|
||||
self,
|
||||
usage_context: UsageContext | None = None,
|
||||
headless: bool = False,
|
||||
) -> VllmConfig:
|
||||
"""
|
||||
Create the VllmConfig.
|
||||
|
||||
NOTE: If VllmConfig is incompatible, we raise an error.
|
||||
"""
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add data parallel params to parallel config.
|
||||
'''
|
||||
if self.mlu_config and "decoder_attn_dtype" in self.mlu_config:
|
||||
if self.mlu_config.get("decoder_attn_dtype") in ["int8", "fp8", "fp8_e4m3"]:
|
||||
self.kv_cache_dtype = self.mlu_config.get("decoder_attn_dtype")
|
||||
|
||||
engine_config = vllm__engine__arg_utils__EngineArgs__create_engine_config_org(
|
||||
self, usage_context, headless)
|
||||
|
||||
world_size = engine_config.parallel_config.world_size_across_dp
|
||||
tensor_parallel_size = engine_config.parallel_config.tensor_parallel_size
|
||||
embedding_tp_size = engine_config.mlu_config.layer_embedding_logit_tp_size
|
||||
if embedding_tp_size:
|
||||
assert embedding_tp_size >= tensor_parallel_size and embedding_tp_size <= world_size, (
|
||||
f"embedding_tp_size = {embedding_tp_size} out of bounds. "
|
||||
f"Require {tensor_parallel_size} ≤ size ≤ {world_size}")
|
||||
dense_mlp_tp_size = engine_config.mlu_config.layer_dense_mlp_tp_size
|
||||
if dense_mlp_tp_size:
|
||||
assert dense_mlp_tp_size >= 1 and dense_mlp_tp_size <= world_size, (
|
||||
f"dense_mlp_tp_size = {dense_mlp_tp_size} out of bounds. Require 1 ≤ size ≤ {world_size}")
|
||||
if dense_mlp_tp_size != world_size:
|
||||
assert not engine_config.mlu_config.is_dpsk_mcc_enabled, (
|
||||
"dense_mlp_tp_size is not supported when dpsk mcc is enabled.")
|
||||
if engine_config.model_config.is_longcat_flash and tensor_parallel_size > 1:
|
||||
raise ValueError("For now, for longcat model, custom dense mlp tp split in data parallel requires dpXtp1. "
|
||||
"Necessity of this constraint requires further investigation.")
|
||||
if engine_config.model_config.is_longcat_flash and dense_mlp_tp_size < tensor_parallel_size:
|
||||
raise ValueError(f"For longcat model, custom dense mlp tp_size {dense_mlp_tp_size} "
|
||||
f"must be greater than or equal to tensor_parallel_size {tensor_parallel_size}")
|
||||
if engine_config.model_config.is_deepseek_mla and dense_mlp_tp_size % tensor_parallel_size != 0:
|
||||
raise ValueError(f"For deepseek mla model, custom mlp tp size {dense_mlp_tp_size} must "
|
||||
f"be divisible by {tensor_parallel_size}")
|
||||
|
||||
if ((engine_config.parallel_config.data_parallel_size > 1 or engine_config.speculative_config is not None
|
||||
or engine_config.mlu_config.prefill_use_sequence_parallel) and engine_config.mlu_config.prefill_enable_mlugraph):
|
||||
logger.info("Data parallel or sequence parallel or speculative is enabled, forcing context mlugraph to be disabled.")
|
||||
engine_config.mlu_config.prefill_enable_mlugraph = False
|
||||
if engine_config.mlu_config.decoder_attn_dtype:
|
||||
if engine_config.mlu_config.decoder_attn_dtype not in get_args(CacheDType):
|
||||
raise ValueError(f"MLU backend does not support {engine_config.mlu_config.decoder_attn_dtype} "
|
||||
f"decoder_attn_dtype for now")
|
||||
is_glm4_moe = (hasattr(engine_config.model_config.hf_text_config, "model_type") and
|
||||
engine_config.model_config.hf_text_config.model_type == "glm4_moe")
|
||||
if (not (engine_config.model_config.is_deepseek_mla or is_glm4_moe)
|
||||
and engine_config.mlu_config.decoder_attn_dtype != "auto"):
|
||||
raise ValueError(f"mlu_config.decoder_attn_dtype only support deepseek_mla and glm4_moe model")
|
||||
|
||||
# sequence parallel checks
|
||||
if (engine_config.mlu_config.prefill_use_sequence_parallel
|
||||
and engine_config.model_config.hf_text_config.model_type not in ["deepseek_v32", "deepseek_v3"]):
|
||||
raise ValueError("Prefill sequence parallel can only use in deepseek model.")
|
||||
if engine_config.mlu_config.prefill_use_sequence_parallel and engine_config.scheduler_config.enable_chunked_prefill:
|
||||
raise ValueError("Prefill sequence parallel can not use with chunked prefill for now.")
|
||||
if engine_config.mlu_config.prefill_use_sequence_parallel and engine_config.mlu_config.is_dpsk_mcc_enabled:
|
||||
raise ValueError("Prefill sequence parallel can not use with mcc.")
|
||||
if engine_config.mlu_config.prefill_use_sequence_parallel and engine_config.parallel_config.data_parallel_size > 1:
|
||||
raise ValueError("Prefill sequence parallel can not use with data parallel.")
|
||||
if (engine_config.mlu_config.prefill_use_sequence_parallel
|
||||
and engine_config.model_config.hf_text_config.model_type == "deepseek_v3"
|
||||
and engine_config.quant_config.get_name() != "SmoothQuant"):
|
||||
raise ValueError("Prefill sequence parallel can only use SmoothQuant for deepseek_v3.")
|
||||
|
||||
# disagg constraint
|
||||
# 1、only support deepseek-v3/r1
|
||||
# 2、unsupport kv8
|
||||
if self.kv_transfer_config is not None:
|
||||
if engine_config.model_config.hf_config.model_type != "deepseek_v3":
|
||||
raise ValueError("Disagg only support DeepDeek-V3/R1")
|
||||
if engine_config.cache_config.cache_dtype == "int8":
|
||||
raise ValueError("Disagg does not support KV cache dtype is int8")
|
||||
if engine_config.cache_config.enable_prefix_caching:
|
||||
raise ValueError("Disagg does not support prefix caching")
|
||||
|
||||
if isinstance(self.kv_transfer_config, dict):
|
||||
kv_connector = self.kv_transfer_config.get("kv_connector")
|
||||
kv_role = self.kv_transfer_config.get("kv_role")
|
||||
else:
|
||||
kv_connector = self.kv_transfer_config.kv_connector
|
||||
kv_role = self.kv_transfer_config.kv_role
|
||||
|
||||
if kv_connector != "LMCacheConnectorV1":
|
||||
raise ValueError("Disagg only support LMCacheConnectorV1 connector")
|
||||
if kv_role == "kv_consumer":
|
||||
if not self.enable_chunked_prefill:
|
||||
raise ValueError("Disagg decoder only support chunk scheduler")
|
||||
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return engine_config
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(EngineArgs,
|
||||
EngineArgs._set_default_args,
|
||||
vllm__engine__arg_utils__EngineArgs___set_default_args)
|
||||
MluHijackObject.apply_hijack(EngineArgs,
|
||||
EngineArgs.create_engine_config,
|
||||
vllm__engine__arg_utils__EngineArgs__create_engine_config)
|
||||
MluHijackObject.apply_hijack(EngineArgs,
|
||||
EngineArgs.get_chunked_prefill_prefix_caching_defaults,
|
||||
vllm__engine__arg_utils__EngineArgs__get_chunked_prefill_prefix_caching_defaults)
|
||||
Reference in New Issue
Block a user