From fb1f28cbbbd3e2abcbf40dc043e5b2556938abec Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 11 Aug 2024 22:54:37 -0700 Subject: [PATCH] Clean up the comments and names under python/sglang/srt/layers (#1047) --- python/sglang/srt/layers/activation.py | 2 + ...token_attention.py => decode_attention.py} | 14 +- python/sglang/srt/layers/extend_attention.py | 7 +- python/sglang/srt/layers/layernorm.py | 2 + python/sglang/srt/layers/linear.py | 884 ------------------ ...ttention_nopad.py => prefill_attention.py} | 5 + .../srt/layers/quantization/__init__.py | 64 -- python/sglang/srt/layers/quantization/fp8.py | 677 -------------- python/sglang/srt/layers/radix_attention.py | 4 +- 9 files changed, 26 insertions(+), 1633 deletions(-) rename python/sglang/srt/layers/{token_attention.py => decode_attention.py} (97%) delete mode 100644 python/sglang/srt/layers/linear.py rename python/sglang/srt/layers/{context_flashattention_nopad.py => prefill_attention.py} (98%) delete mode 100644 python/sglang/srt/layers/quantization/__init__.py delete mode 100644 python/sglang/srt/layers/quantization/fp8.py diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index c767327a6..64d391594 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -11,6 +11,8 @@ See the License for the specific language governing permissions and limitations under the License. """ +"""Fused operators for activation layers.""" + import torch import torch.nn as nn import torch.nn.functional as F diff --git a/python/sglang/srt/layers/token_attention.py b/python/sglang/srt/layers/decode_attention.py similarity index 97% rename from python/sglang/srt/layers/token_attention.py rename to python/sglang/srt/layers/decode_attention.py index ab6e7ba77..c868299ef 100644 --- a/python/sglang/srt/layers/token_attention.py +++ b/python/sglang/srt/layers/decode_attention.py @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. """ +""" +Memory-efficient attention for decoding. +""" + # Adapted from # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py @@ -194,7 +198,7 @@ def _fwd_kernel_stage2( tl.store(out_ptrs, acc) -def _token_att_m_fwd( +def _decode_att_m_fwd( q, k_buffer, att_out, @@ -254,7 +258,7 @@ def _token_att_m_fwd( ) -def _token_softmax_reducev_fwd( +def _decode_softmax_reducev_fwd( logics, v_buffer, o, @@ -292,7 +296,7 @@ def _token_softmax_reducev_fwd( ) -def token_attention_fwd( +def decode_attention_fwd( q, k_buffer, v_buffer, @@ -312,7 +316,7 @@ def token_attention_fwd( (q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda" ) - _token_att_m_fwd( + _decode_att_m_fwd( q, k_buffer, att_m, @@ -324,7 +328,7 @@ def token_attention_fwd( sm_scale, logit_cap, ) - _token_softmax_reducev_fwd( + _decode_softmax_reducev_fwd( att_m, v_buffer, o, diff --git a/python/sglang/srt/layers/extend_attention.py b/python/sglang/srt/layers/extend_attention.py index 7398895d6..0a03f6562 100644 --- a/python/sglang/srt/layers/extend_attention.py +++ b/python/sglang/srt/layers/extend_attention.py @@ -13,11 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. """ +""" +Memory-efficient attention for prefill. +It supporst page size = 1 and prefill with KV cache (i.e. extend). +""" + import torch import triton import triton.language as tl -from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd +from sglang.srt.layers.prefill_attention import context_attention_fwd CUDA_CAPABILITY = torch.cuda.get_device_capability() diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 2a55c25e5..ac4d368d3 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. """ +"""Fused operators for normalization layers.""" + from typing import Optional, Tuple, Union import torch diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py deleted file mode 100644 index fb8891cb2..000000000 --- a/python/sglang/srt/layers/linear.py +++ /dev/null @@ -1,884 +0,0 @@ -""" -Copyright 2023-2024 SGLang Team -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -# temporarily adapted from https://github.com/vllm-project/vllm/blob/e76466dde2bc9525d55165ceaa600d298c7bf773/vllm/model_executor/layers/linear.py -# FIXME: refactor the linear abstraction -from abc import abstractmethod -from typing import Dict, List, Optional, Tuple - -import torch -import torch.nn.functional as F -from torch.nn.parameter import Parameter -from vllm.distributed import ( - divide, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - split_tensor_along_last_dim, - tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce, -) -from vllm.logger import init_logger -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, - QuantizeMethodBase, -) -from vllm.model_executor.utils import set_weight_attrs - -logger = init_logger(__name__) - - -def adjust_marlin_shard(param, shard_size, shard_offset): - marlin_tile_size = getattr(param, "marlin_tile_size", None) - if marlin_tile_size is None: - return shard_size, shard_offset - - return shard_size * marlin_tile_size, shard_offset * marlin_tile_size - - -def adjust_bitsandbytes_shard( - param: Parameter, qkv_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str -) -> Tuple[int, int]: - """Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" - - total, _ = qkv_offsets["total"] - orig_offset, orig_size = qkv_offsets[loaded_shard_id] - - quantized_total = param.data.shape[0] - quantized_offset = orig_offset * quantized_total // total - quantized_size = orig_size * quantized_total // total - - return quantized_size, quantized_offset - - -def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): - """For fused modules (QKV and MLP) we have an array of length - N that holds 1 scale for each "logical" matrix. So the param - is an array of length N. The loaded_weight corresponds to - one of the shards on disk. Here, we slice the param based on - the shard_id for loading. - """ - qkv_idxs = {"q": 0, "k": 1, "v": 2} - - if isinstance(shard_id, str): - shard_id = qkv_idxs[shard_id] - elif not isinstance(shard_id, int): - raise ValueError(f"Unknown Shard Id {shard_id}") - - # AutoFP8 scales do not have a shape - # compressed-tensors scales do have a shape - if len(loaded_weight.shape) != 0: - assert loaded_weight.shape[0] == 1 - loaded_weight = loaded_weight[0] - - return param[shard_id], loaded_weight - - -class LinearMethodBase(QuantizeMethodBase): - """Base class for different (maybe quantized) linear methods.""" - - @abstractmethod - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - """Create weights for a linear layer. - The weights will be set as attributes of the layer. - - Args: - layer: The layer that is using the LinearMethodBase factory. - input_size_per_partition: Size of the weight input dim on rank X. - output_partition_sizes: Sizes of the output dim of each logical - weight on rank X. E.g., output_partition_sizes for QKVLinear - is a list contains the width of Wq, Wk, Wv on rank X. - input_size: Size of the input dim of the weight across all ranks. - output_size: Size of the output dim of the weight across all ranks. - params_dtype: Datatype of the parameters. - """ - raise NotImplementedError - - @abstractmethod - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Apply the weights in layer to the input tensor. - Expects create_weights to have been called before on the layer.""" - raise NotImplementedError - - -class UnquantizedLinearMethod(LinearMethodBase): - """Linear method without quantization. - - Args: - separate_bias_add: If true, add bias separately after matrix - multiplication. - """ - - def __init__(self, separate_bias_add: bool = False): - self.separate_bias_add = separate_bias_add - - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - weight = Parameter( - torch.empty( - sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype, - ), - requires_grad=False, - ) - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - layer.register_parameter("weight", weight) - set_weight_attrs(weight, extra_weight_attrs) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - weight = layer.weight - if self.separate_bias_add: - if bias is not None: - return F.linear(x, weight) + bias - return F.linear(x, weight) - return F.linear(x, weight, bias) - - -class LinearBase(torch.nn.Module): - """Base linear layer. - - Args: - input_size: input dimension of the linear layer. - output_size: output dimension of the linear layer. - bias: If true, add bias. - skip_bias_add: If true, skip adding bias but instead return it. - params_dtype: Data type for the parameters. - quant_config: Quantization configure. - """ - - def __init__( - self, - input_size: int, - output_size: int, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - ): - super().__init__() - - # Keep input parameters - self.input_size = input_size - self.output_size = output_size - self.skip_bias_add = skip_bias_add - if params_dtype is None: - params_dtype = torch.get_default_dtype() - self.params_dtype = params_dtype - if quant_config is None: - self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod() - else: - self.quant_method = quant_config.get_quant_method(self) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - raise NotImplementedError - - -class ReplicatedLinear(LinearBase): - """Replicated linear layer. - - Args: - input_size: input dimension of the linear layer. - output_size: output dimension of the linear layer. - bias: If true, add bias. - skip_bias_add: If true, skip adding bias but instead return it. - params_dtype: Data type for the parameters. - quant_config: Quantization configure. - """ - - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - ): - super().__init__( - input_size, output_size, skip_bias_add, params_dtype, quant_config - ) - - # All the linear layer supports quant method. - assert self.quant_method is not None - self.quant_method.create_weights( - self, - self.input_size, - [self.output_size], - self.input_size, - self.output_size, - self.params_dtype, - ) - - if bias: - self.bias = Parameter( - torch.empty(self.output_size, dtype=self.params_dtype) - ) - set_weight_attrs(self.bias, {"output_dim": 0}) - else: - self.register_parameter("bias", None) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - bias = self.bias if not self.skip_bias_add else None - assert self.quant_method is not None - output = self.quant_method.apply(self, x, bias) - output_bias = self.bias if self.skip_bias_add else None - return output, output_bias - - def extra_repr(self) -> str: - s = f"in_features={self.input_size}" - s += f", output_features={self.output_size}" - s += f", bias={self.bias is not None}" - return s - - -class ColumnParallelLinear(LinearBase): - """Linear layer with column parallelism. - - The linear layer is defined as Y = XA + b. A is parallelized along - its second dimension as A = [A_1, ..., A_p]. - - Args: - input_size: first dimension of matrix A. - output_size: second dimension of matrix A. - bias: If true, add bias. - gather_output: If true, call all-gather on output and make Y available - to all GPUs, otherwise, every GPU will have its output - which is Y_i = XA_i - skip_bias_add: This was added to enable performance optimizations where - bias can be fused with other element-wise operations. we - skip adding bias but instead return it. - params_dtype: Data type for the parameters. - quant_config: Quantization configure. - output_sizes: list of output sizes packed into one output, like for QKV - the list would be size 3. - """ - - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - gather_output: bool = False, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - output_sizes: Optional[List[int]] = None, - ): - super().__init__( - input_size, output_size, skip_bias_add, params_dtype, quant_config - ) - - self.gather_output = gather_output - - # Divide the weight matrix along the last dimension. - tp_size = get_tensor_model_parallel_world_size() - assert self.quant_method is not None - self.output_size_per_partition = divide(self.output_size, tp_size) - self.output_partition_sizes = [self.output_size_per_partition] - # If QKV or MergedColumn, use output size of each partition. - if hasattr(self, "output_sizes"): - self.output_partition_sizes = [ - divide(output_size, tp_size) for output_size in self.output_sizes - ] - - if output_sizes is None: - output_sizes = [output_size] - self.quant_method.create_weights( - layer=self, - input_size_per_partition=self.input_size, - output_partition_sizes=self.output_partition_sizes, - input_size=self.input_size, - output_size=self.output_size, - params_dtype=self.params_dtype, - weight_loader=self.weight_loader, - ) - if bias: - self.bias = Parameter( - torch.empty(self.output_size_per_partition, dtype=params_dtype) - ) - set_weight_attrs( - self.bias, - { - "output_dim": 0, - "weight_loader": self.weight_loader, - }, - ) - else: - self.register_parameter("bias", None) - - def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - if param.data.dtype != loaded_weight.dtype: - param.data = torch.empty_like( - param.data, dtype=loaded_weight.dtype, device="cuda" - ) - - tp_rank = get_tensor_model_parallel_rank() - output_dim = getattr(param, "output_dim", None) - param_data = param.data - if output_dim is not None: - shard_size = param_data.shape[output_dim] - start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) - - # Special case for loading scales off disk, which often do not - # have a shape (such as in the case of AutoFP8). - if len(loaded_weight.shape) == 0: - loaded_weight = loaded_weight.reshape(1) - - assert param_data.shape == loaded_weight.shape - param_data.copy_(loaded_weight) - - def forward(self, input_): - bias = self.bias if not self.skip_bias_add else None - - # Matrix multiply. - assert self.quant_method is not None - output_parallel = self.quant_method.apply(self, input_, bias) - if self.gather_output: - # All-gather across the partitions. - output = tensor_model_parallel_all_gather(output_parallel) - else: - output = output_parallel - output_bias = self.bias if self.skip_bias_add else None - return output, output_bias - - def extra_repr(self) -> str: - s = f"in_features={self.input_size}" - s += f", output_features={self.output_size_per_partition}" - s += f", bias={self.bias is not None}" - s += f", tp_size={get_tensor_model_parallel_world_size()}" - s += f", gather_output={self.gather_output}" - return s - - -class MergedColumnParallelLinear(ColumnParallelLinear): - """Packed linear layers with column parallelism. - - Similar to ColumnParallelLinear, but the weight matrix is concatenated - along the output dimension. When the weight matrix is loaded, the - different partitions are sharded separately. - - Args: - input_size: input dimension of the linear layer. - output_sizes: list of output dimensions of the linear layer. - bias: If true, add bias. - gather_output: If true, call all-gather on output and make the output - available to all GPUs, otherwise, every GPU will have - its own output. - skip_bias_add: This was added to enable performance optimizations where - bias can be fused with other element-wise operations. we - skip adding bias but instead return it. - params_dtype: Data type for the parameters. - quant_config: Quantization configure. - """ - - def __init__( - self, - input_size: int, - output_sizes: List[int], - bias: bool = True, - gather_output: bool = False, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - ): - self.output_sizes = output_sizes - tp_size = get_tensor_model_parallel_world_size() - assert all(output_size % tp_size == 0 for output_size in output_sizes) - super().__init__( - input_size=input_size, - output_size=sum(output_sizes), - bias=bias, - gather_output=gather_output, - skip_bias_add=skip_bias_add, - params_dtype=params_dtype, - quant_config=quant_config, - ) - - def weight_loader( - self, - param: Parameter, - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[int] = None, - ): - if param.data.dtype != loaded_weight.dtype: - param.data = torch.empty_like( - param.data, dtype=loaded_weight.dtype, device="cuda" - ) - - param_data = param.data - output_dim = getattr(param, "output_dim", None) - # Special case for AQLM codebooks. - is_metadata = getattr(param, "is_metadata", False) - # Special case for per-tensor scale to load scalar into fused array. - needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) - - if loaded_shard_id is None: - # Loaded weight is already fused on disk (qkv/mlp). - if output_dim is None: - if needs_scalar_to_array is not None: - param_data, loaded_weight = adjust_scalar_to_fused_array( - param_data, loaded_weight, 0 - ) - - assert param_data.shape == loaded_weight.shape - param_data.copy_(loaded_weight) - return - current_shard_offset = 0 - shard_offsets: List[Tuple[int, int, int]] = [] - for i, output_size in enumerate(self.output_sizes): - shard_offsets.append((i, current_shard_offset, output_size)) - current_shard_offset += output_size - packed_dim = getattr(param, "packed_dim", None) - for shard_id, shard_offset, shard_size in shard_offsets: - # Special case for Quantization. - # If quantized, we need to adjust the offset and size to account - # for the packing. - if packed_dim == output_dim: - shard_size = shard_size // param.pack_factor - shard_offset = shard_offset // param.pack_factor - # Special case for Marlin. - shard_size, shard_offset = adjust_marlin_shard( - param, shard_size, shard_offset - ) - - loaded_weight_shard = loaded_weight.narrow( - output_dim, shard_offset, shard_size - ) - self.weight_loader(param, loaded_weight_shard, shard_id) - return - - assert loaded_shard_id < len(self.output_sizes) - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - if output_dim is not None: - shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size - shard_size = self.output_sizes[loaded_shard_id] // tp_size - # Special case for quantization. - # If quantized, we need to adjust the offset and size to account - # for the packing. - packed_dim = getattr(param, "packed_dim", None) - if packed_dim == output_dim: - shard_size = shard_size // param.pack_factor - shard_offset = shard_offset // param.pack_factor - # Special case for Marlin. - shard_size, shard_offset = adjust_marlin_shard( - param, shard_size, shard_offset - ) - - use_bitsandbytes = getattr(param, "use_bitsandbytes", False) - if use_bitsandbytes: - shard_size = loaded_weight.shape[output_dim] - shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id - - param_data = param_data.narrow(output_dim, shard_offset, shard_size) - start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) - # Special case for AQLM codebooks. - elif is_metadata: - # metadata indicates fixed size concatenated along dim 0 - shard_size = loaded_weight.shape[0] - shard_offset = loaded_shard_id * shard_size - param_data = param_data.narrow(0, shard_offset, shard_size) - - # Special case for per-tensor scales in fused case. - elif needs_scalar_to_array: - param_data, loaded_weight = adjust_scalar_to_fused_array( - param_data, loaded_weight, loaded_shard_id - ) - - else: - ignore_warning = getattr(param, "ignore_warning", False) - if not ignore_warning: - logger.warning( - "Loading a weight without `output_dim` attribute in " - "MergedColumnParallelLinear, assume the weight is " - "the same for all partitions." - ) - - assert param_data.shape == loaded_weight.shape - param_data.copy_(loaded_weight) - - -class QKVParallelLinear(ColumnParallelLinear): - """Linear layers for the attention's QKV transformation. - - Linear layers for the linear transformation of the query, key, and value - vectors in the attention layer. The weight matrix is concatenated along - the output dimension. The layer is parallelized along the head dimension. - When the number of key/value heads is smaller than the number of query - heads (e.g., multi-query/grouped-query attention), the key/value head may - be replicated while the query heads are partitioned. - - Args: - hidden_size: input hidden state size of the transformer. - head_size: size of each attention head. - total_num_heads: total number of attention query heads. - total_num_kv_heads: total number of attention key/value heads. If - None, assume total_num_kv_heads = total_num_heads. - bias: If true, add bias. - skip_bias_add: This was added to enable performance optimizations where - bias can be fused with other element-wise operations. we - skip adding bias but instead return it. - params_dtype: Data type for the parameters. - quant_config: Quantization configure. - """ - - def __init__( - self, - hidden_size: int, - head_size: int, - total_num_heads: int, - total_num_kv_heads: Optional[int] = None, - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - ): - self.hidden_size = hidden_size - self.head_size = head_size - self.total_num_heads = total_num_heads - if total_num_kv_heads is None: - total_num_kv_heads = total_num_heads - self.total_num_kv_heads = total_num_kv_heads - # Divide the weight matrix along the last dimension. - tp_size = get_tensor_model_parallel_world_size() - self.num_heads = divide(self.total_num_heads, tp_size) - if tp_size >= self.total_num_kv_heads: - self.num_kv_heads = 1 - self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads) - else: - self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) - self.num_kv_head_replicas = 1 - input_size = self.hidden_size - output_size = ( - (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size - ) - self.output_sizes = [ - self.num_heads * self.head_size * tp_size, # q_proj - self.num_kv_heads * self.head_size * tp_size, # k_proj - self.num_kv_heads * self.head_size * tp_size, # v_proj - ] - - super().__init__( - input_size=input_size, - output_size=output_size, - bias=bias, - gather_output=False, - skip_bias_add=skip_bias_add, - params_dtype=params_dtype, - quant_config=quant_config, - ) - - def weight_loader( - self, - param: Parameter, - loaded_weight: torch.Tensor, - loaded_shard_id: Optional[str] = None, - ): - if param.data.dtype != loaded_weight.dtype: - param.data = torch.empty_like( - param.data, dtype=loaded_weight.dtype, device="cuda" - ) - - param_data = param.data - output_dim = getattr(param, "output_dim", None) - # Special case for AQLM codebooks. - is_metadata = getattr(param, "is_metadata", False) - - # Special case for per-tensor scales in fused case. - needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) - - if loaded_shard_id is None: - # Loaded weight is already fused on disk (qkv/mlp). - if output_dim is None: - if needs_scalar_to_array is not None: - param_data, loaded_weight = adjust_scalar_to_fused_array( - param_data, loaded_weight, 0 - ) - - assert param_data.shape == loaded_weight.shape - param_data.copy_(loaded_weight) - return - shard_offsets = [ - # (shard_id, shard_offset, shard_size) - ("q", 0, self.total_num_heads * self.head_size), - ( - "k", - self.total_num_heads * self.head_size, - self.total_num_kv_heads * self.head_size, - ), - ( - "v", - (self.total_num_heads + self.total_num_kv_heads) * self.head_size, - self.total_num_kv_heads * self.head_size, - ), - ] - packed_dim = getattr(param, "packed_dim", None) - for shard_id, shard_offset, shard_size in shard_offsets: - # Special case for Quantized Weights. - # If quantized, we need to adjust the offset and size to account - # for the packing. - if packed_dim == output_dim: - shard_size = shard_size // param.pack_factor - shard_offset = shard_offset // param.pack_factor - - # Special case for Marlin. - shard_size, shard_offset = adjust_marlin_shard( - param, shard_size, shard_offset - ) - - loaded_weight_shard = loaded_weight.narrow( - output_dim, shard_offset, shard_size - ) - self.weight_loader(param, loaded_weight_shard, shard_id) - return - - tp_rank = get_tensor_model_parallel_rank() - assert loaded_shard_id in ["q", "k", "v"] - - # If output dim is defined, use the default loading process. - if output_dim is not None: - if loaded_shard_id == "q": - shard_offset = 0 - shard_size = self.num_heads * self.head_size - elif loaded_shard_id == "k": - shard_offset = self.num_heads * self.head_size - shard_size = self.num_kv_heads * self.head_size - elif loaded_shard_id == "v": - shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size - shard_size = self.num_kv_heads * self.head_size - # Special case for Quantized Weights. - # If quantized, we need to adjust the offset and size to account - # for the packing. - packed_dim = getattr(param, "packed_dim", None) - if packed_dim == output_dim: - shard_size = shard_size // param.pack_factor - shard_offset = shard_offset // param.pack_factor - - # Special case for Marlin. - shard_size, shard_offset = adjust_marlin_shard( - param, shard_size, shard_offset - ) - - use_bitsandbytes = getattr(param, "use_bitsandbytes", False) - if use_bitsandbytes: - orig_qkv_offsets = { - "q": (0, self.num_heads * self.head_size), - "k": ( - self.num_heads * self.head_size, - self.num_kv_heads * self.head_size, - ), - "v": ( - (self.num_heads + self.num_kv_heads) * self.head_size, - self.num_kv_heads * self.head_size, - ), - "total": ( - (self.num_heads + 2 * self.num_kv_heads) * self.head_size, - 0, - ), - } - shard_size, shard_offset = adjust_bitsandbytes_shard( - param, orig_qkv_offsets, loaded_shard_id - ) - - param_data = param_data.narrow(output_dim, shard_offset, shard_size) - if loaded_shard_id == "q": - shard_id = tp_rank - else: - shard_id = tp_rank // self.num_kv_head_replicas - start_idx = shard_id * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) - # Special case for for AQLM codebooks. - elif is_metadata: - # metadata indicates fixed size concatenated along dim 0 - shard_size = loaded_weight.shape[0] - shard_index = ["q", "k", "v"].index(loaded_shard_id) - param_data = param_data.narrow(0, shard_index * shard_size, shard_size) - # Special case for per-tensor scales in fused case. - elif needs_scalar_to_array: - param_data, loaded_weight = adjust_scalar_to_fused_array( - param_data, loaded_weight, loaded_shard_id - ) - else: - ignore_warning = getattr(param, "ignore_warning", False) - if not ignore_warning: - logger.warning( - "Loading a weight without `output_dim` attribute in " - "QKVParallelLinear, assume the weight is the same " - "for all partitions." - ) - - assert param_data.shape == loaded_weight.shape - param_data.copy_(loaded_weight) - - -class RowParallelLinear(LinearBase): - """Linear layer with row parallelism. - - The linear layer is defined as Y = XA + b. A is parallelized along - its first dimension and X along its second dimension as: - - - - | A_1 | - | . | - A = | . | X = [X_1, ..., X_p] - | . | - | A_p | - - - - Arguments: - input_size: first dimension of matrix A. - output_size: second dimension of matrix A. - bias: If true, add bias. Note that bias is not parallelized. - input_is_parallel: If true, we assume that the input is already - split across the GPUs and we do not split - again. - skip_bias_add: This was added to enable performance optimization where - bias can be fused with other element-wise operations. - We skip adding bias but instead return it. - params_dtype: Data type for the parameters. - quant_config: Quantization configure. - """ - - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - input_is_parallel: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - reduce_results: bool = True, - quant_config: Optional[QuantizationConfig] = None, - ): - super().__init__( - input_size, output_size, skip_bias_add, params_dtype, quant_config - ) - - self.input_is_parallel = input_is_parallel - self.reduce_results = reduce_results - - # Divide the weight matrix along the last dimension. - self.tp_size = get_tensor_model_parallel_world_size() - self.input_size_per_partition = divide(input_size, self.tp_size) - assert self.quant_method is not None - self.quant_method.create_weights( - layer=self, - input_size_per_partition=self.input_size_per_partition, - output_partition_sizes=[self.output_size], - input_size=self.input_size, - output_size=self.output_size, - params_dtype=self.params_dtype, - weight_loader=self.weight_loader, - ) - if not reduce_results and (bias and not skip_bias_add): - raise ValueError( - "When not reduce the results, adding bias to the " - "results can lead to incorrect results" - ) - - if bias: - self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype)) - set_weight_attrs( - self.bias, - { - "output_dim": 0, - "weight_loader": self.weight_loader, - }, - ) - else: - self.register_parameter("bias", None) - - def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - if param.data.dtype != loaded_weight.dtype: - param.data = torch.empty_like( - param.data, dtype=loaded_weight.dtype, device="cuda" - ) - - param_data = param.data - tp_rank = get_tensor_model_parallel_rank() - input_dim = getattr(param, "input_dim", None) - if input_dim is not None: - shard_size = param.data.shape[input_dim] - start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) - - # Special case for loading scales off disk, which often do not - # have a shape (such as in the case of AutoFP8). - if len(loaded_weight.shape) == 0: - loaded_weight = loaded_weight.reshape(1) - - assert param_data.shape == loaded_weight.shape - param_data.copy_(loaded_weight) - - def forward(self, input_): - # Set up backprop all-reduce. - if self.input_is_parallel: - input_parallel = input_ - else: - tp_rank = get_tensor_model_parallel_rank() - splitted_input = split_tensor_along_last_dim( - input_, num_partitions=self.tp_size - ) - input_parallel = splitted_input[tp_rank].contiguous() - - # Matrix multiply. - assert self.quant_method is not None - output_parallel = self.quant_method.apply(self, input_parallel) - if self.reduce_results and self.tp_size > 1: - output_ = tensor_model_parallel_all_reduce(output_parallel) - else: - output_ = output_parallel - - if not self.skip_bias_add: - output = output_ + self.bias if self.bias is not None else output_ - output_bias = None - else: - output = output_ - output_bias = self.bias - return output, output_bias - - def extra_repr(self) -> str: - s = f"input_features={self.input_size_per_partition}" - s += f", output_features={self.output_size}" - s += f", bias={self.bias is not None}" - s += f", tp_size={self.tp_size}" - s += f", reduce_results={self.reduce_results}" - return s diff --git a/python/sglang/srt/layers/context_flashattention_nopad.py b/python/sglang/srt/layers/prefill_attention.py similarity index 98% rename from python/sglang/srt/layers/context_flashattention_nopad.py rename to python/sglang/srt/layers/prefill_attention.py index a2dc2ff31..99343a4df 100644 --- a/python/sglang/srt/layers/context_flashattention_nopad.py +++ b/python/sglang/srt/layers/prefill_attention.py @@ -13,6 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. """ +""" +Memory-efficient attention for prefill. +It supporst page size = 1. +""" + # Adapted from # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1 import torch diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py deleted file mode 100644 index 564a696b0..000000000 --- a/python/sglang/srt/layers/quantization/__init__.py +++ /dev/null @@ -1,64 +0,0 @@ -""" -Copyright 2023-2024 SGLang Team -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -# temporarily adapted from vLLM -# FIXME: in progress of refactoring the model loader - -from typing import Dict, Type - -from vllm.model_executor.layers.quantization.aqlm import AQLMConfig -from vllm.model_executor.layers.quantization.awq import AWQConfig -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig -from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig -from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 - CompressedTensorsConfig, -) -from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig -from vllm.model_executor.layers.quantization.gptq import GPTQConfig -from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig -from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config -from vllm.model_executor.layers.quantization.marlin import MarlinConfig -from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig - -from sglang.srt.layers.quantization.fp8 import Fp8Config - -QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { - "aqlm": AQLMConfig, - "awq": AWQConfig, - "deepspeedfp": DeepSpeedFPConfig, - "fp8": Fp8Config, - # The order of gptq methods is important for config.py iteration over - # override_quantization_method(..) - "marlin": MarlinConfig, - "gptq_marlin_24": GPTQMarlin24Config, - "gptq_marlin": GPTQMarlinConfig, - "gptq": GPTQConfig, - "squeezellm": SqueezeLLMConfig, - "compressed-tensors": CompressedTensorsConfig, - "bitsandbytes": BitsAndBytesConfig, -} - - -def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: - if quantization not in QUANTIZATION_METHODS: - raise ValueError(f"Invalid quantization method: {quantization}") - return QUANTIZATION_METHODS[quantization] - - -__all__ = [ - "QuantizationConfig", - "get_quantization_config", - "QUANTIZATION_METHODS", -] diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py deleted file mode 100644 index 12378d506..000000000 --- a/python/sglang/srt/layers/quantization/fp8.py +++ /dev/null @@ -1,677 +0,0 @@ -""" -Copyright 2023-2024 SGLang Team -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - -# adapted from https://github.com/vllm-project/vllm/blob/e76466dde2bc9525d55165ceaa600d298c7bf773/vllm/model_executor/layers/quantization/fp8.py -# FIXME refactor in progress -from typing import Any, Dict, List, Optional, Union - -import torch -from torch.nn import Module -from torch.nn.parameter import Parameter -from vllm import _custom_ops as ops -from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase, fused_moe -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, - QuantizeMethodBase, -) -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQ_MARLIN_MAX_PARALLEL, - GPTQ_MARLIN_MIN_THREAD_N, - GPTQMarlinState, - marlin_permute_scales, -) -from vllm.model_executor.layers.quantization.utils.marlin_utils import pack_fp8_to_int32 -from vllm.model_executor.utils import set_weight_attrs -from vllm.platforms import current_platform -from vllm.utils import print_warning_once - -from sglang.srt.layers.linear import LinearBase, LinearMethodBase - -ACTIVATION_SCHEMES = ["static", "dynamic"] - -logger = init_logger(__name__) - - -def cutlass_fp8_supported() -> bool: - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - - return ops.cutlass_scaled_mm_supports_fp8(capability) - - -class Fp8Config(QuantizationConfig): - """Config class for FP8.""" - - def __init__( - self, - is_checkpoint_fp8_serialized: bool = False, - activation_scheme: str = "dynamic", - ) -> None: - self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized - if is_checkpoint_fp8_serialized: - logger.warning( - "Detected fp8 checkpoint. Please note that the " - "format is experimental and subject to change." - ) - if activation_scheme not in ACTIVATION_SCHEMES: - raise ValueError(f"Unsupported activation scheme {activation_scheme}") - self.activation_scheme = activation_scheme - - @classmethod - def get_name(cls) -> str: - return "fp8" - - @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: - return [torch.bfloat16, torch.half] - - @classmethod - def get_min_capability(cls) -> int: - return 80 - - @classmethod - def get_config_filenames(cls) -> List[str]: - return [] - - @classmethod - def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": - quant_method = cls.get_from_keys(config, ["quant_method"]) - is_checkpoint_fp8_serialized = "fp8" in quant_method - activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) - return cls( - is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, - activation_scheme=activation_scheme, - ) - - def get_quant_method( - self, layer: torch.nn.Module - ) -> Optional["QuantizeMethodBase"]: - - if isinstance(layer, LinearBase): - return Fp8LinearMethod(self) - elif isinstance(layer, FusedMoE): - return Fp8MoEMethod(self) - return None - - def get_scaled_act_names(self) -> List[str]: - return [] - - -class Fp8LinearMethod(LinearMethodBase): - """Linear method for FP8. - Supports loading FP8 checkpoints with static weight scale and - dynamic/static activation scale. - - Also supports loading quantized FP16/BF16 model checkpoints with dynamic - activation scaling. The weight scaling factor will be initialized after - the model weights are loaded. - - Limitations: - 1. Only support per-tensor quantization due to torch._scaled_mm support. - 2. Only support float8_e4m3fn data type due to the limitation of - torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) - - Args: - quant_config: The quantization config. - """ - - def __init__(self, quant_config: Fp8Config): - self.quant_config = quant_config - self.cutlass_fp8_supported = cutlass_fp8_supported() - - # For GPUs that lack FP8 hardware support, we can leverage the Marlin - # kernel for fast weight-only FP8 quantization - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - self.use_marlin = capability < 89 - - def _create_scale_param( - self, - scale_name: str, - layer: torch.nn.Module, - output_partition_sizes: List[int], - **extra_weight_attrs, - ) -> None: - scale = Parameter( - torch.empty(len(output_partition_sizes), dtype=torch.float32), - requires_grad=False, - ) - scale[:] = torch.finfo(torch.float8_e4m3fn).min - layer.register_parameter(scale_name, scale) - set_weight_attrs( - scale, - { - **extra_weight_attrs, - "needs_scalar_to_array": True, - }, - ) - - def create_weights( - self, - layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: List[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - del input_size, output_size - output_size_per_partition = sum(output_partition_sizes) - - layer.process_after_load = True - layer.logical_widths = output_partition_sizes - - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - layer.orig_dtype = params_dtype - - # WEIGHT - # weight_dtype = (torch.float8_e4m3fn - # if self.quant_config.is_checkpoint_fp8_serialized else - # params_dtype) - weight_dtype = torch.float8_e4m3fn - weight = Parameter( - torch.empty( - output_size_per_partition, input_size_per_partition, dtype=weight_dtype - ), - requires_grad=False, - ) - layer.register_parameter("weight", weight) - set_weight_attrs( - weight, - { - **extra_weight_attrs, - "input_dim": 1, - "output_dim": 0, - }, - ) - - # If checkpoint is serialized fp8, load them. - # Otherwise, wait until process_weights_after_loading. - if self.quant_config.is_checkpoint_fp8_serialized: - # WEIGHT SCALE - self._create_scale_param( - scale_name="weight_scale", - layer=layer, - output_partition_sizes=output_partition_sizes, - **extra_weight_attrs, - ) - - # INPUT ACTIVATION SCALE - if self.quant_config.activation_scheme == "static": - self._create_scale_param( - scale_name="input_scale", - layer=layer, - output_partition_sizes=output_partition_sizes, - **extra_weight_attrs, - ) - - # For GPUs without FP8 hardware support, we use Marlin for fast - # fused dequantization - if self.use_marlin: - layer.marlin_state = GPTQMarlinState.REPACK - - def prepare_layer_for_marlin(self, layer: Module) -> None: - print_warning_once( - "Your GPU does not have native support for FP8 computation but " - "FP8 quantization is being used. Weight-only FP8 compression will " - "be used leveraging the Marlin kernel. This may degrade " - "performance for compute-heavy workloads." - ) - - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - - assert layer.marlin_state == GPTQMarlinState.REPACK - layer.marlin_state = GPTQMarlinState.READY - - device = layer.weight.device - - # WEIGHTS - # Repack weights to gptq format (packed int32 elements) - packed_gptq_qweight = pack_fp8_to_int32(layer.weight) - - # Repack weights to marlin format - marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=packed_gptq_qweight, - perm=torch.empty(0, dtype=torch.int, device=device), - size_k=part_size_k, - size_n=part_size_n, - num_bits=8, - ) - layer.weight = Parameter(marlin_qweight, requires_grad=False) - - # WEIGHT SCALES - # Currently Marlin doesn't support per-tensor scales, so we - # expand it to channelwise - scales = ( - layer.weight_scale.repeat(1, part_size_n).to(layer.orig_dtype).to(device) - ) - # Permute scales - marlin_scales = marlin_permute_scales( - s=scales, - size_k=part_size_k, - size_n=part_size_n, - group_size=-1, - num_bits=8, - ) - layer.weight_scale = Parameter(marlin_scales, requires_grad=False) - - # Allocate marlin workspace - max_workspace_size = ( - part_size_n // GPTQ_MARLIN_MIN_THREAD_N - ) * GPTQ_MARLIN_MAX_PARALLEL - workspace = torch.zeros( - max_workspace_size, dtype=torch.int, device=device, requires_grad=False - ) - - layer.workspace = workspace - - def process_weights_after_loading(self, layer: Module) -> None: - if not hasattr(layer, "process_after_load") or not layer.process_after_load: - return - - # If checkpoint is fp/bf16 (not serialized fp8), quantize the weights. - if not self.quant_config.is_checkpoint_fp8_serialized: - qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None) - layer.weight = Parameter(qweight.t(), requires_grad=False) - layer.weight_scale = Parameter(weight_scale, requires_grad=False) - layer.logical_widths = None - layer.input_scale = None - if self.use_marlin: - self.prepare_layer_for_marlin(layer) - return - - # If checkpoint is fp8, requantize the separately quantized logical - # weights into a single fp8 weight with a single weight scale. - else: - # WEIGHT_SCALE / WEIGHT - # Loop over logical weights, requantizing with single scale. - max_w_scale = layer.weight_scale.max() - - # QKV / MLP is fused in the on disk checkpoint if any of the - # weight scales are still set to the default since we initialize - # N weight scales for N shards but we only load 1 weight scale - # from disk in this case. As a result, we skip dequant -> requant - # since we already have quantized QKV together. - # Sample Model with fused checkpoint: - # * nm-testing/Phi-3-mini-128k-instruct-FP8 - unfused_module_in_checkpoint = ( - layer.weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min - ) - - if unfused_module_in_checkpoint: - start = 0 - for idx, logical_width in enumerate(layer.logical_widths): - end = start + logical_width - weight_dq = per_tensor_dequantize( - layer.weight[start:end, :], layer.weight_scale[idx] - ) - - layer.weight[start:end, :] = per_tensor_quantize( - weight_dq, layer.weight_scale.max() - ) - start = end - layer.weight_scale = Parameter(max_w_scale, requires_grad=False) - - # WEIGHT - # Transpose weight for passing to torch._scaled_mm - weight = layer.weight - layer.weight = Parameter(weight.t(), requires_grad=False) - - # INPUT ACTIVATION SCALE - # Dynamic: set to None (required input to ops.scaled_fp8_quant). - # Static: set to max of the input_scales (since they are equal). - if self.quant_config.activation_scheme == "dynamic": - layer.input_scale = None - elif self.quant_config.activation_scheme == "static": - layer.input_scale = Parameter( - layer.input_scale.max(), requires_grad=False - ) - else: - raise ValueError( - f"Unknown scheme {self.quant_config.activation_scheme}" - ) - - if self.use_marlin: - self.prepare_layer_for_marlin(layer) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - - if self.use_marlin: - # For GPUs that lack FP8 hardware support, we can leverage the - # Marlin kernel for fast weight-only FP8 quantization - - reshaped_x = x.reshape(-1, x.shape[-1]) - out_shape = x.shape[:-1] + (layer.output_size_per_partition,) - - output = ops.fp8_marlin_gemm( - a=reshaped_x, - b_q_weight=layer.weight, - b_scales=layer.weight_scale, - workspace=layer.workspace, - num_bits=8, - size_m=reshaped_x.shape[0], - size_n=layer.output_size_per_partition, - size_k=layer.input_size_per_partition, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) - - else: - - # ops.scaled_fp8_quant supports both dynamic and static quant. - # If dynamic, layer.input_scale is None and x_scale computed from x - # If static, layer.input_scale is scalar and x_scale is input_scale - - if bias is None and self.cutlass_fp8_supported: - qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale) - - # Fused GEMM_DQ - output = ops.cutlass_scaled_mm( - qinput, - layer.weight, - out_dtype=x.dtype, - scale_a=x_scale, - scale_b=layer.weight_scale, - ) - - else: - qinput, x_scale = ops.scaled_fp8_quant( - x, layer.input_scale, batch_dim_padding=17 - ) - - # Fused GEMM_DQ -- note we padded the input above because - # torch._scaled_mm is more performant for matrices with - # batch dimension > 16. Note that this could change - # in the future. - output, _ = torch._scaled_mm( - qinput, - layer.weight, - out_dtype=x.dtype, - scale_a=x_scale, - scale_b=layer.weight_scale, - bias=bias, - ) - - return torch.narrow(output, 0, 0, x.shape[0]) - - -class Fp8MoEMethod(FusedMoEMethodBase): - """MoE method for FP8. - Supports loading FP8 checkpoints with static weight scale and - dynamic/static activation scale. - - Also supports loading quantized FP16/BF16 model checkpoints with dynamic - activation scaling. The weight scaling factor will be initialized after - the model weights are loaded. - - Args: - quant_config: The quantization config. - """ - - def __init__(self, quant_config: Fp8Config): - self.quant_config = quant_config - - def create_weights( - self, - layer: Module, - num_experts: int, - hidden_size: int, - intermediate_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - - layer.process_after_load = True - - if self.quant_config.is_checkpoint_fp8_serialized: - params_dtype = torch.float8_e4m3fn - - # WEIGHTS - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype - ), - requires_grad=False, - ) - layer.register_parameter("w13_weight", w13_weight) - set_weight_attrs(w13_weight, extra_weight_attrs) - - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, hidden_size, intermediate_size, dtype=params_dtype - ), - requires_grad=False, - ) - layer.register_parameter("w2_weight", w2_weight) - set_weight_attrs(w2_weight, extra_weight_attrs) - - # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - w13_scale = torch.nn.Parameter( - torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w13_scale", w13_scale) - - w2_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w2_scale", w2_scale) - - # If loading fp8 checkpoint, pass the weight loaders. - # If loading an fp16 checkpoint, do not (we will quantize in - # process_weights_after_loading() - if self.quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs(w13_scale, extra_weight_attrs) - set_weight_attrs(w2_scale, extra_weight_attrs) - - # INPUT_SCALES - if self.quant_config.activation_scheme == "static": - if not self.quant_config.is_checkpoint_fp8_serialized: - raise ValueError( - "Found static activation scheme for checkpoint that " - "was not serialized fp8." - ) - - a13_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("a13_scale", a13_scale) - set_weight_attrs(a13_scale, extra_weight_attrs) - - a2_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("a2_scale", a2_scale) - set_weight_attrs(a2_scale, extra_weight_attrs) - else: - layer.a13_scale = None - layer.a2_scale = None - - def process_weights_after_loading(self, layer: Module) -> None: - if not hasattr(layer, "process_after_load") or not layer.process_after_load: - return - - # If checkpoint is fp16, quantize in place. - if not self.quant_config.is_checkpoint_fp8_serialized: - w13_weight = torch.empty_like( - layer.w13_weight.data, dtype=torch.float8_e4m3fn - ) - w2_weight = torch.empty_like( - layer.w2_weight.data, dtype=torch.float8_e4m3fn - ) - - # Re-initialize w13_scale because we directly quantize - # merged w13 weights and generate a single scaling factor. - layer.w13_scale = torch.nn.Parameter( - torch.ones( - layer.num_experts, dtype=torch.float32, device=w13_weight.device - ), - requires_grad=False, - ) - for expert in range(layer.num_experts): - w13_weight[expert, :, :], layer.w13_scale[expert] = ( - ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) - ) - w2_weight[expert, :, :], layer.w2_scale[expert] = ops.scaled_fp8_quant( - layer.w2_weight.data[expert, :, :] - ) - layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) - layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) - return - - # If checkpoint is fp8, we need to handle that the - # MoE kernels require single activation scale and single weight - # scale for w13 per expert. - else: - # Fp8 moe kernels require a single activation scale. - # We take the max of all the scales in case they differ. - if self.quant_config.activation_scheme == "static": - if layer.a13_scale is None or layer.a2_scale is None: - raise ValueError( - "QuantConfig has static quantization, but found " - "activation scales are None." - ) - if not all_close_1d(layer.a13_scale) or not all_close_1d( - layer.a2_scale - ): - print_warning_once( - "Found input_scales that are not equal for " - "fp8 MoE layer. Using the maximum across experts " - "for each layer. " - ) - layer.a13_scale = torch.nn.Parameter( - layer.a13_scale.max(), requires_grad=False - ) - layer.a2_scale = torch.nn.Parameter( - layer.a2_scale.max(), requires_grad=False - ) - - # Fp8 moe kernel needs single weight scale for w13 per expert. - # We take the max then dequant and requant each expert. - assert layer.w13_scale is not None - shard_size = layer.intermediate_size_per_partition - max_w13_scales = layer.w13_scale.max(dim=1).values - for expert_id in range(layer.num_experts): - start = 0 - for shard_id in range(2): - dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start : start + shard_size, :], - layer.w13_scale[expert_id][shard_id], - ) - layer.w13_weight[expert_id][start : start + shard_size, :] = ( - per_tensor_quantize(dq_weight, max_w13_scales[expert_id]) - ) - start += shard_size - - layer.w13_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) - return - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - ) -> torch.Tensor: - - return fused_moe( - x, - layer.w13_weight, - layer.w2_weight, - router_logits, - top_k, - renormalize=renormalize, - inplace=True, - use_fp8=True, - w1_scale=layer.w13_scale, - w2_scale=layer.w2_scale, - a1_scale=layer.a13_scale, - a2_scale=layer.a2_scale, - ) - - -# FIXME: not used -class Fp8KVCacheMethod(QuantizeMethodBase): - """Supports loading kv-cache scaling factors from FP8 checkpoints.""" - - def __init__(self, quant_config: Fp8Config): - self.quant_config = quant_config - - def create_weights(self, layer: torch.nn.Module): - """Create "weight" (aka kv_scale) for an attention layer. - - Args: - layer: The layer that is using the QuantizeMethodBase factory. - """ - # Initialize the KV cache scale to 1.0 as the default value. - # If the kv_scale appears in the checkpoint, it will be - # overwritten when loading weights. - layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False) - - def apply(self, layer: torch.nn.Module) -> torch.Tensor: - raise RuntimeError("Fp8KVCacheMethod.apply should not be called.") - - def process_weights_after_loading(self, layer: Module) -> None: - # If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0 - # regardless whether the kv-scale is available in the checkpoint. - if layer.kv_cache_dtype != "auto": - kv_scale = layer.kv_scale.to("cpu").tolist() - if not isinstance(kv_scale, float): - raise ValueError( - "Only support per-tensor scaling factor " "for fp8 KV cache" - ) - layer._kv_scale = kv_scale - if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype: - print_warning_once( - "Using KV cache scaling factor 1.0 for fp8_e4m3. This may " - "cause accuracy issues. Please make sure kv-cache scaling " - "factor is available in the fp8 checkpoint." - ) - del layer.kv_scale - - -def per_tensor_quantize( - tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor] -) -> torch.Tensor: - finfo = torch.finfo(torch.float8_e4m3fn) - qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max) - return qweight.to(torch.float8_e4m3fn) - - -def per_tensor_dequantize( - tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor] -) -> torch.Tensor: - fake_qweight = tensor.to(torch.float16) - dq_weight = fake_qweight * inv_scale - return dq_weight - - -def all_close_1d(x: torch.Tensor) -> bool: - assert len(x.shape) == 1 - return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 2afd329f9..1568cf6d9 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -20,8 +20,8 @@ from flashinfer.cascade import merge_state from torch import nn from sglang.global_config import global_config +from sglang.srt.layers.decode_attention import decode_attention_fwd from sglang.srt.layers.extend_attention import extend_attention_fwd -from sglang.srt.layers.token_attention import token_attention_fwd from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.model_executor.model_runner import global_server_args_dict @@ -95,7 +95,7 @@ class RadixAttention(nn.Module): o = torch.empty_like(q) self.store_kv_cache(k, v, input_metadata) - token_attention_fwd( + decode_attention_fwd( q.view(-1, self.tp_q_head_num, self.qk_head_dim), input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id), input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),