Minor style fixes for sgl-kernel (#9289)

This commit is contained in:
Lianmin Zheng
2025-08-18 09:38:35 -07:00
committed by GitHub
parent 6e316588f8
commit c480a3f6ea
17 changed files with 439 additions and 109 deletions

View File

@@ -31,11 +31,11 @@ from sgl_kernel.elementwise import (
rmsnorm,
silu_and_mul,
)
from sgl_kernel.fused_moe import fused_marlin_moe
if torch.version.hip is not None:
from sgl_kernel.elementwise import gelu_quick
from sgl_kernel.fused_moe import fused_marlin_moe
from sgl_kernel.gemm import (
awq_dequantize,
bmm_fp8,
@@ -114,7 +114,3 @@ from sgl_kernel.speculative import (
)
from sgl_kernel.top_k import fast_topk
from sgl_kernel.version import __version__
build_tree_kernel = (
None # TODO(ying): remove this after updating the sglang python code.
)

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Optional
from typing import Optional
import torch
from sgl_kernel.utils import get_cuda_stream, is_hopper_arch
@@ -345,3 +345,19 @@ def apply_rope_with_cos_sin_cache_inplace(
else None
),
)
def downcast_fp8(
k: torch.Tensor,
v: torch.Tensor,
k_out: torch.Tensor,
v_out: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
loc: torch.Tensor,
mult: int = 1,
offset: int = 0,
) -> None:
torch.ops.sgl_kernel.downcast_fp8(
k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, get_cuda_stream()
)

View File

@@ -160,7 +160,7 @@ def fused_marlin_moe(
size_m=M,
size_n=2 * N,
size_k=K,
is_full_k=is_k_full,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=True,
is_zp_float=False,
@@ -192,7 +192,7 @@ def fused_marlin_moe(
size_m=M * topk,
size_n=K,
size_k=N,
is_full_k=is_k_full,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=True,
is_zp_float=False,

View File

@@ -1,4 +1,4 @@
from typing import Optional, Tuple, Union
from typing import Optional, Union
import torch
from sgl_kernel.utils import _to_tensor_scalar_tuple

View File

@@ -14,7 +14,6 @@
# ==============================================================================
import functools
import subprocess
from typing import Dict, Tuple
import torch