[300I][Bugfix] fix unquant model weight nd2nz error (#6851)
### What this PR does / why we need it?
- This PR fixes an issue with weight format conversion for unquantized
models running on Ascend 310P devices.
- The changes refactor the logic for converting weights to the
FRACTAL_NZ format. Previously, this was handled in a 310P-specific
linear layer implementation (`AscendUnquantizedLinearMethod310`). This
implementation has been removed, and the logic is now centralized in the
`maybe_trans_nz` utility function. This function now checks if the
device is a 310P and applies the NZ format cast accordingly for
`float16`/`bfloat16` weights.
- This refactoring simplifies the code by removing platform-specific
duplication and ensures correct weight handling for unquantized models
on 310P.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
ut and local test
- vLLM version: v0.15.0
- vLLM main:
83b47f67b1
---------
Signed-off-by: Tflowers-0129 <2906339855@qq.com>
This commit is contained in:
@@ -249,3 +249,91 @@ class TestUtils(TestBase):
|
||||
utils.register_ascend_customop()
|
||||
self.assertEqual(mock_customop.register_oot.call_count,
|
||||
len(REGISTERED_ASCEND_OPS))
|
||||
|
||||
@mock.patch("torch_npu.npu_format_cast")
|
||||
def test_maybe_trans_nz(self, mock_npu_format_cast):
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
|
||||
mock_npu_format_cast.side_effect = lambda weight, fmt: weight
|
||||
|
||||
def assert_nz_cast(weight):
|
||||
mock_npu_format_cast.assert_called_once()
|
||||
args, kwargs = mock_npu_format_cast.call_args
|
||||
self.assertIs(args[0], weight)
|
||||
self.assertEqual(args[1], ACL_FORMAT_FRACTAL_NZ)
|
||||
self.assertEqual(kwargs, {})
|
||||
|
||||
# Test case 1: non-310P, NZ is disabled
|
||||
with (
|
||||
mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "0"}),
|
||||
mock.patch("vllm_ascend.utils.is_310p", return_value=False),
|
||||
):
|
||||
weight = torch.randn(32, 64, dtype=torch.float16)
|
||||
result = utils.maybe_trans_nz(weight)
|
||||
self.assertIs(result, weight)
|
||||
mock_npu_format_cast.assert_not_called()
|
||||
|
||||
# Test case 2: 310P always converts non-fp32 weights, even when NZ=0
|
||||
mock_npu_format_cast.reset_mock()
|
||||
with (
|
||||
mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "0"}),
|
||||
mock.patch("vllm_ascend.utils.is_310p", return_value=True),
|
||||
):
|
||||
weight = torch.randn(32, 64, dtype=torch.float16)
|
||||
result = utils.maybe_trans_nz(weight)
|
||||
self.assertIs(result, weight)
|
||||
assert_nz_cast(weight)
|
||||
|
||||
# Test case 3: fp32 never converts, including on 310P
|
||||
mock_npu_format_cast.reset_mock()
|
||||
with (
|
||||
mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "1"}),
|
||||
mock.patch("vllm_ascend.utils.is_310p", return_value=True),
|
||||
):
|
||||
weight = torch.randn(32, 64, dtype=torch.float32)
|
||||
result = utils.maybe_trans_nz(weight)
|
||||
self.assertIs(result, weight)
|
||||
mock_npu_format_cast.assert_not_called()
|
||||
|
||||
# Test case 4: non-310P fp16 converts only when NZ=2
|
||||
mock_npu_format_cast.reset_mock()
|
||||
with (
|
||||
mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "1"}),
|
||||
mock.patch("vllm_ascend.utils.is_310p", return_value=False),
|
||||
):
|
||||
weight = torch.randn(32, 64, dtype=torch.float16)
|
||||
result = utils.maybe_trans_nz(weight)
|
||||
self.assertIs(result, weight)
|
||||
mock_npu_format_cast.assert_not_called()
|
||||
|
||||
# Test case 5: non-310P fp16 converts when NZ=2
|
||||
mock_npu_format_cast.reset_mock()
|
||||
with (
|
||||
mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "2"}),
|
||||
mock.patch("vllm_ascend.utils.is_310p", return_value=False),
|
||||
):
|
||||
weight = torch.randn(32, 64, dtype=torch.float16)
|
||||
result = utils.maybe_trans_nz(weight)
|
||||
self.assertIs(result, weight)
|
||||
assert_nz_cast(weight)
|
||||
|
||||
# Test case 6: non-310P bf16 converts when NZ=2
|
||||
mock_npu_format_cast.reset_mock()
|
||||
with (
|
||||
mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "2"}),
|
||||
mock.patch("vllm_ascend.utils.is_310p", return_value=False),
|
||||
):
|
||||
weight = torch.randn(32, 64, dtype=torch.bfloat16)
|
||||
result = utils.maybe_trans_nz(weight)
|
||||
self.assertIs(result, weight)
|
||||
assert_nz_cast(weight)
|
||||
|
||||
# Test case 7: non-310P quantized weights still convert by default
|
||||
mock_npu_format_cast.reset_mock()
|
||||
with (
|
||||
mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "1"}),
|
||||
mock.patch("vllm_ascend.utils.is_310p", return_value=False),
|
||||
):
|
||||
weight = torch.zeros(32, 64, dtype=torch.int8)
|
||||
result = utils.maybe_trans_nz(weight)
|
||||
self.assertIs(result, weight)
|
||||
assert_nz_cast(weight)
|
||||
|
||||
Reference in New Issue
Block a user