32 lines
1.2 KiB
Python
32 lines
1.2 KiB
Python
import torch
|
|
|
|
class AsyncMetricsCollector:
|
|
"""Class which copies rejection/typical-acceptance sampler metrics
|
|
from the device to CPU on a non-default Torch stream.
|
|
"""
|
|
def _copy_rejsample_metrics_async(self):
|
|
"""Copy rejection/typical-acceptance sampling metrics
|
|
(number of accepted tokens, etc) to CPU asynchronously.
|
|
|
|
Returns a CUDA event recording when the copy is complete.
|
|
"""
|
|
import torch_vacc
|
|
assert self._copy_stream is not None
|
|
self._copy_stream.wait_stream(torch.vacc.current_stream())
|
|
|
|
with torch.vacc.stream(self._copy_stream):
|
|
self._aggregate_num_accepted_tokens.copy_(
|
|
self.spec_decode_sampler.num_accepted_tokens,
|
|
non_blocking=True)
|
|
self._aggregate_num_emitted_tokens.copy_(
|
|
self.spec_decode_sampler.num_emitted_tokens, non_blocking=True)
|
|
# Number of draft tokens is calculated on CPU, so no copy is
|
|
# required.
|
|
self._aggregate_num_draft_tokens = (
|
|
self.spec_decode_sampler.num_draft_tokens)
|
|
|
|
aggregate_metrics_ready = torch.vacc.Event()
|
|
aggregate_metrics_ready.record(self._copy_stream)
|
|
|
|
return aggregate_metrics_ready
|