Support loading of larger models with on-the-fly quantization (#3061)

This commit is contained in:
Ke Wen
2025-01-22 21:33:17 -08:00
committed by GitHub
parent 8b84e69f25
commit 862bcff833
6 changed files with 116 additions and 14 deletions

View File

@@ -5,6 +5,7 @@ Common utilities for torchao.
import logging
import os
import pwd
from typing import Callable, Optional
import torch
@@ -27,8 +28,18 @@ def save_gemlite_cache(print_error: bool = False) -> bool:
return True
def proj_filter(
module: torch.nn.Module,
fqn: str,
):
"""Filter function for quantizing projection layers."""
return "proj" in fqn
def apply_torchao_config_to_model(
model: torch.nn.Module, torchao_config: str, filter_fn=None
model: torch.nn.Module,
torchao_config: str,
filter_fn: Optional[Callable] = proj_filter,
):
"""Quantize a modelwith torchao quantization specified by torchao_config
@@ -49,11 +60,6 @@ def apply_torchao_config_to_model(
)
from torchao.quantization.observer import PerRow, PerTensor
if filter_fn is None:
def filter_fn(module, fqn):
return "proj" in fqn
if torchao_config == "" or torchao_config is None:
return model
elif "int8wo" in torchao_config: