[2/N][Pangu][MoE] Remove Pangu Related Code (#5130)
### What this PR does / why we need it?
Remove Pangu Related Code
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e & ut
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: weichen <calvin_zhu0210@outlook.com>
This commit is contained in:
@@ -30,8 +30,7 @@ The following table lists additional configuration options available in vLLM Asc
|
||||
| `finegrained_tp_config` | dict | `{}` | Configuration options for module tensor parallelism |
|
||||
| `weight_prefetch_config` | dict | `{}` | Configuration options for weight prefetch |
|
||||
| `refresh` | bool | `false` | Whether to refresh global Ascend configuration content. This is usually used by rlhf or ut/e2e test case. |
|
||||
| `expert_map_path` | str | `None` | When using expert load balancing for an MoE model, an expert map path needs to be passed in. |
|
||||
| `kv_cache_dtype` | str | `None` | When using the KV cache quantization method, KV cache dtype needs to be set, currently only int8 is supported. |
|
||||
| `expert_map_path` | str | `None` | When using expert load balancing for an MoE model, an expert map path needs to be passed in. | |
|
||||
| `enable_shared_expert_dp` | bool | `False` | When the expert is shared in DP, it delivers better performance but consumes more memory. Currently only DeepSeek series models are supported. |
|
||||
| `lmhead_tensor_parallel_size` | int | `None` | The custom tensor parallel size of lmhead. Restriction: Can only be used when tensor_parallel=1 |
|
||||
| `oproj_tensor_parallel_size` | int | `None` | The custom tensor parallel size of oproj. |
|
||||
|
||||
@@ -31,10 +31,6 @@ class TestAscendAttentionBackend(TestBase):
|
||||
result = AscendAttentionBackend.get_kv_cache_shape(10, 20, 30, 40)
|
||||
self.assertEqual(result, (2, 10, 20, 30, 40))
|
||||
|
||||
def test_get_bsh_kv_cache_shape(self):
|
||||
result = AscendAttentionBackend.get_bsh_kv_cache_shape(10, 20, 30, 40)
|
||||
self.assertEqual(result, (2, 10, 20, 30 * 40))
|
||||
|
||||
def test_swap_blocks(self):
|
||||
src_kv_cache = [torch.zeros((10, 20)), torch.zeros((10, 20))]
|
||||
dst_kv_cache = [torch.zeros((10, 20)), torch.zeros((10, 20))]
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
@@ -8,8 +7,7 @@ from vllm.model_executor.layers.linear import LinearBase
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
|
||||
from vllm_ascend.quantization.quant_config import (AscendKVCacheMethod,
|
||||
AscendQuantConfig)
|
||||
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
||||
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
|
||||
|
||||
|
||||
@@ -19,7 +17,6 @@ class TestAscendQuantConfig(TestBase):
|
||||
self.sample_config = {
|
||||
"weight": "INT8",
|
||||
"fa_quant_type": "C8",
|
||||
"kv_quant_type": "C8",
|
||||
"layer1.weight": "INT8",
|
||||
"layer2.weight": "FLOAT",
|
||||
"fused_layer.weight": "FLOAT",
|
||||
@@ -115,16 +112,6 @@ class TestAscendQuantConfig(TestBase):
|
||||
attention_layer, ".attn")
|
||||
self.assertIs(method, mock_ascend_kvcache.return_value)
|
||||
|
||||
with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
|
||||
patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod', \
|
||||
return_value=MagicMock()) as mock_ascend_kvcache:
|
||||
# Test with kv_quant_type
|
||||
modified_config = {"kv_quant_type": "C8"}
|
||||
config = AscendQuantConfig(modified_config)
|
||||
config.packed_modules_mapping = None
|
||||
method = config.get_quant_method(attention_layer, "attn")
|
||||
self.assertIs(method, mock_ascend_kvcache.return_value)
|
||||
|
||||
def test_get_quant_method_for_fused_moe(self):
|
||||
fused_moe_layer = MagicMock(spec=FusedMoE)
|
||||
fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig)
|
||||
@@ -169,73 +156,3 @@ class TestAscendQuantConfig(TestBase):
|
||||
|
||||
def test_get_scaled_act_names(self):
|
||||
self.assertEqual(self.ascend_config.get_scaled_act_names(), [])
|
||||
|
||||
|
||||
class TestAscendKVCacheMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Setup common test fixtures
|
||||
self.mock_quant_config = MagicMock(spec=AscendQuantConfig)
|
||||
self.mock_quant_config.quant_description = {"kv_quant_type": "C8"}
|
||||
self.prefix = "layer.attn"
|
||||
|
||||
# Mock quant_method
|
||||
self.mock_quant_method = MagicMock()
|
||||
self.patcher = patch(
|
||||
'vllm_ascend.quantization.quant_config.get_quant_method')
|
||||
self.mock_get_quant_method = self.patcher.start()
|
||||
self.mock_get_quant_method.return_value = self.mock_quant_method
|
||||
|
||||
# Create instance
|
||||
self.kv_cache_method = AscendKVCacheMethod(self.mock_quant_config,
|
||||
self.prefix)
|
||||
|
||||
def tearDown(self):
|
||||
self.patcher.stop()
|
||||
|
||||
def test_create_weights(self):
|
||||
"""Test create_weights delegates to quant_method."""
|
||||
mock_layer = MagicMock()
|
||||
self.kv_cache_method.create_weights(mock_layer)
|
||||
self.mock_quant_method.create_weights.assert_called_once_with(
|
||||
mock_layer)
|
||||
|
||||
def test_process_weights_after_loading_with_method(self):
|
||||
"""Test process_weights when quant_method has the method."""
|
||||
mock_layer = MagicMock()
|
||||
self.kv_cache_method.process_weights_after_loading(mock_layer)
|
||||
self.mock_quant_method.process_weights_after_loading.assert_called_once_with(
|
||||
mock_layer)
|
||||
|
||||
def test_process_weights_after_loading_without_method(self):
|
||||
"""Test process_weights when quant_method lacks the method."""
|
||||
# Reset mock to remove the method
|
||||
del self.mock_quant_method.process_weights_after_loading
|
||||
mock_layer = MagicMock()
|
||||
|
||||
# Should not raise exception
|
||||
self.kv_cache_method.process_weights_after_loading(mock_layer)
|
||||
|
||||
def test_apply_delegation(self):
|
||||
"""Test apply properly delegates to quant_method."""
|
||||
mock_layer = MagicMock()
|
||||
mock_query = torch.randn(1, 32, 128)
|
||||
mock_key = torch.randn(1, 32, 128)
|
||||
mock_value = torch.randn(1, 32, 128)
|
||||
mock_kv_cache = MagicMock()
|
||||
mock_attn_metadata = MagicMock()
|
||||
mock_scale = 1.0
|
||||
mock_output = torch.zeros(1, 32, 128)
|
||||
mock_attn_type = MagicMock()
|
||||
expected_result = torch.randn(1, 32, 128)
|
||||
self.mock_quant_method.apply.return_value = expected_result
|
||||
|
||||
result = self.kv_cache_method.apply(mock_layer, mock_query, mock_key,
|
||||
mock_value, mock_kv_cache,
|
||||
mock_attn_metadata, mock_attn_type,
|
||||
mock_scale, mock_output)
|
||||
|
||||
self.mock_quant_method.apply.assert_called_once_with(
|
||||
mock_layer, mock_query, mock_key, mock_value, mock_kv_cache,
|
||||
mock_attn_metadata, mock_attn_type, mock_scale, mock_output)
|
||||
self.assertTrue(torch.equal(result, expected_result))
|
||||
|
||||
@@ -39,18 +39,6 @@ class TestGetQuantMethod(TestBase):
|
||||
"moe")
|
||||
self.assertIsInstance(method, cls)
|
||||
|
||||
def test_with_fa_quant_type(self):
|
||||
quant_description = {"fa_quant_type": "C8"}
|
||||
method = get_quant_method(quant_description, ".attn", "attention")
|
||||
self.assertIsInstance(
|
||||
method, ASCEND_QUANTIZATION_METHOD_MAP["C8"]["attention"])
|
||||
|
||||
def test_with_kv_quant_type(self):
|
||||
quant_description = {"kv_quant_type": "C8"}
|
||||
method = get_quant_method(quant_description, ".attn", "attention")
|
||||
self.assertIsInstance(
|
||||
method, ASCEND_QUANTIZATION_METHOD_MAP["C8"]["attention"])
|
||||
|
||||
def test_invalid_layer_type(self):
|
||||
quant_description = {"linear_layer.weight": "W8A8"}
|
||||
with self.assertRaises(NotImplementedError):
|
||||
|
||||
@@ -1,16 +1,9 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import (_native_grouped_topk,
|
||||
select_experts)
|
||||
from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod,
|
||||
AscendW8A8FusedMoEMethod,
|
||||
AscendW8A8LinearMethod,
|
||||
fused_experts, fused_experts_310p,
|
||||
from vllm_ascend.quantization.w8a8 import (AscendW8A8LinearMethod,
|
||||
quant_per_tensor)
|
||||
from vllm_ascend.utils import AscendDeviceType
|
||||
|
||||
@@ -194,791 +187,3 @@ class TestAscendW8A8LinearMethod(TestBase):
|
||||
self.assertEqual(layer.weight_scale.data.shape, (128, ))
|
||||
self.assertEqual(layer.weight_offset.data.shape, (128, ))
|
||||
mock_npu_format_cast.assert_called_once()
|
||||
|
||||
|
||||
class TestAscendW8A8FusedMoEMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.moe_method = AscendW8A8FusedMoEMethod()
|
||||
self.num_experts = 4
|
||||
self.intermediate_size = 64
|
||||
self.hidden_size = 128
|
||||
self.dtype = torch.float32
|
||||
|
||||
def test_init(self):
|
||||
self.assertTrue(self.moe_method.transpose_weight)
|
||||
|
||||
def test_get_weight(self):
|
||||
weights = self.moe_method.get_weight(
|
||||
num_experts=self.num_experts,
|
||||
intermediate_size_per_partition=self.intermediate_size,
|
||||
hidden_sizes=self.hidden_size,
|
||||
params_dtype=self.dtype)
|
||||
|
||||
assert "w13_weight" in weights, f"w13_weight not in {weights}"
|
||||
assert "w2_weight" in weights, f"w2_weight not in {weights}"
|
||||
self.assertEqual(
|
||||
weights["w13_weight"].shape,
|
||||
(self.num_experts, 2 * self.intermediate_size, self.hidden_size))
|
||||
self.assertEqual(
|
||||
weights["w2_weight"].shape,
|
||||
(self.num_experts, self.hidden_size, self.intermediate_size))
|
||||
self.assertEqual(weights["w13_weight"].dtype, torch.int8)
|
||||
self.assertEqual(weights["w2_weight"].dtype, torch.int8)
|
||||
self.assertFalse(weights["w13_weight"].requires_grad)
|
||||
self.assertFalse(weights["w2_weight"].requires_grad)
|
||||
|
||||
def test_get_dynamic_quant_param(self):
|
||||
quant_params = self.moe_method.get_dynamic_quant_param(
|
||||
num_experts=self.num_experts,
|
||||
intermediate_size_per_partition=self.intermediate_size,
|
||||
hidden_sizes=self.hidden_size,
|
||||
params_dtype=self.dtype)
|
||||
|
||||
expected_params = [
|
||||
"w13_weight_scale", "w13_weight_offset", "w2_weight_scale",
|
||||
"w2_weight_offset", "w2_deq_scale", "w13_deq_scale",
|
||||
"w2_input_scale", "w13_input_scale", "w2_input_offset",
|
||||
"w13_input_offset", "quant_bias"
|
||||
]
|
||||
|
||||
for param in expected_params:
|
||||
assert param in quant_params, f"{param} not in {quant_params}"
|
||||
|
||||
# Check some sample shapes
|
||||
self.assertEqual(quant_params["w13_weight_scale"].shape,
|
||||
(self.num_experts, 2 * self.intermediate_size, 1))
|
||||
self.assertEqual(quant_params["w2_input_offset"].shape,
|
||||
(self.num_experts, 1))
|
||||
self.assertEqual(quant_params["quant_bias"].shape,
|
||||
(self.num_experts, self.hidden_size))
|
||||
|
||||
@patch('vllm_ascend.quantization.w8a8.select_experts')
|
||||
@patch('vllm_ascend.quantization.w8a8.fused_experts')
|
||||
def test_apply_with_other_expert_count(self, mock_fused_experts,
|
||||
mock_select_experts):
|
||||
# Setup
|
||||
mock_layer = MagicMock()
|
||||
x = torch.randn(32, self.hidden_size)
|
||||
router_logits = torch.randn(32, 128) # 128 experts
|
||||
top_k = 2
|
||||
|
||||
# Mock return values
|
||||
mock_select_experts.return_value = (torch.randn(32, top_k),
|
||||
torch.randint(0, 128, (32, top_k)))
|
||||
mock_fused_experts.return_value = torch.randn(32, self.hidden_size)
|
||||
|
||||
# Test
|
||||
result = self.moe_method.apply(layer=mock_layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=True,
|
||||
global_num_experts=128)
|
||||
|
||||
# Assertions
|
||||
mock_select_experts.assert_called_once()
|
||||
mock_fused_experts.assert_called_once()
|
||||
self.assertEqual(result.shape, (32, self.hidden_size))
|
||||
|
||||
@patch('vllm_ascend.quantization.w8a8.get_ascend_device_type',
|
||||
return_value=AscendDeviceType._310P)
|
||||
@patch('vllm_ascend.quantization.w8a8.select_experts')
|
||||
@patch('vllm_ascend.quantization.w8a8.fused_experts_310p')
|
||||
def test_apply_is_310p(self, mock_fused_experts_310p, mock_select_experts,
|
||||
mock_soc_version):
|
||||
# Setup
|
||||
mock_layer = MagicMock()
|
||||
x = torch.randn(32, self.hidden_size)
|
||||
router_logits = torch.randn(32, 128) # 128 experts
|
||||
top_k = 2
|
||||
|
||||
# Mock return values
|
||||
mock_select_experts.return_value = (torch.randn(32, top_k),
|
||||
torch.randint(0, 128, (32, top_k)))
|
||||
mock_fused_experts_310p.return_value = torch.randn(
|
||||
32, self.hidden_size)
|
||||
|
||||
# Test
|
||||
result = self.moe_method.apply(layer=mock_layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=True,
|
||||
global_num_experts=128)
|
||||
|
||||
# Assertions
|
||||
mock_select_experts.assert_called_once()
|
||||
mock_fused_experts_310p.assert_called_once()
|
||||
self.assertEqual(result.shape, (32, self.hidden_size))
|
||||
|
||||
|
||||
class TestAscendC8KVCacheMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.layer = MagicMock()
|
||||
self.layer.num_kv_heads = 4
|
||||
self.layer.head_size = 64
|
||||
self.layer.num_heads = 8
|
||||
self.layer._k_scale_float = 1.0
|
||||
self.layer._v_scale_float = 1.0
|
||||
self.method = AscendC8KVCacheMethod()
|
||||
|
||||
self.attention_type = MagicMock()
|
||||
self.attention_type.DECODER = "decoder"
|
||||
self.attention_type.ENCODER = "encoder"
|
||||
|
||||
def test_create_weights(self):
|
||||
AscendC8KVCacheMethod.create_weights(self.layer)
|
||||
|
||||
self.layer.register_parameter.assert_any_call("key_antiquant_scale",
|
||||
unittest.mock.ANY)
|
||||
self.layer.register_parameter.assert_any_call("value_antiquant_scale",
|
||||
unittest.mock.ANY)
|
||||
|
||||
calls = self.layer.register_parameter.call_args_list
|
||||
|
||||
for call in calls:
|
||||
args, kwargs = call
|
||||
param = kwargs.get('parameter', args[1] if len(args) > 1 else None)
|
||||
|
||||
expected_shape = (self.layer.num_kv_heads * self.layer.head_size, )
|
||||
self.assertEqual(param.shape, expected_shape)
|
||||
|
||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
||||
return_value=AscendDeviceType.A3)
|
||||
def test_process_weights_after_loading_not_310p(self, mock_soc_version):
|
||||
key_data = torch.ones(4 * 64)
|
||||
value_data = torch.ones(4 * 64) * 2
|
||||
|
||||
self.layer.key_antiquant_scale.data = key_data
|
||||
self.layer.value_antiquant_scale.data = value_data
|
||||
|
||||
self.method.process_weights_after_loading(self.layer)
|
||||
|
||||
self.assertEqual(self.method.antiquant_scale_comb.shape, (2, 256))
|
||||
self.assertTrue(torch.all(self.method.antiquant_scale_comb[0] == 1))
|
||||
self.assertTrue(torch.all(self.method.antiquant_scale_comb[1] == 2))
|
||||
|
||||
@patch('vllm_ascend.utils.get_ascend_device_type',
|
||||
return_value=AscendDeviceType._310P)
|
||||
def test_process_weights_after_loading_is_310p(self, mock_soc_version):
|
||||
key_data = torch.ones(4 * 64)
|
||||
value_data = torch.ones(4 * 64) * 2
|
||||
|
||||
self.layer.key_antiquant_scale.data = key_data
|
||||
self.layer.value_antiquant_scale.data = value_data
|
||||
|
||||
self.method.process_weights_after_loading(self.layer)
|
||||
|
||||
self.assertEqual(self.method.antiquant_scale_comb.shape, (2, 256))
|
||||
self.assertTrue(torch.all(self.method.antiquant_scale_comb[0] == 1))
|
||||
self.assertTrue(torch.all(self.method.antiquant_scale_comb[1] == 2))
|
||||
|
||||
@patch('torch_npu.npu_scatter_nd_update_')
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
def test_apply_decode_only(self, mock_quant, mock_scatter):
|
||||
|
||||
num_tokens = 2
|
||||
query = torch.randn(num_tokens,
|
||||
self.layer.num_heads * self.layer.head_size)
|
||||
key = torch.randn(num_tokens,
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
value = torch.randn(num_tokens,
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
attn_metadata = MagicMock()
|
||||
attn_metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||
attn_metadata.seq_lens = [10, 10]
|
||||
attn_metadata.block_tables = torch.tensor([[0, 1], [1, 2]])
|
||||
attn_metadata.slot_mapping = torch.tensor([0, 1])
|
||||
attn_metadata.attn_mask = None
|
||||
|
||||
block_size = 16
|
||||
key_cache = torch.empty(2, block_size, self.layer.num_kv_heads,
|
||||
self.layer.head_size)
|
||||
value_cache = torch.empty(2, block_size, self.layer.num_kv_heads,
|
||||
self.layer.head_size)
|
||||
kv_cache = (key_cache, value_cache)
|
||||
|
||||
mock_quant.side_effect = [key, value]
|
||||
|
||||
self.layer.key_antiquant_scale.data = torch.ones(
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
self.layer.value_antiquant_scale.data = torch.ones(
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
self.method.process_weights_after_loading(self.layer)
|
||||
|
||||
expected_output = torch.randn(
|
||||
num_tokens, self.layer.num_heads * self.layer.head_size)
|
||||
with patch('torch_npu.npu_incre_flash_attention',
|
||||
return_value=expected_output):
|
||||
result = self.method.apply(self.layer, query, key, value, kv_cache,
|
||||
attn_metadata,
|
||||
self.attention_type.DECODER, 1.0,
|
||||
output)
|
||||
|
||||
self.assertEqual(mock_quant.call_count, 2)
|
||||
self.assertEqual(mock_scatter.call_count, 2)
|
||||
self.assertTrue(torch.equal(result, expected_output))
|
||||
|
||||
@patch('torch_npu.npu_scatter_nd_update_')
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
def test_apply_attn_metadata_without_decode(self, mock_quant,
|
||||
mock_scatter):
|
||||
|
||||
num_tokens = 2
|
||||
query = torch.randn(num_tokens,
|
||||
self.layer.num_heads * self.layer.head_size)
|
||||
key = torch.randn(num_tokens,
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
value = torch.randn(num_tokens,
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
attn_metadata = MagicMock(spec=[
|
||||
'attn_state', 'seq_lens', 'block_tables', 'slot_mapping',
|
||||
'attn_mask'
|
||||
])
|
||||
attn_metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||
attn_metadata.seq_lens = [10, 10]
|
||||
attn_metadata.block_tables = torch.tensor([[0, 1], [1, 2]])
|
||||
attn_metadata.slot_mapping = torch.tensor([0, 1])
|
||||
attn_metadata.attn_mask = None
|
||||
|
||||
block_size = 16
|
||||
key_cache = torch.empty(2, block_size, self.layer.num_kv_heads,
|
||||
self.layer.head_size)
|
||||
value_cache = torch.empty(2, block_size, self.layer.num_kv_heads,
|
||||
self.layer.head_size)
|
||||
kv_cache = (key_cache, value_cache)
|
||||
|
||||
mock_quant.side_effect = [key, value]
|
||||
|
||||
self.layer.key_antiquant_scale.data = torch.ones(
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
self.layer.value_antiquant_scale.data = torch.ones(
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
self.method.process_weights_after_loading(self.layer)
|
||||
|
||||
expected_output = torch.randn(
|
||||
num_tokens, self.layer.num_heads * self.layer.head_size)
|
||||
with patch('torch_npu.npu_incre_flash_attention',
|
||||
return_value=expected_output):
|
||||
result = self.method.apply(self.layer, query, key, value, kv_cache,
|
||||
attn_metadata,
|
||||
self.attention_type.DECODER, 1.0,
|
||||
output)
|
||||
|
||||
self.assertEqual(mock_quant.call_count, 2)
|
||||
self.assertEqual(mock_scatter.call_count, 2)
|
||||
self.assertTrue(torch.equal(result, expected_output))
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
@patch('torch_npu._npu_flash_attention')
|
||||
def test_apply_prefill_no_cache(self, mock_flash, mock_quant):
|
||||
"""Test apply method in prefill no-cache mode"""
|
||||
|
||||
num_tokens = 2
|
||||
query = torch.randn(num_tokens,
|
||||
self.layer.num_heads * self.layer.head_size)
|
||||
key = torch.randn(num_tokens,
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
value = torch.randn(num_tokens,
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
attn_metadata = MagicMock()
|
||||
attn_metadata.attn_state = AscendAttentionState.PrefillNoCache
|
||||
attn_metadata.seq_lens = [10, 10]
|
||||
attn_metadata.attn_mask = torch.ones(2, 2)
|
||||
|
||||
kv_cache = (torch.tensor([]), torch.tensor([]))
|
||||
mock_quant.return_value = key
|
||||
|
||||
result = self.method.apply(self.layer, query, key, value, kv_cache,
|
||||
attn_metadata, self.attention_type.DECODER,
|
||||
1.0, output)
|
||||
|
||||
# Check that flash attention was called
|
||||
mock_flash.assert_called_once()
|
||||
|
||||
# Check output shape
|
||||
self.assertEqual(
|
||||
result.shape,
|
||||
(num_tokens, self.layer.num_heads * self.layer.head_size))
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
def test_apply_unsupported_attention_type(self, mock_quant):
|
||||
|
||||
query = torch.randn(1, self.layer.num_heads * self.layer.head_size)
|
||||
key = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size)
|
||||
value = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
mock_quant.return_value = key
|
||||
|
||||
attn_metadata = MagicMock()
|
||||
attn_metadata.attn_state = AscendAttentionState.PrefillNoCache
|
||||
|
||||
with self.assertRaises(NotImplementedError) as cm:
|
||||
self.method.apply(self.layer, query, key, value, (None, None),
|
||||
attn_metadata, self.attention_type.ENCODER, 1.0,
|
||||
output)
|
||||
|
||||
assert "Encoder self-attention" in str(
|
||||
cm.exception), f"Encoder self-attention not in {str(cm.exception)}"
|
||||
assert "not implemented" in str(
|
||||
cm.exception), f"not implemented not in{str(cm.exception)}"
|
||||
|
||||
mock_quant.assert_not_called()
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
def test_apply_unsupported_attention_state(self, mock_quant):
|
||||
"""Test apply with unsupported attention state"""
|
||||
query = torch.randn(1, self.layer.num_heads * self.layer.head_size)
|
||||
key = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size)
|
||||
value = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
attn_metadata = MagicMock()
|
||||
attn_metadata.attn_state = AscendAttentionState.PrefillCacheHit
|
||||
mock_quant.return_value = key
|
||||
kv_cache = (torch.tensor([]), torch.tensor([]))
|
||||
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.method.apply(self.layer, query, key, value, kv_cache,
|
||||
attn_metadata, self.attention_type.DECODER, 1.0,
|
||||
output)
|
||||
|
||||
|
||||
class TestFusedExperts(TestBase):
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
@patch('vllm_ascend.quantization.w8a8.get_ep_group')
|
||||
@patch('torch_npu.npu_moe_init_routing_v2')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_moe_finalize_routing')
|
||||
def test_fused_experts_with_expert_map(self, mock_finalize, mock_swiglu,
|
||||
mock_group_matmul,
|
||||
mock_init_routing,
|
||||
mock_get_ep_group,
|
||||
mock_quant_per_tensor):
|
||||
num_tokens = 32
|
||||
hidden_size = 128
|
||||
intermediate_size = 256
|
||||
num_experts = 4
|
||||
top_k = 2
|
||||
|
||||
hidden_states = torch.randn(num_tokens, hidden_size)
|
||||
|
||||
w1 = torch.randn(num_experts, intermediate_size * 2, hidden_size)
|
||||
w1_scale = torch.tensor([0.1])
|
||||
w1_input_scale = torch.tensor([[0.2, 0.2], [0.2, 0.2]])
|
||||
w1_input_offset = torch.tensor([0])
|
||||
|
||||
w2 = torch.randn(num_experts, hidden_size, intermediate_size)
|
||||
w2_scale = torch.tensor([0.1])
|
||||
w2_input_scale = torch.tensor([0.2])
|
||||
w2_input_offset = torch.tensor([0])
|
||||
|
||||
topk_weights = torch.rand(num_tokens, top_k)
|
||||
topk_ids = torch.randint(0, num_experts, (num_tokens, top_k))
|
||||
expert_map = torch.arange(num_experts)
|
||||
|
||||
mock_get_ep_group.return_value.world_size = 8
|
||||
|
||||
mock_quant_per_tensor.return_value = torch.randint(-128,
|
||||
127,
|
||||
hidden_states.shape,
|
||||
dtype=torch.int8)
|
||||
|
||||
mock_init_routing.return_value = (torch.randn(num_tokens * top_k,
|
||||
hidden_size),
|
||||
torch.arange(num_tokens * top_k),
|
||||
torch.tensor([num_tokens // 2] * 2),
|
||||
torch.tensor(1.0))
|
||||
|
||||
mock_group_matmul.side_effect = [[
|
||||
torch.randn(num_tokens * top_k, intermediate_size * 2)
|
||||
], [torch.randn(num_tokens * top_k, hidden_size)]]
|
||||
|
||||
mock_swiglu.return_value = torch.randn(num_tokens * top_k,
|
||||
intermediate_size)
|
||||
|
||||
expected_output = torch.randn(num_tokens, hidden_size)
|
||||
mock_finalize.return_value = expected_output
|
||||
|
||||
output = fused_experts(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w1_input_scale=w1_input_scale,
|
||||
w1_input_offset=w1_input_offset,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
w2_input_offset=w2_input_offset,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
mock_init_routing.assert_called_once()
|
||||
|
||||
self.assertEqual(mock_group_matmul.call_count, 2)
|
||||
|
||||
self.assertEqual(output.shape, (num_tokens, hidden_size))
|
||||
|
||||
mock_finalize.assert_called_once()
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
@patch('vllm_ascend.quantization.w8a8.get_ep_group')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
def test_fused_experts_without_expert_map(self, mock_swiglu,
|
||||
mock_group_matmul,
|
||||
mock_get_ep_group,
|
||||
mock_quant_per_tensor):
|
||||
num_tokens = 16
|
||||
hidden_size = 64
|
||||
intermediate_size = 128
|
||||
num_experts = 8
|
||||
top_k = 1
|
||||
|
||||
hidden_states = torch.randn(num_tokens, hidden_size)
|
||||
w1 = torch.randn(num_experts, intermediate_size * 2, hidden_size)
|
||||
w2 = torch.randn(num_experts, hidden_size, intermediate_size)
|
||||
topk_weights = torch.rand(num_tokens, top_k)
|
||||
topk_ids = torch.randint(0, num_experts, (num_tokens, top_k))
|
||||
|
||||
mock_get_ep_group.return_value.world_size = 8
|
||||
|
||||
mock_quant_per_tensor.return_value = torch.randint(-128,
|
||||
127,
|
||||
hidden_states.shape,
|
||||
dtype=torch.int8)
|
||||
mock_group_matmul.side_effect = [[
|
||||
torch.randn(num_tokens * top_k, intermediate_size * 2)
|
||||
], [torch.randn(num_tokens * top_k, hidden_size)]]
|
||||
mock_swiglu.return_value = torch.randn(num_tokens * top_k,
|
||||
intermediate_size)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
fused_experts(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=torch.tensor([0.1]),
|
||||
w1_input_scale=torch.tensor([[0.2, 0.2], [0.2, 0.2]]),
|
||||
w1_input_offset=torch.tensor([0]),
|
||||
w2=w2,
|
||||
w2_scale=torch.tensor([0.1]),
|
||||
w2_input_scale=torch.tensor([0.1]),
|
||||
w2_input_offset=torch.tensor([0]),
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=num_experts,
|
||||
expert_map=None,
|
||||
)
|
||||
|
||||
|
||||
class TestFusedExperts310(TestBase):
|
||||
|
||||
@patch('torch_npu.npu_quant_grouped_matmul_dequant')
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
@patch('vllm_ascend.quantization.w8a8.get_ep_group')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
def test_fused_experts_310p_with_expert_map(self, mock_swiglu,
|
||||
mock_get_ep_group,
|
||||
mock_quant_per_tensor,
|
||||
mock_matmul_dequant):
|
||||
num_tokens = 32
|
||||
hidden_size = 128
|
||||
intermediate_size = 256
|
||||
num_experts = 4
|
||||
top_k = 1
|
||||
|
||||
hidden_states = torch.randn(num_tokens, hidden_size)
|
||||
|
||||
w1 = torch.randn(num_experts, intermediate_size * 2, hidden_size)
|
||||
w1_scale = torch.tensor([0.1])
|
||||
w1_input_scale = torch.tensor([[0.2, 0.2], [0.2, 0.2]])
|
||||
|
||||
w2 = torch.randn(num_experts, hidden_size, intermediate_size)
|
||||
w2_scale = torch.tensor([0.1])
|
||||
w2_input_scale = torch.tensor([0.2])
|
||||
|
||||
topk_weights = torch.rand(num_tokens, top_k)
|
||||
topk_ids = torch.randint(0, num_experts, (num_tokens, top_k))
|
||||
expert_map = torch.arange(num_experts)
|
||||
|
||||
mock_get_ep_group.return_value.world_size = 1
|
||||
|
||||
mock_quant_per_tensor.return_value = torch.randint(-128,
|
||||
127,
|
||||
hidden_states.shape,
|
||||
dtype=torch.int8)
|
||||
|
||||
mock_swiglu.return_value = torch.randn(num_tokens * top_k,
|
||||
intermediate_size)
|
||||
|
||||
mock_matmul_dequant.return_value = hidden_states
|
||||
|
||||
output = fused_experts_310p(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w1_input_scale=w1_input_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
self.assertEqual(output.shape, (num_tokens, hidden_size))
|
||||
self.assertEqual(mock_matmul_dequant.call_count, 2)
|
||||
|
||||
|
||||
class TestSelectExperts(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Common test data
|
||||
self.num_tokens = 10
|
||||
self.hidden_size = 32
|
||||
self.num_experts = 8
|
||||
self.top_k = 2
|
||||
|
||||
self.hidden_states = torch.randn(self.num_tokens, self.hidden_size)
|
||||
self.router_logits = torch.randn(self.num_tokens, self.num_experts)
|
||||
"""Mock custom routing"""
|
||||
self.mock_custom_routing = MagicMock()
|
||||
self.mock_custom_routing.return_value = (torch.ones(
|
||||
self.num_tokens, self.top_k),
|
||||
torch.zeros(
|
||||
self.num_tokens,
|
||||
self.top_k,
|
||||
dtype=torch.int32))
|
||||
|
||||
self.mock_ctx = MagicMock()
|
||||
self.mock_ctx.weight_prefetch_method = MagicMock()
|
||||
patcher = patch(
|
||||
'vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
|
||||
return_value=self.mock_ctx)
|
||||
self.addCleanup(patcher.stop)
|
||||
patcher.start()
|
||||
|
||||
@patch('torch_npu.npu_moe_gating_top_k')
|
||||
def test_softmax_scoring(self, mock_topk):
|
||||
"""Test softmax scoring function"""
|
||||
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
|
||||
torch.zeros(self.num_tokens,
|
||||
self.top_k,
|
||||
dtype=torch.long),
|
||||
torch.arange(0,
|
||||
self.num_tokens * self.top_k,
|
||||
dtype=torch.int32).view(
|
||||
self.top_k,
|
||||
-1).permute(1,
|
||||
0).contiguous())
|
||||
|
||||
weights, ids = select_experts(hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
scoring_func="softmax")
|
||||
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
|
||||
def test_sigmoid_scoring(self):
|
||||
"""Test sigmoid scoring function"""
|
||||
|
||||
weights, ids = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
scoring_func="sigmoid",
|
||||
custom_routing_function=self.mock_custom_routing)
|
||||
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
|
||||
def test_invalid_scoring_func(self):
|
||||
"""Test invalid scoring function raises ValueError"""
|
||||
with self.assertRaises(ValueError):
|
||||
select_experts(hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
scoring_func="invalid_func")
|
||||
|
||||
@patch('torch.topk')
|
||||
def test_grouped_topk(self, mock_topk):
|
||||
"""Test grouped topk functionality"""
|
||||
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
|
||||
torch.zeros(self.num_tokens,
|
||||
self.top_k,
|
||||
dtype=torch.long))
|
||||
|
||||
weights, ids = select_experts(hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=False,
|
||||
topk_group=4,
|
||||
num_expert_group=2)
|
||||
|
||||
mock_topk.assert_called()
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.dtype, torch.int32)
|
||||
|
||||
@patch('vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk')
|
||||
def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
|
||||
"""Test grouped topk with expert score correction bias"""
|
||||
mock_grouped_topk.return_value = torch.ones(self.num_tokens,
|
||||
self.num_experts)
|
||||
|
||||
e_score_correction_bias = torch.randn(self.num_experts)
|
||||
weights, ids = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=False,
|
||||
topk_group=4,
|
||||
num_expert_group=2,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
mock_grouped_topk.assert_called_once()
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
|
||||
def test_custom_routing_function(self):
|
||||
"""Test custom routing function"""
|
||||
weights, ids = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
custom_routing_function=self.mock_custom_routing)
|
||||
|
||||
self.mock_custom_routing.assert_called_once()
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.dtype, torch.int32)
|
||||
|
||||
@patch('torch_npu.npu_moe_gating_top_k')
|
||||
def test_renormalize(self, mock_topk):
|
||||
"""Test renormalization"""
|
||||
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
|
||||
torch.zeros(self.num_tokens,
|
||||
self.top_k,
|
||||
dtype=torch.long),
|
||||
torch.arange(0,
|
||||
self.num_tokens * self.top_k,
|
||||
dtype=torch.int32).view(
|
||||
self.top_k,
|
||||
-1).permute(1,
|
||||
0).contiguous())
|
||||
|
||||
weights, ids = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=True,
|
||||
)
|
||||
|
||||
# Check if weights are normalized (sum to 1 for each token)
|
||||
sums = weights.sum(dim=-1)
|
||||
self.assertTrue(torch.allclose(sums, torch.ones_like(sums)))
|
||||
|
||||
@patch('torch_npu.npu_moe_gating_top_k')
|
||||
def test_output_dtypes(self, mock_topk):
|
||||
"""Test output dtypes"""
|
||||
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
|
||||
torch.zeros(self.num_tokens,
|
||||
self.top_k,
|
||||
dtype=torch.int32),
|
||||
torch.arange(0,
|
||||
self.num_tokens * self.top_k,
|
||||
dtype=torch.int32).view(
|
||||
self.top_k,
|
||||
-1).permute(1,
|
||||
0).contiguous())
|
||||
|
||||
weights, ids = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
)
|
||||
|
||||
self.assertEqual(weights.dtype, self.hidden_states.dtype)
|
||||
self.assertEqual(ids.dtype, torch.int32)
|
||||
|
||||
|
||||
class TestNativeGroupedTopkPartialMock(TestBase):
|
||||
|
||||
def test_basic_group_selection(self):
|
||||
topk_weights = torch.tensor([[0.1, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4, 0.6],
|
||||
[0.6, 0.4, 0.7, 0.3, 0.8, 0.2, 0.9, 0.1],
|
||||
[0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3],
|
||||
[0.9, 0.1, 0.8, 0.2, 0.7, 0.3, 0.6, 0.4]],
|
||||
dtype=torch.float32)
|
||||
|
||||
expected_topk_indices = torch.tensor([[0, 1], [1, 0], [0, 1], [0, 1]])
|
||||
|
||||
with patch('torch.topk',
|
||||
return_value=(None, expected_topk_indices)) as mock_topk:
|
||||
result = _native_grouped_topk(topk_weights=topk_weights,
|
||||
num_expert_group=2,
|
||||
topk_group=2)
|
||||
|
||||
mock_topk.assert_called_once()
|
||||
|
||||
expected_result = topk_weights
|
||||
self.assertTrue(torch.allclose(result, expected_result))
|
||||
|
||||
def test_partial_group_selection(self):
|
||||
|
||||
topk_weights = torch.tensor([[0.1, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4, 0.6],
|
||||
[0.6, 0.4, 0.7, 0.3, 0.8, 0.2, 0.9, 0.1]])
|
||||
|
||||
expected_topk_indices = torch.tensor([[0], [1]])
|
||||
|
||||
with patch('torch.topk', return_value=(None, expected_topk_indices)):
|
||||
result = _native_grouped_topk(topk_weights=topk_weights,
|
||||
num_expert_group=2,
|
||||
topk_group=1)
|
||||
|
||||
expected_result = torch.tensor(
|
||||
[[0.1, 0.9, 0.2, 0.8, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.8, 0.2, 0.9, 0.1]])
|
||||
self.assertTrue(torch.allclose(result, expected_result))
|
||||
|
||||
def test_single_group(self):
|
||||
topk_weights = torch.tensor([[0.1, 0.9, 0.2], [0.8, 0.3, 0.7]])
|
||||
|
||||
expected_topk_indices = torch.tensor([[0], [0]])
|
||||
|
||||
with patch('torch.topk', return_value=(None, expected_topk_indices)):
|
||||
result = _native_grouped_topk(topk_weights=topk_weights,
|
||||
num_expert_group=1,
|
||||
topk_group=1)
|
||||
self.assertTrue(result.numel() > 0)
|
||||
|
||||
@@ -79,15 +79,6 @@ class AscendAttentionBackend(AttentionBackend):
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_bsh_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (2, num_blocks, block_size, num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: List[torch.Tensor],
|
||||
|
||||
@@ -166,10 +166,6 @@ class NPUPlatform(Platform):
|
||||
) if not isinstance(ascend_compilation_config, dict)
|
||||
else ascend_compilation_config)
|
||||
|
||||
kv_cache_dtype = vllm_config.additional_config.get(
|
||||
"kv_cache_dtype", None)
|
||||
if kv_cache_dtype is not None:
|
||||
vllm_config.cache_config.cache_dtype = kv_cache_dtype
|
||||
elif model_config and hasattr(model_config.hf_config, "index_topk"):
|
||||
vllm_config.cache_config.cache_dtype = str(
|
||||
model_config.dtype).replace("torch.", "")
|
||||
|
||||
@@ -134,9 +134,6 @@ class AscendQuantConfig(QuantizationConfig):
|
||||
'fa_quant_type' in self.quant_description.keys() and \
|
||||
self.quant_description['fa_quant_type'] is not None:
|
||||
return AscendKVCacheMethod(self, prefix)
|
||||
elif isinstance(layer, Attention) and self.quant_description.get(
|
||||
'kv_quant_type') == 'C8':
|
||||
return AscendKVCacheMethod(self, prefix)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
if self.is_layer_skipped_ascend(prefix,
|
||||
self.packed_modules_mapping):
|
||||
|
||||
@@ -9,8 +9,7 @@ from .w4a4_flatquant_dynamic import AscendW4A4FlatQuantDynamicLinearMethod
|
||||
from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod,
|
||||
AscendW4A8DynamicLinearMethod)
|
||||
from .w4a16 import AscendW4A16FusedMoEMethod
|
||||
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
|
||||
AscendW8A8LinearMethod)
|
||||
from .w8a8 import AscendW8A8LinearMethod
|
||||
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
|
||||
AscendW8A8DynamicLinearMethod)
|
||||
from .w8a8_pdmix import (AscendW8A8PDMixFusedMoeMethod,
|
||||
@@ -29,8 +28,6 @@ ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = {
|
||||
},
|
||||
"W8A8": {
|
||||
"linear": AscendW8A8LinearMethod,
|
||||
"moe": AscendW8A8FusedMoEMethod,
|
||||
"attention": AscendC8KVCacheMethod,
|
||||
},
|
||||
"W8A8_DYNAMIC": {
|
||||
"linear": AscendW8A8DynamicLinearMethod,
|
||||
@@ -39,10 +36,7 @@ ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = {
|
||||
"W8A8_MIX": {
|
||||
"linear": AscendW8A8PDMixLinearMethod,
|
||||
"moe": AscendW8A8PDMixFusedMoeMethod,
|
||||
},
|
||||
"C8": {
|
||||
"attention": AscendC8KVCacheMethod,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -100,9 +94,6 @@ def get_quant_method_modelslim(
|
||||
# Attention
|
||||
if '.attn' in prefix and 'fa_quant_type' in quant_description.keys():
|
||||
quant_type = quant_description['fa_quant_type']
|
||||
# Use KVCache int8
|
||||
elif '.attn' in prefix and 'kv_quant_type' in quant_description.keys():
|
||||
quant_type = quant_description['kv_quant_type']
|
||||
# Linear
|
||||
else:
|
||||
quant_type = get_linear_quant_type(quant_description, prefix,
|
||||
|
||||
@@ -15,16 +15,12 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
|
||||
COMPRESSED_TENSORS_METHOD, AscendDeviceType,
|
||||
get_ascend_device_type, is_enable_nz)
|
||||
@@ -205,509 +201,3 @@ class AscendW8A8LinearMethod:
|
||||
deq_scale = layer.input_scale.data * layer.weight_scale.data
|
||||
layer.deq_scale = torch.nn.Parameter(deq_scale,
|
||||
requires_grad=False)
|
||||
|
||||
|
||||
class AscendW8A8FusedMoEMethod:
|
||||
"""FusedMoe method for Ascend W8A8.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
|
||||
@staticmethod
|
||||
def get_weight(num_experts: int, intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
param_dict["w13_weight"] = torch.empty(num_experts,
|
||||
2 *
|
||||
intermediate_size_per_partition,
|
||||
hidden_sizes,
|
||||
dtype=torch.int8,
|
||||
requires_grad=False)
|
||||
param_dict["w2_weight"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition,
|
||||
dtype=torch.int8,
|
||||
requires_grad=False)
|
||||
return param_dict
|
||||
|
||||
@staticmethod
|
||||
def get_dynamic_quant_param(num_experts: int,
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
param_dict = {}
|
||||
param_dict["w13_weight_scale"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
param_dict["w13_weight_offset"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
1,
|
||||
dtype=torch.float16)
|
||||
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
1,
|
||||
dtype=torch.float16)
|
||||
param_dict["w2_deq_scale"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
dtype=torch.float32)
|
||||
param_dict["w13_deq_scale"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_input_scale"] = torch.empty(num_experts,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
param_dict["w13_input_scale"] = torch.empty(num_experts,
|
||||
1,
|
||||
dtype=torch.float32)
|
||||
param_dict["w2_input_offset"] = torch.empty(num_experts,
|
||||
1,
|
||||
dtype=torch.int8)
|
||||
param_dict["w13_input_offset"] = torch.empty(num_experts,
|
||||
1,
|
||||
dtype=torch.int8)
|
||||
param_dict["quant_bias"] = torch.empty(num_experts,
|
||||
hidden_sizes,
|
||||
dtype=torch.int32)
|
||||
|
||||
return param_dict
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = False,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
return fused_experts_310p(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w1_input_scale=layer.w13_input_scale,
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w2_input_scale=layer.w2_input_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w1_input_scale=layer.w13_input_scale,
|
||||
w1_input_offset=layer.w13_input_offset,
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w2_input_scale=layer.w2_input_scale,
|
||||
w2_input_offset=layer.w2_input_offset,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if get_ascend_device_type() != AscendDeviceType._310P:
|
||||
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
||||
layer.w13_weight_scale.data.shape[0], -1)
|
||||
|
||||
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
|
||||
layer.w13_weight_offset.data.shape[0], -1)
|
||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
|
||||
layer.w2_weight_scale.data.shape[0], -1)
|
||||
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
||||
layer.w2_weight_offset.data.shape[0], -1)
|
||||
expanding_factor_w13 = layer.w13_weight.data.shape[1]
|
||||
expanding_factor_w2 = layer.w2_weight.data.shape[1]
|
||||
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
layer.w13_input_scale.data = torch.nn.Parameter(
|
||||
layer.w13_input_scale.data.max())
|
||||
layer.w2_input_scale.data = torch.nn.Parameter(
|
||||
layer.w2_input_scale.data.max())
|
||||
else:
|
||||
layer.w13_input_scale.data = torch.nn.Parameter(
|
||||
layer.w13_input_scale.data.repeat(1,
|
||||
expanding_factor_w13)[0:1])
|
||||
layer.w2_input_scale.data = torch.nn.Parameter(
|
||||
layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1])
|
||||
|
||||
layer.w13_input_offset.data = torch.nn.Parameter(
|
||||
layer.w13_input_scale.data.repeat(1, expanding_factor_w13)[0:1])
|
||||
layer.w2_input_offset.data = torch.nn.Parameter(
|
||||
layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1])
|
||||
|
||||
# converting ACL_FORMAT_FRACTAL_NZ.
|
||||
# npu_quant_grouped_matmul_dequant in eager mode does not accept
|
||||
# ACL_FORMAT_FRACTAL_NZ.
|
||||
if get_ascend_device_type() != AscendDeviceType._310P and is_enable_nz(
|
||||
):
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
|
||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
|
||||
|
||||
|
||||
class AscendC8KVCacheMethod:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.antiquant_scale_comb = None
|
||||
|
||||
@staticmethod
|
||||
def create_weights(layer) -> None:
|
||||
param_dict = {} # num_kv_heads * head_size
|
||||
param_dict["key_antiquant_scale"] = torch.empty(layer.num_kv_heads *
|
||||
layer.head_size,
|
||||
dtype=torch.float16,
|
||||
requires_grad=False)
|
||||
param_dict["value_antiquant_scale"] = torch.empty(layer.num_kv_heads *
|
||||
layer.head_size,
|
||||
dtype=torch.float16,
|
||||
requires_grad=False)
|
||||
for weight_name, weight_param in param_dict.items():
|
||||
param = torch.nn.Parameter(weight_param, requires_grad=False)
|
||||
layer.register_parameter(weight_name, param)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
self.antiquant_scale_comb = torch.cat(
|
||||
(layer.key_antiquant_scale.data.unsqueeze(0),
|
||||
layer.value_antiquant_scale.data.unsqueeze(0)),
|
||||
dim=0).to(torch.float16).contiguous()
|
||||
|
||||
def apply(self, layer, query, key, value, kv_cache, attn_metadata,
|
||||
attn_type, scale, output) -> torch.Tensor:
|
||||
num_tokens = query.shape[0]
|
||||
if attn_metadata is None:
|
||||
return output.view(num_tokens, layer.num_heads * layer.head_size)
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
|
||||
# C8
|
||||
quant_key = quant_per_tensor(
|
||||
key.view(-1, layer.num_kv_heads * layer.head_size),
|
||||
layer.key_antiquant_scale.data.view(-1), None, True)
|
||||
quant_value = quant_per_tensor(
|
||||
value.view(-1, layer.num_kv_heads * layer.head_size),
|
||||
layer.value_antiquant_scale.data.view(-1), None, True)
|
||||
|
||||
# View q k v to BSH.
|
||||
query = query.view(-1, layer.num_heads, layer.head_size)
|
||||
key = key.view(-1, layer.num_kv_heads, layer.head_size)
|
||||
value = value.view(-1, layer.num_kv_heads, layer.head_size)
|
||||
# TODO: Remove this contiguous in the future.
|
||||
value = value.contiguous()
|
||||
|
||||
if kv_cache[0].numel() > 0:
|
||||
# if key_cache is None:
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
slots = attn_metadata.slot_mapping
|
||||
|
||||
block_size = key_cache.shape[1]
|
||||
slots_indices = slots.reshape(-1, 1)
|
||||
block_indices = slots_indices // block_size
|
||||
slots_indices = slots_indices % block_size
|
||||
indices = torch.cat((block_indices, slots_indices), dim=1)
|
||||
|
||||
# C8
|
||||
torch_npu.npu_scatter_nd_update_(key_cache, indices, quant_key)
|
||||
torch_npu.npu_scatter_nd_update_(value_cache, indices, quant_value)
|
||||
|
||||
# V0-Style scheduler situation.
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
assert attn_metadata is not None
|
||||
assert attn_metadata.attn_mask is not None
|
||||
mask = attn_metadata.attn_mask
|
||||
torch_npu._npu_flash_attention(query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
mask=mask,
|
||||
seq_len=attn_metadata.seq_lens,
|
||||
scale_value=scale,
|
||||
num_heads=layer.num_heads,
|
||||
num_kv_heads=layer.num_kv_heads,
|
||||
out=output.reshape(query.shape))
|
||||
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
||||
raise NotImplementedError("kv cache int8 are not "
|
||||
"implemented for "
|
||||
"PrefillCacheHit")
|
||||
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: # changed attn_metadata.attn_state == AscendAttentionState.DecodeOnly
|
||||
if hasattr(attn_metadata, "decode"):
|
||||
# torch_air
|
||||
decode_meta = attn_metadata.decode
|
||||
seq_lens = decode_meta.seq_lens_list
|
||||
else:
|
||||
seq_lens = attn_metadata.seq_lens
|
||||
block_size = key_cache.shape[1]
|
||||
query = query.view(num_tokens, 1, layer.num_heads *
|
||||
layer.head_size).contiguous() # changed
|
||||
|
||||
# [num_blocks, block_size, N, D] --> [num_blocks, N, block_size, D]
|
||||
key = key_cache
|
||||
value = value_cache
|
||||
|
||||
output = torch_npu.npu_incre_flash_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
num_key_value_heads=layer.num_kv_heads,
|
||||
num_heads=layer.num_heads,
|
||||
actual_seq_lengths=seq_lens,
|
||||
scale_value=scale,
|
||||
input_layout='BSH',
|
||||
block_size=block_size,
|
||||
block_table=attn_metadata.block_tables,
|
||||
antiquant_scale=self.antiquant_scale_comb,
|
||||
)
|
||||
|
||||
# Normal V1 situation.
|
||||
else:
|
||||
raise NotImplementedError("kv cache int8 are not "
|
||||
"implemented for "
|
||||
"other case")
|
||||
return output
|
||||
|
||||
|
||||
def fused_experts_310p(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w1_input_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
ep_size = get_ep_group().world_size
|
||||
local_num_experts = global_num_experts // ep_size
|
||||
local_num_group = top_k // ep_size
|
||||
|
||||
bsz, _ = hidden_states.shape
|
||||
flatten_topk_ids = topk_ids.view(-1)
|
||||
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
||||
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
|
||||
sorted_hidden_states = hidden_states.index_select(
|
||||
0, sorted_topk_ids // local_num_group)
|
||||
|
||||
experts_id = torch.arange(0,
|
||||
local_num_experts,
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device)
|
||||
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
||||
torch.float32).sum(0)
|
||||
topk_scales = topk_weights.view(-1).index_select(
|
||||
0, sorted_topk_ids).unsqueeze(-1)
|
||||
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
||||
|
||||
gate_up_out = torch_npu.npu_quant_grouped_matmul_dequant(
|
||||
x=sorted_hidden_states,
|
||||
quantized_weight=w1,
|
||||
weight_scale=w1_scale,
|
||||
group_list=group_list,
|
||||
x_scale=w1_input_scale,
|
||||
quant_mode="pertensor")
|
||||
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
||||
torch.float16)
|
||||
gate_up_out *= topk_scales
|
||||
|
||||
down_out = torch_npu.npu_quant_grouped_matmul_dequant(
|
||||
x=gate_up_out,
|
||||
quantized_weight=w2,
|
||||
weight_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
x_scale=w2_input_scale,
|
||||
quant_mode="pertensor")
|
||||
|
||||
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
|
||||
unsorted_hidden_states = down_out.index_select(0, unsorted_topk_ids)
|
||||
final_hidden_states = unsorted_hidden_states.reshape(
|
||||
bsz, top_k // ep_size, -1).sum(1)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w1_input_scale: torch.Tensor,
|
||||
w1_input_offset: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w2_input_scale: torch.Tensor,
|
||||
w2_input_offset: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Fused experts with top-k routing.
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
||||
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
|
||||
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
|
||||
topk_weights: Routing weights of shape (num_tokens, top_k).
|
||||
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
||||
top_k: Number of experts to select.
|
||||
expert_map: Expert mapping of shape (num_experts,).
|
||||
|
||||
Returns:
|
||||
hidden_states: Hidden states after routing.
|
||||
"""
|
||||
"""
|
||||
# Check constraints.
|
||||
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
||||
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
||||
"""
|
||||
|
||||
original_dtype = hidden_states.dtype
|
||||
ep_size = get_ep_group().world_size
|
||||
local_num_experts = global_num_experts // ep_size
|
||||
w1_input_scale, _ = w1_input_scale.max(0)
|
||||
quant_sorted_hidden_states = quant_per_tensor(
|
||||
hidden_states,
|
||||
w1_input_scale,
|
||||
None,
|
||||
True,
|
||||
)
|
||||
if expert_map is not None:
|
||||
expanded_x, expanded_row_idx, expert_token_count, expanded_scale = torch_npu.npu_moe_init_routing_v2(
|
||||
quant_sorted_hidden_states,
|
||||
topk_ids,
|
||||
scale=None,
|
||||
active_num=topk_ids.numel(),
|
||||
expert_capacity=-1,
|
||||
expert_num=local_num_experts,
|
||||
drop_pad_mode=0,
|
||||
expert_tokens_num_type=1,
|
||||
expert_tokens_num_flag=True,
|
||||
quant_mode=-1,
|
||||
active_expert_range=[0, local_num_experts],
|
||||
row_idx_type=0,
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"The quantified version of MOE class models "
|
||||
"currently does not support tensor parallelism")
|
||||
if expanded_x.dtype != w1.dtype:
|
||||
w1_input_scale, _ = w1_input_scale.max(0)
|
||||
quant_sorted_hidden_states = quant_per_tensor(
|
||||
expanded_x,
|
||||
w1_input_scale,
|
||||
None,
|
||||
True,
|
||||
)
|
||||
else:
|
||||
quant_sorted_hidden_states = expanded_x
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||
x=[quant_sorted_hidden_states],
|
||||
weight=[w1],
|
||||
scale=[w1_scale * w1_input_scale[0]],
|
||||
split_item=2,
|
||||
group_list_type=1,
|
||||
group_type=0,
|
||||
group_list=expert_token_count,
|
||||
output_dtype=original_dtype,
|
||||
)[0]
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
|
||||
if gate_up_out.dtype != w2.dtype:
|
||||
w2_input_scale, _ = w2_input_scale.max(0)
|
||||
quant_gate_up_out = quant_per_tensor(
|
||||
gate_up_out,
|
||||
w2_input_scale,
|
||||
None,
|
||||
True,
|
||||
)
|
||||
else:
|
||||
quant_gate_up_out = gate_up_out
|
||||
|
||||
down_out = torch_npu.npu_grouped_matmul(
|
||||
x=[quant_gate_up_out],
|
||||
weight=[w2],
|
||||
scale=[w2_scale * w2_input_scale[0]],
|
||||
split_item=2,
|
||||
group_list_type=1,
|
||||
group_type=0,
|
||||
group_list=expert_token_count,
|
||||
output_dtype=original_dtype,
|
||||
)[0]
|
||||
|
||||
if expert_map is not None:
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
down_out,
|
||||
skip1=None,
|
||||
skip2=None,
|
||||
bias=None,
|
||||
scales=topk_weights.to(down_out.dtype),
|
||||
expanded_src_to_dst_row=expanded_row_idx,
|
||||
export_for_source_row=topk_ids,
|
||||
drop_pad_mode=2,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"The quantified version of MOE class models "
|
||||
"currently does not support tensor parallelism")
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
@@ -2504,14 +2504,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# the min of all `num_blocks`. Verify it here.
|
||||
assert num_blocks >= kv_cache_config.num_blocks
|
||||
|
||||
if self.vllm_config.additional_config.get(
|
||||
"kv_cache_dtype", None) == 'int8':
|
||||
kv_cache_shape = attn_backend.get_bsh_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size)
|
||||
elif hasattr(attn_backend, "get_supported_block_size"
|
||||
) and self.use_hybrid_blocks:
|
||||
if hasattr(attn_backend, "get_supported_block_size"
|
||||
) and self.use_hybrid_blocks:
|
||||
block_size = attn_backend.get_supported_block_size()[0]
|
||||
|
||||
block_size_chunk = kv_cache_spec.block_size // block_size
|
||||
|
||||
Reference in New Issue
Block a user