[code style] Clean dead triton kernel code in fused_moe and useless vllm_ops import (#8310)
This commit is contained in:
@@ -53,9 +53,7 @@ elif _is_hip:
|
|||||||
from aiter import moe_sum
|
from aiter import moe_sum
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
||||||
else:
|
|
||||||
from vllm import _custom_ops as vllm_ops
|
|
||||||
from vllm._custom_ops import scaled_fp8_quant
|
|
||||||
|
|
||||||
if _is_cuda or _is_hip:
|
if _is_cuda or _is_hip:
|
||||||
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
||||||
@@ -63,9 +61,6 @@ if _is_cuda or _is_hip:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
|
padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
|
||||||
enable_moe_align_block_size_triton = bool(
|
|
||||||
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@@ -533,190 +528,6 @@ def fused_moe_kernel(
|
|||||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def moe_align_block_size_stage1(
|
|
||||||
topk_ids_ptr,
|
|
||||||
tokens_cnts_ptr,
|
|
||||||
num_experts: tl.constexpr,
|
|
||||||
numel: tl.constexpr,
|
|
||||||
tokens_per_thread: tl.constexpr,
|
|
||||||
):
|
|
||||||
pid = tl.program_id(0)
|
|
||||||
|
|
||||||
start_idx = pid * tokens_per_thread
|
|
||||||
|
|
||||||
off_c = (pid + 1) * num_experts
|
|
||||||
|
|
||||||
for i in range(tokens_per_thread):
|
|
||||||
if start_idx + i < numel:
|
|
||||||
idx = tl.load(topk_ids_ptr + start_idx + i)
|
|
||||||
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
|
|
||||||
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def moe_align_block_size_stage2(
|
|
||||||
tokens_cnts_ptr,
|
|
||||||
num_experts: tl.constexpr,
|
|
||||||
):
|
|
||||||
pid = tl.program_id(0)
|
|
||||||
|
|
||||||
last_cnt = 0
|
|
||||||
for i in range(1, num_experts + 1):
|
|
||||||
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
|
|
||||||
last_cnt = last_cnt + token_cnt
|
|
||||||
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def moe_align_block_size_stage3(
|
|
||||||
total_tokens_post_pad_ptr,
|
|
||||||
tokens_cnts_ptr,
|
|
||||||
cumsum_ptr,
|
|
||||||
num_experts: tl.constexpr,
|
|
||||||
block_size: tl.constexpr,
|
|
||||||
):
|
|
||||||
last_cumsum = 0
|
|
||||||
off_cnt = num_experts * num_experts
|
|
||||||
for i in range(1, num_experts + 1):
|
|
||||||
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
|
|
||||||
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
|
|
||||||
tl.store(cumsum_ptr + i, last_cumsum)
|
|
||||||
tl.store(total_tokens_post_pad_ptr, last_cumsum)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def moe_align_block_size_stage4(
|
|
||||||
topk_ids_ptr,
|
|
||||||
sorted_token_ids_ptr,
|
|
||||||
expert_ids_ptr,
|
|
||||||
tokens_cnts_ptr,
|
|
||||||
cumsum_ptr,
|
|
||||||
num_experts: tl.constexpr,
|
|
||||||
block_size: tl.constexpr,
|
|
||||||
numel: tl.constexpr,
|
|
||||||
tokens_per_thread: tl.constexpr,
|
|
||||||
):
|
|
||||||
pid = tl.program_id(0)
|
|
||||||
start_idx = tl.load(cumsum_ptr + pid)
|
|
||||||
end_idx = tl.load(cumsum_ptr + pid + 1)
|
|
||||||
|
|
||||||
for i in range(start_idx, end_idx, block_size):
|
|
||||||
tl.store(expert_ids_ptr + i // block_size, pid)
|
|
||||||
|
|
||||||
start_idx = pid * tokens_per_thread
|
|
||||||
off_t = pid * num_experts
|
|
||||||
|
|
||||||
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
|
|
||||||
expert_id = tl.load(topk_ids_ptr + i)
|
|
||||||
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
|
|
||||||
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
|
|
||||||
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
|
|
||||||
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
|
|
||||||
|
|
||||||
|
|
||||||
def moe_align_block_size_triton(
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
num_experts: int,
|
|
||||||
block_size: int,
|
|
||||||
sorted_token_ids: torch.Tensor,
|
|
||||||
expert_ids: torch.Tensor,
|
|
||||||
num_tokens_post_pad: torch.Tensor,
|
|
||||||
) -> None:
|
|
||||||
numel = topk_ids.numel()
|
|
||||||
grid = (num_experts,)
|
|
||||||
tokens_cnts = torch.zeros(
|
|
||||||
(num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
|
|
||||||
)
|
|
||||||
cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
|
|
||||||
tokens_per_thread = ceil_div(numel, num_experts)
|
|
||||||
|
|
||||||
moe_align_block_size_stage1[grid](
|
|
||||||
topk_ids,
|
|
||||||
tokens_cnts,
|
|
||||||
num_experts,
|
|
||||||
numel,
|
|
||||||
tokens_per_thread,
|
|
||||||
)
|
|
||||||
moe_align_block_size_stage2[grid](
|
|
||||||
tokens_cnts,
|
|
||||||
num_experts,
|
|
||||||
)
|
|
||||||
moe_align_block_size_stage3[(1,)](
|
|
||||||
num_tokens_post_pad,
|
|
||||||
tokens_cnts,
|
|
||||||
cumsum,
|
|
||||||
num_experts,
|
|
||||||
block_size,
|
|
||||||
)
|
|
||||||
moe_align_block_size_stage4[grid](
|
|
||||||
topk_ids,
|
|
||||||
sorted_token_ids,
|
|
||||||
expert_ids,
|
|
||||||
tokens_cnts,
|
|
||||||
cumsum,
|
|
||||||
num_experts,
|
|
||||||
block_size,
|
|
||||||
numel,
|
|
||||||
tokens_per_thread,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def init_sorted_ids_and_cumsum_buffer_kernel(
|
|
||||||
sorted_ids_ptr,
|
|
||||||
cumsum_buffer_ptr,
|
|
||||||
max_num_tokens_padded,
|
|
||||||
topk_ids_numel,
|
|
||||||
num_experts: tl.constexpr,
|
|
||||||
BLOCK_SIZE: tl.constexpr,
|
|
||||||
ALIGNED_NUM_EXPERTS_P1: tl.constexpr,
|
|
||||||
):
|
|
||||||
pid = tl.program_id(0)
|
|
||||||
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
|
||||||
|
|
||||||
sorted_ids_blocks = tl.cdiv(max_num_tokens_padded, BLOCK_SIZE)
|
|
||||||
|
|
||||||
if pid < sorted_ids_blocks:
|
|
||||||
mask = offsets < max_num_tokens_padded
|
|
||||||
tl.store(
|
|
||||||
sorted_ids_ptr + offsets,
|
|
||||||
tl.full((BLOCK_SIZE,), topk_ids_numel, dtype=tl.int32),
|
|
||||||
mask=mask,
|
|
||||||
)
|
|
||||||
elif pid == sorted_ids_blocks:
|
|
||||||
offset_e = tl.arange(0, ALIGNED_NUM_EXPERTS_P1)
|
|
||||||
mask_e = offset_e < num_experts + 1
|
|
||||||
tl.store(
|
|
||||||
cumsum_buffer_ptr + offset_e,
|
|
||||||
tl.zeros((ALIGNED_NUM_EXPERTS_P1,), dtype=tl.int32),
|
|
||||||
mask=mask_e,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def init_sorted_ids_and_cumsum_buffer(
|
|
||||||
max_num_tokens_padded: int, topk_ids_numel: int, num_experts: int, device="cuda"
|
|
||||||
):
|
|
||||||
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device=device)
|
|
||||||
cumsum_buffer = torch.empty((num_experts + 1,), dtype=torch.int32, device=device)
|
|
||||||
|
|
||||||
BLOCK_SIZE = 1024
|
|
||||||
sorted_ids_blocks = triton.cdiv(max_num_tokens_padded, BLOCK_SIZE)
|
|
||||||
grid = (sorted_ids_blocks + 1,)
|
|
||||||
|
|
||||||
init_sorted_ids_and_cumsum_buffer_kernel[grid](
|
|
||||||
sorted_ids,
|
|
||||||
cumsum_buffer,
|
|
||||||
max_num_tokens_padded,
|
|
||||||
topk_ids_numel,
|
|
||||||
num_experts,
|
|
||||||
BLOCK_SIZE,
|
|
||||||
next_power_of_2(num_experts + 1),
|
|
||||||
)
|
|
||||||
|
|
||||||
return sorted_ids, cumsum_buffer
|
|
||||||
|
|
||||||
|
|
||||||
def moe_align_block_size(
|
def moe_align_block_size(
|
||||||
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
@@ -766,42 +577,32 @@ def moe_align_block_size(
|
|||||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||||
)
|
)
|
||||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
||||||
if enable_moe_align_block_size_triton:
|
|
||||||
|
cumsum_buffer = torch.empty(
|
||||||
|
(num_experts + 1,), dtype=torch.int32, device=topk_ids.device
|
||||||
|
)
|
||||||
|
token_cnts_buffer = torch.empty(
|
||||||
|
(num_experts + 1) * num_experts,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=topk_ids.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Threshold based on benchmark results
|
||||||
|
fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096
|
||||||
|
if not fuse_sorted_ids_padding:
|
||||||
sorted_ids.fill_(topk_ids.numel())
|
sorted_ids.fill_(topk_ids.numel())
|
||||||
moe_align_block_size_triton(
|
|
||||||
topk_ids,
|
|
||||||
num_experts,
|
|
||||||
block_size,
|
|
||||||
sorted_ids,
|
|
||||||
expert_ids,
|
|
||||||
num_tokens_post_pad,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cumsum_buffer = torch.empty(
|
|
||||||
(num_experts + 1,), dtype=torch.int32, device=topk_ids.device
|
|
||||||
)
|
|
||||||
token_cnts_buffer = torch.empty(
|
|
||||||
(num_experts + 1) * num_experts,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=topk_ids.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Threshold based on benchmark results
|
sgl_moe_align_block_size(
|
||||||
fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096
|
topk_ids,
|
||||||
if not fuse_sorted_ids_padding:
|
num_experts,
|
||||||
sorted_ids.fill_(topk_ids.numel())
|
block_size,
|
||||||
|
sorted_ids,
|
||||||
sgl_moe_align_block_size(
|
expert_ids,
|
||||||
topk_ids,
|
num_tokens_post_pad,
|
||||||
num_experts,
|
token_cnts_buffer,
|
||||||
block_size,
|
cumsum_buffer,
|
||||||
sorted_ids,
|
fuse_sorted_ids_padding,
|
||||||
expert_ids,
|
)
|
||||||
num_tokens_post_pad,
|
|
||||||
token_cnts_buffer,
|
|
||||||
cumsum_buffer,
|
|
||||||
fuse_sorted_ids_padding,
|
|
||||||
)
|
|
||||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -28,15 +28,6 @@ if TYPE_CHECKING:
|
|||||||
CompressedTensorsConfig,
|
CompressedTensorsConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
|
||||||
_is_npu = is_npu()
|
|
||||||
_is_cpu_amx_available = cpu_has_amx_support()
|
|
||||||
_is_cpu = is_cpu()
|
|
||||||
_is_hip = is_hip()
|
|
||||||
|
|
||||||
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
|
|
||||||
from vllm import _custom_ops as vllm_ops
|
|
||||||
from vllm._custom_ops import scaled_fp8_quant
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import vllm
|
import vllm
|
||||||
@@ -568,6 +559,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from vllm import _custom_ops as vllm_ops
|
||||||
|
|
||||||
marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack(
|
marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack(
|
||||||
layer.w13_weight_packed,
|
layer.w13_weight_packed,
|
||||||
layer.w13_g_idx_sort_indices,
|
layer.w13_g_idx_sort_indices,
|
||||||
|
|||||||
@@ -17,15 +17,6 @@ from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_np
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
|
|
||||||
_is_cuda = is_cuda()
|
|
||||||
_is_npu = is_npu()
|
|
||||||
_is_cpu_amx_available = cpu_has_amx_support()
|
|
||||||
_is_cpu = is_cpu()
|
|
||||||
_is_hip = is_hip()
|
|
||||||
|
|
||||||
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
|
|
||||||
from vllm._custom_ops import scaled_fp8_quant
|
|
||||||
|
|
||||||
|
|
||||||
def is_layer_skipped(
|
def is_layer_skipped(
|
||||||
prefix: str,
|
prefix: str,
|
||||||
|
|||||||
Reference in New Issue
Block a user