init v0.11.0rc0
This commit is contained in:
@@ -7,8 +7,7 @@ from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
|
||||
AscendAttentionBackendImpl,
|
||||
AscendAttentionMetadataBuilder,
|
||||
AscendAttentionState,
|
||||
AscendMetadata,
|
||||
CommonAttentionState)
|
||||
AscendMetadata)
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
|
||||
|
||||
@@ -25,10 +24,6 @@ class TestAscendAttentionBackend(TestBase):
|
||||
self.assertEqual(AscendAttentionBackend.get_metadata_cls(),
|
||||
AscendMetadata)
|
||||
|
||||
def test_get_state_cls(self):
|
||||
self.assertEqual(AscendAttentionBackend.get_state_cls(),
|
||||
CommonAttentionState)
|
||||
|
||||
def test_get_builder_cls(self):
|
||||
self.assertEqual(AscendAttentionBackend.get_builder_cls(),
|
||||
AscendAttentionMetadataBuilder)
|
||||
@@ -72,7 +67,8 @@ class TestAscendAttentionMetadataBuilder(TestBase):
|
||||
self.mock_vllm_config.model_config.max_model_len = 640
|
||||
self.mock_vllm_config.cache_config.block_size = 64
|
||||
self.mock_device = 'cpu:0'
|
||||
self.builder = AscendAttentionMetadataBuilder(self.mock_vllm_config,
|
||||
self.builder = AscendAttentionMetadataBuilder(None, None,
|
||||
self.mock_vllm_config,
|
||||
self.mock_device)
|
||||
|
||||
def test_reorder_batch(self):
|
||||
@@ -100,19 +96,21 @@ class TestAscendAttentionMetadataBuilder(TestBase):
|
||||
max_query_len=5,
|
||||
decode_token_per_req=torch.tensor([1, 1]),
|
||||
block_table_tensor=torch.zeros((10, 10)),
|
||||
slot_mapping_cpu=torch.tensor(range(20)),
|
||||
slot_mapping=torch.tensor(range(20)),
|
||||
actual_seq_lengths_q=torch.tensor([0, 1]),
|
||||
positions=torch.tensor([10, 10]),
|
||||
attn_mask=torch.ones((10, 10)),
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.PrefillNoCache)
|
||||
attn_state=AscendAttentionState.PrefillNoCache,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
|
||||
mock_nz_tensor = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_nd_to_nz_2d.return_value = mock_nz_tensor
|
||||
mock_npu_format_cast.return_value = mock_nz_tensor
|
||||
|
||||
self.builder.build(common_attn_metadata, mock_model)
|
||||
self.builder.build(1, common_attn_metadata, mock_model)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
|
||||
@patch('torch_npu.npu_format_cast')
|
||||
@@ -131,12 +129,14 @@ class TestAscendAttentionMetadataBuilder(TestBase):
|
||||
max_query_len=6,
|
||||
decode_token_per_req=torch.tensor([1, 1, 1]),
|
||||
block_table_tensor=torch.zeros((10, 10)),
|
||||
slot_mapping_cpu=torch.tensor(range(20)),
|
||||
slot_mapping=torch.tensor(range(20)),
|
||||
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
||||
positions=torch.tensor([10, 10]),
|
||||
attn_mask=torch.ones((15, 15)),
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.ChunkedPrefill)
|
||||
attn_state=AscendAttentionState.ChunkedPrefill,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
|
||||
mock_ascend_attention_state = MagicMock()
|
||||
mock_ascend_attention_state.PrefillNoCache = 0
|
||||
@@ -146,7 +146,7 @@ class TestAscendAttentionMetadataBuilder(TestBase):
|
||||
mock_nd_to_nz_spec.return_value = mock_nz_tensor
|
||||
mock_npu_format_cast.return_value = mock_nz_tensor
|
||||
|
||||
self.builder.build(common_attn_metadata, mock_model)
|
||||
self.builder.build(1, common_attn_metadata, mock_model)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
|
||||
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
|
||||
@@ -160,15 +160,17 @@ class TestAscendAttentionMetadataBuilder(TestBase):
|
||||
max_query_len=6,
|
||||
decode_token_per_req=torch.tensor([1, 1, 1]),
|
||||
block_table_tensor=torch.zeros((10, 10)),
|
||||
slot_mapping_cpu=torch.tensor(range(20)),
|
||||
slot_mapping=torch.tensor(range(20)),
|
||||
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
||||
positions=torch.tensor([10, 10]),
|
||||
attn_mask=torch.ones((15, 15)),
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.ChunkedPrefill)
|
||||
attn_state=AscendAttentionState.ChunkedPrefill,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
mock_model = MagicMock()
|
||||
|
||||
self.builder.build(common_attn_metadata, mock_model)
|
||||
self.builder.build(1, common_attn_metadata, mock_model)
|
||||
|
||||
|
||||
class TestAscendAttentionBackendImpl(TestBase):
|
||||
@@ -341,36 +343,6 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
mock_flash_attention.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_flash_attention')
|
||||
def test_forward_prefill_no_cache_swa(self, mock_flash_attention,
|
||||
mock_reshape_cache):
|
||||
"""Test forward pass in PrefillNoCache state"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.PrefillNoCache
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
layer = self.layer_no_quant
|
||||
# layer.quant_method.apply.return_value = metadata
|
||||
print(self.layer_no_quant._v_scale_float)
|
||||
output = self.impl_swa.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
|
||||
mock_reshape_cache.assert_called_once()
|
||||
mock_flash_attention.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_flash_attention_qlens')
|
||||
def test_forward_prefill_cache_hit(self, mock_flash_attention_qlens,
|
||||
@@ -401,10 +373,12 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
mock_flash_attention_qlens.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_paged_attention')
|
||||
def test_forward_decode_only(self, mock_paged_attention,
|
||||
mock_npu_reshape_and_cache):
|
||||
mock_npu_reshape_and_cache,
|
||||
mock_get_forward_context):
|
||||
"""Test forward pass in DecodeOnly state"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
@@ -418,6 +392,8 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
layer = self.layer_no_quant
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
@@ -458,6 +434,44 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
mock_fused_infer_attention_score.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_paged_attention')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
def test_forward_decode_only_swa_seq_len_mismatch(
|
||||
self, mock_fused_infer_attention_score, mock_paged_attention,
|
||||
mock_npu_reshape_and_cache, mock_get_forward_context):
|
||||
"""Test forward pass in DecodeOnly state when seq)len_mismatch"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||
metadata.seq_lens = torch.tensor([10]) # len == 1 != query.size(0)==10
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
|
||||
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
|
||||
64), 1)
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
output = self.impl_swa.forward(self.layer_no_quant,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
|
||||
mock_paged_attention.assert_called_once()
|
||||
mock_fused_infer_attention_score.assert_not_called()
|
||||
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill')
|
||||
|
||||
@@ -186,10 +186,39 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_device = 'cpu'
|
||||
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
ascend_config = MagicMock()
|
||||
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
|
||||
return_value=ascend_config):
|
||||
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
|
||||
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
|
||||
mock_device)
|
||||
|
||||
self.assertEqual(builder.block_size,
|
||||
mock_vllm_config.cache_config.block_size)
|
||||
self.assertEqual(
|
||||
builder.chunked_prefill_enabled,
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
|
||||
|
||||
def test_ascend_mla_metadata_builder_spec_decode(self):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
mock_vllm_config.model_config.dtype = torch.float16
|
||||
mock_vllm_config.cache_config.block_size = 16
|
||||
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_device = 'cpu'
|
||||
|
||||
mock_spec_config = MagicMock()
|
||||
mock_spec_config.num_speculative_tokens = 3
|
||||
mock_vllm_config.speculative_config = mock_spec_config
|
||||
|
||||
ascend_config = MagicMock()
|
||||
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
|
||||
return_value=ascend_config):
|
||||
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
|
||||
mock_device)
|
||||
|
||||
self.assertEqual(builder.block_size,
|
||||
mock_vllm_config.cache_config.block_size)
|
||||
@@ -207,9 +236,12 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_device = 'cpu'
|
||||
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
|
||||
return_value=ascend_config):
|
||||
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
|
||||
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
|
||||
mock_device)
|
||||
builder.decode_threshold = 1
|
||||
|
||||
input_batch = MagicMock()
|
||||
@@ -522,7 +554,11 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.impl.num_kv_heads = self.impl.num_heads
|
||||
|
||||
decode_res, prefill_res = self.impl._mla_preprocess(
|
||||
hidden_states, kv_cache, attn_metadata, need_gather_q_kv=False)
|
||||
"mock_layer",
|
||||
hidden_states,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
need_gather_q_kv=False)
|
||||
|
||||
self.assertIsNotNone(decode_res)
|
||||
self.assertIsNotNone(prefill_res)
|
||||
|
||||
720
tests/ut/compilation/test_acl_graph.py
Normal file
720
tests/ut/compilation/test_acl_graph.py
Normal file
@@ -0,0 +1,720 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import torch
|
||||
from vllm.compilation.cuda_graph import CUDAGraphOptions
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.forward_context import BatchDescriptor, ForwardContext
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.compilation.acl_graph import ACLGraphEntry, ACLGraphWrapper
|
||||
|
||||
|
||||
class TestACLGraphEntry(TestBase):
|
||||
|
||||
def test_aclgraph_entry_initialization(self):
|
||||
"""Test ACLGraphEntry initialization with default values"""
|
||||
batch_descriptor = BatchDescriptor(
|
||||
num_tokens=30,
|
||||
uniform_decode=False,
|
||||
)
|
||||
|
||||
entry = ACLGraphEntry(batch_descriptor=batch_descriptor)
|
||||
|
||||
self.assertEqual(entry.batch_descriptor, batch_descriptor)
|
||||
self.assertIsNone(entry.aclgraph)
|
||||
self.assertIsNone(entry.output)
|
||||
self.assertIsNone(entry.input_addresses)
|
||||
|
||||
def test_aclgraph_entry_with_values(self):
|
||||
"""Test ACLGraphEntry initialization with specified values"""
|
||||
batch_descriptor = BatchDescriptor(
|
||||
num_tokens=30,
|
||||
uniform_decode=False,
|
||||
)
|
||||
|
||||
mock_graph = MagicMock()
|
||||
mock_output = MagicMock()
|
||||
input_addresses = [12345, 67890]
|
||||
|
||||
entry = ACLGraphEntry(batch_descriptor=batch_descriptor,
|
||||
aclgraph=mock_graph,
|
||||
output=mock_output,
|
||||
input_addresses=input_addresses)
|
||||
|
||||
self.assertEqual(entry.batch_descriptor, batch_descriptor)
|
||||
self.assertEqual(entry.aclgraph, mock_graph)
|
||||
self.assertEqual(entry.output, mock_output)
|
||||
self.assertEqual(entry.input_addresses, input_addresses)
|
||||
|
||||
|
||||
class TestACLGraphWrapper(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
super().setUp()
|
||||
|
||||
# Mock VllmConfig
|
||||
self.mock_vllm_config = MagicMock(spec=VllmConfig)
|
||||
self.mock_vllm_config.compilation_config = MagicMock()
|
||||
|
||||
# Mock runnable function
|
||||
self.mock_runnable = MagicMock(return_value="test_output")
|
||||
|
||||
# Mock graph pool
|
||||
self.mock_graph_pool = MagicMock()
|
||||
|
||||
# Mock CUDAGraphOptions
|
||||
self.mock_cudagraph_options = MagicMock(spec=CUDAGraphOptions)
|
||||
self.mock_cudagraph_options.debug_log_enable = False
|
||||
self.mock_cudagraph_options.gc_disable = False
|
||||
self.mock_cudagraph_options.weak_ref_output = False
|
||||
|
||||
# Mock BatchDescriptor
|
||||
self.mock_batch_descriptor = BatchDescriptor(
|
||||
num_tokens=30,
|
||||
uniform_decode=False,
|
||||
)
|
||||
|
||||
# Mock ForwardContext
|
||||
self.mock_forward_context = MagicMock(spec=ForwardContext)
|
||||
self.mock_forward_context.batch_descriptor = self.mock_batch_descriptor
|
||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
def test_initialization_with_default_options(self, mock_envs,
|
||||
mock_current_platform):
|
||||
"""Test ACLGraphWrapper initialization with default CUDAGraphOptions"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
|
||||
wrapper = ACLGraphWrapper(runnable=self.mock_runnable,
|
||||
vllm_config=self.mock_vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL,
|
||||
graph_pool=self.mock_graph_pool)
|
||||
|
||||
self.assertEqual(wrapper.runnable, self.mock_runnable)
|
||||
self.assertEqual(wrapper.vllm_config, self.mock_vllm_config)
|
||||
self.assertEqual(wrapper.graph_pool, self.mock_graph_pool)
|
||||
self.assertEqual(wrapper.runtime_mode, CUDAGraphMode.FULL)
|
||||
self.assertFalse(wrapper.is_debugging_mode)
|
||||
self.assertIsInstance(wrapper.aclgraph_options, CUDAGraphOptions)
|
||||
self.assertEqual(wrapper.concrete_aclgraph_entries, {})
|
||||
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
def test_initialization_with_custom_options(self, mock_envs,
|
||||
mock_current_platform):
|
||||
"""Test ACLGraphWrapper initialization with custom CUDAGraphOptions"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG"
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
|
||||
wrapper = ACLGraphWrapper(
|
||||
runnable=self.mock_runnable,
|
||||
vllm_config=self.mock_vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL,
|
||||
graph_pool=self.mock_graph_pool,
|
||||
cudagraph_options=self.mock_cudagraph_options)
|
||||
|
||||
self.assertEqual(wrapper.runnable, self.mock_runnable)
|
||||
self.assertEqual(wrapper.vllm_config, self.mock_vllm_config)
|
||||
self.assertEqual(wrapper.graph_pool, self.mock_graph_pool)
|
||||
self.assertEqual(wrapper.runtime_mode, CUDAGraphMode.FULL)
|
||||
self.assertTrue(wrapper.is_debugging_mode)
|
||||
self.assertEqual(wrapper.aclgraph_options, self.mock_cudagraph_options)
|
||||
self.assertEqual(wrapper.concrete_aclgraph_entries, {})
|
||||
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
def test_initialization_assertion_error(self, mock_envs,
|
||||
mock_current_platform):
|
||||
"""Test ACLGraphWrapper initialization raises AssertionError for NONE mode"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
ACLGraphWrapper(runnable=self.mock_runnable,
|
||||
vllm_config=self.mock_vllm_config,
|
||||
runtime_mode=CUDAGraphMode.NONE,
|
||||
graph_pool=self.mock_graph_pool)
|
||||
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
def test_call_with_none_runtime_mode(self, mock_envs,
|
||||
mock_current_platform,
|
||||
mock_get_forward_context):
|
||||
"""Test __call__ method when runtime mode is NONE"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
mock_get_forward_context.return_value = self.mock_forward_context
|
||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.NONE
|
||||
|
||||
wrapper = ACLGraphWrapper(
|
||||
runnable=self.mock_runnable,
|
||||
vllm_config=self.mock_vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL,
|
||||
graph_pool=self.mock_graph_pool,
|
||||
cudagraph_options=self.mock_cudagraph_options)
|
||||
|
||||
result = wrapper("arg1", "arg2")
|
||||
|
||||
# Should call the runnable directly without graph capture
|
||||
self.mock_runnable.assert_called_once_with("arg1", "arg2")
|
||||
self.assertEqual(result, "test_output")
|
||||
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
def test_call_with_mismatched_runtime_mode(self, mock_envs,
|
||||
mock_current_platform,
|
||||
mock_get_forward_context):
|
||||
"""Test __call__ method when runtime mode doesn't match wrapper mode"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
mock_get_forward_context.return_value = self.mock_forward_context
|
||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE # Different from FULL
|
||||
|
||||
wrapper = ACLGraphWrapper(
|
||||
runnable=self.mock_runnable,
|
||||
vllm_config=self.mock_vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL,
|
||||
graph_pool=self.mock_graph_pool,
|
||||
cudagraph_options=self.mock_cudagraph_options)
|
||||
|
||||
result = wrapper("arg1", "arg2")
|
||||
|
||||
# Should call the runnable directly without graph capture
|
||||
self.mock_runnable.assert_called_once_with("arg1", "arg2")
|
||||
self.assertEqual(result, "test_output")
|
||||
|
||||
@patch('vllm_ascend.compilation.acl_graph.torch')
|
||||
@patch(
|
||||
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||
)
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
||||
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
||||
def test_call_capture_graph_first_time(
|
||||
self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs,
|
||||
mock_current_platform, mock_get_forward_context,
|
||||
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
||||
"""Test __call__ method captures graph for the first time"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
mock_get_forward_context.return_value = self.mock_forward_context
|
||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
|
||||
# Mock torch.npu.NPUGraph
|
||||
mock_npu_graph = MagicMock()
|
||||
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
||||
|
||||
# Mock torch.npu.graph context manager
|
||||
mock_graph_context = MagicMock()
|
||||
mock_torch.npu.graph.return_value = mock_graph_context
|
||||
mock_graph_context.__enter__ = Mock(return_value=None)
|
||||
mock_graph_context.__exit__ = Mock(return_value=None)
|
||||
|
||||
# Mock weak_ref_tensors to return the same output
|
||||
mock_weak_ref_tensors.return_value = "weak_ref_output"
|
||||
|
||||
# Ensure torch.Tensor can be correctly identified by isinstance
|
||||
mock_torch.Tensor = torch.Tensor
|
||||
|
||||
# Set up the compilation counter mock
|
||||
mock_compilation_counter.num_cudagraph_captured = 0
|
||||
|
||||
wrapper = ACLGraphWrapper(
|
||||
runnable=self.mock_runnable,
|
||||
vllm_config=self.mock_vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL,
|
||||
graph_pool=self.mock_graph_pool,
|
||||
cudagraph_options=self.mock_cudagraph_options)
|
||||
|
||||
# Create a real torch tensor for the test, not a mock
|
||||
test_tensor = torch.tensor([1, 2, 3])
|
||||
|
||||
# Call the wrapper
|
||||
result = wrapper(test_tensor, "arg2")
|
||||
|
||||
# Verify graph capture happened
|
||||
mock_validate_cudagraph_capturing_enabled.assert_called_once()
|
||||
mock_torch.npu.NPUGraph.assert_called_once()
|
||||
mock_torch.npu.graph.assert_called_once_with(mock_npu_graph,
|
||||
pool=self.mock_graph_pool)
|
||||
self.mock_runnable.assert_called_once_with(test_tensor, "arg2")
|
||||
|
||||
# Verify the entry was created and updated
|
||||
self.assertIn(self.mock_batch_descriptor,
|
||||
wrapper.concrete_aclgraph_entries)
|
||||
entry = wrapper.concrete_aclgraph_entries[self.mock_batch_descriptor]
|
||||
self.assertEqual(entry.aclgraph, mock_npu_graph)
|
||||
self.assertEqual(entry.output, "weak_ref_output")
|
||||
|
||||
# Verify compilation counter was incremented
|
||||
self.assertEqual(mock_compilation_counter.num_cudagraph_captured, 1)
|
||||
|
||||
# Should return the original output (not weak ref)
|
||||
self.assertEqual(result, "test_output")
|
||||
|
||||
@patch('vllm_ascend.compilation.acl_graph.torch')
|
||||
@patch(
|
||||
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||
)
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
||||
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
||||
def test_call_replay_graph(self, mock_weak_ref_tensors,
|
||||
mock_compilation_counter, mock_envs,
|
||||
mock_current_platform, mock_get_forward_context,
|
||||
mock_validate_cudagraph_capturing_enabled,
|
||||
mock_torch):
|
||||
"""Test __call__ method replays graph when already captured"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
mock_get_forward_context.return_value = self.mock_forward_context
|
||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
|
||||
# Mock torch.npu.NPUGraph
|
||||
mock_npu_graph = MagicMock()
|
||||
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
||||
|
||||
# Mock torch.npu.graph context manager
|
||||
mock_graph_context = MagicMock()
|
||||
mock_torch.npu.graph.return_value = mock_graph_context
|
||||
mock_graph_context.__enter__ = Mock(return_value=None)
|
||||
mock_graph_context.__exit__ = Mock(return_value=None)
|
||||
|
||||
# Mock weak_ref_tensors to return the same output
|
||||
mock_weak_ref_tensors.return_value = "weak_ref_output"
|
||||
|
||||
# Ensure torch.Tensor can be correctly identified by isinstance
|
||||
mock_torch.Tensor = torch.Tensor
|
||||
|
||||
# Set up the compilation counter mock
|
||||
mock_compilation_counter.num_cudagraph_captured = 0
|
||||
|
||||
wrapper = ACLGraphWrapper(
|
||||
runnable=self.mock_runnable,
|
||||
vllm_config=self.mock_vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL,
|
||||
graph_pool=self.mock_graph_pool,
|
||||
cudagraph_options=self.mock_cudagraph_options)
|
||||
|
||||
# Create a real torch tensor for the test, not a mock
|
||||
test_tensor = torch.tensor([1, 2, 3])
|
||||
|
||||
# First call to capture the graph
|
||||
first_result = wrapper(test_tensor, "arg2")
|
||||
|
||||
# Verify graph capture happened during first call
|
||||
mock_validate_cudagraph_capturing_enabled.assert_called_once()
|
||||
mock_torch.npu.NPUGraph.assert_called_once()
|
||||
mock_torch.npu.graph.assert_called_once()
|
||||
|
||||
# Reset mock to track second call
|
||||
self.mock_runnable.reset_mock()
|
||||
mock_npu_graph.reset_mock()
|
||||
|
||||
# Second call should replay the graph
|
||||
second_result = wrapper(test_tensor, "arg2")
|
||||
|
||||
# Verify runnable was called only during capture (not during replay)
|
||||
self.mock_runnable.assert_not_called()
|
||||
|
||||
# Verify graph replay happened
|
||||
mock_npu_graph.replay.assert_called_once()
|
||||
|
||||
# Both calls should return the weak ref output
|
||||
self.assertEqual(first_result, "test_output") # Original output
|
||||
self.assertEqual(second_result, "weak_ref_output") # Weak ref output
|
||||
|
||||
@patch('vllm_ascend.compilation.acl_graph.torch')
|
||||
@patch(
|
||||
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||
)
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
||||
def test_call_with_debug_mode_input_address_check(
|
||||
self, mock_weak_ref_tensors, mock_envs, mock_current_platform,
|
||||
mock_get_forward_context,
|
||||
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
||||
"""Test __call__ method with debug mode input address checking"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
mock_get_forward_context.return_value = self.mock_forward_context
|
||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
|
||||
# Mock torch.npu.NPUGraph
|
||||
mock_npu_graph = MagicMock()
|
||||
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
||||
|
||||
# Mock torch.npu.graph context manager
|
||||
mock_graph_context = MagicMock()
|
||||
mock_torch.npu.graph.return_value = mock_graph_context
|
||||
mock_graph_context.__enter__ = Mock(return_value=None)
|
||||
mock_graph_context.__exit__ = Mock(return_value=None)
|
||||
|
||||
# Mock weak_ref_tensors
|
||||
mock_weak_ref_tensors.return_value = "weak_ref_output"
|
||||
|
||||
# Ensure torch.Tensor can be correctly identified by isinstance
|
||||
mock_torch.Tensor = torch.Tensor
|
||||
|
||||
# Create a mock tensor as the output of runnable
|
||||
mock_output_tensor = torch.tensor([4, 5, 6])
|
||||
self.mock_runnable.return_value = mock_output_tensor
|
||||
|
||||
wrapper = ACLGraphWrapper(
|
||||
runnable=self.mock_runnable,
|
||||
vllm_config=self.mock_vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL,
|
||||
graph_pool=self.mock_graph_pool,
|
||||
cudagraph_options=self.mock_cudagraph_options)
|
||||
|
||||
# First call to capture the graph
|
||||
tensor = torch.tensor([1, 2, 3]) # Create tensor once
|
||||
_ = wrapper(tensor, "arg2")
|
||||
|
||||
# Second call with same tensor addresses should work
|
||||
_ = wrapper(tensor, "arg2") # Use the same tensor object
|
||||
|
||||
# Should not raise AssertionError
|
||||
self.assertTrue(True)
|
||||
|
||||
@patch('vllm_ascend.compilation.acl_graph.torch')
|
||||
@patch(
|
||||
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||
)
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
||||
def test_call_with_debug_mode_input_address_mismatch(
|
||||
self, mock_weak_ref_tensors, mock_envs, mock_current_platform,
|
||||
mock_get_forward_context,
|
||||
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
||||
"""Test __call__ method with debug mode input address mismatch raises AssertionError"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "DEBUG" # Enable debug mode
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
mock_get_forward_context.return_value = self.mock_forward_context
|
||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
|
||||
# Mock torch.npu.NPUGraph
|
||||
mock_npu_graph = MagicMock()
|
||||
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
||||
|
||||
# Mock torch.npu.graph context manager
|
||||
mock_graph_context = MagicMock()
|
||||
mock_torch.npu.graph.return_value = mock_graph_context
|
||||
mock_graph_context.__enter__ = Mock(return_value=None)
|
||||
mock_graph_context.__exit__ = Mock(return_value=None)
|
||||
|
||||
# Mock weak_ref_tensors
|
||||
mock_weak_ref_tensors.return_value = "weak_ref_output"
|
||||
|
||||
# Ensure torch.Tensor can be correctly identified by isinstance
|
||||
mock_torch.Tensor = torch.Tensor
|
||||
|
||||
# Create a mock tensor as the output of runnable
|
||||
mock_output_tensor = torch.tensor([4, 5, 6])
|
||||
self.mock_runnable.return_value = mock_output_tensor
|
||||
|
||||
wrapper = ACLGraphWrapper(
|
||||
runnable=self.mock_runnable,
|
||||
vllm_config=self.mock_vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL,
|
||||
graph_pool=self.mock_graph_pool,
|
||||
cudagraph_options=self.mock_cudagraph_options)
|
||||
|
||||
# First call to capture the graph
|
||||
tensor1 = torch.tensor([1, 2, 3])
|
||||
_ = wrapper(tensor1, "arg2")
|
||||
|
||||
# Second call with different tensor addresses should raise AssertionError
|
||||
tensor2 = torch.tensor([4, 5,
|
||||
6]) # Different values, different address
|
||||
|
||||
with self.assertRaises(AssertionError) as context:
|
||||
wrapper(tensor2, "arg2")
|
||||
|
||||
self.assertIn("Input addresses for aclgraphs are different",
|
||||
str(context.exception))
|
||||
|
||||
@patch('vllm_ascend.compilation.acl_graph.torch')
|
||||
@patch(
|
||||
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||
)
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
||||
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
||||
@patch('vllm_ascend.compilation.acl_graph.patch')
|
||||
def test_call_capture_graph_with_gc_disable(
|
||||
self, mock_patch, mock_weak_ref_tensors, mock_compilation_counter,
|
||||
mock_envs, mock_current_platform, mock_get_forward_context,
|
||||
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
||||
"""Test __call__ method captures graph with gc_disable option enabled"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
mock_get_forward_context.return_value = self.mock_forward_context
|
||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
|
||||
# Enable gc_disable option
|
||||
self.mock_cudagraph_options.gc_disable = True
|
||||
# weak_ref_output is not enabled by default
|
||||
|
||||
# Mock torch.npu.NPUGraph
|
||||
mock_npu_graph = MagicMock()
|
||||
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
||||
|
||||
# Mock torch.npu.graph context manager
|
||||
mock_graph_context = MagicMock()
|
||||
mock_torch.npu.graph.return_value = mock_graph_context
|
||||
mock_graph_context.__enter__ = Mock(return_value=None)
|
||||
mock_graph_context.__exit__ = Mock(return_value=None)
|
||||
|
||||
# Mock patch context manager
|
||||
mock_exit_stack = MagicMock()
|
||||
mock_patch.return_value = mock_exit_stack
|
||||
mock_exit_stack.enter_context = Mock()
|
||||
|
||||
# Mock weak_ref_tensors to simulate the actual behavior:
|
||||
# 1. First call (inside the graph context) should return "inner_output"
|
||||
# 2. Second call (for entry.output) should return "weak_ref_output"
|
||||
mock_weak_ref_tensors.side_effect = ["inner_output", "weak_ref_output"]
|
||||
|
||||
# Ensure torch.Tensor can be correctly identified by isinstance
|
||||
mock_torch.Tensor = torch.Tensor
|
||||
|
||||
# Set up the compilation counter mock
|
||||
mock_compilation_counter.num_cudagraph_captured = 0
|
||||
|
||||
wrapper = ACLGraphWrapper(
|
||||
runnable=self.mock_runnable,
|
||||
vllm_config=self.mock_vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL,
|
||||
graph_pool=self.mock_graph_pool,
|
||||
cudagraph_options=self.mock_cudagraph_options)
|
||||
|
||||
# Create a real torch tensor for the test, not a mock
|
||||
test_tensor = torch.tensor([1, 2, 3])
|
||||
|
||||
# Call the wrapper
|
||||
result = wrapper(test_tensor, "arg2")
|
||||
|
||||
# Verify patch was called to disable gc
|
||||
self.assertTrue(mock_patch.called)
|
||||
|
||||
# Verify graph capture happened
|
||||
mock_validate_cudagraph_capturing_enabled.assert_called_once()
|
||||
mock_torch.npu.NPUGraph.assert_called_once()
|
||||
mock_torch.npu.graph.assert_called_once_with(mock_npu_graph,
|
||||
pool=self.mock_graph_pool)
|
||||
|
||||
# Should return the original output (not weak ref) since weak_ref_output is not enabled
|
||||
self.assertEqual(result, "test_output")
|
||||
|
||||
@patch('vllm_ascend.compilation.acl_graph.torch')
|
||||
@patch(
|
||||
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||
)
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
@patch('vllm_ascend.compilation.acl_graph.compilation_counter')
|
||||
@patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors')
|
||||
def test_call_capture_graph_with_weak_ref_output(
|
||||
self, mock_weak_ref_tensors, mock_compilation_counter, mock_envs,
|
||||
mock_current_platform, mock_get_forward_context,
|
||||
mock_validate_cudagraph_capturing_enabled, mock_torch):
|
||||
"""Test __call__ method captures graph with weak_ref_output option enabled"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
mock_get_forward_context.return_value = self.mock_forward_context
|
||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
|
||||
# Enable weak_ref_output option
|
||||
self.mock_cudagraph_options.weak_ref_output = True
|
||||
|
||||
# Mock torch.npu.NPUGraph
|
||||
mock_npu_graph = MagicMock()
|
||||
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
||||
|
||||
# Mock torch.npu.graph context manager
|
||||
mock_graph_context = MagicMock()
|
||||
mock_torch.npu.graph.return_value = mock_graph_context
|
||||
mock_graph_context.__enter__ = Mock(return_value=None)
|
||||
mock_graph_context.__exit__ = Mock(return_value=None)
|
||||
|
||||
# Mock weak_ref_tensors to simulate the actual behavior:
|
||||
# 1. First call (inside the graph context with weak_ref_output=True) should return "weak_ref_output"
|
||||
# 2. Second call (for entry.output) should return "weak_ref_output"
|
||||
mock_weak_ref_tensors.side_effect = [
|
||||
"weak_ref_output", "weak_ref_output"
|
||||
]
|
||||
|
||||
# Ensure torch.Tensor can be correctly identified by isinstance
|
||||
mock_torch.Tensor = torch.Tensor
|
||||
|
||||
# Set up the compilation counter mock
|
||||
mock_compilation_counter.num_cudagraph_captured = 0
|
||||
|
||||
wrapper = ACLGraphWrapper(
|
||||
runnable=self.mock_runnable,
|
||||
vllm_config=self.mock_vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL,
|
||||
graph_pool=self.mock_graph_pool,
|
||||
cudagraph_options=self.mock_cudagraph_options)
|
||||
|
||||
# Create a real torch tensor for the test, not a mock
|
||||
test_tensor = torch.tensor([1, 2, 3])
|
||||
|
||||
# Call the wrapper
|
||||
result = wrapper(test_tensor, "arg2")
|
||||
|
||||
# Verify weak_ref_tensors was called twice (once for inner output, once for final output)
|
||||
self.assertEqual(mock_weak_ref_tensors.call_count, 2)
|
||||
|
||||
# Verify graph capture happened
|
||||
mock_validate_cudagraph_capturing_enabled.assert_called_once()
|
||||
mock_torch.npu.NPUGraph.assert_called_once()
|
||||
mock_torch.npu.graph.assert_called_once_with(mock_npu_graph,
|
||||
pool=self.mock_graph_pool)
|
||||
|
||||
# Should return the weak ref output when weak_ref_output option is enabled
|
||||
self.assertEqual(result, "weak_ref_output")
|
||||
|
||||
@patch('vllm_ascend.compilation.acl_graph.get_forward_context')
|
||||
@patch('vllm_ascend.compilation.acl_graph.current_platform')
|
||||
@patch('vllm_ascend.compilation.acl_graph.envs')
|
||||
@patch('vllm_ascend.compilation.acl_graph.logger')
|
||||
def test_call_capture_graph_with_debug_log(self, mock_logger, mock_envs,
|
||||
mock_current_platform,
|
||||
mock_get_forward_context):
|
||||
"""Test __call__ method captures graph with debug logging enabled"""
|
||||
mock_envs.VLLM_LOGGING_LEVEL = "INFO"
|
||||
mock_current_platform.get_global_graph_pool.return_value = self.mock_graph_pool
|
||||
mock_get_forward_context.return_value = self.mock_forward_context
|
||||
self.mock_forward_context.cudagraph_runtime_mode = CUDAGraphMode.FULL
|
||||
|
||||
# Enable debug logging
|
||||
self.mock_cudagraph_options.debug_log_enable = True
|
||||
# weak_ref_output is not enabled by default
|
||||
|
||||
# Mock torch
|
||||
with patch('vllm_ascend.compilation.acl_graph.torch') as mock_torch:
|
||||
# Mock torch.npu.NPUGraph
|
||||
mock_npu_graph = MagicMock()
|
||||
mock_torch.npu.NPUGraph.return_value = mock_npu_graph
|
||||
|
||||
# Mock torch.npu.graph context manager
|
||||
mock_graph_context = MagicMock()
|
||||
mock_torch.npu.graph.return_value = mock_graph_context
|
||||
mock_graph_context.__enter__ = Mock(return_value=None)
|
||||
mock_graph_context.__exit__ = Mock(return_value=None)
|
||||
|
||||
# Ensure torch.Tensor can be correctly identified by isinstance
|
||||
mock_torch.Tensor = torch.Tensor
|
||||
|
||||
# Mock weak_ref_tensors
|
||||
with patch('vllm_ascend.compilation.acl_graph.weak_ref_tensors'
|
||||
) as mock_weak_ref_tensors:
|
||||
# Mock weak_ref_tensors to simulate the actual behavior:
|
||||
# 1. First call (inside the graph context) should return "inner_output"
|
||||
# 2. Second call (for entry.output) should return "weak_ref_output"
|
||||
mock_weak_ref_tensors.side_effect = [
|
||||
"inner_output", "weak_ref_output"
|
||||
]
|
||||
|
||||
# Mock validate_cudagraph_capturing_enabled
|
||||
with patch(
|
||||
'vllm_ascend.compilation.acl_graph.validate_cudagraph_capturing_enabled'
|
||||
):
|
||||
wrapper = ACLGraphWrapper(
|
||||
runnable=self.mock_runnable,
|
||||
vllm_config=self.mock_vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL,
|
||||
graph_pool=self.mock_graph_pool,
|
||||
cudagraph_options=self.mock_cudagraph_options)
|
||||
|
||||
# Create a real torch tensor for the test, not a mock
|
||||
test_tensor = torch.tensor([1, 2, 3])
|
||||
|
||||
# Call the wrapper
|
||||
_ = wrapper(test_tensor, "arg2")
|
||||
|
||||
# Verify debug log was called
|
||||
mock_logger.debug.assert_called_once()
|
||||
|
||||
def test_getattr_access_runnable_attributes(self):
|
||||
"""Test __getattr__ method accesses runnable attributes"""
|
||||
mock_runnable = MagicMock()
|
||||
mock_runnable.test_attr = "test_value"
|
||||
|
||||
wrapper = ACLGraphWrapper(
|
||||
runnable=mock_runnable,
|
||||
vllm_config=self.mock_vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL,
|
||||
graph_pool=self.mock_graph_pool,
|
||||
cudagraph_options=self.mock_cudagraph_options)
|
||||
|
||||
# Should be able to access attributes of the runnable
|
||||
self.assertEqual(wrapper.test_attr, "test_value")
|
||||
|
||||
def test_getattr_attribute_not_exists(self):
|
||||
"""Test __getattr__ method raises AttributeError for non-existent attributes"""
|
||||
|
||||
# Create a simple object without any attributes
|
||||
class EmptyRunnable:
|
||||
pass
|
||||
|
||||
mock_runnable = EmptyRunnable()
|
||||
|
||||
wrapper = ACLGraphWrapper(
|
||||
runnable=mock_runnable,
|
||||
vllm_config=self.mock_vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL,
|
||||
graph_pool=self.mock_graph_pool,
|
||||
cudagraph_options=self.mock_cudagraph_options)
|
||||
|
||||
# Should raise AttributeError for non-existent attributes
|
||||
with self.assertRaises(AttributeError) as context:
|
||||
_ = wrapper.non_existent_attr
|
||||
|
||||
self.assertIn("Attribute non_existent_attr not exists",
|
||||
str(context.exception))
|
||||
|
||||
def test_unwrap_method(self):
|
||||
"""Test unwrap method returns the original runnable"""
|
||||
wrapper = ACLGraphWrapper(
|
||||
runnable=self.mock_runnable,
|
||||
vllm_config=self.mock_vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL,
|
||||
graph_pool=self.mock_graph_pool,
|
||||
cudagraph_options=self.mock_cudagraph_options)
|
||||
|
||||
unwrapped = wrapper.unwrap()
|
||||
self.assertEqual(unwrapped, self.mock_runnable)
|
||||
@@ -27,7 +27,6 @@ class TestAscendSchedulerConfig(TestBase):
|
||||
max_model_len=8192,
|
||||
is_multimodal_model=False,
|
||||
send_delta_data=False,
|
||||
scheduler_delay_factor=0,
|
||||
)
|
||||
|
||||
def test_initialize_from_config_with_default(self):
|
||||
@@ -36,7 +35,6 @@ class TestAscendSchedulerConfig(TestBase):
|
||||
self.basic_scheduler_config, {})
|
||||
self.assertEqual(ascend_config.enable_chunked_prefill, False)
|
||||
self.assertEqual(ascend_config.policy, "fcfs")
|
||||
self.assertEqual(ascend_config.num_scheduler_steps, 1)
|
||||
self.assertEqual(ascend_config.scheduler_cls,
|
||||
"vllm_ascend.core.scheduler.AscendScheduler")
|
||||
self.assertEqual(ascend_config.max_num_encoder_input_tokens, 8192)
|
||||
@@ -49,19 +47,21 @@ class TestAscendSchedulerConfig(TestBase):
|
||||
AscendSchedulerConfig(
|
||||
enable_chunked_prefill=False,
|
||||
policy="fcfs",
|
||||
num_scheduler_steps=1,
|
||||
scheduler_cls="vllm_ascend.core.scheduler.AscendScheduler",
|
||||
max_num_batched_tokens=2048,
|
||||
max_model_len=2048,
|
||||
max_long_partial_prefills=1,
|
||||
long_prefill_token_threshold=512,
|
||||
),
|
||||
)
|
||||
self.assertEqual(ascend_config.enable_chunked_prefill, False)
|
||||
self.assertEqual(ascend_config.policy, "fcfs")
|
||||
self.assertEqual(ascend_config.num_scheduler_steps, 1)
|
||||
self.assertEqual(ascend_config.scheduler_cls,
|
||||
"vllm_ascend.core.scheduler.AscendScheduler")
|
||||
self.assertEqual(ascend_config.max_num_batched_tokens, 2048)
|
||||
self.assertEqual(ascend_config.encoder_cache_size, 2048)
|
||||
self.assertEqual(ascend_config.max_long_partial_prefills, 1)
|
||||
self.assertEqual(ascend_config.long_prefill_token_threshold, 512)
|
||||
|
||||
def test_not_implemented_policy(self):
|
||||
with self.assertRaises(NotImplementedError) as context:
|
||||
@@ -78,28 +78,6 @@ class TestAscendSchedulerConfig(TestBase):
|
||||
str(context.exception),
|
||||
)
|
||||
|
||||
def test_not_implemented_multimodal(self):
|
||||
with self.assertRaises(NotImplementedError) as context:
|
||||
AscendSchedulerConfig.initialize_from_config(
|
||||
SchedulerConfig(is_multimodal_model=True), {})
|
||||
self.assertIn("currently AscendScheduler only supports LLM models",
|
||||
str(context.exception))
|
||||
|
||||
def test_not_implemented_multi_step(self):
|
||||
with self.assertRaises(NotImplementedError) as context:
|
||||
AscendSchedulerConfig.initialize_from_config(
|
||||
self.basic_scheduler_config,
|
||||
AscendSchedulerConfig(
|
||||
num_scheduler_steps=2,
|
||||
max_num_batched_tokens=2048,
|
||||
max_model_len=2048,
|
||||
),
|
||||
)
|
||||
self.assertIn(
|
||||
"currently AscendScheduler doesn't support multi-step",
|
||||
str(context.exception),
|
||||
)
|
||||
|
||||
def test_not_implemented_send_delta_data(self):
|
||||
with self.assertRaises(NotImplementedError) as context:
|
||||
AscendSchedulerConfig.initialize_from_config(
|
||||
@@ -115,27 +93,17 @@ class TestAscendSchedulerConfig(TestBase):
|
||||
str(context.exception),
|
||||
)
|
||||
|
||||
def test_not_implemented_delay_factor(self):
|
||||
with self.assertRaises(NotImplementedError) as context:
|
||||
AscendSchedulerConfig.initialize_from_config(
|
||||
self.basic_scheduler_config,
|
||||
AscendSchedulerConfig(
|
||||
delay_factor=1,
|
||||
max_num_batched_tokens=2048,
|
||||
max_model_len=2048,
|
||||
),
|
||||
)
|
||||
self.assertIn(
|
||||
"currently AscendScheduler doesn't support scheduler_delay_factor",
|
||||
str(context.exception),
|
||||
)
|
||||
|
||||
def test_no_override(self):
|
||||
ascend_config = AscendSchedulerConfig.initialize_from_config(
|
||||
self.basic_scheduler_config, {})
|
||||
self.assertEqual(ascend_config.max_num_encoder_input_tokens, 8192)
|
||||
self.assertEqual(ascend_config.encoder_cache_size, 8192)
|
||||
|
||||
def test_valid_config_with_multimodal(self):
|
||||
config = AscendSchedulerConfig.initialize_from_config(
|
||||
SchedulerConfig(is_multimodal_model=True), {})
|
||||
self.assertTrue(config.is_multimodal_model)
|
||||
|
||||
def test_valid_config_with_chunked_prefill(self):
|
||||
ascend_config = AscendSchedulerConfig.initialize_from_config(
|
||||
self.basic_scheduler_config,
|
||||
@@ -165,3 +133,16 @@ class TestAscendSchedulerConfig(TestBase):
|
||||
)
|
||||
self.assertIn("max_num_batched_tokens (2048)", str(context.exception))
|
||||
self.assertIn("max_model_len (4096)", str(context.exception))
|
||||
|
||||
def test_initialize_from_config_with_pd_transfer(self):
|
||||
ascend_config = AscendSchedulerConfig.initialize_from_config(
|
||||
self.basic_scheduler_config,
|
||||
AscendSchedulerConfig(
|
||||
enable_pd_transfer=True,
|
||||
decode_max_num_seqs=48,
|
||||
max_num_batched_tokens=4096,
|
||||
max_model_len=4096,
|
||||
),
|
||||
)
|
||||
self.assertEqual(ascend_config.enable_pd_transfer, True)
|
||||
self.assertEqual(ascend_config.decode_max_num_seqs, 48)
|
||||
|
||||
@@ -6,25 +6,21 @@ from unittest.mock import MagicMock, patch
|
||||
import torch
|
||||
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
||||
MultiModalKwargsItem, PlaceholderRange)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
||||
init_none_hash)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.core.scheduler import AscendScheduler
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")):
|
||||
from vllm.v1.outputs import DraftTokenIds
|
||||
else:
|
||||
DraftTokenIds = None
|
||||
|
||||
EOS_TOKEN_ID = 50256
|
||||
MODEL = "Qwen3-0.6B"
|
||||
@@ -44,7 +40,7 @@ def create_requests(
|
||||
max_tokens: int = 16,
|
||||
stop_token_ids: Optional[list[int]] = None,
|
||||
block_size: int = 3,
|
||||
hash_fn=hash,
|
||||
hash_fn=sha256,
|
||||
):
|
||||
init_none_hash(hash_fn)
|
||||
prompt_logprobs = PROMPT_LOGPROBS
|
||||
@@ -54,25 +50,25 @@ def create_requests(
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
requests = []
|
||||
for i in range(num_requests):
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
request = Request(request_id=f"{i}",
|
||||
prompt_token_ids=[i] * num_tokens,
|
||||
sampling_params=sampling_params,
|
||||
multi_modal_kwargs=None,
|
||||
multi_modal_placeholders=None,
|
||||
multi_modal_hashes=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
pooling_params=None,
|
||||
block_hasher=get_request_block_hasher(
|
||||
block_size, hash_fn))
|
||||
else:
|
||||
request = Request(request_id=f"{i}",
|
||||
prompt_token_ids=[i] * num_tokens,
|
||||
sampling_params=sampling_params,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
pooling_params=None,
|
||||
block_hasher=get_request_block_hasher(
|
||||
block_size, hash_fn))
|
||||
mm_features = []
|
||||
if mm_positions is not None:
|
||||
mm_position = mm_positions[i]
|
||||
for j, position in enumerate(mm_position):
|
||||
identifier = f"hash{i}_{j}"
|
||||
mm_feature = MultiModalFeatureSpec(
|
||||
data=MultiModalKwargsItem.dummy("dummy_m"),
|
||||
mm_position=position,
|
||||
identifier=identifier,
|
||||
modality="image")
|
||||
mm_features.append(mm_feature)
|
||||
request = Request(request_id=f"{i}",
|
||||
prompt_token_ids=[i] * num_tokens,
|
||||
sampling_params=sampling_params,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
pooling_params=None,
|
||||
mm_features=mm_features if mm_features else None,
|
||||
block_hasher=get_request_block_hasher(
|
||||
block_size, hash_fn))
|
||||
requests.append(request)
|
||||
return requests
|
||||
|
||||
@@ -85,25 +81,15 @@ def make_output(scheduler):
|
||||
}
|
||||
sampled_token_ids = [[1000]] * len(scheduler.running)
|
||||
logprobs = None
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
modelrunner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
spec_token_ids=None,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
else:
|
||||
modelrunner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
|
||||
modelrunner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
return modelrunner_output
|
||||
|
||||
|
||||
@@ -113,7 +99,7 @@ class TestAscendScheduler(TestBase):
|
||||
@patch("vllm.config.VllmConfig.__post_init__", MagicMock())
|
||||
@patch('vllm.v1.core.sched.scheduler.compute_encoder_budget')
|
||||
def create_scheduler(self, mock_compute_encoder_budget):
|
||||
mock_compute_encoder_budget.return_value = [10, 20]
|
||||
mock_compute_encoder_budget.return_value = [100, 100]
|
||||
use_kv_connector = False
|
||||
block_size = 16
|
||||
|
||||
@@ -235,7 +221,7 @@ class TestAscendScheduler(TestBase):
|
||||
len(requests) - i - 1)
|
||||
|
||||
def test_schedule(self):
|
||||
'''Test scheduling.
|
||||
'''Test scheduling.
|
||||
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
|
||||
'''
|
||||
scheduler = self.create_scheduler()
|
||||
@@ -260,6 +246,60 @@ class TestAscendScheduler(TestBase):
|
||||
for i, request in enumerate(requests):
|
||||
self.assertEqual(scheduler.running[i], request)
|
||||
|
||||
def test_schedule_multimodal_requests(self):
|
||||
scheduler = self.create_scheduler()
|
||||
scheduler.scheduler_config.chunked_prefill_enabled = False
|
||||
mm_positions = [[PlaceholderRange(offset=i, length=10)]
|
||||
for i in range(10)]
|
||||
requests = create_requests(
|
||||
num_requests=10,
|
||||
mm_positions=mm_positions,
|
||||
)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
output = scheduler.schedule()
|
||||
self.assertEqual(len(output.scheduled_new_reqs), len(requests))
|
||||
self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0)
|
||||
self.assertEqual(len(output.finished_req_ids), 0)
|
||||
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
||||
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
|
||||
|
||||
# Verify all requests are scheduled.
|
||||
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
||||
self.assertEqual(num_tokens,
|
||||
len(requests[int(req_id)].prompt_token_ids))
|
||||
self.assertEqual(len(output.scheduled_encoder_inputs), len(requests))
|
||||
for req_id, encoder_input in output.scheduled_encoder_inputs.items():
|
||||
assert len(encoder_input) == 1
|
||||
|
||||
# Verify requests moved from waiting to running
|
||||
self.assertEqual(len(scheduler.waiting), 0)
|
||||
self.assertEqual(len(scheduler.running), len(requests))
|
||||
for i, request in enumerate(requests):
|
||||
self.assertEqual(scheduler.running[i], request)
|
||||
|
||||
def test_concurrent_partial_prefills_schedule(self):
|
||||
'''Test concurrent partial prefills scheduling.
|
||||
total requests = 10, every request has 10 token.
|
||||
while set long_prefill_token_threshold = 1, scheduler can
|
||||
only schedule max_long_partial_prefills long request.
|
||||
'''
|
||||
scheduler = self.create_scheduler()
|
||||
scheduler.scheduler_config.chunked_prefill_enabled = False
|
||||
scheduler.scheduler_config.max_long_partial_prefills = 2
|
||||
scheduler.scheduler_config.long_prefill_token_threshold = 1
|
||||
requests = create_requests(num_requests=10, num_tokens=20)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Test initial scheduling
|
||||
output = scheduler.schedule()
|
||||
self.assertEqual(len(output.scheduled_new_reqs),
|
||||
scheduler.scheduler_config.max_long_partial_prefills)
|
||||
self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0)
|
||||
self.assertEqual(len(output.finished_req_ids), 0)
|
||||
|
||||
def test_schedule_enable_prefix_caching(self):
|
||||
'''Test scheduling.
|
||||
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
|
||||
@@ -304,69 +344,34 @@ class TestAscendScheduler(TestBase):
|
||||
scheduler.running.append(req)
|
||||
req.status = RequestStatus.RUNNING
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 1,
|
||||
requests[1].request_id: 2
|
||||
},
|
||||
total_num_scheduled_tokens=3,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [],
|
||||
requests[1].request_id: [10]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[EOS_TOKEN_ID], [
|
||||
10, 11
|
||||
]], # First request hits EOS, second continues
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
else:
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 1,
|
||||
requests[1].request_id: 2
|
||||
},
|
||||
total_num_scheduled_tokens=3,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [],
|
||||
requests[1].request_id: [10]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[EOS_TOKEN_ID], [
|
||||
10, 11
|
||||
]], # First request hits EOS, second continues
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 1,
|
||||
requests[1].request_id: 2
|
||||
},
|
||||
total_num_scheduled_tokens=3,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [],
|
||||
requests[1].request_id: [10]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[EOS_TOKEN_ID], [10, 11]
|
||||
], # First request hits EOS, second continues
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
@@ -391,67 +396,35 @@ class TestAscendScheduler(TestBase):
|
||||
scheduler.running.append(req)
|
||||
req.status = RequestStatus.RUNNING
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 3,
|
||||
requests[1].request_id: 2
|
||||
},
|
||||
total_num_scheduled_tokens=5,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [10, 42],
|
||||
requests[1].request_id: [13]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[10, 42, 12],
|
||||
[13, 14]], # First request hits stop token
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
else:
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 3,
|
||||
requests[1].request_id: 2
|
||||
},
|
||||
total_num_scheduled_tokens=5,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [10, 42],
|
||||
requests[1].request_id: [13]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[10, 42, 12],
|
||||
[13, 14]], # First request hits stop token
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 3,
|
||||
requests[1].request_id: 2
|
||||
},
|
||||
total_num_scheduled_tokens=5,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id:
|
||||
[10, 42],
|
||||
requests[1].request_id: [13]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[10, 42, 12],
|
||||
[13, 14]], # First request hits stop token
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
@@ -475,67 +448,35 @@ class TestAscendScheduler(TestBase):
|
||||
scheduler.running.append(req)
|
||||
req.status = RequestStatus.RUNNING
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 3,
|
||||
requests[1].request_id: 1
|
||||
},
|
||||
total_num_scheduled_tokens=4,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [10, 11],
|
||||
requests[1].request_id: []
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[10, 11, 12],
|
||||
[13]], # First request exceeds max_tokens
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
else:
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 3,
|
||||
requests[1].request_id: 1
|
||||
},
|
||||
total_num_scheduled_tokens=4,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [10, 11],
|
||||
requests[1].request_id: []
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[10, 11, 12],
|
||||
[13]], # First request exceeds max_tokens
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 3,
|
||||
requests[1].request_id: 1
|
||||
},
|
||||
total_num_scheduled_tokens=4,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id:
|
||||
[10, 11],
|
||||
requests[1].request_id: []
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[10, 11, 12],
|
||||
[13]], # First request exceeds max_tokens
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
# Verify first request stopped due to length
|
||||
@@ -556,52 +497,27 @@ class TestAscendScheduler(TestBase):
|
||||
scheduler.requests[requests[0].request_id] = requests[0]
|
||||
scheduler.running.append(requests[0])
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={requests[0].request_id: 3},
|
||||
total_num_scheduled_tokens=3,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [EOS_TOKEN_ID, 10]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
else:
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={requests[0].request_id: 3},
|
||||
total_num_scheduled_tokens=3,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [EOS_TOKEN_ID, 10]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={requests[0].request_id: 3},
|
||||
total_num_scheduled_tokens=3,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [EOS_TOKEN_ID, 10]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
@@ -652,23 +568,13 @@ class TestAscendScheduler(TestBase):
|
||||
512)
|
||||
|
||||
# Model output of the first request.
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
else:
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output0,
|
||||
model_runner_output)
|
||||
@@ -678,23 +584,13 @@ class TestAscendScheduler(TestBase):
|
||||
# request is still running.
|
||||
scheduler.schedule()
|
||||
# Model output of the second request.
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[1].request_id],
|
||||
req_id_to_index={requests[1].request_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
else:
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[1].request_id],
|
||||
req_id_to_index={requests[1].request_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[1].request_id],
|
||||
req_id_to_index={requests[1].request_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output1,
|
||||
model_runner_output)
|
||||
@@ -746,29 +642,19 @@ class TestAscendScheduler(TestBase):
|
||||
req_id = requests[i].request_id
|
||||
self.assertEqual(output.num_scheduled_tokens[req_id], 1)
|
||||
self.assertNotIn(req_id, output.scheduled_spec_decode_tokens)
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[0] for _ in range(len(requests))],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
spec_token_ids=spec_tokens,
|
||||
pooler_output=[])
|
||||
else:
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[0] for _ in range(len(requests))],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
draft_token_ids = DraftTokenIds(req_ids, spec_tokens)
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[0] for _ in range(len(requests))],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
draft_token_ids = DraftTokenIds(req_ids, spec_tokens)
|
||||
|
||||
engine_core_outputs = scheduler.update_from_output(
|
||||
output, model_runner_output)
|
||||
if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")):
|
||||
scheduler.update_draft_token_ids(draft_token_ids)
|
||||
scheduler.update_draft_token_ids(draft_token_ids)
|
||||
|
||||
for i in range(len(requests)):
|
||||
running_req = scheduler.running[i]
|
||||
@@ -804,23 +690,14 @@ class TestAscendScheduler(TestBase):
|
||||
else:
|
||||
self.assertNotIn(req_id,
|
||||
output.scheduled_spec_decode_tokens)
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=output_tokens,
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
else:
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=output_tokens,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=output_tokens,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
engine_core_outputs = scheduler.update_from_output(
|
||||
output, model_runner_output)
|
||||
@@ -896,3 +773,34 @@ class TestAscendScheduler(TestBase):
|
||||
|
||||
# Confirm no memory leak.
|
||||
self.assert_scheduler_empty(scheduler)
|
||||
|
||||
def test_scheduler_with_pd_transfer(self):
|
||||
scheduler = self.create_scheduler()
|
||||
scheduler.phase = "prefill"
|
||||
requests = create_requests(num_requests=32)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
# 1st iteration, move 16 requests from waiting to running for prefill
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = make_output(scheduler)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
first_iter_prefilled_req_num = len(scheduler.running)
|
||||
self.assertEqual(len(scheduler_output.scheduled_new_reqs),
|
||||
scheduler.max_num_running_reqs)
|
||||
self.assertEqual(scheduler_output.scheduled_cached_reqs.num_reqs, 0)
|
||||
self.assertEqual(len(scheduler_output.finished_req_ids), 0)
|
||||
|
||||
# 2nd iteration, move 16 prefilled requests to finished_prefill_reqs
|
||||
# and move 16 requests from waiting to running for prefill
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = make_output(scheduler)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
self.assertEqual(len(scheduler.finished_prefill_reqs),
|
||||
first_iter_prefilled_req_num)
|
||||
|
||||
# 3rd iteration, all requests prefilled, change scheduler phase to decode
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = make_output(scheduler)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
self.assertEqual(scheduler.phase, "decode")
|
||||
|
||||
@@ -1,139 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
import importlib
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.distributed.tensor_parallel import (
|
||||
_gather_along_first_dim, _gather_along_last_dim,
|
||||
_reduce_scatter_along_first_dim, _reduce_scatter_along_last_dim,
|
||||
all_to_all_hp2sp, all_to_all_sp2hp)
|
||||
|
||||
|
||||
class TestDistributedCommunication(PytestBase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def context(self, mocker: MockerFixture):
|
||||
mocker.patch("torch.npu.current_device", return_value="cpu")
|
||||
mocker.patch("torch.distributed.get_world_size", return_value=4)
|
||||
|
||||
mocker.patch("torch.distributed.get_rank", return_value=0)
|
||||
|
||||
@pytest.mark.parametrize("world_size, test_tensor, expected",
|
||||
[(1, torch.randn(8, 16), (8, 16)),
|
||||
(4, torch.randn(8, 16), (32, 16))])
|
||||
def test_gather_along_first_dim(self, test_tensor, expected, world_size,
|
||||
mocker: MockerFixture):
|
||||
"""test _gather_along_first_dim"""
|
||||
mocker.patch("torch.distributed.get_world_size",
|
||||
return_value=world_size)
|
||||
|
||||
result = _gather_along_first_dim(test_tensor, mocker.MagicMock())
|
||||
|
||||
assert result.shape == expected
|
||||
|
||||
@pytest.mark.parametrize("test_tensor, output_split_sizes, expected", [
|
||||
(torch.randn(8, 16), [5, 10, 15, 2], (32, 16)),
|
||||
])
|
||||
def test_gather_along_first_dim_unequal_split(self, test_tensor, expected,
|
||||
output_split_sizes,
|
||||
mocker: MockerFixture):
|
||||
"""test _gather_along_first_dim"""
|
||||
|
||||
result = _gather_along_first_dim(test_tensor, mocker.MagicMock(),
|
||||
output_split_sizes)
|
||||
|
||||
assert result.shape == expected
|
||||
|
||||
@pytest.mark.parametrize("world_size, test_tensor, expected",
|
||||
[(1, torch.randn(8, 16, 32), (8, 16, 32)),
|
||||
(4, torch.randn(8, 16, 32), (8, 16, 32 * 4))])
|
||||
def test_gather_along_last_dim(self, test_tensor, expected, world_size,
|
||||
mocker: MockerFixture):
|
||||
"""test _gather_along_last_dim"""
|
||||
mocker.patch("torch.distributed.get_world_size",
|
||||
return_value=world_size)
|
||||
|
||||
result = _gather_along_last_dim(test_tensor, mocker.MagicMock())
|
||||
|
||||
assert result.shape == expected
|
||||
|
||||
@pytest.mark.parametrize("input_shape,expected_shape", [
|
||||
((32, 16), (8, 16)),
|
||||
((40, 10), (10, 10)),
|
||||
])
|
||||
def test_reduce_scatter_along_first_dim(self, input_shape, expected_shape,
|
||||
mocker: MockerFixture):
|
||||
input_tensor = torch.randn(*input_shape)
|
||||
result = _reduce_scatter_along_first_dim(input_tensor,
|
||||
mocker.MagicMock())
|
||||
assert result.shape == expected_shape
|
||||
|
||||
@pytest.mark.parametrize("input_shape,expected_shape", [
|
||||
((8, 16, 32), (8, 16, 8)),
|
||||
])
|
||||
def test_reduce_scatter_along_last_dim(self, input_shape, expected_shape,
|
||||
mocker: MockerFixture):
|
||||
input_tensor = torch.randn(*input_shape)
|
||||
result = _reduce_scatter_along_last_dim(input_tensor,
|
||||
mocker.MagicMock())
|
||||
assert result.shape == expected_shape
|
||||
|
||||
@pytest.mark.parametrize("func,input_shape,expected_shape", [
|
||||
("all_gather_last_dim_from_tensor_parallel_region", (8, 16, 32),
|
||||
(8, 16, 128)),
|
||||
("reduce_scatter_to_sequence_parallel_region", (32, 16), (8, 16)),
|
||||
("reduce_scatter_last_dim_to_tensor_parallel_region", (8, 16, 32),
|
||||
(8, 16, 8)),
|
||||
("gather_from_sequence_parallel_region", (8, 16), (32, 16)),
|
||||
])
|
||||
def test_wrapper_functions(self, func, input_shape, expected_shape,
|
||||
mocker: MockerFixture):
|
||||
"""test wrapper funcs"""
|
||||
mod = importlib.import_module(
|
||||
'vllm_ascend.distributed.tensor_parallel')
|
||||
globals = mod.__dict__
|
||||
test_func = globals[func]
|
||||
input_tensor = torch.randn(*input_shape)
|
||||
result = test_func(input_tensor, mocker.MagicMock())
|
||||
assert result.shape == expected_shape
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_shape,output_shape",
|
||||
[
|
||||
((8, 16), (32, 4)), # [num_tokens/TP, H] -> [num_tokens, H/TP]
|
||||
])
|
||||
def test_all_to_all_sp2hp(self, input_shape, output_shape,
|
||||
mocker: MockerFixture):
|
||||
input_tensor = torch.randn(*input_shape)
|
||||
result = all_to_all_sp2hp(input_tensor, mocker.MagicMock())
|
||||
assert result.shape == output_shape
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_shape,output_shape",
|
||||
[
|
||||
((32, 4), (8, 16)), # [num_tokens, H/TP] -> [num_tokens/TP, H]
|
||||
])
|
||||
def test_all_to_all_hp2sp(self, input_shape, output_shape,
|
||||
mocker: MockerFixture):
|
||||
input_tensor = torch.randn(*input_shape)
|
||||
result = all_to_all_hp2sp(input_tensor, mocker.MagicMock())
|
||||
assert result.shape == output_shape
|
||||
@@ -4,8 +4,8 @@ import pytest
|
||||
from vllm.config import ParallelConfig
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import (
|
||||
_LMTP, _MC2, destroy_ascend_model_parallel, get_lmhead_tp_group,
|
||||
get_mc2_group, init_ascend_model_parallel)
|
||||
_LMTP, _MC2, _OTP, destroy_ascend_model_parallel, get_lmhead_tp_group,
|
||||
get_mc2_group, get_otp_group, init_ascend_model_parallel)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -29,16 +29,20 @@ def mock_distributed():
|
||||
def test_init_ascend_model_parallel(mock_distributed, parallel_config):
|
||||
mock_ascend_config = MagicMock()
|
||||
mock_ascend_config.lmhead_tensor_parallel_size = 2
|
||||
mock_ascend_config.oproj_tensor_parallel_size = 2
|
||||
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
|
||||
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
|
||||
patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config):
|
||||
init_ascend_model_parallel(parallel_config)
|
||||
|
||||
mc2_group = get_mc2_group()
|
||||
assert mc2_group is not None
|
||||
lmheadtp_group = get_lmhead_tp_group()
|
||||
otp_group = get_otp_group()
|
||||
assert mc2_group is not None
|
||||
assert otp_group is not None
|
||||
assert lmheadtp_group is not None
|
||||
|
||||
destroy_ascend_model_parallel()
|
||||
assert _MC2 is None
|
||||
assert _LMTP is None
|
||||
assert _OTP is None
|
||||
|
||||
73
tests/ut/eplb/adaptor/test_abstract_adaptor.py
Normal file
73
tests/ut/eplb/adaptor/test_abstract_adaptor.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import pytest
|
||||
|
||||
from vllm_ascend.eplb.adaptor.abstract_adaptor import EplbAdaptor
|
||||
|
||||
|
||||
class DummyAdaptor(EplbAdaptor):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.args = kwargs
|
||||
|
||||
def get_rank_expert_workload(self):
|
||||
return "workload"
|
||||
|
||||
def get_init_expert_map(self, num_moe_layers):
|
||||
return {"layers": num_moe_layers}
|
||||
|
||||
def do_update_expert_map(self, layer_id, updated_expert_map):
|
||||
return {"layer_id": layer_id, "map": updated_expert_map}
|
||||
|
||||
def do_update_expert_weight(self, layer_id, local_expert_to_replace,
|
||||
buffer_tensor_id):
|
||||
return {
|
||||
"layer_id": layer_id,
|
||||
"replace": local_expert_to_replace,
|
||||
"buffer": buffer_tensor_id,
|
||||
}
|
||||
|
||||
|
||||
def test_base_class_methods_raise():
|
||||
adaptor = EplbAdaptor()
|
||||
with pytest.raises(NotImplementedError):
|
||||
adaptor.get_rank_expert_workload()
|
||||
with pytest.raises(NotImplementedError):
|
||||
adaptor.get_init_expert_map(1)
|
||||
with pytest.raises(NotImplementedError):
|
||||
adaptor.do_update_expert_map(1, {})
|
||||
with pytest.raises(NotImplementedError):
|
||||
adaptor.do_update_expert_weight(1, "x", "y")
|
||||
|
||||
|
||||
def test_dummy_adaptor_init_and_args():
|
||||
adaptor = DummyAdaptor(test_arg=123)
|
||||
assert adaptor.args["test_arg"] == 123
|
||||
|
||||
|
||||
def test_get_rank_expert_workload():
|
||||
adaptor = DummyAdaptor()
|
||||
result = adaptor.get_rank_expert_workload()
|
||||
assert result == "workload"
|
||||
|
||||
|
||||
def test_get_init_expert_map():
|
||||
adaptor = DummyAdaptor()
|
||||
result = adaptor.get_init_expert_map(5)
|
||||
assert isinstance(result, dict)
|
||||
assert result["layers"] == 5
|
||||
|
||||
|
||||
def test_do_update_expert_map():
|
||||
adaptor = DummyAdaptor()
|
||||
updated = {"expert": 1}
|
||||
result = adaptor.do_update_expert_map(2, updated)
|
||||
assert result["layer_id"] == 2
|
||||
assert result["map"] == updated
|
||||
|
||||
|
||||
def test_do_update_expert_weight():
|
||||
adaptor = DummyAdaptor()
|
||||
result = adaptor.do_update_expert_weight(1, "expertA", "bufferX")
|
||||
assert result["layer_id"] == 1
|
||||
assert result["replace"] == "expertA"
|
||||
assert result["buffer"] == "bufferX"
|
||||
31
tests/ut/eplb/core/policy/test_policy_abstract.py
Normal file
31
tests/ut/eplb/core/policy/test_policy_abstract.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# test_policy_abstract.py
|
||||
from vllm_ascend.eplb.core.policy.policy_abstract import (DynamicConfig,
|
||||
EplbPolicy)
|
||||
|
||||
|
||||
class DummyPolicy(EplbPolicy):
|
||||
|
||||
def rebalance_experts(self, current_expert_table, expert_workload):
|
||||
return 1, current_expert_table
|
||||
|
||||
|
||||
def test_dynamic_config_attributes():
|
||||
config = DynamicConfig()
|
||||
assert config.placement_policy is None
|
||||
assert config.max_transferred_expert_per_layer == 100
|
||||
assert config.ep_worldsize == 64
|
||||
assert config.num_die_per_host == 8
|
||||
|
||||
|
||||
def test_eplb_policy_init_and_method():
|
||||
config = DynamicConfig()
|
||||
policy = DummyPolicy(config)
|
||||
|
||||
assert policy.config == config
|
||||
|
||||
expert_table = [[0, 1, 2]]
|
||||
workload = [10]
|
||||
res, new_table = policy.rebalance_experts(expert_table, workload)
|
||||
|
||||
assert res == 1
|
||||
assert new_table == expert_table
|
||||
98
tests/ut/eplb/core/policy/test_policy_dynamic_ep.py
Normal file
98
tests/ut/eplb/core/policy/test_policy_dynamic_ep.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from vllm_ascend.eplb.core.policy.policy_dynamic_ep import DynamicEplb
|
||||
|
||||
|
||||
class TestDynamicEplb:
|
||||
|
||||
def test_add_redundant_basic(self):
|
||||
current_expert_table = np.array([[[0, 1], [1, 0]]])
|
||||
expert_workload = np.array([[[2, 3], [4, 1]]])
|
||||
num_original_expert = 2
|
||||
result = DynamicEplb.add_redundant(current_expert_table,
|
||||
expert_workload,
|
||||
num_original_expert)
|
||||
expected = np.array([[2 + 1, 3 + 4]])
|
||||
assert np.array_equal(result, expected)
|
||||
|
||||
def test_get_redundant_num(self):
|
||||
counts = np.array([2, 1, 3])
|
||||
assert DynamicEplb.get_redundant_num(3, counts) == 3
|
||||
|
||||
def test_calculate_max_heat_per_layer(self):
|
||||
workload_table = np.array([[[1, 2], [3, 4]], [[2, 2], [1, 1]]])
|
||||
max_heat = DynamicEplb.calculate_max_heat_per_layer(workload_table, 2)
|
||||
assert max_heat == [7, 4]
|
||||
|
||||
def test_constraint_expert_local_exchange(self):
|
||||
current = [[[0, 1], [2, 3]]]
|
||||
global_dep = [[[1, 0], [3, 2]]]
|
||||
new_dep = DynamicEplb.constraint_expert_local_exchange(
|
||||
current, global_dep)
|
||||
assert new_dep == [[[0, 1], [2, 3]]]
|
||||
|
||||
def test_compute_balanced_pack_redundancy_normal(self):
|
||||
origin_weights = [(0, 10), (1, 20)]
|
||||
result, boxes = DynamicEplb.compute_balanced_pack_redundancy(
|
||||
origin_weights, 2, 1)
|
||||
assert isinstance(result, list) and len(result) == 2
|
||||
|
||||
def test_compute_balanced_pack_redundancy_card0(self):
|
||||
origin_weights = [(0, 10)]
|
||||
with pytest.raises(RuntimeError):
|
||||
DynamicEplb.compute_balanced_pack_redundancy(origin_weights, 0, 0)
|
||||
|
||||
def test_compute_balanced_pack_normal(self):
|
||||
origin_weights = np.array([(0, 10), (1, 20)], dtype=object)
|
||||
result, boxes = DynamicEplb.compute_balanced_pack(origin_weights, 2)
|
||||
assert isinstance(result, list) and len(result) == 2
|
||||
|
||||
def test_compute_balanced_pack_card0(self):
|
||||
origin_weights = np.array([(0, 10)], dtype=object)
|
||||
with pytest.raises(RuntimeError):
|
||||
DynamicEplb.compute_balanced_pack(origin_weights, 0)
|
||||
|
||||
def test_original_compute_balanced_pack_redundancy(self):
|
||||
origin_weights = [(0, 5), (1, 10)]
|
||||
result, boxes = DynamicEplb.original_compute_balanced_pack_redundancy(
|
||||
origin_weights, 2, 1)
|
||||
assert isinstance(result, list) and len(result) == 2
|
||||
|
||||
def test_rebalance_experts_normal(self):
|
||||
expert_table = np.array([[[0, 1], [1, 0]]])
|
||||
workload = np.array([[[2, 3], [4, 1]]])
|
||||
policy = DynamicEplb(config=None)
|
||||
change, priority, new_dep = policy.rebalance_experts(
|
||||
expert_table, workload)
|
||||
assert change in [0, 1]
|
||||
assert isinstance(priority, np.ndarray)
|
||||
assert isinstance(new_dep, list)
|
||||
assert np.array(new_dep).shape == expert_table.shape
|
||||
|
||||
def test_rebalance_experts_exceptions(self):
|
||||
policy = DynamicEplb(config=None)
|
||||
|
||||
# case1: num_original_expert != expert_num
|
||||
expert_table = np.array([[[0, 1], [1, 0]]])
|
||||
workload = np.array([[[2, 3], [4, 1]]])
|
||||
with patch.object(DynamicEplb,
|
||||
'add_redundant',
|
||||
return_value=np.array([[1, 2, 3]])):
|
||||
with pytest.raises(ValueError):
|
||||
policy.rebalance_experts(expert_table, workload)
|
||||
|
||||
# case2: num_npus <= 0
|
||||
expert_table_zero = np.array([[]]) # 1 layer, 0 NPU, 0 experts
|
||||
workload_zero = np.array([[]])
|
||||
with pytest.raises(ValueError):
|
||||
policy.rebalance_experts(expert_table_zero, workload_zero)
|
||||
|
||||
# case3: num_npus < num_redundancy_expert
|
||||
expert_table_small = np.array([[[0, 0]]]) # 1 layer, 1 NPU, 2 experts
|
||||
workload_small = np.array([[[1, 1]]])
|
||||
with patch.object(DynamicEplb, 'get_redundant_num', return_value=2):
|
||||
with pytest.raises(ValueError):
|
||||
policy.rebalance_experts(expert_table_small, workload_small)
|
||||
99
tests/ut/eplb/core/policy/test_policy_dynamic_ep_v2.py
Normal file
99
tests/ut/eplb/core/policy/test_policy_dynamic_ep_v2.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from typing import Dict, Set
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from vllm_ascend.eplb.core.policy.policy_dynamic_ep_v2 import (DynamicConfig,
|
||||
DynamicEplbV2)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config():
|
||||
return DynamicConfig()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def policy(config):
|
||||
return DynamicEplbV2(config)
|
||||
|
||||
|
||||
def test_safe_operations(policy):
|
||||
# safe_divide
|
||||
assert policy.safe_divide(10, 2) == 5
|
||||
assert policy.safe_divide(1, 0) == 0
|
||||
|
||||
# safe_exact_divide
|
||||
assert policy.safe_exact_divide(10, 3) == 3
|
||||
assert policy.safe_exact_divide(1, 0) == 0
|
||||
|
||||
# safe_mod
|
||||
assert policy.safe_mod(10, 3) == 1
|
||||
assert policy.safe_mod(1, 0) == 0
|
||||
|
||||
|
||||
def test_add_redundant():
|
||||
workload = np.array([[[1, 2], [3, 4]]])
|
||||
placement = np.array([[[0, 1], [0, 1]]])
|
||||
result = DynamicEplbV2.add_redundant(placement, workload, 2)
|
||||
assert result.shape == (1, 2)
|
||||
assert np.all(result[0] == [4, 6]) # 0:1+3, 1:2+4
|
||||
|
||||
|
||||
def test_get_redundant_num():
|
||||
counts = np.array([1, 2, 1])
|
||||
assert DynamicEplbV2.get_redundant_num(3, counts) == 1 # sum(counts-1)
|
||||
|
||||
|
||||
def test_calculate_max_heat_per_layer():
|
||||
workload = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
|
||||
result = DynamicEplbV2.calculate_max_heat_per_layer(workload, 2)
|
||||
assert result == [7, 15]
|
||||
|
||||
|
||||
def test_calculate_initial_imbalance(policy):
|
||||
deployment = np.array([[[0, 1], [0, 1]]])
|
||||
workloads = np.array([[1, 1]])
|
||||
result = policy.calculate_initial_imbalance(deployment, workloads)
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
def test_compute_redundant_assignments(policy):
|
||||
base_experts = [(0, 10), (1, 5)]
|
||||
redundant, sorted_weights = policy.compute_redundant_assignments(
|
||||
base_experts, num_redundant_experts=2, num_experts=2)
|
||||
assert len(redundant) == 2
|
||||
assert len(sorted_weights) == 2
|
||||
|
||||
|
||||
def test_prepare_expert_list():
|
||||
base_experts = [(0, 10), (1, 5)]
|
||||
redundant_assignments = [[2], []]
|
||||
result = DynamicEplbV2.prepare_expert_list(base_experts,
|
||||
redundant_assignments, 1)
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
def test_non_redundant_expert_information():
|
||||
origin_deployment = np.array([[0, 1]])
|
||||
updated_weights = [(0, 10), (1, 5)]
|
||||
rendun_pos: Dict[int, Set[int]] = {0: set()}
|
||||
assignments, weights, loads, counts = DynamicEplbV2.non_redundant_expert_information(
|
||||
origin_deployment, updated_weights, rendun_pos)
|
||||
assert assignments[0] == [0, 1]
|
||||
assert loads[0] == 15
|
||||
|
||||
|
||||
def test_recomputing_initial_weight(policy):
|
||||
layer_workloads = [10, 5]
|
||||
device_assignments = [[0, 1]]
|
||||
cur_layer_workload, num_all_experts = policy.recomputing_initial_weight(
|
||||
layer_workloads, device_assignments)
|
||||
assert cur_layer_workload[0] == 10
|
||||
assert num_all_experts[0] == 1
|
||||
|
||||
|
||||
def test_safe_divide_zero_edge_case(policy):
|
||||
assert policy.safe_divide(0, 1) == 0
|
||||
assert policy.safe_divide(0, 5) == 0
|
||||
23
tests/ut/eplb/core/policy/test_policy_factor.py
Normal file
23
tests/ut/eplb/core/policy/test_policy_factor.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import pytest
|
||||
|
||||
from vllm_ascend.eplb.core.policy.policy_abstract import DynamicConfig
|
||||
from vllm_ascend.eplb.core.policy.policy_dynamic_ep import DynamicEplb
|
||||
from vllm_ascend.eplb.core.policy.policy_dynamic_ep_v2 import DynamicEplbV2
|
||||
from vllm_ascend.eplb.core.policy.policy_factory import PolicyFactory
|
||||
from vllm_ascend.eplb.core.policy.policy_random import RandomLoadBalance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_config():
|
||||
return DynamicConfig()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_type, expected_class", [
|
||||
(0, RandomLoadBalance),
|
||||
(1, DynamicEplb),
|
||||
(2, DynamicEplbV2),
|
||||
(999, RandomLoadBalance),
|
||||
])
|
||||
def test_generate_policy(policy_type, expected_class, dummy_config):
|
||||
policy_instance = PolicyFactory.generate_policy(policy_type, dummy_config)
|
||||
assert isinstance(policy_instance, expected_class)
|
||||
122
tests/ut/eplb/core/test_eplb_device_transfer_loader.py
Normal file
122
tests/ut/eplb/core/test_eplb_device_transfer_loader.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm_ascend.eplb.core.eplb_device_transfer_loader as loader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_adaptor():
|
||||
adaptor = MagicMock()
|
||||
|
||||
adaptor.expert_map_per_layer_cpu = {
|
||||
0: {
|
||||
10: torch.tensor(1),
|
||||
20: torch.tensor(0)
|
||||
}
|
||||
}
|
||||
|
||||
adaptor.expert_param_per_layer = {
|
||||
0: {
|
||||
0: [[torch.tensor([1.0])]],
|
||||
1: [[torch.tensor([2.0])]]
|
||||
}
|
||||
}
|
||||
|
||||
adaptor.buffer_tensor_list = [[[torch.tensor([3.0])],
|
||||
[torch.tensor([4.0])]]]
|
||||
return adaptor
|
||||
|
||||
|
||||
def test_generate_task_and_state_flow(mock_adaptor):
|
||||
loader_obj = loader.D2DExpertWeightLoader()
|
||||
loader_obj.set_adator(mock_adaptor)
|
||||
|
||||
with patch("torch.distributed.P2POp") as mock_p2p, \
|
||||
patch("torch.distributed.isend", return_value="isend_op"), \
|
||||
patch("torch.distributed.irecv", return_value="irecv_op"):
|
||||
|
||||
mock_p2p.side_effect = lambda op, tensor, rank: (op, tensor, rank)
|
||||
|
||||
loader_obj.state = loader.ExpertWeightUpdateState.READY
|
||||
loader_obj.generate_expert_d2d_transfer_task([(1, 10)], [(2, 20)],
|
||||
{20: torch.tensor(0)}, 0)
|
||||
assert loader_obj.comm_op_list is None
|
||||
loader_obj.state = loader.ExpertWeightUpdateState.WAITING
|
||||
|
||||
loader_obj.generate_expert_d2d_transfer_task([], [], {}, 0)
|
||||
assert loader_obj.comm_op_list is None
|
||||
|
||||
updated_map = {20: torch.tensor(0)}
|
||||
loader_obj.generate_expert_d2d_transfer_task([(1, 10)], [(2, 20)],
|
||||
updated_map, 0)
|
||||
assert loader_obj.state == loader.ExpertWeightUpdateState.READY
|
||||
assert loader_obj.comm_op_list
|
||||
assert loader_obj.recv_expert_list
|
||||
|
||||
|
||||
def test_asyn_transfer_and_update(mock_adaptor):
|
||||
loader_obj = loader.D2DExpertWeightLoader()
|
||||
loader_obj.set_adator(mock_adaptor)
|
||||
|
||||
loader_obj.comm_op_list = ["fake_op"]
|
||||
loader_obj.state = loader.ExpertWeightUpdateState.READY
|
||||
|
||||
reqs: list[MagicMock] = []
|
||||
|
||||
with patch("torch.distributed.batch_isend_irecv",
|
||||
return_value=[MagicMock(), MagicMock()]):
|
||||
loader_obj.asyn_expert_weight_transfer(reqs)
|
||||
|
||||
assert loader_obj.state == loader.ExpertWeightUpdateState.TRANSFERRING
|
||||
assert len(reqs) > 0
|
||||
|
||||
mock_req = MagicMock()
|
||||
mock_req.wait.return_value = None
|
||||
reqs = [mock_req]
|
||||
|
||||
loader_obj.recv_expert_list = [(0, 0)]
|
||||
loader_obj.updated_expert_map = {20: torch.tensor(0)}
|
||||
loader_obj.updated_log2phy_map = {"dummy": 1}
|
||||
loader_obj.layer_id = 0
|
||||
loader_obj.comm_op_list = ["op"]
|
||||
|
||||
loader_obj.update_expert_map_and_weight(reqs)
|
||||
|
||||
mock_adaptor.do_update_expert_map.assert_called_once()
|
||||
mock_adaptor.do_update_log2phy_map.assert_called_once()
|
||||
mock_adaptor.do_update_expert_weight.assert_called_once()
|
||||
|
||||
assert loader_obj.state == loader.ExpertWeightUpdateState.WAITING
|
||||
assert loader_obj.recv_expert_list == []
|
||||
|
||||
|
||||
def test_set_log2phy_map(mock_adaptor):
|
||||
loader_obj = loader.D2DExpertWeightLoader()
|
||||
loader_obj.set_adator(mock_adaptor)
|
||||
loader_obj.set_log2phy_map({"a": 1})
|
||||
assert loader_obj.updated_log2phy_map == {"a": 1}
|
||||
|
||||
|
||||
def test_invalid_state_asyn_update(mock_adaptor):
|
||||
loader_obj = loader.D2DExpertWeightLoader()
|
||||
loader_obj.set_adator(mock_adaptor)
|
||||
|
||||
loader_obj.state = loader.ExpertWeightUpdateState.WAITING
|
||||
reqs: list[Any] = []
|
||||
loader_obj.asyn_expert_weight_transfer(reqs)
|
||||
assert reqs == []
|
||||
|
||||
loader_obj.state = loader.ExpertWeightUpdateState.READY
|
||||
loader_obj.update_expert_map_and_weight([])
|
||||
|
||||
assert not mock_adaptor.do_update_expert_map.called
|
||||
|
||||
|
||||
def test_load_impl_not_implemented(mock_adaptor):
|
||||
loader_obj = loader.D2DExpertWeightLoader()
|
||||
loader_obj.set_adator(mock_adaptor)
|
||||
with pytest.raises(NotImplementedError):
|
||||
loader_obj.load_impl({}, {})
|
||||
79
tests/ut/eplb/core/test_eplb_utils.py
Normal file
79
tests/ut/eplb/core/test_eplb_utils.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_ascend.eplb.core import eplb_utils
|
||||
|
||||
|
||||
def test_determine_default_expert_map_single_world():
|
||||
count, expert_map = eplb_utils.determine_default_expert_map(
|
||||
global_expert_num=4,
|
||||
world_size=1,
|
||||
rank_id=0,
|
||||
global_redundant_expert_num=0)
|
||||
assert count == 4
|
||||
assert torch.equal(expert_map, torch.arange(4, dtype=torch.int32))
|
||||
|
||||
|
||||
def test_determine_default_expert_map_multiple_worlds_no_redundant():
|
||||
count, expert_map = eplb_utils.determine_default_expert_map(
|
||||
global_expert_num=8,
|
||||
world_size=2,
|
||||
rank_id=0,
|
||||
global_redundant_expert_num=0)
|
||||
|
||||
assert count == 4
|
||||
assert torch.all(expert_map[:4] >= 0)
|
||||
assert torch.all(expert_map[4:] == -1)
|
||||
|
||||
|
||||
def test_determine_default_expert_map_multiple_worlds_with_redundant():
|
||||
count, expert_map = eplb_utils.determine_default_expert_map(
|
||||
global_expert_num=5,
|
||||
world_size=2,
|
||||
rank_id=0,
|
||||
global_redundant_expert_num=1)
|
||||
|
||||
assert count == 3
|
||||
assert torch.all(expert_map[0:3] >= 0)
|
||||
|
||||
|
||||
def test_generate_log2phy_map_single_rank_holding():
|
||||
|
||||
expert_map = torch.tensor([[0, -1], [-1, 0]], dtype=torch.int32)
|
||||
log2phy_map = eplb_utils.generate_log2phy_map(expert_map)
|
||||
|
||||
assert torch.all(log2phy_map[:, 0] == log2phy_map[0, 0])
|
||||
assert torch.all(log2phy_map[:, 1] == log2phy_map[1, 1])
|
||||
|
||||
|
||||
def test_generate_log2phy_map_multiple_rank_holding(monkeypatch):
|
||||
|
||||
expert_map = torch.tensor([[0], [0]], dtype=torch.int32)
|
||||
|
||||
monkeypatch.setattr(random, "choice", lambda x: x[0])
|
||||
|
||||
log2phy_map = eplb_utils.generate_log2phy_map(expert_map)
|
||||
|
||||
assert log2phy_map.shape == (2, 1)
|
||||
assert (log2phy_map >= 0).all()
|
||||
|
||||
|
||||
def test_determine_default_log2phy_map_world_size_1():
|
||||
log2phy = eplb_utils.determine_default_log2phy_map(
|
||||
global_expert_num=3,
|
||||
world_size=1,
|
||||
rank_id=0,
|
||||
global_redundant_expert_num=0)
|
||||
assert log2phy.shape == (3, )
|
||||
assert (log2phy >= 0).all()
|
||||
|
||||
|
||||
def test_determine_default_log2phy_map_world_size_multiple():
|
||||
log2phy = eplb_utils.determine_default_log2phy_map(
|
||||
global_expert_num=6,
|
||||
world_size=2,
|
||||
rank_id=1,
|
||||
global_redundant_expert_num=1)
|
||||
assert log2phy.shape == (6, )
|
||||
assert (log2phy >= 0).all()
|
||||
@@ -7,6 +7,7 @@ import time
|
||||
import types
|
||||
import unittest
|
||||
from collections import defaultdict, deque
|
||||
from typing import OrderedDict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import msgspec
|
||||
@@ -34,7 +35,7 @@ class TestKVCacheTaskTrackerInit(unittest.TestCase):
|
||||
tracker = KVCacheTaskTracker()
|
||||
self.assertIsInstance(tracker.done_task_lock, type(threading.Lock()))
|
||||
self.assertIsInstance(tracker.finished_requests, set)
|
||||
self.assertIsInstance(tracker.delayed_free_requests, deque)
|
||||
self.assertIsInstance(tracker.delayed_free_requests, OrderedDict)
|
||||
|
||||
|
||||
class TestGetAndClearFinishedSingleRequests(unittest.TestCase):
|
||||
@@ -495,18 +496,42 @@ class TestKVCacheTaskTracker(unittest.TestCase):
|
||||
def test_update_done_task_count(self):
|
||||
self.assertEqual(len(self.tracker.finished_requests), 0)
|
||||
self.assertEqual(len(self.tracker.delayed_free_requests), 0)
|
||||
self.assertEqual(len(self.tracker.record_finished_requests), 0)
|
||||
|
||||
current_time = time.time()
|
||||
self.tracker.add_delayed_request("req_1", current_time)
|
||||
result = self.tracker.delayed_free_requests
|
||||
result_record = self.tracker.record_finished_requests
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0], ("req_1", current_time))
|
||||
self.assertEqual(result["req_1"], current_time)
|
||||
self.assertEqual(len(result_record), 0)
|
||||
|
||||
self.tracker.update_done_task_count("req_1")
|
||||
result_finished = self.tracker.finished_requests
|
||||
result_delayed = self.tracker.delayed_free_requests
|
||||
result_record = self.tracker.record_finished_requests
|
||||
self.assertEqual(result_finished, {"req_1"})
|
||||
self.assertEqual(len(result_delayed), 0)
|
||||
self.assertEqual(len(result_record), 0)
|
||||
|
||||
self.tracker.update_done_task_count("req_2")
|
||||
result_finished = self.tracker.finished_requests
|
||||
result_delayed = self.tracker.delayed_free_requests
|
||||
result_record = self.tracker.record_finished_requests
|
||||
self.assertEqual(result_finished, {"req_1", "req_2"})
|
||||
self.assertEqual(len(result_delayed), 0)
|
||||
self.assertEqual(len(result_record), 1)
|
||||
self.assertEqual(result_record, {"req_2"})
|
||||
|
||||
def test_updtate_add_delayed_request(self) -> None:
|
||||
self.tracker.update_done_task_count("req2")
|
||||
result_start_record = self.tracker.record_finished_requests
|
||||
self.assertEqual(len(result_start_record), 1)
|
||||
self.tracker.add_delayed_request("req2", time.time())
|
||||
result_delayed = self.tracker.delayed_free_requests
|
||||
result_end_record = self.tracker.record_finished_requests
|
||||
self.assertEqual(len(result_delayed), 0)
|
||||
self.assertEqual(len(result_end_record), 0)
|
||||
|
||||
def test_retrieve_expired_requests(self):
|
||||
current_time = time.time()
|
||||
@@ -518,7 +543,7 @@ class TestKVCacheTaskTracker(unittest.TestCase):
|
||||
})
|
||||
result_delay = self.tracker.delayed_free_requests
|
||||
self.assertEqual(len(result_delay), 1)
|
||||
self.assertEqual(result_delay[0], ("req_2", current_time))
|
||||
self.assertIn("req_2", result_delay)
|
||||
|
||||
def test_duplicate_task_update(self):
|
||||
self.tracker.update_done_task_count("req1")
|
||||
@@ -961,6 +986,46 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
||||
for p in self.patches:
|
||||
p.stop() # type: ignore
|
||||
|
||||
def test_worker_use_ascend_direct(self):
|
||||
test_case = [True, False]
|
||||
|
||||
for use_ascend_direct in test_case:
|
||||
with self.subTest(use_ascend_direct=use_ascend_direct):
|
||||
config = MagicMock()
|
||||
config.kv_transfer_config = MagicMock()
|
||||
config.kv_transfer_config.get_from_extra_config.side_effect = (
|
||||
lambda k, d: {
|
||||
"prefill": {
|
||||
"tp_size": 2,
|
||||
"dp_size": 1
|
||||
},
|
||||
"decode": {
|
||||
"tp_size": 2,
|
||||
"dp_size": 1
|
||||
},
|
||||
"use_ascend_direct": use_ascend_direct,
|
||||
}.get(k, d))
|
||||
|
||||
config.parallel_config = MagicMock()
|
||||
config.parallel_config.tensor_parallel_size = 2
|
||||
config.parallel_config.data_parallel_rank_local = 0
|
||||
config.parallel_config.data_parallel_size_local = 1
|
||||
config.kv_transfer_config.kv_port = 8000
|
||||
config.kv_transfer_config.kv_role = 'worker'
|
||||
|
||||
with patch(
|
||||
"vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_rank",
|
||||
return_value=0):
|
||||
with patch(
|
||||
"vllm_ascend.distributed.mooncake_connector.get_tp_group",
|
||||
return_value=None):
|
||||
with patch(
|
||||
"vllm_ascend.distributed.mooncake_connector.get_ip",
|
||||
return_value="127.0.0.1"):
|
||||
worker = MooncakeConnectorWorker(
|
||||
config, self.engine_id)
|
||||
self.assertIsNotNone(worker)
|
||||
|
||||
def test_register_kv_caches_producer(self):
|
||||
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
||||
worker.register_kv_caches(self.kv_caches)
|
||||
|
||||
@@ -10,6 +10,7 @@ import torch
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
|
||||
ModelConfig, SchedulerConfig, VllmConfig)
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
||||
init_none_hash)
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
@@ -19,8 +20,6 @@ from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
EOS_TOKEN_ID = 50256
|
||||
os.environ["VLLM_USE_V1"] = "1"
|
||||
|
||||
@@ -131,10 +130,10 @@ def create_request(
|
||||
"""Make dummy request for testing."""
|
||||
global _none_hash_initialized
|
||||
if not _none_hash_initialized:
|
||||
init_none_hash(hash)
|
||||
init_none_hash(sha256)
|
||||
_none_hash_initialized = True
|
||||
|
||||
block_hasher = get_request_block_hasher(block_size, hash)
|
||||
block_hasher = get_request_block_hasher(block_size, sha256)
|
||||
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None
|
||||
|
||||
@@ -160,27 +159,14 @@ def create_request(
|
||||
else:
|
||||
prompt_token_ids = [i * request_id for i in range(num_tokens)]
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
req = Request(
|
||||
request_id=f"id-{request_id}",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
multi_modal_kwargs=None,
|
||||
multi_modal_placeholders=None,
|
||||
multi_modal_hashes=None,
|
||||
pooling_params=[],
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
block_hasher=block_hasher,
|
||||
)
|
||||
else:
|
||||
req = Request(
|
||||
request_id=f"id-{request_id}",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=[],
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
block_hasher=block_hasher,
|
||||
)
|
||||
req = Request(
|
||||
request_id=f"id-{request_id}",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=[],
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
block_hasher=block_hasher,
|
||||
)
|
||||
req.kv_transfer_params = kv_transfer_params
|
||||
return req
|
||||
|
||||
@@ -208,26 +194,15 @@ def create_model_runner_output(
|
||||
kv_connector_output = KVConnectorOutput(finished_sending=finished_sending,
|
||||
finished_recving=finished_recving)
|
||||
extra_args = {"kv_connector_output": kv_connector_output}
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
**extra_args,
|
||||
)
|
||||
else:
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
**extra_args,
|
||||
)
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
**extra_args,
|
||||
)
|
||||
|
||||
return model_runner_output
|
||||
|
||||
114
tests/ut/models/conftest.py
Normal file
114
tests/ut/models/conftest.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import CacheConfig, EPLBConfig, ParallelConfig
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_config():
|
||||
config = PretrainedConfig(
|
||||
hidden_size=128,
|
||||
num_attention_heads=8,
|
||||
num_hidden_layers=2,
|
||||
intermediate_size=256,
|
||||
hidden_act="silu",
|
||||
rms_norm_eps=1e-6,
|
||||
rope_theta=10000.0,
|
||||
max_position_embeddings=2048,
|
||||
n_routed_experts=4,
|
||||
n_shared_experts=1,
|
||||
moe_intermediate_size=256,
|
||||
num_experts_per_tok=2,
|
||||
routed_scaling_factor=1.0,
|
||||
first_k_dense_replace=0,
|
||||
moe_layer_freq=1,
|
||||
kv_lora_rank=16,
|
||||
qk_nope_head_dim=16,
|
||||
qk_rope_head_dim=16,
|
||||
v_head_dim=32,
|
||||
topk_method="noaux_tc",
|
||||
scoring_func="softmax",
|
||||
norm_topk_prob=True,
|
||||
n_group=1,
|
||||
topk_group=1,
|
||||
vocab_size=10000,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vllm_config(base_config):
|
||||
model_config = SimpleNamespace(
|
||||
hf_config=base_config,
|
||||
tensor_parallel_size=1,
|
||||
dtype=torch.float32,
|
||||
use_mla=True,
|
||||
quant_config=None,
|
||||
max_model_len=2048,
|
||||
)
|
||||
parallel_config = MagicMock(spec=ParallelConfig)
|
||||
eplb_config = MagicMock(spec=EPLBConfig)
|
||||
eplb_config.num_redundant_experts = 0
|
||||
parallel_config.eplb_config = eplb_config
|
||||
|
||||
cache_config = CacheConfig()
|
||||
vllm_config = Mock()
|
||||
vllm_config.model_config = model_config
|
||||
vllm_config.cache_config = cache_config
|
||||
vllm_config.quant_config = None
|
||||
vllm_config.parallel_config = parallel_config
|
||||
return vllm_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_distributed():
|
||||
tp_group = Mock(spec=GroupCoordinator)
|
||||
tp_group.rank_in_group = 0
|
||||
tp_group.world_size = 1
|
||||
tp_group.device_group = Mock()
|
||||
|
||||
dp_group = Mock(spec=GroupCoordinator)
|
||||
dp_group.rank_in_group = 0
|
||||
dp_group.world_size = 1
|
||||
|
||||
ep_group = Mock(spec=GroupCoordinator)
|
||||
ep_group.rank_in_group = 0
|
||||
ep_group.world_size = 1
|
||||
ep_group.device_group = Mock()
|
||||
ep_group.device_group.rank.return_value = 0
|
||||
ep_group.device_group.size.return_value = 1
|
||||
|
||||
pp_group = Mock(spec=GroupCoordinator)
|
||||
pp_group.rank_in_group = 0
|
||||
pp_group.world_size = 1
|
||||
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.scheduler_config = Mock(max_num_seqs=256)
|
||||
mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None)
|
||||
|
||||
with patch("vllm_ascend.models.deepseek_v2.get_tensor_model_parallel_rank", return_value=0), \
|
||||
patch("vllm_ascend.models.deepseek_v2.get_tensor_model_parallel_world_size", return_value=1), \
|
||||
patch("vllm_ascend.models.deepseek_v2.get_tp_group", return_value=tp_group), \
|
||||
patch("vllm_ascend.models.deepseek_v2.get_pp_group", return_value=pp_group), \
|
||||
patch("vllm_ascend.models.deepseek_v2.get_pp_group",
|
||||
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
|
||||
patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
|
||||
patch("vllm_ascend.ops.moe.token_dispatcher.torch.distributed.get_rank", return_value=0), \
|
||||
patch("vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version", return_value=None), \
|
||||
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
|
||||
_PP=pp_group), \
|
||||
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group), \
|
||||
patch("torch.npu.current_device", return_value=0):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_forward_context():
|
||||
forward_context = Mock(in_profile_run=False, with_prefill=False)
|
||||
with patch("vllm_ascend.models.deepseek_v2.get_forward_context",
|
||||
return_value=forward_context):
|
||||
yield
|
||||
@@ -13,10 +13,13 @@ from vllm_ascend.models.deepseek_mtp import (
|
||||
class TestCustomDeepSeekMultiTokenPredictorLayer(PytestBase):
|
||||
|
||||
@pytest.fixture
|
||||
def setup_mtp_layer(self, mocker: MockerFixture):
|
||||
def setup_mtp_layer(self, mocker: MockerFixture, vllm_config: VllmConfig,
|
||||
mock_distributed):
|
||||
config = PretrainedConfig(vocab_size=1000,
|
||||
hidden_size=768,
|
||||
rms_norm_eps=1e-5)
|
||||
mocker.patch("vllm_ascend.models.deepseek_mtp.get_current_vllm_config",
|
||||
return_value=vllm_config)
|
||||
mocker.patch(
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
@@ -29,15 +32,15 @@ class TestCustomDeepSeekMultiTokenPredictorLayer(PytestBase):
|
||||
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekShareHead.__init__",
|
||||
return_value=None)
|
||||
mocker_deepseek_v2_decode_layer = mocker.patch(
|
||||
"vllm_ascend.models.deepseek_v2.CustomDeepseekV2DecoderLayer.__init__",
|
||||
"vllm.model_executor.models.deepseek_v2.DeepseekV2DecoderLayer.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
mocker.patch("vllm_ascend.utils.get_ascend_config",
|
||||
mocker.patch("vllm_ascend.models.deepseek_v2.get_ascend_config",
|
||||
return_value=mocker.Mock())
|
||||
|
||||
mtp_layer = CustomDeepSeekMultiTokenPredictorLayer(config, "", None)
|
||||
mtp_layer = CustomDeepSeekMultiTokenPredictorLayer(config, "0", None)
|
||||
mocker_deepseek_v2_decode_layer.assert_called_once()
|
||||
return mtp_layer
|
||||
|
||||
@@ -165,8 +168,6 @@ class TestCustomDeepSeekMTP(PytestBase):
|
||||
mocker.patch(
|
||||
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__call__",
|
||||
return_value=None)
|
||||
mocker.patch("vllm.model_executor.layers.sampler.get_sampler",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
|
||||
@@ -12,169 +12,23 @@
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
|
||||
from vllm_ascend.models.deepseek_v2 import (
|
||||
CustomDeepseekV2MergedReplicatedLinear, CustomDeepseekV2MLAAttention,
|
||||
CustomDeepseekV2MLP, CustomDeepseekV2MoE,
|
||||
CustomDeepseekV2RowParallelLinear,
|
||||
CustomDeepseekV2RowParallelLinearReplaceAllreduce,
|
||||
CustomDeepseekV2SiluAndMul, LogitsProcessor, ParallelLMHead)
|
||||
from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2MLAAttention,
|
||||
CustomDeepseekV2RowParallelLinear)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_config():
|
||||
config = PretrainedConfig(
|
||||
hidden_size=128,
|
||||
num_attention_heads=8,
|
||||
num_hidden_layers=2,
|
||||
intermediate_size=256,
|
||||
hidden_act="silu",
|
||||
rms_norm_eps=1e-6,
|
||||
rope_theta=10000.0,
|
||||
max_position_embeddings=2048,
|
||||
n_routed_experts=4,
|
||||
n_shared_experts=1,
|
||||
moe_intermediate_size=256,
|
||||
num_experts_per_tok=2,
|
||||
routed_scaling_factor=1.0,
|
||||
first_k_dense_replace=0,
|
||||
moe_layer_freq=1,
|
||||
kv_lora_rank=16,
|
||||
qk_nope_head_dim=16,
|
||||
qk_rope_head_dim=16,
|
||||
v_head_dim=32,
|
||||
topk_method="noaux_tc",
|
||||
scoring_func="softmax",
|
||||
norm_topk_prob=True,
|
||||
n_group=1,
|
||||
topk_group=1,
|
||||
vocab_size=10000,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vllm_config(base_config):
|
||||
model_config = SimpleNamespace(
|
||||
hf_config=base_config,
|
||||
tensor_parallel_size=1,
|
||||
dtype=torch.float32,
|
||||
use_mla=False,
|
||||
quant_config=None,
|
||||
max_model_len=2048,
|
||||
)
|
||||
|
||||
cache_config = CacheConfig()
|
||||
vllm_config = Mock()
|
||||
vllm_config.model_config = model_config
|
||||
vllm_config.cache_config = cache_config
|
||||
vllm_config.quant_config = None
|
||||
return vllm_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_distributed():
|
||||
tp_group = Mock(spec=GroupCoordinator)
|
||||
tp_group.rank_in_group = 0
|
||||
tp_group.world_size = 1
|
||||
tp_group.device_group = Mock()
|
||||
|
||||
dp_group = Mock(spec=GroupCoordinator)
|
||||
dp_group.rank_in_group = 0
|
||||
dp_group.world_size = 1
|
||||
|
||||
ep_group = Mock(spec=GroupCoordinator)
|
||||
ep_group.rank_in_group = 0
|
||||
ep_group.world_size = 1
|
||||
|
||||
pp_group = Mock(spec=GroupCoordinator)
|
||||
pp_group.rank_in_group = 0
|
||||
pp_group.world_size = 1
|
||||
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.scheduler_config = Mock(max_num_seqs=256)
|
||||
mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None)
|
||||
|
||||
with patch("vllm_ascend.models.deepseek_v2.get_tensor_model_parallel_rank", return_value=0), \
|
||||
patch("vllm_ascend.models.deepseek_v2.get_tensor_model_parallel_world_size", return_value=1), \
|
||||
patch("vllm_ascend.models.deepseek_v2.get_tp_group", return_value=tp_group), \
|
||||
patch("vllm_ascend.models.deepseek_v2.get_ep_group", return_value=ep_group), \
|
||||
patch("vllm_ascend.models.deepseek_v2.get_dp_group", return_value=dp_group), \
|
||||
patch("vllm_ascend.models.deepseek_v2.get_pp_group", return_value=pp_group), \
|
||||
patch("vllm_ascend.models.deepseek_v2.get_pp_group",
|
||||
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
|
||||
patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
|
||||
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
|
||||
_PP=pp_group), \
|
||||
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group), \
|
||||
patch("torch.npu.current_device", return_value=0):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_forward_context():
|
||||
forward_context = Mock(in_profile_run=False, with_prefill=False)
|
||||
with patch("vllm_ascend.models.deepseek_v2.get_forward_context",
|
||||
return_value=forward_context):
|
||||
yield
|
||||
|
||||
|
||||
def test_custom_deepseek_v2_silu_and_mul():
|
||||
torch.set_default_device("cpu")
|
||||
|
||||
silu = CustomDeepseekV2SiluAndMul()
|
||||
assert silu.weight_scale is None
|
||||
|
||||
x = torch.randn(2, 4)
|
||||
output = silu.forward_oot(x)
|
||||
assert output.shape == (2, 2)
|
||||
|
||||
weight_scale = Mock(return_value=torch.tensor(0.1))
|
||||
silu = CustomDeepseekV2SiluAndMul(weight_scale=weight_scale)
|
||||
quant_x = torch.randint(-128, 127, (2, 4), dtype=torch.int32)
|
||||
dynamic_scale = torch.randn(2, 1)
|
||||
with patch("torch_npu.npu_dequant_swiglu_quant",
|
||||
return_value=torch.randn(2, 4)):
|
||||
output = silu.forward_oot((quant_x, dynamic_scale))
|
||||
assert output.shape == (2, 4)
|
||||
|
||||
|
||||
def test_custom_deepseek_v2_merged_replicated_linear(mock_distributed):
|
||||
linear = CustomDeepseekV2MergedReplicatedLinear(input_size=128,
|
||||
output_sizes=[64, 64],
|
||||
bias=False,
|
||||
quant_config=None)
|
||||
assert linear.output_sizes == [64, 64]
|
||||
|
||||
param = Mock()
|
||||
param.data = torch.zeros(128, 128)
|
||||
param.output_dim = 1
|
||||
param.is_gguf_weight = False
|
||||
param.is_gguf_weight_type = False
|
||||
loaded_weight = torch.randn(128, 64)
|
||||
linear.weight_loader(param, loaded_weight, loaded_shard_id=0)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
linear.weight_loader(param, torch.randn(128, 32), loaded_shard_id=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cls", [
|
||||
CustomDeepseekV2RowParallelLinearReplaceAllreduce,
|
||||
CustomDeepseekV2RowParallelLinear
|
||||
])
|
||||
@pytest.mark.parametrize("cls", [CustomDeepseekV2RowParallelLinear])
|
||||
def test_row_parallel_linear(cls, mock_distributed):
|
||||
linear = cls(input_size=128, output_size=64, bias=False, quant_config=None)
|
||||
linear.quant_method = Mock()
|
||||
linear.quant_method.apply.return_value = torch.randn(2, 4, 64)
|
||||
|
||||
input_ = torch.randn(2, 4, 128)
|
||||
with patch("vllm_ascend.models.deepseek_v2.split_tensor_along_last_dim",
|
||||
return_value=[torch.randn(2, 4, 64)]):
|
||||
@@ -187,52 +41,10 @@ def test_row_parallel_linear(cls, mock_distributed):
|
||||
assert output[0].shape == (2, 4, 64)
|
||||
|
||||
|
||||
def test_custom_deepseek_v2_mlp(mock_distributed, base_config):
|
||||
mlp = CustomDeepseekV2MLP(hidden_size=128,
|
||||
intermediate_size=256,
|
||||
hidden_act="silu",
|
||||
quant_config=None)
|
||||
assert isinstance(mlp.act_fn, CustomDeepseekV2SiluAndMul)
|
||||
|
||||
x = torch.randn(2, 4, 128)
|
||||
output = mlp(x)
|
||||
assert output.shape == (2, 4, 128)
|
||||
|
||||
with patch("vllm_ascend.models.deepseek_v2.QuantizationConfig"
|
||||
) as mock_quant_config:
|
||||
mock_quant_config.name = "w8a8dynamic"
|
||||
with pytest.raises(NotImplementedError):
|
||||
CustomDeepseekV2MLP(hidden_size=128,
|
||||
intermediate_size=256,
|
||||
hidden_act="silu",
|
||||
quant_config=mock_quant_config,
|
||||
force_replicate=False)
|
||||
with pytest.raises(ValueError):
|
||||
CustomDeepseekV2MLP(hidden_size=128,
|
||||
intermediate_size=256,
|
||||
hidden_act="relu",
|
||||
quant_config=None)
|
||||
|
||||
|
||||
def test_custom_deepseek_v2_moe(mock_distributed, base_config,
|
||||
mock_forward_context):
|
||||
base_config.n_shared_experts = 1
|
||||
moe = CustomDeepseekV2MoE(config=base_config,
|
||||
quant_config=None,
|
||||
prefix="mlp")
|
||||
assert moe.top_k == 2
|
||||
|
||||
x = torch.randn(2, 4, 128)
|
||||
attn_metadata = Mock(num_prefills=1)
|
||||
with patch("vllm_ascend.ops.fused_moe.AscendFusedMoE.__call__",
|
||||
return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))):
|
||||
output = moe(x, attn_metadata)
|
||||
assert output.shape == (2, 4, 128)
|
||||
|
||||
|
||||
@patch("torch.ops.vllm.mla_forward")
|
||||
@patch("torch_npu.npu_rms_norm")
|
||||
def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
|
||||
base_config):
|
||||
def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward,
|
||||
mock_distributed, base_config):
|
||||
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
|
||||
|
||||
attn = CustomDeepseekV2MLAAttention(config=base_config,
|
||||
@@ -253,8 +65,8 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
|
||||
with patch.object(attn.mla_attn,
|
||||
"__call__",
|
||||
return_value=torch.randn(2, 4, 128)):
|
||||
with pytest.raises(AssertionError):
|
||||
attn(positions, x)
|
||||
attn(positions, x)
|
||||
mock_mla_forward.assert_called_once()
|
||||
|
||||
attn = CustomDeepseekV2MLAAttention(config=base_config,
|
||||
hidden_size=128,
|
||||
|
||||
@@ -286,6 +286,22 @@ class TestAscendQwen2_5_VisionTransformer(PytestBase):
|
||||
"vllm_ascend.models.qwen2_5_vl.parallel_state.get_tensor_model_parallel_world_size",
|
||||
return_value=2,
|
||||
)
|
||||
mocker.patch(
|
||||
"vllm_ascend.ops.linear.divide",
|
||||
return_value=2,
|
||||
)
|
||||
|
||||
mock_group = mocker.MagicMock()
|
||||
mock_group.rank_in_group = 0
|
||||
mock_group.world_size = 2
|
||||
mocker.patch(
|
||||
"vllm_ascend.ops.linear_op.get_tp_group",
|
||||
return_value=mock_group,
|
||||
)
|
||||
mocker.patch(
|
||||
"vllm.distributed.parallel_state.get_tp_group",
|
||||
return_value=mock_group,
|
||||
)
|
||||
|
||||
vision_transformer = AscendQwen2_5_VisionTransformer(
|
||||
vision_config,
|
||||
@@ -341,6 +357,46 @@ class TestAscendQwen2_5_VisionTransformer(PytestBase):
|
||||
cos_new, _ = vision_transformer.cal_cos_sin(self.input_data)
|
||||
assert cos_new.shape == (1, 32, 1, 2)
|
||||
|
||||
def test_pad_qkv_bias(self, mocker: MockerFixture):
|
||||
attention = self.init_vision_transformer(mocker)
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
res = attention.pad_qkv_bias(torch.rand((300)))
|
||||
assert res.shape[0] == 384
|
||||
|
||||
def test_pad_qkv_weight(self, mocker: MockerFixture):
|
||||
attention = self.init_vision_transformer(mocker)
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
res = attention.pad_qkv_weight(torch.rand((300, 300)))
|
||||
assert res.shape == (384, 300)
|
||||
|
||||
def test_pad_proj_weight(self, mocker: MockerFixture):
|
||||
attention = self.init_vision_transformer(mocker)
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
res = attention.pad_proj_weight(torch.rand((300, 300)))
|
||||
assert res.shape == (300, 384)
|
||||
|
||||
def test_pad_qkv_weight_scale_offset(self, mocker: MockerFixture):
|
||||
attention = self.init_vision_transformer(mocker)
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
res = attention.pad_qkv_weight_scale_offset(torch.rand((300, 1)))
|
||||
assert res.shape == (384, 1)
|
||||
|
||||
def test_pad_qkv_deq_scale_quant_bias(self, mocker: MockerFixture):
|
||||
attention = self.init_vision_transformer(mocker)
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
res = attention.pad_qkv_deq_scale_quant_bias(torch.rand((300)))
|
||||
assert res.shape[0] == 384
|
||||
|
||||
def test_forward(self, mocker: MockerFixture):
|
||||
vision_transformer = self.init_vision_transformer(mocker)
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
|
||||
@@ -15,41 +15,11 @@
|
||||
import math
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM
|
||||
|
||||
from vllm_ascend.models.qwen3_moe import CustomQwen3MoeForCausalLM
|
||||
from vllm_ascend.torchair.models.qwen3_moe import CustomQwen3MoeAttention
|
||||
|
||||
|
||||
class TestCustomQwen3MoeForCausalLM:
|
||||
|
||||
def test_class_inheritance(self):
|
||||
assert issubclass(CustomQwen3MoeForCausalLM, Qwen3MoeForCausalLM)
|
||||
|
||||
@pytest.mark.parametrize("key, expected", [
|
||||
("qkv_proj", ["q_proj", "k_proj", "v_proj"]),
|
||||
("gate_up_proj", ["gate_proj", "up_proj"]),
|
||||
("experts",
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]),
|
||||
])
|
||||
def test_packed_modules_mapping(self, key, expected):
|
||||
assert CustomQwen3MoeForCausalLM.packed_modules_mapping[
|
||||
key] == expected
|
||||
|
||||
def test_packed_modules_mapping_structure(self):
|
||||
expected_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts": [
|
||||
"experts.0.gate_proj", "experts.0.up_proj",
|
||||
"experts.0.down_proj"
|
||||
]
|
||||
}
|
||||
assert CustomQwen3MoeForCausalLM.packed_modules_mapping == expected_mapping
|
||||
|
||||
|
||||
class DummyRMSNorm:
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
|
||||
@@ -38,7 +38,12 @@ def test_QuickGELU_forward(mock_gelu, dummy_tensor):
|
||||
|
||||
@pytest.mark.parametrize("is_310p_return", [True, False])
|
||||
@patch("torch_npu.npu_swiglu", side_effect=lambda x: x + 1)
|
||||
def test_SiluAndMul_forward(mock_swiglu, is_310p_return, dummy_tensor):
|
||||
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
|
||||
@patch("torch.ops.vllm.maybe_prefetch_mlp_down_proj",
|
||||
side_effect=lambda x: None)
|
||||
def test_SiluAndMul_forward(mock_maybe_prefetch_mlp_down_proj,
|
||||
mock_maybe_wait_prefetch_done, mock_swiglu,
|
||||
is_310p_return, dummy_tensor):
|
||||
|
||||
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
|
||||
layer = SiluAndMul()
|
||||
@@ -49,9 +54,15 @@ def test_SiluAndMul_forward(mock_swiglu, is_310p_return, dummy_tensor):
|
||||
else:
|
||||
expected_arg = dummy_tensor
|
||||
|
||||
# assert mock_maybe_prefetch_mlp_down_proj.call_count == 1
|
||||
mock_maybe_prefetch_mlp_down_proj.assert_called_once()
|
||||
|
||||
# assert mock_swiglu.call_count == 1
|
||||
mock_swiglu.assert_called_once()
|
||||
|
||||
# assert mock_maybe_wait_prefetch_done.call_count == 1
|
||||
mock_maybe_wait_prefetch_done.assert_called_once()
|
||||
|
||||
actual_arg = mock_swiglu.call_args[0][0]
|
||||
assert torch.allclose(
|
||||
actual_arg,
|
||||
|
||||
98
tests/ut/ops/test_comm_utils.py
Normal file
98
tests/ut/ops/test_comm_utils.py
Normal file
@@ -0,0 +1,98 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.ops.moe.comm_utils import (
|
||||
_gather_along_first_dim, async_all_to_all,
|
||||
gather_from_sequence_parallel_region)
|
||||
|
||||
|
||||
class TestDistributedCommunication(PytestBase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def context(self, mocker: MockerFixture):
|
||||
mocker.patch("torch.npu.current_device", return_value="cpu")
|
||||
mocker.patch("torch.distributed.get_world_size", return_value=4)
|
||||
|
||||
mocker.patch("torch.distributed.get_rank", return_value=0)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_tensor, output_split_sizes, input_split_sizes",
|
||||
[(torch.randn(8, 16), [2, 2, 2, 2], [2, 2, 2, 2]),
|
||||
(torch.randn(16, 32), None, None)])
|
||||
def test_async_all_to_all(self, input_tensor, output_split_sizes,
|
||||
input_split_sizes, mocker: MockerFixture):
|
||||
"""Test async_all_to_all"""
|
||||
mock_group = mocker.MagicMock()
|
||||
mocker.patch("torch.distributed.all_to_all_single",
|
||||
return_value=mocker.MagicMock())
|
||||
|
||||
_, a2a_out, handle = async_all_to_all(input_tensor, output_split_sizes,
|
||||
input_split_sizes, mock_group)
|
||||
|
||||
# Check if the output tensor is created properly
|
||||
if output_split_sizes is None:
|
||||
assert a2a_out.shape == input_tensor.shape
|
||||
else:
|
||||
total_output_size = sum(output_split_sizes)
|
||||
expected_shape = [total_output_size] + list(
|
||||
input_tensor.size())[1:]
|
||||
assert a2a_out.shape == torch.Size(expected_shape)
|
||||
|
||||
# Ensure handle is returned from async operation
|
||||
assert handle is not None
|
||||
assert isinstance(handle, mocker.MagicMock)
|
||||
|
||||
@pytest.mark.parametrize("world_size, test_tensor, expected",
|
||||
[(1, torch.randn(8, 16), (8, 16)),
|
||||
(4, torch.randn(8, 16), (32, 16))])
|
||||
def test_gather_along_first_dim(self, test_tensor, expected, world_size,
|
||||
mocker: MockerFixture):
|
||||
"""Test _gather_along_first_dim"""
|
||||
mocker.patch("torch.distributed.get_world_size",
|
||||
return_value=world_size)
|
||||
|
||||
result = _gather_along_first_dim(test_tensor, mocker.MagicMock())
|
||||
|
||||
assert result.shape == expected
|
||||
|
||||
@pytest.mark.parametrize("input_tensor, output_split_sizes",
|
||||
[(torch.randn(8, 16), None),
|
||||
(torch.randn(8, 16), [2, 2, 2, 2])])
|
||||
def test_gather_from_sequence_parallel_region(self, input_tensor,
|
||||
output_split_sizes,
|
||||
mocker: MockerFixture):
|
||||
"""Test gather_from_sequence_parallel_region"""
|
||||
mock_group = mocker.MagicMock()
|
||||
|
||||
result = gather_from_sequence_parallel_region(input_tensor, mock_group,
|
||||
output_split_sizes)
|
||||
|
||||
# If output_split_sizes is not provided, result should have expanded first dimension by world size
|
||||
if output_split_sizes is None:
|
||||
expected_shape = [input_tensor.shape[0] * 4] + list(
|
||||
input_tensor.shape[1:])
|
||||
assert result.shape == torch.Size(expected_shape)
|
||||
else:
|
||||
# If output_split_sizes is provided, result shape is dictated by sum of output_split_sizes
|
||||
expected_shape = [sum(output_split_sizes)] + list(
|
||||
input_tensor.shape[1:])
|
||||
assert result.shape == torch.Size(expected_shape)
|
||||
@@ -17,53 +17,40 @@ from unittest.mock import patch
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.common_fused_moe import fused_experts_moge
|
||||
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
|
||||
|
||||
|
||||
class TestFusedExpertsMoGE(TestBase):
|
||||
class TestLoadWeight(TestBase):
|
||||
|
||||
def test_fused_experts_moge(self):
|
||||
with patch('torch_npu.npu_grouped_matmul') as mock_grouped_matmul, \
|
||||
patch('torch_npu.npu_swiglu') as mock_swiglu, \
|
||||
patch('vllm_ascend.utils.is_310p') as mock_is_310p:
|
||||
def test_load_w13_transpose(self):
|
||||
with patch.object(AscendFusedMoE, "__init__",
|
||||
lambda self, *args, **kwargs: None):
|
||||
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
|
||||
|
||||
mock_is_310p.return_value = False
|
||||
expert_data = torch.randn(128, 8)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
|
||||
|
||||
mock_grouped_matmul.side_effect = lambda x, weight, **kwargs: [
|
||||
torch.randn(x[0].shape[0], weight[0].shape[1])
|
||||
]
|
||||
expert_data = torch.randn(8, 128)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w1", loaded_weight, 0)
|
||||
|
||||
mock_swiglu.side_effect = lambda x: x
|
||||
expert_data = torch.randn(128, 8)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
|
||||
|
||||
hidden_states = torch.randn(4, 128)
|
||||
w1 = torch.randn(4, 256, 128)
|
||||
w2 = torch.randn(4, 128, 128)
|
||||
topk_weights = torch.rand(4, 1)
|
||||
topk_ids = torch.tensor([[0], [1], [2], [3]], dtype=torch.long)
|
||||
top_k = 1
|
||||
global_num_experts = 4
|
||||
expert_data = torch.randn(8, 128)
|
||||
loaded_weight = torch.randn(128, 4)
|
||||
moe._load_w13(expert_data, 1, "w3", loaded_weight, 0)
|
||||
|
||||
moe_parallel_config = type(
|
||||
'MockConfig', (), {
|
||||
'ep_size': 1,
|
||||
'tp_size': 1,
|
||||
'dp_size': 1,
|
||||
'tp_rank': 0,
|
||||
'dp_rank': 0,
|
||||
'ep_rank': 0,
|
||||
'use_ep': True
|
||||
})()
|
||||
def test_load_w2_transpose(self):
|
||||
with patch.object(AscendFusedMoE, "__init__",
|
||||
lambda self, *args, **kwargs: None):
|
||||
moe = AscendFusedMoE(num_experts=4, top_k=2, hidden_size=8)
|
||||
expert_data = torch.randn(128, 4)
|
||||
loaded_weight = torch.randn(128, 8)
|
||||
moe._load_w2(expert_data, 1, loaded_weight, 0)
|
||||
|
||||
output = fused_experts_moge(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
moe_parallel_config=moe_parallel_config,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=True,
|
||||
)
|
||||
|
||||
self.assertEqual(output.shape, (4, 128))
|
||||
expert_data = torch.randn(4, 128)
|
||||
loaded_weight = torch.randn(128, 8)
|
||||
moe._load_w2(expert_data, 1, loaded_weight, 0)
|
||||
|
||||
289
tests/ut/ops/test_fused_moe_prepare_and_finalize.py
Normal file
289
tests/ut/ops/test_fused_moe_prepare_and_finalize.py
Normal file
@@ -0,0 +1,289 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
|
||||
FusedMoEPrepareAndFinalizeWithAll2All,
|
||||
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2,
|
||||
FusedMoEPrepareAndFinalizeWithNaiveMulticast)
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
|
||||
class TestFusedMoEPrepareAndFinalize(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Mock FusedMoEConfig
|
||||
self.moe_config = MagicMock(spec=FusedMoEConfig)
|
||||
self.moe_config.tp_group = MagicMock()
|
||||
self.moe_config.tp_group.device_group = MagicMock()
|
||||
self.moe_config.dp_size = 1
|
||||
self.moe_config.tp_size = 1
|
||||
self.moe_config.ep_size = 1
|
||||
self.moe_config.dp_group = MagicMock()
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
|
||||
return_value=1)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
|
||||
return_value=0)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
|
||||
)
|
||||
def test_mc2_prepare_finalize(self, mock_get_forward_context, mock_tp_rank,
|
||||
mock_tp_size):
|
||||
mock_context = MagicMock()
|
||||
mock_context.mc2_mask = torch.tensor([1, 0, 1])
|
||||
mock_context.padded_num_tokens = 4
|
||||
mock_get_forward_context.return_value = mock_context
|
||||
|
||||
layer = FusedMoEPrepareAndFinalizeWithMC2(self.moe_config)
|
||||
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
|
||||
h_out, r_out, mask = layer.prepare(hidden_states, router_logits)
|
||||
|
||||
# Check padding and split
|
||||
self.assertEqual(h_out.shape[0], 4)
|
||||
self.assertEqual(r_out.shape[0], 4)
|
||||
self.assertEqual(mask.tolist(), [1, 0, 1])
|
||||
|
||||
# Finalize
|
||||
result = layer.finalize(h_out, reduce_results=False)
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
|
||||
return_value=2)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
|
||||
return_value=0)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
|
||||
)
|
||||
@patch("torch.distributed.all_gather")
|
||||
def test_mc2_tp_split_allgather(self, mock_all_gather,
|
||||
mock_get_forward_context, mock_tp_rank,
|
||||
mock_tp_size):
|
||||
mock_context = MagicMock()
|
||||
mock_context.mc2_mask = torch.tensor([1, 0, 1, 0])
|
||||
mock_context.padded_num_tokens = 4
|
||||
mock_get_forward_context.return_value = mock_context
|
||||
|
||||
layer = FusedMoEPrepareAndFinalizeWithMC2(self.moe_config)
|
||||
hidden_states = torch.randn(4, 8)
|
||||
router_logits = torch.randn(4, 2)
|
||||
|
||||
h_out, r_out, mask = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
enable_shared_expert_dp=False,
|
||||
replace_allreduce=False)
|
||||
|
||||
# With TP=2, should split into 2 parts
|
||||
self.assertEqual(h_out.shape[0], 2)
|
||||
|
||||
# Mock all_gather behavior
|
||||
def mock_all_gather_func(tensor_list, tensor, group=None):
|
||||
tensor_list[0] = tensor
|
||||
tensor_list[1] = tensor.clone()
|
||||
|
||||
mock_all_gather.side_effect = mock_all_gather_func
|
||||
|
||||
layer.split_hidden_states = [
|
||||
torch.zeros_like(h_out),
|
||||
torch.zeros_like(h_out)
|
||||
]
|
||||
final_result = layer.finalize(h_out, reduce_results=False)
|
||||
|
||||
# Should concat back to original size
|
||||
self.assertEqual(final_result.shape[0], 4)
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
|
||||
return_value=1)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
|
||||
return_value=0)
|
||||
def test_all2all_prepare_finalize(self, mock_tp_rank, mock_tp_size):
|
||||
layer = FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
|
||||
h_out, r_out, _ = layer.prepare(hidden_states, router_logits)
|
||||
|
||||
# Pad to tp_size=1, so no change
|
||||
self.assertEqual(h_out.shape[0], 3)
|
||||
|
||||
result = layer.finalize(h_out, reduce_results=False)
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_world_size",
|
||||
return_value=2)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_tensor_model_parallel_rank",
|
||||
return_value=0)
|
||||
@patch("torch.distributed.all_gather")
|
||||
def test_all2all_tp_split_allgather(self, mock_all_gather, mock_tp_rank,
|
||||
mock_tp_size):
|
||||
layer = FusedMoEPrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
hidden_states = torch.randn(2, 8)
|
||||
router_logits = torch.randn(2, 2)
|
||||
|
||||
h_out, r_out, _ = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
enable_shared_expert_dp=False,
|
||||
replace_allreduce=False)
|
||||
|
||||
# Split due to TP=2
|
||||
self.assertEqual(h_out.shape[0], 1)
|
||||
|
||||
# Mock all_gather
|
||||
def mock_all_gather_func(tensor_list, tensor, group=None):
|
||||
tensor_list[0] = tensor
|
||||
tensor_list[1] = tensor.clone()
|
||||
|
||||
mock_all_gather.side_effect = mock_all_gather_func
|
||||
|
||||
layer.split_hidden_states = [
|
||||
torch.zeros_like(h_out),
|
||||
torch.zeros_like(h_out)
|
||||
]
|
||||
final_result = layer.finalize(h_out, reduce_results=False)
|
||||
|
||||
# Should concat back
|
||||
self.assertEqual(final_result.shape[0], 2)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce"
|
||||
)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
|
||||
)
|
||||
def test_allgather_prepare_finalize(self, mock_get_forward_context,
|
||||
mock_tp_all_reduce, mock_get_dp_group):
|
||||
# Mock forward context
|
||||
mock_context = MagicMock()
|
||||
mock_context.max_tokens_across_dp = 6
|
||||
mock_get_forward_context.return_value = mock_context
|
||||
|
||||
# Create a proper mock for DP group with working all_gather
|
||||
mock_dp_group = MagicMock()
|
||||
|
||||
def mock_all_gather_func(tensor, dim):
|
||||
# Simulate DP=2: repeat the tensor along the specified dimension
|
||||
return torch.cat([tensor, tensor], dim=dim)
|
||||
|
||||
mock_dp_group.all_gather = mock_all_gather_func
|
||||
mock_get_dp_group.return_value = mock_dp_group
|
||||
|
||||
self.moe_config.dp_size = 2
|
||||
self.moe_config.tp_size = 1
|
||||
self.moe_config.ep_size = 1
|
||||
self.moe_config.dp_group = mock_dp_group
|
||||
|
||||
layer = FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config)
|
||||
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
|
||||
# Mock the gate function for rm_router_logits=False case
|
||||
mock_gate = MagicMock()
|
||||
mock_gate.return_value = (router_logits.repeat(2, 1), None)
|
||||
|
||||
h_out, r_out, _ = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
rm_router_logits=False,
|
||||
gate=mock_gate)
|
||||
|
||||
# After all-gather with DP=2, should double the batch size
|
||||
self.assertEqual(h_out.shape[0], 12)
|
||||
self.assertEqual(r_out.shape[0], 12)
|
||||
|
||||
# Finalize with reduce_scatter
|
||||
def mock_reduce_scatter_func(tensor, dim):
|
||||
# Simulate reduce_scatter: take first half
|
||||
return tensor[:3]
|
||||
|
||||
mock_dp_group.reduce_scatter = mock_reduce_scatter_func
|
||||
result = layer.finalize(h_out, reduce_results=False)
|
||||
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
# Test with TP all-reduce
|
||||
mock_tp_all_reduce.return_value = result
|
||||
result_with_tp = layer.finalize(h_out, reduce_results=True)
|
||||
self.assertEqual(result_with_tp.shape[0], 3)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_dp_group")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.tensor_model_parallel_all_reduce"
|
||||
)
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context"
|
||||
)
|
||||
def test_naive_multicast_prepare_finalize(self, mock_get_forward_context,
|
||||
mock_tp_all_reduce,
|
||||
mock_get_dp_group):
|
||||
# Mock forward context with DP metadata
|
||||
mock_context = MagicMock()
|
||||
if vllm_version_is("0.10.2"):
|
||||
mock_context.dp_metadata.cu_tokens_across_dp_cpu = torch.tensor(
|
||||
[2, 5, 7])
|
||||
else:
|
||||
mock_context.dp_metadata.cu_tokens_across_sp.return_value = torch.tensor(
|
||||
[2, 5, 7])
|
||||
mock_get_forward_context.return_value = mock_context
|
||||
|
||||
# Setup DP group mock
|
||||
mock_dp_group = MagicMock()
|
||||
mock_dp_group.broadcast = MagicMock()
|
||||
mock_dp_group.all_reduce = MagicMock()
|
||||
mock_get_dp_group.return_value = mock_dp_group
|
||||
|
||||
# Mock all_reduce to just return input (simulate sum)
|
||||
def mock_all_reduce(tensor):
|
||||
return tensor * 2
|
||||
|
||||
mock_dp_group.all_reduce.side_effect = mock_all_reduce
|
||||
|
||||
# Setup config
|
||||
self.moe_config.dp_size = 3
|
||||
self.moe_config.dp_rank = 1
|
||||
self.moe_config.tp_size = 1
|
||||
self.moe_config.ep_size = 1
|
||||
|
||||
layer = FusedMoEPrepareAndFinalizeWithNaiveMulticast(self.moe_config)
|
||||
|
||||
# Local inputs
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
|
||||
# Mock gate for router logits recomputation
|
||||
mock_gate = MagicMock()
|
||||
mock_gate.return_value = (torch.randn(7, 2), None)
|
||||
|
||||
# Run prepare
|
||||
h_out, r_out, _ = layer.prepare(hidden_states,
|
||||
router_logits,
|
||||
rm_router_logits=False,
|
||||
gate=mock_gate)
|
||||
|
||||
# Should be global tensor: [7, 8] and [7, 2]
|
||||
self.assertEqual(h_out.shape, (7, 8))
|
||||
self.assertEqual(r_out.shape, (7, 2))
|
||||
|
||||
# Run finalize
|
||||
result = layer.finalize(h_out, reduce_results=False)
|
||||
|
||||
# Should slice back to local: [3, 8]
|
||||
self.assertEqual(result.shape, (3, 8))
|
||||
|
||||
# Test with reduce_results=True and TP/EP > 1
|
||||
mock_tp_all_reduce.return_value = result
|
||||
result_with_tp = layer.finalize(h_out, reduce_results=True)
|
||||
self.assertEqual(result_with_tp.shape, (3, 8))
|
||||
@@ -22,15 +22,13 @@ import torch_npu
|
||||
from pytest_mock import MockerFixture
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
||||
|
||||
import vllm_ascend.ops.moe_dispatcher.token_dispatcher as token_dispatcher_module
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_forward_context import (FusedMoEState,
|
||||
_get_fused_moe_state)
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
|
||||
AscendUnquantizedFusedMoEMethod)
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.ops.layers.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.utils import AscendSocVersion, adapt_patch
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp
|
||||
from vllm_ascend.utils import AscendSocVersion, adapt_patch, vllm_version_is
|
||||
|
||||
adapt_patch(True)
|
||||
|
||||
@@ -58,122 +56,94 @@ def mock_npu_format_cast(weight_data, format):
|
||||
return weight_data
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_vllm_config_mock(mocker: MockerFixture):
|
||||
mock_hf_config = MagicMock()
|
||||
mock_hf_config.model_type = "llama"
|
||||
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.hf_config = mock_hf_config
|
||||
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config = mock_model_config
|
||||
mock_vllm_config.parallel_config = MagicMock(tensor_parallel_size=2)
|
||||
mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4)
|
||||
mock_vllm_config.model_config.max_model_len = 2048
|
||||
|
||||
mocker.patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
|
||||
return_value=mock_vllm_config)
|
||||
mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config',
|
||||
return_value=mock_vllm_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dist_env(mocker: MockerFixture):
|
||||
mock_setup_token_dispatchers = MagicMock()
|
||||
mock_token_dispatcher_with_allgather = MagicMock()
|
||||
mock_token_dispatcher_with_all2allv = MagicMock()
|
||||
mock_token_dispatcher_with_mc2 = MagicMock()
|
||||
mock_moe_comm_method = MagicMock()
|
||||
|
||||
mock_dispatch_result_allgather = {
|
||||
"hidden_states": torch.randn(16, 2),
|
||||
"group_list": torch.tensor([8, 16], dtype=torch.int64),
|
||||
"group_list_type": 0,
|
||||
}
|
||||
mock_combine_result_allgather = torch.randn(16, 2)
|
||||
def mock_prepare(hidden_states, router_logits, **kwargs):
|
||||
return hidden_states, router_logits
|
||||
|
||||
mock_token_dispatcher_with_allgather.token_dispatch.return_value = mock_dispatch_result_allgather
|
||||
mock_token_dispatcher_with_allgather.token_combine.return_value = mock_combine_result_allgather
|
||||
mock_moe_comm_method.prepare.side_effect = mock_prepare
|
||||
|
||||
mock_dispatch_result_all2allv = {
|
||||
"hidden_states": torch.randn(16, 2),
|
||||
"group_list": torch.tensor([4, 8, 12, 16], dtype=torch.int64),
|
||||
"group_list_type": 1,
|
||||
"dynamic_scale": None,
|
||||
}
|
||||
mock_combine_result_all2allv = torch.randn(16, 2)
|
||||
mock_token_dispatcher_with_all2allv.token_dispatch.return_value = mock_dispatch_result_all2allv
|
||||
mock_token_dispatcher_with_all2allv.token_combine.return_value = mock_combine_result_all2allv
|
||||
mock_fused_experts_result = torch.randn(16, 2)
|
||||
mock_moe_comm_method.fused_experts.return_value = mock_fused_experts_result
|
||||
|
||||
mock_dispatch_result_mc2 = {
|
||||
"hidden_states": torch.randn(16, 2),
|
||||
"group_list": torch.tensor([5, 10, 15, 16], dtype=torch.int64),
|
||||
"group_list_type": 1,
|
||||
"dynamic_scale": None,
|
||||
"assist_info_for_combine": torch.randn(16, 2),
|
||||
"ep_recv_counts": torch.tensor([4, 4, 4, 4], dtype=torch.int32),
|
||||
}
|
||||
mock_combine_result_mc2 = torch.randn(16, 2)
|
||||
mock_token_dispatcher_with_mc2.token_dispatch.return_value = mock_dispatch_result_mc2
|
||||
mock_token_dispatcher_with_mc2.token_combine.return_value = mock_combine_result_mc2
|
||||
def mock_finalize(hidden_states, **kwargs):
|
||||
return hidden_states
|
||||
|
||||
captured_dispatchers = {}
|
||||
mock_moe_comm_method.finalize.side_effect = mock_finalize
|
||||
|
||||
def capture_register(dispatcher_instance):
|
||||
key = dispatcher_instance.__class__.__name__
|
||||
captured_dispatchers[key] = dispatcher_instance
|
||||
if key == 'TokenDispatcherWithAllGather':
|
||||
captured_dispatchers[key] = mock_token_dispatcher_with_allgather
|
||||
elif key == 'TokenDispatcherWithAll2AllV':
|
||||
captured_dispatchers[key] = mock_token_dispatcher_with_all2allv
|
||||
elif key == 'TokenDispatcherWithMC2':
|
||||
captured_dispatchers[key] = mock_token_dispatcher_with_mc2
|
||||
|
||||
mock_register_token_dispatcher_patcher = patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher',
|
||||
side_effect=capture_register)
|
||||
|
||||
mock_get_token_dispatcher_patcher = patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_token_dispatcher',
|
||||
side_effect=lambda name: captured_dispatchers.get(name))
|
||||
|
||||
default_mock_token_dispatcher = mock_token_dispatcher_with_allgather
|
||||
|
||||
mock_forward_context_obj = MagicMock(
|
||||
fused_moe_state=FusedMoEState.AllGather,
|
||||
token_dispatcher=default_mock_token_dispatcher,
|
||||
max_tokens_across_dp=10,
|
||||
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]),
|
||||
mc2_mask=torch.zeros(16, dtype=torch.bool),
|
||||
padded_num_tokens=16,
|
||||
with_quant=False)
|
||||
if vllm_version_is("0.10.2"):
|
||||
dp_metadata = MagicMock(cu_tokens_across_dp_cpu=[5, 10])
|
||||
else:
|
||||
dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5])
|
||||
mock_forward_context_obj = MagicMock(moe_comm_method=mock_moe_comm_method,
|
||||
moe_comm_type=MoECommType.MC2,
|
||||
max_tokens_across_dp=10,
|
||||
dp_metadata=dp_metadata,
|
||||
mc2_mask=torch.zeros(
|
||||
16, dtype=torch.bool),
|
||||
padded_num_tokens=16,
|
||||
with_quant=False)
|
||||
|
||||
with patch('torch.distributed.get_rank', return_value=0), \
|
||||
patch('torch.distributed.get_world_size', return_value=4), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('torch.distributed.all_gather'), \
|
||||
patch('torch.distributed.all_to_all_single'), \
|
||||
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce'), \
|
||||
patch('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter'), \
|
||||
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
|
||||
return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_ascend_config',
|
||||
return_value=MagicMock(
|
||||
torchair_graph_config=MagicMock(enabled=False, enable_multistream_moe=False),
|
||||
torchair_graph_config=MagicMock(enabled=False),
|
||||
enable_multistream_moe=False,
|
||||
expert_map_path=None
|
||||
)), \
|
||||
patch('vllm_ascend.ops.fused_moe.determine_expert_map',
|
||||
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
|
||||
return_value=MagicMock(
|
||||
parallel_config=MagicMock(tensor_parallel_size=2),
|
||||
scheduler_config=MagicMock(max_num_seqs=4),
|
||||
model_config=MagicMock(max_model_len=2048)
|
||||
)), \
|
||||
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
|
||||
patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers), \
|
||||
patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context',
|
||||
return_value=mock_forward_context_obj):
|
||||
patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch('vllm_ascend.ops.moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
|
||||
return_value=None), \
|
||||
patch('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
|
||||
return_value=None), \
|
||||
patch('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
|
||||
return_value=None):
|
||||
|
||||
yield {
|
||||
'mock_forward_context_obj': mock_forward_context_obj,
|
||||
'mock_token_dispatcher_with_allgather':
|
||||
mock_token_dispatcher_with_allgather,
|
||||
'mock_token_dispatcher_with_all2allv':
|
||||
mock_token_dispatcher_with_all2allv,
|
||||
'mock_token_dispatcher_with_mc2': mock_token_dispatcher_with_mc2,
|
||||
'mock_moe_comm_method': mock_moe_comm_method,
|
||||
}
|
||||
|
||||
mock_register_token_dispatcher_patcher.stop()
|
||||
mock_get_token_dispatcher_patcher.stop()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_moe_env(mocker: MockerFixture):
|
||||
@@ -235,6 +205,8 @@ def default_moe_config():
|
||||
def moe_method(mock_dist_env):
|
||||
moe = MagicMock()
|
||||
moe.moe_parallel_config.return_value = MagicMock(ep_size=4)
|
||||
moe.moe_parallel_config.use_ep = False
|
||||
moe.moe_parallel_config.dp_size = 1
|
||||
return AscendUnquantizedFusedMoEMethod(moe)
|
||||
|
||||
|
||||
@@ -280,6 +252,9 @@ class MockFusedMoEMethod(FusedMoEMethodBase):
|
||||
expert_weights: torch.Tensor) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module):
|
||||
pass
|
||||
|
||||
|
||||
class TestAscendFusedMoe:
|
||||
|
||||
@@ -339,9 +314,7 @@ class TestAscendFusedMoe:
|
||||
moe.moe_parallel_config.ep_size = 1
|
||||
|
||||
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
|
||||
forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens,
|
||||
dtype=torch.bool),
|
||||
padded_num_tokens=num_tokens)
|
||||
forward_context = mock_dist_env['mock_forward_context_obj']
|
||||
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
|
||||
return_value=forward_context):
|
||||
output = moe.forward(inputs,
|
||||
@@ -395,25 +368,10 @@ class TestAscendUnquantizedFusedMoEMethod:
|
||||
[[256, 4], [128, 1], [128, 1], [128, 4]])
|
||||
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
|
||||
mock_moe_env, others_param):
|
||||
|
||||
global_num_experts, ep_size = others_param
|
||||
is_prefill = False
|
||||
is_deepseek_v3_r1 = global_num_experts == 256
|
||||
|
||||
if ep_size == 1:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_allgather']
|
||||
elif ep_size < 16:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_all2allv']
|
||||
else:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_mc2']
|
||||
|
||||
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
|
||||
ep_size, is_prefill, is_deepseek_v3_r1),
|
||||
with_quant=False,
|
||||
token_dispatcher=selected_token_dispatcher)
|
||||
forward_context = mock_dist_env['mock_forward_context_obj']
|
||||
|
||||
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
|
||||
return_value=forward_context):
|
||||
@@ -439,35 +397,22 @@ class TestAscendUnquantizedFusedMoEMethod:
|
||||
global_num_experts=global_num_experts,
|
||||
is_prefill=is_prefill)
|
||||
|
||||
expected_shape = (16, 2)
|
||||
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
|
||||
mock_moe_comm_method.fused_experts.assert_called_once()
|
||||
|
||||
expected_shape = (16, 2)
|
||||
assert result.shape == expected_shape
|
||||
|
||||
@pytest.mark.parametrize("others_param", [16, 1, 4])
|
||||
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
|
||||
mock_moe_env, others_param):
|
||||
|
||||
ep_size = others_param
|
||||
is_prefill = False
|
||||
|
||||
if ep_size == 1:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_allgather']
|
||||
elif ep_size < 16:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_all2allv']
|
||||
else:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_mc2']
|
||||
|
||||
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
|
||||
ep_size, is_prefill, True),
|
||||
with_quant=False,
|
||||
token_dispatcher=selected_token_dispatcher)
|
||||
forward_context = mock_dist_env['mock_forward_context_obj']
|
||||
|
||||
with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
|
||||
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3):
|
||||
|
||||
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
|
||||
moe_method.ep_size = ep_size
|
||||
x = torch.randn(8, 2, 2)
|
||||
@@ -494,8 +439,10 @@ class TestAscendUnquantizedFusedMoEMethod:
|
||||
expert_map=expert_map,
|
||||
is_prefill=is_prefill)
|
||||
|
||||
expected_shape = (16, 2)
|
||||
mock_moe_comm_method = mock_dist_env['mock_moe_comm_method']
|
||||
mock_moe_comm_method.fused_experts.assert_called_once()
|
||||
|
||||
expected_shape = (16, 2)
|
||||
assert result.shape == expected_shape
|
||||
|
||||
|
||||
@@ -524,10 +471,47 @@ class TestExpertsSelector:
|
||||
assert topk_ids.shape == (8, 2)
|
||||
|
||||
|
||||
class TestCumsumGroupList(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.active_num = 8
|
||||
self.expert_num = 128
|
||||
self.experts = torch.zeros((self.expert_num, ), dtype=torch.int64)
|
||||
self.experts[:self.active_num] = 1
|
||||
self.experts = self.experts[torch.randperm(self.expert_num)]
|
||||
self.group_list = self.experts.cumsum(dim=0)
|
||||
|
||||
def test_cumsum_group_list_with_type_0(self):
|
||||
group_list = self.experts.cumsum(dim=0)
|
||||
group_list_type = 0
|
||||
result = cumsum_group_list(group_list, group_list_type)
|
||||
self.assertTrue(torch.equal(result, self.group_list))
|
||||
|
||||
def test_cumsum_group_list_with_type_1(self):
|
||||
group_list = self.experts
|
||||
group_list_type = 1
|
||||
result = cumsum_group_list(group_list, group_list_type)
|
||||
self.assertTrue(torch.equal(result, self.group_list))
|
||||
|
||||
def test_cumsum_group_list_with_type_2(self):
|
||||
tokens = torch.arange(self.expert_num, dtype=torch.int64)
|
||||
group_list = torch.cat([
|
||||
tokens.reshape(self.expert_num, 1),
|
||||
self.experts.reshape(self.expert_num, 1)
|
||||
],
|
||||
dim=1)
|
||||
group_list_type = 2
|
||||
result = cumsum_group_list(group_list,
|
||||
group_list_type,
|
||||
active_num=self.active_num,
|
||||
expert_num=self.expert_num)
|
||||
self.assertTrue(torch.equal(result, self.group_list))
|
||||
|
||||
|
||||
class TestUnifiedApplyMLP(TestBase):
|
||||
|
||||
@patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context')
|
||||
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
|
||||
@patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context')
|
||||
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@patch('torch_npu.npu_dequant_swiglu_quant')
|
||||
@@ -538,7 +522,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
mock_get_forward_context):
|
||||
|
||||
mock_forward_context = MagicMock()
|
||||
mock_forward_context.fused_moe_state = FusedMoEState.MC2
|
||||
mock_forward_context.moe_comm_type = MoECommType.MC2
|
||||
mock_get_forward_context.return_value = mock_forward_context
|
||||
|
||||
mock_is_310p.return_value = False
|
||||
@@ -582,8 +566,6 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
with_quant=True)
|
||||
|
||||
mock_get_forward_context.assert_called()
|
||||
self.assertEqual(mock_forward_context.fused_moe_state,
|
||||
FusedMoEState.MC2)
|
||||
|
||||
mock_npu_dynamic_quant.assert_called()
|
||||
|
||||
@@ -593,7 +575,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
|
||||
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@@ -635,7 +617,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.float16)
|
||||
|
||||
@patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context')
|
||||
@patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@@ -695,7 +677,7 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
|
||||
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@@ -739,3 +721,68 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.float16)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_mlp.get_forward_context")
|
||||
@patch("torch_npu.npu_grouped_matmul")
|
||||
@patch("torch_npu.npu_swiglu")
|
||||
@patch("torch_npu.npu_grouped_matmul_swiglu_quant")
|
||||
@patch("torch_npu.npu_dynamic_quant")
|
||||
def test_unified_apply_mlp_with_quantization_and_fusion_mlp(
|
||||
self, mock_npu_dynamic_quant, mock_npu_grouped_matmul_swiglu_quant,
|
||||
mock_npu_swiglu, mock_npu_grouped_matmul,
|
||||
mock_get_forward_context):
|
||||
|
||||
mock_forward_context = MagicMock()
|
||||
mock_forward_context.with_quant = True
|
||||
mock_forward_context.fused_moe_state = "NOT_MC2"
|
||||
mock_get_forward_context.return_value = mock_forward_context
|
||||
|
||||
mock_npu_grouped_matmul_swiglu_quant.return_value = (torch.randint(
|
||||
-128, 127, (10, 40),
|
||||
dtype=torch.int8), torch.rand(
|
||||
10, 1,
|
||||
dtype=torch.float32), torch.rand(10, 1, dtype=torch.float32))
|
||||
mock_npu_grouped_matmul.side_effect = [[
|
||||
torch.randn(10, 20, dtype=torch.bfloat16)
|
||||
]]
|
||||
mock_npu_swiglu.return_value = torch.randn(10,
|
||||
40,
|
||||
dtype=torch.bfloat16)
|
||||
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
|
||||
127, (10, 40),
|
||||
dtype=torch.int8),
|
||||
torch.rand(10,
|
||||
1,
|
||||
dtype=torch.float32))
|
||||
|
||||
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
|
||||
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
|
||||
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
|
||||
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
|
||||
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
|
||||
w1_scale_bias = torch.randn(5, 40, dtype=torch.bfloat16)
|
||||
w2_scale_bias = torch.randn(5, 20, dtype=torch.bfloat16)
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
|
||||
|
||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=provided_dynamic_scale,
|
||||
group_list_type=1,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
topk_scales=None,
|
||||
with_quant=True,
|
||||
fusion=True)
|
||||
|
||||
mock_get_forward_context.assert_called()
|
||||
mock_npu_grouped_matmul.assert_called_once()
|
||||
mock_npu_grouped_matmul_swiglu_quant.assert_called_once()
|
||||
|
||||
self.assertTrue(mock_forward_context.with_quant)
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
from unittest.mock import patch
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from pytest_mock import MockerFixture
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_tensor():
|
||||
return torch.randn(4, 8, dtype=torch.float16)
|
||||
|
||||
def mock_maybe_chunk_residual(x, residual):
|
||||
if x.size(0) != residual.size(0):
|
||||
return residual[:4]
|
||||
return residual
|
||||
|
||||
|
||||
def mock_rms_norm(x, weight, eps):
|
||||
@@ -18,36 +23,139 @@ def mock_add_rms_norm(x, residual, weight, eps):
|
||||
return 2 * x, None, 2 * residual
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_310p_return", [True, False])
|
||||
@pytest.mark.parametrize("residual",
|
||||
[None, torch.randn(4, 8, dtype=torch.float32)])
|
||||
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
|
||||
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
|
||||
def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p_return,
|
||||
residual, dummy_tensor):
|
||||
def mock_add_rms_norm_quant(x, residual, weight, quant_scale, quant_offset,
|
||||
epsilon):
|
||||
x_out = 2 * x
|
||||
residual_out = 2 * residual
|
||||
x_out_quant = x_out.to(torch.int8)
|
||||
residual_out_quant = residual_out.to(torch.int8)
|
||||
return x_out_quant, None, residual_out_quant
|
||||
|
||||
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
|
||||
layer = RMSNorm(hidden_size=32, eps=1e-05)
|
||||
|
||||
class TestAscendRMSNorm(PytestBase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def context(self, mocker: MockerFixture):
|
||||
mocker.patch("torch.ops.vllm.maybe_chunk_residual",
|
||||
side_effect=mock_maybe_chunk_residual)
|
||||
mocker.patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
|
||||
mocker.patch("torch_npu.npu_add_rms_norm",
|
||||
side_effect=mock_add_rms_norm)
|
||||
mocker.patch("torch_npu.npu_add_rms_norm_quant",
|
||||
side_effect=mock_add_rms_norm_quant)
|
||||
mocker.patch("torch.ops.vllm.maybe_wait_prefetch_done",
|
||||
side_effect=lambda x: None)
|
||||
|
||||
# Test case for the most common and basic scenario
|
||||
@pytest.mark.parametrize(
|
||||
"residual", [None, torch.randn(4, 8, dtype=torch.float16)])
|
||||
def test_forward_oot_basic(self, residual):
|
||||
layer = RMSNorm(hidden_size=8, eps=1e-05)
|
||||
x = torch.randn(4, 8, dtype=torch.float16)
|
||||
if residual is not None:
|
||||
out_x, out_residual = layer.forward_oot(dummy_tensor, residual)
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
if is_310p_return:
|
||||
expected_arg_x = dummy_tensor + residual.to(dummy_tensor.dtype)
|
||||
expected_out_x = expected_arg_x + 1
|
||||
expected_out_residual = expected_arg_x.to(residual.dtype)
|
||||
x_out_expected = 2 * x
|
||||
residual_out_expected = 2 * residual
|
||||
|
||||
mock_rmsnorm.assert_called_once()
|
||||
assert torch.allclose(out_x, expected_out_x)
|
||||
assert torch.allclose(out_residual, expected_out_residual)
|
||||
else:
|
||||
expected_out_x = 2 * dummy_tensor
|
||||
expected_out_residual = 2 * residual
|
||||
mock_add_rmsnorm.assert_called_once()
|
||||
assert torch.allclose(out_x, expected_out_x)
|
||||
assert torch.allclose(out_residual, expected_out_residual)
|
||||
assert torch.allclose(x_out, x_out_expected)
|
||||
assert torch.allclose(residual_out, residual_out_expected)
|
||||
else:
|
||||
out_x = layer.forward(dummy_tensor, residual)
|
||||
expected_out_x = dummy_tensor + 1
|
||||
x_out = layer.forward(x, residual)
|
||||
x_out_expected = x + 1
|
||||
|
||||
mock_rmsnorm.assert_called_once()
|
||||
assert torch.allclose(out_x, expected_out_x)
|
||||
assert torch.allclose(x_out, x_out_expected)
|
||||
|
||||
# Test case for flashcomm_v1 scenario
|
||||
def test_forward_oot_with_flashcomm_v1(self):
|
||||
layer = RMSNorm(hidden_size=512, eps=1e-05)
|
||||
x = torch.randn(4, 512, dtype=torch.bfloat16)
|
||||
residual = torch.randn(16, 512, dtype=torch.bfloat16)
|
||||
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
x_out_expected = 2 * x
|
||||
residual_out_expected = 2 * residual[:4]
|
||||
|
||||
assert residual_out.size(0) == 4
|
||||
assert torch.allclose(x_out, x_out_expected)
|
||||
assert torch.allclose(residual_out, residual_out_expected)
|
||||
|
||||
# Test case for addrmsnorm + w8a8 quant fusion
|
||||
def test_forward_oot_with_quant_fusion(self, mocker: MockerFixture):
|
||||
mock_is_310p = mocker.patch("vllm_ascend.utils.is_310p")
|
||||
mock_is_310p.return_value = False
|
||||
mock_get_forward_context = mocker.patch(
|
||||
"vllm_ascend.ops.layernorm.get_forward_context")
|
||||
|
||||
# Simulating a scenario with quant_fusion enabled
|
||||
mock_forward_context = mocker.MagicMock()
|
||||
|
||||
mock_model_instance = mocker.MagicMock()
|
||||
mock_forward_context.model_instance = mock_model_instance
|
||||
mock_model_instance.model.layers = [
|
||||
mocker.MagicMock() for _ in range(2)
|
||||
]
|
||||
|
||||
mock_layer_0 = mock_model_instance.model.layers[0]
|
||||
mock_layer_0.self_attn.qkv_proj = mocker.MagicMock()
|
||||
mock_layer_0.mlp.gate_up_proj = mocker.MagicMock()
|
||||
|
||||
mock_layer_1 = mock_model_instance.model.layers[1]
|
||||
mock_layer_1.self_attn.qkv_proj = mocker.MagicMock()
|
||||
mock_layer_1.mlp.gate_up_proj = mocker.MagicMock()
|
||||
|
||||
mock_quant_method_0_qkv = mocker.MagicMock()
|
||||
mock_quant_method_0_qkv.quant_method = AscendW8A8LinearMethod()
|
||||
mock_quant_method_0_gate_up = mocker.MagicMock()
|
||||
mock_quant_method_0_gate_up.quant_method = AscendW8A8LinearMethod()
|
||||
mock_layer_0.self_attn.qkv_proj.quant_method = mock_quant_method_0_qkv
|
||||
mock_layer_0.mlp.gate_up_proj.quant_method = mock_quant_method_0_gate_up
|
||||
|
||||
mock_quant_method_1_qkv = mocker.MagicMock()
|
||||
mock_quant_method_1_qkv.quant_method = AscendW8A8LinearMethod()
|
||||
mock_quant_method_1_gate_up = mocker.MagicMock()
|
||||
mock_quant_method_1_gate_up.quant_method = AscendW8A8LinearMethod()
|
||||
mock_layer_1.self_attn.qkv_proj.quant_method = mock_quant_method_1_qkv
|
||||
mock_layer_1.mlp.gate_up_proj.quant_method = mock_quant_method_1_gate_up
|
||||
|
||||
mock_get_forward_context.return_value = mock_forward_context
|
||||
|
||||
mock_forward_context.addrmsnorm_quant_fusion_enabled = True
|
||||
mock_forward_context.prefetch_mlp_enabled = False
|
||||
mock_forward_context.layer_idx = 0
|
||||
mock_forward_context.num_hidden_layers = 2
|
||||
mock_forward_context.fusion_linear = "gate_up_dense"
|
||||
|
||||
# Ensure fusion and layer_idx increment are handled correctly
|
||||
x = torch.randn(4, 8, dtype=torch.float16)
|
||||
residual = torch.randn(4, 8, dtype=torch.float16)
|
||||
layer = RMSNorm(hidden_size=8, eps=1e-05)
|
||||
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
assert mock_get_forward_context.call_count == 1
|
||||
assert mock_forward_context.fusion_linear == "qkv_dense"
|
||||
assert mock_forward_context.layer_idx == 1
|
||||
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
assert mock_get_forward_context.call_count == 2
|
||||
assert mock_forward_context.fusion_linear == "gate_up_dense"
|
||||
assert mock_forward_context.layer_idx == 1
|
||||
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
assert mock_get_forward_context.call_count == 3
|
||||
assert mock_forward_context.fusion_linear == "qkv_dense"
|
||||
assert mock_forward_context.layer_idx == 2
|
||||
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
assert mock_get_forward_context.call_count == 4
|
||||
assert mock_forward_context.fusion_linear == "qkv_dense"
|
||||
assert mock_forward_context.layer_idx == 2
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -1,363 +1,96 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_ascend.ops.linear import (AscendMlpColumnParallelLinear,
|
||||
AscendMlpMergedColumnParallelLinear,
|
||||
AscendMlpRowParallelLinear, LinearBase,
|
||||
QuantizationConfig)
|
||||
from vllm_ascend import ascend_config
|
||||
from vllm_ascend.distributed import parallel_state
|
||||
from vllm_ascend.ops.linear import (AscendMergedColumnParallelLinear,
|
||||
AscendRowParallelLinear)
|
||||
|
||||
|
||||
class TestAscendMlpRowParallelLinear(unittest.TestCase):
|
||||
class BaseLinearTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
|
||||
self.tensor_parallel_world_size = 2
|
||||
self.tensor_parallel_rank = 0
|
||||
self.mlp_tensor_parallel_world_size = 2
|
||||
self.mlp_tensor_parallel_rank = 1
|
||||
self.mock_group = mock.MagicMock()
|
||||
self.mock_group.world_size = 2
|
||||
self.mock_group.rank_in_group = 0
|
||||
|
||||
self.get_tensor_model_parallel_world_size_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.get_tensor_model_parallel_world_size',
|
||||
return_value=self.tensor_parallel_world_size)
|
||||
self.get_tensor_model_parallel_rank_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.get_tensor_model_parallel_rank',
|
||||
return_value=self.tensor_parallel_rank)
|
||||
self.get_mlp_tensor_model_parallel_world_size_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size',
|
||||
return_value=self.mlp_tensor_parallel_world_size)
|
||||
self.get_mlp_tensor_model_parallel_rank_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank',
|
||||
return_value=self.mlp_tensor_parallel_rank)
|
||||
parallel_state._MLP_TP = self.mock_group
|
||||
parallel_state._OTP = self.mock_group
|
||||
|
||||
self.get_tensor_model_parallel_world_size_mock = \
|
||||
self.get_tensor_model_parallel_world_size_patch.start()
|
||||
self.get_tensor_model_parallel_rank_mock = \
|
||||
self.get_tensor_model_parallel_rank_patch.start()
|
||||
self.get_mlp_tensor_model_parallel_world_size_mock = \
|
||||
self.get_mlp_tensor_model_parallel_world_size_patch.start()
|
||||
self.get_mlp_tensor_model_parallel_rank_mock = \
|
||||
self.get_mlp_tensor_model_parallel_rank_patch.start()
|
||||
self.mock_ascend_config = MagicMock()
|
||||
self.mock_ascend_config.oproj_tensor_parallel_size = 2
|
||||
|
||||
self.split_tensor_along_last_dim_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.split_tensor_along_last_dim',
|
||||
return_value=(torch.randn(10, 8), torch.randn(10, 8)))
|
||||
self.tensor_model_parallel_all_reduce_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.tensor_model_parallel_all_reduce',
|
||||
return_value=torch.randn(10, 8))
|
||||
self.tensor_model_parallel_all_reduce_mock = \
|
||||
self.tensor_model_parallel_all_reduce_patch.start()
|
||||
self.split_tensor_along_last_dim_mock = \
|
||||
self.split_tensor_along_last_dim_patch.start()
|
||||
self.get_mlp_tp_group_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_mlp_tp_group')
|
||||
self.get_mlp_tp_group_mock = self.get_mlp_tp_group_patch.start()
|
||||
self.get_mlp_tp_group_mock.return_value = mock.MagicMock()
|
||||
self.get_mlp_tp_group_mock.return_value.reduce_scatter = \
|
||||
mock.MagicMock()
|
||||
self.patches = [
|
||||
patch("vllm_ascend.ascend_config.get_ascend_config",
|
||||
return_value=self.mock_ascend_config),
|
||||
patch("vllm_ascend.distributed.parallel_state.get_otp_group",
|
||||
return_value=self.mock_group),
|
||||
patch("vllm_ascend.distributed.parallel_state.get_mlp_tp_group",
|
||||
return_value=self.mock_group),
|
||||
patch("vllm_ascend.ops.linear_op.get_tp_group",
|
||||
return_value=self.mock_group),
|
||||
patch(
|
||||
"vllm.distributed.parallel_state.get_tp_group",
|
||||
return_value=self.mock_group,
|
||||
),
|
||||
patch("vllm_ascend.utils.mlp_tp_enable", return_value=True),
|
||||
patch("vllm_ascend.utils.oproj_tp_enable", return_value=True)
|
||||
]
|
||||
|
||||
for p in self.patches:
|
||||
p.start()
|
||||
|
||||
def tearDown(self):
|
||||
self.get_tensor_model_parallel_world_size_patch.stop()
|
||||
self.get_tensor_model_parallel_rank_patch.stop()
|
||||
self.get_mlp_tensor_model_parallel_world_size_patch.stop()
|
||||
self.get_mlp_tensor_model_parallel_rank_patch.stop()
|
||||
self.split_tensor_along_last_dim_patch.stop()
|
||||
self.tensor_model_parallel_all_reduce_patch.stop()
|
||||
self.get_mlp_tp_group_patch.stop()
|
||||
for p in self.patches:
|
||||
p.stop()
|
||||
|
||||
def test_init_with_down_proj_prefix(self):
|
||||
layer = AscendMlpRowParallelLinear(input_size=16,
|
||||
output_size=8,
|
||||
prefix="down_proj")
|
||||
self.assertEqual(layer.tp_size, self.mlp_tensor_parallel_world_size)
|
||||
self.assertEqual(layer.tp_rank, self.mlp_tensor_parallel_rank)
|
||||
self.assertTrue(layer.enable_mlp_optimze)
|
||||
|
||||
def test_forward_with_mlp_optimize(self):
|
||||
layer = AscendMlpRowParallelLinear(
|
||||
class TestAscendRowParallelLinear(BaseLinearTest):
|
||||
|
||||
def test_mlp_optimize(self):
|
||||
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
|
||||
|
||||
linear = AscendRowParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
prefix="down_proj",
|
||||
input_is_parallel=False,
|
||||
)
|
||||
input_tensor = torch.randn(16, 8) # (batch_size, input_size)
|
||||
layer(input_tensor)
|
||||
self.assertEqual(linear.custom_op.comm_group, parallel_state._MLP_TP)
|
||||
|
||||
self.split_tensor_along_last_dim_mock.assert_called_once_with(
|
||||
input_tensor, num_partitions=layer.tp_size)
|
||||
input_tensor = torch.randn(16, 8)
|
||||
linear(input_tensor)
|
||||
|
||||
def test_forward_without_mlp_optimize(self):
|
||||
layer = AscendMlpRowParallelLinear(
|
||||
def test_oproj_tp(self):
|
||||
ascend_config._ASCEND_CONFIG = MagicMock()
|
||||
ascend_config._ASCEND_CONFIG.oproj_tensor_parallel_size = 2
|
||||
|
||||
linear = AscendRowParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
prefix="other",
|
||||
input_is_parallel=False,
|
||||
prefix="o_proj",
|
||||
)
|
||||
self.assertEqual(linear.custom_op.comm_group, parallel_state._OTP)
|
||||
|
||||
input_tensor = torch.randn(16, 8)
|
||||
layer(input_tensor)
|
||||
linear(input_tensor)
|
||||
|
||||
self.split_tensor_along_last_dim_mock.assert_called_once_with(
|
||||
input_tensor, num_partitions=layer.tp_size)
|
||||
self.tensor_model_parallel_all_reduce_mock.assert_called_once()
|
||||
|
||||
def test_skip_bias_add(self):
|
||||
layer = AscendMlpRowParallelLinear(
|
||||
class TestAscendMergedColumnParallelLinear(BaseLinearTest):
|
||||
|
||||
def test_merged_mlp_tp_init(self):
|
||||
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
|
||||
|
||||
linear = AscendMergedColumnParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
skip_bias_add=True,
|
||||
output_sizes=[8, 8],
|
||||
prefix="gate_up_proj",
|
||||
)
|
||||
input_tensor = torch.randn(16, 8)
|
||||
output, bias = layer(input_tensor)
|
||||
|
||||
self.assertIsNotNone(bias)
|
||||
|
||||
def test_no_reduce_results(self):
|
||||
layer = AscendMlpRowParallelLinear(input_size=16,
|
||||
output_size=8,
|
||||
reduce_results=False,
|
||||
bias=False)
|
||||
input_tensor = torch.randn(16, 8)
|
||||
layer(input_tensor)
|
||||
|
||||
self.tensor_model_parallel_all_reduce_mock.assert_not_called()
|
||||
|
||||
def test_input_not_parallel(self):
|
||||
layer = AscendMlpRowParallelLinear(input_size=16,
|
||||
output_size=8,
|
||||
input_is_parallel=False)
|
||||
input_tensor = torch.randn(16, 8)
|
||||
layer(input_tensor)
|
||||
|
||||
self.split_tensor_along_last_dim_mock.assert_called_once()
|
||||
|
||||
def test_exception_when_reduce_false_and_bias(self):
|
||||
with self.assertRaises(ValueError):
|
||||
AscendMlpRowParallelLinear(input_size=16,
|
||||
output_size=8,
|
||||
reduce_results=False,
|
||||
bias=True,
|
||||
skip_bias_add=False)
|
||||
self.assertEqual(linear.custom_op.comm_group, parallel_state._MLP_TP)
|
||||
|
||||
|
||||
class TestAscendMlpColumnParallelLinear(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
|
||||
# Mock distributed functions
|
||||
self.mlp_tp_size_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size')
|
||||
self.mlp_tp_size_mock = self.mlp_tp_size_patch.start()
|
||||
self.mlp_tp_size_mock.return_value = 2 # Simulate 2 GPUs in MLP TP group
|
||||
|
||||
self.mlp_tp_rank_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank')
|
||||
self.mlp_tp_rank_mock = self.mlp_tp_rank_patch.start()
|
||||
self.mlp_tp_rank_mock.return_value = 0 # Current GPU rank
|
||||
|
||||
self.tp_size_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_tensor_model_parallel_world_size')
|
||||
self.tp_size_mock = self.tp_size_patch.start()
|
||||
self.tp_size_mock.return_value = 4 # Simulate 4 GPUs in regular TP group
|
||||
|
||||
self.tp_rank_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_tensor_model_parallel_rank')
|
||||
self.tp_rank_mock = self.tp_rank_patch.start()
|
||||
self.tp_rank_mock.return_value = 1 # Current GPU rank
|
||||
|
||||
# Mock divide function (assumed to be in your module)
|
||||
self.divide_patch = mock.patch('vllm_ascend.ops.linear.divide')
|
||||
self.divide_mock = self.divide_patch.start()
|
||||
self.divide_mock.side_effect = lambda x, y: x // y # Simulate division
|
||||
|
||||
# Mock QuantizationConfig and QuantMethod
|
||||
self.quant_config_mock = mock.MagicMock(spec=QuantizationConfig)
|
||||
|
||||
# Mock LinearBase initialization
|
||||
self.linear_base_init_patch = mock.patch.object(
|
||||
LinearBase, "__init__", side_effect=self.mock_linear_base_init)
|
||||
self.linear_base_init_patch.start()
|
||||
|
||||
self.quant_method_mock = mock.MagicMock()
|
||||
|
||||
def mock_linear_base_init(self, instance, *args, **kwargs):
|
||||
instance.quant_method = self.quant_method_mock
|
||||
instance.params_dtype = mock.MagicMock()
|
||||
|
||||
instance.input_size = 16
|
||||
instance.output_size = 8
|
||||
instance.output_size_per_partition = 4
|
||||
instance.params_dtype = torch.float32
|
||||
|
||||
def tearDown(self):
|
||||
self.mlp_tp_size_patch.stop()
|
||||
self.mlp_tp_rank_patch.stop()
|
||||
self.tp_size_patch.stop()
|
||||
self.tp_rank_patch.stop()
|
||||
self.divide_patch.stop()
|
||||
self.linear_base_init_patch.stop()
|
||||
|
||||
def test_mlp_optimize_initialization(self):
|
||||
# Test when prefix contains "gate_up_proj"
|
||||
with mock.patch.object(torch.nn.Module, 'register_parameter'):
|
||||
layer = AscendMlpColumnParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
prefix="model.layers.0.gate_up_proj",
|
||||
bias=False,
|
||||
)
|
||||
|
||||
# Verify MLP optimization flags
|
||||
self.assertTrue(layer.enable_mlp_optimze)
|
||||
self.assertEqual(layer.tp_size, 2)
|
||||
self.assertEqual(layer.tp_rank, 0)
|
||||
self.assertEqual(layer.input_size_per_partition, 16)
|
||||
self.assertEqual(layer.output_size_per_partition, 4)
|
||||
|
||||
# Check quant_method.create_weights was called
|
||||
self.quant_method_mock.create_weights.assert_called_once()
|
||||
|
||||
def test_regular_parallel_initialization(self):
|
||||
# Test when prefix does NOT contain "gate_up_proj"
|
||||
with mock.patch.object(torch.nn.Module, 'register_parameter'):
|
||||
layer = AscendMlpColumnParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
prefix="model.layers.0.q_proj",
|
||||
quant_config=self.quant_config_mock,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
# Verify regular TP flags
|
||||
self.assertFalse(layer.enable_mlp_optimze)
|
||||
self.assertEqual(layer.tp_size, 4)
|
||||
self.assertEqual(layer.tp_rank, 1)
|
||||
self.assertEqual(layer.input_size_per_partition, 16)
|
||||
self.assertEqual(layer.output_size_per_partition, 4)
|
||||
# Check quant_method.create_weights was called
|
||||
self.quant_method_mock.create_weights.assert_called_once()
|
||||
|
||||
def test_output_sizes_handling(self):
|
||||
# Test when output_sizes is provided
|
||||
with mock.patch.object(torch.nn.Module, 'register_parameter'):
|
||||
layer = AscendMlpColumnParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
output_sizes=[4, 4],
|
||||
prefix="model.layers.0.qkv_proj",
|
||||
quant_config=self.quant_config_mock,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
# Verify output_partition_sizes
|
||||
self.assertEqual(layer.output_partition_sizes, [2])
|
||||
|
||||
|
||||
class TestAscendMlpMergedColumnParallelLinear(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
|
||||
# Mock get_mlp_tensor_model_parallel_world_size and get_tensor_model_parallel_world_size
|
||||
self.mlp_world_size_patch = \
|
||||
mock.patch("vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size", return_value=2)
|
||||
self.tensor_world_size_patch = \
|
||||
mock.patch("vllm_ascend.ops.linear.get_tensor_model_parallel_world_size", return_value=2)
|
||||
self.mlp_world_size_patch.start()
|
||||
self.tensor_world_size_patch.start()
|
||||
|
||||
# Mock get_mlp_tensor_model_parallel_rank and get_tensor_model_parallel_rank
|
||||
self.mlp_rank_patch = \
|
||||
mock.patch("vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank", return_value=0)
|
||||
self.tensor_rank_patch = \
|
||||
mock.patch("vllm_ascend.ops.linear.get_tensor_model_parallel_rank", return_value=0)
|
||||
self.mlp_rank_patch.start()
|
||||
self.tensor_rank_patch.start()
|
||||
|
||||
# Mock all_gather methods
|
||||
self.get_mlp_tp_group_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_mlp_tp_group')
|
||||
self.get_mlp_tp_group_mock = self.get_mlp_tp_group_patch.start()
|
||||
self.get_mlp_tp_group_mock.return_value = mock.MagicMock()
|
||||
self.get_mlp_tp_group_mock.return_value.all_gather = mock.MagicMock()
|
||||
self.tensor_model_parallel_all_gather_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.tensor_model_parallel_all_gather',
|
||||
return_value=torch.randn(10, 8))
|
||||
self.tensor_model_parallel_all_gather_mock = \
|
||||
self.tensor_model_parallel_all_gather_patch.start()
|
||||
|
||||
# Mock AscendMlpColumnParallelLinear's __init__
|
||||
self.linear_init_patch = mock.patch.object(
|
||||
AscendMlpColumnParallelLinear,
|
||||
"__init__",
|
||||
side_effect=self.mock_linear_init)
|
||||
self.linear_init_patch.start()
|
||||
|
||||
# Create mock objects
|
||||
self.quant_method_mock = mock.MagicMock()
|
||||
self.apply_output = torch.randn(2, 8)
|
||||
|
||||
self.quant_method_mock.apply.return_value = self.apply_output
|
||||
|
||||
def mock_linear_init(self, instance, *args, **kwargs):
|
||||
torch.nn.Module.__init__(instance)
|
||||
# Set quant_method and other attributes
|
||||
instance.quant_method = self.quant_method_mock
|
||||
instance.bias = torch.nn.Parameter(torch.randn(8)) # Example bias
|
||||
instance.input_size = 16
|
||||
instance.output_size = 8
|
||||
instance.gather_output = False
|
||||
instance.skip_bias_add = False
|
||||
instance.return_bias = True
|
||||
|
||||
def test_forward_with_enable_mlp_optimze(self):
|
||||
# Setup input
|
||||
input_tensor = torch.randn(1, 16)
|
||||
|
||||
# Create instance with prefix "gate_up_proj" to trigger enable_mlp_optimze = True
|
||||
layer = AscendMlpMergedColumnParallelLinear(input_size=16,
|
||||
output_sizes=[8],
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
skip_bias_add=False,
|
||||
params_dtype=torch.float32,
|
||||
quant_config=None,
|
||||
prefix="other_proj")
|
||||
|
||||
# Call forward
|
||||
output, bias = layer(input_tensor)
|
||||
|
||||
# Validate calls
|
||||
self.assertEqual(output.shape, self.apply_output.shape)
|
||||
|
||||
def test_forward_without_enable_mlp_optimze(self):
|
||||
# Setup input
|
||||
input_tensor = torch.randn(1, 16)
|
||||
|
||||
# Create instance with prefix not containing "gate_up_proj"
|
||||
layer = AscendMlpMergedColumnParallelLinear(input_size=16,
|
||||
output_sizes=[8],
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
skip_bias_add=False,
|
||||
params_dtype=torch.float32,
|
||||
quant_config=None,
|
||||
prefix="other_proj")
|
||||
|
||||
# Call forward
|
||||
output, bias = layer(input_tensor)
|
||||
|
||||
# Validate calls
|
||||
self.quant_method_mock.apply.assert_called_once_with(
|
||||
layer, input_tensor, layer.bias)
|
||||
self.tensor_model_parallel_all_gather_mock.assert_not_called()
|
||||
self.assertEqual(output.shape, self.apply_output.shape)
|
||||
|
||||
def tearDown(self):
|
||||
self.linear_init_patch.stop()
|
||||
self.mlp_world_size_patch.stop()
|
||||
self.tensor_world_size_patch.stop()
|
||||
self.mlp_rank_patch.stop()
|
||||
self.tensor_rank_patch.stop()
|
||||
self.get_mlp_tp_group_mock.stop()
|
||||
self.tensor_model_parallel_all_gather_mock.stop()
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
232
tests/ut/ops/test_moe_comm_method.py
Normal file
232
tests/ut/ops/test_moe_comm_method.py
Normal file
@@ -0,0 +1,232 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
|
||||
AlltoAllCommImpl, MC2CommImpl)
|
||||
|
||||
|
||||
class TestMoECommMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Mock FusedMoEConfig
|
||||
self.moe_config = MagicMock(spec=FusedMoEConfig)
|
||||
self.moe_config.num_experts = 8
|
||||
self.moe_config.num_local_experts = 2
|
||||
self.moe_config.experts_per_token = 2
|
||||
self.moe_config.tp_group = MagicMock()
|
||||
self.moe_config.tp_group.device_group = MagicMock()
|
||||
self.moe_config.dp_size = 1
|
||||
self.moe_config.tp_size = 1
|
||||
self.moe_config.ep_size = 1
|
||||
self.moe_config.dp_group = MagicMock()
|
||||
self.moe_config.num_global_redundant_experts = 0
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
|
||||
)
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather")
|
||||
def test_all_gather_comm_impl(self, mock_token_dispatcher,
|
||||
mock_prepare_finalize,
|
||||
mock_get_forward_context,
|
||||
mock_get_current_vllm_config):
|
||||
# Mock vLLM config
|
||||
mock_get_current_vllm_config.return_value = MagicMock()
|
||||
|
||||
# Mock forward context
|
||||
mock_context = MagicMock()
|
||||
mock_context.moe_comm_method = "all_gather"
|
||||
mock_get_forward_context.return_value = mock_context
|
||||
|
||||
# Mock prepare finalize
|
||||
mock_pf_instance = MagicMock()
|
||||
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
||||
torch.randn(4, 2), None)
|
||||
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
||||
mock_prepare_finalize.return_value = mock_pf_instance
|
||||
|
||||
# Mock token dispatcher
|
||||
mock_td_instance = MagicMock()
|
||||
mock_token_dispatcher.return_value = mock_td_instance
|
||||
|
||||
# Create instance
|
||||
comm_impl = AllGatherCommImpl(self.moe_config)
|
||||
|
||||
# Test prepare method
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
h_out, r_out = comm_impl.prepare(hidden_states, router_logits)
|
||||
|
||||
# Verify prepare was called with correct arguments
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
hidden_states, router_logits, False, False, False, None)
|
||||
|
||||
# Test finalize method
|
||||
comm_impl.finalize(h_out, reduce_results=True)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithMC2"
|
||||
)
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithMC2")
|
||||
def test_mc2_comm_impl(self, mock_token_dispatcher, mock_prepare_finalize,
|
||||
mock_get_forward_context,
|
||||
mock_get_current_vllm_config):
|
||||
# Mock vLLM config
|
||||
mock_get_current_vllm_config.return_value = MagicMock()
|
||||
|
||||
# Mock forward context
|
||||
mock_context = MagicMock()
|
||||
mock_context.moe_comm_method = "mc2"
|
||||
mock_get_forward_context.return_value = mock_context
|
||||
|
||||
# Mock prepare finalize
|
||||
mock_pf_instance = MagicMock()
|
||||
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
||||
torch.randn(4, 2),
|
||||
torch.tensor([1, 0, 1, 0]))
|
||||
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
||||
mock_prepare_finalize.return_value = mock_pf_instance
|
||||
|
||||
# Mock token dispatcher
|
||||
mock_td_instance = MagicMock()
|
||||
mock_token_dispatcher.return_value = mock_td_instance
|
||||
|
||||
# Create instance
|
||||
comm_impl = MC2CommImpl(self.moe_config)
|
||||
|
||||
# Test prepare method
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
h_out, r_out = comm_impl.prepare(hidden_states, router_logits)
|
||||
|
||||
# Verify prepare was called with correct arguments
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
hidden_states, router_logits, False, False, False, None)
|
||||
|
||||
# Test finalize method
|
||||
comm_impl.finalize(h_out, reduce_results=True)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAll2All"
|
||||
)
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAll2AllV")
|
||||
def test_alltoall_comm_impl(self, mock_token_dispatcher,
|
||||
mock_prepare_finalize,
|
||||
mock_get_forward_context,
|
||||
mock_get_current_vllm_config):
|
||||
# Mock vLLM config
|
||||
mock_get_current_vllm_config.return_value = MagicMock()
|
||||
|
||||
# Mock forward context
|
||||
mock_context = MagicMock()
|
||||
mock_context.moe_comm_method = "alltoall"
|
||||
mock_get_forward_context.return_value = mock_context
|
||||
|
||||
# Mock prepare finalize
|
||||
mock_pf_instance = MagicMock()
|
||||
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
||||
torch.randn(4, 2), None)
|
||||
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
||||
mock_prepare_finalize.return_value = mock_pf_instance
|
||||
|
||||
# Mock token dispatcher
|
||||
mock_td_instance = MagicMock()
|
||||
mock_token_dispatcher.return_value = mock_td_instance
|
||||
|
||||
# Create instance
|
||||
comm_impl = AlltoAllCommImpl(self.moe_config)
|
||||
|
||||
# Test prepare method
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
h_out, r_out = comm_impl.prepare(hidden_states, router_logits)
|
||||
|
||||
# Verify prepare was called with correct arguments
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
hidden_states, router_logits, False, False, False, None)
|
||||
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.get_forward_context")
|
||||
@patch(
|
||||
"vllm_ascend.ops.moe.moe_comm_method.FusedMoEPrepareAndFinalizeWithAllGather"
|
||||
)
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.TokenDispatcherWithAllGather")
|
||||
@patch("vllm_ascend.ops.moe.moe_comm_method.unified_apply_mlp")
|
||||
def test_fused_experts_method(self, mock_unified_apply_mlp,
|
||||
mock_token_dispatcher, mock_prepare_finalize,
|
||||
mock_get_forward_context,
|
||||
mock_get_current_vllm_config):
|
||||
# Mock vLLM config
|
||||
mock_get_current_vllm_config.return_value = MagicMock()
|
||||
|
||||
# Mock forward context
|
||||
mock_context = MagicMock()
|
||||
mock_context.moe_comm_method = "all_gather"
|
||||
mock_get_forward_context.return_value = mock_context
|
||||
|
||||
# Mock prepare finalize
|
||||
mock_pf_instance = MagicMock()
|
||||
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
||||
torch.randn(4, 2), None)
|
||||
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
||||
mock_prepare_finalize.return_value = mock_pf_instance
|
||||
|
||||
# Mock token dispatcher
|
||||
mock_td_instance = MagicMock()
|
||||
mock_td_instance.token_dispatch.return_value = {
|
||||
"hidden_states": torch.randn(6, 8),
|
||||
"group_list": torch.tensor([2, 2, 2]),
|
||||
"group_list_type": 1
|
||||
}
|
||||
mock_td_instance.token_combine.return_value = torch.randn(4, 8)
|
||||
mock_token_dispatcher.return_value = mock_td_instance
|
||||
|
||||
# Mock unified_apply_mlp
|
||||
mock_unified_apply_mlp.return_value = torch.randn(6, 8)
|
||||
|
||||
# Create instance
|
||||
comm_impl = AllGatherCommImpl(self.moe_config)
|
||||
|
||||
# Test fused_experts method
|
||||
hidden_states = torch.randn(4, 8).contiguous()
|
||||
w1 = torch.randn(16, 8).contiguous()
|
||||
w2 = torch.randn(16, 8).contiguous()
|
||||
topk_weights = torch.tensor([[0.5, 0.5], [0.3, 0.7], [0.8, 0.2],
|
||||
[0.6, 0.4]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 0], [1, 1]])
|
||||
row_idx = torch.arange(4)
|
||||
|
||||
# Make sure tensors are contiguous and have correct strides
|
||||
hidden_states = hidden_states.contiguous()
|
||||
w1 = w1.contiguous()
|
||||
w2 = w2.contiguous()
|
||||
|
||||
result = comm_impl.fused_experts(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
activation="silu")
|
||||
|
||||
# Verify result shape
|
||||
self.assertEqual(result.shape, (4, 8))
|
||||
|
||||
# Verify token_dispatch was called
|
||||
mock_td_instance.token_dispatch.assert_called_once()
|
||||
|
||||
# Verify unified_apply_mlp was called
|
||||
mock_unified_apply_mlp.assert_called_once()
|
||||
|
||||
# Verify token_combine was called
|
||||
mock_td_instance.token_combine.assert_called_once()
|
||||
@@ -3,12 +3,18 @@ import unittest
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import torch
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.ops.rotary_embedding import _custom_rotary_embedding_enabled
|
||||
|
||||
MODEL = "Qwen3-0.6B"
|
||||
MAX_NUM_BATCHED_TOKEND = 10000
|
||||
|
||||
|
||||
class TestCustomRotaryEmbeddingEnabled(unittest.TestCase):
|
||||
|
||||
@@ -88,11 +94,15 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
|
||||
self.mock_self.cos_sin_cache = self.cos_sin_cache
|
||||
self.mock_self.is_neox_style = self.is_neox_style
|
||||
|
||||
@patch('torch.ops._C')
|
||||
@patch('torch.ops._C_ascend')
|
||||
@patch('vllm_ascend.ops.rotary_embedding.is_310p', return_value=False)
|
||||
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
|
||||
return_value=True)
|
||||
@patch('torch.ops._npu_rotary_embedding')
|
||||
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
|
||||
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
|
||||
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
|
||||
mock_custom_enabled, mock_is_310p,
|
||||
mock__c):
|
||||
@@ -102,9 +112,15 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
|
||||
# Setup mock for custom kernel path
|
||||
|
||||
mock__c.rotary_embedding.return_value = self.query, self.key
|
||||
|
||||
result_q, result_k = self.layer.forward(self.positions, self.query,
|
||||
self.key)
|
||||
vllm_config = VllmConfig()
|
||||
model_config = ModelConfig(MODEL,
|
||||
tokenizer=MODEL,
|
||||
max_model_len=MAX_NUM_BATCHED_TOKEND)
|
||||
model_config.hf_config = PretrainedConfig()
|
||||
vllm_config.model_config = model_config
|
||||
with set_ascend_forward_context(None, vllm_config):
|
||||
result_q, result_k = self.layer.forward(self.positions, self.query,
|
||||
self.key)
|
||||
|
||||
mock__c.rotary_embedding.assert_called_once()
|
||||
self.assertEqual(result_q.shape, self.query.shape)
|
||||
@@ -113,6 +129,10 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
|
||||
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
|
||||
return_value=False)
|
||||
@patch('torch_npu._npu_rotary_embedding')
|
||||
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
|
||||
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
|
||||
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
|
||||
mock_custom_enabled):
|
||||
mock_config = MagicMock()
|
||||
@@ -121,15 +141,25 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
|
||||
# Test contiguous path when custom is disabled
|
||||
non_contig_query = self.query.transpose(0, 1)
|
||||
non_contig_key = self.key.transpose(0, 1)
|
||||
|
||||
result_q, result_k = self.layer.forward(self.positions,
|
||||
non_contig_query,
|
||||
non_contig_key)
|
||||
vllm_config = VllmConfig()
|
||||
model_config = ModelConfig(MODEL,
|
||||
tokenizer=MODEL,
|
||||
max_model_len=MAX_NUM_BATCHED_TOKEND)
|
||||
model_config.hf_config = PretrainedConfig()
|
||||
vllm_config.model_config = model_config
|
||||
with set_ascend_forward_context(None, vllm_config):
|
||||
result_q, result_k = self.layer.forward(self.positions,
|
||||
non_contig_query,
|
||||
non_contig_key)
|
||||
|
||||
mock_npu_rotary.assert_called_once()
|
||||
self.assertEqual(result_q.shape, non_contig_query.shape)
|
||||
self.assertEqual(result_k.shape, non_contig_key.shape)
|
||||
|
||||
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
|
||||
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
|
||||
def test_rope_forward_oot_with_offsets(self):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
@@ -137,26 +167,78 @@ class TestAscendRotaryEmbedding(unittest.TestCase):
|
||||
# Test that NotImplementedError is raised when offsets is provided
|
||||
offsets = torch.tensor([1, 2, 3])
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.layer.forward(self.positions, self.query, self.key, offsets)
|
||||
vllm_config = VllmConfig()
|
||||
model_config = ModelConfig(MODEL,
|
||||
tokenizer=MODEL,
|
||||
max_model_len=MAX_NUM_BATCHED_TOKEND)
|
||||
model_config.hf_config = PretrainedConfig()
|
||||
vllm_config.model_config = model_config
|
||||
with set_ascend_forward_context(None, vllm_config):
|
||||
self.layer.forward(self.positions, self.query, self.key,
|
||||
offsets)
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
|
||||
return_value=False)
|
||||
@patch('torch_npu._npu_rotary_embedding')
|
||||
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
|
||||
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
|
||||
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
|
||||
mock_custom_enabled):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
# Test neox_style override
|
||||
result_q, result_k = self.layer.forward(self.positions,
|
||||
self.query,
|
||||
self.key,
|
||||
is_neox_style_override=False)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
model_config = ModelConfig(MODEL,
|
||||
tokenizer=MODEL,
|
||||
max_model_len=MAX_NUM_BATCHED_TOKEND)
|
||||
model_config.hf_config = PretrainedConfig()
|
||||
vllm_config.model_config = model_config
|
||||
with set_ascend_forward_context(None, vllm_config):
|
||||
result_q, result_k = self.layer.forward(
|
||||
self.positions,
|
||||
self.query,
|
||||
self.key,
|
||||
is_neox_style_override=False)
|
||||
# Check that neox_style=False was passed to the NPU function
|
||||
args, kwargs = mock_npu_rotary.call_args
|
||||
self.assertFalse(args[-1])
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
|
||||
return_value=False)
|
||||
@patch('torch_npu._npu_rotary_embedding')
|
||||
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
|
||||
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
|
||||
def test_rope_forward_oot_rotary_dim_less_than_head_size(
|
||||
self, mock_npu_rotary, mock_custom_enabled):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
# test case when rotary_dim < head_size
|
||||
org_rotary_dim = self.layer.rotary_dim
|
||||
self.layer.rotary_dim = self.layer.head_size // 2
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
model_config = ModelConfig(MODEL,
|
||||
tokenizer=MODEL,
|
||||
max_model_len=MAX_NUM_BATCHED_TOKEND)
|
||||
model_config.hf_config = PretrainedConfig()
|
||||
vllm_config.model_config = model_config
|
||||
with set_ascend_forward_context(None, vllm_config):
|
||||
result_q, result_k = self.layer.forward(self.positions, self.query,
|
||||
self.key)
|
||||
|
||||
mock_npu_rotary.assert_called_once()
|
||||
self.assertEqual(result_q.shape, self.query.shape)
|
||||
self.assertEqual(result_k.shape, self.key.shape)
|
||||
|
||||
# restore rotary_dim
|
||||
self.layer.rotary_dim = org_rotary_dim
|
||||
|
||||
|
||||
class MockRopeModule:
|
||||
|
||||
@@ -207,28 +289,6 @@ class TestAscendDeepseekScalingRotaryEmbedding(TestBase):
|
||||
assert q_pe.shape == self.query.shape
|
||||
assert k_pe.shape == self.key.shape
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
|
||||
@patch("vllm.platforms.current_platform.device_type",
|
||||
new=torch.device("cpu"))
|
||||
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
|
||||
new_callable=PropertyMock)
|
||||
def test_native_rope_deepseek_forward_cache_handling(
|
||||
self, mock_npuplatform, mock_rope_forward_oot):
|
||||
mock_npuplatform.device_type = torch.device("cpu")
|
||||
self.layer = self._create_layer()
|
||||
self.layer.max_seq_len = 1024
|
||||
# Test cache situation is true
|
||||
with patch.object(self.layer, "_set_cos_sin_cache") as mock_set_cache:
|
||||
mock_rope_forward_oot.return_value = (self.query, self.key)
|
||||
|
||||
q_pe, k_pe = self.layer.forward(self.positions,
|
||||
self.query,
|
||||
self.key,
|
||||
max_seq_len=2048)
|
||||
mock_set_cache.assert_called_once()
|
||||
assert q_pe.shape == self.query.shape
|
||||
assert k_pe.shape == self.key.shape
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
|
||||
@patch("vllm.platforms.current_platform.device_type",
|
||||
new=torch.device("cpu"))
|
||||
|
||||
@@ -20,10 +20,10 @@ from unittest.mock import MagicMock, PropertyMock, patch
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
|
||||
|
||||
from vllm_ascend.ops.moe.token_dispatcher import ( # isort: skip
|
||||
AscendSocVersion, TokenDispatcherWithAll2AllV,
|
||||
TokenDispatcherWithAllGather, TokenDispatcherWithMC2, _Dispatchers,
|
||||
_register_token_dispatcher, get_token_dispatcher, setup_token_dispatchers)
|
||||
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
|
||||
|
||||
|
||||
class TestTokenDispatcherWithMC2(TestBase):
|
||||
@@ -34,7 +34,7 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
self.mc2_group.rank_in_group = 0
|
||||
self.mc2_group.world_size = 8
|
||||
self.mc2_group_patch = patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_mc2_group",
|
||||
"vllm_ascend.ops.moe.token_dispatcher.get_mc2_group",
|
||||
return_value=self.mc2_group)
|
||||
self.mc2_group_patch.start()
|
||||
|
||||
@@ -52,7 +52,7 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
|
||||
# Mock get_ascend_soc_version()
|
||||
self.ascend_soc_version_patch = patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_soc_version",
|
||||
"vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version",
|
||||
return_value=AscendSocVersion.A3)
|
||||
self.ascend_soc_version_patch.start()
|
||||
|
||||
@@ -98,7 +98,7 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
self.row_idx, expert_map)
|
||||
mock_dispatch.assert_called_once()
|
||||
self.assertEqual(output["group_list_type"],
|
||||
1) # group_list_type == 1
|
||||
0) # group_list_type == 0
|
||||
|
||||
def test_token_dispatch_with_shared_experts_and_quant(self):
|
||||
self.shared_experts = MagicMock()
|
||||
@@ -171,32 +171,25 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
self.dispatcher = TokenDispatcherWithAllGather(**kwargs)
|
||||
|
||||
# Mock NPU functions
|
||||
self.patcher_moe_init_routing = patch('torch_npu.npu_moe_init_routing')
|
||||
self.mock_moe_init_routing = self.patcher_moe_init_routing.start()
|
||||
self.mock_moe_init_routing.return_value = (
|
||||
self.patcher_npu_moe_init_routing_v2 = patch(
|
||||
'torch_npu.npu_moe_init_routing_v2')
|
||||
self.mock_npu_moe_init_routing_v2 = self.patcher_npu_moe_init_routing_v2.start(
|
||||
)
|
||||
self.mock_npu_moe_init_routing_v2.return_value = (
|
||||
torch.randn(6, 128), # sorted_hidden_states
|
||||
torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx
|
||||
torch.tensor([0, 1, 0, 1, 0, 1]) # expanded_expert_idx
|
||||
)
|
||||
|
||||
self.patcher_moe_compute_expert_tokens = patch(
|
||||
'torch_npu.npu_moe_compute_expert_tokens')
|
||||
self.mock_moe_compute_expert_tokens = self.patcher_moe_compute_expert_tokens.start(
|
||||
)
|
||||
self.mock_moe_compute_expert_tokens.return_value = torch.tensor(
|
||||
[3, 3]) # expert_tokens
|
||||
|
||||
self.patcher_moe_finalize_routing = patch(
|
||||
'torch_npu.npu_moe_finalize_routing')
|
||||
self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start(
|
||||
)
|
||||
self.mock_moe_finalize_routing.return_value = torch.randn(3, 128)
|
||||
torch.tensor([0, 1, 0, 1, 0, 1]), # expanded_expert_idx
|
||||
torch.tensor([0, 1, 0, 1, 0, 1]))
|
||||
self.row_idx = torch.arange(10, dtype=torch.int32)
|
||||
self.patcher_npu_moe_token_unpermute = patch(
|
||||
'torch_npu.npu_moe_token_unpermute')
|
||||
self.mock_npu_moe_token_unpermute = self.patcher_npu_moe_token_unpermute.start(
|
||||
)
|
||||
self.mock_npu_moe_token_unpermute.return_value = torch.randn(6, 128)
|
||||
|
||||
def tearDown(self):
|
||||
self.patcher_moe_init_routing.stop()
|
||||
self.patcher_moe_compute_expert_tokens.stop()
|
||||
self.patcher_moe_finalize_routing.stop()
|
||||
self.patcher_npu_moe_init_routing_v2.stop()
|
||||
self.patcher_npu_moe_token_unpermute.stop()
|
||||
|
||||
def test_token_dispatch_without_expert_map(self):
|
||||
hidden_states = torch.randn(3, 128)
|
||||
@@ -207,12 +200,27 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
topk_ids, self.row_idx, None)
|
||||
|
||||
# Verify npu_moe_init_routing is called
|
||||
self.mock_moe_init_routing.assert_called_once()
|
||||
args, kwargs = self.mock_moe_init_routing.call_args
|
||||
self.mock_npu_moe_init_routing_v2.assert_called_once()
|
||||
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args
|
||||
|
||||
self.assertEqual(results["group_list_type"], 0)
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
|
||||
def test_token_dispatch_with_quant(self):
|
||||
def test_token_dispatch_with_expert_map(self):
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
||||
hidden_states = torch.randn(3, 128)
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
||||
topk_ids, self.row_idx, None)
|
||||
|
||||
# Verify npu_moe_init_routing is called
|
||||
self.mock_npu_moe_init_routing_v2.assert_called_once()
|
||||
args, kwargs = self.mock_npu_moe_init_routing_v2.call_args
|
||||
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
|
||||
def test_token_dispatch_without_quant(self):
|
||||
kwargs = {
|
||||
"apply_router_weight_on_input": False,
|
||||
"top_k": 2,
|
||||
@@ -230,7 +238,33 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
topk_weights, topk_ids,
|
||||
self.row_idx, None)
|
||||
|
||||
self.assertEqual(results["group_list_type"], 0)
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
|
||||
def test_token_dispatch_with_quant(self):
|
||||
kwargs = {
|
||||
"apply_router_weight_on_input": False,
|
||||
"top_k": 2,
|
||||
"max_num_tokens": 100,
|
||||
"ep_size": 2,
|
||||
"num_experts": 128,
|
||||
}
|
||||
self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs)
|
||||
|
||||
hidden_states = torch.randn(3, 128)
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
results = self.dispatcher_quant.token_dispatch(hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
self.row_idx,
|
||||
None,
|
||||
with_quant=True)
|
||||
|
||||
self.assertIsNotNone(results["hidden_states"])
|
||||
self.assertIsNotNone(results["group_list"])
|
||||
self.assertIsNotNone(results["dynamic_scale"])
|
||||
self.assertEqual(results["group_list_type"], 1)
|
||||
|
||||
def test_token_combine_with_expert_map(self):
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
||||
@@ -242,9 +276,7 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
hidden_states = torch.randn(6, 128)
|
||||
|
||||
final_hidden_states = self.dispatcher.token_combine(hidden_states)
|
||||
|
||||
# Verify index_add_ is applied correctly
|
||||
self.assertEqual(final_hidden_states.shape, (3, 128))
|
||||
self.assertEqual(final_hidden_states.shape, (6, 128))
|
||||
|
||||
def test_token_combine_without_expert_map(self):
|
||||
self.dispatcher.with_quant = False
|
||||
@@ -260,10 +292,10 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
final_hidden_states = self.dispatcher.token_combine(hidden_states)
|
||||
|
||||
# Verify npu_moe_finalize_routing is called
|
||||
self.mock_moe_finalize_routing.assert_called_once()
|
||||
args, kwargs = self.mock_moe_finalize_routing.call_args
|
||||
self.mock_npu_moe_token_unpermute.assert_called_once()
|
||||
args, kwargs = self.mock_npu_moe_token_unpermute.call_args
|
||||
|
||||
self.assertEqual(final_hidden_states.shape, (3, 128))
|
||||
self.assertEqual(final_hidden_states.shape, (6, 128))
|
||||
|
||||
def test_token_dispatch_with_router_weight(self):
|
||||
self.dispatcher.apply_router_weight_on_input = True
|
||||
@@ -315,7 +347,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
self.mock_npu_moe_token_unpermute.return_value = torch.randn(8, 16)
|
||||
|
||||
# Mock async_all_to_all
|
||||
patcher6 = patch('vllm_ascend.ops.comm_utils.async_all_to_all')
|
||||
patcher6 = patch('vllm_ascend.ops.moe.comm_utils.async_all_to_all')
|
||||
self.mock_async_all_to_all = patcher6.start()
|
||||
self.addCleanup(patcher6.stop)
|
||||
self.mock_async_all_to_all.return_value = (None, torch.randn(16, 16),
|
||||
@@ -323,7 +355,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
|
||||
# Mock gather_from_sequence_parallel_region
|
||||
patcher7 = patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.gather_from_sequence_parallel_region'
|
||||
'vllm_ascend.ops.moe.token_dispatcher.gather_from_sequence_parallel_region'
|
||||
)
|
||||
self.mock_gather_from_sequence_parallel_region = patcher7.start()
|
||||
self.addCleanup(patcher7.stop)
|
||||
@@ -488,119 +520,3 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
|
||||
|
||||
class TestDispatcherRegistry(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
_Dispatchers.clear()
|
||||
|
||||
def tearDown(self):
|
||||
_Dispatchers.clear()
|
||||
|
||||
def test_register_and_get_token_dispatcher(self):
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_dispatcher.__class__.__name__ = "MockDispatcher"
|
||||
|
||||
_register_token_dispatcher(mock_dispatcher)
|
||||
|
||||
self.assertIn("MockDispatcher", _Dispatchers)
|
||||
self.assertIs(_Dispatchers["MockDispatcher"], mock_dispatcher)
|
||||
|
||||
retrieved_dispatcher = get_token_dispatcher("MockDispatcher")
|
||||
self.assertIs(retrieved_dispatcher, mock_dispatcher)
|
||||
|
||||
self.assertIsNone(get_token_dispatcher("NonExistentDispatcher"))
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAllGather'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
|
||||
)
|
||||
def test_setup_token_dispatchers_ep_size_1_creates_allgather(
|
||||
self, mock_register, mock_allgather_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 8}
|
||||
mock_instance = MagicMock()
|
||||
mock_allgather_class.return_value = mock_instance
|
||||
|
||||
self.assertNotIn("TokenDispatcherWithAllGather", _Dispatchers)
|
||||
|
||||
setup_token_dispatchers(ep_size=1, **kwargs)
|
||||
|
||||
mock_allgather_class.assert_called_once_with(**kwargs)
|
||||
mock_register.assert_called_once_with(mock_instance)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
|
||||
)
|
||||
def test_setup_token_dispatchers_ep_size_2_creates_all2allv(
|
||||
self, mock_register, mock_all2allv_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 16, "num_local_experts": 2}
|
||||
mock_instance = MagicMock()
|
||||
mock_all2allv_class.return_value = mock_instance
|
||||
|
||||
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
|
||||
|
||||
setup_token_dispatchers(ep_size=2, **kwargs)
|
||||
|
||||
mock_all2allv_class.assert_called_once_with(**kwargs)
|
||||
mock_register.assert_called_once_with(mock_instance)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
|
||||
)
|
||||
def test_setup_token_dispatchers_ep_size_16_creates_all2allv_and_mc2(
|
||||
self, mock_register, mock_mc2_class, mock_all2allv_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
|
||||
mock_all2allv_instance = MagicMock()
|
||||
mock_mc2_instance = MagicMock()
|
||||
mock_all2allv_class.return_value = mock_all2allv_instance
|
||||
mock_mc2_class.return_value = mock_mc2_instance
|
||||
|
||||
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
|
||||
self.assertNotIn("TokenDispatcherWithMC2", _Dispatchers)
|
||||
|
||||
setup_token_dispatchers(ep_size=16, **kwargs)
|
||||
|
||||
mock_all2allv_class.assert_called_once_with(**kwargs)
|
||||
mock_mc2_class.assert_called_once_with(**kwargs)
|
||||
self.assertEqual(mock_register.call_count, 2)
|
||||
mock_register.assert_any_call(mock_all2allv_instance)
|
||||
mock_register.assert_any_call(mock_mc2_instance)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
|
||||
)
|
||||
def test_setup_token_dispatchers_ep_size_16_skips_if_exist(
|
||||
self, mock_register, mock_mc2_class, mock_all2allv_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
|
||||
mock_existing_all2allv = MagicMock()
|
||||
mock_existing_mc2 = MagicMock()
|
||||
_Dispatchers["TokenDispatcherWithAll2AllV"] = mock_existing_all2allv
|
||||
_Dispatchers["TokenDispatcherWithMC2"] = mock_existing_mc2
|
||||
|
||||
setup_token_dispatchers(ep_size=16, **kwargs)
|
||||
|
||||
mock_all2allv_class.assert_not_called()
|
||||
mock_mc2_class.assert_not_called()
|
||||
mock_register.assert_not_called()
|
||||
self.assertIs(_Dispatchers["TokenDispatcherWithAll2AllV"],
|
||||
mock_existing_all2allv)
|
||||
self.assertIs(_Dispatchers["TokenDispatcherWithMC2"],
|
||||
mock_existing_mc2)
|
||||
|
||||
@@ -18,6 +18,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_ascend.ascend_config import init_ascend_config
|
||||
from vllm_ascend.ops.vocab_parallel_embedding import (
|
||||
AscendLogitsProcessor, AscendParallelLMHead, AscendVocabParallelEmbedding)
|
||||
|
||||
@@ -31,6 +32,9 @@ class TestCustomVocabParallelEmbedding(unittest.TestCase):
|
||||
self.embedding_dim = 10
|
||||
self.org_num_embeddings = 40
|
||||
self.padding_size = 8
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.additional_config = {}
|
||||
init_ascend_config(mock_vllm_config)
|
||||
|
||||
def _create_layer(self):
|
||||
# Patch methods and dependencies for VocabParallelEmbedding
|
||||
@@ -206,7 +210,15 @@ class TestAscendLogitsProcessor(unittest.TestCase):
|
||||
return_value=True),
|
||||
patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group.all_to_all",
|
||||
return_value=torch.randn(1, self.vocab_size))
|
||||
return_value=torch.randn(1, self.vocab_size)),
|
||||
patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group.all_gather",
|
||||
return_value=torch.randn(1, self.vocab_size)),
|
||||
patch(
|
||||
"vllm_ascend.core.schedule_config.AscendSchedulerConfig.initialize_from_config",
|
||||
return_value=MagicMock(max_num_batched_tokens=1000,
|
||||
max_model_len=512,
|
||||
enable_chunked_prefill=False))
|
||||
]
|
||||
|
||||
for p in self.patches:
|
||||
|
||||
@@ -1,167 +0,0 @@
|
||||
from importlib import reload
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import vllm
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.patch.worker.patch_common import patch_linear
|
||||
|
||||
|
||||
class TestAscendRowParallelLinear(PytestBase):
|
||||
|
||||
def init_row_parallel_linear(self, mocker: MockerFixture):
|
||||
mocker.patch(
|
||||
"vllm_ascend.patch.worker.patch_common.patch_linear.AscendRowParallelLinear.__init__",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
return patch_linear.AscendRowParallelLinear(
|
||||
input_size=128,
|
||||
output_size=256,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"version, expected",
|
||||
[
|
||||
("1.0.0", 1),
|
||||
("2.1.0", 1),
|
||||
],
|
||||
)
|
||||
def test_get_hcomm_info(self, version, expected, mocker: MockerFixture):
|
||||
mock_group = mocker.MagicMock()
|
||||
backend = mocker.MagicMock()
|
||||
backend.get_hccl_comm_name = lambda x: x
|
||||
mock_group._get_backend = lambda x: backend
|
||||
mock_group.get_hccl_comm_name = lambda x: x
|
||||
mocker.patch("torch.distributed.get_rank", return_value=1)
|
||||
mocker.patch(
|
||||
"torch.distributed.get_global_rank",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch("torch.__version__", new=version)
|
||||
hcomm_info = patch_linear.AscendRowParallelLinear.get_hcomm_info(
|
||||
mock_group)
|
||||
assert hcomm_info == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"skip_bias_add, return_bias, bias, expected",
|
||||
[
|
||||
(True, False, torch.tensor(1.0), torch.tensor(14.0)),
|
||||
(False, True, torch.tensor(1.0), (torch.tensor(14.0), None)),
|
||||
(
|
||||
True,
|
||||
True,
|
||||
torch.tensor(1.0),
|
||||
(torch.tensor(14.0), torch.tensor(1.0)),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_forward(
|
||||
self,
|
||||
skip_bias_add,
|
||||
return_bias,
|
||||
bias,
|
||||
expected,
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
mocker_tp_group = mocker.MagicMock()
|
||||
mocker_tp_group.device_group = mocker.MagicMock()
|
||||
row_parallel_linear = self.init_row_parallel_linear(mocker)
|
||||
row_parallel_linear.__dict__["tp_rank"] = 0
|
||||
row_parallel_linear.__dict__["skip_bias_add"] = skip_bias_add
|
||||
row_parallel_linear.__dict__["return_bias"] = return_bias
|
||||
row_parallel_linear.__dict__["bias"] = bias
|
||||
row_parallel_linear.__dict__["qyuant_method"] = mocker.MagicMock()
|
||||
row_parallel_linear.__dict__["calc_input"] = lambda x: x # noqa
|
||||
row_parallel_linear.__dict__[
|
||||
"calc_output"] = lambda x: x.matmul( # noqa
|
||||
torch.tensor([1.0, 2.0]))
|
||||
ret = row_parallel_linear.forward(torch.tensor([10.0, 2.0]))
|
||||
if isinstance(ret, tuple):
|
||||
assert torch.allclose(ret[0], expected[0])
|
||||
if ret[1] is None:
|
||||
assert ret[1] == expected[1]
|
||||
else:
|
||||
assert torch.allclose(ret[1], expected[1])
|
||||
else:
|
||||
assert torch.allclose(ret, expected)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_is_parallel, expected",
|
||||
[
|
||||
(True, torch.tensor([10.0, 2.0])),
|
||||
(False, torch.tensor([10.0])),
|
||||
],
|
||||
)
|
||||
def test_calc_input(
|
||||
self,
|
||||
input_is_parallel,
|
||||
expected,
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
row_parallel_linear = self.init_row_parallel_linear(mocker)
|
||||
row_parallel_linear.__dict__["input_is_parallel"] = input_is_parallel
|
||||
input_tensor = torch.Tensor([10, 2])
|
||||
mocker.patch(
|
||||
"vllm_ascend.patch.worker.patch_common.patch_linear.get_tensor_model_parallel_rank", # noqa
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"vllm_ascend.patch.worker.patch_common.patch_linear.split_tensor_along_last_dim", # noqa
|
||||
return_value=[torch.Tensor([10]),
|
||||
torch.Tensor([2])],
|
||||
)
|
||||
input_parallel = row_parallel_linear.calc_input(input_tensor)
|
||||
assert torch.allclose(input_parallel, expected)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"reduce_results, tp_size, expected",
|
||||
[
|
||||
(True, 2, torch.tensor(56.0)),
|
||||
(True, 1, torch.tensor(14.0)),
|
||||
(False, 2, torch.tensor(14.0)),
|
||||
],
|
||||
)
|
||||
def test_calc_output(
|
||||
self,
|
||||
reduce_results,
|
||||
tp_size,
|
||||
expected,
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
quant_method = mocker.MagicMock()
|
||||
quant_method.apply = lambda self, x, bias=None: x.matmul( # noqa
|
||||
torch.tensor([1.0, 2.0]))
|
||||
row_parallel_linear = self.init_row_parallel_linear(mocker)
|
||||
row_parallel_linear.__dict__["reduce_results"] = reduce_results
|
||||
row_parallel_linear.__dict__["tp_size"] = tp_size
|
||||
row_parallel_linear.__dict__["quant_method"] = quant_method
|
||||
row_parallel_linear.__dict__["tp_rank"] = 0
|
||||
row_parallel_linear.__dict__["get_hcomm_info"] = lambda x: None # noqa
|
||||
|
||||
mocker.patch(
|
||||
"vllm_ascend.patch.worker.patch_common.patch_linear.get_tp_group",
|
||||
return_value=mocker.MagicMock(device_group=mocker.MagicMock()),
|
||||
)
|
||||
mocker.patch(
|
||||
"torch_npu.npu_mm_all_reduce_base",
|
||||
side_effect=lambda input_, weight, hccl_info, bias: input_.
|
||||
matmul( # noqa
|
||||
torch.tensor([4.0, 8.0])),
|
||||
) # noqa
|
||||
ret = row_parallel_linear.calc_output(torch.tensor([10.0, 2.0]))
|
||||
assert torch.allclose(ret, expected)
|
||||
|
||||
def test_enable_allreduce_matmul(self, mocker: MockerFixture):
|
||||
mocker.patch.object(envs_ascend,
|
||||
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE",
|
||||
new=True)
|
||||
reload(patch_linear)
|
||||
assert envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
|
||||
assert id(vllm.model_executor.layers.linear.RowParallelLinear) == id(
|
||||
patch_linear.AscendRowParallelLinear)
|
||||
@@ -1,134 +0,0 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.func_wrapper import (wrapper_rmsnorm_forward_oot,
|
||||
wrapper_rmsnorm_init)
|
||||
|
||||
|
||||
class MockRMSNorm:
|
||||
|
||||
def __init__(self, hidden_size: int, **extra_args):
|
||||
self.hidden_size = hidden_size
|
||||
self.weight = torch.ones(hidden_size)
|
||||
self.input_scale = 1.0
|
||||
self.input_offset = 0.0
|
||||
self.variance_epsilon = 1e-6
|
||||
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
|
||||
requires_grad=False)
|
||||
self.ignore_anti = extra_args.get('ignore_anti', True)
|
||||
|
||||
|
||||
class TestFuncWrapper(TestBase):
|
||||
|
||||
def test_wrapper_rmsnorm_init(self):
|
||||
|
||||
@wrapper_rmsnorm_init
|
||||
def init(self, hidden_size: int, **extra_args) -> None:
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
hidden_size = 128
|
||||
extra_args = {'arg1': 'value1'}
|
||||
|
||||
rms_norm = MockRMSNorm(hidden_size, **extra_args)
|
||||
init(rms_norm, hidden_size, **extra_args)
|
||||
|
||||
self.assertTrue(hasattr(rms_norm, 'ignore_anti'))
|
||||
self.assertTrue(rms_norm.ignore_anti)
|
||||
|
||||
self.assertTrue(hasattr(rms_norm, 'bias'))
|
||||
self.assertIsInstance(rms_norm.bias, torch.nn.Parameter)
|
||||
self.assertEqual(rms_norm.bias.shape, torch.Size([hidden_size]))
|
||||
self.assertFalse(rms_norm.bias.requires_grad)
|
||||
|
||||
@patch('torch_npu._npu_quant_rms_norm')
|
||||
def test_wrapper_rmsnorm_forward_oot_with_residual(
|
||||
self, mock_npu_quant_rms_norm):
|
||||
hidden_size = 128
|
||||
x = torch.randn(hidden_size)
|
||||
residual = torch.randn(hidden_size)
|
||||
expected_out = torch.randn(hidden_size)
|
||||
|
||||
mock_npu_quant_rms_norm.return_value = (expected_out, residual)
|
||||
|
||||
@wrapper_rmsnorm_forward_oot
|
||||
def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None):
|
||||
return x, residual
|
||||
|
||||
rms_norm = MockRMSNorm(hidden_size)
|
||||
rms_norm.ignore_anti = False
|
||||
|
||||
output, res = forward_oot(rms_norm, x, residual)
|
||||
|
||||
mock_npu_quant_rms_norm.assert_called_once()
|
||||
|
||||
args, kwargs = mock_npu_quant_rms_norm.call_args
|
||||
self.assertTrue(torch.equal(args[1], rms_norm.weight))
|
||||
self.assertTrue(torch.equal(args[2], rms_norm.bias))
|
||||
self.assertEqual(args[3], rms_norm.input_scale)
|
||||
self.assertEqual(args[4], rms_norm.input_offset)
|
||||
self.assertEqual(args[5], rms_norm.variance_epsilon)
|
||||
self.assertTrue(torch.equal(res, residual))
|
||||
|
||||
@patch('torch_npu._npu_quant_rms_norm')
|
||||
def test_wrapper_rmsnorm_forward_oot_without_residual(
|
||||
self, mock_npu_quant_rms_norm):
|
||||
hidden_size = 128
|
||||
x = torch.randn(hidden_size)
|
||||
expected_out = torch.randn(hidden_size)
|
||||
|
||||
mock_npu_quant_rms_norm.return_value = expected_out
|
||||
|
||||
@wrapper_rmsnorm_forward_oot
|
||||
def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None):
|
||||
return x
|
||||
|
||||
rms_norm = MockRMSNorm(hidden_size)
|
||||
rms_norm.ignore_anti = False
|
||||
|
||||
output = forward_oot(rms_norm, x)
|
||||
|
||||
mock_npu_quant_rms_norm.assert_called_once()
|
||||
|
||||
args, kwargs = mock_npu_quant_rms_norm.call_args
|
||||
self.assertTrue(torch.equal(args[0], x))
|
||||
self.assertTrue(torch.equal(args[1], rms_norm.weight))
|
||||
self.assertTrue(torch.equal(args[2], rms_norm.bias))
|
||||
self.assertEqual(args[3], rms_norm.input_scale)
|
||||
self.assertEqual(args[4], rms_norm.input_offset)
|
||||
self.assertEqual(args[5], rms_norm.variance_epsilon)
|
||||
|
||||
self.assertTrue(torch.equal(output, expected_out))
|
||||
|
||||
def test_wrapper_rmsnorm_forward_oot_ignore_anti_with_residual(self):
|
||||
hidden_size = 128
|
||||
x = torch.randn(hidden_size)
|
||||
residual = torch.randn(hidden_size)
|
||||
|
||||
@wrapper_rmsnorm_forward_oot
|
||||
def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None):
|
||||
return x, residual
|
||||
|
||||
rms_norm = MockRMSNorm(hidden_size)
|
||||
rms_norm.ignore_anti = True
|
||||
|
||||
output, res = forward_oot(rms_norm, x, residual)
|
||||
|
||||
self.assertTrue(torch.equal(output, x.add_(rms_norm.bias)))
|
||||
self.assertTrue(torch.equal(res, residual))
|
||||
|
||||
def test_wrapper_rmsnorm_forward_oot_ignore_anti_no_residual(self):
|
||||
hidden_size = 128
|
||||
x = torch.randn(hidden_size)
|
||||
|
||||
@wrapper_rmsnorm_forward_oot
|
||||
def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None):
|
||||
return x
|
||||
|
||||
rms_norm = MockRMSNorm(hidden_size)
|
||||
rms_norm.ignore_anti = True
|
||||
|
||||
output = forward_oot(rms_norm, x)
|
||||
|
||||
self.assertTrue(torch.equal(output, x.add_(rms_norm.bias)))
|
||||
@@ -73,9 +73,12 @@ class TestAscendQuantConfig(TestBase):
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_get_quant_method_for_linear(self):
|
||||
mock_config = MagicMock()
|
||||
mock_config.model_config.hf_config.model_type = None
|
||||
linear_layer = MagicMock(spec=LinearBase)
|
||||
# Test skipped layer
|
||||
with patch.object(self.ascend_config,
|
||||
with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch.object(self.ascend_config, \
|
||||
'is_layer_skipped_ascend',
|
||||
return_value=True):
|
||||
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
|
||||
@@ -83,6 +86,7 @@ class TestAscendQuantConfig(TestBase):
|
||||
|
||||
# Test quantized layer
|
||||
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
|
||||
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch('vllm_ascend.quantization.quant_config.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear:
|
||||
|
||||
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
|
||||
@@ -93,14 +97,18 @@ class TestAscendQuantConfig(TestBase):
|
||||
|
||||
def test_get_quant_method_for_attention(self):
|
||||
attention_layer = MagicMock(spec=Attention)
|
||||
with patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod',
|
||||
mock_config = MagicMock()
|
||||
mock_config.model_config.hf_config.model_type = None
|
||||
with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod', \
|
||||
return_value=MagicMock()) as mock_ascend_kvcache:
|
||||
# Test with fa_quant_type
|
||||
method = self.ascend_config.get_quant_method(
|
||||
attention_layer, ".attn")
|
||||
self.assertIs(method, mock_ascend_kvcache.return_value)
|
||||
|
||||
with patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod',
|
||||
with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod', \
|
||||
return_value=MagicMock()) as mock_ascend_kvcache:
|
||||
# Test with kv_quant_type
|
||||
modified_config = {"kv_quant_type": "C8"}
|
||||
@@ -113,9 +121,12 @@ class TestAscendQuantConfig(TestBase):
|
||||
fused_moe_layer = MagicMock(spec=FusedMoE)
|
||||
fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig)
|
||||
fused_moe_layer.moe_config = MagicMock(spec=FusedMoEConfig)
|
||||
mock_config = MagicMock()
|
||||
mock_config.model_config.hf_config.model_type = None
|
||||
|
||||
# Test skipped layer
|
||||
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=True), \
|
||||
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch('vllm_ascend.quantization.quant_config.AscendUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
|
||||
method = self.ascend_config.get_quant_method(
|
||||
fused_moe_layer, "moe_layer")
|
||||
@@ -123,6 +134,7 @@ class TestAscendQuantConfig(TestBase):
|
||||
|
||||
# Test quantized layer
|
||||
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
|
||||
patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch('vllm_ascend.quantization.quant_config.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
|
||||
method = self.ascend_config.get_quant_method(
|
||||
fused_moe_layer, "moe_layer")
|
||||
@@ -156,33 +168,22 @@ class TestAscendKVCacheMethod(TestBase):
|
||||
def setUp(self):
|
||||
# Setup common test fixtures
|
||||
self.mock_quant_config = MagicMock(spec=AscendQuantConfig)
|
||||
self.mock_quant_config.quant_description = {"some_config": "value"}
|
||||
self.prefix = "attention_layer"
|
||||
self.mock_quant_config.quant_description = {"kv_quant_type": "C8"}
|
||||
self.prefix = "layer.attn"
|
||||
|
||||
# Mock the quantizer and quant_method
|
||||
self.mock_quantizer = MagicMock()
|
||||
# Mock quant_method
|
||||
self.mock_quant_method = MagicMock()
|
||||
|
||||
# Patch the AscendQuantizer
|
||||
self.quantizer_patcher = patch(
|
||||
'vllm_ascend.quantization.quant_config.AscendQuantizer.get_quantizer',
|
||||
return_value=self.mock_quantizer)
|
||||
self.mock_get_quantizer = self.quantizer_patcher.start()
|
||||
|
||||
self.mock_quantizer.build_attention_method.return_value = self.mock_quant_method
|
||||
self.patcher = patch(
|
||||
'vllm_ascend.quantization.quant_config.get_quant_method')
|
||||
self.mock_get_quant_method = self.patcher.start()
|
||||
self.mock_get_quant_method.return_value = self.mock_quant_method
|
||||
|
||||
# Create instance
|
||||
self.kv_cache_method = AscendKVCacheMethod(self.mock_quant_config,
|
||||
self.prefix)
|
||||
|
||||
def tearDown(self):
|
||||
self.quantizer_patcher.stop()
|
||||
|
||||
def test_init(self):
|
||||
"""Test initialization with proper quantizer setup."""
|
||||
self.mock_get_quantizer.assert_called_once_with(
|
||||
self.mock_quant_config.quant_description, self.prefix)
|
||||
self.mock_quantizer.build_attention_method.assert_called_once()
|
||||
self.patcher.stop()
|
||||
|
||||
def test_create_weights(self):
|
||||
"""Test create_weights delegates to quant_method."""
|
||||
|
||||
@@ -1,145 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
||||
from vllm_ascend.quantization.quantizer import (VLLMAscendQuantizer,
|
||||
W4A8DYNAMICQuantizer,
|
||||
W8A8Quantizer)
|
||||
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE = {"test": "1"}
|
||||
|
||||
|
||||
class TestGetQuantizer(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Setup common test fixtures
|
||||
self.supported_types = {
|
||||
'INT8': MagicMock(_instance=None),
|
||||
'FP16': MagicMock(_instance=None),
|
||||
'C8': MagicMock(_instance=None)
|
||||
}
|
||||
self.original_supported_types = SUPPORT_ASCEND_QUANTIZER_TYPE.copy()
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.supported_types)
|
||||
self.mock_quant_config = MagicMock(spec=AscendQuantConfig)
|
||||
self.mock_quant_config.quant_description = {"some_config": "value"}
|
||||
|
||||
def tearDown(self):
|
||||
# Restore original supported types
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE.clear()
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.original_supported_types)
|
||||
|
||||
def test_get_quantizer_fa(self):
|
||||
"""Test successful quantizer retrieval for different cases."""
|
||||
# Setup
|
||||
quant_description = {'fa_quant_type': 'C8'}
|
||||
prefix = '.attn'
|
||||
expected_type = 'C8'
|
||||
with patch.dict(
|
||||
'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE',
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE):
|
||||
|
||||
result = VLLMAscendQuantizer.get_quantizer(
|
||||
quant_description,
|
||||
prefix,
|
||||
packed_modules_mapping={"some": "mapping"})
|
||||
|
||||
# Verify
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result,
|
||||
self.supported_types[expected_type]._instance)
|
||||
self.supported_types[expected_type].assert_called_once_with(
|
||||
quant_description)
|
||||
|
||||
def test_get_quantizer_kv(self):
|
||||
"""Test successful quantizer retrieval for different cases."""
|
||||
# Setup
|
||||
quant_description = {'kv_quant_type': 'C8'}
|
||||
prefix = '.attn'
|
||||
expected_type = 'C8'
|
||||
with patch.dict(
|
||||
'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE',
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE):
|
||||
|
||||
result = VLLMAscendQuantizer.get_quantizer(
|
||||
quant_description,
|
||||
prefix,
|
||||
packed_modules_mapping={"some": "mapping"})
|
||||
|
||||
# Verify
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result,
|
||||
self.supported_types[expected_type]._instance)
|
||||
self.supported_types[expected_type].assert_called_once_with(
|
||||
quant_description)
|
||||
|
||||
def test_get_quantizer_linear(self):
|
||||
"""Test successful quantizer retrieval for different cases."""
|
||||
# Setup
|
||||
quant_description = {'linear_type': 'INT8'}
|
||||
prefix = 'nothing'
|
||||
expected_type = 'INT8'
|
||||
with patch('vllm_ascend.quantization.quantizer.VLLMAscendQuantizer.get_linear_quant_type',
|
||||
return_value=expected_type), \
|
||||
patch.dict('vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE', SUPPORT_ASCEND_QUANTIZER_TYPE):
|
||||
|
||||
result = VLLMAscendQuantizer.get_quantizer(
|
||||
quant_description,
|
||||
prefix,
|
||||
packed_modules_mapping={"some": "mapping"})
|
||||
|
||||
# Verify
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result,
|
||||
self.supported_types[expected_type]._instance)
|
||||
self.supported_types[expected_type].assert_called_once_with(
|
||||
quant_description)
|
||||
|
||||
|
||||
class TestW8A8Quantizer(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.quantizer = W8A8Quantizer(quant_description={})
|
||||
|
||||
def test_build_linear_method(self):
|
||||
with patch('vllm_ascend.quantization.quantizer.AscendW8A8LinearMethod',
|
||||
return_value=MagicMock()) as mock_linear:
|
||||
result = self.quantizer.build_linear_method()
|
||||
mock_linear.assert_called_once_with()
|
||||
self.assertIsInstance(result, MagicMock)
|
||||
|
||||
def test_build_moe_method(self):
|
||||
with patch(
|
||||
'vllm_ascend.quantization.quantizer.AscendW8A8FusedMoEMethod',
|
||||
return_value=MagicMock()) as mock_linear:
|
||||
result = self.quantizer.build_moe_method()
|
||||
mock_linear.assert_called_once_with()
|
||||
self.assertIsInstance(result, MagicMock)
|
||||
|
||||
def test_build_attention_method(self):
|
||||
with patch('vllm_ascend.quantization.quantizer.AscendC8KVCacheMethod',
|
||||
return_value=MagicMock()) as mock_linear:
|
||||
result = self.quantizer.build_attention_method()
|
||||
mock_linear.assert_called_once_with()
|
||||
self.assertIsInstance(result, MagicMock)
|
||||
|
||||
|
||||
class TestW4A8DYNAMICQuantizer(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.quantizer = W4A8DYNAMICQuantizer(quant_description={})
|
||||
|
||||
def test_build_linear_method(self):
|
||||
with patch(
|
||||
'vllm_ascend.quantization.quantizer.AscendW4A8DynamicLinearMethod',
|
||||
return_value=MagicMock()) as mock_linear:
|
||||
result = self.quantizer.build_linear_method()
|
||||
mock_linear.assert_called_once_with()
|
||||
self.assertIsInstance(result, MagicMock)
|
||||
|
||||
def test_build_moe_method(self):
|
||||
with patch(
|
||||
'vllm_ascend.quantization.quantizer.AscendW4A8DynamicFusedMoEMethod',
|
||||
return_value=MagicMock()) as mock_fused_moe:
|
||||
result = self.quantizer.build_moe_method()
|
||||
mock_fused_moe.assert_called_once_with()
|
||||
self.assertIsInstance(result, MagicMock)
|
||||
62
tests/ut/quantization/test_utils.py
Normal file
62
tests/ut/quantization/test_utils.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import types
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.utils import (ASCEND_QUANTIZATION_METHOD_MAP,
|
||||
get_quant_method)
|
||||
|
||||
|
||||
class TestGetQuantMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.original_quantization_method_map = ASCEND_QUANTIZATION_METHOD_MAP.copy(
|
||||
)
|
||||
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
|
||||
for layer_type in layer_map.keys():
|
||||
ASCEND_QUANTIZATION_METHOD_MAP[quant_type][
|
||||
layer_type] = types.new_class(f"{quant_type}_{layer_type}")
|
||||
|
||||
def tearDown(self):
|
||||
# Restore original map
|
||||
ASCEND_QUANTIZATION_METHOD_MAP.clear()
|
||||
ASCEND_QUANTIZATION_METHOD_MAP.update(
|
||||
self.original_quantization_method_map)
|
||||
|
||||
def test_linear_quant_methods(self):
|
||||
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
|
||||
if "linear" in layer_map.keys():
|
||||
prefix = "linear_layer"
|
||||
cls = layer_map["linear"]
|
||||
method = get_quant_method({"linear_layer.weight": quant_type},
|
||||
prefix, "linear")
|
||||
self.assertIsInstance(method, cls)
|
||||
|
||||
def test_moe_quant_methods(self):
|
||||
for quant_type, layer_map in ASCEND_QUANTIZATION_METHOD_MAP.items():
|
||||
if "moe" in layer_map.keys():
|
||||
prefix = "layer"
|
||||
cls = layer_map["moe"]
|
||||
method = get_quant_method({"layer.weight": quant_type}, prefix,
|
||||
"moe")
|
||||
self.assertIsInstance(method, cls)
|
||||
|
||||
def test_with_fa_quant_type(self):
|
||||
quant_description = {"fa_quant_type": "C8"}
|
||||
method = get_quant_method(quant_description, ".attn", "attention")
|
||||
self.assertIsInstance(
|
||||
method, ASCEND_QUANTIZATION_METHOD_MAP["C8"]["attention"])
|
||||
|
||||
def test_with_kv_quant_type(self):
|
||||
quant_description = {"kv_quant_type": "C8"}
|
||||
method = get_quant_method(quant_description, ".attn", "attention")
|
||||
self.assertIsInstance(
|
||||
method, ASCEND_QUANTIZATION_METHOD_MAP["C8"]["attention"])
|
||||
|
||||
def test_invalid_layer_type(self):
|
||||
quant_description = {"linear_layer.weight": "W8A8"}
|
||||
with self.assertRaises(NotImplementedError):
|
||||
get_quant_method(quant_description, "linear_layer", "unsupported")
|
||||
|
||||
def test_invalid_quant_type(self):
|
||||
quant_description = {"linear_layer.weight": "UNKNOWN"}
|
||||
with self.assertRaises(NotImplementedError):
|
||||
get_quant_method(quant_description, "linear_layer", "linear")
|
||||
@@ -1,4 +1,3 @@
|
||||
import copy
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import torch
|
||||
@@ -11,8 +10,19 @@ from vllm_ascend.quantization.w4a8_dynamic import (
|
||||
class TestAscendW4A8DynamicLinearMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.method = AscendW4A8DynamicLinearMethod()
|
||||
self.method.group_size = 8
|
||||
with patch(
|
||||
'vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config'
|
||||
) as mock_get_current_vllm_config:
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.quant_config = Mock(
|
||||
quant_description={"group_size": 256})
|
||||
mock_vllm_config.scheduler_config = Mock(
|
||||
max_num_batched_tokens=2048,
|
||||
max_model_len=2048,
|
||||
enable_chunked_prefill=False)
|
||||
mock_get_current_vllm_config.return_value = mock_vllm_config
|
||||
self.method = AscendW4A8DynamicLinearMethod()
|
||||
self.method.group_size = 8
|
||||
|
||||
def test_get_weight(self):
|
||||
weight = self.method.get_weight(8, 32, torch.bfloat16)
|
||||
@@ -37,18 +47,27 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
||||
output_size = 56
|
||||
group_size = 2
|
||||
|
||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ascend_config')
|
||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
|
||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group')
|
||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group')
|
||||
@patch('torch.distributed.get_rank', return_value=0)
|
||||
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ep_group,
|
||||
get_current_vllm_config):
|
||||
get_current_vllm_config, mock_get_ascend_config):
|
||||
# Mock ascend config
|
||||
mock_ascend_config = Mock()
|
||||
mock_ascend_config.dynamic_eplb = False
|
||||
mock_get_ascend_config.return_value = mock_ascend_config
|
||||
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.quant_config = Mock(quant_description={
|
||||
"group_size": self.group_size,
|
||||
"version": "0.0.0"
|
||||
})
|
||||
mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True)
|
||||
mock_vllm_config.scheduler_config = Mock(max_num_batched_tokens=2048,
|
||||
max_model_len=2048,
|
||||
enable_chunked_prefill=False)
|
||||
get_current_vllm_config.return_value = mock_vllm_config
|
||||
self.quant_method = AscendW4A8DynamicFusedMoEMethod()
|
||||
|
||||
@@ -75,19 +94,19 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
||||
# old quant version weight
|
||||
param_dict = self.quant_method.get_dynamic_quant_param(
|
||||
self.experts, self.input_size, self.output_size, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.float32)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].shape,
|
||||
(self.experts, 2 * self.input_size, 1))
|
||||
self.assertEqual(param_dict["w13_weight_scale_second"].dtype,
|
||||
torch.bfloat16)
|
||||
torch.float32)
|
||||
self.assertEqual(param_dict["w13_weight_scale_second"].shape,
|
||||
(self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size))
|
||||
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.float32)
|
||||
self.assertEqual(param_dict["w2_weight_scale"].shape,
|
||||
(self.experts, self.output_size, 1))
|
||||
self.assertEqual(param_dict["w2_weight_scale_second"].dtype,
|
||||
torch.bfloat16)
|
||||
torch.float32)
|
||||
self.assertEqual(param_dict["w2_weight_scale_second"].shape,
|
||||
(self.experts, self.output_size,
|
||||
self.input_size // self.group_size))
|
||||
@@ -99,40 +118,87 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
||||
self.assertEqual(
|
||||
param_dict["w2_scale_bias"].shape,
|
||||
(self.experts, self.output_size, 16 // self.quant_method.tp_size))
|
||||
# per-channel weight
|
||||
self.quant_method.is_per_channel_weight = True
|
||||
param_dict = self.quant_method.get_dynamic_quant_param(
|
||||
self.experts, self.input_size, self.output_size, torch.bfloat16)
|
||||
pergroup_param = [
|
||||
"w13_weight_scale_second", "w13_weight_offset_second",
|
||||
"w2_weight_scale_second", "w2_weight_offset_second"
|
||||
]
|
||||
is_contains = any(key in param_dict for key in pergroup_param)
|
||||
self.assertFalse(is_contains)
|
||||
|
||||
def build_layer(self,
|
||||
is_new_quant_version=True,
|
||||
is_per_channel_weight=False):
|
||||
layer = torch.nn.Module()
|
||||
if is_new_quant_version:
|
||||
layer.w13_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.input_size, self.output_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.output_size // 2, self.input_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
w13_scale_bias = torch.zeros(
|
||||
(self.experts, 2 * self.input_size, 1), dtype=torch.float32)
|
||||
layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
|
||||
requires_grad=False)
|
||||
w2_scale_bias = torch.zeros((self.experts, self.output_size,
|
||||
16 // self.quant_method.tp_size),
|
||||
dtype=torch.float32)
|
||||
layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
|
||||
requires_grad=False)
|
||||
else:
|
||||
layer.w13_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, 2 * self.input_size, self.output_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.output_size, self.input_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, 2 * self.input_size, 1), dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, self.output_size, 1), dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
if not is_per_channel_weight:
|
||||
layer.w13_weight_scale_second = torch.nn.Parameter(
|
||||
torch.ones((self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size),
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_offset_second = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w13_weight_scale_second.data),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale_second = torch.nn.Parameter(
|
||||
torch.ones((self.experts, self.output_size,
|
||||
self.input_size // self.group_size),
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_offset_second = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w2_weight_scale_second.data),
|
||||
requires_grad=False)
|
||||
return layer
|
||||
|
||||
@patch('torch_npu.npu_format_cast')
|
||||
@patch('torch_npu.npu_quantize')
|
||||
@patch('torch.Tensor.npu')
|
||||
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
|
||||
# old quant version weight
|
||||
layer = torch.nn.Module()
|
||||
layer.w13_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, 2 * self.input_size, self.output_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.output_size, self.input_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, 2 * self.input_size, 1), dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size),
|
||||
dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, self.output_size, 1), dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, self.output_size,
|
||||
self.input_size // self.group_size),
|
||||
dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
new_layer = copy.deepcopy(layer)
|
||||
|
||||
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize,
|
||||
mock_npu_format_cast):
|
||||
mock_npu.return_value = torch.Tensor()
|
||||
mock_npu_quantize.return_value = torch.Tensor()
|
||||
|
||||
def func_by_args(weight, num_format):
|
||||
return weight
|
||||
|
||||
mock_npu_format_cast.side_effect = func_by_args
|
||||
# old quant version weight
|
||||
layer = self.build_layer(is_new_quant_version=False)
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
self.assertTrue(hasattr(layer, "w13_scale_bias"))
|
||||
self.assertEqual(layer.w13_scale_bias.data.shape,
|
||||
@@ -144,23 +210,17 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
||||
self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32)
|
||||
# new quant version weight
|
||||
self.quant_method.new_quant_version = True
|
||||
new_layer.w13_weight.data = torch.zeros(
|
||||
(self.experts, self.input_size, self.output_size),
|
||||
dtype=torch.int8)
|
||||
new_layer.w2_weight.data = torch.zeros(
|
||||
(self.experts, self.output_size // 2, self.input_size),
|
||||
dtype=torch.int8)
|
||||
w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1),
|
||||
dtype=torch.float32)
|
||||
new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
|
||||
requires_grad=False)
|
||||
w2_scale_bias = torch.zeros(
|
||||
(self.experts, self.output_size, 16 // self.quant_method.tp_size),
|
||||
dtype=torch.float32)
|
||||
new_layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
|
||||
requires_grad=False)
|
||||
new_layer = self.build_layer(is_new_quant_version=True)
|
||||
self.quant_method.process_weights_after_loading(new_layer)
|
||||
self.assertEqual(new_layer.w13_scale_bias.data.shape,
|
||||
(self.experts, 2 * self.input_size))
|
||||
self.assertEqual(new_layer.w2_scale_bias.data.shape,
|
||||
(self.experts, self.output_size))
|
||||
self.assertFalse(hasattr(new_layer, "w13_weight_scale_second"))
|
||||
# per-channel weight
|
||||
self.quant_method.is_per_channel_weight = True
|
||||
per_channel_layer = self.build_layer(is_new_quant_version=True,
|
||||
is_per_channel_weight=True)
|
||||
self.quant_method.process_weights_after_loading(per_channel_layer)
|
||||
self.assertEqual(new_layer.w13_scale_bias.data.shape,
|
||||
(self.experts, 2 * self.input_size))
|
||||
|
||||
@@ -5,8 +5,8 @@ import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.ops.layers.experts_selector import (_native_grouped_topk,
|
||||
select_experts)
|
||||
from vllm_ascend.ops.moe.experts_selector import (_native_grouped_topk,
|
||||
select_experts)
|
||||
from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod,
|
||||
AscendW8A8FusedMoEMethod,
|
||||
AscendW8A8LinearMethod,
|
||||
@@ -784,7 +784,7 @@ class TestSelectExperts(TestBase):
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.dtype, torch.int32)
|
||||
|
||||
@patch('vllm_ascend.ops.layers.experts_selector._native_grouped_topk')
|
||||
@patch('vllm_ascend.ops.moe.experts_selector._native_grouped_topk')
|
||||
def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
|
||||
"""Test grouped topk with expert score correction bias"""
|
||||
mock_grouped_topk.return_value = torch.ones(self.num_tokens,
|
||||
|
||||
69
tests/ut/quantization/test_w8a8_dynamic.py
Normal file
69
tests/ut/quantization/test_w8a8_dynamic.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.w8a8_dynamic import \
|
||||
AscendW8A8DynamicFusedMoEMethod
|
||||
|
||||
|
||||
class TestAscendW8A8FusedMoEMethod(TestBase):
|
||||
num_experts = 8
|
||||
hidden_size = 128
|
||||
intermediate_size = 128
|
||||
|
||||
@patch("torch.distributed.get_rank")
|
||||
@patch("vllm_ascend.quantization.w8a8_dynamic.get_mc2_group")
|
||||
@patch("vllm_ascend.quantization.w8a8_dynamic.get_ascend_config")
|
||||
@patch("vllm_ascend.quantization.w8a8_dynamic.get_ep_group")
|
||||
def setUp(self, mock_get_ep_group, mock_get_ascend_config,
|
||||
mock_get_mc2_group, mock_get_rank):
|
||||
with patch(
|
||||
'vllm_ascend.quantization.w8a8_dynamic.get_current_vllm_config'
|
||||
) as mock_get_current_vllm_config:
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.quant_config = Mock(
|
||||
quant_description={"group_size": 256})
|
||||
mock_vllm_config.scheduler_config = Mock(
|
||||
max_num_batched_tokens=2048,
|
||||
max_model_len=2048,
|
||||
enable_chunked_prefill=False)
|
||||
mock_get_current_vllm_config.return_value = mock_vllm_config
|
||||
mock_ep_group = Mock()
|
||||
mock_get_ep_group.return_value = mock_ep_group
|
||||
mock_ascend_config = Mock()
|
||||
|
||||
# 创建一个具有具体属性的 Mock 对象来表示 ascend_scheduler_config
|
||||
mock_ascend_scheduler_config = Mock()
|
||||
mock_ascend_scheduler_config.enabled = False
|
||||
mock_ascend_scheduler_config.max_num_batched_tokens = 1024
|
||||
mock_ascend_scheduler_config.max_model_len = 2048
|
||||
mock_ascend_config.ascend_scheduler_config = mock_ascend_scheduler_config
|
||||
|
||||
mock_ascend_config.torchair_graph_config = Mock(enabled=False)
|
||||
mock_ascend_config.enable_chunked_prefill = False
|
||||
mock_get_ascend_config.return_value = mock_ascend_config
|
||||
mock_mc2_group = Mock(device_group=0)
|
||||
mock_get_mc2_group.return_value = mock_mc2_group
|
||||
mock_rank = Mock()
|
||||
mock_get_rank.return_value = mock_rank
|
||||
|
||||
self.quant_method = AscendW8A8DynamicFusedMoEMethod()
|
||||
|
||||
def test_get_weight(self):
|
||||
param_dict = self.quant_method.get_weight(self.num_experts,
|
||||
self.intermediate_size,
|
||||
self.hidden_size,
|
||||
torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
|
||||
self.assertEqual(
|
||||
param_dict["w13_weight"].shape,
|
||||
(self.num_experts, 2 * self.intermediate_size, self.hidden_size))
|
||||
|
||||
def test_get_dynamic_quant_param(self):
|
||||
param_dict = self.quant_method.get_dynamic_quant_param(
|
||||
self.num_experts, self.intermediate_size, self.hidden_size,
|
||||
torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].shape,
|
||||
(self.num_experts, 2 * self.intermediate_size, 1))
|
||||
40
tests/ut/sample/logits_processor/test_builtin.py
Normal file
40
tests/ut/sample/logits_processor/test_builtin.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import torch
|
||||
from pytest_mock import MockerFixture
|
||||
from vllm.config import SchedulerConfig, VllmConfig
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.sample.logits_processor import AscendMinPLogitsProcessor
|
||||
|
||||
|
||||
class TestMinPLogitsProcessorInitFunc(PytestBase):
|
||||
|
||||
def test_init_func_with_decode_max_num_seqs(self, mocker: MockerFixture):
|
||||
device_cpu = torch.device("cpu")
|
||||
device_npu = torch.device("npu")
|
||||
is_pin_memory = False
|
||||
mock_vllm_config = mocker.MagicMock(spec=VllmConfig)
|
||||
mock_scheduler_config = mocker.MagicMock(spec=SchedulerConfig)
|
||||
mock_scheduler_config.decode_max_num_seqs = 0
|
||||
mock_scheduler_config.max_num_seqs = 128
|
||||
mock_vllm_config.scheduler_config = mock_scheduler_config
|
||||
# torch.zeros/torch.empty returns error on online ut machine, so mock it
|
||||
mock_tensor = torch.zeros((256, ),
|
||||
dtype=torch.float32,
|
||||
pin_memory=False)
|
||||
mocker.patch("torch.zeros", return_value=mock_tensor)
|
||||
mock_empty_tensor = torch.empty((256, ), dtype=torch.float32)
|
||||
mocker.patch("torch.empty", return_value=mock_empty_tensor)
|
||||
|
||||
processor_cpu = AscendMinPLogitsProcessor(mock_vllm_config, device_cpu,
|
||||
is_pin_memory)
|
||||
|
||||
assert processor_cpu.min_p is not None
|
||||
assert processor_cpu.use_double_tensor is False
|
||||
assert processor_cpu.min_p_cpu.shape[0] == 256
|
||||
|
||||
processor_cpu = AscendMinPLogitsProcessor(mock_vllm_config, device_npu,
|
||||
is_pin_memory)
|
||||
|
||||
assert processor_cpu.min_p is not None
|
||||
assert processor_cpu.use_double_tensor is True
|
||||
assert processor_cpu.min_p_cpu.shape[0] == 256
|
||||
@@ -43,6 +43,7 @@ class TestAscendConfig(TestBase):
|
||||
# No additional config given, check the default value here.
|
||||
ascend_config = init_ascend_config(test_vllm_config)
|
||||
self.assertIsNone(ascend_config.expert_map_path)
|
||||
self.assertFalse(ascend_config.multistream_overlap_shared_expert)
|
||||
|
||||
torchair_graph_config = ascend_config.torchair_graph_config
|
||||
self.assertFalse(torchair_graph_config.enabled)
|
||||
@@ -51,8 +52,8 @@ class TestAscendConfig(TestBase):
|
||||
self.assertEqual(torchair_graph_config.graph_batch_sizes, [])
|
||||
self.assertFalse(torchair_graph_config.graph_batch_sizes_init)
|
||||
self.assertFalse(torchair_graph_config.enable_multistream_mla)
|
||||
self.assertFalse(torchair_graph_config.enable_multistream_moe)
|
||||
self.assertTrue(torchair_graph_config.enable_view_optimize)
|
||||
self.assertTrue(torchair_graph_config.enable_frozen_parameter)
|
||||
self.assertFalse(torchair_graph_config.enable_kv_nz)
|
||||
|
||||
ascend_scheduler_config = ascend_config.ascend_scheduler_config
|
||||
@@ -68,10 +69,11 @@ class TestAscendConfig(TestBase):
|
||||
"graph_batch_sizes": [1, 2, 4],
|
||||
"graph_batch_sizes_init": False,
|
||||
"enable_multistream_mla": True,
|
||||
"enable_multistream_moe": True,
|
||||
"enable_view_optimize": True,
|
||||
"enable_frozen_parameter": True,
|
||||
"enable_kv_nz": True
|
||||
},
|
||||
"multistream_overlap_shared_expert": True,
|
||||
"ascend_scheduler_config": {
|
||||
"enabled": True
|
||||
},
|
||||
@@ -80,6 +82,7 @@ class TestAscendConfig(TestBase):
|
||||
}
|
||||
ascend_config = init_ascend_config(test_vllm_config)
|
||||
self.assertEqual(ascend_config.expert_map_path, "test_expert_map_path")
|
||||
self.assertTrue(ascend_config.multistream_overlap_shared_expert)
|
||||
|
||||
torchair_graph_config = ascend_config.torchair_graph_config
|
||||
self.assertTrue(torchair_graph_config.enabled)
|
||||
@@ -87,8 +90,8 @@ class TestAscendConfig(TestBase):
|
||||
self.assertEqual(torchair_graph_config.graph_batch_sizes, [1, 2, 4])
|
||||
self.assertFalse(torchair_graph_config.graph_batch_sizes_init)
|
||||
self.assertTrue(torchair_graph_config.enable_multistream_mla)
|
||||
self.assertTrue(torchair_graph_config.enable_multistream_moe)
|
||||
self.assertTrue(torchair_graph_config.enable_view_optimize)
|
||||
self.assertTrue(torchair_graph_config.enable_frozen_parameter)
|
||||
self.assertTrue(torchair_graph_config.enable_kv_nz)
|
||||
|
||||
ascend_scheduler_config = ascend_config.ascend_scheduler_config
|
||||
@@ -215,21 +218,6 @@ class TestAscendConfig(TestBase):
|
||||
test_vllm_config.model_config = fake_model_config
|
||||
init_ascend_config(test_vllm_config)
|
||||
check_ascend_config(test_vllm_config, False)
|
||||
# aclgraph + deepseek model
|
||||
with self.assertRaises(NotImplementedError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
model_path = os.path.join(os.path.dirname(__file__), "fake_weight")
|
||||
fake_model_config = ModelConfig(model=model_path)
|
||||
fake_model_config.hf_config = PretrainedConfig()
|
||||
fake_model_config.hf_config.model_type = "deepseek"
|
||||
test_vllm_config.model_config = fake_model_config
|
||||
init_ascend_config(test_vllm_config)
|
||||
check_ascend_config(test_vllm_config, False)
|
||||
|
||||
def test_check_torchair_supported(self):
|
||||
test_cases = [('deepseek_v3', True), ('PanguProMoE', True),
|
||||
@@ -318,17 +306,6 @@ class TestAscendConfig(TestBase):
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
# enable_multistream_moe should not be enabled without torchair graph mode
|
||||
with self.assertRaises(RuntimeError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
"enable_multistream_moe": True,
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
# mode should not be configured without torchair graph mode
|
||||
with self.assertRaises(RuntimeError):
|
||||
test_vllm_config.additional_config = {
|
||||
@@ -359,3 +336,27 @@ class TestAscendConfig(TestBase):
|
||||
test_vllm_config.parallel_config = ParallelConfig(
|
||||
data_parallel_size=4, tensor_parallel_size=2)
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"oproj_tensor_parallel_size": 2,
|
||||
"refresh": True
|
||||
}
|
||||
test_vllm_config.parallel_config = ParallelConfig(
|
||||
data_parallel_size=4, tensor_parallel_size=2)
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
},
|
||||
"oproj_tensor_parallel_size": 2,
|
||||
"refresh": True
|
||||
}
|
||||
test_vllm_config.parallel_config = ParallelConfig(
|
||||
data_parallel_size=4, tensor_parallel_size=1)
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
@@ -36,6 +36,7 @@ class TestNPUPlatform(TestBase):
|
||||
mock_ascend_config = MagicMock()
|
||||
mock_ascend_config.torchair_graph_config.enabled = False
|
||||
mock_ascend_config.ascend_scheduler_config.enabled = False
|
||||
mock_ascend_config.enable_shared_expert_dp = False
|
||||
return mock_ascend_config
|
||||
|
||||
def setUp(self):
|
||||
@@ -363,36 +364,6 @@ class TestNPUPlatform(TestBase):
|
||||
CUDAGraphMode.NONE,
|
||||
)
|
||||
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
def test_check_and_update_config_disable_aclgraph_when_ray_enabled(
|
||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
||||
)
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.model_config.enforce_eager = False
|
||||
vllm_config.compilation_config.level = CompilationLevel.PIECEWISE
|
||||
vllm_config.parallel_config.distributed_executor_backend = "ray"
|
||||
|
||||
with self.assertLogs(logger="vllm", level="WARNING") as cm:
|
||||
from vllm_ascend import platform
|
||||
|
||||
importlib.reload(platform)
|
||||
self.platform.check_and_update_config(vllm_config)
|
||||
print(30 * "=", f"cm.output: {cm.output}")
|
||||
self.assertTrue(
|
||||
"Ray distributed executor backend is not compatible with ACL Graph mode"
|
||||
in cm.output[0])
|
||||
self.assertEqual(
|
||||
vllm_config.compilation_config.level,
|
||||
CompilationLevel.NO_COMPILATION,
|
||||
)
|
||||
self.assertEqual(
|
||||
vllm_config.compilation_config.cudagraph_mode,
|
||||
CUDAGraphMode.NONE,
|
||||
)
|
||||
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
@@ -509,6 +480,7 @@ class TestNPUPlatform(TestBase):
|
||||
def test_get_attn_backend_cls_use_v1_and_mla(self, mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_config.enable_shared_expert_dp = False
|
||||
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
@@ -589,9 +561,8 @@ class TestNPUPlatform(TestBase):
|
||||
|
||||
def test_get_punica_wrapper(self):
|
||||
result = self.platform.get_punica_wrapper()
|
||||
self.assertEqual(
|
||||
result,
|
||||
"vllm_ascend.lora.punica_wrapper.punica_npu.PunicaWrapperNPU")
|
||||
self.assertEqual(result,
|
||||
"vllm_ascend.lora.punica_npu.PunicaWrapperNPU")
|
||||
|
||||
@patch("torch.npu.reset_peak_memory_stats")
|
||||
@patch("torch.npu.max_memory_allocated")
|
||||
|
||||
@@ -24,6 +24,7 @@ from vllm.config import (CompilationConfig, ModelConfig, ParallelConfig,
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend import utils
|
||||
from vllm_ascend.utils import REGISTERED_ASCEND_OPS
|
||||
|
||||
|
||||
class TestUtils(TestBase):
|
||||
@@ -259,8 +260,22 @@ class TestUtils(TestBase):
|
||||
utils.update_aclgraph_sizes(test_vllm_config)
|
||||
del os.environ['HCCL_OP_EXPANSION_MODE']
|
||||
self.assertEqual(
|
||||
147,
|
||||
137,
|
||||
len(test_vllm_config.compilation_config.cudagraph_capture_sizes))
|
||||
|
||||
test_vllm_config.speculative_config = mock.MagicMock()
|
||||
test_vllm_config.speculative_config.draft_model_config = mock.MagicMock(
|
||||
)
|
||||
test_vllm_config.speculative_config.draft_model_config.hf_config = mock.MagicMock(
|
||||
)
|
||||
test_vllm_config.speculative_config.draft_model_config.hf_config.num_hidden_layers = 2
|
||||
os.environ['HCCL_OP_EXPANSION_MODE'] = 'AIV'
|
||||
utils.update_aclgraph_sizes(test_vllm_config)
|
||||
del os.environ['HCCL_OP_EXPANSION_MODE']
|
||||
self.assertEqual(
|
||||
111,
|
||||
len(test_vllm_config.compilation_config.cudagraph_capture_sizes))
|
||||
|
||||
# max_num_batch_sizes >= len(original_sizes)
|
||||
test_compilation_config = CompilationConfig(
|
||||
cudagraph_capture_sizes=[1, 2, 3])
|
||||
@@ -288,14 +303,14 @@ class TestUtils(TestBase):
|
||||
|
||||
# ascend custom op is not registered
|
||||
utils.register_ascend_customop()
|
||||
# should call register_oot three
|
||||
self.assertEqual(mock_customop.register_oot.call_count, 12)
|
||||
self.assertEqual(mock_customop.register_oot.call_count,
|
||||
len(REGISTERED_ASCEND_OPS))
|
||||
self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED)
|
||||
|
||||
# ascend custom op is already registered
|
||||
utils.register_ascend_customop()
|
||||
# should not register_oot again, thus only called three in this ut
|
||||
self.assertEqual(mock_customop.register_oot.call_count, 12)
|
||||
self.assertEqual(mock_customop.register_oot.call_count,
|
||||
len(REGISTERED_ASCEND_OPS))
|
||||
|
||||
|
||||
class TestProfileExecuteDuration(TestBase):
|
||||
|
||||
@@ -165,8 +165,6 @@ class TestTorchairDeepSeekMTP(PytestBase):
|
||||
mocker.patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__call__",
|
||||
return_value=None)
|
||||
mocker.patch("vllm.model_executor.layers.sampler.get_sampler",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
|
||||
@@ -100,6 +100,11 @@ def mock_distributed():
|
||||
pp_group.rank_in_group = 0
|
||||
pp_group.world_size = 1
|
||||
|
||||
mlp_tp_group = Mock(spec=GroupCoordinator)
|
||||
mlp_tp_group.rank_in_group = 0
|
||||
mlp_tp_group.world_size = 1
|
||||
mlp_tp_group.all_gather = Mock(return_value=torch.randn(2, 4, 128))
|
||||
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.scheduler_config = Mock(max_num_seqs=256)
|
||||
mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None)
|
||||
@@ -196,10 +201,6 @@ def test_torchair_deepseek_v2_mlp(mock_distributed, base_config):
|
||||
quant_config=None)
|
||||
assert isinstance(mlp.act_fn, TorchairDeepseekV2SiluAndMul)
|
||||
|
||||
x = torch.randn(2, 4, 128)
|
||||
output = mlp(x)
|
||||
assert output.shape == (2, 4, 128)
|
||||
|
||||
with patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v2.QuantizationConfig"
|
||||
) as mock_quant_config:
|
||||
@@ -274,7 +275,12 @@ def test_torchair_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
|
||||
|
||||
@patch("torch_npu.npu_add_rms_norm")
|
||||
@patch("torch_npu.npu_rms_norm")
|
||||
def test_torchair_deepseek_v2_decoder_layer(mock_rms_norm, mock_add_norm,
|
||||
@patch("torch.ops.vllm.maybe_wait_prefetch_done", side_effect=lambda x: None)
|
||||
@patch("torch.ops.vllm.maybe_chunk_residual",
|
||||
side_effect=lambda x, residual: residual)
|
||||
def test_torchair_deepseek_v2_decoder_layer(mock_maybe_chunk_residual,
|
||||
mock_maybe_wait_prefetch_done,
|
||||
mock_rms_norm, mock_add_norm,
|
||||
mock_distributed, base_config,
|
||||
vllm_config):
|
||||
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
|
||||
|
||||
@@ -24,10 +24,10 @@ from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
||||
|
||||
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
|
||||
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
|
||||
from vllm_ascend.quantization.quantizer import W8A8Quantizer
|
||||
from vllm_ascend.torchair.ops.torchair_fused_moe import (
|
||||
TorchairAscendFusedMoE, TorchairAscendUnquantizedFusedMoEMethod)
|
||||
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
|
||||
from vllm_ascend.utils import adapt_patch # noqa E402
|
||||
from vllm_ascend.utils import AscendSocVersion, vllm_version_is
|
||||
|
||||
adapt_patch(True)
|
||||
|
||||
@@ -54,6 +54,10 @@ def mock_dp_and_tp_group(mocker):
|
||||
@pytest.fixture
|
||||
def mock_dist_env(mocker: MockerFixture):
|
||||
# init dist env patch
|
||||
if vllm_version_is("0.10.2"):
|
||||
dp_metadata = MagicMock(cu_tokens_across_dp_cpu=[5, 10])
|
||||
else:
|
||||
dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5])
|
||||
|
||||
with patch('torch.distributed.get_rank', return_value=0), \
|
||||
patch('torch.distributed.get_world_size', return_value=4), \
|
||||
@@ -67,13 +71,13 @@ def mock_dist_env(mocker: MockerFixture):
|
||||
patch('torch.distributed.all_to_all_single', return_value=torch.randn(8, 32)), \
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.tensor_model_parallel_all_reduce',
|
||||
return_value=torch.randn(5, 32)), \
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.data_parallel_reduce_scatter',
|
||||
return_value=torch.randn(5, 32)), \
|
||||
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
|
||||
return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_config',
|
||||
return_value=MagicMock(
|
||||
torchair_graph_config=MagicMock(enabled=False, enable_multistream_moe=False),
|
||||
torchair_graph_config=MagicMock(enabled=False),
|
||||
enable_multistream_moe=False,
|
||||
enable_shared_expert_dp=False,
|
||||
expert_map_path=None
|
||||
)), \
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.determine_expert_map',
|
||||
@@ -81,7 +85,7 @@ def mock_dist_env(mocker: MockerFixture):
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context',
|
||||
return_value=MagicMock(
|
||||
max_tokens_across_dp=10,
|
||||
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10])
|
||||
dp_metadata=dp_metadata,
|
||||
)), \
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config',
|
||||
return_value=MagicMock(
|
||||
@@ -154,6 +158,8 @@ def default_moe_config():
|
||||
def moe_method(mock_dist_env):
|
||||
moe = MagicMock()
|
||||
moe.moe_parallel_config.return_value = MagicMock(ep_size=4)
|
||||
moe.moe_parallel_config.use_ep = False
|
||||
moe.moe_parallel_config.dp_size = 1
|
||||
return TorchairAscendUnquantizedFusedMoEMethod(moe)
|
||||
|
||||
|
||||
@@ -199,6 +205,9 @@ class MockFusedMoEMethod(FusedMoEMethodBase):
|
||||
expert_weights: torch.Tensor) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
def get_fused_moe_quant_config(self, layer: torch.nn.Module):
|
||||
pass
|
||||
|
||||
|
||||
class TestTorchairAscendFusedMoe:
|
||||
|
||||
@@ -236,12 +245,9 @@ class TestTorchairAscendFusedMoe:
|
||||
mock_quant_method = MockFusedMoEMethod()
|
||||
mock_quant_config.get_quant_method.return_value = mock_quant_method
|
||||
mock_quant_config.is_layer_skipped_ascend.return_value = False
|
||||
with patch(
|
||||
'vllm_ascend.quantization.quantizer.AscendQuantizer.get_quantizer',
|
||||
return_value=W8A8Quantizer):
|
||||
with patch("vllm_ascend.quantization.quant_config.get_quant_method"):
|
||||
moe = TorchairAscendFusedMoE(**default_moe_config,
|
||||
quant_config=mock_quant_config)
|
||||
|
||||
assert moe.quant_method is not None
|
||||
assert isinstance(moe.quant_method, AscendFusedMoEMethod)
|
||||
|
||||
|
||||
@@ -5,8 +5,9 @@ import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.torchair.ops.torchair_rotary_embedding import (
|
||||
custom_rotary_embedding_enabled, native_rope_deepseek_forward,
|
||||
rope_forward_oot, rotate_half, yarn_find_correction_dim, yarn_get_mscale)
|
||||
_set_cos_sin_cache, custom_rotary_embedding_enabled,
|
||||
native_rope_deepseek_forward, rope_forward_oot, rotate_half,
|
||||
yarn_find_correction_dim, yarn_get_mscale)
|
||||
|
||||
|
||||
class TestCustomRotaryEmbeddingEnabled(TestBase):
|
||||
@@ -103,7 +104,7 @@ class TestRopeForwardOot(TestBase):
|
||||
self.assertTrue(torch.equal(result_q, self.query))
|
||||
self.assertTrue(torch.equal(result_k, self.key))
|
||||
|
||||
@patch('torch.ops._C')
|
||||
@patch('torch.ops._C_ascend')
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
|
||||
@patch('vllm_ascend.torchair.ops.torchair_rotary_embedding.is_310p',
|
||||
@@ -200,6 +201,28 @@ class MockRopeModule:
|
||||
self.sin_cached = None
|
||||
self.rotary_dim = 1
|
||||
self.base = 1
|
||||
self.beta_fast = 32
|
||||
self.beta_slow = 1
|
||||
self.max_position_embeddings = 4096
|
||||
self.mscale = 1.0
|
||||
self.scaling_factor = 40
|
||||
|
||||
def register_buffer(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestSetSinCosCache(TestBase):
|
||||
|
||||
def test_set_cos_sin_cache(self):
|
||||
module = MockRopeModule()
|
||||
|
||||
with patch.object(module, "register_buffer") as mock_register_buffer:
|
||||
_set_cos_sin_cache(module,
|
||||
1024,
|
||||
device="cpu",
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
mock_register_buffer.assert_called()
|
||||
|
||||
|
||||
class TestNativeRopeDeepseekForward(TestBase):
|
||||
@@ -220,30 +243,6 @@ class TestNativeRopeDeepseekForward(TestBase):
|
||||
assert q_pe.shape == query.shape
|
||||
assert k_pe.shape == key.shape
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding._set_cos_sin_cache'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
|
||||
def test_native_rope_deepseek_forward_cache_handling(
|
||||
self, mock_rope_forward_oot, mock_set_cache):
|
||||
# Test cache situation is true
|
||||
module = MockRopeModule(max_seq_len=1024)
|
||||
positions = torch.tensor([1, 2, 3])
|
||||
query = torch.randn(1, 8, 128)
|
||||
key = torch.randn(1, 8, 128)
|
||||
|
||||
mock_rope_forward_oot.return_value = (query, key)
|
||||
|
||||
q_pe, k_pe = native_rope_deepseek_forward(module,
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
max_seq_len=2048)
|
||||
|
||||
assert q_pe.shape == query.shape
|
||||
assert k_pe.shape == key.shape
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
|
||||
def test_native_rope_deepseek_forward_key_reshaping(
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import copy
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import torch
|
||||
@@ -85,19 +84,19 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
||||
# old quant version weight
|
||||
param_dict = self.quant_method.get_dynamic_quant_param(
|
||||
self.experts, self.input_size, self.output_size, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.float32)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].shape,
|
||||
(self.experts, 2 * self.input_size, 1))
|
||||
self.assertEqual(param_dict["w13_weight_scale_second"].dtype,
|
||||
torch.bfloat16)
|
||||
torch.float32)
|
||||
self.assertEqual(param_dict["w13_weight_scale_second"].shape,
|
||||
(self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size))
|
||||
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.float32)
|
||||
self.assertEqual(param_dict["w2_weight_scale"].shape,
|
||||
(self.experts, self.output_size, 1))
|
||||
self.assertEqual(param_dict["w2_weight_scale_second"].dtype,
|
||||
torch.bfloat16)
|
||||
torch.float32)
|
||||
self.assertEqual(param_dict["w2_weight_scale_second"].shape,
|
||||
(self.experts, self.output_size,
|
||||
self.input_size // self.group_size))
|
||||
@@ -109,40 +108,80 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
||||
self.assertEqual(
|
||||
param_dict["w2_scale_bias"].shape,
|
||||
(self.experts, self.output_size, 16 // self.quant_method.tp_size))
|
||||
# per-channel weight
|
||||
self.quant_method.is_per_channel_weight = True
|
||||
param_dict = self.quant_method.get_dynamic_quant_param(
|
||||
self.experts, self.input_size, self.output_size, torch.bfloat16)
|
||||
pergroup_param = [
|
||||
"w13_weight_scale_second", "w13_weight_offset_second",
|
||||
"w2_weight_scale_second", "w2_weight_offset_second"
|
||||
]
|
||||
is_contains = any(key in param_dict for key in pergroup_param)
|
||||
self.assertFalse(is_contains)
|
||||
|
||||
def build_layer(self,
|
||||
is_new_quant_version=True,
|
||||
is_per_channel_weight=False):
|
||||
layer = torch.nn.Module()
|
||||
if is_new_quant_version:
|
||||
layer.w13_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.input_size, self.output_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.output_size // 2, self.input_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
w13_scale_bias = torch.zeros(
|
||||
(self.experts, 2 * self.input_size, 1), dtype=torch.float32)
|
||||
layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
|
||||
requires_grad=False)
|
||||
w2_scale_bias = torch.zeros((self.experts, self.output_size,
|
||||
16 // self.quant_method.tp_size),
|
||||
dtype=torch.float32)
|
||||
layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
|
||||
requires_grad=False)
|
||||
else:
|
||||
layer.w13_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, 2 * self.input_size, self.output_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.output_size, self.input_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, 2 * self.input_size, 1), dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, self.output_size, 1), dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
if not is_per_channel_weight:
|
||||
layer.w13_weight_scale_second = torch.nn.Parameter(
|
||||
torch.ones((self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size),
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_offset_second = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w13_weight_scale_second.data),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale_second = torch.nn.Parameter(
|
||||
torch.ones((self.experts, self.output_size,
|
||||
self.input_size // self.group_size),
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_offset_second = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w2_weight_scale_second.data),
|
||||
requires_grad=False)
|
||||
return layer
|
||||
|
||||
@patch('torch_npu.npu_quantize')
|
||||
@patch('torch.Tensor.npu')
|
||||
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
|
||||
# old quant version weight
|
||||
layer = torch.nn.Module()
|
||||
layer.w13_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, 2 * self.input_size, self.output_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.output_size, self.input_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, 2 * self.input_size, 1), dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size),
|
||||
dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, self.output_size, 1), dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, self.output_size,
|
||||
self.input_size // self.group_size),
|
||||
dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
new_layer = copy.deepcopy(layer)
|
||||
|
||||
mock_npu.return_value = torch.Tensor()
|
||||
mock_npu_quantize.return_value = torch.Tensor()
|
||||
# old quant version weight
|
||||
layer = self.build_layer(is_new_quant_version=False)
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
self.assertTrue(hasattr(layer, "w13_scale_bias"))
|
||||
self.assertEqual(layer.w13_scale_bias.data.shape,
|
||||
@@ -154,23 +193,17 @@ class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
||||
self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32)
|
||||
# new quant version weight
|
||||
self.quant_method.new_quant_version = True
|
||||
new_layer.w13_weight.data = torch.zeros(
|
||||
(self.experts, self.input_size, self.output_size),
|
||||
dtype=torch.int8)
|
||||
new_layer.w2_weight.data = torch.zeros(
|
||||
(self.experts, self.output_size // 2, self.input_size),
|
||||
dtype=torch.int8)
|
||||
w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1),
|
||||
dtype=torch.float32)
|
||||
new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
|
||||
requires_grad=False)
|
||||
w2_scale_bias = torch.zeros(
|
||||
(self.experts, self.output_size, 16 // self.quant_method.tp_size),
|
||||
dtype=torch.float32)
|
||||
new_layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
|
||||
requires_grad=False)
|
||||
new_layer = self.build_layer(is_new_quant_version=True)
|
||||
self.quant_method.process_weights_after_loading(new_layer)
|
||||
self.assertEqual(new_layer.w13_scale_bias.data.shape,
|
||||
(self.experts, 2 * self.input_size))
|
||||
self.assertEqual(new_layer.w2_scale_bias.data.shape,
|
||||
(self.experts, self.output_size))
|
||||
self.assertFalse(hasattr(new_layer, "w13_weight_scale_second"))
|
||||
# per-channel weight
|
||||
self.quant_method.is_per_channel_weight = True
|
||||
per_channel_layer = self.build_layer(is_new_quant_version=True,
|
||||
is_per_channel_weight=True)
|
||||
self.quant_method.process_weights_after_loading(per_channel_layer)
|
||||
self.assertEqual(new_layer.w13_scale_bias.data.shape,
|
||||
(self.experts, 2 * self.input_size))
|
||||
|
||||
95
tests/ut/torchair/test_torchair_attention.py
Normal file
95
tests/ut/torchair/test_torchair_attention.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.torchair.torchair_attention import \
|
||||
AscendAttentionTorchairBackendImpl
|
||||
|
||||
|
||||
class TestAscendAttentionTorchairBackendImpl(TestBase):
|
||||
|
||||
@patch("torch.zeros")
|
||||
@patch('vllm.distributed.parallel_state._TP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator)) # TODO
|
||||
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
|
||||
return_value=2) # TODO
|
||||
@patch("vllm.config.get_current_vllm_config") # TODO
|
||||
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config") # TODO
|
||||
def setUp(self, ascend_config, vllm_config, mock_get_tp_size, mock_tp,
|
||||
mock_zeros):
|
||||
mock_tp.world_size = 2 # TODO
|
||||
ascend_config.torchair_graph_config.enabled = True # TODO
|
||||
ascend_config.torchair_graph_config.enable_kv_nz = False # TODO
|
||||
speculative_config = MagicMock()
|
||||
speculative_config.num_speculative_tokens = 4
|
||||
vllm_config.speculative_config = speculative_config
|
||||
|
||||
num_heads = 32
|
||||
head_size = 128 # TODO
|
||||
scale = 0.1 # TODO
|
||||
num_kv_heads = 4
|
||||
kv_cache_dtype = "auto"
|
||||
attn_type = AttentionType.DECODER
|
||||
mock_zeros.return_value = torch.ones((),
|
||||
device='cpu',
|
||||
dtype=torch.int32)
|
||||
|
||||
self.impl = AscendAttentionTorchairBackendImpl(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
blocksparse_params=None,
|
||||
logits_soft_cap=None,
|
||||
attn_type=attn_type,
|
||||
kv_sharing_target_layer_name=None)
|
||||
|
||||
@patch("torch_npu.npu_scatter_nd_update_")
|
||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||
def test_forward_with_decode_only(self, mock_fused, _):
|
||||
layer = MagicMock()
|
||||
layer._k_scale_float = 1.0
|
||||
layer._v_scale_float = 1.0
|
||||
|
||||
seq_len = 1
|
||||
num_tokens = 100
|
||||
num_blocks = 256
|
||||
block_size = 4
|
||||
|
||||
query = torch.randn(num_tokens, seq_len,
|
||||
self.impl.num_heads * self.impl.head_size)
|
||||
key = torch.randn(num_tokens, seq_len,
|
||||
self.impl.num_kv_heads * self.impl.head_size)
|
||||
value = torch.randn(num_tokens, seq_len,
|
||||
self.impl.num_kv_heads * self.impl.head_size)
|
||||
kv_cache = (torch.randn(num_blocks, block_size,
|
||||
self.impl.num_heads * self.impl.head_size),
|
||||
torch.randn(num_blocks, block_size,
|
||||
self.impl.num_heads * self.impl.head_size))
|
||||
output = torch.randn(num_tokens, self.impl.num_heads,
|
||||
self.impl.head_size)
|
||||
|
||||
decode = MagicMock() # TODO
|
||||
decode.seq_lens_list = [2] * num_tokens
|
||||
decode.block_table = torch.ones(num_tokens, 8, dtype=torch.int32)
|
||||
decode.attn_mask = None
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||
metadata.slot_mapping = torch.arange(num_tokens, dtype=torch.int32)
|
||||
metadata.decode = decode
|
||||
|
||||
mock_fused.return_value = (torch.ones(num_tokens, self.impl.num_heads,
|
||||
self.impl.head_size),
|
||||
torch.ones(1))
|
||||
|
||||
result = self.impl.forward(layer, query, key, value, kv_cache,
|
||||
metadata, output, False)
|
||||
self.assertEqual(result.shape[0], num_tokens)
|
||||
@@ -190,12 +190,15 @@ class TestAscendMLATorchairMetadataBuilder(TestBase):
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_device = 'cpu'
|
||||
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
ascend_config = MagicMock()
|
||||
ascend_config.torchair_graph_config = MagicMock()
|
||||
ascend_config.torchair_graph_config.enabled = True
|
||||
with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config",
|
||||
return_value=ascend_config):
|
||||
builder = AscendMLATorchairMetadataBuilder(mock_vllm_config,
|
||||
builder = AscendMLATorchairMetadataBuilder(None, None,
|
||||
mock_vllm_config,
|
||||
mock_device)
|
||||
|
||||
self.assertEqual(builder.block_size,
|
||||
@@ -216,7 +219,10 @@ class TestAscendMLATorchairMetadataBuilder(TestBase):
|
||||
ascend_config.torchair_graph_config = MagicMock()
|
||||
ascend_config.torchair_graph_config.enabled = True
|
||||
|
||||
builder = AscendMLATorchairMetadataBuilder(mock_vllm_config,
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
builder = AscendMLATorchairMetadataBuilder(None, None,
|
||||
mock_vllm_config,
|
||||
mock_device)
|
||||
|
||||
input_batch = MagicMock()
|
||||
@@ -250,9 +256,12 @@ class TestAscendMLATorchairMetadataBuilder(TestBase):
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_device = 'cpu'
|
||||
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config",
|
||||
return_value=ascend_config):
|
||||
builder = AscendMLATorchairMetadataBuilder(mock_vllm_config,
|
||||
builder = AscendMLATorchairMetadataBuilder(None, None,
|
||||
mock_vllm_config,
|
||||
mock_device)
|
||||
|
||||
input_batch = MagicMock()
|
||||
@@ -285,7 +294,10 @@ class TestAscendMLATorchairMetadataBuilder(TestBase):
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_device = 'cpu'
|
||||
|
||||
builder = AscendMLATorchairMetadataBuilder(mock_vllm_config,
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
builder = AscendMLATorchairMetadataBuilder(None, None,
|
||||
mock_vllm_config,
|
||||
mock_device)
|
||||
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
|
||||
|
||||
@@ -305,7 +317,10 @@ class TestAscendMLATorchairMetadataBuilder(TestBase):
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_device = 'cpu'
|
||||
|
||||
builder = AscendMLATorchairMetadataBuilder(mock_vllm_config,
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
builder = AscendMLATorchairMetadataBuilder(None, None,
|
||||
mock_vllm_config,
|
||||
mock_device)
|
||||
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
|
||||
|
||||
@@ -326,7 +341,10 @@ class TestAscendMLATorchairMetadataBuilder(TestBase):
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_device = 'cpu'
|
||||
|
||||
builder = AscendMLATorchairMetadataBuilder(mock_vllm_config,
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
builder = AscendMLATorchairMetadataBuilder(None, None,
|
||||
mock_vllm_config,
|
||||
mock_device)
|
||||
|
||||
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
|
||||
@@ -351,7 +369,11 @@ class TestAscendMLATorchairMetadataBuilder(TestBase):
|
||||
mock_vllm_config.model_config.dtype = torch.float16
|
||||
mock_device = 'cpu'
|
||||
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
builder = AscendMLATorchairMetadataBuilder(
|
||||
None,
|
||||
None,
|
||||
mock_vllm_config,
|
||||
mock_device,
|
||||
metadata_cls=AscendMLATorchairMetadata)
|
||||
@@ -416,7 +438,11 @@ class TestAscendMLATorchairMetadataBuilder(TestBase):
|
||||
model = MagicMock(spec=nn.Module)
|
||||
model.model = MagicMock(spec=nn.Module)
|
||||
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
builder = AscendMLATorchairMetadataBuilder(
|
||||
None,
|
||||
None,
|
||||
mock_vllm_config,
|
||||
mock_device,
|
||||
metadata_cls=AscendMLATorchairMetadata)
|
||||
@@ -437,14 +463,16 @@ class TestAscendMLATorchairMetadataBuilder(TestBase):
|
||||
max_query_len=1,
|
||||
decode_token_per_req=torch.tensor([1, 1, 1]),
|
||||
block_table_tensor=torch.zeros((10, 10)),
|
||||
slot_mapping_cpu=torch.tensor(range(20)),
|
||||
slot_mapping=torch.tensor(range(20)),
|
||||
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
||||
positions=torch.tensor([1, 1]),
|
||||
attn_mask=torch.ones((15, 15)),
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.ChunkedPrefill)
|
||||
attn_state=AscendAttentionState.ChunkedPrefill,
|
||||
num_computed_tokens_cpu=None,
|
||||
seq_lens=None)
|
||||
|
||||
metadata = builder.build(common_attn_metadata, model)
|
||||
metadata = builder.build(1, common_attn_metadata, model)
|
||||
|
||||
self.assertIsInstance(metadata, AscendMLATorchairMetadata)
|
||||
self.assertEqual(metadata.num_input_tokens, 0)
|
||||
|
||||
@@ -6,7 +6,6 @@ from unittest.mock import MagicMock, patch
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.quantizer import SUPPORT_ASCEND_QUANTIZER_TYPE
|
||||
from vllm_ascend.torchair import utils
|
||||
|
||||
|
||||
@@ -135,15 +134,3 @@ class TestTorchairUtils(TestBase):
|
||||
|
||||
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
|
||||
mock_npu_cast.assert_not_called()
|
||||
|
||||
def test_torchair_quant_method_register(self):
|
||||
|
||||
TorchairW8A8DYNAMICQuantizer = SUPPORT_ASCEND_QUANTIZER_TYPE[
|
||||
"W8A8_DYNAMIC"]
|
||||
TorchairW4A8DYNAMICQuantizer = SUPPORT_ASCEND_QUANTIZER_TYPE[
|
||||
"W4A8_DYNAMIC"]
|
||||
utils.torchair_quant_method_register()
|
||||
self.assertNotEqual(TorchairW8A8DYNAMICQuantizer,
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE["W8A8_DYNAMIC"])
|
||||
self.assertNotEqual(TorchairW4A8DYNAMICQuantizer,
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE["W4A8_DYNAMIC"])
|
||||
|
||||
@@ -24,8 +24,8 @@ from vllm.utils import make_tensor_with_pad
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
|
||||
|
||||
from vllm_ascend.worker.block_table import BlockTable, MultiGroupBlockTable
|
||||
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
VOCAB_SIZE = 1024
|
||||
|
||||
107
tests/ut/worker/test_model_runner_v1.py
Normal file
107
tests/ut/worker/test_model_runner_v1.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.utils import AscendSocVersion
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
"soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, quant_type, expected_method",
|
||||
[
|
||||
# Case 1: Expert parallel is disabled, should always be 'allgather'
|
||||
(AscendSocVersion.A2, False, 8, 100, 256, None, MoECommType.ALLGATHER),
|
||||
(AscendSocVersion.A3, False, 16, 500, 256, None, MoECommType.ALLGATHER),
|
||||
|
||||
# Case 2: A2 SOC with w4a8_dynamic -> use alltoall when not mc2
|
||||
(AscendSocVersion.A2, True, 8, 100, 256, "w4a8_dynamic", MoECommType.ALLTOALL),
|
||||
(AscendSocVersion.A2, True, 16, 257, 256, "w4a8_dynamic", MoECommType.ALLTOALL),
|
||||
(AscendSocVersion.A2, True, 16, 100, 256, "w4a8_dynamic", MoECommType.MC2), # meets mc2 condition
|
||||
|
||||
# Case 3: A2 SOC without w4a8_dynamic -> fallback to allgather
|
||||
(AscendSocVersion.A2, True, 8, 100, 256, None, MoECommType.ALLGATHER),
|
||||
(AscendSocVersion.A2, True, 16, 257, 256, None, MoECommType.ALLGATHER),
|
||||
|
||||
# Case 4: A3 SOC
|
||||
(AscendSocVersion.A3, True, 8, 100, 256, None, MoECommType.MC2),
|
||||
(AscendSocVersion.A3, True, 8, 257, 256, None, MoECommType.ALLTOALL),
|
||||
])
|
||||
# yapf: enable
|
||||
def test_select_moe_comm_method(soc_version, enable_expert_parallel,
|
||||
world_size, num_tokens, mc2_tokens_capacity,
|
||||
quant_type, expected_method):
|
||||
"""
|
||||
Tests the _select_moe_comm_method with various configurations including quant_type.
|
||||
"""
|
||||
# Mock the NPUModelRunner instance and its dependencies
|
||||
mock_runner = MagicMock(spec=NPUModelRunner)
|
||||
mock_runner.parallel_config = MagicMock()
|
||||
mock_runner.parallel_config.enable_expert_parallel = enable_expert_parallel
|
||||
mock_runner.parallel_config.world_size_across_dp = world_size
|
||||
mock_runner.mc2_tokens_capacity = mc2_tokens_capacity
|
||||
|
||||
# Add vllm_config.model_config.hf_config mock with moe_quantize
|
||||
mock_hf_config = MagicMock()
|
||||
mock_hf_config.moe_quantize = quant_type
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.hf_config = mock_hf_config
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config = mock_model_config
|
||||
mock_runner.vllm_config = mock_vllm_config
|
||||
|
||||
# Patch the helper functions
|
||||
with patch('vllm_ascend.worker.model_runner_v1.get_ascend_soc_version',
|
||||
return_value=soc_version), \
|
||||
patch('vllm_ascend.worker.model_runner_v1.is_global_first_rank',
|
||||
return_value=True):
|
||||
|
||||
# Bind the real method to the mock object
|
||||
method = NPUModelRunner._select_moe_comm_method(
|
||||
mock_runner, num_tokens, False)
|
||||
|
||||
# Assert the result
|
||||
assert method == expected_method
|
||||
|
||||
|
||||
def test_select_moe_comm_method_unsupported_soc():
|
||||
"""
|
||||
Tests that _select_moe_comm_method raises ValueError for an unsupported SOC.
|
||||
"""
|
||||
mock_runner = MagicMock(spec=NPUModelRunner)
|
||||
mock_runner.parallel_config = MagicMock()
|
||||
mock_runner.parallel_config.enable_expert_parallel = True
|
||||
mock_runner.mc2_tokens_capacity = 256
|
||||
|
||||
# Add vllm_config.model_config.hf_config mock with moe_quantize
|
||||
mock_hf_config = MagicMock()
|
||||
mock_hf_config.moe_quantize = None
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.hf_config = mock_hf_config
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config = mock_model_config
|
||||
mock_runner.vllm_config = mock_vllm_config
|
||||
|
||||
unsupported_soc = "UnsupportedSOC"
|
||||
|
||||
with patch('vllm_ascend.worker.model_runner_v1.get_ascend_soc_version',
|
||||
return_value=unsupported_soc), \
|
||||
patch('vllm_ascend.worker.model_runner_v1.is_global_first_rank',
|
||||
return_value=True), \
|
||||
pytest.raises(ValueError, match=f"Unsupported soc_version: {unsupported_soc}"):
|
||||
|
||||
NPUModelRunner._select_moe_comm_method(mock_runner, 100, False)
|
||||
@@ -258,7 +258,7 @@ class TestNPUWorker(TestBase):
|
||||
# Create worker mock
|
||||
with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None):
|
||||
worker = NPUWorker()
|
||||
|
||||
worker._sleep_saved_buffers = {}
|
||||
# Test wake_up method
|
||||
worker.wake_up(tags=["test_tag"])
|
||||
|
||||
@@ -355,6 +355,28 @@ class TestNPUWorker(TestBase):
|
||||
|
||||
self.assertIn("Profiler is not enabled", str(cm.exception))
|
||||
|
||||
@patch("vllm_ascend.worker.worker_v1.envs_vllm")
|
||||
@patch("vllm_ascend.worker.worker_v1.envs_ascend")
|
||||
def test_profile_and_msmonitor_both_enabled_raises_error(
|
||||
self, mock_envs_vllm, mock_envs_ascend):
|
||||
"""Test profile method raises exception when both profiler and msmonitor are enabled"""
|
||||
from vllm_ascend.worker.worker_v1 import NPUWorker
|
||||
|
||||
mock_envs_vllm.VLLM_TORCH_PROFILER_DIR = "/path/to/traces"
|
||||
mock_envs_ascend.MSMONITOR_USE_DAEMON = 1
|
||||
|
||||
# Create worker mock
|
||||
with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None):
|
||||
worker = NPUWorker()
|
||||
|
||||
# Test should raise exception
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
_ = worker._init_profiler()
|
||||
|
||||
self.assertIn(
|
||||
"MSMONITOR_USE_DAEMON and VLLM_TORCH_PROFILER_DIR cannot be both set at the same time.",
|
||||
str(cm.exception))
|
||||
|
||||
def test_lora_methods(self):
|
||||
"""Test LoRA related methods"""
|
||||
from vllm_ascend.worker.worker_v1 import NPUWorker
|
||||
@@ -828,6 +850,7 @@ class TestNPUWorker(TestBase):
|
||||
|
||||
# Mock scheduler_output and return result
|
||||
mock_scheduler_output = MagicMock()
|
||||
mock_scheduler_output.total_num_scheduled_tokens = 1
|
||||
# Create a real ModelRunnerOutput instance or mock
|
||||
mock_model_output = MagicMock(spec=ModelRunnerOutput)
|
||||
worker.model_runner.execute_model.return_value = mock_model_output
|
||||
@@ -842,9 +865,8 @@ class TestNPUWorker(TestBase):
|
||||
|
||||
@patch("vllm_ascend.worker.worker_v1.get_pp_group")
|
||||
@patch("vllm_ascend.worker.worker_v1.get_tp_group")
|
||||
@patch("vllm_ascend.worker.worker_v1.has_kv_transfer_group")
|
||||
def test_execute_model_middle_rank(self, mock_has_kv_transfer_group,
|
||||
mock_get_tp_group, mock_get_pp_group):
|
||||
def test_execute_model_middle_rank(self, mock_get_tp_group,
|
||||
mock_get_pp_group):
|
||||
"""Test execute_model method - middle rank case"""
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@@ -875,10 +897,8 @@ class TestNPUWorker(TestBase):
|
||||
)
|
||||
worker.model_runner.execute_model.return_value = mock_intermediate_output
|
||||
|
||||
# Set has_kv_transfer_group returns False
|
||||
mock_has_kv_transfer_group.return_value = False
|
||||
|
||||
mock_scheduler_output = MagicMock()
|
||||
mock_scheduler_output.total_num_scheduled_tokens = 1
|
||||
|
||||
# Test execute_model
|
||||
result = worker.execute_model(mock_scheduler_output)
|
||||
@@ -926,6 +946,7 @@ class TestNPUWorker(TestBase):
|
||||
|
||||
# Mock return result
|
||||
mock_scheduler_output = MagicMock()
|
||||
mock_scheduler_output.total_num_scheduled_tokens = 1
|
||||
mock_model_output = MagicMock(spec=ModelRunnerOutput)
|
||||
worker.model_runner.execute_model.return_value = mock_model_output
|
||||
|
||||
@@ -1009,7 +1030,9 @@ class TestNPUWorker(TestBase):
|
||||
|
||||
@patch("vllm_ascend.worker.worker_v1.NPUPlatform.seed_everything")
|
||||
@patch("vllm_ascend.worker.worker_v1.logger")
|
||||
def test_compile_or_warm_up_model_with_eager_mode(self, mock_logger,
|
||||
@patch("vllm_ascend.worker.worker_v1.NPUWorker._warm_up_atb")
|
||||
def test_compile_or_warm_up_model_with_eager_mode(self, mock_warm_up_atb,
|
||||
mock_logger,
|
||||
mock_seed_everything):
|
||||
"""Test compile_or_warm_up_model method - eager mode"""
|
||||
from vllm_ascend.worker.worker_v1 import NPUWorker
|
||||
@@ -1051,10 +1074,14 @@ class TestNPUWorker(TestBase):
|
||||
# Verify seed setting
|
||||
mock_seed_everything.assert_called_once_with(12345)
|
||||
|
||||
# Verify atb warm up
|
||||
mock_warm_up_atb.assert_called_once()
|
||||
|
||||
@patch("vllm_ascend.worker.worker_v1.NPUPlatform.seed_everything")
|
||||
@patch("vllm_ascend.worker.worker_v1.logger")
|
||||
@patch("vllm_ascend.worker.worker_v1.NPUWorker._warm_up_atb")
|
||||
def test_compile_or_warm_up_model_with_graph_capture(
|
||||
self, mock_logger, mock_seed_everything):
|
||||
self, mock_warm_up_atb, mock_logger, mock_seed_everything):
|
||||
"""Test compile_or_warm_up_model method - with graph capture enabled"""
|
||||
from vllm_ascend.worker.worker_v1 import NPUWorker
|
||||
|
||||
@@ -1087,6 +1114,9 @@ class TestNPUWorker(TestBase):
|
||||
# Verify seed setting
|
||||
mock_seed_everything.assert_called_once_with(67890)
|
||||
|
||||
# Verify atb warm up
|
||||
mock_warm_up_atb.assert_called_once()
|
||||
|
||||
@patch("vllm_ascend.worker.worker_v1.CaMemAllocator")
|
||||
def test_initialize_from_config_with_sleep_mode(self,
|
||||
mock_allocator_class):
|
||||
@@ -1141,3 +1171,55 @@ class TestNPUWorker(TestBase):
|
||||
# Verify calls
|
||||
worker.model_runner.initialize_kv_cache.assert_called_once_with(
|
||||
mock_kv_cache_config)
|
||||
|
||||
@patch("vllm_ascend.worker.worker_v1.get_pp_group")
|
||||
@patch("vllm_ascend.worker.worker_v1.get_tp_group")
|
||||
@patch("vllm_ascend.worker.worker_v1.EMPTY_MODEL_RUNNER_OUTPUT")
|
||||
def test_execute_model_kv_connector_not_finished(self, mock_empty_output,
|
||||
mock_get_tp_group,
|
||||
mock_get_pp_group):
|
||||
"""Test execute_model method - kv_connector_output not finished sending/recving case"""
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.worker.worker_v1 import NPUWorker
|
||||
|
||||
# Create worker mock
|
||||
with patch.object(NPUWorker, "__init__", lambda x, **kwargs: None):
|
||||
worker = NPUWorker()
|
||||
worker.model_runner = MagicMock()
|
||||
worker.vllm_config = MagicMock()
|
||||
worker.vllm_config.parallel_config = MagicMock()
|
||||
worker.vllm_config.parallel_config.distributed_executor_backend = "ray"
|
||||
|
||||
# Set as middle rank (not first, not last)
|
||||
mock_pp_group = MagicMock()
|
||||
mock_pp_group.is_first_rank = False
|
||||
mock_pp_group.is_last_rank = False
|
||||
mock_get_pp_group.return_value = mock_pp_group
|
||||
|
||||
# Setup tensor reception data
|
||||
mock_pp_group.recv_tensor_dict.return_value = {"tensor": "data"}
|
||||
|
||||
# Create mock kv_connector_output - both finished_sending and finished_recving are False
|
||||
mock_kv_connector_output = MagicMock()
|
||||
mock_kv_connector_output.finished_sending = False
|
||||
mock_kv_connector_output.finished_recving = False
|
||||
|
||||
# Mock return IntermediateTensors with kv_connector_output
|
||||
mock_intermediate_output = MagicMock(spec=IntermediateTensors)
|
||||
mock_intermediate_output.tensors = {"output_tensor": "data"}
|
||||
mock_intermediate_output.kv_connector_output = mock_kv_connector_output
|
||||
worker.model_runner.execute_model.return_value = mock_intermediate_output
|
||||
|
||||
mock_scheduler_output = MagicMock()
|
||||
mock_scheduler_output.total_num_scheduled_tokens = 1
|
||||
|
||||
# Test execute_model
|
||||
result = worker.execute_model(mock_scheduler_output)
|
||||
|
||||
# Verify tensor reception and sending
|
||||
mock_pp_group.recv_tensor_dict.assert_called_once()
|
||||
mock_pp_group.send_tensor_dict.assert_called_once()
|
||||
|
||||
# When both finished_sending and finished_recving are False, should return EMPTY_MODEL_RUNNER_OUTPUT directly
|
||||
self.assertEqual(result, mock_empty_output)
|
||||
|
||||
Reference in New Issue
Block a user