adapt to ds3.2
This commit is contained in:
@@ -4,7 +4,7 @@ import time
|
||||
|
||||
import torch
|
||||
|
||||
from sglang import ServerArgs
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.managers.cache_controller import HiCacheController
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
|
||||
@@ -127,21 +127,45 @@ class RMSNorm(CustomOp):
|
||||
return output, residual_out
|
||||
return rms_norm(x, self.weight.data, self.variance_epsilon)
|
||||
|
||||
# def forward_hip(
|
||||
# self,
|
||||
# x: torch.Tensor,
|
||||
# residual: Optional[torch.Tensor] = None,
|
||||
# ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
# if not x.is_contiguous():
|
||||
# # NOTE: Remove this if aiter kernel supports discontinuous input
|
||||
# x = x.contiguous()
|
||||
# if residual is not None:
|
||||
# if _vllm_version < Version("0.9"):
|
||||
# fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
|
||||
# return x, residual
|
||||
# else:
|
||||
# residual_out = torch.empty_like(x)
|
||||
# output = torch.empty_like(x)
|
||||
# fused_add_rms_norm(
|
||||
# output,
|
||||
# x,
|
||||
# residual_out,
|
||||
# residual,
|
||||
# self.weight.data,
|
||||
# self.variance_epsilon,
|
||||
# )
|
||||
# return output, residual_out
|
||||
# out = torch.empty_like(x)
|
||||
# rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
||||
# return out
|
||||
def forward_hip(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if not x.is_contiguous():
|
||||
# NOTE: Remove this if aiter kernel supports discontinuous input
|
||||
x = x.contiguous()
|
||||
|
||||
if residual is not None:
|
||||
if _vllm_version < Version("0.9"):
|
||||
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
|
||||
return x, residual
|
||||
else:
|
||||
residual_out = torch.empty_like(x)
|
||||
try:
|
||||
output = torch.empty_like(x)
|
||||
residual_out = torch.empty_like(x)
|
||||
fused_add_rms_norm(
|
||||
output,
|
||||
x,
|
||||
@@ -151,10 +175,21 @@ class RMSNorm(CustomOp):
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return output, residual_out
|
||||
except TypeError:
|
||||
fused_add_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return x, residual
|
||||
|
||||
out = torch.empty_like(x)
|
||||
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
|
||||
@@ -61,7 +61,7 @@ def inplace_fused_experts(
|
||||
topk_ids: torch.Tensor,
|
||||
b1: Optional[torch.Tensor] = None,
|
||||
b2: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
activation: int = 0,#0 silu 1 gelu
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
@@ -79,6 +79,8 @@ def inplace_fused_experts(
|
||||
gemm1_alpha: Optional[float] = None,
|
||||
gemm1_limit: Optional[float] = None,
|
||||
) -> None:
|
||||
if isinstance(activation, int):
|
||||
activation = "silu" if activation == 0 else "gelu"
|
||||
fused_experts_impl(
|
||||
hidden_states,
|
||||
w1,
|
||||
@@ -117,7 +119,7 @@ def inplace_fused_experts_fake(
|
||||
topk_ids: torch.Tensor,
|
||||
b1: Optional[torch.Tensor] = None,
|
||||
b2: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
activation: int = 0,#0 silu 1 gelu
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
@@ -154,7 +156,7 @@ def outplace_fused_experts(
|
||||
topk_ids: torch.Tensor,
|
||||
b1: Optional[torch.Tensor] = None,
|
||||
b2: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
activation: int = 0,#0 silu 1 gelu
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
@@ -173,6 +175,8 @@ def outplace_fused_experts(
|
||||
gemm1_alpha: Optional[float] = None,
|
||||
gemm1_limit: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(activation, int):
|
||||
activation = "silu" if activation == 0 else "gelu"
|
||||
return fused_experts_impl(
|
||||
hidden_states,
|
||||
w1,
|
||||
@@ -211,7 +215,7 @@ def outplace_fused_experts_fake(
|
||||
topk_ids: torch.Tensor,
|
||||
b1: Optional[torch.Tensor] = None,
|
||||
b2: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
activation: int = 0,#0 silu 1 gelu
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
@@ -263,6 +267,13 @@ def fused_experts(
|
||||
block_shape: Optional[List[int]] = None,
|
||||
):
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
act_id = (
|
||||
0 if (
|
||||
moe_runner_config.activation == 0
|
||||
or (isinstance(moe_runner_config.activation, str)
|
||||
and moe_runner_config.activation.lower() == "silu")
|
||||
) else 1
|
||||
)
|
||||
if moe_runner_config.inplace:
|
||||
assert not moe_runner_config.no_combine, "no combine + inplace makes no sense"
|
||||
torch.ops.sglang.inplace_fused_experts(
|
||||
@@ -273,7 +284,7 @@ def fused_experts(
|
||||
topk_ids,
|
||||
b1,
|
||||
b2,
|
||||
moe_runner_config.activation,
|
||||
act_id,
|
||||
moe_runner_config.apply_router_weight_on_input,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
@@ -301,7 +312,7 @@ def fused_experts(
|
||||
topk_ids,
|
||||
b1,
|
||||
b2,
|
||||
moe_runner_config.activation,
|
||||
act_id,
|
||||
moe_runner_config.apply_router_weight_on_input,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
@@ -345,7 +356,7 @@ def fused_experts_impl(
|
||||
b1: Optional[torch.Tensor] = None,
|
||||
b2: Optional[torch.Tensor] = None,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
activation: int = 0,#0 silu 1 gelu
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
@@ -364,6 +375,9 @@ def fused_experts_impl(
|
||||
gemm1_alpha: Optional[float] = None,
|
||||
gemm1_limit: Optional[float] = None,
|
||||
):
|
||||
if isinstance(activation, int):
|
||||
activation = "silu" if activation == 0 else "gelu"
|
||||
|
||||
padded_size = padding_size
|
||||
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
|
||||
padded_size = 0
|
||||
|
||||
@@ -516,7 +516,7 @@ class ModelRunner:
|
||||
):
|
||||
server_args.attention_backend = "fa3"
|
||||
elif _is_hip:
|
||||
server_args.attention_backend = "aiter"
|
||||
server_args.attention_backend = "triton"
|
||||
elif _is_npu:
|
||||
server_args.attention_backend = "ascend"
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user