[Feat] Unquantized Linear to nz and control all nz-cast (#3356)

### What this PR does / why we need it?
Currently, when executing to the Linear layer of models in vLLM-Ascend,
the weights format is ND in unquantized case and skipped ascend case.
This PR supplements the execution logic for Linear layer. We use a new
global variable: VLLM_ASCEND_ENABLE_NZ. When VLLM_ASCEND_ENABLE_NZ=1 and
CANN version is 8.3, the weights of the Linear layer will be converted
to FRACTAL_NZ, in both unquantized case and skipped ascend case. We also
use VLLM_ASCEND_ENABLE_NZ to control the existing NZ conversion, such as
w8a8-quantized case.

### Does this PR introduce _any_ user-facing change?
Add a new global variable VLLM_ASCEND_ENABLE_NZ. If you want to use NZ
format, you should set VLLM_ASCEND_ENABLE_NZ=1.

### How was this patch tested?

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
anon189Ty
2025-10-14 17:39:26 +08:00
committed by GitHub
parent 5c45c227dc
commit 07e39620ea
22 changed files with 413 additions and 49 deletions

View File

@@ -376,7 +376,8 @@ class TestAscendMLAImpl(TestBase):
self.assertEqual(q_pe.shape[1], self.impl.num_heads) self.assertEqual(q_pe.shape[1], self.impl.num_heads)
self.assertEqual(q_pe.shape[2], self.impl.qk_rope_head_dim) self.assertEqual(q_pe.shape[2], self.impl.qk_rope_head_dim)
def test_process_weights_after_loading(self): @patch('torch_npu.npu_format_cast')
def test_process_weights_after_loading(self, mock_format_cast):
layer = MagicMock(spec=LinearBase) layer = MagicMock(spec=LinearBase)
layer.input_size_per_partition = 10 layer.input_size_per_partition = 10
quant_method = MagicMock() quant_method = MagicMock()
@@ -389,6 +390,7 @@ class TestAscendMLAImpl(TestBase):
layer.weight = torch.randn(shape_0, shape_1) layer.weight = torch.randn(shape_0, shape_1)
self.impl.kv_b_proj = layer self.impl.kv_b_proj = layer
apply.return_value = layer.weight.T apply.return_value = layer.weight.T
mock_format_cast.return_value = layer.weight
self.impl.process_weights_after_loading(torch.bfloat16) self.impl.process_weights_after_loading(torch.bfloat16)
self.assertEqual(self.impl.W_UK_T.shape[0], self.impl.num_heads) self.assertEqual(self.impl.W_UK_T.shape[0], self.impl.num_heads)

View File

@@ -12,7 +12,7 @@
# limitations under the License. # limitations under the License.
# This file is a part of the vllm-ascend project. # This file is a part of the vllm-ascend project.
# #
from unittest.mock import Mock, patch from unittest.mock import MagicMock, Mock, patch
import pytest import pytest
import torch import torch
@@ -20,6 +20,7 @@ from vllm.config import CacheConfig
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm_ascend import ascend_config
from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2MLAAttention, from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2MLAAttention,
CustomDeepseekV2RowParallelLinear) CustomDeepseekV2RowParallelLinear)
@@ -46,6 +47,13 @@ def test_row_parallel_linear(cls, mock_distributed):
def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward, def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward,
mock_distributed, base_config): mock_distributed, base_config):
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128)) mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
# Make a fake ascend config because of the AscendLinearBase
vllm_config = MagicMock()
vllm_config.additional_config = None
vllm_config.parallel_config.enable_expert_parallel = False
vllm_config.parallel_config.tensor_parallel_size = 1
vllm_config.kv_transfer_config = None
ascend_config.init_ascend_config(vllm_config)
attn = CustomDeepseekV2MLAAttention(config=base_config, attn = CustomDeepseekV2MLAAttention(config=base_config,
hidden_size=128, hidden_size=128,
@@ -78,6 +86,7 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward,
kv_lora_rank=16, kv_lora_rank=16,
prefix="layers.1.self_attn") prefix="layers.1.self_attn")
assert hasattr(attn, "q_proj") assert hasattr(attn, "q_proj")
ascend_config._ASCEND_CONFIG = None
def test_deepseek_v2_lmhead(mock_distributed, vllm_config): def test_deepseek_v2_lmhead(mock_distributed, vllm_config):
@@ -90,6 +99,14 @@ def test_deepseek_v2_lmhead(mock_distributed, vllm_config):
config = SimpleConfig() config = SimpleConfig()
# Make a fake ascend config because of the AscendLinearBase
vllm_config = MagicMock()
vllm_config.additional_config = None
vllm_config.parallel_config.enable_expert_parallel = False
vllm_config.parallel_config.tensor_parallel_size = 1
vllm_config.kv_transfer_config = None
ascend_config.init_ascend_config(vllm_config)
# 直接创建lmhead和logits_processor # 直接创建lmhead和logits_processor
lmhead = ParallelLMHead(config.vocab_size, config.hidden_size) lmhead = ParallelLMHead(config.vocab_size, config.hidden_size)
logits_processor = LogitsProcessor(config.vocab_size) logits_processor = LogitsProcessor(config.vocab_size)
@@ -105,3 +122,4 @@ def test_deepseek_v2_lmhead(mock_distributed, vllm_config):
return_value=mock_logits): return_value=mock_logits):
logits = logits_processor(lmhead, mock_output) logits = logits_processor(lmhead, mock_output)
assert logits.shape == (2, 4, config.vocab_size) assert logits.shape == (2, 4, config.vocab_size)
ascend_config._ASCEND_CONFIG = None

