diff --git a/python/sglang/srt/layers/amx_utils.py b/python/sglang/srt/layers/amx_utils.py new file mode 100644 index 000000000..df2a05ba5 --- /dev/null +++ b/python/sglang/srt/layers/amx_utils.py @@ -0,0 +1,86 @@ +import logging + +import torch + +from sglang.srt.utils import cpu_has_amx_support + +logger = logging.getLogger(__name__) + + +def amx_process_weight_after_loading(weight): + if weight.device != torch.device("cpu"): + return weight + if not cpu_has_amx_support(): + return weight + + return torch.ops.sgl_kernel.convert_weight_packed(weight) + + +# TODO: currently gemm kernel has the below requirements: +# OC % TILE_N == 0, where TILE_N = 16 +# IC % TILE_K == 0, where TILE_K = 32 +def dim_is_supported(weight): + TILE_N = 16 + TILE_K = 32 + ndim = weight.ndim + OC = weight.size(1) if ndim == 3 else weight.size(0) + IC = weight.size(2) if ndim == 3 else weight.size(1) + return OC % TILE_N == 0 and IC % TILE_K == 0 + + +def _amx_process_weight_after_loading( + module, weight_names, transpose_dims=None +) -> None: + # Pack weight for get better performance on CPU + devices = {getattr(module, weight_name).device for weight_name in weight_names} + assert len(devices) == 1, f"Expects all weights to be on the same device" + device = devices.pop() + + if transpose_dims: + assert len(weight_names) == len( + transpose_dims + ), "len(weight_names) should be equal to len(transpose_dims)" + + for i, weight_name in enumerate(weight_names): + weight_tensor = getattr(module, weight_name) + + if transpose_dims and transpose_dims[i]: + weight_tensor = weight_tensor.transpose(*transpose_dims[i]) + + # We don't pack weight or use intel amx backend if any weight of this module has unsupported dim. + if not dim_is_supported(weight_tensor): + logger.warning( + f"Unsupported dimension for prepacking for weight '{weight_name}' with shape {weight_tensor.shape} in {module}. " + f"The derived (OC, IC) dimensions must be divisible by (16, 32). " + ) + module.use_intel_amx_backend = False + return + + packed_weight = torch.nn.Parameter( + amx_process_weight_after_loading(weight_tensor), + requires_grad=False, + ) + packed_weight.__dict__ = weight_tensor.__dict__ + setattr(module, weight_name, packed_weight) + + module.use_intel_amx_backend = ( + device == torch.device("cpu") and cpu_has_amx_support() + ) + + if ( + module.use_intel_amx_backend + and hasattr(module, "bias") + and module.bias is not None + ): + module.bias = torch.nn.Parameter(module.bias.data.float(), requires_grad=False) + + +class PackWeightMethod: + def __init__(self, weight_names, transpose_dims=None): + self.weight_names = weight_names + self.transpose_dims = transpose_dims + + def process_weights_after_loading(self, module) -> None: + _amx_process_weight_after_loading( + module, self.weight_names, self.transpose_dims + ) diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 3304340f1..1fc43b8b6 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -17,6 +17,7 @@ from sglang.srt.distributed import ( tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, ) +from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading from sglang.srt.layers.parameter import ( BasevLLMParameter, BlockQuantScaleParameter, @@ -31,10 +32,10 @@ from sglang.srt.layers.quantization.base_config import ( QuantizeMethodBase, ) from sglang.srt.utils import ( - _process_weight_after_loading, cpu_has_amx_support, is_cpu, set_weight_attrs, + use_intel_amx_backend, ) logger = logging.getLogger(__name__) @@ -175,7 +176,7 @@ class UnquantizedLinearMethod(LinearMethodBase): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if _is_cpu and _is_cpu_amx_available: - _process_weight_after_loading(layer, ["weight"]) + _amx_process_weight_after_loading(layer, ["weight"]) def apply( self, @@ -184,7 +185,7 @@ class UnquantizedLinearMethod(LinearMethodBase): bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if getattr(layer, "use_intel_amx_backend", False): + if use_intel_amx_backend(layer): return torch.ops.sgl_kernel.weight_packed_linear( x, layer.weight, bias, True # is_vnni ) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index e98295184..c01e10090 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -42,7 +42,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ForwardBatch, ForwardMode, ) -from sglang.srt.utils import dump_to_file +from sglang.srt.utils import dump_to_file, use_intel_amx_backend logger = logging.getLogger(__name__) @@ -442,7 +442,7 @@ class LogitsProcessor(nn.Module): dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata) if hasattr(lm_head, "weight"): - if getattr(lm_head, "use_intel_amx_backend", False): + if use_intel_amx_backend(lm_head): logits = torch.ops.sgl_kernel.weight_packed_linear( hidden_states.to(lm_head.weight.dtype), lm_head.weight, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 8cc068dbf..9147136e3 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -12,6 +12,7 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) +from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading from sglang.srt.layers.moe.fused_moe_native import moe_forward_native from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import ( @@ -19,12 +20,12 @@ from sglang.srt.layers.quantization.base_config import ( QuantizeMethodBase, ) from sglang.srt.utils import ( - _process_weight_after_loading, cpu_has_amx_support, get_bool_env_var, is_cpu, is_hip, set_weight_attrs, + use_intel_amx_backend, ) if torch.cuda.is_available(): @@ -129,7 +130,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): # Pack weight for get better performance on CPU if _is_cpu and _is_cpu_amx_available: - _process_weight_after_loading(layer, ["w13_weight", "w2_weight"]) + _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"]) return @@ -264,10 +265,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ) -> torch.Tensor: assert activation == "silu", f"activation = {activation} is not supported." - if ( - getattr(layer, "use_intel_amx_backend", False) - and not apply_router_weight_on_input - ): + if use_intel_amx_backend(layer) and not apply_router_weight_on_input: topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 358c152d3..f2c0d6139 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -27,6 +27,7 @@ except ImportError: from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading from sglang.srt.layers.linear import ( LinearBase, LinearMethodBase, @@ -64,7 +65,6 @@ from sglang.srt.layers.quantization.utils import ( ) from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.utils import ( - _process_weight_after_loading, cpu_has_amx_support, get_bool_env_var, is_cpu, @@ -74,6 +74,7 @@ from sglang.srt.utils import ( log_info_on_rank0, print_warning_once, set_weight_attrs, + use_intel_amx_backend, ) _is_hip = is_hip() @@ -335,7 +336,7 @@ class Fp8LinearMethod(LinearMethodBase): assert ( _is_cpu_amx_available ), "Fp8LinearMethod on CPU requires that CPU has AMX support" - _process_weight_after_loading(layer, ["weight"]) + _amx_process_weight_after_loading(layer, ["weight"]) return else: weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data @@ -433,7 +434,7 @@ class Fp8LinearMethod(LinearMethodBase): ) if self.block_quant: - if getattr(layer, "use_intel_amx_backend", False): + if use_intel_amx_backend(layer): return torch.ops.sgl_kernel.fp8_scaled_mm_cpu( x, layer.weight, @@ -769,7 +770,7 @@ class Fp8MoEMethod: assert ( _is_cpu_amx_available ), "Fp8MoEMethod on CPU requires that CPU has AMX support" - _process_weight_after_loading(layer, ["w13_weight", "w2_weight"]) + _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"]) return @@ -996,7 +997,7 @@ class Fp8MoEMethod: routed_scaling_factor=routed_scaling_factor, ) - if getattr(layer, "use_intel_amx_backend", False): + if use_intel_amx_backend(layer): return torch.ops.sgl_kernel.fused_experts_cpu( x, layer.w13_weight, diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index 4e1d90a0e..db0351052 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -4,6 +4,7 @@ import torch from torch.nn.parameter import Parameter from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading from sglang.srt.layers.linear import LinearMethodBase from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter from sglang.srt.layers.quantization.base_config import ( @@ -12,11 +13,11 @@ from sglang.srt.layers.quantization.base_config import ( ) from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 from sglang.srt.utils import ( - _process_weight_after_loading, cpu_has_amx_support, is_cpu, is_cuda, set_weight_attrs, + use_intel_amx_backend, ) _is_cuda = is_cuda() @@ -84,7 +85,7 @@ class W8A8Int8LinearMethod(LinearMethodBase): assert ( _is_cpu_amx_available ), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support" - _process_weight_after_loading(layer, ["weight"]) + _amx_process_weight_after_loading(layer, ["weight"]) return layer.weight = Parameter(layer.weight.t(), requires_grad=False) @@ -127,7 +128,7 @@ class W8A8Int8LinearMethod(LinearMethodBase): x: torch.Tensor, bias: Optional[torch.Tensor] = None, ): - if getattr(layer, "use_intel_amx_backend", False): + if use_intel_amx_backend(layer): return torch.ops.sgl_kernel.int8_scaled_mm_with_quant( x, layer.weight, @@ -235,7 +236,7 @@ class W8A8Int8MoEMethod: assert ( _is_cpu_amx_available ), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support" - _process_weight_after_loading(layer, ["w13_weight", "w2_weight"]) + _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"]) return layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False) @@ -284,7 +285,7 @@ class W8A8Int8MoEMethod: routed_scaling_factor=routed_scaling_factor, ) - if getattr(layer, "use_intel_amx_backend", False): + if use_intel_amx_backend(layer): return torch.ops.sgl_kernel.fused_experts_cpu( x, layer.w13_weight, diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index 19a281e48..8e31a621c 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -13,6 +13,7 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) +from sglang.srt.layers.amx_utils import PackWeightMethod from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size from sglang.srt.layers.parameter import BasevLLMParameter from sglang.srt.layers.quantization.base_config import ( @@ -20,12 +21,7 @@ from sglang.srt.layers.quantization.base_config import ( QuantizeMethodBase, method_has_implemented_embedding, ) -from sglang.srt.utils import ( - PackWeightMethod, - cpu_has_amx_support, - is_cpu, - set_weight_attrs, -) +from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs DEFAULT_VOCAB_PADDING_SIZE = 64 diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index d969baf3c..9c618efa5 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -36,6 +36,7 @@ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_r from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.amx_utils import PackWeightMethod from sglang.srt.layers.communicator import ( LayerCommunicator, LayerScatterModes, @@ -91,7 +92,6 @@ from sglang.srt.utils import ( BumpAllocator, DeepEPMode, LazyValue, - PackWeightMethod, add_prefix, bind_or_assign, cpu_has_amx_support, @@ -103,6 +103,7 @@ from sglang.srt.utils import ( is_hip, is_non_idle_and_non_empty, log_info_on_rank0, + use_intel_amx_backend, ) _is_hip = is_hip() @@ -224,7 +225,7 @@ class MoEGate(nn.Module): self.quant_method = PackWeightMethod(weight_names=["weight"]) def forward(self, hidden_states): - if getattr(self, "use_intel_amx_backend", False): + if use_intel_amx_backend(self): return torch.ops.sgl_kernel.weight_packed_linear( hidden_states, self.weight, @@ -437,8 +438,8 @@ class DeepseekV2MoE(nn.Module): return final_hidden_states def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: - if hasattr(self, "shared_experts") and getattr( - self.shared_experts.gate_up_proj, "use_intel_amx_backend", False + if hasattr(self, "shared_experts") and use_intel_amx_backend( + self.shared_experts.gate_up_proj ): return self.forward_cpu(hidden_states) @@ -464,9 +465,9 @@ class DeepseekV2MoE(nn.Module): hidden_states=hidden_states, router_logits=router_logits ) - assert getattr( - self.shared_experts.gate_up_proj, "use_intel_amx_backend", False - ) == getattr(self.shared_experts.down_proj, "use_intel_amx_backend", False) + assert use_intel_amx_backend( + self.shared_experts.gate_up_proj + ) == use_intel_amx_backend(self.shared_experts.down_proj) # [Note] inplace should be False in fused_experts. # If inplace is True in fused_experts (self.experts), hidden_states will be changed after fused_experts # While hidden_states is still needed in shared_expert. @@ -928,15 +929,23 @@ class DeepseekV2AttentionMLA(nn.Module): ) self.weight_block_size = None - if self.qkv_proj_with_rope_is_fp8: - assert ( - self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size - == self.q_b_proj.quant_method.quant_config.weight_block_size - ) - self.weight_block_size = ( - self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size + if self.qkv_proj_with_rope_is_fp8 and _is_cpu and _is_cpu_amx_available: + assert getattr( + self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False + ) == getattr(self.q_b_proj.quant_method, "block_quant", False) + use_block_quant = getattr( + self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False ) + if use_block_quant: + assert ( + self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size + == self.q_b_proj.quant_method.quant_config.weight_block_size + ) + self.weight_block_size = ( + self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size + ) + def dispatch_attn_forward_method( self, forward_batch: ForwardBatch ) -> AttnForwardMethod: @@ -950,8 +959,8 @@ class DeepseekV2AttentionMLA(nn.Module): else: return AttnForwardMethod.MLA else: - if hasattr(self, "fused_qkv_a_proj_with_mqa") and getattr( - self, "use_intel_amx_backend", False + if hasattr(self, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend( + self ): return AttnForwardMethod.MLA_FUSED_ROPE_CPU else: @@ -1426,8 +1435,8 @@ class DeepseekV2AttentionMLA(nn.Module): forward_batch: ForwardBatch, zero_allocator: BumpAllocator, ): - assert self.q_lora_rank is not None and getattr( - self, "use_intel_amx_backend", False + assert self.q_lora_rank is not None and use_intel_amx_backend( + self ), "forward_absorb_fused_mla_rope_cpu_prepare requires q_lora_rank is not None and use_intel_amx_backend" q_input, k_input, v_input = ( @@ -1546,8 +1555,8 @@ class DeepseekV2AttentionMLA(nn.Module): def forward_absorb_fused_mla_rope_cpu_core( self, q_input, k_input, v_input, forward_batch, zero_allocator ): - assert self.q_lora_rank is not None and getattr( - self, "use_intel_amx_backend", False + assert self.q_lora_rank is not None and use_intel_amx_backend( + self ), "forward_absorb_fused_mla_rope_cpu_core requires q_lora_rank is not None and use_intel_amx_backend" attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 608eae654..5761aa8f4 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2416,75 +2416,8 @@ def cpu_has_amx_support(): return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available -def prepack_weight_if_needed(weight): - if weight.device != torch.device("cpu"): - return weight - if not cpu_has_amx_support(): - return weight - - return torch.ops.sgl_kernel.convert_weight_packed(weight) - - -# TODO: currently gemm kernel has the below requirements: -# OC % TILE_N == 0, where TILE_N = 16 -# IC % TILE_K == 0, where TILE_K = 32 -def dim_is_supported(weight): - return weight.size(0) % 16 == 0 and weight.size(1) % 32 == 0 - - -def _process_weight_after_loading(module, weight_names, transpose_dims=None) -> None: - # Pack weight for get better performance on CPU - devices = {getattr(module, weight_name).device for weight_name in weight_names} - assert len(devices) == 1, f"Expects all weights to be on the same device" - device = devices.pop() - - if transpose_dims: - assert len(weight_names) == len( - transpose_dims - ), "len(weight_names) should be equal to len(transpose_dims)" - - for i, weight_name in enumerate(weight_names): - weight_tensor = getattr(module, weight_name) - - # We don't pack weight or use intel amx backend if any weight of this module has unsupported dim. - if not dim_is_supported(weight_tensor): - logger.warning( - f"Expects weight.size(0) % 16 == 0 and weight.size(1) % 32 == 0 " - f"but {weight_tensor.size(0)=} and {weight_tensor.size(1)=} in {module}. " - f"{module} won't use intel amx backend." - ) - module.use_intel_amx_backend = False - return - - if transpose_dims and transpose_dims[i]: - weight_tensor = weight_tensor.transpose(*transpose_dims[i]) - - packed_weight = torch.nn.Parameter( - prepack_weight_if_needed(weight_tensor), - requires_grad=False, - ) - packed_weight.__dict__ = weight_tensor.__dict__ - setattr(module, weight_name, packed_weight) - - module.use_intel_amx_backend = ( - device == torch.device("cpu") and cpu_has_amx_support() - ) - - if ( - module.use_intel_amx_backend - and hasattr(module, "bias") - and module.bias is not None - ): - module.bias = torch.nn.Parameter(module.bias.data.float(), requires_grad=False) - - -class PackWeightMethod: - def __init__(self, weight_names, transpose_dims=None): - self.weight_names = weight_names - self.transpose_dims = transpose_dims - - def process_weights_after_loading(self, module) -> None: - _process_weight_after_loading(module, self.weight_names, self.transpose_dims) +def use_intel_amx_backend(layer): + return getattr(layer, "use_intel_amx_backend", False) class LazyValue: