diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 6cae1fd9a..1ebb44d55 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -31,8 +31,9 @@ from vllm.distributed import ( get_tensor_model_parallel_world_size, ) from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.utils import set_weight_attrs + +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.utils import set_weight_attrs logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py new file mode 100644 index 000000000..e61dccd6a --- /dev/null +++ b/python/sglang/srt/layers/linear.py @@ -0,0 +1,1133 @@ +# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/linear.py + +import logging +from abc import abstractmethod +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter, UninitializedParameter +from vllm.distributed import ( + divide, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) + +# workaround +from vllm.model_executor.layers.linear import LinearBase +from vllm.model_executor.parameter import ( + BasevLLMParameter, + PackedvLLMParameter, + PerTensorScaleParameter, +) + +from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.utils import set_weight_attrs + +logger = logging.getLogger(__name__) + +WEIGHT_LOADER_V2_SUPPORTED = [ + "CompressedTensorsLinearMethod", + "AWQMarlinLinearMethod", + "AWQLinearMethod", + "GPTQMarlinLinearMethod", + "Fp8LinearMethod", + "MarlinLinearMethod", +] + + +def adjust_marlin_shard(param, shard_size, shard_offset): + marlin_tile_size = getattr(param, "marlin_tile_size", None) + if marlin_tile_size is None: + return shard_size, shard_offset + + return shard_size * marlin_tile_size, shard_offset * marlin_tile_size + + +def adjust_bitsandbytes_shard( + param: Parameter, qkv_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str +) -> Tuple[int, int]: + """Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" + + total, _ = qkv_offsets["total"] + orig_offset, orig_size = qkv_offsets[loaded_shard_id] + + quantized_total = param.data.shape[0] + quantized_offset = orig_offset * quantized_total // total + quantized_size = orig_size * quantized_total // total + + return quantized_size, quantized_offset + + +def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): + """For fused modules (QKV and MLP) we have an array of length + N that holds 1 scale for each "logical" matrix. So the param + is an array of length N. The loaded_weight corresponds to + one of the shards on disk. Here, we slice the param based on + the shard_id for loading. + """ + qkv_idxs = {"q": 0, "k": 1, "v": 2} + + if isinstance(shard_id, str): + shard_id = qkv_idxs[shard_id] + elif not isinstance(shard_id, int): + raise ValueError(f"Unknown Shard Id {shard_id}") + + # AutoFP8 scales do not have a shape + # compressed-tensors scales do have a shape + if len(loaded_weight.shape) != 0: + assert loaded_weight.shape[0] == 1 + loaded_weight = loaded_weight[0] + + return param[shard_id], loaded_weight + + +class LinearMethodBase(QuantizeMethodBase): + """Base class for different (maybe quantized) linear methods.""" + + @abstractmethod + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + """Create weights for a linear layer. + The weights will be set as attributes of the layer. + + Args: + layer: The layer that is using the LinearMethodBase factory. + input_size_per_partition: Size of the weight input dim on rank X. + output_partition_sizes: Sizes of the output dim of each logical + weight on rank X. E.g., output_partition_sizes for QKVLinear + is a list contains the width of Wq, Wk, Wv on rank X. + input_size: Size of the input dim of the weight across all ranks. + output_size: Size of the output dim of the weight across all ranks. + params_dtype: Datatype of the parameters. + """ + raise NotImplementedError + + @abstractmethod + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Apply the weights in layer to the input tensor. + Expects create_weights to have been called before on the layer.""" + raise NotImplementedError + + +class UnquantizedLinearMethod(LinearMethodBase): + """Linear method without quantization.""" + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight = Parameter( + torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + return F.linear(x, layer.weight, bias) + + +class ReplicatedLinear(LinearBase): + """Replicated linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__( + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix=prefix, + ) + + # All the linear layer supports quant method. + assert self.quant_method is not None + self.quant_method.create_weights( + self, + self.input_size, + [self.output_size], + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader, + prefix=prefix, + ) + + if bias: + self.bias = Parameter( + torch.empty(self.output_size, dtype=self.params_dtype) + ) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + # If the weight on disk does not have a shape, give it one + # (such scales for AutoFp8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param.size() == loaded_weight.size() + param.data.copy_(loaded_weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + bias = self.bias if not self.skip_bias_add else None + assert self.quant_method is not None + output = self.quant_method.apply(self, x, bias) + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + return s + + +class ColumnParallelLinear(LinearBase): + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Args: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. + gather_output: If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is Y_i = XA_i + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + output_sizes: list of output sizes packed into one output, like for QKV + the list would be size 3. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + output_sizes: Optional[List[int]] = None, + prefix: str = "", + ): + super().__init__( + input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix + ) + + self.gather_output = gather_output + + # Divide the weight matrix along the last dimension. + tp_size = get_tensor_model_parallel_world_size() + assert self.quant_method is not None + self.output_size_per_partition = divide(self.output_size, tp_size) + self.output_partition_sizes = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = [ + divide(output_size, tp_size) for output_size in self.output_sizes + ] + + if output_sizes is None: + output_sizes = [output_size] + + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=( + self.weight_loader_v2 + if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED + else self.weight_loader + ), + prefix=prefix, + ) + if bias: + self.bias = Parameter( + torch.empty(self.output_size_per_partition, dtype=params_dtype) + ) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + output_dim = getattr(param, "output_dim", None) + + # Special case for GGUF + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.weight_type = loaded_weight.item() + + # Materialize GGUF UninitializedParameter + if is_gguf_weight and isinstance(param, UninitializedParameter): + param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype) + + param_data = param.data + if output_dim is not None: + shard_size = param_data.shape[output_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + assert loaded_weight.numel() == 1 + loaded_weight = loaded_weight.reshape(1) + param.load_column_parallel_weight(loaded_weight=loaded_weight) + + def forward(self, input_): + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + assert self.quant_method is not None + output_parallel = self.quant_method.apply(self, input_, bias) + if self.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size_per_partition}" + s += f", bias={self.bias is not None}" + s += f", tp_size={get_tensor_model_parallel_world_size()}" + s += f", gather_output={self.gather_output}" + return s + + +class MergedColumnParallelLinear(ColumnParallelLinear): + """Packed linear layers with column parallelism. + + Similar to ColumnParallelLinear, but the weight matrix is concatenated + along the output dimension. When the weight matrix is loaded, the + different partitions are sharded separately. + + Args: + input_size: input dimension of the linear layer. + output_sizes: list of output dimensions of the linear layer. + bias: If true, add bias. + gather_output: If true, call all-gather on output and make the output + available to all GPUs, otherwise, every GPU will have + its own output. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__( + self, + input_size: int, + output_sizes: List[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + self.output_sizes = output_sizes + tp_size = get_tensor_model_parallel_world_size() + assert all(output_size % tp_size == 0 for output_size in output_sizes) + super().__init__( + input_size=input_size, + output_size=sum(output_sizes), + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + ) + + def weight_loader( + self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[int] = None, + ): + + # Special case for GGUF + # initialize GGUF param after we know the quantize type + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.data[loaded_shard_id].copy_(loaded_weight) + param.shard_weight_type[loaded_shard_id] = loaded_weight.item() + return + + if is_gguf_weight and isinstance(param, UninitializedParameter): + from gguf.constants import GGML_QUANT_SIZES + + ori_shape = param.tensor_shape + weight_types = self.qweight_type.shard_weight_type.values() + row_size = [] + for weight_type in weight_types: + block_size, type_size = GGML_QUANT_SIZES[weight_type] + row_size.append(ori_shape[1] // block_size * type_size) + q_shape = (ori_shape[0], max(row_size)) + param.materialize(q_shape, dtype=loaded_weight.dtype) + + param_data = param.data + output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. + is_metadata = getattr(param, "is_metadata", False) + # Special case for per-tensor scale to load scalar into fused array. + needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) + + if loaded_shard_id is None: + # Loaded weight is already fused on disk (qkv/mlp). + if output_dim is None: + if needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, 0 + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + current_shard_offset = 0 + shard_offsets: List[Tuple[int, int, int]] = [] + for i, output_size in enumerate(self.output_sizes): + shard_offsets.append((i, current_shard_offset, output_size)) + current_shard_offset += output_size + packed_dim = getattr(param, "packed_dim", None) + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset + ) + + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size + ) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + assert loaded_shard_id < len(self.output_sizes) + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + if output_dim is not None: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + # Special case for quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + packed_dim = getattr(param, "packed_dim", None) + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset + ) + + use_bitsandbytes = getattr(param, "use_bitsandbytes", False) + if use_bitsandbytes: + shard_size = loaded_weight.shape[output_dim] + shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id + + if is_gguf_weight: + tp_size = get_tensor_model_parallel_world_size() + output_dim = getattr(param, "output_dim", None) + shard_shape = list(loaded_weight.shape) + shard_shape[output_dim] = shard_shape[output_dim] // tp_size + param.shard_id.append(loaded_shard_id) + param.shard_size[loaded_shard_id] = shard_shape + + input_dim = getattr(param, "input_dim", None) + input_size = loaded_weight.shape[input_dim] + param_data = param_data.narrow(input_dim, 0, input_size) + + param_data = param_data.narrow(output_dim, shard_offset, shard_size) + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + # Special case for AQLM codebooks. + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_offset = loaded_shard_id * shard_size + param_data = param_data.narrow(0, shard_offset, shard_size) + + # Special case for per-tensor scales in fused case. + elif needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, loaded_shard_id + ) + + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "MergedColumnParallelLinear, assume the weight is " + "the same for all partitions." + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def _load_fused_module_from_checkpoint( + self, param: BasevLLMParameter, loaded_weight: torch.Tensor + ): + """ + Handle special case for models where MLP layers are already + fused on disk. In this case, we have no shard id. This function + determmines the shard id by splitting these layers and then calls + the weight loader using the shard id. + + An example of a model with these fused layers: + https://huggingface.co/microsoft/Phi-3-mini-4k-instruct + """ + + current_shard_offset = 0 + shard_offsets: List[Tuple[int, int, int]] = [] + for i, output_size in enumerate(self.output_sizes): + shard_offsets.append((i, current_shard_offset, output_size)) + current_shard_offset += output_size + + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if ( + isinstance(param, PackedvLLMParameter) + and param.packed_dim == param.output_dim + ): + shard_size, shard_offset = param.adjust_shard_indexes_for_packing( + shard_size=shard_size, shard_offset=shard_offset + ) + + loaded_weight_shard = loaded_weight.narrow( + param.output_dim, shard_offset, shard_size + ) + self.weight_loader_v2(param, loaded_weight_shard, shard_id) + + def weight_loader_v2( + self, + param: BasevLLMParameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[int] = None, + ): + if loaded_shard_id is None: + if isinstance(param, PerTensorScaleParameter): + param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0) + return + elif type(param) is BasevLLMParameter: + param.load_merged_column_weight(loaded_weight=loaded_weight) + return + self._load_fused_module_from_checkpoint(param, loaded_weight) + return + + assert loaded_shard_id < len(self.output_sizes) + + tp_size = get_tensor_model_parallel_world_size() + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + + param.load_merged_column_weight( + loaded_weight=loaded_weight, + shard_id=loaded_shard_id, + shard_offset=shard_offset, + shard_size=shard_size, + ) + + +class QKVParallelLinear(ColumnParallelLinear): + """Linear layers for the attention's QKV transformation. + + Linear layers for the linear transformation of the query, key, and value + vectors in the attention layer. The weight matrix is concatenated along + the output dimension. The layer is parallelized along the head dimension. + When the number of key/value heads is smaller than the number of query + heads (e.g., multi-query/grouped-query attention), the key/value head may + be replicated while the query heads are partitioned. + + Args: + hidden_size: input hidden state size of the transformer. + head_size: size of each attention head. + total_num_heads: total number of attention query heads. + total_num_kv_heads: total number of attention key/value heads. If + None, assume total_num_kv_heads = total_num_heads. + bias: If true, add bias. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__( + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + self.hidden_size = hidden_size + self.head_size = head_size + self.total_num_heads = total_num_heads + if total_num_kv_heads is None: + total_num_kv_heads = total_num_heads + self.total_num_kv_heads = total_num_kv_heads + # Divide the weight matrix along the last dimension. + tp_size = get_tensor_model_parallel_world_size() + self.num_heads = divide(self.total_num_heads, tp_size) + if tp_size >= self.total_num_kv_heads: + self.num_kv_heads = 1 + self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads) + else: + self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) + self.num_kv_head_replicas = 1 + input_size = self.hidden_size + output_size = ( + (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size + ) + self.output_sizes = [ + self.num_heads * self.head_size * tp_size, # q_proj + self.num_kv_heads * self.head_size * tp_size, # k_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj + ] + + super().__init__( + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix, + ) + + def _get_shard_offset_mapping(self, loaded_shard_id: str): + shard_offset_mapping = { + "q": 0, + "k": self.num_heads * self.head_size, + "v": (self.num_heads + self.num_kv_heads) * self.head_size, + "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size, + } + return shard_offset_mapping.get(loaded_shard_id) + + def _get_shard_size_mapping(self, loaded_shard_id: str): + shard_size_mapping = { + "q": self.num_heads * self.head_size, + "k": self.num_kv_heads * self.head_size, + "v": self.num_kv_heads * self.head_size, + } + return shard_size_mapping.get(loaded_shard_id) + + def _load_fused_module_from_checkpoint( + self, param: BasevLLMParameter, loaded_weight: torch.Tensor + ): + """ + Handle special case for models where QKV layers are already + fused on disk. In this case, we have no shard id. This function + determmines the shard id by splitting these layers and then calls + the weight loader using the shard id. + + An example of a model with these fused layers: + https://huggingface.co/microsoft/Phi-3-mini-4k-instruct + """ + shard_offsets = [ + # (shard_id, shard_offset, shard_size) + ("q", 0, self.total_num_heads * self.head_size), + ( + "k", + self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + ( + "v", + (self.total_num_heads + self.total_num_kv_heads) * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + ] + + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if ( + isinstance(param, PackedvLLMParameter) + and param.packed_dim == param.output_dim + ): + shard_size, shard_offset = param.adjust_shard_indexes_for_packing( + shard_size=shard_size, shard_offset=shard_offset + ) + + loaded_weight_shard = loaded_weight.narrow( + param.output_dim, shard_offset, shard_size + ) + self.weight_loader_v2(param, loaded_weight_shard, shard_id) + + def weight_loader_v2( + self, + param: BasevLLMParameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None, + ): + if loaded_shard_id is None: # special case for certain models + if isinstance(param, PerTensorScaleParameter): + param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0) + return + elif type(param) is BasevLLMParameter: + param.load_merged_column_weight(loaded_weight=loaded_weight) + return + self._load_fused_module_from_checkpoint(param, loaded_weight) + return + + assert loaded_shard_id in ["q", "k", "v"] + + shard_offset = self._get_shard_offset_mapping(loaded_shard_id) + shard_size = self._get_shard_size_mapping(loaded_shard_id) + + param.load_qkv_weight( + loaded_weight=loaded_weight, + num_heads=self.num_kv_head_replicas, + shard_id=loaded_shard_id, + shard_offset=shard_offset, + shard_size=shard_size, + ) + + def weight_loader( + self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None, + ): + + # Special case for GGUF + # initialize GGUF param after we know the quantize type + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type and loaded_shard_id is not None: + idx_map = {"q": 0, "k": 1, "v": 2} + param.data[idx_map[loaded_shard_id]].copy_(loaded_weight) + param.shard_weight_type[loaded_shard_id] = loaded_weight.item() + return + + if is_gguf_weight and isinstance(param, UninitializedParameter): + from gguf.constants import GGML_QUANT_SIZES + + ori_shape = param.tensor_shape + weight_types = self.qweight_type.shard_weight_type.values() + row_size = [] + for weight_type in weight_types: + block_size, type_size = GGML_QUANT_SIZES[weight_type] + row_size.append(ori_shape[1] // block_size * type_size) + q_shape = (ori_shape[0], max(row_size)) + param.materialize(q_shape, dtype=loaded_weight.dtype) + + param_data = param.data + output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. + is_metadata = getattr(param, "is_metadata", False) + + # Special case for per-tensor scales in fused case. + needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) + + if loaded_shard_id is None: + # Loaded weight is already fused on disk (qkv/mlp). + if output_dim is None: + if needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, 0 + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + shard_offsets = [ + # (shard_id, shard_offset, shard_size) + ("q", 0, self.total_num_heads * self.head_size), + ( + "k", + self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + ( + "v", + (self.total_num_heads + self.total_num_kv_heads) * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + ] + packed_dim = getattr(param, "packed_dim", None) + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantized Weights. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset + ) + + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size + ) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + tp_rank = get_tensor_model_parallel_rank() + assert loaded_shard_id in ["q", "k", "v"] + + # If output dim is defined, use the default loading process. + if output_dim is not None: + if loaded_shard_id == "q": + shard_offset = 0 + shard_size = self.num_heads * self.head_size + elif loaded_shard_id == "k": + shard_offset = self.num_heads * self.head_size + shard_size = self.num_kv_heads * self.head_size + elif loaded_shard_id == "v": + shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size + shard_size = self.num_kv_heads * self.head_size + # Special case for Quantized Weights. + # If quantized, we need to adjust the offset and size to account + # for the packing. + packed_dim = getattr(param, "packed_dim", None) + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset + ) + + use_bitsandbytes = getattr(param, "use_bitsandbytes", False) + if use_bitsandbytes: + orig_qkv_offsets = { + "q": (0, self.num_heads * self.head_size), + "k": ( + self.num_heads * self.head_size, + self.num_kv_heads * self.head_size, + ), + "v": ( + (self.num_heads + self.num_kv_heads) * self.head_size, + self.num_kv_heads * self.head_size, + ), + "total": ( + (self.num_heads + 2 * self.num_kv_heads) * self.head_size, + 0, + ), + } + shard_size, shard_offset = adjust_bitsandbytes_shard( + param, orig_qkv_offsets, loaded_shard_id + ) + + if is_gguf_weight: + tp_size = get_tensor_model_parallel_world_size() + output_dim = getattr(param, "output_dim", None) + shard_shape = list(loaded_weight.shape) + shard_shape[output_dim] = shard_shape[output_dim] // tp_size + param.shard_id.append(loaded_shard_id) + param.shard_size[loaded_shard_id] = shard_shape + + input_dim = getattr(param, "input_dim", None) + input_size = loaded_weight.shape[input_dim] + param_data = param_data.narrow(input_dim, 0, input_size) + + param_data = param_data.narrow(output_dim, shard_offset, shard_size) + if loaded_shard_id == "q": + shard_id = tp_rank + else: + shard_id = tp_rank // self.num_kv_head_replicas + start_idx = shard_id * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + # Special case for for AQLM codebooks. + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_index = ["q", "k", "v"].index(loaded_shard_id) + param_data = param_data.narrow(0, shard_index * shard_size, shard_size) + # Special case for per-tensor scales in fused case. + elif needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, loaded_shard_id + ) + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "QKVParallelLinear, assume the weight is the same " + "for all partitions." + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class RowParallelLinear(LinearBase): + """Linear layer with row parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. Note that bias is not parallelized. + input_is_parallel: If true, we assume that the input is already + split across the GPUs and we do not split + again. + skip_bias_add: This was added to enable performance optimization where + bias can be fused with other element-wise operations. + We skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__( + input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix + ) + + self.input_is_parallel = input_is_parallel + self.reduce_results = reduce_results + + # Divide the weight matrix along the last dimension. + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.input_size_per_partition = divide(input_size, self.tp_size) + assert self.quant_method is not None + + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=[self.output_size], + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=( + self.weight_loader_v2 + if self.quant_method.__class__.__name__ in WEIGHT_LOADER_V2_SUPPORTED + else self.weight_loader + ), + prefix=prefix, + ) + if not reduce_results and (bias and not skip_bias_add): + raise ValueError( + "When not reduce the results, adding bias to the " + "results can lead to incorrect results" + ) + + if bias: + self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype)) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + input_dim = getattr(param, "input_dim", None) + + # Special case for GGUF + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.weight_type = loaded_weight.item() + + # Materialize GGUF UninitializedParameter + if is_gguf_weight and isinstance(param, UninitializedParameter): + weight_shape = list(loaded_weight.shape) + if input_dim: + weight_shape[input_dim] = weight_shape[input_dim] // tp_size + param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) + + param_data = param.data + if input_dim is not None: + shard_size = param_data.shape[input_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor): + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + assert loaded_weight.numel() == 1 + loaded_weight = loaded_weight.reshape(1) + + param.load_row_parallel_weight(loaded_weight=loaded_weight) + + def forward(self, input_): + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size + ) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) + if self.reduce_results and self.tp_size > 1: + output = tensor_model_parallel_all_reduce(output_parallel) + else: + output = output_parallel + + output_bias = self.bias if self.skip_bias_add else None + + return output, output_bias + + def extra_repr(self) -> str: + s = f"input_features={self.input_size_per_partition}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + s += f", tp_size={self.tp_size}" + s += f", reduce_results={self.reduce_results}" + return s diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py new file mode 100644 index 000000000..5f9619de1 --- /dev/null +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -0,0 +1,76 @@ +# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py + +from typing import Dict, Type + +from vllm.model_executor.layers.quantization.aqlm import AQLMConfig +from vllm.model_executor.layers.quantization.awq import AWQConfig +from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig +from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsConfig, +) +from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig +from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config +from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config +from vllm.model_executor.layers.quantization.fp8 import Fp8Config +from vllm.model_executor.layers.quantization.gguf import GGUFConfig +from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig +from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config +from vllm.model_executor.layers.quantization.marlin import MarlinConfig +from vllm.model_executor.layers.quantization.qqq import QQQConfig +from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig +from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig + +from sglang.srt.layers.quantization.base_config import QuantizationConfig + +QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { + "aqlm": AQLMConfig, + "awq": AWQConfig, + "deepspeedfp": DeepSpeedFPConfig, + "tpu_int8": Int8TpuConfig, + "fp8": Fp8Config, + "fbgemm_fp8": FBGEMMFp8Config, + # The order of gptq methods is important for config.py iteration over + # override_quantization_method(..) + "marlin": MarlinConfig, + "gguf": GGUFConfig, + "gptq_marlin_24": GPTQMarlin24Config, + "gptq_marlin": GPTQMarlinConfig, + "awq_marlin": AWQMarlinConfig, + "gptq": GPTQConfig, + "squeezellm": SqueezeLLMConfig, + "compressed-tensors": CompressedTensorsConfig, + "bitsandbytes": BitsAndBytesConfig, + "qqq": QQQConfig, + "experts_int8": ExpertsInt8Config, +} + + +def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: + if quantization not in QUANTIZATION_METHODS: + raise ValueError(f"Invalid quantization method: {quantization}") + return QUANTIZATION_METHODS[quantization] + + +__all__ = [ + "QuantizationConfig", + "get_quantization_config", + "QUANTIZATION_METHODS", +] + +""" +def fp8_get_quant_method( + self, layer: torch.nn.Module, prefix: str +) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, self.ignored_layers): + return UnquantizedLinearMethod() + return Fp8LinearMethod(self) + elif isinstance(layer, FusedMoE): + return Fp8MoEMethod(self) + return None + + +setattr(Fp8Config, "get_quant_method", fp8_get_quant_method) +""" diff --git a/python/sglang/srt/layers/quantization/base_config.py b/python/sglang/srt/layers/quantization/base_config.py new file mode 100644 index 000000000..aa83b4eea --- /dev/null +++ b/python/sglang/srt/layers/quantization/base_config.py @@ -0,0 +1,122 @@ +# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/base_config.py + +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + +import torch +from torch import nn + + +class QuantizeMethodBase(ABC): + """Base class for different quantized methods.""" + + @abstractmethod + def create_weights( + self, layer: torch.nn.Module, *weight_args, **extra_weight_attrs + ): + """Create weights for a layer. + + The weights will be set as attributes of the layer.""" + raise NotImplementedError + + @abstractmethod + def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor: + """Apply the weights in layer to the input tensor. + + Expects create_weights to have been called before on the layer.""" + raise NotImplementedError + + def process_weights_after_loading(self, layer: nn.Module) -> None: + """Process the weight after loading. + + This can be used for example, to transpose weights for computation. + """ + return + + +class QuantizationConfig(ABC): + """Base class for quantization configs.""" + + @abstractmethod + def get_name(self) -> str: + """Name of the quantization method.""" + raise NotImplementedError + + @abstractmethod + def get_supported_act_dtypes(self) -> List[torch.dtype]: + """List of supported activation dtypes.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + """Minimum GPU capability to support the quantization method. + + E.g., 70 for Volta, 75 for Turing, 80 for Ampere. + This requirement is due to the custom CUDA kernels used by the + quantization method. + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_config_filenames() -> List[str]: + """List of filenames to search for in the model directory.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig": + """Create a config class from the model's quantization config.""" + raise NotImplementedError + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: + """ + Detects if this quantization method can support a given checkpoint + format by overriding the user specified quantization method -- + this method should only be overwritten by subclasses in exceptional + circumstances + """ + return None + + @staticmethod + def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: + """Get a value from the model's quantization config.""" + for key in keys: + if key in config: + return config[key] + raise ValueError( + f"Cannot find any of {keys} in the model's " "quantization config." + ) + + @staticmethod + def get_from_keys_or(config: Dict[str, Any], keys: List[str], default: Any) -> Any: + """Get a optional value from the model's quantization config.""" + try: + return QuantizationConfig.get_from_keys(config, keys) + except ValueError: + return default + + @abstractmethod + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[QuantizeMethodBase]: + """Get the quantize method to use for the quantized layer. + + Args: + layer: The layer for the quant method. + prefix: The full name of the layer in the state dict + Returns: + The quantize method. None if the given layer doesn't support quant + method. + """ + raise NotImplementedError + + @abstractmethod + def get_scaled_act_names(self) -> List[str]: + """Returns the activation function names that should be post-scaled. + + For now, this is only used by AWQ. + """ + raise NotImplementedError diff --git a/python/sglang/srt/models/baichuan.py b/python/sglang/srt/models/baichuan.py index 19d2a3384..f42627ad6 100644 --- a/python/sglang/srt/models/baichuan.py +++ b/python/sglang/srt/models/baichuan.py @@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import ( QKVParallelLinear, RowParallelLinear, ) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -45,6 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import InputMetadata diff --git a/python/sglang/srt/models/chatglm.py b/python/sglang/srt/models/chatglm.py index 783dbb4c5..7e8dd2adc 100644 --- a/python/sglang/srt/models/chatglm.py +++ b/python/sglang/srt/models/chatglm.py @@ -24,12 +24,6 @@ from torch import nn from torch.nn import LayerNorm from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.linear import ( - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -40,7 +34,13 @@ from vllm.transformers_utils.configs import ChatGLMConfig from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import InputMetadata diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index f6d6f6e1f..9c93d1e41 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -50,21 +50,21 @@ from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from vllm.model_executor.layers.linear import ( +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.utils import set_weight_attrs - -from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.utils import set_weight_attrs @torch.compile diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index 39ac4aefa..9fd0f335d 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -27,12 +27,6 @@ from vllm.distributed import ( tensor_model_parallel_all_reduce, ) from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.linear import ( - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, @@ -40,12 +34,18 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.utils import set_weight_attrs from vllm.transformers_utils.configs.dbrx import DbrxConfig +from sglang.srt.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import InputMetadata +from sglang.srt.utils import set_weight_attrs class DbrxRouter(nn.Module): diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py index 59fd1ec7e..590f8f0bf 100644 --- a/python/sglang/srt/models/deepseek.py +++ b/python/sglang/srt/models/deepseek.py @@ -28,13 +28,6 @@ from vllm.distributed import ( tensor_model_parallel_all_reduce, ) from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.linear import ( - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -44,7 +37,14 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import InputMetadata diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 22b434a85..2d46f3e8a 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -27,13 +27,6 @@ from vllm.distributed import ( tensor_model_parallel_all_reduce, ) from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, - MergedColumnParallelLinear, - ReplicatedLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -43,7 +36,14 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import InputMetadata diff --git a/python/sglang/srt/models/exaone.py b/python/sglang/srt/models/exaone.py index 63d40be7a..4af7a9f47 100644 --- a/python/sglang/srt/models/exaone.py +++ b/python/sglang/srt/models/exaone.py @@ -23,12 +23,6 @@ import torch from torch import nn from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.linear import ( - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -38,7 +32,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import InputMetadata diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index ae3b1b194..50cbca1a1 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -23,19 +23,19 @@ from torch import nn from transformers import PretrainedConfig from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.linear import ( - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import InputMetadata diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 3223424d7..b6a282bd8 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -22,12 +22,6 @@ from torch import nn from transformers import PretrainedConfig from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.linear import ( - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig # from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding @@ -35,7 +29,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.layernorm import GemmaRMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import InputMetadata diff --git a/python/sglang/srt/models/gpt_bigcode.py b/python/sglang/srt/models/gpt_bigcode.py index 94b7f6153..f063c732f 100644 --- a/python/sglang/srt/models/gpt_bigcode.py +++ b/python/sglang/srt/models/gpt_bigcode.py @@ -23,17 +23,17 @@ from torch import nn from transformers import GPTBigCodeConfig from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import get_act_fn +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import InputMetadata diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index daf6f25da..2eda52983 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -28,12 +28,6 @@ from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from vllm.model_executor.layers.linear import ( - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -44,7 +38,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.fused_moe import FusedMoE from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import InputMetadata diff --git a/python/sglang/srt/models/internlm2.py b/python/sglang/srt/models/internlm2.py index f2947e991..9eb73f1fc 100644 --- a/python/sglang/srt/models/internlm2.py +++ b/python/sglang/srt/models/internlm2.py @@ -23,12 +23,6 @@ from torch import nn from transformers import PretrainedConfig from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.linear import ( - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -38,7 +32,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import InputMetadata diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index a01619b3a..447b548aa 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -24,12 +24,6 @@ from torch import nn from transformers import LlamaConfig from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.linear import ( - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -39,7 +33,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.managers.schedule_batch import global_server_args_dict diff --git a/python/sglang/srt/models/llama_classification.py b/python/sglang/srt/models/llama_classification.py index c17651064..536bec2f1 100644 --- a/python/sglang/srt/models/llama_classification.py +++ b/python/sglang/srt/models/llama_classification.py @@ -19,10 +19,10 @@ import torch from torch import nn from transformers import LlamaConfig from vllm.config import CacheConfig -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 9c1e529f0..df62b39fc 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -32,9 +32,9 @@ from transformers import ( ) from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from vllm.config import CacheConfig -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.mm_utils import ( get_anyres_image_grid_shape, unpad_image, diff --git a/python/sglang/srt/models/llavavid.py b/python/sglang/srt/models/llavavid.py index 45f47cffc..408a90f19 100644 --- a/python/sglang/srt/models/llavavid.py +++ b/python/sglang/srt/models/llavavid.py @@ -23,9 +23,9 @@ from torch import nn from transformers import CLIPVisionModel, LlavaConfig from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from vllm.config import CacheConfig -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.models.llama import LlamaForCausalLM diff --git a/python/sglang/srt/models/minicpm.py b/python/sglang/srt/models/minicpm.py index 49ff1926f..f796e881f 100644 --- a/python/sglang/srt/models/minicpm.py +++ b/python/sglang/srt/models/minicpm.py @@ -22,12 +22,6 @@ import torch from torch import nn from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.linear import ( - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -37,7 +31,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import InputMetadata diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py index a025d074e..3bc094e5a 100644 --- a/python/sglang/srt/models/minicpm3.py +++ b/python/sglang/srt/models/minicpm3.py @@ -29,7 +29,6 @@ from vllm.model_executor.layers.linear import ( ReplicatedLinear, RowParallelLinear, ) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -40,6 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import InputMetadata diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index d7c232ec1..73a7b8686 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -24,12 +24,6 @@ from transformers import MixtralConfig from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.linear import ( - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, @@ -39,7 +33,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.managers.schedule_batch import global_server_args_dict diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index b02e925c5..2e7483771 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -29,12 +29,6 @@ from vllm.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.layers.linear import ( - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -43,7 +37,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import InputMetadata diff --git a/python/sglang/srt/models/olmoe.py b/python/sglang/srt/models/olmoe.py index 6ada0e9f9..28eda6f4f 100644 --- a/python/sglang/srt/models/olmoe.py +++ b/python/sglang/srt/models/olmoe.py @@ -35,7 +35,6 @@ from vllm.model_executor.layers.linear import ( ReplicatedLinear, RowParallelLinear, ) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -47,6 +46,7 @@ from vllm.utils import print_warning_once from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import InputMetadata diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index 8787d0d58..3a2d31001 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -22,12 +22,6 @@ from torch import nn from transformers import PretrainedConfig from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.linear import ( - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -37,7 +31,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import InputMetadata diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 4e7d8e57b..5aac3f79c 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -22,12 +22,6 @@ import torch from torch import nn from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.linear import ( - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -37,8 +31,14 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.pooler import Pooler, PoolingType +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import InputMetadata diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index d47589ddc..27fcf6321 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -29,13 +29,6 @@ from vllm.distributed import ( tensor_model_parallel_all_reduce, ) from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.linear import ( - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -45,7 +38,14 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.managers.schedule_batch import global_server_args_dict diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index 9e10f12f2..2e344ca50 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -24,12 +24,6 @@ from torch import nn from transformers import PretrainedConfig from vllm.config import CacheConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.linear import ( - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, -) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -38,7 +32,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import InputMetadata diff --git a/python/sglang/srt/models/xverse.py b/python/sglang/srt/models/xverse.py index 388bd66f7..88d3ea2c7 100644 --- a/python/sglang/srt/models/xverse.py +++ b/python/sglang/srt/models/xverse.py @@ -31,7 +31,6 @@ from vllm.model_executor.layers.linear import ( QKVParallelLinear, RowParallelLinear, ) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -40,6 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import InputMetadata diff --git a/python/sglang/srt/models/xverse_moe.py b/python/sglang/srt/models/xverse_moe.py index 2f53682d1..666bf1d3d 100644 --- a/python/sglang/srt/models/xverse_moe.py +++ b/python/sglang/srt/models/xverse_moe.py @@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import ( ReplicatedLinear, RowParallelLinear, ) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -43,6 +42,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import InputMetadata diff --git a/python/sglang/srt/models/yivl.py b/python/sglang/srt/models/yivl.py index 0f86206d8..52f930edc 100644 --- a/python/sglang/srt/models/yivl.py +++ b/python/sglang/srt/models/yivl.py @@ -21,9 +21,9 @@ import torch import torch.nn as nn from transformers import CLIPVisionModel, LlavaConfig from vllm.config import CacheConfig -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.models.llava import LlavaLlamaForCausalLM diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 92e670479..c2d2b5f11 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -26,7 +26,7 @@ import struct import time from importlib.metadata import PackageNotFoundError, version from io import BytesIO -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union import numpy as np import psutil @@ -682,3 +682,23 @@ def replace_submodule( target_name = module_name.split(".")[-1] setattr(parent, target_name, new_module) return new_module + + +def set_weight_attrs( + weight: torch.Tensor, + weight_attrs: Optional[Dict[str, Any]], +): + """Set attributes on a weight tensor. + + This method is used to set attributes on a weight tensor. This method + will not overwrite existing attributes. + + Args: + weight: The weight tensor. + weight_attrs: A dictionary of attributes to set on the weight tensor. + """ + if weight_attrs is None: + return + for key, value in weight_attrs.items(): + assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}" + setattr(weight, key, value)