From 8a6906127a81421e06c904273f8e06dff85039a7 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 7 Jan 2025 23:29:10 -0800 Subject: [PATCH] Improve linear.py to load sharded weights & remove the dependency of Parameters from vllm (#2784) Co-authored-by: SangBin Cho rkooo567@gmail.com --- 3rdparty/amd/tuning/benchmark_moe_rocm.py | 5 +- .../sglang/srt/layers/attention/__init__.py | 9 +- .../layers/attention/flashinfer_backend.py | 6 +- python/sglang/srt/layers/linear.py | 222 ++++++--- .../srt/layers/moe/fused_moe_triton/layer.py | 5 +- python/sglang/srt/layers/parameter.py | 431 ++++++++++++++++++ python/sglang/srt/layers/quantization/fp8.py | 2 +- .../srt/layers/vocab_parallel_embedding.py | 2 +- .../sglang/srt/managers/session_controller.py | 2 +- .../srt/model_executor/forward_batch_info.py | 3 + .../sglang/srt/model_executor/model_runner.py | 3 +- python/sglang/srt/models/grok.py | 41 +- python/sglang/srt/server.py | 9 +- python/sglang/srt/speculative/eagle_utils.py | 2 +- scripts/killall_sglang.sh | 1 + 15 files changed, 655 insertions(+), 88 deletions(-) create mode 100644 python/sglang/srt/layers/parameter.py diff --git a/3rdparty/amd/tuning/benchmark_moe_rocm.py b/3rdparty/amd/tuning/benchmark_moe_rocm.py index a3f26e8e5..5aff8c0d6 100644 --- a/3rdparty/amd/tuning/benchmark_moe_rocm.py +++ b/3rdparty/amd/tuning/benchmark_moe_rocm.py @@ -10,7 +10,10 @@ import triton.language as tl from tqdm import tqdm from transformers import AutoConfig -from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe, get_config_file_name +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( + fused_moe, + get_config_file_name, +) padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 diff --git a/python/sglang/srt/layers/attention/__init__.py b/python/sglang/srt/layers/attention/__init__.py index 140755ff5..745598643 100644 --- a/python/sglang/srt/layers/attention/__init__.py +++ b/python/sglang/srt/layers/attention/__init__.py @@ -66,7 +66,14 @@ class AttentionBackend(ABC): if forward_batch.forward_mode.is_decode(): return self.forward_decode(q, k, v, layer, forward_batch, save_kv_cache) else: - return self.forward_extend(q, k, v, layer, forward_batch, save_kv_cache) + return self.forward_extend( + q, + k, + v, + layer, + forward_batch, + save_kv_cache, + ) def forward_decode( self, diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 8b823cc5a..fc3455b60 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -347,6 +347,8 @@ class FlashInferAttnBackend(AttentionBackend): else forward_batch.encoder_out_cache_loc ) + logits_soft_cap = layer.logit_cap + if not self.forward_metadata.use_ragged: if k is not None: assert v is not None @@ -359,7 +361,7 @@ class FlashInferAttnBackend(AttentionBackend): causal=not layer.is_cross_attention, sm_scale=layer.scaling, window_left=layer.sliding_window_size, - logits_soft_cap=layer.logit_cap, + logits_soft_cap=logits_soft_cap, ) else: o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( @@ -368,7 +370,7 @@ class FlashInferAttnBackend(AttentionBackend): v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim), causal=True, sm_scale=layer.scaling, - logits_soft_cap=layer.logit_cap, + logits_soft_cap=logits_soft_cap, ) if self.forward_metadata.extend_no_prefix: diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index b828c0391..9edfa7394 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -18,14 +18,15 @@ from vllm.distributed import ( # workaround from vllm.model_executor.layers.linear import LinearBase -from vllm.model_executor.parameter import ( + +from sglang.srt.layers.parameter import ( BasevLLMParameter, PackedColumnParameter, PackedvLLMParameter, PerTensorScaleParameter, RowvLLMParameter, + _ColumnvLLMParameter, ) - from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, @@ -94,6 +95,62 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): return param[shard_id], loaded_weight +def load_column_qkv_weight( + self, loaded_weight, num_heads, shard_id, shard_offset, shard_size, tp_rank +): + if ( + isinstance(self, (PackedColumnParameter, PackedvLLMParameter)) + and self.output_dim == self.packed_dim + ): + shard_size, shard_offset = self.adjust_shard_indexes_for_packing( + shard_offset=shard_offset, shard_size=shard_size + ) + + param_data = self.data + shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads + param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) + loaded_weight = loaded_weight.narrow( + self.output_dim, shard_id * shard_size, shard_size + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +def load_column_parallel_weight( + self, loaded_weight: torch.Tensor, tp_rank, use_presharded_weights: bool = False +): + if isinstance(self, _ColumnvLLMParameter): + if not use_presharded_weights: + shard_size = self.data.shape[self.output_dim] + loaded_weight = loaded_weight.narrow( + self.output_dim, tp_rank * shard_size, shard_size + ) + assert self.data.shape == loaded_weight.shape + self.data.copy_(loaded_weight) + else: + self.data.copy_(loaded_weight) + + +def load_row_parallel_weight( + self, loaded_weight: torch.Tensor, tp_rank, use_presharded_weights: bool = False +): + if isinstance(self, RowvLLMParameter): + if not use_presharded_weights: + shard_size = self.data.shape[self.input_dim] + loaded_weight = loaded_weight.narrow( + self.input_dim, tp_rank * shard_size, shard_size + ) + + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert self.data.shape == loaded_weight.shape + self.data.copy_(loaded_weight) + else: + self.data.copy_(loaded_weight) + + class LinearMethodBase(QuantizeMethodBase): """Base class for different (maybe quantized) linear methods.""" @@ -287,6 +344,8 @@ class ColumnParallelLinear(LinearBase): quant_config: Optional[QuantizationConfig] = None, output_sizes: Optional[List[int]] = None, prefix: str = "", + tp_rank: Optional[int] = None, + tp_size: Optional[int] = None, ): super().__init__( input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix @@ -295,7 +354,11 @@ class ColumnParallelLinear(LinearBase): self.gather_output = gather_output # Divide the weight matrix along the last dimension. - tp_size = get_tensor_model_parallel_world_size() + if tp_rank is None: + tp_rank = get_tensor_model_parallel_rank() + if tp_size is None: + tp_size = get_tensor_model_parallel_world_size() + self.tp_rank, self.tp_size = tp_rank, tp_size assert self.quant_method is not None self.output_size_per_partition = divide(self.output_size, tp_size) self.output_partition_sizes = [self.output_size_per_partition] @@ -336,7 +399,6 @@ class ColumnParallelLinear(LinearBase): self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) # Special case for GGUF @@ -356,7 +418,7 @@ class ColumnParallelLinear(LinearBase): # no need to narrow here if output_dim is not None and not use_bitsandbytes_4bit: shard_size = param_data.shape[output_dim] - start_idx = tp_rank * shard_size + start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) # Special case for loading scales off disk, which often do not @@ -364,7 +426,9 @@ class ColumnParallelLinear(LinearBase): if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) - assert param_data.shape == loaded_weight.shape + assert ( + param_data.shape == loaded_weight.shape + ), f"{param_data.shape=}, {loaded_weight.shape=}" param_data.copy_(loaded_weight) def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): @@ -373,7 +437,7 @@ class ColumnParallelLinear(LinearBase): if len(loaded_weight.shape) == 0: assert loaded_weight.numel() == 1 loaded_weight = loaded_weight.reshape(1) - param.load_column_parallel_weight(loaded_weight=loaded_weight) + load_column_parallel_weight(param, loaded_weight, self.tp_rank) def forward(self, input_): bias = self.bias if not self.skip_bias_add else None @@ -393,7 +457,7 @@ class ColumnParallelLinear(LinearBase): s = f"in_features={self.input_size}" s += f", output_features={self.output_size_per_partition}" s += f", bias={self.bias is not None}" - s += f", tp_size={get_tensor_model_parallel_world_size()}" + s += f", tp_size={self.tp_size}" s += f", gather_output={self.gather_output}" return s @@ -431,10 +495,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear): params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + tp_rank: Optional[int] = None, + tp_size: Optional[int] = None, + use_presharded_weights: bool = False, ): self.output_sizes = output_sizes - tp_size = get_tensor_model_parallel_world_size() + if tp_rank is None: + tp_rank = get_tensor_model_parallel_rank() + if tp_size is None: + tp_size = get_tensor_model_parallel_world_size() + self.tp_rank, self.tp_size = tp_rank, tp_size assert all(output_size % tp_size == 0 for output_size in output_sizes) + self.use_presharded_weights = use_presharded_weights super().__init__( input_size=input_size, output_size=sum(output_sizes), @@ -444,6 +516,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): params_dtype=params_dtype, quant_config=quant_config, prefix=prefix, + tp_rank=tp_rank, + tp_size=tp_size, ) def weight_loader( @@ -463,12 +537,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear): return if is_gguf_weight: - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - output_dim = getattr(param, "output_dim", None) - shard_size = loaded_weight.size(output_dim) // tp_size - start_idx = tp_rank * shard_size + shard_size = loaded_weight.size(output_dim) // self.tp_size + start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) @@ -494,7 +565,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear): param_data, loaded_weight, 0 ) - assert param_data.shape == loaded_weight.shape + assert ( + param_data.shape == loaded_weight.shape + ), f"{param_data.shape=}, {loaded_weight.shape=}" param_data.copy_(loaded_weight) return current_shard_offset = 0 @@ -522,11 +595,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear): return assert loaded_shard_id < len(self.output_sizes) - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() if output_dim is not None: - shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size - shard_size = self.output_sizes[loaded_shard_id] // tp_size + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size + shard_size = self.output_sizes[loaded_shard_id] // self.tp_size # Special case for quantization. # If quantized, we need to adjust the offset and size to account # for the packing. @@ -545,10 +616,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear): shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id param_data = param_data.narrow(output_dim, shard_offset, shard_size) - start_idx = tp_rank * shard_size + start_idx = self.tp_rank * shard_size # bitsandbytes loads the weights of the specific portion # no need to narrow here - if not use_bitsandbytes_4bit: + if not use_bitsandbytes_4bit and not self.use_presharded_weights: loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) # Special case for AQLM codebooks. elif is_metadata: @@ -572,7 +643,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear): "the same for all partitions." ) - assert param_data.shape == loaded_weight.shape + assert ( + param_data.shape == loaded_weight.shape + ), f"{param_data.shape=}, {loaded_weight.shape=}" param_data.copy_(loaded_weight) def _load_fused_module_from_checkpoint( @@ -629,26 +702,27 @@ class MergedColumnParallelLinear(ColumnParallelLinear): assert loaded_shard_id < len(self.output_sizes) - tp_size = get_tensor_model_parallel_world_size() - if isinstance(param, BlockQuantScaleParameter): weight_block_size = self.quant_method.quant_config.weight_block_size block_n, _ = weight_block_size[0], weight_block_size[1] shard_offset = ( (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n - ) // tp_size + ) // self.tp_size shard_size = ( - (self.output_sizes[loaded_shard_id] + block_n - 1) // block_n // tp_size + (self.output_sizes[loaded_shard_id] + block_n - 1) + // block_n + // self.tp_size ) else: - shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size - shard_size = self.output_sizes[loaded_shard_id] // tp_size + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size + shard_size = self.output_sizes[loaded_shard_id] // self.tp_size param.load_merged_column_weight( loaded_weight=loaded_weight, shard_id=loaded_shard_id, shard_offset=shard_offset, shard_size=shard_size, + use_presharded_weights=self.use_presharded_weights, ) @@ -689,6 +763,8 @@ class QKVParallelLinear(ColumnParallelLinear): params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + tp_rank: Optional[int] = None, + tp_size: Optional[int] = None, ): self.hidden_size = hidden_size self.head_size = head_size @@ -697,7 +773,11 @@ class QKVParallelLinear(ColumnParallelLinear): total_num_kv_heads = total_num_heads self.total_num_kv_heads = total_num_kv_heads # Divide the weight matrix along the last dimension. - tp_size = get_tensor_model_parallel_world_size() + if tp_rank is None: + tp_rank = get_tensor_model_parallel_rank() + if tp_size is None: + tp_size = get_tensor_model_parallel_world_size() + self.tp_rank, self.tp_size = tp_rank, tp_size self.num_heads = divide(self.total_num_heads, tp_size) if tp_size >= self.total_num_kv_heads: self.num_kv_heads = 1 @@ -724,6 +804,8 @@ class QKVParallelLinear(ColumnParallelLinear): params_dtype=params_dtype, quant_config=quant_config, prefix=prefix, + tp_rank=tp_rank, + tp_size=tp_size, ) def _get_shard_offset_mapping(self, loaded_shard_id: str): @@ -814,13 +896,24 @@ class QKVParallelLinear(ColumnParallelLinear): shard_offset = (shard_offset + block_n - 1) // block_n shard_size = (shard_size + block_n - 1) // block_n - param.load_qkv_weight( - loaded_weight=loaded_weight, - num_heads=self.num_kv_head_replicas, - shard_id=loaded_shard_id, - shard_offset=shard_offset, - shard_size=shard_size, - ) + if isinstance(param, _ColumnvLLMParameter): + load_column_qkv_weight( + param, + loaded_weight, + num_heads=self.num_kv_head_replicas, + shard_id=loaded_shard_id, + shard_offset=shard_offset, + shard_size=shard_size, + tp_rank=self.tp_rank, + ) + else: + param.load_qkv_weight( + loaded_weight=loaded_weight, + num_heads=self.num_kv_head_replicas, + shard_id=loaded_shard_id, + shard_offset=shard_offset, + shard_size=shard_size, + ) def weight_loader( self, @@ -840,12 +933,9 @@ class QKVParallelLinear(ColumnParallelLinear): return if is_gguf_weight: - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - output_dim = getattr(param, "output_dim", None) - shard_size = loaded_weight.size(output_dim) // tp_size - start_idx = tp_rank * shard_size + shard_size = loaded_weight.size(output_dim) // self.tp_size + start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) @@ -872,7 +962,9 @@ class QKVParallelLinear(ColumnParallelLinear): param_data, loaded_weight, 0 ) - assert param_data.shape == loaded_weight.shape + assert ( + param_data.shape == loaded_weight.shape + ), f"{param_data.shape=}, {loaded_weight.shape=}" param_data.copy_(loaded_weight) return shard_offsets = [ @@ -934,7 +1026,6 @@ class QKVParallelLinear(ColumnParallelLinear): self.weight_loader(param, loaded_weight_shard, shard_id) return - tp_rank = get_tensor_model_parallel_rank() assert loaded_shard_id in ["q", "k", "v"] # If output dim is defined, use the default loading process. @@ -984,9 +1075,9 @@ class QKVParallelLinear(ColumnParallelLinear): param_data = param_data.narrow(output_dim, shard_offset, shard_size) if loaded_shard_id == "q": - shard_id = tp_rank + shard_id = self.tp_rank else: - shard_id = tp_rank // self.num_kv_head_replicas + shard_id = self.tp_rank // self.num_kv_head_replicas start_idx = shard_id * shard_size # bitsandbytes loads the weights of the specific portion @@ -1014,7 +1105,9 @@ class QKVParallelLinear(ColumnParallelLinear): "for all partitions." ) - assert param_data.shape == loaded_weight.shape + assert ( + param_data.shape == loaded_weight.shape + ), f"{param_data.shape=}, {loaded_weight.shape=}" param_data.copy_(loaded_weight) @@ -1055,6 +1148,9 @@ class RowParallelLinear(LinearBase): reduce_results: bool = True, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + tp_rank: Optional[int] = None, + tp_size: Optional[int] = None, + use_presharded_weights: bool = False, ): super().__init__( input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix @@ -1064,10 +1160,14 @@ class RowParallelLinear(LinearBase): self.reduce_results = reduce_results # Divide the weight matrix along the last dimension. - self.tp_rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() + if tp_rank is None: + tp_rank = get_tensor_model_parallel_rank() + if tp_size is None: + tp_size = get_tensor_model_parallel_world_size() + self.tp_rank, self.tp_size = tp_rank, tp_size self.input_size_per_partition = divide(input_size, self.tp_size) assert self.quant_method is not None + self.use_presharded_weights = use_presharded_weights self.quant_method.create_weights( layer=self, @@ -1101,8 +1201,6 @@ class RowParallelLinear(LinearBase): self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() input_dim = getattr(param, "input_dim", None) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) @@ -1116,15 +1214,19 @@ class RowParallelLinear(LinearBase): if is_gguf_weight and isinstance(param, UninitializedParameter): weight_shape = list(loaded_weight.shape) if input_dim: - weight_shape[input_dim] = weight_shape[input_dim] // tp_size + weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) param_data = param.data # bitsandbytes loads the weights of the specific portion # no need to narrow here - if input_dim is not None and not use_bitsandbytes_4bit: + if ( + input_dim is not None + and not use_bitsandbytes_4bit + and not self.use_presharded_weights + ): shard_size = param_data.shape[input_dim] - start_idx = tp_rank * shard_size + start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) # Special case for loading scales off disk, which often do not @@ -1132,7 +1234,9 @@ class RowParallelLinear(LinearBase): if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.reshape(1) - assert param_data.shape == loaded_weight.shape + assert ( + param_data.shape == loaded_weight.shape + ), f"{param_data.shape=}, {loaded_weight.shape=}" param_data.copy_(loaded_weight) def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor): @@ -1143,17 +1247,21 @@ class RowParallelLinear(LinearBase): assert loaded_weight.numel() == 1 loaded_weight = loaded_weight.reshape(1) - param.load_row_parallel_weight(loaded_weight=loaded_weight) + load_row_parallel_weight( + param, + loaded_weight, + self.tp_rank, + use_presharded_weights=self.use_presharded_weights, + ) def forward(self, input_): if self.input_is_parallel: input_parallel = input_ else: - tp_rank = get_tensor_model_parallel_rank() splitted_input = split_tensor_along_last_dim( input_, num_partitions=self.tp_size ) - input_parallel = splitted_input[tp_rank].contiguous() + input_parallel = splitted_input[self.tp_rank].contiguous() # Matrix multiply. assert self.quant_method is not None diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 96eaf8566..8d0b7035e 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -204,6 +204,7 @@ class FusedMoE(torch.nn.Module): prefix: str = "", custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, + use_presharded_weights: bool = False, ): super().__init__() @@ -243,6 +244,7 @@ class FusedMoE(torch.nn.Module): params_dtype=params_dtype, weight_loader=self.weight_loader, ) + self.use_presharded_weights = use_presharded_weights def _load_per_tensor_weight_scale( self, @@ -395,10 +397,7 @@ class FusedMoE(torch.nn.Module): weight_name: str, shard_id: str, expert_id: int, - use_presharded_weights: bool = False, ) -> None: - self.use_presharded_weights = use_presharded_weights - # compressed-tensors checkpoints with packed weights are stored flipped # TODO (mgoin): check self.quant_method.quant_config.quant_format # against known CompressionFormat enum values that have this quality diff --git a/python/sglang/srt/layers/parameter.py b/python/sglang/srt/layers/parameter.py new file mode 100644 index 000000000..435cc69bb --- /dev/null +++ b/python/sglang/srt/layers/parameter.py @@ -0,0 +1,431 @@ +""" +Adapted from vLLM (0.6.4.post1). +https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/parameter.py +""" + +import logging +from fractions import Fraction +from typing import Callable, Optional, Union + +import torch +from torch.nn import Parameter +from vllm.distributed import get_tensor_model_parallel_rank + +__all__ = [ + "BasevLLMParameter", + "PackedvLLMParameter", + "PerTensorScaleParameter", + "ModelWeightParameter", + "ChannelQuantScaleParameter", + "GroupQuantScaleParameter", + "PackedColumnParameter", + "RowvLLMParameter", +] + +logger = logging.getLogger(__name__) + + +class BasevLLMParameter(Parameter): + """ + Base parameter for vLLM linear layers. Extends the torch.nn.parameter + by taking in a linear weight loader. Will copy the loaded weight + into the parameter when the provided weight loader is called. + """ + + def __new__(cls, data: torch.Tensor, **kwargs): + + return super().__new__(cls, data=data, requires_grad=False) + + def __init__(self, data: torch.Tensor, weight_loader: Callable): + """ + Initialize the BasevLLMParameter + + :param data: torch tensor with the parameter data + :param weight_loader: weight loader callable + + :returns: a torch.nn.parameter + """ + + self._weight_loader = weight_loader + + @property + def weight_loader(self): + return self._weight_loader + + def _assert_and_load(self, loaded_weight: torch.Tensor): + assert self.data.shape == loaded_weight.shape + self.data.copy_(loaded_weight) + + def load_column_parallel_weight(self, loaded_weight: torch.Tensor): + self._assert_and_load(loaded_weight) + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor): + self._assert_and_load(loaded_weight) + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + self._assert_and_load(loaded_weight) + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + self._assert_and_load(loaded_weight) + + +class _ColumnvLLMParameter(BasevLLMParameter): + """ + Private class defining weight loading functionality + (load_merged_column_weight, load_qkv_weight) + for parameters being loaded into linear layers with column + parallelism. This includes QKV and MLP layers which are + not already fused on disk. Requires an output dimension + to be defined. Called within the weight loader of + each of the column parallel linear layers. + """ + + def __init__(self, output_dim: int, **kwargs): + self._output_dim = output_dim + super().__init__(**kwargs) + + @property + def output_dim(self): + return self._output_dim + + def load_column_parallel_weight(self, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.data.shape[self.output_dim] + loaded_weight = loaded_weight.narrow( + self.output_dim, tp_rank * shard_size, shard_size + ) + assert self.data.shape == loaded_weight.shape + self.data.copy_(loaded_weight) + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + + shard_offset = kwargs.get("shard_offset") + shard_size = kwargs.get("shard_size") + use_presharded_weights = kwargs.get("use_presharded_weights") + if ( + isinstance(self, (PackedColumnParameter, PackedvLLMParameter)) + and self.packed_dim == self.output_dim + ): + shard_size, shard_offset = self.adjust_shard_indexes_for_packing( + shard_offset=shard_offset, shard_size=shard_size + ) + + param_data = self.data + + tp_rank = get_tensor_model_parallel_rank() + param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) + if not use_presharded_weights: + loaded_weight = loaded_weight.narrow( + self.output_dim, tp_rank * shard_size, shard_size + ) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + + shard_offset = kwargs.get("shard_offset") + shard_size = kwargs.get("shard_size") + shard_id = kwargs.get("shard_id") + num_heads = kwargs.get("num_heads") + + if ( + isinstance(self, (PackedColumnParameter, PackedvLLMParameter)) + and self.output_dim == self.packed_dim + ): + shard_size, shard_offset = self.adjust_shard_indexes_for_packing( + shard_offset=shard_offset, shard_size=shard_size + ) + + param_data = self.data + tp_rank = get_tensor_model_parallel_rank() + shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads + param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) + loaded_weight = loaded_weight.narrow( + self.output_dim, shard_id * shard_size, shard_size + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class RowvLLMParameter(BasevLLMParameter): + """ + Parameter class defining weight_loading functionality + (load_row_parallel_weight) for parameters being loaded + into linear layers with row parallel functionality. + Requires an input_dim to be defined. + """ + + def __init__(self, input_dim: int, **kwargs): + self._input_dim = input_dim + super().__init__(**kwargs) + + @property + def input_dim(self): + return self._input_dim + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor, **kwargs): + use_presharded_weights = kwargs.get("use_presharded_weights") + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.data.shape[self.input_dim] + if not use_presharded_weights: + loaded_weight = loaded_weight.narrow( + self.input_dim, tp_rank * shard_size, shard_size + ) + + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert self.data.shape == loaded_weight.shape + self.data.copy_(loaded_weight) + + +class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter): + """ + Parameter class for linear layer weights. Uses both column and + row parallelism. + """ + + pass + + +class GroupQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): + """ + Parameter class for weight scales loaded for weights with + grouped quantization. Uses both column and row parallelism. + """ + + pass + + +class ChannelQuantScaleParameter(_ColumnvLLMParameter): + """ + Parameter class for weight scales loaded for weights with + channel-wise quantization. Equivalent to _ColumnvLLMParameter. + """ + + pass + + +class PerTensorScaleParameter(BasevLLMParameter): + """ + Parameter class for scales where the number of scales is + equivalent to the number of logical matrices in fused linear + layers (e.g. for QKV, there are 3 scales loaded from disk). + This is relevant to weights with per-tensor quantization. + Adds functionality to map the scalers to a shard during + weight loading. + + Note: additional parameter manipulation may be handled + for each quantization config specifically, within + process_weights_after_loading + """ + + def __init__(self, **kwargs): + self.qkv_idxs = {"q": 0, "k": 1, "v": 2} + super().__init__(**kwargs) + + def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: + if isinstance(shard_id, int): + return shard_id + + # if not int, assume shard_id for qkv + # map to int and return + assert isinstance(shard_id, str) + assert shard_id in self.qkv_idxs + return self.qkv_idxs[shard_id] + + # For row parallel layers, no sharding needed + # load weight into parameter as is + def load_row_parallel_weight(self, *args, **kwargs): + super().load_row_parallel_weight(*args, **kwargs) + + def load_merged_column_weight(self, *args, **kwargs): + self._load_into_shard_id(*args, **kwargs) + + def load_qkv_weight(self, *args, **kwargs): + self._load_into_shard_id(*args, **kwargs) + + def load_column_parallel_weight(self, *args, **kwargs): + super().load_row_parallel_weight(*args, **kwargs) + + def _load_into_shard_id( + self, loaded_weight: torch.Tensor, shard_id: Union[str, int], **kwargs + ): + """ + Slice the parameter data based on the shard id for + loading. + """ + + param_data = self.data + shard_id = self._shard_id_as_int(shard_id) + + # AutoFP8 scales do not have a shape + # compressed-tensors scales do have a shape + if len(loaded_weight.shape) != 0: + assert loaded_weight.shape[0] == 1 + loaded_weight = loaded_weight[0] + + param_data = param_data[shard_id] + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class PackedColumnParameter(_ColumnvLLMParameter): + """ + Parameter for model parameters which are packed on disk + and support column parallelism only. See PackedvLLMParameter + for more details on the packed properties. + """ + + def __init__( + self, + packed_factor: Union[int, Fraction], + packed_dim: int, + marlin_tile_size: Optional[int] = None, + **kwargs + ): + self._packed_factor = packed_factor + self._packed_dim = packed_dim + self._marlin_tile_size = marlin_tile_size + super().__init__(**kwargs) + + @property + def packed_dim(self): + return self._packed_dim + + @property + def packed_factor(self): + return self._packed_factor + + @property + def marlin_tile_size(self): + return self._marlin_tile_size + + def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): + return _adjust_shard_indexes_for_packing( + shard_size=shard_size, + shard_offset=shard_offset, + packed_factor=self.packed_factor, + marlin_tile_size=self.marlin_tile_size, + ) + + +class PackedvLLMParameter(ModelWeightParameter): + """ + Parameter for model weights which are packed on disk. + Example: GPTQ Marlin weights are int4 or int8, packed into int32. + Extends the ModelWeightParameter to take in the + packed factor, the packed dimension, and optionally, marlin + tile size for marlin kernels. Adjusts the shard_size and + shard_offset for fused linear layers model weight loading + by accounting for packing and optionally, marlin tile size. + """ + + def __init__( + self, + packed_factor: Union[int, Fraction], + packed_dim: int, + marlin_tile_size: Optional[int] = None, + **kwargs + ): + self._packed_factor = packed_factor + self._packed_dim = packed_dim + self._marlin_tile_size = marlin_tile_size + super().__init__(**kwargs) + + @property + def packed_dim(self): + return self._packed_dim + + @property + def packed_factor(self): + return self._packed_factor + + @property + def marlin_tile_size(self): + return self._marlin_tile_size + + def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): + return _adjust_shard_indexes_for_packing( + shard_size=shard_size, + shard_offset=shard_offset, + packed_factor=self.packed_factor, + marlin_tile_size=self.marlin_tile_size, + ) + + +def permute_param_layout_( + param: BasevLLMParameter, input_dim: int, output_dim: int, **kwargs +) -> BasevLLMParameter: + """ + Permute a parameter's layout to the specified input and output dimensions, + useful for forcing the parameter into a known layout, for example, if I need + a packed (quantized) weight matrix to be in the layout + {input_dim = 0, output_dim = 1, packed_dim = 0} + then I can call: + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + to ensure x is in the correct layout (permuting it to the correct layout if + required, asserting if it cannot get it to the correct layout) + """ + + curr_input_dim = getattr(param, "input_dim", None) + curr_output_dim = getattr(param, "output_dim", None) + + if curr_input_dim is None or curr_output_dim is None: + assert param.data.dim() == 2, ( + "permute_param_layout_ only supports 2D parameters when either " + "input_dim or output_dim is not set" + ) + + # if one of the dimensions is not set, set it to the opposite of the other + # we can only do this since we asserted the parameter is 2D above + if curr_input_dim is None: + assert curr_output_dim is not None, "either input or output dim must be set" + curr_input_dim = (curr_output_dim + 1) % 2 + if curr_output_dim is None: + assert curr_input_dim is not None, "either input or output dim must be set" + curr_output_dim = (curr_input_dim + 1) % 2 + + # create permutation from the current layout to the layout with + # self.input_dim at input_dim and self.output_dim at output_dim preserving + # other dimensions + perm = [ + i for i in range(param.data.dim()) if i not in [curr_input_dim, curr_output_dim] + ] + perm.insert(input_dim, curr_input_dim) + perm.insert(output_dim, curr_output_dim) + + if "packed_dim" in kwargs: + assert ( + hasattr(param, "packed_dim") + and param.packed_dim == perm[kwargs["packed_dim"]] + ), "permute_param_layout_ currently doesn't support repacking" + + param.data = param.data.permute(*perm) + if hasattr(param, "_input_dim"): + param._input_dim = input_dim + if hasattr(param, "_output_dim"): + param._output_dim = output_dim + if "packed_dim" in kwargs and hasattr(param, "_packed_dim"): + param._packed_dim = kwargs["packed_dim"] + + return param + + +def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, marlin_tile_size): + return shard_size * marlin_tile_size, shard_offset * marlin_tile_size + + +def _adjust_shard_indexes_for_packing( + shard_size, shard_offset, packed_factor, marlin_tile_size +): + shard_size = shard_size // packed_factor + shard_offset = shard_offset // packed_factor + if marlin_tile_size is not None: + return _adjust_shard_indexes_for_marlin( + shard_size=shard_size, + shard_offset=shard_offset, + marlin_tile_size=marlin_tile_size, + ) + return shard_size, shard_offset diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index a263cb236..f9e4a8a4f 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -25,9 +25,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( per_tensor_dequantize, requantize_with_max_scale, ) -from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod +from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index effea1c6c..21d973918 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -12,8 +12,8 @@ from vllm.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.parameter import BasevLLMParameter +from sglang.srt.layers.parameter import BasevLLMParameter from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, diff --git a/python/sglang/srt/managers/session_controller.py b/python/sglang/srt/managers/session_controller.py index e3e94ce6b..e9c0c909d 100644 --- a/python/sglang/srt/managers/session_controller.py +++ b/python/sglang/srt/managers/session_controller.py @@ -99,7 +99,7 @@ class Session: if last_req is not None: # trim bos token if it is an append - if req.input_ids[0] == tokenizer.bos_token_id: + if tokenizer is not None and req.input_ids[0] == tokenizer.bos_token_id: req.input_ids = req.input_ids[1:] input_ids = ( diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index fab8b15a3..354408ab3 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -106,6 +106,9 @@ class ForwardMode(IntEnum): def is_dummy_first(self): return self == ForwardMode.DUMMY_FIRST + def is_decode_or_idle(self): + return self == ForwardMode.DECODE or self == ForwardMode.IDLE + class CaptureHiddenMode(IntEnum): NULL = auto() diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7cd9e759a..719db19cd 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -205,7 +205,7 @@ class ModelRunner: if self.device == "cuda": backend = "nccl" elif self.device == "xpu": - # TODO(liangan1):Just use gloo to bypass the initilization fail + # TODO(liangan1): Just use gloo to bypass the initilization fail # Need to use xccl for xpu backend in the future backend = "gloo" elif self.device == "hpu": @@ -634,7 +634,6 @@ class ModelRunner: ) def init_double_sparsity_channel_config(self, selected_channel): - selected_channel = "." + selected_channel + "_proj" self.sorted_channels = [] # load channel config diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 0485b80fc..33a055a8f 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -57,6 +57,7 @@ class Grok1MLP(nn.Module): quant_config: Optional[QuantizationConfig] = None, prefix: str = "", reduce_results=True, + use_presharded_weights: bool = False, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -65,6 +66,7 @@ class Grok1MLP(nn.Module): bias=False, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj", + use_presharded_weights=use_presharded_weights, ) self.down_proj = RowParallelLinear( intermediate_size, @@ -73,6 +75,7 @@ class Grok1MLP(nn.Module): quant_config=quant_config, prefix=f"{prefix}.down_proj", reduce_results=reduce_results, + use_presharded_weights=use_presharded_weights, ) self.act_fn = GeluAndMul(approximate="tanh") @@ -103,6 +106,7 @@ class Grok1MoE(nn.Module): quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, reduce_results=True, + use_presharded_weights: bool = False, ): super().__init__() self.hidden_size = hidden_size @@ -129,6 +133,7 @@ class Grok1MoE(nn.Module): renormalize=False, quant_config=quant_config, tp_size=tp_size, + use_presharded_weights=use_presharded_weights, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -156,6 +161,7 @@ class Grok1Attention(nn.Module): max_position: int = 4096 * 32, rope_theta: float = 10000, quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, ) -> None: super().__init__() self.config = config @@ -194,6 +200,7 @@ class Grok1Attention(nn.Module): hidden_size, bias=False, quant_config=quant_config, + reduce_results=reduce_results, ) self.rotary_emb = get_rope( self.head_dim, @@ -234,10 +241,12 @@ class Grok1DecoderLayer(nn.Module): config: PretrainedConfig, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + use_presharded_weights: bool = False, ) -> None: super().__init__() self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size + self.layer_id = layer_id rope_theta = getattr(config, "rope_theta", 10000) self.self_attn = Grok1Attention( @@ -262,6 +271,7 @@ class Grok1DecoderLayer(nn.Module): ), quant_config=quant_config, reduce_results=True, + use_presharded_weights=use_presharded_weights, ) self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -299,6 +309,7 @@ class Grok1Model(nn.Module): self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + use_presharded_weights: bool = False, ) -> None: super().__init__() self.config = config @@ -311,7 +322,12 @@ class Grok1Model(nn.Module): ) self.layers = nn.ModuleList( [ - Grok1DecoderLayer(config, i, quant_config=quant_config) + Grok1DecoderLayer( + config, + i, + quant_config=quant_config, + use_presharded_weights=use_presharded_weights, + ) for i in range(config.num_hidden_layers) ] ) @@ -347,11 +363,7 @@ class Grok1ForCausalLM(nn.Module): super().__init__() self.config = config self.quant_config = quant_config - self.model = Grok1Model(config, quant_config=quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.logits_processor = LogitsProcessor(config) - # Monkey patch _prepare_weights to load pre-sharded weights if ( self.config.num_local_experts > 0 and get_tensor_model_parallel_world_size() > 1 @@ -361,6 +373,14 @@ class Grok1ForCausalLM(nn.Module): else: self.use_presharded_weights = False + self.model = Grok1Model( + config, + quant_config=quant_config, + use_presharded_weights=self.use_presharded_weights, + ) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.logits_processor = LogitsProcessor(config) + def forward( self, input_ids: torch.Tensor, @@ -376,10 +396,7 @@ class Grok1ForCausalLM(nn.Module): def load_weights( self, weights: Iterable[Tuple[str, torch.Tensor]], - use_presharded_weights: Optional[bool] = None, ): - if use_presharded_weights is None: - use_presharded_weights = self.use_presharded_weights num_experts = self.config.num_local_experts stacked_params_mapping = [ @@ -435,20 +452,12 @@ class Grok1ForCausalLM(nn.Module): continue name = name.replace(weight_name, param_name) - if use_presharded_weights: - extra_kwargs = { - "use_presharded_weights": use_presharded_weights - } - else: - extra_kwargs = {} - load_weight_wrapper( name, loaded_weight, name, shard_id=shard_id, expert_id=expert_id, - **extra_kwargs, ) break else: diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index f60af5d73..8fd902818 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -544,7 +544,12 @@ def launch_server( # Send a warmup request t = threading.Thread( - target=_wait_and_warmup, args=(server_args, pipe_finish_writer) + target=_wait_and_warmup, + args=( + server_args, + pipe_finish_writer, + tokenizer_manager.image_token_id, + ), ) t.start() @@ -614,7 +619,7 @@ def _set_envs_and_config(server_args: ServerArgs): mp.set_start_method("spawn", force=True) -def _wait_and_warmup(server_args, pipe_finish_writer): +def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text): headers = {} url = server_args.url() if server_args.api_key: diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 88c88c072..b804e7c6a 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -14,7 +14,7 @@ from sglang.srt.speculative.build_eagle_tree import build_tree_kernel from sglang.srt.speculative.spec_info import SpecInfo if TYPE_CHECKING: - from python.sglang.srt.managers.schedule_batch import ScheduleBatch + from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.server_args import ServerArgs diff --git a/scripts/killall_sglang.sh b/scripts/killall_sglang.sh index 4057d2be2..53d08703e 100755 --- a/scripts/killall_sglang.sh +++ b/scripts/killall_sglang.sh @@ -7,6 +7,7 @@ nvidia-smi kill -9 $(ps aux | grep 'sglang::' | grep -v 'grep' | awk '{print $2}') 2>/dev/null kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}') 2>/dev/null kill -9 $(ps aux | grep 'sglang.bench' | grep -v 'grep' | awk '{print $2}') 2>/dev/null +kill -9 $(ps aux | grep 'sglang.data_parallel' | grep -v 'grep' | awk '{print $2}') 2>/dev/null # Clean all GPU processes if any argument is provided if [ $# -gt 0 ]; then