Unify the model type checking (#1905)

This commit is contained in:
Lianmin Zheng
2024-11-03 12:25:39 -08:00
committed by GitHub
parent c17c578108
commit 0abbf289a8
13 changed files with 146 additions and 160 deletions

View File

@@ -59,11 +59,6 @@ from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
enable_show_time_cost,
get_available_gpu_memory,
is_attention_free_model,
is_embedding_model,
is_generation_model,
is_multimodal_model,
model_has_inner_state,
monkey_patch_vllm_dummy_weight_loader,
monkey_patch_vllm_p2p_access_check,
)
@@ -93,9 +88,8 @@ class ModelRunner:
self.tp_size = tp_size
self.dist_port = nccl_port
self.server_args = server_args
self.is_multimodal_model = is_multimodal_model(
self.model_config.hf_config.architectures
)
self.is_generation = model_config.is_generation
self.is_multimodal = model_config.is_multimodal
# Model-specific adjustment
if (
@@ -119,7 +113,7 @@ class ModelRunner:
self.server_args.ds_heavy_channel_type
)
if self.is_multimodal_model:
if self.is_multimodal:
logger.warning(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
@@ -270,9 +264,6 @@ class ModelRunner:
if hasattr(self.model, "get_attention_sliding_window_size")
else None
)
self.is_generation = is_generation_model(
self.model_config.hf_config.architectures, self.server_args.is_embedding
)
logger.info(
f"Load weight end. "
@@ -679,7 +670,7 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
# Monkey patch model loader
setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
setattr(ModelRegistry, "is_multimodal_model", is_multimodal_model)
setattr(ModelRegistry, "is_attention_free_model", is_attention_free_model)
setattr(ModelRegistry, "model_has_inner_state", model_has_inner_state)
setattr(ModelRegistry, "is_embedding_model", is_embedding_model)
setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False)
setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False)
setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False)
setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False)