Sync from v0.13
This commit is contained in:
0
vllm/v1/worker/gpu/metrics/__init__.py
Normal file
0
vllm/v1/worker/gpu/metrics/__init__.py
Normal file
42
vllm/v1/worker/gpu/metrics/logits.py
Normal file
42
vllm/v1/worker/gpu/metrics/logits.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
from torch._inductor.runtime.triton_helpers import libdevice
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _num_nans_kernel(
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
num_nans_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
num_nans = 0
|
||||
for i in range(0, vocab_size, BLOCK_SIZE):
|
||||
block = i + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + req_idx * logits_stride + block, mask=mask, other=0
|
||||
)
|
||||
logits = logits.to(tl.float32)
|
||||
is_nan = libdevice.isnan(logits).to(tl.int1)
|
||||
num_nans += tl.sum(is_nan).to(tl.int32)
|
||||
tl.store(num_nans_ptr + req_idx, num_nans)
|
||||
|
||||
|
||||
def get_num_nans(logits: torch.Tensor) -> torch.Tensor:
|
||||
num_reqs, vocab_size = logits.shape
|
||||
BLOCK_SIZE = 8192
|
||||
num_nans = torch.empty(num_reqs, dtype=torch.int32, device=logits.device)
|
||||
_num_nans_kernel[(num_reqs,)](
|
||||
logits,
|
||||
logits.stride(0),
|
||||
num_nans,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
return num_nans
|
||||
Reference in New Issue
Block a user