### What this PR does / why we need it?
This pull request introduces significant enhancements for 310P device
support, primarily by enabling W8A8S quantization and facilitating the
saving of models with W8A8SC state outputs. It provides an example
script for saving sharded and compressed model states, implements the
core W8A8S quantization method, and integrates metadata generation
within the 310P worker to accurately describe the quantization types of
saved parameters. These changes aim to improve efficiency and
compatibility for quantized models on 310P hardware.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
W8A8S accuarcy test and W8A8SC states save.
<img width="886" height="184" alt="image"
src="https://github.com/user-attachments/assets/e9bcac54-1f69-4d3a-a5b8-221a147ef99d"
/>
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
---------
Signed-off-by: pu-zhe <zpuaa@outlook.com>
283 lines
9.8 KiB
Python
283 lines
9.8 KiB
Python
#
|
|
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
# Adapted from vllm-project/vllm/examples/offline_inference/save_sharded_state.py
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
"""
|
|
Saves each worker's model state dict directly to a checkpoint, which enables a
|
|
fast load path for large tensor-parallel models where each worker only needs to
|
|
read its own shard rather than the entire checkpoint.
|
|
Sparse-Compress-Quantization state dict could also be saved via this script.
|
|
|
|
Example usage:
|
|
|
|
python save_sharded_state.py \
|
|
--model /path/to/load \
|
|
--tensor-parallel-size 8 \
|
|
--output /path/to/save \
|
|
--enable-compress \
|
|
--compress-process-num 8
|
|
|
|
Then, the model can be loaded with
|
|
|
|
llm = LLM(
|
|
model="/path/to/save",
|
|
load_format="sharded_state",
|
|
tensor_parallel_size=8,
|
|
quantization="ascend",
|
|
)
|
|
"""
|
|
|
|
import dataclasses
|
|
import json
|
|
import multiprocessing as mp
|
|
import os
|
|
import shutil
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from vllm import LLM, EngineArgs
|
|
from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel
|
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
|
|
|
SUPPORTED_COMPRESS_QUANT_TYPE = ["W8A8S", "W16A16S"]
|
|
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
|
|
QUANTIZATION_UPDATE_MAP = {"W8A8S": "W8A8SC", "W16A16S": "W16A16SC"}
|
|
|
|
|
|
class FileHandler:
|
|
@staticmethod
|
|
def validate_path(path: str, must_exist: bool = True, check_writable: bool = False) -> Path:
|
|
"""
|
|
Comprehensive path validation.
|
|
- Checks existence
|
|
- Checks write permissions for the target or its parent
|
|
"""
|
|
p = Path(path)
|
|
if must_exist and not p.exists():
|
|
raise FileNotFoundError(f"Error: Path '{path}' does not exist.")
|
|
|
|
if check_writable:
|
|
# Check the directory itself if it exists, otherwise check the parent
|
|
target = p if p.exists() else p.parent
|
|
if not os.access(target, os.W_OK):
|
|
raise PermissionError(f"Permission Denied: No write access to '{target}'.")
|
|
return p
|
|
|
|
@staticmethod
|
|
def safe_copy(src: Path, dst: Path):
|
|
"""Copies files or directories with permission handling."""
|
|
try:
|
|
if src.is_dir():
|
|
# dirs_exist_ok=True prevents errors if the destination directory exists
|
|
shutil.copytree(src, dst, dirs_exist_ok=True)
|
|
else:
|
|
# copy2 preserves metadata (timestamps, permissions)
|
|
shutil.copy2(src, dst)
|
|
except (PermissionError, OSError) as e:
|
|
print(f"Warning: Failed to copy {src} due to: {e}")
|
|
|
|
|
|
def clean_up():
|
|
"""Clean up VLLM resources"""
|
|
destroy_model_parallel()
|
|
destroy_distributed_environment()
|
|
torch.npu.empty_cache()
|
|
|
|
|
|
def parse_args():
|
|
parser = FlexibleArgumentParser()
|
|
EngineArgs.add_cli_args(parser)
|
|
parser.add_argument("--output", "-o", required=True, type=str, help="path to output checkpoint")
|
|
parser.add_argument(
|
|
"--enable-compress",
|
|
action="store_true",
|
|
)
|
|
parser.add_argument(
|
|
"--compress-process-num",
|
|
type=int,
|
|
default=1,
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def get_quant_description(json_file: str) -> dict:
|
|
"""
|
|
Extract quantization description from JSON configuration file.
|
|
|
|
Args:
|
|
json_file: Path to the JSON configuration file
|
|
|
|
Returns:
|
|
dict: Quantization descriptor dictionary
|
|
|
|
Raises:
|
|
FileNotFoundError: If the JSON file does not exist
|
|
RuntimeError: If JSON parsing fails or required keys are missing
|
|
"""
|
|
config_path = Path(json_file)
|
|
if not config_path.exists():
|
|
raise FileNotFoundError(f"Model configuration file not found: {json_file}")
|
|
try:
|
|
with config_path.open("r", encoding="utf-8") as file:
|
|
quant_desc = json.load(file)
|
|
except json.JSONDecodeError as e:
|
|
raise RuntimeError(f"Invalid JSON format in {json_file}: {e}")
|
|
|
|
return quant_desc
|
|
|
|
|
|
def update_quant_description(json_file: str) -> None:
|
|
"""
|
|
Update quantization types in JSON configuration file based on update mapping.
|
|
|
|
Args:
|
|
json_file: Path to the JSON configuration file
|
|
|
|
Raises:
|
|
FileNotFoundError: If the JSON file does not exist
|
|
RuntimeError: If JSON parsing fails or required keys are missing
|
|
"""
|
|
config_path = Path(json_file)
|
|
try:
|
|
with config_path.open("r", encoding="utf-8") as file:
|
|
json_data = json.load(file)
|
|
except (FileNotFoundError, json.JSONDecodeError) as e:
|
|
raise RuntimeError(f"Failed to read configuration file {json_file}: {e}")
|
|
|
|
original_quant_type = json_data.get("model_quant_type")
|
|
if not original_quant_type or original_quant_type not in QUANTIZATION_UPDATE_MAP:
|
|
raise RuntimeError(
|
|
f"Cannot update quantization type. "
|
|
f"Original type '{original_quant_type}' not found or not supported for update in {json_file}."
|
|
)
|
|
updated_quant_type = QUANTIZATION_UPDATE_MAP[original_quant_type]
|
|
|
|
updated_config = {"model_quant_type": updated_quant_type, "version": "1.0.0"}
|
|
|
|
for key, value in json_data.items():
|
|
if key.endswith(".weight") and value == original_quant_type:
|
|
updated_config[key] = updated_quant_type
|
|
elif key not in ("model_quant_type", "version"):
|
|
updated_config[key] = value
|
|
|
|
try:
|
|
new_file_path = config_path.parent / "quant_model_description.json"
|
|
with new_file_path.open("w", encoding="utf-8") as file:
|
|
json.dump(updated_config, file, indent=2, ensure_ascii=False)
|
|
os.remove(json_file)
|
|
except OSError as e:
|
|
raise RuntimeError(f"Failed to write updated configuration to {json_file}: {e}")
|
|
|
|
|
|
def weight_compress_worker(file_path: str, quant_desc: dict, process_num: int) -> bool:
|
|
"""
|
|
Worker logic for multiprocessing.
|
|
Note: Imports are inside the worker to save memory in the main process.
|
|
|
|
Returns:
|
|
bool: True if processing succeeded, False otherwise.
|
|
"""
|
|
import safetensors
|
|
import safetensors.torch
|
|
from msmodelslim.pytorch.weight_compression import CompressConfig, Compressor
|
|
|
|
p = Path(file_path)
|
|
if not p.exists():
|
|
print(f"Error: File not found, failed to compress: {file_path}")
|
|
return False
|
|
|
|
try:
|
|
state_dict = safetensors.torch.load_file(str(p))
|
|
|
|
compress_config = CompressConfig(
|
|
do_pseudo_sparse=False,
|
|
sparse_ratio=1,
|
|
is_debug=True,
|
|
record_detail_root=str(p.parent),
|
|
multiprocess_num=process_num,
|
|
)
|
|
compressor = Compressor(compress_config, weight=state_dict, quant_model_description=quant_desc)
|
|
compressor.run()
|
|
if p.exists():
|
|
os.remove(p)
|
|
ori_quant_desc_file = p.parent / "quant_model_description.json"
|
|
if ori_quant_desc_file.exists():
|
|
os.rename(str(ori_quant_desc_file), str(ori_quant_desc_file.parent / "ori_quant_model_description.json"))
|
|
compressor.export_safetensors(str(p.parent), safetensors_name=p.name)
|
|
return True
|
|
except Exception as e:
|
|
print(f"Error processing Rank file {file_path}: {e}")
|
|
return False
|
|
|
|
|
|
def main(args):
|
|
# 1. Initial Validation
|
|
# Validate early so the script doesn't fail after hours of inference
|
|
output_dir = FileHandler.validate_path(args.output, must_exist=False, check_writable=True)
|
|
model_dir = FileHandler.validate_path(args.model, must_exist=True)
|
|
|
|
# 2. Run VLLM Engine and save sharded states
|
|
engine_args = EngineArgs.from_cli_args(args)
|
|
llm = LLM(**dataclasses.asdict(engine_args))
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
llm.llm_engine.engine_core.save_sharded_state(path=str(output_dir))
|
|
|
|
del llm
|
|
clean_up()
|
|
|
|
# 3. Migrate Metadata (Excluding large weights)
|
|
for item in model_dir.iterdir():
|
|
if item.suffix not in (".bin", ".pt", ".safetensors"):
|
|
FileHandler.safe_copy(item, output_dir / item.name)
|
|
|
|
# 4. Compression Logic
|
|
parameters_map_fpath = output_dir / "parameters_type_map.json"
|
|
if args.enable_compress:
|
|
quant_desc = get_quant_description(str(parameters_map_fpath))
|
|
quant_type = quant_desc["model_quant_type"]
|
|
if quant_type in SUPPORTED_COMPRESS_QUANT_TYPE:
|
|
# TODO: Implement w16a16sc
|
|
if quant_type == "W16A16S":
|
|
raise NotImplementedError("W16A16SC is not supported yet.")
|
|
|
|
tasks = []
|
|
for i in range(args.tensor_parallel_size):
|
|
file_name = DEFAULT_PATTERN.format(rank=i, part="0")
|
|
full_path = output_dir / file_name
|
|
|
|
p = mp.Process(
|
|
target=weight_compress_worker, args=(str(full_path), quant_desc, args.compress_process_num)
|
|
)
|
|
tasks.append(p)
|
|
p.start()
|
|
|
|
for p in tasks:
|
|
p.join()
|
|
|
|
update_quant_description(os.path.join(args.output, "ori_quant_model_description.json"))
|
|
print("Compression completed successfully.")
|
|
else:
|
|
print(f"Skipping compression: Unsupported type {quant_type}")
|
|
if parameters_map_fpath.exists():
|
|
os.remove(parameters_map_fpath)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
main(args)
|