init support for KTransformers Heterogeneous Computing (#11487)
Co-authored-by: Jianwei Dong <1913953267@qq.com>
This commit is contained in:
@@ -229,6 +229,14 @@ class Envs:
|
||||
SGLANG_IMAGE_MAX_PIXELS = EnvInt(16384 * 28 * 28)
|
||||
SGLANG_RESIZE_RESAMPLE = EnvStr("")
|
||||
|
||||
# Ktransformers
|
||||
SGLANG_KT_MOE_NUM_GPU_EXPERTS = EnvInt(None)
|
||||
SGLANG_KT_MOE_CPUINFER = EnvInt(None)
|
||||
SGLANG_KT_THREADPOOL_COUNT = EnvInt(None)
|
||||
SGLANG_KT_MOE_AMX_WEIGHT_PATH = EnvStr(None)
|
||||
SGLANG_KT_AMX_METHOD = EnvStr(None)
|
||||
SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE = EnvInt(None)
|
||||
|
||||
# fmt: on
|
||||
|
||||
|
||||
|
||||
@@ -33,6 +33,11 @@ from sglang.srt.layers.quantization.base_config import (
|
||||
FusedMoEMethodBase,
|
||||
QuantizationConfig,
|
||||
)
|
||||
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
||||
CompressedTensorsWNA16AMXEPMoEMethod,
|
||||
CompressedTensorsWNA16AMXMoEMethod,
|
||||
CompressedTensorsWNA16MoEMethod,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8MoEMethod
|
||||
from sglang.srt.layers.quantization.modelopt_quant import ModelOptNvFp4FusedMoEMethod
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
|
||||
@@ -150,7 +155,6 @@ class FusedMoE(torch.nn.Module):
|
||||
with_bias=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
@@ -227,6 +231,8 @@ class FusedMoE(torch.nn.Module):
|
||||
if not use_weight_loader_fused
|
||||
else self.weight_loader_fused
|
||||
),
|
||||
intermediate_size_full=intermediate_size,
|
||||
top_k=top_k,
|
||||
with_bias=with_bias,
|
||||
)
|
||||
|
||||
@@ -542,6 +548,18 @@ class FusedMoE(torch.nn.Module):
|
||||
if expert_id == -1:
|
||||
return
|
||||
|
||||
if isinstance(
|
||||
self.quant_method,
|
||||
(
|
||||
CompressedTensorsWNA16MoEMethod,
|
||||
CompressedTensorsWNA16AMXMoEMethod,
|
||||
CompressedTensorsWNA16AMXEPMoEMethod,
|
||||
),
|
||||
):
|
||||
if self.quant_method.num_gpu_experts != -1:
|
||||
if expert_id >= self.quant_method.num_gpu_experts:
|
||||
return
|
||||
|
||||
self._weight_loader_impl(
|
||||
param=param,
|
||||
loaded_weight=loaded_weight,
|
||||
@@ -568,7 +586,12 @@ class FusedMoE(torch.nn.Module):
|
||||
loaded_weight.t().contiguous()
|
||||
if (
|
||||
self.quant_method.__class__.__name__
|
||||
== "CompressedTensorsWNA16MoEMethod"
|
||||
in [
|
||||
"CompressedTensorsWNA16MarlinMoEMethod",
|
||||
"CompressedTensorsWNA16MoEMethod",
|
||||
"CompressedTensorsWNA16AMXMoEMethod",
|
||||
"CompressedTensorsWNA16AMXEPMoEMethod",
|
||||
]
|
||||
)
|
||||
else loaded_weight
|
||||
)
|
||||
@@ -827,7 +850,6 @@ class FusedMoE(torch.nn.Module):
|
||||
dispatch_output=dispatch_output,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
final_hidden_states = self.dispatcher.combine(combine_input)
|
||||
|
||||
# TODO: should we add some conditions here?
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
class scalar_types:
|
||||
uint4b8 = "uint4b8"
|
||||
uint8b128 = "uint8b128"
|
||||
|
||||
|
||||
WNA16_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4b8, 8: scalar_types.uint8b128}
|
||||
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
@@ -19,11 +19,13 @@ from compressed_tensors.quantization import (
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sglang.srt.environ import envs
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
LinearMethodBase,
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.compressed_tensors import WNA16_SUPPORTED_BITS
|
||||
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
|
||||
CompressedTensorsMoEMethod,
|
||||
)
|
||||
@@ -38,6 +40,7 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
||||
is_activation_quantization_format,
|
||||
should_ignore_layer,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
|
||||
try:
|
||||
@@ -76,6 +79,7 @@ class DeviceCapability(NamedTuple):
|
||||
|
||||
|
||||
class CompressedTensorsConfig(QuantizationConfig):
|
||||
DeepSeekFP8Config = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -129,6 +133,10 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
if isinstance(layer, LinearBase):
|
||||
if CompressedTensorsConfig.DeepSeekFP8Config is not None:
|
||||
return Fp8LinearMethod(CompressedTensorsConfig.DeepSeekFP8Config)
|
||||
if envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.is_set():
|
||||
return UnquantizedLinearMethod()
|
||||
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||
if scheme is None:
|
||||
return UnquantizedLinearMethod()
|
||||
@@ -137,7 +145,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
return CompressedTensorsMoEMethod.get_moe_method(self)
|
||||
# Ktransformers use CompressedTensorsWNA16AMXMOEMethod if AMX weights are provided
|
||||
return CompressedTensorsMoEMethod.get_moe_method(self, layer, prefix)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -4,16 +4,34 @@ from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
try:
|
||||
from sgl_kernel import fused_marlin_moe
|
||||
|
||||
FUSED_MARLIN_MOE_AVAILABLE = True
|
||||
except ImportError:
|
||||
FUSED_MARLIN_MOE_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from kt_kernel import AMXMoEWrapper
|
||||
|
||||
KTRANSFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
KTRANSFORMERS_AVAILABLE = False
|
||||
|
||||
import torch
|
||||
from compressed_tensors import CompressionFormat
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
from sglang.srt.environ import envs
|
||||
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
||||
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
||||
from sglang.srt.layers.quantization.compressed_tensors import WNA16_SUPPORTED_BITS
|
||||
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
|
||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||
from sglang.srt.layers.quantization.utils import (
|
||||
@@ -21,7 +39,12 @@ from sglang.srt.layers.quantization.utils import (
|
||||
per_tensor_dequantize,
|
||||
replace_parameter,
|
||||
)
|
||||
from sglang.srt.utils import get_bool_env_var, is_hip, set_weight_attrs
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
get_compiler_backend,
|
||||
is_hip,
|
||||
set_weight_attrs,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
@@ -51,6 +74,18 @@ except ImportError:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _mask_topk_ids_cpu_experts(topk_ids: torch.Tensor, num_gpu_experts: int):
|
||||
"""Mask topk_ids >= num_gpu_experts by setting them to -1."""
|
||||
topk_ids[topk_ids >= num_gpu_experts] = -1
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def mask_cpu_expert_ids(topk_ids: torch.Tensor, num_gpu_experts: int):
|
||||
"""mask CPU expert IDs."""
|
||||
_mask_topk_ids_cpu_experts(topk_ids, num_gpu_experts)
|
||||
return topk_ids
|
||||
|
||||
|
||||
class GPTQMarlinState(Enum):
|
||||
REPACK = enum.auto()
|
||||
READY = enum.auto()
|
||||
@@ -60,6 +95,7 @@ __all__ = [
|
||||
"CompressedTensorsMoEMethod",
|
||||
"CompressedTensorsW8A8Fp8MoEMethod",
|
||||
"CompressedTensorsWNA16MoEMethod",
|
||||
"CompressedTensorsWNA16AMXEPMoEMethod", # for Ktransformers
|
||||
]
|
||||
|
||||
|
||||
@@ -72,12 +108,24 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
@staticmethod
|
||||
def get_moe_method(
|
||||
quant_config: CompressedTensorsConfig,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
) -> "CompressedTensorsMoEMethod":
|
||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||
# are supported + check if the layer is being ignored.
|
||||
|
||||
if envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.is_set():
|
||||
match = re.search(r"(\d+)\.mlp", prefix)
|
||||
if not match:
|
||||
raise ValueError(
|
||||
f"Unable to extract layer number from prefix '{prefix}'. "
|
||||
f"Expected format: '<layer_number>.mlp'"
|
||||
)
|
||||
layer_number = int(match.group(1))
|
||||
return CompressedTensorsWNA16AMXEPMoEMethod(quant_config, layer_number)
|
||||
|
||||
weight_quant = quant_config.target_scheme_map["Linear"].get("weights")
|
||||
input_quant = quant_config.target_scheme_map["Linear"].get("input_activations")
|
||||
|
||||
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError(
|
||||
@@ -201,7 +249,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: FusedMoE) -> None:
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module | FusedMoE) -> None:
|
||||
# Fp8 moe kernels require a single activation scale.
|
||||
# We take the max of all the scales in case they differ.
|
||||
if self.static_input_scales:
|
||||
@@ -349,7 +397,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def __init__(self, quant_config: CompressedTensorsConfig):
|
||||
def __init__(self, quant_config: CompressedTensorsConfig, num_gpu_experts=-1):
|
||||
self.quant_config = quant_config
|
||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||
# are supported + check if the layer is being ignored.
|
||||
@@ -371,6 +419,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
"is supported for the following bits: ",
|
||||
f"{WNA16_SUPPORTED_BITS}",
|
||||
)
|
||||
self.num_gpu_experts = num_gpu_experts
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -381,10 +430,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
|
||||
assert (
|
||||
params_dtype == torch.float16
|
||||
), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
|
||||
if self.num_gpu_experts != -1:
|
||||
num_experts = self.num_gpu_experts
|
||||
# assert (
|
||||
# params_dtype == torch.float16
|
||||
# ), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
|
||||
|
||||
# Will transpose the loaded weight along the
|
||||
# intermediate and hidden dim sizes. Will
|
||||
@@ -683,3 +733,353 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
is_k_full=self.is_k_full,
|
||||
)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
|
||||
class CompressedTensorsWNA16AMXMoEMethod(CompressedTensorsMoEMethod):
|
||||
"""AMX MoE method using AMXMoEWrapper for CPU inference."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
layer_idx,
|
||||
num_gpu_experts,
|
||||
cpuinfer,
|
||||
threadpool_count,
|
||||
amx_weight_path,
|
||||
chunked_prefill_size,
|
||||
):
|
||||
|
||||
if not KTRANSFORMERS_AVAILABLE:
|
||||
raise ImportError(
|
||||
"kt_kernel is not installed, to use CompressedTensorsWNA16AMXEPMoEMethod, please install kt_kernel."
|
||||
)
|
||||
|
||||
if not FUSED_MARLIN_MOE_AVAILABLE:
|
||||
raise ImportError("fused_marlin_moe is not available")
|
||||
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.layer_idx = layer_idx
|
||||
self.num_gpu_experts = num_gpu_experts
|
||||
self.amx_weight_path = amx_weight_path
|
||||
self.chunked_prefill_size = chunked_prefill_size
|
||||
self.cpuinfer = cpuinfer
|
||||
self.threadpool_count = threadpool_count
|
||||
self.amx_wrapper = None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
self.experts_num = num_experts
|
||||
self.num_experts_per_tok = extra_weight_attrs.pop("top_k")
|
||||
self.hidden_size = hidden_size
|
||||
self.moe_intermediate_size = extra_weight_attrs.pop("intermediate_size_full")
|
||||
|
||||
if self.tp_rank != 0:
|
||||
return
|
||||
self.amx_wrapper = AMXMoEWrapper(
|
||||
layer_idx=self.layer_idx,
|
||||
num_experts=num_experts,
|
||||
num_experts_per_tok=self.num_experts_per_tok,
|
||||
hidden_size=hidden_size,
|
||||
moe_intermediate_size=self.moe_intermediate_size,
|
||||
num_gpu_experts=self.num_gpu_experts,
|
||||
cpuinfer_threads=self.cpuinfer,
|
||||
threadpool_count=self.threadpool_count,
|
||||
amx_weight_path=self.amx_weight_path,
|
||||
chunked_prefill_size=self.chunked_prefill_size,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if self.tp_rank != 0:
|
||||
return
|
||||
|
||||
if self.amx_wrapper is None:
|
||||
raise RuntimeError(
|
||||
"AMXMoEWrapper not initialized. Call create_weights first."
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
# Load weights using wrapper
|
||||
from sglang.srt.eplb.expert_location_dispatch import (
|
||||
get_global_expert_location_metadata,
|
||||
)
|
||||
|
||||
physical_to_logical_map_cpu = (
|
||||
get_global_expert_location_metadata()
|
||||
.physical_to_logical_map_cpu[self.layer_idx]
|
||||
.contiguous()
|
||||
)
|
||||
self.amx_wrapper.load_weights(physical_to_logical_map_cpu)
|
||||
|
||||
def submit(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> None:
|
||||
"""Submit AMX inference task asynchronously."""
|
||||
assert (
|
||||
self.moe_runner_config.activation == "silu"
|
||||
), "Only SiLU activation is supported."
|
||||
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
|
||||
if self.tp_rank != 0 or self.amx_wrapper is None:
|
||||
return None
|
||||
|
||||
# Submit forward task using wrapper
|
||||
self.amx_wrapper.submit_forward(
|
||||
x, topk_ids, topk_weights, torch.cuda.current_stream(x.device).cuda_stream
|
||||
)
|
||||
return None
|
||||
|
||||
def sync(self, x):
|
||||
"""Synchronize and retrieve AMX inference results."""
|
||||
if self.tp_rank != 0 or self.amx_wrapper is None:
|
||||
return torch.zeros_like(x)
|
||||
|
||||
# Sync forward task using wrapper
|
||||
return self.amx_wrapper.sync_forward(
|
||||
x, torch.cuda.current_stream(x.device).cuda_stream
|
||||
)
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
"""Execute AMX MoE forward pass synchronously."""
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
|
||||
assert (
|
||||
self.moe_runner_config.activation == "silu"
|
||||
), "Only SiLU activation is supported."
|
||||
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
|
||||
if self.tp_rank != 0 or self.amx_wrapper is None:
|
||||
return StandardCombineInput(hidden_states=torch.zeros_like(x))
|
||||
|
||||
# Execute forward using wrapper (submit + sync)
|
||||
output = self.amx_wrapper.forward(
|
||||
x, topk_ids, topk_weights, torch.cuda.current_stream(x.device).cuda_stream
|
||||
)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
|
||||
def override_config(
|
||||
cls,
|
||||
num_gpu_experts,
|
||||
cpuinfer,
|
||||
threadpool_count,
|
||||
amx_weight_path,
|
||||
amx_method,
|
||||
chunked_prefill_size,
|
||||
):
|
||||
"""Override MOE configuration via environment variables."""
|
||||
# Set environment variables using envs utility class
|
||||
if num_gpu_experts is not None:
|
||||
envs.SGLANG_KT_MOE_NUM_GPU_EXPERTS.set(num_gpu_experts)
|
||||
if cpuinfer is not None:
|
||||
envs.SGLANG_KT_MOE_CPUINFER.set(cpuinfer)
|
||||
if threadpool_count is not None:
|
||||
envs.SGLANG_KT_THREADPOOL_COUNT.set(threadpool_count)
|
||||
if amx_weight_path is not None:
|
||||
envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.set(amx_weight_path)
|
||||
if amx_method is not None:
|
||||
envs.SGLANG_KT_AMX_METHOD.set(amx_method)
|
||||
if chunked_prefill_size is not None:
|
||||
envs.SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE.set(chunked_prefill_size)
|
||||
|
||||
|
||||
class CompressedTensorsWNA16AMXEPMoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
layer_idx,
|
||||
):
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
if (
|
||||
not envs.SGLANG_KT_MOE_NUM_GPU_EXPERTS.is_set()
|
||||
or not envs.SGLANG_KT_MOE_CPUINFER.is_set()
|
||||
or not envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.is_set()
|
||||
):
|
||||
raise RuntimeError(
|
||||
"the following arguments are required: --kt-amx-weight-path, --kt-cpuinfer, --kt-num-gpu-experts"
|
||||
)
|
||||
self.num_gpu_experts = envs.SGLANG_KT_MOE_NUM_GPU_EXPERTS.value
|
||||
cpuinfer = envs.SGLANG_KT_MOE_CPUINFER.value
|
||||
threadpool_count = envs.SGLANG_KT_THREADPOOL_COUNT.value
|
||||
amx_weight_path = envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.value
|
||||
chunked_prefill_size = envs.SGLANG_KT_MOE_CHUNKED_PREFILL_SIZE.value
|
||||
|
||||
self.AMX_method = CompressedTensorsWNA16AMXMoEMethod(
|
||||
quant_config,
|
||||
layer_idx,
|
||||
self.num_gpu_experts,
|
||||
cpuinfer,
|
||||
threadpool_count,
|
||||
amx_weight_path,
|
||||
chunked_prefill_size,
|
||||
)
|
||||
self.marlin_method = CompressedTensorsWNA16MoEMethod(
|
||||
quant_config, self.num_gpu_experts
|
||||
)
|
||||
self.layer_id = layer_idx
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
self.global_num_experts = num_experts
|
||||
self.AMX_method.create_weights(
|
||||
layer,
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
params_dtype,
|
||||
**extra_weight_attrs,
|
||||
)
|
||||
self.marlin_method.create_weights(
|
||||
layer,
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
params_dtype,
|
||||
**extra_weight_attrs,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.AMX_method.process_weights_after_loading(layer)
|
||||
self.marlin_method.process_weights_after_loading(layer)
|
||||
|
||||
def submit(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
"""Submit hybrid GPU+CPU MoE task (AMX submission + GPU execution)."""
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
|
||||
assert (
|
||||
self.moe_runner_config.activation == "silu"
|
||||
), "Only SiLU activation is supported."
|
||||
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
|
||||
topk_weights, topk_ids, router_logits = topk_output
|
||||
|
||||
# Submit AMX task if on rank 0
|
||||
if self.tp_rank == 0:
|
||||
self.AMX_method.submit(layer, dispatch_output)
|
||||
|
||||
# Mask CPU expert IDs (>= num_gpu_experts) as -1 so they won't be computed on GPU
|
||||
topk_ids = mask_cpu_expert_ids(topk_ids, self.num_gpu_experts)
|
||||
|
||||
# Execute GPU (Marlin) experts
|
||||
output = fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight_packed,
|
||||
layer.w2_weight_packed,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
g_idx1=layer.w13_weight_g_idx,
|
||||
g_idx2=layer.w2_weight_g_idx,
|
||||
sort_indices1=layer.w13_g_idx_sort_indices,
|
||||
sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
num_bits=self.marlin_method.num_bits,
|
||||
is_k_full=self.marlin_method.is_k_full,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=torch.empty(1, device=x.device),
|
||||
)
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
def sync(self, x):
|
||||
"""Synchronize and retrieve AMX results."""
|
||||
if self.tp_rank != 0:
|
||||
return torch.zeros_like(x)
|
||||
return self.AMX_method.sync(x)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
dispatch_output: StandardDispatchOutput,
|
||||
) -> CombineInput:
|
||||
"""Execute hybrid GPU+CPU MoE forward pass with parallelism."""
|
||||
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
|
||||
|
||||
assert (
|
||||
self.moe_runner_config.activation == "silu"
|
||||
), "Only SiLU activation is supported."
|
||||
|
||||
x = dispatch_output.hidden_states
|
||||
topk_output = dispatch_output.topk_output
|
||||
topk_weights, topk_ids, router_logits = topk_output
|
||||
|
||||
# Step 1: Submit AMX task (non-blocking) if on rank 0
|
||||
# This starts CPU computation in parallel
|
||||
if self.tp_rank == 0:
|
||||
self.AMX_method.submit(layer, dispatch_output)
|
||||
|
||||
# Step 2: Execute GPU (Marlin) experts in parallel with CPU
|
||||
|
||||
# Mask CPU expert IDs (>= num_gpu_experts) as -1 so they won't be computed on GPU
|
||||
topk_ids = mask_cpu_expert_ids(topk_ids, self.num_gpu_experts)
|
||||
|
||||
# While GPU computes, CPU is also computing
|
||||
output = fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight_packed,
|
||||
layer.w2_weight_packed,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
g_idx1=layer.w13_weight_g_idx,
|
||||
g_idx2=layer.w2_weight_g_idx,
|
||||
sort_indices1=layer.w13_g_idx_sort_indices,
|
||||
sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
num_bits=self.marlin_method.num_bits,
|
||||
is_k_full=self.marlin_method.is_k_full,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=torch.empty(1, device=x.device),
|
||||
)
|
||||
|
||||
# Step 3: Sync AMX results and combine with GPU results
|
||||
if self.tp_rank == 0:
|
||||
amx_output = self.AMX_method.sync(x)
|
||||
output += amx_output
|
||||
|
||||
return StandardCombineInput(hidden_states=output)
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
):
|
||||
self.moe_runner_config = moe_runner_config
|
||||
self.AMX_method.create_moe_runner(layer, moe_runner_config)
|
||||
|
||||
@@ -65,6 +65,13 @@ from sglang.srt.utils import (
|
||||
)
|
||||
from sglang.srt.utils.patch_torch import monkey_patch_torch_compile
|
||||
|
||||
try:
|
||||
from kt_kernel import AMXMoEWrapper
|
||||
|
||||
KTRANSFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
KTRANSFORMERS_AVAILABLE = False
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -248,6 +255,8 @@ class CudaGraphRunner:
|
||||
# Batch sizes to capture
|
||||
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
||||
log_info_on_rank0(logger, f"Capture cuda graph bs {self.capture_bs}")
|
||||
if KTRANSFORMERS_AVAILABLE:
|
||||
AMXMoEWrapper.set_capture_batch_sizes(self.capture_bs)
|
||||
self.capture_forward_mode = ForwardMode.DECODE
|
||||
self.capture_hidden_mode = CaptureHiddenMode.NULL
|
||||
self.num_tokens_per_bs = 1
|
||||
|
||||
@@ -44,6 +44,7 @@ from sglang.srt.distributed import (
|
||||
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
||||
use_symmetric_memory,
|
||||
)
|
||||
from sglang.srt.environ import envs
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
||||
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||
@@ -81,7 +82,12 @@ from sglang.srt.layers.moe import (
|
||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
|
||||
from sglang.srt.layers.quantization import CompressedTensorsConfig
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
||||
CompressedTensorsWNA16AMXEPMoEMethod,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
is_fp8_fnuz,
|
||||
per_tensor_quant_mla_fp8,
|
||||
@@ -707,6 +713,10 @@ class DeepseekV2MoE(nn.Module):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
if isinstance(
|
||||
self.experts.quant_method, CompressedTensorsWNA16AMXEPMoEMethod
|
||||
):
|
||||
topk_output.topk_weights.mul_(self.routed_scaling_factor)
|
||||
final_hidden_states = self.experts(hidden_states, topk_output)
|
||||
if not _is_cuda:
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
@@ -2837,6 +2847,10 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.quant_config = quant_config
|
||||
if envs.SGLANG_KT_MOE_AMX_WEIGHT_PATH.is_set():
|
||||
CompressedTensorsConfig.DeepSeekFP8Config = Fp8Config(
|
||||
True, "dynamic", None, [128, 128]
|
||||
)
|
||||
self.determine_num_fused_shared_experts()
|
||||
self.model = DeepseekV2Model(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
@@ -2976,11 +2990,13 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
):
|
||||
if (
|
||||
hasattr(self.quant_config, "weight_block_size")
|
||||
and self.quant_config.weight_block_size is not None
|
||||
):
|
||||
weight_block_size = self.quant_config.weight_block_size
|
||||
selected_quant_config = getattr(
|
||||
self.quant_config, "DeepSeekFP8Config", self.quant_config
|
||||
)
|
||||
weight_block_size = getattr(
|
||||
selected_quant_config, "weight_block_size", None
|
||||
)
|
||||
if weight_block_size is not None:
|
||||
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
||||
if _is_fp8_fnuz:
|
||||
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
||||
|
||||
@@ -520,6 +520,7 @@ class Qwen3HybridLinearDecoderLayer(nn.Module):
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
alt_stream=alt_stream,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
else:
|
||||
self.mlp = Qwen2MoeMLP(
|
||||
@@ -673,6 +674,7 @@ class Qwen3HybridAttentionDecoderLayer(nn.Module):
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
alt_stream=alt_stream,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
else:
|
||||
self.mlp = Qwen2MoeMLP(
|
||||
|
||||
@@ -91,6 +91,7 @@ QUANTIZATION_CHOICES = [
|
||||
"qoq",
|
||||
"w4afp8",
|
||||
"mxfp4",
|
||||
"compressed-tensors", # for Ktransformers
|
||||
]
|
||||
|
||||
ATTENTION_BACKEND_CHOICES = [
|
||||
@@ -389,6 +390,13 @@ class ServerArgs:
|
||||
# LMCache
|
||||
enable_lmcache: bool = False
|
||||
|
||||
# Ktransformers
|
||||
kt_amx_weight_path: Optional[str] = None
|
||||
kt_amx_method: Optional[str] = None
|
||||
kt_cpuinfer: Optional[int] = None
|
||||
kt_threadpool_count: Optional[int] = None
|
||||
kt_num_gpu_experts: Optional[int] = None
|
||||
|
||||
# Double Sparsity
|
||||
enable_double_sparsity: bool = False
|
||||
ds_channel_config_path: Optional[str] = None
|
||||
@@ -544,6 +552,9 @@ class ServerArgs:
|
||||
self._handle_amd_specifics()
|
||||
self._handle_grammar_backend()
|
||||
|
||||
# Handle Ktransformers specific configs
|
||||
self._handle_ktransformers_configs()
|
||||
|
||||
# Handle data parallelism.
|
||||
self._handle_data_parallelism()
|
||||
|
||||
@@ -595,6 +606,22 @@ class ServerArgs:
|
||||
)
|
||||
self.tool_call_parser = deprecated_tool_call_parsers[self.tool_call_parser]
|
||||
|
||||
def _handle_ktransformers_configs(self):
|
||||
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
||||
CompressedTensorsWNA16AMXEPMoEMethod,
|
||||
override_config,
|
||||
)
|
||||
|
||||
override_config(
|
||||
CompressedTensorsWNA16AMXEPMoEMethod,
|
||||
self.kt_num_gpu_experts,
|
||||
self.kt_cpuinfer,
|
||||
self.kt_threadpool_count,
|
||||
self.kt_amx_weight_path,
|
||||
self.kt_amx_method,
|
||||
self.chunked_prefill_size,
|
||||
)
|
||||
|
||||
def _handle_missing_default_values(self):
|
||||
if self.tokenizer_path is None:
|
||||
self.tokenizer_path = self.model_path
|
||||
@@ -1518,6 +1545,7 @@ class ServerArgs:
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
|
||||
# Model and tokenizer
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
@@ -2675,6 +2703,35 @@ class ServerArgs:
|
||||
help="Using LMCache as an alternative hierarchical cache solution",
|
||||
)
|
||||
|
||||
# Ktransformer server args
|
||||
parser.add_argument(
|
||||
"--kt-amx-weight-path",
|
||||
type=str,
|
||||
help="[ktransformers parameter] The path of the quantized expert weights for amx kernel. A local folder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kt-amx-method",
|
||||
type=str,
|
||||
default="AMXINT4",
|
||||
help="[ktransformers parameter] Quantization formats for CPU execution.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kt-cpuinfer",
|
||||
type=int,
|
||||
help="[ktransformers parameter] The number of CPUInfer threads.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kt-threadpool-count",
|
||||
type=int,
|
||||
default=2,
|
||||
help="[ktransformers parameter] One-to-one with the number of NUMA nodes (one thread pool per NUMA).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kt-num-gpu-experts",
|
||||
type=int,
|
||||
help="[ktransformers parameter] The number of GPU experts.",
|
||||
)
|
||||
|
||||
# Double Sparsity
|
||||
parser.add_argument(
|
||||
"--enable-double-sparsity",
|
||||
|
||||
Reference in New Issue
Block a user