87 lines
2.8 KiB
Python
87 lines
2.8 KiB
Python
import logging
|
|
|
|
import torch
|
|
|
|
from sglang.srt.utils import cpu_has_amx_support
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def amx_process_weight_after_loading(weight):
|
|
if weight.device != torch.device("cpu"):
|
|
return weight
|
|
if not cpu_has_amx_support():
|
|
return weight
|
|
|
|
return torch.ops.sgl_kernel.convert_weight_packed(weight)
|
|
|
|
|
|
# TODO: currently gemm kernel has the below requirements:
|
|
# OC % TILE_N == 0, where TILE_N = 16
|
|
# IC % TILE_K == 0, where TILE_K = 32
|
|
def dim_is_supported(weight):
|
|
TILE_N = 16
|
|
TILE_K = 32
|
|
ndim = weight.ndim
|
|
OC = weight.size(1) if ndim == 3 else weight.size(0)
|
|
IC = weight.size(2) if ndim == 3 else weight.size(1)
|
|
return OC % TILE_N == 0 and IC % TILE_K == 0
|
|
|
|
|
|
def _amx_process_weight_after_loading(
|
|
module, weight_names, transpose_dims=None
|
|
) -> None:
|
|
# Pack weight for get better performance on CPU
|
|
devices = {getattr(module, weight_name).device for weight_name in weight_names}
|
|
assert len(devices) == 1, f"Expects all weights to be on the same device"
|
|
device = devices.pop()
|
|
|
|
if transpose_dims:
|
|
assert len(weight_names) == len(
|
|
transpose_dims
|
|
), "len(weight_names) should be equal to len(transpose_dims)"
|
|
|
|
for i, weight_name in enumerate(weight_names):
|
|
weight_tensor = getattr(module, weight_name)
|
|
|
|
if transpose_dims and transpose_dims[i]:
|
|
weight_tensor = weight_tensor.transpose(*transpose_dims[i])
|
|
|
|
# We don't pack weight or use intel amx backend if any weight of this module has unsupported dim.
|
|
if not dim_is_supported(weight_tensor):
|
|
logger.warning(
|
|
f"Unsupported dimension for prepacking for weight '{weight_name}' with shape {weight_tensor.shape} in {module}. "
|
|
f"The derived (OC, IC) dimensions must be divisible by (16, 32). "
|
|
)
|
|
module.use_intel_amx_backend = False
|
|
return
|
|
|
|
packed_weight = torch.nn.Parameter(
|
|
amx_process_weight_after_loading(weight_tensor),
|
|
requires_grad=False,
|
|
)
|
|
packed_weight.__dict__ = weight_tensor.__dict__
|
|
setattr(module, weight_name, packed_weight)
|
|
|
|
module.use_intel_amx_backend = (
|
|
device == torch.device("cpu") and cpu_has_amx_support()
|
|
)
|
|
|
|
if (
|
|
module.use_intel_amx_backend
|
|
and hasattr(module, "bias")
|
|
and module.bias is not None
|
|
):
|
|
module.bias = torch.nn.Parameter(module.bias.data.float(), requires_grad=False)
|
|
|
|
|
|
class PackWeightMethod:
|
|
def __init__(self, weight_names, transpose_dims=None):
|
|
self.weight_names = weight_names
|
|
self.transpose_dims = transpose_dims
|
|
|
|
def process_weights_after_loading(self, module) -> None:
|
|
_amx_process_weight_after_loading(
|
|
module, self.weight_names, self.transpose_dims
|
|
)
|