From badf3fa02011f9e1af9a043033a41ff8c25dfbec Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 27 Jun 2024 23:30:39 -0700 Subject: [PATCH] Expose dtype argument (#569) --- .../srt/managers/controller/model_runner.py | 12 +++-- .../srt/managers/controller/tp_worker.py | 2 +- python/sglang/srt/server_args.py | 46 +++++++++++++------ 3 files changed, 39 insertions(+), 21 deletions(-) diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index 4b4add62b..1450abd1d 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -6,7 +6,7 @@ import logging import pkgutil from dataclasses import dataclass from functools import lru_cache -from typing import List, Optional, Type, Any +from typing import List, Optional, Type import numpy as np import torch @@ -119,7 +119,7 @@ class InputMetadata: head_dim, 1, pos_encoding_mode="NONE", - data_type="float16", + data_type=self.token_to_kv_pool.kv_data[0].dtype ) def init_extend_args(self): @@ -287,10 +287,11 @@ class ModelRunner: tokenizer=None, tokenizer_mode=None, trust_remote_code=self.server_args.trust_remote_code, - dtype=torch.float16, + dtype=self.server_args.dtype, seed=42, skip_tokenizer_init=True, ) + self.dtype = vllm_model_config.dtype if self.model_config.model_overide_args is not None: vllm_model_config.hf_config.update(self.model_config.model_overide_args) @@ -307,6 +308,7 @@ class ModelRunner: logger.info( f"[gpu_id={self.gpu_id}] Load weight end. " f"type={type(self.model).__name__}, " + f"dtype={self.dtype}, " f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" ) @@ -316,7 +318,7 @@ class ModelRunner: ) head_dim = self.model_config.head_dim head_num = self.model_config.get_num_kv_heads(self.tp_size) - cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2 + cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * torch._utils._element_size(self.dtype) rest_memory = available_gpu_memory - total_gpu_memory * ( 1 - self.mem_fraction_static ) @@ -337,7 +339,7 @@ class ModelRunner: ) self.token_to_kv_pool = TokenToKVPool( self.max_total_num_tokens, - dtype=torch.float16, + dtype=self.dtype, head_num=self.model_config.get_num_kv_heads(self.tp_size), head_dim=self.model_config.head_dim, layer_num=self.model_config.num_hidden_layers, diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index 98cdf4bc7..ba19142da 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -120,7 +120,7 @@ class ModelTpServer: f"[gpu_id={self.gpu_id}] " f"max_total_num_tokens={self.max_total_num_tokens}, " f"max_prefill_tokens={self.max_prefill_tokens}, " - f"context_len={self.model_config.context_len}, " + f"context_len={self.model_config.context_len}" ) if self.tp_rank == 0: logger.info( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 75d9033d6..4b7daf5f9 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -11,12 +11,13 @@ class ServerArgs: # Model and tokenizer model_path: str tokenizer_path: Optional[str] = None - load_format: str = "auto" tokenizer_mode: str = "auto" - chat_template: Optional[str] = None + load_format: str = "auto" + dtype: str = "auto" trust_remote_code: bool = True context_length: Optional[int] = None quantization: Optional[str] = None + chat_template: Optional[str] = None # Port host: str = "127.0.0.1" @@ -107,6 +108,15 @@ class ServerArgs: default=[], help="The additional ports specified for the server.", ) + parser.add_argument( + "--tokenizer-mode", + type=str, + default=ServerArgs.tokenizer_mode, + choices=["auto", "slow"], + help="Tokenizer mode. 'auto' will use the fast " + "tokenizer if available, and 'slow' will " + "always use the slow tokenizer.", + ) parser.add_argument( "--load-format", type=str, @@ -124,20 +134,20 @@ class ServerArgs: "which is mainly for profiling.", ) parser.add_argument( - "--tokenizer-mode", + "--dtype", type=str, - default=ServerArgs.tokenizer_mode, - choices=["auto", "slow"], - help="Tokenizer mode. 'auto' will use the fast " - "tokenizer if available, and 'slow' will " - "always use the slow tokenizer.", - ) - parser.add_argument( - "--chat-template", - type=str, - default=ServerArgs.chat_template, - help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server", - ) + default=ServerArgs.dtype, + choices=[ + "auto", "half", "float16", "bfloat16", "float", "float32" + ], + help='Data type for model weights and activations.\n\n' + '* "auto" will use FP16 precision for FP32 and FP16 models, and ' + 'BF16 precision for BF16 models.\n' + '* "half" for FP16. Recommended for AWQ quantization.\n' + '* "float16" is the same as "half".\n' + '* "bfloat16" for a balance between precision and range.\n' + '* "float" is shorthand for FP32 precision.\n' + '* "float32" for FP32 precision.') parser.add_argument( "--trust-remote-code", action="store_true", @@ -155,6 +165,12 @@ class ServerArgs: default=ServerArgs.quantization, help="The quantization method.", ) + parser.add_argument( + "--chat-template", + type=str, + default=ServerArgs.chat_template, + help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.", + ) parser.add_argument( "--mem-fraction-static", type=float,