diff --git a/docs/test_process.md b/docs/test_process.md index 90958ec62..99889f999 100644 --- a/docs/test_process.md +++ b/docs/test_process.md @@ -91,4 +91,10 @@ python3 run_all.py ``` cd test/srt python test_openai_server.py -``` \ No newline at end of file +``` + +## Format +pip3 install pre-commit +cd sglang +pre-commit install +pre-commit run --all-files \ No newline at end of file diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index 01450d8ac..34beebb3b 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -123,6 +123,15 @@ class ModelRunner: if self.model_config.model_overide_args is not None: vllm_model_config.hf_config.update(self.model_config.model_overide_args) + if ( + self.server_args.efficient_weight_load + and "llama" in self.server_args.model_path.lower() + and self.server_args.quantization == "fp8" + ): + from sglang.srt.model_loader.model_loader import get_model + else: + from vllm.model_executor.model_loader import get_model + self.model = get_model( model_config=vllm_model_config, device_config=device_config, @@ -237,7 +246,16 @@ class ModelRunner: self.cuda_graph_runner = CudaGraphRunner( self, max_batch_size_to_capture=max(batch_size_list) ) - self.cuda_graph_runner.capture(batch_size_list) + logger.info(f"Capture for batch sizes {batch_size_list}") + try: + self.cuda_graph_runner.capture(batch_size_list) + except: + raise Exception( + f"Capture cuda graph failed. Possible solutions:\n" + f"1. disable cuda graph by --disable-cuda-graph\n" + f"2. set --mem-fraction-static to a smaller value\n" + f"Open an issue on GitHub with reproducible scripts if you need help.\n" + ) @torch.inference_mode() def forward_decode(self, batch: Batch): diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index 183a7a786..ab7c7f9e9 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -304,6 +304,12 @@ class ModelTpServer: self.model_config.context_len - 1 - len(req.origin_input_ids), self.max_total_num_tokens - 128 - len(req.origin_input_ids), ) + if req.sampling_params.max_new_tokens < 0: + req.origin_input_ids = req.origin_input_ids[ + : self.max_total_num_tokens - 128 + ] + logger.error("Request longer than memory pool size, truncated!!!") + self.forward_queue.append(req) def get_new_prefill_batch(self) -> Optional[Batch]: diff --git a/python/sglang/srt/model_loader/model_loader.py b/python/sglang/srt/model_loader/model_loader.py index ded68ae97..719fa9269 100644 --- a/python/sglang/srt/model_loader/model_loader.py +++ b/python/sglang/srt/model_loader/model_loader.py @@ -91,6 +91,7 @@ def _initialize_model( config=model_config.hf_config, cache_config=cache_config, quant_config=quant_config, + efficient_weight_load=True, **_get_model_initialization_kwargs(model_class, lora_config, multimodal_config), ) diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py index e6b3c1d19..f5da6dcb3 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama2.py @@ -15,11 +15,6 @@ from vllm.distributed import ( ) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ( - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, -) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -32,6 +27,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.controller.model_runner import InputMetadata +MergedColumnParallelLinear = None +QKVParallelLinear = None +RowParallelLinear = None + class LlamaMLP(nn.Module): def __init__( @@ -267,7 +266,25 @@ class LlamaForCausalLM(nn.Module): config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, cache_config: Optional[CacheConfig] = None, + efficient_weight_load=False, ) -> None: + global MergedColumnParallelLinear + global QKVParallelLinear + global RowParallelLinear + + if efficient_weight_load: + from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, + ) + else: + from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, + ) + super().__init__() self.config = config self.quant_config = quant_config @@ -288,7 +305,30 @@ class LlamaForCausalLM(nn.Module): input_ids, hidden_states, self.lm_head.weight, input_metadata ) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def get_module_name(self, name): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id, num_shard) + ("qkv_proj", "q_proj", "q", 3), + ("qkv_proj", "k_proj", "k", 3), + ("qkv_proj", "v_proj", "v", 3), + ("gate_up_proj", "gate_proj", 0, 2), + ("gate_up_proj", "up_proj", 1, 2), + ] + for param_name, weight_name, shard_id, num_shard in stacked_params_mapping: + if weight_name in name: + return ( + name.replace(weight_name, param_name)[: -len(".weight")], + num_shard, + ) + return name[: -len(".weight")], 1 + + def get_num_params(self): + params_dict = dict(self.named_parameters()) + return len(params_dict) + + def load_weights( + self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None + ): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -298,15 +338,14 @@ class LlamaForCausalLM(nn.Module): ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - if get_tensor_model_parallel_rank() == 0: - weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5)) - for name, loaded_weight in weights: + + def load_weights_per_param(name, loaded_weight): if "rotary_emb.inv_freq" in name or "projector" in name: - continue + return if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. - continue + return for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue @@ -323,12 +362,21 @@ class LlamaForCausalLM(nn.Module): else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: - continue + return if name.startswith("model.vision_tower") and name not in params_dict: - continue + return param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + if name is None or loaded_weight is None: + if get_tensor_model_parallel_rank() == 0: + weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5)) + + for name, loaded_weight in weights: + load_weights_per_param(name, loaded_weight) + else: + load_weights_per_param(name, loaded_weight) + EntryClass = LlamaForCausalLM diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 264985fb5..68dd90025 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -57,6 +57,7 @@ class ServerArgs: disable_disk_cache: bool = False attention_reduce_in_fp32: bool = False enable_p2p_check: bool = False + efficient_weight_load: bool = False # Distributed args nccl_init_addr: Optional[str] = None @@ -327,6 +328,11 @@ class ServerArgs: action="store_true", help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.", ) + parser.add_argument( + "--efficient-weight-load", + action="store_true", + help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace):