2025-01-23 00:39:38 +08:00
|
|
|
from typing import Dict, Tuple
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_cuda_stream(device: torch.device) -> int:
|
|
|
|
|
return torch.cuda.current_stream(device).cuda_stream
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor:
|
|
|
|
|
key = (name, device)
|
|
|
|
|
buf = _cache_buf.get(key)
|
|
|
|
|
if buf is None:
|
|
|
|
|
buf = torch.empty(bytes, dtype=torch.uint8, device=device)
|
|
|
|
|
_cache_buf[key] = buf
|
|
|
|
|
return buf
|
2025-01-24 01:54:47 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def _to_tensor_scalar_tuple(x):
|
|
|
|
|
if isinstance(x, torch.Tensor):
|
|
|
|
|
return (x, 0)
|
|
|
|
|
else:
|
|
|
|
|
return (None, x)
|