[refactor] refactor weight trans nz and transpose (#4878)

### What this PR does / why we need it?

Now `VLLM_ASCEND_ENABLE_NZ` will have three options:
0: disable nz;
1: only quant case enable nz;
2: enable nz as long as possible;

And `VLLM_ASCEND_ENABLE_NZ`=1 by default.

All cases are shown in the table below:

|  | W4A4 | W4A8 | W8A8 | fp16/bf16 | fp32 |
|---|---|---|---|---|---|
| trans nz | can't support nz | trans nz by default | trans nz by
default | trans nz when VLLM_ASCEND_ENABLE_NZ is 2 | can't support nz |
| transpose | only support not transpose case | only support transpose
case | only support transpose case | linear: only support not transpose
case<br>gmm: only support transpose case | same to fp16/bf16 |

Some exceptional cases:
1. MLAPO op need to do some additional processing on the weights,
including trans nz. If use MLAPO op, some weight will be transformed to
nz forcely;
2. MLA/SFA's weight `W_UV` will be used by op
`torch.ops._C_ascend.batch_matmul_transpose`, and this op can't support
nz currently;

### Does this PR introduce _any_ user-facing change?
Now fp16/bf16 weight will not trans nz by default.

### How was this patch tested?

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
zzzzwwjj
2025-12-19 14:27:24 +08:00
committed by GitHub
parent ea8f544ce7
commit cc23067f1e
19 changed files with 156 additions and 255 deletions

View File

@@ -4,7 +4,8 @@ from unittest.mock import MagicMock, patch
import torch
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from tests.ut.base import TestBase
from vllm_ascend.ascend_config import init_ascend_config
@@ -972,16 +973,13 @@ class TestAscendMLAImpl(TestBase):
def test_process_weights_after_loading(self, mock_format_cast):
layer = MagicMock(spec=LinearBase)
layer.input_size_per_partition = 10
quant_method = MagicMock()
apply = MagicMock()
quant_method.apply = apply
quant_method = MagicMock(spec=UnquantizedLinearMethod)
layer.quant_method = quant_method
shape_0 = self.impl.num_heads * (self.impl.qk_nope_head_dim +
self.impl.v_head_dim)
shape_1 = self.impl.kv_lora_rank
layer.weight = torch.randn(shape_0, shape_1)
self.impl.kv_b_proj = layer
apply.return_value = layer.weight.T
mock_format_cast.return_value = layer.weight
self.impl.process_weights_after_loading(torch.bfloat16)

View File

@@ -1,3 +1,4 @@
import os
import unittest
from unittest import mock
from unittest.mock import MagicMock, patch
@@ -61,22 +62,24 @@ class TestAscendUnquantizedLinearMethod(TestBase):
mock_dtype = mock.PropertyMock(return_value=torch.float16)
type(self.layer.weight.data).dtype = mock_dtype
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "0"})
@mock.patch("torch_npu.npu_format_cast")
def test_process_weights_after_loading_enable_nz(self, mock_format_cast,
mock_is_nz):
mock_is_nz.return_value = 1
self.method.process_weights_after_loading(self.layer)
mock_format_cast.assert_called_once()
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
@mock.patch("torch_npu.npu_format_cast")
def test_process_weights_after_loading_disable_nz(self, mock_format_cast,
mock_is_nz):
mock_is_nz.return_value = 0
def test_process_weights_after_loading_with_nz0(self, mock_format_cast):
self.method.process_weights_after_loading(self.layer)
mock_format_cast.assert_not_called()
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "1"})
@mock.patch("torch_npu.npu_format_cast")
def test_process_weights_after_loading_with_nz1(self, mock_format_cast):
self.method.process_weights_after_loading(self.layer)
mock_format_cast.assert_not_called()
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "2"})
@mock.patch("torch_npu.npu_format_cast")
def test_process_weights_after_loading_with_nz2(self, mock_format_cast):
self.method.process_weights_after_loading(self.layer)
mock_format_cast.assert_called_once()
class TestAscendRowParallelLinear(BaseLinearTest):

View File

