[Minor] Code style improvements (#2355)

This commit is contained in:
Lianmin Zheng
2024-12-04 19:02:08 -08:00
committed by GitHub
parent 9cc733b38c
commit 2b0fc5941d
3 changed files with 16 additions and 16 deletions

View File

@@ -47,7 +47,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
if "FusedMoE" in sub.__class__.__name__:
if batch_size == 1:
# The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to skip it for now.
# so we decide to only use torch.compile when bs =1
sub._forward_method = fused_moe_forward_native
else:
sub._forward_method = sub.forward_native

View File

@@ -27,7 +27,6 @@ from vllm.distributed import (
initialize_model_parallel,
set_custom_all_reduce,
)
from vllm.distributed.parallel_state import in_the_same_node_as
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
@@ -38,7 +37,7 @@ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBack
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model_
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import (
@@ -112,11 +111,13 @@ class ModelRunner:
)
if self.is_multimodal:
logger.info(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args.chunked_prefill_size = -1
self.mem_fraction_static *= 0.95
logger.info(
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static} "
f"and turn off chunked prefill "
f"because this is a multimodal model."
)
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
if self.model_config.hf_config.architectures == [
"Qwen2VLForConditionalGeneration"
@@ -160,11 +161,8 @@ class ModelRunner:
else:
self.torch_tp_applied = False
def filter_fn(module, fqn):
return "proj" in fqn
apply_torchao_config_to_model_(
self.model, global_server_args_dict["torchao_config"], filter_fn
apply_torchao_config_to_model(
self.model, global_server_args_dict["torchao_config"]
)
# Init memory pool and attention backends