init
This commit is contained in:
31
vllm_vacc/vllm/spec_decode/metrics.py
Normal file
31
vllm_vacc/vllm/spec_decode/metrics.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
|
||||
class AsyncMetricsCollector:
|
||||
"""Class which copies rejection/typical-acceptance sampler metrics
|
||||
from the device to CPU on a non-default Torch stream.
|
||||
"""
|
||||
def _copy_rejsample_metrics_async(self):
|
||||
"""Copy rejection/typical-acceptance sampling metrics
|
||||
(number of accepted tokens, etc) to CPU asynchronously.
|
||||
|
||||
Returns a CUDA event recording when the copy is complete.
|
||||
"""
|
||||
import torch_vacc
|
||||
assert self._copy_stream is not None
|
||||
self._copy_stream.wait_stream(torch.vacc.current_stream())
|
||||
|
||||
with torch.vacc.stream(self._copy_stream):
|
||||
self._aggregate_num_accepted_tokens.copy_(
|
||||
self.spec_decode_sampler.num_accepted_tokens,
|
||||
non_blocking=True)
|
||||
self._aggregate_num_emitted_tokens.copy_(
|
||||
self.spec_decode_sampler.num_emitted_tokens, non_blocking=True)
|
||||
# Number of draft tokens is calculated on CPU, so no copy is
|
||||
# required.
|
||||
self._aggregate_num_draft_tokens = (
|
||||
self.spec_decode_sampler.num_draft_tokens)
|
||||
|
||||
aggregate_metrics_ready = torch.vacc.Event()
|
||||
aggregate_metrics_ready.record(self._copy_stream)
|
||||
|
||||
return aggregate_metrics_ready
|
||||
Reference in New Issue
Block a user