[CPU] refine CPU integration code (#7647)
This commit is contained in:
86
python/sglang/srt/layers/amx_utils.py
Normal file
86
python/sglang/srt/layers/amx_utils.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.utils import cpu_has_amx_support
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def amx_process_weight_after_loading(weight):
|
||||||
|
if weight.device != torch.device("cpu"):
|
||||||
|
return weight
|
||||||
|
if not cpu_has_amx_support():
|
||||||
|
return weight
|
||||||
|
|
||||||
|
return torch.ops.sgl_kernel.convert_weight_packed(weight)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: currently gemm kernel has the below requirements:
|
||||||
|
# OC % TILE_N == 0, where TILE_N = 16
|
||||||
|
# IC % TILE_K == 0, where TILE_K = 32
|
||||||
|
def dim_is_supported(weight):
|
||||||
|
TILE_N = 16
|
||||||
|
TILE_K = 32
|
||||||
|
ndim = weight.ndim
|
||||||
|
OC = weight.size(1) if ndim == 3 else weight.size(0)
|
||||||
|
IC = weight.size(2) if ndim == 3 else weight.size(1)
|
||||||
|
return OC % TILE_N == 0 and IC % TILE_K == 0
|
||||||
|
|
||||||
|
|
||||||
|
def _amx_process_weight_after_loading(
|
||||||
|
module, weight_names, transpose_dims=None
|
||||||
|
) -> None:
|
||||||
|
# Pack weight for get better performance on CPU
|
||||||
|
devices = {getattr(module, weight_name).device for weight_name in weight_names}
|
||||||
|
assert len(devices) == 1, f"Expects all weights to be on the same device"
|
||||||
|
device = devices.pop()
|
||||||
|
|
||||||
|
if transpose_dims:
|
||||||
|
assert len(weight_names) == len(
|
||||||
|
transpose_dims
|
||||||
|
), "len(weight_names) should be equal to len(transpose_dims)"
|
||||||
|
|
||||||
|
for i, weight_name in enumerate(weight_names):
|
||||||
|
weight_tensor = getattr(module, weight_name)
|
||||||
|
|
||||||
|
if transpose_dims and transpose_dims[i]:
|
||||||
|
weight_tensor = weight_tensor.transpose(*transpose_dims[i])
|
||||||
|
|
||||||
|
# We don't pack weight or use intel amx backend if any weight of this module has unsupported dim.
|
||||||
|
if not dim_is_supported(weight_tensor):
|
||||||
|
logger.warning(
|
||||||
|
f"Unsupported dimension for prepacking for weight '{weight_name}' with shape {weight_tensor.shape} in {module}. "
|
||||||
|
f"The derived (OC, IC) dimensions must be divisible by (16, 32). "
|
||||||
|
)
|
||||||
|
module.use_intel_amx_backend = False
|
||||||
|
return
|
||||||
|
|
||||||
|
packed_weight = torch.nn.Parameter(
|
||||||
|
amx_process_weight_after_loading(weight_tensor),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
packed_weight.__dict__ = weight_tensor.__dict__
|
||||||
|
setattr(module, weight_name, packed_weight)
|
||||||
|
|
||||||
|
module.use_intel_amx_backend = (
|
||||||
|
device == torch.device("cpu") and cpu_has_amx_support()
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
module.use_intel_amx_backend
|
||||||
|
and hasattr(module, "bias")
|
||||||
|
and module.bias is not None
|
||||||
|
):
|
||||||
|
module.bias = torch.nn.Parameter(module.bias.data.float(), requires_grad=False)
|
||||||
|
|
||||||
|
|
||||||
|
class PackWeightMethod:
|
||||||
|
def __init__(self, weight_names, transpose_dims=None):
|
||||||
|
self.weight_names = weight_names
|
||||||
|
self.transpose_dims = transpose_dims
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, module) -> None:
|
||||||
|
_amx_process_weight_after_loading(
|
||||||
|
module, self.weight_names, self.transpose_dims
|
||||||
|
)
|
||||||
@@ -17,6 +17,7 @@ from sglang.srt.distributed import (
|
|||||||
tensor_model_parallel_all_gather,
|
tensor_model_parallel_all_gather,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
||||||
from sglang.srt.layers.parameter import (
|
from sglang.srt.layers.parameter import (
|
||||||
BasevLLMParameter,
|
BasevLLMParameter,
|
||||||
BlockQuantScaleParameter,
|
BlockQuantScaleParameter,
|
||||||
@@ -31,10 +32,10 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
_process_weight_after_loading,
|
|
||||||
cpu_has_amx_support,
|
cpu_has_amx_support,
|
||||||
is_cpu,
|
is_cpu,
|
||||||
set_weight_attrs,
|
set_weight_attrs,
|
||||||
|
use_intel_amx_backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -175,7 +176,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
if _is_cpu and _is_cpu_amx_available:
|
if _is_cpu and _is_cpu_amx_available:
|
||||||
_process_weight_after_loading(layer, ["weight"])
|
_amx_process_weight_after_loading(layer, ["weight"])
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@@ -184,7 +185,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
|||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
if getattr(layer, "use_intel_amx_backend", False):
|
if use_intel_amx_backend(layer):
|
||||||
return torch.ops.sgl_kernel.weight_packed_linear(
|
return torch.ops.sgl_kernel.weight_packed_linear(
|
||||||
x, layer.weight, bias, True # is_vnni
|
x, layer.weight, bias, True # is_vnni
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|||||||
ForwardBatch,
|
ForwardBatch,
|
||||||
ForwardMode,
|
ForwardMode,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import dump_to_file
|
from sglang.srt.utils import dump_to_file, use_intel_amx_backend
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -442,7 +442,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
|
||||||
|
|
||||||
if hasattr(lm_head, "weight"):
|
if hasattr(lm_head, "weight"):
|
||||||
if getattr(lm_head, "use_intel_amx_backend", False):
|
if use_intel_amx_backend(lm_head):
|
||||||
logits = torch.ops.sgl_kernel.weight_packed_linear(
|
logits = torch.ops.sgl_kernel.weight_packed_linear(
|
||||||
hidden_states.to(lm_head.weight.dtype),
|
hidden_states.to(lm_head.weight.dtype),
|
||||||
lm_head.weight,
|
lm_head.weight,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from sglang.srt.distributed import (
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
||||||
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
@@ -19,12 +20,12 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
_process_weight_after_loading,
|
|
||||||
cpu_has_amx_support,
|
cpu_has_amx_support,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
is_cpu,
|
is_cpu,
|
||||||
is_hip,
|
is_hip,
|
||||||
set_weight_attrs,
|
set_weight_attrs,
|
||||||
|
use_intel_amx_backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@@ -129,7 +130,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
|
|
||||||
# Pack weight for get better performance on CPU
|
# Pack weight for get better performance on CPU
|
||||||
if _is_cpu and _is_cpu_amx_available:
|
if _is_cpu and _is_cpu_amx_available:
|
||||||
_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -264,10 +265,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert activation == "silu", f"activation = {activation} is not supported."
|
assert activation == "silu", f"activation = {activation} is not supported."
|
||||||
|
|
||||||
if (
|
if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
|
||||||
getattr(layer, "use_intel_amx_backend", False)
|
|
||||||
and not apply_router_weight_on_input
|
|
||||||
):
|
|
||||||
topk_weights, topk_ids = select_experts(
|
topk_weights, topk_ids = select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ except ImportError:
|
|||||||
|
|
||||||
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||||
|
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
||||||
from sglang.srt.layers.linear import (
|
from sglang.srt.layers.linear import (
|
||||||
LinearBase,
|
LinearBase,
|
||||||
LinearMethodBase,
|
LinearMethodBase,
|
||||||
@@ -64,7 +65,6 @@ from sglang.srt.layers.quantization.utils import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.utils import is_sm100_supported
|
from sglang.srt.layers.utils import is_sm100_supported
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
_process_weight_after_loading,
|
|
||||||
cpu_has_amx_support,
|
cpu_has_amx_support,
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
is_cpu,
|
is_cpu,
|
||||||
@@ -74,6 +74,7 @@ from sglang.srt.utils import (
|
|||||||
log_info_on_rank0,
|
log_info_on_rank0,
|
||||||
print_warning_once,
|
print_warning_once,
|
||||||
set_weight_attrs,
|
set_weight_attrs,
|
||||||
|
use_intel_amx_backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
@@ -335,7 +336,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
assert (
|
assert (
|
||||||
_is_cpu_amx_available
|
_is_cpu_amx_available
|
||||||
), "Fp8LinearMethod on CPU requires that CPU has AMX support"
|
), "Fp8LinearMethod on CPU requires that CPU has AMX support"
|
||||||
_process_weight_after_loading(layer, ["weight"])
|
_amx_process_weight_after_loading(layer, ["weight"])
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
|
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
|
||||||
@@ -433,7 +434,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
if getattr(layer, "use_intel_amx_backend", False):
|
if use_intel_amx_backend(layer):
|
||||||
return torch.ops.sgl_kernel.fp8_scaled_mm_cpu(
|
return torch.ops.sgl_kernel.fp8_scaled_mm_cpu(
|
||||||
x,
|
x,
|
||||||
layer.weight,
|
layer.weight,
|
||||||
@@ -769,7 +770,7 @@ class Fp8MoEMethod:
|
|||||||
assert (
|
assert (
|
||||||
_is_cpu_amx_available
|
_is_cpu_amx_available
|
||||||
), "Fp8MoEMethod on CPU requires that CPU has AMX support"
|
), "Fp8MoEMethod on CPU requires that CPU has AMX support"
|
||||||
_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -996,7 +997,7 @@ class Fp8MoEMethod:
|
|||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
if getattr(layer, "use_intel_amx_backend", False):
|
if use_intel_amx_backend(layer):
|
||||||
return torch.ops.sgl_kernel.fused_experts_cpu(
|
return torch.ops.sgl_kernel.fused_experts_cpu(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import torch
|
|||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||||
|
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
||||||
from sglang.srt.layers.linear import LinearMethodBase
|
from sglang.srt.layers.linear import LinearMethodBase
|
||||||
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
@@ -12,11 +13,11 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
_process_weight_after_loading,
|
|
||||||
cpu_has_amx_support,
|
cpu_has_amx_support,
|
||||||
is_cpu,
|
is_cpu,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
set_weight_attrs,
|
set_weight_attrs,
|
||||||
|
use_intel_amx_backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
@@ -84,7 +85,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
|||||||
assert (
|
assert (
|
||||||
_is_cpu_amx_available
|
_is_cpu_amx_available
|
||||||
), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
|
), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
|
||||||
_process_weight_after_loading(layer, ["weight"])
|
_amx_process_weight_after_loading(layer, ["weight"])
|
||||||
return
|
return
|
||||||
|
|
||||||
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
||||||
@@ -127,7 +128,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
if getattr(layer, "use_intel_amx_backend", False):
|
if use_intel_amx_backend(layer):
|
||||||
return torch.ops.sgl_kernel.int8_scaled_mm_with_quant(
|
return torch.ops.sgl_kernel.int8_scaled_mm_with_quant(
|
||||||
x,
|
x,
|
||||||
layer.weight,
|
layer.weight,
|
||||||
@@ -235,7 +236,7 @@ class W8A8Int8MoEMethod:
|
|||||||
assert (
|
assert (
|
||||||
_is_cpu_amx_available
|
_is_cpu_amx_available
|
||||||
), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
|
), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
|
||||||
_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
||||||
return
|
return
|
||||||
|
|
||||||
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
|
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
|
||||||
@@ -284,7 +285,7 @@ class W8A8Int8MoEMethod:
|
|||||||
routed_scaling_factor=routed_scaling_factor,
|
routed_scaling_factor=routed_scaling_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
if getattr(layer, "use_intel_amx_backend", False):
|
if use_intel_amx_backend(layer):
|
||||||
return torch.ops.sgl_kernel.fused_experts_cpu(
|
return torch.ops.sgl_kernel.fused_experts_cpu(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight,
|
layer.w13_weight,
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.amx_utils import PackWeightMethod
|
||||||
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
||||||
from sglang.srt.layers.parameter import BasevLLMParameter
|
from sglang.srt.layers.parameter import BasevLLMParameter
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
@@ -20,12 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
method_has_implemented_embedding,
|
method_has_implemented_embedding,
|
||||||
)
|
)
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs
|
||||||
PackWeightMethod,
|
|
||||||
cpu_has_amx_support,
|
|
||||||
is_cpu,
|
|
||||||
set_weight_attrs,
|
|
||||||
)
|
|
||||||
|
|
||||||
DEFAULT_VOCAB_PADDING_SIZE = 64
|
DEFAULT_VOCAB_PADDING_SIZE = 64
|
||||||
|
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_r
|
|||||||
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
||||||
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
|
from sglang.srt.layers.amx_utils import PackWeightMethod
|
||||||
from sglang.srt.layers.communicator import (
|
from sglang.srt.layers.communicator import (
|
||||||
LayerCommunicator,
|
LayerCommunicator,
|
||||||
LayerScatterModes,
|
LayerScatterModes,
|
||||||
@@ -91,7 +92,6 @@ from sglang.srt.utils import (
|
|||||||
BumpAllocator,
|
BumpAllocator,
|
||||||
DeepEPMode,
|
DeepEPMode,
|
||||||
LazyValue,
|
LazyValue,
|
||||||
PackWeightMethod,
|
|
||||||
add_prefix,
|
add_prefix,
|
||||||
bind_or_assign,
|
bind_or_assign,
|
||||||
cpu_has_amx_support,
|
cpu_has_amx_support,
|
||||||
@@ -103,6 +103,7 @@ from sglang.srt.utils import (
|
|||||||
is_hip,
|
is_hip,
|
||||||
is_non_idle_and_non_empty,
|
is_non_idle_and_non_empty,
|
||||||
log_info_on_rank0,
|
log_info_on_rank0,
|
||||||
|
use_intel_amx_backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
@@ -224,7 +225,7 @@ class MoEGate(nn.Module):
|
|||||||
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
if getattr(self, "use_intel_amx_backend", False):
|
if use_intel_amx_backend(self):
|
||||||
return torch.ops.sgl_kernel.weight_packed_linear(
|
return torch.ops.sgl_kernel.weight_packed_linear(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
self.weight,
|
self.weight,
|
||||||
@@ -437,8 +438,8 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
if hasattr(self, "shared_experts") and getattr(
|
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
||||||
self.shared_experts.gate_up_proj, "use_intel_amx_backend", False
|
self.shared_experts.gate_up_proj
|
||||||
):
|
):
|
||||||
return self.forward_cpu(hidden_states)
|
return self.forward_cpu(hidden_states)
|
||||||
|
|
||||||
@@ -464,9 +465,9 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
hidden_states=hidden_states, router_logits=router_logits
|
hidden_states=hidden_states, router_logits=router_logits
|
||||||
)
|
)
|
||||||
|
|
||||||
assert getattr(
|
assert use_intel_amx_backend(
|
||||||
self.shared_experts.gate_up_proj, "use_intel_amx_backend", False
|
self.shared_experts.gate_up_proj
|
||||||
) == getattr(self.shared_experts.down_proj, "use_intel_amx_backend", False)
|
) == use_intel_amx_backend(self.shared_experts.down_proj)
|
||||||
# [Note] inplace should be False in fused_experts.
|
# [Note] inplace should be False in fused_experts.
|
||||||
# If inplace is True in fused_experts (self.experts), hidden_states will be changed after fused_experts
|
# If inplace is True in fused_experts (self.experts), hidden_states will be changed after fused_experts
|
||||||
# While hidden_states is still needed in shared_expert.
|
# While hidden_states is still needed in shared_expert.
|
||||||
@@ -928,15 +929,23 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.weight_block_size = None
|
self.weight_block_size = None
|
||||||
if self.qkv_proj_with_rope_is_fp8:
|
if self.qkv_proj_with_rope_is_fp8 and _is_cpu and _is_cpu_amx_available:
|
||||||
assert (
|
assert getattr(
|
||||||
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
|
self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
|
||||||
== self.q_b_proj.quant_method.quant_config.weight_block_size
|
) == getattr(self.q_b_proj.quant_method, "block_quant", False)
|
||||||
)
|
use_block_quant = getattr(
|
||||||
self.weight_block_size = (
|
self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
|
||||||
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if use_block_quant:
|
||||||
|
assert (
|
||||||
|
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
|
||||||
|
== self.q_b_proj.quant_method.quant_config.weight_block_size
|
||||||
|
)
|
||||||
|
self.weight_block_size = (
|
||||||
|
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
|
||||||
|
)
|
||||||
|
|
||||||
def dispatch_attn_forward_method(
|
def dispatch_attn_forward_method(
|
||||||
self, forward_batch: ForwardBatch
|
self, forward_batch: ForwardBatch
|
||||||
) -> AttnForwardMethod:
|
) -> AttnForwardMethod:
|
||||||
@@ -950,8 +959,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
else:
|
else:
|
||||||
return AttnForwardMethod.MLA
|
return AttnForwardMethod.MLA
|
||||||
else:
|
else:
|
||||||
if hasattr(self, "fused_qkv_a_proj_with_mqa") and getattr(
|
if hasattr(self, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(
|
||||||
self, "use_intel_amx_backend", False
|
self
|
||||||
):
|
):
|
||||||
return AttnForwardMethod.MLA_FUSED_ROPE_CPU
|
return AttnForwardMethod.MLA_FUSED_ROPE_CPU
|
||||||
else:
|
else:
|
||||||
@@ -1426,8 +1435,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
zero_allocator: BumpAllocator,
|
zero_allocator: BumpAllocator,
|
||||||
):
|
):
|
||||||
assert self.q_lora_rank is not None and getattr(
|
assert self.q_lora_rank is not None and use_intel_amx_backend(
|
||||||
self, "use_intel_amx_backend", False
|
self
|
||||||
), "forward_absorb_fused_mla_rope_cpu_prepare requires q_lora_rank is not None and use_intel_amx_backend"
|
), "forward_absorb_fused_mla_rope_cpu_prepare requires q_lora_rank is not None and use_intel_amx_backend"
|
||||||
|
|
||||||
q_input, k_input, v_input = (
|
q_input, k_input, v_input = (
|
||||||
@@ -1546,8 +1555,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
def forward_absorb_fused_mla_rope_cpu_core(
|
def forward_absorb_fused_mla_rope_cpu_core(
|
||||||
self, q_input, k_input, v_input, forward_batch, zero_allocator
|
self, q_input, k_input, v_input, forward_batch, zero_allocator
|
||||||
):
|
):
|
||||||
assert self.q_lora_rank is not None and getattr(
|
assert self.q_lora_rank is not None and use_intel_amx_backend(
|
||||||
self, "use_intel_amx_backend", False
|
self
|
||||||
), "forward_absorb_fused_mla_rope_cpu_core requires q_lora_rank is not None and use_intel_amx_backend"
|
), "forward_absorb_fused_mla_rope_cpu_core requires q_lora_rank is not None and use_intel_amx_backend"
|
||||||
|
|
||||||
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
||||||
|
|||||||
@@ -2416,75 +2416,8 @@ def cpu_has_amx_support():
|
|||||||
return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available
|
return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available
|
||||||
|
|
||||||
|
|
||||||
def prepack_weight_if_needed(weight):
|
def use_intel_amx_backend(layer):
|
||||||
if weight.device != torch.device("cpu"):
|
return getattr(layer, "use_intel_amx_backend", False)
|
||||||
return weight
|
|
||||||
if not cpu_has_amx_support():
|
|
||||||
return weight
|
|
||||||
|
|
||||||
return torch.ops.sgl_kernel.convert_weight_packed(weight)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: currently gemm kernel has the below requirements:
|
|
||||||
# OC % TILE_N == 0, where TILE_N = 16
|
|
||||||
# IC % TILE_K == 0, where TILE_K = 32
|
|
||||||
def dim_is_supported(weight):
|
|
||||||
return weight.size(0) % 16 == 0 and weight.size(1) % 32 == 0
|
|
||||||
|
|
||||||
|
|
||||||
def _process_weight_after_loading(module, weight_names, transpose_dims=None) -> None:
|
|
||||||
# Pack weight for get better performance on CPU
|
|
||||||
devices = {getattr(module, weight_name).device for weight_name in weight_names}
|
|
||||||
assert len(devices) == 1, f"Expects all weights to be on the same device"
|
|
||||||
device = devices.pop()
|
|
||||||
|
|
||||||
if transpose_dims:
|
|
||||||
assert len(weight_names) == len(
|
|
||||||
transpose_dims
|
|
||||||
), "len(weight_names) should be equal to len(transpose_dims)"
|
|
||||||
|
|
||||||
for i, weight_name in enumerate(weight_names):
|
|
||||||
weight_tensor = getattr(module, weight_name)
|
|
||||||
|
|
||||||
# We don't pack weight or use intel amx backend if any weight of this module has unsupported dim.
|
|
||||||
if not dim_is_supported(weight_tensor):
|
|
||||||
logger.warning(
|
|
||||||
f"Expects weight.size(0) % 16 == 0 and weight.size(1) % 32 == 0 "
|
|
||||||
f"but {weight_tensor.size(0)=} and {weight_tensor.size(1)=} in {module}. "
|
|
||||||
f"{module} won't use intel amx backend."
|
|
||||||
)
|
|
||||||
module.use_intel_amx_backend = False
|
|
||||||
return
|
|
||||||
|
|
||||||
if transpose_dims and transpose_dims[i]:
|
|
||||||
weight_tensor = weight_tensor.transpose(*transpose_dims[i])
|
|
||||||
|
|
||||||
packed_weight = torch.nn.Parameter(
|
|
||||||
prepack_weight_if_needed(weight_tensor),
|
|
||||||
requires_grad=False,
|
|
||||||
)
|
|
||||||
packed_weight.__dict__ = weight_tensor.__dict__
|
|
||||||
setattr(module, weight_name, packed_weight)
|
|
||||||
|
|
||||||
module.use_intel_amx_backend = (
|
|
||||||
device == torch.device("cpu") and cpu_has_amx_support()
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
|
||||||
module.use_intel_amx_backend
|
|
||||||
and hasattr(module, "bias")
|
|
||||||
and module.bias is not None
|
|
||||||
):
|
|
||||||
module.bias = torch.nn.Parameter(module.bias.data.float(), requires_grad=False)
|
|
||||||
|
|
||||||
|
|
||||||
class PackWeightMethod:
|
|
||||||
def __init__(self, weight_names, transpose_dims=None):
|
|
||||||
self.weight_names = weight_names
|
|
||||||
self.transpose_dims = transpose_dims
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, module) -> None:
|
|
||||||
_process_weight_after_loading(module, self.weight_names, self.transpose_dims)
|
|
||||||
|
|
||||||
|
|
||||||
class LazyValue:
|
class LazyValue:
|
||||||
|
|||||||
Reference in New Issue
Block a user