Clean up the comments and names under python/sglang/srt/layers (#1047)
This commit is contained in:
@@ -11,6 +11,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""Fused operators for activation layers."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
@@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""
|
||||
Memory-efficient attention for decoding.
|
||||
"""
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py
|
||||
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py
|
||||
@@ -194,7 +198,7 @@ def _fwd_kernel_stage2(
|
||||
tl.store(out_ptrs, acc)
|
||||
|
||||
|
||||
def _token_att_m_fwd(
|
||||
def _decode_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
att_out,
|
||||
@@ -254,7 +258,7 @@ def _token_att_m_fwd(
|
||||
)
|
||||
|
||||
|
||||
def _token_softmax_reducev_fwd(
|
||||
def _decode_softmax_reducev_fwd(
|
||||
logics,
|
||||
v_buffer,
|
||||
o,
|
||||
@@ -292,7 +296,7 @@ def _token_softmax_reducev_fwd(
|
||||
)
|
||||
|
||||
|
||||
def token_attention_fwd(
|
||||
def decode_attention_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
@@ -312,7 +316,7 @@ def token_attention_fwd(
|
||||
(q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda"
|
||||
)
|
||||
|
||||
_token_att_m_fwd(
|
||||
_decode_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
att_m,
|
||||
@@ -324,7 +328,7 @@ def token_attention_fwd(
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
)
|
||||
_token_softmax_reducev_fwd(
|
||||
_decode_softmax_reducev_fwd(
|
||||
att_m,
|
||||
v_buffer,
|
||||
o,
|
||||
@@ -13,11 +13,16 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""
|
||||
Memory-efficient attention for prefill.
|
||||
It supporst page size = 1 and prefill with KV cache (i.e. extend).
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
|
||||
from sglang.srt.layers.prefill_attention import context_attention_fwd
|
||||
|
||||
CUDA_CAPABILITY = torch.cuda.get_device_capability()
|
||||
|
||||
|
||||
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""Fused operators for normalization layers."""
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -1,884 +0,0 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
# temporarily adapted from https://github.com/vllm-project/vllm/blob/e76466dde2bc9525d55165ceaa600d298c7bf773/vllm/model_executor/layers/linear.py
|
||||
# FIXME: refactor the linear abstraction
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.distributed import (
|
||||
divide,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def adjust_marlin_shard(param, shard_size, shard_offset):
|
||||
marlin_tile_size = getattr(param, "marlin_tile_size", None)
|
||||
if marlin_tile_size is None:
|
||||
return shard_size, shard_offset
|
||||
|
||||
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
|
||||
|
||||
|
||||
def adjust_bitsandbytes_shard(
|
||||
param: Parameter, qkv_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
|
||||
) -> Tuple[int, int]:
|
||||
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
|
||||
|
||||
total, _ = qkv_offsets["total"]
|
||||
orig_offset, orig_size = qkv_offsets[loaded_shard_id]
|
||||
|
||||
quantized_total = param.data.shape[0]
|
||||
quantized_offset = orig_offset * quantized_total // total
|
||||
quantized_size = orig_size * quantized_total // total
|
||||
|
||||
return quantized_size, quantized_offset
|
||||
|
||||
|
||||
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
|
||||
"""For fused modules (QKV and MLP) we have an array of length
|
||||
N that holds 1 scale for each "logical" matrix. So the param
|
||||
is an array of length N. The loaded_weight corresponds to
|
||||
one of the shards on disk. Here, we slice the param based on
|
||||
the shard_id for loading.
|
||||
"""
|
||||
qkv_idxs = {"q": 0, "k": 1, "v": 2}
|
||||
|
||||
if isinstance(shard_id, str):
|
||||
shard_id = qkv_idxs[shard_id]
|
||||
elif not isinstance(shard_id, int):
|
||||
raise ValueError(f"Unknown Shard Id {shard_id}")
|
||||
|
||||
# AutoFP8 scales do not have a shape
|
||||
# compressed-tensors scales do have a shape
|
||||
if len(loaded_weight.shape) != 0:
|
||||
assert loaded_weight.shape[0] == 1
|
||||
loaded_weight = loaded_weight[0]
|
||||
|
||||
return param[shard_id], loaded_weight
|
||||
|
||||
|
||||
class LinearMethodBase(QuantizeMethodBase):
|
||||
"""Base class for different (maybe quantized) linear methods."""
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
"""Create weights for a linear layer.
|
||||
The weights will be set as attributes of the layer.
|
||||
|
||||
Args:
|
||||
layer: The layer that is using the LinearMethodBase factory.
|
||||
input_size_per_partition: Size of the weight input dim on rank X.
|
||||
output_partition_sizes: Sizes of the output dim of each logical
|
||||
weight on rank X. E.g., output_partition_sizes for QKVLinear
|
||||
is a list contains the width of Wq, Wk, Wv on rank X.
|
||||
input_size: Size of the input dim of the weight across all ranks.
|
||||
output_size: Size of the output dim of the weight across all ranks.
|
||||
params_dtype: Datatype of the parameters.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Apply the weights in layer to the input tensor.
|
||||
Expects create_weights to have been called before on the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class UnquantizedLinearMethod(LinearMethodBase):
|
||||
"""Linear method without quantization.
|
||||
|
||||
Args:
|
||||
separate_bias_add: If true, add bias separately after matrix
|
||||
multiplication.
|
||||
"""
|
||||
|
||||
def __init__(self, separate_bias_add: bool = False):
|
||||
self.separate_bias_add = separate_bias_add
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
weight = Parameter(
|
||||
torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
weight = layer.weight
|
||||
if self.separate_bias_add:
|
||||
if bias is not None:
|
||||
return F.linear(x, weight) + bias
|
||||
return F.linear(x, weight)
|
||||
return F.linear(x, weight, bias)
|
||||
|
||||
|
||||
class LinearBase(torch.nn.Module):
|
||||
"""Base linear layer.
|
||||
|
||||
Args:
|
||||
input_size: input dimension of the linear layer.
|
||||
output_size: output dimension of the linear layer.
|
||||
bias: If true, add bias.
|
||||
skip_bias_add: If true, skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
quant_config: Quantization configure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.skip_bias_add = skip_bias_add
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod()
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ReplicatedLinear(LinearBase):
|
||||
"""Replicated linear layer.
|
||||
|
||||
Args:
|
||||
input_size: input dimension of the linear layer.
|
||||
output_size: output dimension of the linear layer.
|
||||
bias: If true, add bias.
|
||||
skip_bias_add: If true, skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
quant_config: Quantization configure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__(
|
||||
input_size, output_size, skip_bias_add, params_dtype, quant_config
|
||||
)
|
||||
|
||||
# All the linear layer supports quant method.
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
self,
|
||||
self.input_size,
|
||||
[self.output_size],
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size, dtype=self.params_dtype)
|
||||
)
|
||||
set_weight_attrs(self.bias, {"output_dim": 0})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
assert self.quant_method is not None
|
||||
output = self.quant_method.apply(self, x, bias)
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"in_features={self.input_size}"
|
||||
s += f", output_features={self.output_size}"
|
||||
s += f", bias={self.bias is not None}"
|
||||
return s
|
||||
|
||||
|
||||
class ColumnParallelLinear(LinearBase):
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
its second dimension as A = [A_1, ..., A_p].
|
||||
|
||||
Args:
|
||||
input_size: first dimension of matrix A.
|
||||
output_size: second dimension of matrix A.
|
||||
bias: If true, add bias.
|
||||
gather_output: If true, call all-gather on output and make Y available
|
||||
to all GPUs, otherwise, every GPU will have its output
|
||||
which is Y_i = XA_i
|
||||
skip_bias_add: This was added to enable performance optimizations where
|
||||
bias can be fused with other element-wise operations. we
|
||||
skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
quant_config: Quantization configure.
|
||||
output_sizes: list of output sizes packed into one output, like for QKV
|
||||
the list would be size 3.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
output_sizes: Optional[List[int]] = None,
|
||||
):
|
||||
super().__init__(
|
||||
input_size, output_size, skip_bias_add, params_dtype, quant_config
|
||||
)
|
||||
|
||||
self.gather_output = gather_output
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert self.quant_method is not None
|
||||
self.output_size_per_partition = divide(self.output_size, tp_size)
|
||||
self.output_partition_sizes = [self.output_size_per_partition]
|
||||
# If QKV or MergedColumn, use output size of each partition.
|
||||
if hasattr(self, "output_sizes"):
|
||||
self.output_partition_sizes = [
|
||||
divide(output_size, tp_size) for output_size in self.output_sizes
|
||||
]
|
||||
|
||||
if output_sizes is None:
|
||||
output_sizes = [output_size]
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size,
|
||||
output_partition_sizes=self.output_partition_sizes,
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=self.weight_loader,
|
||||
)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition, dtype=params_dtype)
|
||||
)
|
||||
set_weight_attrs(
|
||||
self.bias,
|
||||
{
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
if param.data.dtype != loaded_weight.dtype:
|
||||
param.data = torch.empty_like(
|
||||
param.data, dtype=loaded_weight.dtype, device="cuda"
|
||||
)
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
param_data = param.data
|
||||
if output_dim is not None:
|
||||
shard_size = param_data.shape[output_dim]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
|
||||
# Special case for loading scales off disk, which often do not
|
||||
# have a shape (such as in the case of AutoFP8).
|
||||
if len(loaded_weight.shape) == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, input_):
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
output_parallel = self.quant_method.apply(self, input_, bias)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"in_features={self.input_size}"
|
||||
s += f", output_features={self.output_size_per_partition}"
|
||||
s += f", bias={self.bias is not None}"
|
||||
s += f", tp_size={get_tensor_model_parallel_world_size()}"
|
||||
s += f", gather_output={self.gather_output}"
|
||||
return s
|
||||
|
||||
|
||||
class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
"""Packed linear layers with column parallelism.
|
||||
|
||||
Similar to ColumnParallelLinear, but the weight matrix is concatenated
|
||||
along the output dimension. When the weight matrix is loaded, the
|
||||
different partitions are sharded separately.
|
||||
|
||||
Args:
|
||||
input_size: input dimension of the linear layer.
|
||||
output_sizes: list of output dimensions of the linear layer.
|
||||
bias: If true, add bias.
|
||||
gather_output: If true, call all-gather on output and make the output
|
||||
available to all GPUs, otherwise, every GPU will have
|
||||
its own output.
|
||||
skip_bias_add: This was added to enable performance optimizations where
|
||||
bias can be fused with other element-wise operations. we
|
||||
skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
quant_config: Quantization configure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_sizes: List[int],
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
self.output_sizes = output_sizes
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
||||
super().__init__(
|
||||
input_size=input_size,
|
||||
output_size=sum(output_sizes),
|
||||
bias=bias,
|
||||
gather_output=gather_output,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def weight_loader(
|
||||
self,
|
||||
param: Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[int] = None,
|
||||
):
|
||||
if param.data.dtype != loaded_weight.dtype:
|
||||
param.data = torch.empty_like(
|
||||
param.data, dtype=loaded_weight.dtype, device="cuda"
|
||||
)
|
||||
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
# Special case for AQLM codebooks.
|
||||
is_metadata = getattr(param, "is_metadata", False)
|
||||
# Special case for per-tensor scale to load scalar into fused array.
|
||||
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
|
||||
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already fused on disk (qkv/mlp).
|
||||
if output_dim is None:
|
||||
if needs_scalar_to_array is not None:
|
||||
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||
param_data, loaded_weight, 0
|
||||
)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
current_shard_offset = 0
|
||||
shard_offsets: List[Tuple[int, int, int]] = []
|
||||
for i, output_size in enumerate(self.output_sizes):
|
||||
shard_offsets.append((i, current_shard_offset, output_size))
|
||||
current_shard_offset += output_size
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
# Special case for Quantization.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
# Special case for Marlin.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset
|
||||
)
|
||||
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
output_dim, shard_offset, shard_size
|
||||
)
|
||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||
return
|
||||
|
||||
assert loaded_shard_id < len(self.output_sizes)
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if output_dim is not None:
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
||||
# Special case for quantization.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
# Special case for Marlin.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset
|
||||
)
|
||||
|
||||
use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
|
||||
if use_bitsandbytes:
|
||||
shard_size = loaded_weight.shape[output_dim]
|
||||
shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
|
||||
|
||||
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
# Special case for AQLM codebooks.
|
||||
elif is_metadata:
|
||||
# metadata indicates fixed size concatenated along dim 0
|
||||
shard_size = loaded_weight.shape[0]
|
||||
shard_offset = loaded_shard_id * shard_size
|
||||
param_data = param_data.narrow(0, shard_offset, shard_size)
|
||||
|
||||
# Special case for per-tensor scales in fused case.
|
||||
elif needs_scalar_to_array:
|
||||
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||
param_data, loaded_weight, loaded_shard_id
|
||||
)
|
||||
|
||||
else:
|
||||
ignore_warning = getattr(param, "ignore_warning", False)
|
||||
if not ignore_warning:
|
||||
logger.warning(
|
||||
"Loading a weight without `output_dim` attribute in "
|
||||
"MergedColumnParallelLinear, assume the weight is "
|
||||
"the same for all partitions."
|
||||
)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class QKVParallelLinear(ColumnParallelLinear):
|
||||
"""Linear layers for the attention's QKV transformation.
|
||||
|
||||
Linear layers for the linear transformation of the query, key, and value
|
||||
vectors in the attention layer. The weight matrix is concatenated along
|
||||
the output dimension. The layer is parallelized along the head dimension.
|
||||
When the number of key/value heads is smaller than the number of query
|
||||
heads (e.g., multi-query/grouped-query attention), the key/value head may
|
||||
be replicated while the query heads are partitioned.
|
||||
|
||||
Args:
|
||||
hidden_size: input hidden state size of the transformer.
|
||||
head_size: size of each attention head.
|
||||
total_num_heads: total number of attention query heads.
|
||||
total_num_kv_heads: total number of attention key/value heads. If
|
||||
None, assume total_num_kv_heads = total_num_heads.
|
||||
bias: If true, add bias.
|
||||
skip_bias_add: This was added to enable performance optimizations where
|
||||
bias can be fused with other element-wise operations. we
|
||||
skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
quant_config: Quantization configure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
head_size: int,
|
||||
total_num_heads: int,
|
||||
total_num_kv_heads: Optional[int] = None,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = head_size
|
||||
self.total_num_heads = total_num_heads
|
||||
if total_num_kv_heads is None:
|
||||
total_num_kv_heads = total_num_heads
|
||||
self.total_num_kv_heads = total_num_kv_heads
|
||||
# Divide the weight matrix along the last dimension.
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads = divide(self.total_num_heads, tp_size)
|
||||
if tp_size >= self.total_num_kv_heads:
|
||||
self.num_kv_heads = 1
|
||||
self.num_kv_head_replicas = divide(tp_size, self.total_num_kv_heads)
|
||||
else:
|
||||
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
||||
self.num_kv_head_replicas = 1
|
||||
input_size = self.hidden_size
|
||||
output_size = (
|
||||
(self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
|
||||
)
|
||||
self.output_sizes = [
|
||||
self.num_heads * self.head_size * tp_size, # q_proj
|
||||
self.num_kv_heads * self.head_size * tp_size, # k_proj
|
||||
self.num_kv_heads * self.head_size * tp_size, # v_proj
|
||||
]
|
||||
|
||||
super().__init__(
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
bias=bias,
|
||||
gather_output=False,
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def weight_loader(
|
||||
self,
|
||||
param: Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[str] = None,
|
||||
):
|
||||
if param.data.dtype != loaded_weight.dtype:
|
||||
param.data = torch.empty_like(
|
||||
param.data, dtype=loaded_weight.dtype, device="cuda"
|
||||
)
|
||||
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
# Special case for AQLM codebooks.
|
||||
is_metadata = getattr(param, "is_metadata", False)
|
||||
|
||||
# Special case for per-tensor scales in fused case.
|
||||
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
|
||||
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already fused on disk (qkv/mlp).
|
||||
if output_dim is None:
|
||||
if needs_scalar_to_array is not None:
|
||||
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||
param_data, loaded_weight, 0
|
||||
)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
shard_offsets = [
|
||||
# (shard_id, shard_offset, shard_size)
|
||||
("q", 0, self.total_num_heads * self.head_size),
|
||||
(
|
||||
"k",
|
||||
self.total_num_heads * self.head_size,
|
||||
self.total_num_kv_heads * self.head_size,
|
||||
),
|
||||
(
|
||||
"v",
|
||||
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
|
||||
self.total_num_kv_heads * self.head_size,
|
||||
),
|
||||
]
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
# Special case for Quantized Weights.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
|
||||
# Special case for Marlin.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset
|
||||
)
|
||||
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
output_dim, shard_offset, shard_size
|
||||
)
|
||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||
return
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
assert loaded_shard_id in ["q", "k", "v"]
|
||||
|
||||
# If output dim is defined, use the default loading process.
|
||||
if output_dim is not None:
|
||||
if loaded_shard_id == "q":
|
||||
shard_offset = 0
|
||||
shard_size = self.num_heads * self.head_size
|
||||
elif loaded_shard_id == "k":
|
||||
shard_offset = self.num_heads * self.head_size
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
elif loaded_shard_id == "v":
|
||||
shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
# Special case for Quantized Weights.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
|
||||
# Special case for Marlin.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset
|
||||
)
|
||||
|
||||
use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
|
||||
if use_bitsandbytes:
|
||||
orig_qkv_offsets = {
|
||||
"q": (0, self.num_heads * self.head_size),
|
||||
"k": (
|
||||
self.num_heads * self.head_size,
|
||||
self.num_kv_heads * self.head_size,
|
||||
),
|
||||
"v": (
|
||||
(self.num_heads + self.num_kv_heads) * self.head_size,
|
||||
self.num_kv_heads * self.head_size,
|
||||
),
|
||||
"total": (
|
||||
(self.num_heads + 2 * self.num_kv_heads) * self.head_size,
|
||||
0,
|
||||
),
|
||||
}
|
||||
shard_size, shard_offset = adjust_bitsandbytes_shard(
|
||||
param, orig_qkv_offsets, loaded_shard_id
|
||||
)
|
||||
|
||||
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
|
||||
if loaded_shard_id == "q":
|
||||
shard_id = tp_rank
|
||||
else:
|
||||
shard_id = tp_rank // self.num_kv_head_replicas
|
||||
start_idx = shard_id * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
||||
# Special case for for AQLM codebooks.
|
||||
elif is_metadata:
|
||||
# metadata indicates fixed size concatenated along dim 0
|
||||
shard_size = loaded_weight.shape[0]
|
||||
shard_index = ["q", "k", "v"].index(loaded_shard_id)
|
||||
param_data = param_data.narrow(0, shard_index * shard_size, shard_size)
|
||||
# Special case for per-tensor scales in fused case.
|
||||
elif needs_scalar_to_array:
|
||||
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
||||
param_data, loaded_weight, loaded_shard_id
|
||||
)
|
||||
else:
|
||||
ignore_warning = getattr(param, "ignore_warning", False)
|
||||
if not ignore_warning:
|
||||
logger.warning(
|
||||
"Loading a weight without `output_dim` attribute in "
|
||||
"QKVParallelLinear, assume the weight is the same "
|
||||
"for all partitions."
|
||||
)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class RowParallelLinear(LinearBase):
|
||||
"""Linear layer with row parallelism.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
its first dimension and X along its second dimension as:
|
||||
- -
|
||||
| A_1 |
|
||||
| . |
|
||||
A = | . | X = [X_1, ..., X_p]
|
||||
| . |
|
||||
| A_p |
|
||||
- -
|
||||
Arguments:
|
||||
input_size: first dimension of matrix A.
|
||||
output_size: second dimension of matrix A.
|
||||
bias: If true, add bias. Note that bias is not parallelized.
|
||||
input_is_parallel: If true, we assume that the input is already
|
||||
split across the GPUs and we do not split
|
||||
again.
|
||||
skip_bias_add: This was added to enable performance optimization where
|
||||
bias can be fused with other element-wise operations.
|
||||
We skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
quant_config: Quantization configure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
input_is_parallel: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__(
|
||||
input_size, output_size, skip_bias_add, params_dtype, quant_config
|
||||
)
|
||||
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(
|
||||
layer=self,
|
||||
input_size_per_partition=self.input_size_per_partition,
|
||||
output_partition_sizes=[self.output_size],
|
||||
input_size=self.input_size,
|
||||
output_size=self.output_size,
|
||||
params_dtype=self.params_dtype,
|
||||
weight_loader=self.weight_loader,
|
||||
)
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError(
|
||||
"When not reduce the results, adding bias to the "
|
||||
"results can lead to incorrect results"
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
|
||||
set_weight_attrs(
|
||||
self.bias,
|
||||
{
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
if param.data.dtype != loaded_weight.dtype:
|
||||
param.data = torch.empty_like(
|
||||
param.data, dtype=loaded_weight.dtype, device="cuda"
|
||||
)
|
||||
|
||||
param_data = param.data
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
input_dim = getattr(param, "input_dim", None)
|
||||
if input_dim is not None:
|
||||
shard_size = param.data.shape[input_dim]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size)
|
||||
|
||||
# Special case for loading scales off disk, which often do not
|
||||
# have a shape (such as in the case of AutoFP8).
|
||||
if len(loaded_weight.shape) == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, input_):
|
||||
# Set up backprop all-reduce.
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size
|
||||
)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
output_parallel = self.quant_method.apply(self, input_parallel)
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output_ = output_parallel
|
||||
|
||||
if not self.skip_bias_add:
|
||||
output = output_ + self.bias if self.bias is not None else output_
|
||||
output_bias = None
|
||||
else:
|
||||
output = output_
|
||||
output_bias = self.bias
|
||||
return output, output_bias
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"input_features={self.input_size_per_partition}"
|
||||
s += f", output_features={self.output_size}"
|
||||
s += f", bias={self.bias is not None}"
|
||||
s += f", tp_size={self.tp_size}"
|
||||
s += f", reduce_results={self.reduce_results}"
|
||||
return s
|
||||
@@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""
|
||||
Memory-efficient attention for prefill.
|
||||
It supporst page size = 1.
|
||||
"""
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
|
||||
import torch
|
||||
@@ -1,64 +0,0 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
# temporarily adapted from vLLM
|
||||
# FIXME: in progress of refactoring the model loader
|
||||
|
||||
from typing import Dict, Type
|
||||
|
||||
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
||||
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
||||
|
||||
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
"aqlm": AQLMConfig,
|
||||
"awq": AWQConfig,
|
||||
"deepspeedfp": DeepSpeedFPConfig,
|
||||
"fp8": Fp8Config,
|
||||
# The order of gptq methods is important for config.py iteration over
|
||||
# override_quantization_method(..)
|
||||
"marlin": MarlinConfig,
|
||||
"gptq_marlin_24": GPTQMarlin24Config,
|
||||
"gptq_marlin": GPTQMarlinConfig,
|
||||
"gptq": GPTQConfig,
|
||||
"squeezellm": SqueezeLLMConfig,
|
||||
"compressed-tensors": CompressedTensorsConfig,
|
||||
"bitsandbytes": BitsAndBytesConfig,
|
||||
}
|
||||
|
||||
|
||||
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
if quantization not in QUANTIZATION_METHODS:
|
||||
raise ValueError(f"Invalid quantization method: {quantization}")
|
||||
return QUANTIZATION_METHODS[quantization]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"QuantizationConfig",
|
||||
"get_quantization_config",
|
||||
"QUANTIZATION_METHODS",
|
||||
]
|
||||
@@ -1,677 +0,0 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/e76466dde2bc9525d55165ceaa600d298c7bf773/vllm/model_executor/layers/quantization/fp8.py
|
||||
# FIXME refactor in progress
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase, fused_moe
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL,
|
||||
GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQMarlinState,
|
||||
marlin_permute_scales,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import pack_fp8_to_int32
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def cutlass_fp8_supported() -> bool:
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
return ops.cutlass_scaled_mm_supports_fp8(capability)
|
||||
|
||||
|
||||
class Fp8Config(QuantizationConfig):
|
||||
"""Config class for FP8."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_checkpoint_fp8_serialized: bool = False,
|
||||
activation_scheme: str = "dynamic",
|
||||
) -> None:
|
||||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||
if is_checkpoint_fp8_serialized:
|
||||
logger.warning(
|
||||
"Detected fp8 checkpoint. Please note that the "
|
||||
"format is experimental and subject to change."
|
||||
)
|
||||
if activation_scheme not in ACTIVATION_SCHEMES:
|
||||
raise ValueError(f"Unsupported activation scheme {activation_scheme}")
|
||||
self.activation_scheme = activation_scheme
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "fp8"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.bfloat16, torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
|
||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
is_checkpoint_fp8_serialized = "fp8" in quant_method
|
||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||
return cls(
|
||||
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
||||
activation_scheme=activation_scheme,
|
||||
)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
return Fp8LinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return Fp8MoEMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class Fp8LinearMethod(LinearMethodBase):
|
||||
"""Linear method for FP8.
|
||||
Supports loading FP8 checkpoints with static weight scale and
|
||||
dynamic/static activation scale.
|
||||
|
||||
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
|
||||
activation scaling. The weight scaling factor will be initialized after
|
||||
the model weights are loaded.
|
||||
|
||||
Limitations:
|
||||
1. Only support per-tensor quantization due to torch._scaled_mm support.
|
||||
2. Only support float8_e4m3fn data type due to the limitation of
|
||||
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
|
||||
|
||||
Args:
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
self.use_marlin = capability < 89
|
||||
|
||||
def _create_scale_param(
|
||||
self,
|
||||
scale_name: str,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
scale = Parameter(
|
||||
torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
requires_grad=False,
|
||||
)
|
||||
scale[:] = torch.finfo(torch.float8_e4m3fn).min
|
||||
layer.register_parameter(scale_name, scale)
|
||||
set_weight_attrs(
|
||||
scale,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"needs_scalar_to_array": True,
|
||||
},
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del input_size, output_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
layer.process_after_load = True
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
|
||||
# WEIGHT
|
||||
# weight_dtype = (torch.float8_e4m3fn
|
||||
# if self.quant_config.is_checkpoint_fp8_serialized else
|
||||
# params_dtype)
|
||||
weight_dtype = torch.float8_e4m3fn
|
||||
weight = Parameter(
|
||||
torch.empty(
|
||||
output_size_per_partition, input_size_per_partition, dtype=weight_dtype
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(
|
||||
weight,
|
||||
{
|
||||
**extra_weight_attrs,
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
},
|
||||
)
|
||||
|
||||
# If checkpoint is serialized fp8, load them.
|
||||
# Otherwise, wait until process_weights_after_loading.
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# WEIGHT SCALE
|
||||
self._create_scale_param(
|
||||
scale_name="weight_scale",
|
||||
layer=layer,
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
**extra_weight_attrs,
|
||||
)
|
||||
|
||||
# INPUT ACTIVATION SCALE
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
self._create_scale_param(
|
||||
scale_name="input_scale",
|
||||
layer=layer,
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
**extra_weight_attrs,
|
||||
)
|
||||
|
||||
# For GPUs without FP8 hardware support, we use Marlin for fast
|
||||
# fused dequantization
|
||||
if self.use_marlin:
|
||||
layer.marlin_state = GPTQMarlinState.REPACK
|
||||
|
||||
def prepare_layer_for_marlin(self, layer: Module) -> None:
|
||||
print_warning_once(
|
||||
"Your GPU does not have native support for FP8 computation but "
|
||||
"FP8 quantization is being used. Weight-only FP8 compression will "
|
||||
"be used leveraging the Marlin kernel. This may degrade "
|
||||
"performance for compute-heavy workloads."
|
||||
)
|
||||
|
||||
part_size_n = layer.output_size_per_partition
|
||||
part_size_k = layer.input_size_per_partition
|
||||
|
||||
assert layer.marlin_state == GPTQMarlinState.REPACK
|
||||
layer.marlin_state = GPTQMarlinState.READY
|
||||
|
||||
device = layer.weight.device
|
||||
|
||||
# WEIGHTS
|
||||
# Repack weights to gptq format (packed int32 elements)
|
||||
packed_gptq_qweight = pack_fp8_to_int32(layer.weight)
|
||||
|
||||
# Repack weights to marlin format
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=packed_gptq_qweight,
|
||||
perm=torch.empty(0, dtype=torch.int, device=device),
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
num_bits=8,
|
||||
)
|
||||
layer.weight = Parameter(marlin_qweight, requires_grad=False)
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Currently Marlin doesn't support per-tensor scales, so we
|
||||
# expand it to channelwise
|
||||
scales = (
|
||||
layer.weight_scale.repeat(1, part_size_n).to(layer.orig_dtype).to(device)
|
||||
)
|
||||
# Permute scales
|
||||
marlin_scales = marlin_permute_scales(
|
||||
s=scales,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
group_size=-1,
|
||||
num_bits=8,
|
||||
)
|
||||
layer.weight_scale = Parameter(marlin_scales, requires_grad=False)
|
||||
|
||||
# Allocate marlin workspace
|
||||
max_workspace_size = (
|
||||
part_size_n // GPTQ_MARLIN_MIN_THREAD_N
|
||||
) * GPTQ_MARLIN_MAX_PARALLEL
|
||||
workspace = torch.zeros(
|
||||
max_workspace_size, dtype=torch.int, device=device, requires_grad=False
|
||||
)
|
||||
|
||||
layer.workspace = workspace
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if not hasattr(layer, "process_after_load") or not layer.process_after_load:
|
||||
return
|
||||
|
||||
# If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, scale=None)
|
||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
layer.logical_widths = None
|
||||
layer.input_scale = None
|
||||
if self.use_marlin:
|
||||
self.prepare_layer_for_marlin(layer)
|
||||
return
|
||||
|
||||
# If checkpoint is fp8, requantize the separately quantized logical
|
||||
# weights into a single fp8 weight with a single weight scale.
|
||||
else:
|
||||
# WEIGHT_SCALE / WEIGHT
|
||||
# Loop over logical weights, requantizing with single scale.
|
||||
max_w_scale = layer.weight_scale.max()
|
||||
|
||||
# QKV / MLP is fused in the on disk checkpoint if any of the
|
||||
# weight scales are still set to the default since we initialize
|
||||
# N weight scales for N shards but we only load 1 weight scale
|
||||
# from disk in this case. As a result, we skip dequant -> requant
|
||||
# since we already have quantized QKV together.
|
||||
# Sample Model with fused checkpoint:
|
||||
# * nm-testing/Phi-3-mini-128k-instruct-FP8
|
||||
unfused_module_in_checkpoint = (
|
||||
layer.weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min
|
||||
)
|
||||
|
||||
if unfused_module_in_checkpoint:
|
||||
start = 0
|
||||
for idx, logical_width in enumerate(layer.logical_widths):
|
||||
end = start + logical_width
|
||||
weight_dq = per_tensor_dequantize(
|
||||
layer.weight[start:end, :], layer.weight_scale[idx]
|
||||
)
|
||||
|
||||
layer.weight[start:end, :] = per_tensor_quantize(
|
||||
weight_dq, layer.weight_scale.max()
|
||||
)
|
||||
start = end
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
|
||||
# WEIGHT
|
||||
# Transpose weight for passing to torch._scaled_mm
|
||||
weight = layer.weight
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
|
||||
# INPUT ACTIVATION SCALE
|
||||
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
|
||||
# Static: set to max of the input_scales (since they are equal).
|
||||
if self.quant_config.activation_scheme == "dynamic":
|
||||
layer.input_scale = None
|
||||
elif self.quant_config.activation_scheme == "static":
|
||||
layer.input_scale = Parameter(
|
||||
layer.input_scale.max(), requires_grad=False
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown scheme {self.quant_config.activation_scheme}"
|
||||
)
|
||||
|
||||
if self.use_marlin:
|
||||
self.prepare_layer_for_marlin(layer)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if self.use_marlin:
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the
|
||||
# Marlin kernel for fast weight-only FP8 quantization
|
||||
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (layer.output_size_per_partition,)
|
||||
|
||||
output = ops.fp8_marlin_gemm(
|
||||
a=reshaped_x,
|
||||
b_q_weight=layer.weight,
|
||||
b_scales=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
num_bits=8,
|
||||
size_m=reshaped_x.shape[0],
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
else:
|
||||
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x
|
||||
# If static, layer.input_scale is scalar and x_scale is input_scale
|
||||
|
||||
if bias is None and self.cutlass_fp8_supported:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
|
||||
|
||||
# Fused GEMM_DQ
|
||||
output = ops.cutlass_scaled_mm(
|
||||
qinput,
|
||||
layer.weight,
|
||||
out_dtype=x.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
)
|
||||
|
||||
else:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(
|
||||
x, layer.input_scale, batch_dim_padding=17
|
||||
)
|
||||
|
||||
# Fused GEMM_DQ -- note we padded the input above because
|
||||
# torch._scaled_mm is more performant for matrices with
|
||||
# batch dimension > 16. Note that this could change
|
||||
# in the future.
|
||||
output, _ = torch._scaled_mm(
|
||||
qinput,
|
||||
layer.weight,
|
||||
out_dtype=x.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
return torch.narrow(output, 0, 0, x.shape[0])
|
||||
|
||||
|
||||
class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
"""MoE method for FP8.
|
||||
Supports loading FP8 checkpoints with static weight scale and
|
||||
dynamic/static activation scale.
|
||||
|
||||
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
|
||||
activation scaling. The weight scaling factor will be initialized after
|
||||
the model weights are loaded.
|
||||
|
||||
Args:
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
|
||||
layer.process_after_load = True
|
||||
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts, hidden_size, intermediate_size, dtype=params_dtype
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
w13_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
layer.register_parameter("w13_scale", w13_scale)
|
||||
|
||||
w2_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
layer.register_parameter("w2_scale", w2_scale)
|
||||
|
||||
# If loading fp8 checkpoint, pass the weight loaders.
|
||||
# If loading an fp16 checkpoint, do not (we will quantize in
|
||||
# process_weights_after_loading()
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
set_weight_attrs(w13_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_scale, extra_weight_attrs)
|
||||
|
||||
# INPUT_SCALES
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
raise ValueError(
|
||||
"Found static activation scheme for checkpoint that "
|
||||
"was not serialized fp8."
|
||||
)
|
||||
|
||||
a13_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
layer.register_parameter("a13_scale", a13_scale)
|
||||
set_weight_attrs(a13_scale, extra_weight_attrs)
|
||||
|
||||
a2_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
layer.register_parameter("a2_scale", a2_scale)
|
||||
set_weight_attrs(a2_scale, extra_weight_attrs)
|
||||
else:
|
||||
layer.a13_scale = None
|
||||
layer.a2_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if not hasattr(layer, "process_after_load") or not layer.process_after_load:
|
||||
return
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
w13_weight = torch.empty_like(
|
||||
layer.w13_weight.data, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
w2_weight = torch.empty_like(
|
||||
layer.w2_weight.data, dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
# Re-initialize w13_scale because we directly quantize
|
||||
# merged w13 weights and generate a single scaling factor.
|
||||
layer.w13_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
layer.num_experts, dtype=torch.float32, device=w13_weight.device
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
for expert in range(layer.num_experts):
|
||||
w13_weight[expert, :, :], layer.w13_scale[expert] = (
|
||||
ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
|
||||
)
|
||||
w2_weight[expert, :, :], layer.w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
layer.w2_weight.data[expert, :, :]
|
||||
)
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
||||
return
|
||||
|
||||
# If checkpoint is fp8, we need to handle that the
|
||||
# MoE kernels require single activation scale and single weight
|
||||
# scale for w13 per expert.
|
||||
else:
|
||||
# Fp8 moe kernels require a single activation scale.
|
||||
# We take the max of all the scales in case they differ.
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
if layer.a13_scale is None or layer.a2_scale is None:
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
"activation scales are None."
|
||||
)
|
||||
if not all_close_1d(layer.a13_scale) or not all_close_1d(
|
||||
layer.a2_scale
|
||||
):
|
||||
print_warning_once(
|
||||
"Found input_scales that are not equal for "
|
||||
"fp8 MoE layer. Using the maximum across experts "
|
||||
"for each layer. "
|
||||
)
|
||||
layer.a13_scale = torch.nn.Parameter(
|
||||
layer.a13_scale.max(), requires_grad=False
|
||||
)
|
||||
layer.a2_scale = torch.nn.Parameter(
|
||||
layer.a2_scale.max(), requires_grad=False
|
||||
)
|
||||
|
||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||
# We take the max then dequant and requant each expert.
|
||||
assert layer.w13_scale is not None
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
max_w13_scales = layer.w13_scale.max(dim=1).values
|
||||
for expert_id in range(layer.num_experts):
|
||||
start = 0
|
||||
for shard_id in range(2):
|
||||
dq_weight = per_tensor_dequantize(
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :],
|
||||
layer.w13_scale[expert_id][shard_id],
|
||||
)
|
||||
layer.w13_weight[expert_id][start : start + shard_size, :] = (
|
||||
per_tensor_quantize(dq_weight, max_w13_scales[expert_id])
|
||||
)
|
||||
start += shard_size
|
||||
|
||||
layer.w13_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
|
||||
return
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool = True,
|
||||
) -> torch.Tensor:
|
||||
|
||||
return fused_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
router_logits,
|
||||
top_k,
|
||||
renormalize=renormalize,
|
||||
inplace=True,
|
||||
use_fp8=True,
|
||||
w1_scale=layer.w13_scale,
|
||||
w2_scale=layer.w2_scale,
|
||||
a1_scale=layer.a13_scale,
|
||||
a2_scale=layer.a2_scale,
|
||||
)
|
||||
|
||||
|
||||
# FIXME: not used
|
||||
class Fp8KVCacheMethod(QuantizeMethodBase):
|
||||
"""Supports loading kv-cache scaling factors from FP8 checkpoints."""
|
||||
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module):
|
||||
"""Create "weight" (aka kv_scale) for an attention layer.
|
||||
|
||||
Args:
|
||||
layer: The layer that is using the QuantizeMethodBase factory.
|
||||
"""
|
||||
# Initialize the KV cache scale to 1.0 as the default value.
|
||||
# If the kv_scale appears in the checkpoint, it will be
|
||||
# overwritten when loading weights.
|
||||
layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False)
|
||||
|
||||
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
||||
raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
|
||||
# regardless whether the kv-scale is available in the checkpoint.
|
||||
if layer.kv_cache_dtype != "auto":
|
||||
kv_scale = layer.kv_scale.to("cpu").tolist()
|
||||
if not isinstance(kv_scale, float):
|
||||
raise ValueError(
|
||||
"Only support per-tensor scaling factor " "for fp8 KV cache"
|
||||
)
|
||||
layer._kv_scale = kv_scale
|
||||
if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
|
||||
print_warning_once(
|
||||
"Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
|
||||
"cause accuracy issues. Please make sure kv-cache scaling "
|
||||
"factor is available in the fp8 checkpoint."
|
||||
)
|
||||
del layer.kv_scale
|
||||
|
||||
|
||||
def per_tensor_quantize(
|
||||
tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
|
||||
return qweight.to(torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def per_tensor_dequantize(
|
||||
tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
fake_qweight = tensor.to(torch.float16)
|
||||
dq_weight = fake_qweight * inv_scale
|
||||
return dq_weight
|
||||
|
||||
|
||||
def all_close_1d(x: torch.Tensor) -> bool:
|
||||
assert len(x.shape) == 1
|
||||
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
||||
@@ -20,8 +20,8 @@ from flashinfer.cascade import merge_state
|
||||
from torch import nn
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.decode_attention import decode_attention_fwd
|
||||
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
||||
from sglang.srt.layers.token_attention import token_attention_fwd
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
||||
from sglang.srt.model_executor.model_runner import global_server_args_dict
|
||||
|
||||
@@ -95,7 +95,7 @@ class RadixAttention(nn.Module):
|
||||
o = torch.empty_like(q)
|
||||
self.store_kv_cache(k, v, input_metadata)
|
||||
|
||||
token_attention_fwd(
|
||||
decode_attention_fwd(
|
||||
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
|
||||
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
|
||||
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
|
||||
|
||||
Reference in New Issue
Block a user