Minor style fixes for sgl-kernel (#9289)

This commit is contained in:
Lianmin Zheng
2025-08-18 09:38:35 -07:00
committed by GitHub
parent 6e316588f8
commit c480a3f6ea
17 changed files with 439 additions and 109 deletions

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Optional
from typing import Optional
import torch
from sgl_kernel.utils import get_cuda_stream, is_hopper_arch
@@ -345,3 +345,19 @@ def apply_rope_with_cos_sin_cache_inplace(
else None
),
)
def downcast_fp8(
k: torch.Tensor,
v: torch.Tensor,
k_out: torch.Tensor,
v_out: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
loc: torch.Tensor,
mult: int = 1,
offset: int = 0,
) -> None:
torch.ops.sgl_kernel.downcast_fp8(
k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, get_cuda_stream()
)