diff --git a/examples/save_sharded_state_310.py b/examples/save_sharded_state_310.py new file mode 100644 index 00000000..fb7acabe --- /dev/null +++ b/examples/save_sharded_state_310.py @@ -0,0 +1,282 @@ +# +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/examples/offline_inference/save_sharded_state.py +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Saves each worker's model state dict directly to a checkpoint, which enables a +fast load path for large tensor-parallel models where each worker only needs to +read its own shard rather than the entire checkpoint. +Sparse-Compress-Quantization state dict could also be saved via this script. + +Example usage: + +python save_sharded_state.py \ + --model /path/to/load \ + --tensor-parallel-size 8 \ + --output /path/to/save \ + --enable-compress \ + --compress-process-num 8 + +Then, the model can be loaded with + +llm = LLM( + model="/path/to/save", + load_format="sharded_state", + tensor_parallel_size=8, + quantization="ascend", +) +""" + +import dataclasses +import json +import multiprocessing as mp +import os +import shutil +from pathlib import Path + +import torch +from vllm import LLM, EngineArgs +from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel +from vllm.utils.argparse_utils import FlexibleArgumentParser + +SUPPORTED_COMPRESS_QUANT_TYPE = ["W8A8S", "W16A16S"] +DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" +QUANTIZATION_UPDATE_MAP = {"W8A8S": "W8A8SC", "W16A16S": "W16A16SC"} + + +class FileHandler: + @staticmethod + def validate_path(path: str, must_exist: bool = True, check_writable: bool = False) -> Path: + """ + Comprehensive path validation. + - Checks existence + - Checks write permissions for the target or its parent + """ + p = Path(path) + if must_exist and not p.exists(): + raise FileNotFoundError(f"Error: Path '{path}' does not exist.") + + if check_writable: + # Check the directory itself if it exists, otherwise check the parent + target = p if p.exists() else p.parent + if not os.access(target, os.W_OK): + raise PermissionError(f"Permission Denied: No write access to '{target}'.") + return p + + @staticmethod + def safe_copy(src: Path, dst: Path): + """Copies files or directories with permission handling.""" + try: + if src.is_dir(): + # dirs_exist_ok=True prevents errors if the destination directory exists + shutil.copytree(src, dst, dirs_exist_ok=True) + else: + # copy2 preserves metadata (timestamps, permissions) + shutil.copy2(src, dst) + except (PermissionError, OSError) as e: + print(f"Warning: Failed to copy {src} due to: {e}") + + +def clean_up(): + """Clean up VLLM resources""" + destroy_model_parallel() + destroy_distributed_environment() + torch.npu.empty_cache() + + +def parse_args(): + parser = FlexibleArgumentParser() + EngineArgs.add_cli_args(parser) + parser.add_argument("--output", "-o", required=True, type=str, help="path to output checkpoint") + parser.add_argument( + "--enable-compress", + action="store_true", + ) + parser.add_argument( + "--compress-process-num", + type=int, + default=1, + ) + return parser.parse_args() + + +def get_quant_description(json_file: str) -> dict: + """ + Extract quantization description from JSON configuration file. + + Args: + json_file: Path to the JSON configuration file + + Returns: + dict: Quantization descriptor dictionary + + Raises: + FileNotFoundError: If the JSON file does not exist + RuntimeError: If JSON parsing fails or required keys are missing + """ + config_path = Path(json_file) + if not config_path.exists(): + raise FileNotFoundError(f"Model configuration file not found: {json_file}") + try: + with config_path.open("r", encoding="utf-8") as file: + quant_desc = json.load(file) + except json.JSONDecodeError as e: + raise RuntimeError(f"Invalid JSON format in {json_file}: {e}") + + return quant_desc + + +def update_quant_description(json_file: str) -> None: + """ + Update quantization types in JSON configuration file based on update mapping. + + Args: + json_file: Path to the JSON configuration file + + Raises: + FileNotFoundError: If the JSON file does not exist + RuntimeError: If JSON parsing fails or required keys are missing + """ + config_path = Path(json_file) + try: + with config_path.open("r", encoding="utf-8") as file: + json_data = json.load(file) + except (FileNotFoundError, json.JSONDecodeError) as e: + raise RuntimeError(f"Failed to read configuration file {json_file}: {e}") + + original_quant_type = json_data.get("model_quant_type") + if not original_quant_type or original_quant_type not in QUANTIZATION_UPDATE_MAP: + raise RuntimeError( + f"Cannot update quantization type. " + f"Original type '{original_quant_type}' not found or not supported for update in {json_file}." + ) + updated_quant_type = QUANTIZATION_UPDATE_MAP[original_quant_type] + + updated_config = {"model_quant_type": updated_quant_type, "version": "1.0.0"} + + for key, value in json_data.items(): + if key.endswith(".weight") and value == original_quant_type: + updated_config[key] = updated_quant_type + elif key not in ("model_quant_type", "version"): + updated_config[key] = value + + try: + new_file_path = config_path.parent / "quant_model_description.json" + with new_file_path.open("w", encoding="utf-8") as file: + json.dump(updated_config, file, indent=2, ensure_ascii=False) + os.remove(json_file) + except OSError as e: + raise RuntimeError(f"Failed to write updated configuration to {json_file}: {e}") + + +def weight_compress_worker(file_path: str, quant_desc: dict, process_num: int) -> bool: + """ + Worker logic for multiprocessing. + Note: Imports are inside the worker to save memory in the main process. + + Returns: + bool: True if processing succeeded, False otherwise. + """ + import safetensors + import safetensors.torch + from msmodelslim.pytorch.weight_compression import CompressConfig, Compressor + + p = Path(file_path) + if not p.exists(): + print(f"Error: File not found, failed to compress: {file_path}") + return False + + try: + state_dict = safetensors.torch.load_file(str(p)) + + compress_config = CompressConfig( + do_pseudo_sparse=False, + sparse_ratio=1, + is_debug=True, + record_detail_root=str(p.parent), + multiprocess_num=process_num, + ) + compressor = Compressor(compress_config, weight=state_dict, quant_model_description=quant_desc) + compressor.run() + if p.exists(): + os.remove(p) + ori_quant_desc_file = p.parent / "quant_model_description.json" + if ori_quant_desc_file.exists(): + os.rename(str(ori_quant_desc_file), str(ori_quant_desc_file.parent / "ori_quant_model_description.json")) + compressor.export_safetensors(str(p.parent), safetensors_name=p.name) + return True + except Exception as e: + print(f"Error processing Rank file {file_path}: {e}") + return False + + +def main(args): + # 1. Initial Validation + # Validate early so the script doesn't fail after hours of inference + output_dir = FileHandler.validate_path(args.output, must_exist=False, check_writable=True) + model_dir = FileHandler.validate_path(args.model, must_exist=True) + + # 2. Run VLLM Engine and save sharded states + engine_args = EngineArgs.from_cli_args(args) + llm = LLM(**dataclasses.asdict(engine_args)) + + output_dir.mkdir(parents=True, exist_ok=True) + llm.llm_engine.engine_core.save_sharded_state(path=str(output_dir)) + + del llm + clean_up() + + # 3. Migrate Metadata (Excluding large weights) + for item in model_dir.iterdir(): + if item.suffix not in (".bin", ".pt", ".safetensors"): + FileHandler.safe_copy(item, output_dir / item.name) + + # 4. Compression Logic + parameters_map_fpath = output_dir / "parameters_type_map.json" + if args.enable_compress: + quant_desc = get_quant_description(str(parameters_map_fpath)) + quant_type = quant_desc["model_quant_type"] + if quant_type in SUPPORTED_COMPRESS_QUANT_TYPE: + # TODO: Implement w16a16sc + if quant_type == "W16A16S": + raise NotImplementedError("W16A16SC is not supported yet.") + + tasks = [] + for i in range(args.tensor_parallel_size): + file_name = DEFAULT_PATTERN.format(rank=i, part="0") + full_path = output_dir / file_name + + p = mp.Process( + target=weight_compress_worker, args=(str(full_path), quant_desc, args.compress_process_num) + ) + tasks.append(p) + p.start() + + for p in tasks: + p.join() + + update_quant_description(os.path.join(args.output, "ori_quant_model_description.json")) + print("Compression completed successfully.") + else: + print(f"Skipping compression: Unsupported type {quant_type}") + if parameters_map_fpath.exists(): + os.remove(parameters_map_fpath) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/mypy.ini b/mypy.ini index a86a89df..e9c2f8c4 100644 --- a/mypy.ini +++ b/mypy.ini @@ -34,3 +34,6 @@ ignore_missing_imports = True [mypy-ucm.*] ignore_missing_imports = True + +[mypy-msmodelslim.*] +ignore_missing_imports = True diff --git a/tests/ut/_310p/quantization/test_w8a8s_310.py b/tests/ut/_310p/quantization/test_w8a8s_310.py new file mode 100644 index 00000000..f9c13809 --- /dev/null +++ b/tests/ut/_310p/quantization/test_w8a8s_310.py @@ -0,0 +1,93 @@ +# +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch + +import torch + +from tests.ut.base import TestBase +from vllm_ascend._310p.quantization.methods.w8a8s import AscendW8A8SLinearMethod310 + + +class TestAscendW8A8SLinearMethod310(TestBase): + def setUp(self): + self.method = AscendW8A8SLinearMethod310() + + def test_get_weight_310(self): + weight = self.method.get_weight(10, 20) + self.assertEqual(weight["weight"].dtype, torch.int8) + self.assertEqual(weight["weight"].shape, (20, 10)) + + def test_get_pertensor_param_310(self): + params = self.method.get_pertensor_param(torch.float16) + self.assertEqual(params["input_scale"].dtype, torch.float16) + self.assertEqual(params["input_offset"].dtype, torch.int8) + self.assertEqual(params["input_scale"].shape, (1,)) + self.assertEqual(params["input_offset"].shape, (1,)) + + def test_get_perchannel_param_310(self): + params = self.method.get_perchannel_param(10, torch.float16) + + self.assertEqual(params["quant_bias"].dtype, torch.int32) + self.assertEqual(params["deq_scale"].dtype, torch.int64) + self.assertEqual(params["quant_bias"].shape, (10,)) + self.assertEqual(params["deq_scale"].shape, (10,)) + + @patch("torch.ops.vllm.quantize") + @patch("torch_npu.npu_quant_matmul") + def test_apply_with_x_not_int8_310(self, mock_npu_quant_matmul, mock_quantize): + layer = MagicMock() + layer.aclnn_input_scale = torch.randn(256) + layer.aclnn_input_scale_reciprocal = 1.0 / layer.aclnn_input_scale + layer.aclnn_input_offset = torch.randint(-128, 127, (256,), dtype=torch.int8) + layer.weight = torch.randn(128, 256) + layer.deq_scale = torch.randn(128) + layer.quant_bias = torch.randint(-128, 127, (256,)) + layer.params_dtype = torch.float16 + + x = torch.randn(32, 128) + expect_x_output = torch.randint(-128, 127, x.shape, dtype=torch.int8) + mock_quantize.return_value = expect_x_output + + expected_y_output = torch.randn(32, 256) + mock_npu_quant_matmul.return_value = expected_y_output + + output = self.method.apply(layer, x, tp_rank=0) + + mock_quantize.assert_called_with( + x, layer.aclnn_input_scale, layer.aclnn_input_scale_reciprocal, layer.aclnn_input_offset + ) + self.assertTrue(torch.equal(output, expected_y_output)) + + @patch("torch.ops.vllm.quantize") + @patch("torch_npu.npu_quant_matmul") + def test_apply_with_x_is_int8_310(self, mock_npu_quant_matmul, mock_quantize): + layer = MagicMock() + layer.aclnn_input_scale = torch.randn(256) + layer.aclnn_input_offset = torch.randint(-128, 127, (256,), dtype=torch.int8) + layer.weight = torch.randn(128, 256) + layer.deq_scale = torch.randn(128) + layer.quant_bias = torch.randint(-128, 127, (256,)) + layer.params_dtype = torch.float16 + + x = torch.randint(-128, 127, (32, 128), dtype=torch.int8) + + expected_y_output = torch.randn(32, 256) + mock_npu_quant_matmul.return_value = expected_y_output + + output = self.method.apply(layer, x, tp_rank=0) + + mock_quantize.assert_not_called() + self.assertTrue(torch.equal(output, expected_y_output)) diff --git a/tests/ut/_310p/test_sharded_state_loader_310p.py b/tests/ut/_310p/test_sharded_state_loader_310p.py new file mode 100644 index 00000000..edd9295c --- /dev/null +++ b/tests/ut/_310p/test_sharded_state_loader_310p.py @@ -0,0 +1,116 @@ +# +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import torch + +from tests.ut.base import TestBase +from vllm_ascend._310p.sharded_state_loader_310p import ShardedStateLoader310 + + +class MockQuantConfig: + """Mock quantization config for testing.""" + + def __init__(self, quant_type: str = "FLOAT"): + self.quant_description = {"model_quant_type": quant_type} + + +class MockModel(torch.nn.Module): + """Mock model for testing.""" + + def __init__(self, quant_config=None, with_int_weights: bool = False): + super().__init__() + self.quant_config = quant_config + self.with_int_weights = with_int_weights + if with_int_weights: + self.linear = torch.nn.Linear(10, 10) + self.linear.weight = torch.nn.Parameter( + torch.randint(-127, 127, (10, 10), dtype=torch.int8), requires_grad=False + ) + self.linear.bias = torch.nn.Parameter(torch.zeros(10, dtype=torch.int32), requires_grad=False) + else: + self.linear = torch.nn.Linear(10, 10) + + +class TestShardedStateLoader310(TestBase): + """Test cases for ShardedStateLoader310.""" + + @patch("vllm.model_executor.model_loader.ShardedStateLoader._filter_subtensors") + @patch("vllm.distributed.get_tensor_model_parallel_rank") + @patch("safetensors.torch.save_file") + def test_save_model_with_nd_format_310(self, mock_save_file, mock_get_rank, mock_filter): + """Test save_model with ND format tensors (no conversion needed).""" + mock_get_rank.return_value = 0 + mock_filter.side_effect = lambda x: x + mock_tensor = MagicMock(spec=torch.Tensor) + + model = MockModel() + with ( + patch.object(model, "state_dict", return_value={"linear.weight": mock_tensor}), + tempfile.TemporaryDirectory() as tmpdir, + ): + ShardedStateLoader310.save_model(model, tmpdir) + + mock_save_file.assert_called_once() + + @patch("vllm.model_executor.model_loader.ShardedStateLoader._filter_subtensors") + def test_generate_quant_description_float_model_310(self, mock_filter): + """Test generate_quant_description for float model.""" + mock_filter.side_effect = lambda x: x + quant_config = MockQuantConfig(quant_type="FLOAT") + model = MockModel(quant_config=quant_config, with_int_weights=False) + + with tempfile.TemporaryDirectory() as tmpdir: + ShardedStateLoader310.generate_quant_description(model, tmpdir) + + json_path = Path(tmpdir) / "parameters_type_map.json" + self.assertTrue(json_path.exists()) + + with open(json_path, encoding="utf-8") as f: + quant_description = json.load(f) + + self.assertEqual(quant_description["model_quant_type"], "FLOAT") + self.assertEqual(quant_description["version"], "1.0.0") + self.assertIn("linear.weight", quant_description) + self.assertEqual(quant_description["linear.weight"], "FLOAT") + self.assertIn("linear.bias", quant_description) + self.assertEqual(quant_description["linear.bias"], "FLOAT") + + @patch("vllm.model_executor.model_loader.ShardedStateLoader._filter_subtensors") + def test_generate_quant_description_int_model_310(self, mock_filter): + """Test generate_quant_description for int8 quantized model.""" + mock_filter.side_effect = lambda x: x + quant_config = MockQuantConfig(quant_type="W8A8") + model = MockModel(quant_config=quant_config, with_int_weights=True) + + with tempfile.TemporaryDirectory() as tmpdir: + ShardedStateLoader310.generate_quant_description(model, tmpdir) + + json_path = Path(tmpdir) / "parameters_type_map.json" + self.assertTrue(json_path.exists()) + + with open(json_path, encoding="utf-8") as f: + quant_description = json.load(f) + + self.assertEqual(quant_description["model_quant_type"], "W8A8") + self.assertEqual(quant_description["version"], "1.0.0") + self.assertIn("linear.weight", quant_description) + self.assertEqual(quant_description["linear.weight"], "W8A8") + self.assertIn("linear.bias", quant_description) + self.assertEqual(quant_description["linear.bias"], "W8A8") diff --git a/vllm_ascend/_310p/quantization/methods/__init__.py b/vllm_ascend/_310p/quantization/methods/__init__.py index 70fda173..4399207a 100644 --- a/vllm_ascend/_310p/quantization/methods/__init__.py +++ b/vllm_ascend/_310p/quantization/methods/__init__.py @@ -18,4 +18,5 @@ from . import ( w8a8_dynamic, # noqa: F401 w8a8_static, # noqa: F401 + w8a8s, # noqa: F401 ) diff --git a/vllm_ascend/_310p/quantization/methods/w8a8s.py b/vllm_ascend/_310p/quantization/methods/w8a8s.py new file mode 100644 index 00000000..80e09472 --- /dev/null +++ b/vllm_ascend/_310p/quantization/methods/w8a8s.py @@ -0,0 +1,87 @@ +# +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# + +from typing import Any + +import torch +import torch_npu + +from vllm_ascend.quantization.methods.base import AscendLinearScheme +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ + +from .registry import register_scheme + + +@register_scheme("W8A8S", "linear") +class AscendW8A8SLinearMethod310(AscendLinearScheme): + """310P-only W8A8S Sparse linear scheme. + + Notes: + - This scheme is discovered via 310P local registry. + """ + + def get_weight( + self, + input_size: int, + output_size: int, + params_dtype: torch.dtype = torch.float16, + ) -> dict[str, Any]: + return {"weight": torch.empty(output_size, input_size, dtype=torch.int8)} + + def get_pertensor_param(self, params_dtype: torch.dtype) -> dict[str, Any]: + return { + "input_scale": torch.empty(1, dtype=params_dtype), + "input_offset": torch.empty(1, dtype=torch.int8), + } + + def get_perchannel_param(self, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]: + return { + "quant_bias": torch.empty(output_size, dtype=torch.int32), + "deq_scale": torch.empty(output_size, dtype=torch.int64), + } + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + tp_rank: int | None = 0, + ) -> torch.Tensor: + if x.dtype != torch.int8: + x = torch.ops.vllm.quantize( + x, + layer.aclnn_input_scale, + layer.aclnn_input_scale_reciprocal, + layer.aclnn_input_offset, + ) + + quant_bias = layer.quant_bias if tp_rank == 0 else None + + return torch_npu.npu_quant_matmul( + x, + layer.weight.data.transpose(0, 1), + layer.deq_scale, + bias=quant_bias, + output_dtype=layer.params_dtype, + ) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + expanding_factor = layer.weight.data.shape[1] + layer.aclnn_input_scale = layer.input_scale.data.repeat(expanding_factor) + layer.aclnn_input_scale_reciprocal = 1.0 / layer.aclnn_input_scale.data + layer.aclnn_input_offset = layer.input_offset.data.repeat(expanding_factor).to(layer.aclnn_input_scale.dtype) + layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, ACL_FORMAT_FRACTAL_NZ) diff --git a/vllm_ascend/_310p/sharded_state_loader_310p.py b/vllm_ascend/_310p/sharded_state_loader_310p.py new file mode 100644 index 00000000..9200d82a --- /dev/null +++ b/vllm_ascend/_310p/sharded_state_loader_310p.py @@ -0,0 +1,69 @@ +# +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +import json +import os +from pathlib import Path + +import torch +from vllm.config.load import LoadConfig +from vllm.model_executor.model_loader import ShardedStateLoader + + +class ShardedStateLoader310(ShardedStateLoader): + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + + @staticmethod + def save_model( + model: torch.nn.Module, + path: str, + pattern: str | None = None, + max_size: int | None = None, + ) -> None: + from safetensors.torch import save_file + from vllm.distributed import get_tensor_model_parallel_rank + + rank = get_tensor_model_parallel_rank() + part_idx = 0 + state_dict = ShardedStateLoader._filter_subtensors(model.state_dict()) + + filename = ShardedStateLoader.DEFAULT_PATTERN.format(rank=rank, part=part_idx) + save_file( + state_dict, + os.path.join(path, filename), + ) + + @staticmethod + def generate_quant_description(model: torch.nn.Module, path: str): + """Generate a mapping of parameter names to their corresponding quantization types.""" + quant_description = {} + quantize_type = model.quant_config.quant_description.get("model_quant_type", "FLOAT") + quant_description["model_quant_type"] = quantize_type + quant_description["version"] = "1.0.0" + state_dict = ShardedStateLoader._filter_subtensors(model.state_dict()) + for name, tensor in state_dict.items(): + if name.endswith(".weight") or name.endswith(".bias"): + if tensor.dtype in [torch.int8, torch.int32, torch.int64]: + quant_description[name] = quantize_type + else: + quant_description[name] = "FLOAT" + else: + quant_description[name] = "FLOAT" + + json_path = Path(path) / "parameters_type_map.json" + with json_path.open("w", encoding="utf-8") as f: + json.dump(quant_description, f, indent=2) diff --git a/vllm_ascend/_310p/worker_310p.py b/vllm_ascend/_310p/worker_310p.py index bb0fa28d..87fe3a81 100644 --- a/vllm_ascend/_310p/worker_310p.py +++ b/vllm_ascend/_310p/worker_310p.py @@ -14,7 +14,6 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # - import torch_npu from vllm.logger import logger @@ -31,6 +30,23 @@ class NPUWorker310(NPUWorker): self.model_runner = NPUModelRunner310(self.vllm_config, self.device) + def save_sharded_state( + self, + path: str, + pattern: str | None = None, + max_size: int | None = None, + ) -> None: + from vllm_ascend._310p.sharded_state_loader_310p import ShardedStateLoader310 + + ShardedStateLoader310.save_model( + self.model_runner.model, + path, + pattern=pattern, + max_size=max_size, + ) + + ShardedStateLoader310.generate_quant_description(self.model_runner.model, path) + def _warm_up_atb(self): # 310p device do not support torch_npu._npu_matmul_add_fp32 atb ops logger.info("Skip warm-up atb ops for 310P device.")