@@ -199,7 +199,6 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
(self.output_size, self.input_size // 8),
dtype=torch.int32)
mock_pack_weights.return_value = mock_packed
self.method.transpose_weight = False
self.method.process_weights_after_loading(layer)
mock_pack_weights.assert_called_once()
self.assertFalse(hasattr(layer, 'weight'))
@@ -212,35 +211,6 @@ class TestW4A4FlatQuantDynamic(unittest.TestCase):
self.assertEqual(layer.left_trans.shape, (24, 24))
self.assertTrue(layer.left_trans.is_contiguous())
@patch('vllm_ascend.quantization.w4a4_flatquant_dynamic.pack_int4_weights')
def test_process_weights_after_loading_with_transpose(
self, mock_pack_weights):
"""Tests weight processing after loading, with transpose."""
layer = nn.Module()
layer.weight = torch.randint(-8,
7, (self.output_size, self.input_size),
dtype=torch.int8)
layer.weight_scale = torch.randn(self.output_size,
1,
dtype=torch.bfloat16)
layer.weight_offset = torch.randn(self.output_size,
1,
dtype=torch.bfloat16)
layer.left_trans = torch.randn(24, 24)
layer.right_trans = torch.randn(32, 32)
layer.clip_ratio = torch.tensor([0.9])
mock_packed = torch.randint(0,
100,
(self.output_size, self.input_size // 8),
dtype=torch.int32)
mock_pack_weights.return_value = mock_packed
self.method.transpose_weight = True
self.method.process_weights_after_loading(layer)
self.assertTrue(hasattr(layer, 'weight_packed'))
self.assertEqual(layer.weight_packed.shape,
(self.input_size // 8, self.output_size))
self.assertTrue(layer.weight_packed.is_contiguous())
if __name__ == '__main__':
unittest.main(argv=['first-arg-is-ignored'], exit=False)

View File

@@ -62,7 +62,8 @@ class TestAscendW4A8DynamicLinearMethod(TestBase):
@patch('torch_npu.npu_convert_weight_to_int4pack')
@patch('torch.Tensor.npu')
def test_process_weights_after_loading(self, mock_npu,
@patch("torch_npu.npu_format_cast")
def test_process_weights_after_loading(self, mock_format_cast, mock_npu,
mock_npu_convert_weight):
mock_npu.side_effect = lambda: torch.zeros(
(1, 32), dtype=torch.float32)
@@ -85,6 +86,8 @@ class TestAscendW4A8DynamicLinearMethod(TestBase):
layer.weight_offset_second = torch.nn.Parameter(torch.empty_like(
layer.weight_scale_second.data),
requires_grad=False)
mock_format_cast.return_value = layer.weight.data.transpose(
0, 1).contiguous()
self.method.process_weights_after_loading(layer)
self.assertTrue(hasattr(layer, "weight_scale_bias"))
self.assertEqual(layer.weight_scale_bias.data.shape, (32, ))
@@ -110,6 +113,8 @@ class TestAscendW4A8DynamicLinearMethod(TestBase):
new_layer.scale_bias = torch.nn.Parameter(torch.zeros(
(32, 1), dtype=torch.float32),
requires_grad=False)
mock_format_cast.return_value = new_layer.weight.data.transpose(
0, 1).contiguous()
self.method.process_weights_after_loading(new_layer)
self.assertEqual(new_layer.scale_bias.data.shape, (32, ))
self.assertTrue(hasattr(new_layer, "weight_scale_second"))

View File

@@ -1,3 +1,4 @@
import os
from unittest.mock import MagicMock, patch
import torch
@@ -132,20 +133,21 @@ class TestAscendW8A8LinearMethod(TestBase):
expected_y_output += bias
self.assertTrue(torch.equal(output, expected_y_output))
@patch("vllm_ascend.quantization.w8a8.is_enable_nz")
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "0"})
@patch('torch_npu.npu_format_cast')
def test_process_weights_after_loading_not_nz(self, mock_npu_format_cast,
mock_is_nz):
def test_process_weights_after_loading_with_nz0(self,
mock_npu_format_cast):
layer = MagicMock()
layer.weight.data = torch.randn(128, 256)
layer.weight.data = torch.randint(-127,
128, (128, 256),
dtype=torch.int8)
layer.input_scale.data = torch.tensor([0.1])
layer.input_offset.data = torch.tensor([0])
layer.deq_scale = torch.tensor([0.5])
layer.weight_scale.data = torch.randn(128, 1)
layer.weight_offset.data = torch.randn(128, 1)
mock_is_nz.return_value = 0
mock_npu_format_cast.return_value = MagicMock
self.method.process_weights_after_loading(layer)
@@ -160,20 +162,50 @@ class TestAscendW8A8LinearMethod(TestBase):
self.assertEqual(layer.weight_offset.data.shape, (128, ))
mock_npu_format_cast.assert_not_called()
@patch("vllm_ascend.quantization.w8a8.is_enable_nz")
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "1"})
@patch('torch_npu.npu_format_cast')
def test_process_weights_after_loading_nz(self, mock_npu_format_cast,
mock_is_nz):
def test_process_weights_after_loading_with_nz1(self,
mock_npu_format_cast):
layer = MagicMock()
layer.weight.data = torch.randn(128, 256)
layer.weight.data = torch.randint(-127,
128, (128, 256),
dtype=torch.int8)
layer.input_scale.data = torch.tensor([0.1])
layer.input_offset.data = torch.tensor([0])
layer.deq_scale = torch.tensor([0.5])
layer.weight_scale.data = torch.randn(128, 1)
layer.weight_offset.data = torch.randn(128, 1)
mock_npu_format_cast.return_value = MagicMock
self.method.process_weights_after_loading(layer)
expected_offset = torch.tensor([0]).repeat(256).to(torch.int8)
self.assertTrue(
torch.equal(layer.aclnn_input_offset.data, expected_offset))
self.assertFalse(layer.aclnn_input_offset.requires_grad)
self.assertFalse(layer.deq_scale.requires_grad)
self.assertEqual(layer.weight_scale.data.shape, (128, ))
self.assertEqual(layer.weight_offset.data.shape, (128, ))
mock_npu_format_cast.assert_called_once()
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "2"})
@patch('torch_npu.npu_format_cast')
def test_process_weights_after_loading_with_nz2(self,
mock_npu_format_cast):
layer = MagicMock()
layer.weight.data = torch.randint(-127,
128, (128, 256),
dtype=torch.int8)
layer.input_scale.data = torch.tensor([0.1])
layer.input_offset.data = torch.tensor([0])
layer.deq_scale = torch.tensor([0.5])
layer.weight_scale.data = torch.randn(128, 1)
layer.weight_offset.data = torch.randn(128, 1)
mock_is_nz.return_value = 1
mock_npu_format_cast.return_value = MagicMock
self.method.process_weights_after_loading(layer)

