[Model] Support DeepSeek-V4

This commit is contained in:
chenxb002
2026-04-24 09:50:34 +08:00
commit b9925203b8
172 changed files with 44780 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,294 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import get_args
from vllm.platforms import current_platform
from vllm.config import (
ModelConfig,
VllmConfig,
SchedulerConfig,
)
from vllm.config.cache import CacheDType
from vllm.engine.arg_utils import (
EngineArgs,
_raise_unsupported_error,
)
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
import vllm_mlu._mlu_utils as mlu_envs
from vllm_mlu.mlu_hijack_utils import MluHijackObject
logger = init_logger(__name__)
@classmethod
def vllm__engine__arg_utils__EngineArgs__get_chunked_prefill_prefix_caching_defaults(
cls,
model_config: ModelConfig,
) -> tuple[bool, bool]:
if model_config.runner_type != "pooling":
'''
=============================
Modify by vllm_mlu
=============================
@brief: mlu-v1 default use unchunked scheduler
'''
if mlu_envs.VLLM_V1_USE_UNCHUNK_SCHED:
default_chunked_prefill = False
else:
default_chunked_prefill = True
'''
==================
End of MLU Hijack
==================
'''
# Disable prefix caching default for hybrid models
# since the feature is still experimental.
default_prefix_caching = not model_config.is_hybrid
else:
assert model_config.pooler_config is not None
pooling_type = model_config.pooler_config.pooling_type
incremental_prefill_supported = (
pooling_type is not None
and pooling_type.lower() == "last"
and getattr(model_config.hf_config, "is_causal", True)
)
default_chunked_prefill = incremental_prefill_supported
default_prefix_caching = incremental_prefill_supported
return default_chunked_prefill, default_prefix_caching
def vllm__engine__arg_utils__EngineArgs___set_default_args(
self, usage_context: UsageContext, model_config: ModelConfig
) -> None:
"""Set Default Arguments for V1 Engine."""
(
default_chunked_prefill,
default_prefix_caching,
) = self.get_chunked_prefill_prefix_caching_defaults(model_config)
if self.enable_chunked_prefill is None:
self.enable_chunked_prefill = default_chunked_prefill
logger.debug(
"%s chunked prefill by default",
"Enabling" if default_chunked_prefill else "Disabling",
)
elif (
model_config.runner_type == "pooling"
and self.enable_chunked_prefill
and not default_chunked_prefill
):
logger.warning(
"This model does not officially support chunked prefill. "
"Enabling this manually may cause the engine to crash "
"or produce incorrect outputs.",
)
if self.enable_prefix_caching is None:
self.enable_prefix_caching = default_prefix_caching
logger.debug(
"%s prefix caching by default",
"Enabling" if default_prefix_caching else "Disabling",
)
elif (
model_config.runner_type == "pooling"
and self.enable_prefix_caching
and not default_prefix_caching
):
logger.warning(
"This model does not officially support prefix caching. "
"Enabling this manually may cause the engine to crash "
"or produce incorrect outputs.",
)
world_size = self.pipeline_parallel_size * self.tensor_parallel_size
(
default_max_num_batched_tokens,
default_max_num_seqs,
) = self.get_batch_defaults(world_size)
orig_max_num_batched_tokens = self.max_num_batched_tokens
orig_max_num_seqs = self.max_num_seqs
if self.max_num_seqs is None:
self.max_num_seqs = default_max_num_seqs.get(
usage_context,
SchedulerConfig.DEFAULT_MAX_NUM_SEQS,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: only set max_num_batched_tokens when enable chunked_prefill
'''
if self.max_num_batched_tokens is None:
self.max_num_batched_tokens = default_max_num_batched_tokens.get(
usage_context,
SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS,
)
if orig_max_num_batched_tokens is None:
if not self.enable_chunked_prefill:
# If max_model_len is too short, use the default for higher throughput.
self.max_num_batched_tokens = max(
model_config.max_model_len,
self.max_num_batched_tokens,
)
# When using default settings,
# Ensure max_num_batched_tokens does not exceed model limit.
# Some models (e.g., Whisper) have embeddings tied to max length.
self.max_num_batched_tokens = min(
self.max_num_seqs * model_config.max_model_len,
self.max_num_batched_tokens,
)
logger.debug(
"Defaulting max_num_batched_tokens to %d for %s usage context.",
self.max_num_batched_tokens,
usage_context.value if usage_context else None,
)
if orig_max_num_seqs is None:
if self.max_num_batched_tokens is not None: # For type checking
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
logger.debug(
"Defaulting max_num_seqs to %d for %s usage context.",
self.max_num_seqs,
usage_context.value if usage_context else None,
)
'''
==================
End of MLU Hijack
==================
'''
_VALID_QUANT_ATTN_QKV_DTYPE = ['int8', 'fp8', 'fp8_e4m3']
def vllm__engine__arg_utils__EngineArgs__create_engine_config(
self,
usage_context: UsageContext | None = None,
headless: bool = False,
) -> VllmConfig:
"""
Create the VllmConfig.
NOTE: If VllmConfig is incompatible, we raise an error.
"""
'''
=============================
Modify by vllm_mlu
=============================
@brief: add data parallel params to parallel config.
'''
if self.mlu_config and "decoder_attn_dtype" in self.mlu_config:
if self.mlu_config.get("decoder_attn_dtype") in ["int8", "fp8", "fp8_e4m3"]:
self.kv_cache_dtype = self.mlu_config.get("decoder_attn_dtype")
engine_config = vllm__engine__arg_utils__EngineArgs__create_engine_config_org(
self, usage_context, headless)
world_size = engine_config.parallel_config.world_size_across_dp
tensor_parallel_size = engine_config.parallel_config.tensor_parallel_size
embedding_tp_size = engine_config.mlu_config.layer_embedding_logit_tp_size
if embedding_tp_size:
assert embedding_tp_size >= tensor_parallel_size and embedding_tp_size <= world_size, (
f"embedding_tp_size = {embedding_tp_size} out of bounds. "
f"Require {tensor_parallel_size} ≤ size ≤ {world_size}")
dense_mlp_tp_size = engine_config.mlu_config.layer_dense_mlp_tp_size
if dense_mlp_tp_size:
assert dense_mlp_tp_size >= 1 and dense_mlp_tp_size <= world_size, (
f"dense_mlp_tp_size = {dense_mlp_tp_size} out of bounds. Require 1 ≤ size ≤ {world_size}")
if dense_mlp_tp_size != world_size:
assert not engine_config.mlu_config.is_dpsk_mcc_enabled, (
"dense_mlp_tp_size is not supported when dpsk mcc is enabled.")
if engine_config.model_config.is_longcat_flash and tensor_parallel_size > 1:
raise ValueError("For now, for longcat model, custom dense mlp tp split in data parallel requires dpXtp1. "
"Necessity of this constraint requires further investigation.")
if engine_config.model_config.is_longcat_flash and dense_mlp_tp_size < tensor_parallel_size:
raise ValueError(f"For longcat model, custom dense mlp tp_size {dense_mlp_tp_size} "
f"must be greater than or equal to tensor_parallel_size {tensor_parallel_size}")
if engine_config.model_config.is_deepseek_mla and dense_mlp_tp_size % tensor_parallel_size != 0:
raise ValueError(f"For deepseek mla model, custom mlp tp size {dense_mlp_tp_size} must "
f"be divisible by {tensor_parallel_size}")
if ((engine_config.parallel_config.data_parallel_size > 1 or engine_config.speculative_config is not None
or engine_config.mlu_config.prefill_use_sequence_parallel) and engine_config.mlu_config.prefill_enable_mlugraph):
logger.info("Data parallel or sequence parallel or speculative is enabled, forcing context mlugraph to be disabled.")
engine_config.mlu_config.prefill_enable_mlugraph = False
if engine_config.mlu_config.decoder_attn_dtype:
if engine_config.mlu_config.decoder_attn_dtype not in get_args(CacheDType):
raise ValueError(f"MLU backend does not support {engine_config.mlu_config.decoder_attn_dtype} "
f"decoder_attn_dtype for now")
is_glm4_moe = (hasattr(engine_config.model_config.hf_text_config, "model_type") and
engine_config.model_config.hf_text_config.model_type == "glm4_moe")
if (not (engine_config.model_config.is_deepseek_mla or is_glm4_moe)
and engine_config.mlu_config.decoder_attn_dtype != "auto"):
raise ValueError(f"mlu_config.decoder_attn_dtype only support deepseek_mla and glm4_moe model")
# sequence parallel checks
if (engine_config.mlu_config.prefill_use_sequence_parallel
and engine_config.model_config.hf_text_config.model_type not in ["deepseek_v32", "deepseek_v3"]):
raise ValueError("Prefill sequence parallel can only use in deepseek model.")
if engine_config.mlu_config.prefill_use_sequence_parallel and engine_config.scheduler_config.enable_chunked_prefill:
raise ValueError("Prefill sequence parallel can not use with chunked prefill for now.")
if engine_config.mlu_config.prefill_use_sequence_parallel and engine_config.mlu_config.is_dpsk_mcc_enabled:
raise ValueError("Prefill sequence parallel can not use with mcc.")
if engine_config.mlu_config.prefill_use_sequence_parallel and engine_config.parallel_config.data_parallel_size > 1:
raise ValueError("Prefill sequence parallel can not use with data parallel.")
if (engine_config.mlu_config.prefill_use_sequence_parallel
and engine_config.model_config.hf_text_config.model_type == "deepseek_v3"
and engine_config.quant_config.get_name() != "SmoothQuant"):
raise ValueError("Prefill sequence parallel can only use SmoothQuant for deepseek_v3.")
# disagg constraint
# 1、only support deepseek-v3/r1
# 2、unsupport kv8
if self.kv_transfer_config is not None:
if engine_config.model_config.hf_config.model_type != "deepseek_v3":
raise ValueError("Disagg only support DeepDeek-V3/R1")
if engine_config.cache_config.cache_dtype == "int8":
raise ValueError("Disagg does not support KV cache dtype is int8")
if engine_config.cache_config.enable_prefix_caching:
raise ValueError("Disagg does not support prefix caching")
if isinstance(self.kv_transfer_config, dict):
kv_connector = self.kv_transfer_config.get("kv_connector")
kv_role = self.kv_transfer_config.get("kv_role")
else:
kv_connector = self.kv_transfer_config.kv_connector
kv_role = self.kv_transfer_config.kv_role
if kv_connector != "LMCacheConnectorV1":
raise ValueError("Disagg only support LMCacheConnectorV1 connector")
if kv_role == "kv_consumer":
if not self.enable_chunked_prefill:
raise ValueError("Disagg decoder only support chunk scheduler")
'''
==================
End of MLU Hijack
==================
'''
return engine_config
MluHijackObject.apply_hijack(EngineArgs,
EngineArgs._set_default_args,
vllm__engine__arg_utils__EngineArgs___set_default_args)
MluHijackObject.apply_hijack(EngineArgs,
EngineArgs.create_engine_config,
vllm__engine__arg_utils__EngineArgs__create_engine_config)
MluHijackObject.apply_hijack(EngineArgs,
EngineArgs.get_chunked_prefill_prefix_caching_defaults,
vllm__engine__arg_utils__EngineArgs__get_chunked_prefill_prefix_caching_defaults)