[main][Feature] Support deepseek w4a8 quantization (#2172)
### What this PR does / why we need it?
Supports Deepseek-R1 w4a8 quantization.
Since R1 w4a8 uses mixed quantization, only the MOE layer uses
w4a8_dynamic quantization, so we added the w4a8_dynamic.py file, which
includes the AscendW4A8DynamicFusedMoEMethod class.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
Adding ut case in `tests/ut/quantization/test_w4a8_dynamic.py` and
`tests/ut/quantization/test_quantizer.py`
Adding e2e case in
`tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC`
to test deepseek w4a8_dynamic quantized model
#### 1.How to get weights using Modelslim
##### Installation steps
Use the branch master, the commit id is:
298e175d69b3b855111a1e09bbe2fcd12fdb4e24
git clone https://gitee.com/ascend/msit.git
cd msit/msmodelslim
bash install.sh
##### The required transformers environment
transformers>=4.48.2
##### Generate w4a8 weights
cd /example/DeepSeek
Command reference: msmodelslim/example/DeepSeek/README.md Execute the
[pre-check](https://gitee.com/ascend/msit/blob/master/msmodelslim/example/DeepSeek/README.md#%E8%BF%90%E8%A1%8C%E5%89%8D%E5%BF%85%E6%A3%80)
and [DeepSeek-R1 w4a8 mix
quantization](https://gitee.com/ascend/msit/blob/master/msmodelslim/example/DeepSeek/README.md#deepseek-r1-w4a8-%E6%B7%B7%E5%90%88%E9%87%8F%E5%8C%96%E5%89%8D%E4%B8%89%E5%B1%82-mlpw8a8-dynamic-%E9%87%8F%E5%8C%96mla%E5%85%B1%E4%BA%AB%E4%B8%93%E5%AE%B6w8a8%E9%87%8F%E5%8C%96%E8%B7%AF%E7%94%B1%E4%B8%93%E5%AE%B6w4a8-dynamic%E9%87%8F%E5%8C%96)
chapter
Reference command:python3 quant_deepseek_w4a8.py --model_path {Original
weight path} --save_path {Generate weight path} --mindie_format
##### Adapt to vllm-ascend
Since mindie_format generates mindie format, some adaptation
modifications are needed for vllm-ascend to use it:
`quant_model_description_w8a8_dynamic.json` rename to
`quant_model_description.json`, and add `"group_size": 256`
Modification in `config.json`:`"model_type":deepseekv2` is changed to
`"model_type":deepseek_v3`; `quantization_config` is removed;
tips:The group_size and weights match. If the w4a8 weights are not
generated using msmodelslim, you can check the group_size in
quantization_config in config.json.
#### 2.How to run w4a8
##### a.How to run eager mode
export VLLM_USE_V1=1 # v1
python -m vllm.entrypoints.openai.api_server --model=$1
--trust-remote-code -tp $2 -dp $3 --enable_expert_parallel
--quantization ascend --port $4 --max-model-len $5 --max-num-seqs $6
--enforce-eager
eg: python -m vllm.entrypoints.openai.api_server
--model=/weightpath/w4a8_4_layer --trust-remote-code -tp 4 -dp 4
--enable_expert_parallel --quantization ascend --port 8002
--max-model-len 5120 --max-num-seqs 128 --enforce-eager
##### b.How to run graph mode
export VLLM_USE_V1=1 # v1
export HCCL_BUFFSIZE=1024
python -m vllm.entrypoints.openai.api_server --model=$1
--trust-remote-code -tp $2 -dp $3 --enable_expert_parallel
--quantization ascend --port $4 --max-model-len $5
--additional_config='{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}'
eg: python -m vllm.entrypoints.openai.api_server
--model=/weight/dsr1_w4a8_vllm --trust-remote-code -tp 4 -dp 4
--enable_expert_parallel --quantization ascend --port 8002
--max-model-len 5120
--additional_config='{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":true}}'
- vLLM version: v0.10.0
- vLLM main:
c494f96fbc
---------
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
1
.github/workflows/vllm_ascend_test.yaml
vendored
1
.github/workflows/vllm_ascend_test.yaml
vendored
@@ -283,6 +283,7 @@ jobs:
|
|||||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo
|
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeekV3_dbo
|
||||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_alltoallv
|
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_alltoallv
|
||||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC
|
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Qwen3_W4A8DYNAMIC
|
||||||
|
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC
|
||||||
pytest -sv tests/e2e/multicard/test_data_parallel.py
|
pytest -sv tests/e2e/multicard/test_data_parallel.py
|
||||||
pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \
|
pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \
|
||||||
--ignore=tests/e2e/multicard/test_offline_inference_distributed.py \
|
--ignore=tests/e2e/multicard/test_offline_inference_distributed.py \
|
||||||
|
|||||||
@@ -209,3 +209,28 @@ def test_models_distributed_Qwen3_W4A8DYNAMIC():
|
|||||||
quantization="ascend",
|
quantization="ascend",
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
@patch.dict(os.environ, {"VLLM_ASCEND_MLA_PA": "1"})
|
||||||
|
def test_models_distributed_DeepSeek_W4A8DYNAMIC():
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
]
|
||||||
|
max_tokens = 5
|
||||||
|
with VllmRunner(
|
||||||
|
snapshot_download("vllm-ascend/DeepSeek-R1-w4a8-pruning"),
|
||||||
|
dtype="auto",
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
quantization="ascend",
|
||||||
|
enforce_eager=True,
|
||||||
|
enable_expert_parallel=True,
|
||||||
|
additional_config={
|
||||||
|
"torchair_graph_config": {
|
||||||
|
"enabled": False,
|
||||||
|
},
|
||||||
|
"ascend_scheduler_config": {
|
||||||
|
"enabled": True,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
) as vllm_model:
|
||||||
|
vllm_model.generate_greedy(prompts, max_tokens)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
|
|||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
||||||
from vllm_ascend.quantization.quantizer import (VLLMAscendQuantizer,
|
from vllm_ascend.quantization.quantizer import (VLLMAscendQuantizer,
|
||||||
|
W4A8DYNAMICQuantizer,
|
||||||
W8A8Quantizer)
|
W8A8Quantizer)
|
||||||
|
|
||||||
SUPPORT_ASCEND_QUANTIZER_TYPE = {"test": "1"}
|
SUPPORT_ASCEND_QUANTIZER_TYPE = {"test": "1"}
|
||||||
@@ -120,3 +121,25 @@ class TestW8A8Quantizer(TestBase):
|
|||||||
result = self.quantizer.build_attention_method()
|
result = self.quantizer.build_attention_method()
|
||||||
mock_linear.assert_called_once_with()
|
mock_linear.assert_called_once_with()
|
||||||
self.assertIsInstance(result, MagicMock)
|
self.assertIsInstance(result, MagicMock)
|
||||||
|
|
||||||
|
|
||||||
|
class TestW4A8DYNAMICQuantizer(TestBase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.quantizer = W4A8DYNAMICQuantizer(quant_description={})
|
||||||
|
|
||||||
|
def test_build_linear_method(self):
|
||||||
|
with patch(
|
||||||
|
'vllm_ascend.quantization.quantizer.AscendW4A8DynamicLinearMethod',
|
||||||
|
return_value=MagicMock()) as mock_linear:
|
||||||
|
result = self.quantizer.build_linear_method()
|
||||||
|
mock_linear.assert_called_once_with()
|
||||||
|
self.assertIsInstance(result, MagicMock)
|
||||||
|
|
||||||
|
def test_build_moe_method(self):
|
||||||
|
with patch(
|
||||||
|
'vllm_ascend.quantization.quantizer.AscendW4A8DynamicFusedMoEMethod',
|
||||||
|
return_value=MagicMock()) as mock_fused_moe:
|
||||||
|
result = self.quantizer.build_moe_method()
|
||||||
|
mock_fused_moe.assert_called_once_with()
|
||||||
|
self.assertIsInstance(result, MagicMock)
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend.quantization.w4a8_dynamic import AscendW4A8DynamicLinearMethod
|
from vllm_ascend.quantization.w4a8_dynamic import (
|
||||||
|
AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod)
|
||||||
|
|
||||||
|
|
||||||
class TestAscendW4A8DynamicLinearMethod(TestBase):
|
class TestAscendW4A8DynamicLinearMethod(TestBase):
|
||||||
@@ -25,3 +28,82 @@ class TestAscendW4A8DynamicLinearMethod(TestBase):
|
|||||||
self.assertEqual(params["weight_scale_second"].shape, (32, 1))
|
self.assertEqual(params["weight_scale_second"].shape, (32, 1))
|
||||||
self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16)
|
self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16)
|
||||||
self.assertEqual(params["weight_offset_second"].shape, (32, 1))
|
self.assertEqual(params["weight_offset_second"].shape, (32, 1))
|
||||||
|
|
||||||
|
|
||||||
|
class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
||||||
|
|
||||||
|
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group')
|
||||||
|
@patch("vllm_ascend.ascend_config.get_ascend_config")
|
||||||
|
@patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group')
|
||||||
|
@patch('torch.distributed.get_rank', return_value=0)
|
||||||
|
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ascend_config,
|
||||||
|
mock_get_ep_group):
|
||||||
|
mock_ascend_config = Mock()
|
||||||
|
mock_ascend_config.torchair_graph_config = Mock(enabled=False)
|
||||||
|
mock_get_ascend_config.return_value = mock_ascend_config
|
||||||
|
self.quant_method = AscendW4A8DynamicFusedMoEMethod()
|
||||||
|
|
||||||
|
def test_get_weight(self):
|
||||||
|
param_dict = self.quant_method.get_weight(8, 4, 14, torch.bfloat16)
|
||||||
|
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
|
||||||
|
self.assertEqual(param_dict["w13_weight"].shape, (8, 8, 14))
|
||||||
|
|
||||||
|
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
|
||||||
|
def test_get_dynamic_quant_param(self, mock_get_current_vllm_config):
|
||||||
|
mock_vllm_config = Mock()
|
||||||
|
mock_vllm_config.quant_config = Mock(
|
||||||
|
quant_description={"group_size": 2})
|
||||||
|
mock_get_current_vllm_config.return_value = mock_vllm_config
|
||||||
|
param_dict = self.quant_method.get_dynamic_quant_param(
|
||||||
|
8, 4, 14, torch.bfloat16)
|
||||||
|
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
|
||||||
|
self.assertEqual(param_dict["w13_weight_scale"].shape, (8, 8, 1))
|
||||||
|
self.assertEqual(param_dict["w13_weight_scale_second"].dtype,
|
||||||
|
torch.bfloat16)
|
||||||
|
self.assertEqual(param_dict["w13_weight_scale_second"].shape,
|
||||||
|
(8, 8, 7))
|
||||||
|
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16)
|
||||||
|
self.assertEqual(param_dict["w2_weight_scale"].shape, (8, 14, 1))
|
||||||
|
self.assertEqual(param_dict["w2_weight_scale_second"].dtype,
|
||||||
|
torch.bfloat16)
|
||||||
|
self.assertEqual(param_dict["w2_weight_scale_second"].shape,
|
||||||
|
(8, 14, 2))
|
||||||
|
|
||||||
|
@patch('torch_npu.npu_quantize')
|
||||||
|
@patch('torch.Tensor.npu')
|
||||||
|
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
|
||||||
|
layer = torch.nn.Module()
|
||||||
|
layer.w13_weight = torch.nn.Parameter(torch.zeros((8, 8, 14),
|
||||||
|
dtype=torch.int8),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.w2_weight = torch.nn.Parameter(torch.zeros((8, 14, 4),
|
||||||
|
dtype=torch.int8),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||||
|
(8, 8, 1), dtype=torch.bfloat16),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.w13_weight_offset = torch.nn.Parameter(torch.zeros(
|
||||||
|
(8, 8, 1), dtype=torch.bfloat16),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones(
|
||||||
|
(8, 8, 7), dtype=torch.bfloat16),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||||
|
(8, 14, 1), dtype=torch.bfloat16),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.w2_weight_offset = torch.nn.Parameter(torch.zeros(
|
||||||
|
(8, 14, 1), dtype=torch.bfloat16),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones(
|
||||||
|
(8, 14, 2), dtype=torch.bfloat16),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
mock_npu.return_value = torch.Tensor()
|
||||||
|
mock_npu_quantize.return_value = torch.Tensor()
|
||||||
|
self.quant_method.process_weights_after_loading(layer)
|
||||||
|
self.assertTrue(hasattr(layer, "w13_scale_bias"))
|
||||||
|
self.assertEqual(layer.w13_scale_bias.data.shape, (8, 8))
|
||||||
|
self.assertEqual(layer.w13_scale_bias.data.dtype, torch.float32)
|
||||||
|
self.assertTrue(hasattr(layer, "w2_scale_bias"))
|
||||||
|
self.assertEqual(layer.w2_scale_bias.data.shape, (8, 14))
|
||||||
|
self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32)
|
||||||
|
|||||||
@@ -905,6 +905,8 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
|
|||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
if "module" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||||
if spec_layer is not None:
|
if spec_layer is not None:
|
||||||
|
|||||||
@@ -302,6 +302,9 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
|
|||||||
param = torch.nn.Parameter(param_value, requires_grad=False)
|
param = torch.nn.Parameter(param_value, requires_grad=False)
|
||||||
layer.register_parameter(param_key, param)
|
layer.register_parameter(param_key, param)
|
||||||
set_weight_attrs(param, extra_weight_attrs)
|
set_weight_attrs(param, extra_weight_attrs)
|
||||||
|
if "weight_scale_second" in param_key or "weight_offset_second" in param_key:
|
||||||
|
setattr(param, "quant_method",
|
||||||
|
FusedMoeWeightScaleSupported.GROUP.value)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@@ -348,4 +351,4 @@ class AscendEmbeddingMethod(AscendLinearMethod):
|
|||||||
packed_modules_mapping: Dict[str, Any]) -> None:
|
packed_modules_mapping: Dict[str, Any]) -> None:
|
||||||
self.quantizer = AscendQuantizer.get_quantizer(
|
self.quantizer = AscendQuantizer.get_quantizer(
|
||||||
quant_config.quant_description, prefix, packed_modules_mapping)
|
quant_config.quant_description, prefix, packed_modules_mapping)
|
||||||
self.quant_method = self.quantizer.build_linear_method()
|
self.quant_method = self.quantizer.build_linear_method()
|
||||||
|
|||||||
@@ -24,7 +24,8 @@ from vllm.logger import logger
|
|||||||
|
|
||||||
from .func_wrapper import (wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init,
|
from .func_wrapper import (wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init,
|
||||||
wrapper_vocab_parallel_embedding_init)
|
wrapper_vocab_parallel_embedding_init)
|
||||||
from .w4a8_dynamic import AscendW4A8DynamicLinearMethod
|
from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod,
|
||||||
|
AscendW4A8DynamicLinearMethod)
|
||||||
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
|
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
|
||||||
AscendW8A8LinearMethod)
|
AscendW8A8LinearMethod)
|
||||||
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
|
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
|
||||||
@@ -97,12 +98,15 @@ class VLLMAscendQuantizer:
|
|||||||
if target_function is not None:
|
if target_function is not None:
|
||||||
setattr(original_module, target_function, candidate)
|
setattr(original_module, target_function, candidate)
|
||||||
|
|
||||||
for key, value in sys.modules.copy().items():
|
for _, value in sys.modules.copy().items():
|
||||||
if (target_function is not None
|
if target_function is None:
|
||||||
and hasattr(value, target_function)
|
continue
|
||||||
and id(getattr(value,
|
try:
|
||||||
target_function)) == original_function_id):
|
attr = getattr(value, target_function, None)
|
||||||
setattr(value, target_function, candidate)
|
if attr is not None and id(attr) == original_function_id:
|
||||||
|
setattr(value, target_function, candidate)
|
||||||
|
except ImportError:
|
||||||
|
continue
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_path(module_path, function_name, create_dummy):
|
def parse_path(module_path, function_name, create_dummy):
|
||||||
@@ -268,6 +272,10 @@ class W4A8DYNAMICQuantizer(VLLMAscendQuantizer):
|
|||||||
def build_linear_method():
|
def build_linear_method():
|
||||||
return AscendW4A8DynamicLinearMethod()
|
return AscendW4A8DynamicLinearMethod()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_moe_method():
|
||||||
|
return AscendW4A8DynamicFusedMoEMethod()
|
||||||
|
|
||||||
|
|
||||||
class W8A8Quantizer(VLLMAscendQuantizer):
|
class W8A8Quantizer(VLLMAscendQuantizer):
|
||||||
|
|
||||||
|
|||||||
@@ -15,11 +15,22 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Callable, Dict, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
|
from vllm.distributed import get_ep_group
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||||
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
|
from vllm_ascend.ops.fused_moe import select_experts
|
||||||
|
from vllm_ascend.quantization.w8a8_dynamic import (fused_experts_with_all2all,
|
||||||
|
fused_experts_with_mc2)
|
||||||
|
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||||
|
|
||||||
|
|
||||||
class AscendW4A8DynamicLinearMethod:
|
class AscendW4A8DynamicLinearMethod:
|
||||||
@@ -111,3 +122,275 @@ class AscendW4A8DynamicLinearMethod:
|
|||||||
layer.register_parameter("weight_scale_bias", param)
|
layer.register_parameter("weight_scale_bias", param)
|
||||||
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
|
layer.weight.data = torch_npu.npu_convert_weight_to_int4pack(
|
||||||
layer.weight.data.to(torch.int32))
|
layer.weight.data.to(torch.int32))
|
||||||
|
|
||||||
|
|
||||||
|
class AscendW4A8DynamicFusedMoEMethod:
|
||||||
|
"""FusedMoe method for Ascend W4A8_DYNAMIC.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.transpose_weight = True
|
||||||
|
|
||||||
|
self.ep_group = get_ep_group()
|
||||||
|
|
||||||
|
ascend_config = get_ascend_config()
|
||||||
|
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||||
|
|
||||||
|
try:
|
||||||
|
device_group = get_mc2_group().device_group
|
||||||
|
# TODO: Try local_rank = ep_group.rank_in_group
|
||||||
|
local_rank = torch.distributed.get_rank(group=device_group)
|
||||||
|
backend = device_group._get_backend(torch.device("npu"))
|
||||||
|
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
|
||||||
|
local_rank)
|
||||||
|
except AttributeError:
|
||||||
|
self.moe_all_to_all_group_name = ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_weight(num_experts: int, intermediate_size_per_partition: int,
|
||||||
|
hidden_sizes: int,
|
||||||
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||||
|
param_dict = {}
|
||||||
|
param_dict["w13_weight"] = torch.empty(num_experts,
|
||||||
|
2 *
|
||||||
|
intermediate_size_per_partition,
|
||||||
|
hidden_sizes,
|
||||||
|
dtype=torch.int8)
|
||||||
|
param_dict["w2_weight"] = torch.empty(num_experts,
|
||||||
|
hidden_sizes,
|
||||||
|
intermediate_size_per_partition,
|
||||||
|
dtype=torch.int8)
|
||||||
|
return param_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_dynamic_quant_param(num_experts: int,
|
||||||
|
intermediate_size_per_partition: int,
|
||||||
|
hidden_sizes: int,
|
||||||
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||||
|
param_dict = {}
|
||||||
|
config = get_current_vllm_config()
|
||||||
|
group_size = config.quant_config.quant_description.get(
|
||||||
|
"group_size", 256)
|
||||||
|
|
||||||
|
param_dict["w13_weight_scale"] = torch.empty(
|
||||||
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition,
|
||||||
|
1,
|
||||||
|
dtype=params_dtype)
|
||||||
|
|
||||||
|
param_dict["w13_weight_offset"] = torch.empty(
|
||||||
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition,
|
||||||
|
1,
|
||||||
|
dtype=params_dtype)
|
||||||
|
|
||||||
|
param_dict["w13_weight_scale_second"] = torch.empty(
|
||||||
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition,
|
||||||
|
hidden_sizes // group_size,
|
||||||
|
dtype=params_dtype)
|
||||||
|
|
||||||
|
param_dict["w13_weight_offset_second"] = torch.empty(
|
||||||
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition,
|
||||||
|
hidden_sizes // group_size,
|
||||||
|
dtype=params_dtype)
|
||||||
|
|
||||||
|
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
||||||
|
hidden_sizes,
|
||||||
|
1,
|
||||||
|
dtype=params_dtype)
|
||||||
|
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
||||||
|
hidden_sizes,
|
||||||
|
1,
|
||||||
|
dtype=params_dtype)
|
||||||
|
param_dict["w2_weight_scale_second"] = torch.empty(
|
||||||
|
num_experts,
|
||||||
|
hidden_sizes,
|
||||||
|
intermediate_size_per_partition // group_size,
|
||||||
|
dtype=params_dtype)
|
||||||
|
param_dict["w2_weight_offset_second"] = torch.empty(
|
||||||
|
num_experts,
|
||||||
|
hidden_sizes,
|
||||||
|
intermediate_size_per_partition // group_size,
|
||||||
|
dtype=params_dtype)
|
||||||
|
|
||||||
|
return param_dict
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
is_prefill: bool = True,
|
||||||
|
enable_force_load_balance: bool = True,
|
||||||
|
log2phy: torch.Tensor = None,
|
||||||
|
global_redundant_expert_num: int = 0,
|
||||||
|
shared_experts: Optional[Any] = None,
|
||||||
|
quantized_x_for_share: Optional[Any] = None,
|
||||||
|
dynamic_scale_for_share: Optional[Any] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert router_logits.shape[
|
||||||
|
1] == global_num_experts, "Number of global experts mismatch"
|
||||||
|
|
||||||
|
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||||
|
if global_num_experts == 256:
|
||||||
|
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||||
|
router_logits,
|
||||||
|
k=top_k, # topk currently is 8
|
||||||
|
bias=e_score_correction_bias,
|
||||||
|
k_group=topk_group, # fix: 4
|
||||||
|
group_count=num_expert_group, # fix 8
|
||||||
|
group_select_mode=
|
||||||
|
1, # 0: the maximum in the group; 1: topk2.sum(fix)
|
||||||
|
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
||||||
|
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
||||||
|
# out_flag=False, # todo new api; should the third output be output
|
||||||
|
# y2_flag=False, # old api; should the third output be output
|
||||||
|
routed_scaling_factor=1,
|
||||||
|
eps=float(1e-20))
|
||||||
|
else:
|
||||||
|
topk_weights, topk_ids = select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
top_k=top_k,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
fused_moe_state = get_forward_context().fused_moe_state
|
||||||
|
shared_gate_up, shared_dequant_scale = None, None
|
||||||
|
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
|
||||||
|
with npu_stream_switch("moe_secondary", 0):
|
||||||
|
npu_wait_tensor(quantized_x_for_share, router_logits)
|
||||||
|
share_up_out, _ = shared_experts.gate_up_proj(
|
||||||
|
(quantized_x_for_share, dynamic_scale_for_share))
|
||||||
|
shared_gate_up, shared_dequant_scale = share_up_out[
|
||||||
|
0], share_up_out[1]
|
||||||
|
|
||||||
|
# this is a naive implementation for experts load balance so as
|
||||||
|
# to avoid accumulating too much tokens on a single rank.
|
||||||
|
# currently it is only activated when doing profile runs.
|
||||||
|
if enable_force_load_balance:
|
||||||
|
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||||
|
|
||||||
|
topk_weights = topk_weights.to(x.dtype)
|
||||||
|
if fused_moe_state == FusedMoEState.MC2:
|
||||||
|
return fused_experts_with_mc2(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
w1_scale=layer.w13_weight_scale_second,
|
||||||
|
w2_scale=layer.w2_weight_scale_second,
|
||||||
|
w1_scale_bias=layer.w13_scale_bias,
|
||||||
|
w2_scale_bias=layer.w2_scale_bias,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
top_k=top_k,
|
||||||
|
expert_map=expert_map,
|
||||||
|
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
||||||
|
log2phy=log2phy,
|
||||||
|
global_redundant_expert_num=global_redundant_expert_num,
|
||||||
|
shared_experts=shared_experts,
|
||||||
|
is_torchair=self.torchair_graph_enabled,
|
||||||
|
quantized_x_for_share=shared_gate_up,
|
||||||
|
dynamic_scale_for_share=shared_dequant_scale,
|
||||||
|
mc2_mask=kwargs.get("mc2_mask", None))
|
||||||
|
else:
|
||||||
|
# The current implementation of deepseek moe splits hidden_states
|
||||||
|
# according to tp_size before they are feed into fused_moe module.
|
||||||
|
# Therefore, all2all is needed no matter how dp/tp is set so as to
|
||||||
|
# dispatch/combine tokens.
|
||||||
|
return fused_experts_with_all2all(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
w1_scale=layer.w13_weight_scale_second,
|
||||||
|
w2_scale=layer.w2_weight_scale_second,
|
||||||
|
w1_scale_bias=layer.w13_scale_bias,
|
||||||
|
w2_scale_bias=layer.w2_scale_bias,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
top_k=top_k,
|
||||||
|
expert_map=expert_map,
|
||||||
|
ep_group=self.ep_group,
|
||||||
|
log2phy=log2phy,
|
||||||
|
global_redundant_expert_num=global_redundant_expert_num,
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
|
||||||
|
group_num, k, n = weight.shape
|
||||||
|
per_group_scale = per_group_scale.reshape(group_num, -1, n)
|
||||||
|
group_num, quantgroup_num, n = per_group_scale.shape
|
||||||
|
weight_high = weight.to(torch.float32).reshape([group_num, quantgroup_num, -1, n]) * \
|
||||||
|
per_group_scale.reshape([group_num, quantgroup_num, 1, n])
|
||||||
|
weight_high = weight_high.reshape([group_num, k, n])
|
||||||
|
bias = 8 * (weight_high.to(torch.float32) * scale).sum(axis=1)
|
||||||
|
scale_fp32 = (scale * per_group_scale).to(torch.float16).to(
|
||||||
|
torch.float32)
|
||||||
|
scale_fp32_np = scale_fp32.cpu().numpy()
|
||||||
|
scale_fp32_np.dtype = np.uint32
|
||||||
|
sscale_uint64 = np.zeros((group_num, quantgroup_num, n * 2),
|
||||||
|
dtype=np.uint32)
|
||||||
|
|
||||||
|
sscale_uint64[..., ::2] = scale_fp32_np
|
||||||
|
|
||||||
|
sscale_uint64_buffer = np.frombuffer(sscale_uint64.tobytes(),
|
||||||
|
dtype=np.int64).copy()
|
||||||
|
sscale_uint64_tensor = torch.from_numpy(sscale_uint64_buffer).reshape(
|
||||||
|
group_num, quantgroup_num, n)
|
||||||
|
sscale_uint64_tensor = sscale_uint64_tensor.npu()
|
||||||
|
return sscale_uint64_tensor, bias
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer):
|
||||||
|
if self.transpose_weight:
|
||||||
|
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
||||||
|
1, 2).contiguous()
|
||||||
|
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||||
|
1, 2).contiguous()
|
||||||
|
layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose(
|
||||||
|
1, 2).contiguous()
|
||||||
|
layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose(
|
||||||
|
1, 2).contiguous()
|
||||||
|
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
|
||||||
|
layer.w13_weight_offset.data.shape[0], -1)
|
||||||
|
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
||||||
|
layer.w2_weight_offset.data.shape[0], -1)
|
||||||
|
layer.w13_weight_scale_second.data = layer.w13_weight_scale_second.data.transpose(
|
||||||
|
1, 2).contiguous()
|
||||||
|
layer.w2_weight_scale_second.data = layer.w2_weight_scale_second.data.transpose(
|
||||||
|
1, 2).contiguous()
|
||||||
|
|
||||||
|
layer.w13_weight_scale_second.data, bias = self.process_scale(
|
||||||
|
layer.w13_weight, layer.w13_weight_scale.data,
|
||||||
|
layer.w13_weight_scale_second.data)
|
||||||
|
param = torch.nn.Parameter(bias, requires_grad=False)
|
||||||
|
layer.register_parameter("w13_scale_bias", param)
|
||||||
|
layer.w2_weight_scale_second.data, bias1 = self.process_scale(
|
||||||
|
layer.w2_weight, layer.w2_weight_scale.data,
|
||||||
|
layer.w2_weight_scale_second.data)
|
||||||
|
param = torch.nn.Parameter(bias1, requires_grad=False)
|
||||||
|
layer.register_parameter("w2_scale_bias", param)
|
||||||
|
|
||||||
|
layer.w13_weight.data = torch_npu.npu_quantize(
|
||||||
|
layer.w13_weight.data.to(torch.float32),
|
||||||
|
torch.tensor([1.]).npu(), None, torch.quint4x2, -1, False)
|
||||||
|
layer.w2_weight.data = torch_npu.npu_quantize(
|
||||||
|
layer.w2_weight.data.to(torch.float32),
|
||||||
|
torch.tensor([1.]).npu(), None, torch.quint4x2, -1, False)
|
||||||
|
|||||||
@@ -116,7 +116,9 @@ def apply_mlp(hidden_states: torch.Tensor,
|
|||||||
w2_scale: torch.Tensor,
|
w2_scale: torch.Tensor,
|
||||||
group_list: torch.Tensor,
|
group_list: torch.Tensor,
|
||||||
dynamic_scale: torch.Tensor = None,
|
dynamic_scale: torch.Tensor = None,
|
||||||
group_list_type: int = 1) -> torch.Tensor:
|
group_list_type: int = 1,
|
||||||
|
w1_scale_bias: torch.Tensor = None,
|
||||||
|
w2_scale_bias: torch.Tensor = None) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
apply MLP: gate_up_proj -> swiglu -> down_proj
|
apply MLP: gate_up_proj -> swiglu -> down_proj
|
||||||
|
|
||||||
@@ -150,17 +152,31 @@ def apply_mlp(hidden_states: torch.Tensor,
|
|||||||
else:
|
else:
|
||||||
pertoken_scale = dynamic_scale
|
pertoken_scale = dynamic_scale
|
||||||
|
|
||||||
|
bias1, bias2 = None, None
|
||||||
|
_output_dtype = w2_scale.dtype
|
||||||
|
|
||||||
|
if w1_scale_bias is not None:
|
||||||
|
if group_list_type == 0:
|
||||||
|
group_list = torch.cat(
|
||||||
|
[group_list[:1], torch.diff(group_list, dim=0)])
|
||||||
|
group_list_type = 1
|
||||||
|
bias1 = [w1_scale_bias]
|
||||||
|
bias2 = [w2_scale_bias]
|
||||||
|
# TODO w4a8 scene: dynamic acquisition of dtype in the future
|
||||||
|
_output_dtype = torch.bfloat16
|
||||||
|
|
||||||
# gmm1: gate_up_proj
|
# gmm1: gate_up_proj
|
||||||
hidden_states = torch_npu.npu_grouped_matmul(
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
x=[hidden_states],
|
x=[hidden_states],
|
||||||
weight=[w1],
|
weight=[w1],
|
||||||
scale=[w1_scale],
|
scale=[w1_scale],
|
||||||
|
bias=bias1,
|
||||||
per_token_scale=[pertoken_scale],
|
per_token_scale=[pertoken_scale],
|
||||||
split_item=2,
|
split_item=2,
|
||||||
group_list_type=group_list_type,
|
group_list_type=group_list_type,
|
||||||
group_type=0,
|
group_type=0,
|
||||||
group_list=group_list,
|
group_list=group_list,
|
||||||
output_dtype=w2_scale.dtype)[0]
|
output_dtype=_output_dtype)[0]
|
||||||
|
|
||||||
# act_fn: swiglu
|
# act_fn: swiglu
|
||||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||||
@@ -172,12 +188,13 @@ def apply_mlp(hidden_states: torch.Tensor,
|
|||||||
x=[hidden_states],
|
x=[hidden_states],
|
||||||
weight=[w2],
|
weight=[w2],
|
||||||
scale=[w2_scale],
|
scale=[w2_scale],
|
||||||
|
bias=bias2,
|
||||||
per_token_scale=[swiglu_out_scale],
|
per_token_scale=[swiglu_out_scale],
|
||||||
split_item=2,
|
split_item=2,
|
||||||
group_list_type=group_list_type,
|
group_list_type=group_list_type,
|
||||||
group_type=0,
|
group_type=0,
|
||||||
group_list=group_list,
|
group_list=group_list,
|
||||||
output_dtype=w2_scale.dtype)[0]
|
output_dtype=_output_dtype)[0]
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@@ -202,6 +219,8 @@ def fused_experts_with_mc2(
|
|||||||
mc2_mask: Optional[torch.Tensor] = None,
|
mc2_mask: Optional[torch.Tensor] = None,
|
||||||
shared_gate_up: Optional[Any] = None,
|
shared_gate_up: Optional[Any] = None,
|
||||||
shared_dequant_scale: Optional[Any] = None,
|
shared_dequant_scale: Optional[Any] = None,
|
||||||
|
w1_scale_bias: torch.Tensor = None,
|
||||||
|
w2_scale_bias: torch.Tensor = None,
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
assert mc2_mask is not None
|
assert mc2_mask is not None
|
||||||
if log2phy is not None:
|
if log2phy is not None:
|
||||||
@@ -270,13 +289,25 @@ def fused_experts_with_mc2(
|
|||||||
shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1]
|
shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1]
|
||||||
|
|
||||||
# `expand_x` will be disposed in the `apply_mlp` function
|
# `expand_x` will be disposed in the `apply_mlp` function
|
||||||
down_out_list = apply_mlp_decode(expand_x,
|
if w1_scale_bias is None:
|
||||||
w1,
|
down_out_list = apply_mlp_decode(expand_x,
|
||||||
w1_scale,
|
w1,
|
||||||
w2,
|
w1_scale,
|
||||||
w2_scale,
|
w2,
|
||||||
expert_token_nums,
|
w2_scale,
|
||||||
dynamic_scale=dynamic_scale)
|
expert_token_nums,
|
||||||
|
dynamic_scale=dynamic_scale)
|
||||||
|
else:
|
||||||
|
# w4a8 scene, cannot use apply_mlp_decode because the operator is not supported
|
||||||
|
down_out_list = apply_mlp(expand_x,
|
||||||
|
w1,
|
||||||
|
w1_scale,
|
||||||
|
w2,
|
||||||
|
w2_scale,
|
||||||
|
expert_token_nums,
|
||||||
|
dynamic_scale=dynamic_scale,
|
||||||
|
w1_scale_bias=w1_scale_bias,
|
||||||
|
w2_scale_bias=w2_scale_bias)
|
||||||
|
|
||||||
# moeCombine
|
# moeCombine
|
||||||
kwargs_mc2 = {
|
kwargs_mc2 = {
|
||||||
@@ -372,6 +403,8 @@ def fused_experts_with_all2all(
|
|||||||
ep_group: GroupCoordinator = None,
|
ep_group: GroupCoordinator = None,
|
||||||
log2phy: torch.Tensor = None,
|
log2phy: torch.Tensor = None,
|
||||||
global_redundant_expert_num: int = 0,
|
global_redundant_expert_num: int = 0,
|
||||||
|
w1_scale_bias: torch.Tensor = None,
|
||||||
|
w2_scale_bias: torch.Tensor = None,
|
||||||
):
|
):
|
||||||
if log2phy is not None:
|
if log2phy is not None:
|
||||||
topk_ids = log2phy[topk_ids]
|
topk_ids = log2phy[topk_ids]
|
||||||
@@ -457,7 +490,9 @@ def fused_experts_with_all2all(
|
|||||||
w2_scale,
|
w2_scale,
|
||||||
expert_tokens, #16
|
expert_tokens, #16
|
||||||
dynamic_scale=dynamic_scale,
|
dynamic_scale=dynamic_scale,
|
||||||
group_list_type=group_list_type)
|
group_list_type=group_list_type,
|
||||||
|
w1_scale_bias=w1_scale_bias,
|
||||||
|
w2_scale_bias=w2_scale_bias)
|
||||||
|
|
||||||
if expert_map is not None:
|
if expert_map is not None:
|
||||||
reordered_outputs = torch.index_select(
|
reordered_outputs = torch.index_select(
|
||||||
|
|||||||
Reference in New Issue
Block a user