[Feat] Support native Kimi-K2-Thinking native W4A16 quantized experts weights (#4516)
### What this PR does / why we need it?
Adds W4A16 quantization method for the Kimi-K2-Thinking model and
updates relevant modules to support the new quantization method.
- Implements complete W4A16 quantization method including weight
packing/unpacking, per-group quantization parameter generation,
post-processing logic and MoE method application.
- Adds parameters `use_int4_w4a16`, `w1_offset` and `w2_offset`, adjusts
`with_quant` conditional logic to support W4A16 matrix multiplication.
- Adds `packed_modules_model_mapping` for Kimi-K2-Thinking model and
processing logic for `weight_packed` field.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
Signed-off-by: Ruri <33858552+zhoux77899@users.noreply.github.com>
Signed-off-by: Ruri <zhouxiang100@huawei.com>
This commit is contained in:
1
.github/workflows/_e2e_test.yaml
vendored
1
.github/workflows/_e2e_test.yaml
vendored
@@ -269,6 +269,7 @@ jobs:
|
||||
run: |
|
||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_multistream_moe
|
||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC
|
||||
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_Kimi_K2_Thinking_W4A16
|
||||
# pytest -sv tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_TP2_WITH_EP
|
||||
# pytest -sv tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_W8A8_WITH_EP
|
||||
pytest -sv tests/e2e/multicard/test_data_parallel_tp2.py
|
||||
|
||||
@@ -12,6 +12,7 @@ single_npu_qwen3_w4a4
|
||||
single_node_pd_disaggregation_mooncake
|
||||
multi_npu_qwen3_next
|
||||
multi_npu
|
||||
multi_npu_kimi-k2-thinking
|
||||
multi_npu_moge
|
||||
multi_npu_qwen3_moe
|
||||
multi_npu_quantization
|
||||
|
||||
107
docs/source/tutorials/multi_npu_kimi-k2-thinking.md
Normal file
107
docs/source/tutorials/multi_npu_kimi-k2-thinking.md
Normal file
@@ -0,0 +1,107 @@
|
||||
# Multi-NPU (Kimi-K2-Thinking)
|
||||
|
||||
## Run with Docker
|
||||
|
||||
```{code-block} bash
|
||||
:substitutions:
|
||||
# Update the vllm-ascend image
|
||||
export IMAGE=m.daocloud.io/quay.io/ascend/vllm-ascend:|vllm_ascend_version|
|
||||
export NAME=vllm-ascend
|
||||
|
||||
# Run the container using the defined variables
|
||||
# Note: If you are running bridge network with docker, please expose available ports for multiple nodes communication in advance
|
||||
docker run --rm \
|
||||
--name $NAME \
|
||||
--net=host \
|
||||
--shm-size=1g \
|
||||
--device /dev/davinci0 \
|
||||
--device /dev/davinci1 \
|
||||
--device /dev/davinci2 \
|
||||
--device /dev/davinci3 \
|
||||
--device /dev/davinci4 \
|
||||
--device /dev/davinci5 \
|
||||
--device /dev/davinci6 \
|
||||
--device /dev/davinci7 \
|
||||
--device /dev/davinci8 \
|
||||
--device /dev/davinci9 \
|
||||
--device /dev/davinci10 \
|
||||
--device /dev/davinci11 \
|
||||
--device /dev/davinci12 \
|
||||
--device /dev/davinci13 \
|
||||
--device /dev/davinci14 \
|
||||
--device /dev/davinci15 \
|
||||
--device /dev/davinci_manager \
|
||||
--device /dev/devmm_svm \
|
||||
--device /dev/hisi_hdc \
|
||||
-v /usr/local/dcmi:/usr/local/dcmi \
|
||||
-v /usr/local/Ascend/driver/tools/hccn_tool:/usr/local/Ascend/driver/tools/hccn_tool \
|
||||
-v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
|
||||
-v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
|
||||
-v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
|
||||
-v /etc/ascend_install.info:/etc/ascend_install.info \
|
||||
-v /mnt/sfs_turbo/.cache:/home/cache \
|
||||
-it $IMAGE bash
|
||||
```
|
||||
|
||||
## Verify the Quantized Model
|
||||
Please be advised to edit the value of `"quantization_config.config_groups.group_0.targets"` from `["Linear"]` into `["MoE"]` in `config.json` of original model downloaded from [Hugging Face](https://huggingface.co/moonshotai/Kimi-K2-Thinking).
|
||||
|
||||
```json
|
||||
{
|
||||
"quantization_config": {
|
||||
"config_groups": {
|
||||
"group_0": {
|
||||
"targets": [
|
||||
"MoE"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Your model files look like:
|
||||
|
||||
```bash
|
||||
.
|
||||
|-- chat_template.jinja
|
||||
|-- config.json
|
||||
|-- configuration_deepseek.py
|
||||
|-- configuration.json
|
||||
|-- generation_config.json
|
||||
|-- model-00001-of-000062.safetensors
|
||||
|-- ...
|
||||
|-- model-00062-of-000062.safetensors
|
||||
|-- model.safetensors.index.json
|
||||
|-- modeling_deepseek.py
|
||||
|-- tiktoken.model
|
||||
|-- tokenization_kimi.py
|
||||
`-- tokenizer_config.json
|
||||
```
|
||||
|
||||
## Online Inference on Multi-NPU
|
||||
|
||||
Run the following script to start the vLLM server on Multi-NPU:
|
||||
|
||||
For an Atlas 800 A3 (64G*16) node, tensor-parallel-size should be at least 16.
|
||||
|
||||
```bash
|
||||
vllm serve Kimi-K2-Thinking \
|
||||
--served-model-name kimi-k2-thinking \
|
||||
--tensor-parallel-size 16 \
|
||||
--enable_expert_parallel \
|
||||
--trust-remote-code \
|
||||
--no-enable-prefix-caching
|
||||
```
|
||||
|
||||
Once your server is started, you can query the model with input prompts.
|
||||
|
||||
```bash
|
||||
curl http://localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
|
||||
"model": "kimi-k2-thinking",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Who are you?"}
|
||||
],
|
||||
"temperature": 1.0
|
||||
}'
|
||||
```
|
||||
@@ -49,6 +49,10 @@ DEEPSEEK_W4A8_MODELS = [
|
||||
"vllm-ascend/DeepSeek-V3.1-W4A8-puring"
|
||||
]
|
||||
|
||||
KIMI_W4A16_MODELS = [
|
||||
"vllm-ascend/Kimi-K2-Thinking-Pruning",
|
||||
]
|
||||
|
||||
|
||||
def test_models_distributed_QwQ():
|
||||
example_prompts = [
|
||||
@@ -250,3 +254,24 @@ def test_models_distributed_Qwen_Dense_with_prefetch_mlp_weight(model):
|
||||
quantization="ascend",
|
||||
) as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", KIMI_W4A16_MODELS)
|
||||
def test_models_distributed_Kimi_K2_Thinking_W4A16(model):
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
]
|
||||
max_tokens = 5
|
||||
|
||||
with VllmRunner(
|
||||
model,
|
||||
max_model_len=8192,
|
||||
dtype="auto",
|
||||
tensor_parallel_size=4,
|
||||
enable_expert_parallel=True,
|
||||
compilation_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
"cudagraph_capture_sizes": [1],
|
||||
},
|
||||
) as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
269
tests/ut/quantization/test_w4a16.py
Normal file
269
tests/ut/quantization/test_w4a16.py
Normal file
@@ -0,0 +1,269 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.w4a16 import (AscendW4A16FusedMoEMethod,
|
||||
pack_to_int32, unpack_from_int32)
|
||||
|
||||
|
||||
class TestUnpackFromInt32(TestBase):
|
||||
|
||||
def test_unpack_from_int32_packed_dim_1(self):
|
||||
weight = torch.tensor([[305419896, -1420531520]], dtype=torch.int32)
|
||||
shape = torch.Size([1, 8])
|
||||
num_bits = 4
|
||||
|
||||
result = unpack_from_int32(weight, shape, num_bits, packed_dim=1)
|
||||
|
||||
self.assertEqual(result.dtype, torch.int8)
|
||||
self.assertEqual(result.shape, shape)
|
||||
|
||||
def test_unpack_from_int32_packed_dim_0(self):
|
||||
weight = torch.tensor([[305419896], [-1420531520]], dtype=torch.int32)
|
||||
shape = torch.Size([8, 1])
|
||||
num_bits = 4
|
||||
|
||||
result = unpack_from_int32(weight, shape, num_bits, packed_dim=0)
|
||||
|
||||
self.assertEqual(result.dtype, torch.int8)
|
||||
self.assertEqual(result.shape, shape)
|
||||
|
||||
def test_unpack_from_int32_assertions(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
weight = torch.tensor([[1, 2]], dtype=torch.int64)
|
||||
unpack_from_int32(weight, torch.Size([8, 1]), 4)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
weight = torch.tensor([[1, 2]], dtype=torch.int32)
|
||||
unpack_from_int32(weight, torch.Size([8, 1]), 16)
|
||||
|
||||
|
||||
class TestPackToInt32(TestBase):
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.quantization.w4a16.torch_npu.npu_convert_weight_to_int4pack"
|
||||
)
|
||||
def test_pack_to_int32_int8(self, mock_npu_convert_weight_to_int4pack):
|
||||
mock_npu_convert_weight_to_int4pack.return_value = torch.zeros(
|
||||
(2, 4), dtype=torch.int32)
|
||||
|
||||
weight = torch.zeros((2, 8, 16), dtype=torch.int8)
|
||||
result = pack_to_int32(weight)
|
||||
|
||||
self.assertEqual(result.dtype, torch.int32)
|
||||
mock_npu_convert_weight_to_int4pack.assert_not_called()
|
||||
|
||||
self.assertEqual(result.shape, torch.Size([2, 8, 4]))
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.quantization.w4a16.torch_npu.npu_convert_weight_to_int4pack"
|
||||
)
|
||||
def test_pack_to_int32_int32(self, mock_npu_convert_weight_to_int4pack):
|
||||
|
||||
def mock_convert_weight(weight):
|
||||
return weight
|
||||
|
||||
mock_npu_convert_weight_to_int4pack.side_effect = mock_convert_weight
|
||||
weight = torch.zeros((2, 8, 8), dtype=torch.int32)
|
||||
result = pack_to_int32(weight)
|
||||
|
||||
self.assertEqual(result.dtype, torch.int32)
|
||||
self.assertEqual(result.shape, weight.shape)
|
||||
|
||||
def test_pack_to_int32_assertion_dim(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
weight = torch.zeros((8, 8), dtype=torch.int8)
|
||||
pack_to_int32(weight)
|
||||
|
||||
def test_pack_to_int32_assertion_dtype(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
weight = torch.zeros((2, 8, 8), dtype=torch.float32)
|
||||
pack_to_int32(weight)
|
||||
|
||||
def test_pack_to_int32_assertion_divisible(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
weight = torch.zeros((2, 8, 7), dtype=torch.int32)
|
||||
pack_to_int32(weight)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
weight = torch.zeros((2, 8, 7), dtype=torch.int8)
|
||||
pack_to_int32(weight)
|
||||
|
||||
|
||||
class TestAscendW4A16FusedMoEMethod(TestBase):
|
||||
experts = 8
|
||||
input_size = 32
|
||||
output_size = 128
|
||||
group_size = 32
|
||||
|
||||
@patch("vllm_ascend.quantization.w4a16.get_ascend_config")
|
||||
@patch("vllm_ascend.quantization.w4a16.get_current_vllm_config")
|
||||
def setUp(self, mock_get_current_vllm_config, mock_get_ascend_config):
|
||||
mock_ascend_config = Mock()
|
||||
mock_ascend_config.dynamic_eplb = False
|
||||
mock_ascend_config.expert_map_record_path = None
|
||||
mock_get_ascend_config.return_value = mock_ascend_config
|
||||
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.quant_config = Mock(quant_description={
|
||||
"group_size": self.group_size,
|
||||
})
|
||||
mock_get_current_vllm_config.return_value = mock_vllm_config
|
||||
|
||||
self.quant_method = AscendW4A16FusedMoEMethod()
|
||||
|
||||
def test_init(self):
|
||||
self.assertTrue(self.quant_method.transpose_weight)
|
||||
self.assertEqual(self.quant_method.num_bits, 4)
|
||||
self.assertEqual(self.quant_method.pack_factor, 8)
|
||||
self.assertEqual(self.quant_method.group_size, self.group_size)
|
||||
self.assertFalse(self.quant_method.dynamic_eplb)
|
||||
|
||||
def test_get_weight(self):
|
||||
param_dict = self.quant_method.get_weight(self.experts,
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
torch.bfloat16)
|
||||
|
||||
self.assertEqual(param_dict["w13_weight_packed"].dtype, torch.int32)
|
||||
expected_w13_shape = (self.experts, 2 * self.input_size,
|
||||
self.output_size //
|
||||
self.quant_method.pack_factor)
|
||||
self.assertEqual(param_dict["w13_weight_packed"].shape,
|
||||
expected_w13_shape)
|
||||
|
||||
self.assertEqual(param_dict["w2_weight_packed"].dtype, torch.int32)
|
||||
expected_w2_shape = (self.experts, self.output_size,
|
||||
self.input_size // self.quant_method.pack_factor)
|
||||
self.assertEqual(param_dict["w2_weight_packed"].shape,
|
||||
expected_w2_shape)
|
||||
|
||||
def test_get_dynamic_quant_param(self):
|
||||
param_dict = self.quant_method.get_dynamic_quant_param(
|
||||
self.experts, self.input_size, self.output_size, torch.bfloat16)
|
||||
|
||||
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
|
||||
expected_w13_scale_shape = (self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].shape,
|
||||
expected_w13_scale_shape)
|
||||
|
||||
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16)
|
||||
expected_w2_scale_shape = (self.experts, self.output_size,
|
||||
self.input_size // self.group_size)
|
||||
self.assertEqual(param_dict["w2_weight_scale"].shape,
|
||||
expected_w2_scale_shape)
|
||||
|
||||
self.assertEqual(param_dict["w13_weight_shape"].dtype, torch.int32)
|
||||
self.assertEqual(param_dict["w13_weight_shape"].shape,
|
||||
(self.experts, 2))
|
||||
|
||||
self.assertEqual(param_dict["w2_weight_shape"].dtype, torch.int32)
|
||||
self.assertEqual(param_dict["w2_weight_shape"].shape,
|
||||
(self.experts, 2))
|
||||
|
||||
self.assertEqual(param_dict["w13_weight_offset"].dtype, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_offset"].shape,
|
||||
expected_w13_scale_shape)
|
||||
|
||||
self.assertEqual(param_dict["w2_weight_offset"].dtype, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w2_weight_offset"].shape,
|
||||
expected_w2_scale_shape)
|
||||
|
||||
def build_layer(self):
|
||||
"""Build a mock layer for testing"""
|
||||
layer = torch.nn.Module()
|
||||
|
||||
w13_shape = (self.experts, 2 * self.input_size,
|
||||
self.output_size // self.quant_method.pack_factor)
|
||||
w2_shape = (self.experts, self.output_size,
|
||||
self.input_size // self.quant_method.pack_factor)
|
||||
|
||||
layer.w13_weight_packed = torch.nn.Parameter(torch.randint(
|
||||
-100, 100, w13_shape, dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_packed = torch.nn.Parameter(torch.randint(
|
||||
-100, 100, w2_shape, dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
|
||||
w13_scale_shape = (self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size)
|
||||
w2_scale_shape = (self.experts, self.output_size,
|
||||
self.input_size // self.group_size)
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
w13_scale_shape, dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
w2_scale_shape, dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
|
||||
layer.w13_weight_offset = torch.nn.Parameter(torch.zeros(
|
||||
w13_scale_shape, dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_offset = torch.nn.Parameter(torch.zeros(
|
||||
w2_scale_shape, dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
|
||||
layer.w13_weight_shape = torch.nn.Parameter(torch.tensor(
|
||||
[[2 * self.input_size, self.output_size]] * self.experts,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_shape = torch.nn.Parameter(torch.tensor(
|
||||
[[self.output_size, self.input_size]] * self.experts,
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
|
||||
return layer
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.quantization.w4a16.torch_npu.npu_convert_weight_to_int4pack"
|
||||
)
|
||||
def test_process_weights_after_loading_with_transpose(
|
||||
self, mock_npu_convert_weight_to_int4pack):
|
||||
|
||||
def mock_convert_weight(weight):
|
||||
new_shape = list(weight.shape)
|
||||
new_shape[-1] = new_shape[-1] // 8
|
||||
return torch.zeros(new_shape, dtype=torch.int32)
|
||||
|
||||
mock_npu_convert_weight_to_int4pack.side_effect = mock_convert_weight
|
||||
|
||||
layer = self.build_layer()
|
||||
self.quant_method.transpose_weight = True
|
||||
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
|
||||
self.assertEqual(layer.w13_weight_packed.data.shape,
|
||||
torch.Size([8, 128, 8]))
|
||||
self.assertEqual(layer.w2_weight_packed.data.shape,
|
||||
torch.Size([8, 32, 16]))
|
||||
|
||||
self.assertEqual(layer.w13_weight_scale.data.shape,
|
||||
torch.Size([8, 4, 64]))
|
||||
self.assertEqual(layer.w2_weight_scale.data.shape,
|
||||
torch.Size([8, 1, 128]))
|
||||
self.assertEqual(layer.w13_weight_offset.data.shape,
|
||||
torch.Size([8, 4, 64]))
|
||||
self.assertEqual(layer.w2_weight_offset.data.shape,
|
||||
torch.Size([8, 1, 128]))
|
||||
|
||||
self.assertTrue(layer.w13_weight_scale.data.is_contiguous())
|
||||
self.assertTrue(layer.w2_weight_scale.data.is_contiguous())
|
||||
self.assertTrue(layer.w13_weight_offset.data.is_contiguous())
|
||||
self.assertTrue(layer.w2_weight_offset.data.is_contiguous())
|
||||
|
||||
def test_process_weights_after_loading_without_transpose(self):
|
||||
layer = self.build_layer()
|
||||
self.quant_method.transpose_weight = False
|
||||
|
||||
original_w13_data = layer.w13_weight_packed.data.clone()
|
||||
original_w2_data = layer.w2_weight_packed.data.clone()
|
||||
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
|
||||
self.assertTrue(
|
||||
torch.equal(layer.w13_weight_packed.data, original_w13_data))
|
||||
self.assertTrue(
|
||||
torch.equal(layer.w2_weight_packed.data, original_w2_data))
|
||||
@@ -93,12 +93,15 @@ class MoECommMethod(ABC):
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int4_w4a8: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
global_num_experts: Optional[int] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[list[torch.Tensor]] = None,
|
||||
w2_scale: Optional[list[torch.Tensor]] = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
w1_offset: Optional[torch.Tensor] = None,
|
||||
w2_offset: Optional[torch.Tensor] = None,
|
||||
# For Cube/Vector parallel
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
@@ -147,9 +150,11 @@ class MoECommMethod(ABC):
|
||||
group_list_type=group_list_type,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
w1_offset=w1_offset,
|
||||
w2_offset=w2_offset,
|
||||
topk_scales=topk_scales,
|
||||
with_quant=use_int8_w8a8
|
||||
or use_int4_w4a8,
|
||||
or use_int4_w4a8 or use_int4_w4a16,
|
||||
fusion=use_int8_w8a8,
|
||||
need_trans=need_trans,
|
||||
dynamic_eplb=dynamic_eplb)
|
||||
@@ -275,12 +280,15 @@ class FusedAlltoAllCommImpl(MoECommMethod):
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int4_w4a8: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
global_num_experts: Optional[int] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
w1_offset: Optional[torch.Tensor] = None,
|
||||
w2_offset: Optional[torch.Tensor] = None,
|
||||
# For Cube/Vector parallel
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
|
||||
@@ -68,9 +68,14 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
w1_offset: Optional[torch.Tensor] = None,
|
||||
w2_offset: Optional[torch.Tensor] = None,
|
||||
fusion: bool = False,
|
||||
dynamic_eplb: bool = False) -> torch.Tensor:
|
||||
if dynamic_scale is None:
|
||||
if w1_offset is not None:
|
||||
unquantized_hidden_states = hidden_states
|
||||
quantized_hidden_states = None
|
||||
elif dynamic_scale is None:
|
||||
unquantized_hidden_states = hidden_states
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
@@ -79,6 +84,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
dispose_tensor(unquantized_hidden_states)
|
||||
quantized_hidden_states = None
|
||||
else:
|
||||
unquantized_hidden_states = None
|
||||
pertoken_scale = dynamic_scale
|
||||
quantized_hidden_states = hidden_states
|
||||
|
||||
@@ -90,7 +96,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(
|
||||
hidden_states)
|
||||
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
|
||||
if w1_scale_bias is None and is_mc2:
|
||||
if w1_scale_bias is None and w1_offset is None and is_mc2:
|
||||
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
|
||||
# gmm1: gate_up_proj & act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale, _ = (
|
||||
@@ -149,6 +155,32 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=w2_scale[0].dtype)[0]
|
||||
elif w1_offset is not None:
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[unquantized_hidden_states],
|
||||
weight=[w1],
|
||||
antiquant_scale=[w1_scale],
|
||||
antiquant_offset=[w1_offset],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
dispose_tensor(unquantized_hidden_states)
|
||||
# act_fn: swiglu
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
antiquant_scale=[w2_scale],
|
||||
antiquant_offset=[w2_offset],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
else:
|
||||
if w1_scale_bias is not None:
|
||||
if group_list_type == 0:
|
||||
@@ -269,6 +301,8 @@ def unified_apply_mlp(hidden_states: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
w1_offset: Optional[torch.Tensor] = None,
|
||||
w2_offset: Optional[torch.Tensor] = None,
|
||||
topk_scales: Optional[torch.Tensor] = None,
|
||||
with_quant: bool = False,
|
||||
fusion: bool = False,
|
||||
@@ -286,6 +320,8 @@ def unified_apply_mlp(hidden_states: torch.Tensor,
|
||||
group_list_type=group_list_type,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
w1_offset=w1_offset,
|
||||
w2_offset=w2_offset,
|
||||
fusion=fusion,
|
||||
dynamic_eplb=dynamic_eplb)
|
||||
else:
|
||||
|
||||
@@ -65,8 +65,8 @@ def _rope_forward_oot(
|
||||
raise NotImplementedError(
|
||||
"Batched rotary embedding is currently not supported on NPU.")
|
||||
else:
|
||||
if self.cos is not None and \
|
||||
self.sin is not None:
|
||||
if hasattr(self, "cos") and hasattr(self, "sin") and \
|
||||
self.cos is not None and self.sin is not None:
|
||||
# If cos and sin are generated outside, use npu_apply_rotary_pos_emb to avoid redundant calculation.
|
||||
# This method requires head_size and rotary_dim equal 128 and neox_style is True
|
||||
query = query.contiguous().view(1, query.shape[0], -1,
|
||||
|
||||
@@ -4,7 +4,8 @@ import torch
|
||||
from compressed_tensors.quantization import (QuantizationArgs,
|
||||
QuantizationStrategy)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import (
|
||||
QUANTIZATION_METHODS, register_quantization_config)
|
||||
@@ -16,8 +17,11 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
find_matched_target, is_activation_quantization_format,
|
||||
should_ignore_layer)
|
||||
|
||||
from vllm_ascend.quantization.quant_config import (AscendLinearMethod,
|
||||
from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
|
||||
from vllm_ascend.quantization.quant_config import (AscendFusedMoEMethod,
|
||||
AscendLinearMethod,
|
||||
AscendQuantConfig)
|
||||
from vllm_ascend.quantization.w4a16 import AscendW4A16FusedMoEMethod
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
||||
from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD
|
||||
@@ -142,7 +146,7 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
|
||||
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||
|
||||
# choose quantization method
|
||||
quant_method: LinearMethodBase = UnquantizedLinearMethod()
|
||||
quant_method = UnquantizedLinearMethod()
|
||||
if quant_scheme is not None:
|
||||
layer.scheme = quant_scheme
|
||||
ascend_quant_config = AscendQuantConfig(self.quant_description
|
||||
@@ -150,6 +154,21 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
|
||||
quant_method = AscendLinearMethod(ascend_quant_config, prefix,
|
||||
None, layer)
|
||||
return quant_method
|
||||
if isinstance(layer, FusedMoE):
|
||||
layer.ascend_quant_method = COMPRESSED_TENSORS_METHOD
|
||||
# collect schemes
|
||||
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||
|
||||
# choose quantization method
|
||||
quant_method = AscendUnquantizedFusedMoEMethod(layer.moe_config)
|
||||
if quant_scheme is not None:
|
||||
layer.scheme = quant_scheme
|
||||
ascend_quant_config = AscendQuantConfig(self.quant_description
|
||||
or {})
|
||||
quant_method = AscendFusedMoEMethod(
|
||||
ascend_quant_config, prefix,
|
||||
ascend_quant_config.packed_modules_mapping, layer)
|
||||
return quant_method
|
||||
return None
|
||||
|
||||
def get_scheme(self,
|
||||
@@ -215,6 +234,10 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
|
||||
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||
return AscendW8A8DynamicLinearMethod()
|
||||
|
||||
if weight_quant is not None:
|
||||
if self._is_w4a16(weight_quant):
|
||||
return AscendW4A16FusedMoEMethod()
|
||||
|
||||
raise NotImplementedError(
|
||||
"No compressed-tensors compatible scheme was found.")
|
||||
|
||||
@@ -246,6 +269,10 @@ class AscendCompressedTensorsConfig(QuantizationConfig):
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_8_bits and is_token and is_symmetric and is_dynamic
|
||||
|
||||
def _is_w4a16(self, weight_quant: QuantizationArgs) -> bool:
|
||||
is_4_bits = weight_quant.num_bits == 4
|
||||
return is_4_bits
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
self.target_scheme_map = hf_to_vllm_mapper.apply_dict(
|
||||
self.target_scheme_map)
|
||||
|
||||
@@ -65,6 +65,9 @@ class AscendQuantConfig(QuantizationConfig):
|
||||
if "shared_head" in k:
|
||||
new_k = k.replace(".shared_head.", ".")
|
||||
extra_quant_dict[new_k] = self.quant_description[k]
|
||||
if "weight_packed" in k:
|
||||
new_k = k.replace("weight_packed", "weight")
|
||||
extra_quant_dict[new_k] = self.quant_description[k]
|
||||
self.quant_description.update(extra_quant_dict)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -200,7 +203,8 @@ packed_modules_model_mapping = {
|
||||
"kimi_k2": {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"],
|
||||
"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
|
||||
},
|
||||
"deepseek_v32": {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
@@ -439,7 +443,9 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
|
||||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
|
||||
per_group_param = [
|
||||
"weight_scale_second", "weight_offset_second", "scale_bias"
|
||||
]
|
||||
] + ["weight_scale", "weight_offset"] if hasattr(
|
||||
self.quant_method,
|
||||
"group_size") and self.quant_method.group_size > 0 else []
|
||||
dynamic_quant_param = self.quant_method.get_dynamic_quant_param(
|
||||
num_experts, intermediate_size_per_partition, hidden_size,
|
||||
params_dtype)
|
||||
|
||||
@@ -8,6 +8,7 @@ from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD
|
||||
from .w4a4_flatquant_dynamic import AscendW4A4FlatQuantDynamicLinearMethod
|
||||
from .w4a8_dynamic import (AscendW4A8DynamicFusedMoEMethod,
|
||||
AscendW4A8DynamicLinearMethod)
|
||||
from .w4a16 import AscendW4A16FusedMoEMethod
|
||||
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
|
||||
AscendW8A8LinearMethod)
|
||||
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
|
||||
@@ -16,6 +17,9 @@ from .w8a8_pdmix import (AscendW8A8PDMixFusedMoeMethod,
|
||||
AscendW8A8PDMixLinearMethod)
|
||||
|
||||
ASCEND_QUANTIZATION_METHOD_MAP: Dict[str, Dict[str, Type[Any]]] = {
|
||||
"W4A16": {
|
||||
"moe": AscendW4A16FusedMoEMethod,
|
||||
},
|
||||
"W4A8_DYNAMIC": {
|
||||
"linear": AscendW4A8DynamicLinearMethod,
|
||||
"moe": AscendW4A8DynamicFusedMoEMethod,
|
||||
|
||||
284
vllm_ascend/quantization/w4a16.py
Normal file
284
vllm_ascend/quantization/w4a16.py
Normal file
@@ -0,0 +1,284 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
|
||||
|
||||
def unpack_from_int32(
|
||||
weight: torch.Tensor,
|
||||
shape: torch.Size,
|
||||
num_bits: int,
|
||||
packed_dim: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Unpacks quantized weights from int32 format back to original bits.
|
||||
|
||||
:param weight: The packed int32 tensor containing quantized weights
|
||||
:param shape: Original shape to restore, defaults to None
|
||||
:param num_bits: The number of bits used for quantization (<= 8)
|
||||
:param packed_dim: Dimension along which weights are packed (0 or 1), defaults to 1
|
||||
:return: Unpacked tensor with int8 dtype after applying offset correction
|
||||
"""
|
||||
assert weight.dtype == torch.int32, f"Expecting `weight.dtype` is torch.int32 but got {weight.dtype}."
|
||||
assert num_bits <= 8, f"Expecting `num_bits` should not be larger than 8 but got {num_bits}."
|
||||
|
||||
pack_factor = 32 // num_bits
|
||||
mask = (1 << num_bits) - 1
|
||||
|
||||
if packed_dim == 1:
|
||||
unpacked_weight = torch.zeros(
|
||||
(weight.shape[0], weight.shape[1] * pack_factor),
|
||||
device=weight.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
for i in range(pack_factor):
|
||||
unpacked_weight[:, i::pack_factor] = (weight >>
|
||||
(num_bits * i)) & mask
|
||||
original_row_size = int(shape[1])
|
||||
unpacked_weight = unpacked_weight[:, :original_row_size]
|
||||
else:
|
||||
unpacked_weight = torch.zeros(
|
||||
(weight.shape[0] * pack_factor, weight.shape[1]),
|
||||
device=weight.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
for i in range(pack_factor):
|
||||
unpacked_weight[i::pack_factor, :] = (weight >>
|
||||
(num_bits * i)) & mask
|
||||
original_row_size = int(shape[0])
|
||||
unpacked_weight = unpacked_weight[:original_row_size, :]
|
||||
|
||||
offset = pow(2, num_bits) // 2
|
||||
unpacked_weight = (unpacked_weight - offset).to(torch.int8)
|
||||
|
||||
return unpacked_weight
|
||||
|
||||
|
||||
def pack_to_int32(weight: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Packs quantized weights into int32 format for storage.
|
||||
|
||||
:param weight: The 3D tensor to pack, must be int8 or int32 dtype
|
||||
:return: Packed tensor with int32 dtype optimized for storage
|
||||
"""
|
||||
assert weight.dim(
|
||||
) == 3, f"Expecting `weight.dim()` is 3 ([e, n, k] or [e, k, n]) but got {weight.dim()}."
|
||||
assert weight.dtype in [
|
||||
torch.int8, torch.int32
|
||||
], f"Expecting `weight.dtype` is torch.int8 or torch.int32 bug got {weight.dtype}."
|
||||
|
||||
if weight.dtype == torch.int32:
|
||||
assert weight.shape[
|
||||
-1] % 8 == 0, "the last dim of weight needs to be divided by 8."
|
||||
packed_weight = torch_npu.npu_convert_weight_to_int4pack(
|
||||
weight.flatten(0, 1))
|
||||
packed_weight = packed_weight.view(weight.shape[0], weight.shape[1],
|
||||
-1)
|
||||
else:
|
||||
assert weight.shape[
|
||||
-1] % 4 == 0, "the last dim of weight needs to be divided by 4."
|
||||
packed_weight = weight.view(torch.int32).contiguous()
|
||||
|
||||
return packed_weight
|
||||
|
||||
|
||||
class AscendW4A16FusedMoEMethod:
|
||||
"""FusedMoe method for Ascend W4A16.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.transpose_weight = True
|
||||
self.num_bits = 4 # dtype = torch.int4
|
||||
self.pack_factor = 8 # pack 8 of torch.int4 tensors to torch.int32
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.group_size = vllm_config.quant_config.quant_description.get(
|
||||
"group_size", 32)
|
||||
ascend_config = get_ascend_config()
|
||||
self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
|
||||
|
||||
def get_weight(
|
||||
self,
|
||||
num_experts: int,
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
assert intermediate_size_per_partition % self.pack_factor == 0, f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} can be divided by `pack_factor` {self.pack_factor}"
|
||||
assert hidden_sizes % self.pack_factor == 0, f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `pack_factor` {self.pack_factor}"
|
||||
|
||||
param_dict = {}
|
||||
|
||||
param_dict["w13_weight_packed"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.pack_factor,
|
||||
dtype=torch.int32)
|
||||
param_dict["w2_weight_packed"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.pack_factor,
|
||||
dtype=torch.int32)
|
||||
|
||||
return param_dict
|
||||
|
||||
def get_dynamic_quant_param(
|
||||
self,
|
||||
num_experts: int,
|
||||
intermediate_size_per_partition: int,
|
||||
hidden_sizes: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
assert intermediate_size_per_partition % self.group_size == 0, f"Expecting `intermediate_size_per_partition` {intermediate_size_per_partition} can be divided by `group_size` {self.group_size}"
|
||||
assert hidden_sizes % self.group_size == 0, f"Expecting `hidden_sizes` {hidden_sizes} can be divided by `group_size` {self.group_size}"
|
||||
|
||||
param_dict = {}
|
||||
|
||||
param_dict["w13_weight_scale"] = torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=torch.bfloat16)
|
||||
param_dict["w2_weight_scale"] = torch.empty(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=torch.bfloat16)
|
||||
param_dict["w13_weight_shape"] = torch.empty(num_experts,
|
||||
2,
|
||||
dtype=torch.int32)
|
||||
param_dict["w2_weight_shape"] = torch.empty(num_experts,
|
||||
2,
|
||||
dtype=torch.int32)
|
||||
param_dict["w13_weight_offset"] = torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_sizes // self.group_size,
|
||||
dtype=torch.bfloat16)
|
||||
param_dict["w2_weight_offset"] = torch.zeros(
|
||||
num_experts,
|
||||
hidden_sizes,
|
||||
intermediate_size_per_partition // self.group_size,
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
return param_dict
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = True,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
return moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight_packed,
|
||||
w2=layer.w2_weight_packed,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1_offset=layer.w13_weight_offset,
|
||||
w2_offset=layer.w2_weight_offset,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_int4_w4a16=True,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
shared_experts=shared_experts,
|
||||
quantized_x_for_share=quantized_x_for_share,
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
dynamic_eplb=self.dynamic_eplb,
|
||||
mc2_mask=kwargs.get("mc2_mask", None))
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if self.transpose_weight:
|
||||
w13_shape = layer.w13_weight_packed.data.shape
|
||||
w2_shape = layer.w2_weight_packed.data.shape
|
||||
unpacked_w13_weight = (unpack_from_int32(
|
||||
layer.w13_weight_packed.data.flatten(0, 1),
|
||||
torch.Size([
|
||||
w13_shape[0] * w13_shape[1],
|
||||
w13_shape[2] * self.pack_factor
|
||||
]),
|
||||
self.num_bits,
|
||||
).view(w13_shape[0], w13_shape[1],
|
||||
-1).transpose(1, 2).contiguous().int())
|
||||
unpacked_w2_weight = (unpack_from_int32(
|
||||
layer.w2_weight_packed.data.flatten(0, 1),
|
||||
torch.Size([
|
||||
w2_shape[0] * w2_shape[1], w2_shape[2] * self.pack_factor
|
||||
]),
|
||||
self.num_bits,
|
||||
).view(w2_shape[0], w2_shape[1],
|
||||
-1).transpose(1, 2).contiguous().int())
|
||||
layer.w13_weight_packed.data = pack_to_int32(unpacked_w13_weight)
|
||||
layer.w2_weight_packed.data = pack_to_int32(unpacked_w2_weight)
|
||||
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose(
|
||||
1, 2).contiguous()
|
||||
|
||||
layer.w13_weight_offset.data = layer.w13_weight_offset.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight_offset.data = layer.w2_weight_offset.data.transpose(
|
||||
1, 2).contiguous()
|
||||
@@ -3471,13 +3471,13 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
# as it only support the 0-dim of kv_cache is `num_blocks`.
|
||||
# For deepseek mla, we need to spilt cache tensor accrodding to the nope head dim
|
||||
# and rope head dim.
|
||||
if self.model_config.is_deepseek_mla:
|
||||
if self.model_config.use_mla:
|
||||
head_size = self.model_config.hf_text_config.qk_rope_head_dim + \
|
||||
self.model_config.hf_text_config.kv_lora_rank
|
||||
|
||||
dsa_k_cache_factor = None
|
||||
dsa_k_cache_size = None
|
||||
if not self.model_config.is_deepseek_mla:
|
||||
if not self.model_config.use_mla:
|
||||
# for non-mla model, use FullAttentionSpec
|
||||
k_tensor_split_factor = 2
|
||||
v_tensor_split_factor = 2
|
||||
@@ -3627,7 +3627,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin):
|
||||
kv_cache_spec.num_kv_heads,
|
||||
kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
if not self.model_config.is_deepseek_mla:
|
||||
if not self.model_config.use_mla:
|
||||
k_shape = kv_cache_shape[1:]
|
||||
v_shape = k_shape
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user