Minor style fixes for sgl-kernel (#9289)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from sgl_kernel.utils import get_cuda_stream, is_hopper_arch
|
||||
@@ -345,3 +345,19 @@ def apply_rope_with_cos_sin_cache_inplace(
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downcast_fp8(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
k_out: torch.Tensor,
|
||||
v_out: torch.Tensor,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
loc: torch.Tensor,
|
||||
mult: int = 1,
|
||||
offset: int = 0,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.downcast_fp8(
|
||||
k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, get_cuda_stream()
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user