Support loading of larger models with on-the-fly quantization (#3061)
This commit is contained in:
@@ -5,6 +5,7 @@ Common utilities for torchao.
|
||||
import logging
|
||||
import os
|
||||
import pwd
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -27,8 +28,18 @@ def save_gemlite_cache(print_error: bool = False) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def proj_filter(
|
||||
module: torch.nn.Module,
|
||||
fqn: str,
|
||||
):
|
||||
"""Filter function for quantizing projection layers."""
|
||||
return "proj" in fqn
|
||||
|
||||
|
||||
def apply_torchao_config_to_model(
|
||||
model: torch.nn.Module, torchao_config: str, filter_fn=None
|
||||
model: torch.nn.Module,
|
||||
torchao_config: str,
|
||||
filter_fn: Optional[Callable] = proj_filter,
|
||||
):
|
||||
"""Quantize a modelwith torchao quantization specified by torchao_config
|
||||
|
||||
@@ -49,11 +60,6 @@ def apply_torchao_config_to_model(
|
||||
)
|
||||
from torchao.quantization.observer import PerRow, PerTensor
|
||||
|
||||
if filter_fn is None:
|
||||
|
||||
def filter_fn(module, fqn):
|
||||
return "proj" in fqn
|
||||
|
||||
if torchao_config == "" or torchao_config is None:
|
||||
return model
|
||||
elif "int8wo" in torchao_config:
|
||||
|
||||
Reference in New Issue
Block a user