Files

32 lines
1.2 KiB
Python
Raw Permalink Normal View History

2026-04-02 04:53:13 +00:00
import torch
class AsyncMetricsCollector:
"""Class which copies rejection/typical-acceptance sampler metrics
from the device to CPU on a non-default Torch stream.
"""
def _copy_rejsample_metrics_async(self):
"""Copy rejection/typical-acceptance sampling metrics
(number of accepted tokens, etc) to CPU asynchronously.
Returns a CUDA event recording when the copy is complete.
"""
import torch_vacc
assert self._copy_stream is not None
self._copy_stream.wait_stream(torch.vacc.current_stream())
with torch.vacc.stream(self._copy_stream):
self._aggregate_num_accepted_tokens.copy_(
self.spec_decode_sampler.num_accepted_tokens,
non_blocking=True)
self._aggregate_num_emitted_tokens.copy_(
self.spec_decode_sampler.num_emitted_tokens, non_blocking=True)
# Number of draft tokens is calculated on CPU, so no copy is
# required.
self._aggregate_num_draft_tokens = (
self.spec_decode_sampler.num_draft_tokens)
aggregate_metrics_ready = torch.vacc.Event()
aggregate_metrics_ready.record(self._copy_stream)
return aggregate_metrics_ready