[Feat] [310p] Support w8a8sc quantization method (#7075)
### What this PR does / why we need it?
New Quantization Method: Introduced support for the W8A8SC static linear
quantization scheme specifically for 310P hardware, enabling more
efficient model compression.
Refactored the save_sharded_state_310.py to avoid multi-process issue.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
W8A8SC quant E2E test.
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: pu-zhe <zpuaa@outlook.com>
This commit is contained in:
@@ -24,12 +24,15 @@ Sparse-Compress-Quantization state dict could also be saved via this script.
|
||||
|
||||
Example usage:
|
||||
|
||||
python save_sharded_state.py \
|
||||
python save_sharded_state_310.py \
|
||||
--model /path/to/load \
|
||||
--tensor-parallel-size 8 \
|
||||
--output /path/to/save \
|
||||
--enable-compress \
|
||||
--compress-process-num 8
|
||||
--compress-process-num 8 \
|
||||
--enforce-eager \
|
||||
--dtype float16 \
|
||||
--quantization ascend
|
||||
|
||||
Then, the model can be loaded with
|
||||
|
||||
@@ -140,29 +143,30 @@ def get_quant_description(json_file: str) -> dict:
|
||||
return quant_desc
|
||||
|
||||
|
||||
def update_quant_description(json_file: str) -> None:
|
||||
def update_quant_description(ori_json_file: str, target_json_file: str) -> None:
|
||||
"""
|
||||
Update quantization types in JSON configuration file based on update mapping.
|
||||
|
||||
Args:
|
||||
json_file: Path to the JSON configuration file
|
||||
ori_json_file: Path to the JSON configuration file
|
||||
target_json_file: Path to the JSON configuration file to be saved
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the JSON file does not exist
|
||||
RuntimeError: If JSON parsing fails or required keys are missing
|
||||
"""
|
||||
config_path = Path(json_file)
|
||||
config_path = Path(ori_json_file)
|
||||
try:
|
||||
with config_path.open("r", encoding="utf-8") as file:
|
||||
json_data = json.load(file)
|
||||
except (FileNotFoundError, json.JSONDecodeError) as e:
|
||||
raise RuntimeError(f"Failed to read configuration file {json_file}: {e}")
|
||||
raise RuntimeError(f"Failed to read configuration file {ori_json_file}: {e}")
|
||||
|
||||
original_quant_type = json_data.get("model_quant_type")
|
||||
if not original_quant_type or original_quant_type not in QUANTIZATION_UPDATE_MAP:
|
||||
raise RuntimeError(
|
||||
f"Cannot update quantization type. "
|
||||
f"Original type '{original_quant_type}' not found or not supported for update in {json_file}."
|
||||
f"Original type '{original_quant_type}' not found or not supported for update in {ori_json_file}."
|
||||
)
|
||||
updated_quant_type = QUANTIZATION_UPDATE_MAP[original_quant_type]
|
||||
|
||||
@@ -175,12 +179,12 @@ def update_quant_description(json_file: str) -> None:
|
||||
updated_config[key] = value
|
||||
|
||||
try:
|
||||
new_file_path = config_path.parent / "quant_model_description.json"
|
||||
new_file_path = Path(target_json_file)
|
||||
with new_file_path.open("w", encoding="utf-8") as file:
|
||||
json.dump(updated_config, file, indent=2, ensure_ascii=False)
|
||||
os.remove(json_file)
|
||||
os.remove(ori_json_file)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"Failed to write updated configuration to {json_file}: {e}")
|
||||
raise RuntimeError(f"Failed to write updated configuration to {target_json_file}: {e}")
|
||||
|
||||
|
||||
def weight_compress_worker(file_path: str, quant_desc: dict, process_num: int) -> bool:
|
||||
@@ -214,9 +218,6 @@ def weight_compress_worker(file_path: str, quant_desc: dict, process_num: int) -
|
||||
compressor.run()
|
||||
if p.exists():
|
||||
os.remove(p)
|
||||
ori_quant_desc_file = p.parent / "quant_model_description.json"
|
||||
if ori_quant_desc_file.exists():
|
||||
os.rename(str(ori_quant_desc_file), str(ori_quant_desc_file.parent / "ori_quant_model_description.json"))
|
||||
compressor.export_safetensors(str(p.parent), safetensors_name=p.name)
|
||||
return True
|
||||
except Exception as e:
|
||||
@@ -248,6 +249,10 @@ def main(args):
|
||||
# 4. Compression Logic
|
||||
parameters_map_fpath = output_dir / "parameters_type_map.json"
|
||||
if args.enable_compress:
|
||||
quant_desc_file = output_dir / "quant_model_description.json"
|
||||
backup_quant_desc_file = output_dir / "ori_quant_model_description.json"
|
||||
if quant_desc_file.exists():
|
||||
os.rename(str(quant_desc_file), str(backup_quant_desc_file))
|
||||
quant_desc = get_quant_description(str(parameters_map_fpath))
|
||||
quant_type = quant_desc["model_quant_type"]
|
||||
if quant_type in SUPPORTED_COMPRESS_QUANT_TYPE:
|
||||
@@ -269,7 +274,7 @@ def main(args):
|
||||
for p in tasks:
|
||||
p.join()
|
||||
|
||||
update_quant_description(os.path.join(args.output, "ori_quant_model_description.json"))
|
||||
update_quant_description(str(backup_quant_desc_file), str(quant_desc_file))
|
||||
print("Compression completed successfully.")
|
||||
else:
|
||||
print(f"Skipping compression: Unsupported type {quant_type}")
|
||||
|
||||
Reference in New Issue
Block a user