[Feat] Unquantized Linear to nz and control all nz-cast (#3356)
### What this PR does / why we need it? Currently, when executing to the Linear layer of models in vLLM-Ascend, the weights format is ND in unquantized case and skipped ascend case. This PR supplements the execution logic for Linear layer. We use a new global variable: VLLM_ASCEND_ENABLE_NZ. When VLLM_ASCEND_ENABLE_NZ=1 and CANN version is 8.3, the weights of the Linear layer will be converted to FRACTAL_NZ, in both unquantized case and skipped ascend case. We also use VLLM_ASCEND_ENABLE_NZ to control the existing NZ conversion, such as w8a8-quantized case. ### Does this PR introduce _any_ user-facing change? Add a new global variable VLLM_ASCEND_ENABLE_NZ. If you want to use NZ format, you should set VLLM_ASCEND_ENABLE_NZ=1. ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
@@ -29,6 +29,7 @@ from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts
|
||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
|
||||
dispose_tensor, get_ascend_soc_version,
|
||||
is_enable_nz,
|
||||
is_hierarchical_communication_enabled)
|
||||
|
||||
|
||||
@@ -829,7 +830,9 @@ class TorchairAscendW8A8DynamicLinearMethod:
|
||||
if self.transpose_weight:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
# cast quantized weight tensors in NZ format (29) for higher inference speed
|
||||
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
|
||||
if is_enable_nz():
|
||||
layer.weight.data = torch_npu.npu_format_cast(
|
||||
layer.weight.data, 29)
|
||||
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
||||
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||
@@ -1048,7 +1051,8 @@ class TorchairAscendW8A8DynamicFusedMoEMethod:
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
|
||||
if is_enable_nz():
|
||||
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
||||
layer.w13_weight_scale.data.shape[0], -1)
|
||||
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(
|
||||
|
||||
@@ -24,6 +24,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
|
||||
from vllm_ascend.utils import is_enable_nz
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -841,7 +842,8 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
|
||||
wd_qkv = wd_qkv.t().contiguous()
|
||||
wd_qkv = transdata(wd_qkv,
|
||||
block_size=(16, 32)).unsqueeze(0).contiguous()
|
||||
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)
|
||||
if is_enable_nz():
|
||||
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)
|
||||
|
||||
kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale.clone()
|
||||
kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(
|
||||
@@ -874,7 +876,8 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
|
||||
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim),
|
||||
-1)
|
||||
wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous()
|
||||
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)
|
||||
if is_enable_nz():
|
||||
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)
|
||||
|
||||
qb_deq_scl = self.q_proj.deq_scale.data.clone()
|
||||
qb_deq_scl = qb_deq_scl.reshape(
|
||||
|
||||
@@ -14,6 +14,7 @@ try:
|
||||
except ImportError:
|
||||
from torchair.ops import NpuStreamSwitch as _npu_stream_switch
|
||||
from torchair.ops import npu_wait_tensor as _npu_wait_tensor
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
|
||||
|
||||
KV_CACHE_BYTES_CACHE_PATH_NAME = ".kv_cache_bytes"
|
||||
KV_CACHE_BYTES_CACHE_FILE_NAME = "kv_cache_bytes"
|
||||
@@ -141,6 +142,9 @@ def converting_weight_acl_format(model, format):
|
||||
if isinstance(module, FusedMoE):
|
||||
if torch_npu.get_npu_format(module.w13_weight.data) == format:
|
||||
return
|
||||
if format == ACL_FORMAT_FRACTAL_NZ \
|
||||
and not is_enable_nz():
|
||||
return
|
||||
module.w13_weight.data = torch_npu.npu_format_cast(
|
||||
module.w13_weight.data, format)
|
||||
module.w2_weight.data = torch_npu.npu_format_cast(
|
||||
|
||||
Reference in New Issue
Block a user