### What this PR does / why we need it?
New Quantization Method: Introduced support for the W8A8SC static linear
quantization scheme specifically for 310P hardware, enabling more
efficient model compression.
Refactored the save_sharded_state_310.py to avoid multi-process issue.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
W8A8SC quant E2E test.
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: pu-zhe <zpuaa@outlook.com>
288 lines
10 KiB
Python
288 lines
10 KiB
Python
#
|
|
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
# Adapted from vllm-project/vllm/examples/offline_inference/save_sharded_state.py
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
"""
|
|
Saves each worker's model state dict directly to a checkpoint, which enables a
|
|
fast load path for large tensor-parallel models where each worker only needs to
|
|
read its own shard rather than the entire checkpoint.
|
|
Sparse-Compress-Quantization state dict could also be saved via this script.
|
|
|
|
Example usage:
|
|
|
|
python save_sharded_state_310.py \
|
|
--model /path/to/load \
|
|
--tensor-parallel-size 8 \
|
|
--output /path/to/save \
|
|
--enable-compress \
|
|
--compress-process-num 8 \
|
|
--enforce-eager \
|
|
--dtype float16 \
|
|
--quantization ascend
|
|
|
|
Then, the model can be loaded with
|
|
|
|
llm = LLM(
|
|
model="/path/to/save",
|
|
load_format="sharded_state",
|
|
tensor_parallel_size=8,
|
|
quantization="ascend",
|
|
)
|
|
"""
|
|
|
|
import dataclasses
|
|
import json
|
|
import multiprocessing as mp
|
|
import os
|
|
import shutil
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from vllm import LLM, EngineArgs
|
|
from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel
|
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
|
|
|
SUPPORTED_COMPRESS_QUANT_TYPE = ["W8A8S", "W16A16S"]
|
|
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
|
|
QUANTIZATION_UPDATE_MAP = {"W8A8S": "W8A8SC", "W16A16S": "W16A16SC"}
|
|
|
|
|
|
class FileHandler:
|
|
@staticmethod
|
|
def validate_path(path: str, must_exist: bool = True, check_writable: bool = False) -> Path:
|
|
"""
|
|
Comprehensive path validation.
|
|
- Checks existence
|
|
- Checks write permissions for the target or its parent
|
|
"""
|
|
p = Path(path)
|
|
if must_exist and not p.exists():
|
|
raise FileNotFoundError(f"Error: Path '{path}' does not exist.")
|
|
|
|
if check_writable:
|
|
# Check the directory itself if it exists, otherwise check the parent
|
|
target = p if p.exists() else p.parent
|
|
if not os.access(target, os.W_OK):
|
|
raise PermissionError(f"Permission Denied: No write access to '{target}'.")
|
|
return p
|
|
|
|
@staticmethod
|
|
def safe_copy(src: Path, dst: Path):
|
|
"""Copies files or directories with permission handling."""
|
|
try:
|
|
if src.is_dir():
|
|
# dirs_exist_ok=True prevents errors if the destination directory exists
|
|
shutil.copytree(src, dst, dirs_exist_ok=True)
|
|
else:
|
|
# copy2 preserves metadata (timestamps, permissions)
|
|
shutil.copy2(src, dst)
|
|
except (PermissionError, OSError) as e:
|
|
print(f"Warning: Failed to copy {src} due to: {e}")
|
|
|
|
|
|
def clean_up():
|
|
"""Clean up VLLM resources"""
|
|
destroy_model_parallel()
|
|
destroy_distributed_environment()
|
|
torch.npu.empty_cache()
|
|
|
|
|
|
def parse_args():
|
|
parser = FlexibleArgumentParser()
|
|
EngineArgs.add_cli_args(parser)
|
|
parser.add_argument("--output", "-o", required=True, type=str, help="path to output checkpoint")
|
|
parser.add_argument(
|
|
"--enable-compress",
|
|
action="store_true",
|
|
)
|
|
parser.add_argument(
|
|
"--compress-process-num",
|
|
type=int,
|
|
default=1,
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def get_quant_description(json_file: str) -> dict:
|
|
"""
|
|
Extract quantization description from JSON configuration file.
|
|
|
|
Args:
|
|
json_file: Path to the JSON configuration file
|
|
|
|
Returns:
|
|
dict: Quantization descriptor dictionary
|
|
|
|
Raises:
|
|
FileNotFoundError: If the JSON file does not exist
|
|
RuntimeError: If JSON parsing fails or required keys are missing
|
|
"""
|
|
config_path = Path(json_file)
|
|
if not config_path.exists():
|
|
raise FileNotFoundError(f"Model configuration file not found: {json_file}")
|
|
try:
|
|
with config_path.open("r", encoding="utf-8") as file:
|
|
quant_desc = json.load(file)
|
|
except json.JSONDecodeError as e:
|
|
raise RuntimeError(f"Invalid JSON format in {json_file}: {e}")
|
|
|
|
return quant_desc
|
|
|
|
|
|
def update_quant_description(ori_json_file: str, target_json_file: str) -> None:
|
|
"""
|
|
Update quantization types in JSON configuration file based on update mapping.
|
|
|
|
Args:
|
|
ori_json_file: Path to the JSON configuration file
|
|
target_json_file: Path to the JSON configuration file to be saved
|
|
|
|
Raises:
|
|
FileNotFoundError: If the JSON file does not exist
|
|
RuntimeError: If JSON parsing fails or required keys are missing
|
|
"""
|
|
config_path = Path(ori_json_file)
|
|
try:
|
|
with config_path.open("r", encoding="utf-8") as file:
|
|
json_data = json.load(file)
|
|
except (FileNotFoundError, json.JSONDecodeError) as e:
|
|
raise RuntimeError(f"Failed to read configuration file {ori_json_file}: {e}")
|
|
|
|
original_quant_type = json_data.get("model_quant_type")
|
|
if not original_quant_type or original_quant_type not in QUANTIZATION_UPDATE_MAP:
|
|
raise RuntimeError(
|
|
f"Cannot update quantization type. "
|
|
f"Original type '{original_quant_type}' not found or not supported for update in {ori_json_file}."
|
|
)
|
|
updated_quant_type = QUANTIZATION_UPDATE_MAP[original_quant_type]
|
|
|
|
updated_config = {"model_quant_type": updated_quant_type, "version": "1.0.0"}
|
|
|
|
for key, value in json_data.items():
|
|
if key.endswith(".weight") and value == original_quant_type:
|
|
updated_config[key] = updated_quant_type
|
|
elif key not in ("model_quant_type", "version"):
|
|
updated_config[key] = value
|
|
|
|
try:
|
|
new_file_path = Path(target_json_file)
|
|
with new_file_path.open("w", encoding="utf-8") as file:
|
|
json.dump(updated_config, file, indent=2, ensure_ascii=False)
|
|
os.remove(ori_json_file)
|
|
except OSError as e:
|
|
raise RuntimeError(f"Failed to write updated configuration to {target_json_file}: {e}")
|
|
|
|
|
|
def weight_compress_worker(file_path: str, quant_desc: dict, process_num: int) -> bool:
|
|
"""
|
|
Worker logic for multiprocessing.
|
|
Note: Imports are inside the worker to save memory in the main process.
|
|
|
|
Returns:
|
|
bool: True if processing succeeded, False otherwise.
|
|
"""
|
|
import safetensors
|
|
import safetensors.torch
|
|
from msmodelslim.pytorch.weight_compression import CompressConfig, Compressor
|
|
|
|
p = Path(file_path)
|
|
if not p.exists():
|
|
print(f"Error: File not found, failed to compress: {file_path}")
|
|
return False
|
|
|
|
try:
|
|
state_dict = safetensors.torch.load_file(str(p))
|
|
|
|
compress_config = CompressConfig(
|
|
do_pseudo_sparse=False,
|
|
sparse_ratio=1,
|
|
is_debug=True,
|
|
record_detail_root=str(p.parent),
|
|
multiprocess_num=process_num,
|
|
)
|
|
compressor = Compressor(compress_config, weight=state_dict, quant_model_description=quant_desc)
|
|
compressor.run()
|
|
if p.exists():
|
|
os.remove(p)
|
|
compressor.export_safetensors(str(p.parent), safetensors_name=p.name)
|
|
return True
|
|
except Exception as e:
|
|
print(f"Error processing Rank file {file_path}: {e}")
|
|
return False
|
|
|
|
|
|
def main(args):
|
|
# 1. Initial Validation
|
|
# Validate early so the script doesn't fail after hours of inference
|
|
output_dir = FileHandler.validate_path(args.output, must_exist=False, check_writable=True)
|
|
model_dir = FileHandler.validate_path(args.model, must_exist=True)
|
|
|
|
# 2. Run VLLM Engine and save sharded states
|
|
engine_args = EngineArgs.from_cli_args(args)
|
|
llm = LLM(**dataclasses.asdict(engine_args))
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
llm.llm_engine.engine_core.save_sharded_state(path=str(output_dir))
|
|
|
|
del llm
|
|
clean_up()
|
|
|
|
# 3. Migrate Metadata (Excluding large weights)
|
|
for item in model_dir.iterdir():
|
|
if item.suffix not in (".bin", ".pt", ".safetensors"):
|
|
FileHandler.safe_copy(item, output_dir / item.name)
|
|
|
|
# 4. Compression Logic
|
|
parameters_map_fpath = output_dir / "parameters_type_map.json"
|
|
if args.enable_compress:
|
|
quant_desc_file = output_dir / "quant_model_description.json"
|
|
backup_quant_desc_file = output_dir / "ori_quant_model_description.json"
|
|
if quant_desc_file.exists():
|
|
os.rename(str(quant_desc_file), str(backup_quant_desc_file))
|
|
quant_desc = get_quant_description(str(parameters_map_fpath))
|
|
quant_type = quant_desc["model_quant_type"]
|
|
if quant_type in SUPPORTED_COMPRESS_QUANT_TYPE:
|
|
# TODO: Implement w16a16sc
|
|
if quant_type == "W16A16S":
|
|
raise NotImplementedError("W16A16SC is not supported yet.")
|
|
|
|
tasks = []
|
|
for i in range(args.tensor_parallel_size):
|
|
file_name = DEFAULT_PATTERN.format(rank=i, part="0")
|
|
full_path = output_dir / file_name
|
|
|
|
p = mp.Process(
|
|
target=weight_compress_worker, args=(str(full_path), quant_desc, args.compress_process_num)
|
|
)
|
|
tasks.append(p)
|
|
p.start()
|
|
|
|
for p in tasks:
|
|
p.join()
|
|
|
|
update_quant_description(str(backup_quant_desc_file), str(quant_desc_file))
|
|
print("Compression completed successfully.")
|
|
else:
|
|
print(f"Skipping compression: Unsupported type {quant_type}")
|
|
if parameters_map_fpath.exists():
|
|
os.remove(parameters_map_fpath)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
main(args)
|