Clean up the comments and names under python/sglang/srt/layers (#1047)
This commit is contained in:
@@ -11,6 +11,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
"""Fused operators for activation layers."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|||||||
@@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
Memory-efficient attention for decoding.
|
||||||
|
"""
|
||||||
|
|
||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
|
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
|
||||||
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
|
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
|
||||||
@@ -194,7 +198,7 @@ def _fwd_kernel_stage2(
|
|||||||
tl.store(out_ptrs, acc)
|
tl.store(out_ptrs, acc)
|
||||||
|
|
||||||
|
|
||||||
def _token_att_m_fwd(
|
def _decode_att_m_fwd(
|
||||||
q,
|
q,
|
||||||
k_buffer,
|
k_buffer,
|
||||||
att_out,
|
att_out,
|
||||||
@@ -254,7 +258,7 @@ def _token_att_m_fwd(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _token_softmax_reducev_fwd(
|
def _decode_softmax_reducev_fwd(
|
||||||
logics,
|
logics,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
o,
|
o,
|
||||||
@@ -292,7 +296,7 @@ def _token_softmax_reducev_fwd(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def token_attention_fwd(
|
def decode_attention_fwd(
|
||||||
q,
|
q,
|
||||||
k_buffer,
|
k_buffer,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
@@ -312,7 +316,7 @@ def token_attention_fwd(
|
|||||||
(q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
|
(q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
|
||||||
)
|
)
|
||||||
|
|
||||||
_token_att_m_fwd(
|
_decode_att_m_fwd(
|
||||||
q,
|
q,
|
||||||
k_buffer,
|
k_buffer,
|
||||||
att_m,
|
att_m,
|
||||||
@@ -324,7 +328,7 @@ def token_attention_fwd(
|
|||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap,
|
logit_cap,
|
||||||
)
|
)
|
||||||
_token_softmax_reducev_fwd(
|
_decode_softmax_reducev_fwd(
|
||||||
att_m,
|
att_m,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
o,
|
o,
|
||||||
@@ -13,11 +13,16 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
Memory-efficient attention for prefill.
|
||||||
|
It supporst page size = 1 and prefill with KV cache (i.e. extend).
|
||||||
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
from sglang.srt.layers.prefill_attention import context_attention_fwd
|
||||||
|
|
||||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
"""Fused operators for normalization layers."""
|
||||||
|
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -1,884 +0,0 @@
|
|||||||
"""
|
|
||||||
Copyright 2023-2024 SGLang Team
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# temporarily adapted from https://github.com/vllm-project/vllm/blob/e76466dde2bc9525d55165ceaa600d298c7bf773/vllm/model_executor/layers/linear.py
|
|
||||||
# FIXME: refactor the linear abstraction
|
|
||||||
from abc import abstractmethod
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
from vllm.distributed import (
|
|
||||||
divide,
|
|
||||||
get_tensor_model_parallel_rank,
|
|
||||||
get_tensor_model_parallel_world_size,
|
|
||||||
split_tensor_along_last_dim,
|
|
||||||
tensor_model_parallel_all_gather,
|
|
||||||
tensor_model_parallel_all_reduce,
|
|
||||||
)
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
|
||||||
QuantizationConfig,
|
|
||||||
QuantizeMethodBase,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def adjust_marlin_shard(param, shard_size, shard_offset):
|
|
||||||
marlin_tile_size = getattr(param, "marlin_tile_size", None)
|
|
||||||
if marlin_tile_size is None:
|
|
||||||
return shard_size, shard_offset
|
|
||||||
|
|
||||||
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
|
|
||||||
|
|
||||||
|
|
||||||
def adjust_bitsandbytes_shard(
|
|
||||||
param: Parameter, qkv_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
|
|
||||||
) -> Tuple[int, int]:
|
|
||||||
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
|
|
||||||
|
|
||||||
total, _ = qkv_offsets["total"]
|
|
||||||
orig_offset, orig_size = qkv_offsets[loaded_shard_id]
|
|
||||||
|
|
||||||
quantized_total = param.data.shape[0]
|
|
||||||
quantized_offset = orig_offset * quantized_total // total
|
|
||||||
quantized_size = orig_size * quantized_total // total
|
|
||||||
|
|
||||||
return quantized_size, quantized_offset
|
|
||||||
|
|
||||||
|
|
||||||
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
|
|
||||||
"""For fused modules (QKV and MLP) we have an array of length
|
|
||||||
N that holds 1 scale for each "logical" matrix. So the param
|
|
||||||
is an array of length N. The loaded_weight corresponds to
|
|
||||||
one of the shards on disk. Here, we slice the param based on
|
|
||||||
the shard_id for loading.
|
|
||||||
"""
|
|
||||||
qkv_idxs = {"q": 0, "k": 1, "v": 2}
|
|
||||||
|
|
||||||
if isinstance(shard_id, str):
|
|
||||||
shard_id = qkv_idxs[shard_id]
|
|
||||||
elif not isinstance(shard_id, int):
|
|
||||||
raise ValueError(f"Unknown Shard Id {shard_id}")
|
|
||||||
|
|
||||||
# AutoFP8 scales do not have a shape
|
|
||||||
# compressed-tensors scales do have a shape
|
|
||||||
if len(loaded_weight.shape) != 0:
|
|
||||||
assert loaded_weight.shape[0] == 1
|
|
||||||
loaded_weight = loaded_weight[0]
|
|
||||||
|
|
||||||
return param[shard_id], loaded_weight
|
|
||||||
|
|
||||||
|
|
||||||
class LinearMethodBase(QuantizeMethodBase):
|
|
||||||
"""Base class for different (maybe quantized) linear methods."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create_weights(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
input_size_per_partition: int,
|
|
||||||
output_partition_sizes: List[int],
|
|
||||||
input_size: int,
|
|
||||||
output_size: int,
|
|
||||||
params_dtype: torch.dtype,
|
|
||||||
**extra_weight_attrs,
|
|
||||||
):
|
|
||||||
"""Create weights for a linear layer.
|
|
||||||
The weights will be set as attributes of the layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer: The layer that is using the LinearMethodBase factory.
|
|
||||||
input_size_per_partition: Size of the weight input dim on rank X.
|
|
||||||
output_partition_sizes: Sizes of the output dim of each logical
|
|
||||||
weight on rank X. E.g., output_partition_sizes for QKVLinear
|
|
||||||
is a list contains the width of Wq, Wk, Wv on rank X.
|
|
||||||
input_size: Size of the input dim of the weight across all ranks.
|
|
||||||
output_size: Size of the output dim of the weight across all ranks.
|
|
||||||
params_dtype: Datatype of the parameters.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def apply(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
bias: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Apply the weights in layer to the input tensor.
|
|
||||||
Expects create_weights to have been called before on the layer."""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class UnquantizedLinearMethod(LinearMethodBase):
|
|
||||||
"""Linear method without quantization.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
separate_bias_add: If true, add bias separately after matrix
|
|
||||||
multiplication.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, separate_bias_add: bool = False):
|
|
||||||
self.separate_bias_add = separate_bias_add
|
|
||||||
|
|
||||||
def create_weights(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
input_size_per_partition: int,
|
|
||||||
output_partition_sizes: List[int],
|
|
||||||
input_size: int,
|
|
||||||
output_size: int,
|
|
||||||
params_dtype: torch.dtype,
|
|
||||||
**extra_weight_attrs,
|
|
||||||
):
|
|
||||||
weight = Parameter(
|
|
||||||
torch.empty(
|
|
||||||
sum(output_partition_sizes),
|
|
||||||
input_size_per_partition,
|
|
||||||
dtype=params_dtype,
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
|
||||||
layer.register_parameter("weight", weight)
|
|
||||||
set_weight_attrs(weight, extra_weight_attrs)
|
|
||||||
|
|
||||||
def apply(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
bias: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
weight = layer.weight
|
|
||||||
if self.separate_bias_add:
|
|
||||||
if bias is not None:
|
|
||||||
return F.linear(x, weight) + bias
|
|
||||||
return F.linear(x, weight)
|
|
||||||
return F.linear(x, weight, bias)
|
|
||||||
|
|
||||||
|
|
||||||
class LinearBase(torch.nn.Module):
|
|
||||||
"""Base linear layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_size: input dimension of the linear layer.
|
|
||||||
output_size: output dimension of the linear layer.
|
|
||||||
bias: If true, add bias.
|
|
||||||
skip_bias_add: If true, skip adding bias but instead return it.
|
|
||||||
params_dtype: Data type for the parameters.
|
|
||||||
quant_config: Quantization configure.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_size: int,
|
|
||||||
output_size: int,
|
|
||||||
skip_bias_add: bool = False,
|
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
# Keep input parameters
|
|
||||||
self.input_size = input_size
|
|
||||||
self.output_size = output_size
|
|
||||||
self.skip_bias_add = skip_bias_add
|
|
||||||
if params_dtype is None:
|
|
||||||
params_dtype = torch.get_default_dtype()
|
|
||||||
self.params_dtype = params_dtype
|
|
||||||
if quant_config is None:
|
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod()
|
|
||||||
else:
|
|
||||||
self.quant_method = quant_config.get_quant_method(self)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class ReplicatedLinear(LinearBase):
|
|
||||||
"""Replicated linear layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_size: input dimension of the linear layer.
|
|
||||||
output_size: output dimension of the linear layer.
|
|
||||||
bias: If true, add bias.
|
|
||||||
skip_bias_add: If true, skip adding bias but instead return it.
|
|
||||||
params_dtype: Data type for the parameters.
|
|
||||||
quant_config: Quantization configure.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_size: int,
|
|
||||||
output_size: int,
|
|
||||||
bias: bool = True,
|
|
||||||
skip_bias_add: bool = False,
|
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
input_size, output_size, skip_bias_add, params_dtype, quant_config
|
|
||||||
)
|
|
||||||
|
|
||||||
# All the linear layer supports quant method.
|
|
||||||
assert self.quant_method is not None
|
|
||||||
self.quant_method.create_weights(
|
|
||||||
self,
|
|
||||||
self.input_size,
|
|
||||||
[self.output_size],
|
|
||||||
self.input_size,
|
|
||||||
self.output_size,
|
|
||||||
self.params_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
if bias:
|
|
||||||
self.bias = Parameter(
|
|
||||||
torch.empty(self.output_size, dtype=self.params_dtype)
|
|
||||||
)
|
|
||||||
set_weight_attrs(self.bias, {"output_dim": 0})
|
|
||||||
else:
|
|
||||||
self.register_parameter("bias", None)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
|
||||||
assert self.quant_method is not None
|
|
||||||
output = self.quant_method.apply(self, x, bias)
|
|
||||||
output_bias = self.bias if self.skip_bias_add else None
|
|
||||||
return output, output_bias
|
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
|
||||||
s = f"in_features={self.input_size}"
|
|
||||||
s += f", output_features={self.output_size}"
|
|
||||||
s += f", bias={self.bias is not None}"
|
|
||||||
return s
|
|
||||||
|
|
||||||
|
|
||||||
class ColumnParallelLinear(LinearBase):
|
|
||||||
"""Linear layer with column parallelism.
|
|
||||||
|
|
||||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
|
||||||
its second dimension as A = [A_1, ..., A_p].
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_size: first dimension of matrix A.
|
|
||||||
output_size: second dimension of matrix A.
|
|
||||||
bias: If true, add bias.
|
|
||||||
gather_output: If true, call all-gather on output and make Y available
|
|
||||||
to all GPUs, otherwise, every GPU will have its output
|
|
||||||
which is Y_i = XA_i
|
|
||||||
skip_bias_add: This was added to enable performance optimizations where
|
|
||||||
bias can be fused with other element-wise operations. we
|
|
||||||
skip adding bias but instead return it.
|
|
||||||
params_dtype: Data type for the parameters.
|
|
||||||
quant_config: Quantization configure.
|
|
||||||
output_sizes: list of output sizes packed into one output, like for QKV
|
|
||||||
the list would be size 3.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_size: int,
|
|
||||||
output_size: int,
|
|
||||||
bias: bool = True,
|
|
||||||
gather_output: bool = False,
|
|
||||||
skip_bias_add: bool = False,
|
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
output_sizes: Optional[List[int]] = None,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
input_size, output_size, skip_bias_add, params_dtype, quant_config
|
|
||||||
)
|
|
||||||
|
|
||||||
self.gather_output = gather_output
|
|
||||||
|
|
||||||
# Divide the weight matrix along the last dimension.
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
assert self.quant_method is not None
|
|
||||||
self.output_size_per_partition = divide(self.output_size, tp_size)
|
|
||||||
self.output_partition_sizes = [self.output_size_per_partition]
|
|
||||||
# If QKV or MergedColumn, use output size of each partition.
|
|
||||||
if hasattr(self, "output_sizes"):
|
|
||||||
self.output_partition_sizes = [
|
|
||||||
divide(output_size, tp_size) for output_size in self.output_sizes
|
|
||||||
]
|
|
||||||
|
|
||||||
if output_sizes is None:
|
|
||||||
output_sizes = [output_size]
|
|
||||||
self.quant_method.create_weights(
|
|
||||||
layer=self,
|
|
||||||
input_size_per_partition=self.input_size,
|
|
||||||
output_partition_sizes=self.output_partition_sizes,
|
|
||||||
input_size=self.input_size,
|
|
||||||
output_size=self.output_size,
|
|
||||||
params_dtype=self.params_dtype,
|
|
||||||
weight_loader=self.weight_loader,
|
|
||||||
)
|
|
||||||
if bias:
|
|
||||||
self.bias = Parameter(
|
|
||||||
torch.empty(self.output_size_per_partition, dtype=params_dtype)
|
|
||||||
)
|
|
||||||
set_weight_attrs(
|
|
||||||
self.bias,
|
|
||||||
{
|
|
||||||
"output_dim": 0,
|
|
||||||
"weight_loader": self.weight_loader,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.register_parameter("bias", None)
|
|
||||||
|
|
||||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
|
||||||
if param.data.dtype != loaded_weight.dtype:
|
|
||||||
param.data = torch.empty_like(
|
|
||||||
param.data, dtype=loaded_weight.dtype, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
output_dim = getattr(param, "output_dim", None)
|
|
||||||
param_data = param.data
|
|
||||||
if output_dim is not None:
|
|
||||||
shard_size = param_data.shape[output_dim]
|
|
||||||
start_idx = tp_rank * shard_size
|
|
||||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
|
||||||
|
|
||||||
# Special case for loading scales off disk, which often do not
|
|
||||||
# have a shape (such as in the case of AutoFP8).
|
|
||||||
if len(loaded_weight.shape) == 0:
|
|
||||||
loaded_weight = loaded_weight.reshape(1)
|
|
||||||
|
|
||||||
assert param_data.shape == loaded_weight.shape
|
|
||||||
param_data.copy_(loaded_weight)
|
|
||||||
|
|
||||||
def forward(self, input_):
|
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
|
||||||
|
|
||||||
# Matrix multiply.
|
|
||||||
assert self.quant_method is not None
|
|
||||||
output_parallel = self.quant_method.apply(self, input_, bias)
|
|
||||||
if self.gather_output:
|
|
||||||
# All-gather across the partitions.
|
|
||||||
output = tensor_model_parallel_all_gather(output_parallel)
|
|
||||||
else:
|
|
||||||
output = output_parallel
|
|
||||||
output_bias = self.bias if self.skip_bias_add else None
|
|
||||||
return output, output_bias
|
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
|
||||||
s = f"in_features={self.input_size}"
|
|
||||||
s += f", output_features={self.output_size_per_partition}"
|
|
||||||
s += f", bias={self.bias is not None}"
|
|
||||||
s += f", tp_size={get_tensor_model_parallel_world_size()}"
|
|
||||||
s += f", gather_output={self.gather_output}"
|
|
||||||
return s
|
|
||||||
|
|
||||||
|
|
||||||
class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
||||||
"""Packed linear layers with column parallelism.
|
|
||||||
|
|
||||||
Similar to ColumnParallelLinear, but the weight matrix is concatenated
|
|
||||||
along the output dimension. When the weight matrix is loaded, the
|
|
||||||
different partitions are sharded separately.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_size: input dimension of the linear layer.
|
|
||||||
output_sizes: list of output dimensions of the linear layer.
|
|
||||||
bias: If true, add bias.
|
|
||||||
gather_output: If true, call all-gather on output and make the output
|
|
||||||
available to all GPUs, otherwise, every GPU will have
|
|
||||||
its own output.
|
|
||||||
skip_bias_add: This was added to enable performance optimizations where
|
|
||||||
bias can be fused with other element-wise operations. we
|
|
||||||
skip adding bias but instead return it.
|
|
||||||
params_dtype: Data type for the parameters.
|
|
||||||
quant_config: Quantization configure.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_size: int,
|
|
||||||
output_sizes: List[int],
|
|
||||||
bias: bool = True,
|
|
||||||
gather_output: bool = False,
|
|
||||||
skip_bias_add: bool = False,
|
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
):
|
|
||||||
self.output_sizes = output_sizes
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
|
||||||
super().__init__(
|
|
||||||
input_size=input_size,
|
|
||||||
output_size=sum(output_sizes),
|
|
||||||
bias=bias,
|
|
||||||
gather_output=gather_output,
|
|
||||||
skip_bias_add=skip_bias_add,
|
|
||||||
params_dtype=params_dtype,
|
|
||||||
quant_config=quant_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
def weight_loader(
|
|
||||||
self,
|
|
||||||
param: Parameter,
|
|
||||||
loaded_weight: torch.Tensor,
|
|
||||||
loaded_shard_id: Optional[int] = None,
|
|
||||||
):
|
|
||||||
if param.data.dtype != loaded_weight.dtype:
|
|
||||||
param.data = torch.empty_like(
|
|
||||||
param.data, dtype=loaded_weight.dtype, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
param_data = param.data
|
|
||||||
output_dim = getattr(param, "output_dim", None)
|
|
||||||
# Special case for AQLM codebooks.
|
|
||||||
is_metadata = getattr(param, "is_metadata", False)
|
|
||||||
# Special case for per-tensor scale to load scalar into fused array.
|
|
||||||
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
|
|
||||||
|
|
||||||
if loaded_shard_id is None:
|
|
||||||
# Loaded weight is already fused on disk (qkv/mlp).
|
|
||||||
if output_dim is None:
|
|
||||||
if needs_scalar_to_array is not None:
|
|
||||||
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
|
||||||
param_data, loaded_weight, 0
|
|
||||||
)
|
|
||||||
|
|
||||||
assert param_data.shape == loaded_weight.shape
|
|
||||||
param_data.copy_(loaded_weight)
|
|
||||||
return
|
|
||||||
current_shard_offset = 0
|
|
||||||
shard_offsets: List[Tuple[int, int, int]] = []
|
|
||||||
for i, output_size in enumerate(self.output_sizes):
|
|
||||||
shard_offsets.append((i, current_shard_offset, output_size))
|
|
||||||
current_shard_offset += output_size
|
|
||||||
packed_dim = getattr(param, "packed_dim", None)
|
|
||||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
|
||||||
# Special case for Quantization.
|
|
||||||
# If quantized, we need to adjust the offset and size to account
|
|
||||||
# for the packing.
|
|
||||||
if packed_dim == output_dim:
|
|
||||||
shard_size = shard_size // param.pack_factor
|
|
||||||
shard_offset = shard_offset // param.pack_factor
|
|
||||||
# Special case for Marlin.
|
|
||||||
shard_size, shard_offset = adjust_marlin_shard(
|
|
||||||
param, shard_size, shard_offset
|
|
||||||
)
|
|
||||||
|
|
||||||
loaded_weight_shard = loaded_weight.narrow(
|
|
||||||
output_dim, shard_offset, shard_size
|
|
||||||
)
|
|
||||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
assert loaded_shard_id < len(self.output_sizes)
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
if output_dim is not None:
|
|
||||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
|
||||||
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
|
||||||
# Special case for quantization.
|
|
||||||
# If quantized, we need to adjust the offset and size to account
|
|
||||||
# for the packing.
|
|
||||||
packed_dim = getattr(param, "packed_dim", None)
|
|
||||||
if packed_dim == output_dim:
|
|
||||||
shard_size = shard_size // param.pack_factor
|
|
||||||
shard_offset = shard_offset // param.pack_factor
|
|
||||||
# Special case for Marlin.
|
|
||||||
shard_size, shard_offset = adjust_marlin_shard(
|
|
||||||
param, shard_size, shard_offset
|
|
||||||
)
|
|
||||||
|
|
||||||
use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
|
|
||||||
if use_bitsandbytes:
|
|
||||||
shard_size = loaded_weight.shape[output_dim]
|
|
||||||
shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
|
|
||||||
|
|
||||||
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
|
||||||
start_idx = tp_rank * shard_size
|
|
||||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
|
||||||
# Special case for AQLM codebooks.
|
|
||||||
elif is_metadata:
|
|
||||||
# metadata indicates fixed size concatenated along dim 0
|
|
||||||
shard_size = loaded_weight.shape[0]
|
|
||||||
shard_offset = loaded_shard_id * shard_size
|
|
||||||
param_data = param_data.narrow(0, shard_offset, shard_size)
|
|
||||||
|
|
||||||
# Special case for per-tensor scales in fused case.
|
|
||||||
elif needs_scalar_to_array:
|
|
||||||
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
|
||||||
param_data, loaded_weight, loaded_shard_id
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
ignore_warning = getattr(param, "ignore_warning", False)
|
|
||||||
if not ignore_warning:
|
|
||||||
logger.warning(
|
|
||||||
"Loading a weight without `output_dim` attribute in "
|
|
||||||
"MergedColumnParallelLinear, assume the weight is "
|
|
||||||
"the same for all partitions."
|
|
||||||
)
|
|
||||||
|
|
||||||
assert param_data.shape == loaded_weight.shape
|
|
||||||
param_data.copy_(loaded_weight)
|
|
||||||
|
|
||||||
|
|
||||||
class QKVParallelLinear(ColumnParallelLinear):
|
|
||||||
"""Linear layers for the attention's QKV transformation.
|
|
||||||
|
|
||||||
Linear layers for the linear transformation of the query, key, and value
|
|
||||||
vectors in the attention layer. The weight matrix is concatenated along
|
|
||||||
the output dimension. The layer is parallelized along the head dimension.
|
|
||||||
When the number of key/value heads is smaller than the number of query
|
|
||||||
heads (e.g., multi-query/grouped-query attention), the key/value head may
|
|
||||||
be replicated while the query heads are partitioned.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
hidden_size: input hidden state size of the transformer.
|
|
||||||
head_size: size of each attention head.
|
|
||||||
total_num_heads: total number of attention query heads.
|
|
||||||
total_num_kv_heads: total number of attention key/value heads. If
|
|
||||||
None, assume total_num_kv_heads = total_num_heads.
|
|
||||||
bias: If true, add bias.
|
|
||||||
skip_bias_add: This was added to enable performance optimizations where
|
|
||||||
bias can be fused with other element-wise operations. we
|
|
||||||
skip adding bias but instead return it.
|
|
||||||
params_dtype: Data type for the parameters.
|
|
||||||
quant_config: Quantization configure.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
head_size: int,
|
|
||||||
total_num_heads: int,
|
|
||||||
total_num_kv_heads: Optional[int] = None,
|
|
||||||
bias: bool = True,
|
|
||||||
skip_bias_add: bool = False,
|
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
):
|
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.head_size = head_size
|
|
||||||
self.total_num_heads = total_num_heads
|
|
||||||
if total_num_kv_heads is None:
|
|
||||||
total_num_kv_heads = total_num_heads
|
|
||||||
self.total_num_kv_heads = total_num_kv_heads
|
|
||||||
# Divide the weight matrix along the last dimension.
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
self.num_heads = divide(self.total_num_heads, tp_size)
|
|
||||||
if tp_size >= self.total_num_kv_heads:
|
|
||||||
self.num_kv_heads = 1
|
|
||||||
self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
|
|
||||||
else:
|
|
||||||
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
|
||||||
self.num_kv_head_replicas = 1
|
|
||||||
input_size = self.hidden_size
|
|
||||||
output_size = (
|
|
||||||
(self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
|
|
||||||
)
|
|
||||||
self.output_sizes = [
|
|
||||||
self.num_heads * self.head_size * tp_size, # q_proj
|
|
||||||
self.num_kv_heads * self.head_size * tp_size, # k_proj
|
|
||||||
self.num_kv_heads * self.head_size * tp_size, # v_proj
|
|
||||||
]
|
|
||||||
|
|
||||||
super().__init__(
|
|
||||||
input_size=input_size,
|
|
||||||
output_size=output_size,
|
|
||||||
bias=bias,
|
|
||||||
gather_output=False,
|
|
||||||
skip_bias_add=skip_bias_add,
|
|
||||||
params_dtype=params_dtype,
|
|
||||||
quant_config=quant_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
def weight_loader(
|
|
||||||
self,
|
|
||||||
param: Parameter,
|
|
||||||
loaded_weight: torch.Tensor,
|
|
||||||
loaded_shard_id: Optional[str] = None,
|
|
||||||
):
|
|
||||||
if param.data.dtype != loaded_weight.dtype:
|
|
||||||
param.data = torch.empty_like(
|
|
||||||
param.data, dtype=loaded_weight.dtype, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
param_data = param.data
|
|
||||||
output_dim = getattr(param, "output_dim", None)
|
|
||||||
# Special case for AQLM codebooks.
|
|
||||||
is_metadata = getattr(param, "is_metadata", False)
|
|
||||||
|
|
||||||
# Special case for per-tensor scales in fused case.
|
|
||||||
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
|
|
||||||
|
|
||||||
if loaded_shard_id is None:
|
|
||||||
# Loaded weight is already fused on disk (qkv/mlp).
|
|
||||||
if output_dim is None:
|
|
||||||
if needs_scalar_to_array is not None:
|
|
||||||
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
|
||||||
param_data, loaded_weight, 0
|
|
||||||
)
|
|
||||||
|
|
||||||
assert param_data.shape == loaded_weight.shape
|
|
||||||
param_data.copy_(loaded_weight)
|
|
||||||
return
|
|
||||||
shard_offsets = [
|
|
||||||
# (shard_id, shard_offset, shard_size)
|
|
||||||
("q", 0, self.total_num_heads * self.head_size),
|
|
||||||
(
|
|
||||||
"k",
|
|
||||||
self.total_num_heads * self.head_size,
|
|
||||||
self.total_num_kv_heads * self.head_size,
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"v",
|
|
||||||
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
|
|
||||||
self.total_num_kv_heads * self.head_size,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
packed_dim = getattr(param, "packed_dim", None)
|
|
||||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
|
||||||
# Special case for Quantized Weights.
|
|
||||||
# If quantized, we need to adjust the offset and size to account
|
|
||||||
# for the packing.
|
|
||||||
if packed_dim == output_dim:
|
|
||||||
shard_size = shard_size // param.pack_factor
|
|
||||||
shard_offset = shard_offset // param.pack_factor
|
|
||||||
|
|
||||||
# Special case for Marlin.
|
|
||||||
shard_size, shard_offset = adjust_marlin_shard(
|
|
||||||
param, shard_size, shard_offset
|
|
||||||
)
|
|
||||||
|
|
||||||
loaded_weight_shard = loaded_weight.narrow(
|
|
||||||
output_dim, shard_offset, shard_size
|
|
||||||
)
|
|
||||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
assert loaded_shard_id in ["q", "k", "v"]
|
|
||||||
|
|
||||||
# If output dim is defined, use the default loading process.
|
|
||||||
if output_dim is not None:
|
|
||||||
if loaded_shard_id == "q":
|
|
||||||
shard_offset = 0
|
|
||||||
shard_size = self.num_heads * self.head_size
|
|
||||||
elif loaded_shard_id == "k":
|
|
||||||
shard_offset = self.num_heads * self.head_size
|
|
||||||
shard_size = self.num_kv_heads * self.head_size
|
|
||||||
elif loaded_shard_id == "v":
|
|
||||||
shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
|
|
||||||
shard_size = self.num_kv_heads * self.head_size
|
|
||||||
# Special case for Quantized Weights.
|
|
||||||
# If quantized, we need to adjust the offset and size to account
|
|
||||||
# for the packing.
|
|
||||||
packed_dim = getattr(param, "packed_dim", None)
|
|
||||||
if packed_dim == output_dim:
|
|
||||||
shard_size = shard_size // param.pack_factor
|
|
||||||
shard_offset = shard_offset // param.pack_factor
|
|
||||||
|
|
||||||
# Special case for Marlin.
|
|
||||||
shard_size, shard_offset = adjust_marlin_shard(
|
|
||||||
param, shard_size, shard_offset
|
|
||||||
)
|
|
||||||
|
|
||||||
use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
|
|
||||||
if use_bitsandbytes:
|
|
||||||
orig_qkv_offsets = {
|
|
||||||
"q": (0, self.num_heads * self.head_size),
|
|
||||||
"k": (
|
|
||||||
self.num_heads * self.head_size,
|
|
||||||
self.num_kv_heads * self.head_size,
|
|
||||||
),
|
|
||||||
"v": (
|
|
||||||
(self.num_heads + self.num_kv_heads) * self.head_size,
|
|
||||||
self.num_kv_heads * self.head_size,
|
|
||||||
),
|
|
||||||
"total": (
|
|
||||||
(self.num_heads + 2 * self.num_kv_heads) * self.head_size,
|
|
||||||
0,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
shard_size, shard_offset = adjust_bitsandbytes_shard(
|
|
||||||
param, orig_qkv_offsets, loaded_shard_id
|
|
||||||
)
|
|
||||||
|
|
||||||
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
|
||||||
if loaded_shard_id == "q":
|
|
||||||
shard_id = tp_rank
|
|
||||||
else:
|
|
||||||
shard_id = tp_rank // self.num_kv_head_replicas
|
|
||||||
start_idx = shard_id * shard_size
|
|
||||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
|
||||||
# Special case for for AQLM codebooks.
|
|
||||||
elif is_metadata:
|
|
||||||
# metadata indicates fixed size concatenated along dim 0
|
|
||||||
shard_size = loaded_weight.shape[0]
|
|
||||||
shard_index = ["q", "k", "v"].index(loaded_shard_id)
|
|
||||||
param_data = param_data.narrow(0, shard_index * shard_size, shard_size)
|
|
||||||
# Special case for per-tensor scales in fused case.
|
|
||||||
elif needs_scalar_to_array:
|
|
||||||
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
|
||||||
param_data, loaded_weight, loaded_shard_id
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
ignore_warning = getattr(param, "ignore_warning", False)
|
|
||||||
if not ignore_warning:
|
|
||||||
logger.warning(
|
|
||||||
"Loading a weight without `output_dim` attribute in "
|
|
||||||
"QKVParallelLinear, assume the weight is the same "
|
|
||||||
"for all partitions."
|
|
||||||
)
|
|
||||||
|
|
||||||
assert param_data.shape == loaded_weight.shape
|
|
||||||
param_data.copy_(loaded_weight)
|
|
||||||
|
|
||||||
|
|
||||||
class RowParallelLinear(LinearBase):
|
|
||||||
"""Linear layer with row parallelism.
|
|
||||||
|
|
||||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
|
||||||
its first dimension and X along its second dimension as:
|
|
||||||
- -
|
|
||||||
| A_1 |
|
|
||||||
| . |
|
|
||||||
A = | . | X = [X_1, ..., X_p]
|
|
||||||
| . |
|
|
||||||
| A_p |
|
|
||||||
- -
|
|
||||||
Arguments:
|
|
||||||
input_size: first dimension of matrix A.
|
|
||||||
output_size: second dimension of matrix A.
|
|
||||||
bias: If true, add bias. Note that bias is not parallelized.
|
|
||||||
input_is_parallel: If true, we assume that the input is already
|
|
||||||
split across the GPUs and we do not split
|
|
||||||
again.
|
|
||||||
skip_bias_add: This was added to enable performance optimization where
|
|
||||||
bias can be fused with other element-wise operations.
|
|
||||||
We skip adding bias but instead return it.
|
|
||||||
params_dtype: Data type for the parameters.
|
|
||||||
quant_config: Quantization configure.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_size: int,
|
|
||||||
output_size: int,
|
|
||||||
bias: bool = True,
|
|
||||||
input_is_parallel: bool = True,
|
|
||||||
skip_bias_add: bool = False,
|
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
|
||||||
reduce_results: bool = True,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
input_size, output_size, skip_bias_add, params_dtype, quant_config
|
|
||||||
)
|
|
||||||
|
|
||||||
self.input_is_parallel = input_is_parallel
|
|
||||||
self.reduce_results = reduce_results
|
|
||||||
|
|
||||||
# Divide the weight matrix along the last dimension.
|
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
|
||||||
assert self.quant_method is not None
|
|
||||||
self.quant_method.create_weights(
|
|
||||||
layer=self,
|
|
||||||
input_size_per_partition=self.input_size_per_partition,
|
|
||||||
output_partition_sizes=[self.output_size],
|
|
||||||
input_size=self.input_size,
|
|
||||||
output_size=self.output_size,
|
|
||||||
params_dtype=self.params_dtype,
|
|
||||||
weight_loader=self.weight_loader,
|
|
||||||
)
|
|
||||||
if not reduce_results and (bias and not skip_bias_add):
|
|
||||||
raise ValueError(
|
|
||||||
"When not reduce the results, adding bias to the "
|
|
||||||
"results can lead to incorrect results"
|
|
||||||
)
|
|
||||||
|
|
||||||
if bias:
|
|
||||||
self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
|
|
||||||
set_weight_attrs(
|
|
||||||
self.bias,
|
|
||||||
{
|
|
||||||
"output_dim": 0,
|
|
||||||
"weight_loader": self.weight_loader,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.register_parameter("bias", None)
|
|
||||||
|
|
||||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
|
||||||
if param.data.dtype != loaded_weight.dtype:
|
|
||||||
param.data = torch.empty_like(
|
|
||||||
param.data, dtype=loaded_weight.dtype, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
param_data = param.data
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
input_dim = getattr(param, "input_dim", None)
|
|
||||||
if input_dim is not None:
|
|
||||||
shard_size = param.data.shape[input_dim]
|
|
||||||
start_idx = tp_rank * shard_size
|
|
||||||
loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
|
|
||||||
|
|
||||||
# Special case for loading scales off disk, which often do not
|
|
||||||
# have a shape (such as in the case of AutoFP8).
|
|
||||||
if len(loaded_weight.shape) == 0:
|
|
||||||
loaded_weight = loaded_weight.reshape(1)
|
|
||||||
|
|
||||||
assert param_data.shape == loaded_weight.shape
|
|
||||||
param_data.copy_(loaded_weight)
|
|
||||||
|
|
||||||
def forward(self, input_):
|
|
||||||
# Set up backprop all-reduce.
|
|
||||||
if self.input_is_parallel:
|
|
||||||
input_parallel = input_
|
|
||||||
else:
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
splitted_input = split_tensor_along_last_dim(
|
|
||||||
input_, num_partitions=self.tp_size
|
|
||||||
)
|
|
||||||
input_parallel = splitted_input[tp_rank].contiguous()
|
|
||||||
|
|
||||||
# Matrix multiply.
|
|
||||||
assert self.quant_method is not None
|
|
||||||
output_parallel = self.quant_method.apply(self, input_parallel)
|
|
||||||
if self.reduce_results and self.tp_size > 1:
|
|
||||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
|
||||||
else:
|
|
||||||
output_ = output_parallel
|
|
||||||
|
|
||||||
if not self.skip_bias_add:
|
|
||||||
output = output_ + self.bias if self.bias is not None else output_
|
|
||||||
output_bias = None
|
|
||||||
else:
|
|
||||||
output = output_
|
|
||||||
output_bias = self.bias
|
|
||||||
return output, output_bias
|
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
|
||||||
s = f"input_features={self.input_size_per_partition}"
|
|
||||||
s += f", output_features={self.output_size}"
|
|
||||||
s += f", bias={self.bias is not None}"
|
|
||||||
s += f", tp_size={self.tp_size}"
|
|
||||||
s += f", reduce_results={self.reduce_results}"
|
|
||||||
return s
|
|
||||||
@@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
Memory-efficient attention for prefill.
|
||||||
|
It supporst page size = 1.
|
||||||
|
"""
|
||||||
|
|
||||||
# Adapted from
|
# Adapted from
|
||||||
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
|
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
|
||||||
import torch
|
import torch
|
||||||
@@ -1,64 +0,0 @@
|
|||||||
"""
|
|
||||||
Copyright 2023-2024 SGLang Team
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# temporarily adapted from vLLM
|
|
||||||
# FIXME: in progress of refactoring the model loader
|
|
||||||
|
|
||||||
from typing import Dict, Type
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
|
||||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
|
||||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
|
||||||
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
|
||||||
CompressedTensorsConfig,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
|
||||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
|
||||||
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
|
|
||||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config
|
|
||||||
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
|
||||||
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
|
||||||
|
|
||||||
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
||||||
"aqlm": AQLMConfig,
|
|
||||||
"awq": AWQConfig,
|
|
||||||
"deepspeedfp": DeepSpeedFPConfig,
|
|
||||||
"fp8": Fp8Config,
|
|
||||||
# The order of gptq methods is important for config.py iteration over
|
|
||||||
# override_quantization_method(..)
|
|
||||||
"marlin": MarlinConfig,
|
|
||||||
"gptq_marlin_24": GPTQMarlin24Config,
|
|
||||||
"gptq_marlin": GPTQMarlinConfig,
|
|
||||||
"gptq": GPTQConfig,
|
|
||||||
"squeezellm": SqueezeLLMConfig,
|
|
||||||
"compressed-tensors": CompressedTensorsConfig,
|
|
||||||
"bitsandbytes": BitsAndBytesConfig,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|
||||||
if quantization not in QUANTIZATION_METHODS:
|
|
||||||
raise ValueError(f"Invalid quantization method: {quantization}")
|
|
||||||
return QUANTIZATION_METHODS[quantization]
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"QuantizationConfig",
|
|
||||||
"get_quantization_config",
|
|
||||||
"QUANTIZATION_METHODS",
|
|
||||||
]
|
|
||||||
@@ -1,677 +0,0 @@
|
|||||||
"""
|
|
||||||
Copyright 2023-2024 SGLang Team
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# adapted from https://github.com/vllm-project/vllm/blob/e76466dde2bc9525d55165ceaa600d298c7bf773/vllm/model_executor/layers/quantization/fp8.py
|
|
||||||
# FIXME refactor in progress
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.nn import Module
|
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase, fused_moe
|
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
|
||||||
QuantizationConfig,
|
|
||||||
QuantizeMethodBase,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
|
||||||
GPTQ_MARLIN_MAX_PARALLEL,
|
|
||||||
GPTQ_MARLIN_MIN_THREAD_N,
|
|
||||||
GPTQMarlinState,
|
|
||||||
marlin_permute_scales,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import pack_fp8_to_int32
|
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.utils import print_warning_once
|
|
||||||
|
|
||||||
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
|
||||||
|
|
||||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def cutlass_fp8_supported() -> bool:
|
|
||||||
capability = current_platform.get_device_capability()
|
|
||||||
capability = capability[0] * 10 + capability[1]
|
|
||||||
|
|
||||||
return ops.cutlass_scaled_mm_supports_fp8(capability)
|
|
||||||
|
|
||||||
|
|
||||||
class Fp8Config(QuantizationConfig):
|
|
||||||
"""Config class for FP8."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
is_checkpoint_fp8_serialized: bool = False,
|
|
||||||
activation_scheme: str = "dynamic",
|
|
||||||
) -> None:
|
|
||||||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
|
||||||
if is_checkpoint_fp8_serialized:
|
|
||||||
logger.warning(
|
|
||||||
"Detected fp8 checkpoint. Please note that the "
|
|
||||||
"format is experimental and subject to change."
|
|
||||||
)
|
|
||||||
if activation_scheme not in ACTIVATION_SCHEMES:
|
|
||||||
raise ValueError(f"Unsupported activation scheme {activation_scheme}")
|
|
||||||
self.activation_scheme = activation_scheme
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_name(cls) -> str:
|
|
||||||
return "fp8"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
|
||||||
return [torch.bfloat16, torch.half]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_min_capability(cls) -> int:
|
|
||||||
return 80
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config_filenames(cls) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
|
|
||||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
|
||||||
is_checkpoint_fp8_serialized = "fp8" in quant_method
|
|
||||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
|
||||||
return cls(
|
|
||||||
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
|
||||||
activation_scheme=activation_scheme,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_quant_method(
|
|
||||||
self, layer: torch.nn.Module
|
|
||||||
) -> Optional["QuantizeMethodBase"]:
|
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
|
||||||
return Fp8LinearMethod(self)
|
|
||||||
elif isinstance(layer, FusedMoE):
|
|
||||||
return Fp8MoEMethod(self)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
class Fp8LinearMethod(LinearMethodBase):
|
|
||||||
"""Linear method for FP8.
|
|
||||||
Supports loading FP8 checkpoints with static weight scale and
|
|
||||||
dynamic/static activation scale.
|
|
||||||
|
|
||||||
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
|
|
||||||
activation scaling. The weight scaling factor will be initialized after
|
|
||||||
the model weights are loaded.
|
|
||||||
|
|
||||||
Limitations:
|
|
||||||
1. Only support per-tensor quantization due to torch._scaled_mm support.
|
|
||||||
2. Only support float8_e4m3fn data type due to the limitation of
|
|
||||||
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
quant_config: The quantization config.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, quant_config: Fp8Config):
|
|
||||||
self.quant_config = quant_config
|
|
||||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
|
||||||
|
|
||||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
|
||||||
# kernel for fast weight-only FP8 quantization
|
|
||||||
capability = current_platform.get_device_capability()
|
|
||||||
capability = capability[0] * 10 + capability[1]
|
|
||||||
self.use_marlin = capability < 89
|
|
||||||
|
|
||||||
def _create_scale_param(
|
|
||||||
self,
|
|
||||||
scale_name: str,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
output_partition_sizes: List[int],
|
|
||||||
**extra_weight_attrs,
|
|
||||||
) -> None:
|
|
||||||
scale = Parameter(
|
|
||||||
torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
scale[:] = torch.finfo(torch.float8_e4m3fn).min
|
|
||||||
layer.register_parameter(scale_name, scale)
|
|
||||||
set_weight_attrs(
|
|
||||||
scale,
|
|
||||||
{
|
|
||||||
**extra_weight_attrs,
|
|
||||||
"needs_scalar_to_array": True,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_weights(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
input_size_per_partition: int,
|
|
||||||
output_partition_sizes: List[int],
|
|
||||||
input_size: int,
|
|
||||||
output_size: int,
|
|
||||||
params_dtype: torch.dtype,
|
|
||||||
**extra_weight_attrs,
|
|
||||||
):
|
|
||||||
del input_size, output_size
|
|
||||||
output_size_per_partition = sum(output_partition_sizes)
|
|
||||||
|
|
||||||
layer.process_after_load = True
|
|
||||||
layer.logical_widths = output_partition_sizes
|
|
||||||
|
|
||||||
layer.input_size_per_partition = input_size_per_partition
|
|
||||||
layer.output_size_per_partition = output_size_per_partition
|
|
||||||
layer.orig_dtype = params_dtype
|
|
||||||
|
|
||||||
# WEIGHT
|
|
||||||
# weight_dtype = (torch.float8_e4m3fn
|
|
||||||
# if self.quant_config.is_checkpoint_fp8_serialized else
|
|
||||||
# params_dtype)
|
|
||||||
weight_dtype = torch.float8_e4m3fn
|
|
||||||
weight = Parameter(
|
|
||||||
torch.empty(
|
|
||||||
output_size_per_partition, input_size_per_partition, dtype=weight_dtype
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("weight", weight)
|
|
||||||
set_weight_attrs(
|
|
||||||
weight,
|
|
||||||
{
|
|
||||||
**extra_weight_attrs,
|
|
||||||
"input_dim": 1,
|
|
||||||
"output_dim": 0,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# If checkpoint is serialized fp8, load them.
|
|
||||||
# Otherwise, wait until process_weights_after_loading.
|
|
||||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
# WEIGHT SCALE
|
|
||||||
self._create_scale_param(
|
|
||||||
scale_name="weight_scale",
|
|
||||||
layer=layer,
|
|
||||||
output_partition_sizes=output_partition_sizes,
|
|
||||||
**extra_weight_attrs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# INPUT ACTIVATION SCALE
|
|
||||||
if self.quant_config.activation_scheme == "static":
|
|
||||||
self._create_scale_param(
|
|
||||||
scale_name="input_scale",
|
|
||||||
layer=layer,
|
|
||||||
output_partition_sizes=output_partition_sizes,
|
|
||||||
**extra_weight_attrs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# For GPUs without FP8 hardware support, we use Marlin for fast
|
|
||||||
# fused dequantization
|
|
||||||
if self.use_marlin:
|
|
||||||
layer.marlin_state = GPTQMarlinState.REPACK
|
|
||||||
|
|
||||||
def prepare_layer_for_marlin(self, layer: Module) -> None:
|
|
||||||
print_warning_once(
|
|
||||||
"Your GPU does not have native support for FP8 computation but "
|
|
||||||
"FP8 quantization is being used. Weight-only FP8 compression will "
|
|
||||||
"be used leveraging the Marlin kernel. This may degrade "
|
|
||||||
"performance for compute-heavy workloads."
|
|
||||||
)
|
|
||||||
|
|
||||||
part_size_n = layer.output_size_per_partition
|
|
||||||
part_size_k = layer.input_size_per_partition
|
|
||||||
|
|
||||||
assert layer.marlin_state == GPTQMarlinState.REPACK
|
|
||||||
layer.marlin_state = GPTQMarlinState.READY
|
|
||||||
|
|
||||||
device = layer.weight.device
|
|
||||||
|
|
||||||
# WEIGHTS
|
|
||||||
# Repack weights to gptq format (packed int32 elements)
|
|
||||||
packed_gptq_qweight = pack_fp8_to_int32(layer.weight)
|
|
||||||
|
|
||||||
# Repack weights to marlin format
|
|
||||||
marlin_qweight = ops.gptq_marlin_repack(
|
|
||||||
b_q_weight=packed_gptq_qweight,
|
|
||||||
perm=torch.empty(0, dtype=torch.int, device=device),
|
|
||||||
size_k=part_size_k,
|
|
||||||
size_n=part_size_n,
|
|
||||||
num_bits=8,
|
|
||||||
)
|
|
||||||
layer.weight = Parameter(marlin_qweight, requires_grad=False)
|
|
||||||
|
|
||||||
# WEIGHT SCALES
|
|
||||||
# Currently Marlin doesn't support per-tensor scales, so we
|
|
||||||
# expand it to channelwise
|
|
||||||
scales = (
|
|
||||||
layer.weight_scale.repeat(1, part_size_n).to(layer.orig_dtype).to(device)
|
|
||||||
)
|
|
||||||
# Permute scales
|
|
||||||
marlin_scales = marlin_permute_scales(
|
|
||||||
s=scales,
|
|
||||||
size_k=part_size_k,
|
|
||||||
size_n=part_size_n,
|
|
||||||
group_size=-1,
|
|
||||||
num_bits=8,
|
|
||||||
)
|
|
||||||
layer.weight_scale = Parameter(marlin_scales, requires_grad=False)
|
|
||||||
|
|
||||||
# Allocate marlin workspace
|
|
||||||
max_workspace_size = (
|
|
||||||
part_size_n // GPTQ_MARLIN_MIN_THREAD_N
|
|
||||||
) * GPTQ_MARLIN_MAX_PARALLEL
|
|
||||||
workspace = torch.zeros(
|
|
||||||
max_workspace_size, dtype=torch.int, device=device, requires_grad=False
|
|
||||||
)
|
|
||||||
|
|
||||||
layer.workspace = workspace
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
|
||||||
if not hasattr(layer, "process_after_load") or not layer.process_after_load:
|
|
||||||
return
|
|
||||||
|
|
||||||
# If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
|
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
|
|
||||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
|
||||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
|
||||||
layer.logical_widths = None
|
|
||||||
layer.input_scale = None
|
|
||||||
if self.use_marlin:
|
|
||||||
self.prepare_layer_for_marlin(layer)
|
|
||||||
return
|
|
||||||
|
|
||||||
# If checkpoint is fp8, requantize the separately quantized logical
|
|
||||||
# weights into a single fp8 weight with a single weight scale.
|
|
||||||
else:
|
|
||||||
# WEIGHT_SCALE / WEIGHT
|
|
||||||
# Loop over logical weights, requantizing with single scale.
|
|
||||||
max_w_scale = layer.weight_scale.max()
|
|
||||||
|
|
||||||
# QKV / MLP is fused in the on disk checkpoint if any of the
|
|
||||||
# weight scales are still set to the default since we initialize
|
|
||||||
# N weight scales for N shards but we only load 1 weight scale
|
|
||||||
# from disk in this case. As a result, we skip dequant -> requant
|
|
||||||
# since we already have quantized QKV together.
|
|
||||||
# Sample Model with fused checkpoint:
|
|
||||||
# * nm-testing/Phi-3-mini-128k-instruct-FP8
|
|
||||||
unfused_module_in_checkpoint = (
|
|
||||||
layer.weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min
|
|
||||||
)
|
|
||||||
|
|
||||||
if unfused_module_in_checkpoint:
|
|
||||||
start = 0
|
|
||||||
for idx, logical_width in enumerate(layer.logical_widths):
|
|
||||||
end = start + logical_width
|
|
||||||
weight_dq = per_tensor_dequantize(
|
|
||||||
layer.weight[start:end, :], layer.weight_scale[idx]
|
|
||||||
)
|
|
||||||
|
|
||||||
layer.weight[start:end, :] = per_tensor_quantize(
|
|
||||||
weight_dq, layer.weight_scale.max()
|
|
||||||
)
|
|
||||||
start = end
|
|
||||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
|
||||||
|
|
||||||
# WEIGHT
|
|
||||||
# Transpose weight for passing to torch._scaled_mm
|
|
||||||
weight = layer.weight
|
|
||||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
|
||||||
|
|
||||||
# INPUT ACTIVATION SCALE
|
|
||||||
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
|
|
||||||
# Static: set to max of the input_scales (since they are equal).
|
|
||||||
if self.quant_config.activation_scheme == "dynamic":
|
|
||||||
layer.input_scale = None
|
|
||||||
elif self.quant_config.activation_scheme == "static":
|
|
||||||
layer.input_scale = Parameter(
|
|
||||||
layer.input_scale.max(), requires_grad=False
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown scheme {self.quant_config.activation_scheme}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.use_marlin:
|
|
||||||
self.prepare_layer_for_marlin(layer)
|
|
||||||
|
|
||||||
def apply(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
bias: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
if self.use_marlin:
|
|
||||||
# For GPUs that lack FP8 hardware support, we can leverage the
|
|
||||||
# Marlin kernel for fast weight-only FP8 quantization
|
|
||||||
|
|
||||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
|
||||||
out_shape = x.shape[:-1] + (layer.output_size_per_partition,)
|
|
||||||
|
|
||||||
output = ops.fp8_marlin_gemm(
|
|
||||||
a=reshaped_x,
|
|
||||||
b_q_weight=layer.weight,
|
|
||||||
b_scales=layer.weight_scale,
|
|
||||||
workspace=layer.workspace,
|
|
||||||
num_bits=8,
|
|
||||||
size_m=reshaped_x.shape[0],
|
|
||||||
size_n=layer.output_size_per_partition,
|
|
||||||
size_k=layer.input_size_per_partition,
|
|
||||||
)
|
|
||||||
|
|
||||||
if bias is not None:
|
|
||||||
output.add_(bias) # In-place add
|
|
||||||
|
|
||||||
return output.reshape(out_shape)
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
|
||||||
# If dynamic, layer.input_scale is None and x_scale computed from x
|
|
||||||
# If static, layer.input_scale is scalar and x_scale is input_scale
|
|
||||||
|
|
||||||
if bias is None and self.cutlass_fp8_supported:
|
|
||||||
qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
|
|
||||||
|
|
||||||
# Fused GEMM_DQ
|
|
||||||
output = ops.cutlass_scaled_mm(
|
|
||||||
qinput,
|
|
||||||
layer.weight,
|
|
||||||
out_dtype=x.dtype,
|
|
||||||
scale_a=x_scale,
|
|
||||||
scale_b=layer.weight_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
qinput, x_scale = ops.scaled_fp8_quant(
|
|
||||||
x, layer.input_scale, batch_dim_padding=17
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fused GEMM_DQ -- note we padded the input above because
|
|
||||||
# torch._scaled_mm is more performant for matrices with
|
|
||||||
# batch dimension > 16. Note that this could change
|
|
||||||
# in the future.
|
|
||||||
output, _ = torch._scaled_mm(
|
|
||||||
qinput,
|
|
||||||
layer.weight,
|
|
||||||
out_dtype=x.dtype,
|
|
||||||
scale_a=x_scale,
|
|
||||||
scale_b=layer.weight_scale,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
return torch.narrow(output, 0, 0, x.shape[0])
|
|
||||||
|
|
||||||
|
|
||||||
class Fp8MoEMethod(FusedMoEMethodBase):
|
|
||||||
"""MoE method for FP8.
|
|
||||||
Supports loading FP8 checkpoints with static weight scale and
|
|
||||||
dynamic/static activation scale.
|
|
||||||
|
|
||||||
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
|
|
||||||
activation scaling. The weight scaling factor will be initialized after
|
|
||||||
the model weights are loaded.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
quant_config: The quantization config.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, quant_config: Fp8Config):
|
|
||||||
self.quant_config = quant_config
|
|
||||||
|
|
||||||
def create_weights(
|
|
||||||
self,
|
|
||||||
layer: Module,
|
|
||||||
num_experts: int,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
params_dtype: torch.dtype,
|
|
||||||
**extra_weight_attrs,
|
|
||||||
):
|
|
||||||
|
|
||||||
layer.process_after_load = True
|
|
||||||
|
|
||||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
params_dtype = torch.float8_e4m3fn
|
|
||||||
|
|
||||||
# WEIGHTS
|
|
||||||
w13_weight = torch.nn.Parameter(
|
|
||||||
torch.empty(
|
|
||||||
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w13_weight", w13_weight)
|
|
||||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
|
||||||
|
|
||||||
w2_weight = torch.nn.Parameter(
|
|
||||||
torch.empty(
|
|
||||||
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
layer.register_parameter("w2_weight", w2_weight)
|
|
||||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
|
||||||
|
|
||||||
# WEIGHT_SCALES
|
|
||||||
# Allocate 2 scales for w1 and w3 respectively.
|
|
||||||
# They will be combined to a single scale after weight loading.
|
|
||||||
w13_scale = torch.nn.Parameter(
|
|
||||||
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
|
||||||
)
|
|
||||||
layer.register_parameter("w13_scale", w13_scale)
|
|
||||||
|
|
||||||
w2_scale = torch.nn.Parameter(
|
|
||||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
|
||||||
)
|
|
||||||
layer.register_parameter("w2_scale", w2_scale)
|
|
||||||
|
|
||||||
# If loading fp8 checkpoint, pass the weight loaders.
|
|
||||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
|
||||||
# process_weights_after_loading()
|
|
||||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
set_weight_attrs(w13_scale, extra_weight_attrs)
|
|
||||||
set_weight_attrs(w2_scale, extra_weight_attrs)
|
|
||||||
|
|
||||||
# INPUT_SCALES
|
|
||||||
if self.quant_config.activation_scheme == "static":
|
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
raise ValueError(
|
|
||||||
"Found static activation scheme for checkpoint that "
|
|
||||||
"was not serialized fp8."
|
|
||||||
)
|
|
||||||
|
|
||||||
a13_scale = torch.nn.Parameter(
|
|
||||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
|
||||||
)
|
|
||||||
layer.register_parameter("a13_scale", a13_scale)
|
|
||||||
set_weight_attrs(a13_scale, extra_weight_attrs)
|
|
||||||
|
|
||||||
a2_scale = torch.nn.Parameter(
|
|
||||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
|
||||||
)
|
|
||||||
layer.register_parameter("a2_scale", a2_scale)
|
|
||||||
set_weight_attrs(a2_scale, extra_weight_attrs)
|
|
||||||
else:
|
|
||||||
layer.a13_scale = None
|
|
||||||
layer.a2_scale = None
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
|
||||||
if not hasattr(layer, "process_after_load") or not layer.process_after_load:
|
|
||||||
return
|
|
||||||
|
|
||||||
# If checkpoint is fp16, quantize in place.
|
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
|
||||||
w13_weight = torch.empty_like(
|
|
||||||
layer.w13_weight.data, dtype=torch.float8_e4m3fn
|
|
||||||
)
|
|
||||||
w2_weight = torch.empty_like(
|
|
||||||
layer.w2_weight.data, dtype=torch.float8_e4m3fn
|
|
||||||
)
|
|
||||||
|
|
||||||
# Re-initialize w13_scale because we directly quantize
|
|
||||||
# merged w13 weights and generate a single scaling factor.
|
|
||||||
layer.w13_scale = torch.nn.Parameter(
|
|
||||||
torch.ones(
|
|
||||||
layer.num_experts, dtype=torch.float32, device=w13_weight.device
|
|
||||||
),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
for expert in range(layer.num_experts):
|
|
||||||
w13_weight[expert, :, :], layer.w13_scale[expert] = (
|
|
||||||
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
|
||||||
)
|
|
||||||
w2_weight[expert, :, :], layer.w2_scale[expert] = ops.scaled_fp8_quant(
|
|
||||||
layer.w2_weight.data[expert, :, :]
|
|
||||||
)
|
|
||||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
|
||||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
|
||||||
return
|
|
||||||
|
|
||||||
# If checkpoint is fp8, we need to handle that the
|
|
||||||
# MoE kernels require single activation scale and single weight
|
|
||||||
# scale for w13 per expert.
|
|
||||||
else:
|
|
||||||
# Fp8 moe kernels require a single activation scale.
|
|
||||||
# We take the max of all the scales in case they differ.
|
|
||||||
if self.quant_config.activation_scheme == "static":
|
|
||||||
if layer.a13_scale is None or layer.a2_scale is None:
|
|
||||||
raise ValueError(
|
|
||||||
"QuantConfig has static quantization, but found "
|
|
||||||
"activation scales are None."
|
|
||||||
)
|
|
||||||
if not all_close_1d(layer.a13_scale) or not all_close_1d(
|
|
||||||
layer.a2_scale
|
|
||||||
):
|
|
||||||
print_warning_once(
|
|
||||||
"Found input_scales that are not equal for "
|
|
||||||
"fp8 MoE layer. Using the maximum across experts "
|
|
||||||
"for each layer. "
|
|
||||||
)
|
|
||||||
layer.a13_scale = torch.nn.Parameter(
|
|
||||||
layer.a13_scale.max(), requires_grad=False
|
|
||||||
)
|
|
||||||
layer.a2_scale = torch.nn.Parameter(
|
|
||||||
layer.a2_scale.max(), requires_grad=False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
|
||||||
# We take the max then dequant and requant each expert.
|
|
||||||
assert layer.w13_scale is not None
|
|
||||||
shard_size = layer.intermediate_size_per_partition
|
|
||||||
max_w13_scales = layer.w13_scale.max(dim=1).values
|
|
||||||
for expert_id in range(layer.num_experts):
|
|
||||||
start = 0
|
|
||||||
for shard_id in range(2):
|
|
||||||
dq_weight = per_tensor_dequantize(
|
|
||||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
|
||||||
layer.w13_scale[expert_id][shard_id],
|
|
||||||
)
|
|
||||||
layer.w13_weight[expert_id][start : start + shard_size, :] = (
|
|
||||||
per_tensor_quantize(dq_weight, max_w13_scales[expert_id])
|
|
||||||
)
|
|
||||||
start += shard_size
|
|
||||||
|
|
||||||
layer.w13_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
|
|
||||||
return
|
|
||||||
|
|
||||||
def apply(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
router_logits: torch.Tensor,
|
|
||||||
top_k: int,
|
|
||||||
renormalize: bool = True,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
return fused_moe(
|
|
||||||
x,
|
|
||||||
layer.w13_weight,
|
|
||||||
layer.w2_weight,
|
|
||||||
router_logits,
|
|
||||||
top_k,
|
|
||||||
renormalize=renormalize,
|
|
||||||
inplace=True,
|
|
||||||
use_fp8=True,
|
|
||||||
w1_scale=layer.w13_scale,
|
|
||||||
w2_scale=layer.w2_scale,
|
|
||||||
a1_scale=layer.a13_scale,
|
|
||||||
a2_scale=layer.a2_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# FIXME: not used
|
|
||||||
class Fp8KVCacheMethod(QuantizeMethodBase):
|
|
||||||
"""Supports loading kv-cache scaling factors from FP8 checkpoints."""
|
|
||||||
|
|
||||||
def __init__(self, quant_config: Fp8Config):
|
|
||||||
self.quant_config = quant_config
|
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module):
|
|
||||||
"""Create "weight" (aka kv_scale) for an attention layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer: The layer that is using the QuantizeMethodBase factory.
|
|
||||||
"""
|
|
||||||
# Initialize the KV cache scale to 1.0 as the default value.
|
|
||||||
# If the kv_scale appears in the checkpoint, it will be
|
|
||||||
# overwritten when loading weights.
|
|
||||||
layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False)
|
|
||||||
|
|
||||||
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
|
||||||
raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
|
||||||
# If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
|
|
||||||
# regardless whether the kv-scale is available in the checkpoint.
|
|
||||||
if layer.kv_cache_dtype != "auto":
|
|
||||||
kv_scale = layer.kv_scale.to("cpu").tolist()
|
|
||||||
if not isinstance(kv_scale, float):
|
|
||||||
raise ValueError(
|
|
||||||
"Only support per-tensor scaling factor " "for fp8 KV cache"
|
|
||||||
)
|
|
||||||
layer._kv_scale = kv_scale
|
|
||||||
if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
|
|
||||||
print_warning_once(
|
|
||||||
"Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
|
|
||||||
"cause accuracy issues. Please make sure kv-cache scaling "
|
|
||||||
"factor is available in the fp8 checkpoint."
|
|
||||||
)
|
|
||||||
del layer.kv_scale
|
|
||||||
|
|
||||||
|
|
||||||
def per_tensor_quantize(
|
|
||||||
tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor]
|
|
||||||
) -> torch.Tensor:
|
|
||||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
|
||||||
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
|
|
||||||
return qweight.to(torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
|
|
||||||
def per_tensor_dequantize(
|
|
||||||
tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor]
|
|
||||||
) -> torch.Tensor:
|
|
||||||
fake_qweight = tensor.to(torch.float16)
|
|
||||||
dq_weight = fake_qweight * inv_scale
|
|
||||||
return dq_weight
|
|
||||||
|
|
||||||
|
|
||||||
def all_close_1d(x: torch.Tensor) -> bool:
|
|
||||||
assert len(x.shape) == 1
|
|
||||||
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
|
||||||
@@ -20,8 +20,8 @@ from flashinfer.cascade import merge_state
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
|
from sglang.srt.layers.decode_attention import decode_attention_fwd
|
||||||
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
||||||
from sglang.srt.layers.token_attention import token_attention_fwd
|
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||||
from sglang.srt.model_executor.model_runner import global_server_args_dict
|
from sglang.srt.model_executor.model_runner import global_server_args_dict
|
||||||
|
|
||||||
@@ -95,7 +95,7 @@ class RadixAttention(nn.Module):
|
|||||||
o = torch.empty_like(q)
|
o = torch.empty_like(q)
|
||||||
self.store_kv_cache(k, v, input_metadata)
|
self.store_kv_cache(k, v, input_metadata)
|
||||||
|
|
||||||
token_attention_fwd(
|
decode_attention_fwd(
|
||||||
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
|
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
|
||||||
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
||||||
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
||||||
|
|||||||
Reference in New Issue
Block a user