diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py new file mode 100644 index 000000000..72677bbd5 --- /dev/null +++ b/python/sglang/srt/layers/linear.py @@ -0,0 +1,869 @@ +# temporarily adapted from https://github.com/vllm-project/vllm/blob/e76466dde2bc9525d55165ceaa600d298c7bf773/vllm/model_executor/layers/linear.py +# FIXME: refactor the linear abstraction +from abc import abstractmethod +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter +from vllm.distributed import ( + divide, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) + + +def adjust_marlin_shard(param, shard_size, shard_offset): + marlin_tile_size = getattr(param, "marlin_tile_size", None) + if marlin_tile_size is None: + return shard_size, shard_offset + + return shard_size * marlin_tile_size, shard_offset * marlin_tile_size + + +def adjust_bitsandbytes_shard( + param: Parameter, qkv_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str +) -> Tuple[int, int]: + """Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" + + total, _ = qkv_offsets["total"] + orig_offset, orig_size = qkv_offsets[loaded_shard_id] + + quantized_total = param.data.shape[0] + quantized_offset = orig_offset * quantized_total // total + quantized_size = orig_size * quantized_total // total + + return quantized_size, quantized_offset + + +def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): + """For fused modules (QKV and MLP) we have an array of length + N that holds 1 scale for each "logical" matrix. So the param + is an array of length N. The loaded_weight corresponds to + one of the shards on disk. Here, we slice the param based on + the shard_id for loading. + """ + qkv_idxs = {"q": 0, "k": 1, "v": 2} + + if isinstance(shard_id, str): + shard_id = qkv_idxs[shard_id] + elif not isinstance(shard_id, int): + raise ValueError(f"Unknown Shard Id {shard_id}") + + # AutoFP8 scales do not have a shape + # compressed-tensors scales do have a shape + if len(loaded_weight.shape) != 0: + assert loaded_weight.shape[0] == 1 + loaded_weight = loaded_weight[0] + + return param[shard_id], loaded_weight + + +class LinearMethodBase(QuantizeMethodBase): + """Base class for different (maybe quantized) linear methods.""" + + @abstractmethod + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + """Create weights for a linear layer. + The weights will be set as attributes of the layer. + + Args: + layer: The layer that is using the LinearMethodBase factory. + input_size_per_partition: Size of the weight input dim on rank X. + output_partition_sizes: Sizes of the output dim of each logical + weight on rank X. E.g., output_partition_sizes for QKVLinear + is a list contains the width of Wq, Wk, Wv on rank X. + input_size: Size of the input dim of the weight across all ranks. + output_size: Size of the output dim of the weight across all ranks. + params_dtype: Datatype of the parameters. + """ + raise NotImplementedError + + @abstractmethod + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Apply the weights in layer to the input tensor. + Expects create_weights to have been called before on the layer.""" + raise NotImplementedError + + +class UnquantizedLinearMethod(LinearMethodBase): + """Linear method without quantization. + + Args: + separate_bias_add: If true, add bias separately after matrix + multiplication. + """ + + def __init__(self, separate_bias_add: bool = False): + self.separate_bias_add = separate_bias_add + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight = Parameter( + torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + weight = layer.weight + if self.separate_bias_add: + if bias is not None: + return F.linear(x, weight) + bias + return F.linear(x, weight) + return F.linear(x, weight, bias) + + +class LinearBase(torch.nn.Module): + """Base linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + """ + + def __init__( + self, + input_size: int, + output_size: int, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.skip_bias_add = skip_bias_add + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + if quant_config is None: + self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod() + else: + self.quant_method = quant_config.get_quant_method(self) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + +class ReplicatedLinear(LinearBase): + """Replicated linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__( + input_size, output_size, skip_bias_add, params_dtype, quant_config + ) + + # All the linear layer supports quant method. + assert self.quant_method is not None + self.quant_method.create_weights( + self, + self.input_size, + [self.output_size], + self.input_size, + self.output_size, + self.params_dtype, + ) + + if bias: + self.bias = Parameter( + torch.empty(self.output_size, dtype=self.params_dtype) + ) + set_weight_attrs(self.bias, {"output_dim": 0}) + else: + self.register_parameter("bias", None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + bias = self.bias if not self.skip_bias_add else None + assert self.quant_method is not None + output = self.quant_method.apply(self, x, bias) + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + return s + + +class ColumnParallelLinear(LinearBase): + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Args: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. + gather_output: If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is Y_i = XA_i + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + output_sizes: list of output sizes packed into one output, like for QKV + the list would be size 3. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + output_sizes: Optional[List[int]] = None, + ): + super().__init__( + input_size, output_size, skip_bias_add, params_dtype, quant_config + ) + + self.gather_output = gather_output + + # Divide the weight matrix along the last dimension. + tp_size = get_tensor_model_parallel_world_size() + assert self.quant_method is not None + self.output_size_per_partition = divide(self.output_size, tp_size) + self.output_partition_sizes = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = [ + divide(output_size, tp_size) for output_size in self.output_sizes + ] + + if output_sizes is None: + output_sizes = [output_size] + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=self.weight_loader, + ) + if bias: + self.bias = Parameter( + torch.empty(self.output_size_per_partition, dtype=params_dtype) + ) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + if param.data.dtype != loaded_weight.dtype: + param.data = torch.empty_like( + param.data, dtype=loaded_weight.dtype, device="cuda" + ) + + tp_rank = get_tensor_model_parallel_rank() + output_dim = getattr(param, "output_dim", None) + param_data = param.data + if output_dim is not None: + shard_size = param_data.shape[output_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def forward(self, input_): + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + assert self.quant_method is not None + output_parallel = self.quant_method.apply(self, input_, bias) + if self.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size_per_partition}" + s += f", bias={self.bias is not None}" + s += f", tp_size={get_tensor_model_parallel_world_size()}" + s += f", gather_output={self.gather_output}" + return s + + +class MergedColumnParallelLinear(ColumnParallelLinear): + """Packed linear layers with column parallelism. + + Similar to ColumnParallelLinear, but the weight matrix is concatenated + along the output dimension. When the weight matrix is loaded, the + different partitions are sharded separately. + + Args: + input_size: input dimension of the linear layer. + output_sizes: list of output dimensions of the linear layer. + bias: If true, add bias. + gather_output: If true, call all-gather on output and make the output + available to all GPUs, otherwise, every GPU will have + its own output. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + """ + + def __init__( + self, + input_size: int, + output_sizes: List[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + self.output_sizes = output_sizes + tp_size = get_tensor_model_parallel_world_size() + assert all(output_size % tp_size == 0 for output_size in output_sizes) + super().__init__( + input_size=input_size, + output_size=sum(output_sizes), + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + ) + + def weight_loader( + self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[int] = None, + ): + if param.data.dtype != loaded_weight.dtype: + param.data = torch.empty_like( + param.data, dtype=loaded_weight.dtype, device="cuda" + ) + + param_data = param.data + output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. + is_metadata = getattr(param, "is_metadata", False) + # Special case for per-tensor scale to load scalar into fused array. + needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) + + if loaded_shard_id is None: + # Loaded weight is already fused on disk (qkv/mlp). + if output_dim is None: + if needs_scalar_to_array is not None: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, 0 + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + current_shard_offset = 0 + shard_offsets: List[Tuple[int, int, int]] = [] + for i, output_size in enumerate(self.output_sizes): + shard_offsets.append((i, current_shard_offset, output_size)) + current_shard_offset += output_size + packed_dim = getattr(param, "packed_dim", None) + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset + ) + + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size + ) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + assert loaded_shard_id < len(self.output_sizes) + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + if output_dim is not None: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + # Special case for quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + packed_dim = getattr(param, "packed_dim", None) + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset + ) + + use_bitsandbytes = getattr(param, "use_bitsandbytes", False) + if use_bitsandbytes: + shard_size = loaded_weight.shape[output_dim] + shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id + + param_data = param_data.narrow(output_dim, shard_offset, shard_size) + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + # Special case for AQLM codebooks. + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_offset = loaded_shard_id * shard_size + param_data = param_data.narrow(0, shard_offset, shard_size) + + # Special case for per-tensor scales in fused case. + elif needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, loaded_shard_id + ) + + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "MergedColumnParallelLinear, assume the weight is " + "the same for all partitions." + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class QKVParallelLinear(ColumnParallelLinear): + """Linear layers for the attention's QKV transformation. + + Linear layers for the linear transformation of the query, key, and value + vectors in the attention layer. The weight matrix is concatenated along + the output dimension. The layer is parallelized along the head dimension. + When the number of key/value heads is smaller than the number of query + heads (e.g., multi-query/grouped-query attention), the key/value head may + be replicated while the query heads are partitioned. + + Args: + hidden_size: input hidden state size of the transformer. + head_size: size of each attention head. + total_num_heads: total number of attention query heads. + total_num_kv_heads: total number of attention key/value heads. If + None, assume total_num_kv_heads = total_num_heads. + bias: If true, add bias. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + """ + + def __init__( + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + self.hidden_size = hidden_size + self.head_size = head_size + self.total_num_heads = total_num_heads + if total_num_kv_heads is None: + total_num_kv_heads = total_num_heads + self.total_num_kv_heads = total_num_kv_heads + # Divide the weight matrix along the last dimension. + tp_size = get_tensor_model_parallel_world_size() + self.num_heads = divide(self.total_num_heads, tp_size) + if tp_size >= self.total_num_kv_heads: + self.num_kv_heads = 1 + self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads) + else: + self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) + self.num_kv_head_replicas = 1 + input_size = self.hidden_size + output_size = ( + (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size + ) + self.output_sizes = [ + self.num_heads * self.head_size * tp_size, # q_proj + self.num_kv_heads * self.head_size * tp_size, # k_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj + ] + + super().__init__( + input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + ) + + def weight_loader( + self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None, + ): + if param.data.dtype != loaded_weight.dtype: + param.data = torch.empty_like( + param.data, dtype=loaded_weight.dtype, device="cuda" + ) + + param_data = param.data + output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. + is_metadata = getattr(param, "is_metadata", False) + + # Special case for per-tensor scales in fused case. + needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) + + if loaded_shard_id is None: + # Loaded weight is already fused on disk (qkv/mlp). + if output_dim is None: + if needs_scalar_to_array is not None: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, 0 + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + shard_offsets = [ + # (shard_id, shard_offset, shard_size) + ("q", 0, self.total_num_heads * self.head_size), + ( + "k", + self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + ( + "v", + (self.total_num_heads + self.total_num_kv_heads) * self.head_size, + self.total_num_kv_heads * self.head_size, + ), + ] + packed_dim = getattr(param, "packed_dim", None) + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantized Weights. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset + ) + + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size + ) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + tp_rank = get_tensor_model_parallel_rank() + assert loaded_shard_id in ["q", "k", "v"] + + # If output dim is defined, use the default loading process. + if output_dim is not None: + if loaded_shard_id == "q": + shard_offset = 0 + shard_size = self.num_heads * self.head_size + elif loaded_shard_id == "k": + shard_offset = self.num_heads * self.head_size + shard_size = self.num_kv_heads * self.head_size + elif loaded_shard_id == "v": + shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size + shard_size = self.num_kv_heads * self.head_size + # Special case for Quantized Weights. + # If quantized, we need to adjust the offset and size to account + # for the packing. + packed_dim = getattr(param, "packed_dim", None) + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset + ) + + use_bitsandbytes = getattr(param, "use_bitsandbytes", False) + if use_bitsandbytes: + orig_qkv_offsets = { + "q": (0, self.num_heads * self.head_size), + "k": ( + self.num_heads * self.head_size, + self.num_kv_heads * self.head_size, + ), + "v": ( + (self.num_heads + self.num_kv_heads) * self.head_size, + self.num_kv_heads * self.head_size, + ), + "total": ( + (self.num_heads + 2 * self.num_kv_heads) * self.head_size, + 0, + ), + } + shard_size, shard_offset = adjust_bitsandbytes_shard( + param, orig_qkv_offsets, loaded_shard_id + ) + + param_data = param_data.narrow(output_dim, shard_offset, shard_size) + if loaded_shard_id == "q": + shard_id = tp_rank + else: + shard_id = tp_rank // self.num_kv_head_replicas + start_idx = shard_id * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + # Special case for for AQLM codebooks. + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_index = ["q", "k", "v"].index(loaded_shard_id) + param_data = param_data.narrow(0, shard_index * shard_size, shard_size) + # Special case for per-tensor scales in fused case. + elif needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, loaded_shard_id + ) + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "QKVParallelLinear, assume the weight is the same " + "for all partitions." + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class RowParallelLinear(LinearBase): + """Linear layer with row parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. Note that bias is not parallelized. + input_is_parallel: If true, we assume that the input is already + split across the GPUs and we do not split + again. + skip_bias_add: This was added to enable performance optimization where + bias can be fused with other element-wise operations. + We skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__( + input_size, output_size, skip_bias_add, params_dtype, quant_config + ) + + self.input_is_parallel = input_is_parallel + self.reduce_results = reduce_results + + # Divide the weight matrix along the last dimension. + self.tp_size = get_tensor_model_parallel_world_size() + self.input_size_per_partition = divide(input_size, self.tp_size) + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=[self.output_size], + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=self.weight_loader, + ) + if not reduce_results and (bias and not skip_bias_add): + raise ValueError( + "When not reduce the results, adding bias to the " + "results can lead to incorrect results" + ) + + if bias: + self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype)) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + if param.data.dtype != loaded_weight.dtype: + param.data = torch.empty_like( + param.data, dtype=loaded_weight.dtype, device="cuda" + ) + + param_data = param.data + tp_rank = get_tensor_model_parallel_rank() + input_dim = getattr(param, "input_dim", None) + if input_dim is not None: + shard_size = param.data.shape[input_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def forward(self, input_): + # Set up backprop all-reduce. + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size + ) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + assert self.quant_method is not None + output_parallel = self.quant_method.apply(self, input_parallel) + if self.reduce_results and self.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.skip_bias_add: + output = output_ + self.bias if self.bias is not None else output_ + output_bias = None + else: + output = output_ + output_bias = self.bias + return output, output_bias + + def extra_repr(self) -> str: + s = f"input_features={self.input_size_per_partition}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + s += f", tp_size={self.tp_size}" + s += f", reduce_results={self.reduce_results}" + return s diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py new file mode 100644 index 000000000..cbfdcc476 --- /dev/null +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -0,0 +1,49 @@ +# temporarily adapted from vLLM +# FIXME: in progress of refactoring the model loader + +from typing import Dict, Type + +from vllm.model_executor.layers.quantization.aqlm import AQLMConfig +from vllm.model_executor.layers.quantization.awq import AWQConfig +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsConfig, +) +from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig +from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig +from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config +from vllm.model_executor.layers.quantization.marlin import MarlinConfig +from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig + +from sglang.srt.layers.quantization.fp8 import Fp8Config + +QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { + "aqlm": AQLMConfig, + "awq": AWQConfig, + "deepspeedfp": DeepSpeedFPConfig, + "fp8": Fp8Config, + # The order of gptq methods is important for config.py iteration over + # override_quantization_method(..) + "marlin": MarlinConfig, + "gptq_marlin_24": GPTQMarlin24Config, + "gptq_marlin": GPTQMarlinConfig, + "gptq": GPTQConfig, + "squeezellm": SqueezeLLMConfig, + "compressed-tensors": CompressedTensorsConfig, + "bitsandbytes": BitsAndBytesConfig, +} + + +def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: + if quantization not in QUANTIZATION_METHODS: + raise ValueError(f"Invalid quantization method: {quantization}") + return QUANTIZATION_METHODS[quantization] + + +__all__ = [ + "QuantizationConfig", + "get_quantization_config", + "QUANTIZATION_METHODS", +] diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py new file mode 100644 index 000000000..79ae531a5 --- /dev/null +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -0,0 +1,662 @@ +# adapted from https://github.com/vllm-project/vllm/blob/e76466dde2bc9525d55165ceaa600d298c7bf773/vllm/model_executor/layers/quantization/fp8.py +# FIXME refactor in progress +from typing import Any, Dict, List, Optional, Union + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase, fused_moe +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQ_MARLIN_MAX_PARALLEL, + GPTQ_MARLIN_MIN_THREAD_N, + GPTQMarlinState, + marlin_permute_scales, +) +from vllm.model_executor.layers.quantization.utils.marlin_utils import pack_fp8_to_int32 +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.utils import print_warning_once + +from sglang.srt.layers.linear import LinearBase, LinearMethodBase + +ACTIVATION_SCHEMES = ["static", "dynamic"] + +logger = init_logger(__name__) + + +def cutlass_fp8_supported() -> bool: + capability = current_platform.get_device_capability() + capability = capability[0] * 10 + capability[1] + + return ops.cutlass_scaled_mm_supports_fp8(capability) + + +class Fp8Config(QuantizationConfig): + """Config class for FP8.""" + + def __init__( + self, + is_checkpoint_fp8_serialized: bool = False, + activation_scheme: str = "dynamic", + ) -> None: + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + if is_checkpoint_fp8_serialized: + logger.warning( + "Detected fp8 checkpoint. Please note that the " + "format is experimental and subject to change." + ) + if activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError(f"Unsupported activation scheme {activation_scheme}") + self.activation_scheme = activation_scheme + + @classmethod + def get_name(cls) -> str: + return "fp8" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_fp8_serialized = "fp8" in quant_method + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + return cls( + is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + activation_scheme=activation_scheme, + ) + + def get_quant_method( + self, layer: torch.nn.Module + ) -> Optional["QuantizeMethodBase"]: + + if isinstance(layer, LinearBase): + return Fp8LinearMethod(self) + elif isinstance(layer, FusedMoE): + return Fp8MoEMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class Fp8LinearMethod(LinearMethodBase): + """Linear method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Limitations: + 1. Only support per-tensor quantization due to torch._scaled_mm support. + 2. Only support float8_e4m3fn data type due to the limitation of + torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + self.cutlass_fp8_supported = cutlass_fp8_supported() + + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + capability = current_platform.get_device_capability() + capability = capability[0] * 10 + capability[1] + self.use_marlin = capability < 89 + + def _create_scale_param( + self, + scale_name: str, + layer: torch.nn.Module, + output_partition_sizes: List[int], + **extra_weight_attrs, + ) -> None: + scale = Parameter( + torch.empty(len(output_partition_sizes), dtype=torch.float32), + requires_grad=False, + ) + scale[:] = torch.finfo(torch.float8_e4m3fn).min + layer.register_parameter(scale_name, scale) + set_weight_attrs( + scale, + { + **extra_weight_attrs, + "needs_scalar_to_array": True, + }, + ) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + output_size_per_partition = sum(output_partition_sizes) + + layer.process_after_load = True + layer.logical_widths = output_partition_sizes + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # WEIGHT + # weight_dtype = (torch.float8_e4m3fn + # if self.quant_config.is_checkpoint_fp8_serialized else + # params_dtype) + weight_dtype = torch.float8_e4m3fn + weight = Parameter( + torch.empty( + output_size_per_partition, input_size_per_partition, dtype=weight_dtype + ), + requires_grad=False, + ) + layer.register_parameter("weight", weight) + set_weight_attrs( + weight, + { + **extra_weight_attrs, + "input_dim": 1, + "output_dim": 0, + }, + ) + + # If checkpoint is serialized fp8, load them. + # Otherwise, wait until process_weights_after_loading. + if self.quant_config.is_checkpoint_fp8_serialized: + # WEIGHT SCALE + self._create_scale_param( + scale_name="weight_scale", + layer=layer, + output_partition_sizes=output_partition_sizes, + **extra_weight_attrs, + ) + + # INPUT ACTIVATION SCALE + if self.quant_config.activation_scheme == "static": + self._create_scale_param( + scale_name="input_scale", + layer=layer, + output_partition_sizes=output_partition_sizes, + **extra_weight_attrs, + ) + + # For GPUs without FP8 hardware support, we use Marlin for fast + # fused dequantization + if self.use_marlin: + layer.marlin_state = GPTQMarlinState.REPACK + + def prepare_layer_for_marlin(self, layer: Module) -> None: + print_warning_once( + "Your GPU does not have native support for FP8 computation but " + "FP8 quantization is being used. Weight-only FP8 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads." + ) + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + + assert layer.marlin_state == GPTQMarlinState.REPACK + layer.marlin_state = GPTQMarlinState.READY + + device = layer.weight.device + + # WEIGHTS + # Repack weights to gptq format (packed int32 elements) + packed_gptq_qweight = pack_fp8_to_int32(layer.weight) + + # Repack weights to marlin format + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=packed_gptq_qweight, + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=part_size_k, + size_n=part_size_n, + num_bits=8, + ) + layer.weight = Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + # Currently Marlin doesn't support per-tensor scales, so we + # expand it to channelwise + scales = ( + layer.weight_scale.repeat(1, part_size_n).to(layer.orig_dtype).to(device) + ) + # Permute scales + marlin_scales = marlin_permute_scales( + s=scales, + size_k=part_size_k, + size_n=part_size_n, + group_size=-1, + num_bits=8, + ) + layer.weight_scale = Parameter(marlin_scales, requires_grad=False) + + # Allocate marlin workspace + max_workspace_size = ( + part_size_n // GPTQ_MARLIN_MIN_THREAD_N + ) * GPTQ_MARLIN_MAX_PARALLEL + workspace = torch.zeros( + max_workspace_size, dtype=torch.int, device=device, requires_grad=False + ) + + layer.workspace = workspace + + def process_weights_after_loading(self, layer: Module) -> None: + if not hasattr(layer, "process_after_load") or not layer.process_after_load: + return + + # If checkpoint is fp/bf16 (not serialized fp8), quantize the weights. + if not self.quant_config.is_checkpoint_fp8_serialized: + qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) + layer.weight = Parameter(qweight.t(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + layer.logical_widths = None + layer.input_scale = None + if self.use_marlin: + self.prepare_layer_for_marlin(layer) + return + + # If checkpoint is fp8, requantize the separately quantized logical + # weights into a single fp8 weight with a single weight scale. + else: + # WEIGHT_SCALE / WEIGHT + # Loop over logical weights, requantizing with single scale. + max_w_scale = layer.weight_scale.max() + + # QKV / MLP is fused in the on disk checkpoint if any of the + # weight scales are still set to the default since we initialize + # N weight scales for N shards but we only load 1 weight scale + # from disk in this case. As a result, we skip dequant -> requant + # since we already have quantized QKV together. + # Sample Model with fused checkpoint: + # * nm-testing/Phi-3-mini-128k-instruct-FP8 + unfused_module_in_checkpoint = ( + layer.weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min + ) + + if unfused_module_in_checkpoint: + start = 0 + for idx, logical_width in enumerate(layer.logical_widths): + end = start + logical_width + weight_dq = per_tensor_dequantize( + layer.weight[start:end, :], layer.weight_scale[idx] + ) + + layer.weight[start:end, :] = per_tensor_quantize( + weight_dq, layer.weight_scale.max() + ) + start = end + layer.weight_scale = Parameter(max_w_scale, requires_grad=False) + + # WEIGHT + # Transpose weight for passing to torch._scaled_mm + weight = layer.weight + layer.weight = Parameter(weight.t(), requires_grad=False) + + # INPUT ACTIVATION SCALE + # Dynamic: set to None (required input to ops.scaled_fp8_quant). + # Static: set to max of the input_scales (since they are equal). + if self.quant_config.activation_scheme == "dynamic": + layer.input_scale = None + elif self.quant_config.activation_scheme == "static": + layer.input_scale = Parameter( + layer.input_scale.max(), requires_grad=False + ) + else: + raise ValueError( + f"Unknown scheme {self.quant_config.activation_scheme}" + ) + + if self.use_marlin: + self.prepare_layer_for_marlin(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + if self.use_marlin: + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization + + reshaped_x = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (layer.output_size_per_partition,) + + output = ops.fp8_marlin_gemm( + a=reshaped_x, + b_q_weight=layer.weight, + b_scales=layer.weight_scale, + workspace=layer.workspace, + num_bits=8, + size_m=reshaped_x.shape[0], + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + else: + + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_scale computed from x + # If static, layer.input_scale is scalar and x_scale is input_scale + + if bias is None and self.cutlass_fp8_supported: + qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale) + + # Fused GEMM_DQ + output = ops.cutlass_scaled_mm( + qinput, + layer.weight, + out_dtype=x.dtype, + scale_a=x_scale, + scale_b=layer.weight_scale, + ) + + else: + qinput, x_scale = ops.scaled_fp8_quant( + x, layer.input_scale, batch_dim_padding=17 + ) + + # Fused GEMM_DQ -- note we padded the input above because + # torch._scaled_mm is more performant for matrices with + # batch dimension > 16. Note that this could change + # in the future. + output, _ = torch._scaled_mm( + qinput, + layer.weight, + out_dtype=x.dtype, + scale_a=x_scale, + scale_b=layer.weight_scale, + bias=bias, + ) + + return torch.narrow(output, 0, 0, x.shape[0]) + + +class Fp8MoEMethod(FusedMoEMethodBase): + """MoE method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + layer.process_after_load = True + + if self.quant_config.is_checkpoint_fp8_serialized: + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, hidden_size, intermediate_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_scale", w13_scale) + + w2_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_scale", w2_scale) + + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if self.quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(w13_scale, extra_weight_attrs) + set_weight_attrs(w2_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.quant_config.activation_scheme == "static": + if not self.quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8." + ) + + a13_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("a13_scale", a13_scale) + set_weight_attrs(a13_scale, extra_weight_attrs) + + a2_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("a2_scale", a2_scale) + set_weight_attrs(a2_scale, extra_weight_attrs) + else: + layer.a13_scale = None + layer.a2_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + if not hasattr(layer, "process_after_load") or not layer.process_after_load: + return + + # If checkpoint is fp16, quantize in place. + if not self.quant_config.is_checkpoint_fp8_serialized: + w13_weight = torch.empty_like( + layer.w13_weight.data, dtype=torch.float8_e4m3fn + ) + w2_weight = torch.empty_like( + layer.w2_weight.data, dtype=torch.float8_e4m3fn + ) + + # Re-initialize w13_scale because we directly quantize + # merged w13 weights and generate a single scaling factor. + layer.w13_scale = torch.nn.Parameter( + torch.ones( + layer.num_experts, dtype=torch.float32, device=w13_weight.device + ), + requires_grad=False, + ) + for expert in range(layer.num_experts): + w13_weight[expert, :, :], layer.w13_scale[expert] = ( + ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_scale[expert] = ops.scaled_fp8_quant( + layer.w2_weight.data[expert, :, :] + ) + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + return + + # If checkpoint is fp8, we need to handle that the + # MoE kernels require single activation scale and single weight + # scale for w13 per expert. + else: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.quant_config.activation_scheme == "static": + if layer.a13_scale is None or layer.a2_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None." + ) + if not all_close_1d(layer.a13_scale) or not all_close_1d( + layer.a2_scale + ): + print_warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer. " + ) + layer.a13_scale = torch.nn.Parameter( + layer.a13_scale.max(), requires_grad=False + ) + layer.a2_scale = torch.nn.Parameter( + layer.a2_scale.max(), requires_grad=False + ) + + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_scale.max(dim=1).values + for expert_id in range(layer.num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start : start + shard_size, :], + layer.w13_scale[expert_id][shard_id], + ) + layer.w13_weight[expert_id][start : start + shard_size, :] = ( + per_tensor_quantize(dq_weight, max_w13_scales[expert_id]) + ) + start += shard_size + + layer.w13_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) + return + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + ) -> torch.Tensor: + + return fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, + top_k, + renormalize=renormalize, + inplace=True, + use_fp8=True, + w1_scale=layer.w13_scale, + w2_scale=layer.w2_scale, + a1_scale=layer.a13_scale, + a2_scale=layer.a2_scale, + ) + + +# FIXME: not used +class Fp8KVCacheMethod(QuantizeMethodBase): + """Supports loading kv-cache scaling factors from FP8 checkpoints.""" + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module): + """Create "weight" (aka kv_scale) for an attention layer. + + Args: + layer: The layer that is using the QuantizeMethodBase factory. + """ + # Initialize the KV cache scale to 1.0 as the default value. + # If the kv_scale appears in the checkpoint, it will be + # overwritten when loading weights. + layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False) + + def apply(self, layer: torch.nn.Module) -> torch.Tensor: + raise RuntimeError("Fp8KVCacheMethod.apply should not be called.") + + def process_weights_after_loading(self, layer: Module) -> None: + # If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0 + # regardless whether the kv-scale is available in the checkpoint. + if layer.kv_cache_dtype != "auto": + kv_scale = layer.kv_scale.to("cpu").tolist() + if not isinstance(kv_scale, float): + raise ValueError( + "Only support per-tensor scaling factor " "for fp8 KV cache" + ) + layer._kv_scale = kv_scale + if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype: + print_warning_once( + "Using KV cache scaling factor 1.0 for fp8_e4m3. This may " + "cause accuracy issues. Please make sure kv-cache scaling " + "factor is available in the fp8 checkpoint." + ) + del layer.kv_scale + + +def per_tensor_quantize( + tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor] +) -> torch.Tensor: + finfo = torch.finfo(torch.float8_e4m3fn) + qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max) + return qweight.to(torch.float8_e4m3fn) + + +def per_tensor_dequantize( + tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor] +) -> torch.Tensor: + fake_qweight = tensor.to(torch.float16) + dq_weight = fake_qweight * inv_scale + return dq_weight + + +def all_close_1d(x: torch.Tensor) -> bool: + assert len(x.shape) == 1 + return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) diff --git a/python/sglang/srt/model_loader/model_loader.py b/python/sglang/srt/model_loader/model_loader.py new file mode 100644 index 000000000..ded68ae97 --- /dev/null +++ b/python/sglang/srt/model_loader/model_loader.py @@ -0,0 +1,276 @@ +# temporarily adapted from https://github.com/vllm-project/vllm/blob/10383887e03412196a2689b9398290719c4797bf/vllm/model_executor/model_loader/loader.py +# FIXME: in progress of refactoring the model loader + +import glob +import os +import re +from typing import Any, Dict, Generator, List, Optional, Tuple, Type + +import torch +from torch import nn +from tqdm import tqdm +from vllm.config import ( + CacheConfig, + DeviceConfig, + LoadConfig, + LoadFormat, + LoRAConfig, + ModelConfig, + MultiModalConfig, + ParallelConfig, + SchedulerConfig, +) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.model_loader.utils import ( + get_model_architecture, + set_default_torch_dtype, +) +from vllm.platforms import current_platform + +from sglang.srt.model_loader.utils import ( + download_safetensors_index_file_from_hf, + download_weights_from_hf, + filter_duplicate_safetensors_files, + get_quant_config, + safetensors_weights_iterator, +) + + +def _get_quantization_config( + model_config: ModelConfig, load_config: LoadConfig +) -> Optional[QuantizationConfig]: + """Get the quantization config.""" + if model_config.quantization is not None: + quant_config = get_quant_config(model_config, load_config) + capability = current_platform.get_device_capability() + capability = capability[0] * 10 + capability[1] + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} is not " + "supported for the current GPU. " + f"Minimum capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}." + ) + supported_dtypes = quant_config.get_supported_act_dtypes() + if model_config.dtype not in supported_dtypes: + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}" + ) + return quant_config + return None + + +def _get_model_initialization_kwargs( + model_class: Type[nn.Module], + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], +) -> Dict[str, Any]: + """Get extra kwargs for model initialization.""" + extra_kwargs: Dict[str, Any] = {} + + assert lora_config is None + assert multimodal_config is None + + return extra_kwargs + + +def _initialize_model( + model_config: ModelConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + cache_config: CacheConfig, +) -> nn.Module: + """Initialize a model with the given configurations.""" + model_class = get_model_architecture(model_config)[0] + quant_config = _get_quantization_config(model_config, load_config) + + return model_class( + config=model_config.hf_config, + cache_config=cache_config, + quant_config=quant_config, + **_get_model_initialization_kwargs(model_class, lora_config, multimodal_config), + ) + + +class ModelLoader: + """Model loader that can load different file types from disk.""" + + def __init__(self, load_config: LoadConfig): + self.load_config = load_config + + def _prepare_weights( + self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool + ) -> Tuple[str, List[str], bool]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + + is_local = os.path.isdir(model_name_or_path) + load_format = self.load_config.load_format + use_safetensors = False + # Some quantized models use .pt files for storing the weights. + if load_format == LoadFormat.AUTO: + allow_patterns = ["*.safetensors", "*.bin"] + elif load_format == LoadFormat.SAFETENSORS: + use_safetensors = True + allow_patterns = ["*.safetensors"] + elif load_format == LoadFormat.PT: + allow_patterns = ["*.pt"] + elif load_format == LoadFormat.NPCACHE: + allow_patterns = ["*.bin"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + if not is_local: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ) + else: + hf_folder = model_name_or_path + + hf_weights_files: List[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break + + if use_safetensors: + # For models like Mistral-7B-Instruct-v0.3 + # there are both sharded safetensors files and a consolidated + # safetensors file. Using both breaks. + # Here, we download the `model.safetensors.index.json` and filter + # any files not found in the index. + if not is_local: + download_safetensors_index_file_from_hf( + model_name_or_path, self.load_config.download_dir, revision + ) + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder + ) + else: + hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`" + ) + + return hf_folder, hf_weights_files, use_safetensors + + def _get_weights_iterator( + self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( + model_name_or_path, revision, fall_back_to_pt + ) + if self.load_config.load_format == LoadFormat.NPCACHE: + # Currently np_cache only support *.bin checkpoints + assert use_safetensors is False + weights_iterator = np_cache_weights_iterator( + model_name_or_path, + self.load_config.download_dir, + hf_folder, + hf_weights_files, + ) + elif use_safetensors: + weights_iterator = safetensors_weights_iterator(hf_weights_files) + else: + weights_iterator = pt_weights_iterator(hf_weights_files) + + return weights_iterator + + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model( + model_config, + self.load_config, + lora_config, + multimodal_config, + cache_config, + ) + weights = self._get_weights_iterator( + model_config.model, + model_config.revision, + fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True), + ) + + modules = {} + for name, module in model.named_modules(): + modules[name] = module + + def apply_quant_method(module): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + # print("before apply quant", module.weight, module.weight.dtype) + quant_method.process_weights_after_loading(module) + # print("after apply quant", module.weight, module.weight.dtype) + # FIXME: Remove this after Mixtral is updated + # to use quant_method. + if hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading() + + if torch.cuda.current_device() == 0: + weights = tqdm( + weights, total=model.get_num_params() * 1.5, desc="load model" + ) + + num_shard = {} + num_loaded = {} + for name, loaded_weight in weights: + model.load_weights(None, name, loaded_weight) + module_name, shard_num = model.get_module_name(name) + num_shard[module_name] = shard_num + if module_name not in num_loaded: + num_loaded[module_name] = 1 + else: + num_loaded[module_name] += 1 + if num_loaded[module_name] == num_shard[module_name]: + apply_quant_method(modules[module_name]) + + return model.eval() + + +def get_model( + *, + model_config: ModelConfig, + load_config: LoadConfig, + device_config: DeviceConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + cache_config: CacheConfig, +) -> nn.Module: + loader = ModelLoader(load_config) + return loader.load_model( + model_config=model_config, + device_config=device_config, + lora_config=lora_config, + multimodal_config=multimodal_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + cache_config=cache_config, + ) diff --git a/python/sglang/srt/model_loader/utils.py b/python/sglang/srt/model_loader/utils.py new file mode 100644 index 000000000..fdbf615f2 --- /dev/null +++ b/python/sglang/srt/model_loader/utils.py @@ -0,0 +1,260 @@ +# temporarily adapted from vLLM +# FIXME: in progress of refactoring the model loader +"""Utilities for selecting and loading models.""" +import contextlib +import fnmatch +import hashlib +import json +import logging +import os +import tempfile +from typing import Any, Generator, Iterable, List, Optional, Tuple, Type + +import filelock +import huggingface_hub.constants +import torch +from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download +from safetensors.torch import load_file, safe_open, save_file +from torch import nn +from tqdm.auto import tqdm +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME +from vllm.config import LoadConfig, ModelConfig +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + +from sglang.srt.layers.quantization import get_quantization_config + +logger = logging.getLogger("srt.model_loader") +temp_dir = tempfile.gettempdir() + + +@contextlib.contextmanager +def set_default_torch_dtype(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(old_dtype) + + +def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: + architectures = getattr(model_config.hf_config, "architectures", []) + # Special handling for quantized Mixtral. + # FIXME(woosuk): This is a temporary hack. + if ( + model_config.quantization is not None + and model_config.quantization != "fp8" + and "MixtralForCausalLM" in architectures + ): + architectures = ["QuantMixtralForCausalLM"] + + for arch in architectures: + model_cls = ModelRegistry.load_model_cls(arch) + if model_cls is not None: + return (model_cls, arch) + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}" + ) + + +class DisabledTqdm(tqdm): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, disable=True) + + +def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None): + lock_dir = cache_dir or temp_dir + os.makedirs(os.path.dirname(lock_dir), exist_ok=True) + model_name = model_name_or_path.replace("/", "-") + hash_name = hashlib.sha256(model_name.encode()).hexdigest() + # add hash to avoid conflict with old users' lock files + lock_file_name = hash_name + model_name + ".lock" + # mode 0o666 is required for the filelock to be shared across users + lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666) + return lock + + +def download_weights_from_hf( + model_name_or_path: str, + cache_dir: Optional[str], + allow_patterns: List[str], + revision: Optional[str] = None, +) -> str: + """Download model weights from Hugging Face Hub. + + Args: + model_name_or_path (str): The model name or path. + cache_dir (Optional[str]): The cache directory to store the model + weights. If None, will use HF defaults. + allow_patterns (List[str]): The allowed patterns for the + weight files. Files matched by any of the patterns will be + downloaded. + revision (Optional[str]): The revision of the model. + + Returns: + str: The path to the downloaded model weights. + """ + if not huggingface_hub.constants.HF_HUB_OFFLINE: + # Before we download we look at that is available: + fs = HfFileSystem() + file_list = fs.ls(model_name_or_path, detail=False, revision=revision) + + # depending on what is available we download different things + for pattern in allow_patterns: + matching = fnmatch.filter(file_list, pattern) + if len(matching) > 0: + allow_patterns = [pattern] + break + + logger.info("Using model weights format %s", allow_patterns) + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + hf_folder = snapshot_download( + model_name_or_path, + allow_patterns=allow_patterns, + cache_dir=cache_dir, + tqdm_class=DisabledTqdm, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ) + return hf_folder + + +def download_safetensors_index_file_from_hf( + model_name_or_path: str, + cache_dir: Optional[str], + revision: Optional[str] = None, +) -> None: + """Download hf safetensors index file from Hugging Face Hub. + + Args: + model_name_or_path (str): The model name or path. + cache_dir (Optional[str]): The cache directory to store the model + weights. If None, will use HF defaults. + revision (Optional[str]): The revision of the model. + """ + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + try: + # Download the safetensors index file. + hf_hub_download( + repo_id=model_name_or_path, + filename=SAFE_WEIGHTS_INDEX_NAME, + cache_dir=cache_dir, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ) + # If file not found on remote or locally, we should not fail since + # only some models will have SAFE_WEIGHTS_INDEX_NAME. + except huggingface_hub.utils.EntryNotFoundError: + logger.info("No %s found in remote.", SAFE_WEIGHTS_INDEX_NAME) + except huggingface_hub.utils.LocalEntryNotFoundError: + logger.info("No %s found in local cache.", SAFE_WEIGHTS_INDEX_NAME) + + +# For models like Mistral-7B-v0.3, there are both sharded +# safetensors files and a consolidated safetensors file. +# Passing both of these to the weight loader functionality breaks. +# So, we use the SAFE_WEIGHTS_INDEX_NAME to +# look up which safetensors files should be used. +def filter_duplicate_safetensors_files( + hf_weights_files: List[str], hf_folder: str +) -> List[str]: + # model.safetensors.index.json is a mapping from keys in the + # torch state_dict to safetensors file holding that weight. + index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME) + if not os.path.isfile(index_file_name): + return hf_weights_files + + # Iterate through the weight_map (weight_name: safetensors files) + # to identify weights that we should use. + with open(index_file_name) as index_file: + weight_map = json.load(index_file)["weight_map"] + weight_files_in_index = set() + for weight_name in weight_map: + weight_files_in_index.add(os.path.join(hf_folder, weight_map[weight_name])) + # Filter out any fields that are not found in the index file. + hf_weights_files = [f for f in hf_weights_files if f in weight_files_in_index] + return hf_weights_files + + +def safetensors_weights_iterator( + hf_weights_files: List[str], +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model safetensor files.""" + for st_file in hf_weights_files: + with safe_open(st_file, framework="pt") as f: + for name in f.keys(): # noqa: SIM118 + param = f.get_tensor(name) + yield name, param + + +def get_quant_config( + model_config: ModelConfig, load_config: LoadConfig +) -> QuantizationConfig: + quant_cls = get_quantization_config(model_config.quantization) + # Read the quantization config from the HF model config, if available. + hf_quant_config = getattr(model_config.hf_config, "quantization_config", None) + if hf_quant_config is None: + # compressed-tensors uses a compressions_config + hf_quant_config = getattr(model_config.hf_config, "compression_config", None) + if hf_quant_config is not None: + return quant_cls.from_config(hf_quant_config) + # In case of bitsandbytes/QLoRA, get quant config from the adapter model. + if model_config.quantization == "bitsandbytes": + if ( + not load_config.model_loader_extra_config + or "qlora_adapter_name_or_path" not in load_config.model_loader_extra_config + ): + return quant_cls.from_config({"adapter_name_or_path": ""}) + model_name_or_path = load_config.model_loader_extra_config[ + "qlora_adapter_name_or_path" + ] + + else: + model_name_or_path = model_config.model + is_local = os.path.isdir(model_name_or_path) + if not is_local: + # Download the config files. + with get_lock(model_name_or_path, load_config.download_dir): + hf_folder = snapshot_download( + model_name_or_path, + revision=model_config.revision, + allow_patterns="*.json", + cache_dir=load_config.download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + tqdm_class=DisabledTqdm, + ) + else: + hf_folder = model_name_or_path + + possible_config_filenames = quant_cls.get_config_filenames() + + # If the quantization config is not found, use the default config. + if not possible_config_filenames: + return quant_cls() + + config_files = glob.glob(os.path.join(hf_folder, "*.json")) + + quant_config_files = [ + f for f in config_files if any(f.endswith(x) for x in possible_config_filenames) + ] + if len(quant_config_files) == 0: + raise ValueError(f"Cannot find the config file for {model_config.quantization}") + if len(quant_config_files) > 1: + raise ValueError( + f"Found multiple config files for {model_config.quantization}: " + f"{quant_config_files}" + ) + + quant_config_file = quant_config_files[0] + with open(quant_config_file, "r") as f: + config = json.load(f) + + if model_config.quantization == "bitsandbytes": + config["adapter_name_or_path"] = model_name_or_path + + return quant_cls.from_config(config)