[300I][Bugfix] fix unquant model weight nd2nz error (#6851)
### What this PR does / why we need it?
- This PR fixes an issue with weight format conversion for unquantized
models running on Ascend 310P devices.
- The changes refactor the logic for converting weights to the
FRACTAL_NZ format. Previously, this was handled in a 310P-specific
linear layer implementation (`AscendUnquantizedLinearMethod310`). This
implementation has been removed, and the logic is now centralized in the
`maybe_trans_nz` utility function. This function now checks if the
device is a 310P and applies the NZ format cast accordingly for
`float16`/`bfloat16` weights.
- This refactoring simplifies the code by removing platform-specific
duplication and ensures correct weight handling for unquantized models
on 310P.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
ut and local test
- vLLM version: v0.15.0
- vLLM main:
83b47f67b1
---------
Signed-off-by: Tflowers-0129 <2906339855@qq.com>
This commit is contained in:
@@ -21,7 +21,7 @@ import torch
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.quantization.methods.base import AscendLinearScheme
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
|
||||
from vllm_ascend.utils import maybe_trans_nz
|
||||
|
||||
from .registry import register_scheme
|
||||
|
||||
@@ -105,7 +105,7 @@ class AscendW8A8LinearMethod310(AscendLinearScheme):
|
||||
).to(layer.aclnn_input_scale.dtype)
|
||||
|
||||
# ---- matmul stage tensor ----
|
||||
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, ACL_FORMAT_FRACTAL_NZ).transpose(0, 1)
|
||||
layer.weight.data = maybe_trans_nz(layer.weight.data).transpose(0, 1)
|
||||
|
||||
# ---- dequant stage tensors ----
|
||||
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.quantization.methods.base import AscendLinearScheme
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
|
||||
from vllm_ascend.utils import maybe_trans_nz
|
||||
|
||||
from .registry import register_scheme
|
||||
|
||||
@@ -84,4 +84,4 @@ class AscendW8A8SLinearMethod310(AscendLinearScheme):
|
||||
layer.aclnn_input_scale = layer.input_scale.data.repeat(expanding_factor)
|
||||
layer.aclnn_input_scale_reciprocal = 1.0 / layer.aclnn_input_scale.data
|
||||
layer.aclnn_input_offset = layer.input_offset.data.repeat(expanding_factor).to(layer.aclnn_input_scale.dtype)
|
||||
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.weight.data = maybe_trans_nz(layer.weight.data)
|
||||
|
||||
@@ -27,7 +27,6 @@ from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.quantization import register_quantization_config
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
UnquantizedEmbeddingMethod,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
|
||||
@@ -104,9 +103,9 @@ class AscendModelSlimConfig310(AscendModelSlimConfig):
|
||||
if isinstance(layer, LinearBase):
|
||||
packed = getattr(self, "packed_modules_mapping", {})
|
||||
if self.is_layer_skipped_ascend(prefix, packed):
|
||||
from vllm_ascend._310p.ops.linear import AscendUnquantizedLinearMethod310
|
||||
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
|
||||
|
||||
return AscendUnquantizedLinearMethod310()
|
||||
return AscendUnquantizedLinearMethod()
|
||||
|
||||
scheme = create_scheme_for_layer(
|
||||
quant_description=self.quant_description,
|
||||
@@ -125,6 +124,8 @@ class AscendModelSlimConfig310(AscendModelSlimConfig):
|
||||
return AscendFusedMoEMethod(scheme, layer.moe_config)
|
||||
|
||||
elif isinstance(layer, VocabParallelEmbedding):
|
||||
return UnquantizedEmbeddingMethod()
|
||||
from vllm_ascend._310p.ops.vocab_parallel_embedding import AscendUnquantizedEmbeddingMethod310
|
||||
|
||||
return AscendUnquantizedEmbeddingMethod310()
|
||||
|
||||
return super().get_quant_method(layer, prefix)
|
||||
|
||||
Reference in New Issue
Block a user