View File

@@ -5,10 +5,13 @@ from unittest.mock import MagicMock, patch
import torch import torch
from tests.ut.base import TestBase
from vllm_ascend import ascend_config from vllm_ascend import ascend_config
from vllm_ascend.distributed import parallel_state from vllm_ascend.distributed import parallel_state
from vllm_ascend.ops.linear import (AscendMergedColumnParallelLinear, from vllm_ascend.ops.linear import (AscendMergedColumnParallelLinear,
AscendRowParallelLinear) AscendReplicatedLinear,
AscendRowParallelLinear,
AscendUnquantizedLinearMethod)
class BaseLinearTest(unittest.TestCase): class BaseLinearTest(unittest.TestCase):
@@ -49,6 +52,47 @@ class BaseLinearTest(unittest.TestCase):
p.stop() p.stop()
class TestAscendUnquantizedLinearMethod(TestBase):
def setUp(self):
self.method = AscendUnquantizedLinearMethod()
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
@mock.patch("torch_npu.npu_format_cast")
@mock.patch("torch.version")
def test_process_weights_after_loading_is_8_3_enable_nz(
self, mock_version, mock_format_cast, mock_is_nz):
layer = mock.MagicMock()
mock_version.cann = "8.3.RC1"
mock_is_nz.return_value = 1
self.method.process_weights_after_loading(layer)
mock_format_cast.assert_called_once()
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
@mock.patch("torch_npu.npu_format_cast")
@mock.patch("torch.version")
def test_process_weights_after_loading_is_8_3_disable_nz(
self, mock_version, mock_format_cast, mock_is_nz):
layer = mock.MagicMock()
mock_version.cann = "8.3.RC1"
mock_is_nz.return_value = 0
self.method.process_weights_after_loading(layer)
mock_format_cast.assert_not_called()
@mock.patch("vllm_ascend.ops.linear.is_enable_nz")
@mock.patch("torch.version")
def test_process_weights_after_loading_not_8_3(self, mock_version,
mock_is_nz):
layer = mock.MagicMock()
mock_version.cann = "8.2.RC1"
mock_is_nz.return_value = 1
# Should not raise exception
self.method.process_weights_after_loading(layer)
class TestAscendRowParallelLinear(BaseLinearTest): class TestAscendRowParallelLinear(BaseLinearTest):
def test_mlp_optimize(self): def test_mlp_optimize(self):
@@ -92,5 +136,24 @@ class TestAscendMergedColumnParallelLinear(BaseLinearTest):
self.assertEqual(linear.custom_op.comm_group, parallel_state._MLP_TP) self.assertEqual(linear.custom_op.comm_group, parallel_state._MLP_TP)
class TestAscendReplicatedLinear(BaseLinearTest):
def test_init_disable_tp(self):
linear = AscendReplicatedLinear(
input_size=16,
output_size=8,
)
self.assertTrue(
isinstance(linear.quant_method, AscendUnquantizedLinearMethod))
def test_init_without_disable_tp(self):
linear = AscendReplicatedLinear(
input_size=16,
output_size=8,
)
self.assertTrue(
isinstance(linear.quant_method, AscendUnquantizedLinearMethod))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@@ -4,10 +4,10 @@ import torch
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import LinearBase
UnquantizedLinearMethod)
from tests.ut.base import TestBase from tests.ut.base import TestBase
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
from vllm_ascend.quantization.quant_config import (AscendKVCacheMethod, from vllm_ascend.quantization.quant_config import (AscendKVCacheMethod,
AscendQuantConfig) AscendQuantConfig)
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
@@ -82,7 +82,7 @@ class TestAscendQuantConfig(TestBase):
'is_layer_skipped_ascend', 'is_layer_skipped_ascend',
return_value=True): return_value=True):
method = self.ascend_config.get_quant_method(linear_layer, ".attn") method = self.ascend_config.get_quant_method(linear_layer, ".attn")
self.assertIsInstance(method, UnquantizedLinearMethod) self.assertIsInstance(method, AscendUnquantizedLinearMethod)
# Test quantized layer # Test quantized layer
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \ with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \

