[Minor] Code style improvements (#2355)
This commit is contained in:
@@ -2,12 +2,10 @@
|
|||||||
Common utilities for torchao.
|
Common utilities for torchao.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, Set
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def apply_torchao_config_to_model_(
|
def apply_torchao_config_to_model(
|
||||||
model: torch.nn.Module, torchao_config: str, filter_fn=None
|
model: torch.nn.Module, torchao_config: str, filter_fn=None
|
||||||
):
|
):
|
||||||
"""Quantize a modelwith torchao quantization specified by torchao_config
|
"""Quantize a modelwith torchao quantization specified by torchao_config
|
||||||
@@ -21,6 +19,7 @@ def apply_torchao_config_to_model_(
|
|||||||
# Lazy import to suppress some warnings
|
# Lazy import to suppress some warnings
|
||||||
from torchao.quantization import (
|
from torchao.quantization import (
|
||||||
float8_dynamic_activation_float8_weight,
|
float8_dynamic_activation_float8_weight,
|
||||||
|
float8_weight_only,
|
||||||
int4_weight_only,
|
int4_weight_only,
|
||||||
int8_dynamic_activation_int8_weight,
|
int8_dynamic_activation_int8_weight,
|
||||||
int8_weight_only,
|
int8_weight_only,
|
||||||
@@ -28,6 +27,11 @@ def apply_torchao_config_to_model_(
|
|||||||
)
|
)
|
||||||
from torchao.quantization.observer import PerRow, PerTensor
|
from torchao.quantization.observer import PerRow, PerTensor
|
||||||
|
|
||||||
|
if filter_fn is None:
|
||||||
|
|
||||||
|
def filter_fn(module, fqn):
|
||||||
|
return "proj" in fqn
|
||||||
|
|
||||||
if torchao_config == "" or torchao_config is None:
|
if torchao_config == "" or torchao_config is None:
|
||||||
return model
|
return model
|
||||||
elif "int8wo" in torchao_config:
|
elif "int8wo" in torchao_config:
|
||||||
@@ -44,8 +48,6 @@ def apply_torchao_config_to_model_(
|
|||||||
], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
|
], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
|
||||||
quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn)
|
quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn)
|
||||||
elif "fp8wo" in torchao_config:
|
elif "fp8wo" in torchao_config:
|
||||||
from torchao.quantization import float8_weight_only
|
|
||||||
|
|
||||||
# this requires newer hardware
|
# this requires newer hardware
|
||||||
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
|
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
|
||||||
quantize_(model, float8_weight_only(), filter_fn=filter_fn)
|
quantize_(model, float8_weight_only(), filter_fn=filter_fn)
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
|
|||||||
if "FusedMoE" in sub.__class__.__name__:
|
if "FusedMoE" in sub.__class__.__name__:
|
||||||
if batch_size == 1:
|
if batch_size == 1:
|
||||||
# The performance of torch.compile on this layer is not always good when bs > 1,
|
# The performance of torch.compile on this layer is not always good when bs > 1,
|
||||||
# so we decide to skip it for now.
|
# so we decide to only use torch.compile when bs =1
|
||||||
sub._forward_method = fused_moe_forward_native
|
sub._forward_method = fused_moe_forward_native
|
||||||
else:
|
else:
|
||||||
sub._forward_method = sub.forward_native
|
sub._forward_method = sub.forward_native
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ from vllm.distributed import (
|
|||||||
initialize_model_parallel,
|
initialize_model_parallel,
|
||||||
set_custom_all_reduce,
|
set_custom_all_reduce,
|
||||||
)
|
)
|
||||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
|
||||||
|
|
||||||
from sglang.srt.configs.device_config import DeviceConfig
|
from sglang.srt.configs.device_config import DeviceConfig
|
||||||
from sglang.srt.configs.load_config import LoadConfig
|
from sglang.srt.configs.load_config import LoadConfig
|
||||||
@@ -38,7 +37,7 @@ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBack
|
|||||||
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.layers.sampler import Sampler
|
from sglang.srt.layers.sampler import Sampler
|
||||||
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model_
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
||||||
from sglang.srt.lora.lora_manager import LoRAManager
|
from sglang.srt.lora.lora_manager import LoRAManager
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.mem_cache.memory_pool import (
|
from sglang.srt.mem_cache.memory_pool import (
|
||||||
@@ -112,11 +111,13 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.is_multimodal:
|
if self.is_multimodal:
|
||||||
logger.info(
|
|
||||||
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
|
||||||
)
|
|
||||||
server_args.chunked_prefill_size = -1
|
server_args.chunked_prefill_size = -1
|
||||||
self.mem_fraction_static *= 0.95
|
self.mem_fraction_static *= 0.95
|
||||||
|
logger.info(
|
||||||
|
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static} "
|
||||||
|
f"and turn off chunked prefill "
|
||||||
|
f"because this is a multimodal model."
|
||||||
|
)
|
||||||
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
|
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
|
||||||
if self.model_config.hf_config.architectures == [
|
if self.model_config.hf_config.architectures == [
|
||||||
"Qwen2VLForConditionalGeneration"
|
"Qwen2VLForConditionalGeneration"
|
||||||
@@ -160,11 +161,8 @@ class ModelRunner:
|
|||||||
else:
|
else:
|
||||||
self.torch_tp_applied = False
|
self.torch_tp_applied = False
|
||||||
|
|
||||||
def filter_fn(module, fqn):
|
apply_torchao_config_to_model(
|
||||||
return "proj" in fqn
|
self.model, global_server_args_dict["torchao_config"]
|
||||||
|
|
||||||
apply_torchao_config_to_model_(
|
|
||||||
self.model, global_server_args_dict["torchao_config"], filter_fn
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init memory pool and attention backends
|
# Init memory pool and attention backends
|
||||||
|
|||||||
Reference in New Issue
Block a user