View File

@@ -35,14 +35,6 @@ class TestUtils(TestBase):
from vllm_ascend import platform
importlib.reload(platform)
def test_is_enable_nz(self):
with mock.patch("vllm_ascend.utils.envs_ascend.VLLM_ASCEND_ENABLE_NZ",
1):
self.assertTrue(utils.is_enable_nz())
with mock.patch("vllm_ascend.utils.envs_ascend.VLLM_ASCEND_ENABLE_NZ",
0):
self.assertFalse(utils.is_enable_nz())
def test_nd_to_nz_2d(self):
# can be divided by 16
input_tensor = torch.randn(32, 64)

View File

@@ -14,8 +14,7 @@ from vllm.distributed import (get_decode_context_model_parallel_rank,
get_pcp_group)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.utils.math_utils import cdiv, round_down
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import MLAAttentionSpec
@@ -38,8 +37,8 @@ from vllm_ascend.ops.shared_weight_layer import (
register_layer_to_shared_weight_series)
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
flashcomm2_o_shared_enabled, is_enable_nz,
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND,
flashcomm2_o_shared_enabled, maybe_trans_nz,
weak_ref_tensors)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
@@ -796,40 +795,11 @@ class AscendMLAImpl(MLAAttentionImpl):
return ql_nope.transpose(0, 1), q_pe
def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
for attr in WEIGHT_NAMES:
try:
return getattr(layer, attr)
except AttributeError:
pass
raise AttributeError(
f"Layer '{layer}' has no recognized weight attribute:"
f" {WEIGHT_NAMES}.")
def get_and_maybe_dequant_weights(layer: LinearBase):
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(layer.input_size_per_partition,
dtype=act_dtype,
device=get_layer_weight(layer).device)
dequant_weights = layer.quant_method.apply(layer,
eye,
bias=None)
del eye
# standardize to (output, input)
return dequant_weights.T
# Weight will be reshaped next. To be on the safe side, the format
# of the weight should be reverted to FRACTAL_AND.
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_ND)
return layer.weight
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
# NOTE: We currently do not support quant kv_b_proj.
assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod)
# NOTE: Weight will be reshaped next, we need to revert and transpose it.
kv_b_proj_weight = torch_npu.npu_format_cast(
self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_ND).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
@@ -852,15 +822,8 @@ class AscendMLAImpl(MLAAttentionImpl):
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
# Function `get_and_maybe_dequant_weights` will cast the weights to
# FRACTAL_AND. So we need to cast to FRACTAL_NZ again.
if is_enable_nz():
self.kv_b_proj.weight.data = torch_npu.npu_format_cast(
self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_NZ)
# Waiting for BMM NZ support
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
# TODO(zzzzwwjj): Currently, torch.ops._C_ascend.batch_matmul_transpose cannot support weight nz
# self.W_UV = maybe_trans_nz(self.W_UV)
if self.enable_mlapo:
# Currently mlapo only supports W8A8 quantization in MLA scenario
@@ -875,6 +838,9 @@ class AscendMLAImpl(MLAAttentionImpl):
"thus mlapo is disabled for these layers.")
if self.enable_mlapo:
self._process_weights_for_fused_mlapo(act_dtype)
else:
# if mlapo, W_UK_T can't trans nz
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
if self.fc2_o_shared_enable and is_hidden_layer(
self.vllm_config, self.o_proj):

View File

@@ -9,7 +9,7 @@ from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.forward_context import get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.linear import (LinearBase, ReplicatedLinear,
from vllm.model_executor.layers.linear import (ReplicatedLinear,
UnquantizedLinearMethod)
from vllm.triton_utils import HAS_TRITON
from vllm.v1.attention.backends.utils import AttentionCGSupport
@@ -29,9 +29,8 @@ from vllm_ascend.ops.shared_weight_layer import (
from vllm_ascend.ops.triton.rope import rope_forward_triton
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
_round_up, dispose_layer, enable_sp,
is_enable_nz, replace_layer)
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, _round_up, dispose_layer,
enable_sp, maybe_trans_nz, replace_layer)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
if TYPE_CHECKING:
@@ -404,40 +403,11 @@ class AscendSFAImpl(MLAAttentionImpl):
self.cp_size = 1
def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
for attr in WEIGHT_NAMES:
try:
return getattr(layer, attr)
except AttributeError:
pass
raise AttributeError(
f"Layer '{layer}' has no recognized weight attribute:"
f" {WEIGHT_NAMES}.")
def get_and_maybe_dequant_weights(layer: LinearBase):
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(layer.input_size_per_partition,
dtype=act_dtype,
device=get_layer_weight(layer).device)
dequant_weights = layer.quant_method.apply(layer,
eye,
bias=None)
del eye
# standardize to (output, input)
return dequant_weights.T
# Weight will be reshaped next. To be on the safe side, the format
# of the weight should be reverted to FRACTAL_AND.
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_ND)
return layer.weight
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
# NOTE: We currently do not support quant kv_b_proj.
assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod)
# NOTE: Weight will be reshaped next, we need to revert and transpose it.
kv_b_proj_weight = torch_npu.npu_format_cast(
self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_ND).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank, self.local_num_heads *
(self.qk_nope_head_dim + self.v_head_dim)), (
@@ -460,15 +430,9 @@ class AscendSFAImpl(MLAAttentionImpl):
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
# Function `get_and_maybe_dequant_weights` will cast the weights to
# FRACTAL_AND. So we need to cast to FRACTAL_NZ again.
if is_enable_nz():
self.kv_b_proj.weight.data = torch_npu.npu_format_cast(
self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_NZ)
# TODO(zzzzwwjj): Currently, torch.ops._C_ascend.batch_matmul_transpose cannot support weight nz
# self.W_UV = maybe_trans_nz(self.W_UV)
# Waiting for BMM NZ support
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
# Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory
dispose_layer(self.kv_b_proj)
@@ -502,6 +466,9 @@ class AscendSFAImpl(MLAAttentionImpl):
logger.warning_once(msg)
else:
self._process_weights_for_fused_mlapo(act_dtype)
if not self.enable_mlapo:
# if mlapo, W_UK_T can't trans nz
self.W_UK_T = maybe_trans_nz(self.W_UK_T)
def _v_up_proj(self, x):
forward_context = get_forward_context()

View File

@@ -123,7 +123,10 @@ env_variables: Dict[str, Callable[[], Any]] = {
lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", '0'))),
"VLLM_ASCEND_ENABLE_MLAPO":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLAPO", '0'))),
# Whether to enable transpose weight and cast format to FRACTAL_NZ.
# Whether to enable weight cast format to FRACTAL_NZ.
# 0: close nz;
# 1: only quant case enable nz;
# 2: enable nz as long as possible.
"VLLM_ASCEND_ENABLE_NZ":
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_NZ", 1)),
# Decide whether we should enable CP parallelism.

View File

@@ -19,7 +19,6 @@ from typing import Any, Callable, Optional
import torch
import torch.nn.functional as F
import torch_npu
from vllm.config import get_current_vllm_config
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
tensor_model_parallel_all_reduce)
@@ -48,8 +47,8 @@ from vllm_ascend.quantization.w4a8_dynamic import \
AscendW4A8DynamicFusedMoEMethod
from vllm_ascend.quantization.w8a8_dynamic import \
AscendW8A8DynamicFusedMoEMethod
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendDeviceType,
enable_sp, get_ascend_device_type, is_enable_nz,
from vllm_ascend.utils import (AscendDeviceType, enable_sp,
get_ascend_device_type, maybe_trans_nz,
npu_stream_switch, shared_expert_dp_enabled,
shared_experts_calculation_stream)
@@ -73,12 +72,9 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
1, 2).contiguous()
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
if get_ascend_device_type() != AscendDeviceType._310P and is_enable_nz(
):
layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w2_weight.data = torch_npu.npu_format_cast(
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
if get_ascend_device_type() != AscendDeviceType._310P:
layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data)
layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data)
def apply(self,
layer: torch.nn.Module,

View File

@@ -24,7 +24,6 @@ from typing import Optional, Union
import torch
import torch.nn as nn
import torch_npu
from torch.nn.parameter import Parameter
from vllm.config import get_current_vllm_config
from vllm.distributed import divide
@@ -37,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import \
from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.ops.linear_op import get_parallel_op, get_replicated_op
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
from vllm_ascend.utils import maybe_trans_nz
class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
@@ -45,11 +44,8 @@ class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
if "conv1d" not in layer.prefix and (
is_enable_nz() and layer.weight.data.dtype
in [torch.float16, torch.bfloat16]):
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
if "conv1d" not in layer.prefix:
layer.weight.data = maybe_trans_nz(layer.weight.data)
# TODO(realliujiaxu): Remove this class after linear of vllm supports custom comm group

View File

@@ -86,7 +86,6 @@ class AscendW4A4FlatQuantDynamicLinearMethod:
input_size = 0
def __init__(self):
self.transpose_weight = False
self.sym = True
@staticmethod
@@ -176,9 +175,8 @@ class AscendW4A4FlatQuantDynamicLinearMethod:
return output
def process_weights_after_loading(self, layer):
# NOTE: Currently, w4a4 can't support weight nz
weight_packed = pack_int4_weights(layer.weight.data)
if self.transpose_weight:
weight_packed = weight_packed.transpose(0, 1).contiguous()
layer.register_parameter(
'weight_packed',
torch.nn.Parameter(weight_packed, requires_grad=False))

View File

@@ -27,7 +27,7 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
from vllm_ascend.utils import maybe_trans_nz
class AscendW4A8DynamicLinearMethod:
@@ -35,8 +35,6 @@ class AscendW4A8DynamicLinearMethod:
"""
def __init__(self):
self.transpose_weight = True
vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get(
"group_size", 256)
@@ -170,8 +168,8 @@ class AscendW4A8DynamicLinearMethod:
)
def process_weights_after_loading(self, layer: torch.nn.Module):
if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight.data = maybe_trans_nz(layer.weight.data)
layer.weight_scale.data = layer.weight_scale.data.flatten().to(
torch.float32)
layer.weight_offset.data = layer.weight_offset.data.flatten()
@@ -214,8 +212,6 @@ class AscendW4A8DynamicFusedMoEMethod:
"""
def __init__(self):
self.transpose_weight = True
self.ep_group = get_ep_group()
vllm_config = get_current_vllm_config()
@@ -462,11 +458,10 @@ class AscendW4A8DynamicFusedMoEMethod:
torch.quint4x2, -1, False)
def process_weights_after_loading(self, layer):
if self.transpose_weight:
layer.w13_weight.data = layer.w13_weight.data.transpose(
1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(
1, 2).contiguous()
layer.w13_weight.data = layer.w13_weight.data.transpose(
1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(1,
2).contiguous()
w13_weight_scale_second = layer.w13_weight_scale_second.data if hasattr(
layer, "w13_weight_scale_second") else None
@@ -487,10 +482,7 @@ class AscendW4A8DynamicFusedMoEMethod:
self.update_bias(layer, w13_bias, w2_bias)
if is_enable_nz():
layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w2_weight.data = torch_npu.npu_format_cast(
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data)
layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data)
layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)

