43 lines
1.2 KiB
Python
43 lines
1.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import torch
|
|
from torch._inductor.runtime.triton_helpers import libdevice
|
|
|
|
from vllm.triton_utils import tl, triton
|
|
|
|
|
|
@triton.jit
|
|
def _num_nans_kernel(
|
|
logits_ptr,
|
|
logits_stride,
|
|
num_nans_ptr,
|
|
vocab_size,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
req_idx = tl.program_id(0)
|
|
num_nans = 0
|
|
for i in range(0, vocab_size, BLOCK_SIZE):
|
|
block = i + tl.arange(0, BLOCK_SIZE)
|
|
mask = block < vocab_size
|
|
logits = tl.load(
|
|
logits_ptr + req_idx * logits_stride + block, mask=mask, other=0
|
|
)
|
|
logits = logits.to(tl.float32)
|
|
is_nan = libdevice.isnan(logits).to(tl.int1)
|
|
num_nans += tl.sum(is_nan).to(tl.int32)
|
|
tl.store(num_nans_ptr + req_idx, num_nans)
|
|
|
|
|
|
def get_num_nans(logits: torch.Tensor) -> torch.Tensor:
|
|
num_reqs, vocab_size = logits.shape
|
|
BLOCK_SIZE = 8192
|
|
num_nans = torch.empty(num_reqs, dtype=torch.int32, device=logits.device)
|
|
_num_nans_kernel[(num_reqs,)](
|
|
logits,
|
|
logits.stride(0),
|
|
num_nans,
|
|
vocab_size,
|
|
BLOCK_SIZE=BLOCK_SIZE,
|
|
)
|
|
return num_nans
|