1335 lines
44 KiB
Python
1335 lines
44 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||
|
|
|
||
|
|
"""
|
||
|
|
Analytic flops/memory estimation module for transformer components,
|
||
|
|
to help derive MFU (Model Flops Utilization) stats for a running model.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import json
|
||
|
|
import time
|
||
|
|
from abc import ABC, abstractmethod
|
||
|
|
from collections.abc import Iterable
|
||
|
|
from dataclasses import asdict, dataclass
|
||
|
|
from typing import Any, Protocol
|
||
|
|
|
||
|
|
import prometheus_client
|
||
|
|
import torch
|
||
|
|
from pydantic import BaseModel, Field, ValidationError, model_validator
|
||
|
|
from typing_extensions import Self
|
||
|
|
|
||
|
|
import vllm.envs as envs
|
||
|
|
from vllm.config import VllmConfig
|
||
|
|
from vllm.logger import init_logger
|
||
|
|
from vllm.utils.torch_utils import (
|
||
|
|
STR_DTYPE_TO_TORCH_DTYPE,
|
||
|
|
get_dtype_size,
|
||
|
|
get_kv_cache_torch_dtype,
|
||
|
|
)
|
||
|
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||
|
|
|
||
|
|
logger = init_logger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class InvalidComponent(Exception):
|
||
|
|
"""
|
||
|
|
Custom exception to indicate that a certain ComponentMetric is not
|
||
|
|
applicable to the given VllmConfig.
|
||
|
|
"""
|
||
|
|
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
#### Basic Data Types ####
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class DebugPerfStats:
|
||
|
|
## Stats for debugging the metrics calculation
|
||
|
|
calc_duration: float = 0.0 # time spent calculating these stats
|
||
|
|
num_prefill_requests: int = 0
|
||
|
|
num_decode_requests: int = 0
|
||
|
|
context_breakdown: dict[str, int] | None = None
|
||
|
|
num_flops_per_gpu_breakdown: dict[str, int] | None = None
|
||
|
|
num_read_bytes_per_gpu_breakdown: dict[str, int] | None = None
|
||
|
|
num_write_bytes_per_gpu_breakdown: dict[str, int] | None = None
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class PerfStats:
|
||
|
|
num_flops_per_gpu: int = 0
|
||
|
|
num_read_bytes_per_gpu: int = 0
|
||
|
|
num_write_bytes_per_gpu: int = 0
|
||
|
|
debug_stats: DebugPerfStats | None = None
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class ExecutionContext:
|
||
|
|
"""
|
||
|
|
Represents an execution context for a batch of requests.
|
||
|
|
|
||
|
|
This class aggregates statistics across multiple requests in a batch,
|
||
|
|
separately tracking prefill and decode phases.
|
||
|
|
|
||
|
|
Example)
|
||
|
|
- Batch with one full prefill (2048 tokens) and one decode (1 token, 8192 context):
|
||
|
|
ctx = ExecutionContext()
|
||
|
|
ctx.add(2048, 2048, is_prefill=True)
|
||
|
|
ctx.add(1, 8192, is_prefill=False)
|
||
|
|
"""
|
||
|
|
|
||
|
|
# Prefill phase statistics
|
||
|
|
num_prefill_requests: int = 0
|
||
|
|
prefill_num_tokens: int = 0 # sum of num_tokens for prefill requests
|
||
|
|
prefill_context_len: int = 0 # sum of context_len for prefill requests
|
||
|
|
prefill_token_context_product: int = 0 # sum of (num_tokens * context_len)
|
||
|
|
|
||
|
|
# Decode phase statistics
|
||
|
|
num_decode_requests: int = 0
|
||
|
|
decode_num_tokens: int = 0 # sum of num_tokens for decode requests
|
||
|
|
decode_context_len: int = 0 # sum of context_len for decode requests
|
||
|
|
decode_token_context_product: int = 0 # sum of (num_tokens * context_len)
|
||
|
|
|
||
|
|
def add(self, num_tokens: int, context_len: int, is_prefill: bool) -> None:
|
||
|
|
"""Add a single request's statistics to this batch context."""
|
||
|
|
if is_prefill:
|
||
|
|
self.num_prefill_requests += 1
|
||
|
|
self.prefill_num_tokens += num_tokens
|
||
|
|
self.prefill_context_len += context_len
|
||
|
|
self.prefill_token_context_product += num_tokens * context_len
|
||
|
|
else:
|
||
|
|
self.num_decode_requests += 1
|
||
|
|
self.decode_num_tokens += num_tokens
|
||
|
|
self.decode_context_len += context_len
|
||
|
|
self.decode_token_context_product += num_tokens * context_len
|
||
|
|
|
||
|
|
def total_num_tokens(self) -> int:
|
||
|
|
"""Total number of tokens across all requests in the batch."""
|
||
|
|
return self.prefill_num_tokens + self.decode_num_tokens
|
||
|
|
|
||
|
|
def total_token_context_product(self) -> int:
|
||
|
|
"""Total sum of (num_tokens * context_len) across all requests."""
|
||
|
|
return self.prefill_token_context_product + self.decode_token_context_product
|
||
|
|
|
||
|
|
def num_logits_tokens(self) -> int:
|
||
|
|
"""Number of tokens that require logits computation (unembedding).
|
||
|
|
|
||
|
|
For prefill, only the last token per request needs logits.
|
||
|
|
For decode, all tokens need logits.
|
||
|
|
"""
|
||
|
|
return self.num_prefill_requests + self.decode_num_tokens
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def from_single_request(
|
||
|
|
cls, num_tokens: int, context_len: int, is_prefill: bool
|
||
|
|
) -> "ExecutionContext":
|
||
|
|
"""Create an ExecutionContext from a single request.
|
||
|
|
|
||
|
|
This is a convenience method primarily for testing.
|
||
|
|
"""
|
||
|
|
ctx = cls()
|
||
|
|
ctx.add(num_tokens, context_len, is_prefill)
|
||
|
|
return ctx
|
||
|
|
|
||
|
|
|
||
|
|
class ParsedArgs:
|
||
|
|
"""
|
||
|
|
Syntactic sugar so that Parsers can use dot notations
|
||
|
|
to access/update the parsed arguments.
|
||
|
|
|
||
|
|
e.g.)
|
||
|
|
args = ParsedArgs()
|
||
|
|
args.x = 3
|
||
|
|
args.y = args.x + 1
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __getattr__(self, name: str) -> Any:
|
||
|
|
raise AttributeError(f"'{type(self).__name__}' has no attribute '{name}'")
|
||
|
|
|
||
|
|
def __setattr__(self, name: str, value: Any) -> None:
|
||
|
|
object.__setattr__(self, name, value)
|
||
|
|
|
||
|
|
def model_dump(self) -> dict[str, Any]:
|
||
|
|
return vars(self).copy()
|
||
|
|
|
||
|
|
|
||
|
|
#### Abstract ####
|
||
|
|
|
||
|
|
|
||
|
|
class Parser(Protocol):
|
||
|
|
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
|
||
|
|
"""
|
||
|
|
Parse the vllm config and update the current ParsedArgs and pass it on.
|
||
|
|
If the parser isn't applicable to the vllm_config, it will do nothing.
|
||
|
|
"""
|
||
|
|
...
|
||
|
|
|
||
|
|
|
||
|
|
class ParserChain:
|
||
|
|
"""
|
||
|
|
Applies chain of parser in a sequential order.
|
||
|
|
Later parsers might overwrite results from previous parsers,
|
||
|
|
so parsers should be chained in the appropriate order if they
|
||
|
|
are not mutually exclusive.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, *parsers: Parser) -> None:
|
||
|
|
self.parsers = list(parsers)
|
||
|
|
|
||
|
|
def add_parser(self, parser: Parser) -> None:
|
||
|
|
self.parsers.append(parser)
|
||
|
|
|
||
|
|
def parse(self, vllm_config: VllmConfig) -> ParsedArgs:
|
||
|
|
args = ParsedArgs()
|
||
|
|
for parser in self.parsers:
|
||
|
|
args = parser.parse(args, vllm_config)
|
||
|
|
return args
|
||
|
|
|
||
|
|
|
||
|
|
_COMPONENT_METRICS_REGISTRY: dict[str, type["ComponentMetrics"]] = {}
|
||
|
|
|
||
|
|
|
||
|
|
class ComponentMetrics(BaseModel, ABC):
|
||
|
|
"""
|
||
|
|
Each concrete ComponentMetrics class is associated with:
|
||
|
|
- fields that are required for metric derivation
|
||
|
|
(fields are specified/validated through pydantic model)
|
||
|
|
- parser to parse VllmConfig into fields
|
||
|
|
- metric methods that derive flops/bytes for a given execution context
|
||
|
|
"""
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
@abstractmethod
|
||
|
|
def component_type(cls) -> str: ...
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
@abstractmethod
|
||
|
|
def get_parser(cls) -> ParserChain:
|
||
|
|
"""
|
||
|
|
Return a ParserChain that provides values for all required fields.
|
||
|
|
The returned parser chain must populate ParsedArgs with values for every
|
||
|
|
field defined on this ComponentMetrics class. Missing fields will cause
|
||
|
|
a ValidationError when from_vllm_config() is called.
|
||
|
|
See individual Parser docstrings for which args they provide, and field
|
||
|
|
comments on ComponentMetrics subclasses for which parser provides each field.
|
||
|
|
"""
|
||
|
|
...
|
||
|
|
|
||
|
|
def __init_subclass__(cls):
|
||
|
|
_COMPONENT_METRICS_REGISTRY[cls.component_type()] = cls
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def from_vllm_config(cls, vllm_config: VllmConfig) -> Self:
|
||
|
|
"""
|
||
|
|
Instantiate this class from VllmConfig.
|
||
|
|
Raises ValidationError if parsing fails.
|
||
|
|
"""
|
||
|
|
|
||
|
|
parser = cls.get_parser()
|
||
|
|
parsed_args = parser.parse(vllm_config)
|
||
|
|
try:
|
||
|
|
return cls.model_validate(parsed_args.model_dump())
|
||
|
|
except ValidationError as e:
|
||
|
|
raise InvalidComponent(f"Invalid {cls.component_type()} config: {e}") from e
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def registered_metrics(cls) -> Iterable[type["ComponentMetrics"]]:
|
||
|
|
return iter(_COMPONENT_METRICS_REGISTRY.values())
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def get_num_flops_breakdown(
|
||
|
|
self, ctx: ExecutionContext, per_gpu: bool = True
|
||
|
|
) -> dict[str, int]: ...
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def get_read_bytes_breakdown(
|
||
|
|
self, ctx: ExecutionContext, per_gpu: bool = True
|
||
|
|
) -> dict[str, int]: ...
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def get_write_bytes_breakdown(
|
||
|
|
self, ctx: ExecutionContext, per_gpu: bool = True
|
||
|
|
) -> dict[str, int]: ...
|
||
|
|
|
||
|
|
def get_num_flops(self, ctx: ExecutionContext, per_gpu: bool = True) -> int:
|
||
|
|
return sum(self.get_num_flops_breakdown(ctx, per_gpu).values())
|
||
|
|
|
||
|
|
def get_read_bytes(self, ctx: ExecutionContext, per_gpu: bool = True) -> int:
|
||
|
|
return sum(self.get_read_bytes_breakdown(ctx, per_gpu).values())
|
||
|
|
|
||
|
|
def get_write_bytes(self, ctx: ExecutionContext, per_gpu: bool = True) -> int:
|
||
|
|
return sum(self.get_write_bytes_breakdown(ctx, per_gpu).values())
|
||
|
|
|
||
|
|
|
||
|
|
#### parsers ####
|
||
|
|
|
||
|
|
|
||
|
|
class BaseConfigParser(Parser):
|
||
|
|
"""
|
||
|
|
Parses base model configuration.
|
||
|
|
Provides: vocab_size, hidden_size, num_attention_heads, num_hidden_layers,
|
||
|
|
weight_byte_size, activation_byte_size, dp_size, tp_size, pp_size, enable_ep
|
||
|
|
"""
|
||
|
|
|
||
|
|
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
|
||
|
|
model_config = vllm_config.model_config
|
||
|
|
|
||
|
|
args.vocab_size = model_config.get_vocab_size()
|
||
|
|
args.hidden_size = model_config.get_hidden_size()
|
||
|
|
# NOTE: model_config.get_attention_heads() divide by TP
|
||
|
|
# so we access field manually here to get total num_heads
|
||
|
|
args.num_attention_heads = get_required(
|
||
|
|
model_config.hf_text_config, "num_attention_heads"
|
||
|
|
)
|
||
|
|
args.num_hidden_layers = get_required(
|
||
|
|
model_config.hf_text_config, "num_hidden_layers"
|
||
|
|
)
|
||
|
|
|
||
|
|
model_dtype = vllm_config.model_config.dtype
|
||
|
|
|
||
|
|
if isinstance(model_dtype, torch.dtype):
|
||
|
|
torch_dtype = model_dtype
|
||
|
|
elif isinstance(model_dtype, str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE:
|
||
|
|
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
|
||
|
|
else:
|
||
|
|
# FIXME: handle this better
|
||
|
|
logger.warning(
|
||
|
|
"Unknown model_dtype %s, defaulting to bfloat16",
|
||
|
|
model_dtype,
|
||
|
|
)
|
||
|
|
torch_dtype = torch.bfloat16
|
||
|
|
|
||
|
|
args.weight_byte_size = get_dtype_size(torch_dtype)
|
||
|
|
|
||
|
|
# FIXME: handle this better by parsing whether activations use
|
||
|
|
# bf16, fp32, etc...
|
||
|
|
args.activation_byte_size = 2
|
||
|
|
|
||
|
|
args.dp_size = vllm_config.parallel_config.data_parallel_size
|
||
|
|
args.tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||
|
|
args.pp_size = vllm_config.parallel_config.pipeline_parallel_size
|
||
|
|
args.enable_ep = vllm_config.parallel_config.enable_expert_parallel
|
||
|
|
|
||
|
|
return args
|
||
|
|
|
||
|
|
|
||
|
|
#### Attention ####
|
||
|
|
|
||
|
|
|
||
|
|
class BaseAttentionConfigParser(Parser):
|
||
|
|
"""
|
||
|
|
Parses attention-specific configuration.
|
||
|
|
Provides: num_key_value_heads, head_dim, cache_byte_size
|
||
|
|
"""
|
||
|
|
|
||
|
|
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
|
||
|
|
model_config = vllm_config.model_config
|
||
|
|
|
||
|
|
args.num_key_value_heads = model_config.get_total_num_kv_heads()
|
||
|
|
args.head_dim = model_config.get_head_size()
|
||
|
|
|
||
|
|
model_dtype = vllm_config.model_config.dtype
|
||
|
|
cache_dtype = vllm_config.cache_config.cache_dtype
|
||
|
|
|
||
|
|
kv_cache_torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
|
||
|
|
args.cache_byte_size = get_dtype_size(kv_cache_torch_dtype)
|
||
|
|
|
||
|
|
return args
|
||
|
|
|
||
|
|
|
||
|
|
class AttentionQuantizationConfigParser(Parser):
|
||
|
|
"""
|
||
|
|
Parses quantization configuration for attention layers.
|
||
|
|
Overrides: weight_byte_size
|
||
|
|
"""
|
||
|
|
|
||
|
|
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
|
||
|
|
cfg = vllm_config.quant_config
|
||
|
|
|
||
|
|
if cfg is None:
|
||
|
|
return args
|
||
|
|
|
||
|
|
quant_method = cfg.get_name()
|
||
|
|
if quant_method in ["fp8", "fbgemm_fp8"]:
|
||
|
|
# FIXME: This is a hacky coarse-grained fp8 quantization detection.
|
||
|
|
# FIXME: These configs also have concept of "ignored layers" and we
|
||
|
|
# need to solve the same problem as above.
|
||
|
|
args.weight_byte_size = 1
|
||
|
|
elif quant_method == "mxfp4":
|
||
|
|
# FIXME: Also has "ignored layers" issue above
|
||
|
|
args.weight_byte_size = 0.5
|
||
|
|
else:
|
||
|
|
# FIXME: Add more parsing logic for different quant methods.
|
||
|
|
raise InvalidComponent
|
||
|
|
|
||
|
|
return args
|
||
|
|
|
||
|
|
|
||
|
|
class AttentionMetrics(ComponentMetrics):
|
||
|
|
# From BaseConfigParser
|
||
|
|
num_hidden_layers: int = Field(..., gt=0)
|
||
|
|
hidden_size: int = Field(..., gt=0)
|
||
|
|
num_attention_heads: int = Field(..., gt=0)
|
||
|
|
activation_byte_size: int = Field(..., gt=0)
|
||
|
|
tp_size: int = Field(..., gt=0)
|
||
|
|
pp_size: int = Field(..., gt=0)
|
||
|
|
|
||
|
|
# From BaseAttentionConfigParser
|
||
|
|
num_key_value_heads: int = Field(..., gt=0)
|
||
|
|
head_dim: int = Field(..., gt=0)
|
||
|
|
cache_byte_size: int = Field(..., gt=0)
|
||
|
|
|
||
|
|
# From BaseConfig Parser, overridden by AttentionQuantizationConfigParser
|
||
|
|
weight_byte_size: int | float = Field(..., gt=0)
|
||
|
|
|
||
|
|
# TODO: discern cases where we have mixture of different attention layer types
|
||
|
|
# such as SWA, MLA, etc.
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def component_type(cls) -> str:
|
||
|
|
return "attn"
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def get_parser(cls) -> ParserChain:
|
||
|
|
return ParserChain(
|
||
|
|
BaseConfigParser(),
|
||
|
|
BaseAttentionConfigParser(),
|
||
|
|
AttentionQuantizationConfigParser(),
|
||
|
|
)
|
||
|
|
|
||
|
|
def get_num_flops_breakdown(
|
||
|
|
self, ctx: ExecutionContext, per_gpu: bool = True
|
||
|
|
) -> dict[str, int]:
|
||
|
|
L, D, q, kv, d = (
|
||
|
|
self.num_hidden_layers,
|
||
|
|
self.hidden_size,
|
||
|
|
self.num_attention_heads,
|
||
|
|
self.num_key_value_heads,
|
||
|
|
self.head_dim,
|
||
|
|
)
|
||
|
|
T = ctx.total_num_tokens()
|
||
|
|
TC = ctx.total_token_context_product()
|
||
|
|
|
||
|
|
if per_gpu:
|
||
|
|
L //= self.pp_size
|
||
|
|
# tensor parallel along heads
|
||
|
|
q = max(1, q // self.tp_size)
|
||
|
|
kv = max(1, kv // self.tp_size)
|
||
|
|
|
||
|
|
return {
|
||
|
|
"qkv_proj": 2 * T * D * (q + 2 * kv) * d * L,
|
||
|
|
"attn_qk": 2 * q * TC * d * L,
|
||
|
|
"attn_av": 2 * q * TC * d * L,
|
||
|
|
"out_proj": 2 * T * D * q * d * L,
|
||
|
|
}
|
||
|
|
|
||
|
|
def get_read_bytes_breakdown(
|
||
|
|
self, ctx: ExecutionContext, per_gpu: bool = True
|
||
|
|
) -> dict[str, int]:
|
||
|
|
L, D, q, kv, d = (
|
||
|
|
self.num_hidden_layers,
|
||
|
|
self.hidden_size,
|
||
|
|
self.num_attention_heads,
|
||
|
|
self.num_key_value_heads,
|
||
|
|
self.head_dim,
|
||
|
|
)
|
||
|
|
T = ctx.total_num_tokens()
|
||
|
|
|
||
|
|
if per_gpu:
|
||
|
|
L //= self.pp_size
|
||
|
|
# tensor parallel along heads
|
||
|
|
q = max(1, q // self.tp_size)
|
||
|
|
kv = max(1, kv // self.tp_size)
|
||
|
|
|
||
|
|
read_bytes = {}
|
||
|
|
|
||
|
|
read_bytes["qkv_input"] = T * D * self.activation_byte_size * L
|
||
|
|
read_bytes["qkv_weight"] = int(D * (q + 2 * kv) * d * self.weight_byte_size * L)
|
||
|
|
|
||
|
|
# Attention input reads differ between prefill and decode
|
||
|
|
# Prefill: read Q, K, V activations (all in activation_byte_size)
|
||
|
|
if ctx.prefill_num_tokens > 0:
|
||
|
|
read_bytes["attn_input"] = (
|
||
|
|
(ctx.prefill_num_tokens * q + 2 * ctx.prefill_context_len * kv)
|
||
|
|
* d
|
||
|
|
* self.activation_byte_size
|
||
|
|
* L
|
||
|
|
)
|
||
|
|
|
||
|
|
# Decode: read Q activations + read K, V from cache (in cache_byte_size)
|
||
|
|
if ctx.decode_num_tokens > 0:
|
||
|
|
read_bytes["attn_input"] = read_bytes.get("attn_input", 0) + (
|
||
|
|
ctx.decode_num_tokens * q * d * self.activation_byte_size * L
|
||
|
|
+ 2 * ctx.decode_context_len * kv * d * self.cache_byte_size * L
|
||
|
|
)
|
||
|
|
|
||
|
|
read_bytes["out_input"] = T * q * d * self.activation_byte_size * L
|
||
|
|
read_bytes["out_weight"] = int(q * d * D * self.weight_byte_size * L)
|
||
|
|
|
||
|
|
return read_bytes
|
||
|
|
|
||
|
|
def get_write_bytes_breakdown(
|
||
|
|
self, ctx: ExecutionContext, per_gpu: bool = True
|
||
|
|
) -> dict[str, int]:
|
||
|
|
"""Calculate write memory traffic for attention layers."""
|
||
|
|
L, D, q, kv, d = (
|
||
|
|
self.num_hidden_layers,
|
||
|
|
self.hidden_size,
|
||
|
|
self.num_attention_heads,
|
||
|
|
self.num_key_value_heads,
|
||
|
|
self.head_dim,
|
||
|
|
)
|
||
|
|
T = ctx.total_num_tokens()
|
||
|
|
|
||
|
|
if per_gpu:
|
||
|
|
L //= self.pp_size
|
||
|
|
# tensor parallel along heads
|
||
|
|
q = max(1, q // self.tp_size)
|
||
|
|
kv = max(1, kv // self.tp_size)
|
||
|
|
|
||
|
|
return {
|
||
|
|
"qkv_output": T * (q + 2 * kv) * d * self.activation_byte_size * L,
|
||
|
|
"kv_cache": 2 * T * kv * d * self.cache_byte_size * L,
|
||
|
|
"out_output": T * D * self.activation_byte_size * L,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
#### Ffn ####
|
||
|
|
|
||
|
|
|
||
|
|
class BaseFfnConfigParser(Parser):
|
||
|
|
"""
|
||
|
|
Parses FFN and MoE configuration.
|
||
|
|
Provides: intermediate_size, num_experts, num_experts_per_tok,
|
||
|
|
moe_intermediate_size, num_shared_experts, num_moe_layers
|
||
|
|
"""
|
||
|
|
|
||
|
|
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
|
||
|
|
cfg = vllm_config.model_config.hf_config
|
||
|
|
if hasattr(cfg, "text_config") and cfg.text_config is not None:
|
||
|
|
cfg = cfg.text_config
|
||
|
|
|
||
|
|
args.intermediate_size = getattr(cfg, "intermediate_size", args.hidden_size * 4)
|
||
|
|
|
||
|
|
# Try different naming conventions.
|
||
|
|
args.num_experts = vllm_config.model_config.get_num_experts()
|
||
|
|
args.num_experts_per_tok = getattr_from_list(
|
||
|
|
cfg, ["num_experts_per_tok", "moe_topk"], 0
|
||
|
|
)
|
||
|
|
args.moe_intermediate_size = getattr_from_list(
|
||
|
|
cfg, ["moe_intermediate_size", "intermediate_size"], 0
|
||
|
|
)
|
||
|
|
args.num_shared_experts = getattr_from_list(
|
||
|
|
cfg, ["n_shared_experts", "num_shared_experts"], 0
|
||
|
|
)
|
||
|
|
|
||
|
|
is_moe = args.num_experts != 0
|
||
|
|
# Assume all MoE layers by default
|
||
|
|
args.num_moe_layers = args.num_hidden_layers if is_moe else 0
|
||
|
|
|
||
|
|
return args
|
||
|
|
|
||
|
|
|
||
|
|
class FfnParallelParser(Parser):
|
||
|
|
"""
|
||
|
|
Parses FFN parallelism configuration.
|
||
|
|
|
||
|
|
Provides: ffn_tp_size, ffn_ep_size
|
||
|
|
"""
|
||
|
|
|
||
|
|
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
|
||
|
|
# NOTE: ffn tp_size does not equal the tp_size parameter directly.
|
||
|
|
# e.g.) If we use DP2TP4, ffn will use TP8 (or EP8 if EP is enabled.)
|
||
|
|
if args.enable_ep:
|
||
|
|
ffn_tp_size, ffn_ep_size = 1, args.dp_size * args.tp_size
|
||
|
|
else:
|
||
|
|
ffn_tp_size, ffn_ep_size = args.dp_size * args.tp_size, 1
|
||
|
|
|
||
|
|
args.ffn_tp_size = ffn_tp_size
|
||
|
|
args.ffn_ep_size = ffn_ep_size
|
||
|
|
|
||
|
|
return args
|
||
|
|
|
||
|
|
|
||
|
|
class InterleaveMoeLayerStepParser(Parser):
|
||
|
|
"""
|
||
|
|
Parses interleave_moe_layer_step field for models like Llama4.
|
||
|
|
|
||
|
|
Overrides: num_moe_layers
|
||
|
|
"""
|
||
|
|
|
||
|
|
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
|
||
|
|
cfg = vllm_config.model_config.hf_config
|
||
|
|
if hasattr(cfg, "text_config") and cfg.text_config is not None:
|
||
|
|
cfg = cfg.text_config
|
||
|
|
|
||
|
|
if (
|
||
|
|
hasattr(cfg, "interleave_moe_layer_step")
|
||
|
|
and cfg.interleave_moe_layer_step > 0
|
||
|
|
):
|
||
|
|
args.num_moe_layers = len(
|
||
|
|
[
|
||
|
|
layer
|
||
|
|
for layer in range(args.num_hidden_layers)
|
||
|
|
if (layer + 1) % cfg.interleave_moe_layer_step == 0
|
||
|
|
]
|
||
|
|
)
|
||
|
|
|
||
|
|
return args
|
||
|
|
|
||
|
|
|
||
|
|
class MoeLayerFreqParser(Parser):
|
||
|
|
"""
|
||
|
|
Parses moe_layer_freq and first_k_dense_replace fields for models like Deepseek.
|
||
|
|
|
||
|
|
Overrides: num_moe_layers
|
||
|
|
"""
|
||
|
|
|
||
|
|
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
|
||
|
|
cfg = vllm_config.model_config.hf_config
|
||
|
|
if hasattr(cfg, "text_config") and cfg.text_config is not None:
|
||
|
|
cfg = cfg.text_config
|
||
|
|
|
||
|
|
if hasattr(cfg, "moe_layer_freq") and hasattr(cfg, "first_k_dense_replace"):
|
||
|
|
args.num_moe_layers = len(
|
||
|
|
[
|
||
|
|
layer
|
||
|
|
for layer in range(args.num_hidden_layers)
|
||
|
|
if layer >= cfg.first_k_dense_replace
|
||
|
|
and layer % cfg.moe_layer_freq == 0
|
||
|
|
]
|
||
|
|
)
|
||
|
|
|
||
|
|
return args
|
||
|
|
|
||
|
|
|
||
|
|
class FfnQuantizationConfigParser(Parser):
|
||
|
|
"""
|
||
|
|
Parses quantization configuration for FFN layers.
|
||
|
|
|
||
|
|
Overrides: weight_byte_size
|
||
|
|
"""
|
||
|
|
|
||
|
|
def parse(self, args: ParsedArgs, vllm_config: VllmConfig) -> ParsedArgs:
|
||
|
|
cfg = vllm_config.quant_config
|
||
|
|
|
||
|
|
if cfg is None:
|
||
|
|
return args
|
||
|
|
|
||
|
|
quant_method = cfg.get_name()
|
||
|
|
if quant_method in ["fp8", "fbgemm_fp8"]:
|
||
|
|
# FIXME: This is a hacky coarse-grained fp8 quantization detection.
|
||
|
|
# (there might be more quantization methods for fp8).
|
||
|
|
# FIXME: These configs also have concept of "ignored layers" and we
|
||
|
|
# need to solve the same problem as above.
|
||
|
|
args.weight_byte_size = 1
|
||
|
|
pass
|
||
|
|
elif quant_method == "mxfp4":
|
||
|
|
# FIXME: Also has "ignored layers" issue above
|
||
|
|
args.weight_byte_size = 0.5
|
||
|
|
else:
|
||
|
|
# FIXME: Add more parsing logic for different quant methods.
|
||
|
|
raise InvalidComponent
|
||
|
|
|
||
|
|
return args
|
||
|
|
|
||
|
|
|
||
|
|
class FfnMetrics(ComponentMetrics):
|
||
|
|
# From BaseConfigParser
|
||
|
|
num_hidden_layers: int = Field(..., gt=0)
|
||
|
|
hidden_size: int = Field(..., gt=0)
|
||
|
|
activation_byte_size: int = Field(..., gt=0)
|
||
|
|
pp_size: int = Field(..., gt=0)
|
||
|
|
|
||
|
|
# From FfnParallelParser
|
||
|
|
ffn_tp_size: int = Field(..., gt=0)
|
||
|
|
ffn_ep_size: int = Field(..., gt=0)
|
||
|
|
|
||
|
|
# From BaseFfnConfigParser
|
||
|
|
intermediate_size: int = Field(..., gt=0)
|
||
|
|
num_experts: int = Field(0)
|
||
|
|
num_experts_per_tok: int = Field(1)
|
||
|
|
moe_intermediate_size: int = Field(0)
|
||
|
|
num_shared_experts: int = Field(0)
|
||
|
|
|
||
|
|
# From BaseConfigParser, can be overridden InterleaveMoeLayerStep or MoeLayerFreq
|
||
|
|
num_moe_layers: int = Field(..., ge=0)
|
||
|
|
|
||
|
|
# FIXME: might have to make this more granular
|
||
|
|
# (i.e. dense_weight_byte_size, moe_routed_weight_byte_size,
|
||
|
|
# moe_shared_weight_byte_size)
|
||
|
|
# since it can differ from byte size of other components (e.g. attn)
|
||
|
|
# and can differ even from each other.
|
||
|
|
|
||
|
|
# From BaseConfigParser, can be overridden by FfnQuantizationConfigParser
|
||
|
|
weight_byte_size: int | float = Field(..., gt=0)
|
||
|
|
|
||
|
|
@model_validator(mode="after")
|
||
|
|
def validate_moe_fields(self) -> Self:
|
||
|
|
"""Validate that MoE-related fields are properly set when num_moe_layers > 0."""
|
||
|
|
if self.num_moe_layers > 0:
|
||
|
|
assert self.num_experts, f"{self.num_experts=}"
|
||
|
|
assert self.num_experts_per_tok, f"{self.num_experts_per_tok=}"
|
||
|
|
assert self.moe_intermediate_size, f"{self.moe_intermediate_size=}"
|
||
|
|
return self
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def component_type(cls) -> str:
|
||
|
|
return "ffn"
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def get_parser(cls) -> ParserChain:
|
||
|
|
return ParserChain(
|
||
|
|
BaseConfigParser(),
|
||
|
|
FfnParallelParser(),
|
||
|
|
BaseFfnConfigParser(),
|
||
|
|
InterleaveMoeLayerStepParser(),
|
||
|
|
MoeLayerFreqParser(),
|
||
|
|
FfnQuantizationConfigParser(),
|
||
|
|
)
|
||
|
|
|
||
|
|
def get_num_flops_breakdown(
|
||
|
|
self, ctx: ExecutionContext, per_gpu: bool = True
|
||
|
|
) -> dict[str, int]:
|
||
|
|
"""Calculate flops breakdown for FFN layers."""
|
||
|
|
L, D, DI = self.num_hidden_layers, self.hidden_size, self.intermediate_size
|
||
|
|
Lm, E, MI, S = (
|
||
|
|
self.num_moe_layers,
|
||
|
|
self.num_experts_per_tok,
|
||
|
|
self.moe_intermediate_size,
|
||
|
|
self.num_shared_experts,
|
||
|
|
)
|
||
|
|
T = ctx.total_num_tokens()
|
||
|
|
|
||
|
|
Ld = L - Lm
|
||
|
|
|
||
|
|
num_activated_tokens = T * E if E else 0
|
||
|
|
|
||
|
|
if per_gpu:
|
||
|
|
Ld //= self.pp_size
|
||
|
|
Lm //= self.pp_size
|
||
|
|
|
||
|
|
DI //= self.ffn_tp_size
|
||
|
|
if MI is not None:
|
||
|
|
MI //= self.ffn_tp_size
|
||
|
|
if E:
|
||
|
|
num_activated_tokens //= self.ffn_ep_size
|
||
|
|
|
||
|
|
flops = {}
|
||
|
|
|
||
|
|
# Dense FFN layers (SwiGLU: 3 linear layers: up, gate, down)
|
||
|
|
if Ld:
|
||
|
|
flops["dense_ffn"] = 2 * D * 3 * DI * T * Ld
|
||
|
|
|
||
|
|
# MoE routed experts (each token activates E experts)
|
||
|
|
if Lm and E:
|
||
|
|
flops["routed_ffn"] = 2 * D * 3 * MI * num_activated_tokens * Lm
|
||
|
|
|
||
|
|
# MoE shared experts (all S shared experts run for every token)
|
||
|
|
if Lm and S:
|
||
|
|
flops["shared_ffn"] = 2 * D * 3 * MI * S * T * Lm
|
||
|
|
|
||
|
|
return flops
|
||
|
|
|
||
|
|
def get_read_bytes_breakdown(
|
||
|
|
self, ctx: ExecutionContext, per_gpu: bool = True
|
||
|
|
) -> dict[str, int]:
|
||
|
|
"""Calculate read memory traffic for FFN layers."""
|
||
|
|
L, D, DI = self.num_hidden_layers, self.hidden_size, self.intermediate_size
|
||
|
|
Lm, E, MI, S = (
|
||
|
|
self.num_moe_layers,
|
||
|
|
self.num_experts_per_tok,
|
||
|
|
self.moe_intermediate_size,
|
||
|
|
self.num_shared_experts,
|
||
|
|
)
|
||
|
|
T = ctx.total_num_tokens()
|
||
|
|
num_experts = self.num_experts
|
||
|
|
|
||
|
|
Ld = L - Lm
|
||
|
|
|
||
|
|
num_activated_tokens = T * E if E else 0
|
||
|
|
|
||
|
|
if per_gpu:
|
||
|
|
Ld //= self.pp_size
|
||
|
|
Lm //= self.pp_size
|
||
|
|
|
||
|
|
DI //= self.ffn_tp_size
|
||
|
|
if MI is not None:
|
||
|
|
MI //= self.ffn_tp_size
|
||
|
|
if E:
|
||
|
|
num_activated_tokens //= self.ffn_ep_size
|
||
|
|
if num_experts is not None:
|
||
|
|
num_experts //= self.ffn_ep_size
|
||
|
|
|
||
|
|
read_bytes = {}
|
||
|
|
|
||
|
|
# Dense FFN layers (3 GEMMs: up, gate, down projections + SiLU activation)
|
||
|
|
if Ld:
|
||
|
|
read_bytes["dense_up_gate_input"] = int(
|
||
|
|
T * D * self.activation_byte_size * Ld
|
||
|
|
)
|
||
|
|
read_bytes["dense_up_gate_weights"] = int(
|
||
|
|
2 * D * DI * self.weight_byte_size * Ld
|
||
|
|
)
|
||
|
|
read_bytes["dense_silu_input"] = int(
|
||
|
|
2 * T * DI * self.activation_byte_size * Ld
|
||
|
|
)
|
||
|
|
read_bytes["dense_down_input"] = int(
|
||
|
|
T * DI * self.activation_byte_size * Ld
|
||
|
|
)
|
||
|
|
read_bytes["dense_down_weights"] = int(D * DI * self.weight_byte_size * Ld)
|
||
|
|
|
||
|
|
if Lm:
|
||
|
|
# MoE routed expert reads
|
||
|
|
if E:
|
||
|
|
# FIXME: Assume perfect load balancing for now.
|
||
|
|
num_activated_experts = min(num_activated_tokens, num_experts)
|
||
|
|
|
||
|
|
read_bytes["routed_up_gate_input"] = int(
|
||
|
|
num_activated_tokens * D * self.activation_byte_size * Lm
|
||
|
|
)
|
||
|
|
read_bytes["routed_up_gate_weights"] = int(
|
||
|
|
2 * D * MI * num_activated_experts * self.weight_byte_size * Lm
|
||
|
|
)
|
||
|
|
read_bytes["routed_silu_input"] = int(
|
||
|
|
2 * num_activated_tokens * MI * self.activation_byte_size * Lm
|
||
|
|
)
|
||
|
|
read_bytes["routed_down_input"] = int(
|
||
|
|
num_activated_tokens * MI * self.activation_byte_size * Lm
|
||
|
|
)
|
||
|
|
read_bytes["routed_down_weights"] = int(
|
||
|
|
D * MI * num_activated_experts * self.weight_byte_size * Lm
|
||
|
|
)
|
||
|
|
|
||
|
|
# MoE shared expert reads
|
||
|
|
if S:
|
||
|
|
read_bytes["shared_up_gate_input"] = int(
|
||
|
|
T * D * self.activation_byte_size * Lm
|
||
|
|
)
|
||
|
|
read_bytes["shared_up_gate_weights"] = int(
|
||
|
|
2 * D * MI * S * self.weight_byte_size * Lm
|
||
|
|
)
|
||
|
|
read_bytes["shared_silu_input"] = int(
|
||
|
|
2 * T * MI * S * self.activation_byte_size * Lm
|
||
|
|
)
|
||
|
|
read_bytes["shared_down_input"] = int(
|
||
|
|
T * MI * self.activation_byte_size * Lm
|
||
|
|
)
|
||
|
|
read_bytes["shared_down_weights"] = int(
|
||
|
|
D * MI * S * self.weight_byte_size * Lm
|
||
|
|
)
|
||
|
|
|
||
|
|
return read_bytes
|
||
|
|
|
||
|
|
def get_write_bytes_breakdown(
|
||
|
|
self, ctx: ExecutionContext, per_gpu: bool = True
|
||
|
|
) -> dict[str, int]:
|
||
|
|
"""Calculate write memory traffic for FFN layers."""
|
||
|
|
L, D, DI = self.num_hidden_layers, self.hidden_size, self.intermediate_size
|
||
|
|
Lm, E, MI, S = (
|
||
|
|
self.num_moe_layers,
|
||
|
|
self.num_experts_per_tok,
|
||
|
|
self.moe_intermediate_size,
|
||
|
|
self.num_shared_experts,
|
||
|
|
)
|
||
|
|
T = ctx.total_num_tokens()
|
||
|
|
|
||
|
|
Ld = L - Lm
|
||
|
|
|
||
|
|
num_activated_tokens = T * E if E else 0
|
||
|
|
|
||
|
|
if per_gpu:
|
||
|
|
Ld //= self.pp_size
|
||
|
|
Lm //= self.pp_size
|
||
|
|
|
||
|
|
DI //= self.ffn_tp_size
|
||
|
|
if MI is not None:
|
||
|
|
MI //= self.ffn_tp_size
|
||
|
|
if E:
|
||
|
|
num_activated_tokens //= self.ffn_ep_size
|
||
|
|
|
||
|
|
write_bytes = {}
|
||
|
|
|
||
|
|
# Dense FFN layers
|
||
|
|
if Ld:
|
||
|
|
write_bytes["dense_up_gate_output"] = int(
|
||
|
|
2 * T * DI * self.activation_byte_size * Ld
|
||
|
|
)
|
||
|
|
write_bytes["dense_silu_output"] = int(
|
||
|
|
T * DI * self.activation_byte_size * Ld
|
||
|
|
)
|
||
|
|
write_bytes["dense_down_output"] = int(
|
||
|
|
T * D * self.activation_byte_size * Ld
|
||
|
|
)
|
||
|
|
|
||
|
|
# MoE outputs
|
||
|
|
if Lm:
|
||
|
|
if E:
|
||
|
|
write_bytes["routed_up_gate_output"] = int(
|
||
|
|
2 * num_activated_tokens * MI * self.activation_byte_size * Lm
|
||
|
|
)
|
||
|
|
write_bytes["routed_silu_output"] = int(
|
||
|
|
num_activated_tokens * MI * self.activation_byte_size * Lm
|
||
|
|
)
|
||
|
|
write_bytes["routed_down_output"] = int(
|
||
|
|
num_activated_tokens * D * self.activation_byte_size * Lm
|
||
|
|
)
|
||
|
|
if S:
|
||
|
|
write_bytes["shared_up_gate_output"] = int(
|
||
|
|
2 * T * S * MI * self.activation_byte_size * Lm
|
||
|
|
)
|
||
|
|
write_bytes["shared_silu_output"] = int(
|
||
|
|
T * S * MI * self.activation_byte_size * Lm
|
||
|
|
)
|
||
|
|
write_bytes["shared_down_output"] = int(
|
||
|
|
T * S * D * self.activation_byte_size * Lm
|
||
|
|
)
|
||
|
|
|
||
|
|
return write_bytes
|
||
|
|
|
||
|
|
|
||
|
|
#### Unembed ####
|
||
|
|
|
||
|
|
|
||
|
|
class UnembedMetrics(ComponentMetrics):
|
||
|
|
# From BaseConfigParser
|
||
|
|
hidden_size: int = Field(..., gt=0)
|
||
|
|
vocab_size: int = Field(..., gt=0)
|
||
|
|
weight_byte_size: int = Field(..., gt=0)
|
||
|
|
activation_byte_size: int = Field(..., gt=0)
|
||
|
|
|
||
|
|
tp_size: int
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def component_type(cls) -> str:
|
||
|
|
return "unembed"
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def get_parser(cls) -> ParserChain:
|
||
|
|
return ParserChain(
|
||
|
|
BaseConfigParser(),
|
||
|
|
)
|
||
|
|
|
||
|
|
def get_num_flops_breakdown(
|
||
|
|
self, ctx: ExecutionContext, per_gpu: bool = True
|
||
|
|
) -> dict[str, int]:
|
||
|
|
"""Calculate flops breakdown for unembedding layer."""
|
||
|
|
D, V = self.hidden_size, self.vocab_size
|
||
|
|
T = ctx.num_logits_tokens()
|
||
|
|
|
||
|
|
if per_gpu:
|
||
|
|
V //= self.tp_size
|
||
|
|
|
||
|
|
return {
|
||
|
|
"unembed": 2 * T * D * V,
|
||
|
|
}
|
||
|
|
|
||
|
|
def get_read_bytes_breakdown(
|
||
|
|
self, ctx: ExecutionContext, per_gpu: bool = True
|
||
|
|
) -> dict[str, int]:
|
||
|
|
"""Calculate read memory traffic for unembedding layer."""
|
||
|
|
D, V = self.hidden_size, self.vocab_size
|
||
|
|
T = ctx.num_logits_tokens()
|
||
|
|
|
||
|
|
if per_gpu:
|
||
|
|
V //= self.tp_size
|
||
|
|
|
||
|
|
return {
|
||
|
|
"input": T * D * self.activation_byte_size,
|
||
|
|
"weight": D * V * self.weight_byte_size,
|
||
|
|
}
|
||
|
|
|
||
|
|
def get_write_bytes_breakdown(
|
||
|
|
self, ctx: ExecutionContext, per_gpu: bool = True
|
||
|
|
) -> dict[str, int]:
|
||
|
|
"""Calculate write memory traffic for unembedding layer."""
|
||
|
|
V = self.vocab_size
|
||
|
|
T = ctx.num_logits_tokens()
|
||
|
|
|
||
|
|
if per_gpu:
|
||
|
|
V //= self.tp_size
|
||
|
|
|
||
|
|
return {
|
||
|
|
"output": T * V * self.activation_byte_size,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
#### ModelMetrics ####
|
||
|
|
|
||
|
|
|
||
|
|
class ModelMetrics:
|
||
|
|
def __init__(self, vllm_config: VllmConfig) -> None:
|
||
|
|
"""
|
||
|
|
Parse vllm_config to instantiate metrics for each component.
|
||
|
|
is_enabled() will return False if no component metrics could be instantiated.
|
||
|
|
"""
|
||
|
|
|
||
|
|
self.vllm_config = vllm_config
|
||
|
|
|
||
|
|
self.metrics: list[ComponentMetrics] = []
|
||
|
|
for metric_cls in ComponentMetrics.registered_metrics():
|
||
|
|
try:
|
||
|
|
metric = metric_cls.from_vllm_config(vllm_config)
|
||
|
|
self.metrics.append(metric)
|
||
|
|
logger.info(
|
||
|
|
"Instantiated ComponentMetrics [%s] with (%s)",
|
||
|
|
metric.component_type(),
|
||
|
|
str(metric),
|
||
|
|
)
|
||
|
|
except InvalidComponent as e:
|
||
|
|
logger.debug(
|
||
|
|
"Failed to instantiate %s from %s",
|
||
|
|
metric_cls.component_type(),
|
||
|
|
str(e),
|
||
|
|
)
|
||
|
|
|
||
|
|
def is_enabled(self) -> bool:
|
||
|
|
return len(self.metrics) > 0
|
||
|
|
|
||
|
|
def get_num_flops(self, ctx: ExecutionContext, per_gpu: bool = True) -> int:
|
||
|
|
return sum(metric.get_num_flops(ctx, per_gpu) for metric in self.metrics)
|
||
|
|
|
||
|
|
def get_read_bytes(self, ctx: ExecutionContext, per_gpu: bool = True) -> int:
|
||
|
|
return sum(metric.get_read_bytes(ctx, per_gpu) for metric in self.metrics)
|
||
|
|
|
||
|
|
def get_write_bytes(self, ctx: ExecutionContext, per_gpu: bool = True) -> int:
|
||
|
|
return sum(metric.get_write_bytes(ctx, per_gpu) for metric in self.metrics)
|
||
|
|
|
||
|
|
def get_num_flops_breakdown(
|
||
|
|
self, ctx: ExecutionContext, per_gpu: bool = True
|
||
|
|
) -> dict[str, int]:
|
||
|
|
total = {}
|
||
|
|
for metric in self.metrics:
|
||
|
|
breakdown = metric.get_num_flops_breakdown(ctx, per_gpu)
|
||
|
|
component = metric.component_type()
|
||
|
|
prefixed = {f"{component}.{key}": val for key, val in breakdown.items()}
|
||
|
|
total.update(prefixed)
|
||
|
|
return total
|
||
|
|
|
||
|
|
def get_read_bytes_breakdown(
|
||
|
|
self, ctx: ExecutionContext, per_gpu: bool = True
|
||
|
|
) -> dict[str, int]:
|
||
|
|
total = {}
|
||
|
|
for metric in self.metrics:
|
||
|
|
breakdown = metric.get_read_bytes_breakdown(ctx, per_gpu)
|
||
|
|
component = metric.component_type()
|
||
|
|
prefixed = {f"{component}.{key}": val for key, val in breakdown.items()}
|
||
|
|
total.update(prefixed)
|
||
|
|
return total
|
||
|
|
|
||
|
|
def get_write_bytes_breakdown(
|
||
|
|
self, ctx: ExecutionContext, per_gpu: bool = True
|
||
|
|
) -> dict[str, int]:
|
||
|
|
total = {}
|
||
|
|
for metric in self.metrics:
|
||
|
|
breakdown = metric.get_write_bytes_breakdown(ctx, per_gpu)
|
||
|
|
component = metric.component_type()
|
||
|
|
prefixed = {f"{component}.{key}": val for key, val in breakdown.items()}
|
||
|
|
total.update(prefixed)
|
||
|
|
return total
|
||
|
|
|
||
|
|
def get_step_perf_stats_per_gpu(
|
||
|
|
self, scheduler_output: SchedulerOutput
|
||
|
|
) -> PerfStats:
|
||
|
|
"""
|
||
|
|
Calculate perf stats for the current step based on scheduled tokens.
|
||
|
|
"""
|
||
|
|
|
||
|
|
t0 = time.monotonic()
|
||
|
|
|
||
|
|
# Build a single batch context
|
||
|
|
ctx = ExecutionContext()
|
||
|
|
|
||
|
|
# Process new requests (these are in prefill phase)
|
||
|
|
for new_req in scheduler_output.scheduled_new_reqs:
|
||
|
|
req_id = new_req.req_id
|
||
|
|
num_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0)
|
||
|
|
if num_tokens == 0:
|
||
|
|
continue
|
||
|
|
|
||
|
|
# For new requests, context_len = num_computed_tokens + num_tokens
|
||
|
|
# num_computed_tokens represents previously computed tokens in the sequence
|
||
|
|
context_len = new_req.num_computed_tokens + num_tokens
|
||
|
|
ctx.add(num_tokens, context_len, is_prefill=True)
|
||
|
|
|
||
|
|
# Process cached requests (continuing requests)
|
||
|
|
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||
|
|
for i, req_id in enumerate(cached_reqs.req_ids):
|
||
|
|
num_tokens = scheduler_output.num_scheduled_tokens.get(req_id, 0)
|
||
|
|
if num_tokens == 0:
|
||
|
|
continue
|
||
|
|
|
||
|
|
# For cached requests, we have the current num_computed_tokens
|
||
|
|
num_computed_tokens = cached_reqs.num_computed_tokens[i]
|
||
|
|
context_len = num_computed_tokens + num_tokens
|
||
|
|
|
||
|
|
# Cached requests are typically in decode phase (num_tokens == 1)
|
||
|
|
# unless they're doing chunked prefill (num_tokens > 1)
|
||
|
|
is_prefill = num_tokens > 1
|
||
|
|
ctx.add(num_tokens, context_len, is_prefill)
|
||
|
|
|
||
|
|
num_flops_breakdown = self.get_num_flops_breakdown(ctx, True)
|
||
|
|
read_bytes_breakdown = self.get_read_bytes_breakdown(ctx, True)
|
||
|
|
write_bytes_breakdown = self.get_write_bytes_breakdown(ctx, True)
|
||
|
|
perf_stats = PerfStats(
|
||
|
|
sum(num_flops_breakdown.values()),
|
||
|
|
sum(read_bytes_breakdown.values()),
|
||
|
|
sum(write_bytes_breakdown.values()),
|
||
|
|
)
|
||
|
|
|
||
|
|
if envs.VLLM_DEBUG_MFU_METRICS:
|
||
|
|
perf_stats.debug_stats = DebugPerfStats(
|
||
|
|
time.monotonic() - t0,
|
||
|
|
ctx.num_prefill_requests,
|
||
|
|
ctx.num_decode_requests,
|
||
|
|
asdict(ctx),
|
||
|
|
num_flops_breakdown,
|
||
|
|
read_bytes_breakdown,
|
||
|
|
write_bytes_breakdown,
|
||
|
|
)
|
||
|
|
|
||
|
|
return perf_stats
|
||
|
|
|
||
|
|
|
||
|
|
#### Logging ####
|
||
|
|
|
||
|
|
|
||
|
|
class PerfMetricsDebugLogging:
|
||
|
|
def __init__(self):
|
||
|
|
self.reset()
|
||
|
|
|
||
|
|
def reset(self):
|
||
|
|
self.total_calc_duration: float = 0.0
|
||
|
|
self.total_num_prefill_requests: int = 0
|
||
|
|
self.total_num_decode_requests: int = 0
|
||
|
|
self.total_num_batches: int = 0
|
||
|
|
self.total_context_breakdown: dict[str, int] = {}
|
||
|
|
self.total_num_flops_per_gpu_breakdown: dict[str, int] = {}
|
||
|
|
self.total_read_bytes_per_gpu_breakdown: dict[str, int] = {}
|
||
|
|
self.total_write_bytes_per_gpu_breakdown: dict[str, int] = {}
|
||
|
|
|
||
|
|
def observe(self, debug_stats: DebugPerfStats) -> None:
|
||
|
|
self.total_calc_duration += debug_stats.calc_duration
|
||
|
|
self.total_num_prefill_requests += debug_stats.num_prefill_requests
|
||
|
|
self.total_num_decode_requests += debug_stats.num_decode_requests
|
||
|
|
self.total_num_batches += 1
|
||
|
|
|
||
|
|
for dst, src in zip(
|
||
|
|
[
|
||
|
|
self.total_context_breakdown,
|
||
|
|
self.total_num_flops_per_gpu_breakdown,
|
||
|
|
self.total_read_bytes_per_gpu_breakdown,
|
||
|
|
self.total_write_bytes_per_gpu_breakdown,
|
||
|
|
],
|
||
|
|
[
|
||
|
|
debug_stats.context_breakdown,
|
||
|
|
debug_stats.num_flops_per_gpu_breakdown,
|
||
|
|
debug_stats.num_read_bytes_per_gpu_breakdown,
|
||
|
|
debug_stats.num_write_bytes_per_gpu_breakdown,
|
||
|
|
],
|
||
|
|
):
|
||
|
|
assert isinstance(src, dict)
|
||
|
|
for key, val in src.items():
|
||
|
|
dst[key] = dst.get(key, 0) + val
|
||
|
|
|
||
|
|
def log(self, log_fn, log_prefix: str, delta_time: float):
|
||
|
|
# pretty print breakdowns
|
||
|
|
total_num_flops_per_gpu_breakdown = {
|
||
|
|
k: f"{v / 1e12:.1f}TF"
|
||
|
|
for k, v in self.total_num_flops_per_gpu_breakdown.items()
|
||
|
|
}
|
||
|
|
total_read_bytes_per_gpu_breakdown = {
|
||
|
|
k: f"{v / 1e9:.1f}GB"
|
||
|
|
for k, v in self.total_read_bytes_per_gpu_breakdown.items()
|
||
|
|
}
|
||
|
|
total_write_bytes_per_gpu_breakdown = {
|
||
|
|
k: f"{v / 1e9:.1f}GB"
|
||
|
|
for k, v in self.total_write_bytes_per_gpu_breakdown.items()
|
||
|
|
}
|
||
|
|
|
||
|
|
logger.debug(
|
||
|
|
"%sMFU details: %s",
|
||
|
|
log_prefix,
|
||
|
|
json.dumps(
|
||
|
|
{
|
||
|
|
"prefill_reqs": self.total_num_prefill_requests,
|
||
|
|
"decode_reqs": self.total_num_decode_requests,
|
||
|
|
"num_batches": self.total_num_batches,
|
||
|
|
"context_breakdown": self.total_context_breakdown,
|
||
|
|
"flops_breakdown": total_num_flops_per_gpu_breakdown,
|
||
|
|
"num_read_bytes_breakdown": total_read_bytes_per_gpu_breakdown,
|
||
|
|
"num_write_bytes_breakdown": (total_write_bytes_per_gpu_breakdown),
|
||
|
|
"duration": f"{delta_time:.1f}s",
|
||
|
|
"mfu_calc_overhead": (
|
||
|
|
f"{self.total_calc_duration / delta_time:.1%}"
|
||
|
|
),
|
||
|
|
},
|
||
|
|
indent=2,
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class PerfMetricsLogging:
|
||
|
|
def __init__(self, vllm_config: VllmConfig):
|
||
|
|
self.vllm_config = vllm_config
|
||
|
|
self.pp_size = vllm_config.parallel_config.pipeline_parallel_size
|
||
|
|
|
||
|
|
self.debug_logging: PerfMetricsDebugLogging | None = None
|
||
|
|
if envs.VLLM_DEBUG_MFU_METRICS:
|
||
|
|
self.debug_logging = PerfMetricsDebugLogging()
|
||
|
|
|
||
|
|
self.reset()
|
||
|
|
|
||
|
|
def reset(self):
|
||
|
|
self.last_log_time = time.monotonic()
|
||
|
|
|
||
|
|
self.total_num_flops_per_gpu: int = 0
|
||
|
|
self.total_read_bytes_per_gpu: int = 0
|
||
|
|
self.total_write_bytes_per_gpu: int = 0
|
||
|
|
|
||
|
|
if self.debug_logging:
|
||
|
|
self.debug_logging.reset()
|
||
|
|
|
||
|
|
def observe(self, perf_stats: PerfStats) -> None:
|
||
|
|
self.total_num_flops_per_gpu += perf_stats.num_flops_per_gpu
|
||
|
|
self.total_read_bytes_per_gpu += perf_stats.num_read_bytes_per_gpu
|
||
|
|
self.total_write_bytes_per_gpu += perf_stats.num_write_bytes_per_gpu
|
||
|
|
|
||
|
|
if self.debug_logging:
|
||
|
|
assert perf_stats.debug_stats is not None
|
||
|
|
self.debug_logging.observe(perf_stats.debug_stats)
|
||
|
|
|
||
|
|
def log(self, log_fn=logger.info, log_prefix: str = "") -> None:
|
||
|
|
if not (
|
||
|
|
self.total_num_flops_per_gpu
|
||
|
|
or self.total_read_bytes_per_gpu
|
||
|
|
or self.total_write_bytes_per_gpu
|
||
|
|
):
|
||
|
|
return
|
||
|
|
|
||
|
|
now = time.monotonic()
|
||
|
|
delta_time = now - self.last_log_time
|
||
|
|
|
||
|
|
if delta_time <= 0.0:
|
||
|
|
avg_tflops_per_gpu = 0.0
|
||
|
|
avg_gbps_per_gpu = 0.0
|
||
|
|
else:
|
||
|
|
avg_tflops_per_gpu = self.total_num_flops_per_gpu / delta_time / 1e12
|
||
|
|
avg_gbps_per_gpu = (
|
||
|
|
(self.total_read_bytes_per_gpu + self.total_write_bytes_per_gpu)
|
||
|
|
/ delta_time
|
||
|
|
/ 1e9
|
||
|
|
)
|
||
|
|
|
||
|
|
log_fn(
|
||
|
|
"%sMFU: %.1f TF/s/GPU %.1f GB/s/GPU",
|
||
|
|
log_prefix,
|
||
|
|
avg_tflops_per_gpu,
|
||
|
|
avg_gbps_per_gpu,
|
||
|
|
)
|
||
|
|
|
||
|
|
if self.debug_logging:
|
||
|
|
self.debug_logging.log(log_fn, log_prefix, delta_time)
|
||
|
|
|
||
|
|
self.reset()
|
||
|
|
|
||
|
|
|
||
|
|
#### Prometheus Integration ####
|
||
|
|
|
||
|
|
|
||
|
|
class PerfMetricsProm:
|
||
|
|
"""Record performance metrics in Prometheus.
|
||
|
|
|
||
|
|
Average TFLOPS (tera floating-point operations per second) can be
|
||
|
|
calculated using a PromQL query:
|
||
|
|
|
||
|
|
rate(vllm:estimated_flops_per_gpu_total[1m]) / 1e12
|
||
|
|
|
||
|
|
Average memory bandwidth in GB/s can be calculated using:
|
||
|
|
|
||
|
|
(rate(vllm:estimated_read_bytes_per_gpu_total[1m]) +
|
||
|
|
rate(vllm:estimated_write_bytes_per_gpu_total[1m])) / 1e9
|
||
|
|
"""
|
||
|
|
|
||
|
|
_counter_cls = prometheus_client.Counter
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
vllm_config: VllmConfig,
|
||
|
|
labelnames: list[str],
|
||
|
|
per_engine_labelvalues: dict[int, list[object]],
|
||
|
|
):
|
||
|
|
counter_flops = self._counter_cls(
|
||
|
|
name="vllm:estimated_flops_per_gpu_total",
|
||
|
|
documentation=(
|
||
|
|
"Estimated number of floating point operations per GPU "
|
||
|
|
"(for Model Flops Utilization calculations)."
|
||
|
|
),
|
||
|
|
labelnames=labelnames,
|
||
|
|
)
|
||
|
|
self.counter_flops = make_per_engine(counter_flops, per_engine_labelvalues)
|
||
|
|
|
||
|
|
counter_read_bytes = self._counter_cls(
|
||
|
|
name="vllm:estimated_read_bytes_per_gpu_total",
|
||
|
|
documentation=(
|
||
|
|
"Estimated number of bytes read from memory per GPU "
|
||
|
|
"(for Model Flops Utilization calculations)."
|
||
|
|
),
|
||
|
|
labelnames=labelnames,
|
||
|
|
)
|
||
|
|
self.counter_read_bytes = make_per_engine(
|
||
|
|
counter_read_bytes, per_engine_labelvalues
|
||
|
|
)
|
||
|
|
|
||
|
|
counter_write_bytes = self._counter_cls(
|
||
|
|
name="vllm:estimated_write_bytes_per_gpu_total",
|
||
|
|
documentation=(
|
||
|
|
"Estimated number of bytes written to memory per GPU "
|
||
|
|
"(for Model Flops Utilization calculations)."
|
||
|
|
),
|
||
|
|
labelnames=labelnames,
|
||
|
|
)
|
||
|
|
self.counter_write_bytes = make_per_engine(
|
||
|
|
counter_write_bytes, per_engine_labelvalues
|
||
|
|
)
|
||
|
|
|
||
|
|
def observe(self, perf_stats: PerfStats, engine_idx: int = 0):
|
||
|
|
if not (
|
||
|
|
perf_stats.num_flops_per_gpu
|
||
|
|
or perf_stats.num_read_bytes_per_gpu
|
||
|
|
or perf_stats.num_write_bytes_per_gpu
|
||
|
|
):
|
||
|
|
return
|
||
|
|
self.counter_flops[engine_idx].inc(perf_stats.num_flops_per_gpu)
|
||
|
|
self.counter_read_bytes[engine_idx].inc(perf_stats.num_read_bytes_per_gpu)
|
||
|
|
self.counter_write_bytes[engine_idx].inc(perf_stats.num_write_bytes_per_gpu)
|
||
|
|
|
||
|
|
|
||
|
|
def make_per_engine(
|
||
|
|
counter: prometheus_client.Counter, per_engine_labelvalues: dict[int, list[object]]
|
||
|
|
):
|
||
|
|
"""Create a counter for each label value."""
|
||
|
|
return {
|
||
|
|
idx: counter.labels(*labelvalues)
|
||
|
|
for idx, labelvalues in per_engine_labelvalues.items()
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
## util functions
|
||
|
|
|
||
|
|
|
||
|
|
def get_required(obj: object, attr: str):
|
||
|
|
"""Get an attr from an object, or throw a InvalidComponentError if it's not set."""
|
||
|
|
if not hasattr(obj, attr):
|
||
|
|
raise InvalidComponent(f"Missing required attr {attr} in config")
|
||
|
|
return getattr(obj, attr)
|
||
|
|
|
||
|
|
|
||
|
|
def getattr_from_list(obj: object, attrs: list[str], default: object = None):
|
||
|
|
"""Try to get the first attr that exists in the object
|
||
|
|
from a list of attrs. Otherwise return None."""
|
||
|
|
for attr in attrs:
|
||
|
|
if hasattr(obj, attr):
|
||
|
|
return getattr(obj, attr)
|
||
|
|
return default
|