[Bugfix] fix env variable in dbo (#1284)

### What this PR does / why we need it?
Fix env variable in dbo to enable dbo in DeepSeek-V3 model. Besides, we
have fixed an known issue in deepseek-dbo.


### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
This patch can be tested with newly added e2e tests:
[tests/multicard/test_offline_inference_distributed.py](https://github.com/vllm-project/vllm-ascend/pull/1285/files#diff-7cd2e6b1bda6b8ad1bedb3276971fe7064aeae4dc0efd41c301c4ede2158c57e).
It can be verified with pytest.

---------

Signed-off-by: zhuohuan <zxdu1997@gmail.com>
This commit is contained in:
zxdukki
2025-06-23 09:07:57 +08:00
committed by GitHub
parent 21fb68a03a
commit f04c6763d8
4 changed files with 41 additions and 7 deletions

View File

@@ -361,6 +361,8 @@ jobs:
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_topk
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_dbo
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo
pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py --ignore=tests/e2e/multicard/test_offline_inference_distributed.py
- name: Run vllm-project/vllm-ascend test on V0 engine

View File

@@ -25,6 +25,7 @@ from unittest.mock import patch
from modelscope import snapshot_download # type: ignore
from vllm import SamplingParams
from vllm.model_executor.models.registry import ModelRegistry
from tests.conftest import VllmRunner
@@ -94,6 +95,32 @@ def test_models_distributed_DeepSeek_dbo():
tensor_parallel_size=4,
distributed_executor_backend="mp",
) as vllm_model:
model_arch = 'DeepseekV2ForCausalLM'
registed_models = ModelRegistry.models
assert registed_models[
model_arch].module_name == "vllm_ascend.models.deepseek_dbo"
assert registed_models[
model_arch].class_name == "CustomDeepseekDBOForCausalLM"
vllm_model.generate(example_prompts, sampling_params)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"})
def test_models_distributed_DeepSeekV3_dbo():
example_prompts = ["The president of the United States is"] * 41
dtype = "half"
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
with VllmRunner(
"vllm-ascend/DeepSeek-V3-Pruning",
dtype=dtype,
tensor_parallel_size=4,
distributed_executor_backend="mp",
) as vllm_model:
model_arch = 'DeepseekV3ForCausalLM'
registed_models = ModelRegistry.models
assert registed_models[
model_arch].module_name == "vllm_ascend.models.deepseek_dbo"
assert registed_models[
model_arch].class_name == "CustomDeepseekDBOForCausalLM"
vllm_model.generate(example_prompts, sampling_params)

View File

@@ -35,14 +35,19 @@ def register_model():
ModelRegistry.register_model(
"DeepseekV2ForCausalLM",
"vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
ModelRegistry.register_model(
"DeepseekV3ForCausalLM",
"vllm_ascend.models.deepseek_dbo:CustomDeepseekDBOForCausalLM")
else:
ModelRegistry.register_model(
"DeepseekV2ForCausalLM",
"vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM")
ModelRegistry.register_model(
"DeepseekV3ForCausalLM",
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")
ModelRegistry.register_model(
"DeepseekV3ForCausalLM",
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")
ModelRegistry.register_model(
"Qwen3MoeForCausalLM",

View File

@@ -641,7 +641,7 @@ class CustomDeepseekDBODecoderLayer(DeepseekV2DecoderLayer):
if self.mlp.tp_size > 1:
num_token, _ = hidden_states[i].shape
padded_num_tokens = (self.mlp.tp_size - num_token %
padded_num_tokens = (self.mlp.tp_size - num_tokens[i] %
self.mlp.tp_size) % self.mlp.tp_size
if padded_num_tokens > 0:
hidden_states[i] = nn.functional.pad(
@@ -851,7 +851,8 @@ class CustomDeepseekDBOModel(nn.Module):
if VLLM_ASCEND_ENABLE_DBO and self.can_run_ms()
else self.end_layer - self.start_layer)
for i in range(self.start_layer, self.start_layer + num_normal_layers):
moe_start_layer = self.start_layer + num_normal_layers
for i in range(self.start_layer, min(moe_start_layer, self.end_layer)):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, residual,
@@ -859,8 +860,7 @@ class CustomDeepseekDBOModel(nn.Module):
self.start_layer] if kv_caches is not None else None,
attn_metadata)
moe_start_layer = self.start_layer + num_normal_layers
if moe_start_layer != self.end_layer:
if moe_start_layer < self.end_layer:
# if we enable multistream/dbo, process sparse layers here
hidden_states, residual = self._forward_ms_layers(
positions=positions,