[1/N][refactor] torchair deepseek modeling refactor (#2384)
### What this PR does / why we need it?
Move torchair related model arch into torchair moduel to make the code
clear. Next step we'll remove all torchair related code outside of
torchair moduel.
### Does this PR introduce _any_ user-facing change?
No.
- vLLM version: v0.10.0
- vLLM main:
08d5f7113a
Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
180
tests/ut/torchair/models/test_torchair_deepseek_mtp.py
Normal file
180
tests/ut/torchair/models/test_torchair_deepseek_mtp.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import pytest
|
||||
import torch
|
||||
from pytest_mock import MockerFixture
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.torchair.models.torchair_deepseek_mtp import (
|
||||
TorchairDeepSeekMTP, TorchairDeepSeekMultiTokenPredictor,
|
||||
TorchairDeepSeekMultiTokenPredictorLayer)
|
||||
|
||||
|
||||
class TestTorchairDeepSeekMultiTokenPredictorLayer(PytestBase):
|
||||
|
||||
@pytest.fixture
|
||||
def setup_mtp_layer(self, mocker: MockerFixture):
|
||||
config = PretrainedConfig(vocab_size=1000,
|
||||
hidden_size=768,
|
||||
rms_norm_eps=1e-5)
|
||||
mocker.patch(
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
mocker.patch("vllm.model_executor.layers.layernorm.RMSNorm.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm.model_executor.models.deepseek_mtp.SharedHead.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekShareHead.__init__",
|
||||
return_value=None)
|
||||
mocker_deepseek_v2_decode_layer = mocker.patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v2.TorchairDeepseekV2DecoderLayer.__init__",
|
||||
return_value=None)
|
||||
|
||||
mtp_layer = TorchairDeepSeekMultiTokenPredictorLayer(config, "", None)
|
||||
mocker_deepseek_v2_decode_layer.assert_called_once()
|
||||
return mtp_layer
|
||||
|
||||
def test_init(self, mocker: MockerFixture, setup_mtp_layer):
|
||||
mtp_layer = setup_mtp_layer
|
||||
assert isinstance(mtp_layer, TorchairDeepSeekMultiTokenPredictorLayer)
|
||||
|
||||
def test_forward(self, mocker: MockerFixture, setup_mtp_layer):
|
||||
mtp_layer = setup_mtp_layer
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
mocker.patch.object(mtp_layer,
|
||||
'eh_proj',
|
||||
return_value=torch.randn(2, 3, 768))
|
||||
mocker.patch("torch.cat", return_value=torch.randn(2, 3, 768))
|
||||
mtp_layer.mtp_block.return_value = (torch.randn(2, 3, 768),
|
||||
torch.randn(2, 3, 768))
|
||||
|
||||
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
|
||||
positions = torch.tensor([[0, 1, 2], [0, 1, 2]])
|
||||
kv_cache = torch.randn(2, 3, 768)
|
||||
previous_hidden_states = torch.randn(2, 3, 768)
|
||||
inputs_embeds = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
|
||||
output = mtp_layer(input_ids, positions, kv_cache, None,
|
||||
previous_hidden_states, inputs_embeds, 0)
|
||||
assert output.shape == (2, 3, 768)
|
||||
|
||||
|
||||
class TestTorchairDeepSeekMultiTokenPredictor(PytestBase):
|
||||
|
||||
@pytest.fixture
|
||||
def setup_predictor(self, mocker: MockerFixture):
|
||||
mock_vllm_config = mocker.MagicMock(spec=VllmConfig)
|
||||
mock_model_config = mocker.MagicMock(spec=ModelConfig)
|
||||
mock_hf_config = mocker.MagicMock()
|
||||
mock_hf_config.num_hidden_layers = 12
|
||||
mock_hf_config.num_nextn_predict_layers = 3
|
||||
mock_hf_config.vocab_size = 30000
|
||||
mock_model_config.hf_config = mock_hf_config
|
||||
mock_vllm_config.model_config = mock_model_config
|
||||
mock_vllm_config.cache_config = CacheConfig()
|
||||
mock_vllm_config.quant_config = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__init__",
|
||||
return_value=None)
|
||||
|
||||
predictor = TorchairDeepSeekMultiTokenPredictor(
|
||||
vllm_config=mock_vllm_config)
|
||||
return predictor
|
||||
|
||||
def test_init(self, mocker: MockerFixture, setup_predictor):
|
||||
predictor = setup_predictor
|
||||
assert predictor.num_mtp_layers == 3
|
||||
assert isinstance(predictor, TorchairDeepSeekMultiTokenPredictor)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'kv_caches, inputs_embeds',
|
||||
[(torch.tensor([[[0.1, 0.2, 0.3]]]), torch.tensor([[0.1, 0.2, 0.3]]))])
|
||||
def test_forward(self, mocker: MockerFixture, setup_predictor, kv_caches,
|
||||
inputs_embeds):
|
||||
predictor = setup_predictor
|
||||
mock_layer = mocker.MagicMock()
|
||||
mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0])
|
||||
predictor.layers_list = [mock_layer]
|
||||
|
||||
# todo: need or not?
|
||||
# predictor.num_mtp_layers = 1
|
||||
input_ids = torch.tensor([[1, 2, 3]])
|
||||
positions = torch.tensor([[0, 1, 2]])
|
||||
mocker.patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__call__",
|
||||
return_value=torch.tensor([[1.0, 2.0, 3.0]]))
|
||||
output = predictor.forward(input_ids, positions, kv_caches, None, None,
|
||||
inputs_embeds, 0)
|
||||
mock_layer.assert_called_once()
|
||||
assert torch.allclose(output, torch.tensor([1.0, 2.0, 3.0]))
|
||||
|
||||
def test_compute_logits(self, mocker: MockerFixture, setup_predictor):
|
||||
hidden_states = torch.tensor([[1, 2, 3], [4, 5, 6]])
|
||||
predictor = setup_predictor
|
||||
|
||||
mock_layer = mocker.MagicMock()
|
||||
mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0])
|
||||
predictor.layers_list = [mock_layer]
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
mocker.patch(
|
||||
"vllm.model_executor.layers.logits_processor.LogitsProcessor.__init__",
|
||||
return_value=None)
|
||||
predictor.logits_processor.return_value = torch.tensor([1.0, 2.0, 3.0])
|
||||
|
||||
result_logits = predictor.compute_logits(hidden_states=hidden_states,
|
||||
sampling_metadata=None)
|
||||
predictor.logits_processor.assert_called_once()
|
||||
assert torch.allclose(result_logits, torch.tensor([1.0, 2.0, 3.0]))
|
||||
|
||||
|
||||
class TestTorchairDeepSeekMTP(PytestBase):
|
||||
|
||||
@pytest.fixture
|
||||
def setup_mtp(self, mocker: MockerFixture):
|
||||
vllm_config = mocker.MagicMock()
|
||||
vllm_config.model_config.hf_config.num_hidden_layers = 12
|
||||
vllm_config.model_config.hf_config.num_nextn_predict_layers = 3
|
||||
vllm_config.cache_config = mocker.MagicMock()
|
||||
vllm_config.quant_config = mocker.MagicMock()
|
||||
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
mocker.patch(
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__call__",
|
||||
return_value=None)
|
||||
mocker.patch("vllm.model_executor.layers.sampler.get_sampler",
|
||||
return_value=None)
|
||||
|
||||
mtp = TorchairDeepSeekMTP(vllm_config=vllm_config)
|
||||
return mtp
|
||||
|
||||
def test_init(self, mocker: MockerFixture, setup_mtp):
|
||||
mtp = setup_mtp
|
||||
assert isinstance(mtp, TorchairDeepSeekMTP)
|
||||
|
||||
def test_forward(self, mocker: MockerFixture, setup_mtp):
|
||||
input_ids = torch.tensor([[1, 2, 3]])
|
||||
positions = torch.tensor([[0, 1, 2]])
|
||||
kv_caches = [torch.tensor([[0.1, 0.2, 0.3]])]
|
||||
previous_hidden_states = torch.tensor([[0.1, 0.2, 0.3]])
|
||||
inputs_embeds = torch.tensor([[0.1, 0.2, 0.3]])
|
||||
spec_step_idx = 0
|
||||
setup_mtp.model.return_value = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
|
||||
output = setup_mtp.forward(input_ids, positions, kv_caches, None,
|
||||
previous_hidden_states, inputs_embeds,
|
||||
spec_step_idx)
|
||||
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))
|
||||
324
tests/ut/torchair/models/test_torchair_deepseek_v2.py
Normal file
324
tests/ut/torchair/models/test_torchair_deepseek_v2.py
Normal file
@@ -0,0 +1,324 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
|
||||
from vllm_ascend.torchair.models.torchair_deepseek_v2 import (
|
||||
TorchairDeepseekV2DecoderLayer, TorchairDeepseekV2ForCausalLM,
|
||||
TorchairDeepseekV2MergedReplicatedLinear, TorchairDeepseekV2MLAAttention,
|
||||
TorchairDeepseekV2MLP, TorchairDeepseekV2MoE,
|
||||
TorchairDeepseekV2RowParallelLinear,
|
||||
TorchairDeepseekV2RowParallelLinearReplaceAllreduce,
|
||||
TorchairDeepseekV2SiluAndMul)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_config():
|
||||
config = PretrainedConfig(
|
||||
hidden_size=128,
|
||||
num_attention_heads=8,
|
||||
num_hidden_layers=2,
|
||||
intermediate_size=256,
|
||||
hidden_act="silu",
|
||||
rms_norm_eps=1e-6,
|
||||
rope_theta=10000.0,
|
||||
max_position_embeddings=2048,
|
||||
n_routed_experts=4,
|
||||
n_shared_experts=1,
|
||||
moe_intermediate_size=256,
|
||||
num_experts_per_tok=2,
|
||||
routed_scaling_factor=1.0,
|
||||
first_k_dense_replace=0,
|
||||
moe_layer_freq=1,
|
||||
kv_lora_rank=16,
|
||||
qk_nope_head_dim=16,
|
||||
qk_rope_head_dim=16,
|
||||
v_head_dim=32,
|
||||
topk_method="noaux_tc",
|
||||
scoring_func="softmax",
|
||||
norm_topk_prob=True,
|
||||
n_group=1,
|
||||
topk_group=1,
|
||||
vocab_size=10000,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vllm_config(base_config):
|
||||
model_config = SimpleNamespace(
|
||||
hf_config=base_config,
|
||||
tensor_parallel_size=1,
|
||||
dtype=torch.float32,
|
||||
use_mla=False,
|
||||
quant_config=None,
|
||||
max_model_len=2048,
|
||||
)
|
||||
|
||||
cache_config = CacheConfig()
|
||||
vllm_config = Mock()
|
||||
vllm_config.model_config = model_config
|
||||
vllm_config.cache_config = cache_config
|
||||
vllm_config.quant_config = None
|
||||
return vllm_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_distributed():
|
||||
tp_group = Mock(spec=GroupCoordinator)
|
||||
tp_group.rank_in_group = 0
|
||||
tp_group.world_size = 1
|
||||
tp_group.device_group = Mock()
|
||||
|
||||
dp_group = Mock(spec=GroupCoordinator)
|
||||
dp_group.rank_in_group = 0
|
||||
dp_group.world_size = 1
|
||||
|
||||
ep_group = Mock(spec=GroupCoordinator)
|
||||
ep_group.rank_in_group = 0
|
||||
ep_group.world_size = 1
|
||||
|
||||
pp_group = Mock(spec=GroupCoordinator)
|
||||
pp_group.rank_in_group = 0
|
||||
pp_group.world_size = 1
|
||||
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.scheduler_config = Mock(max_num_seqs=256)
|
||||
mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None)
|
||||
|
||||
with patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_tensor_model_parallel_rank", return_value=0), \
|
||||
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_tensor_model_parallel_world_size", return_value=1), \
|
||||
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_tp_group", return_value=tp_group), \
|
||||
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_ep_group", return_value=ep_group), \
|
||||
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_dp_group", return_value=dp_group), \
|
||||
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group", return_value=pp_group), \
|
||||
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group",
|
||||
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
|
||||
patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
|
||||
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
|
||||
_PP=pp_group), \
|
||||
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_forward_context():
|
||||
forward_context = Mock(in_profile_run=False, with_prefill=False)
|
||||
with patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v2.get_forward_context",
|
||||
return_value=forward_context):
|
||||
yield
|
||||
|
||||
|
||||
def test_torchair_deepseek_v2_silu_and_mul():
|
||||
torch.set_default_device("cpu")
|
||||
|
||||
silu = TorchairDeepseekV2SiluAndMul()
|
||||
assert silu.weight_scale is None
|
||||
|
||||
x = torch.randn(2, 4)
|
||||
output = silu.forward_oot(x)
|
||||
assert output.shape == (2, 2)
|
||||
|
||||
weight_scale = Mock(return_value=torch.tensor(0.1))
|
||||
silu = TorchairDeepseekV2SiluAndMul(weight_scale=weight_scale)
|
||||
quant_x = torch.randint(-128, 127, (2, 4), dtype=torch.int32)
|
||||
dynamic_scale = torch.randn(2, 1)
|
||||
with patch("torch_npu.npu_dequant_swiglu_quant",
|
||||
return_value=torch.randn(2, 4)):
|
||||
output = silu.forward_oot((quant_x, dynamic_scale))
|
||||
assert output.shape == (2, 4)
|
||||
|
||||
|
||||
def test_torchair_deepseek_v2_merged_replicated_linear(mock_distributed):
|
||||
linear = TorchairDeepseekV2MergedReplicatedLinear(input_size=128,
|
||||
output_sizes=[64, 64],
|
||||
bias=False,
|
||||
quant_config=None)
|
||||
assert linear.output_sizes == [64, 64]
|
||||
|
||||
param = Mock()
|
||||
param.data = torch.zeros(128, 128)
|
||||
param.output_dim = 1
|
||||
param.is_gguf_weight = False
|
||||
param.is_gguf_weight_type = False
|
||||
loaded_weight = torch.randn(128, 64)
|
||||
linear.weight_loader(param, loaded_weight, loaded_shard_id=0)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
linear.weight_loader(param, torch.randn(128, 32), loaded_shard_id=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cls", [
|
||||
TorchairDeepseekV2RowParallelLinearReplaceAllreduce,
|
||||
TorchairDeepseekV2RowParallelLinear
|
||||
])
|
||||
def test_row_parallel_linear(cls, mock_distributed):
|
||||
linear = cls(input_size=128, output_size=64, bias=False, quant_config=None)
|
||||
linear.quant_method = Mock()
|
||||
linear.quant_method.apply.return_value = torch.randn(2, 4, 64)
|
||||
|
||||
input_ = torch.randn(2, 4, 128)
|
||||
with patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v2.split_tensor_along_last_dim",
|
||||
return_value=[torch.randn(2, 4, 64)]):
|
||||
linear.input_is_parallel = False
|
||||
output = linear(input_, is_prefill=True)
|
||||
assert output[0].shape == (2, 4, 64)
|
||||
|
||||
linear.input_is_parallel = True
|
||||
output = linear(input_, is_prefill=False)
|
||||
assert output[0].shape == (2, 4, 64)
|
||||
|
||||
|
||||
def test_torchair_deepseek_v2_mlp(mock_distributed, base_config):
|
||||
mlp = TorchairDeepseekV2MLP(hidden_size=128,
|
||||
intermediate_size=256,
|
||||
hidden_act="silu",
|
||||
quant_config=None)
|
||||
assert isinstance(mlp.act_fn, TorchairDeepseekV2SiluAndMul)
|
||||
|
||||
x = torch.randn(2, 4, 128)
|
||||
output = mlp(x)
|
||||
assert output.shape == (2, 4, 128)
|
||||
|
||||
with patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v2.QuantizationConfig"
|
||||
) as mock_quant_config:
|
||||
mock_quant_config.name = "w8a8dynamic"
|
||||
with pytest.raises(NotImplementedError):
|
||||
TorchairDeepseekV2MLP(hidden_size=128,
|
||||
intermediate_size=256,
|
||||
hidden_act="silu",
|
||||
quant_config=mock_quant_config,
|
||||
force_replicate=False)
|
||||
with pytest.raises(ValueError):
|
||||
TorchairDeepseekV2MLP(hidden_size=128,
|
||||
intermediate_size=256,
|
||||
hidden_act="relu",
|
||||
quant_config=None)
|
||||
|
||||
|
||||
def test_torchair_deepseek_v2_moe(mock_distributed, base_config,
|
||||
mock_forward_context):
|
||||
base_config.n_shared_experts = 1
|
||||
moe = TorchairDeepseekV2MoE(config=base_config,
|
||||
quant_config=None,
|
||||
prefix="mlp")
|
||||
assert moe.top_k == 2
|
||||
|
||||
x = torch.randn(2, 4, 128)
|
||||
attn_metadata = Mock(num_prefills=1)
|
||||
with patch("vllm_ascend.ops.fused_moe.AscendFusedMoE.__call__",
|
||||
return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))):
|
||||
output = moe(x, attn_metadata)
|
||||
assert output.shape == (2, 4, 128)
|
||||
|
||||
|
||||
@patch("torch_npu.npu_rms_norm")
|
||||
def test_torchair_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
|
||||
base_config):
|
||||
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
|
||||
|
||||
attn = TorchairDeepseekV2MLAAttention(config=base_config,
|
||||
hidden_size=128,
|
||||
num_heads=8,
|
||||
qk_nope_head_dim=16,
|
||||
qk_rope_head_dim=16,
|
||||
v_head_dim=32,
|
||||
q_lora_rank=16,
|
||||
kv_lora_rank=16,
|
||||
cache_config=CacheConfig(),
|
||||
quant_config=None,
|
||||
prefix="layers.0.self_attn")
|
||||
assert attn.debug_layer_idx == 0
|
||||
|
||||
x = torch.randn(2, 4, 128)
|
||||
positions = torch.arange(4).repeat(2, 1)
|
||||
with patch.object(attn.mla_attn,
|
||||
"__call__",
|
||||
return_value=torch.randn(2, 4, 128)):
|
||||
with pytest.raises(AssertionError):
|
||||
attn(positions, x)
|
||||
|
||||
attn = TorchairDeepseekV2MLAAttention(config=base_config,
|
||||
hidden_size=128,
|
||||
num_heads=8,
|
||||
qk_nope_head_dim=16,
|
||||
qk_rope_head_dim=16,
|
||||
v_head_dim=32,
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=16,
|
||||
prefix="layers.1.self_attn")
|
||||
assert hasattr(attn, "q_proj")
|
||||
|
||||
|
||||
@patch("torch_npu.npu_add_rms_norm")
|
||||
@patch("torch_npu.npu_rms_norm")
|
||||
def test_torchair_deepseek_v2_decoder_layer(mock_rms_norm, mock_add_norm,
|
||||
mock_distributed, base_config,
|
||||
vllm_config):
|
||||
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
|
||||
mock_add_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128),
|
||||
torch.randn(2, 128))
|
||||
base_config.n_routed_experts = 4
|
||||
layer = TorchairDeepseekV2DecoderLayer(
|
||||
config=base_config,
|
||||
prefix="layers.0",
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=CacheConfig(),
|
||||
quant_config=None)
|
||||
assert isinstance(layer.mlp, TorchairDeepseekV2MoE)
|
||||
|
||||
x = torch.randn(2, 4, 128)
|
||||
positions = torch.arange(4).repeat(2, 1)
|
||||
|
||||
with patch.object(layer.self_attn, "forward", Mock(return_value=torch.randn(2, 4, 128))), \
|
||||
patch.object(layer.mlp, "forward", Mock(return_value=torch.randn(2, 4, 128))):
|
||||
hidden_states, residual = layer(positions, x, None)
|
||||
assert hidden_states.shape == (2, 4, 128)
|
||||
|
||||
base_config.n_routed_experts = None
|
||||
layer = TorchairDeepseekV2DecoderLayer(
|
||||
config=base_config,
|
||||
prefix="layers.0",
|
||||
model_config=vllm_config.model_config,
|
||||
quant_config=None)
|
||||
assert isinstance(layer.mlp, TorchairDeepseekV2MLP)
|
||||
|
||||
|
||||
def test_torchair_deepseek_v2_for_causal_lm(mock_distributed, vllm_config):
|
||||
model = TorchairDeepseekV2ForCausalLM(vllm_config=vllm_config)
|
||||
|
||||
input_ids = torch.randint(0, 10000, (2, 4))
|
||||
positions = torch.arange(4).repeat(2, 1)
|
||||
with patch.object(model.model,
|
||||
"forward",
|
||||
return_value=torch.randn(2, 4, 128)):
|
||||
output = model(input_ids, positions)
|
||||
assert output.shape == (2, 4, 128)
|
||||
|
||||
weights = [("model.embed_tokens.weight", torch.randn(10000, 128))]
|
||||
with patch(
|
||||
"vllm.model_executor.model_loader.weight_utils.default_weight_loader"
|
||||
):
|
||||
loaded = model.load_weights(weights)
|
||||
assert loaded is not None
|
||||
@@ -1,4 +1,6 @@
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.torchair import utils
|
||||
@@ -26,3 +28,46 @@ class TestTorchairUtils(TestBase):
|
||||
"Delete torchair cache dir failed")
|
||||
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
|
||||
"Delete kv cache bytes cache dir failed")
|
||||
|
||||
def test_torchair_cache_dir_multiple_ranks(self):
|
||||
ranks = [0, 1, 2, 3]
|
||||
values = [100, 200, 300, 400]
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
executor.map(utils.write_kv_cache_bytes_to_file, ranks, values)
|
||||
for rank, expected in zip(ranks, values):
|
||||
self.assertEqual(expected,
|
||||
utils.read_kv_cache_bytes_from_file(rank))
|
||||
utils.delete_torchair_cache_file()
|
||||
|
||||
self.assertFalse(utils.check_torchair_cache_exist(),
|
||||
"Delete torchair cache dir failed")
|
||||
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
|
||||
"Delete kv cache bytes cache dir failed")
|
||||
|
||||
@patch('vllm.ModelRegistry')
|
||||
def test_register_torchair_model(self, mock_model_registry):
|
||||
mock_registry = MagicMock()
|
||||
mock_model_registry.return_value = mock_registry
|
||||
utils.register_torchair_model()
|
||||
|
||||
self.assertEqual(mock_model_registry.register_model.call_count, 3)
|
||||
call_args_list = mock_model_registry.register_model.call_args_list
|
||||
|
||||
expected_registrations = [
|
||||
("DeepSeekMTPModel",
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_mtp:TorchairDeepSeekMTP"
|
||||
),
|
||||
("DeepseekV2ForCausalLM",
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v2:TorchairDeepseekV2ForCausalLM"
|
||||
),
|
||||
("DeepseekV3ForCausalLM",
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
|
||||
)
|
||||
]
|
||||
|
||||
for i, (expected_name,
|
||||
expected_path) in enumerate(expected_registrations):
|
||||
args, kwargs = call_args_list[i]
|
||||
self.assertEqual(args[0], expected_name)
|
||||
self.assertEqual(args[1], expected_path)
|
||||
|
||||
0
vllm_ascend/torchair/models/__init__.py
Normal file
0
vllm_ascend/torchair/models/__init__.py
Normal file
218
vllm_ascend/torchair/models/torchair_deepseek_mtp.py
Normal file
218
vllm_ascend/torchair/models/torchair_deepseek_mtp.py
Normal file
@@ -0,0 +1,218 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Adapted from vllm/model_executor/models/deepseek_mtp.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.deepseek_mtp import (
|
||||
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
|
||||
SharedHead)
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.torchair.models.torchair_deepseek_v2 import \
|
||||
TorchairDeepseekV2DecoderLayer
|
||||
|
||||
|
||||
class TorchairDeepSeekShareHead(SharedHead):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "head"))
|
||||
|
||||
|
||||
class TorchairDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer
|
||||
):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
prefix: str,
|
||||
model_config: ModelConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
self.shared_head = TorchairDeepSeekShareHead(config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix,
|
||||
"shared_head"))
|
||||
self.mtp_block = TorchairDeepseekV2DecoderLayer(
|
||||
config, prefix, model_config, cache_config, quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_index: int = 0,
|
||||
) -> torch.Tensor:
|
||||
assert inputs_embeds is not None
|
||||
# masking inputs at position 0, as not needed by MTP
|
||||
inputs_embeds = torch.where((positions == 0).unsqueeze(-1),
|
||||
torch.zeros_like(inputs_embeds),
|
||||
inputs_embeds)
|
||||
inputs_embeds = self.enorm(inputs_embeds)
|
||||
previous_hidden_states = self.hnorm(previous_hidden_states)
|
||||
|
||||
hidden_states = self.eh_proj(
|
||||
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
||||
|
||||
hidden_states, residual = self.mtp_block(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=None)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TorchairDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.mtp_start_layer_idx = config.num_hidden_layers
|
||||
self.num_mtp_layers = config.num_nextn_predict_layers
|
||||
# to map the exact layer index from weights
|
||||
self.layers = torch.nn.ModuleDict({
|
||||
str(idx):
|
||||
TorchairDeepSeekMultiTokenPredictorLayer(
|
||||
config,
|
||||
f"{prefix}.layers.{idx}",
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
quant_config=vllm_config.quant_config,
|
||||
)
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
})
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
|
||||
# Note: torch._dynamo.exc.Unsupported: builtin: str
|
||||
self.layers_list = [
|
||||
self.layers[str(idx)]
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
]
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
step_kv_cache = kv_caches[
|
||||
current_step_idx] if kv_caches is not None else None
|
||||
return self.layers_list[current_step_idx](
|
||||
input_ids,
|
||||
positions,
|
||||
step_kv_cache,
|
||||
attn_metadata,
|
||||
previous_hidden_states,
|
||||
inputs_embeds,
|
||||
current_step_idx,
|
||||
)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
mtp_layer = self.layers_list[current_step_idx]
|
||||
logits = self.logits_processor(mtp_layer.shared_head.head,
|
||||
mtp_layer.shared_head(hidden_states),
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
||||
class TorchairDeepSeekMTP(DeepSeekMTP):
|
||||
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
|
||||
# NOTE 2.The description file generated by the current msmodelslim tool does not have
|
||||
# MTP layer info. Please manually add it and set the value to FLOAT.
|
||||
packed_modules_mapping = {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.model = TorchairDeepSeekMultiTokenPredictor(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
self.sampler = get_sampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
previous_hidden_states: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, previous_hidden_states,
|
||||
inputs_embeds, spec_step_idx)
|
||||
return hidden_states
|
||||
1047
vllm_ascend/torchair/models/torchair_deepseek_v2.py
Normal file
1047
vllm_ascend/torchair/models/torchair_deepseek_v2.py
Normal file
File diff suppressed because it is too large
Load Diff
28
vllm_ascend/torchair/models/torchair_deepseek_v3.py
Normal file
28
vllm_ascend/torchair/models/torchair_deepseek_v3.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from vllm_ascend.torchair.models.torchair_deepseek_v2 import \
|
||||
TorchairDeepseekV2ForCausalLM
|
||||
|
||||
|
||||
class TorchairDeepseekV3ForCausalLM(TorchairDeepseekV2ForCausalLM):
|
||||
pass
|
||||
@@ -27,6 +27,7 @@ from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
|
||||
register_torchair_model,
|
||||
write_kv_cache_bytes_to_file)
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
maybe_converting_weight_acl_format)
|
||||
@@ -37,6 +38,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||
super().__init__(vllm_config, device)
|
||||
register_torchair_model()
|
||||
|
||||
def _get_forward_metadata_across_dp_and_pad(
|
||||
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
|
||||
|
||||
@@ -96,3 +96,22 @@ def npu_wait_tensor(self: torch.Tensor,
|
||||
*,
|
||||
enabled: bool = True):
|
||||
return _npu_wait_tensor(self, dependency) if enabled else self
|
||||
|
||||
|
||||
def register_torchair_model():
|
||||
from vllm import ModelRegistry
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepSeekMTPModel",
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_mtp:TorchairDeepSeekMTP"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV2ForCausalLM",
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v2:TorchairDeepseekV2ForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV3ForCausalLM",
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user