[UT] fix skip ut test and enable ut test run normally (#3410)
### What this PR does / why we need it? fix skip ut test and enable ut test run normally ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: hfadzxy <starmoon_zhang@163.com>
This commit is contained in:
@@ -133,6 +133,33 @@ def mock_forward_context():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_attention_init():
|
||||
try:
|
||||
from vllm_ascend.torchair.models.torchair_deepseek_v2 import \
|
||||
DeepseekV2Attention
|
||||
original_init = DeepseekV2Attention.__init__
|
||||
|
||||
def patched_init(self, *args, **kwargs):
|
||||
kwargs.pop("decoder_layer", None)
|
||||
if 'vllm_config' not in kwargs:
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.model_config = Mock()
|
||||
mock_vllm_config.model_config.hf_config = Mock()
|
||||
mock_vllm_config.model_config.hf_config.hidden_size = 128
|
||||
mock_vllm_config.model_config.dtype = torch.float32
|
||||
mock_vllm_config.model_config.quant_config = None
|
||||
mock_vllm_config.cache_config = CacheConfig()
|
||||
kwargs['vllm_config'] = mock_vllm_config
|
||||
return original_init(self, *args, **kwargs)
|
||||
|
||||
DeepseekV2Attention.__init__ = patched_init
|
||||
yield
|
||||
DeepseekV2Attention.__init__ = original_init
|
||||
except ImportError:
|
||||
yield
|
||||
|
||||
|
||||
def test_torchair_deepseek_v2_silu_and_mul():
|
||||
torch.set_default_device("cpu")
|
||||
|
||||
@@ -276,10 +303,14 @@ def test_torchair_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
|
||||
@patch("torch_npu.npu_add_rms_norm")
|
||||
@patch("torch_npu.npu_rms_norm")
|
||||
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
|
||||
def test_torchair_deepseek_v2_decoder_layer(mock_maybe_wait_prefetch_done,
|
||||
@patch("torch.ops.vllm.maybe_chunk_residual")
|
||||
def test_torchair_deepseek_v2_decoder_layer(mock_maybe_chunk_residual,
|
||||
mock_maybe_wait_prefetch_done,
|
||||
mock_rms_norm, mock_add_norm,
|
||||
mock_distributed, base_config,
|
||||
vllm_config, mock_forward_context):
|
||||
vllm_config, mock_forward_context,
|
||||
patch_attention_init):
|
||||
mock_maybe_chunk_residual.return_value = torch.randn(2, 4, 128)
|
||||
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
|
||||
mock_add_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128),
|
||||
torch.randn(2, 128))
|
||||
@@ -309,7 +340,8 @@ def test_torchair_deepseek_v2_decoder_layer(mock_maybe_wait_prefetch_done,
|
||||
assert isinstance(layer.mlp, TorchairDeepseekV2MLP)
|
||||
|
||||
|
||||
def test_torchair_deepseek_v2_for_causal_lm(mock_distributed, vllm_config):
|
||||
def test_torchair_deepseek_v2_for_causal_lm(mock_distributed, vllm_config,
|
||||
patch_attention_init):
|
||||
model = TorchairDeepseekV2ForCausalLM(vllm_config=vllm_config)
|
||||
|
||||
input_ids = torch.randint(0, 10000, (2, 4))
|
||||
|
||||
Reference in New Issue
Block a user