Files
sglang/sgl-kernel/src/sgl-kernel/ops/utils.py

27 lines
616 B
Python
Raw Normal View History

from typing import Dict, Tuple
import torch
def _get_cuda_stream(device: torch.device) -> int:
return torch.cuda.current_stream(device).cuda_stream
_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {}
def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor:
key = (name, device)
buf = _cache_buf.get(key)
if buf is None:
buf = torch.empty(bytes, dtype=torch.uint8, device=device)
_cache_buf[key] = buf
return buf
def _to_tensor_scalar_tuple(x):
if isinstance(x, torch.Tensor):
return (x, 0)
else:
return (None, x)