Tiny refactor ModelConfig.from_server_args (#5219)
This commit is contained in:
@@ -137,17 +137,7 @@ def load_model(server_args, port_args, tp_rank):
|
|||||||
suppress_other_loggers()
|
suppress_other_loggers()
|
||||||
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||||
|
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig.from_server_args(server_args)
|
||||||
server_args.model_path,
|
|
||||||
trust_remote_code=server_args.trust_remote_code,
|
|
||||||
revision=server_args.revision,
|
|
||||||
context_length=server_args.context_length,
|
|
||||||
model_override_args=server_args.json_model_override_args,
|
|
||||||
is_embedding=server_args.is_embedding,
|
|
||||||
enable_multimodal=server_args.enable_multimodal,
|
|
||||||
dtype=server_args.dtype,
|
|
||||||
quantization=server_args.quantization,
|
|
||||||
)
|
|
||||||
model_runner = ModelRunner(
|
model_runner = ModelRunner(
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
mem_fraction_static=server_args.mem_fraction_static,
|
mem_fraction_static=server_args.mem_fraction_static,
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from transformers import PretrainedConfig
|
|||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
||||||
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
||||||
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import get_bool_env_var, is_hip
|
from sglang.srt.utils import get_bool_env_var, is_hip
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -210,6 +211,21 @@ class ModelConfig:
|
|||||||
self.hf_eos_token_id = self.get_hf_eos_token_id()
|
self.hf_eos_token_id = self.get_hf_eos_token_id()
|
||||||
self.image_token_id = getattr(self.hf_config, "image_token_id", None)
|
self.image_token_id = getattr(self.hf_config, "image_token_id", None)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
|
||||||
|
return ModelConfig(
|
||||||
|
model_path=model_path or server_args.model_path,
|
||||||
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
|
revision=server_args.revision,
|
||||||
|
context_length=server_args.context_length,
|
||||||
|
model_override_args=server_args.json_model_override_args,
|
||||||
|
is_embedding=server_args.is_embedding,
|
||||||
|
enable_multimodal=server_args.enable_multimodal,
|
||||||
|
dtype=server_args.dtype,
|
||||||
|
quantization=server_args.quantization,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
||||||
def get_total_num_kv_heads(self) -> int:
|
def get_total_num_kv_heads(self) -> int:
|
||||||
"""Returns the total number of KV heads."""
|
"""Returns the total number of KV heads."""
|
||||||
|
|||||||
@@ -455,17 +455,7 @@ class Scheduler(
|
|||||||
def init_tokenizer(self):
|
def init_tokenizer(self):
|
||||||
server_args = self.server_args
|
server_args = self.server_args
|
||||||
|
|
||||||
self.model_config = ModelConfig(
|
self.model_config = ModelConfig.from_server_args(server_args)
|
||||||
server_args.model_path,
|
|
||||||
trust_remote_code=server_args.trust_remote_code,
|
|
||||||
revision=server_args.revision,
|
|
||||||
context_length=server_args.context_length,
|
|
||||||
model_override_args=server_args.json_model_override_args,
|
|
||||||
is_embedding=server_args.is_embedding,
|
|
||||||
enable_multimodal=server_args.enable_multimodal,
|
|
||||||
dtype=server_args.dtype,
|
|
||||||
quantization=server_args.quantization,
|
|
||||||
)
|
|
||||||
self.is_generation = self.model_config.is_generation
|
self.is_generation = self.model_config.is_generation
|
||||||
|
|
||||||
if server_args.skip_tokenizer_init:
|
if server_args.skip_tokenizer_init:
|
||||||
|
|||||||
@@ -165,17 +165,7 @@ class TokenizerManager:
|
|||||||
# Read model args
|
# Read model args
|
||||||
self.model_path = server_args.model_path
|
self.model_path = server_args.model_path
|
||||||
self.served_model_name = server_args.served_model_name
|
self.served_model_name = server_args.served_model_name
|
||||||
self.model_config = ModelConfig(
|
self.model_config = ModelConfig.from_server_args(server_args)
|
||||||
server_args.model_path,
|
|
||||||
trust_remote_code=server_args.trust_remote_code,
|
|
||||||
revision=server_args.revision,
|
|
||||||
context_length=server_args.context_length,
|
|
||||||
model_override_args=server_args.json_model_override_args,
|
|
||||||
is_embedding=server_args.is_embedding,
|
|
||||||
enable_multimodal=server_args.enable_multimodal,
|
|
||||||
dtype=server_args.dtype,
|
|
||||||
quantization=server_args.quantization,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.is_generation = self.model_config.is_generation
|
self.is_generation = self.model_config.is_generation
|
||||||
self.is_image_gen = self.model_config.is_image_gen
|
self.is_image_gen = self.model_config.is_image_gen
|
||||||
|
|||||||
@@ -65,20 +65,13 @@ class TpModelWorker:
|
|||||||
self.pp_rank = pp_rank
|
self.pp_rank = pp_rank
|
||||||
|
|
||||||
# Init model and tokenizer
|
# Init model and tokenizer
|
||||||
self.model_config = ModelConfig(
|
self.model_config = ModelConfig.from_server_args(
|
||||||
(
|
server_args,
|
||||||
|
model_path=(
|
||||||
server_args.model_path
|
server_args.model_path
|
||||||
if not is_draft_worker
|
if not is_draft_worker
|
||||||
else server_args.speculative_draft_model_path
|
else server_args.speculative_draft_model_path
|
||||||
),
|
),
|
||||||
trust_remote_code=server_args.trust_remote_code,
|
|
||||||
revision=server_args.revision,
|
|
||||||
context_length=server_args.context_length,
|
|
||||||
model_override_args=server_args.json_model_override_args,
|
|
||||||
is_embedding=server_args.is_embedding,
|
|
||||||
enable_multimodal=server_args.enable_multimodal,
|
|
||||||
dtype=server_args.dtype,
|
|
||||||
quantization=server_args.quantization,
|
|
||||||
is_draft_model=is_draft_worker,
|
is_draft_model=is_draft_worker,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -43,16 +43,7 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
server_args = ServerArgs(model_path=model_path, dtype=torch.float16)
|
server_args = ServerArgs(model_path=model_path, dtype=torch.float16)
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig.from_server_args(server_args)
|
||||||
server_args.model_path,
|
|
||||||
trust_remote_code=server_args.trust_remote_code,
|
|
||||||
revision=server_args.revision,
|
|
||||||
context_length=server_args.context_length,
|
|
||||||
model_override_args=server_args.json_model_override_args,
|
|
||||||
is_embedding=server_args.is_embedding,
|
|
||||||
dtype=server_args.dtype,
|
|
||||||
quantization=server_args.quantization,
|
|
||||||
)
|
|
||||||
|
|
||||||
load_config = LoadConfig()
|
load_config = LoadConfig()
|
||||||
device_config = DeviceConfig("cuda")
|
device_config = DeviceConfig("cuda")
|
||||||
|
|||||||
Reference in New Issue
Block a user