adapt to ds3.2

This commit is contained in:
maxiao
2025-09-30 17:44:54 +08:00
parent 1237aa19ce
commit 8f7453e3af
9 changed files with 199 additions and 49 deletions

View File

@@ -4,7 +4,7 @@ import time
import torch
from sglang import ServerArgs
from sglang.srt.server_args import ServerArgs
from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache

View File

@@ -127,21 +127,45 @@ class RMSNorm(CustomOp):
return output, residual_out
return rms_norm(x, self.weight.data, self.variance_epsilon)
# def forward_hip(
# self,
# x: torch.Tensor,
# residual: Optional[torch.Tensor] = None,
# ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# if not x.is_contiguous():
# # NOTE: Remove this if aiter kernel supports discontinuous input
# x = x.contiguous()
# if residual is not None:
# if _vllm_version < Version("0.9"):
# fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
# return x, residual
# else:
# residual_out = torch.empty_like(x)
# output = torch.empty_like(x)
# fused_add_rms_norm(
# output,
# x,
# residual_out,
# residual,
# self.weight.data,
# self.variance_epsilon,
# )
# return output, residual_out
# out = torch.empty_like(x)
# rms_norm(out, x, self.weight.data, self.variance_epsilon)
# return out
def forward_hip(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
):
if not x.is_contiguous():
# NOTE: Remove this if aiter kernel supports discontinuous input
x = x.contiguous()
if residual is not None:
if _vllm_version < Version("0.9"):
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
return x, residual
else:
residual_out = torch.empty_like(x)
try:
output = torch.empty_like(x)
residual_out = torch.empty_like(x)
fused_add_rms_norm(
output,
x,
@@ -151,10 +175,21 @@ class RMSNorm(CustomOp):
self.variance_epsilon,
)
return output, residual_out
except TypeError:
fused_add_rms_norm(
x,
residual,
self.weight.data,
self.variance_epsilon,
)
return x, residual
out = torch.empty_like(x)
rms_norm(out, x, self.weight.data, self.variance_epsilon)
return out
def forward_native(
self,
x: torch.Tensor,

View File

@@ -61,7 +61,7 @@ def inplace_fused_experts(
topk_ids: torch.Tensor,
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
activation: str = "silu",
activation: int = 0,#0 silu 1 gelu
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
@@ -79,6 +79,8 @@ def inplace_fused_experts(
gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None,
) -> None:
if isinstance(activation, int):
activation = "silu" if activation == 0 else "gelu"
fused_experts_impl(
hidden_states,
w1,
@@ -117,7 +119,7 @@ def inplace_fused_experts_fake(
topk_ids: torch.Tensor,
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
activation: str = "silu",
activation: int = 0,#0 silu 1 gelu
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
@@ -154,7 +156,7 @@ def outplace_fused_experts(
topk_ids: torch.Tensor,
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
activation: str = "silu",
activation: int = 0,#0 silu 1 gelu
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
@@ -173,6 +175,8 @@ def outplace_fused_experts(
gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None,
) -> torch.Tensor:
if isinstance(activation, int):
activation = "silu" if activation == 0 else "gelu"
return fused_experts_impl(
hidden_states,
w1,
@@ -211,7 +215,7 @@ def outplace_fused_experts_fake(
topk_ids: torch.Tensor,
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
activation: str = "silu",
activation: int = 0,#0 silu 1 gelu
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
@@ -263,6 +267,13 @@ def fused_experts(
block_shape: Optional[List[int]] = None,
):
topk_weights, topk_ids, _ = topk_output
act_id = (
0 if (
moe_runner_config.activation == 0
or (isinstance(moe_runner_config.activation, str)
and moe_runner_config.activation.lower() == "silu")
) else 1
)
if moe_runner_config.inplace:
assert not moe_runner_config.no_combine, "no combine + inplace makes no sense"
torch.ops.sglang.inplace_fused_experts(
@@ -273,7 +284,7 @@ def fused_experts(
topk_ids,
b1,
b2,
moe_runner_config.activation,
act_id,
moe_runner_config.apply_router_weight_on_input,
use_fp8_w8a8,
use_int8_w8a8,
@@ -301,7 +312,7 @@ def fused_experts(
topk_ids,
b1,
b2,
moe_runner_config.activation,
act_id,
moe_runner_config.apply_router_weight_on_input,
use_fp8_w8a8,
use_int8_w8a8,
@@ -345,7 +356,7 @@ def fused_experts_impl(
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
inplace: bool = False,
activation: str = "silu",
activation: int = 0,#0 silu 1 gelu
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
@@ -364,6 +375,9 @@ def fused_experts_impl(
gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None,
):
if isinstance(activation, int):
activation = "silu" if activation == 0 else "gelu"
padded_size = padding_size
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
padded_size = 0

View File

@@ -516,7 +516,7 @@ class ModelRunner:
):
server_args.attention_backend = "fa3"
elif _is_hip:
server_args.attention_backend = "aiter"
server_args.attention_backend = "triton"
elif _is_npu:
server_args.attention_backend = "ascend"
else: