221 lines
7.0 KiB
Python
221 lines
7.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from collections.abc import Iterable, Sequence
|
|
from functools import partial
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from vllm.triton_utils import tl, triton
|
|
from vllm.utils.platform_utils import is_uva_available
|
|
from vllm.utils.torch_utils import (
|
|
async_tensor_h2d,
|
|
get_accelerator_view_from_cpu_tensor,
|
|
)
|
|
|
|
|
|
def async_copy_to_gpu(
|
|
x: torch.Tensor | np.ndarray,
|
|
out: torch.Tensor | None = None,
|
|
device: torch.device | None = None,
|
|
) -> torch.Tensor:
|
|
if isinstance(x, np.ndarray):
|
|
x = torch.from_numpy(x)
|
|
assert x.is_cpu
|
|
|
|
if out is None:
|
|
assert device is not None
|
|
out = torch.empty_like(x, device=device)
|
|
|
|
# CPU-to-CPU copy
|
|
tmp = x.pin_memory()
|
|
assert tmp is not x
|
|
|
|
# CPU-to-GPU copy
|
|
return out.copy_(tmp, non_blocking=True)
|
|
|
|
|
|
class UvaBuffer:
|
|
def __init__(self, size: int | Sequence[int], dtype: torch.dtype):
|
|
if not is_uva_available():
|
|
raise RuntimeError("UVA is not available")
|
|
self.cpu = torch.zeros(size, dtype=dtype, device="cpu", pin_memory=True)
|
|
self.np = self.cpu.numpy()
|
|
self.uva = get_accelerator_view_from_cpu_tensor(self.cpu)
|
|
|
|
|
|
class UvaBufferPool:
|
|
def __init__(
|
|
self,
|
|
size: int | Sequence[int],
|
|
dtype: torch.dtype,
|
|
max_concurrency: int = 2,
|
|
):
|
|
self.size = size
|
|
self.dtype = dtype
|
|
self.max_concurrency = max_concurrency
|
|
|
|
# UVA buffers for concurrency
|
|
self._uva_bufs = [UvaBuffer(size, dtype) for _ in range(max_concurrency)]
|
|
# Current buffer index
|
|
self._curr = 0
|
|
|
|
def copy_to_uva(self, x: torch.Tensor | np.ndarray | list) -> torch.Tensor:
|
|
# Round robin to the next buffer.
|
|
self._curr = (self._curr + 1) % self.max_concurrency
|
|
buf = self._uva_bufs[self._curr]
|
|
# CPU-to-CPU copy
|
|
dst = buf.cpu if isinstance(x, torch.Tensor) else buf.np
|
|
n = len(x)
|
|
dst[:n] = x
|
|
return buf.uva[:n]
|
|
|
|
def copy_to_gpu(
|
|
self,
|
|
x: torch.Tensor | np.ndarray,
|
|
out: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
uva = self.copy_to_uva(x)
|
|
# CPU-to-GPU copy
|
|
return uva.clone() if out is None else out.copy_(uva, non_blocking=True)
|
|
|
|
|
|
class UvaBackedTensor:
|
|
def __init__(
|
|
self, size: int | Sequence[int], dtype: torch.dtype, max_concurrency: int = 2
|
|
):
|
|
self.dtype = dtype
|
|
self.max_concurrency = max_concurrency
|
|
|
|
# Source of truth
|
|
self.cpu = torch.zeros(size, dtype=dtype, device="cpu", pin_memory=False)
|
|
self.np = self.cpu.numpy()
|
|
|
|
# Buffers for concurrency
|
|
self.pool = UvaBufferPool(size, dtype, max_concurrency)
|
|
self.gpu = self.pool.copy_to_uva(self.np)
|
|
|
|
def copy_to_uva(self, n: int | None = None) -> torch.Tensor:
|
|
# CPU-to-CPU copy
|
|
self.gpu = self.pool.copy_to_uva(self.np[:n] if n is not None else self.np)
|
|
return self.gpu
|
|
|
|
|
|
class StagedWriteTensor:
|
|
def __init__(
|
|
self,
|
|
size: int | Sequence[int],
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
max_concurrency: int = 2,
|
|
uva_instead_of_gpu: bool = False,
|
|
):
|
|
supported_dtypes = [torch.int32, torch.int64, torch.float32]
|
|
if dtype not in supported_dtypes:
|
|
raise ValueError(
|
|
f"Unsupported dtype {dtype}: should be one of {supported_dtypes}"
|
|
)
|
|
self.num_rows = size if isinstance(size, int) else size[0]
|
|
self.dtype = dtype
|
|
self.device = device
|
|
self.max_concurrency = max_concurrency
|
|
|
|
if not uva_instead_of_gpu:
|
|
# Create a GPU tensor (default)
|
|
self.gpu = torch.zeros(size, dtype=dtype, device=device)
|
|
else:
|
|
# For a large but not-frequently-accessed tensor, we can use UVA instead of
|
|
# GPU to save GPU memory
|
|
self._uva_buf = UvaBuffer(size, dtype)
|
|
self.gpu = self._uva_buf.uva
|
|
|
|
self._staged_write_indices: list[int] = []
|
|
self._staged_write_starts: list[int] = []
|
|
self._staged_write_contents: list[int | float] = []
|
|
self._staged_write_cu_lens: list[int] = []
|
|
|
|
new_buffer = partial(UvaBufferPool, max_concurrency=max_concurrency)
|
|
|
|
self.write_indices = new_buffer(self.num_rows, dtype=torch.int32)
|
|
self.write_starts = new_buffer(self.num_rows, dtype=torch.int32)
|
|
self.write_cu_lens = new_buffer(self.num_rows, dtype=torch.int32)
|
|
|
|
def stage_write(
|
|
self, index: int, start: int, x: Iterable[int] | Iterable[float]
|
|
) -> None:
|
|
assert index >= 0
|
|
assert start >= 0
|
|
if not x:
|
|
return
|
|
self._staged_write_indices.append(index)
|
|
self._staged_write_starts.append(start)
|
|
self._staged_write_contents.extend(x)
|
|
self._staged_write_cu_lens.append(len(self._staged_write_contents))
|
|
|
|
def stage_write_elem(self, index: int, x: int) -> None:
|
|
assert index >= 0
|
|
self._staged_write_indices.append(index)
|
|
self._staged_write_starts.append(0)
|
|
self._staged_write_contents.append(x)
|
|
self._staged_write_cu_lens.append(len(self._staged_write_contents))
|
|
|
|
def apply_write(self) -> None:
|
|
n = len(self._staged_write_indices)
|
|
if n == 0:
|
|
return
|
|
|
|
indices_uva = self.write_indices.copy_to_uva(self._staged_write_indices)
|
|
starts_uva = self.write_starts.copy_to_uva(self._staged_write_starts)
|
|
cu_lens_uva = self.write_cu_lens.copy_to_uva(self._staged_write_cu_lens)
|
|
|
|
# Special handling for write_contents
|
|
write_contents = async_tensor_h2d(
|
|
self._staged_write_contents, self.dtype, self.device, pin_memory=True
|
|
)
|
|
|
|
# Write diffs to the GPU buffer
|
|
_apply_write_kernel[(n,)](
|
|
self.gpu,
|
|
self.gpu.stride(0),
|
|
indices_uva,
|
|
starts_uva,
|
|
write_contents,
|
|
cu_lens_uva,
|
|
BLOCK_SIZE=1024,
|
|
)
|
|
# Clear the staged writes
|
|
self.clear_staged_writes()
|
|
|
|
def clear_staged_writes(self) -> None:
|
|
self._staged_write_indices.clear()
|
|
self._staged_write_starts.clear()
|
|
self._staged_write_contents.clear()
|
|
self._staged_write_cu_lens.clear()
|
|
|
|
|
|
@triton.jit
|
|
def _apply_write_kernel(
|
|
output_ptr,
|
|
output_stride,
|
|
write_indices_ptr,
|
|
write_starts_ptr,
|
|
write_contents_ptr,
|
|
write_cu_lens_ptr,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
pid = tl.program_id(0)
|
|
row_idx = tl.load(write_indices_ptr + pid)
|
|
start_idx = tl.load(write_starts_ptr + pid)
|
|
|
|
cu_start = tl.load(write_cu_lens_ptr + pid - 1) if pid > 0 else 0
|
|
cu_end = tl.load(write_cu_lens_ptr + pid)
|
|
content_len = cu_end - cu_start
|
|
|
|
for i in range(0, content_len, BLOCK_SIZE):
|
|
block = i + tl.arange(0, BLOCK_SIZE)
|
|
mask = block < content_len
|
|
content = tl.load(write_contents_ptr + cu_start + block, mask=mask)
|
|
tl.store(
|
|
output_ptr + row_idx * output_stride + start_idx + block, content, mask=mask
|
|
)
|