feat: integrate bmm_fp8 kernel into sgl-kernel (#3056)

This commit is contained in:
Yineng Zhang
2025-01-23 00:39:38 +08:00
committed by GitHub
parent b2bd8f444c
commit bf669606eb
6 changed files with 131 additions and 12 deletions

View File

@@ -0,0 +1,19 @@
from typing import Dict, Tuple
import torch
def _get_cuda_stream(device: torch.device) -> int:
return torch.cuda.current_stream(device).cuda_stream
_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {}
def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor:
key = (name, device)
buf = _cache_buf.get(key)
if buf is None:
buf = torch.empty(bytes, dtype=torch.uint8, device=device)
_cache_buf[key] = buf
return buf