[Feat] Unquantized Linear to nz and control all nz-cast (#3356)
### What this PR does / why we need it? Currently, when executing to the Linear layer of models in vLLM-Ascend, the weights format is ND in unquantized case and skipped ascend case. This PR supplements the execution logic for Linear layer. We use a new global variable: VLLM_ASCEND_ENABLE_NZ. When VLLM_ASCEND_ENABLE_NZ=1 and CANN version is 8.3, the weights of the Linear layer will be converted to FRACTAL_NZ, in both unquantized case and skipped ascend case. We also use VLLM_ASCEND_ENABLE_NZ to control the existing NZ conversion, such as w8a8-quantized case. ### Does this PR introduce _any_ user-facing change? Add a new global variable VLLM_ASCEND_ENABLE_NZ. If you want to use NZ format, you should set VLLM_ASCEND_ENABLE_NZ=1. ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
@@ -27,6 +27,8 @@ from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
is_enable_nz)
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -595,6 +597,10 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
del eye
|
||||
# standardize to (output, input)
|
||||
return dequant_weights.T
|
||||
# Weight will be reshaped next. To be on the safe side, the format
|
||||
# of the weight should be reverted to FRACTAL_AND.
|
||||
layer.weight.data = torch_npu.npu_format_cast(
|
||||
layer.weight.data, ACL_FORMAT_FRACTAL_ND)
|
||||
return layer.weight
|
||||
|
||||
# we currently do not have quantized bmm's which are needed for
|
||||
@@ -623,6 +629,12 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
# Convert from (L, N, P) to (N, P, L)
|
||||
self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
|
||||
|
||||
# Function `get_and_maybe_dequant_weights` will cast the weights to
|
||||
# FRACTAL_AND. So we need to cast to FRACTAL_NZ again.
|
||||
if is_enable_nz():
|
||||
self.kv_b_proj.weight.data = torch_npu.npu_format_cast(
|
||||
self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
# Waiting for BMM NZ support
|
||||
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
|
||||
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
|
||||
|
||||
@@ -169,6 +169,9 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
lambda: int(os.getenv("VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT", 250)),
|
||||
"VLLM_ASCEND_ENABLE_MLAPO":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLAPO", '0'))),
|
||||
# Whether to enable transpose weight and cast format to FRACTAL_NZ.
|
||||
"VLLM_ASCEND_ENABLE_NZ":
|
||||
lambda: int(os.getenv("VLLM_ASCEND_ENABLE_NZ", 1)),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
@@ -32,13 +32,15 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
from vllm.distributed import (divide, get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group, split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
|
||||
ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@@ -57,16 +59,81 @@ from vllm.model_executor.models.deepseek_v2 import (
|
||||
from vllm.model_executor.models.utils import (PPMissingLayer,
|
||||
is_pp_missing_parameter,
|
||||
maybe_prefix)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.models.layers.mla import AscendMLAModules
|
||||
from vllm_ascend.models.layers.sfa import (AscendSFAModules,
|
||||
AscendSparseFlashAttention, Indexer)
|
||||
from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.ops.linear import AscendLinearBase
|
||||
|
||||
|
||||
class CustomDeepseekV2RowParallelLinear(RowParallelLinear):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
input_is_parallel: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
# Divide the weight matrix along the first dimension.
|
||||
self.tp_rank = (get_tensor_model_parallel_rank()
|
||||
if not disable_tp else 0)
|
||||
self.tp_size = (get_tensor_model_parallel_world_size()
|
||||
if not disable_tp else 1)
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
self.output_size_per_partition = output_size
|
||||
self.output_partition_sizes = [output_size]
|
||||
|
||||
AscendLinearBase.__init__(self,
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size_per_partition,
|
||||
output_partition_sizes=self.output_partition_sizes,
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=(
|
||||
self.weight_loader_v2 if self.quant_method.__class__.__name__
|
||||
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
"results can lead to incorrect results")
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(
|
||||
torch.empty(self.output_size, dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
self.update_param_tp_status()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
|
||||
@@ -37,7 +37,8 @@ from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
|
||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, npu_stream_switch
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, is_310p, is_enable_nz,
|
||||
npu_stream_switch)
|
||||
|
||||
|
||||
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
@@ -83,7 +84,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
w2_data = self._maybe_pad_weight(layer.w2_weight.data)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
||||
|
||||
if not is_310p():
|
||||
if not is_310p() and is_enable_nz():
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||
|
||||
@@ -24,17 +24,29 @@ from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.distributed import divide
|
||||
from vllm.model_executor.layers.linear import ( # noqa
|
||||
WEIGHT_LOADER_V2_SUPPORTED, ColumnParallelLinear, LinearBase,
|
||||
MergedColumnParallelLinear, QKVParallelLinear, QuantizeMethodBase,
|
||||
RowParallelLinear, UnquantizedLinearMethod)
|
||||
ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import \
|
||||
QuantizationConfig
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
from vllm_ascend.ops.linear_op import get_parallel_op
|
||||
from vllm_ascend.ops.linear_op import get_parallel_op, get_replicated_op
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
||||
|
||||
|
||||
class AscendUnquantizedLinearMethod(UnquantizedLinearMethod):
|
||||
"""Linear method without quantization"""
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
super().process_weights_after_loading(layer)
|
||||
if is_enable_nz() and torch.version.cann.startswith("8.3"):
|
||||
layer.weight.data = torch_npu.npu_format_cast(
|
||||
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
|
||||
# TODO(realliujiaxu): Remove this class after linear of vllm supports custom comm group
|
||||
@@ -65,7 +77,7 @@ class AscendLinearBase(LinearBase):
|
||||
self.prefix = prefix
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[
|
||||
QuantizeMethodBase] = UnquantizedLinearMethod()
|
||||
QuantizeMethodBase] = AscendUnquantizedLinearMethod()
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self,
|
||||
prefix=prefix)
|
||||
@@ -364,3 +376,81 @@ class AscendColumnParallelLinear(ColumnParallelLinear):
|
||||
return self.custom_op.apply(input_)
|
||||
|
||||
return super().forward(input_)
|
||||
|
||||
|
||||
class AscendReplicatedLinear(ReplicatedLinear):
|
||||
"""Ascend Replicated linear layer.
|
||||
|
||||
Args:
|
||||
input_size: input dimension of the linear layer.
|
||||
output_size: output dimension of the linear layer.
|
||||
bias: If true, add bias.
|
||||
skip_bias_add: If true, skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
quant_config: Quantization configure.
|
||||
prefix: The name of the layer in the state dict, including all parents
|
||||
(e.g. model.layers.0.qkv_proj)
|
||||
return_bias: If true, return bias together with outputs in forward pass.
|
||||
disable_tp: Take no effect for replicated linear layers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
disable_tp: bool = False,
|
||||
):
|
||||
self.custom_op = get_replicated_op(disable_tp, prefix, self)
|
||||
# If MergedReplicatedLinear, use output size of each partition.
|
||||
if hasattr(self, "output_sizes"):
|
||||
self.output_partition_sizes = self.output_sizes
|
||||
else:
|
||||
self.output_partition_sizes = [output_size]
|
||||
|
||||
AscendLinearBase.__init__(self,
|
||||
input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix=prefix,
|
||||
return_bias=return_bias,
|
||||
disable_tp=disable_tp)
|
||||
|
||||
# All the linear layer supports quant method.
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(self,
|
||||
self.input_size, [self.output_size],
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size, dtype=self.params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
if self.custom_op is not None:
|
||||
self.custom_op.update_attrs()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if self.custom_op is not None:
|
||||
return self.custom_op.apply(input_)
|
||||
|
||||
return super().forward(input_)
|
||||
|
||||
@@ -17,16 +17,16 @@ This file extends the functionality of linear operations by encapsulating custom
|
||||
communication groups and forward functions into classes (linear ops).
|
||||
|
||||
Current class inheritance structure:
|
||||
CustomTensorParallelOp
|
||||
CustomLinearOp
|
||||
├── CustomColumnParallelOp
|
||||
│ ├── MLPColumnParallelOp
|
||||
│ ├── SequenceColumnParallelOp
|
||||
└── CustomRowParallelOp
|
||||
├── MLPRowParallelOp
|
||||
├── OProjRowParallelOp
|
||||
├── MatmulAllreduceRowParallelOp
|
||||
└── SequenceRowParallelOp
|
||||
|
||||
│ ├── MLPRowParallelOp
|
||||
│ ├── OProjRowParallelOp
|
||||
│ ├── MatmulAllreduceRowParallelOp
|
||||
│ └── SequenceRowParallelOp
|
||||
└── CustomReplicatedOp
|
||||
How to extend a new linear op? Taking column parallel op as an example:
|
||||
1. Inherit from CustomColumnParallelOp and create a new class MyColumnParallelOp
|
||||
2. [Optional] The default communication group is the TP group. If a custom communication group is needed, override the comm_group method
|
||||
@@ -52,7 +52,7 @@ from vllm_ascend.utils import (dense_optim_enable, enable_sp,
|
||||
oproj_tp_enable)
|
||||
|
||||
|
||||
class CustomTensorParallelOp:
|
||||
class CustomLinearOp:
|
||||
|
||||
def __init__(self, layer):
|
||||
self.layer = layer
|
||||
@@ -95,7 +95,7 @@ class CustomTensorParallelOp:
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class CustomColumnParallelOp(CustomTensorParallelOp):
|
||||
class CustomColumnParallelOp(CustomLinearOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
@@ -106,7 +106,7 @@ class CustomColumnParallelOp(CustomTensorParallelOp):
|
||||
self.gather_output = self.layer.gather_output
|
||||
|
||||
|
||||
class CustomRowParallelOp(CustomTensorParallelOp):
|
||||
class CustomRowParallelOp(CustomLinearOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
super().__init__(layer)
|
||||
@@ -129,6 +129,18 @@ class CustomRowParallelOp(CustomTensorParallelOp):
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class CustomReplicatedOp(CustomLinearOp):
|
||||
|
||||
def apply_impl(self, input_):
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
assert self.quant_method is not None
|
||||
|
||||
output = self.quant_method.apply(self.layer, input_, bias)
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class MLPColumnParallelOp(CustomColumnParallelOp):
|
||||
|
||||
def __init__(self, layer):
|
||||
@@ -422,3 +434,11 @@ def get_parallel_op(disable_tp, prefix, layer, direct):
|
||||
return custom_op, custom_op.tp_rank, custom_op.tp_size
|
||||
|
||||
return None, get_tp_group().rank_in_group, get_tp_group().world_size
|
||||
|
||||
|
||||
def get_replicated_op(disable_tp, prefix,
|
||||
layer) -> Optional[Union[CustomReplicatedOp]]:
|
||||
if disable_tp:
|
||||
return None
|
||||
|
||||
return CustomReplicatedOp(layer)
|
||||
|
||||
@@ -24,8 +24,7 @@ from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import \
|
||||
register_quantization_config
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@@ -39,6 +38,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
|
||||
get_otp_group)
|
||||
from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod
|
||||
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
|
||||
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable,
|
||||
oproj_tp_enable)
|
||||
|
||||
@@ -101,7 +101,7 @@ class AscendQuantConfig(QuantizationConfig):
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.is_layer_skipped_ascend(prefix,
|
||||
self.packed_modules_mapping):
|
||||
return UnquantizedLinearMethod()
|
||||
return AscendUnquantizedLinearMethod()
|
||||
return AscendLinearMethod(self, prefix,
|
||||
self.packed_modules_mapping)
|
||||
elif isinstance(layer, Attention) and \
|
||||
|
||||
@@ -27,7 +27,7 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
||||
|
||||
|
||||
class AscendW4A8DynamicLinearMethod:
|
||||
@@ -393,9 +393,10 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
|
||||
self.update_bias(layer, w13_bias, w2_bias)
|
||||
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
if is_enable_nz():
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w13_weight.data = self.pack_to_int32(layer.w13_weight.data)
|
||||
layer.w2_weight.data = self.pack_to_int32(layer.w2_weight.data)
|
||||
|
||||
@@ -25,7 +25,7 @@ from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, is_enable_nz
|
||||
|
||||
|
||||
def quant_per_tensor(in_tensor: torch.Tensor,
|
||||
@@ -156,8 +156,9 @@ class AscendW8A8LinearMethod:
|
||||
requires_grad=False).to(layer.aclnn_input_scale.dtype)
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data,
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
if is_enable_nz():
|
||||
layer.weight.data = torch_npu.npu_format_cast(
|
||||
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
||||
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
||||
|
||||
@@ -340,7 +341,7 @@ class AscendW8A8FusedMoEMethod:
|
||||
# converting ACL_FORMAT_FRACTAL_NZ.
|
||||
# npu_quant_grouped_matmul_dequant in eager mode does not accept
|
||||
# ACL_FORMAT_FRACTAL_NZ.
|
||||
if not is_310p():
|
||||
if not is_310p() and is_enable_nz():
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous()
|
||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||
|
||||
@@ -26,7 +26,7 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
||||
|
||||
|
||||
class AscendW8A8DynamicLinearMethod:
|
||||
@@ -101,8 +101,9 @@ class AscendW8A8DynamicLinearMethod:
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
# cast quantized weight tensors in NZ format for higher inference speed
|
||||
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data,
|
||||
ACL_FORMAT_FRACTAL_NZ)
|
||||
if is_enable_nz():
|
||||
layer.weight.data = torch_npu.npu_format_cast(
|
||||
layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
||||
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||
@@ -267,8 +268,9 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
torch_npu.npu_format_cast_(layer.w13_weight, ACL_FORMAT_FRACTAL_NZ)
|
||||
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
|
||||
if is_enable_nz():
|
||||
torch_npu.npu_format_cast_(layer.w13_weight, ACL_FORMAT_FRACTAL_NZ)
|
||||
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
||||
layer.w13_weight_scale.data.shape[0], -1)
|
||||
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(
|
||||
|
||||
@@ -29,6 +29,7 @@ from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts
|
||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
|
||||
dispose_tensor, get_ascend_soc_version,
|
||||
is_enable_nz,
|
||||
is_hierarchical_communication_enabled)
|
||||
|
||||
|
||||
@@ -829,7 +830,9 @@ class TorchairAscendW8A8DynamicLinearMethod:
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
# cast quantized weight tensors in NZ format (29) for higher inference speed
|
||||
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
|
||||
if is_enable_nz():
|
||||
layer.weight.data = torch_npu.npu_format_cast(
|
||||
layer.weight.data, 29)
|
||||
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
||||
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||
@@ -1048,7 +1051,8 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
|
||||
if is_enable_nz():
|
||||
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
||||
layer.w13_weight_scale.data.shape[0], -1)
|
||||
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(
|
||||
|
||||
@@ -24,6 +24,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
|
||||
from vllm_ascend.utils import is_enable_nz
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -841,7 +842,8 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
|
||||
wd_qkv = wd_qkv.t().contiguous()
|
||||
wd_qkv = transdata(wd_qkv,
|
||||
block_size=(16, 32)).unsqueeze(0).contiguous()
|
||||
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)
|
||||
if is_enable_nz():
|
||||
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)
|
||||
|
||||
kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale.clone()
|
||||
kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(
|
||||
@@ -874,7 +876,8 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
|
||||
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim),
|
||||
-1)
|
||||
wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous()
|
||||
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)
|
||||
if is_enable_nz():
|
||||
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)
|
||||
|
||||
qb_deq_scl = self.q_proj.deq_scale.data.clone()
|
||||
qb_deq_scl = qb_deq_scl.reshape(
|
||||
|
||||
@@ -14,6 +14,7 @@ try:
|
||||
except ImportError:
|
||||
from torchair.ops import NpuStreamSwitch as _npu_stream_switch
|
||||
from torchair.ops import npu_wait_tensor as _npu_wait_tensor
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
||||
|
||||
KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes"
|
||||
KV_CACHE_BYTES_CACHE_FILE_NAME = "kv_cache_bytes"
|
||||
@@ -141,6 +142,9 @@ def converting_weight_acl_format(model, format):
|
||||
if isinstance(module, FusedMoE):
|
||||
if torch_npu.get_npu_format(module.w13_weight.data) == format:
|
||||
return
|
||||
if format == ACL_FORMAT_FRACTAL_NZ \
|
||||
and not is_enable_nz():
|
||||
return
|
||||
module.w13_weight.data = torch_npu.npu_format_cast(
|
||||
module.w13_weight.data, format)
|
||||
module.w2_weight.data = torch_npu.npu_format_cast(
|
||||
|
||||
@@ -65,6 +65,10 @@ def is_310p():
|
||||
return _IS_310P
|
||||
|
||||
|
||||
def is_enable_nz():
|
||||
return envs_ascend.VLLM_ASCEND_ENABLE_NZ
|
||||
|
||||
|
||||
def sleep_mode_enabled():
|
||||
global _SLEEP_MODE_ENABLED
|
||||
if _SLEEP_MODE_ENABLED is None:
|
||||
@@ -508,6 +512,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
||||
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
|
||||
AscendMergedColumnParallelLinear,
|
||||
AscendQKVParallelLinear,
|
||||
AscendReplicatedLinear,
|
||||
AscendRowParallelLinear)
|
||||
from vllm_ascend.ops.rotary_embedding import (
|
||||
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding,
|
||||
@@ -526,6 +531,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
||||
"YaRNScalingRotaryEmbedding": AscendYaRNRotaryEmbedding,
|
||||
"MergedColumnParallelLinear": AscendMergedColumnParallelLinear,
|
||||
"QKVParallelLinear": AscendQKVParallelLinear,
|
||||
"ReplicatedLinear": AscendReplicatedLinear,
|
||||
"DeepseekScalingRotaryEmbedding": AscendDeepseekScalingRotaryEmbedding,
|
||||
"VocabParallelEmbedding": AscendVocabParallelEmbedding,
|
||||
"ParallelLMHead": AscendParallelLMHead,
|
||||
|
||||
@@ -97,6 +97,7 @@ from vllm.v1.worker.utils import (AttentionGroup, bind_kv_cache,
|
||||
sanity_check_mm_encoder_outputs,
|
||||
scatter_mm_placeholders)
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import (MoECommType,
|
||||
set_ascend_forward_context)
|
||||
@@ -125,7 +126,7 @@ from vllm_ascend.spec_decode.interface import SpecDcodeType
|
||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
AscendSocVersion, ProfileExecuteDuration,
|
||||
get_ascend_soc_version, is_310p,
|
||||
get_ascend_soc_version, is_310p, is_enable_nz,
|
||||
lmhead_tp_enable)
|
||||
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
@@ -137,8 +138,6 @@ else:
|
||||
|
||||
import torch_npu
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
|
||||
# if true, allow tensor initialization and casting with internal format (e.g., NZ)
|
||||
torch.npu.config.allow_internal_format = True
|
||||
|
||||
@@ -2609,6 +2608,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
runtime_mode=CUDAGraphMode.FULL)
|
||||
|
||||
def _convert_torch_format(self, tensor):
|
||||
if ACL_FORMAT == ACL_FORMAT_FRACTAL_NZ \
|
||||
and not is_enable_nz():
|
||||
return tensor
|
||||
tensor = torch_npu.npu_format_cast(tensor, ACL_FORMAT)
|
||||
return tensor
|
||||
|
||||
|
||||
Reference in New Issue
Block a user