View File

@@ -137,8 +137,10 @@ class TestAscendW8A8LinearMethod(TestBase):
expected_y_output += bias expected_y_output += bias
self.assertTrue(torch.equal(output, expected_y_output)) self.assertTrue(torch.equal(output, expected_y_output))
@patch("vllm_ascend.quantization.w8a8.is_enable_nz")
@patch('torch_npu.npu_format_cast') @patch('torch_npu.npu_format_cast')
def test_process_weights_after_loading(self, mock_npu_format_cast): def test_process_weights_after_loading_not_nz(self, mock_npu_format_cast,
mock_is_nz):
layer = MagicMock() layer = MagicMock()
layer.weight.data = torch.randn(128, 256) layer.weight.data = torch.randn(128, 256)
@@ -148,6 +150,7 @@ class TestAscendW8A8LinearMethod(TestBase):
layer.weight_scale.data = torch.randn(128, 1) layer.weight_scale.data = torch.randn(128, 1)
layer.weight_offset.data = torch.randn(128, 1) layer.weight_offset.data = torch.randn(128, 1)
mock_is_nz.return_value = 0
mock_npu_format_cast.return_value = MagicMock mock_npu_format_cast.return_value = MagicMock
self.method.process_weights_after_loading(layer) self.method.process_weights_after_loading(layer)
@@ -160,6 +163,35 @@ class TestAscendW8A8LinearMethod(TestBase):
self.assertEqual(layer.weight_scale.data.shape, (128, )) self.assertEqual(layer.weight_scale.data.shape, (128, ))
self.assertEqual(layer.weight_offset.data.shape, (128, )) self.assertEqual(layer.weight_offset.data.shape, (128, ))
mock_npu_format_cast.assert_not_called()
@patch("vllm_ascend.quantization.w8a8.is_enable_nz")
@patch('torch_npu.npu_format_cast')
def test_process_weights_after_loading_nz(self, mock_npu_format_cast,
mock_is_nz):
layer = MagicMock()
layer.weight.data = torch.randn(128, 256)
layer.input_scale.data = torch.tensor([0.1])
layer.input_offset.data = torch.tensor([0])
layer.deq_scale = torch.tensor([0.5])
layer.weight_scale.data = torch.randn(128, 1)
layer.weight_offset.data = torch.randn(128, 1)
mock_is_nz.return_value = 1
mock_npu_format_cast.return_value = MagicMock
self.method.process_weights_after_loading(layer)
expected_offset = torch.tensor([0]).repeat(256).to(torch.int8)
self.assertTrue(
torch.equal(layer.aclnn_input_offset.data, expected_offset))
self.assertFalse(layer.aclnn_input_offset.requires_grad)
self.assertFalse(layer.deq_scale.requires_grad)
self.assertEqual(layer.weight_scale.data.shape, (128, ))
self.assertEqual(layer.weight_offset.data.shape, (128, ))
mock_npu_format_cast.assert_called_once()
class TestAscendW8A8FusedMoEMethod(TestBase): class TestAscendW8A8FusedMoEMethod(TestBase):

View File

@@ -39,6 +39,14 @@ class TestUtils(TestBase):
"Ascend910P1"): "Ascend910P1"):
self.assertFalse(utils.is_310p()) self.assertFalse(utils.is_310p())
def test_is_enable_nz(self):
with mock.patch("vllm_ascend.utils.envs_ascend.VLLM_ASCEND_ENABLE_NZ",
1):
self.assertTrue(utils.is_enable_nz())
with mock.patch("vllm_ascend.utils.envs_ascend.VLLM_ASCEND_ENABLE_NZ",
0):
self.assertFalse(utils.is_enable_nz())
def test_sleep_mode_enabled(self): def test_sleep_mode_enabled(self):
utils._SLEEP_MODE_ENABLED = None utils._SLEEP_MODE_ENABLED = None
with mock.patch("vllm_ascend._build_info.__sleep_mode_enabled__", with mock.patch("vllm_ascend._build_info.__sleep_mode_enabled__",

View File

@@ -96,15 +96,17 @@ class TestTorchairUtils(TestBase):
self.assertEqual(args[0], expected_name) self.assertEqual(args[0], expected_name)
self.assertEqual(args[1], expected_path) self.assertEqual(args[1], expected_path)
@mock.patch('vllm_ascend.torchair.utils.is_enable_nz')
@mock.patch('torch_npu.get_npu_format') @mock.patch('torch_npu.get_npu_format')
@mock.patch('torch_npu.npu_format_cast') @mock.patch('torch_npu.npu_format_cast')
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE', @mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
new=mock.MagicMock) new=mock.MagicMock)
def test_converting_weight_acl_format(self, mock_npu_cast, def test_converting_weight_acl_format_to_nz(self, mock_npu_cast,
mock_get_format): mock_get_format, mock_is_nz):
ACL_FORMAT_FRACTAL_NZ = 29 ACL_FORMAT_FRACTAL_NZ = 29
mock_get_format.return_value = 1 mock_get_format.return_value = 1
mock_npu_cast.return_value = 1 mock_npu_cast.return_value = 1
mock_is_nz.return_value = 1
fused_moe = mock.MagicMock() fused_moe = mock.MagicMock()
fused_moe.w13_weight = mock.MagicMock() fused_moe.w13_weight = mock.MagicMock()
@@ -137,3 +139,26 @@ class TestTorchairUtils(TestBase):
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ) utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
mock_npu_cast.assert_not_called() mock_npu_cast.assert_not_called()
@mock.patch('vllm_ascend.torchair.utils.is_enable_nz')
@mock.patch('torch_npu.get_npu_format')
@mock.patch('torch_npu.npu_format_cast')
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
new=mock.MagicMock)
def test_converting_weight_acl_format_no_nz(self, mock_npu_cast,
mock_get_format, mock_is_nz):
ACL_FORMAT_FRACTAL_NZ = 29
mock_get_format.return_value = 1
mock_npu_cast.return_value = 1
mock_is_nz.return_value = 0
fused_moe = mock.MagicMock()
fused_moe.w13_weight = mock.MagicMock()
fused_moe.w2_weight = mock.MagicMock()
fused_moe.w13_weight.data = torch.randn(128, 256)
fused_moe.w2_weight.data = torch.randn(256, 128)
model = mock.MagicMock()
model.modules.return_value = [fused_moe]
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
mock_npu_cast.assert_not_called()

View File

@@ -27,6 +27,8 @@ from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
is_enable_nz)
from vllm_ascend.worker.npu_input_batch import InputBatch from vllm_ascend.worker.npu_input_batch import InputBatch
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -595,6 +597,10 @@ class AscendMLAImpl(MLAAttentionImpl):
del eye del eye
# standardize to (output, input) # standardize to (output, input)
return dequant_weights.T return dequant_weights.T
# Weight will be reshaped next. To be on the safe side, the format
# of the weight should be reverted to FRACTAL_AND.
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_ND)
return layer.weight return layer.weight
# we currently do not have quantized bmm's which are needed for # we currently do not have quantized bmm's which are needed for
@@ -623,6 +629,12 @@ class AscendMLAImpl(MLAAttentionImpl):
# Convert from (L, N, P) to (N, P, L) # Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous() self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
# Function `get_and_maybe_dequant_weights` will cast the weights to
# FRACTAL_AND. So we need to cast to FRACTAL_NZ again.
if is_enable_nz():
self.kv_b_proj.weight.data = torch_npu.npu_format_cast(
self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_NZ)
# Waiting for BMM NZ support # Waiting for BMM NZ support
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)

View File

@@ -169,6 +169,9 @@ env_variables: Dict[str, Callable[[], Any]] = {
lambda: int(os.getenv("VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT", 250)), lambda: int(os.getenv("VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT", 250)),
"VLLM_ASCEND_ENABLE_MLAPO": "VLLM_ASCEND_ENABLE_MLAPO":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLAPO", '0'))), lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLAPO", '0'))),
# Whether to enable transpose weight and cast format to FRACTAL_NZ.
"VLLM_ASCEND_ENABLE_NZ":
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_NZ", 1)),
} }
# end-env-vars-definition # end-env-vars-definition

View File

