Minor style fixes for sgl-kernel (#9289)
This commit is contained in:
@@ -31,11 +31,11 @@ from sgl_kernel.elementwise import (
|
||||
rmsnorm,
|
||||
silu_and_mul,
|
||||
)
|
||||
from sgl_kernel.fused_moe import fused_marlin_moe
|
||||
|
||||
if torch.version.hip is not None:
|
||||
from sgl_kernel.elementwise import gelu_quick
|
||||
|
||||
from sgl_kernel.fused_moe import fused_marlin_moe
|
||||
from sgl_kernel.gemm import (
|
||||
awq_dequantize,
|
||||
bmm_fp8,
|
||||
@@ -114,7 +114,3 @@ from sgl_kernel.speculative import (
|
||||
)
|
||||
from sgl_kernel.top_k import fast_topk
|
||||
from sgl_kernel.version import __version__
|
||||
|
||||
build_tree_kernel = (
|
||||
None # TODO(ying): remove this after updating the sglang python code.
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from sgl_kernel.utils import get_cuda_stream, is_hopper_arch
|
||||
@@ -345,3 +345,19 @@ def apply_rope_with_cos_sin_cache_inplace(
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downcast_fp8(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
k_out: torch.Tensor,
|
||||
v_out: torch.Tensor,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
loc: torch.Tensor,
|
||||
mult: int = 1,
|
||||
offset: int = 0,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.downcast_fp8(
|
||||
k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, get_cuda_stream()
|
||||
)
|
||||
|
||||
@@ -160,7 +160,7 @@ def fused_marlin_moe(
|
||||
size_m=M,
|
||||
size_n=2 * N,
|
||||
size_k=K,
|
||||
is_full_k=is_k_full,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
@@ -192,7 +192,7 @@ def fused_marlin_moe(
|
||||
size_m=M * topk,
|
||||
size_n=K,
|
||||
size_k=N,
|
||||
is_full_k=is_k_full,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=True,
|
||||
is_zp_float=False,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from sgl_kernel.utils import _to_tensor_scalar_tuple
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# ==============================================================================
|
||||
|
||||
import functools
|
||||
import subprocess
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
Reference in New Issue
Block a user