This commit is contained in:
@@ -42,22 +42,10 @@ from sglang.srt.layers.moe import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import is_cuda, is_flashinfer_available, is_sm100_supported
|
||||||
get_bool_env_var,
|
|
||||||
is_cuda,
|
|
||||||
is_flashinfer_available,
|
|
||||||
is_gfx95_supported,
|
|
||||||
is_hip,
|
|
||||||
is_sm100_supported,
|
|
||||||
)
|
|
||||||
|
|
||||||
_is_flashinfer_available = is_flashinfer_available()
|
_is_flashinfer_available = is_flashinfer_available()
|
||||||
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
||||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
|
||||||
_is_gfx95_supported = is_gfx95_supported()
|
|
||||||
|
|
||||||
if _use_aiter and _is_gfx95_supported:
|
|
||||||
from sglang.srt.layers.quantization.rocm_mxfp4_utils import fused_rms_mxfp4_quant
|
|
||||||
|
|
||||||
FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
|
FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
|
||||||
|
|
||||||
@@ -213,7 +201,6 @@ class LayerCommunicator:
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
residual: torch.Tensor,
|
residual: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
qaunt_format: str = "",
|
|
||||||
):
|
):
|
||||||
if hidden_states.shape[0] == 0:
|
if hidden_states.shape[0] == 0:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@@ -231,34 +218,11 @@ class LayerCommunicator:
|
|||||||
else:
|
else:
|
||||||
if residual is None:
|
if residual is None:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
|
|
||||||
hidden_states = fused_rms_mxfp4_quant(
|
|
||||||
hidden_states,
|
|
||||||
self.input_layernorm.weight,
|
|
||||||
self.input_layernorm.variance_epsilon,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
|
||||||
else:
|
else:
|
||||||
if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
|
hidden_states, residual = self.input_layernorm(
|
||||||
hidden_states, residual = fused_rms_mxfp4_quant(
|
hidden_states, residual
|
||||||
hidden_states,
|
)
|
||||||
self.input_layernorm.weight,
|
|
||||||
self.input_layernorm.variance_epsilon,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
residual,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states, residual = self.input_layernorm(
|
|
||||||
hidden_states, residual
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self._communicate_simple_fn(
|
hidden_states = self._communicate_simple_fn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import torch.nn.functional as F
|
|||||||
from aiter.ops.gemm_op_a4w4 import gemm_a4w4
|
from aiter.ops.gemm_op_a4w4 import gemm_a4w4
|
||||||
from aiter.ops.shuffle import shuffle_weight
|
from aiter.ops.shuffle import shuffle_weight
|
||||||
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
|
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
|
||||||
from aiter.ops.triton.gemm_afp4wfp4_pre_quant_atomic import gemm_afp4wfp4_pre_quant
|
|
||||||
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
||||||
from aiter.utility import dtypes
|
from aiter.utility import dtypes
|
||||||
from aiter.utility.fp4_utils import e8m0_shuffle
|
from aiter.utility.fp4_utils import e8m0_shuffle
|
||||||
@@ -39,6 +38,15 @@ class QuarkW4A4MXFP4(QuarkScheme):
|
|||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# for aiter implement
|
||||||
|
# wshuffle = shuffle_weight(layer.weight.data, layout=(16, 16))
|
||||||
|
# w_scales_shuffle = e8m0_shuffle(layer.weight_scale.data).view(dtypes.fp8_e8m0)
|
||||||
|
|
||||||
|
# layer.weight = torch.nn.Parameter(wshuffle,
|
||||||
|
# requires_grad=False)
|
||||||
|
# layer.weight_scale = torch.nn.Parameter(w_scales_shuffle,
|
||||||
|
# requires_grad=False)
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@@ -85,53 +93,26 @@ class QuarkW4A4MXFP4(QuarkScheme):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None,
|
bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# This path does not have support for bias currently
|
|
||||||
assert bias is None, "bias is not supported"
|
|
||||||
|
|
||||||
three_d = False
|
out_dtype = x.dtype
|
||||||
x_s = None
|
# M = x.shape[0]
|
||||||
y = None
|
# N = layer.weight.shape[0]
|
||||||
if isinstance(x, tuple):
|
|
||||||
assert len(x) in [
|
|
||||||
2,
|
|
||||||
3,
|
|
||||||
], "For tuple input, only (x, x_s) or (x, x_s, y) formats are accepted"
|
|
||||||
if len(x) == 2:
|
|
||||||
x, x_s = x
|
|
||||||
elif len(x) == 3:
|
|
||||||
x, x_s, y = x
|
|
||||||
|
|
||||||
use_fused_quant_gemm = (
|
# quant_func = aiter.get_triton_quant(aiter.QuantType.per_1x32)
|
||||||
x_s is None and y is not None and layer.weight.shape[0] == y.shape[1]
|
# x, x_scales_shuffle = quant_func(x, shuffle=True)
|
||||||
|
|
||||||
|
# y = torch.zeros((M + 255) // 256 * 256, N, device=x.device, dtype=self.out_dtype)
|
||||||
|
|
||||||
|
# out = gemm_a4w4(x, layer.weight.data, x_scales_shuffle, layer.weight_scale.data, y, bias=bias)
|
||||||
|
|
||||||
|
# return out[:M]
|
||||||
|
|
||||||
|
# triton implement
|
||||||
|
x_q, x_s = dynamic_mxfp4_quant(x)
|
||||||
|
y = torch.empty(
|
||||||
|
x_q.shape[0], layer.weight.shape[0], device=x_q.device, dtype=out_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
if x.dim() == 3:
|
out = gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, out_dtype, y)
|
||||||
three_d = True
|
|
||||||
x = x.view(-1, x.shape[-1])
|
|
||||||
output_shape = [*x.shape[:-1], layer.weight.shape[0]]
|
|
||||||
|
|
||||||
# use_fused_quant_gemm = true, x_q is a bf16/fp16 num
|
return out
|
||||||
# x_s is not None = true, x_q is uint8 num
|
|
||||||
if use_fused_quant_gemm or x_s is not None:
|
|
||||||
x_q = x
|
|
||||||
else:
|
|
||||||
x_q, x_s = dynamic_mxfp4_quant(x)
|
|
||||||
|
|
||||||
if y is None:
|
|
||||||
y = torch.empty(
|
|
||||||
x_q.shape[0],
|
|
||||||
layer.weight.shape[0],
|
|
||||||
device=x_q.device,
|
|
||||||
dtype=self.out_dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_fused_quant_gemm:
|
|
||||||
gemm_afp4wfp4_pre_quant(x_q, layer.weight, layer.weight_scale, y.dtype, y)
|
|
||||||
y = y.to(x.dtype)
|
|
||||||
else:
|
|
||||||
gemm_afp4wfp4(x_q, layer.weight, x_s, layer.weight_scale, self.out_dtype, y)
|
|
||||||
|
|
||||||
if three_d:
|
|
||||||
return y.view(*output_shape)
|
|
||||||
|
|
||||||
return y
|
|
||||||
|
|||||||
@@ -5,10 +5,6 @@ from collections.abc import Iterable, Mapping
|
|||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
|
||||||
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
|
|
||||||
def deep_compare(dict1: Any, dict2: Any) -> bool:
|
def deep_compare(dict1: Any, dict2: Any) -> bool:
|
||||||
if type(dict1) is not type(dict2):
|
if type(dict1) is not type(dict2):
|
||||||
@@ -109,96 +105,3 @@ def _is_equal_or_regex_match(
|
|||||||
elif target == value:
|
elif target == value:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
# utility for tensor dims > 2 cases
|
|
||||||
def b_dynamic_mxfp4_quant(x):
|
|
||||||
h, b, d = x.shape
|
|
||||||
x, x_scales = dynamic_mxfp4_quant(x.reshape(-1, d))
|
|
||||||
return x.view(h, b, d // 2), x_scales.view(h, b, d // 32)
|
|
||||||
|
|
||||||
|
|
||||||
def mxfp4_to_f32(x, is_threed):
|
|
||||||
# 2 because we pack fp4 in uint8.
|
|
||||||
x = x.repeat_interleave(2, dim=-1)
|
|
||||||
if is_threed:
|
|
||||||
x[..., ::2] = x[..., ::2] & 0xF
|
|
||||||
x[..., 1::2] = x[..., 1::2] >> 4
|
|
||||||
else:
|
|
||||||
x[:, ::2] = x[:, ::2] & 0xF
|
|
||||||
x[:, 1::2] = x[:, 1::2] >> 4
|
|
||||||
|
|
||||||
mxfp4_list = [
|
|
||||||
0.0,
|
|
||||||
0.5,
|
|
||||||
1.0,
|
|
||||||
1.5,
|
|
||||||
2.0,
|
|
||||||
3.0,
|
|
||||||
4.0,
|
|
||||||
6.0,
|
|
||||||
-0.0,
|
|
||||||
-0.5,
|
|
||||||
-1.0,
|
|
||||||
-1.5,
|
|
||||||
-2.0,
|
|
||||||
-3.0,
|
|
||||||
-4.0,
|
|
||||||
-6.0,
|
|
||||||
]
|
|
||||||
mxfp4_in_f32 = torch.tensor(mxfp4_list, dtype=torch.float32, device="cuda")
|
|
||||||
return mxfp4_in_f32[x.long()]
|
|
||||||
|
|
||||||
|
|
||||||
def e8m0_to_f32(x):
|
|
||||||
# Convert the input tensor `x` (assumed to be in e8m0 format) to float32.
|
|
||||||
# e8m0 is a custom 8-bit floating point format with 8 bits for exponent, 0 for mantissa.
|
|
||||||
# This means the value is essentially 2^(exponent - 127), similar to how IEEE-754 stores floats.
|
|
||||||
|
|
||||||
# Convert x to float32 for computation, and compute the power of 2 by subtracting the bias (127).
|
|
||||||
x_f32 = 2 ** ((x.to(torch.float32)) - 127)
|
|
||||||
|
|
||||||
# If the exponent value was 255 (i.e., 2^(128)), this is a special case usually used to represent NaN or Inf.
|
|
||||||
# Since this custom format has no mantissa, treat 2^128 as NaN.
|
|
||||||
x_f32[x_f32 == 128] = float("nan")
|
|
||||||
return x_f32
|
|
||||||
|
|
||||||
|
|
||||||
def quark_post_load_weights(self_attn: nn.Module, w: torch.Tensor, quant_format: str):
|
|
||||||
if "mxfp4" in quant_format:
|
|
||||||
# when dtype is bf16, the processing flow is to dynamic quantize bf16 tensor to uint8 tensor
|
|
||||||
# do w_kc (bf16) first to get the w_kc(uint8) w_s_kc(uint8)
|
|
||||||
# and w_vc repeating the same procedure of w_kc to get w_vc(uint8) w_s_vc(uint8)
|
|
||||||
if w.dtype == torch.bfloat16:
|
|
||||||
w_kc, w_vc = w.unflatten(
|
|
||||||
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
|
||||||
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
|
||||||
w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1))
|
|
||||||
w_kc = w_kc.transpose(-2, -1)
|
|
||||||
w_s_kc = w_s_kc.transpose(-2, -1)
|
|
||||||
w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc)
|
|
||||||
w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
|
||||||
w_s_vc = w_s_vc.contiguous().transpose(1, 2)
|
|
||||||
elif w.dtype == torch.uint8: # static quant for mxfp4
|
|
||||||
# when dtype is uint8, it means the w has been quantized to mxfp4 format
|
|
||||||
# but we must separate it to w_kc and w_vc.
|
|
||||||
# The quantized tensor size is only half of original tensor size
|
|
||||||
# and the scaling factor is 1/32, the transpose behavior will be not correct
|
|
||||||
# need to upcast it to fp32 to separate w to w_kc and w_vc
|
|
||||||
# to ensure the following transpose behavior is correct
|
|
||||||
# and then do mxfp4 quant again
|
|
||||||
w = mxfp4_to_f32(w, True).to(torch.bfloat16)
|
|
||||||
w_scales = self_attn.kv_b_proj.weight_scale.repeat_interleave(32, dim=-1)
|
|
||||||
w_scales = e8m0_to_f32(w_scales).to(torch.bfloat16)
|
|
||||||
w = w * w_scales
|
|
||||||
w_kc, w_vc = w.unflatten(
|
|
||||||
0, (-1, (self_attn.qk_nope_head_dim + self_attn.v_head_dim))
|
|
||||||
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
|
||||||
w_kc, w_s_kc = b_dynamic_mxfp4_quant(w_kc.transpose(-2, -1))
|
|
||||||
w_kc = w_kc.transpose(-2, -1)
|
|
||||||
w_s_kc = w_s_kc.transpose(-2, -1)
|
|
||||||
w_vc, w_s_vc = b_dynamic_mxfp4_quant(w_vc)
|
|
||||||
w_s_kc = w_s_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
|
||||||
w_s_vc = w_s_vc.contiguous().transpose(1, 2)
|
|
||||||
|
|
||||||
return w_kc, w_s_kc, w_vc, w_s_vc
|
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
from aiter.ops.triton.batched_gemm_afp4wfp4_pre_quant import (
|
|
||||||
batched_gemm_afp4wfp4_pre_quant,
|
|
||||||
)
|
|
||||||
from aiter.ops.triton.fused_mxfp4_quant import (
|
|
||||||
fused_flatten_mxfp4_quant,
|
|
||||||
fused_rms_mxfp4_quant,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"fused_rms_mxfp4_quant",
|
|
||||||
"fused_flatten_mxfp4_quant",
|
|
||||||
"batched_gemm_afp4wfp4_pre_quant",
|
|
||||||
]
|
|
||||||
@@ -1,44 +0,0 @@
|
|||||||
import torch
|
|
||||||
from aiter.ops.triton.fused_qk_concat import fused_qk_rope_cat
|
|
||||||
from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
|
|
||||||
from aiter.ops.triton.gemm_a16w16_atomic import gemm_a16w16_atomic
|
|
||||||
|
|
||||||
from sglang.srt.utils import BumpAllocator
|
|
||||||
|
|
||||||
__all__ = ["fused_qk_rope_cat"]
|
|
||||||
|
|
||||||
|
|
||||||
def aiter_dsv3_router_gemm(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
gemm_output_zero_allocator: BumpAllocator = None,
|
|
||||||
):
|
|
||||||
M = hidden_states.shape[0]
|
|
||||||
N = weight.shape[0]
|
|
||||||
y = None
|
|
||||||
|
|
||||||
if M <= 256:
|
|
||||||
# TODO (cagri): convert to bfloat16 as part of another kernel to save time
|
|
||||||
# for now it is also coupled with zero allocator.
|
|
||||||
if gemm_output_zero_allocator != None:
|
|
||||||
y = gemm_output_zero_allocator.allocate(M * N).view(M, N)
|
|
||||||
else:
|
|
||||||
y = torch.zeros((M, N), dtype=torch.float32, device=hidden_states.device)
|
|
||||||
|
|
||||||
if y is not None:
|
|
||||||
logits = gemm_a16w16_atomic(hidden_states, weight, y=y).to(hidden_states.dtype)
|
|
||||||
else:
|
|
||||||
logits = gemm_a16w16(hidden_states, weight)
|
|
||||||
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
def get_dsv3_gemm_output_zero_allocator_size(
|
|
||||||
n_routed_experts: int, num_moe_layers: int, allocate_size: int, embedding_dim: int
|
|
||||||
):
|
|
||||||
if embedding_dim != 7168 or n_routed_experts != 256:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
per_layer_size = 256 * (allocate_size + n_routed_experts)
|
|
||||||
|
|
||||||
return num_moe_layers * per_layer_size
|
|
||||||
@@ -112,7 +112,6 @@ from sglang.srt.utils import (
|
|||||||
is_cpu,
|
is_cpu,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
is_flashinfer_available,
|
is_flashinfer_available,
|
||||||
is_gfx95_supported,
|
|
||||||
is_hip,
|
is_hip,
|
||||||
is_non_idle_and_non_empty,
|
is_non_idle_and_non_empty,
|
||||||
is_npu,
|
is_npu,
|
||||||
@@ -130,22 +129,6 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
|||||||
_is_cpu_amx_available = cpu_has_amx_support()
|
_is_cpu_amx_available = cpu_has_amx_support()
|
||||||
_is_cpu = is_cpu()
|
_is_cpu = is_cpu()
|
||||||
_device_sm = get_device_sm()
|
_device_sm = get_device_sm()
|
||||||
_is_gfx95_supported = is_gfx95_supported()
|
|
||||||
|
|
||||||
_use_aiter_gfx95 = _use_aiter and _is_gfx95_supported
|
|
||||||
|
|
||||||
if _use_aiter_gfx95:
|
|
||||||
from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights
|
|
||||||
from sglang.srt.layers.quantization.rocm_mxfp4_utils import (
|
|
||||||
batched_gemm_afp4wfp4_pre_quant,
|
|
||||||
fused_flatten_mxfp4_quant,
|
|
||||||
fused_rms_mxfp4_quant,
|
|
||||||
)
|
|
||||||
from sglang.srt.layers.rocm_linear_utils import (
|
|
||||||
aiter_dsv3_router_gemm,
|
|
||||||
fused_qk_rope_cat,
|
|
||||||
get_dsv3_gemm_output_zero_allocator_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel import (
|
from sgl_kernel import (
|
||||||
@@ -241,17 +224,10 @@ class DeepseekV2MLP(nn.Module):
|
|||||||
forward_batch=None,
|
forward_batch=None,
|
||||||
should_allreduce_fusion: bool = False,
|
should_allreduce_fusion: bool = False,
|
||||||
use_reduce_scatter: bool = False,
|
use_reduce_scatter: bool = False,
|
||||||
gemm_output_zero_allocator: BumpAllocator = None,
|
|
||||||
):
|
):
|
||||||
if (self.tp_size == 1) and x.shape[0] == 0:
|
if (self.tp_size == 1) and x.shape[0] == 0:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
if gemm_output_zero_allocator != None and x.shape[0] <= 256:
|
|
||||||
y = gemm_output_zero_allocator.allocate(
|
|
||||||
x.shape[0] * self.gate_up_proj.output_size_per_partition
|
|
||||||
).view(x.shape[0], self.gate_up_proj.output_size_per_partition)
|
|
||||||
x = (x, None, y)
|
|
||||||
|
|
||||||
gate_up, _ = self.gate_up_proj(x)
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
x = self.act_fn(gate_up)
|
x = self.act_fn(gate_up)
|
||||||
x, _ = self.down_proj(
|
x, _ = self.down_proj(
|
||||||
@@ -281,7 +257,7 @@ class MoEGate(nn.Module):
|
|||||||
if _is_cpu and _is_cpu_amx_available:
|
if _is_cpu and _is_cpu_amx_available:
|
||||||
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
||||||
|
|
||||||
def forward(self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None):
|
def forward(self, hidden_states):
|
||||||
if use_intel_amx_backend(self):
|
if use_intel_amx_backend(self):
|
||||||
return torch.ops.sgl_kernel.weight_packed_linear(
|
return torch.ops.sgl_kernel.weight_packed_linear(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -300,10 +276,6 @@ class MoEGate(nn.Module):
|
|||||||
):
|
):
|
||||||
# router gemm output float32
|
# router gemm output float32
|
||||||
logits = dsv3_router_gemm(hidden_states, self.weight)
|
logits = dsv3_router_gemm(hidden_states, self.weight)
|
||||||
elif _use_aiter_gfx95 and hidden_states.shape[0] <= 256:
|
|
||||||
logits = aiter_dsv3_router_gemm(
|
|
||||||
hidden_states, self.weight, gemm_output_zero_allocator
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logits = F.linear(hidden_states, self.weight, None)
|
logits = F.linear(hidden_states, self.weight, None)
|
||||||
|
|
||||||
@@ -467,7 +439,6 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
forward_batch: Optional[ForwardBatch] = None,
|
forward_batch: Optional[ForwardBatch] = None,
|
||||||
should_allreduce_fusion: bool = False,
|
should_allreduce_fusion: bool = False,
|
||||||
use_reduce_scatter: bool = False,
|
use_reduce_scatter: bool = False,
|
||||||
gemm_output_zero_allocator: BumpAllocator = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if not self._enable_deepep_moe:
|
if not self._enable_deepep_moe:
|
||||||
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
||||||
@@ -481,14 +452,12 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
should_allreduce_fusion,
|
should_allreduce_fusion,
|
||||||
use_reduce_scatter,
|
use_reduce_scatter,
|
||||||
gemm_output_zero_allocator,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.forward_normal(
|
return self.forward_normal(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
should_allreduce_fusion,
|
should_allreduce_fusion,
|
||||||
use_reduce_scatter,
|
use_reduce_scatter,
|
||||||
gemm_output_zero_allocator,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.forward_deepep(hidden_states, forward_batch)
|
return self.forward_deepep(hidden_states, forward_batch)
|
||||||
@@ -498,7 +467,6 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
should_allreduce_fusion: bool = False,
|
should_allreduce_fusion: bool = False,
|
||||||
use_reduce_scatter: bool = False,
|
use_reduce_scatter: bool = False,
|
||||||
gemm_output_zero_allocator: BumpAllocator = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
current_stream = torch.cuda.current_stream()
|
current_stream = torch.cuda.current_stream()
|
||||||
@@ -507,7 +475,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
|
|
||||||
with torch.cuda.stream(self.alt_stream):
|
with torch.cuda.stream(self.alt_stream):
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
|
router_logits = self.gate(hidden_states)
|
||||||
topk_output = self.topk(hidden_states, router_logits)
|
topk_output = self.topk(hidden_states, router_logits)
|
||||||
final_hidden_states = self.experts(hidden_states, topk_output)
|
final_hidden_states = self.experts(hidden_states, topk_output)
|
||||||
if not _is_cuda:
|
if not _is_cuda:
|
||||||
@@ -534,7 +502,6 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
should_allreduce_fusion: bool = False,
|
should_allreduce_fusion: bool = False,
|
||||||
use_reduce_scatter: bool = False,
|
use_reduce_scatter: bool = False,
|
||||||
gemm_output_zero_allocator: BumpAllocator = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
||||||
self.shared_experts.gate_up_proj
|
self.shared_experts.gate_up_proj
|
||||||
@@ -544,7 +511,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
if hidden_states.shape[0] > 0:
|
if hidden_states.shape[0] > 0:
|
||||||
shared_output = self._forward_shared_experts(hidden_states)
|
shared_output = self._forward_shared_experts(hidden_states)
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
|
router_logits = self.gate(hidden_states)
|
||||||
topk_output = self.topk(hidden_states, router_logits)
|
topk_output = self.topk(hidden_states, router_logits)
|
||||||
else:
|
else:
|
||||||
shared_output = None
|
shared_output = None
|
||||||
@@ -1130,19 +1097,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
if self.attn_mha.kv_b_proj is None:
|
if self.attn_mha.kv_b_proj is None:
|
||||||
self.attn_mha.kv_b_proj = self.kv_b_proj
|
self.attn_mha.kv_b_proj = self.kv_b_proj
|
||||||
|
|
||||||
# when hidden_states is a tuple of tensors, the tuple will include quantized weight and scale tensor
|
if hidden_states.shape[0] == 0:
|
||||||
if isinstance(hidden_states, tuple):
|
assert (
|
||||||
if hidden_states[0].shape[0] == 0:
|
not self.o_proj.reduce_results
|
||||||
assert (
|
), "short-circuiting allreduce will lead to hangs"
|
||||||
not self.o_proj.reduce_results
|
return hidden_states, None, forward_batch, None
|
||||||
), "short-circuiting allreduce will lead to hangs"
|
|
||||||
return hidden_states[0]
|
|
||||||
else:
|
|
||||||
if hidden_states.shape[0] == 0:
|
|
||||||
assert (
|
|
||||||
not self.o_proj.reduce_results
|
|
||||||
), "short-circuiting allreduce will lead to hangs"
|
|
||||||
return hidden_states, None, forward_batch, None
|
|
||||||
|
|
||||||
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
|
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
|
||||||
|
|
||||||
@@ -1266,11 +1225,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
||||||
|
|
||||||
if self.q_lora_rank is not None:
|
if self.q_lora_rank is not None:
|
||||||
if (
|
if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm:
|
||||||
(not isinstance(hidden_states, tuple))
|
|
||||||
and hidden_states.shape[0] <= 16
|
|
||||||
and self.use_min_latency_fused_a_gemm
|
|
||||||
):
|
|
||||||
fused_qkv_a_proj_out = dsv3_fused_a_gemm(
|
fused_qkv_a_proj_out = dsv3_fused_a_gemm(
|
||||||
hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
|
hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
|
||||||
)
|
)
|
||||||
@@ -1290,18 +1245,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
k_nope = self.kv_a_layernorm(k_nope)
|
k_nope = self.kv_a_layernorm(k_nope)
|
||||||
current_stream.wait_stream(self.alt_stream)
|
current_stream.wait_stream(self.alt_stream)
|
||||||
else:
|
else:
|
||||||
if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
|
q = self.q_a_layernorm(q)
|
||||||
q, k_nope = fused_rms_mxfp4_quant(
|
k_nope = self.kv_a_layernorm(k_nope)
|
||||||
q,
|
|
||||||
self.q_a_layernorm.weight,
|
|
||||||
self.q_a_layernorm.variance_epsilon,
|
|
||||||
k_nope,
|
|
||||||
self.kv_a_layernorm.weight,
|
|
||||||
self.kv_a_layernorm.variance_epsilon,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
q = self.q_a_layernorm(q)
|
|
||||||
k_nope = self.kv_a_layernorm(k_nope)
|
|
||||||
|
|
||||||
k_nope = k_nope.unsqueeze(1)
|
k_nope = k_nope.unsqueeze(1)
|
||||||
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
||||||
@@ -1333,27 +1278,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
q_nope_out = q_nope_out[:, :expected_m, :]
|
q_nope_out = q_nope_out[:, :expected_m, :]
|
||||||
elif _is_hip:
|
elif _is_hip:
|
||||||
# TODO(haishaw): add bmm_fp8 to ROCm
|
# TODO(haishaw): add bmm_fp8 to ROCm
|
||||||
if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
|
q_nope_out = torch.bmm(
|
||||||
x = q_nope.transpose(0, 1)
|
q_nope.to(torch.bfloat16).transpose(0, 1),
|
||||||
q_nope_out = torch.empty(
|
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
||||||
x.shape[0],
|
)
|
||||||
x.shape[1],
|
|
||||||
self.w_kc.shape[2],
|
|
||||||
device=x.device,
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
)
|
|
||||||
batched_gemm_afp4wfp4_pre_quant(
|
|
||||||
x,
|
|
||||||
self.w_kc.transpose(-2, -1),
|
|
||||||
self.w_scale_k.transpose(-2, -1),
|
|
||||||
torch.bfloat16,
|
|
||||||
q_nope_out,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
q_nope_out = torch.bmm(
|
|
||||||
q_nope.to(torch.bfloat16).transpose(0, 1),
|
|
||||||
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
|
||||||
)
|
|
||||||
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
||||||
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
||||||
q_nope.transpose(0, 1),
|
q_nope.transpose(0, 1),
|
||||||
@@ -1367,15 +1295,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
|
|
||||||
q_nope_out = q_nope_out.transpose(0, 1)
|
q_nope_out = q_nope_out.transpose(0, 1)
|
||||||
|
|
||||||
if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
|
if not self._fuse_rope_for_trtllm_mla(forward_batch):
|
||||||
not _use_aiter or not _is_gfx95_supported
|
|
||||||
):
|
|
||||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||||
|
|
||||||
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
|
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
||||||
|
|
||||||
def forward_absorb_core(
|
def forward_absorb_core(
|
||||||
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
|
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
||||||
):
|
):
|
||||||
if (
|
if (
|
||||||
self.current_attention_backend == "fa3"
|
self.current_attention_backend == "fa3"
|
||||||
@@ -1400,23 +1326,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
**extra_args,
|
**extra_args,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if _use_aiter_gfx95:
|
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
||||||
cos = self.rotary_emb.cos_cache
|
k = torch.cat([k_nope, k_pe], dim=-1)
|
||||||
sin = self.rotary_emb.sin_cache
|
|
||||||
q, k = fused_qk_rope_cat(
|
|
||||||
q_nope_out,
|
|
||||||
q_pe,
|
|
||||||
k_nope,
|
|
||||||
k_pe,
|
|
||||||
positions,
|
|
||||||
cos,
|
|
||||||
sin,
|
|
||||||
self.rotary_emb.is_neox_style,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
|
||||||
k = torch.cat([k_nope, k_pe], dim=-1)
|
|
||||||
|
|
||||||
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
|
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
|
||||||
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
||||||
|
|
||||||
@@ -1441,34 +1352,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
)
|
)
|
||||||
elif _is_hip:
|
elif _is_hip:
|
||||||
# TODO(haishaw): add bmm_fp8 to ROCm
|
# TODO(haishaw): add bmm_fp8 to ROCm
|
||||||
if _use_aiter_gfx95 and self.w_vc.dtype == torch.uint8:
|
attn_bmm_output = torch.bmm(
|
||||||
x = attn_output.transpose(0, 1)
|
attn_output.to(torch.bfloat16).transpose(0, 1),
|
||||||
attn_bmm_output = torch.empty(
|
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
||||||
x.shape[0],
|
)
|
||||||
x.shape[1],
|
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
||||||
self.w_vc.shape[2],
|
|
||||||
device=x.device,
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
)
|
|
||||||
batched_gemm_afp4wfp4_pre_quant(
|
|
||||||
x,
|
|
||||||
self.w_vc.transpose(-2, -1),
|
|
||||||
self.w_scale_v.transpose(-2, -1),
|
|
||||||
torch.bfloat16,
|
|
||||||
attn_bmm_output,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
attn_bmm_output = torch.bmm(
|
|
||||||
attn_output.to(torch.bfloat16).transpose(0, 1),
|
|
||||||
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.o_proj.weight.dtype == torch.uint8:
|
|
||||||
attn_bmm_output = attn_bmm_output.transpose(0, 1)
|
|
||||||
attn_bmm_output = fused_flatten_mxfp4_quant(attn_bmm_output)
|
|
||||||
else:
|
|
||||||
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
|
||||||
|
|
||||||
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
||||||
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
||||||
attn_output.transpose(0, 1),
|
attn_output.transpose(0, 1),
|
||||||
@@ -1976,21 +1864,10 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
residual: Optional[torch.Tensor],
|
residual: Optional[torch.Tensor],
|
||||||
zero_allocator: BumpAllocator,
|
zero_allocator: BumpAllocator,
|
||||||
gemm_output_zero_allocator: BumpAllocator = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
quant_format = (
|
|
||||||
"mxfp4"
|
|
||||||
if _is_gfx95_supported
|
|
||||||
and self.self_attn.fused_qkv_a_proj_with_mqa.weight == torch.uint8
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states, residual = self.layer_communicator.prepare_attn(
|
hidden_states, residual = self.layer_communicator.prepare_attn(
|
||||||
hidden_states,
|
hidden_states, residual, forward_batch
|
||||||
residual,
|
|
||||||
forward_batch,
|
|
||||||
quant_format,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
@@ -2159,37 +2036,6 @@ class DeepseekV2Model(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.norm = PPMissingLayer(return_tuple=True)
|
self.norm = PPMissingLayer(return_tuple=True)
|
||||||
|
|
||||||
self.gemm_output_zero_allocator_size = 0
|
|
||||||
if (
|
|
||||||
_use_aiter_gfx95
|
|
||||||
and config.n_routed_experts == 256
|
|
||||||
and self.embed_tokens.embedding_dim == 7168
|
|
||||||
):
|
|
||||||
num_moe_layers = sum(
|
|
||||||
[
|
|
||||||
1
|
|
||||||
for i in range(len(self.layers))
|
|
||||||
if isinstance(self.layers[i].mlp, DeepseekV2MoE)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
allocate_size = 0
|
|
||||||
for i in range(len(self.layers)):
|
|
||||||
if isinstance(self.layers[i].mlp, DeepseekV2MoE):
|
|
||||||
allocate_size = self.layers[
|
|
||||||
i
|
|
||||||
].mlp.shared_experts.gate_up_proj.output_size_per_partition
|
|
||||||
break
|
|
||||||
|
|
||||||
self.gemm_output_zero_allocator_size = (
|
|
||||||
get_dsv3_gemm_output_zero_allocator_size(
|
|
||||||
config.n_routed_experts,
|
|
||||||
num_moe_layers,
|
|
||||||
allocate_size,
|
|
||||||
self.embed_tokens.embedding_dim,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_input_embeddings(self) -> torch.Tensor:
|
def get_input_embeddings(self) -> torch.Tensor:
|
||||||
return self.embed_tokens
|
return self.embed_tokens
|
||||||
|
|
||||||
@@ -2209,16 +2055,6 @@ class DeepseekV2Model(nn.Module):
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
gemm_output_zero_allocator = (
|
|
||||||
BumpAllocator(
|
|
||||||
buffer_size=self.gemm_output_zero_allocator_size,
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
if self.gemm_output_zero_allocator_size > 0
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.pp_group.is_first_rank:
|
if self.pp_group.is_first_rank:
|
||||||
if input_embeds is None:
|
if input_embeds is None:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
@@ -2245,12 +2081,7 @@ class DeepseekV2Model(nn.Module):
|
|||||||
with get_global_expert_distribution_recorder().with_current_layer(i):
|
with get_global_expert_distribution_recorder().with_current_layer(i):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions,
|
positions, hidden_states, forward_batch, residual, zero_allocator
|
||||||
hidden_states,
|
|
||||||
forward_batch,
|
|
||||||
residual,
|
|
||||||
zero_allocator,
|
|
||||||
gemm_output_zero_allocator,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if normal_end_layer != self.end_layer:
|
if normal_end_layer != self.end_layer:
|
||||||
@@ -2523,12 +2354,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
w_kc, w_vc = w.unflatten(
|
w_kc, w_vc = w.unflatten(
|
||||||
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
||||||
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
||||||
|
|
||||||
if _use_aiter_gfx95 and self.quant_config.get_name() == "quark":
|
|
||||||
w_kc, self_attn.w_scale_k, w_vc, self_attn.w_scale_v = (
|
|
||||||
quark_post_load_weights(self_attn, w, "mxfp4")
|
|
||||||
)
|
|
||||||
|
|
||||||
if not use_deep_gemm_bmm:
|
if not use_deep_gemm_bmm:
|
||||||
self_attn.w_kc = bind_or_assign(
|
self_attn.w_kc = bind_or_assign(
|
||||||
self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
||||||
|
|||||||
@@ -2900,18 +2900,6 @@ def mxfp_supported():
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
|
||||||
def is_gfx95_supported():
|
|
||||||
"""
|
|
||||||
Returns whether the current platform supports MX types.
|
|
||||||
"""
|
|
||||||
if torch.version.hip:
|
|
||||||
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
|
|
||||||
return any(gfx in gcn_arch for gfx in ["gfx95"])
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# LoRA-related constants and utilities
|
# LoRA-related constants and utilities
|
||||||
SUPPORTED_LORA_TARGET_MODULES = [
|
SUPPORTED_LORA_TARGET_MODULES = [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
|
|||||||
Reference in New Issue
Block a user