@@ -32,13 +32,15 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (divide, get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
get_tp_group, split_tensor_along_last_dim, get_tp_group, split_tensor_along_last_dim,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
ColumnParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -57,16 +59,81 @@ from vllm.model_executor.models.deepseek_v2 import (
from vllm.model_executor.models.utils import (PPMissingLayer, from vllm.model_executor.models.utils import (PPMissingLayer,
is_pp_missing_parameter, is_pp_missing_parameter,
maybe_prefix) maybe_prefix)
from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.models.layers.mla import AscendMLAModules from vllm_ascend.models.layers.mla import AscendMLAModules
from vllm_ascend.models.layers.sfa import (AscendSFAModules, from vllm_ascend.models.layers.sfa import (AscendSFAModules,
AscendSparseFlashAttention, Indexer) AscendSparseFlashAttention, Indexer)
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
from vllm_ascend.ops.linear import AscendLinearBase
class CustomDeepseekV2RowParallelLinear(RowParallelLinear): class CustomDeepseekV2RowParallelLinear(RowParallelLinear):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
# Divide the weight matrix along the first dimension.
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
AscendLinearBase.__init__(self,
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias,
disable_tp=disable_tp)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
if bias:
self.bias = nn.Parameter(
torch.empty(self.output_size, dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
self.update_param_tp_status()
def forward( def forward(
self, self,
input_, input_,

View File

@@ -37,7 +37,8 @@ from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.moe.experts_selector import select_experts from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, npu_stream_switch from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, is_310p, is_enable_nz,
npu_stream_switch)
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
@@ -83,7 +84,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
w2_data = self._maybe_pad_weight(layer.w2_weight.data) w2_data = self._maybe_pad_weight(layer.w2_weight.data)
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
if not is_310p(): if not is_310p() and is_enable_nz():
layer.w13_weight.data = torch_npu.npu_format_cast( layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w2_weight.data = torch_npu.npu_format_cast( layer.w2_weight.data = torch_npu.npu_format_cast(

View File

@@ -24,17 +24,29 @@ from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch_npu
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.distributed import divide from vllm.distributed import divide
from vllm.model_executor.layers.linear import ( # noqa from vllm.model_executor.layers.linear import ( # noqa
WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear, LinearBase, WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear, LinearBase,
MergedColumnParallelLinear, QKVParallelLinear, QuantizeMethodBase, MergedColumnParallelLinear, QKVParallelLinear, QuantizeMethodBase,
RowParallelLinear, UnquantizedLinearMethod) ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import \ from vllm.model_executor.layers.quantization.base_config import \
QuantizationConfig QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.ops.linear_op import get_parallel_op from vllm_ascend.ops.linear_op import get_parallel_op, get_replicated_op
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
"""Linear method without quantization"""
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
if is_enable_nz() and torch.version.cann.startswith("8.3"):
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
# TODO(realliujiaxu): Remove this class after linear of vllm supports custom comm group # TODO(realliujiaxu): Remove this class after linear of vllm supports custom comm group
@@ -65,7 +77,7 @@ class AscendLinearBase(LinearBase):
self.prefix = prefix self.prefix = prefix
if quant_config is None: if quant_config is None:
self.quant_method: Optional[ self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod() QuantizeMethodBase] = AscendUnquantizedLinearMethod()
else: else:
self.quant_method = quant_config.get_quant_method(self, self.quant_method = quant_config.get_quant_method(self,
prefix=prefix) prefix=prefix)
@@ -364,3 +376,81 @@ class AscendColumnParallelLinear(ColumnParallelLinear):
return self.custom_op.apply(input_) return self.custom_op.apply(input_)
return super().forward(input_) return super().forward(input_)
class AscendReplicatedLinear(ReplicatedLinear):
"""Ascend Replicated linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: Take no effect for replicated linear layers.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
self.custom_op = get_replicated_op(disable_tp, prefix, self)
# If MergedReplicatedLinear, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = self.output_sizes
else:
self.output_partition_sizes = [output_size]
AscendLinearBase.__init__(self,
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix,
return_bias=return_bias,
disable_tp=disable_tp)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size, [self.output_size],
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader)
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=self.params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
if self.custom_op is not None:
self.custom_op.update_attrs()
def forward(
self,
input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.custom_op is not None:
return self.custom_op.apply(input_)
return super().forward(input_)

View File

@@ -17,16 +17,16 @@ This file extends the functionality of linear operations by encapsulating custom
communication groups and forward functions into classes (linear ops). communication groups and forward functions into classes (linear ops).
Current class inheritance structure: Current class inheritance structure:
CustomTensorParallelOp CustomLinearOp
├── CustomColumnParallelOp ├── CustomColumnParallelOp
│ ├── MLPColumnParallelOp │ ├── MLPColumnParallelOp
│ ├── SequenceColumnParallelOp │ ├── SequenceColumnParallelOp
└── CustomRowParallelOp └── CustomRowParallelOp
├── MLPRowParallelOp ├── MLPRowParallelOp
├── OProjRowParallelOp ├── OProjRowParallelOp
├── MatmulAllreduceRowParallelOp ├── MatmulAllreduceRowParallelOp
└── SequenceRowParallelOp └── SequenceRowParallelOp
└── CustomReplicatedOp
How to extend a new linear op? Taking column parallel op as an example: How to extend a new linear op? Taking column parallel op as an example:
1. Inherit from CustomColumnParallelOp and create a new class MyColumnParallelOp 1. Inherit from CustomColumnParallelOp and create a new class MyColumnParallelOp
2. [Optional] The default communication group is the TP group. If a custom communication group is needed, override the comm_group method 2. [Optional] The default communication group is the TP group. If a custom communication group is needed, override the comm_group method
@@ -52,7 +52,7 @@ from vllm_ascend.utils import (dense_optim_enable, enable_sp,
oproj_tp_enable) oproj_tp_enable)
class CustomTensorParallelOp: class CustomLinearOp:
def __init__(self, layer): def __init__(self, layer):
self.layer = layer self.layer = layer
@@ -95,7 +95,7 @@ class CustomTensorParallelOp:
return output, output_bias return output, output_bias
class CustomColumnParallelOp(CustomTensorParallelOp): class CustomColumnParallelOp(CustomLinearOp):
def __init__(self, layer): def __init__(self, layer):
super().__init__(layer) super().__init__(layer)
@@ -106,7 +106,7 @@ class CustomColumnParallelOp(CustomTensorParallelOp):
self.gather_output = self.layer.gather_output self.gather_output = self.layer.gather_output
class CustomRowParallelOp(CustomTensorParallelOp): class CustomRowParallelOp(CustomLinearOp):
def __init__(self, layer): def __init__(self, layer):
super().__init__(layer) super().__init__(layer)
@@ -129,6 +129,18 @@ class CustomRowParallelOp(CustomTensorParallelOp):
return output, output_bias return output, output_bias
class CustomReplicatedOp(CustomLinearOp):
def apply_impl(self, input_):
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self.layer, input_, bias)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class MLPColumnParallelOp(CustomColumnParallelOp): class MLPColumnParallelOp(CustomColumnParallelOp):
def __init__(self, layer): def __init__(self, layer):
@@ -422,3 +434,11 @@ def get_parallel_op(disable_tp, prefix, layer, direct):
return custom_op, custom_op.tp_rank, custom_op.tp_size return custom_op, custom_op.tp_rank, custom_op.tp_size
return None, get_tp_group().rank_in_group, get_tp_group().world_size return None, get_tp_group().rank_in_group, get_tp_group().world_size
def get_replicated_op(disable_tp, prefix,
layer) -> Optional[Union[CustomReplicatedOp]]:
if disable_tp:
return None
return CustomReplicatedOp(layer)

View File

@@ -24,8 +24,7 @@ from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported) FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
RowParallelLinear, RowParallelLinear)
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import \ from vllm.model_executor.layers.quantization import \
register_quantization_config register_quantization_config
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
@@ -39,6 +38,7 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
get_otp_group) get_otp_group)
from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable, from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable,
oproj_tp_enable) oproj_tp_enable)
@@ -101,7 +101,7 @@ class AscendQuantConfig(QuantizationConfig):
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if self.is_layer_skipped_ascend(prefix, if self.is_layer_skipped_ascend(prefix,
self.packed_modules_mapping): self.packed_modules_mapping):
return UnquantizedLinearMethod() return AscendUnquantizedLinearMethod()
return AscendLinearMethod(self, prefix, return AscendLinearMethod(self, prefix,
self.packed_modules_mapping) self.packed_modules_mapping)
elif isinstance(layer, Attention) and \ elif isinstance(layer, Attention) and \

View File

@@ -27,7 +27,7 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.moe.experts_selector import select_experts from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
class AscendW4A8DynamicLinearMethod: class AscendW4A8DynamicLinearMethod:
@@ -393,9 +393,10 @@ class AscendW4A8DynamicFusedMoEMethod:
self.update_bias(layer, w13_bias, w2_bias) self.update_bias(layer, w13_bias, w2_bias)
layer.w13_weight.data = torch_npu.npu_format_cast( if is_enable_nz():
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w2_weight.data = torch_npu.npu_format_cast( layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ) layer.w2_weight.data = torch_npu.npu_format_cast(
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data) layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data) layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)

View File

@@ -25,7 +25,7 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.moe.experts_selector import select_experts from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, is_enable_nz
def quant_per_tensor(in_tensor: torch.Tensor, def quant_per_tensor(in_tensor: torch.Tensor,
@@ -156,8 +156,9 @@ class AscendW8A8LinearMethod:
requires_grad=False).to(layer.aclnn_input_scale.dtype) requires_grad=False).to(layer.aclnn_input_scale.dtype)
if self.transpose_weight: if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, if is_enable_nz():
ACL_FORMAT_FRACTAL_NZ) layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.weight_scale.data = torch.flatten(layer.weight_scale.data) layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
layer.weight_offset.data = torch.flatten(layer.weight_offset.data) layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
@@ -340,7 +341,7 @@ class AscendW8A8FusedMoEMethod:
# converting ACL_FORMAT_FRACTAL_NZ. # converting ACL_FORMAT_FRACTAL_NZ.
# npu_quant_grouped_matmul_dequant in eager mode does not accept # npu_quant_grouped_matmul_dequant in eager mode does not accept
# ACL_FORMAT_FRACTAL_NZ. # ACL_FORMAT_FRACTAL_NZ.
if not is_310p(): if not is_310p() and is_enable_nz():
layer.w13_weight.data = torch_npu.npu_format_cast( layer.w13_weight.data = torch_npu.npu_format_cast(
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous() layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
layer.w2_weight.data = torch_npu.npu_format_cast( layer.w2_weight.data = torch_npu.npu_format_cast(

View File

@@ -26,7 +26,7 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.moe.experts_selector import select_experts from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
class AscendW8A8DynamicLinearMethod: class AscendW8A8DynamicLinearMethod:
@@ -101,8 +101,9 @@ class AscendW8A8DynamicLinearMethod:
if self.transpose_weight: if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
# cast quantized weight tensors in NZ format for higher inference speed # cast quantized weight tensors in NZ format for higher inference speed
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, if is_enable_nz():
ACL_FORMAT_FRACTAL_NZ) layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.weight_scale.data = layer.weight_scale.data.flatten() layer.weight_scale.data = layer.weight_scale.data.flatten()
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
layer.weight_offset.data = layer.weight_offset.data.flatten() layer.weight_offset.data = layer.weight_offset.data.flatten()
@@ -267,8 +268,9 @@ class AscendW8A8DynamicFusedMoEMethod:
1, 2).contiguous() 1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose( layer.w2_weight.data = layer.w2_weight.data.transpose(
1, 2).contiguous() 1, 2).contiguous()
torch_npu.npu_format_cast_(layer.w13_weight, ACL_FORMAT_FRACTAL_NZ) if is_enable_nz():
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ) torch_npu.npu_format_cast_(layer.w13_weight, ACL_FORMAT_FRACTAL_NZ)
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view( layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
layer.w13_weight_scale.data.shape[0], -1) layer.w13_weight_scale.data.shape[0], -1)
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to( layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(

View File

@@ -29,6 +29,7 @@ from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
dispose_tensor, get_ascend_soc_version, dispose_tensor, get_ascend_soc_version,
is_enable_nz,
is_hierarchical_communication_enabled) is_hierarchical_communication_enabled)
@@ -829,7 +830,9 @@ class TorchairAscendW8A8DynamicLinearMethod:
if self.transpose_weight: if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
# cast quantized weight tensors in NZ format (29) for higher inference speed # cast quantized weight tensors in NZ format (29) for higher inference speed
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29) if is_enable_nz():
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, 29)
layer.weight_scale.data = layer.weight_scale.data.flatten() layer.weight_scale.data = layer.weight_scale.data.flatten()
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
layer.weight_offset.data = layer.weight_offset.data.flatten() layer.weight_offset.data = layer.weight_offset.data.flatten()
@@ -1048,7 +1051,8 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
1, 2).contiguous() 1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose( layer.w2_weight.data = layer.w2_weight.data.transpose(
1, 2).contiguous() 1, 2).contiguous()
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ) if is_enable_nz():
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view( layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
layer.w13_weight_scale.data.shape[0], -1) layer.w13_weight_scale.data.shape[0], -1)
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to( layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(

View File

@@ -24,6 +24,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
from vllm_ascend.utils import is_enable_nz
from vllm_ascend.worker.npu_input_batch import InputBatch from vllm_ascend.worker.npu_input_batch import InputBatch
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -841,7 +842,8 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
wd_qkv = wd_qkv.t().contiguous() wd_qkv = wd_qkv.t().contiguous()
wd_qkv = transdata(wd_qkv, wd_qkv = transdata(wd_qkv,
block_size=(16, 32)).unsqueeze(0).contiguous() block_size=(16, 32)).unsqueeze(0).contiguous()
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29) if is_enable_nz():
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)
kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale.clone() kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale.clone()
kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape( kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(
@@ -874,7 +876,8 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim),
-1) -1)
wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous() wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous()
self.wu_q = torch_npu.npu_format_cast(wu_q, 29) if is_enable_nz():
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)
qb_deq_scl = self.q_proj.deq_scale.data.clone() qb_deq_scl = self.q_proj.deq_scale.data.clone()
qb_deq_scl = qb_deq_scl.reshape( qb_deq_scl = qb_deq_scl.reshape(

View File

@@ -14,6 +14,7 @@ try:
except ImportError: except ImportError:
from torchair.ops import NpuStreamSwitch as _npu_stream_switch from torchair.ops import NpuStreamSwitch as _npu_stream_switch
from torchair.ops import npu_wait_tensor as _npu_wait_tensor from torchair.ops import npu_wait_tensor as _npu_wait_tensor
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes" KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes"
KV_CACHE_BYTES_CACHE_FILE_NAME = "kv_cache_bytes" KV_CACHE_BYTES_CACHE_FILE_NAME = "kv_cache_bytes"
@@ -141,6 +142,9 @@ def converting_weight_acl_format(model, format):
if isinstance(module, FusedMoE): if isinstance(module, FusedMoE):
if torch_npu.get_npu_format(module.w13_weight.data) == format: if torch_npu.get_npu_format(module.w13_weight.data) == format:
return return
if format == ACL_FORMAT_FRACTAL_NZ \
and not is_enable_nz():
return
module.w13_weight.data = torch_npu.npu_format_cast( module.w13_weight.data = torch_npu.npu_format_cast(
module.w13_weight.data, format) module.w13_weight.data, format)
module.w2_weight.data = torch_npu.npu_format_cast( module.w2_weight.data = torch_npu.npu_format_cast(

View File

@@ -65,6 +65,10 @@ def is_310p():
return _IS_310P return _IS_310P
def is_enable_nz():
return envs_ascend.VLLM_ASCEND_ENABLE_NZ
def sleep_mode_enabled(): def sleep_mode_enabled():
global _SLEEP_MODE_ENABLED global _SLEEP_MODE_ENABLED
if _SLEEP_MODE_ENABLED is None: if _SLEEP_MODE_ENABLED is None:
@@ -508,6 +512,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
from vllm_ascend.ops.linear import (AscendColumnParallelLinear, from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
AscendMergedColumnParallelLinear, AscendMergedColumnParallelLinear,
AscendQKVParallelLinear, AscendQKVParallelLinear,
AscendReplicatedLinear,
AscendRowParallelLinear) AscendRowParallelLinear)
from vllm_ascend.ops.rotary_embedding import ( from vllm_ascend.ops.rotary_embedding import (
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding, AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding,
@@ -526,6 +531,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
"YaRNScalingRotaryEmbedding": AscendYaRNRotaryEmbedding, "YaRNScalingRotaryEmbedding": AscendYaRNRotaryEmbedding,
"MergedColumnParallelLinear": AscendMergedColumnParallelLinear, "MergedColumnParallelLinear": AscendMergedColumnParallelLinear,
"QKVParallelLinear": AscendQKVParallelLinear, "QKVParallelLinear": AscendQKVParallelLinear,
"ReplicatedLinear": AscendReplicatedLinear,
"DeepseekScalingRotaryEmbedding": AscendDeepseekScalingRotaryEmbedding, "DeepseekScalingRotaryEmbedding": AscendDeepseekScalingRotaryEmbedding,
"VocabParallelEmbedding": AscendVocabParallelEmbedding, "VocabParallelEmbedding": AscendVocabParallelEmbedding,
"ParallelLMHead": AscendParallelLMHead, "ParallelLMHead": AscendParallelLMHead,

View File

@@ -97,6 +97,7 @@ from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache,
sanity_check_mm_encoder_outputs, sanity_check_mm_encoder_outputs,
scatter_mm_placeholders) scatter_mm_placeholders)
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import (MoECommType, from vllm_ascend.ascend_forward_context import (MoECommType,
set_ascend_forward_context) set_ascend_forward_context)
@@ -125,7 +126,7 @@ from vllm_ascend.spec_decode.interface import SpecDcodeType
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
AscendSocVersion, ProfileExecuteDuration, AscendSocVersion, ProfileExecuteDuration,
get_ascend_soc_version, is_310p, get_ascend_soc_version, is_310p, is_enable_nz,
lmhead_tp_enable) lmhead_tp_enable)
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
@@ -137,8 +138,6 @@ else:
import torch_npu import torch_npu
import vllm_ascend.envs as envs_ascend
# if true, allow tensor initialization and casting with internal format (e.g., NZ) # if true, allow tensor initialization and casting with internal format (e.g., NZ)
torch.npu.config.allow_internal_format = True torch.npu.config.allow_internal_format = True
@@ -2609,6 +2608,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
runtime_mode=CUDAGraphMode.FULL) runtime_mode=CUDAGraphMode.FULL)
def _convert_torch_format(self, tensor): def _convert_torch_format(self, tensor):
if ACL_FORMAT == ACL_FORMAT_FRACTAL_NZ \
and not is_enable_nz():
return tensor
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT) tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
return tensor return tensor