### What this PR does / why we need it?
Support MTP with:
- [x] V0 Scheduler
- [x] TorchAir
- [x] Single DP
- [x] Multi DP
- [x] Disaggregate PD
Known issues:
- [ ] Not support V1 Scheduler (chunked prefill), will be supported in a
few weeks
- [ ] vllm v0.10.0 does not support metrics with `DP > 1` right now,
need to comment out the line 171-175 in file
`vllm/vllm/v1/metrics/loggers.py`
```
if (len(self.engine_indexes) > 1
and vllm_config.speculative_config is not None):
raise NotImplementedError("Prometheus metrics with Spec Decoding "
"with >1 EngineCore per AsyncLLM is not "
"supported yet.")
```
To start an online server with torchair enabled, here is an example:
```
python -m vllm.entrypoints.openai.api_server \
--model="/weights/DeepSeek-R1_w8a8/" \
--trust-remote-code \
--max-model-len 40000 \
--tensor-parallel-size 4 \
--data_parallel_size 4 \
--max-num-seqs 16 \
--no-enable-prefix-caching \
--enable_expert_parallel \
--served-model-name deepseekr1 \
--speculative-config '{"num_speculative_tokens": 1, "method":"deepseek_mtp"}' \
--quantization ascend \
--host 0.0.0.0 \
--port 1234 \
--additional-config '{"ascend_scheduler_config":{"enabled":true,"enable_chunked_prefill":false},"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]},"enable_weight_nz_layout":true}' \
--gpu_memory_utilization 0.9
```
offline example with torchair enabled
```
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=16, temperature=0)
# Create an LLM.
llm = LLM(
model="/home/data/DeepSeek-R1_w8a8/",
tensor_parallel_size=16,
max_num_seqs=16,
gpu_memory_utilization=0.9,
distributed_executor_backend="mp",
enable_expert_parallel=True,
speculative_config={
"method": "deepseek_mtp",
"num_speculative_tokens": 1,
},
trust_remote_code=True,
enforce_eager=False,
max_model_len=2000,
additional_config = {
'torchair_graph_config': {
'enabled': True,
"graph_batch_sizes": [16],
'enable_multistream_shared_expert': False,
},
"ascend_scheduler_config": {
"enabled": True
},
# 'expert_tensor_parallel_size': 16,
}
)
# Generate texts from the prompts.
# llm.start_profile()
outputs = llm.generate(prompts, sampling_params)
# llm.stop_profile()
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
- vLLM version: v0.10.0
- vLLM main:
302962e806
---------
Signed-off-by: xuyexiong <xuyexiong@huawei.com>
180 lines
7.8 KiB
Python
180 lines
7.8 KiB
Python
import pytest
|
|
import torch
|
|
from pytest_mock import MockerFixture
|
|
from transformers import PretrainedConfig
|
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
|
|
|
from tests.ut.base import PytestBase
|
|
from vllm_ascend.models.deepseek_mtp import (
|
|
CustomDeepSeekMTP, CustomDeepSeekMultiTokenPredictor,
|
|
CustomDeepSeekMultiTokenPredictorLayer)
|
|
|
|
|
|
class TestCustomDeepSeekMultiTokenPredictorLayer(PytestBase):
|
|
|
|
@pytest.fixture
|
|
def setup_mtp_layer(self, mocker: MockerFixture):
|
|
config = PretrainedConfig(vocab_size=1000,
|
|
hidden_size=768,
|
|
rms_norm_eps=1e-5)
|
|
mocker.patch(
|
|
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
|
|
return_value=None)
|
|
mocker.patch("vllm.model_executor.layers.layernorm.RMSNorm.__init__",
|
|
return_value=None)
|
|
mocker.patch(
|
|
"vllm.model_executor.models.deepseek_mtp.SharedHead.__init__",
|
|
return_value=None)
|
|
mocker.patch(
|
|
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekShareHead.__init__",
|
|
return_value=None)
|
|
mocker_deepseek_v2_decode_layer = mocker.patch(
|
|
"vllm_ascend.models.deepseek_v2.CustomDeepseekV2DecoderLayer.__init__",
|
|
return_value=None)
|
|
|
|
mtp_layer = CustomDeepSeekMultiTokenPredictorLayer(config, "", None)
|
|
mocker_deepseek_v2_decode_layer.assert_called_once()
|
|
return mtp_layer
|
|
|
|
def test_init(self, mocker: MockerFixture, setup_mtp_layer):
|
|
mtp_layer = setup_mtp_layer
|
|
assert isinstance(mtp_layer, CustomDeepSeekMultiTokenPredictorLayer)
|
|
|
|
def test_forward(self, mocker: MockerFixture, setup_mtp_layer):
|
|
mtp_layer = setup_mtp_layer
|
|
mocker.patch("torch.nn.Module.__setattr__")
|
|
mocker.patch("torch.nn.Module.__getattr__")
|
|
mocker.patch("torch.nn.Module.__delattr__")
|
|
mocker.patch.object(mtp_layer,
|
|
'eh_proj',
|
|
return_value=torch.randn(2, 3, 768))
|
|
mocker.patch("torch.cat", return_value=torch.randn(2, 3, 768))
|
|
mtp_layer.mtp_block.return_value = (torch.randn(2, 3, 768),
|
|
torch.randn(2, 3, 768))
|
|
|
|
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
|
|
positions = torch.tensor([[0, 1, 2], [0, 1, 2]])
|
|
kv_cache = torch.randn(2, 3, 768)
|
|
previous_hidden_states = torch.randn(2, 3, 768)
|
|
inputs_embeds = torch.tensor([[1.0, 2.0, 3.0]])
|
|
|
|
output = mtp_layer(input_ids, positions, kv_cache, None,
|
|
previous_hidden_states, inputs_embeds, 0)
|
|
assert output.shape == (2, 3, 768)
|
|
|
|
|
|
class TestCustomDeepSeekMultiTokenPredictor(PytestBase):
|
|
|
|
@pytest.fixture
|
|
def setup_predictor(self, mocker: MockerFixture):
|
|
mock_vllm_config = mocker.MagicMock(spec=VllmConfig)
|
|
mock_model_config = mocker.MagicMock(spec=ModelConfig)
|
|
mock_hf_config = mocker.MagicMock()
|
|
mock_hf_config.num_hidden_layers = 12
|
|
mock_hf_config.num_nextn_predict_layers = 3
|
|
mock_hf_config.vocab_size = 30000
|
|
mock_model_config.hf_config = mock_hf_config
|
|
mock_vllm_config.model_config = mock_model_config
|
|
mock_vllm_config.cache_config = CacheConfig()
|
|
mock_vllm_config.quant_config = mocker.MagicMock()
|
|
mocker.patch(
|
|
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
|
|
return_value=None)
|
|
mocker.patch(
|
|
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__init__",
|
|
return_value=None)
|
|
|
|
predictor = CustomDeepSeekMultiTokenPredictor(
|
|
vllm_config=mock_vllm_config)
|
|
return predictor
|
|
|
|
def test_init(self, mocker: MockerFixture, setup_predictor):
|
|
predictor = setup_predictor
|
|
assert predictor.num_mtp_layers == 3
|
|
assert isinstance(predictor, CustomDeepSeekMultiTokenPredictor)
|
|
|
|
@pytest.mark.parametrize(
|
|
'kv_caches, inputs_embeds',
|
|
[(torch.tensor([[[0.1, 0.2, 0.3]]]), torch.tensor([[0.1, 0.2, 0.3]]))])
|
|
def test_forward(self, mocker: MockerFixture, setup_predictor, kv_caches,
|
|
inputs_embeds):
|
|
predictor = setup_predictor
|
|
mock_layer = mocker.MagicMock()
|
|
mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0])
|
|
predictor.layers_list = [mock_layer]
|
|
|
|
# todo: need or not?
|
|
# predictor.num_mtp_layers = 1
|
|
input_ids = torch.tensor([[1, 2, 3]])
|
|
positions = torch.tensor([[0, 1, 2]])
|
|
mocker.patch(
|
|
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__call__",
|
|
return_value=torch.tensor([[1.0, 2.0, 3.0]]))
|
|
output = predictor.forward(input_ids, positions, kv_caches, None, None,
|
|
inputs_embeds, 0)
|
|
mock_layer.assert_called_once()
|
|
assert torch.allclose(output, torch.tensor([1.0, 2.0, 3.0]))
|
|
|
|
def test_compute_logits(self, mocker: MockerFixture, setup_predictor):
|
|
hidden_states = torch.tensor([[1, 2, 3], [4, 5, 6]])
|
|
predictor = setup_predictor
|
|
|
|
mock_layer = mocker.MagicMock()
|
|
mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0])
|
|
predictor.layers_list = [mock_layer]
|
|
mocker.patch("torch.nn.Module.__setattr__")
|
|
mocker.patch("torch.nn.Module.__getattr__")
|
|
mocker.patch("torch.nn.Module.__delattr__")
|
|
mocker.patch(
|
|
"vllm.model_executor.layers.logits_processor.LogitsProcessor.__init__",
|
|
return_value=None)
|
|
predictor.logits_processor.return_value = torch.tensor([1.0, 2.0, 3.0])
|
|
|
|
result_logits = predictor.compute_logits(hidden_states=hidden_states,
|
|
sampling_metadata=None)
|
|
predictor.logits_processor.assert_called_once()
|
|
assert torch.allclose(result_logits, torch.tensor([1.0, 2.0, 3.0]))
|
|
|
|
|
|
class TestCustomDeepSeekMTP(PytestBase):
|
|
|
|
@pytest.fixture
|
|
def setup_mtp(self, mocker: MockerFixture):
|
|
vllm_config = mocker.MagicMock()
|
|
vllm_config.model_config.hf_config.num_hidden_layers = 12
|
|
vllm_config.model_config.hf_config.num_nextn_predict_layers = 3
|
|
vllm_config.cache_config = mocker.MagicMock()
|
|
vllm_config.quant_config = mocker.MagicMock()
|
|
|
|
mocker.patch("torch.nn.Module.__setattr__")
|
|
mocker.patch("torch.nn.Module.__getattr__")
|
|
mocker.patch("torch.nn.Module.__delattr__")
|
|
mocker.patch(
|
|
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
|
|
return_value=None)
|
|
mocker.patch(
|
|
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__call__",
|
|
return_value=None)
|
|
mocker.patch("vllm.model_executor.layers.sampler.get_sampler",
|
|
return_value=None)
|
|
|
|
mtp = CustomDeepSeekMTP(vllm_config=vllm_config)
|
|
return mtp
|
|
|
|
def test_init(self, mocker: MockerFixture, setup_mtp):
|
|
mtp = setup_mtp
|
|
assert isinstance(mtp, CustomDeepSeekMTP)
|
|
|
|
def test_forward(self, mocker: MockerFixture, setup_mtp):
|
|
input_ids = torch.tensor([[1, 2, 3]])
|
|
positions = torch.tensor([[0, 1, 2]])
|
|
kv_caches = [torch.tensor([[0.1, 0.2, 0.3]])]
|
|
previous_hidden_states = torch.tensor([[0.1, 0.2, 0.3]])
|
|
inputs_embeds = torch.tensor([[0.1, 0.2, 0.3]])
|
|
spec_step_idx = 0
|
|
setup_mtp.model.return_value = torch.tensor([[1.0, 2.0, 3.0]])
|
|
|
|
output = setup_mtp.forward(input_ids, positions, kv_caches, None,
|
|
previous_hidden_states, inputs_embeds,
|
|
spec_step_idx)
|
|
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]])) |