View File

@@ -21,9 +21,8 @@ import torch
import torch_npu
from vllm.forward_context import get_forward_context
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ,
COMPRESSED_TENSORS_METHOD, AscendDeviceType,
get_ascend_device_type, is_enable_nz)
from vllm_ascend.utils import (COMPRESSED_TENSORS_METHOD, AscendDeviceType,
get_ascend_device_type, maybe_trans_nz)
def quant_per_tensor(in_tensor: torch.Tensor,
@@ -42,9 +41,7 @@ class AscendW8A8LinearMethod:
"""
def __init__(self) -> None:
# aclnn quant matmul requires to transpose matrix B, set to true by default.
self.transpose_weight = get_ascend_device_type(
) != AscendDeviceType._310P
pass
@staticmethod
def get_weight(
@@ -189,11 +186,9 @@ class AscendW8A8LinearMethod:
layer.aclnn_input_offset = torch.nn.Parameter(
layer.input_offset.data.repeat(expanding_factor),
requires_grad=False).to(layer.aclnn_input_scale.dtype)
if self.transpose_weight:
if get_ascend_device_type() != AscendDeviceType._310P:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
if is_enable_nz():
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.weight.data = maybe_trans_nz(layer.weight.data)
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
ascend_quant_method = getattr(layer, "ascend_quant_method", "")

View File

@@ -29,7 +29,7 @@ from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.flash_common3_context import get_flash_common3_context
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
from vllm_ascend.utils import maybe_trans_nz
class AscendW8A8DynamicLinearMethod:
@@ -37,7 +37,7 @@ class AscendW8A8DynamicLinearMethod:
"""
def __init__(self):
self.transpose_weight = True
pass
@staticmethod
def get_weight(input_size: int, output_size: int,
@@ -91,12 +91,9 @@ class AscendW8A8DynamicLinearMethod:
return output
def process_weights_after_loading(self, layer):
if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
# cast quantized weight tensors in NZ format for higher inference speed
if is_enable_nz():
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.weight.data = maybe_trans_nz(layer.weight.data)
layer.weight_scale.data = layer.weight_scale.data.flatten()
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
layer.weight_offset.data = layer.weight_offset.data.flatten()
@@ -107,8 +104,6 @@ class AscendW8A8DynamicFusedMoEMethod:
"""
def __init__(self):
self.transpose_weight = True
self.ep_group = get_ep_group()
vllm_config = get_current_vllm_config()
@@ -270,14 +265,12 @@ class AscendW8A8DynamicFusedMoEMethod:
mc2_mask=kwargs.get("mc2_mask", None))
def process_weights_after_loading(self, layer):
if self.transpose_weight:
layer.w13_weight.data = layer.w13_weight.data.transpose(
1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(
1, 2).contiguous()
if is_enable_nz():
torch_npu.npu_format_cast_(layer.w13_weight, ACL_FORMAT_FRACTAL_NZ)
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
layer.w13_weight.data = layer.w13_weight.data.transpose(
1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(1,
2).contiguous()
layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data)
layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data)
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
layer.w13_weight_scale.data.shape[0], -1)
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(

View File

@@ -122,8 +122,22 @@ def _unregister_print_streams_on_exit():
atexit.register(_unregister_print_streams_on_exit)
def is_enable_nz():
return envs_ascend.VLLM_ASCEND_ENABLE_NZ
def maybe_trans_nz(weight: torch.Tensor):
if not envs_ascend.VLLM_ASCEND_ENABLE_NZ:
# NZ is not enabled
return weight
if weight.dtype == torch.float:
# fp32 can not support NZ
return weight
elif weight.dtype in {torch.bfloat16, torch.float16}:
# bf16/fp16 will trans nz when VLLM_ASCEND_ENABLE_NZ is 2
if envs_ascend.VLLM_ASCEND_ENABLE_NZ == 2:
return torch_npu.npu_format_cast(weight, ACL_FORMAT_FRACTAL_NZ)
else:
return weight
else:
# quant weight will trans nz by default
return torch_npu.npu_format_cast(weight, ACL_FORMAT_FRACTAL_NZ)
def _round_up(x: int, align: int):

View File

@@ -114,10 +114,10 @@ from vllm_ascend.spec_decode import get_spec_decode_method
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
from vllm_ascend.spec_decode.interface import SpecDcodeType
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
AscendDeviceType, ProfileExecuteDuration,
enable_sp, get_ascend_device_type, is_enable_nz,
is_moe_model, lmhead_tp_enable, vllm_version_is)
from vllm_ascend.utils import (AscendDeviceType, ProfileExecuteDuration,
enable_sp, get_ascend_device_type, is_moe_model,
lmhead_tp_enable, maybe_trans_nz,
vllm_version_is)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
from vllm_ascend.ascend_forward_context import ( # isort: skip
@@ -137,9 +137,6 @@ torch.npu.config.allow_internal_format = True
if get_ascend_device_type() == AscendDeviceType._310P:
torch_npu.npu.set_compile_mode(jit_compile=False)
ACL_FORMAT = ACL_FORMAT_FRACTAL_NZ
else:
ACL_FORMAT = ACL_FORMAT_FRACTAL_ND
@dataclass
@@ -2225,16 +2222,6 @@ class NPUModelRunner(GPUModelRunner):
self.model = get_model(vllm_config=self.vllm_config)
if self.dynamic_eplb:
model_register(self.model, self.model_config)
if get_ascend_device_type() == AscendDeviceType._310P:
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, QKVParallelLinear,
RowParallelLinear)
for module in self.model.modules():
if isinstance(module,
(MergedColumnParallelLinear,
QKVParallelLinear, RowParallelLinear)):
module.weight.data = self._convert_torch_format(
module.weight.data)
if self.drafter:
logger.info("Loading drafter model...")
self.drafter.load_model(self.model)
@@ -2255,13 +2242,6 @@ class NPUModelRunner(GPUModelRunner):
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL)
def _convert_torch_format(self, tensor):
if ACL_FORMAT == ACL_FORMAT_FRACTAL_NZ \
and not is_enable_nz():
return tensor
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
return tensor
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize KV cache based on `kv_cache_config`.
@@ -2534,9 +2514,10 @@ class NPUModelRunner(GPUModelRunner):
self.model_config.hf_text_config.qk_rope_head_dim
]
k_cache = raw_k_tensor.view(dtype).view(k_shape)
k_cache = self._convert_torch_format(k_cache)
v_cache = raw_v_tensor.view(dtype).view(v_shape)
v_cache = self._convert_torch_format(v_cache)
if get_ascend_device_type() == AscendDeviceType._310P:
k_cache = maybe_trans_nz(k_cache)
v_cache = maybe_trans_nz(v_cache)
if self.use_sparse and raw_dsa_k_tensor is not None:
dsa_k_cache_shape = (num_blocks,
kv_cache_spec.block_size, 1, 128)

View File

@@ -55,7 +55,7 @@ from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import (AscendDeviceType, check_ascend_device_type,
enable_sp, get_ascend_device_type, is_enable_nz,
enable_sp, get_ascend_device_type,
register_ascend_customop)
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
@@ -160,7 +160,7 @@ class NPUWorker(WorkerBase):
used_bytes / GiB_bytes)
def wake_up(self, tags: Optional[list[str]] = None) -> None:
if is_enable_nz():
if envs_ascend.VLLM_ASCEND_ENABLE_NZ:
raise ValueError(
"FRACTAL_NZ mode is enabled. This may cause model parameter precision issues "
"in the RL scenarios. Please set VLLM_ASCEND_ENABLE_NZ=0.")

View File

@@ -26,10 +26,10 @@ from vllm.logger import logger
from vllm.sequence import IntermediateTensors
from xlite._C import AttnMHA, Model, ModelAttnMeta, ModelConfig, Runtime
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
AscendMetadata)
from vllm_ascend.utils import is_enable_nz
class XliteModel:
@@ -134,7 +134,7 @@ class LlamaXliteModel(XliteModel):
config.moe_tp_size = 1
config.attn_type = AttnMHA
config.weight_nz = is_enable_nz()
config.weight_nz = envs_ascend.VLLM_ASCEND_ENABLE_NZ
scheduler_config = vllm_config.scheduler_config
max_batch_size = scheduler_config.max_num_seqs
max_seq_len = vllm_config.model_config.max_model_len