63 lines
1.7 KiB
Python
63 lines
1.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from dataclasses import dataclass
|
|
|
|
import numpy as np
|
|
|
|
from vllm.logger import init_logger
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class SpecDecodingStats:
|
|
num_draft_tokens: int = 0
|
|
num_accepted_tokens: int = 0
|
|
|
|
def take(self):
|
|
copied = SpecDecodingStats(self.num_draft_tokens,
|
|
self.num_accepted_tokens)
|
|
self.reset()
|
|
return copied
|
|
|
|
def reset(self):
|
|
self.num_draft_tokens = 0
|
|
self.num_accepted_tokens = 0
|
|
|
|
def observe(self, num_draft_tokens: int, num_accepted_tokens: int):
|
|
self.num_draft_tokens += num_draft_tokens
|
|
self.num_accepted_tokens += num_accepted_tokens
|
|
|
|
|
|
class SpecDecodingMetrics:
|
|
|
|
def __init__(self):
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.num_draft_tokens: list[int] = []
|
|
self.num_accepted_tokens: list[int] = []
|
|
|
|
def observe(self, spec_decoding_stats: SpecDecodingStats):
|
|
self.num_draft_tokens.append(spec_decoding_stats.num_draft_tokens)
|
|
self.num_accepted_tokens.append(
|
|
spec_decoding_stats.num_accepted_tokens)
|
|
|
|
def log(self):
|
|
num_draft_tokens = np.sum(self.num_draft_tokens)
|
|
num_accepted_tokens = np.sum(self.num_accepted_tokens)
|
|
|
|
draft_acceptance_rate = (num_accepted_tokens / num_draft_tokens *
|
|
100 if num_draft_tokens > 0 else float("nan"))
|
|
|
|
logger.info(
|
|
"SpecDecoding metrics: "
|
|
"Draft acceptance rate: %.1f%%, "
|
|
"Accepted: %d tokens, "
|
|
"Drafted: %d tokens",
|
|
draft_acceptance_rate,
|
|
num_accepted_tokens,
|
|
num_draft_tokens,
|
|
)
|
|
self.reset()
|