diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 1ed7b8f7d..088d94a78 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -34,12 +34,11 @@ class LogitProcessorOutput: @dataclasses.dataclass class LogitsMetadata: forward_mode: ForwardMode - extend_seq_lens: torch.Tensor - extend_start_loc: torch.Tensor - - # For logprobs return_logprob: bool - top_logprobs_nums: List[int] + + extend_seq_lens: torch.Tensor = None + extend_start_loc: torch.Tensor = None + top_logprobs_nums: List[int] = None @classmethod def from_input_metadata(cls, input_metadata: InputMetadata): diff --git a/python/sglang/srt/managers/controller/cuda_graph_runner.py b/python/sglang/srt/managers/controller/cuda_graph_runner.py index 2a9a0af6d..0066f92b8 100644 --- a/python/sglang/srt/managers/controller/cuda_graph_runner.py +++ b/python/sglang/srt/managers/controller/cuda_graph_runner.py @@ -6,6 +6,7 @@ import torch from flashinfer import BatchDecodeWithPagedKVCacheWrapper from flashinfer.decode import _grouped_size_compiled_for_decode_kernels from vllm.distributed.parallel_state import graph_capture +from vllm.model_executor.custom_op import CustomOp from sglang.srt.layers.logits_processor import LogitProcessorOutput from sglang.srt.managers.controller.infer_batch import ( @@ -16,8 +17,28 @@ from sglang.srt.managers.controller.infer_batch import ( ) +def _to_torch(model: torch.nn.Module, reverse=False): + for sub in model._modules.values(): + if isinstance(sub, CustomOp): + if reverse: + sub._forward_method = sub.forward_cuda + else: + sub._forward_method = sub.forward_native + if isinstance(sub, torch.nn.Module): + _to_torch(sub, reverse) + + +def get_forward(model: torch.nn.Module, use_torch: bool): + if use_torch: + _to_torch(model, reverse=False) + return torch.compile(model.forward, mode="max-autotune-no-cudagraphs") + else: + _to_torch(model, reverse=True) + return model.forward + + class CudaGraphRunner: - def __init__(self, model_runner, max_batch_size_to_capture): + def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile): self.model_runner = model_runner self.graphs = {} self.input_buffers = {} @@ -55,6 +76,8 @@ class CudaGraphRunner: (self.max_bs,), dtype=torch.int32, device="cuda" ) + self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else [] + def can_run(self, batch_size): return batch_size < self.max_bs @@ -63,18 +86,19 @@ class CudaGraphRunner: with graph_capture() as graph_capture_context: self.stream = graph_capture_context.stream for bs in batch_size_list: + forward = get_forward(self.model_runner.model, bs in self.compile_bs) ( graph, input_buffers, output_buffers, flashinfer_handler, - ) = self.capture_one_batch_size(bs) + ) = self.capture_one_batch_size(bs, forward) self.graphs[bs] = graph self.input_buffers[bs] = input_buffers self.output_buffers[bs] = output_buffers self.flashinfer_handlers[bs] = flashinfer_handler - def capture_one_batch_size(self, bs): + def capture_one_batch_size(self, bs, forward): graph = torch.cuda.CUDAGraph() stream = self.stream @@ -127,9 +151,8 @@ class CudaGraphRunner: skip_flashinfer_init=True, ) input_metadata.flashinfer_decode_wrapper = flashinfer_decode_wrapper - return self.model_runner.model.forward( - input_ids, input_metadata.positions, input_metadata - ) + + return forward(input_ids, input_metadata.positions, input_metadata) for _ in range(2): run_once() diff --git a/python/sglang/srt/managers/controller/model_runner.py b/python/sglang/srt/managers/controller/model_runner.py index 13ccc6041..9fd0f19a3 100644 --- a/python/sglang/srt/managers/controller/model_runner.py +++ b/python/sglang/srt/managers/controller/model_runner.py @@ -244,7 +244,9 @@ class ModelRunner: logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.") batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)] self.cuda_graph_runner = CudaGraphRunner( - self, max_batch_size_to_capture=max(batch_size_list) + self, + max_batch_size_to_capture=max(batch_size_list), + use_torch_compile=self.server_args.enable_torch_compile, ) try: self.cuda_graph_runner.capture(batch_size_list) diff --git a/python/sglang/srt/models/gpt_bigcode.py b/python/sglang/srt/models/gpt_bigcode.py new file mode 100644 index 000000000..4592d1f60 --- /dev/null +++ b/python/sglang/srt/models/gpt_bigcode.py @@ -0,0 +1,282 @@ +# Adapted from: +# https://github.com/vllm-project/vllm/blob/07eb6f19f3b0ee9f7adf6eb689607028aa40bfd5/vllm/model_executor/models/gpt_bigcode.py +"""Inference-only GPTBigCode model compatible with HuggingFace weights.""" +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import GPTBigCodeConfig +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.managers.controller.infer_batch import InputMetadata + + +class GPTBigCodeAttention(nn.Module): + + def __init__( + self, + layer_id: int, + config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.hidden_size = config.hidden_size + total_num_heads = config.num_attention_heads + self.tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() + assert total_num_heads % self.tensor_model_parallel_world_size == 0 + self.num_heads = total_num_heads // self.tensor_model_parallel_world_size + self.head_dim = self.hidden_size // total_num_heads + self.scale = self.head_dim**-0.5 + + self.multi_query = config.multi_query + if self.multi_query: + total_num_kv_heads = 1 + self.num_kv_heads = 1 + else: + total_num_kv_heads = total_num_heads + self.num_kv_heads = self.num_heads + self.kv_dim = self.head_dim * self.num_kv_heads + self.c_attn = QKVParallelLinear( + self.hidden_size, + self.head_dim, + total_num_heads, + total_num_kv_heads, + bias=True, + quant_config=quant_config, + ) + + self.c_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + ) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + scaling=self.scale, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + ) + + def forward( + self, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.split( + [ + self.hidden_size // self.tensor_model_parallel_world_size, + self.kv_dim, + self.kv_dim, + ], + dim=-1, + ) + attn_output = self.attn(q, k, v, input_metadata) + attn_output, _ = self.c_proj(attn_output) + return attn_output + + +class GPTBigMLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: GPTBigCodeConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + hidden_size = config.hidden_size + self.c_fc = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + quant_config=quant_config, + ) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + quant_config=quant_config, + ) + self.act = get_act_fn( + config.activation_function, quant_config, intermediate_size + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.c_proj(hidden_states) + return hidden_states + + +class GPTBigCodeBlock(nn.Module): + + def __init__( + self, + layer_id: int, + config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPTBigCodeAttention(layer_id, config, cache_config, quant_config) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = GPTBigMLP(inner_dim, config, quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn( + hidden_states=hidden_states, input_metadata=input_metadata + ) + # residual connection + hidden_states = attn_output + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + return hidden_states + + +class GPTBigCodeModel(nn.Module): + + def __init__( + self, + config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ): + super().__init__() + self.config = config + assert not config.add_cross_attention + + self.embed_dim = config.hidden_size + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) + self.vocab_size = config.vocab_size + lora_vocab + self.wte = VocabParallelEmbedding( + self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size + ) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + self.h = nn.ModuleList( + [ + GPTBigCodeBlock(i, config, cache_config, quant_config) + for i in range(config.num_hidden_layers) + ] + ) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + for i in range(len(self.h)): + layer = self.h[i] + hidden_states = layer(hidden_states, input_metadata) + + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class GPTBigCodeForCausalLM(nn.Module): + packed_modules_mapping = {"c_attn": ["c_attn"]} + + supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"] + + embedding_modules = { + "wte": "input_embeddings", + "lm_head": "output_embeddings", + } + + embedding_padding_modules = [] + + def __init__( + self, + config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ): + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.transformer = GPTBigCodeModel( + config, cache_config, quant_config, lora_config + ) + self.lm_head = self.transformer.wte + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.transformer(input_ids, positions, input_metadata) + return self.logits_processor( + input_ids, hidden_states, self.lm_head.weight, input_metadata + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "lm_head.weight" in name: + continue + if ".attn.bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method + if "c_attn.input_scale" in name or "c_attn.weight_scale" in name: + weight_loader(param, loaded_weight, "q") + weight_loader(param, loaded_weight, "k") + weight_loader(param, loaded_weight, "v") + else: + weight_loader(param, loaded_weight) + + +EntryClass = GPTBigCodeForCausalLM diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index c65186bf5..3c6a79e30 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -157,6 +157,19 @@ def _set_global_server_args(server_args: ServerArgs): } +def _set_torch_compile_config(): + # The following configurations are for torch compile optimizations + import torch._dynamo.config + import torch._inductor.config + + torch._inductor.config.coordinate_descent_tuning = True + torch._inductor.config.triton.unique_kernel_names = True + torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future + + # FIXME: tmp workaround + torch._dynamo.config.accumulated_cache_size_limit = 128 + + def launch_server( server_args: ServerArgs, model_overide_args: Optional[dict] = None, @@ -190,6 +203,10 @@ def launch_server( if server_args.chat_template: # TODO: replace this with huggingface transformers template load_chat_template_for_openai_api(server_args.chat_template) + + if server_args.enable_torch_compile: + _set_torch_compile_config() + _set_global_server_args(server_args) # Allocate ports diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 68dd90025..d77ea8782 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -55,6 +55,7 @@ class ServerArgs: disable_regex_jump_forward: bool = False disable_cuda_graph: bool = False disable_disk_cache: bool = False + enable_torch_compile: bool = False attention_reduce_in_fp32: bool = False enable_p2p_check: bool = False efficient_weight_load: bool = False @@ -317,6 +318,11 @@ class ServerArgs: action="store_true", help="Disable disk cache to avoid possible crashes related to file system or high concurrency.", ) + parser.add_argument( + "--enable-torch-compile", + action="store_true", + help="Optimize the model with torch.compile, experimental feature.", + ) parser.add_argument( "--attention-reduce-in-fp32", action="store_true",