Small fixes for torchao quant (#2476)
This commit is contained in:
@@ -26,11 +26,12 @@ def apply_torchao_config_to_model(
|
|||||||
quantize_,
|
quantize_,
|
||||||
)
|
)
|
||||||
from torchao.quantization.observer import PerRow, PerTensor
|
from torchao.quantization.observer import PerRow, PerTensor
|
||||||
|
from torchao.quantization.quant_api import _is_linear
|
||||||
|
|
||||||
if filter_fn is None:
|
if filter_fn is None:
|
||||||
|
|
||||||
def filter_fn(module, fqn):
|
def filter_fn(module, fqn):
|
||||||
return "proj" in fqn
|
return _is_linear(module) and "proj" in fqn
|
||||||
|
|
||||||
if torchao_config == "" or torchao_config is None:
|
if torchao_config == "" or torchao_config is None:
|
||||||
return model
|
return model
|
||||||
|
|||||||
@@ -157,6 +157,10 @@ class ModelRunner:
|
|||||||
self.sampler = Sampler()
|
self.sampler = Sampler()
|
||||||
self.load_model()
|
self.load_model()
|
||||||
|
|
||||||
|
apply_torchao_config_to_model(
|
||||||
|
self.model, global_server_args_dict["torchao_config"]
|
||||||
|
)
|
||||||
|
|
||||||
# Apply torch TP if the model supports it
|
# Apply torch TP if the model supports it
|
||||||
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
|
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
|
||||||
if self.tp_size > 1 and supports_torch_tp:
|
if self.tp_size > 1 and supports_torch_tp:
|
||||||
@@ -165,10 +169,6 @@ class ModelRunner:
|
|||||||
else:
|
else:
|
||||||
self.torch_tp_applied = False
|
self.torch_tp_applied = False
|
||||||
|
|
||||||
apply_torchao_config_to_model(
|
|
||||||
self.model, global_server_args_dict["torchao_config"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Init memory pool and attention backends
|
# Init memory pool and attention backends
|
||||||
if server_args.lora_paths is not None:
|
if server_args.lora_paths is not None:
|
||||||
self.init_lora_manager()
|
self.init_lora_manager()
|
||||||
|
|||||||
Reference in New Issue
Block a user