# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Analytic flops/memory estimation module for transformer components, to help derive MFU (Model Flops Utilization) stats for a running model. """ import json import time from abc import ABC, abstractmethod from collections.abc import Iterable from dataclasses import asdict, dataclass from typing import Any, Protocol import prometheus_client import torch from pydantic import BaseModel, Field, ValidationError, model_validator from typing_extensions import Self import vllm.envs as envs from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.utils.torch_utils import ( STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, get_kv_cache_torch_dtype, ) from vllm.v1.core.sched.output import SchedulerOutput logger = init_logger(__name__) class InvalidComponent(Exception): """ Custom exception to indicate that a certain ComponentMetric is not applicable to the given VllmConfig. """ pass #### Basic Data Types #### @dataclass class DebugPerfStats: ## Stats for debugging the metrics calculation calc_duration: float = 0.0 # time spent calculating these stats num_prefill_requests: int = 0 num_decode_requests: int = 0 context_breakdown: dict[str, int] | None = None num_flops_per_gpu_breakdown: dict[str, int] | None = None num_read_bytes_per_gpu_breakdown: dict[str, int] | None = None num_write_bytes_per_gpu_breakdown: dict[str, int] | None = None @dataclass class PerfStats: num_flops_per_gpu: int = 0 num_read_bytes_per_gpu: int = 0 num_write_bytes_per_gpu: int = 0 debug_stats: DebugPerfStats | None = None @dataclass class ExecutionContext: """ Represents an execution context for a batch of requests. This class aggregates statistics across multiple requests in a batch, separately tracking prefill and decode phases. Example) - Batch with one full prefill (2048 tokens) and one decode (1 token, 8192 context): ctx = ExecutionContext() ctx.add(2048, 2048, is_prefill=True) ctx.add(1, 8192, is_prefill=False) """ # Prefill phase statistics num_prefill_requests: int = 0 prefill_num_tokens: int = 0 # sum of num_tokens for prefill requests prefill_context_len: int = 0 # sum of context_len for prefill requests prefill_token_context_product: int = 0 # sum of (num_tokens * context_len) # Decode phase statistics num_decode_requests: int = 0 decode_num_tokens: int = 0 # sum of num_tokens for decode requests decode_context_len: int = 0 # sum of context_len for decode requests decode_token_context_product: int = 0 # sum of (num_tokens * context_len) def add(self, num_tokens: int, context_len: int, is_prefill: bool) -> None: """Add a single request's statistics to this batch context.""" if is_prefill: self.num_prefill_requests += 1 self.prefill_num_tokens += num_tokens self.prefill_context_len += context_len self.prefill_token_context_product += num_tokens * context_len else: self.num_decode_requests += 1 self.decode_num_tokens += num_tokens self.decode_context_len += context_len self.decode_token_context_product += num_tokens * context_len def total_num_tokens(self) -> int: """Total number of tokens across all requests in the batch.""" return self.prefill_num_tokens + self.decode_num_tokens def total_token_context_product(self) -> int: """Total sum of (num_tokens * context_len) across all requests.""" return self.prefill_token_context_product + self.decode_token_context_product def num_logits_tokens(self) -> int: """Number of tokens that require logits computation (unembedding). For prefill, only the last token per request needs logits. For decode, all tokens need logits. """ return self.num_prefill_requests + self.decode_num_tokens @classmethod def from_single_request( cls, num_tokens: int, context_len: int, is_prefill: bool ) -> "ExecutionContext": """Create an ExecutionContext from a single request. This is a convenience method primarily for testing. """ ctx = cls() ctx.add(num_tokens, context_len, is_prefill) return ctx class ParsedArgs: """ Syntactic sugar so that Parsers can use dot notations to access/update the parsed arguments. e.g.) args = ParsedArgs() args.x = 3 args.y = args.x + 1 """ def __getattr__(self, name: str) -> Any: raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'") def __setattr__(self, name: str, value: Any) -> None: object.__setattr__(self, name, value) def model_dump(self) -> dict[str, Any]: return vars(self).copy() #### Abstract #### class Parser(Protocol): def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs: """ Parse the vllm config and update the current ParsedArgs and pass it on. If the parser isn't applicable to the vllm_config, it will do nothing. """ ... class ParserChain: """ Applies chain of parser in a sequential order. Later parsers might overwrite results from previous parsers, so parsers should be chained in the appropriate order if they are not mutually exclusive. """ def __init__(self, *parsers: Parser) -> None: self.parsers = list(parsers) def add_parser(self, parser: Parser) -> None: self.parsers.append(parser) def parse(self, vllm_config: VllmConfig) -> ParsedArgs: args = ParsedArgs() for parser in self.parsers: args = parser.parse(args, vllm_config) return args _COMPONENT_METRICS_REGISTRY: dict[str, type["ComponentMetrics"]] = {} class ComponentMetrics(BaseModel, ABC): """ Each concrete ComponentMetrics class is associated with: - fields that are required for metric derivation (fields are specified/validated through pydantic model) - parser to parse VllmConfig into fields - metric methods that derive flops/bytes for a given execution context """ @classmethod @abstractmethod def component_type(cls) -> str: ... @classmethod @abstractmethod def get_parser(cls) -> ParserChain: """ Return a ParserChain that provides values for all required fields. The returned parser chain must populate ParsedArgs with values for every field defined on this ComponentMetrics class. Missing fields will cause a ValidationError when from_vllm_config() is called. See individual Parser docstrings for which args they provide, and field comments on ComponentMetrics subclasses for which parser provides each field. """ ... def __init_subclass__(cls): _COMPONENT_METRICS_REGISTRY[cls.component_type()] = cls @classmethod def from_vllm_config(cls, vllm_config: VllmConfig) -> Self: """ Instantiate this class from VllmConfig. Raises ValidationError if parsing fails. """ parser = cls.get_parser() parsed_args = parser.parse(vllm_config) try: return cls.model_validate(parsed_args.model_dump()) except ValidationError as e: raise InvalidComponent(f"Invalid {cls.component_type()} config: {e}") from e @classmethod def registered_metrics(cls) -> Iterable[type["ComponentMetrics"]]: return iter(_COMPONENT_METRICS_REGISTRY.values()) @abstractmethod def get_num_flops_breakdown( self, ctx: ExecutionContext, per_gpu: bool = True ) -> dict[str, int]: ... @abstractmethod def get_read_bytes_breakdown( self, ctx: ExecutionContext, per_gpu: bool = True ) -> dict[str, int]: ... @abstractmethod def get_write_bytes_breakdown( self, ctx: ExecutionContext, per_gpu: bool = True ) -> dict[str, int]: ... def get_num_flops(self, ctx: ExecutionContext, per_gpu: bool = True) -> int: return sum(self.get_num_flops_breakdown(ctx, per_gpu).values()) def get_read_bytes(self, ctx: ExecutionContext, per_gpu: bool = True) -> int: return sum(self.get_read_bytes_breakdown(ctx, per_gpu).values()) def get_write_bytes(self, ctx: ExecutionContext, per_gpu: bool = True) -> int: return sum(self.get_write_bytes_breakdown(ctx, per_gpu).values()) #### parsers #### class BaseConfigParser(Parser): """ Parses base model configuration. Provides: vocab_size, hidden_size, num_attention_heads, num_hidden_layers, weight_byte_size, activation_byte_size, dp_size, tp_size, pp_size, enable_ep """ def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs: model_config = vllm_config.model_config args.vocab_size = model_config.get_vocab_size() args.hidden_size = model_config.get_hidden_size() # NOTE: model_config.get_attention_heads() divide by TP # so we access field manually here to get total num_heads args.num_attention_heads = get_required( model_config.hf_text_config, "num_attention_heads" ) args.num_hidden_layers = get_required( model_config.hf_text_config, "num_hidden_layers" ) model_dtype = vllm_config.model_config.dtype if isinstance(model_dtype, torch.dtype): torch_dtype = model_dtype elif isinstance(model_dtype, str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE: torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] else: # FIXME: handle this better logger.warning( "Unknown model_dtype %s, defaulting to bfloat16", model_dtype, ) torch_dtype = torch.bfloat16 args.weight_byte_size = get_dtype_size(torch_dtype) # FIXME: handle this better by parsing whether activations use # bf16, fp32, etc... args.activation_byte_size = 2 args.dp_size = vllm_config.parallel_config.data_parallel_size args.tp_size = vllm_config.parallel_config.tensor_parallel_size args.pp_size = vllm_config.parallel_config.pipeline_parallel_size args.enable_ep = vllm_config.parallel_config.enable_expert_parallel return args #### Attention #### class BaseAttentionConfigParser(Parser): """ Parses attention-specific configuration. Provides: num_key_value_heads, head_dim, cache_byte_size """ def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs: model_config = vllm_config.model_config args.num_key_value_heads = model_config.get_total_num_kv_heads() args.head_dim = model_config.get_head_size() model_dtype = vllm_config.model_config.dtype cache_dtype = vllm_config.cache_config.cache_dtype kv_cache_torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) args.cache_byte_size = get_dtype_size(kv_cache_torch_dtype) return args class AttentionQuantizationConfigParser(Parser): """ Parses quantization configuration for attention layers. Overrides: weight_byte_size """ def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs: cfg = vllm_config.quant_config if cfg is None: return args quant_method = cfg.get_name() if quant_method in ["fp8", "fbgemm_fp8"]: # FIXME: This is a hacky coarse-grained fp8 quantization detection. # FIXME: These configs also have concept of "ignored layers" and we # need to solve the same problem as above. args.weight_byte_size = 1 elif quant_method == "mxfp4": # FIXME: Also has "ignored layers" issue above args.weight_byte_size = 0.5 else: # FIXME: Add more parsing logic for different quant methods. raise InvalidComponent return args class AttentionMetrics(ComponentMetrics): # From BaseConfigParser num_hidden_layers: int = Field(..., gt=0) hidden_size: int = Field(..., gt=0) num_attention_heads: int = Field(..., gt=0) activation_byte_size: int = Field(..., gt=0) tp_size: int = Field(..., gt=0) pp_size: int = Field(..., gt=0) # From BaseAttentionConfigParser num_key_value_heads: int = Field(..., gt=0) head_dim: int = Field(..., gt=0) cache_byte_size: int = Field(..., gt=0) # From BaseConfig Parser, overridden by AttentionQuantizationConfigParser weight_byte_size: int | float = Field(..., gt=0) # TODO: discern cases where we have mixture of different attention layer types # such as SWA, MLA, etc. @classmethod def component_type(cls) -> str: return "attn" @classmethod def get_parser(cls) -> ParserChain: return ParserChain( BaseConfigParser(), BaseAttentionConfigParser(), AttentionQuantizationConfigParser(), ) def get_num_flops_breakdown( self, ctx: ExecutionContext, per_gpu: bool = True ) -> dict[str, int]: L, D, q, kv, d = ( self.num_hidden_layers, self.hidden_size, self.num_attention_heads, self.num_key_value_heads, self.head_dim, ) T = ctx.total_num_tokens() TC = ctx.total_token_context_product() if per_gpu: L //= self.pp_size # tensor parallel along heads q = max(1, q // self.tp_size) kv = max(1, kv // self.tp_size) return { "qkv_proj": 2 * T * D * (q + 2 * kv) * d * L, "attn_qk": 2 * q * TC * d * L, "attn_av": 2 * q * TC * d * L, "out_proj": 2 * T * D * q * d * L, } def get_read_bytes_breakdown( self, ctx: ExecutionContext, per_gpu: bool = True ) -> dict[str, int]: L, D, q, kv, d = ( self.num_hidden_layers, self.hidden_size, self.num_attention_heads, self.num_key_value_heads, self.head_dim, ) T = ctx.total_num_tokens() if per_gpu: L //= self.pp_size # tensor parallel along heads q = max(1, q // self.tp_size) kv = max(1, kv // self.tp_size) read_bytes = {} read_bytes["qkv_input"] = T * D * self.activation_byte_size * L read_bytes["qkv_weight"] = int(D * (q + 2 * kv) * d * self.weight_byte_size * L) # Attention input reads differ between prefill and decode # Prefill: read Q, K, V activations (all in activation_byte_size) if ctx.prefill_num_tokens > 0: read_bytes["attn_input"] = ( (ctx.prefill_num_tokens * q + 2 * ctx.prefill_context_len * kv) * d * self.activation_byte_size * L ) # Decode: read Q activations + read K, V from cache (in cache_byte_size) if ctx.decode_num_tokens > 0: read_bytes["attn_input"] = read_bytes.get("attn_input", 0) + ( ctx.decode_num_tokens * q * d * self.activation_byte_size * L + 2 * ctx.decode_context_len * kv * d * self.cache_byte_size * L ) read_bytes["out_input"] = T * q * d * self.activation_byte_size * L read_bytes["out_weight"] = int(q * d * D * self.weight_byte_size * L) return read_bytes def get_write_bytes_breakdown( self, ctx: ExecutionContext, per_gpu: bool = True ) -> dict[str, int]: """Calculate write memory traffic for attention layers.""" L, D, q, kv, d = ( self.num_hidden_layers, self.hidden_size, self.num_attention_heads, self.num_key_value_heads, self.head_dim, ) T = ctx.total_num_tokens() if per_gpu: L //= self.pp_size # tensor parallel along heads q = max(1, q // self.tp_size) kv = max(1, kv // self.tp_size) return { "qkv_output": T * (q + 2 * kv) * d * self.activation_byte_size * L, "kv_cache": 2 * T * kv * d * self.cache_byte_size * L, "out_output": T * D * self.activation_byte_size * L, } #### Ffn #### class BaseFfnConfigParser(Parser): """ Parses FFN and MoE configuration. Provides: intermediate_size, num_experts, num_experts_per_tok, moe_intermediate_size, num_shared_experts, num_moe_layers """ def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs: cfg = vllm_config.model_config.hf_config if hasattr(cfg, "text_config") and cfg.text_config is not None: cfg = cfg.text_config args.intermediate_size = getattr(cfg, "intermediate_size", args.hidden_size * 4) # Try different naming conventions. args.num_experts = vllm_config.model_config.get_num_experts() args.num_experts_per_tok = getattr_from_list( cfg, ["num_experts_per_tok", "moe_topk"], 0 ) args.moe_intermediate_size = getattr_from_list( cfg, ["moe_intermediate_size", "intermediate_size"], 0 ) args.num_shared_experts = getattr_from_list( cfg, ["n_shared_experts", "num_shared_experts"], 0 ) is_moe = args.num_experts != 0 # Assume all MoE layers by default args.num_moe_layers = args.num_hidden_layers if is_moe else 0 return args class FfnParallelParser(Parser): """ Parses FFN parallelism configuration. Provides: ffn_tp_size, ffn_ep_size """ def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs: # NOTE: ffn tp_size does not equal the tp_size parameter directly. # e.g.) If we use DP2TP4, ffn will use TP8 (or EP8 if EP is enabled.) if args.enable_ep: ffn_tp_size, ffn_ep_size = 1, args.dp_size * args.tp_size else: ffn_tp_size, ffn_ep_size = args.dp_size * args.tp_size, 1 args.ffn_tp_size = ffn_tp_size args.ffn_ep_size = ffn_ep_size return args class InterleaveMoeLayerStepParser(Parser): """ Parses interleave_moe_layer_step field for models like Llama4. Overrides: num_moe_layers """ def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs: cfg = vllm_config.model_config.hf_config if hasattr(cfg, "text_config") and cfg.text_config is not None: cfg = cfg.text_config if ( hasattr(cfg, "interleave_moe_layer_step") and cfg.interleave_moe_layer_step > 0 ): args.num_moe_layers = len( [ layer for layer in range(args.num_hidden_layers) if (layer + 1) % cfg.interleave_moe_layer_step == 0 ] ) return args class MoeLayerFreqParser(Parser): """ Parses moe_layer_freq and first_k_dense_replace fields for models like Deepseek. Overrides: num_moe_layers """ def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs: cfg = vllm_config.model_config.hf_config if hasattr(cfg, "text_config") and cfg.text_config is not None: cfg = cfg.text_config if hasattr(cfg, "moe_layer_freq") and hasattr(cfg, "first_k_dense_replace"): args.num_moe_layers = len( [ layer for layer in range(args.num_hidden_layers) if layer >= cfg.first_k_dense_replace and layer % cfg.moe_layer_freq == 0 ] ) return args class FfnQuantizationConfigParser(Parser): """ Parses quantization configuration for FFN layers. Overrides: weight_byte_size """ def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs: cfg = vllm_config.quant_config if cfg is None: return args quant_method = cfg.get_name() if quant_method in ["fp8", "fbgemm_fp8"]: # FIXME: This is a hacky coarse-grained fp8 quantization detection. # (there might be more quantization methods for fp8). # FIXME: These configs also have concept of "ignored layers" and we # need to solve the same problem as above. args.weight_byte_size = 1 pass elif quant_method == "mxfp4": # FIXME: Also has "ignored layers" issue above args.weight_byte_size = 0.5 else: # FIXME: Add more parsing logic for different quant methods. raise InvalidComponent return args class FfnMetrics(ComponentMetrics): # From BaseConfigParser num_hidden_layers: int = Field(..., gt=0) hidden_size: int = Field(..., gt=0) activation_byte_size: int = Field(..., gt=0) pp_size: int = Field(..., gt=0) # From FfnParallelParser ffn_tp_size: int = Field(..., gt=0) ffn_ep_size: int = Field(..., gt=0) # From BaseFfnConfigParser intermediate_size: int = Field(..., gt=0) num_experts: int = Field(0) num_experts_per_tok: int = Field(1) moe_intermediate_size: int = Field(0) num_shared_experts: int = Field(0) # From BaseConfigParser, can be overridden InterleaveMoeLayerStep or MoeLayerFreq num_moe_layers: int = Field(..., ge=0) # FIXME: might have to make this more granular # (i.e. dense_weight_byte_size, moe_routed_weight_byte_size, # moe_shared_weight_byte_size) # since it can differ from byte size of other components (e.g. attn) # and can differ even from each other. # From BaseConfigParser, can be overridden by FfnQuantizationConfigParser weight_byte_size: int | float = Field(..., gt=0) @model_validator(mode="after") def validate_moe_fields(self) -> Self: """Validate that MoE-related fields are properly set when num_moe_layers > 0.""" if self.num_moe_layers > 0: assert self.num_experts, f"{self.num_experts=}" assert self.num_experts_per_tok, f"{self.num_experts_per_tok=}" assert self.moe_intermediate_size, f"{self.moe_intermediate_size=}" return self @classmethod def component_type(cls) -> str: return "ffn" @classmethod def get_parser(cls) -> ParserChain: return ParserChain( BaseConfigParser(), FfnParallelParser(), BaseFfnConfigParser(), InterleaveMoeLayerStepParser(), MoeLayerFreqParser(), FfnQuantizationConfigParser(), ) def get_num_flops_breakdown( self, ctx: ExecutionContext, per_gpu: bool = True ) -> dict[str, int]: """Calculate flops breakdown for FFN layers.""" L, D, DI = self.num_hidden_layers, self.hidden_size, self.intermediate_size Lm, E, MI, S = ( self.num_moe_layers, self.num_experts_per_tok, self.moe_intermediate_size, self.num_shared_experts, ) T = ctx.total_num_tokens() Ld = L - Lm num_activated_tokens = T * E if E else 0 if per_gpu: Ld //= self.pp_size Lm //= self.pp_size DI //= self.ffn_tp_size if MI is not None: MI //= self.ffn_tp_size if E: num_activated_tokens //= self.ffn_ep_size flops = {} # Dense FFN layers (SwiGLU: 3 linear layers: up, gate, down) if Ld: flops["dense_ffn"] = 2 * D * 3 * DI * T * Ld # MoE routed experts (each token activates E experts) if Lm and E: flops["routed_ffn"] = 2 * D * 3 * MI * num_activated_tokens * Lm # MoE shared experts (all S shared experts run for every token) if Lm and S: flops["shared_ffn"] = 2 * D * 3 * MI * S * T * Lm return flops def get_read_bytes_breakdown( self, ctx: ExecutionContext, per_gpu: bool = True ) -> dict[str, int]: """Calculate read memory traffic for FFN layers.""" L, D, DI = self.num_hidden_layers, self.hidden_size, self.intermediate_size Lm, E, MI, S = ( self.num_moe_layers, self.num_experts_per_tok, self.moe_intermediate_size, self.num_shared_experts, ) T = ctx.total_num_tokens() num_experts = self.num_experts Ld = L - Lm num_activated_tokens = T * E if E else 0 if per_gpu: Ld //= self.pp_size Lm //= self.pp_size DI //= self.ffn_tp_size if MI is not None: MI //= self.ffn_tp_size if E: num_activated_tokens //= self.ffn_ep_size if num_experts is not None: num_experts //= self.ffn_ep_size read_bytes = {} # Dense FFN layers (3 GEMMs: up, gate, down projections + SiLU activation) if Ld: read_bytes["dense_up_gate_input"] = int( T * D * self.activation_byte_size * Ld ) read_bytes["dense_up_gate_weights"] = int( 2 * D * DI * self.weight_byte_size * Ld ) read_bytes["dense_silu_input"] = int( 2 * T * DI * self.activation_byte_size * Ld ) read_bytes["dense_down_input"] = int( T * DI * self.activation_byte_size * Ld ) read_bytes["dense_down_weights"] = int(D * DI * self.weight_byte_size * Ld) if Lm: # MoE routed expert reads if E: # FIXME: Assume perfect load balancing for now. num_activated_experts = min(num_activated_tokens, num_experts) read_bytes["routed_up_gate_input"] = int( num_activated_tokens * D * self.activation_byte_size * Lm ) read_bytes["routed_up_gate_weights"] = int( 2 * D * MI * num_activated_experts * self.weight_byte_size * Lm ) read_bytes["routed_silu_input"] = int( 2 * num_activated_tokens * MI * self.activation_byte_size * Lm ) read_bytes["routed_down_input"] = int( num_activated_tokens * MI * self.activation_byte_size * Lm ) read_bytes["routed_down_weights"] = int( D * MI * num_activated_experts * self.weight_byte_size * Lm ) # MoE shared expert reads if S: read_bytes["shared_up_gate_input"] = int( T * D * self.activation_byte_size * Lm ) read_bytes["shared_up_gate_weights"] = int( 2 * D * MI * S * self.weight_byte_size * Lm ) read_bytes["shared_silu_input"] = int( 2 * T * MI * S * self.activation_byte_size * Lm ) read_bytes["shared_down_input"] = int( T * MI * self.activation_byte_size * Lm ) read_bytes["shared_down_weights"] = int( D * MI * S * self.weight_byte_size * Lm ) return read_bytes def get_write_bytes_breakdown( self, ctx: ExecutionContext, per_gpu: bool = True ) -> dict[str, int]: """Calculate write memory traffic for FFN layers.""" L, D, DI = self.num_hidden_layers, self.hidden_size, self.intermediate_size Lm, E, MI, S = ( self.num_moe_layers, self.num_experts_per_tok, self.moe_intermediate_size, self.num_shared_experts, ) T = ctx.total_num_tokens() Ld = L - Lm num_activated_tokens = T * E if E else 0 if per_gpu: Ld //= self.pp_size Lm //= self.pp_size DI //= self.ffn_tp_size if MI is not None: MI //= self.ffn_tp_size if E: num_activated_tokens //= self.ffn_ep_size write_bytes = {} # Dense FFN layers if Ld: write_bytes["dense_up_gate_output"] = int( 2 * T * DI * self.activation_byte_size * Ld ) write_bytes["dense_silu_output"] = int( T * DI * self.activation_byte_size * Ld ) write_bytes["dense_down_output"] = int( T * D * self.activation_byte_size * Ld ) # MoE outputs if Lm: if E: write_bytes["routed_up_gate_output"] = int( 2 * num_activated_tokens * MI * self.activation_byte_size * Lm ) write_bytes["routed_silu_output"] = int( num_activated_tokens * MI * self.activation_byte_size * Lm ) write_bytes["routed_down_output"] = int( num_activated_tokens * D * self.activation_byte_size * Lm ) if S: write_bytes["shared_up_gate_output"] = int( 2 * T * S * MI * self.activation_byte_size * Lm ) write_bytes["shared_silu_output"] = int( T * S * MI * self.activation_byte_size * Lm ) write_bytes["shared_down_output"] = int( T * S * D * self.activation_byte_size * Lm ) return write_bytes #### Unembed #### class UnembedMetrics(ComponentMetrics): # From BaseConfigParser hidden_size: int = Field(..., gt=0) vocab_size: int = Field(..., gt=0) weight_byte_size: int = Field(..., gt=0) activation_byte_size: int = Field(..., gt=0) tp_size: int @classmethod def component_type(cls) -> str: return "unembed" @classmethod def get_parser(cls) -> ParserChain: return ParserChain( BaseConfigParser(), ) def get_num_flops_breakdown( self, ctx: ExecutionContext, per_gpu: bool = True ) -> dict[str, int]: """Calculate flops breakdown for unembedding layer.""" D, V = self.hidden_size, self.vocab_size T = ctx.num_logits_tokens() if per_gpu: V //= self.tp_size return { "unembed": 2 * T * D * V, } def get_read_bytes_breakdown( self, ctx: ExecutionContext, per_gpu: bool = True ) -> dict[str, int]: """Calculate read memory traffic for unembedding layer.""" D, V = self.hidden_size, self.vocab_size T = ctx.num_logits_tokens() if per_gpu: V //= self.tp_size return { "input": T * D * self.activation_byte_size, "weight": D * V * self.weight_byte_size, } def get_write_bytes_breakdown( self, ctx: ExecutionContext, per_gpu: bool = True ) -> dict[str, int]: """Calculate write memory traffic for unembedding layer.""" V = self.vocab_size T = ctx.num_logits_tokens() if per_gpu: V //= self.tp_size return { "output": T * V * self.activation_byte_size, } #### ModelMetrics #### class ModelMetrics: def __init__(self, vllm_config: VllmConfig) -> None: """ Parse vllm_config to instantiate metrics for each component. is_enabled() will return False if no component metrics could be instantiated. """ self.vllm_config = vllm_config self.metrics: list[ComponentMetrics] = [] for metric_cls in ComponentMetrics.registered_metrics(): try: metric = metric_cls.from_vllm_config(vllm_config) self.metrics.append(metric) logger.info( "Instantiated ComponentMetrics [%s] with (%s)", metric.component_type(), str(metric), ) except InvalidComponent as e: logger.debug( "Failed to instantiate %s from %s", metric_cls.component_type(), str(e), ) def is_enabled(self) -> bool: return len(self.metrics) > 0 def get_num_flops(self, ctx: ExecutionContext, per_gpu: bool = True) -> int: return sum(metric.get_num_flops(ctx, per_gpu) for metric in self.metrics) def get_read_bytes(self, ctx: ExecutionContext, per_gpu: bool = True) -> int: return sum(metric.get_read_bytes(ctx, per_gpu) for metric in self.metrics) def get_write_bytes(self, ctx: ExecutionContext, per_gpu: bool = True) -> int: return sum(metric.get_write_bytes(ctx, per_gpu) for metric in self.metrics) def get_num_flops_breakdown( self, ctx: ExecutionContext, per_gpu: bool = True ) -> dict[str, int]: total = {} for metric in self.metrics: breakdown = metric.get_num_flops_breakdown(ctx, per_gpu) component = metric.component_type() prefixed = {f"{component}.{key}": val for key, val in breakdown.items()} total.update(prefixed) return total def get_read_bytes_breakdown( self, ctx: ExecutionContext, per_gpu: bool = True ) -> dict[str, int]: total = {} for metric in self.metrics: breakdown = metric.get_read_bytes_breakdown(ctx, per_gpu) component = metric.component_type() prefixed = {f"{component}.{key}": val for key, val in breakdown.items()} total.update(prefixed) return total def get_write_bytes_breakdown( self, ctx: ExecutionContext, per_gpu: bool = True ) -> dict[str, int]: total = {} for metric in self.metrics: breakdown = metric.get_write_bytes_breakdown(ctx, per_gpu) component = metric.component_type() prefixed = {f"{component}.{key}": val for key, val in breakdown.items()} total.update(prefixed) return total def get_step_perf_stats_per_gpu( self, scheduler_output: SchedulerOutput ) -> PerfStats: """ Calculate perf stats for the current step based on scheduled tokens. """ t0 = time.monotonic() # Build a single batch context ctx = ExecutionContext() # Process new requests (these are in prefill phase) for new_req in scheduler_output.scheduled_new_reqs: req_id = new_req.req_id num_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0) if num_tokens == 0: continue # For new requests, context_len = num_computed_tokens + num_tokens # num_computed_tokens represents previously computed tokens in the sequence context_len = new_req.num_computed_tokens + num_tokens ctx.add(num_tokens, context_len, is_prefill=True) # Process cached requests (continuing requests) cached_reqs = scheduler_output.scheduled_cached_reqs for i, req_id in enumerate(cached_reqs.req_ids): num_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0) if num_tokens == 0: continue # For cached requests, we have the current num_computed_tokens num_computed_tokens = cached_reqs.num_computed_tokens[i] context_len = num_computed_tokens + num_tokens # Cached requests are typically in decode phase (num_tokens == 1) # unless they're doing chunked prefill (num_tokens > 1) is_prefill = num_tokens > 1 ctx.add(num_tokens, context_len, is_prefill) num_flops_breakdown = self.get_num_flops_breakdown(ctx, True) read_bytes_breakdown = self.get_read_bytes_breakdown(ctx, True) write_bytes_breakdown = self.get_write_bytes_breakdown(ctx, True) perf_stats = PerfStats( sum(num_flops_breakdown.values()), sum(read_bytes_breakdown.values()), sum(write_bytes_breakdown.values()), ) if envs.VLLM_DEBUG_MFU_METRICS: perf_stats.debug_stats = DebugPerfStats( time.monotonic() - t0, ctx.num_prefill_requests, ctx.num_decode_requests, asdict(ctx), num_flops_breakdown, read_bytes_breakdown, write_bytes_breakdown, ) return perf_stats #### Logging #### class PerfMetricsDebugLogging: def __init__(self): self.reset() def reset(self): self.total_calc_duration: float = 0.0 self.total_num_prefill_requests: int = 0 self.total_num_decode_requests: int = 0 self.total_num_batches: int = 0 self.total_context_breakdown: dict[str, int] = {} self.total_num_flops_per_gpu_breakdown: dict[str, int] = {} self.total_read_bytes_per_gpu_breakdown: dict[str, int] = {} self.total_write_bytes_per_gpu_breakdown: dict[str, int] = {} def observe(self, debug_stats: DebugPerfStats) -> None: self.total_calc_duration += debug_stats.calc_duration self.total_num_prefill_requests += debug_stats.num_prefill_requests self.total_num_decode_requests += debug_stats.num_decode_requests self.total_num_batches += 1 for dst, src in zip( [ self.total_context_breakdown, self.total_num_flops_per_gpu_breakdown, self.total_read_bytes_per_gpu_breakdown, self.total_write_bytes_per_gpu_breakdown, ], [ debug_stats.context_breakdown, debug_stats.num_flops_per_gpu_breakdown, debug_stats.num_read_bytes_per_gpu_breakdown, debug_stats.num_write_bytes_per_gpu_breakdown, ], ): assert isinstance(src, dict) for key, val in src.items(): dst[key] = dst.get(key, 0) + val def log(self, log_fn, log_prefix: str, delta_time: float): # pretty print breakdowns total_num_flops_per_gpu_breakdown = { k: f"{v / 1e12:.1f}TF" for k, v in self.total_num_flops_per_gpu_breakdown.items() } total_read_bytes_per_gpu_breakdown = { k: f"{v / 1e9:.1f}GB" for k, v in self.total_read_bytes_per_gpu_breakdown.items() } total_write_bytes_per_gpu_breakdown = { k: f"{v / 1e9:.1f}GB" for k, v in self.total_write_bytes_per_gpu_breakdown.items() } logger.debug( "%sMFU details: %s", log_prefix, json.dumps( { "prefill_reqs": self.total_num_prefill_requests, "decode_reqs": self.total_num_decode_requests, "num_batches": self.total_num_batches, "context_breakdown": self.total_context_breakdown, "flops_breakdown": total_num_flops_per_gpu_breakdown, "num_read_bytes_breakdown": total_read_bytes_per_gpu_breakdown, "num_write_bytes_breakdown": (total_write_bytes_per_gpu_breakdown), "duration": f"{delta_time:.1f}s", "mfu_calc_overhead": ( f"{self.total_calc_duration / delta_time:.1%}" ), }, indent=2, ), ) class PerfMetricsLogging: def __init__(self, vllm_config: VllmConfig): self.vllm_config = vllm_config self.pp_size = vllm_config.parallel_config.pipeline_parallel_size self.debug_logging: PerfMetricsDebugLogging | None = None if envs.VLLM_DEBUG_MFU_METRICS: self.debug_logging = PerfMetricsDebugLogging() self.reset() def reset(self): self.last_log_time = time.monotonic() self.total_num_flops_per_gpu: int = 0 self.total_read_bytes_per_gpu: int = 0 self.total_write_bytes_per_gpu: int = 0 if self.debug_logging: self.debug_logging.reset() def observe(self, perf_stats: PerfStats) -> None: self.total_num_flops_per_gpu += perf_stats.num_flops_per_gpu self.total_read_bytes_per_gpu += perf_stats.num_read_bytes_per_gpu self.total_write_bytes_per_gpu += perf_stats.num_write_bytes_per_gpu if self.debug_logging: assert perf_stats.debug_stats is not None self.debug_logging.observe(perf_stats.debug_stats) def log(self, log_fn=logger.info, log_prefix: str = "") -> None: if not ( self.total_num_flops_per_gpu or self.total_read_bytes_per_gpu or self.total_write_bytes_per_gpu ): return now = time.monotonic() delta_time = now - self.last_log_time if delta_time <= 0.0: avg_tflops_per_gpu = 0.0 avg_gbps_per_gpu = 0.0 else: avg_tflops_per_gpu = self.total_num_flops_per_gpu / delta_time / 1e12 avg_gbps_per_gpu = ( (self.total_read_bytes_per_gpu + self.total_write_bytes_per_gpu) / delta_time / 1e9 ) log_fn( "%sMFU: %.1f TF/s/GPU %.1f GB/s/GPU", log_prefix, avg_tflops_per_gpu, avg_gbps_per_gpu, ) if self.debug_logging: self.debug_logging.log(log_fn, log_prefix, delta_time) self.reset() #### Prometheus Integration #### class PerfMetricsProm: """Record performance metrics in Prometheus. Average TFLOPS (tera floating-point operations per second) can be calculated using a PromQL query: rate(vllm:estimated_flops_per_gpu_total[1m]) / 1e12 Average memory bandwidth in GB/s can be calculated using: (rate(vllm:estimated_read_bytes_per_gpu_total[1m]) + rate(vllm:estimated_write_bytes_per_gpu_total[1m])) / 1e9 """ _counter_cls = prometheus_client.Counter def __init__( self, vllm_config: VllmConfig, labelnames: list[str], per_engine_labelvalues: dict[int, list[object]], ): counter_flops = self._counter_cls( name="vllm:estimated_flops_per_gpu_total", documentation=( "Estimated number of floating point operations per GPU " "(for Model Flops Utilization calculations)." ), labelnames=labelnames, ) self.counter_flops = make_per_engine(counter_flops, per_engine_labelvalues) counter_read_bytes = self._counter_cls( name="vllm:estimated_read_bytes_per_gpu_total", documentation=( "Estimated number of bytes read from memory per GPU " "(for Model Flops Utilization calculations)." ), labelnames=labelnames, ) self.counter_read_bytes = make_per_engine( counter_read_bytes, per_engine_labelvalues ) counter_write_bytes = self._counter_cls( name="vllm:estimated_write_bytes_per_gpu_total", documentation=( "Estimated number of bytes written to memory per GPU " "(for Model Flops Utilization calculations)." ), labelnames=labelnames, ) self.counter_write_bytes = make_per_engine( counter_write_bytes, per_engine_labelvalues ) def observe(self, perf_stats: PerfStats, engine_idx: int = 0): if not ( perf_stats.num_flops_per_gpu or perf_stats.num_read_bytes_per_gpu or perf_stats.num_write_bytes_per_gpu ): return self.counter_flops[engine_idx].inc(perf_stats.num_flops_per_gpu) self.counter_read_bytes[engine_idx].inc(perf_stats.num_read_bytes_per_gpu) self.counter_write_bytes[engine_idx].inc(perf_stats.num_write_bytes_per_gpu) def make_per_engine( counter: prometheus_client.Counter, per_engine_labelvalues: dict[int, list[object]] ): """Create a counter for each label value.""" return { idx: counter.labels(*labelvalues) for idx, labelvalues in per_engine_labelvalues.items() } ## util functions def get_required(obj: object, attr: str): """Get an attr from an object, or throw a InvalidComponentError if it's not set.""" if not hasattr(obj, attr): raise InvalidComponent(f"Missing required attr {attr} in config") return getattr(obj, attr) def getattr_from_list(obj: object, attrs: list[str], default: object = None): """Try to get the first attr that exists in the object from a list of attrs. Otherwise return None.""" for attr in attrs: if hasattr(obj, attr): return getattr(obj, attr) return default