14 lines
376 B
Python
14 lines
376 B
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import torch
|
|
|
|
from vllm.triton_utils import triton
|
|
|
|
|
|
def set_triton_allocator(device: torch.device):
|
|
def alloc_fn(size: int, alignment: int, stream: int | None):
|
|
return torch.empty(size, device=device, dtype=torch.int8)
|
|
|
|
triton.set_allocator(alloc_fn)
|