From 938d0854a58f17d17a758daa698e4ed463541535 Mon Sep 17 00:00:00 2001
From: LZiBee <2736864745@qq.com>
Date: Wed, 29 Apr 2026 19:38:22 +0800
Subject: [PATCH] Upgrade to vllm 0.17.0 corex v4.1 overlay
---
Dockerfile | 11 +-
README.md | 69 +-
.../direct_url.json | 1 -
.../INSTALLER | 0
.../METADATA | 15 +-
.../RECORD | 838 +++---
.../REQUESTED | 0
.../WHEEL | 2 +-
.../direct_url.json | 1 +
.../entry_points.txt | 0
.../licenses/LICENSE | 0
.../top_level.txt | 0
vllm/.gitignore | 244 --
vllm/__init__.py | 6 -
vllm/_aiter_ops.py | 160 ++
vllm/_bc_linter.py | 54 -
vllm/_custom_ops.py | 1741 ++++++------
vllm/benchmarks/datasets.py | 31 +-
vllm/benchmarks/lib/__init__.py | 3 +
vllm/benchmarks/lib/endpoint_request_func.py | 802 ++++++
vllm/benchmarks/lib/ready_checker.py | 79 +
vllm/benchmarks/lib/utils.py | 131 +
vllm/benchmarks/plot.py | 316 +++
vllm/benchmarks/serve.py | 167 +-
vllm/benchmarks/sweep/cli.py | 6 +-
vllm/benchmarks/sweep/plot.py | 17 +-
vllm/benchmarks/sweep/plot_pareto.py | 5 +-
vllm/benchmarks/sweep/serve.py | 129 +-
vllm/benchmarks/sweep/serve_sla.py | 305 ---
vllm/benchmarks/sweep/serve_workload.py | 328 +++
vllm/benchmarks/sweep/startup.py | 103 +-
vllm/compilation/backends.py | 70 +-
vllm/compilation/caching.py | 21 +-
vllm/compilation/compiler_interface.py | 48 +-
vllm/compilation/counter.py | 2 +
vllm/compilation/cuda_graph.py | 7 +-
.../passes/fusion/collective_fusion.py | 6 +-
.../passes/fusion/rocm_aiter_fusion.py | 22 +-
.../passes/fusion/sequence_parallelism.py | 4 -
.../passes/utility/fix_functionalization.py | 10 +-
vllm/compilation/piecewise_backend.py | 24 +-
vllm/compilation/wrapper.py | 49 +
vllm/config/attention.py | 4 +-
vllm/config/compilation.py | 18 +-
vllm/config/model.py | 25 +-
vllm/config/parallel.py | 142 +-
vllm/config/speculative.py | 230 +-
vllm/config/vllm.py | 131 +-
vllm/config/weight_transfer.py | 2 +-
vllm/device_allocator/cumem.py | 17 +-
.../device_communicators/all2all.py | 128 +-
.../device_communicators/all_reduce_utils.py | 49 +-
.../base_device_communicator.py | 68 +-
.../device_communicators/cpu_communicator.py | 37 +-
.../device_communicators/cuda_communicator.py | 115 +-
.../device_communicators/pynccl.py | 36 +-
vllm/distributed/elastic_ep/__init__.py | 0
.../distributed/elastic_ep/elastic_execute.py | 529 ++++
vllm/distributed/elastic_ep/elastic_state.py | 563 ++++
vllm/distributed/elastic_ep/standby_state.py | 117 +
vllm/distributed/eplb/async_worker.py | 4 -
vllm/distributed/eplb/eplb_state.py | 314 +--
vllm/distributed/eplb/rebalance_execute.py | 80 +-
vllm/distributed/kv_events.py | 4 +
.../kv_transfer/kv_connector/factory.py | 6 +
.../kv_transfer/kv_connector/utils.py | 45 +-
.../kv_transfer/kv_connector/v1/base.py | 22 +
.../kv_connector/v1/example_connector.py | 19 +-
.../v1/example_hidden_states_connector.py | 354 +++
.../kv_connector/v1/lmcache_connector.py | 10 +
.../v1/mooncake/mooncake_connector.py | 31 +-
.../kv_connector/v1/multi_connector.py | 15 +
.../kv_connector/v1/nixl_connector.py | 1231 +++++----
vllm/distributed/parallel_state.py | 301 ++-
vllm/distributed/stateless_coordinator.py | 322 +++
vllm/distributed/utils.py | 93 +-
vllm/distributed/weight_transfer/base.py | 29 +-
vllm/distributed/weight_transfer/factory.py | 6 +
.../distributed/weight_transfer/ipc_engine.py | 291 ++
.../weight_transfer/nccl_engine.py | 85 +-
vllm/engine/arg_utils.py | 30 +-
vllm/entrypoints/anthropic/api_router.py | 66 +-
vllm/entrypoints/anthropic/protocol.py | 44 +-
vllm/entrypoints/anthropic/serving.py | 684 +++--
vllm/entrypoints/chat_utils.py | 8 +
vllm/entrypoints/cli/__init__.py | 25 +-
vllm/entrypoints/cli/benchmark/main.py | 28 -
vllm/entrypoints/cli/serve.py | 12 +-
vllm/entrypoints/grpc_server.py | 15 +-
vllm/entrypoints/llm.py | 158 +-
vllm/entrypoints/logger.py | 14 +
.../openai/chat_completion/protocol.py | 41 +-
.../openai/chat_completion/serving.py | 22 +-
vllm/entrypoints/openai/cli_args.py | 7 +-
.../entrypoints/openai/completion/protocol.py | 43 +-
vllm/entrypoints/openai/engine/serving.py | 21 +-
vllm/entrypoints/openai/responses/protocol.py | 19 +-
vllm/entrypoints/openai/responses/serving.py | 79 +-
.../openai/speech_to_text/speech_to_text.py | 21 +-
.../openai/translations/__init__.py | 12 -
.../openai/translations/api_router.py | 14 -
.../openai/translations/protocol.py | 14 -
.../openai/translations/serving.py | 14 -
.../openai/translations/speech_to_text.py | 15 -
vllm/entrypoints/pooling/__init__.py | 1 +
vllm/entrypoints/pooling/base/io_processor.py | 189 ++
vllm/entrypoints/pooling/base/protocol.py | 4 -
vllm/entrypoints/pooling/base/serving.py | 378 +++
.../pooling/classify/api_router.py | 31 +-
.../pooling/classify/io_processor.py | 50 +
vllm/entrypoints/pooling/classify/protocol.py | 2 -
vllm/entrypoints/pooling/classify/serving.py | 132 +-
vllm/entrypoints/pooling/embed/protocol.py | 19 -
.../pooling/io_processor_factories.py | 31 +
vllm/entrypoints/pooling/pooling/protocol.py | 19 -
vllm/entrypoints/pooling/score/protocol.py | 2 -
vllm/entrypoints/pooling/score/serving.py | 19 +-
vllm/entrypoints/pooling/score/utils.py | 88 +-
vllm/entrypoints/pooling/typing.py | 51 +
vllm/entrypoints/sagemaker/api_router.py | 3 +-
vllm/entrypoints/utils.py | 71 +-
vllm/env_override.py | 41 +
vllm/envs.py | 203 +-
vllm/forward_context.py | 3 +-
vllm/kernels/helion/register.py | 127 +-
vllm/lora/layers/fused_moe.py | 25 +-
vllm/lora/layers/logits_processor.py | 6 +-
vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 252 +-
vllm/lora/ops/triton_ops/lora_expand_op.py | 6 +-
.../ops/triton_ops/lora_kernel_metadata.py | 42 +-
vllm/lora/ops/triton_ops/lora_shrink_op.py | 9 +-
vllm/lora/ops/triton_ops/utils.py | 6 +
vllm/lora/punica_wrapper/punica_gpu.py | 13 +
.../model_executor/kernels/linear/__init__.py | 2 +-
.../kernels/linear/mixed_precision/machete.py | 39 +-
.../kernels/linear/mixed_precision/marlin.py | 314 ++-
.../kernels/linear/scaled_mm/cutlass.py | 56 +-
vllm/model_executor/layers/activation.py | 28 +-
.../layers/attention/attention.py | 123 +-
.../layers/attention/extra_cache.py | 131 +
.../layers/attention/mla_attention.py | 2039 +++++++-------
.../layers/attention/mm_encoder_attention.py | 209 +-
.../model_executor/layers/fla/ops/__init__.py | 8 +-
vllm/model_executor/layers/fla/ops/chunk.py | 10 +-
.../layers/fla/ops/chunk_delta_h.py | 2 +-
vllm/model_executor/layers/fla/ops/chunk_o.py | 4 +-
.../layers/fla/ops/chunk_scaled_dot_kkt.py | 4 +-
.../layers/fla/ops/fused_recurrent.py | 248 +-
.../layers/fla/ops/fused_sigmoid_gating.py | 279 ++
vllm/model_executor/layers/fla/ops/index.py | 10 +-
vllm/model_executor/layers/fla/ops/kda.py | 16 +-
.../layers/fla/ops/layernorm_guard.py | 35 +-
vllm/model_executor/layers/fla/ops/wy_fast.py | 2 +-
.../layers/fused_moe/__init__.py | 12 +-
.../layers/fused_moe/activation.py | 12 +-
.../layers/fused_moe/all2all_utils.py | 68 +-
.../layers/fused_moe/batched_deep_gemm_moe.py | 2 +-
.../model_executor/layers/fused_moe/config.py | 16 +-
...evice_name=NVIDIA_H200,dtype=fp8_w8a8.json | 147 +
.../E=128,N=96,device_name=NVIDIA_H200.json | 147 +
.../layers/fused_moe/cutlass_moe.py | 17 +-
.../layers/fused_moe/deep_gemm_moe.py | 2 +-
.../fused_moe/deepep_ht_prepare_finalize.py | 9 +-
.../fused_moe/deepep_ll_prepare_finalize.py | 12 +-
.../layers/fused_moe/experts/__init__.py | 0
.../fused_moe/experts/trtllm_fp8_moe.py | 335 +++
.../fused_moe/experts/trtllm_nvfp4_moe.py | 326 +++
.../layers/fused_moe/fallback.py | 12 +-
.../flashinfer_a2a_prepare_finalize.py | 6 +-
.../fused_moe/flashinfer_cutedsl_moe.py | 2 +-
.../fused_moe/flashinfer_cutlass_moe.py | 2 +-
.../layers/fused_moe/flashinfer_trtllm_moe.py | 298 ---
.../layers/fused_moe/fused_batched_moe.py | 12 +-
.../layers/fused_moe/fused_marlin_moe.py | 2 +-
.../layers/fused_moe/fused_moe.py | 594 +++--
.../layers/fused_moe/fused_moe_method_base.py | 51 +-
.../fused_moe/fused_moe_modular_method.py | 18 +-
.../fused_moe/gpt_oss_triton_kernels_moe.py | 139 +-
vllm/model_executor/layers/fused_moe/layer.py | 138 +-
.../layers/fused_moe/modular_kernel.py | 701 +++--
.../layers/fused_moe/mori_prepare_finalize.py | 2 +-
.../layers/fused_moe/oracle/fp8.py | 132 +-
.../layers/fused_moe/oracle/nvfp4.py | 143 +-
.../layers/fused_moe/oracle/unquantized.py | 35 +-
.../layers/fused_moe/pplx_prepare_finalize.py | 373 ---
.../layers/fused_moe/prepare_finalize.py | 209 --
.../fused_moe/prepare_finalize/__init__.py | 22 +
.../fused_moe/prepare_finalize/naive_dp_ep.py | 253 ++
.../fused_moe/prepare_finalize/no_dp_ep.py | 141 +
.../layers/fused_moe/rocm_aiter_fused_moe.py | 2 +-
.../fused_moe/routed_experts_capturer.py | 3 +-
.../layers/fused_moe/router/base_router.py | 3 +-
.../router/fused_topk_bias_router.py | 9 +-
.../fused_moe/router/fused_topk_router.py | 8 +-
.../layers/fused_moe/router/gate_linear.py | 115 +
.../fused_moe/router/grouped_topk_router.py | 83 +-
.../layers/fused_moe/router/router_factory.py | 2 +-
.../fused_moe/runner/default_moe_runner.py | 143 +-
.../fused_moe/topk_weight_and_reduce.py | 13 +-
.../layers/fused_moe/triton_cutlass_moe.py | 6 +-
.../layers/fused_moe/triton_deep_gemm_moe.py | 6 +-
.../layers/fused_moe/trtllm_moe.py | 2 +-
.../fused_moe/unquantized_fused_moe_method.py | 28 +-
.../layers/fused_moe/xpu_fused_moe.py | 2 +-
vllm/model_executor/layers/layernorm.py | 36 +-
vllm/model_executor/layers/linear.py | 260 +-
.../model_executor/layers/logits_processor.py | 5 +-
.../layers/mamba/linear_attn.py | 189 +-
.../layers/mamba/mamba_mixer.py | 4 +
.../layers/mamba/mamba_utils.py | 3 -
.../layers/mamba/ops/causal_conv1d.py | 2 +-
.../layers/mamba/ops/mamba_ssm.py | 4 +
vllm/model_executor/layers/mla.py | 98 +-
.../layers/quantization/__init__.py | 12 +-
.../layers/quantization/awq_marlin.py | 256 +-
.../compressed_tensors/compressed_tensors.py | 16 +-
.../compressed_tensors_moe.py | 2367 +++++++++++++----
.../compressed_tensors/schemes/__init__.py | 3 +-
.../schemes/compressed_tensors_w8a8_int8.py | 145 +-
.../model_executor/layers/quantization/fp8.py | 97 +-
.../layers/quantization/gguf.py | 14 +-
.../layers/quantization/gptq.py | 2 +-
.../layers/quantization/gptq_marlin.py | 560 +++-
.../layers/quantization/modelopt.py | 540 ++--
.../layers/quantization/mxfp4.py | 126 +-
.../layers/quantization/ptpc_fp8.py | 5 -
.../layers/quantization/quark/quark.py | 48 +-
.../layers/quantization/quark/quark_moe.py | 202 +-
.../quark/schemes/quark_ocp_mx.py | 93 +-
.../quantization/utils/flashinfer_fp4_moe.py | 296 +--
.../quantization/utils/flashinfer_utils.py | 119 +-
.../layers/quantization/utils/fp8_utils.py | 5 +-
.../layers/quantization/utils/gguf_utils.py | 373 +++
.../layers/quantization/utils/marlin_utils.py | 18 -
.../layers/quantization/utils/mxfp4_utils.py | 4 +-
.../layers/quantization/utils/quant_utils.py | 10 +-
.../layers/quantization/w8a16.py | 114 +
.../layers/rotary_embedding/base.py | 1 +
.../layers/rotary_embedding/common.py | 15 -
.../rotary_embedding/deepseek_scaling_rope.py | 74 +-
.../layers/rotary_embedding/mrope.py | 78 +-
.../phi3_long_rope_scaled_rope.py | 38 +-
.../shared_fused_moe/shared_fused_moe.py | 56 +
.../layers/sparse_attn_indexer.py | 349 ++-
vllm/model_executor/layers/utils.py | 100 +-
vllm/model_executor/model_loader/__init__.py | 4 +
.../model_loader/default_loader.py | 9 +-
.../model_loader/weight_utils.py | 229 +-
vllm/model_executor/models/AXK1.py | 1168 ++++++++
.../models/bailing_moe_linear.py | 1246 +++++++++
vllm/model_executor/models/config.py | 51 +-
vllm/model_executor/models/deepencoder.py | 7 +-
vllm/model_executor/models/deepseek_mtp.py | 136 +-
vllm/model_executor/models/deepseek_v2.py | 668 ++---
vllm/model_executor/models/ernie45_moe.py | 11 +-
vllm/model_executor/models/ernie45_vl.py | 21 +-
.../models/extract_hidden_states.py | 394 +++
vllm/model_executor/models/fireredasr2.py | 829 ++++++
vllm/model_executor/models/funaudiochat.py | 80 -
vllm/model_executor/models/gemma3n.py | 2 +-
vllm/model_executor/models/glm4_moe.py | 18 +-
vllm/model_executor/models/glm4_moe_lite.py | 316 ++-
vllm/model_executor/models/gpt_oss.py | 141 +-
vllm/model_executor/models/hunyuan_vision.py | 8 +-
.../models/hyperclovax_vision.py | 1 -
vllm/model_executor/models/isaac.py | 2 +-
vllm/model_executor/models/keye.py | 8 +-
vllm/model_executor/models/minicpm.py | 2 +-
vllm/model_executor/models/minicpm3.py | 2 +-
vllm/model_executor/models/minimax_m2.py | 11 +-
.../model_executor/models/nano_nemotron_vl.py | 372 ++-
vllm/model_executor/models/nemotron_h.py | 20 +-
vllm/model_executor/models/nemotron_vl.py | 359 ++-
vllm/model_executor/models/ovis2_5.py | 3 -
vllm/model_executor/models/paddleocr_vl.py | 21 +-
vllm/model_executor/models/parakeet.py | 145 +
vllm/model_executor/models/phimoe.py | 3 +-
vllm/model_executor/models/pixtral.py | 27 +-
vllm/model_executor/models/qwen.py | 2 +-
.../models/qwen2_5_omni_thinker.py | 92 +-
vllm/model_executor/models/qwen2_5_vl.py | 13 +-
vllm/model_executor/models/qwen2_audio.py | 20 +
vllm/model_executor/models/qwen2_vl.py | 18 +-
vllm/model_executor/models/qwen3.py | 30 +-
vllm/model_executor/models/qwen3_5.py | 30 +-
vllm/model_executor/models/qwen3_5_mtp.py | 2 +-
vllm/model_executor/models/qwen3_moe.py | 44 +-
vllm/model_executor/models/qwen3_next.py | 281 +-
.../models/qwen3_omni_moe_thinker.py | 62 +-
vllm/model_executor/models/qwen3_vl.py | 883 ++++--
vllm/model_executor/models/qwen3_vl_moe.py | 7 +-
vllm/model_executor/models/qwen_vl.py | 68 +-
vllm/model_executor/models/registry.py | 15 +
vllm/model_executor/models/step3_text.py | 2 +-
vllm/model_executor/models/step3p5.py | 67 +-
vllm/model_executor/models/step3p5_mtp.py | 63 +-
vllm/model_executor/models/swin.py | 500 ----
vllm/model_executor/models/teleflm.py | 3 +-
.../models/transformers/base.py | 14 +-
.../models/transformers/multimodal.py | 26 +-
vllm/model_executor/models/utils.py | 16 +-
vllm/model_executor/parameter.py | 29 +-
.../model_executor/warmup/deep_gemm_warmup.py | 2 +-
vllm/model_executor/warmup/kernel_warmup.py | 11 +-
vllm/multimodal/evs.py | 78 +-
vllm/multimodal/processing/processor.py | 12 +-
vllm/multimodal/utils.py | 18 -
vllm/platforms/__init__.py | 40 +-
vllm/platforms/cpu.py | 24 +
vllm/platforms/cuda.py | 37 +-
vllm/platforms/interface.py | 13 +-
vllm/platforms/rocm.py | 150 +-
vllm/platforms/xpu.py | 6 +
vllm/plugins/io_processors/__init__.py | 18 +-
vllm/plugins/io_processors/interface.py | 3 +-
vllm/pooling_params.py | 8 +-
vllm/reasoning/__init__.py | 4 +-
vllm/reasoning/kimi_k2_reasoning_parser.py | 228 ++
vllm/reasoning/minimax_m2_reasoning_parser.py | 7 +-
vllm/reasoning/qwen3_reasoning_parser.py | 25 +-
vllm/renderers/qwen_vl.py | 29 +
vllm/renderers/registry.py | 1 +
vllm/sampling_params.py | 87 +-
vllm/third_party/pynvml.py | 2 +-
vllm/tokenizers/deepseek_v32.py | 2 +-
vllm/tokenizers/qwen_vl.py | 67 +
vllm/tokenizers/registry.py | 5 +
vllm/tool_parsers/qwen3coder_tool_parser.py | 399 ++-
vllm/tool_parsers/utils.py | 15 -
vllm/transformers_utils/configs/AXK1.py | 215 ++
vllm/transformers_utils/configs/__init__.py | 2 +
.../configs/extract_hidden_states.py | 53 +
vllm/transformers_utils/configs/parakeet.py | 49 +
.../model_arch_config_convertor.py | 10 +-
vllm/transformers_utils/processor.py | 94 +-
.../transformers_utils/processors/__init__.py | 4 +
.../processors/fireredasr2_processor.py | 341 +++
vllm/transformers_utils/repo_utils.py | 2 +-
vllm/triton_utils/allocation.py | 13 +
vllm/utils/deep_gemm.py | 121 +
vllm/utils/flashinfer.py | 9 +-
vllm/utils/import_utils.py | 5 -
vllm/utils/math_utils.py | 5 +
vllm/utils/system_utils.py | 16 +
vllm/utils/torch_utils.py | 39 +-
vllm/v1/attention/backend.py | 71 +-
vllm/v1/attention/backends/fa_utils.py | 139 +-
vllm/v1/attention/backends/flash_attn.py | 616 +++--
vllm/v1/attention/backends/flashinfer.py | 295 +-
vllm/v1/attention/backends/mamba1_attn.py | 33 +-
vllm/v1/attention/backends/mamba2_attn.py | 112 +-
vllm/v1/attention/backends/mamba_attn.py | 121 +
.../attention/backends/mla/flashattn_mla.py | 6 +-
vllm/v1/attention/backends/mla/flashmla.py | 6 +-
.../attention/backends/mla/flashmla_sparse.py | 112 +-
vllm/v1/attention/backends/mla/indexer.py | 247 +-
.../attention/backends/mla/rocm_aiter_mla.py | 6 +-
vllm/v1/attention/backends/mla/triton_mla.py | 71 +-
.../backends/rocm_aiter_unified_attn.py | 31 +
vllm/v1/attention/backends/rocm_attn.py | 80 +-
vllm/v1/attention/ops/flashmla.py | 28 +
vllm/v1/attention/ops/vit_attn_wrappers.py | 92 +-
vllm/v1/core/kv_cache_manager.py | 38 +
vllm/v1/core/kv_cache_utils.py | 143 +-
vllm/v1/core/sched/output.py | 13 +-
vllm/v1/core/sched/scheduler.py | 583 +++-
vllm/v1/core/sched/utils.py | 66 +
vllm/v1/core/single_type_kv_cache_manager.py | 11 +
vllm/v1/cudagraph_dispatcher.py | 60 +-
vllm/v1/engine/__init__.py | 20 +-
vllm/v1/engine/async_llm.py | 42 +-
vllm/v1/engine/coordinator.py | 21 +-
vllm/v1/engine/core.py | 413 ++-
vllm/v1/engine/core_client.py | 336 ++-
vllm/v1/engine/input_processor.py | 11 -
vllm/v1/engine/llm_engine.py | 1 +
vllm/v1/engine/utils.py | 68 +-
vllm/v1/executor/abstract.py | 10 +-
vllm/v1/executor/multiproc_executor.py | 42 +-
vllm/v1/executor/ray_executor.py | 6 +-
vllm/v1/executor/ray_utils.py | 52 +-
vllm/v1/executor/uniproc_executor.py | 17 +-
vllm/v1/kv_cache_interface.py | 112 +-
vllm/v1/kv_offload/worker/cpu_gpu.py | 22 +-
vllm/v1/outputs.py | 70 +-
vllm/v1/request.py | 2 +
vllm/v1/sample/logits_processor/__init__.py | 7 +-
vllm/v1/sample/logits_processor/builtin.py | 54 +
vllm/v1/sample/logits_processor/state.py | 4 +-
vllm/v1/sample/rejection_sampler.py | 15 +-
vllm/v1/spec_decode/eagle.py | 596 +++--
vllm/v1/spec_decode/extract_hidden_states.py | 395 +++
vllm/v1/spec_decode/metadata.py | 42 +
vllm/v1/spec_decode/multi_layer_eagle.py | 504 ++++
vllm/v1/spec_decode/utils.py | 13 +-
vllm/v1/worker/cpu_model_runner.py | 7 +-
vllm/v1/worker/cpu_worker.py | 5 +-
vllm/v1/worker/dp_utils.py | 48 +-
vllm/v1/worker/gpu/async_utils.py | 36 +
vllm/v1/worker/gpu/block_table.py | 11 +
vllm/v1/worker/gpu/cudagraph_utils.py | 132 +-
vllm/v1/worker/gpu/input_batch.py | 54 +-
vllm/v1/worker/gpu/mm/encoder_cache.py | 40 +
vllm/v1/worker/gpu/mm/encoder_runner.py | 53 +-
vllm/v1/worker/gpu/model_runner.py | 524 ++--
vllm/v1/worker/gpu/model_states/__init__.py | 18 +
vllm/v1/worker/gpu/model_states/default.py | 161 ++
vllm/v1/worker/gpu/model_states/interface.py | 67 +
vllm/v1/worker/gpu/pool/__init__.py | 0
vllm/v1/worker/gpu/pool/pooling_runner.py | 45 +
vllm/v1/worker/gpu/sample/bad_words.py | 18 +-
vllm/v1/worker/gpu/sample/gumbel.py | 48 +-
vllm/v1/worker/gpu/sample/logit_bias.py | 32 +-
vllm/v1/worker/gpu/sample/min_p.py | 24 +-
vllm/v1/worker/gpu/sample/penalties.py | 30 +-
vllm/v1/worker/gpu/sample/sampler.py | 24 +-
vllm/v1/worker/gpu/sample/states.py | 14 +-
.../worker/gpu/spec_decode/eagle/cudagraph.py | 27 +-
.../gpu/spec_decode/eagle/speculator.py | 114 +-
vllm/v1/worker/gpu/warmup.py | 105 +
vllm/v1/worker/gpu_input_batch.py | 89 +
vllm/v1/worker/gpu_model_runner.py | 963 ++++---
vllm/v1/worker/gpu_worker.py | 326 +--
vllm/v1/worker/mamba_utils.py | 94 +-
vllm/v1/worker/utils.py | 366 ++-
vllm/v1/worker/worker_base.py | 17 +-
vllm/v1/worker/workspace.py | 28 +
vllm/version.py | 4 +-
vllm/vllm_flash_attn/__init__.py | 24 +
vllm/vllm_flash_attn/flash_attn_interface.py | 567 ++++
430 files changed, 35969 insertions(+), 14511 deletions(-)
delete mode 100644 vllm-0.16.1rc0+corex.4.4.0.dist-info/direct_url.json
rename {vllm-0.16.1rc0+corex.4.4.0.dist-info => vllm-0.17.0+corex.20260420090923.dist-info}/INSTALLER (100%)
rename {vllm-0.16.1rc0+corex.4.4.0.dist-info => vllm-0.17.0+corex.20260420090923.dist-info}/METADATA (95%)
rename {vllm-0.16.1rc0+corex.4.4.0.dist-info => vllm-0.17.0+corex.20260420090923.dist-info}/RECORD (87%)
rename {vllm-0.16.1rc0+corex.4.4.0.dist-info => vllm-0.17.0+corex.20260420090923.dist-info}/REQUESTED (100%)
rename {vllm-0.16.1rc0+corex.4.4.0.dist-info => vllm-0.17.0+corex.20260420090923.dist-info}/WHEEL (65%)
create mode 100644 vllm-0.17.0+corex.20260420090923.dist-info/direct_url.json
rename {vllm-0.16.1rc0+corex.4.4.0.dist-info => vllm-0.17.0+corex.20260420090923.dist-info}/entry_points.txt (100%)
rename {vllm-0.16.1rc0+corex.4.4.0.dist-info => vllm-0.17.0+corex.20260420090923.dist-info}/licenses/LICENSE (100%)
rename {vllm-0.16.1rc0+corex.4.4.0.dist-info => vllm-0.17.0+corex.20260420090923.dist-info}/top_level.txt (100%)
delete mode 100644 vllm/.gitignore
delete mode 100644 vllm/_bc_linter.py
create mode 100644 vllm/benchmarks/lib/__init__.py
create mode 100644 vllm/benchmarks/lib/endpoint_request_func.py
create mode 100644 vllm/benchmarks/lib/ready_checker.py
create mode 100644 vllm/benchmarks/lib/utils.py
create mode 100644 vllm/benchmarks/plot.py
delete mode 100644 vllm/benchmarks/sweep/serve_sla.py
create mode 100644 vllm/benchmarks/sweep/serve_workload.py
create mode 100644 vllm/distributed/elastic_ep/__init__.py
create mode 100644 vllm/distributed/elastic_ep/elastic_execute.py
create mode 100644 vllm/distributed/elastic_ep/elastic_state.py
create mode 100644 vllm/distributed/elastic_ep/standby_state.py
create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py
create mode 100644 vllm/distributed/stateless_coordinator.py
create mode 100644 vllm/distributed/weight_transfer/ipc_engine.py
delete mode 100644 vllm/entrypoints/openai/translations/__init__.py
delete mode 100644 vllm/entrypoints/openai/translations/api_router.py
delete mode 100644 vllm/entrypoints/openai/translations/protocol.py
delete mode 100644 vllm/entrypoints/openai/translations/serving.py
delete mode 100644 vllm/entrypoints/openai/translations/speech_to_text.py
create mode 100644 vllm/entrypoints/pooling/base/io_processor.py
create mode 100644 vllm/entrypoints/pooling/base/serving.py
create mode 100644 vllm/entrypoints/pooling/classify/io_processor.py
create mode 100644 vllm/entrypoints/pooling/io_processor_factories.py
create mode 100644 vllm/entrypoints/pooling/typing.py
create mode 100644 vllm/model_executor/layers/attention/extra_cache.py
create mode 100644 vllm/model_executor/layers/fla/ops/fused_sigmoid_gating.py
create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H200.json
create mode 100644 vllm/model_executor/layers/fused_moe/experts/__init__.py
create mode 100644 vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py
create mode 100644 vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py
delete mode 100644 vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
delete mode 100644 vllm/model_executor/layers/fused_moe/prepare_finalize.py
create mode 100644 vllm/model_executor/layers/fused_moe/prepare_finalize/__init__.py
create mode 100644 vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py
create mode 100644 vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py
create mode 100644 vllm/model_executor/layers/fused_moe/router/gate_linear.py
create mode 100644 vllm/model_executor/layers/quantization/utils/gguf_utils.py
create mode 100644 vllm/model_executor/layers/quantization/w8a16.py
create mode 100644 vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py
create mode 100644 vllm/model_executor/models/AXK1.py
create mode 100644 vllm/model_executor/models/bailing_moe_linear.py
create mode 100644 vllm/model_executor/models/extract_hidden_states.py
create mode 100644 vllm/model_executor/models/fireredasr2.py
create mode 100644 vllm/model_executor/models/parakeet.py
delete mode 100644 vllm/model_executor/models/swin.py
create mode 100644 vllm/reasoning/kimi_k2_reasoning_parser.py
create mode 100644 vllm/renderers/qwen_vl.py
create mode 100644 vllm/tokenizers/qwen_vl.py
create mode 100644 vllm/transformers_utils/configs/AXK1.py
create mode 100644 vllm/transformers_utils/configs/extract_hidden_states.py
create mode 100644 vllm/transformers_utils/configs/parakeet.py
create mode 100644 vllm/transformers_utils/processors/fireredasr2_processor.py
create mode 100644 vllm/triton_utils/allocation.py
create mode 100644 vllm/v1/spec_decode/extract_hidden_states.py
create mode 100644 vllm/v1/spec_decode/multi_layer_eagle.py
create mode 100644 vllm/v1/worker/gpu/mm/encoder_cache.py
create mode 100644 vllm/v1/worker/gpu/model_states/__init__.py
create mode 100644 vllm/v1/worker/gpu/model_states/default.py
create mode 100644 vllm/v1/worker/gpu/model_states/interface.py
create mode 100644 vllm/v1/worker/gpu/pool/__init__.py
create mode 100644 vllm/v1/worker/gpu/pool/pooling_runner.py
create mode 100644 vllm/v1/worker/gpu/warmup.py
create mode 100644 vllm/vllm_flash_attn/__init__.py
create mode 100644 vllm/vllm_flash_attn/flash_attn_interface.py
diff --git a/Dockerfile b/Dockerfile
index e83ace5..99280a6 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,11 +1,12 @@
-FROM registry.iluvatar.com.cn:10443/customer/sz/vllm0.11.2-4.4.0-x86:v8
+ARG BASE_IMAGE=registry.iluvatar.com.cn:10443/customer/sz/vllm0.17.0-4.4.0-x86:v4.1
+FROM ${BASE_IMAGE}
-# Keep the runtime stack from the known-good v8 image, but replace the
-# installed Python package with the repository's patched 0.16.1rc0 sources.
-WORKDIR /tmp
+WORKDIR /home
RUN rm -rf /usr/local/lib/python3.12/dist-packages/vllm \
/usr/local/lib/python3.12/dist-packages/vllm-*.dist-info
COPY vllm /usr/local/lib/python3.12/dist-packages/vllm
-COPY vllm-0.16.1rc0+corex.4.4.0.dist-info /usr/local/lib/python3.12/dist-packages/vllm-0.16.1rc0+corex.4.4.0.dist-info
+COPY vllm-0.17.0+corex.20260420090923.dist-info /usr/local/lib/python3.12/dist-packages/vllm-0.17.0+corex.20260420090923.dist-info
+
+ENTRYPOINT ["/bin/bash"]
diff --git a/README.md b/README.md
index 8689007..81019bd 100644
--- a/README.md
+++ b/README.md
@@ -1,62 +1,37 @@
# bi_150-vllm
-基于 `registry.iluvatar.com.cn:10443/customer/sz/vllm0.11.2-4.4.0-x86:v8` 的
-`vLLM 0.16.1rc0` 构建仓库,用于在 BI-V150 虚拟机环境中生成可直接运行的镜像。
+This repository contains the extracted `vLLM 0.17.0+corex.20260420090923`
+Python package used to overlay the vendor image
+`registry.iluvatar.com.cn:10443/customer/sz/vllm0.17.0-4.4.0-x86:v4.1`.
-## 改动说明
-
-本仓库只保留构建镜像所需的最小内容:
+## Included files
- `vllm/`
- 当前运行代码
-- `vllm-0.16.1rc0+corex.4.4.0.dist-info/`
- 对应的包元数据
+ The Python package code copied from the image payload.
+- `vllm-0.17.0+corex.20260420090923.dist-info/`
+ The package metadata extracted from the image.
- `Dockerfile`
- 构建最终镜像
+ Builds a new image by replacing the installed `vllm` package in the vendor base image.
-与基础镜像相比,本仓库保留的关键代码改动如下:
+## Build
-- 在 `vllm/platforms/__init__.py` 中修复 CUDA 平台识别逻辑
-- 当 NVML 不可用且出现 `NVML Shared Library Not Found` 一类错误时
- 不再直接判定为非 CUDA 平台
-- 改为回退到 `torch.cuda.is_available()` 和
- `torch.cuda.device_count()` 继续判断 CUDA 是否可用
-- 调整 CLI 初始化逻辑,避免 benchmark 可选依赖缺失时阻塞
- `vllm serve ...` 启动
-
-这个修复用于解决如下启动失败:
-
-```text
-RuntimeError: Failed to infer device type
-```
-
-## 构建镜像
-
-在仓库根目录执行:
+Run the following command from the repository root:
```bash
-docker build -t bi_150_vllm:0.16.1 .
+docker build --pull=false \
+ --build-arg BASE_IMAGE=registry.iluvatar.com.cn:10443/customer/sz/vllm0.17.0-4.4.0-x86:v4.1 \
+ -t bi_150_vllm:0.17.0 \
+ .
```
-## 启动镜像
+## Verify
```bash
-docker run -dit \
- --name iluvatar_test \
- -p 38047:8000 \
- --privileged \
- -v /lib/modules:/lib/modules \
- -v /dev:/dev \
- -v /usr/src:/usr/src \
- -v /mnt/gpfs/leaderboard/modelHubXC/Amu/t1-1.5B:/model \
- -e CUDA_VISIBLE_DEVICES=0 \
- --entrypoint vllm \
- bi_150_vllm:0.16.1 \
- serve /model \
- --port 8000 \
- --served-model-name llm \
- --max-model-len 2048 \
- --enforce-eager \
- --trust-remote-code \
- -tp 1
+docker run --rm -it bi_150_vllm:0.17.0 \
+ python3 -c "import vllm; print(vllm.__file__); print(vllm.__version__)"
```
+
+## Notes
+
+- This is an overlay-style repository, not the original upstream git source tree.
+- The Docker image keeps the vendor runtime stack and only replaces the Python package files.
diff --git a/vllm-0.16.1rc0+corex.4.4.0.dist-info/direct_url.json b/vllm-0.16.1rc0+corex.4.4.0.dist-info/direct_url.json
deleted file mode 100644
index 3a7409e..0000000
--- a/vllm-0.16.1rc0+corex.4.4.0.dist-info/direct_url.json
+++ /dev/null
@@ -1 +0,0 @@
-{"archive_info": {"hash": "sha256=f19c1c4880bc4199fc95de8d77590fcdb4912b1fad9bf939883005812f2338c1", "hashes": {"sha256": "f19c1c4880bc4199fc95de8d77590fcdb4912b1fad9bf939883005812f2338c1"}}, "url": "file:///workspace/vllm-0.16.1rc0%2Bcorex.4.4.0-py3-none-any.whl"}
\ No newline at end of file
diff --git a/vllm-0.16.1rc0+corex.4.4.0.dist-info/INSTALLER b/vllm-0.17.0+corex.20260420090923.dist-info/INSTALLER
similarity index 100%
rename from vllm-0.16.1rc0+corex.4.4.0.dist-info/INSTALLER
rename to vllm-0.17.0+corex.20260420090923.dist-info/INSTALLER
diff --git a/vllm-0.16.1rc0+corex.4.4.0.dist-info/METADATA b/vllm-0.17.0+corex.20260420090923.dist-info/METADATA
similarity index 95%
rename from vllm-0.16.1rc0+corex.4.4.0.dist-info/METADATA
rename to vllm-0.17.0+corex.20260420090923.dist-info/METADATA
index 9eab02e..15c009c 100644
--- a/vllm-0.16.1rc0+corex.4.4.0.dist-info/METADATA
+++ b/vllm-0.17.0+corex.20260420090923.dist-info/METADATA
@@ -1,9 +1,9 @@
Metadata-Version: 2.4
Name: vllm
-Version: 0.16.1rc0+corex.4.4.0
+Version: 0.17.0+corex.20260420090923
Summary: A high-throughput and memory-efficient inference and serving engine for LLMs
Author: vLLM Team
-License-Expression: Apache-2.0
+License: Apache-2.0
Project-URL: Homepage, https://github.com/vllm-project/vllm
Project-URL: Documentation, https://docs.vllm.ai/en/latest/
Project-URL: Slack, https://slack.vllm.ai/
@@ -23,7 +23,7 @@ Requires-Dist: regex
Requires-Dist: cachetools
Requires-Dist: psutil
Requires-Dist: sentencepiece
-Requires-Dist: numpy==1.26.4
+Requires-Dist: numpy
Requires-Dist: requests>=2.26.0
Requires-Dist: tqdm
Requires-Dist: blake3
@@ -33,7 +33,7 @@ Requires-Dist: tokenizers>=0.21.1
Requires-Dist: protobuf!=6.30.*,!=6.31.*,!=6.32.*,!=6.33.0.*,!=6.33.1.*,!=6.33.2.*,!=6.33.3.*,!=6.33.4.*,>=5.29.6
Requires-Dist: fastapi[standard]>=0.115.0
Requires-Dist: aiohttp>=3.13.3
-Requires-Dist: openai>=1.99.1
+Requires-Dist: openai<2.25.0,>=1.99.1
Requires-Dist: pydantic>=2.12.0
Requires-Dist: prometheus_client>=0.18.0
Requires-Dist: pillow
@@ -52,6 +52,7 @@ Requires-Dist: pyzmq>=25.0.0
Requires-Dist: msgspec
Requires-Dist: gguf>=0.17.0
Requires-Dist: mistral_common[image]>=1.9.1
+Requires-Dist: opencv-python-headless>=4.13.0
Requires-Dist: pyyaml
Requires-Dist: six>=1.16.0; python_version > "3.11"
Requires-Dist: setuptools<81.0.0,>=77.0.3; python_version > "3.11"
@@ -76,6 +77,7 @@ Requires-Dist: opentelemetry-sdk>=1.27.0
Requires-Dist: opentelemetry-api>=1.27.0
Requires-Dist: opentelemetry-exporter-otlp>=1.27.0
Requires-Dist: opentelemetry-semantic-conventions-ai>=0.4.1
+Requires-Dist: kaldi-native-fbank>=1.18.7
Requires-Dist: numba==0.61.2
Requires-Dist: ray[cgraph]>=2.48.0
Provides-Extra: bench
@@ -84,6 +86,7 @@ Requires-Dist: matplotlib; extra == "bench"
Requires-Dist: seaborn; extra == "bench"
Requires-Dist: datasets; extra == "bench"
Requires-Dist: scipy; extra == "bench"
+Requires-Dist: plotly; extra == "bench"
Provides-Extra: tensorizer
Requires-Dist: tensorizer==2.10.1; extra == "tensorizer"
Provides-Extra: fastsafetensors
@@ -97,6 +100,10 @@ Requires-Dist: soundfile; extra == "audio"
Requires-Dist: mistral_common[audio]; extra == "audio"
Provides-Extra: video
Provides-Extra: flashinfer
+Provides-Extra: petit-kernel
+Requires-Dist: petit-kernel; extra == "petit-kernel"
+Provides-Extra: helion
+Requires-Dist: helion; extra == "helion"
Provides-Extra: otel
Requires-Dist: opentelemetry-sdk>=1.26.0; extra == "otel"
Requires-Dist: opentelemetry-api>=1.26.0; extra == "otel"
diff --git a/vllm-0.16.1rc0+corex.4.4.0.dist-info/RECORD b/vllm-0.17.0+corex.20260420090923.dist-info/RECORD
similarity index 87%
rename from vllm-0.16.1rc0+corex.4.4.0.dist-info/RECORD
rename to vllm-0.17.0+corex.20260420090923.dist-info/RECORD
index d7a8649..4b5fab1 100644
--- a/vllm-0.16.1rc0+corex.4.4.0.dist-info/RECORD
+++ b/vllm-0.17.0+corex.20260420090923.dist-info/RECORD
@@ -1,17 +1,16 @@
../../../bin/vllm,sha256=X5a5Mk4d820sk-Ykhco4CdCc8Bp3-tHJuyboP7UXkgw,172
-vllm-0.16.1rc0+corex.4.4.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
-vllm-0.16.1rc0+corex.4.4.0.dist-info/METADATA,sha256=Moq2RINlNwYeWq4WfWK1brgpmC_yY_sdEEEWz8is5uQ,9243
-vllm-0.16.1rc0+corex.4.4.0.dist-info/RECORD,,
-vllm-0.16.1rc0+corex.4.4.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
-vllm-0.16.1rc0+corex.4.4.0.dist-info/WHEEL,sha256=YCfwYGOYMi5Jhw2fU4yNgwErybb2IX5PEwBKV4ZbdBo,91
-vllm-0.16.1rc0+corex.4.4.0.dist-info/direct_url.json,sha256=-_DIKKgt63x-HOusl6PX9Hvuz52Kdn6rtr73dXXPyUQ,265
-vllm-0.16.1rc0+corex.4.4.0.dist-info/entry_points.txt,sha256=N3YeH_2RPW4YWbdFCRGvTmJ1qSzLsxllfCjpUf5e4UA,276
-vllm-0.16.1rc0+corex.4.4.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
-vllm-0.16.1rc0+corex.4.4.0.dist-info/top_level.txt,sha256=fAgb8Pt4zQoKTUA3ZnKEIgcjh0L97_dwEjYDTL5MEEo,5
-vllm/__init__.py,sha256=fS_Nsr5gWhdic4rOVRUVgy1PvdSFt-iIxjh_wb1-KHY,3881
+vllm-0.17.0+corex.20260420090923.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
+vllm-0.17.0+corex.20260420090923.dist-info/METADATA,sha256=QtgKJZW0hL48Lkk1ylapaZSJL0FeD6duFwiioo_yu40,9512
+vllm-0.17.0+corex.20260420090923.dist-info/RECORD,,
+vllm-0.17.0+corex.20260420090923.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
+vllm-0.17.0+corex.20260420090923.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
+vllm-0.17.0+corex.20260420090923.dist-info/direct_url.json,sha256=dfda9jaQTAyViphUkyn_78MzNasJkr2bA6plpqmvM-k,309
+vllm-0.17.0+corex.20260420090923.dist-info/entry_points.txt,sha256=N3YeH_2RPW4YWbdFCRGvTmJ1qSzLsxllfCjpUf5e4UA,276
+vllm-0.17.0+corex.20260420090923.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
+vllm-0.17.0+corex.20260420090923.dist-info/top_level.txt,sha256=fAgb8Pt4zQoKTUA3ZnKEIgcjh0L97_dwEjYDTL5MEEo,5
+vllm/__init__.py,sha256=_TcI5aE6vpjFZq_XmmsZh873DAVIuxnBIY_2p_nUM0Y,3661
vllm/__pycache__/__init__.cpython-312.pyc,,
vllm/__pycache__/_aiter_ops.cpython-312.pyc,,
-vllm/__pycache__/_bc_linter.cpython-312.pyc,,
vllm/__pycache__/_custom_ops.cpython-312.pyc,,
vllm/__pycache__/_oink_ops.cpython-312.pyc,,
vllm/__pycache__/_xpu_ops.cpython-312.pyc,,
@@ -34,9 +33,8 @@ vllm/__pycache__/scripts.cpython-312.pyc,,
vllm/__pycache__/sequence.cpython-312.pyc,,
vllm/__pycache__/tasks.cpython-312.pyc,,
vllm/__pycache__/version.cpython-312.pyc,,
-vllm/_aiter_ops.py,sha256=giwrq3xY5REN8-dp_FuSq9yYPZ0uHz7kUo_mSHpVlzY,56984
-vllm/_bc_linter.py,sha256=2F7ZE_cba80XNDfbrZKVBunazE7ZDxVMtpOUdEFdH3w,1105
-vllm/_custom_ops.py,sha256=B3_ssd1AKErsUByn6nzMfwZttRGooGspgHKK76cbuNE,127805
+vllm/_aiter_ops.py,sha256=3m_hqOTaJaCJKHGCvWmmroU7GyVlzogqwxVdz2rlKFQ,61895
+vllm/_custom_ops.py,sha256=B1njqipvzncK3HpOspkwYyQ9-CL1tubXDbePbDR8YUs,128364
vllm/_oink_ops.py,sha256=AcvOGv-Yr-4MU9AkUU3G9B2kTs2c_TAHZGKTloSCfgM,3057
vllm/_xpu_ops.py,sha256=Mdpuen6oVkGmy0At4rrzdBHBEicEfquh0tqrIkGR8kw,5287
vllm/assets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -55,10 +53,11 @@ vllm/benchmarks/__pycache__/__init__.cpython-312.pyc,,
vllm/benchmarks/__pycache__/datasets.cpython-312.pyc,,
vllm/benchmarks/__pycache__/latency.cpython-312.pyc,,
vllm/benchmarks/__pycache__/mm_processor.cpython-312.pyc,,
+vllm/benchmarks/__pycache__/plot.cpython-312.pyc,,
vllm/benchmarks/__pycache__/serve.cpython-312.pyc,,
vllm/benchmarks/__pycache__/startup.cpython-312.pyc,,
vllm/benchmarks/__pycache__/throughput.cpython-312.pyc,,
-vllm/benchmarks/datasets.py,sha256=2AwzVH9a_SW3F6_bYtQFDU04TzVZFE3NqDZH9uY_2x0,128835
+vllm/benchmarks/datasets.py,sha256=hbTjt0B-81rTZMzUVO_c-poGdZFEWC1aOjdh4_haQeI,129511
vllm/benchmarks/latency.py,sha256=S-LAXVvNOdfRJRVj4KtX0WVRIdyD6-l7BOnzCfkB_1I,5837
vllm/benchmarks/lib/__init__.py,sha256=BlbNrv7ga0LXYeHLKKiey2woiGh5MkW36g7QG6L6BN0,142
vllm/benchmarks/lib/__pycache__/__init__.cpython-312.pyc,,
@@ -69,7 +68,8 @@ vllm/benchmarks/lib/endpoint_request_func.py,sha256=v3Hcx5fBVsjgVDDxzg1T3A7xQezc
vllm/benchmarks/lib/ready_checker.py,sha256=Hth9XB4BZsleV_vs0GiKBxb7bc-2gk0BJ4dR14rHtMA,2607
vllm/benchmarks/lib/utils.py,sha256=QZ45aYXPSbqa_V-WY2tOjd1b3KsEo8X7aOM6qSXNAPk,4248
vllm/benchmarks/mm_processor.py,sha256=gqIa5hIoYcBQ7Yg50gda0c24TIhGySKhKEw5Y2AbQH8,17883
-vllm/benchmarks/serve.py,sha256=jthwgqL4yjHY4AKU6BBLfGlpy61OfyvppgiPbWoey_Y,67322
+vllm/benchmarks/plot.py,sha256=UZF66p7jp-N2H53nwKQGeoTzO8085LyjvMAZEXWRmOY,9975
+vllm/benchmarks/serve.py,sha256=8qPEljDYvDJ6IStp0sUWrNhCEkr6eNv3BgJ6MHB5HtI,72360
vllm/benchmarks/startup.py,sha256=L7yRKATwX4CUL1DqlugNCiysFEhLA3c7BVaa4Bp8ldU,11743
vllm/benchmarks/sweep/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/benchmarks/sweep/__pycache__/__init__.cpython-312.pyc,,
@@ -78,18 +78,18 @@ vllm/benchmarks/sweep/__pycache__/param_sweep.cpython-312.pyc,,
vllm/benchmarks/sweep/__pycache__/plot.cpython-312.pyc,,
vllm/benchmarks/sweep/__pycache__/plot_pareto.cpython-312.pyc,,
vllm/benchmarks/sweep/__pycache__/serve.cpython-312.pyc,,
-vllm/benchmarks/sweep/__pycache__/serve_sla.cpython-312.pyc,,
+vllm/benchmarks/sweep/__pycache__/serve_workload.cpython-312.pyc,,
vllm/benchmarks/sweep/__pycache__/server.cpython-312.pyc,,
vllm/benchmarks/sweep/__pycache__/startup.cpython-312.pyc,,
vllm/benchmarks/sweep/__pycache__/utils.cpython-312.pyc,,
-vllm/benchmarks/sweep/cli.py,sha256=_dDV5dIvxKflDgaXURwvYj25pvMvjxwstgMJag3J4fQ,1455
+vllm/benchmarks/sweep/cli.py,sha256=55RWr5Nrln867HPRz5LMv7gCdOqi_mInZ_jIHlkOFbw,1485
vllm/benchmarks/sweep/param_sweep.py,sha256=7fOlYnrpnZc6kMxfAW203AFHADSugT-wlR3ajaRuWls,5594
-vllm/benchmarks/sweep/plot.py,sha256=ghtvWMH8OuHENd5cc82kzAIkSn17_BhiN2MosBLYMS0,19800
-vllm/benchmarks/sweep/plot_pareto.py,sha256=pMFyla6LX7dafUWSlHvuxa9fikilD91MG4Qx-uD07hE,11015
-vllm/benchmarks/sweep/serve.py,sha256=FQXjWRoK7mM5xjznHQJK4AeH_atheOx9hGq9oR6NhuE,14407
-vllm/benchmarks/sweep/serve_sla.py,sha256=IJIpj7W6aDaoN28xM4AZ5WD15nwMvecTfgFLMzHIVK8,9390
+vllm/benchmarks/sweep/plot.py,sha256=Whg0sFICzqu-j2YxGv9eylawYcsilCPkW0SbNFdu9IA,19833
+vllm/benchmarks/sweep/plot_pareto.py,sha256=NbDyL63B5XwF93injxzGNFm9tX6n0IPETSHy5IAO7VQ,10992
+vllm/benchmarks/sweep/serve.py,sha256=WLMVR3F6KM3Y7Nx2vy65BbuQm41bI2iCLwgnIQyIsPw,15634
+vllm/benchmarks/sweep/serve_workload.py,sha256=rw339djPijtbdLI5lTq4Hl_VYOrlfQiuyjRmCGFjZj0,10218
vllm/benchmarks/sweep/server.py,sha256=RBUnoU85t_a4yMTFbZhdz77IhJLhQbg5KkmQsBsiTSE,4619
-vllm/benchmarks/sweep/startup.py,sha256=rranpLrrfBMvwSj_r0jBLhQxFsmydyiHTrotezRuQFk,12302
+vllm/benchmarks/sweep/startup.py,sha256=gCHF4JNCz80wApQJvNvjKSdVPSKNKmgh_eJJpfYZeBo,13371
vllm/benchmarks/sweep/utils.py,sha256=YzhiqE4jmEN7TzUnzQXnlm8TLERkNjgoQ_SF6jJ8hB4,232
vllm/benchmarks/throughput.py,sha256=Kp14mTYV8cf1vcabXar15tWlCfRG5KdQ3HM3vJocwPE,34962
vllm/collect_env.py,sha256=lih4_OSz6OlOE6eWoAYhMCR5aTUs80ZziBLGEWlpQgo,27835
@@ -106,12 +106,12 @@ vllm/compilation/__pycache__/monitor.cpython-312.pyc,,
vllm/compilation/__pycache__/partition_rules.cpython-312.pyc,,
vllm/compilation/__pycache__/piecewise_backend.cpython-312.pyc,,
vllm/compilation/__pycache__/wrapper.cpython-312.pyc,,
-vllm/compilation/backends.py,sha256=sVkt0Ydf8ruosaV-KWAJXKoUUTkSifDXwIv4DimJTiI,43776
+vllm/compilation/backends.py,sha256=970-FYlOkVesieL04LjD-03xBMRTJzXuFMwn3aWLfwQ,41386
vllm/compilation/base_static_graph.py,sha256=XP6OBtuEqkLNQc3ZPr6TrwCzokLIARBCBXUw4pIrUjo,2007
-vllm/compilation/caching.py,sha256=Is0Ej-9Wu3tWXuLixlsgaXcw9FJIl1f238AWPiU9ggM,19889
-vllm/compilation/compiler_interface.py,sha256=El4c_sUAvKaqcK1pIMXxyhctjemvwU0sZnNT2OFhmoE,25450
-vllm/compilation/counter.py,sha256=WiKJRLbgqBchplITJ7FrpZ54f6A87829K8VJa1un6hA,1735
-vllm/compilation/cuda_graph.py,sha256=yhpcOz7IoRffPfBgVleLF7cKSbImH6McViTxcWGPhnk,13462
+vllm/compilation/caching.py,sha256=r4YuAboiTyvirfsKKREmVIQnmlaKNqsnZ_fS1LBvx4Q,20301
+vllm/compilation/compiler_interface.py,sha256=5nJ-WxRMyK_6BMWlRxrCrAJXJfTbi86C56yrtxM3Xo4,27254
+vllm/compilation/counter.py,sha256=fx9MGc70eex-sqBlL-IKdX-pmYBRmxgJKNOYi4zURTI,1854
+vllm/compilation/cuda_graph.py,sha256=srsYfNoiv6X_2gUiLzI9Wq6q3MeZqyJ_GLjW3gOBNS4,13345
vllm/compilation/decorators.py,sha256=xOb8eZ5JhusNy8Jl0lL0L3oHHtv4arwQFDkQ0OGHM7Y,26849
vllm/compilation/monitor.py,sha256=VkCQdUi9cLAG2DZnUj2bNfkOkRmSC-ooFB3JM9KrFmM,2202
vllm/compilation/partition_rules.py,sha256=hX0vsG3IceodNxfNAfvU5SBZhBANmLe3Kikz1v8c-k0,2296
@@ -136,13 +136,13 @@ vllm/compilation/passes/fusion/__pycache__/sequence_parallelism.cpython-312.pyc,
vllm/compilation/passes/fusion/act_quant_fusion.py,sha256=HcSahkJO6nOhDJG0JYuKJ64qmEqC1jO4hnNgG95yl6o,6929
vllm/compilation/passes/fusion/allreduce_rms_fusion.py,sha256=EQIhoUTSenqef9YBWi1sp1-4iEOqeKVgSdIKWAGeWeQ,32060
vllm/compilation/passes/fusion/attn_quant_fusion.py,sha256=LXR4NrhQ7H9IFO_gEFNFaai9rGvAEenWpj5983uBwys,13228
-vllm/compilation/passes/fusion/collective_fusion.py,sha256=0dpP9QNj5DIx44jboZUMH96C7mH2fQBaE3VXWk1lNnI,15000
+vllm/compilation/passes/fusion/collective_fusion.py,sha256=IUOqTzy7gaH8_GML13sOjns-NjbVxP0s768dUjyr-a8,15000
vllm/compilation/passes/fusion/matcher_utils.py,sha256=9sT6rbkiFb9otMrqEhy1BHmbs-FOPOE2TfpGW5p05eQ,15834
vllm/compilation/passes/fusion/qk_norm_rope_fusion.py,sha256=R6hOWdYz5nTAgeAPokPpVcDrcLNrFykHNUKYK72NSz8,8906
vllm/compilation/passes/fusion/rms_quant_fusion.py,sha256=unEcPpWRu3tsjDIxKjTFkPsd5LkGToFIHA_HQ9LtXds,22500
-vllm/compilation/passes/fusion/rocm_aiter_fusion.py,sha256=v1JTk9hCukpRlwF0j8DiQeP1-FGejRHe2w_HdTsxXmg,17267
+vllm/compilation/passes/fusion/rocm_aiter_fusion.py,sha256=SvOQog2ufjG7yrQ6IcKnnYkusWrFQ_dYtOT34rjZvpI,17091
vllm/compilation/passes/fusion/rope_kvcache_fusion.py,sha256=ZOxtqsNg-H8390X0G4Ouy-vV1F-G7jgGe7xPfYqYt3Q,8267
-vllm/compilation/passes/fusion/sequence_parallelism.py,sha256=4qGn7LENXs-9PxmRbI5fgmJFu2gDE1YRzGrZk7Etkgs,17753
+vllm/compilation/passes/fusion/sequence_parallelism.py,sha256=fy58E1FEgTq0LzjwkKeRdtdMPMoRrNn44jK-hyg7CIY,17666
vllm/compilation/passes/fx_utils.py,sha256=BvALST35BtfoeBhH0_H39Wrn7RaagLv7Dizm_kB3ujM,2681
vllm/compilation/passes/inductor_pass.py,sha256=IdgVtsX6yzMQiJ7ji00Fltu7qWkPaHQNuMyoUIhBahs,4003
vllm/compilation/passes/pass_manager.py,sha256=Xf51XSGWM806CT6fnTBUVuv2DCRHv1xBPXhguYsXfBg,7003
@@ -153,14 +153,14 @@ vllm/compilation/passes/utility/__pycache__/noop_elimination.cpython-312.pyc,,
vllm/compilation/passes/utility/__pycache__/post_cleanup.cpython-312.pyc,,
vllm/compilation/passes/utility/__pycache__/scatter_split_replace.cpython-312.pyc,,
vllm/compilation/passes/utility/__pycache__/split_coalescing.cpython-312.pyc,,
-vllm/compilation/passes/utility/fix_functionalization.py,sha256=7lRG0kQhM_XY07dztxOA_oH7HpJrcfWPcl6EidIYX-A,12785
+vllm/compilation/passes/utility/fix_functionalization.py,sha256=gzVk0IyuYqhLixrRyyhIAmHOAi_Pl-mhOW2M-6K6iHI,13021
vllm/compilation/passes/utility/noop_elimination.py,sha256=5IfkneFTFC_vK37tgGSnrAwvWh4DAJ9rgMeqgDH8A-k,5287
vllm/compilation/passes/utility/post_cleanup.py,sha256=HDixT8IJFly8-Mp2OhesTlKLl_WbVz5dUt8szn11jeM,743
vllm/compilation/passes/utility/scatter_split_replace.py,sha256=E4QCzDCXdooy6eW40aBUthduSUvdyGhOI8otxf3D_do,6168
vllm/compilation/passes/utility/split_coalescing.py,sha256=CwKs2eGtQ9ez4Si5FOlpuVgHc_KBgHiEbfvxY6i9CHE,2274
vllm/compilation/passes/vllm_inductor_pass.py,sha256=JcsztOiPeFIIdSeiKesbcsQBfbqnOt5mNqf9BKCXDpM,6783
-vllm/compilation/piecewise_backend.py,sha256=coB0urtV6sdcTvySBnZeWLPYU29qI08PFfRR6KLqJaI,13986
-vllm/compilation/wrapper.py,sha256=Ntbqm4hGTiIouXGUjfQPxHVQd3gxvaIue55r3yRoJ98,13050
+vllm/compilation/piecewise_backend.py,sha256=n49Pm2d0pmsmX8acy7IjqNJmTouY7vzS1_7qrJedjN8,13956
+vllm/compilation/wrapper.py,sha256=rTZkzKRt-k5xE77Cs-y54EqllTPeYuC0Lmo5OUT1jsk,15182
vllm/config/__init__.py,sha256=Johf-4RN5M25Yag9OQrJytrGprPizqln8CyoK0eEyjg,3810
vllm/config/__pycache__/__init__.cpython-312.pyc,,
vllm/config/__pycache__/attention.cpython-312.pyc,,
@@ -188,9 +188,9 @@ vllm/config/__pycache__/structured_outputs.cpython-312.pyc,,
vllm/config/__pycache__/utils.cpython-312.pyc,,
vllm/config/__pycache__/vllm.cpython-312.pyc,,
vllm/config/__pycache__/weight_transfer.cpython-312.pyc,,
-vllm/config/attention.py,sha256=9tS0q7u4dpz51FUceowhQnBj5Tolxe788U3m-mxGV60,2518
+vllm/config/attention.py,sha256=F93F9gXDE9GuB-lcqQYE4QgLASkgJqLr9vowGjkATNQ,2525
vllm/config/cache.py,sha256=eWUIia30-vZZxYWDI_hJQvxX7lDL0JzFHLnz3Pja48c,11993
-vllm/config/compilation.py,sha256=KePNi2O6PVBGiAoGMfVPWpuu0LWlHvr5SdLg1lIhnRo,50769
+vllm/config/compilation.py,sha256=1Mc_PGxVpIkwj84JECS_oZtqZToDKRAJwyqyqc5ZJ80,51104
vllm/config/device.py,sha256=NlIM9KMfuETmlvp_JHhxMrvKDsmoO3_-wkrMwV08M2M,2719
vllm/config/ec_transfer.py,sha256=6Cc3nmfPrfcFSkDHOmO_1QHoV7FhzUi8QtZna-G4-pw,3841
vllm/config/kernel.py,sha256=bCrlswXTw1wRJd9AbtSPyqiQnKzwBPHOKJ5H1yM9zA8,2694
@@ -198,31 +198,32 @@ vllm/config/kv_events.py,sha256=ZB7tW09-WdmXAC0uXnSSgepOzQLn3yqVQam2_10MRG4,1591
vllm/config/kv_transfer.py,sha256=-CnBqhNYnEEIhLTtRRBTNKpPe-gHcga7qEzcfAd3h0k,4254
vllm/config/load.py,sha256=qlVIok6eKmoRjEe63KXJheU_W59xn8Dh7tI7_qyR8Ck,5697
vllm/config/lora.py,sha256=igUl4RE5B38Q46JONF0HZNaWpiCsFKdKM_Oj7kron6k,4473
-vllm/config/model.py,sha256=AMS9fGA722XhhOg84AX8AKujBs07U5mEiScV2eh7I7U,83285
+vllm/config/model.py,sha256=uoGS6kmKizZJcC8Vx_6f18SJHy-jQOliwLVLzySI9w0,83647
vllm/config/model_arch.py,sha256=Z10j3I3r851inz9mivcErhNMtINjAMMGpaMpb318PKQ,1658
vllm/config/multimodal.py,sha256=2zBq1kXtUUoZ6kQ_c6TxtUHwnxxxJBQ7HzmoeFcaDpw,10917
vllm/config/observability.py,sha256=TN02f-Clehu5Av-FfSrGXPVHIn2MIobJfaTc5vo9Fps,6402
vllm/config/offload.py,sha256=by2z2aJzBYhHPVFVgRaQxiQbFnoh3LYvu9nrtjht5Cw,6522
-vllm/config/parallel.py,sha256=QH_NWb2FrERrfN5tbI5oE8KVFka_nlPpDYWkAdyBuAU,29859
+vllm/config/parallel.py,sha256=b0rHxXQF9lwzbURqjJjkO-NtcUofHslOA08jhxaOZFk,35453
vllm/config/pooler.py,sha256=v3OFT5ep3tjWwsXQIJDngXybuqFDaMSm2LrMCl3sdrI,5382
vllm/config/profiler.py,sha256=3tFoS0peSYRgpkZUIfDUrEMU0yJjvothMcFRyLKEDxU,4799
vllm/config/scheduler.py,sha256=dEUbBOBjvjVN8XlrYQu17zx4PBYxEF9QyFZpr8z86ws,12671
-vllm/config/speculative.py,sha256=luo8LX7UCcCzs1MHBAcJ_aw28rqzS8_Fz5GSJIddU6I,35030
+vllm/config/speculative.py,sha256=j8qwVoiZ4E9L1-vxcrnFMr7-tnHL_ei_kt4FAA7QQNk,42136
vllm/config/speech_to_text.py,sha256=mXSpXPUVxjxGrEz2xCLo1gllx3SYybIiblxphCXGcSE,1560
vllm/config/structured_outputs.py,sha256=wCVC9wuK7vaYo0ZbQuv45s-y6djocq75Qof_9Po5p9I,3296
vllm/config/utils.py,sha256=0swsZ-Ddp_mKhemVKjG3Kh9OcaaosnyJv2NBtodk0Go,15108
-vllm/config/vllm.py,sha256=uFE5EZ91DXFNq_eYhnR0Ea3Za1BwQkDeSrnEE8XSFsE,76962
-vllm/config/weight_transfer.py,sha256=BRBCpStnmCHbCw_5v8c6qYO3Kr-W61Cf9ElIa1Q_wOI,363
+vllm/config/vllm.py,sha256=IetpfDsD5qBAyzWS_GguYi0K_D7EmCLpMT65RsMwqlw,79354
+vllm/config/weight_transfer.py,sha256=iZGfB2dMA5vdL4OhWoMeWyuBoGJxyUu5JTxrIxB9ny8,370
vllm/connections.py,sha256=66GGcz2MO5RtwBp4NinqL355v5lJq8EjB9cMAB7eZtI,5352
vllm/device_allocator/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/device_allocator/__pycache__/__init__.cpython-312.pyc,,
vllm/device_allocator/__pycache__/cumem.cpython-312.pyc,,
-vllm/device_allocator/cumem.py,sha256=2itL8dIqAIY8tkeNvFh6jHRqN6aZuYKXhJ4d1THDvxU,11525
+vllm/device_allocator/cumem.py,sha256=4zAbmNjZxkOFhK__yKLqLUuCM4lMGRqxDDwleZ1hhRI,11632
vllm/distributed/__init__.py,sha256=l_KMMLMmYq-MqxY5OeTuQjA38RYwdS4TtVy8JafQq3E,191
vllm/distributed/__pycache__/__init__.cpython-312.pyc,,
vllm/distributed/__pycache__/communication_op.cpython-312.pyc,,
vllm/distributed/__pycache__/kv_events.cpython-312.pyc,,
vllm/distributed/__pycache__/parallel_state.cpython-312.pyc,,
+vllm/distributed/__pycache__/stateless_coordinator.cpython-312.pyc,,
vllm/distributed/__pycache__/utils.cpython-312.pyc,,
vllm/distributed/communication_op.py,sha256=PD-Kxg2zjs6JHTncHi9vGUetQ4aFiT0uV3fQcjIL0bg,1323
vllm/distributed/device_communicators/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -245,16 +246,16 @@ vllm/distributed/device_communicators/__pycache__/shm_broadcast.cpython-312.pyc,
vllm/distributed/device_communicators/__pycache__/shm_object_storage.cpython-312.pyc,,
vllm/distributed/device_communicators/__pycache__/symm_mem.cpython-312.pyc,,
vllm/distributed/device_communicators/__pycache__/xpu_communicator.cpython-312.pyc,,
-vllm/distributed/device_communicators/all2all.py,sha256=CXZN9HaxCO5MI9njTUKD0Xe2s0emqR5g6-cp9yUrbNQ,23610
-vllm/distributed/device_communicators/all_reduce_utils.py,sha256=FXN49WKCqJM3WKxMlQntN1xVwrlss5IjrGRXjfIaoYY,12324
-vllm/distributed/device_communicators/base_device_communicator.py,sha256=2R29s4AZIQuagNYinQtWTOTkkCiLO3P4PbpjY55UQA4,13373
-vllm/distributed/device_communicators/cpu_communicator.py,sha256=ULLbY3DLSpqsLJg3znYfRTM_ooI-uRcD2WQOQRMRAuM,10115
-vllm/distributed/device_communicators/cuda_communicator.py,sha256=BXe2fjCkTenM1tAs30Bzl0G2dB6SESIMxW9vkFVweFI,17214
+vllm/distributed/device_communicators/all2all.py,sha256=b-lzmlCCIZYn8jhL_yr6yRr1vkArUq4bI9qgRqIBAAE,21093
+vllm/distributed/device_communicators/all_reduce_utils.py,sha256=tKM6UaaviabxhpA-t4EGclm23ZEYxLHfxob6X38nktA,13965
+vllm/distributed/device_communicators/base_device_communicator.py,sha256=tU-YjSUFT2FkgoYR7TzMLhU6hxP-1hE-EsTTIPSL-m8,14582
+vllm/distributed/device_communicators/cpu_communicator.py,sha256=p6YxyWRfgKGJdA38G5UIK0tI79bCVvtPYoOBQW7q4qA,11137
+vllm/distributed/device_communicators/cuda_communicator.py,sha256=Iy2FjDDm4QXj8QErGZVDU2Ew4iXGwlu4HwutmCSnjOg,16797
vllm/distributed/device_communicators/cuda_wrapper.py,sha256=N1znzgx_UgCm5XLXVnKdtmtMcguhmFE0B6lxCPoVcGY,7483
vllm/distributed/device_communicators/custom_all_reduce.py,sha256=qwIMODFiNFmXGFWjr9wyzXZ6dl6WASbya1wEqAEmXP0,12857
vllm/distributed/device_communicators/flashinfer_all_reduce.py,sha256=x1gkMrLjnnNgzk50L6r7JaTBjWdpRIwzOMpxKw_B5PM,7846
vllm/distributed/device_communicators/mnnvl_compat.py,sha256=zgi7ly54YOyLax2TgmlHn5R_iQ6JsdQwxqixY3Hfcqw,1207
-vllm/distributed/device_communicators/pynccl.py,sha256=IZRiSwMlWK-yS5cwGBPr7kn4Ov4XSVhA61aVP5ggpxs,13902
+vllm/distributed/device_communicators/pynccl.py,sha256=6rk6TKGVUMqAXOK9Bj_70ug9s9iple6arbDlKe2WUus,14954
vllm/distributed/device_communicators/pynccl_allocator.py,sha256=D4IyD8MX76loodZhzbsHSmtT0LLoyhcZbIMFN_BZVMU,6260
vllm/distributed/device_communicators/pynccl_wrapper.py,sha256=Md9XRm8EPyhzErzSp8xyN8r86psxZLG7sa-w3Uvw5HM,19098
vllm/distributed/device_communicators/quick_all_reduce.py,sha256=BU5G_-FaMmper_qOYtsA8ViItWNE0u8XUlwoORUX1cw,10903
@@ -275,14 +276,22 @@ vllm/distributed/ec_transfer/ec_connector/base.py,sha256=Nx_zH02nSLHgSicRDx0E4x9
vllm/distributed/ec_transfer/ec_connector/example_connector.py,sha256=tCi_MwxdnEgJuzJ21sJzJOTqlrcxTYLDxixiopzx0hU,7215
vllm/distributed/ec_transfer/ec_connector/factory.py,sha256=LRo-mSc6_mm5a9WjcHTxzSO7Lds7HrLHYLdkCpWfmKs,3070
vllm/distributed/ec_transfer/ec_transfer_state.py,sha256=60Iu-N4qVnjSa2kXaysANPa77HX-cYomPTu7GRk5Evw,1152
+vllm/distributed/elastic_ep/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
+vllm/distributed/elastic_ep/__pycache__/__init__.cpython-312.pyc,,
+vllm/distributed/elastic_ep/__pycache__/elastic_execute.cpython-312.pyc,,
+vllm/distributed/elastic_ep/__pycache__/elastic_state.cpython-312.pyc,,
+vllm/distributed/elastic_ep/__pycache__/standby_state.cpython-312.pyc,,
+vllm/distributed/elastic_ep/elastic_execute.py,sha256=IfwBwevEgROygmxlOsHh0lyY1SKMRQdk4-jHPlVVp70,22059
+vllm/distributed/elastic_ep/elastic_state.py,sha256=xDT71-jpKGvC9-iKBt0HZ9osSvZIhY450yq-Fv2Km6M,21767
+vllm/distributed/elastic_ep/standby_state.py,sha256=t7A9rZOGCtUXxwFXebXwd9v1nGkcuxR0Hb1cb4vqPK0,3495
vllm/distributed/eplb/__init__.py,sha256=G5wu4iq7WIjDVMWqAfCUTyU-XiThNdAZd4OYu-SVpig,154
vllm/distributed/eplb/__pycache__/__init__.cpython-312.pyc,,
vllm/distributed/eplb/__pycache__/async_worker.cpython-312.pyc,,
vllm/distributed/eplb/__pycache__/eplb_state.cpython-312.pyc,,
vllm/distributed/eplb/__pycache__/eplb_utils.cpython-312.pyc,,
vllm/distributed/eplb/__pycache__/rebalance_execute.cpython-312.pyc,,
-vllm/distributed/eplb/async_worker.py,sha256=GZ1i90SmMYFqxaPmuKK2QH-MTVGbzUC3ICWrARxIOLo,7816
-vllm/distributed/eplb/eplb_state.py,sha256=tmOe5JfvpHvgjOf_EBXP2cqp7ufhi-bzSpdaMSoRZVc,48615
+vllm/distributed/eplb/async_worker.py,sha256=Wz5JRddZ_tI-l-y4m7IxU5LBBy-ADJigTVwmGXFSAYQ,7618
+vllm/distributed/eplb/eplb_state.py,sha256=FXlKFof025ZjKIatiM7aPp6YVDOosydOxeYwGcz4lFA,43591
vllm/distributed/eplb/eplb_utils.py,sha256=h8WJ1iqulRVQbxtoxb9xg7MZ5Ax0vPQkvjOCoVoV3K8,2281
vllm/distributed/eplb/policy/__init__.py,sha256=2usmbIhEKDKxq4yMTMBkiRs4uem6AK6jJPkmyH3G034,542
vllm/distributed/eplb/policy/__pycache__/__init__.cpython-312.pyc,,
@@ -290,8 +299,8 @@ vllm/distributed/eplb/policy/__pycache__/abstract.cpython-312.pyc,,
vllm/distributed/eplb/policy/__pycache__/default.cpython-312.pyc,,
vllm/distributed/eplb/policy/abstract.py,sha256=xkphHsTmcIZSkrw5_CbMhxDiZsDMSEYqVGL9QW4jHcE,1624
vllm/distributed/eplb/policy/default.py,sha256=TOe_1Vy5dxCau9rVn3uH9RQP7UXJJ2KjA8Ya0m9tDIs,16382
-vllm/distributed/eplb/rebalance_execute.py,sha256=ufQLtQp-RLGGORVVB8scSg2bZpYbk7SUNorE6_S3xQc,28133
-vllm/distributed/kv_events.py,sha256=_sUdT-GYOZPmzgcc7Qhkqm0NaDdJv4ZLxj0t_hOLqRA,16513
+vllm/distributed/eplb/rebalance_execute.py,sha256=zLj2KcKTjod--3GRQPaU79tkCgKX1oFCRfiRSMTIC0o,29492
+vllm/distributed/kv_events.py,sha256=f22G2ao2-PyXuA5PeHbV_mxgCTJO6rey6jqXnplHYjw,16658
vllm/distributed/kv_transfer/README.md,sha256=cKIw6vXYSxBlf0wWwO7haP82CX2oB2QzW0-RxZE5mT4,2007
vllm/distributed/kv_transfer/__init__.py,sha256=wb5OWOxpI7rmB6NVrSKJHHQPBWpp_3NYseEOo_0mgOo,552
vllm/distributed/kv_transfer/__pycache__/__init__.cpython-312.pyc,,
@@ -303,23 +312,25 @@ vllm/distributed/kv_transfer/kv_connector/__pycache__/base.cpython-312.pyc,,
vllm/distributed/kv_transfer/kv_connector/__pycache__/factory.cpython-312.pyc,,
vllm/distributed/kv_transfer/kv_connector/__pycache__/utils.cpython-312.pyc,,
vllm/distributed/kv_transfer/kv_connector/base.py,sha256=KuKixI9XMfNTMWanVED-kedkMyMFtdcT34QO26lweJ0,370
-vllm/distributed/kv_transfer/kv_connector/factory.py,sha256=C2YTjVvEC7r_P7x3wGIGuXR1in0nLWkQYg7noMDKJFk,7245
-vllm/distributed/kv_transfer/kv_connector/utils.py,sha256=FnRC_x34z_AMTN_ZzmvTXpG0xu3uOQBdRKZR3VVLioU,19179
+vllm/distributed/kv_transfer/kv_connector/factory.py,sha256=DhEKiTtBrvkL8MYZwUfJ1lfPHK7exTdDOltajTzis94,7443
+vllm/distributed/kv_transfer/kv_connector/utils.py,sha256=4gLVJsbKl8pIxMCk1bV3l4Socd9ex2H9vRUwB93k2_8,20481
vllm/distributed/kv_transfer/kv_connector/v1/__init__.py,sha256=W-tsytEv3P4VgZgtpvGDIcrKUOc6azJZ6_PI3Is__Mo,508
vllm/distributed/kv_transfer/kv_connector/v1/__pycache__/__init__.cpython-312.pyc,,
vllm/distributed/kv_transfer/kv_connector/v1/__pycache__/base.cpython-312.pyc,,
vllm/distributed/kv_transfer/kv_connector/v1/__pycache__/decode_bench_connector.cpython-312.pyc,,
vllm/distributed/kv_transfer/kv_connector/v1/__pycache__/example_connector.cpython-312.pyc,,
+vllm/distributed/kv_transfer/kv_connector/v1/__pycache__/example_hidden_states_connector.cpython-312.pyc,,
vllm/distributed/kv_transfer/kv_connector/v1/__pycache__/lmcache_connector.cpython-312.pyc,,
vllm/distributed/kv_transfer/kv_connector/v1/__pycache__/lmcache_mp_connector.cpython-312.pyc,,
vllm/distributed/kv_transfer/kv_connector/v1/__pycache__/metrics.cpython-312.pyc,,
vllm/distributed/kv_transfer/kv_connector/v1/__pycache__/multi_connector.cpython-312.pyc,,
vllm/distributed/kv_transfer/kv_connector/v1/__pycache__/nixl_connector.cpython-312.pyc,,
vllm/distributed/kv_transfer/kv_connector/v1/__pycache__/offloading_connector.cpython-312.pyc,,
-vllm/distributed/kv_transfer/kv_connector/v1/base.py,sha256=bfDdRRgMWeKL8HsE6_Gi8WWmdS5XmCyWWsarXG9C4Xk,21225
+vllm/distributed/kv_transfer/kv_connector/v1/base.py,sha256=QuWuf-D4UysNmPAcigABKrBzlXcsrNxiY5Xi8HJnQh4,22102
vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py,sha256=bzlQ8_pgzmMviZaOnHsdNhlWH6TUALRsajNoBHsbV4A,15531
-vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py,sha256=8O6lyNeinE1BHJwAx8kjBpBLk106_CGcJ9jHLv2bp_I,16765
-vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py,sha256=zDBxiFNCjg5kX1XKYXmMfRze0iLidEJaIXIEEFBE1iw,12005
+vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py,sha256=5sFs135kFbU-WfJT796s3rg63XXqGL7mAHtZ5kUAnXc,17539
+vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py,sha256=UoXHU0LKUg06MbiZfDtqo0rYhYH0WxDlfHRlZoPfbW4,12484
+vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py,sha256=rp-i3iAaXfkyniz4kA6_PkvuH0uo1mWLaceFuuhrM50,12429
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/__init__.py,sha256=k3tOaUj8P0IjtWcZ0CGxO3iE8IuALduzLpW5ejYJAFE,426
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/__pycache__/__init__.cpython-312.pyc,,
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/__pycache__/multi_process_adapter.cpython-312.pyc,,
@@ -334,7 +345,7 @@ vllm/distributed/kv_transfer/kv_connector/v1/mooncake/__init__.py,sha256=47DEQpj
vllm/distributed/kv_transfer/kv_connector/v1/mooncake/__pycache__/__init__.cpython-312.pyc,,
vllm/distributed/kv_transfer/kv_connector/v1/mooncake/__pycache__/mooncake_connector.cpython-312.pyc,,
vllm/distributed/kv_transfer/kv_connector/v1/mooncake/__pycache__/mooncake_utils.cpython-312.pyc,,
-vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py,sha256=3wBhN0pKAV-6ae59q7xoi3vwKSJU-SZ0fRFkP8ALPeA,51087
+vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py,sha256=wTGfVM2k2tCu8uqirfPLxBXryUx7bom6CesZJz7FL8o,52127
vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py,sha256=wSg8lscXSMpUESTGpDr8xFqiLgAkVsPJTXtkVIzurTA,4170
vllm/distributed/kv_transfer/kv_connector/v1/moriio/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/distributed/kv_transfer/kv_connector/v1/moriio/__pycache__/__init__.cpython-312.pyc,,
@@ -344,8 +355,8 @@ vllm/distributed/kv_transfer/kv_connector/v1/moriio/__pycache__/moriio_engine.cp
vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py,sha256=hAA-iozOOJOf-IdwuynQDB3imwtwbKGI_DLUGuI5l6Q,8898
vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py,sha256=uyL5lnmtIK1LQKTu4y0Br-IFLx1PDrSBBbQiGRuXfVI,59062
vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_engine.py,sha256=5vX3LF6AOwyJFCV3jwEBNIDi5Ai2cEC964ScgBYrgwM,21306
-vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py,sha256=AyIOBx0NEzJfiitEeMsn7e9y4d25VeBvnTTC02mZPpo,20165
-vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py,sha256=x2mCraQgUZgWM7scX9dLQZ9rZ4KQTvEjfyHefi2Qq2o,116288
+vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py,sha256=EnyCeJ3Yy_EIuVjfIidjpQfo6tnFf__WjmU0n-_Qekw,20856
+vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py,sha256=9SuIjfh79psb536pCgOHEkQH87ePtFiDB8QRs4iVd5Y,121294
vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py,sha256=GbNAYMx4Fw6nqLqy9PJTcaKPfNyVZgEy_kcGosIThV0,31212
vllm/distributed/kv_transfer/kv_connector/v1/p2p/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/distributed/kv_transfer/kv_connector/v1/p2p/__pycache__/__init__.cpython-312.pyc,,
@@ -356,17 +367,20 @@ vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py,sha256=nY
vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py,sha256=gdHYw55Qm1z8cz5cWoq76siI8tL6tewzYldcxNCG2s4,23390
vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py,sha256=A2mMTFKJ1zBYEo0gPhe166-ms1S_gAaL4Zs9tULMbTc,9225
vllm/distributed/kv_transfer/kv_transfer_state.py,sha256=js4h2XWbnmB8IoshDJBaGK9yc7bX10uB2D7PY3A3TEs,2277
-vllm/distributed/parallel_state.py,sha256=Xv9s57_NweT5g3Vg6X3XL4wlwTSNbTvLAf1vT1uMed4,71800
-vllm/distributed/utils.py,sha256=BLdN-Vn16TYG-wY4MTEMXcGIc6MIKbMSk_60a4pQFSM,21419
+vllm/distributed/parallel_state.py,sha256=CUE7LArKpC31tWDjrPVsWQn06JCX_59wluqrdWe3Wc8,77664
+vllm/distributed/stateless_coordinator.py,sha256=opc5FRC4IU4TZ1CsFni87nO8uvq_kdv2aXQ2fE51eC0,11625
+vllm/distributed/utils.py,sha256=0aLYrVvnGLdokC9mCRzL8kahYKtHlJfpciPUgol-_dk,23717
vllm/distributed/weight_transfer/__init__.py,sha256=BNylvNnKXwhGl-mBSwRKuAdi0sGXpZzaQRk4DHn1XZI,333
vllm/distributed/weight_transfer/__pycache__/__init__.cpython-312.pyc,,
vllm/distributed/weight_transfer/__pycache__/base.cpython-312.pyc,,
vllm/distributed/weight_transfer/__pycache__/factory.cpython-312.pyc,,
+vllm/distributed/weight_transfer/__pycache__/ipc_engine.cpython-312.pyc,,
vllm/distributed/weight_transfer/__pycache__/nccl_engine.cpython-312.pyc,,
vllm/distributed/weight_transfer/__pycache__/packed_tensor.cpython-312.pyc,,
-vllm/distributed/weight_transfer/base.py,sha256=WsY9kVujjSIvFiDtyNwx8D_LF8WCHA-t38NQKp1iieM,5084
-vllm/distributed/weight_transfer/factory.py,sha256=a6-Jk4YmVWSACbxqY9KmSmpx8yXzPg1_afVzhy15FUI,3972
-vllm/distributed/weight_transfer/nccl_engine.py,sha256=1z12EBYiXFamrRkQi_Q1f-gShrww20nN_8vOssEtRqQ,12459
+vllm/distributed/weight_transfer/base.py,sha256=9DmAtuLTxWJY2aRdHbZWvY-59P_1Iu1Ek6kG_WOIj-E,6245
+vllm/distributed/weight_transfer/factory.py,sha256=OCyp9reYTzMMJHLeqHisf9Z_6L98BejEJq79zhddGqo,4113
+vllm/distributed/weight_transfer/ipc_engine.py,sha256=0oOLGMNy_CgNiZGMNGRcyFDH68Pya-My4nixHHpnhNU,11019
+vllm/distributed/weight_transfer/nccl_engine.py,sha256=AJQIVm5-IjJ7YC8Amv0rFsLsyQN41Cml8qCdhtt6nts,13089
vllm/distributed/weight_transfer/packed_tensor.py,sha256=Sl28ATzJUAjQx30o_QEZFyqyQb7R_U2-jpOGCBkkX_I,8989
vllm/engine/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/engine/__pycache__/__init__.cpython-312.pyc,,
@@ -374,7 +388,7 @@ vllm/engine/__pycache__/arg_utils.cpython-312.pyc,,
vllm/engine/__pycache__/async_llm_engine.cpython-312.pyc,,
vllm/engine/__pycache__/llm_engine.cpython-312.pyc,,
vllm/engine/__pycache__/protocol.cpython-312.pyc,,
-vllm/engine/arg_utils.py,sha256=5KYdKSnnoZOUZEOPv6nX9NxqfM0d8SCQVORYc-Qpkhc,94731
+vllm/engine/arg_utils.py,sha256=hG_lreAPllHnzuSHCSxstp6wLxmxWQaoxoo2g2eKF58,94861
vllm/engine/async_llm_engine.py,sha256=P3UcYRNf9T5GO7j-Nj00LVs2OYKbzl0ivavGILQS21s,284
vllm/engine/llm_engine.py,sha256=Ormr0eFokwGJd2DBWJzJGraI47sc3TPbBBprWEvyU0I,296
vllm/engine/protocol.py,sha256=R8vhuOZRFsOgyH2bXz-TYGr-kLCDYcg25PeSo51-194,6963
@@ -394,11 +408,11 @@ vllm/entrypoints/anthropic/__pycache__/__init__.cpython-312.pyc,,
vllm/entrypoints/anthropic/__pycache__/api_router.cpython-312.pyc,,
vllm/entrypoints/anthropic/__pycache__/protocol.cpython-312.pyc,,
vllm/entrypoints/anthropic/__pycache__/serving.cpython-312.pyc,,
-vllm/entrypoints/anthropic/api_router.py,sha256=VvFf7mLGh4EXm0mK537cPkxax6fUDh3vXeK_lWtZ-64,3079
-vllm/entrypoints/anthropic/protocol.py,sha256=nrxqThJiw762ZKd0m-rHe0hWnmdcm-Hs_D8hhBZaU14,4582
-vllm/entrypoints/anthropic/serving.py,sha256=MC1AyTerhhuMsoolVF_TRf-fyyTsru6KB4xxWVdAYc4,21312
+vllm/entrypoints/anthropic/api_router.py,sha256=-_NhIEW9OnukH5lE6copkgOS44Ob7NPAxbAM3mqjsnE,4591
+vllm/entrypoints/anthropic/protocol.py,sha256=IQOP0NQuIi2sdzbg_8hCXjjbVd5r--ByJTXJ-rIaCgw,5620
+vllm/entrypoints/anthropic/serving.py,sha256=HOzHRtIrVk5WoEtvRjxunaAVT_rmTAnZRqLCEq4XthI,33710
vllm/entrypoints/api_server.py,sha256=JUM2xq5yDLIVY6lDfhEtY002kQBo_YomSfmrrRzu7cE,5902
-vllm/entrypoints/chat_utils.py,sha256=XpNCh1EjCSTTd5TdFBFsazf4CqEMHKaPrzkKJTSuT38,56673
+vllm/entrypoints/chat_utils.py,sha256=f0X_hOTfyJb6Iav1LDI8RihXk7X3eO8ySBBRm8FVsGg,56904
vllm/entrypoints/cli/__init__.py,sha256=-dU_jpAUU9ksRsxbw8Leuhy-wPcRTp7YQp_q5-tQM1Y,828
vllm/entrypoints/cli/__pycache__/__init__.cpython-312.pyc,,
vllm/entrypoints/cli/__pycache__/collect_env.cpython-312.pyc,,
@@ -429,13 +443,13 @@ vllm/entrypoints/cli/collect_env.py,sha256=H9qpqasc2FRmAQUvmRV5WuCxJqdj-ouYAQ2CM
vllm/entrypoints/cli/main.py,sha256=vHRDDzGjKK2ZE9EWyGn--jpfBWewDtiYFDab9srDc78,2468
vllm/entrypoints/cli/openai.py,sha256=CWzTEwp_nLWLXEPpvKlw7rZjiM-CSEnMaxic9s8IpIc,8134
vllm/entrypoints/cli/run_batch.py,sha256=8ChYNc3UbbCVEbtp_tTbgR32_jVq3WRJgO1RzGhP8yI,2242
-vllm/entrypoints/cli/serve.py,sha256=bNd6Mq7gyXw2hP89wKZSgk-KII_k6BmLJPW_oICtf4g,11424
+vllm/entrypoints/cli/serve.py,sha256=ZLRCfZidbjWzU9HXGAdaHGVg4cZQJt6FkqV8FWfnsOw,11829
vllm/entrypoints/cli/types.py,sha256=On70PbJbmYXWTQZaL3EbI7I-fd1eIE21FAhKuHnTcfo,812
vllm/entrypoints/constants.py,sha256=7rJGI5ZxBEJr2_7YjIJHzVaRX_9PChU7aOSYp95v6tE,356
-vllm/entrypoints/grpc_server.py,sha256=NrqiGJQr3C_3o41sYZ4AkkZ0Y3y136kQUTQ3G7PUTJs,17665
+vllm/entrypoints/grpc_server.py,sha256=HmPLZdHZuuXzY7Mg9J20l8D55ExVd74jL52-jMSuZE0,17994
vllm/entrypoints/launcher.py,sha256=iLd368aVcRHtb2i35cn99MZZVxPYVDpmbSKTvVRuNHM,6446
-vllm/entrypoints/llm.py,sha256=jJS8T-bhdeSTX0YS6E0qf3bXPMvxajaDRMUcoUeDJdI,85063
-vllm/entrypoints/logger.py,sha256=GeSbm6Y6FndcA5HU_PYmlFC-6xFe2oGmegmNUgcHMKA,2659
+vllm/entrypoints/llm.py,sha256=mmxbulxBDgC-Won2sSt7zbmHSIcgn2i_aZkMpbn8Kbg,85440
+vllm/entrypoints/logger.py,sha256=GZqaMQEyqkOPRCVpSdVe2NPD6HVV2SmKo0OLaSGy9-A,3297
vllm/entrypoints/mcp/__init__.py,sha256=48Vcuw16i9R99vM3iDVkRkX107yJtFeTtdWgLQE-XFg,107
vllm/entrypoints/mcp/__pycache__/__init__.cpython-312.pyc,,
vllm/entrypoints/mcp/__pycache__/tool.cpython-312.pyc,,
@@ -458,24 +472,24 @@ vllm/entrypoints/openai/chat_completion/__pycache__/protocol.cpython-312.pyc,,
vllm/entrypoints/openai/chat_completion/__pycache__/serving.cpython-312.pyc,,
vllm/entrypoints/openai/chat_completion/__pycache__/stream_harmony.cpython-312.pyc,,
vllm/entrypoints/openai/chat_completion/api_router.py,sha256=XZsKmAyfurBNM5QzjCJk5Pv107cTH--Z8vi0NY3z4ds,3691
-vllm/entrypoints/openai/chat_completion/protocol.py,sha256=vt7RuTozB0X9j7jl1lG7VMSM0w63Q-Cple5Kgu99FF8,28983
-vllm/entrypoints/openai/chat_completion/serving.py,sha256=A-1WyPsqH2HGiPWysrlVKfvRc9H3qmIVlKP4InpJzLc,88403
+vllm/entrypoints/openai/chat_completion/protocol.py,sha256=CQw5DFLdgsBXDSp8MJSrvu4WjgOOi4tFADvb-67WTzQ,30517
+vllm/entrypoints/openai/chat_completion/serving.py,sha256=vot1wNgmHoKB7x_uRXI6-qX2Po9WZDqZEcvSNlm9xkE,89084
vllm/entrypoints/openai/chat_completion/stream_harmony.py,sha256=de1cKmdVEfJfkYVKPLpH_R2VEBC-Zefh4x8zfHkviWI,6158
-vllm/entrypoints/openai/cli_args.py,sha256=1BidcF3nh8AAdo0MoMDxCaiyzfQRFwkDbdNk0X83_iU,15969
+vllm/entrypoints/openai/cli_args.py,sha256=CMlFwCLd9kRWd7q4Llz5KLQocGeCwJNHSw5jWbchUvU,16302
vllm/entrypoints/openai/completion/__init__.py,sha256=48Vcuw16i9R99vM3iDVkRkX107yJtFeTtdWgLQE-XFg,107
vllm/entrypoints/openai/completion/__pycache__/__init__.cpython-312.pyc,,
vllm/entrypoints/openai/completion/__pycache__/api_router.cpython-312.pyc,,
vllm/entrypoints/openai/completion/__pycache__/protocol.cpython-312.pyc,,
vllm/entrypoints/openai/completion/__pycache__/serving.cpython-312.pyc,,
vllm/entrypoints/openai/completion/api_router.py,sha256=QDNpQA5gC1YjLRsTFZZEg7eT9EDyEqD9Cou4CcP2pKo,3641
-vllm/entrypoints/openai/completion/protocol.py,sha256=oRPTwHWqWzZ4OabxCGH-C7sKRFhLepRTHGGUttlDgCE,17864
+vllm/entrypoints/openai/completion/protocol.py,sha256=zMu9piJ9s0UxBJCtQRIXQ5NZBofDhBAPwF5wPMNChsI,19367
vllm/entrypoints/openai/completion/serving.py,sha256=i4uMF351n5j7wsNYfTsBwm8zXFLTTJ1XyaHQgR6NjBM,27370
vllm/entrypoints/openai/engine/__init__.py,sha256=48Vcuw16i9R99vM3iDVkRkX107yJtFeTtdWgLQE-XFg,107
vllm/entrypoints/openai/engine/__pycache__/__init__.cpython-312.pyc,,
vllm/entrypoints/openai/engine/__pycache__/protocol.cpython-312.pyc,,
vllm/entrypoints/openai/engine/__pycache__/serving.cpython-312.pyc,,
vllm/entrypoints/openai/engine/protocol.py,sha256=WpE0x_RUno9Gvs4kjiGgQB30bmiGV29Z0zKdgG9uNx8,9841
-vllm/entrypoints/openai/engine/serving.py,sha256=RTIOkuMZSC0p_EhvX6Eu5K3GqDfHA6SdqbNK36J_3xI,47254
+vllm/entrypoints/openai/engine/serving.py,sha256=GCcYWRN1ta_FTPw61Sxa4tQwQYDNXJAcr1ebtYnrI0c,46702
vllm/entrypoints/openai/generate/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/entrypoints/openai/generate/__pycache__/__init__.cpython-312.pyc,,
vllm/entrypoints/openai/generate/__pycache__/api_router.cpython-312.pyc,,
@@ -517,8 +531,8 @@ vllm/entrypoints/openai/responses/__pycache__/utils.cpython-312.pyc,,
vllm/entrypoints/openai/responses/api_router.py,sha256=Y8p4WT9POGG7nQstsYE1-T6x3U13T9O971eN7S71f3U,4748
vllm/entrypoints/openai/responses/context.py,sha256=l7JXOgngdBqd84LuJWUTlUT0qusm34rwbHnHJBDp_GM,35185
vllm/entrypoints/openai/responses/harmony.py,sha256=iqXkCyKNeIwxYs0l745c3tsmOnA0uj5Qs3RQCtzfA20,20068
-vllm/entrypoints/openai/responses/protocol.py,sha256=yzPLhSjZyY-8vBixnltO7FsJv1w4V_K-77GknPvt2EI,23685
-vllm/entrypoints/openai/responses/serving.py,sha256=VtogA39uqpHG9OQvFAHjysHPKLQSl31wH3iVbIA2SYU,70206
+vllm/entrypoints/openai/responses/protocol.py,sha256=AuXLRUo_OpWW9wF8Cr59C1__Wiy4SqPhXvf8_LO3mO4,23945
+vllm/entrypoints/openai/responses/serving.py,sha256=y7odWu6baXWx25MNLLQvaYij8sQZYIviDZgGhanX7F4,72058
vllm/entrypoints/openai/responses/streaming_events.py,sha256=ynVNwoBd4x38f0n9SNTCQyymzuFn_V3tnP6GhnJCvuQ,27220
vllm/entrypoints/openai/responses/utils.py,sha256=NSkKEUcUMmCrgVYMKzvshGoQPEd-7QbNXpt_8XcN_Ow,9128
vllm/entrypoints/openai/run_batch.py,sha256=VaHRg-2Ja_3u7wzGtq_EvvqEFDiJDGTlKKR0fp95gMM,29303
@@ -532,48 +546,47 @@ vllm/entrypoints/openai/speech_to_text/__pycache__/speech_to_text.cpython-312.py
vllm/entrypoints/openai/speech_to_text/api_router.py,sha256=-CFrCyG9z4FVf6fv4aQxJAmNPebdELGTAQ0I3pLShlg,5104
vllm/entrypoints/openai/speech_to_text/protocol.py,sha256=ZAnM3Jk4vCGQyTydyUXsaF0IbHOnU9fXm_HH9nzLKY0,17205
vllm/entrypoints/openai/speech_to_text/serving.py,sha256=mZl6FEbpzS-JyvKsov09odZeIXEY9c8Br82T0dTmoK0,6150
-vllm/entrypoints/openai/speech_to_text/speech_to_text.py,sha256=ByVpW4SKmYbwjgq-vqSjbcMrxUqAr6Es3AYGiuo8P5U,31459
-vllm/entrypoints/openai/translations/__init__.py,sha256=3L2Iwb-SbRw9JjhRlP-GUF9aaGJ0Z_52o6smXxbve0c,410
-vllm/entrypoints/openai/translations/__pycache__/__init__.cpython-312.pyc,,
-vllm/entrypoints/openai/translations/__pycache__/api_router.cpython-312.pyc,,
-vllm/entrypoints/openai/translations/__pycache__/protocol.cpython-312.pyc,,
-vllm/entrypoints/openai/translations/__pycache__/serving.cpython-312.pyc,,
-vllm/entrypoints/openai/translations/__pycache__/speech_to_text.cpython-312.pyc,,
-vllm/entrypoints/openai/translations/api_router.py,sha256=bz-XDlh_lQFLoPvrsa74dg61EHo5I6yG39T19IotaSc,508
-vllm/entrypoints/openai/translations/protocol.py,sha256=PUVFoOn9f1kpagYj9ogWtF-mRA5kNtWD2sqwE35z9Us,502
-vllm/entrypoints/openai/translations/serving.py,sha256=OMTUCsZQOY95vQmnaZ_CVykK1R5VzFFKmLZTimzhV7E,499
-vllm/entrypoints/openai/translations/speech_to_text.py,sha256=KP2KldbK8EafCwgtaQ5pHknF4-p-Oo00haK0FdpuHE0,527
+vllm/entrypoints/openai/speech_to_text/speech_to_text.py,sha256=dmonxmIRKckLuVTShHU1ja-G9YPxZIzJJLiIBpbCA7I,32279
vllm/entrypoints/openai/utils.py,sha256=4ws1wnAPy-ozf3Dw1-D2h1S07xWkjRC1AFpsUECWf0o,1624
-vllm/entrypoints/pooling/__init__.py,sha256=I4DWP-s2rOTaXSbjdfC3H_d9I0ZJVFo5jyz38vIPTX4,4325
+vllm/entrypoints/pooling/__init__.py,sha256=0aDd_nORG54TwBZQcHV2LniatQ1XuxI5AHzl1BGUSN4,4414
vllm/entrypoints/pooling/__pycache__/__init__.cpython-312.pyc,,
+vllm/entrypoints/pooling/__pycache__/io_processor_factories.cpython-312.pyc,,
+vllm/entrypoints/pooling/__pycache__/typing.cpython-312.pyc,,
vllm/entrypoints/pooling/__pycache__/utils.cpython-312.pyc,,
vllm/entrypoints/pooling/base/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/entrypoints/pooling/base/__pycache__/__init__.cpython-312.pyc,,
+vllm/entrypoints/pooling/base/__pycache__/io_processor.cpython-312.pyc,,
vllm/entrypoints/pooling/base/__pycache__/protocol.cpython-312.pyc,,
-vllm/entrypoints/pooling/base/protocol.py,sha256=peoLP5PbN28KHiUHe8ybS4ZsHor5XqphHxsF89pUDKg,7490
+vllm/entrypoints/pooling/base/__pycache__/serving.cpython-312.pyc,,
+vllm/entrypoints/pooling/base/io_processor.py,sha256=_VtgNro9ycFsbDCZIzs0ot9YA3qPpWlfJPp4PWb87x8,6374
+vllm/entrypoints/pooling/base/protocol.py,sha256=QA7obcw7N_XsNcxOfdWjBaKNMCOMVcD_VGLX1ifqtsk,7354
+vllm/entrypoints/pooling/base/serving.py,sha256=mTIKxd-d-yPWMs9ZYxsXvB09x1_xnzv4HEndQfEO7f0,13018
vllm/entrypoints/pooling/classify/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/entrypoints/pooling/classify/__pycache__/__init__.cpython-312.pyc,,
vllm/entrypoints/pooling/classify/__pycache__/api_router.cpython-312.pyc,,
+vllm/entrypoints/pooling/classify/__pycache__/io_processor.cpython-312.pyc,,
vllm/entrypoints/pooling/classify/__pycache__/protocol.cpython-312.pyc,,
vllm/entrypoints/pooling/classify/__pycache__/serving.cpython-312.pyc,,
-vllm/entrypoints/pooling/classify/api_router.py,sha256=oVlc9VP1R1cNy_WIw1mQofOmrHdjKyCHm9-8ybGu85w,1704
-vllm/entrypoints/pooling/classify/protocol.py,sha256=u_JCHyPp7DN97DRCUgPrt8GTNUGAtzLfhEoLqfvHnGY,2822
-vllm/entrypoints/pooling/classify/serving.py,sha256=R-B9-uiTDUeJYgM3zDAdHxHnt661A4kIaJwRUUmK4kI,5578
+vllm/entrypoints/pooling/classify/api_router.py,sha256=eId06qUMKV80fSWvmpYVUo1tf3IZi4t_s-wnkeAVC1w,1272
+vllm/entrypoints/pooling/classify/io_processor.py,sha256=Wck30F81yUmDAtM1mcL8O_8eS2Ucjk5uhTTnA846tN8,1997
+vllm/entrypoints/pooling/classify/protocol.py,sha256=o6oI_xXXTEsbuhChHQaWrn_lGdeh-2vtxK4vna54d40,2694
+vllm/entrypoints/pooling/classify/serving.py,sha256=nJ3o3swwS00bWhB04ddfUFRSvf-aCRpsqMJPEJVWh7k,2552
vllm/entrypoints/pooling/embed/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/entrypoints/pooling/embed/__pycache__/__init__.cpython-312.pyc,,
vllm/entrypoints/pooling/embed/__pycache__/api_router.cpython-312.pyc,,
vllm/entrypoints/pooling/embed/__pycache__/protocol.cpython-312.pyc,,
vllm/entrypoints/pooling/embed/__pycache__/serving.cpython-312.pyc,,
vllm/entrypoints/pooling/embed/api_router.py,sha256=eO-thkJZ0_ACfBE6I26dSeblrxu-U4Bmbsw446dV3-A,2628
-vllm/entrypoints/pooling/embed/protocol.py,sha256=dX35bWvbj_ls2kK6hcB26y9B1pAK47yDB6WkixmBjRY,4454
+vllm/entrypoints/pooling/embed/protocol.py,sha256=uz5gJ7a05ssZLPwulaWhrxFKr5vDZhT79UqOWSMEqP0,3724
vllm/entrypoints/pooling/embed/serving.py,sha256=-byoFBEG4zpM5Qz6I9n0wq2l8twDI_vbVqjdQbeanL8,25045
+vllm/entrypoints/pooling/io_processor_factories.py,sha256=A0jAz430HHdgvDPtjCZ3jXrZAvZI8bOSPudP6kZzIEQ,1026
vllm/entrypoints/pooling/pooling/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/entrypoints/pooling/pooling/__pycache__/__init__.cpython-312.pyc,,
vllm/entrypoints/pooling/pooling/__pycache__/api_router.cpython-312.pyc,,
vllm/entrypoints/pooling/pooling/__pycache__/protocol.cpython-312.pyc,,
vllm/entrypoints/pooling/pooling/__pycache__/serving.cpython-312.pyc,,
vllm/entrypoints/pooling/pooling/api_router.py,sha256=BT3Wf9C3Wf-0U-3Fr9WBIxE8ij2qGrLXNMSLI4hqR8s,2163
-vllm/entrypoints/pooling/pooling/protocol.py,sha256=Tov1Nx3PpL6xOliqP7sXendfrEh6mMwbqLZjHCSfr1g,4893
+vllm/entrypoints/pooling/pooling/protocol.py,sha256=lkF3Mu3lzzFUeURmO2wrsmqY5M1tIfCdeJ9oSKgyIjI,4163
vllm/entrypoints/pooling/pooling/serving.py,sha256=8Q7wD_UEyU_6Uf_mFEdKoGfCAUldL0izZvuB8X7a78E,12756
vllm/entrypoints/pooling/score/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/entrypoints/pooling/score/__pycache__/__init__.cpython-312.pyc,,
@@ -582,14 +595,15 @@ vllm/entrypoints/pooling/score/__pycache__/protocol.cpython-312.pyc,,
vllm/entrypoints/pooling/score/__pycache__/serving.cpython-312.pyc,,
vllm/entrypoints/pooling/score/__pycache__/utils.cpython-312.pyc,,
vllm/entrypoints/pooling/score/api_router.py,sha256=EgEEBPugF80xLXgMZwpLZcPGXpQ3-ZmYM3uYsQDzvPg,4690
-vllm/entrypoints/pooling/score/protocol.py,sha256=Vr4YLFtegqo6tFT6Dk_czMh6B_jmE2-ftiXoxpMrG7M,4025
-vllm/entrypoints/pooling/score/serving.py,sha256=jq2pbuajHZGrTll_by52AtCL9A0xrvIZxBROHvSOlF4,22113
-vllm/entrypoints/pooling/score/utils.py,sha256=L2y1JuMywqwnUpKs3PCuVZP_YmxIvKYIbWmHvxdsf3o,13844
+vllm/entrypoints/pooling/score/protocol.py,sha256=2nVYKBKp00HHfKm-g4CKMvq2cnK32wiEllHt_JxmE88,3897
+vllm/entrypoints/pooling/score/serving.py,sha256=IZHHmlA4LeaI757tidNif1WZkxsVQvewWGQyESjkdnE,22253
+vllm/entrypoints/pooling/score/utils.py,sha256=0F8ijuXas6wBey4ZXFfH8L8POaO8zB2jKQgUPyPSpqc,17063
+vllm/entrypoints/pooling/typing.py,sha256=p184UjkQARdEqT_iiJ2AfJVu6fH4QmN4_AHNVmupzwA,1302
vllm/entrypoints/pooling/utils.py,sha256=hutLstZWBpVpLPGur21YBatWDw1qKm8x-ye4clPjwrw,3150
vllm/entrypoints/sagemaker/__init__.py,sha256=Fh71s9Oigag26xVp42xT7JCAr_dvDRGnjoRhEV18cwk,155
vllm/entrypoints/sagemaker/__pycache__/__init__.cpython-312.pyc,,
vllm/entrypoints/sagemaker/__pycache__/api_router.cpython-312.pyc,,
-vllm/entrypoints/sagemaker/api_router.py,sha256=-A-nAq0SwH0yUX50EeeT06EjltvBLbqVprJV6JZ_TD8,6225
+vllm/entrypoints/sagemaker/api_router.py,sha256=Z3jd_n66sTm-9foJS0k5NfjBwN7IbXNE0QqmYXs4O7E,6307
vllm/entrypoints/serve/__init__.py,sha256=3NTr65tVBGmL_xe-ypSqffMKBMrCaHlVItF2QCiSmiU,1439
vllm/entrypoints/serve/__pycache__/__init__.cpython-312.pyc,,
vllm/entrypoints/serve/cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -655,11 +669,11 @@ vllm/entrypoints/serve/tokenize/api_router.py,sha256=HH4-2P4hruiUbusT6Rq0F8j2zIE
vllm/entrypoints/serve/tokenize/protocol.py,sha256=FQ_wXfO7rupRl94JQkqNznQXhQoNIIPneQhft10ZrNA,6071
vllm/entrypoints/serve/tokenize/serving.py,sha256=61jRHCiABN8U0G-p9IIxeO5RVHdwXMWFQK5zypvFDnU,7052
vllm/entrypoints/ssl.py,sha256=P6ApVGAQFLHERnJeoLNpZk-HQI1nNqWdlF8UAQpnC48,2712
-vllm/entrypoints/utils.py,sha256=IiE53NLG1Hn7byyvqx5WH8toPrCaeCP1yELV-f5DIK8,11253
-vllm/env_override.py,sha256=NkjKlikC3W5yjLfT9-wSl79RuFvtRDskUBahkmgK-Bk,18087
-vllm/envs.py,sha256=vNPBXbjpmKZY1EWKrU64VnhFnnp6HO6-lq0_mPNAAbA,83716
+vllm/entrypoints/utils.py,sha256=grCOSCM66SZzcom-_lSZIdjco22dVzRsbOG1j7WA77s,13270
+vllm/env_override.py,sha256=4HQLIXC8V6H6-pzHEVk8Jw0Y1BdWQmyAqQS6qzzxH94,19912
+vllm/envs.py,sha256=mmV1CyoMXLPSHEQ2DQqElndXnM-M5Wqvk5UhItcAEAI,87511
vllm/exceptions.py,sha256=AtRlqcG-WMiNYd_W1uJsbHdRoJusuBTWHAfpUru1NkQ,1067
-vllm/forward_context.py,sha256=poboqbv4FfR4uN5K84MI_JZDY92FoPjw5Hw-qUUAFic,16286
+vllm/forward_context.py,sha256=fyE3Z3oH1FBse9k145wQkUPFSGdewxk9rHkvCXUxkm8,16248
vllm/grpc/__init__.py,sha256=leqIiw5daoeDRF-SfkvbfjNXiJVmEp4p0vZcsNyYOME,504
vllm/grpc/__pycache__/__init__.cpython-312.pyc,,
vllm/grpc/__pycache__/compile_protos.cpython-312.pyc,,
@@ -686,7 +700,7 @@ vllm/kernels/helion/ops/__init__.py,sha256=vdhTpQvMYs9RDhS4APjiSh0S8dmeUMLBv1aV1
vllm/kernels/helion/ops/__pycache__/__init__.cpython-312.pyc,,
vllm/kernels/helion/ops/__pycache__/silu_mul_fp8.cpython-312.pyc,,
vllm/kernels/helion/ops/silu_mul_fp8.py,sha256=UMGebTNqxQ-HtMTxi0kN1ytwrbw4Gxw7Ai83d-DW1Mw,4836
-vllm/kernels/helion/register.py,sha256=5pZvkAeAVfaijr1vATcoXDMHiVe4wXRvUPQ5fsIkiw0,16371
+vllm/kernels/helion/register.py,sha256=whg7hgdYp9XNDnBEPCziZyVz38kKOG1Q7trfVR_Yxuw,19890
vllm/kernels/helion/utils.py,sha256=clw61SkYlr-fNwdm9DwhUowPgooDhEoycqFAassb8I8,3101
vllm/logger.py,sha256=stQuWFxaJ_CKcNOvN9EQIaOA__ylmSIOAUIVMLAp7CY,11133
vllm/logging_utils/__init__.py,sha256=91YL2ScSZ4kK19HWo6DCGXwhFH5K6jqDrweZWk3Ixg8,538
@@ -727,8 +741,8 @@ vllm/lora/layers/__pycache__/vocal_parallel_embedding.cpython-312.pyc,,
vllm/lora/layers/base.py,sha256=djqVQ30rFpuG71V8nIj4QARsbEDzjfeosBk9jTt98Ts,1825
vllm/lora/layers/base_linear.py,sha256=ixQzJ4ogWgyITKjvUFCb7vZU1Rul-9496vhVeLrbKkE,5909
vllm/lora/layers/column_parallel_linear.py,sha256=Rp3sSTG4AlG_N2FA7QwmEmva06qxRzg0mLFg8dz_7CM,22947
-vllm/lora/layers/fused_moe.py,sha256=PcbT2LH_IeASAYwkbIricUeBJXW9rBDHsbQavWRwJWM,29585
-vllm/lora/layers/logits_processor.py,sha256=W_ZzEyMsydDqrynWqabzIRMmjfZQVnrNuNnmc5DHkyo,6625
+vllm/lora/layers/fused_moe.py,sha256=g2N2Wi2zhRUPoyHJIyLT2US7wJbKnw0gEY7sjEIL5e8,29746
+vllm/lora/layers/logits_processor.py,sha256=6uUYo43HEL7-sfqpD7UddbjnGWI1NoCDRZp8BuXsCUg,6544
vllm/lora/layers/replicated_linear.py,sha256=oswJbpelrd1Smiz1AV9zlLRcZb62XW700uporDliAJE,2205
vllm/lora/layers/row_parallel_linear.py,sha256=y6MpIbqGU3eSv50nlzuaSDRZdyUfwUJp66E6Fwct8GM,6187
vllm/lora/layers/utils.py,sha256=PGhIDjvImTX6hte5t-PtuwBV9EGPB1yZeTKCEq_uRC4,3329
@@ -753,12 +767,12 @@ vllm/lora/ops/triton_ops/__pycache__/lora_kernel_metadata.cpython-312.pyc,,
vllm/lora/ops/triton_ops/__pycache__/lora_shrink_op.cpython-312.pyc,,
vllm/lora/ops/triton_ops/__pycache__/utils.cpython-312.pyc,,
vllm/lora/ops/triton_ops/fused_moe_lora_fp8_op.py,sha256=ZPbTYFG2YvplXu5Y5BT_0Q0CmVnk31irzeAgaRnL2vM,32151
-vllm/lora/ops/triton_ops/fused_moe_lora_op.py,sha256=jFWC86-yuhiUUAsYZom_8vpQLBQHodh2qgtHMxfMY38,24369
+vllm/lora/ops/triton_ops/fused_moe_lora_op.py,sha256=x4ONU29wqBfxD8tTdLLoVmn5_ifG0f1ZJG1FjpwCxEM,29829
vllm/lora/ops/triton_ops/kernel_utils.py,sha256=YiE1F7sojPFFF06QsVSeaDN19Fpo4X1MhFSoX4F8EXA,10375
-vllm/lora/ops/triton_ops/lora_expand_op.py,sha256=lqhtVPpv8kGz5bm2FY_ueiM3IBdcA1t0TeLUOXZGGd4,9477
-vllm/lora/ops/triton_ops/lora_kernel_metadata.py,sha256=oTXKTYFBLEYZR60kkro7g1AxiNX2Dre2FxxAnQzN0u0,6889
-vllm/lora/ops/triton_ops/lora_shrink_op.py,sha256=AEXWzv_099JVS2f4hJEnorwKQNx76pqtDYyx1nwXu-o,8932
-vllm/lora/ops/triton_ops/utils.py,sha256=kUfcxG1ALtR0_9M4LulLS7ah-PAmK0jBmg5OjTmR53Y,10511
+vllm/lora/ops/triton_ops/lora_expand_op.py,sha256=V2_Jnj-Lq5NMzm7opVTTuipWQ86hgTpFpzSA9WroUa4,9530
+vllm/lora/ops/triton_ops/lora_kernel_metadata.py,sha256=dmnu60keid9-9P8FCnnTk0znpjsUb-kyOfCB0XbGiIE,7695
+vllm/lora/ops/triton_ops/lora_shrink_op.py,sha256=AVrVt7uAprP94HpZEfi4hZn8-jG2Z6-fHViaINilh74,9204
+vllm/lora/ops/triton_ops/utils.py,sha256=OqcUyicFRFj2XrTWpOZAuZOD33XP6qiobVgUGe6M54c,10723
vllm/lora/ops/xpu_ops/__init__.py,sha256=7h0eSNAE9MWqAPvjKCCZOaNQSsXkgCxP-rBcJrEfqrI,258
vllm/lora/ops/xpu_ops/__pycache__/__init__.cpython-312.pyc,,
vllm/lora/ops/xpu_ops/__pycache__/lora_ops.cpython-312.pyc,,
@@ -774,7 +788,7 @@ vllm/lora/punica_wrapper/__pycache__/punica_xpu.cpython-312.pyc,,
vllm/lora/punica_wrapper/__pycache__/utils.cpython-312.pyc,,
vllm/lora/punica_wrapper/punica_base.py,sha256=gh0ZvV_6bdhnBhk5IbOoP74HB-RbQ94sShDDc2tBbO4,15864
vllm/lora/punica_wrapper/punica_cpu.py,sha256=qJLetoy4aPyQazVqbEq0pOOr848ox_hXvO3AOIH4XR4,10817
-vllm/lora/punica_wrapper/punica_gpu.py,sha256=uklNQbXbs6XPUgQthvSLjufEOvx-w02UE9fjhakUgXk,14605
+vllm/lora/punica_wrapper/punica_gpu.py,sha256=M9RfCIZ2hkB7VZqS3-6ZRXrkka4UrOrhXFiJTZZmLws,15232
vllm/lora/punica_wrapper/punica_selector.py,sha256=-IyHqLvTPs5onKUXacOqS9J03otlfaT8J7yj4iD_YRs,818
vllm/lora/punica_wrapper/punica_xpu.py,sha256=YjsIUsRV_ciLquTZO7GpC-Gwjr7EvoaebJ0QnwF67ys,13633
vllm/lora/punica_wrapper/utils.py,sha256=omm684H29vAOXrC4ny19kmKFRk0_-pIAMTR87CmRMXw,5517
@@ -790,7 +804,7 @@ vllm/model_executor/__pycache__/utils.cpython-312.pyc,,
vllm/model_executor/custom_op.py,sha256=gNfYzyjwHnT0GbpaWmeuD-P6PrbeBTfBUG_O3OoijCM,13728
vllm/model_executor/kernels/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/model_executor/kernels/__pycache__/__init__.cpython-312.pyc,,
-vllm/model_executor/kernels/linear/__init__.py,sha256=_TTL-2c1LaCjdwzX8aZj6sf3Zg5bKj_IlAenw-NhCM4,12957
+vllm/model_executor/kernels/linear/__init__.py,sha256=y3NUXd2h96j2EfJ1aDuqxwvWFui5RrFi8hKrcS8rQf0,12957
vllm/model_executor/kernels/linear/__pycache__/__init__.cpython-312.pyc,,
vllm/model_executor/kernels/linear/mixed_precision/MPLinearKernel.py,sha256=dvSBv9nYGrOiSEN3udSv2ElQYpf7vQorIjA61b32D2g,2825
vllm/model_executor/kernels/linear/mixed_precision/__init__.py,sha256=sQyWnhpVrXm5f2SrQwmtap4F1dvUdKCAWAETyhwLgdw,1453
@@ -811,8 +825,8 @@ vllm/model_executor/kernels/linear/mixed_precision/cpu.py,sha256=OdM7Q6LUYwfT4ny
vllm/model_executor/kernels/linear/mixed_precision/cutlass.py,sha256=LY3XjsyGlY1y1cGflBYkC40MqsJ9-uuvF6DIi2-aaI8,4631
vllm/model_executor/kernels/linear/mixed_precision/dynamic_4bit.py,sha256=1NT07Ioc9wQN5CoQZ4CKE0yoFOKq8aAGngLIrcjg-dw,6115
vllm/model_executor/kernels/linear/mixed_precision/exllama.py,sha256=9vy2P7jHm_bDNCDpiNkCH8wEg20_pwN4HLAhpHg-Org,6443
-vllm/model_executor/kernels/linear/mixed_precision/machete.py,sha256=gdBdme4JUaZ_YcrKTzJ1JfdoeX_oG9e0SSQa204AvTI,5726
-vllm/model_executor/kernels/linear/mixed_precision/marlin.py,sha256=K9HDpxkAEuopvXCsExA8Gd7o3bwumkte9TzVTRkNUnc,7456
+vllm/model_executor/kernels/linear/mixed_precision/machete.py,sha256=YSXF_YG2KNEaz58sGvTKMCFyqRJUY1r0pko1H97rGes,6019
+vllm/model_executor/kernels/linear/mixed_precision/marlin.py,sha256=8CnwYzgqZF9ilcHbxWHO_f-hpcp6We3_v-qvbNU1TOU,13793
vllm/model_executor/kernels/linear/mixed_precision/xpu.py,sha256=sudCl-MAKmFx_Tx_B3YxjhxPT78KABj65lj_7YSc1zo,3012
vllm/model_executor/kernels/linear/scaled_mm/ScaledMMLinearKernel.py,sha256=hhjrsj8bfhSkrSU0pMGcMq31KzC5JbhFOhgIWhdNRHE,5577
vllm/model_executor/kernels/linear/scaled_mm/__init__.py,sha256=XAedDbqkte2vfh0Z0yw4_20yczzSt3Xud8lnFRpk6Tw,1839
@@ -828,7 +842,7 @@ vllm/model_executor/kernels/linear/scaled_mm/__pycache__/triton.cpython-312.pyc,
vllm/model_executor/kernels/linear/scaled_mm/__pycache__/xpu.cpython-312.pyc,,
vllm/model_executor/kernels/linear/scaled_mm/aiter.py,sha256=l2XlFMcOWwv8MWzC_O-U8bAgYRz-PnIJEhtpURfnQYU,4253
vllm/model_executor/kernels/linear/scaled_mm/cpu.py,sha256=TK1c_-zADd-McY4tmtlKRAZniNt7RXpfM5RH-b965nw,7980
-vllm/model_executor/kernels/linear/scaled_mm/cutlass.py,sha256=EbTxE3ti7Kysh3LWBu95O_0cHW1clPbOaV1zJ5bE3K0,6589
+vllm/model_executor/kernels/linear/scaled_mm/cutlass.py,sha256=ddc33PnnGAfOJ8Vb8qpd-WwPpVnF5vdwap2SN0FejKw,8027
vllm/model_executor/kernels/linear/scaled_mm/flashinfer.py,sha256=dcs5UdeM-sWoDqkCuDue9O63kQ3RkKeROoXA_-UVGNQ,1783
vllm/model_executor/kernels/linear/scaled_mm/pytorch.py,sha256=4H6pqQN66UbgTMnkwa6ifZzxUb2Lfe5h4ukps5HxYBk,7730
vllm/model_executor/kernels/linear/scaled_mm/rocm.py,sha256=erNElT7iQxk1_4gMLvJqZzZoCkGSGcPrVtdx-3z74AA,3286
@@ -850,31 +864,33 @@ vllm/model_executor/layers/__pycache__/resampler.cpython-312.pyc,,
vllm/model_executor/layers/__pycache__/sparse_attn_indexer.cpython-312.pyc,,
vllm/model_executor/layers/__pycache__/utils.cpython-312.pyc,,
vllm/model_executor/layers/__pycache__/vocab_parallel_embedding.cpython-312.pyc,,
-vllm/model_executor/layers/activation.py,sha256=zNAl1vvd-aYcZOULsK0g6e5oD6iCfqSN_pBmgE_WTLU,24690
+vllm/model_executor/layers/activation.py,sha256=5KIB6Quq92koiYif6svQXN99fPPMRg4CvF9hX_KoaEM,24813
vllm/model_executor/layers/attention/__init__.py,sha256=WVCsMgceTb4i77vkNi4z42zaM5DhAZLttnO-gAcmZnY,912
vllm/model_executor/layers/attention/__pycache__/__init__.cpython-312.pyc,,
vllm/model_executor/layers/attention/__pycache__/attention.cpython-312.pyc,,
vllm/model_executor/layers/attention/__pycache__/chunked_local_attention.cpython-312.pyc,,
vllm/model_executor/layers/attention/__pycache__/cross_attention.cpython-312.pyc,,
vllm/model_executor/layers/attention/__pycache__/encoder_only_attention.cpython-312.pyc,,
+vllm/model_executor/layers/attention/__pycache__/extra_cache.cpython-312.pyc,,
vllm/model_executor/layers/attention/__pycache__/kv_transfer_utils.cpython-312.pyc,,
vllm/model_executor/layers/attention/__pycache__/mla_attention.cpython-312.pyc,,
vllm/model_executor/layers/attention/__pycache__/mm_encoder_attention.cpython-312.pyc,,
vllm/model_executor/layers/attention/__pycache__/static_sink_attention.cpython-312.pyc,,
-vllm/model_executor/layers/attention/attention.py,sha256=DnmYpSuIH1UjammjzOGXQhWFseHvowkCA8nDd4b5_Ww,28445
+vllm/model_executor/layers/attention/attention.py,sha256=s3wmcNTMpfmbw6bqZw7YcTPYGSKdOQIetY96P4DbCzI,30292
vllm/model_executor/layers/attention/chunked_local_attention.py,sha256=Lra8sc8XRfHewPca_suNdn9wYKmivWNnIRdG_pseXe4,4619
vllm/model_executor/layers/attention/cross_attention.py,sha256=jEq6GJk4AgKm0Va52HTI82Oy1uX2s1gY68J7XaBqXd0,8117
vllm/model_executor/layers/attention/encoder_only_attention.py,sha256=7Pa-sQszZrggyFQawK-cfDRNiqd8_valjo3XYV-cRUI,3036
+vllm/model_executor/layers/attention/extra_cache.py,sha256=47q-6XEOyj29m_V0YLuBO7ocJosX3WtHD4qsQyiXLPY,4525
vllm/model_executor/layers/attention/kv_transfer_utils.py,sha256=zzMz0HAky07ebU5kZhnl8IdbBxhp-8lDiiBF65B04_4,2030
-vllm/model_executor/layers/attention/mla_attention.py,sha256=wjAHqWj-EQ8nTjajRZemuTcqC75mcSbEzYv8biDPd7A,120514
-vllm/model_executor/layers/attention/mm_encoder_attention.py,sha256=cdKkO1cgu-SUY0wv4qng6AMiZ47ztlPsW9KI805gyaI,8964
+vllm/model_executor/layers/attention/mla_attention.py,sha256=otfm87etz4vBu94_vHHfLNX4RPNn--34rI9coY2cmmk,122436
+vllm/model_executor/layers/attention/mm_encoder_attention.py,sha256=P0I_pUQgjp-zvCZ18uT2HsWHsq_UsVq3SvKaM8y9bFc,15052
vllm/model_executor/layers/attention/static_sink_attention.py,sha256=MITS6MM1-fied2C0YANgXFOdwLMzNmJY9W3RU52yca4,8379
vllm/model_executor/layers/attention_layer_base.py,sha256=T8X8sxru-eTivn94eWOVqjEosV6LNaJbpTrR8sOe1QQ,1015
vllm/model_executor/layers/batch_invariant.py,sha256=lJAO7RJc79q5Nl_r26c7M_T6SJzCAI6MZAsXZbzrx_w,33841
vllm/model_executor/layers/conv.py,sha256=UVHGlHFBo56rLzFK7bOE22WkDEB4W3xpMJp4pYtxOjc,8152
vllm/model_executor/layers/fla/__init__.py,sha256=Xbv6P8KG7bLpcOF17zPO5PYYHImweE9aiYYgGyMS9U4,391
vllm/model_executor/layers/fla/__pycache__/__init__.cpython-312.pyc,,
-vllm/model_executor/layers/fla/ops/__init__.py,sha256=Tx9ajmbsCCptuJofX-QpCC_Ut2SwZBTT56SUjGXGNs0,642
+vllm/model_executor/layers/fla/ops/__init__.py,sha256=6lsg78KTrKLkZ0Re9n2GmIJmVzqyLzzyWxS4ixI0Rzc,876
vllm/model_executor/layers/fla/ops/__pycache__/__init__.cpython-312.pyc,,
vllm/model_executor/layers/fla/ops/__pycache__/chunk.cpython-312.pyc,,
vllm/model_executor/layers/fla/ops/__pycache__/chunk_delta_h.cpython-312.pyc,,
@@ -882,6 +898,7 @@ vllm/model_executor/layers/fla/ops/__pycache__/chunk_o.cpython-312.pyc,,
vllm/model_executor/layers/fla/ops/__pycache__/chunk_scaled_dot_kkt.cpython-312.pyc,,
vllm/model_executor/layers/fla/ops/__pycache__/cumsum.cpython-312.pyc,,
vllm/model_executor/layers/fla/ops/__pycache__/fused_recurrent.cpython-312.pyc,,
+vllm/model_executor/layers/fla/ops/__pycache__/fused_sigmoid_gating.cpython-312.pyc,,
vllm/model_executor/layers/fla/ops/__pycache__/index.cpython-312.pyc,,
vllm/model_executor/layers/fla/ops/__pycache__/kda.cpython-312.pyc,,
vllm/model_executor/layers/fla/ops/__pycache__/l2norm.cpython-312.pyc,,
@@ -890,21 +907,22 @@ vllm/model_executor/layers/fla/ops/__pycache__/op.cpython-312.pyc,,
vllm/model_executor/layers/fla/ops/__pycache__/solve_tril.cpython-312.pyc,,
vllm/model_executor/layers/fla/ops/__pycache__/utils.cpython-312.pyc,,
vllm/model_executor/layers/fla/ops/__pycache__/wy_fast.cpython-312.pyc,,
-vllm/model_executor/layers/fla/ops/chunk.py,sha256=YWK9csLUNXzExLLjGi2QMc-OocBUCzoLIZzdxucgfww,7885
-vllm/model_executor/layers/fla/ops/chunk_delta_h.py,sha256=SGaH1JZJB4Kqa_G2gXxwQSGD0eX2vBOKY6hDgUxw9T0,11944
-vllm/model_executor/layers/fla/ops/chunk_o.py,sha256=zshbCdGrTizDEbFVp9ahHmCEyiPOmxZ9QF9MPm1HOMM,5109
-vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py,sha256=_i7MMGO66MjD1RYQFDqNpa9yGknPyHBOTEDwChc9hFw,4653
+vllm/model_executor/layers/fla/ops/chunk.py,sha256=vM41AcghOkRolUZ-FhgKAlbXsonZMXvedUmVngACn9I,7870
+vllm/model_executor/layers/fla/ops/chunk_delta_h.py,sha256=gaRUkiXe67fo36q8z52wPB16tpNWp49cDSZR-X5_0cU,11940
+vllm/model_executor/layers/fla/ops/chunk_o.py,sha256=80kaFqamRCEq5Sl4raQ-J68e1Y11WP_WdBGjfSt6ZSI,5101
+vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py,sha256=bdoiRdFOvWHZPNT6F-CzHSkcR_JDw63jajArkF2OVIg,4645
vllm/model_executor/layers/fla/ops/cumsum.py,sha256=qhS_Lny-roiECMEDvtWBE2RpiFYDKjqWxwQONRdL2gA,8224
-vllm/model_executor/layers/fla/ops/fused_recurrent.py,sha256=W9qGwriB3oaeB1BLmMSmfs_Oy3NLD_BJg9p2wE668YA,13510
-vllm/model_executor/layers/fla/ops/index.py,sha256=cg2IV5Tn7cvDzPw8fTdI96sIYFy87tk960b0nPJFXSE,1221
-vllm/model_executor/layers/fla/ops/kda.py,sha256=TWHF34L4V1h-Ff0TpASmp7EmnvrpuIFarENkgtduLR4,38297
+vllm/model_executor/layers/fla/ops/fused_recurrent.py,sha256=Oio8UkWsuhJ68ybws_KHs1J3ao6x4fbkDCSnZ_rjwu8,21309
+vllm/model_executor/layers/fla/ops/fused_sigmoid_gating.py,sha256=AAq4mWr5eI_biEOmo7kYM-ehTIrMDh6gc6U2Mw9ky28,9166
+vllm/model_executor/layers/fla/ops/index.py,sha256=H6Nrd4M8Vs3B7Brz6jgiXYrp4frsWyGM9Ir1SC3fFaA,1185
+vllm/model_executor/layers/fla/ops/kda.py,sha256=mFUW0Pg9S-efqJQe5hwAMx8bbBxjduKuExMm2TdN1eY,38265
vllm/model_executor/layers/fla/ops/l2norm.py,sha256=HfUKWKl1EvrUKcowbDu4E5QPQCgZbPFumK43eli0Q6Q,3978
-vllm/model_executor/layers/fla/ops/layernorm_guard.py,sha256=ToZ3wAZ66DJ9RZN5FC0i4kXOZgEOkc3nr1gh-mZEhz8,12272
+vllm/model_executor/layers/fla/ops/layernorm_guard.py,sha256=pdEFfJWMfrjKALIqqCejzybUCUaYYLlNZwywdf98d_E,13019
vllm/model_executor/layers/fla/ops/op.py,sha256=atkZKsyTmBUTa68GnRk0VfEVfKAbmRoFLYtQlf8_JBM,1712
vllm/model_executor/layers/fla/ops/solve_tril.py,sha256=ARfbUgKOHld2Zbq3VlgjP07bjczFSCCQc0MRLmABcV8,19385
vllm/model_executor/layers/fla/ops/utils.py,sha256=8RRhqjpzSTv0DUsZpUor_qzrBrX6FkFuOOpha_HpmGc,6374
-vllm/model_executor/layers/fla/ops/wy_fast.py,sha256=TqyZCDUrAUo0ZFdtHPRY4ntUxavxsYsAGGj0m23oyiI,4372
-vllm/model_executor/layers/fused_moe/__init__.py,sha256=CLy7BL0W8uOuP1IuHlWwim3olUZfQmcX2VdF-eRn2Pc,4110
+vllm/model_executor/layers/fla/ops/wy_fast.py,sha256=9xUY538LvcvXnqKm2jt16wOUNBszvc41owl563RfCDQ,4368
+vllm/model_executor/layers/fused_moe/__init__.py,sha256=7RbCAApAQ4x32dvPzaIc91FP7NlqbUbs9x7cnb51_N0,4202
vllm/model_executor/layers/fused_moe/__pycache__/__init__.cpython-312.pyc,,
vllm/model_executor/layers/fused_moe/__pycache__/activation.cpython-312.pyc,,
vllm/model_executor/layers/fused_moe/__pycache__/all2all_utils.cpython-312.pyc,,
@@ -932,8 +950,6 @@ vllm/model_executor/layers/fused_moe/__pycache__/modular_kernel.cpython-312.pyc,
vllm/model_executor/layers/fused_moe/__pycache__/moe_align_block_size.cpython-312.pyc,,
vllm/model_executor/layers/fused_moe/__pycache__/moe_permute_unpermute.cpython-312.pyc,,
vllm/model_executor/layers/fused_moe/__pycache__/mori_prepare_finalize.cpython-312.pyc,,
-vllm/model_executor/layers/fused_moe/__pycache__/pplx_prepare_finalize.cpython-312.pyc,,
-vllm/model_executor/layers/fused_moe/__pycache__/prepare_finalize.cpython-312.pyc,,
vllm/model_executor/layers/fused_moe/__pycache__/rocm_aiter_fused_moe.cpython-312.pyc,,
vllm/model_executor/layers/fused_moe/__pycache__/routed_experts_capturer.cpython-312.pyc,,
vllm/model_executor/layers/fused_moe/__pycache__/shared_fused_moe.cpython-312.pyc,,
@@ -945,10 +961,10 @@ vllm/model_executor/layers/fused_moe/__pycache__/unquantized_fused_moe_method.cp
vllm/model_executor/layers/fused_moe/__pycache__/utils.cpython-312.pyc,,
vllm/model_executor/layers/fused_moe/__pycache__/xpu_fused_moe.cpython-312.pyc,,
vllm/model_executor/layers/fused_moe/__pycache__/zero_expert_fused_moe.cpython-312.pyc,,
-vllm/model_executor/layers/fused_moe/activation.py,sha256=OiQlBhQsJbpZefi-ZOFAg2fF_qdQv5fZ_7EyX9ce-SI,4919
-vllm/model_executor/layers/fused_moe/all2all_utils.py,sha256=nXmupKl6wIitp3bwdAJd7w9H_URRVdncBJs_7R_2KYA,9574
-vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py,sha256=V4XdBBljjms2ToISaHvJY0-m33nl_hUTvNLmSWG3QaY,16068
-vllm/model_executor/layers/fused_moe/config.py,sha256=GTQD51HOaDCzWNzQADHsnKQzd5d1cT3tRBqhimZ7cmw,41671
+vllm/model_executor/layers/fused_moe/activation.py,sha256=U9QeFMk9XFPMdkhm91ncYvmobf9QsGOOk-vYFhI5HSQ,4733
+vllm/model_executor/layers/fused_moe/all2all_utils.py,sha256=Rb5piZhuzyXHGzGCrHjxePI2t0NU3o3qohqGVRD7HEc,8105
+vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py,sha256=gU7jzYR2hfgoRmjlTugnlNw5cE6jURjKDRAipWxSUWQ,16059
+vllm/model_executor/layers/fused_moe/config.py,sha256=Guwe4HSIbLxcwwY0XA5YtHWLDFdWbDSUB6FtNZgE1D4,41748
"vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json",sha256=iNGsE2ZeVnQEnN4A8UJ9Jv0d3hbRF2MJ9oBgjup5Szk,2737
"vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json",sha256=hH5rRN9Wtyv35azxMzyUMHWtiKgOHev5tNjIG8j6dsE,2751
"vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json",sha256=qPumkNxaHMvVBnEjPe_Xiuz9ICb6Hqc-9I1DAR8s3gA,4130
@@ -1016,6 +1032,8 @@ vllm/model_executor/layers/fused_moe/config.py,sha256=GTQD51HOaDCzWNzQADHsnKQzd5
"vllm/model_executor/layers/fused_moe/configs/E=128,N=928,device_name=NVIDIA_H100_80GB_HBM3.json",sha256=iplHjPj366nNBu44ZHlCRe7aMaV2T4QthXINmkWi218,3269
"vllm/model_executor/layers/fused_moe/configs/E=128,N=928,device_name=NVIDIA_L40S.json",sha256=l59WyUnnymmJQykoJzEvvkzwlhIlmyEeVtvCFRzXGRM,3279
"vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H20.json",sha256=JtcHRlPz8xQEAqJ9EWI63oYvdmjQFG6VTHqtt85VOSA,3221
+"vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H200,dtype=fp8_w8a8.json",sha256=LJgf-vNZtRHawIm17vq6eN6cuzxw256i_MJ-VZGHIt0,3276
+"vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H200.json",sha256=KoFfThKn3gvLmzE4NMKRe9-g7ZA1zDMJJy30ABJSfwE,3269
"vllm/model_executor/layers/fused_moe/configs/E=129,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8.json",sha256=vaiXepgaaixIxU3_vVGcYsA18XAiC8V13phIN5niRrc,3250
"vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json",sha256=f3iM3xm8hGUirJ4ilAIPO6Pe9bs4sm3qaRKMswN9SKE,4731
"vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_B200,dtype=fp8_w8a8.json",sha256=Pux4G7sjfPL22uOJU6t35TXe-VU3OaoPA97TDj9_wZQ,3251
@@ -1250,39 +1268,49 @@ vllm/model_executor/layers/fused_moe/config.py,sha256=GTQD51HOaDCzWNzQADHsnKQzd5
"vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json",sha256=WQLKugnKzlQ0avf1N-41lRHtG6wJ56DfVPv_nip6NBc,3273
vllm/model_executor/layers/fused_moe/configs/README,sha256=W2yIZkP9O8GGlg97We9BJfTtWUtPbuz5ZH3esrrjBX0,572
vllm/model_executor/layers/fused_moe/cpu_fused_moe.py,sha256=GDuJQLv1AzB6X6dFZBOXe24Vo2xN5x3nJgmyKqhpa8w,16076
-vllm/model_executor/layers/fused_moe/cutlass_moe.py,sha256=1NNO0WoatYpynhVerZdFGr67BnyRGN0fXu66YAL4mU4,40797
-vllm/model_executor/layers/fused_moe/deep_gemm_moe.py,sha256=TR0dsJm3n_XE5DmZVGjDhUrOYvnttHY7FFU24SWJnXI,10621
+vllm/model_executor/layers/fused_moe/cutlass_moe.py,sha256=I5-_euo_sy2FcOzyICtjdQrUwvtMz7mzbspdKA0jn0Y,40767
+vllm/model_executor/layers/fused_moe/deep_gemm_moe.py,sha256=3WPsba24gKi7QKL6Vq6r37Ij1aYYJZL1F2O3sVoypuI,10612
vllm/model_executor/layers/fused_moe/deep_gemm_utils.py,sha256=OZ5vnMME0tDx4xSaWTroxkaO3akmLyznT_VWs1IZ42w,13264
-vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py,sha256=ajZ_R-CdU9AsKTDRkK6fGY0yldMCCNsYYXPcvt8uSF4,15657
-vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py,sha256=ewj4N7Bl4gdr2M82CtlaeD7ADHtr1cffVl6DRFIEXI4,15598
-vllm/model_executor/layers/fused_moe/fallback.py,sha256=pU1jxeelW47TOVKlrM6R1t6CTn-MMfXmZ1gY8-btqoQ,6174
-vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py,sha256=AdVfb0Q-ExuEjYgiouFL_rF54yOGLm_4tRjCV6uX5zc,7067
-vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py,sha256=l7knaZMEdAcDR8-RGcravrNeELR0eeogshwi5FxCKx8,11981
-vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py,sha256=jkBpj1WT8KzUurnhSnhz7w8DzJRbXq1pcKbo8ZcI8rc,15070
-vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py,sha256=0Q-wv4UzfxkmS48M5IiSJTa1s0sy51z6UGwfQJv_0RE,15335
-vllm/model_executor/layers/fused_moe/fused_batched_moe.py,sha256=SCqKf05Ih-bRvfj-3HnxSyQNZg7B1mTd3BVGYMU1Jgo,35954
-vllm/model_executor/layers/fused_moe/fused_marlin_moe.py,sha256=q_3FZBo530Isl3Lfw-qDICQZ08QdqLGEBcBxZug4SVk,29612
-vllm/model_executor/layers/fused_moe/fused_moe.py,sha256=lzrRaouy1432Ksd_h9iuDJXLiVEhMhh-ctjyaMu4vrk,83276
-vllm/model_executor/layers/fused_moe/fused_moe_method_base.py,sha256=fBF60yY-oztJDLqNxGgeUEN7cWv3xvqUYM_vQTnbxTo,4598
-vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py,sha256=bF6Wplb4ep1c0kmlbYwWcNhKpWOpXqCyxy8ILiENJ9E,3409
-vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py,sha256=tPuxX4nzZUQPrHlremrDyz_U2aNFIjZdNAjGUEYWsRg,23335
-vllm/model_executor/layers/fused_moe/layer.py,sha256=-ipIIEZZ0Np4M_m8IC7VHHf0voqxJUgADYWA5EZ4QKY,62185
-vllm/model_executor/layers/fused_moe/modular_kernel.py,sha256=THV7B6GCwJIB7wv2N4RARUbthdb_PABX02P8UT-KdWI,51871
+vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py,sha256=wsLE3TDA9d_mqm-b3586sqAE3XCr3y4ZK7ZKHPHOcSI,15764
+vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py,sha256=d6A-1FC3IiDbQmJUmuEsFfjbtjIyCf-Rp6uKWhz-MZU,15607
+vllm/model_executor/layers/fused_moe/experts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
+vllm/model_executor/layers/fused_moe/experts/__pycache__/__init__.cpython-312.pyc,,
+vllm/model_executor/layers/fused_moe/experts/__pycache__/trtllm_fp8_moe.cpython-312.pyc,,
+vllm/model_executor/layers/fused_moe/experts/__pycache__/trtllm_nvfp4_moe.cpython-312.pyc,,
+vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py,sha256=tjZVW89ghNWv_6JfFbwqyMaNSnTNYPyvNLGyrlx3fEY,12948
+vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py,sha256=O-DrUoyGWgQ65EB98yUpOGntAxLhlbgaKjOX5wEb6CI,11862
+vllm/model_executor/layers/fused_moe/fallback.py,sha256=b7uGJJs_NqZkjZLGH1AXRzv4KnbbPtAFB0GQJG-IYGs,6120
+vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py,sha256=fswvle7q9qP8Op8jVbPsjthuu0lGqs6dFG_48WlZ7qA,7144
+vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py,sha256=O-XyogSPLBaqZFZxdANa_81EYVC8CGT1_MRTxPqAqfs,11972
+vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py,sha256=hUqPWz2wxXBTvkvnZb6E4FgtdTMo2EXRg06I0w3Cpbg,15061
+vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py,sha256=yplWuWVHhtMClU7DZaUk3IN1Ldr_TaQ9WwSkTdF5PWo,4574
+vllm/model_executor/layers/fused_moe/fused_batched_moe.py,sha256=FqLxGcRCoFL4THzP67P-IXiAlxN05D_EpuvPdjzPIL8,35952
+vllm/model_executor/layers/fused_moe/fused_marlin_moe.py,sha256=By6-b6VX0FC4i6O1ttAIHvJoM4WTS1DshhDkC5BbqlU,29603
+vllm/model_executor/layers/fused_moe/fused_moe.py,sha256=8_cvvSE5Hk6tyksBPFnlnJQboIgMF0Rqq-iY_jY9ang,85591
+vllm/model_executor/layers/fused_moe/fused_moe_method_base.py,sha256=3dPVwK0OrOrPf6ZzSId4X_suscOH_rCZ3s20mEOq63Q,4453
+vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py,sha256=b395WXOpgOJoL554tkS13h5GKADyT9Gaw2Puvpei7MM,3430
+vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py,sha256=e6x5_qaZdwc7A_MyYwfzVF0jvOyVcJtGE94Ir4P6pV4,27690
+vllm/model_executor/layers/fused_moe/layer.py,sha256=99JSTy05tAKbP7umntGGpOwjBOnKTtYi-MgLLOPWvno,66842
+vllm/model_executor/layers/fused_moe/modular_kernel.py,sha256=bYIDbRbeMIrBuPbe9G_ZA9dVghbtGGgVdaFKvgOAglc,63407
vllm/model_executor/layers/fused_moe/moe_align_block_size.py,sha256=w_f8IIeDbwFgoyq5nOsdj26Hx5NnnaNN1nLiJct_oxs,8339
vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py,sha256=ztL9u2t25hNBKFgD0Qj1buDbVmWlcIBgF9XjtV02Xok,4992
-vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py,sha256=Rm85qW912Y_e7vFHCeFE4lrnigIYyp7XDfVRmhJxZf0,4121
+vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py,sha256=6B-0hMQjOGSGZIhsbfTcQMGv08o54jdhgHZbsT2YGPY,4128
vllm/model_executor/layers/fused_moe/oracle/__init__.py,sha256=48Vcuw16i9R99vM3iDVkRkX107yJtFeTtdWgLQE-XFg,107
vllm/model_executor/layers/fused_moe/oracle/__pycache__/__init__.cpython-312.pyc,,
vllm/model_executor/layers/fused_moe/oracle/__pycache__/fp8.cpython-312.pyc,,
vllm/model_executor/layers/fused_moe/oracle/__pycache__/nvfp4.cpython-312.pyc,,
vllm/model_executor/layers/fused_moe/oracle/__pycache__/unquantized.cpython-312.pyc,,
-vllm/model_executor/layers/fused_moe/oracle/fp8.py,sha256=QyPaQzy82JVVRTsL6X2uFtqF3FUnRLO1RrvdZGy-oLY,22226
-vllm/model_executor/layers/fused_moe/oracle/nvfp4.py,sha256=YM-prlv66hDjT3rGfSnjahPt2PcWOSaxjr_R3fs92sA,15875
-vllm/model_executor/layers/fused_moe/oracle/unquantized.py,sha256=jie64ySw2lz7Sb2ZGLqfKqWuq0G1TO5NFiFIRnsL00s,9533
-vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py,sha256=PND1lFLLUKz4KygyrX_1HefXzocVVeUAd07cobUXCME,12218
-vllm/model_executor/layers/fused_moe/prepare_finalize.py,sha256=eOvUSMWpdaiJuRTqPv98rHKNxI1Wa4_QIXDpi_E3FBk,7337
-vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py,sha256=EuP6YL2gFZEKLylrVHtBatmrzLHe_Tpzz1RplYpolv0,14004
-vllm/model_executor/layers/fused_moe/routed_experts_capturer.py,sha256=z7ufSRlsWlFAl3ULl1R2wLD650ouQ0MP4LinCbfego4,11568
+vllm/model_executor/layers/fused_moe/oracle/fp8.py,sha256=5PMXJ5zNFAqmGgPteG2hT95EBT86PdCzXM2Yi-6kL3M,20299
+vllm/model_executor/layers/fused_moe/oracle/nvfp4.py,sha256=dI6Okdrp6V6rDhMEYOudxU3sOmcXMdwtcFOsjuprNTM,14537
+vllm/model_executor/layers/fused_moe/oracle/unquantized.py,sha256=DS14JoXWTLDnQrkorcQZ8vhXY6EI5os0GP_osFFVT7Q,8916
+vllm/model_executor/layers/fused_moe/prepare_finalize/__init__.py,sha256=iwuOYZXkbmX0D3Udt6LpGTFlyo_mlK1jEyov1cxvqDk,822
+vllm/model_executor/layers/fused_moe/prepare_finalize/__pycache__/__init__.cpython-312.pyc,,
+vllm/model_executor/layers/fused_moe/prepare_finalize/__pycache__/naive_dp_ep.cpython-312.pyc,,
+vllm/model_executor/layers/fused_moe/prepare_finalize/__pycache__/no_dp_ep.cpython-312.pyc,,
+vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py,sha256=pqx-xzTaWh_v4wpg7oQKOZB0OknksFQi3yuju7ILBiI,8276
+vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py,sha256=4QjzExh0Wtb15JmAxOchreKDo1pdRMaS1XqgswXyKTA,4597
+vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py,sha256=9GSW36FKmdfkBFKjIGb14Sj8SiiJg0vvV56tQcILy3o,13995
+vllm/model_executor/layers/fused_moe/routed_experts_capturer.py,sha256=3b-7pwPTN_rvMase5plWxYxQnV5I6w0TXh3v0yUkHM0,11634
vllm/model_executor/layers/fused_moe/router/__init__.py,sha256=48Vcuw16i9R99vM3iDVkRkX107yJtFeTtdWgLQE-XFg,107
vllm/model_executor/layers/fused_moe/router/__pycache__/__init__.cpython-312.pyc,,
vllm/model_executor/layers/fused_moe/router/__pycache__/base_router.cpython-312.pyc,,
@@ -1290,37 +1318,39 @@ vllm/model_executor/layers/fused_moe/router/__pycache__/custom_routing_router.cp
vllm/model_executor/layers/fused_moe/router/__pycache__/fused_moe_router.cpython-312.pyc,,
vllm/model_executor/layers/fused_moe/router/__pycache__/fused_topk_bias_router.cpython-312.pyc,,
vllm/model_executor/layers/fused_moe/router/__pycache__/fused_topk_router.cpython-312.pyc,,
+vllm/model_executor/layers/fused_moe/router/__pycache__/gate_linear.cpython-312.pyc,,
vllm/model_executor/layers/fused_moe/router/__pycache__/grouped_topk_router.cpython-312.pyc,,
vllm/model_executor/layers/fused_moe/router/__pycache__/router_factory.cpython-312.pyc,,
vllm/model_executor/layers/fused_moe/router/__pycache__/routing_simulator_router.cpython-312.pyc,,
-vllm/model_executor/layers/fused_moe/router/base_router.py,sha256=cLvlX4sFRnt9nC44GLThNoLIePmTcKUhFm3xS7XbE30,9601
+vllm/model_executor/layers/fused_moe/router/base_router.py,sha256=BMsX5xlhNYqY0xTu2cS4S1WZvX6TfrWe2bBbhx1n1Wo,9652
vllm/model_executor/layers/fused_moe/router/custom_routing_router.py,sha256=vDfX3qXVCb5-AXK4pKo-EymWoBxWODStPivZVPDpWl0,2122
vllm/model_executor/layers/fused_moe/router/fused_moe_router.py,sha256=DlbIBHHQs-kqisruGPe7vKzN_UeTu4ogcC-K0t9WNro,1261
-vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py,sha256=anUL0jqvRd24aTjo_aRwJDNVcdNjGDTZOFrG4ZwypC0,6144
-vllm/model_executor/layers/fused_moe/router/fused_topk_router.py,sha256=F60qRZ3sr1J-OCRopf6PTWK5QqGWbaPzPZGuy6zqi-c,4941
-vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py,sha256=hKXp091Heu6Ey0sjLXqSzd29gCwrcWjNgPBWRkGznWw,12916
-vllm/model_executor/layers/fused_moe/router/router_factory.py,sha256=Vegx6kR16u9ADbfCaN8tbL1rRpFs-zPo-WpjxrIsm6A,6337
+vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py,sha256=2VU8IAvMBEblib-bX6KgJyhzY-3cdR-hJyp9z63jAcU,6237
+vllm/model_executor/layers/fused_moe/router/fused_topk_router.py,sha256=uBDhmVUzKbgz3xKNu1s0deV0cxsx0ArLLLwIK_rKSaY,5061
+vllm/model_executor/layers/fused_moe/router/gate_linear.py,sha256=qVUylcrqPZIPO92jJy6kj7rynnCIhIadRk5iVNFjsYE,4255
+vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py,sha256=SvG7t755eUNPYlE4sJymxNLgvrh_Q5gL6Fw46OlOw0E,10288
+vllm/model_executor/layers/fused_moe/router/router_factory.py,sha256=UZPMhGivUlIj95GAMvnC8Rc6EPW4pw6FqlMk50HozjM,6337
vllm/model_executor/layers/fused_moe/router/routing_simulator_router.py,sha256=tD040EnxUfoa8wKjjJj2KqcOhgIFYxXC31r6Hfu6qhk,12060
vllm/model_executor/layers/fused_moe/runner/__init__.py,sha256=48Vcuw16i9R99vM3iDVkRkX107yJtFeTtdWgLQE-XFg,107
vllm/model_executor/layers/fused_moe/runner/__pycache__/__init__.cpython-312.pyc,,
vllm/model_executor/layers/fused_moe/runner/__pycache__/default_moe_runner.cpython-312.pyc,,
vllm/model_executor/layers/fused_moe/runner/__pycache__/moe_runner.cpython-312.pyc,,
-vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py,sha256=9ve4iCY0qAJUChLCGIqWyMuYsRsioL7MarCLJCWFyIs,30270
+vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py,sha256=qOK4FIcc-hvmrjeDQAgWxPAZ6J8-3j8Ew39TdPMAt1c,30096
vllm/model_executor/layers/fused_moe/runner/moe_runner.py,sha256=nQnuvW9D9DsMvA9Y_wWVVMvxozQD6UQg_WdsjFVOaZY,992
vllm/model_executor/layers/fused_moe/shared_fused_moe.py,sha256=xrkplE36sFIWFkhEiCrb06UmYJTuyNtx2HWjqZPjsfA,2161
-vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py,sha256=sLKn82zPW7Yij8Bcrt4sXW78SyuwIwDELfiqsXaTxsw,5801
-vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py,sha256=O5wADe6kTwJkHXoMKypPcxETSPpFPJfDankwEZgClQk,2684
-vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py,sha256=bADax16fYxG-wKPYuW9idzwIlpuVO9MTk7OeTczSMPI,2857
-vllm/model_executor/layers/fused_moe/trtllm_moe.py,sha256=i1J7aU7NMzQ-rxkgatIT-7sMGtwhqQ20D1Wr7uQQ3F4,6338
-vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py,sha256=yU1xzCy2LeXuAe3nMXEAtXnKLO0D1fgt66_yhdxjl4M,14792
+vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py,sha256=XC84o7emdFIyVMP8YfndvtavqfVPkDzauLGjheEkG4I,5874
+vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py,sha256=qLDd3tjuEfcdSZHU3apLMGQ5ioFwpa8Zb1oQ0JC4Iw0,2657
+vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py,sha256=4RKgs0D0xjTY2E2JFDfTLyRLdXZ17rAI5Usn_B4Nkqg,2830
+vllm/model_executor/layers/fused_moe/trtllm_moe.py,sha256=QcyXsi8Uufh-Se4R4etoVFjc1ZPOi07pvOq1B79Ji8c,6329
+vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py,sha256=gsbKewhv312LLpk03lp6-MKpPuv1RH04zN7rbIbO3-M,15008
vllm/model_executor/layers/fused_moe/utils.py,sha256=HmnOO-olejNBBTMf3XVtYR5mT3tc1gshzLeHjVMZ2Jk,11712
-vllm/model_executor/layers/fused_moe/xpu_fused_moe.py,sha256=fvl7rUCDeuLqtANnPkaMQCysc7CWmZKI_D4D7_RmaY8,4676
+vllm/model_executor/layers/fused_moe/xpu_fused_moe.py,sha256=6CD_APqC4SIK4qXBCTKEiWC1gxNTOYsjiQB5PsL0EIA,4667
vllm/model_executor/layers/fused_moe/zero_expert_fused_moe.py,sha256=4LiP1U9X2ESu_SOVdZ1SHrHZOe7XzdqBt2Ee0RDQPgM,7676
vllm/model_executor/layers/kda.py,sha256=9gozc7-nXg7ubhTPk_CzE0c1oWVNGZcxlP6SaWepRyo,15497
-vllm/model_executor/layers/layernorm.py,sha256=bsZuPCvDsDZo1-KnSawYwKoIGidSdMUB6j2yPu7XpU8,25348
+vllm/model_executor/layers/layernorm.py,sha256=k3Q_kdNGEO-m6hxm83Oog2SvPCju8BVogJLUFH9FEJs,25645
vllm/model_executor/layers/lightning_attn.py,sha256=yknJVad9ilzRGFdFhvjkoe_xEiffNp_KopUsd32Qzf8,20864
-vllm/model_executor/layers/linear.py,sha256=NzlAcv23OLV5PnN5IjF4ey7CtJkvhFt3IaFDcUUcSeY,62213
-vllm/model_executor/layers/logits_processor.py,sha256=UyFpVbQLr9Y7JxNxYtBOb6Tv9mIIB6RGhJB93Wf1dSE,6208
+vllm/model_executor/layers/linear.py,sha256=kr2f0QFQEtHYOIeB8KCJhCkVKomdQiAk3wCamy4Z4Wk,66525
+vllm/model_executor/layers/logits_processor.py,sha256=1H2XdBVL__PwzlQXlZ4lNPOobQNng222HMBZioXNeOE,6396
vllm/model_executor/layers/mamba/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/model_executor/layers/mamba/__pycache__/__init__.cpython-312.pyc,,
vllm/model_executor/layers/mamba/__pycache__/abstract.cpython-312.pyc,,
@@ -1330,10 +1360,10 @@ vllm/model_executor/layers/mamba/__pycache__/mamba_mixer2.cpython-312.pyc,,
vllm/model_executor/layers/mamba/__pycache__/mamba_utils.cpython-312.pyc,,
vllm/model_executor/layers/mamba/__pycache__/short_conv.cpython-312.pyc,,
vllm/model_executor/layers/mamba/abstract.py,sha256=8HIHxbE1th8PUbL2rIfhM81n9WhanyUwvPezabHuxsc,2213
-vllm/model_executor/layers/mamba/linear_attn.py,sha256=J00weDmLflZV1jshtA02XasPMHTrIYl5WYas5txSS4U,15254
-vllm/model_executor/layers/mamba/mamba_mixer.py,sha256=VNWCtLeI_E2VdgW2TuqEkhpKPYUu0cOvYY4bVMP2UwQ,19660
+vllm/model_executor/layers/mamba/linear_attn.py,sha256=Qm9bDDzN2o5JrC_SlEsZaHkty58oDSea9fiv3MJVXXM,16562
+vllm/model_executor/layers/mamba/mamba_mixer.py,sha256=pKZFo82YqqMD0pR0k65CC-EcRcJI0DxnV78rvtXMh4A,19902
vllm/model_executor/layers/mamba/mamba_mixer2.py,sha256=Fdr4hKModFKAP1WDTSnu-exWcqqTUuiZxjqj2YorAhA,38444
-vllm/model_executor/layers/mamba/mamba_utils.py,sha256=3w7V0LrNiZHBkFrZG-OBMYbQd7ED56K7jv3-rzPB6fQ,10372
+vllm/model_executor/layers/mamba/mamba_utils.py,sha256=NQ7XbYBIv2_VxFyZ_trUxaI3zYXSr-KszHkwn7szGzw,10326
vllm/model_executor/layers/mamba/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/model_executor/layers/mamba/ops/__pycache__/__init__.cpython-312.pyc,,
vllm/model_executor/layers/mamba/ops/__pycache__/causal_conv1d.cpython-312.pyc,,
@@ -1344,16 +1374,16 @@ vllm/model_executor/layers/mamba/ops/__pycache__/ssd_chunk_scan.cpython-312.pyc,
vllm/model_executor/layers/mamba/ops/__pycache__/ssd_chunk_state.cpython-312.pyc,,
vllm/model_executor/layers/mamba/ops/__pycache__/ssd_combined.cpython-312.pyc,,
vllm/model_executor/layers/mamba/ops/__pycache__/ssd_state_passing.cpython-312.pyc,,
-vllm/model_executor/layers/mamba/ops/causal_conv1d.py,sha256=g4ahWb0tzWIbXky0huQtfCJYA5Hn6SJ4NazsQ6-HZhE,49672
+vllm/model_executor/layers/mamba/ops/causal_conv1d.py,sha256=bi_Saq1xSfzL5cIzq7ABpxNyFCM7JEw7nWCm1a92pc8,49670
vllm/model_executor/layers/mamba/ops/layernorm_gated.py,sha256=KA5JRI8jEStDEkN6U6dqaOBQSCWugxUelAhZISZyf7A,5228
-vllm/model_executor/layers/mamba/ops/mamba_ssm.py,sha256=KRP4YhI9MQCFi_HpDxjkkR-fTxB3dqyjHFVVlhERVv0,20010
+vllm/model_executor/layers/mamba/ops/mamba_ssm.py,sha256=s5mUlU7lK5hoFwzNeViBpra_GtQrMWklAqb7EuaU16s,20118
vllm/model_executor/layers/mamba/ops/ssd_bmm.py,sha256=mlQ258Jsexnb5vHGh09YW0l-rF3SZM4H1R29o3hb2-s,6842
vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py,sha256=Yow2h10PSWWKvrwT3TRz7UZTCq--99j7MZToXN59aCc,15550
vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py,sha256=6-0Dl2-frs5av7ra-H0xHWPupVaSo0Vex4d5vZVq41Q,24226
vllm/model_executor/layers/mamba/ops/ssd_combined.py,sha256=oKPphPh0wjU3XT3k5U2t9cHb4jDLuYzIQuWK9w0sOcs,7353
vllm/model_executor/layers/mamba/ops/ssd_state_passing.py,sha256=OWIEJ6WjdYHbX_7HRzhPt50oejUjvzhBTST6bQPhjas,5275
vllm/model_executor/layers/mamba/short_conv.py,sha256=peKNJ66UhX78ZWjfV2kLedNlcAOPJn58vQ1oUiCbJ1Y,8257
-vllm/model_executor/layers/mla.py,sha256=SNpV2Bi1BnP9iOcm0jnN1PL9V2i8LTcJ4kepi6Zu6yI,6440
+vllm/model_executor/layers/mla.py,sha256=3g8P4Zn2t_lWHbmBVws_3NvIsMNWu_agE6lIQ5s4HV8,6516
vllm/model_executor/layers/pooler/__init__.py,sha256=AvMJtY1sQcP_n7SWB9Zp0F7dBfgNxKOpoUfjuZfP1_k,176
vllm/model_executor/layers/pooler/__pycache__/__init__.cpython-312.pyc,,
vllm/model_executor/layers/pooler/__pycache__/abstract.cpython-312.pyc,,
@@ -1380,7 +1410,7 @@ vllm/model_executor/layers/pooler/tokwise/__pycache__/poolers.cpython-312.pyc,,
vllm/model_executor/layers/pooler/tokwise/heads.py,sha256=CCwLTzPIZ_s344JSCPoZWNgX6iO-vWl03uB7AbBFTLo,4149
vllm/model_executor/layers/pooler/tokwise/methods.py,sha256=ewPD-pazZ_mxULxR5ZW1m-3tfHoHz8SKlu_OKlCEO7c,4240
vllm/model_executor/layers/pooler/tokwise/poolers.py,sha256=BU6GwjUveV7Mx6aSuVdd9HBDwFY8MEMHJqreCC9QhpE,4096
-vllm/model_executor/layers/quantization/__init__.py,sha256=p-p8jPHEkQyu7nkhbkESZEo8D45OhooF8OLMUYjzdLE,5324
+vllm/model_executor/layers/quantization/__init__.py,sha256=2Vn15_EIGAHFE4BLagygwzfSzluLbpcTGZCEM4NJgB8,5550
vllm/model_executor/layers/quantization/__pycache__/__init__.cpython-312.pyc,,
vllm/model_executor/layers/quantization/__pycache__/awq.cpython-312.pyc,,
vllm/model_executor/layers/quantization/__pycache__/awq_marlin.cpython-312.pyc,,
@@ -1406,8 +1436,9 @@ vllm/model_executor/layers/quantization/__pycache__/ptpc_fp8.cpython-312.pyc,,
vllm/model_executor/layers/quantization/__pycache__/qutlass_utils.cpython-312.pyc,,
vllm/model_executor/layers/quantization/__pycache__/schema.cpython-312.pyc,,
vllm/model_executor/layers/quantization/__pycache__/torchao.cpython-312.pyc,,
+vllm/model_executor/layers/quantization/__pycache__/w8a16.cpython-312.pyc,,
vllm/model_executor/layers/quantization/awq.py,sha256=EqFKOHtFQepFitFjmAdMMnrjeeRCwlQvjcL-k2rZJ1o,10208
-vllm/model_executor/layers/quantization/awq_marlin.py,sha256=FnOD0LI11ztLDgbVSf64BLEPOb8y4UU1la3UWeiDd5k,36870
+vllm/model_executor/layers/quantization/awq_marlin.py,sha256=uHW2Q0Pa6MV7S0j-I08opJjlq2napcjJeiVj-c-a8G0,38811
vllm/model_executor/layers/quantization/awq_triton.py,sha256=fRH9rX5jDSqQtdmcrXwxtPtQEiKxTUdPkAJmtvW_zng,11620
vllm/model_executor/layers/quantization/base_config.py,sha256=ZY_lnCDF33boaqF9u3DPV3D9-KWJ37MRr8amKC9aVRo,6563
vllm/model_executor/layers/quantization/bitsandbytes.py,sha256=JiHCrscq3pKS_FFaETQ9rpzEmripvXSxK_AbIKi04IE,20250
@@ -1417,9 +1448,9 @@ vllm/model_executor/layers/quantization/compressed_tensors/__pycache__/compresse
vllm/model_executor/layers/quantization/compressed_tensors/__pycache__/compressed_tensors_moe.cpython-312.pyc,,
vllm/model_executor/layers/quantization/compressed_tensors/__pycache__/triton_scaled_mm.cpython-312.pyc,,
vllm/model_executor/layers/quantization/compressed_tensors/__pycache__/utils.cpython-312.pyc,,
-vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py,sha256=sK8R-FeQtwOQ6TIcaBz87nMrC8D6vafmR4uW-wBahQo,43488
-vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py,sha256=id0owM9G5wvMfL7WIelnELsPY7lvVP6WZcnFgECzfxA,125810
-vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py,sha256=9ulnB6i8TA_ET4N84o56K8j-vxYHpcjL_csVUZCo0kI,1357
+vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py,sha256=hlDlYVo7JQcuu4C2lCdtP03QrocshL1zJfPYttNgiGY,43131
+vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py,sha256=JjZpIZ0RcoJLz73U2gLodCg2yl-R4NSxayR-kM0EOi0,185236
+vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py,sha256=9MgIcmvg_tbj7cvK1V2b2WNdsou4Etb7xStFdvnZ9X8,1298
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/__init__.cpython-312.pyc,,
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/compressed_tensors_24.cpython-312.pyc,,
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/compressed_tensors_scheme.cpython-312.pyc,,
@@ -1441,7 +1472,7 @@ vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_te
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_int.py,sha256=cwo918k-M8UUqf0o7sN5IttmgnA--mWudwLplZbwH8M,5124
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py,sha256=IX-MRlV3KEnSaReMS5oP4cnzYD2qMOjziPh2A1v3plg,5365
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py,sha256=3wzNEzEaS75TxKnymt3OkXuD-X4G7RengWzAFlPGs5g,7650
-vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py,sha256=5_7c-lYG0XRXr-6dmB4w7w4Vt8QJlx2O98XpDt08cBY,7628
+vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py,sha256=hEh9SeMz0AXnF_9SnVB51SJwOgLZgT2FJI4SNIoSk5g,5022
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py,sha256=pUnlwIWELDxSfKyOUnIaF1LKydCNhpW7cdEzTk4YNGU,7875
vllm/model_executor/layers/quantization/compressed_tensors/transform/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/model_executor/layers/quantization/compressed_tensors/transform/__pycache__/__init__.cpython-312.pyc,,
@@ -1460,33 +1491,33 @@ vllm/model_executor/layers/quantization/compressed_tensors/utils.py,sha256=MGBWA
vllm/model_executor/layers/quantization/cpu_wna16.py,sha256=CCh4moMa9rDrpMxDViuD9bMh-AObWLugjiLX8Wagovc,9713
vllm/model_executor/layers/quantization/experts_int8.py,sha256=uEOc_6HbAxRHk4kGVrVXKXY-NDmsZhKPzOz8f9O7E-8,6891
vllm/model_executor/layers/quantization/fbgemm_fp8.py,sha256=4WYLD1DECYNdK1vlV9JNm5xQz0tA-QW5PZLuwiqqT1M,6537
-vllm/model_executor/layers/quantization/fp8.py,sha256=ftGOMgzAt8hGmBBHHPvL1iZ0dV36zB7N_36EISWkkT4,51874
+vllm/model_executor/layers/quantization/fp8.py,sha256=BbDGeeAJcdzPT3TcEAeBBYnG368Lz5vwXFG_qJQL87g,49438
vllm/model_executor/layers/quantization/fp_quant.py,sha256=-_Y67uqhWYnF3-VgelEieYYyJdHGFNHQU7N0EZRkS30,12808
-vllm/model_executor/layers/quantization/gguf.py,sha256=vNzIk01jPHcfybG1DpW84SL2ZSCnGRwOjCPLFZ0lSJg,23210
-vllm/model_executor/layers/quantization/gptq.py,sha256=9N215iLPUsDYI4bIaaEh690XvH3Zei7VgXAMysOxEcI,14814
-vllm/model_executor/layers/quantization/gptq_marlin.py,sha256=jFL9Z4WCqD0PpR8R9k9wFBe--hSmWyvUJvOPIve6W5w,34985
+vllm/model_executor/layers/quantization/gguf.py,sha256=UloLOKvRfV26yLl2v5R8eUqaqMXZQ4a9JUP2yuvgC3g,23559
+vllm/model_executor/layers/quantization/gptq.py,sha256=VMFcCvRxWW8Y0yQ8g_0F4lZav5VoHmox3ck7IoeblyQ,14814
+vllm/model_executor/layers/quantization/gptq_marlin.py,sha256=cfXVJepdbqDkiKLKe7F4zZmaiehy4lssLQk4tZs87To,45705
vllm/model_executor/layers/quantization/inc.py,sha256=e1WXuAsEHbeCMhCDL_D-JiLpKuo6xz0o0udCjjbZaqo,16368
vllm/model_executor/layers/quantization/input_quant_fp8.py,sha256=dPgo7XUhapj0dcbi57i-iRxGtXrDfWGbae8Ck2A8pS4,9421
vllm/model_executor/layers/quantization/kv_cache.py,sha256=MEa3GgF2uX3PMmj27T5XaITpXpRKvanEhahDHHMvICE,6616
-vllm/model_executor/layers/quantization/modelopt.py,sha256=gTNqO6NhocH0472tJe9RjacPCP8wpHyWQB-ebVpdcyE,67413
+vllm/model_executor/layers/quantization/modelopt.py,sha256=9DyDh7ZPjCLRDS3FKP2eCEzWHb9ZnrEotzKi4BpZNPQ,68978
vllm/model_executor/layers/quantization/moe_wna16.py,sha256=kDQ--mJ4XsCp95yR_txEEIZGZBP0Pahj_RpTEKQC0Vc,20437
-vllm/model_executor/layers/quantization/mxfp4.py,sha256=Mbaol2gJw0nQRgyEwf2pYF0Ecjp2bCLCkggjmPghUOg,46827
+vllm/model_executor/layers/quantization/mxfp4.py,sha256=SbmLqaO6aMQ8vXINFQq-DvcieorwDywAzENwwf6yRKo,50841
vllm/model_executor/layers/quantization/petit.py,sha256=OxEhXQ83yAPyoz_wQanCo9ZIeIAr77ERIThMW7imJb0,11563
-vllm/model_executor/layers/quantization/ptpc_fp8.py,sha256=U2UlS5PeDUGVqNfI_QIk7DQr0-OhOJi2Psabz7dtyFY,5178
+vllm/model_executor/layers/quantization/ptpc_fp8.py,sha256=dGjDPG2Jry6tream6pXiFza1fc8vZcxJ3mJe29BaChM,5066
vllm/model_executor/layers/quantization/quark/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/model_executor/layers/quantization/quark/__pycache__/__init__.cpython-312.pyc,,
vllm/model_executor/layers/quantization/quark/__pycache__/quark.cpython-312.pyc,,
vllm/model_executor/layers/quantization/quark/__pycache__/quark_moe.cpython-312.pyc,,
vllm/model_executor/layers/quantization/quark/__pycache__/utils.cpython-312.pyc,,
-vllm/model_executor/layers/quantization/quark/quark.py,sha256=zhgVltnFyOecCtZ7uQP6ONV_kulSwe925Erxips6Zpk,22873
-vllm/model_executor/layers/quantization/quark/quark_moe.py,sha256=3j8Ah6ve7RPMGUJyQjTwOFeN5WTdroOqX6BBxJmU2cA,41962
+vllm/model_executor/layers/quantization/quark/quark.py,sha256=2tl5jg2Kz-VJ3YBhiSBwp8-CoR6k6eQZIN2Z_YXir9g,24362
+vllm/model_executor/layers/quantization/quark/quark_moe.py,sha256=s1n8LsqPzkfLpFAZAPzxUhDnr1S27YBTm6mbJxljgvE,47045
vllm/model_executor/layers/quantization/quark/schemes/__init__.py,sha256=A699VuIbNQcvuANSmPjZSu9SmSHuqKfSQhvn9Y7GlpE,343
vllm/model_executor/layers/quantization/quark/schemes/__pycache__/__init__.cpython-312.pyc,,
vllm/model_executor/layers/quantization/quark/schemes/__pycache__/quark_ocp_mx.cpython-312.pyc,,
vllm/model_executor/layers/quantization/quark/schemes/__pycache__/quark_scheme.cpython-312.pyc,,
vllm/model_executor/layers/quantization/quark/schemes/__pycache__/quark_w8a8_fp8.cpython-312.pyc,,
vllm/model_executor/layers/quantization/quark/schemes/__pycache__/quark_w8a8_int8.cpython-312.pyc,,
-vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py,sha256=1_TY1bFMBtkP2N2_CblVGPcHmYyuIorMQAUoH6qLN5w,12912
+vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py,sha256=SD95VzSRowSDcRZvek20PXCERY3xXEydRMVVTSsTh1c,14048
vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py,sha256=VMPNUJAQ6PqGlGmVD0wUx3a2NstTp2aC5Dx6w5ZTHn0,1516
vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py,sha256=VKI1F1WfR7bXnsMpAvipR_PgLoBjGZsnNBUrJzQbrmc,7051
vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py,sha256=58h2xdOgp60F5I8juKR5kfs0Q_tCJ2HnVKsfD4LFxRs,4489
@@ -1501,6 +1532,7 @@ vllm/model_executor/layers/quantization/utils/__pycache__/flashinfer_fp4_moe.cpy
vllm/model_executor/layers/quantization/utils/__pycache__/flashinfer_mxint4_moe.cpython-312.pyc,,
vllm/model_executor/layers/quantization/utils/__pycache__/flashinfer_utils.cpython-312.pyc,,
vllm/model_executor/layers/quantization/utils/__pycache__/fp8_utils.cpython-312.pyc,,
+vllm/model_executor/layers/quantization/utils/__pycache__/gguf_utils.cpython-312.pyc,,
vllm/model_executor/layers/quantization/utils/__pycache__/gptq_utils.cpython-312.pyc,,
vllm/model_executor/layers/quantization/utils/__pycache__/int8_utils.cpython-312.pyc,,
vllm/model_executor/layers/quantization/utils/__pycache__/layer_utils.cpython-312.pyc,,
@@ -1738,27 +1770,29 @@ vllm/model_executor/layers/quantization/utils/allspark_utils.py,sha256=-VRplxaZT
"vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json",sha256=HIoWSUgAOcNaK2kj2YwDjDa23PzQVTT2C2ePW985Ovw,3805
"vllm/model_executor/layers/quantization/utils/configs/N=9216,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json",sha256=gZCqNrWU0B8v6xL1Vgizi1K7fO5gChFMlxp2UbNkOZ4,3254
vllm/model_executor/layers/quantization/utils/configs/README.md,sha256=kfjjurECwd-xH4EDjuueS0Xezi86c_pYu2yELgiw8Ew,102
-vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py,sha256=b_P-NpzJT-f9M_etWKOGj97hQU1ZFx65_hnAccUswAw,20186
+vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py,sha256=RiXhdFi85Oy3rr0_KYzFMhFFJzpTRhU6B-59ObhU5SY,10291
vllm/model_executor/layers/quantization/utils/flashinfer_mxint4_moe.py,sha256=Y1hwfirmxoXradoUqkJ3CeLzHY0KH5jELIDHwjeEN_I,10157
-vllm/model_executor/layers/quantization/utils/flashinfer_utils.py,sha256=sAAMKIgyDSoBIltyMdpIomZ6XafCzcKfi7AGD5wUswg,16759
-vllm/model_executor/layers/quantization/utils/fp8_utils.py,sha256=8QKemKKq3WrbqTQ1SEjGg7bH18pYSeAppw-0Utp5dyo,53968
+vllm/model_executor/layers/quantization/utils/flashinfer_utils.py,sha256=VuifuvCV8lgj6uR7F4pLtl9KQZmFZIo7veKvYY8Dtbw,13212
+vllm/model_executor/layers/quantization/utils/fp8_utils.py,sha256=5HKm6LG5zyyVpCcQQ7AQyLzaBBtbyhsXNV3SSFOzp8Y,54014
+vllm/model_executor/layers/quantization/utils/gguf_utils.py,sha256=UsgKgIbb1sKz2CB8J-q6k7BuVNtwN2zXXgXCCQNgP2c,14654
vllm/model_executor/layers/quantization/utils/gptq_utils.py,sha256=-0GnKH51q6vGtt42RhlbGRkvDI4Yhta4QBG1FlWWV5w,5886
vllm/model_executor/layers/quantization/utils/int8_utils.py,sha256=JI0KrskdcFxZ073TZ-1iLMOgPWN6Gc-BZIA415T8q4s,14424
vllm/model_executor/layers/quantization/utils/layer_utils.py,sha256=CvEUZEj2ZFt7oxcDaRklOJo0IFc-llCiA2bVjybRKFc,1574
vllm/model_executor/layers/quantization/utils/machete_utils.py,sha256=Wv5uFC2UuuoBC_X1YGFOORBf23Zl_Yp5Ck_hFQIbRUA,1699
-vllm/model_executor/layers/quantization/utils/marlin_utils.py,sha256=kIsJzQpmpvhFueA17kgAIUsC82WiJeyo6bPfMXLyXaM,21561
+vllm/model_executor/layers/quantization/utils/marlin_utils.py,sha256=sPn-SwJww6GT4XSgCBdU0KtH9JsFLighLOgqjwdkUPQ,21044
vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py,sha256=L2lfxNwj86u3R3ZClLIECuxpnG1VkI-gw7NY3lCeda0,19454
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py,sha256=U7Ne8oQRqyV4POZKh0XwINlwUcdV7V4otNFPao0etDc,13977
vllm/model_executor/layers/quantization/utils/marlin_utils_test.py,sha256=D8c0-obb1g9qQVAMptBTTF_2SzHzXR5sIV9ant1_-1M,7011
-vllm/model_executor/layers/quantization/utils/mxfp4_utils.py,sha256=mleIf8cK_qaxk_nehd0Y3d-pl3DplvdLjEa-ikaLQIg,6477
+vllm/model_executor/layers/quantization/utils/mxfp4_utils.py,sha256=ow7rlvA-XCOStp35_UPEB092xXfVeImobEcdTjSifhA,6423
vllm/model_executor/layers/quantization/utils/mxfp6_utils.py,sha256=3ccQ0AmjiDMVtIyZ-G5WZ778QDKbpbRqjv8vfNE9u5I,4533
vllm/model_executor/layers/quantization/utils/mxfp8_utils.py,sha256=aod_8kn4USLCdVojSAgPIv_pEvI3gQeopHA36hH5j_U,7496
vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py,sha256=rFpYWNmA1DD4I536b0U0FgjB4zBVDSpN6cpeldjssvw,4633
vllm/model_executor/layers/quantization/utils/nvfp4_utils.py,sha256=xQ9PdFxtD4_LpV7Uv_zym-y0Zt7O00M8jjjxpixaCmw,13078
vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py,sha256=wkKQfLry_jtVO5sDNBDJDajAOUP2YjRixRrYY_8kqpY,2610
vllm/model_executor/layers/quantization/utils/petit_utils.py,sha256=jKuC_wfRf4D55z9IgDsSJQC8ycsBI2SjBOePR8qxDy4,3841
-vllm/model_executor/layers/quantization/utils/quant_utils.py,sha256=ZqlU3Q4PPBFJqAB0HwglvzETTfhStlWUMiwk-8Q8pUM,27903
+vllm/model_executor/layers/quantization/utils/quant_utils.py,sha256=9nx9Tq_EZNcP3FPMDnumFQxVUOVwj57BgxM8GjkeBtg,27974
vllm/model_executor/layers/quantization/utils/w8a8_utils.py,sha256=ltM6EmBuI9dpUz6bp3xnEoab58z5AfbFYig-3y5z2YU,4856
+vllm/model_executor/layers/quantization/w8a16.py,sha256=Et0jZF5U-2tyRAy6cEuM0IYaZx-2Jyj5Ed8l0wjbk3w,3350
vllm/model_executor/layers/resampler.py,sha256=kK_gbASYVbvMoPPN3petUTfNGK8DkdNdihYnFx-JLn4,9877
vllm/model_executor/layers/rotary_embedding/__init__.py,sha256=PllRvEzr6kmY_ZJeGA9p3u_ubdEKhnbtjoHzD0CmgO4,11094
vllm/model_executor/layers/rotary_embedding/__pycache__/__init__.cpython-312.pyc,,
@@ -1779,9 +1813,9 @@ vllm/model_executor/layers/rotary_embedding/__pycache__/ntk_scaling_rope.cpython
vllm/model_executor/layers/rotary_embedding/__pycache__/phi3_long_rope_scaled_rope.cpython-312.pyc,,
vllm/model_executor/layers/rotary_embedding/__pycache__/xdrope.cpython-312.pyc,,
vllm/model_executor/layers/rotary_embedding/__pycache__/yarn_scaling_rope.cpython-312.pyc,,
-vllm/model_executor/layers/rotary_embedding/base.py,sha256=MuIUUjEBL5jIxhPMN9_mIZlqNQivd-5-AmW6795cETk,10129
-vllm/model_executor/layers/rotary_embedding/common.py,sha256=2s9Oiytr2WeKr1ebloMvap-MKho8Yh94YqZw2qL0zTg,8505
-vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py,sha256=nWQRNBkTYYj1yrMIqUCKzlXnTi9TvnEsgp7RNCtv4v0,6143
+vllm/model_executor/layers/rotary_embedding/base.py,sha256=mFFI_QYNResCyYVOBFVNZ-oK_83POWHSsk8ljUmsVG8,10158
+vllm/model_executor/layers/rotary_embedding/common.py,sha256=sHf-fXTKfjulR7st_Fj5bvR8GwpUkWl9gmPuvdv3We8,7885
+vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py,sha256=13K3gZiUXQ1o15jCqXq7XPeZ2YRsKA8BkLRlsMG8hjo,3474
vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py,sha256=vCIgcU7FgX8eymrhLHyZ2MXl9i5-kevyyE-J3ex9Mtc,8461
vllm/model_executor/layers/rotary_embedding/dynamic_ntk_alpha_rope.py,sha256=gkD_OmIqduxIUvBgqEqFu22dVDm4zzOmtJ5SgRQJOAc,1289
vllm/model_executor/layers/rotary_embedding/dynamic_ntk_scaling_rope.py,sha256=6lZQAj1v-nvRbZkLIR9_4a9nzAczr09YByOF7IXDr_U,2681
@@ -1790,16 +1824,18 @@ vllm/model_executor/layers/rotary_embedding/fope.py,sha256=RqRQj8sMSazROYGZjF9f3
vllm/model_executor/layers/rotary_embedding/linear_scaling_rope.py,sha256=5a1OVeRvwUeU35x7xjUbhpc_UNtWd2MRWlZ1QftnxkI,4643
vllm/model_executor/layers/rotary_embedding/llama3_rope.py,sha256=EMeHz5tB4hKDXr6ce4GFfFpt8x03HKYlmedwXUgDcHw,1785
vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py,sha256=ls-dUts73Wyp99QPkk4_ED1fewJVH05BighSzRpCmm4,3313
-vllm/model_executor/layers/rotary_embedding/mrope.py,sha256=1fxcpkC_R_38J0R7zGtnsilMmshAu8c2_nXMYJjipRw,14342
+vllm/model_executor/layers/rotary_embedding/mrope.py,sha256=p6HWG-CXP1TsteD3HrnJPJNwgsLI13DogWo4wZEfo-k,14501
vllm/model_executor/layers/rotary_embedding/mrope_interleaved.py,sha256=U83KHEt0F8tM4jjgDgpybXYW47A3csafhFOT1qXbdgE,6291
vllm/model_executor/layers/rotary_embedding/ntk_scaling_rope.py,sha256=KzRfhq8fROqVm5Eu6ETmQZyvc8LZjIKzde0ZGcr8Pkc,1462
-vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py,sha256=AT89hmBGCmGUtX1hoy4agKZfYdHd98eJnE1fGlzkQVA,5699
+vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py,sha256=XSIA8_Ab5B2QqrF60Tjhy6xyd6eppPBzxcHy6HeUiuo,5108
vllm/model_executor/layers/rotary_embedding/xdrope.py,sha256=MDHhx7D3Lir4BWx_bhmOYIruQfBJCzsD-OV7g2sUfGA,5108
vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py,sha256=rZENzx8gm_Ysbp_OzOn0qQB9v-D1faYWz5ZILHJBbrs,2863
-vllm/model_executor/layers/sparse_attn_indexer.py,sha256=0KOwB0Pb7Ut5FrnD8I6qRlSot7K5lnKNQ_F3lmQdrKc,11877
-vllm/model_executor/layers/utils.py,sha256=nkb9R36XoorqHg0mm5bYvTr_gbEiY61Es7BVVdfSVpE,9720
+vllm/model_executor/layers/shared_fused_moe/__pycache__/shared_fused_moe.cpython-312.pyc,,
+vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py,sha256=0KI0seYhUNqYNH15p_siLdW_CHYc2hec6P6OYI8VmH0,1936
+vllm/model_executor/layers/sparse_attn_indexer.py,sha256=H8JHgeV4igtlXBG1Y_ICGcmd8hetsc-brpTYN7Cikms,17630
+vllm/model_executor/layers/utils.py,sha256=5mO4gzXtSA8o7nq1W3kpQ26IrpWe8mrf_SXFEgP6I58,11933
vllm/model_executor/layers/vocab_parallel_embedding.py,sha256=nf9sq3wXeYmRzEjoVPJ1prWWzGN-VpTxZsIM3GjTGHk,22192
-vllm/model_executor/model_loader/__init__.py,sha256=snspEqkFK3If_Ev_pvYP5fcIOAq49tnfO-HQ1ocsbVs,5019
+vllm/model_executor/model_loader/__init__.py,sha256=LUgqYUBeNGqh0zXN2Gh_VQWBvGgiIFc2ADWNpKBgfKY,5108
vllm/model_executor/model_loader/__pycache__/__init__.cpython-312.pyc,,
vllm/model_executor/model_loader/__pycache__/base_loader.cpython-312.pyc,,
vllm/model_executor/model_loader/__pycache__/bitsandbytes_loader.cpython-312.pyc,,
@@ -1814,7 +1850,7 @@ vllm/model_executor/model_loader/__pycache__/utils.cpython-312.pyc,,
vllm/model_executor/model_loader/__pycache__/weight_utils.cpython-312.pyc,,
vllm/model_executor/model_loader/base_loader.py,sha256=VzYyzhjqk_IgUJuJ9pJm6WYHggNyunYCILg8F2usd94,3113
vllm/model_executor/model_loader/bitsandbytes_loader.py,sha256=QjnWFa8OGprcFE0qI2njvjEEmXgdumgBfyypwRfch3U,34623
-vllm/model_executor/model_loader/default_loader.py,sha256=ms-p8MRpMW0vR1-0u5bCPDUZEDwLFNyljMC_aRxdiLc,11817
+vllm/model_executor/model_loader/default_loader.py,sha256=GbtJE43yTXJmMleW6E4-1XndD6zcZ8f0_Met0eeSYuM,11928
vllm/model_executor/model_loader/dummy_loader.py,sha256=R7xjrtLVejnd9j8zejScUTuY9zeAN6JlpBcNCfeUXJ0,1129
vllm/model_executor/model_loader/gguf_loader.py,sha256=cDZIzZ18naZghuK1Si6d6gxe-GyJwvHdQpA5sjY-qBI,15557
vllm/model_executor/model_loader/reload/__init__.py,sha256=h3WL_a1isSc2RKQBuEzD-0HaYslFIVWPsR7BMPCZVdc,1435
@@ -1836,8 +1872,10 @@ vllm/model_executor/model_loader/sharded_state_loader.py,sha256=KxhCy60owpwAzg8O
vllm/model_executor/model_loader/tensorizer.py,sha256=kqhUvxmpKEb2WKDQdE-8kHqsogZMvFtUpTmPXjH49jA,29902
vllm/model_executor/model_loader/tensorizer_loader.py,sha256=Gq7pYVtrOr8engSSWbW9x1qtoos5HuaS-ZrzWQHvR40,5900
vllm/model_executor/model_loader/utils.py,sha256=LQ6dgzLXaDDpPghN1VaWPtIjCBSap2DyAYWRGgt9Fbk,11020
-vllm/model_executor/model_loader/weight_utils.py,sha256=_AROcisvdtIP-Nopk_oRwecL7Qs0qd3CJmbxtjamv3w,53018
+vllm/model_executor/model_loader/weight_utils.py,sha256=wSjU-TUkeMLTBv-sRrAOulQWtQTnB7Nb28m4LTsw4_w,49996
+vllm/model_executor/models/AXK1.py,sha256=gliagXFSEpAOjNjYxy0RU5tzjGpDsceiS8E2bOCHNtU,45919
vllm/model_executor/models/__init__.py,sha256=JC8KFZavvBlB4woyRgOBd_vcuoiA31nmBOR5dGcJSjo,997
+vllm/model_executor/models/__pycache__/AXK1.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/__init__.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/adapters.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/afmoe.cpython-312.pyc,,
@@ -1851,6 +1889,7 @@ vllm/model_executor/models/__pycache__/aya_vision.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/bagel.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/baichuan.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/bailing_moe.cpython-312.pyc,,
+vllm/model_executor/models/__pycache__/bailing_moe_linear.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/bamba.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/bee.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/bert.cpython-312.pyc,,
@@ -1888,9 +1927,11 @@ vllm/model_executor/models/__pycache__/exaone.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/exaone4.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/exaone_moe.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/exaone_moe_mtp.cpython-312.pyc,,
+vllm/model_executor/models/__pycache__/extract_hidden_states.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/fairseq2_llama.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/falcon.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/falcon_h1.cpython-312.pyc,,
+vllm/model_executor/models/__pycache__/fireredasr2.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/flex_olmo.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/funasr.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/funaudiochat.cpython-312.pyc,,
@@ -2021,6 +2062,7 @@ vllm/model_executor/models/__pycache__/ovis.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/ovis2_5.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/paddleocr_vl.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/paligemma.cpython-312.pyc,,
+vllm/model_executor/models/__pycache__/parakeet.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/persimmon.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/phi.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/phi3.cpython-312.pyc,,
@@ -2070,7 +2112,6 @@ vllm/model_executor/models/__pycache__/step3_vl.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/step3p5.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/step3p5_mtp.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/step_vl.cpython-312.pyc,,
-vllm/model_executor/models/__pycache__/swin.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/tarsier.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/telechat2.cpython-312.pyc,,
vllm/model_executor/models/__pycache__/teleflm.cpython-312.pyc,,
@@ -2097,6 +2138,7 @@ vllm/model_executor/models/aya_vision.py,sha256=5XLbDFjKP3TMhhMEZKiPvtuvqR13m-B_
vllm/model_executor/models/bagel.py,sha256=RFmNsYD-cHvkJYu2t_lqcwsUzYf2Ge-xBQWPOrad1iI,20840
vllm/model_executor/models/baichuan.py,sha256=UDYfE2aKMTOLJVLcCa5tPOlGnpgulbYgSp42EYnUs98,17841
vllm/model_executor/models/bailing_moe.py,sha256=_mYgI8Wx_MoTTWaso0jZcC7LzH2J-OmbabLzxUJYiR8,23151
+vllm/model_executor/models/bailing_moe_linear.py,sha256=2cFzHeSqLXII49rzD8hBLiQ4MNGxJn_WIwFVdw9PXGA,44543
vllm/model_executor/models/bamba.py,sha256=F5BlJxQgw4jtqoW_D62rmAkVV_pdWMA_IFtMhCyM9nU,17905
vllm/model_executor/models/bee.py,sha256=VyADm_f5piyv7b0-_D_mEZpSCq-tRld4dmJ9CJT5be4,5412
vllm/model_executor/models/bert.py,sha256=2BWsUrhAPe7XDu-2EvlPnMzR5sAU4DnpC-dGoT0zxcc,30889
@@ -2112,47 +2154,49 @@ vllm/model_executor/models/colbert.py,sha256=9khjV3Pvo30Rn5ZRiRTZOBt8JaltTSy4Ct8
vllm/model_executor/models/colmodernvbert.py,sha256=cb4sak9EiKmNdWgL4xJYjM2PpTGnMCbwRCCZWXOW2LU,15895
vllm/model_executor/models/colqwen3.py,sha256=uaJdiDF7kCNFE8tst-MWmKfUQtI9KKUqCyjLV8FmlIw,12519
vllm/model_executor/models/commandr.py,sha256=nt2FhYi9BMJy-dNWdt_n5QrrKIYLi4yGcq1TuMaGR7U,17426
-vllm/model_executor/models/config.py,sha256=nJ--ydGzN2cVqpdlq5ycKoxHfVoU0Isl2RPdSCEpT7w,27285
+vllm/model_executor/models/config.py,sha256=8e8wpp_tSIaKCmTtgAzXEhQbM7DSH7rV9h7gyRQGTK8,29241
vllm/model_executor/models/dbrx.py,sha256=lotryU3Rj2LkGDvDh56OmpUpett9FAG1OulvJFndfi0,17459
-vllm/model_executor/models/deepencoder.py,sha256=dGi2ga5F1pIvKzDkVXTx4QYfw2Jg17eJgD5MheqjIIs,23840
+vllm/model_executor/models/deepencoder.py,sha256=fQpcg2l92QZxT_zZxwO6l26YCrKEDUVHZxI8uRL9jxU,24021
vllm/model_executor/models/deepencoder2.py,sha256=wX5MmEGmVcOzXCYdBvUcPECUpTaY9c4uQkfl2Yu3bcw,9295
vllm/model_executor/models/deepseek_eagle.py,sha256=godN3QkVdWqAcpDrX7lpHe_8tVFmvvThd_BAzcexk8k,9347
-vllm/model_executor/models/deepseek_mtp.py,sha256=71PGNDRN0Kdg12MfGDdwGJDmcHubVkxIN4jLAxBAmNQ,18208
+vllm/model_executor/models/deepseek_mtp.py,sha256=0oOlEU_fAzzw8SpFZpHzcQ5DFcbmfQpZfmssmUCZYVs,24697
vllm/model_executor/models/deepseek_ocr.py,sha256=n2e35GbvjPPc5ROT8pKEBTyo9iorCfdayb4d6MOLyuo,21479
vllm/model_executor/models/deepseek_ocr2.py,sha256=XTfawQbWR_pj3s9MCwzc49DpnQCZRPyJcH7EIo2sOF4,15521
-vllm/model_executor/models/deepseek_v2.py,sha256=gxB24k4gMtblkj1FmwyHu8BpORBH4bnsJkvV9YiuuPo,63828
+vllm/model_executor/models/deepseek_v2.py,sha256=zimuPBOFDeebwyhjKQNXmiHOWl4ti_ZljNphvKcm7oY,66562
vllm/model_executor/models/deepseek_vl2.py,sha256=Xp9LsW6G7d2uF-D94luBaFW3om-A8bD4twa4ufKctGo,22435
vllm/model_executor/models/dots1.py,sha256=BqFjkVW2NlrLoTDhY4IWSWpUK1L_C-nd9FFiFmaoNUo,20865
vllm/model_executor/models/dots_ocr.py,sha256=Lt5qSAqHgZ8jiV8hNBDzv3usUMFJMDMMUhZkmbRra4I,26938
vllm/model_executor/models/eagle2_5_vl.py,sha256=4qvPgo27uo2UzrMTZWN94Ffilktau1ZGT5wyGX6rhRU,16273
vllm/model_executor/models/ernie45.py,sha256=gxkle1tmFbYsDkQx8AjoYZ-2_e93ALUtHMMdVnPA_fY,2244
-vllm/model_executor/models/ernie45_moe.py,sha256=V1DB-RLpKkN-DxrkUw5uZYzbhFOb4NK8_dbP6CGOyE0,28140
-vllm/model_executor/models/ernie45_vl.py,sha256=8H_gqsqmjADNqEjetMKt8bkb2LTG2wIrBoagRrstb2I,60912
+vllm/model_executor/models/ernie45_moe.py,sha256=_-UdUvUiUCxcqh5p8HYLv79WjFdu0Prs3fSlUiXOkPg,28234
+vllm/model_executor/models/ernie45_vl.py,sha256=brF7polIt5aOwjkAdIN82smjI1HmQnn5NtBkKswt_mA,61614
vllm/model_executor/models/ernie45_vl_moe.py,sha256=B_xBD4bvbQsGbV3Vk1tPPpt9TAvtnirxuyfSo2M-vBY,30498
vllm/model_executor/models/ernie_mtp.py,sha256=Fxb09Xis5scfTMtOngTHbftAFvYLMCx4Ja4R5qW1Ov0,10528
vllm/model_executor/models/exaone.py,sha256=ZSt7DOqMrA58yJ0pUWF-_lpmN72cFTUsRKrDkq31EsI,18832
vllm/model_executor/models/exaone4.py,sha256=UVO950qM-kG6qJg8CQ7CTUz7mobl2I-Ki1Fd-9MiTtc,18753
vllm/model_executor/models/exaone_moe.py,sha256=L2jDqHY1KPMUzXTIDPfTbkrJQVOcDA21KpDnqHPVsrw,21895
vllm/model_executor/models/exaone_moe_mtp.py,sha256=84xCAROtS2_BNWcySKRFm-v7bUS1sjUsIjhoE6clH_k,9294
+vllm/model_executor/models/extract_hidden_states.py,sha256=EcCjVVgjSTj7XqMTpXG6CLHJk6gbmaGOsJ4pKj4iYDI,13218
vllm/model_executor/models/fairseq2_llama.py,sha256=dIadCgG1u8yASjN4qmOTvkli4xxu-Ygl3t0iL2Czsik,6375
vllm/model_executor/models/falcon.py,sha256=LqHvEA3eLJG5qmZyby1y94rGXhuWQxvWJNMqogPoWXg,20490
vllm/model_executor/models/falcon_h1.py,sha256=fNGP9Yl1rc8pa4rkBHaP0VnQZrZmRWm6Td6pPXPHhhc,24732
+vllm/model_executor/models/fireredasr2.py,sha256=HT2ZxZgm20a-xdRoOhWR0qZLI1aZ0w5n0L3a-ACOySc,27555
vllm/model_executor/models/flex_olmo.py,sha256=8xGT9dxqgsyGeebsCzlV_4Z9WsJvUv868bOyWEd21xs,5610
vllm/model_executor/models/funasr.py,sha256=13zyQqR7uiqiw4r9wuARByxKE-rGPDZSksva5MDEm40,34185
-vllm/model_executor/models/funaudiochat.py,sha256=TneAjOvWfTleNiTKnTsL_LQ5oQtZiid2hoWwc_QYKvE,42133
+vllm/model_executor/models/funaudiochat.py,sha256=_XJvJA7-qEeraJaKwzjRcw_bhPbPd2wd_tA98W3fJfw,38660
vllm/model_executor/models/fuyu.py,sha256=60pJLMtw1ahoISbFi-rocZrTD2UWE7lXAh2Imo11VCU,12790
vllm/model_executor/models/gemma.py,sha256=3Jge3UOI-LmSOKHOaUFH8riFZ4SNxDxsRHx4Zb4M4hI,15564
vllm/model_executor/models/gemma2.py,sha256=OFktjmniv2G9EvQVF9CjMCQbeGbcty7a1sHdNawj2QM,16385
vllm/model_executor/models/gemma3.py,sha256=T7ulpHlAUoIBwpQv_1EmU51GclNoglosMX7hbC3rDak,19521
vllm/model_executor/models/gemma3_mm.py,sha256=CPBhSYTD5bQnYRcBTncCGUnUuWnDE8QFIsf8zYgQVQU,23327
-vllm/model_executor/models/gemma3n.py,sha256=yc7muBjfrsDIOZFPUiK4ZbGasGZ8wVngTVlU6WavluM,44213
+vllm/model_executor/models/gemma3n.py,sha256=FxdCWLWN-nETIWtn3dg8jTlw4NPkJApkPIPwv1MWGDA,44277
vllm/model_executor/models/gemma3n_audio_utils.py,sha256=8QRbJkqTWWcTtQZ8q8w-jBC6WlslT1ytYX49Zqn-oxQ,2414
vllm/model_executor/models/gemma3n_mm.py,sha256=yFLFL2y1u4iJxwBfmaRNlbSKgwBmLZqrzbrnzhZFwyE,30115
vllm/model_executor/models/glm.py,sha256=Xnak9Z2N9i2PFVY9m9-ocRcocTEQCPow5mGFena0URI,1108
vllm/model_executor/models/glm4.py,sha256=6aVSxPSrNVikr9hXe4r5NyplJE7YKni26JwEAOsa2qk,14460
vllm/model_executor/models/glm4_1v.py,sha256=lPbnzKQrxGjHvBKjW9cBUTd_zPopnLy18uw5EbLIBHo,63465
-vllm/model_executor/models/glm4_moe.py,sha256=1mXD0vtG1JyzCAVsK1XepWZUyq7ulGSE4WkkHeDX1nk,27686
-vllm/model_executor/models/glm4_moe_lite.py,sha256=RNW4hradY48Mbh7aLF6MU-RRHkg4vTVY_pRMVT7wNj8,24704
+vllm/model_executor/models/glm4_moe.py,sha256=xW2YcmruADEtMUmvTJk4fmeBJMx1rCmeiDFMuxe-tj4,28067
+vllm/model_executor/models/glm4_moe_lite.py,sha256=DVw9ss8nwtAG4tIO_8TAc_5buu7xOF_CWoZmKTDbsw8,28356
vllm/model_executor/models/glm4_moe_lite_mtp.py,sha256=KBVnPqBX9_ebWB-B6p81wYgl-gEUkGqSu4bhpbDJTyk,18956
vllm/model_executor/models/glm4_moe_mtp.py,sha256=xBuVd-8HQf0dPcXdCSuVRkIizLbOpKBJ63r2gxEY31U,14543
vllm/model_executor/models/glm4v.py,sha256=hLOKgSaMmkppivK7PKlJoUlQsT_y-SjGysoBOitKCds,22760
@@ -2164,7 +2208,7 @@ vllm/model_executor/models/gpt2.py,sha256=hn1Ke4iG7EFl2Z08cDh_h4eT_H6-A_cDkBWzvV
vllm/model_executor/models/gpt_bigcode.py,sha256=Z6CXbWDnof42fumqyNk1gkHP6EM3HaaSS52OnFlYnVQ,12165
vllm/model_executor/models/gpt_j.py,sha256=rCYV9bJeMnNPfOSmprPN7SgIBsDn83wzGinGflYh98s,12788
vllm/model_executor/models/gpt_neox.py,sha256=WmjxCXe50QcSJ4BYnmy3vLFwx62cUxwlfp4xKp3s6QA,12733
-vllm/model_executor/models/gpt_oss.py,sha256=zAwPWoFm2Cj1HVw3ypn3dA1vj6KSHNtaF1T5IwL3BJE,49178
+vllm/model_executor/models/gpt_oss.py,sha256=ez6LZkZUljft_qX8delpij9KY9iTQCnKZGjQnXH6-kQ,50254
vllm/model_executor/models/granite.py,sha256=fhn0J0lU3X2JDaOL0XAXwLPJWkHYgafLNjoEgubw4uQ,17392
vllm/model_executor/models/granite_speech.py,sha256=C0PblOLooxXdRDC39aFHv55T6E3_HAmgS351Ok_NVq4,34913
vllm/model_executor/models/granitemoe.py,sha256=X8iv2scVPqg0RmGZbb1sHzj8ajIK9DKTB_l1ansVbc8,21148
@@ -2174,8 +2218,8 @@ vllm/model_executor/models/gritlm.py,sha256=TI4KV2jQeocNtwY-VF1gGq_66EvX1kPo7NNG
vllm/model_executor/models/grok1.py,sha256=tNciZn3JfgQ1xiKrtT_mmNcHgpPx0zYXexaalcvXwjg,29214
vllm/model_executor/models/h2ovl.py,sha256=-_UGnfuUCDJgKP0eLyg_z79fOg8aCsMHY2ewRYcv4nI,17187
vllm/model_executor/models/hunyuan_v1.py,sha256=FetuRvv6Ry8ksbkbZ1l9257j9YnllULpxo85aURtm4A,40081
-vllm/model_executor/models/hunyuan_vision.py,sha256=RUhknygqSk0dlEnFjwXfwxlWD1JLMEHH1GDv8gFUQcg,34868
-vllm/model_executor/models/hyperclovax_vision.py,sha256=UejOXTZcqP8ko8CkMQa1fHK3xyT9hX726eNrvvzB_ns,40079
+vllm/model_executor/models/hunyuan_vision.py,sha256=acNd-RjmKUNg3eg7z0rNgA9TOKRlFNjSjIE_KavwsQs,35219
+vllm/model_executor/models/hyperclovax_vision.py,sha256=yfAs11Bh5EBnukembCU7k6iaQgIp1vtlRRFidmAGTb8,40057
vllm/model_executor/models/idefics2_vision_model.py,sha256=sq8wgfKf8s84Vfvhf0QvYgti7EuDHY0BUjXkfo-gWbk,15354
vllm/model_executor/models/idefics3.py,sha256=LVJLfwIXkQd1CLUX9rgAgjAM8E0OEYI0pshqAVjR45o,24142
vllm/model_executor/models/interfaces.py,sha256=qfCq1h0UP7A5MKnpV_DSg2ToG7wh-9ZuN3mz_6SRkQM,45722
@@ -2188,13 +2232,13 @@ vllm/model_executor/models/interns1_pro.py,sha256=DMT6NCMIeKUpg2_Meulj-bNmJdtJVn
vllm/model_executor/models/interns1_vit.py,sha256=_xYnzI342n_VEGVD6Bh_paQRAkMkhMl0wAh9P0VFe8U,15097
vllm/model_executor/models/internvl.py,sha256=gUczPpJolPFpESSGfxQnbUC3axD69Zrbnh_YobrJyTA,49486
vllm/model_executor/models/iquest_loopcoder.py,sha256=eoGgtrYyWMN4GbBm_Y26Blj5UCghLOZ-I-_OjQ6FwXk,22512
-vllm/model_executor/models/isaac.py,sha256=tAROikBtYtDAnVLqIV16p21c545eVIogfgEdYrMZQ_0,53516
+vllm/model_executor/models/isaac.py,sha256=QH42oeAUyYTEHiBsTncCeE1gA57rvvKL2MRG7ihnnKQ,53517
vllm/model_executor/models/jais.py,sha256=6Gh19Z0QfJ_Sy7fyLpYB7ev9OmtEoyFBRvNbEijb8u0,14182
vllm/model_executor/models/jais2.py,sha256=Pxu3I7vI3yy3jTCeIg3NbPdMYFiy7sqM6fccKKXejE0,18586
vllm/model_executor/models/jamba.py,sha256=MA-WJEfGyTWwgzaBDBvWLcBA7Nnq6OuejRc-X61FTsc,21509
vllm/model_executor/models/jina_vl.py,sha256=zp_BVbAQ5toYF6_pwlzTk8qsh9ezG0dW4IgBOaeoVtk,5084
vllm/model_executor/models/kanana_v.py,sha256=IORzJfTJBOSQfnkmpS7BCtnN5CsRyNIOP5QxGOoGRZg,27312
-vllm/model_executor/models/keye.py,sha256=JyFkuxrek5__J5p85I_dwkxmpNZm5Y3m93tL1QQE2c4,59197
+vllm/model_executor/models/keye.py,sha256=uEINJwt2PbcHWZghTfLxXkyp7qNv9wWP_bP-b4v3NB0,59543
vllm/model_executor/models/keye_vl1_5.py,sha256=a4ZzxhDIqDSU3J3GtzPm-I22lw34s8UzLgpFSoqnpJg,24967
vllm/model_executor/models/kimi_k25.py,sha256=tRvAZV2UxwxeaSvPatl0qncAFCRn0z8wx3RmpQB4MtQ,17313
vllm/model_executor/models/kimi_k25_vit.py,sha256=LU_0gXaYzaWo48mEXGxePpf-6KeMfxtKWwX7jd2MWjc,23163
@@ -2223,12 +2267,12 @@ vllm/model_executor/models/midashenglm.py,sha256=qKLjwkuRF4rX3lgsXeaAc5vCuqW3uGz
vllm/model_executor/models/mimo.py,sha256=oMHhy2sKkX4jrQUE4wxo9Ogk7ftrJ7PHJAxDzKaCfAs,7339
vllm/model_executor/models/mimo_mtp.py,sha256=c763S3lq5jzNswJ4TeyWlZ4YEjNEkLqt9qKKAP4uUxM,11068
vllm/model_executor/models/mimo_v2_flash.py,sha256=sYAPzCmLc_W5gKPGKjSwNnN1-4VDc0_oCGxjf_2NRJM,25791
-vllm/model_executor/models/minicpm.py,sha256=onw84qqdkTxWInx271tkZ1Ed6EAyk38px8ftGmTNAYE,23790
-vllm/model_executor/models/minicpm3.py,sha256=lA7ah336BFyl-3PKNHGEDqvdNdydk3CpEp2yJHRKHRc,8585
+vllm/model_executor/models/minicpm.py,sha256=UZGtCej1sXoS3QS5Gd8-jQvmnHxFjDkKKmlt4e5Yts4,23789
+vllm/model_executor/models/minicpm3.py,sha256=5PhiB4l7f6mlkGumDtiko6GdjijI2AqVjSOPcIgR88c,8584
vllm/model_executor/models/minicpm_eagle.py,sha256=-Y-N1X76_Crvq1OMd_r-2FqPfZ-d-5-OYbOhgZYHI1o,14435
vllm/model_executor/models/minicpmo.py,sha256=mlRcavEVtZJZuUXuXILevcdsJpcYcPsPXhEhgeRHkJk,30605
vllm/model_executor/models/minicpmv.py,sha256=igRUuvDYijPdX4OE-sSA37NZ251azZbb6qqT6ZyVYbQ,59008
-vllm/model_executor/models/minimax_m2.py,sha256=4LEPOv737S2cLwTTy27Ye4mNmaw4BBkcZLbGl4283t0,20952
+vllm/model_executor/models/minimax_m2.py,sha256=I57RFtBryk3ryVYRmQPUPewNTHUGmXodl34GwVg5JhY,20919
vllm/model_executor/models/minimax_text_01.py,sha256=KXXkFsjYQketFw_dz3uTVDtNJkq_9aULi5f1ypQ6XVk,38245
vllm/model_executor/models/minimax_vl_01.py,sha256=pzFFB2OvyRsx4CcRhQLUlkpwEXIHqFql-Vpqg7loKq8,14147
vllm/model_executor/models/mistral.py,sha256=YYklS51bg-BRZv2dNTODfpjXOF2jxtphb_xO4EZIhv8,11706
@@ -2245,13 +2289,13 @@ vllm/model_executor/models/molmo2.py,sha256=q26NYVOKkLzjZIwAkhCecE7WLONbl-gyJFsT
vllm/model_executor/models/moonvit.py,sha256=0kIg-6OIbPmRWWnNCuWcMSLgD5GN623A8c5f_EtOeQU,21491
vllm/model_executor/models/mpt.py,sha256=hT2CtvYFxHRbmGmkDp2fmZ2n5C7OWKuVgTFsbkml0t8,12149
vllm/model_executor/models/musicflamingo.py,sha256=W9K3xxiB-hbl07faLAZ5Tl_XaP2t87tZPhor4oPq8os,2383
-vllm/model_executor/models/nano_nemotron_vl.py,sha256=PdswnArosgwi4KipZEoGNdbAkIkAYCFmHPVPI2FFvwk,81121
+vllm/model_executor/models/nano_nemotron_vl.py,sha256=ncAcHPfWsiUIsaOUOmQO3qSrMWz29Q2_B9w3UNlcSHU,86755
vllm/model_executor/models/nemotron.py,sha256=NFaGuOxfWOJGbR9s1EjqWjHAf4V40P5LYvOAzLgoiTw,17952
-vllm/model_executor/models/nemotron_h.py,sha256=j_pTGXNsUVKRyufxV9q0Cc8jRCcVdwNfHn7YRxBNLgk,35030
+vllm/model_executor/models/nemotron_h.py,sha256=NZkUNMkfqasFpPIjFH1xl5ZC6uMuPacAV8nzv3XAwGI,35147
vllm/model_executor/models/nemotron_h_mtp.py,sha256=HKJFEN2ue2_vq5aym9HIiQ023j5v0ONq6WH-N7g4nww,17912
vllm/model_executor/models/nemotron_nas.py,sha256=Er37_CCWpEreyAQKoDujVmDTGDBZPupfUBnJNluK63I,17101
vllm/model_executor/models/nemotron_parse.py,sha256=K3Qr4c3Tc7jCtkRD0Z-uqduYsukIyPjzNmJ3UHCUde0,32172
-vllm/model_executor/models/nemotron_vl.py,sha256=ryzL1J8mKLlAovTs8cgP7Xr6oCCWVNShGbZv3hWSZhs,22405
+vllm/model_executor/models/nemotron_vl.py,sha256=nXwin_4Spok5Krdf0JhUB-4TU8i_qn04O3LPwnpj-nA,33620
vllm/model_executor/models/nvlm_d.py,sha256=piXDDA8xuNAWfruEPTHbYj5X5l3YZ_TJ1EaCvNXWNgU,7598
vllm/model_executor/models/olmo.py,sha256=5TXW9R5WZYCIr9ZXm340uIkrXJolsmorMAmTNt83S4A,14357
vllm/model_executor/models/olmo2.py,sha256=cd4UEDpwHV90pln9e8m-HmvPROAczb54qTP9q5Cy3lg,16618
@@ -2264,9 +2308,10 @@ vllm/model_executor/models/opt.py,sha256=VovPe87bR_iDzkUMnTiRiO8sq6pmzRbARTPDXy3
vllm/model_executor/models/orion.py,sha256=3-7b_q8dW9C8mbHdF59rp8bnJxUhLpp95nmIbs566VQ,13242
vllm/model_executor/models/ouro.py,sha256=FUcWm2suD7ROuA-RAcjHRe6ae-NRk_UPMTML_YceuBA,18806
vllm/model_executor/models/ovis.py,sha256=f-F1pVovgx_b7JegMEhHgXgFjZhbY2UyJXFinrWoRBY,20489
-vllm/model_executor/models/ovis2_5.py,sha256=L-XiG74IOJvbsVmj8TaVkxh-jNMZ2tLaa8v4CAb_SYc,24277
-vllm/model_executor/models/paddleocr_vl.py,sha256=0vaw2w5w8faNdmVgHuqHTaBjX8BrscxOsfCvBZOuq6A,43231
+vllm/model_executor/models/ovis2_5.py,sha256=FTR4GVuwDrlEMkncXpoBQZoearimbqHa-84Wgky1yv0,24189
+vllm/model_executor/models/paddleocr_vl.py,sha256=dTa73tvoy6Kicpi5_Anw8HtxcgYrGn68mDpcP83ht9s,43933
vllm/model_executor/models/paligemma.py,sha256=xZedgaDERYrwBP5OyTx0D7veWPo97Oj07J2NjSyFNoo,13873
+vllm/model_executor/models/parakeet.py,sha256=vwq0mgj_Y0EzMSoUmUb7BDlP9C-HPSmcFInb54Jw2Ts,5774
vllm/model_executor/models/persimmon.py,sha256=ZrT26c5oE6YXyuO9Lu-qFfFqwh6QQKHYxGtz-Ii12FQ,13620
vllm/model_executor/models/phi.py,sha256=CS86mMtB-SS1d-fxzKLu26EgvcE6CwyCcMYgx_jgdFw,12987
vllm/model_executor/models/phi3.py,sha256=N380ApMP7_3VL-rlWeVD27SUjC-kLXbldzYNWgG8ApM,456
@@ -2274,32 +2319,32 @@ vllm/model_executor/models/phi3v.py,sha256=16lU2Ry4D9HQWDXTSHgf_TYOYm-H9QrGOR3vk
vllm/model_executor/models/phi4mm.py,sha256=fY5utJ3CmhehRhawB2ITPdtsCZtZAfPgaXc_bEA6K7o,45373
vllm/model_executor/models/phi4mm_audio.py,sha256=4la2vPrfGNSC0Y_LAD5qeoJDOk514IXfVWij6_z_e70,50426
vllm/model_executor/models/phi4mm_utils.py,sha256=GxjDFelDuVA-jdf9HBYUlaXNfgs62ebCAVK80CFzqTo,66152
-vllm/model_executor/models/phimoe.py,sha256=BuUlupn3L7i49USU0tTuCEjOKNMg0EhD7TYLdF0Ki-8,23530
-vllm/model_executor/models/pixtral.py,sha256=ugP9TulFlcm_ynoJnld7bC1eRJ4YEL5j-DE2ZBuMiEc,49151
+vllm/model_executor/models/phimoe.py,sha256=4AZ8TQfyOrTYdhRgFs3YA_g8VH5Dq4ZqqoeAUWqo1KY,23631
+vllm/model_executor/models/pixtral.py,sha256=pbs1OE7xfwwNq78xozpcoRuTPcqwXAGQ0jUYiW7aUOc,49305
vllm/model_executor/models/plamo2.py,sha256=UH6XNDM0ll6GznSqOHSaLMRemk2NkHIoofflHYeD9fs,37903
vllm/model_executor/models/plamo3.py,sha256=CbqsDUY59dgrr5j37aT4_74mMZKYe3EEntzhHsExjpo,15850
-vllm/model_executor/models/qwen.py,sha256=hK3zUTqM8oRU2x5hl1CMpVulQpcXyVwFSu26QJdR-XU,13274
+vllm/model_executor/models/qwen.py,sha256=n0801rEic97cw6YOetUbBi1iYitoOBbiE1FW_Y7btAI,13282
vllm/model_executor/models/qwen2.py,sha256=-UoJnvCi1tQUODCDbPooJJesEDtOuiGnPat0eQYUvP4,22235
-vllm/model_executor/models/qwen2_5_omni_thinker.py,sha256=bWgLbLPXox9U0yVY0vwe8SqunBJn-II4SW4HLkm1wBY,55350
-vllm/model_executor/models/qwen2_5_vl.py,sha256=PDLj70FNFMHVw0L4O3iiFrsP-6_iRT4NW4xO6bK6oO0,53962
-vllm/model_executor/models/qwen2_audio.py,sha256=cY4Jy1r7EsVsWxBfFQcUbRTXPhi-O5hJ2bs8fA5okC8,16642
+vllm/model_executor/models/qwen2_5_omni_thinker.py,sha256=b7Htse1Mb2UgqO2EwMB4bn1Wdets0l132nOraCLoLeE,57875
+vllm/model_executor/models/qwen2_5_vl.py,sha256=LBG2zprvqKNkfDyKnB8yWmmGbnxerp9-Qvjmi5nS1Fs,54596
+vllm/model_executor/models/qwen2_audio.py,sha256=g0veswhwz4mWEZIDFHkgru5AdWiz_LwMOMSrFR_fJB4,17394
vllm/model_executor/models/qwen2_moe.py,sha256=WIma4Bczak6gDjd4h4iWnhemF4wROWFIasyp5lafSo8,23264
vllm/model_executor/models/qwen2_rm.py,sha256=qDtLznXFur9l3y6sQSeG_AMNIdMken95UHdKBbgVG6E,4120
-vllm/model_executor/models/qwen2_vl.py,sha256=kStWdKZsqnimKdxSBGqj1wkW2qFFBz4QicKd7LWJC7g,53333
-vllm/model_executor/models/qwen3.py,sha256=r7q8DSESqiHQjcI51rjIYHvFt1d9fB4MUzjRE11b9rY,13294
-vllm/model_executor/models/qwen3_5.py,sha256=RqkiCHK8zkTiPh7n8B80o-eebWWiPGCrE628UMo8_dQ,32317
-vllm/model_executor/models/qwen3_5_mtp.py,sha256=wlktoP4-wYSR_eHLqq-O2KSgzzD2_UqugeZ4TkPSzTk,16574
+vllm/model_executor/models/qwen2_vl.py,sha256=vaOZPPwEvbNCV-msL8XmcYiDR0TLFA_OutIbB4IkoCk,54141
+vllm/model_executor/models/qwen3.py,sha256=_WYUj0_aBKUDmJ7i4-1H-nyLT-AcoM_lpp7v_zJl4mc,13397
+vllm/model_executor/models/qwen3_5.py,sha256=jHhBKnHzl-rPUovQkg0Y6hIPqZ-VHu_JI62P8VxH6_k,32218
+vllm/model_executor/models/qwen3_5_mtp.py,sha256=6yjQt59R0VhrlCCe9toyScBmEM9cpcZ4Sq4lI2z83vA,16574
vllm/model_executor/models/qwen3_asr.py,sha256=H2s7E2svPp9Gpy1pKEIKTOzl3Y2mbQxn4cW19mNPd8Q,21181
vllm/model_executor/models/qwen3_asr_realtime.py,sha256=6k2QkZZ9E6wb5JEk9zb7MkJ7mhmquGzysEW1Nrqlcps,8879
-vllm/model_executor/models/qwen3_moe.py,sha256=iMmeP73CF02m4l_Mhk8OU3Mu6JVeoyv0yGbWoZagQqg,31302
-vllm/model_executor/models/qwen3_next.py,sha256=WEqxvBk95T5lRBaJ6BWjb_YgH3yXnutAKZzBSGDPMqA,56101
+vllm/model_executor/models/qwen3_moe.py,sha256=ue8JOXWEGoD4vUo-AJOtbv_EW0_k36njr-9tTacClxk,32314
+vllm/model_executor/models/qwen3_next.py,sha256=sDhPnDxtmyA3jISo-Nr7VyaayVYO7TMQ0GAhlAsakGg,62735
vllm/model_executor/models/qwen3_next_mtp.py,sha256=lMY5ULOAj2ykUsLFsQTnSx2-wfn_Z5Akwdw2g41Jtzw,10668
-vllm/model_executor/models/qwen3_omni_moe_thinker.py,sha256=8J-LV7hkxkx_gplEwALhMDwtbIFdksmFnNG2O3ETD7M,88291
-vllm/model_executor/models/qwen3_vl.py,sha256=k-9DiyzpFJECpyaYjqgRNZDTMIbHoYJ9Z1JqK5yA90w,81041
-vllm/model_executor/models/qwen3_vl_moe.py,sha256=_RQDu20sMXoPnKi3FslBORFgFl8V8LKjNyRupcXfr-0,19664
-vllm/model_executor/models/qwen_vl.py,sha256=dpe4J6aIOynwdnP3oMr43g8Okk9Qwuq0Lrwkda0jMsQ,25918
+vllm/model_executor/models/qwen3_omni_moe_thinker.py,sha256=0K3PbENWnKowRxqY59fIIxZ6ZYC0LY-LYlDpG0mdAHY,90452
+vllm/model_executor/models/qwen3_vl.py,sha256=BGwW4P8fzP6sKpZX6bUZP-jKgUNSdcizPX3VSmbts1U,95398
+vllm/model_executor/models/qwen3_vl_moe.py,sha256=P2W7NLHLT953mQDhrNFP1gm_hJ4_ILSIWrFRnl5jLOk,19890
+vllm/model_executor/models/qwen_vl.py,sha256=F0Ubyo0LslTfWNXTta-gbKs2gII7FjAJAIkt6tBfrCY,23987
vllm/model_executor/models/radio.py,sha256=46UFQzPd4bkoyZwt03rV8iM9Ld6YN4rXjwlt1FZNePk,25512
-vllm/model_executor/models/registry.py,sha256=V-IHKusPALbPvNiQsVrnpNk0ca2urn4D4ZE6rfrhqN0,51521
+vllm/model_executor/models/registry.py,sha256=RohMN3KxCV2iFc9iDy-jz4uMA-ehAjfys5xf4QdhZHE,52098
vllm/model_executor/models/roberta.py,sha256=kY8L-R4SqEbL7cGc_En3LMTyNVN-NkfHH8BITXjq4Gg,12952
vllm/model_executor/models/rvl.py,sha256=x0LpjLj9DPzT2ErbrxM2ZDZgxM84QMD-5m1ROp8Ino8,3529
vllm/model_executor/models/seed_oss.py,sha256=3au3bmOI5W0hYE_Qq1ExX8NJrktd-17Kn2pt9c1w_e0,17901
@@ -2311,15 +2356,14 @@ vllm/model_executor/models/solar.py,sha256=cp5LoWIxnQZ2_WFjeb6YdkTOpaMevW5lxw2XJ
vllm/model_executor/models/stablelm.py,sha256=qxsgs-JAfeR0UXOv58PAaBuEO8oKmU_PJe0TU46ivl4,13595
vllm/model_executor/models/starcoder2.py,sha256=3qw3Rg7K1qyNzrzXkuP1aDGqj4b7D1WxLhzOoi_Yvhk,13401
vllm/model_executor/models/step1.py,sha256=1hH9WJwssqYIP4wF9Nq4HqGjnKVwhllCXQcO6a5GYSc,15190
-vllm/model_executor/models/step3_text.py,sha256=Jho9rywZWBQn4cFeRbLbCegRxhMEOrllUsei3y4MZZM,19894
+vllm/model_executor/models/step3_text.py,sha256=I9J573qe95GSW72sstQZNrZOBr74NMCLAGZeEL8AlZg,19905
vllm/model_executor/models/step3_vl.py,sha256=Pakr_mwr9FnsSzALSDgIHJBq9oXKIN-6S13yKsYfW4s,39484
-vllm/model_executor/models/step3p5.py,sha256=6r-PcRDlGXE_jdJHsIQh8AQBgesnVdUd1BqegOwHBDA,37688
-vllm/model_executor/models/step3p5_mtp.py,sha256=lWs3tSxr_sYT4SDisJaBSq09Du2-Er9P9xRkGHK0KfI,12542
+vllm/model_executor/models/step3p5.py,sha256=oMQMKsqifSj4tMXPgPEo3JsZDZodmCVQR9Rriq-GWho,39071
+vllm/model_executor/models/step3p5_mtp.py,sha256=prsiPktOD0zQ4TemNe4ybXdlppaUMAsbB8sRg4Y3OI0,14279
vllm/model_executor/models/step_vl.py,sha256=464Mjc6zeUApfPiNC04yPBJzNBw_xS5VMIxjOFW35Y0,18859
-vllm/model_executor/models/swin.py,sha256=gKgBbW4e1gkSkUw_2OlqvLyijWGPkPu6YBYvfkKZ-ro,16428
vllm/model_executor/models/tarsier.py,sha256=7KdUQBig2GNlnP4i1KqZZuiH0CAsGyh-mLIRuZNyRvQ,22354
vllm/model_executor/models/telechat2.py,sha256=wOEquSezPeLg2qqUNZ7ncsIzpfxHxRCrFoproADyARE,6245
-vllm/model_executor/models/teleflm.py,sha256=vlTObvAPJAJhkhy-8Snyx_7R1jGqAbYU8OYbWAPgvh4,3030
+vllm/model_executor/models/teleflm.py,sha256=AQvwDlusYU4VyDuOO5_uj2FSyoD4pB7cUx9GLf4ZG50,3104
vllm/model_executor/models/terratorch.py,sha256=ZyRLssdTyADURq2uHROxn-ejp3mgnQwzDg4_aaa_IzA,11086
vllm/model_executor/models/transformers/__init__.py,sha256=KR0hAoCBU4oqFfOpZ1kT9u80zQ_s-Z6HWpTfkvra-iM,4264
vllm/model_executor/models/transformers/__pycache__/__init__.cpython-312.pyc,,
@@ -2330,15 +2374,15 @@ vllm/model_executor/models/transformers/__pycache__/moe.cpython-312.pyc,,
vllm/model_executor/models/transformers/__pycache__/multimodal.cpython-312.pyc,,
vllm/model_executor/models/transformers/__pycache__/pooling.cpython-312.pyc,,
vllm/model_executor/models/transformers/__pycache__/utils.cpython-312.pyc,,
-vllm/model_executor/models/transformers/base.py,sha256=CB8v6-tcJHPjPda8X5b55b0H-nWaLKR5_Vl6VrDY5ao,21105
+vllm/model_executor/models/transformers/base.py,sha256=Ans3HGdUNAg4R_6AkXvqwxeZDk8e2NuyDkImpgGdQvk,21881
vllm/model_executor/models/transformers/causal.py,sha256=lO_fPotGi08ZcpwRMKeEP2DKctafVjhW9XZAvNA-S7A,2661
vllm/model_executor/models/transformers/legacy.py,sha256=NY96kEaStY66o4Sa3tOFCqLhhwgiIf3eO2dnViXowIQ,3229
vllm/model_executor/models/transformers/moe.py,sha256=ey3ADmdqPZkwTFdQ5avKdC5gXVp_z8dagNg28uO9GIc,14328
-vllm/model_executor/models/transformers/multimodal.py,sha256=U70KCN4N5PmsX1md_6BYw_d7RNrxN-fZKb_X4uneJ-A,19289
+vllm/model_executor/models/transformers/multimodal.py,sha256=wFVkVDR9zBwfE-ZLBA_sBOe_TWE32tGNJddL7_WMq9E,20313
vllm/model_executor/models/transformers/pooling.py,sha256=SS_1ciigE75k4WB5IthheWqYLZPeIGZWiT3tvKCmaqE,4077
vllm/model_executor/models/transformers/utils.py,sha256=HVyyfp_2Y9XPwg15SacdGFlpbWSK8abHE1ZQRFb4-I8,8825
vllm/model_executor/models/ultravox.py,sha256=mRny8WOuDkOToEvAi_hx_NdY2gkLPc9Bom3pGL5BNbM,29549
-vllm/model_executor/models/utils.py,sha256=beOOPhvdL-75EhwiAbcUlg8kE2ApiHOp4-YKQ5TwqKQ,29424
+vllm/model_executor/models/utils.py,sha256=ThsVFW09be9dVIoReq4WFNaMJa9fhbKRAcS2P7cWHQY,29149
vllm/model_executor/models/vision.py,sha256=0jQuAx2aSJKCS7ZuVCB6gfFWbDqWmfL522yjmWnX66E,21391
vllm/model_executor/models/voxtral.py,sha256=QTP6ebmVb3uwzABO7rYzEmCQCjty0N811hBloLP2NJE,34824
vllm/model_executor/models/voxtral_realtime.py,sha256=lbMt7I0oDjUk8KjMdvFB6q451phLJo9vDW9S6l6py-E,18997
@@ -2357,14 +2401,14 @@ vllm/model_executor/offloader/base.py,sha256=Eg89-UqJ22EGGYYPQhSoZpXoWyEq8dfTtYZ
vllm/model_executor/offloader/prefetch.py,sha256=GnH8axqlOL2AJjaUFo9YSBJAoc8RXmcWDXPi35a2fd4,28295
vllm/model_executor/offloader/prefetch_ops.py,sha256=DB1mKrQ91LBO4pEIRwI6_pEOdamCX3hn2vHrrR3VvxI,2545
vllm/model_executor/offloader/uva.py,sha256=NaWXzOKNus3zNgqpL_cmNB55SnT01VddiKQGsykfE34,5220
-vllm/model_executor/parameter.py,sha256=E9TVgals5xAY4jx0mcgbi90hJygCJJL4qjb7gU5sjvE,21945
+vllm/model_executor/parameter.py,sha256=Fs6RjMq461a8eAt7lJP-497YU8Zcu9UD6aP09NbYC_E,22429
vllm/model_executor/utils.py,sha256=WI1ei6LPUsDfoZaw7fgwdSlVEfRkKegN6qK5rsKZHac,4542
vllm/model_executor/warmup/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/model_executor/warmup/__pycache__/__init__.cpython-312.pyc,,
vllm/model_executor/warmup/__pycache__/deep_gemm_warmup.cpython-312.pyc,,
vllm/model_executor/warmup/__pycache__/kernel_warmup.cpython-312.pyc,,
-vllm/model_executor/warmup/deep_gemm_warmup.py,sha256=OC6sL7gNXDU9lgmktF1HGNEkjKE0kVYH34NqpnYe-Eo,13025
-vllm/model_executor/warmup/kernel_warmup.py,sha256=_m1mE9tiMgcqja1kvjxPWg1byGJndrRm87sqOa2o8ws,3755
+vllm/model_executor/warmup/deep_gemm_warmup.py,sha256=rhnLDXmHTxV7_iymb_H-JHX3n2FFYUuf3_BJLL5Y5uA,13029
+vllm/model_executor/warmup/kernel_warmup.py,sha256=VSL0EIGcY59NQjgefYM4u5UbgqS3xzmr1lZ1SYwo2fo,4036
vllm/model_inspection.py,sha256=99U1s9FBinjggNPt5NvHvabMcIh3IK_dnXLgbFYC_jc,4743
vllm/multimodal/__init__.py,sha256=tnAxNnu1kVSEBizAbOTSXQ2zpRHvc_ZnyxORjLjmGBc,985
vllm/multimodal/__pycache__/__init__.cpython-312.pyc,,
@@ -2382,7 +2426,7 @@ vllm/multimodal/__pycache__/video.cpython-312.pyc,,
vllm/multimodal/audio.py,sha256=sS5AeYnH01hdM2TRzgAHHLgdPrTL2b3I9H5FsfTxMCc,11139
vllm/multimodal/cache.py,sha256=uZushqgTUTR4F-m__CFuVFRH2_BC_sF-j2VX0oFu1Bk,22324
vllm/multimodal/encoder_budget.py,sha256=dqryW7h1X3Iyxfic0VFzorsI2M2qqhB5cZmklEiltJs,7111
-vllm/multimodal/evs.py,sha256=fRGDSsF8VNjS6c665IH4SgKZd__RKIll-sEeDZP3Kyg,10886
+vllm/multimodal/evs.py,sha256=cs1ZvdSx7r3QpQDZjDDQkMvLqLdv_LmDB4mzEve0Adk,14494
vllm/multimodal/hasher.py,sha256=AET4pGTelW9oVmHIKGGlSo_aYMv9I5cbQ9fjQcxmZos,5340
vllm/multimodal/image.py,sha256=V7Ev9PMZt4MqN6bXefVCOTe1_RdZPlyldtkstzN-XH0,1192
vllm/multimodal/inputs.py,sha256=zcJhUiqaiLV3KANARgLRK2u2UEpPj6LN0qb8IXR-D1k,35665
@@ -2408,9 +2452,9 @@ vllm/multimodal/processing/__pycache__/processor.cpython-312.pyc,,
vllm/multimodal/processing/context.py,sha256=0xnFAD3iAVIYu5gbnpS93lWenQ_EFounyIPNkWpN9Lg,16378
vllm/multimodal/processing/dummy_inputs.py,sha256=C7wgoX9gmrtWsur4EPFB-FuzlC_9z_Zr6h6pHQmXdQU,6281
vllm/multimodal/processing/inputs.py,sha256=JCMdwmCD_w4XNvupdPSSFCkxDxuC-_AOU--IOU9ZO24,2798
-vllm/multimodal/processing/processor.py,sha256=PwZ4_2wjCw7nqdjoDHvigF7I1J52xPmiWRBYTVF2tmA,58929
+vllm/multimodal/processing/processor.py,sha256=V1lIReMzRuHG_2DsjHBHQNWpEv6aOGG9sEg_wiRXqes,58563
vllm/multimodal/registry.py,sha256=IX6LWAFYnZuKo0bF-SVpa2hk7nkONnnWOjm-IhyWLWA,12284
-vllm/multimodal/utils.py,sha256=vW2Md85sh_vygLaDjLZOY89IFsdYUh-KcFeR-gwpUYA,9832
+vllm/multimodal/utils.py,sha256=jLnnfxOa5ZMhlnanpQkTKNeKRJ9Jjr63XcoifWvg8LU,9281
vllm/multimodal/video.py,sha256=CiILeXYUrjyW7WSkd34AEtpYvtVpS8jeCA_SmsdslCA,30131
vllm/outputs.py,sha256=mRwJYHcKYZBmzL9GnMCM4elSPn_hnjoRbEzRPB4XUKA,12628
vllm/parser/__init__.py,sha256=tISZkyTnsqET8CWO-_w9DKGfR016aJRuy_7e00GoVow,920
@@ -2421,7 +2465,7 @@ vllm/parser/__pycache__/parser_manager.cpython-312.pyc,,
vllm/parser/abstract_parser.py,sha256=1vzOQp-KTgkhFQxlHkyqMsdE7VhwZFC6jNDr3lhSfUk,18946
vllm/parser/minimax_m2_parser.py,sha256=7ITK9L57iW60y18H2LLxJuCem2FL3fvflhogeo-gszU,1908
vllm/parser/parser_manager.py,sha256=AMWOnEKFSwaKx6hhxQDnZuZBpVxMEitD85jGvXp__o8,10514
-vllm/platforms/__init__.py,sha256=yVDxm_6yxNOgxHqxWbytL1k0gbvzSCLXrjTBzdDMqfc,9727
+vllm/platforms/__init__.py,sha256=-RUbNJzU1WSuvMPAhNh5g9NSKOl6W6OAA6u164WqHG0,9681
vllm/platforms/__pycache__/__init__.cpython-312.pyc,,
vllm/platforms/__pycache__/cpu.cpython-312.pyc,,
vllm/platforms/__pycache__/cuda.cpython-312.pyc,,
@@ -2429,25 +2473,25 @@ vllm/platforms/__pycache__/interface.cpython-312.pyc,,
vllm/platforms/__pycache__/rocm.cpython-312.pyc,,
vllm/platforms/__pycache__/tpu.cpython-312.pyc,,
vllm/platforms/__pycache__/xpu.cpython-312.pyc,,
-vllm/platforms/cpu.py,sha256=eK2HRaEKigFmFwChBOjfTdeo4SOsR3e0PdF_lTF3gLg,18338
-vllm/platforms/cuda.py,sha256=eqGY7NXVVP3CxWNAgtyUfSYqC_Vn8qW8vZ-UNGzOUpc,25508
-vllm/platforms/interface.py,sha256=BTxw-MJJnYQ0vC7w-pAZoaR7mEFX3ZMeJ-1GTEk1HN8,23504
-vllm/platforms/rocm.py,sha256=yIh1P_PSnf0-yNJES_Wm2vHq7HrPjq27MVhHq4_K6F0,26010
+vllm/platforms/cpu.py,sha256=qRBZPDrcJxQrqOGuTWo6OpxyHwQocHafifH_HVGzvXw,19426
+vllm/platforms/cuda.py,sha256=dH5TE7jcUkWIEkDnL5LIxKHjE5y0DJMP2l7uPggWzdw,26656
+vllm/platforms/interface.py,sha256=GkftAURIiA3AUXQT516vNeTiFLv5XJgvrg6T4AJpGic,23529
+vllm/platforms/rocm.py,sha256=oWfl-cNmhvLiQ4iyRtpLg-SPESYvFi8nG2qGj14eXS0,30601
vllm/platforms/tpu.py,sha256=jP2fyAJCwm7b0phOFcRlT7BxXHubFOiQMFLxA3oE31Q,497
-vllm/platforms/xpu.py,sha256=Q9asZAboAyNz5bB2hi67iajp19fVOeIkOxCKKRs57WI,11389
+vllm/platforms/xpu.py,sha256=1q6G6_jsxBCmTiDSJ0ukjIcyJtPGby9SB5jjVsmp66A,11716
vllm/plugins/__init__.py,sha256=S-ZhkM6u6dBGX2Kt6AGo6UqQfXq5_bCmf6FM6HRIrp8,2902
vllm/plugins/__pycache__/__init__.cpython-312.pyc,,
-vllm/plugins/io_processors/__init__.py,sha256=UhJEF--e0B816MCdrtyBxon4RQgsuM6vxLTBu6V6Y4A,2517
+vllm/plugins/io_processors/__init__.py,sha256=cJWVuQGdBZT0wyg1_omUBpx-PcIERlN_KUAG7Lqbbvo,3067
vllm/plugins/io_processors/__pycache__/__init__.cpython-312.pyc,,
vllm/plugins/io_processors/__pycache__/interface.cpython-312.pyc,,
-vllm/plugins/io_processors/interface.py,sha256=cl9MdLYrYUq-EmbqBu-59d3TW34luNT7XXf0NxHWIlc,4260
+vllm/plugins/io_processors/interface.py,sha256=JM4M0WsF7MKwKRrnRXnPSUAcrzHTrr5SD4ZT6GKAoho,4324
vllm/plugins/lora_resolvers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/plugins/lora_resolvers/__pycache__/__init__.cpython-312.pyc,,
vllm/plugins/lora_resolvers/__pycache__/filesystem_resolver.cpython-312.pyc,,
vllm/plugins/lora_resolvers/__pycache__/hf_hub_resolver.cpython-312.pyc,,
vllm/plugins/lora_resolvers/filesystem_resolver.py,sha256=ys4JylLFIcxzO8maZID3OJKM_ZbP_A7JRh3wmqNs2_M,2345
vllm/plugins/lora_resolvers/hf_hub_resolver.py,sha256=dsXM5Vlb2YVNhYTeUKaF-grzx-nnN4I41tUR6M7FlRQ,5600
-vllm/pooling_params.py,sha256=m8gF0DVPyP-blqTcYKL6PxbPwPHvkL7wHjbMSLNc8tM,7925
+vllm/pooling_params.py,sha256=e1Hvm4fzEunEb-5MT86oWyC_FJwf97GzI0Rj7w2OD7U,7522
vllm/profiler/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/profiler/__pycache__/__init__.cpython-312.pyc,,
vllm/profiler/__pycache__/layerwise_profile.cpython-312.pyc,,
@@ -2463,7 +2507,7 @@ vllm/ray/__pycache__/lazy_utils.cpython-312.pyc,,
vllm/ray/__pycache__/ray_env.cpython-312.pyc,,
vllm/ray/lazy_utils.py,sha256=0QXF-uTIUdzl992ugkYgkvHbWUG-85yjuIDNSLVu2gQ,651
vllm/ray/ray_env.py,sha256=x04IQZOk22XIrW7rUJo7EUJFJChzm8utQNv72-1jRh4,4169
-vllm/reasoning/__init__.py,sha256=WY68O9kcZXb9Lmj6wUdX_y6Gnic-bnXX8n8Am3RQFkE,2576
+vllm/reasoning/__init__.py,sha256=5wv00rMoOjbqRTLDmUCXdtmkr6N0arx06NRotQHcWgM,2556
vllm/reasoning/__pycache__/__init__.cpython-312.pyc,,
vllm/reasoning/__pycache__/abs_reasoning_parsers.cpython-312.pyc,,
vllm/reasoning/__pycache__/basic_parsers.cpython-312.pyc,,
@@ -2474,6 +2518,7 @@ vllm/reasoning/__pycache__/gptoss_reasoning_parser.cpython-312.pyc,,
vllm/reasoning/__pycache__/granite_reasoning_parser.cpython-312.pyc,,
vllm/reasoning/__pycache__/hunyuan_a13b_reasoning_parser.cpython-312.pyc,,
vllm/reasoning/__pycache__/identity_reasoning_parser.cpython-312.pyc,,
+vllm/reasoning/__pycache__/kimi_k2_reasoning_parser.cpython-312.pyc,,
vllm/reasoning/__pycache__/minimax_m2_reasoning_parser.cpython-312.pyc,,
vllm/reasoning/__pycache__/mistral_reasoning_parser.cpython-312.pyc,,
vllm/reasoning/__pycache__/olmo3_reasoning_parser.cpython-312.pyc,,
@@ -2490,10 +2535,11 @@ vllm/reasoning/gptoss_reasoning_parser.py,sha256=HWJ2eXOICySjOm3jFuamMJMcALsDPb2
vllm/reasoning/granite_reasoning_parser.py,sha256=MeDBbB2Iyv1SrKJisfsqYHR1Cn_JQfgE3QD3TeFXH5U,15272
vllm/reasoning/hunyuan_a13b_reasoning_parser.py,sha256=ThfBaOuTVFqY7OEU8vq_OWsiRBZXQ7Uu-xqcv8Y3bgI,9651
vllm/reasoning/identity_reasoning_parser.py,sha256=0wNlliwNu_wea0IdvdJqdZq1xYoIwtwwef04oV8slQo,2210
-vllm/reasoning/minimax_m2_reasoning_parser.py,sha256=ag2xV0h86r1yIdsDneIUs5m9W3b2tDmBj0ptUx6w__8,3941
+vllm/reasoning/kimi_k2_reasoning_parser.py,sha256=j-6uGi3WcXgL-lgNb3LpkmIkfrYuj-fGAQL5VeXTJDo,8309
+vllm/reasoning/minimax_m2_reasoning_parser.py,sha256=m0PlsWAa_PLA6mHLM8Ox6UQnwkKjrRGOv81Slu5SiAA,4134
vllm/reasoning/mistral_reasoning_parser.py,sha256=1OVWfroTTHBwlOI-L2COq7FPQvVDDZnD8vMB3VTqQnE,6319
vllm/reasoning/olmo3_reasoning_parser.py,sha256=ocBot-3mlnfuEkWg941OY-bbpcdB5A8PHU1jlI89brE,11123
-vllm/reasoning/qwen3_reasoning_parser.py,sha256=_gMOu9SXAnxUHVajPYfY4TVJRROY7fmWnNOOIpznzsc,5339
+vllm/reasoning/qwen3_reasoning_parser.py,sha256=OLWU8THTcNcLHWmlMLBjIDP3ky3NXrynsDvPVuGMO5o,6099
vllm/reasoning/seedoss_reasoning_parser.py,sha256=Oz4zSgr6seYtssSyDKsFi7J-b9kTy-IcuC01HNJkv_c,832
vllm/reasoning/step3_reasoning_parser.py,sha256=PD92XgxuGaH7yJf3pQLt6DTXeZkekorTIv9bY4SDPlQ,4403
vllm/reasoning/step3p5_reasoning_parser.py,sha256=RA4AmtxGLicQL_V2i3duPW4_fMRUCz0rzF8F6rKnaJE,7227
@@ -2506,6 +2552,7 @@ vllm/renderers/__pycache__/grok2.cpython-312.pyc,,
vllm/renderers/__pycache__/hf.cpython-312.pyc,,
vllm/renderers/__pycache__/mistral.cpython-312.pyc,,
vllm/renderers/__pycache__/params.cpython-312.pyc,,
+vllm/renderers/__pycache__/qwen_vl.cpython-312.pyc,,
vllm/renderers/__pycache__/registry.cpython-312.pyc,,
vllm/renderers/__pycache__/terratorch.cpython-312.pyc,,
vllm/renderers/base.py,sha256=fP7HzLgm-oMDOr5vbPDtxDpvIvOZZJ2NNx30VZmkUCw,26149
@@ -2521,9 +2568,10 @@ vllm/renderers/inputs/preprocess.py,sha256=OgO3sBDbaKcFk95rkQJZ1Kuq0cfoGmNkgH5HE
vllm/renderers/inputs/tokenize.py,sha256=p1SGoV6Z_BdoW1ZclSNhzHm8aH4gFzOLv0m799XqKe0,1393
vllm/renderers/mistral.py,sha256=5xJEpFTCaJgxrvYti3eB2yKtyXGEcgsdANx-bILpbGs,4337
vllm/renderers/params.py,sha256=KZYXBx9JRuxhD1EVxizfgnRzYvhZnZgWsgg2qlKY7o8,14128
-vllm/renderers/registry.py,sha256=eOVvDfFdt7mQwlwvjqruBTeyzFVDoodf8Mu1P02XbAg,2850
+vllm/renderers/qwen_vl.py,sha256=nDMh0Siz9z85FCJl2qDZfWhEcxTruNuv0cMzC6VMLc8,871
+vllm/renderers/registry.py,sha256=wzx4GDoucNRHKGFRqfJnuM6Li6FP98Gz61J7wEtftEI,2896
vllm/renderers/terratorch.py,sha256=QNmRVC-dddsBXE3NV6v_rBK-u8aVamBRh54bkBLMNc4,2278
-vllm/sampling_params.py,sha256=_qcGtMojZzw74PGt3ZuhItnus2e2yOG0rzAIuzC8EfI,35852
+vllm/sampling_params.py,sha256=QUuPyJnCrWmXvKP6CdQe1KLgkCr5n_NuoDJnDpB5y2o,37465
vllm/scalar_type.py,sha256=s8CZpZru6-4vEexIofKsJzhw5OBjVuvueG6b6yb6A94,12573
vllm/scripts.py,sha256=d2RU3e6Qi_xY7QUD44uqsdz8DdBR-_zQnFf4rYUW_iw,504
vllm/sequence.py,sha256=q10MDnLiFsnhKuuKBeIrQ6Pyj9LpGzVcHKnm-1uK0YI,2173
@@ -2533,7 +2581,7 @@ vllm/third_party/__pycache__/__init__.cpython-312.pyc,,
vllm/third_party/__pycache__/pynvml.cpython-312.pyc,,
vllm/third_party/flashmla/__init__.py,sha256=eMVQe_qZ6bEFHduSNgQ7ZAP3AvV3oMbEmyy_A-s-RSc,31
vllm/third_party/flashmla/__pycache__/__init__.cpython-312.pyc,,
-vllm/third_party/pynvml.py,sha256=HRQEbE5ZB-AgMIySG4j24hjlEgSGLrNvHJUq7UawCfA,234653
+vllm/third_party/pynvml.py,sha256=7ylxuC4d6eXttiQdX9CjOAVWe2sqEEtIztGQKLCf_RE,234646
vllm/tokenizers/__init__.py,sha256=F3xPUuvHJZfUCeLImLWYyDBZa-e_1aMlj37SfUZHr_A,418
vllm/tokenizers/__pycache__/__init__.cpython-312.pyc,,
vllm/tokenizers/__pycache__/deepseek_v32.cpython-312.pyc,,
@@ -2543,15 +2591,17 @@ vllm/tokenizers/__pycache__/grok2.cpython-312.pyc,,
vllm/tokenizers/__pycache__/hf.cpython-312.pyc,,
vllm/tokenizers/__pycache__/mistral.cpython-312.pyc,,
vllm/tokenizers/__pycache__/protocol.cpython-312.pyc,,
+vllm/tokenizers/__pycache__/qwen_vl.cpython-312.pyc,,
vllm/tokenizers/__pycache__/registry.cpython-312.pyc,,
-vllm/tokenizers/deepseek_v32.py,sha256=b0nnVfswcvF-IVct3zOnVNuOIf0m18bBQI2zDaIvzHM,3213
+vllm/tokenizers/deepseek_v32.py,sha256=nNR_ToCswY76ZhuYX2UUxrUHqO2uXICMs6YM94cl1Uo,3221
vllm/tokenizers/deepseek_v32_encoding.py,sha256=3jX1CQW-tmoh-YB8w2uxm_jWdo3eUHcPtjnnQyvjVLU,15812
vllm/tokenizers/detokenizer_utils.py,sha256=nyIW3eO0H7ovSUyyHt4mGbmmKOr3Ip-jAgK-3BN5KXQ,7821
vllm/tokenizers/grok2.py,sha256=x5j44_g31paTCXVp4WvccBRGA5XnqecwrXjDyk4wKpk,14522
vllm/tokenizers/hf.py,sha256=9GIQPaNBIgUkAqmWwSdRLR4cXQqiX8Ewz5fUCCqS9vg,4493
vllm/tokenizers/mistral.py,sha256=kntyYVaIafQ3N5fEn8PAmxIbSIQVV-Z2d_8bHgDqO0Q,21738
vllm/tokenizers/protocol.py,sha256=_LfaT2d9M6AEu-nPJW82pdw4b8Byk-T6Ps2viGAh1EI,3291
-vllm/tokenizers/registry.py,sha256=5KEg0piy6hX9wi0OGxvmTNPIQq0b-ykgOdes6gajDT4,7991
+vllm/tokenizers/qwen_vl.py,sha256=cDjvLkjkfNgfkBjM2SbFQmnK-WT7Z2t5yxtH1MMu2uc,2165
+vllm/tokenizers/registry.py,sha256=83OHhZhAmMKu7HIQjAY-rI2jZLc42J-z62Hqcv5O-_g,8177
vllm/tool_parsers/__init__.py,sha256=TvPqzR3Wlr2qzbX50fmzbt22zxN9_bGOwHSORuh7Ebc,3666
vllm/tool_parsers/__pycache__/__init__.cpython-312.pyc,,
vllm/tool_parsers/__pycache__/abstract_tool_parser.cpython-312.pyc,,
@@ -2613,12 +2663,12 @@ vllm/tool_parsers/olmo3_tool_parser.py,sha256=-pGGNRErh_F8AJTCUo4uOTitEDE1hpoDai
vllm/tool_parsers/openai_tool_parser.py,sha256=NnO1_6IwnAgmbsPiHZ0-Et3e1EW-WyKrBoRC_Or5K_M,4279
vllm/tool_parsers/phi4mini_tool_parser.py,sha256=5QFCzfRrwpUGuwGaXrOakOhwpsFqvQdyXQF-q_U99Bo,4084
vllm/tool_parsers/pythonic_tool_parser.py,sha256=9C8fOprJgamB_8g0rHsHfgBhrU_KW7BQRv6jmsiY4tQ,12269
-vllm/tool_parsers/qwen3coder_tool_parser.py,sha256=MBPZKQoCpQB4EMtn_JmXevEZ4-bxWIrfs3dqSwXKEYQ,32804
+vllm/tool_parsers/qwen3coder_tool_parser.py,sha256=zuu4Nml-xRndFBcfZSVcVCFWt2VklPZ9XyKMvCYq32Q,29185
vllm/tool_parsers/qwen3xml_tool_parser.py,sha256=0cRnsspFfki7fPwrGFD8RD1wAfTyl0fWtuK-JVMYD8A,53857
vllm/tool_parsers/seed_oss_tool_parser.py,sha256=gEQbkx3uY54FBcXK513d9kWPurkuGIFPUMnpOcTHO68,31404
vllm/tool_parsers/step3_tool_parser.py,sha256=wakZ3uuPcW8IOFe0tG3SONYtfEF55LTKhy3DA_y6098,12159
vllm/tool_parsers/step3p5_tool_parser.py,sha256=zmSzhP9207A1cEEodvfviFihTBePIWDpPsjQwKdvbn4,61867
-vllm/tool_parsers/utils.py,sha256=PaTL3WKpzrZwU7wJHqI-XlO9McJR4LGL4ldFk1ukMC0,7460
+vllm/tool_parsers/utils.py,sha256=TDnPFr7xYcLQuEoAfEMIYkZs5oSu_lkm-gCkezspxZI,7087
vllm/tool_parsers/xlam_tool_parser.py,sha256=1xOuZ1IRvsvbhTwLx2WXejI2ylfiFrDYQLy7N-SAaHg,24689
vllm/tracing/__init__.py,sha256=LYFnkXkGpQBWKbAY95Am9yKeOmoi-iRmM4WQhoju8mE,4308
vllm/tracing/__pycache__/__init__.cpython-312.pyc,,
@@ -2652,7 +2702,9 @@ vllm/transformers_utils/chat_templates/template_fuyu.jinja,sha256=hzdsPgeUMaZnd5
vllm/transformers_utils/chat_templates/template_minicpmv45.jinja,sha256=tNvJgf7S0je1GqFxPUQ-q5nSGwXoINdqW3-VOFNekh8,4231
vllm/transformers_utils/config.py,sha256=PP_sZlNDfxx_yHTIJYITVv_Nzg4v5SjYzfmiix8MWhw,42310
vllm/transformers_utils/config_parser_base.py,sha256=qSyL1AYBvvnQTII-lS9_N602OJ_Huf5BkUZYBhQEFtU,523
-vllm/transformers_utils/configs/__init__.py,sha256=MCljcKMzfFhdTPLIaOTMl3IEMKt0mlEq75stsaRD1J0,5482
+vllm/transformers_utils/configs/AXK1.py,sha256=7KL6K89TunBNrye7TKtNH76tbeT67m4lvP97wQttqF0,10995
+vllm/transformers_utils/configs/__init__.py,sha256=EbMCeMBZf-g46jMnb1C-QyY1gVLGk8NgaOttlI5bnQ8,5558
+vllm/transformers_utils/configs/__pycache__/AXK1.cpython-312.pyc,,
vllm/transformers_utils/configs/__pycache__/__init__.cpython-312.pyc,,
vllm/transformers_utils/configs/__pycache__/afmoe.cpython-312.pyc,,
vllm/transformers_utils/configs/__pycache__/arctic.cpython-312.pyc,,
@@ -2663,6 +2715,7 @@ vllm/transformers_utils/configs/__pycache__/colqwen3.cpython-312.pyc,,
vllm/transformers_utils/configs/__pycache__/deepseek_vl2.cpython-312.pyc,,
vllm/transformers_utils/configs/__pycache__/dotsocr.cpython-312.pyc,,
vllm/transformers_utils/configs/__pycache__/eagle.cpython-312.pyc,,
+vllm/transformers_utils/configs/__pycache__/extract_hidden_states.cpython-312.pyc,,
vllm/transformers_utils/configs/__pycache__/falcon.cpython-312.pyc,,
vllm/transformers_utils/configs/__pycache__/flex_olmo.cpython-312.pyc,,
vllm/transformers_utils/configs/__pycache__/funaudiochat.cpython-312.pyc,,
@@ -2682,6 +2735,7 @@ vllm/transformers_utils/configs/__pycache__/nemotron.cpython-312.pyc,,
vllm/transformers_utils/configs/__pycache__/nemotron_h.cpython-312.pyc,,
vllm/transformers_utils/configs/__pycache__/olmo3.cpython-312.pyc,,
vllm/transformers_utils/configs/__pycache__/ovis.cpython-312.pyc,,
+vllm/transformers_utils/configs/__pycache__/parakeet.cpython-312.pyc,,
vllm/transformers_utils/configs/__pycache__/qwen3_5.cpython-312.pyc,,
vllm/transformers_utils/configs/__pycache__/qwen3_5_moe.cpython-312.pyc,,
vllm/transformers_utils/configs/__pycache__/qwen3_asr.cpython-312.pyc,,
@@ -2700,6 +2754,7 @@ vllm/transformers_utils/configs/colqwen3.py,sha256=IUckSjIN7qxmKVDEttyXK91jP-6iq
vllm/transformers_utils/configs/deepseek_vl2.py,sha256=I4Ge43Aqaw1cS1M7gps9utrM7XwmXj_ktbxRcHAeoKI,4075
vllm/transformers_utils/configs/dotsocr.py,sha256=_mxuA6xDpO6N2-FD4R5XVqLDP79D8i0OAZcPvqUCSdU,2445
vllm/transformers_utils/configs/eagle.py,sha256=R8LT924n_tTE8LkfYcENCv_UVvLmsIWKl6th6okyOWg,3012
+vllm/transformers_utils/configs/extract_hidden_states.py,sha256=84LXbdhcsDSJQ1E3bKFgVgOPmKuDiz-7RnUPgyh3FQE,1749
vllm/transformers_utils/configs/falcon.py,sha256=XNrg2x08UDZngZqtv5T0CiEFZq7qhoJPWdEYnI64e_Q,2937
vllm/transformers_utils/configs/flex_olmo.py,sha256=NInzKT7pO5w6CKxxaaEAXW_YnN6f1Sv3hDWQYhPL7ko,3157
vllm/transformers_utils/configs/funaudiochat.py,sha256=F3R7akXupVr2KQTh9vycB6S8Z90oEFWDiihyLx2qr6o,4695
@@ -2719,6 +2774,7 @@ vllm/transformers_utils/configs/nemotron.py,sha256=EAYX0byXwgziKfac4nL2cFtaatPyD
vllm/transformers_utils/configs/nemotron_h.py,sha256=fL8DudDBuHNAdBAdkXjP1XhbGpGRFIOW98qXH3zHCdc,13414
vllm/transformers_utils/configs/olmo3.py,sha256=pnjjin2JhHw7lA0JzpsVaRiF4l3-jDwj5_DrroEEVTU,3026
vllm/transformers_utils/configs/ovis.py,sha256=zjfzz8-btPauEhKwQm9dRqmsZ8jakSaVCqqZ2rJZv2g,7535
+vllm/transformers_utils/configs/parakeet.py,sha256=B2bzUzRWpUqojQhx4_ZK9bBrymoXawnTbfue9iPyhSg,1658
vllm/transformers_utils/configs/qwen3_5.py,sha256=d4NwouDihpFiCkOGWGPxq_5rPjWhT4C_LeUZ9SnJiW4,6996
vllm/transformers_utils/configs/qwen3_5_moe.py,sha256=t85XK6BlTnhevVDiUr55xlr3oIh0X0_0AVl78EwcL9Q,7661
vllm/transformers_utils/configs/qwen3_asr.py,sha256=Vv0fHqk53_R1M7BSli0pQ3XoVILtMQ5IzK1ALPVpmBg,20563
@@ -2736,13 +2792,14 @@ vllm/transformers_utils/configs/tarsier2.py,sha256=ydiT9TPSYbSL58QHX8y2KwtboQ0gt
vllm/transformers_utils/configs/ultravox.py,sha256=4DwreIyyKxSwyie6av42UJof5wejQHeDLWtcA5jBPO0,4944
vllm/transformers_utils/dynamic_module.py,sha256=Gh0_B_IHLEedgKAHPNUZcQo2F_dAxicKo6luvfTo7zw,2038
vllm/transformers_utils/gguf_utils.py,sha256=DHxZAHenR0iZAwxRaK2KuzhUSZXCgJ-wgAw7SF9OgtI,10279
-vllm/transformers_utils/model_arch_config_convertor.py,sha256=MKieLiPi4Se4DjICUz4jq6jhg5kbBIRhbSdFol_Yiyw,17803
-vllm/transformers_utils/processor.py,sha256=QEIBQBoYMr-zgNnj2aGHyhmtC0I93ZbD1NlEnE0a8kw,17747
-vllm/transformers_utils/processors/__init__.py,sha256=MHOjadEoO3x803thzujcfC-raEWzofV5uCC7EQIQaXg,1067
+vllm/transformers_utils/model_arch_config_convertor.py,sha256=6yGeq38_M65ZyWvUJvVlkPJg8QSdlJ2OkWmxqBTFqek,17980
+vllm/transformers_utils/processor.py,sha256=UF5zLB6CwEYpEFT5tGGwqIvRguVQeNLBOBgoB25QUR4,17576
+vllm/transformers_utils/processors/__init__.py,sha256=TDMetNekay1ikxBwruohrq1aMXi2y5l_OnzgakelPJY,1194
vllm/transformers_utils/processors/__pycache__/__init__.cpython-312.pyc,,
vllm/transformers_utils/processors/__pycache__/bagel.cpython-312.pyc,,
vllm/transformers_utils/processors/__pycache__/deepseek_ocr.cpython-312.pyc,,
vllm/transformers_utils/processors/__pycache__/deepseek_vl2.cpython-312.pyc,,
+vllm/transformers_utils/processors/__pycache__/fireredasr2_processor.cpython-312.pyc,,
vllm/transformers_utils/processors/__pycache__/funasr_processor.cpython-312.pyc,,
vllm/transformers_utils/processors/__pycache__/hunyuan_vl.cpython-312.pyc,,
vllm/transformers_utils/processors/__pycache__/hunyuan_vl_image.cpython-312.pyc,,
@@ -2752,20 +2809,23 @@ vllm/transformers_utils/processors/__pycache__/qwen3_asr.cpython-312.pyc,,
vllm/transformers_utils/processors/bagel.py,sha256=zi9b5HjfDpnd7a7GS-ANquWCft-3KcMrMuVfcVZSJuA,2745
vllm/transformers_utils/processors/deepseek_ocr.py,sha256=IG7NDyVUoZa5lNsMebYz5FBPzN-wy4o0LsKgYwR2r_8,15899
vllm/transformers_utils/processors/deepseek_vl2.py,sha256=YWmA16RtLAOE_eMHTJwHJSFoRFlBU3wD7cR3ZXP9JyQ,15121
+vllm/transformers_utils/processors/fireredasr2_processor.py,sha256=NvYz59sEnOEt93IGTo7LETY2WvyvKNohlDjecjBRMos,13047
vllm/transformers_utils/processors/funasr_processor.py,sha256=cUe7FyL5vwQHIIYT2Y0U16kUk0bXJWuX62CZVMb8Y_Q,18547
vllm/transformers_utils/processors/hunyuan_vl.py,sha256=4Bo9Ge6OXn4_KnzVArwFr37gdMNdPI_EpZhufv9Nnnk,9435
vllm/transformers_utils/processors/hunyuan_vl_image.py,sha256=KeauATCER1RdedquOh8bvXPKk9GT6eQAG8avSRqSdz8,22292
vllm/transformers_utils/processors/ovis.py,sha256=X7F5oc4jep7cV_G2Zj9r7uJE3QOccZCbeB9NsK2mZpg,19608
vllm/transformers_utils/processors/ovis2_5.py,sha256=LbudomwDaXccixhZcYT7V-J5eWYNbEnANn4Q_69RzbE,20083
vllm/transformers_utils/processors/qwen3_asr.py,sha256=w82qQ7wMbWe97bGoqrGjBZYpkXMr5BQ5gVAki9_1_Rk,9095
-vllm/transformers_utils/repo_utils.py,sha256=Ktp76jBwmBN8YUDxlc8adZ6IYQOHVgD0IyTt1gE5ClE,9005
+vllm/transformers_utils/repo_utils.py,sha256=vTY2zuYnYiQGsjPsRu4sl4ORMyy6XjV6o047Af7mi7M,9015
vllm/transformers_utils/runai_utils.py,sha256=SUDsljkSzlbmLhGzqpgykG4kgNVaVOxe2g4-kmXm4z0,3154
vllm/transformers_utils/s3_utils.py,sha256=3Vk1LybSbByVKEZZzOly7d3ggdhMdPRwS2wvqxUOpcE,2770
vllm/transformers_utils/tokenizer.py,sha256=v4LtF_1D5_aqj_1cpv9d53FNsVUAWqtX4-osRvkU-rs,665
vllm/transformers_utils/utils.py,sha256=H5CxGzLlIpCA-Zd8lkKlghMcFgeHopgYBA_9p7Rrle4,3078
vllm/triton_utils/__init__.py,sha256=FY20qv1flNw_TDjM9j544voqnflClkYGSOEF5lPwv0M,567
vllm/triton_utils/__pycache__/__init__.cpython-312.pyc,,
+vllm/triton_utils/__pycache__/allocation.cpython-312.pyc,,
vllm/triton_utils/__pycache__/importing.cpython-312.pyc,,
+vllm/triton_utils/allocation.py,sha256=oGBpkpw7qD5MH7ZNr5jb7y0IJXUPYQw6v_7Ya7bgDhA,376
vllm/triton_utils/importing.py,sha256=CcaJ7kpP7bKP3yfuR_1w_vNRBxiZk8qNdcatjsvoAYI,3589
vllm/usage/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/usage/__pycache__/__init__.cpython-312.pyc,,
@@ -2806,14 +2866,14 @@ vllm/utils/async_utils.py,sha256=8QiEviVKKi3NQ_OHRZVQqrPk4L-K_rAqFV77gZDQCHU,113
vllm/utils/cache.py,sha256=p6Cvrul9e1fOuw4h9-n74qmY_lHMQyS0BqhHXQqN-_Y,6206
vllm/utils/collection_utils.py,sha256=UHEf5oYJtCzYI8bEW6K5J-zLhGTF8gnmIzBnFhwiFVs,3592
vllm/utils/counter.py,sha256=6Xa0WjHey2LwyPJkGfYa37mgOsQfcHJF9kRypIcaHDQ,1165
-vllm/utils/deep_gemm.py,sha256=Kl9l_pcLZIYmuECIpA-XVG6D_38L3nXoBACWIXDRerU,14925
-vllm/utils/flashinfer.py,sha256=Y820S37DtbPacSFvrXfxxa8F4ty1L8l7OBqL0AQq1jE,24792
+vllm/utils/deep_gemm.py,sha256=gzpu5Yzh7zN1qLP7NFar4d024HF3xUfhxkEwbjNKVbQ,19593
+vllm/utils/flashinfer.py,sha256=apQmb9MkQtVB75qpZjKkeOEwPLUKKfMkLvNorw3BBRc,24831
vllm/utils/func_utils.py,sha256=pLXKoub-DlGrxZRwliCGmuMy2eiEHrglZPEQSbvPjKg,7706
vllm/utils/gc_utils.py,sha256=zMuLZmEcjECblk30kHjKGwLdOHANND5vopi17x_IRjI,4970
vllm/utils/hashing.py,sha256=ibS9K9JGpb_c8A12bqOKgYI3auz9a2t9XEfZaxVZJD4,3618
-vllm/utils/import_utils.py,sha256=iv2ycHmNYKHljx3OpNxtoIwO6DE10HaJ388a-H6isTg,14154
+vllm/utils/import_utils.py,sha256=cbPEomdgp8bfCu5ueAoNb19DwXCHc8lnOImpIiXpXIE,14021
vllm/utils/jsontree.py,sha256=lktXNEOq0w1qqAGTPgpJ5eDogjetHDC15Hq___2OLmg,3755
-vllm/utils/math_utils.py,sha256=j0iLxkURCK07Y3Kc9zr5NGWn_oKK57ijNKKB2wvyD7E,890
+vllm/utils/math_utils.py,sha256=3HHFjy16D9iPqbPd_4xdBgNQrgXLZ-O3TcZHqZEFaL4,1042
vllm/utils/mem_constants.py,sha256=DerNf9-iEKE2kx8c6ufmLY23mn7vV3K1vskiAw3SWJU,390
vllm/utils/mem_utils.py,sha256=fUzur9cxpAP1E5Jf0mGQ1Q9Xtv5TIvCJDjIuu1dI43M,10285
vllm/utils/mistral.py,sha256=aEBEQhLIXcgkv8RaDEi6mSruMq9_5-sgflgpTP4dQ04,1029
@@ -2825,9 +2885,9 @@ vllm/utils/print_utils.py,sha256=M_6PwgPusaGHGQLPo9e3TovYSAdpW5HoNsyM4ZZxqYY,302
vllm/utils/profiling.py,sha256=UWku3rWFiMP4UclIokbOxF1fh5xVuwiOc4vjsq2Fdo4,1469
vllm/utils/registry.py,sha256=gWRZ2XRwU-YapctL5A05O1bG2Vgm-QBcPrxjUPRugFU,1555
vllm/utils/serial_utils.py,sha256=W_2bYeQugj6AlLZLGHgsm0TXsfEz1kDxe0ecvIwfj0Q,2805
-vllm/utils/system_utils.py,sha256=3T02U9mDJzjxQ70tbRxnZLUtcbYmUpktzps-fDYXj54,9241
+vllm/utils/system_utils.py,sha256=sKIPCcmjTXHjaMhCGpalqPBzLJ6097C4WE4-aDl93hs,9767
vllm/utils/tensor_schema.py,sha256=Q2GzQ-CzlvQjA9qvZBcP8f3layD454s7c5nRt_FeZWU,9140
-vllm/utils/torch_utils.py,sha256=TyjxKoyoWZh4dS2-XmKCjNwjZZK5LXH_zZOSMIjGFzg,26997
+vllm/utils/torch_utils.py,sha256=ymLP-Yv-tk3wvoUfG19z0Y0HGvQR5_Ei_NtnKOxIXeY,27945
vllm/utils/tqdm_utils.py,sha256=JW5ymX2EnmIix7RLymsiDPFasapB6udhbSurNvXTlXM,831
vllm/v1/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/v1/__pycache__/__init__.cpython-312.pyc,,
@@ -2841,7 +2901,7 @@ vllm/v1/attention/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU
vllm/v1/attention/__pycache__/__init__.cpython-312.pyc,,
vllm/v1/attention/__pycache__/backend.cpython-312.pyc,,
vllm/v1/attention/__pycache__/selector.cpython-312.pyc,,
-vllm/v1/attention/backend.py,sha256=muIkHtECY0qPILG3MO4DSzIwKx2_lGgkl_XzGDKDTB4,30672
+vllm/v1/attention/backend.py,sha256=P6tojKvUNwYyZRzt7n_S-9aRD3S-x8BsRYfaTy7ZbO4,32661
vllm/v1/attention/backends/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/v1/attention/backends/__pycache__/__init__.cpython-312.pyc,,
vllm/v1/attention/backends/__pycache__/cpu_attn.cpython-312.pyc,,
@@ -2864,16 +2924,16 @@ vllm/v1/attention/backends/__pycache__/tree_attn.cpython-312.pyc,,
vllm/v1/attention/backends/__pycache__/triton_attn.cpython-312.pyc,,
vllm/v1/attention/backends/__pycache__/utils.cpython-312.pyc,,
vllm/v1/attention/backends/cpu_attn.py,sha256=iYyIvrNtpSBcKrIplpde1u49-KcmqkNTTHJGCK76oA8,18338
-vllm/v1/attention/backends/fa_utils.py,sha256=JUYQ594bQYVZGUU6Pj_jOeWt15t_yhQF7r67abH0rQ0,6110
-vllm/v1/attention/backends/flash_attn.py,sha256=v6wtmZFVPTl__VjipyEBqI4SKwNWKEOwLmrlXSAEN1Y,51112
+vllm/v1/attention/backends/fa_utils.py,sha256=qeo-t4CkCimRDBlo__qumUeAsKzsv3DT4LDk4fQJY8w,7318
+vllm/v1/attention/backends/flash_attn.py,sha256=ry_b2avMWWD9SbNnV-uyQFCzCWxyCRwg2gEKxGZNbyA,60988
vllm/v1/attention/backends/flash_attn_diffkv.py,sha256=8EU5rcDQZDtc8znzCvrL8mc5AFWFbQbPUDsPIap9DgE,11068
-vllm/v1/attention/backends/flashinfer.py,sha256=NAP38iPLg3gU6kOAJpxDZrbUR4jgZ5XwkumwuA8Qo5I,70745
+vllm/v1/attention/backends/flashinfer.py,sha256=49NpSoqLCwY_765DXN4Gx50gH2_58vHRT4h-tYbTVEE,69003
vllm/v1/attention/backends/flex_attention.py,sha256=0Hk1rJaq04PnApS8y23K0n_meHwd5SctGOxgIQ0ewfc,39786
vllm/v1/attention/backends/gdn_attn.py,sha256=93921EKaiUJ_ThvRVayxt00tRY6hIHCDyQgnjRql1CM,17327
vllm/v1/attention/backends/linear_attn.py,sha256=5CzIcHkvZMu-2Nkf2kaTWl3ZRuAWIOJGIqwxtHpvzM0,2722
-vllm/v1/attention/backends/mamba1_attn.py,sha256=b35cfbixMEN6XEETvY3sMpXJbVBykJJ7o9AIyTIc7Do,811
-vllm/v1/attention/backends/mamba2_attn.py,sha256=xCB_VlCKDHP-twOhbuUdZuTK1cP7-dLKVkWQJBMwtdc,9885
-vllm/v1/attention/backends/mamba_attn.py,sha256=QBoMLjkGgK_T3wWmSRh-TDf_ly0jnGdJyNbuS_kngFc,18401
+vllm/v1/attention/backends/mamba1_attn.py,sha256=rP-FJ3xGx1QHO6tcjHn57vnZHqc5Yk3Q4TsjDxIMdCo,1743
+vllm/v1/attention/backends/mamba2_attn.py,sha256=aY0ksoNS_hCdFr2vwLfjtjStt6BCjm77KXwstUJfZj4,5692
+vllm/v1/attention/backends/mamba_attn.py,sha256=vpu7bKD0xbsmFKpZhS211qL5FyWVJx68ge7OFGcrnPk,23072
vllm/v1/attention/backends/mla/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/v1/attention/backends/mla/__pycache__/__init__.cpython-312.pyc,,
vllm/v1/attention/backends/mla/__pycache__/aiter_triton_mla.cpython-312.pyc,,
@@ -2890,20 +2950,20 @@ vllm/v1/attention/backends/mla/__pycache__/sparse_utils.cpython-312.pyc,,
vllm/v1/attention/backends/mla/__pycache__/triton_mla.cpython-312.pyc,,
vllm/v1/attention/backends/mla/aiter_triton_mla.py,sha256=7Q4oXNrgagGRUMQcyVizwLXOsOTymdUxjRyQ4sI5U5A,2000
vllm/v1/attention/backends/mla/cutlass_mla.py,sha256=U5piAUug73lO6ioNkDe48gD0gQNG1q5p7ikcXFgMmo0,8799
-vllm/v1/attention/backends/mla/flashattn_mla.py,sha256=RUh_zQxwPTSeYLQyWfc9O-Lnkvk98_OljSsOlJtVHZI,12805
+vllm/v1/attention/backends/mla/flashattn_mla.py,sha256=AJEMZ4_dcy-MN_naxGwVcoBcgg5kIupb9vpNS_nzFG0,12974
vllm/v1/attention/backends/mla/flashinfer_mla.py,sha256=RnAiv3mmxEZ4spmkqPNcARcQN-5EUqR5pe-1_sqpKRc,6591
vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py,sha256=L9kd6cixFsX-bwzUT_fhS4WxJMfPqVilK_0KTFvahPo,11872
-vllm/v1/attention/backends/mla/flashmla.py,sha256=ZgpQjV1kltnILiyOGPwPG5jTXEpotue2H4KPpx3c6G4,11002
-vllm/v1/attention/backends/mla/flashmla_sparse.py,sha256=UkyZDdbijKAcJozt_BZ5MeL2XwDWAW7rBR-RywuuX2Q,33493
-vllm/v1/attention/backends/mla/indexer.py,sha256=QnxZDvmRnTyolQYA_ZPP5Zpo0fSmC3DE_BdNoSOWY8I,14592
-vllm/v1/attention/backends/mla/rocm_aiter_mla.py,sha256=lsCZA_CKMXQft__BMA6dVbK2dTWuwtTAqMXRQOvickg,9871
+vllm/v1/attention/backends/mla/flashmla.py,sha256=bNDNcVPaWPFOA7yVxCtDboo0fAVkuDKYdopgxb7_i5E,11171
+vllm/v1/attention/backends/mla/flashmla_sparse.py,sha256=g2iNW_VqDaA1_vQGjhSRYyIpox9dM_z6dy8ElSgZsyA,36349
+vllm/v1/attention/backends/mla/indexer.py,sha256=YRALm24dJuR3z7plkjuun9aoddITD1mYL16lo9o6duQ,21150
+vllm/v1/attention/backends/mla/rocm_aiter_mla.py,sha256=ppnGFEhOslTchdZN_81-7gW6Q1QJFStAayeR1_JL8-w,10040
vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py,sha256=A-psJkVpEhf6rF_g_EHso0hR6jWMO415e52Z96lTWTw,12857
vllm/v1/attention/backends/mla/sparse_utils.py,sha256=Niep6oTN_8_CyrqTvO_JPByN-7cRoeY7SvnYekSYumU,7124
-vllm/v1/attention/backends/mla/triton_mla.py,sha256=4ChkFVZNMuyhGe87gNSBB7em6xMffJ9WS2Cj3Vlscdc,7038
+vllm/v1/attention/backends/mla/triton_mla.py,sha256=62lvVPNyWlDBYb-lXHonDcNYC_ZmlTVVZfUYR8TucgA,7616
vllm/v1/attention/backends/registry.py,sha256=3nIonTWV4rJx-rI_LIDJUa0fjVaev7N5n3HzSJ1BdQw,9695
vllm/v1/attention/backends/rocm_aiter_fa.py,sha256=VD-JfyOSt2jFfP3_lscGG7gPmRlDIqMpZJ_NfixZJvc,53117
-vllm/v1/attention/backends/rocm_aiter_unified_attn.py,sha256=UU14VRIMXKiNEupXHveZuBIA_MgtpSgbT1VOgGQJuBE,8056
-vllm/v1/attention/backends/rocm_attn.py,sha256=R6i2CkKwPi9MMiTLvjoShfTzDDPFpI1HjKGJ8VrJTwI,16626
+vllm/v1/attention/backends/rocm_aiter_unified_attn.py,sha256=ObOdNzBZ-j5RxsUMuwbVX1Re4_fs2-jN-nKBweYi1IU,9329
+vllm/v1/attention/backends/rocm_attn.py,sha256=vllPGdybcBMF13lR1tgXTZ-wrtkE8N-ZUiP1Jjp_DT0,19124
vllm/v1/attention/backends/short_conv_attn.py,sha256=MEf-oDqbQl62cWs8vmKeSo5_ZcOuNUjW4PC724iGIns,835
vllm/v1/attention/backends/tree_attn.py,sha256=ctIymZevnFTX2VuZCrpa6Qce2RqpmK_iWE8hsf8ecSg,15716
vllm/v1/attention/backends/triton_attn.py,sha256=rymjUf9jVPJhYd2PRLH2dX_mpWBfZt61P8MIpMpOi5o,23275
@@ -2925,7 +2985,7 @@ vllm/v1/attention/ops/__pycache__/triton_unified_attention.cpython-312.pyc,,
vllm/v1/attention/ops/__pycache__/vit_attn_wrappers.cpython-312.pyc,,
vllm/v1/attention/ops/chunked_prefill_paged_decode.py,sha256=QJOLdArCSXYIhTX6s-HoVbYa3Yol6oR3zLOmNNDEBrA,15750
vllm/v1/attention/ops/common.py,sha256=J_g0iiIkHoqwaxgVe8IGTRNluldoeriDRRg6cmLYA9s,13602
-vllm/v1/attention/ops/flashmla.py,sha256=bl_rqebAj-ip7_5M-U_8bH2EpE8DlkeD4t-72mQZc_c,5119
+vllm/v1/attention/ops/flashmla.py,sha256=N6aXtzwa_u3TBTXClwzoiPGiMx9GKTQ4XIRDq8riTYI,6025
vllm/v1/attention/ops/merge_attn_states.py,sha256=gwJoIzgUPGqEr20vWSF8ByEUPxkBDTSeUd5whJPvAXM,1658
vllm/v1/attention/ops/paged_attn.py,sha256=AWqV-JLsdRyGs3vh4SFxutnwtAdrvyIZ_zNvTpxQNlM,1437
vllm/v1/attention/ops/prefix_prefill.py,sha256=_p33Jl6zj3PqWkq7klbBo6pS_RJvnImBnaA9WUzAaxM,27286
@@ -2935,7 +2995,7 @@ vllm/v1/attention/ops/triton_merge_attn_states.py,sha256=4fOT2PY6AocmhvwAYt5Zu1P
vllm/v1/attention/ops/triton_prefill_attention.py,sha256=jO4m6SCLKLk9GYc5zA2ki8qxWoYG34RR-KmYJMwmPbk,7585
vllm/v1/attention/ops/triton_reshape_and_cache_flash.py,sha256=l23FcUrpKLfL1CAjtRKRE7jDotvOm7bGhh4FTBy84cU,12771
vllm/v1/attention/ops/triton_unified_attention.py,sha256=jo05PE1VFUfeOXhZRirXw3UCMIQUWOZY1zOB2_P1kAU,39926
-vllm/v1/attention/ops/vit_attn_wrappers.py,sha256=H7TzVtBZeFed44ec_NJOF2eGfJOVGw7VIF8VJWxcvyw,7404
+vllm/v1/attention/ops/vit_attn_wrappers.py,sha256=mTc6aReCGePTbFltpxQXba1E9YIOY46-8fgQeX-Rxnc,10054
vllm/v1/attention/selector.py,sha256=6YRzCyWndCuO773kxQnjgFRjkpU6OJAFfL1LDuY-jA4,4765
vllm/v1/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/v1/core/__pycache__/__init__.cpython-312.pyc,,
@@ -2949,9 +3009,9 @@ vllm/v1/core/__pycache__/single_type_kv_cache_manager.cpython-312.pyc,,
vllm/v1/core/block_pool.py,sha256=LqkyYorGoGYH3Uje16kI9wW1XD62gBIyJk7i2xQsbNk,20030
vllm/v1/core/encoder_cache_manager.py,sha256=yG0-WuDAGL0ljoYtxZfbW1ulyaY93k0Z68y8m2HP5DI,15537
vllm/v1/core/kv_cache_coordinator.py,sha256=jwTjrG1gjcMC0Ha5fEm9E_geGwUGNShYCyyUfCvMkeg,22399
-vllm/v1/core/kv_cache_manager.py,sha256=5owMv2aer3Sj_p6iNLc-MN2pRiVZESmdOo429zBpin4,20597
+vllm/v1/core/kv_cache_manager.py,sha256=9byU-exEalqxVO5ZvjoAmy2YlXuZ9p3bx2BFm87nQlI,22107
vllm/v1/core/kv_cache_metrics.py,sha256=7-3sBP7_XwSm1RxsRkKQtSq4wknWUSGpq58TcNa2f0c,3138
-vllm/v1/core/kv_cache_utils.py,sha256=ZkLJ_QXRbjpJIuAo2NtW2oaR0zc739pc0rm4HE3Bcqo,67368
+vllm/v1/core/kv_cache_utils.py,sha256=SOMvwuYKJcbrmBwTPNlBKFTQwXTPlwq1GFbyc78ady8,68571
vllm/v1/core/sched/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/v1/core/sched/__pycache__/__init__.cpython-312.pyc,,
vllm/v1/core/sched/__pycache__/async_scheduler.cpython-312.pyc,,
@@ -2962,13 +3022,13 @@ vllm/v1/core/sched/__pycache__/scheduler.cpython-312.pyc,,
vllm/v1/core/sched/__pycache__/utils.cpython-312.pyc,,
vllm/v1/core/sched/async_scheduler.py,sha256=TfN-iEL7BtXM47xBTpQjoJgI6TNqlqJVMN0icPB40JM,2659
vllm/v1/core/sched/interface.py,sha256=W-rtDuHPA6D3YndYfY97-e7njS0fAkhxNEqMwH60520,9077
-vllm/v1/core/sched/output.py,sha256=_ynXfniAPeJ3dfOj97D1mUAqjhmfvEs8kDIYRrW4_aA,9828
+vllm/v1/core/sched/output.py,sha256=X6_GOOUpT_AmpbAZJ_s0JPughfBIC3WKgdvZo_LBXfo,10129
vllm/v1/core/sched/request_queue.py,sha256=S42Tjl-4FS_mEDC4qpkfmD_KyaQFw37YkIBF4F6Crl4,6989
-vllm/v1/core/sched/scheduler.py,sha256=3Tp8IkUUBcdMj996i70WTEFbk2KC6ZjLdolCaB6u-WU,101673
-vllm/v1/core/sched/utils.py,sha256=3EmgmNpjOubCbvW5q0BFekIRQoGkGjSHdX1EskOOP9w,2148
-vllm/v1/core/single_type_kv_cache_manager.py,sha256=SfO1IxrkoY64LyN17CbJbSg_USU6TQmuVr_nRMC0Dmo,48201
-vllm/v1/cudagraph_dispatcher.py,sha256=D5UbrFbjT2lake_WN8shyvVLFAGCpwJmNk4rcqFbjNg,14433
-vllm/v1/engine/__init__.py,sha256=-eyCzUh0c7eXKl0R2LAGqTROR5eDAsVS1olvFlTH-vM,7650
+vllm/v1/core/sched/scheduler.py,sha256=yJ-_VWDF8E2SVLV6uHdZO1YR-x77IgU37XUFpE3XM00,126753
+vllm/v1/core/sched/utils.py,sha256=hegurlVaA0l60qwVQO1WKmw2_CYYWqYjNyXJFIFqobM,4115
+vllm/v1/core/single_type_kv_cache_manager.py,sha256=wTqh7tVrf9jqJQrv57Xftd2LYi5jWm9ijjBaUNgz2ZY,48726
+vllm/v1/cudagraph_dispatcher.py,sha256=tucITt0mub8XERm8UhoveEkUutts2TTFjC8EkFGXLmA,15325
+vllm/v1/engine/__init__.py,sha256=gIHKUrwQ5kP1I985MbkbsFU6CA0oZCPWBtuN_SrV7NM,8342
vllm/v1/engine/__pycache__/__init__.cpython-312.pyc,,
vllm/v1/engine/__pycache__/async_llm.cpython-312.pyc,,
vllm/v1/engine/__pycache__/coordinator.cpython-312.pyc,,
@@ -2982,18 +3042,18 @@ vllm/v1/engine/__pycache__/logprobs.cpython-312.pyc,,
vllm/v1/engine/__pycache__/output_processor.cpython-312.pyc,,
vllm/v1/engine/__pycache__/parallel_sampling.cpython-312.pyc,,
vllm/v1/engine/__pycache__/utils.cpython-312.pyc,,
-vllm/v1/engine/async_llm.py,sha256=0xjQJ_tdSIhgY3uuHqY09Trz8ISVvteCuuMeWjjBhkI,41673
-vllm/v1/engine/coordinator.py,sha256=wGbIqHFaHGThFBjgDDsdANVMI33mcp6G1uahTybRpto,17400
-vllm/v1/engine/core.py,sha256=caGTN0fD35dWGXP5iHnI7VzKt73aIxoe3qplR1braSA,74937
-vllm/v1/engine/core_client.py,sha256=_WBCgwuez4dD-qF0iGR1pmQb4V2eZpD-yK_cWuHrWqM,55909
+vllm/v1/engine/async_llm.py,sha256=r4BTWzNMdCA5-FU0hZHoeVPMXY_LFtNvjIPN3dVcn70,42353
+vllm/v1/engine/coordinator.py,sha256=kF3GdPZlPdZcVU4pudADKsZxRzfrpDjpu5UzIheSG2c,18416
+vllm/v1/engine/core.py,sha256=XEzaITtIP0gSG22sMgjkRVIOLmiunWXWHfVrF1EMaMY,87333
+vllm/v1/engine/core_client.py,sha256=P6XQMGBLHJfA6D9Z6G32syo1IdiC6P_dPPYXy-cTu-8,67207
vllm/v1/engine/detokenizer.py,sha256=2bi9yMmZWhtJxxPf_7DGXmQNwqAFewSg35BqdtCQwWc,12779
vllm/v1/engine/exceptions.py,sha256=Uqu6bUKEvMYpeVJYxlqYeveRcd9WnsFUENKt0JLMins,732
-vllm/v1/engine/input_processor.py,sha256=vFNkZ4gEZMwFkUiOTSTA4dDv5Ktm2QfZsZptgVB4Clo,19620
-vllm/v1/engine/llm_engine.py,sha256=-zRacelQUffg0eRLI70fu4874bsZ8-ykhD3mKx7zHi4,16761
+vllm/v1/engine/input_processor.py,sha256=tzo6FjUp3e7MbTMaqj8A19KLahWKtSxjGBXb_id7vbk,19184
+vllm/v1/engine/llm_engine.py,sha256=3l-FyV1wKuDWgmosrxJPoD7yOa0c2lrtMTl2v18rqVI,16788
vllm/v1/engine/logprobs.py,sha256=3KDfKD9nIkUyjzc9Jf5SgfBn8tPN3cnJkO8Vh4Sat3U,8682
vllm/v1/engine/output_processor.py,sha256=VcabTEItWEKOsuBqsnqDuzMFg8VCU1JPsyun_wQRkqI,31131
vllm/v1/engine/parallel_sampling.py,sha256=vz8KPGZAqvcG0bb37DqOCalw892efpb6apAXeeqE9h0,5481
-vllm/v1/engine/utils.py,sha256=QnqRiiHXdLk6T-0k3jZjRotKmSSCNXo8W2T98tkrjUI,42322
+vllm/v1/engine/utils.py,sha256=uc8uVI0inY4mxNHogC6yCkmgHX3X2_eeRdEXyqTAFGI,43327
vllm/v1/executor/__init__.py,sha256=VwU0pRwKjvcCcXII_7N2FIYT2oNj1fquLjvUS-qqE2U,227
vllm/v1/executor/__pycache__/__init__.cpython-312.pyc,,
vllm/v1/executor/__pycache__/abstract.cpython-312.pyc,,
@@ -3002,13 +3062,13 @@ vllm/v1/executor/__pycache__/ray_distributed_executor.cpython-312.pyc,,
vllm/v1/executor/__pycache__/ray_executor.cpython-312.pyc,,
vllm/v1/executor/__pycache__/ray_utils.cpython-312.pyc,,
vllm/v1/executor/__pycache__/uniproc_executor.cpython-312.pyc,,
-vllm/v1/executor/abstract.py,sha256=wugXhVnGaxPbqqNeBXnoLrDvenn9O28tsw55ZAQR-xU,13543
-vllm/v1/executor/multiproc_executor.py,sha256=jnhZ3EmZqTP48gi8gFgIGp6TCOeQF5JksK-pjtHPywY,36296
+vllm/v1/executor/abstract.py,sha256=wVly5AWYsLBqaVmW5kBgUKTr8d8pxAXsyPvk9fIwVlA,13984
+vllm/v1/executor/multiproc_executor.py,sha256=_mctcYpZNrK9D847MN-u8yS402BjBMJyCgi_cUHFHtg,37701
vllm/v1/executor/ray_distributed_executor.py,sha256=tl8FrTK1jqQk2SKF-Rt6RtXacGD6ND-hXIjTrquAbkc,289
-vllm/v1/executor/ray_executor.py,sha256=ZTQBTnJWWC2v8-WLNmmX2jInSo7vnpd3oq9Zvzr5dSs,25988
-vllm/v1/executor/ray_utils.py,sha256=BMK5PvSLYPVGGGj4BUWN9bTj70uGe6ZoXDBz7w19uc0,19384
-vllm/v1/executor/uniproc_executor.py,sha256=ZOJR_Xq5V581VWJjkvPQFwJMHuMqNDL7k_O7tzqNvQo,7244
-vllm/v1/kv_cache_interface.py,sha256=SAnPe2VhcWNmJLgrY_EABO9Atn4uhsqbhKftTyuN6hc,17920
+vllm/v1/executor/ray_executor.py,sha256=FfbQAD4_jA6yCyfhSihOcJyssdIJhR3oIL_JjWEBrhs,26095
+vllm/v1/executor/ray_utils.py,sha256=uQ0LTGxcW3xMTx3DvgI7xCjO6230GsMXCJfi65IN0T8,21725
+vllm/v1/executor/uniproc_executor.py,sha256=J29IY8nFtbJIpnpvgBKofNx-cTiY8H6MUvgPwURqvfQ,6930
+vllm/v1/kv_cache_interface.py,sha256=qrp2bwc4nZGnJAepkfF4Ab5LYFxFkBUeiYWst61q6pE,21143
vllm/v1/kv_offload/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/v1/kv_offload/__pycache__/__init__.cpython-312.pyc,,
vllm/v1/kv_offload/__pycache__/abstract.cpython-312.pyc,,
@@ -3035,7 +3095,7 @@ vllm/v1/kv_offload/worker/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJW
vllm/v1/kv_offload/worker/__pycache__/__init__.cpython-312.pyc,,
vllm/v1/kv_offload/worker/__pycache__/cpu_gpu.cpython-312.pyc,,
vllm/v1/kv_offload/worker/__pycache__/worker.cpython-312.pyc,,
-vllm/v1/kv_offload/worker/cpu_gpu.py,sha256=MjoWZ6i1prPRYxcFLvasoZgXLEw_NmPaZ79ylgedudU,12260
+vllm/v1/kv_offload/worker/cpu_gpu.py,sha256=INfbItgRU9Jr0Hven2rcLbjKBZQ7GDp_T6ejCVdsdu4,12515
vllm/v1/kv_offload/worker/worker.py,sha256=fbTsSg9KkNsFLDyq8NfO64M4pa9A0iGbtpsBH9m-B4M,5275
vllm/v1/metrics/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/v1/metrics/__pycache__/__init__.cpython-312.pyc,,
@@ -3051,25 +3111,25 @@ vllm/v1/metrics/prometheus.py,sha256=dhI7DQbT7X3gFlcQOXmoa7PY5JPyw3fvFm0azVEP-ok
vllm/v1/metrics/ray_wrappers.py,sha256=dOK01HU7uWdjDIjJ14Y98tnFXp81YfFA9-4COAig-SQ,6614
vllm/v1/metrics/reader.py,sha256=drwtEKJn4gOnaCFvlaYEjibmQx7cQHIsDiQii4B9Uk4,8694
vllm/v1/metrics/stats.py,sha256=8KdlA4K8y2-sfWQMDWfsGZIIfjccXhEiVToYJj047Gc,17808
-vllm/v1/outputs.py,sha256=CGVW4PTP6aUnyEwD-EbFrDsDTc-XhOZrIk1h-Dnw-tg,8705
+vllm/v1/outputs.py,sha256=6B8Mapc9GXmNgQEFmuv1QiXU06tNj1Vp1ldsE-Cwyfo,10547
vllm/v1/pool/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/v1/pool/__pycache__/__init__.cpython-312.pyc,,
vllm/v1/pool/__pycache__/metadata.cpython-312.pyc,,
vllm/v1/pool/metadata.py,sha256=_ztjF86T0kNUDHkBtRyoXJPeM02qt0cCmQqv76B9lUE,4106
-vllm/v1/request.py,sha256=hqCuC8XgYG9b8JqPs7um07-Mv_ZXY3upv61f2UAR184,12414
+vllm/v1/request.py,sha256=9fdnsu3d_QLkn5xUj_rfUeW6b41k8NbaFnChdqZyYz8,12516
vllm/v1/sample/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/v1/sample/__pycache__/__init__.cpython-312.pyc,,
vllm/v1/sample/__pycache__/metadata.cpython-312.pyc,,
vllm/v1/sample/__pycache__/rejection_sampler.cpython-312.pyc,,
vllm/v1/sample/__pycache__/sampler.cpython-312.pyc,,
-vllm/v1/sample/logits_processor/__init__.py,sha256=Hs6S01p-p_Xo-LcG5ihCwwGjqypt_8d_TEFnNLNfbIU,12047
+vllm/v1/sample/logits_processor/__init__.py,sha256=zxtm6oiHsp642GXAylVFjVjRSf4xJJrl8sAHoltyPT0,12085
vllm/v1/sample/logits_processor/__pycache__/__init__.cpython-312.pyc,,
vllm/v1/sample/logits_processor/__pycache__/builtin.cpython-312.pyc,,
vllm/v1/sample/logits_processor/__pycache__/interface.cpython-312.pyc,,
vllm/v1/sample/logits_processor/__pycache__/state.cpython-312.pyc,,
-vllm/v1/sample/logits_processor/builtin.py,sha256=ejN7QaSXMarG07G5rVPwCZ-MbYua3ji6JbnjErnz1x8,10323
+vllm/v1/sample/logits_processor/builtin.py,sha256=N-Bq-Nk_Nzc2xgrx7LdPLJpovdmfliMmyc0j5_YfGPM,12440
vllm/v1/sample/logits_processor/interface.py,sha256=t2hnDF_0VXH9D0ZzZ1v3KAteJslPs6fZyBe6XxA6gIY,3205
-vllm/v1/sample/logits_processor/state.py,sha256=Fk_xGySqz7y2MiAi8hL0Q95DNXWelyu5E4XB-pD0d4E,5666
+vllm/v1/sample/logits_processor/state.py,sha256=qWQR8UNixF-DnhgdyzrAHVuICBK8yzVS_BeGOyGfRKw,5676
vllm/v1/sample/metadata.py,sha256=jLeOuEJuWBJhgcy9TBL2gfjFj2DaJLWk-zPXARUEbsI,1142
vllm/v1/sample/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/v1/sample/ops/__pycache__/__init__.cpython-312.pyc,,
@@ -3083,27 +3143,31 @@ vllm/v1/sample/ops/logprobs.py,sha256=krdqPu-zZQChtxbssdvIQtm49l4EckUTljJNVSxZfj
vllm/v1/sample/ops/penalties.py,sha256=OOrkMB8o3sFEbYjnCq8fvTbT0EEdL7Zn2U1Euf44CjM,1929
vllm/v1/sample/ops/topk_topp_sampler.py,sha256=dfxdo8oMZWuCxd_RYyz3bXAVN21LAIoSO6RsTW42-PU,15403
vllm/v1/sample/ops/topk_topp_triton.py,sha256=Fp-qC5xPzkFodAWqKevHKQL2XxGhothV0okhin6CRAo,49608
-vllm/v1/sample/rejection_sampler.py,sha256=EFjwscvMpQKlUjKYl4oB2tYjvfPkYKMczi5rK4edzSQ,31448
+vllm/v1/sample/rejection_sampler.py,sha256=Mbe4cf4PxqXHzDVSvyexW81_tzMJ1LI5S5UMDeWmqq0,31742
vllm/v1/sample/sampler.py,sha256=uTYAiQ2r3lQxCjmzTPKpYXP7tmIRudSnQ4bBq57krl4,12583
vllm/v1/serial_utils.py,sha256=kDvyyWcvgeomZ5i8K4Vvap8wCbn8ZjiQyJFFv-HoR5k,19753
vllm/v1/spec_decode/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/v1/spec_decode/__pycache__/__init__.cpython-312.pyc,,
vllm/v1/spec_decode/__pycache__/draft_model.cpython-312.pyc,,
vllm/v1/spec_decode/__pycache__/eagle.cpython-312.pyc,,
+vllm/v1/spec_decode/__pycache__/extract_hidden_states.cpython-312.pyc,,
vllm/v1/spec_decode/__pycache__/medusa.cpython-312.pyc,,
vllm/v1/spec_decode/__pycache__/metadata.cpython-312.pyc,,
vllm/v1/spec_decode/__pycache__/metrics.cpython-312.pyc,,
+vllm/v1/spec_decode/__pycache__/multi_layer_eagle.cpython-312.pyc,,
vllm/v1/spec_decode/__pycache__/ngram_proposer.cpython-312.pyc,,
vllm/v1/spec_decode/__pycache__/suffix_decoding.cpython-312.pyc,,
vllm/v1/spec_decode/__pycache__/utils.cpython-312.pyc,,
vllm/v1/spec_decode/draft_model.py,sha256=mcv4VYfi9qcg_IUIwpYV82NXPUoJqHg_kDq9byFUSo8,2904
-vllm/v1/spec_decode/eagle.py,sha256=BnMaN9L6Ablyl38UCGnLMgNpTw-VX_6aRwmHnEcnK2Q,78574
+vllm/v1/spec_decode/eagle.py,sha256=K1yxTolHg1FqE1JxRl_KTKzFSw_Mw0L0JFmh_Ecuaqk,84237
+vllm/v1/spec_decode/extract_hidden_states.py,sha256=U2OVdz1_cTgAKm-cMRzpMuGm2tcLeUAcJU8l4NDc-ag,16545
vllm/v1/spec_decode/medusa.py,sha256=8MlI6DXXZLx6dZbGeu-PkKUuWmRUW0eg9TnCH6tk_V4,2694
-vllm/v1/spec_decode/metadata.py,sha256=pmlkh7PIgjwRn2Ad6BUENBve_v6Vuv5nPB9XCrICjag,2353
+vllm/v1/spec_decode/metadata.py,sha256=ekCirq6LvkOCh39r4pAfJ-3MTp3FOjYI8fcjypnKJoI,3771
vllm/v1/spec_decode/metrics.py,sha256=KZ4k_OZ5hyJFc9EciPxjEXJCTDWG6BuT9369H0VANxI,7867
+vllm/v1/spec_decode/multi_layer_eagle.py,sha256=0Tn6vHKEzHfibuiMjqheF8_KXIZeXGNRkFNa8h_conY,16506
vllm/v1/spec_decode/ngram_proposer.py,sha256=TTS_3Z_72a1U5pgTTJYorYgbk8cFGOrC1I8OUB-cZ2M,10935
vllm/v1/spec_decode/suffix_decoding.py,sha256=cbt6zu0gqI8Fe7mjNqea2vKy2dQWLStQs3yIvj5i2Pw,4442
-vllm/v1/spec_decode/utils.py,sha256=Bn6fU9Jzg6P2mNP-SjRCAfQs-Iw0lSqlReAgUW_puqA,15121
+vllm/v1/spec_decode/utils.py,sha256=iTaZanw9Azukt0U2mZYTxfcL8xLAH3ljY7JRrdcepBA,15624
vllm/v1/structured_output/__init__.py,sha256=RnhQEalmOswNbXZGTVF-nvyzl8U31ppO6mGFin8cU4Q,14970
vllm/v1/structured_output/__pycache__/__init__.cpython-312.pyc,,
vllm/v1/structured_output/__pycache__/backend_guidance.cpython-312.pyc,,
@@ -3146,9 +3210,9 @@ vllm/v1/worker/__pycache__/xpu_model_runner.cpython-312.pyc,,
vllm/v1/worker/__pycache__/xpu_worker.cpython-312.pyc,,
vllm/v1/worker/block_table.py,sha256=G2urbvJ3fJQXzTimkvkxokCG5DVVL2KDQ76w4TuQ5uU,13264
vllm/v1/worker/cp_utils.py,sha256=BOVYrUTA2bv1YpaoRRTQPiy0oAERoWldGwFFkitUCcg,2389
-vllm/v1/worker/cpu_model_runner.py,sha256=ALLO8HBcqsntz68X8w92AvjSUmE49fJfpamsrjq75Qc,4316
-vllm/v1/worker/cpu_worker.py,sha256=nasaeTVYKSwDOa4cds-kcnHfmK4FP2nhmCA0eANJsrc,9015
-vllm/v1/worker/dp_utils.py,sha256=HF6MZ5bGffPbVV7rsyfasF18J3yqkj6PSp27NBMo2Eo,8851
+vllm/v1/worker/cpu_model_runner.py,sha256=ymDHFo4eiV06SIrUVvWpWfbgb1S9YxugtyukIPxIvr0,4531
+vllm/v1/worker/cpu_worker.py,sha256=dhxpYHOor_HFRckX6K_MaogHBHypOwOnrX7IKjZGq1k,9066
+vllm/v1/worker/dp_utils.py,sha256=wUPKqk9kn4OV14Pf3rhr9yLv0VTNoTzz5wRBxW7qPyU,8688
vllm/v1/worker/ec_connector_model_runner_mixin.py,sha256=j9YWGXs4tuxFrDJ-SLOcdPbTJ6623cI5ptbDeQCpzTs,2988
vllm/v1/worker/gpu/README.md,sha256=1H_fnHpMus5c95vNXWX5-Ryx3UkQiRwY_b5WDd0D8Ww,181
vllm/v1/worker/gpu/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -3167,14 +3231,15 @@ vllm/v1/worker/gpu/__pycache__/model_runner.cpython-312.pyc,,
vllm/v1/worker/gpu/__pycache__/pp_utils.cpython-312.pyc,,
vllm/v1/worker/gpu/__pycache__/states.cpython-312.pyc,,
vllm/v1/worker/gpu/__pycache__/structured_outputs.cpython-312.pyc,,
-vllm/v1/worker/gpu/async_utils.py,sha256=EJlTaJg44ySvT2WFxfIgQk71K1ikYoVLOXT3MnJAGas,3533
+vllm/v1/worker/gpu/__pycache__/warmup.cpython-312.pyc,,
+vllm/v1/worker/gpu/async_utils.py,sha256=ulLo02hHgh5QxMd_nCGN9nXBfRiAcnyJpGjlpL86mGg,4923
vllm/v1/worker/gpu/attn_utils.py,sha256=AZ_u3enlzaIsQT0y5r2A4-aEr7QeQnEy1ZscR28zmy0,8624
-vllm/v1/worker/gpu/block_table.py,sha256=9r3ddR4ikKo6BaXYRHWmh0-QYWYtZM5NE2pFYWzkW6Q,9709
+vllm/v1/worker/gpu/block_table.py,sha256=1o5qeyq8neXRt_yJf8WIWlvmoDRP-hK8tEWaehCTabQ,10454
vllm/v1/worker/gpu/buffer_utils.py,sha256=rECYHbR75Ls1zLIRa4hJfRokKfYaR5tuDgFpy9UHZYc,7209
vllm/v1/worker/gpu/cp_utils.py,sha256=ffa-o_Ujw2M5D4ceuTBqSBeFuLD0lbDHDqhASxvuJ7E,1700
-vllm/v1/worker/gpu/cudagraph_utils.py,sha256=0Q5OHMcWm5FIwVTDp7EjoBl5QlGirZknUcmzlv0BGJM,17605
+vllm/v1/worker/gpu/cudagraph_utils.py,sha256=_bzWlD267-JvuMOkFKuHgVYiBkyVHU28z-wbMenj-8c,16053
vllm/v1/worker/gpu/dp_utils.py,sha256=ZgZkMHZgAZGBfpJE-ZDQEOgEENRh583xt3Xq_LQrOdE,2936
-vllm/v1/worker/gpu/input_batch.py,sha256=4NaX9YxOlGjG2WjwuzB6XG4xDQjS2QKpqVGthxsDrzk,17462
+vllm/v1/worker/gpu/input_batch.py,sha256=Hj6nmlXwoD8zKT66wHeuAs6sygVlu7xt7BR5YI6DJE0,17907
vllm/v1/worker/gpu/kv_connector.py,sha256=rEIQ0BaHqWdq0KNafnsBwwmwNo0l0ljA2MHKLq3eTb0,4709
vllm/v1/worker/gpu/lora_utils.py,sha256=XF82PxJNFA4AR6BUk2kyqmgKHdfF9o6tkLibEELbMM0,1563
vllm/v1/worker/gpu/metrics/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -3183,11 +3248,23 @@ vllm/v1/worker/gpu/metrics/__pycache__/logits.cpython-312.pyc,,
vllm/v1/worker/gpu/metrics/logits.py,sha256=6gWNdRCgd6Vfisfkw9WWUINGJE6QjLNyhl1xWfYu3Ik,1205
vllm/v1/worker/gpu/mm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/v1/worker/gpu/mm/__pycache__/__init__.cpython-312.pyc,,
+vllm/v1/worker/gpu/mm/__pycache__/encoder_cache.cpython-312.pyc,,
vllm/v1/worker/gpu/mm/__pycache__/encoder_runner.cpython-312.pyc,,
vllm/v1/worker/gpu/mm/__pycache__/mrope_utils.cpython-312.pyc,,
-vllm/v1/worker/gpu/mm/encoder_runner.py,sha256=mBrs8T1_A2niX2Qo6ktDYJoTPT7Karv8bDV8TBPBwLA,7244
+vllm/v1/worker/gpu/mm/encoder_cache.py,sha256=B5nfNNjgosWQPKvWbJXcacBs4Yxg8ZGL8fdKlorhCIk,1326
+vllm/v1/worker/gpu/mm/encoder_runner.py,sha256=_J3xYnU6qh4S7QbB_KMgLVCQCz_Ubxwhl6mhnD7hk4g,6159
vllm/v1/worker/gpu/mm/mrope_utils.py,sha256=NO9QS1rPfqAAMOI8ppNp4aFJrsmhagge8ZAZpTYS-RM,4879
-vllm/v1/worker/gpu/model_runner.py,sha256=5ggLg-eUroMW5c1fz4GOnx8cfe3_b6blXjfZTmoKB_g,47058
+vllm/v1/worker/gpu/model_runner.py,sha256=WreH6o1QR0TjJLBslGTxpn_H8huNpCf3M6EiNsRSrcs,48326
+vllm/v1/worker/gpu/model_states/__init__.py,sha256=WkIIo16IxVTj5oOB_opmueNvTS_XHhf0WJ1g6gOSu04,530
+vllm/v1/worker/gpu/model_states/__pycache__/__init__.cpython-312.pyc,,
+vllm/v1/worker/gpu/model_states/__pycache__/default.cpython-312.pyc,,
+vllm/v1/worker/gpu/model_states/__pycache__/interface.cpython-312.pyc,,
+vllm/v1/worker/gpu/model_states/default.py,sha256=rKG1UEuAEylulebXXZvCp8aVwm57bmrZOVHaFZ-NVCo,6345
+vllm/v1/worker/gpu/model_states/interface.py,sha256=xZ5GQBbbWr6DnpHhz8X-a6xm583BSbf8nF94XtdZBrg,1971
+vllm/v1/worker/gpu/pool/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
+vllm/v1/worker/gpu/pool/__pycache__/__init__.cpython-312.pyc,,
+vllm/v1/worker/gpu/pool/__pycache__/pooling_runner.cpython-312.pyc,,
+vllm/v1/worker/gpu/pool/pooling_runner.py,sha256=_RNihQqlvQnR1SfH46vZunkIAxFfAB-co40frFR5nQ0,1682
vllm/v1/worker/gpu/pp_utils.py,sha256=4iDJrkigTxM3IcyoEAnZzHlT50wBVGKQwKqHu7vhBP8,1371
vllm/v1/worker/gpu/sample/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
vllm/v1/worker/gpu/sample/__pycache__/__init__.cpython-312.pyc,,
@@ -3201,16 +3278,16 @@ vllm/v1/worker/gpu/sample/__pycache__/penalties.cpython-312.pyc,,
vllm/v1/worker/gpu/sample/__pycache__/prompt_logprob.cpython-312.pyc,,
vllm/v1/worker/gpu/sample/__pycache__/sampler.cpython-312.pyc,,
vllm/v1/worker/gpu/sample/__pycache__/states.cpython-312.pyc,,
-vllm/v1/worker/gpu/sample/bad_words.py,sha256=rU2u_9OnLDLdTsJX2pOqy5xaAihPew7YUlG5AqfW-gE,6483
-vllm/v1/worker/gpu/sample/gumbel.py,sha256=_sCRSx0VS0H743wu2ZYMvAMsL4XsJj0iDjvGzFFQWbo,4372
-vllm/v1/worker/gpu/sample/logit_bias.py,sha256=c2meBhCqk2KplYd5r1_HlafQjqTEodM04Rfhx06qb04,9438
+vllm/v1/worker/gpu/sample/bad_words.py,sha256=6ZWcrjNSk_DZ5iosRDjrBSRjC3w_VAcis6r3SZrJ55Q,6489
+vllm/v1/worker/gpu/sample/gumbel.py,sha256=saZxdQQBUBuAff1GIE_aozusW_hxUb1T22jJLCt8xXI,4458
+vllm/v1/worker/gpu/sample/logit_bias.py,sha256=BbkqDRpL4pGiYtcH-hRIK564p7QVTX5QPRh1A8Lt4EY,9496
vllm/v1/worker/gpu/sample/logprob.py,sha256=mhQxQN8aKznHd4hOTtBcwNY8dIeasQkpksdggp1pX-8,3983
-vllm/v1/worker/gpu/sample/min_p.py,sha256=lQPdt01iJjWeIvkC5Tan-3oQM1kBOMaC-4IzhpjOjwQ,1661
+vllm/v1/worker/gpu/sample/min_p.py,sha256=iEcKMo0tc9Z7UP2i_J-tf59tmDPyo5qTzbpPp9Y2Ij0,1761
vllm/v1/worker/gpu/sample/output.py,sha256=Uzj6Lujb5LEswYwYUsTmJ9HkDwJ-QoV1s1uyH4VZ0CE,349
-vllm/v1/worker/gpu/sample/penalties.py,sha256=w_WNkWAH1wBLZ5omhaE44uex-C8mL3AnO-hJ5r4pOso,10636
+vllm/v1/worker/gpu/sample/penalties.py,sha256=7HUBXmO-CHSLVpIpn-OUG-D64RIqW8XURZF_grCjoJs,10757
vllm/v1/worker/gpu/sample/prompt_logprob.py,sha256=bAdWFBumN1rPxQbUi_tcYZCKVN-3l6dKMtdmwzZNwms,8131
-vllm/v1/worker/gpu/sample/sampler.py,sha256=wwaw91gaFN1glvTNEo6Pg56iWGZvW7AnZIRwcg-qiuY,5838
-vllm/v1/worker/gpu/sample/states.py,sha256=hkXodXx-Vqe6Vj9m7XHJSurskz9TuR_7OzORx0ZMJ_Q,3886
+vllm/v1/worker/gpu/sample/sampler.py,sha256=EBx21NwKubvWQu_X-5qov_Jmn2vgUeoUPSNMA0cuAKM,5972
+vllm/v1/worker/gpu/sample/states.py,sha256=OW-3nHrPXB3Ksa_H_UMA2bJAe-_IdNHoiuxnwtTQ5qk,3949
vllm/v1/worker/gpu/spec_decode/__init__.py,sha256=MwFTIriVJZZqVhVWbas6czZW_eXW5TwSoCg1Ro-ETqQ,584
vllm/v1/worker/gpu/spec_decode/__pycache__/__init__.cpython-312.pyc,,
vllm/v1/worker/gpu/spec_decode/__pycache__/rejection_sample.cpython-312.pyc,,
@@ -3221,28 +3298,33 @@ vllm/v1/worker/gpu/spec_decode/eagle/__pycache__/cudagraph.cpython-312.pyc,,
vllm/v1/worker/gpu/spec_decode/eagle/__pycache__/eagle3_utils.cpython-312.pyc,,
vllm/v1/worker/gpu/spec_decode/eagle/__pycache__/speculator.cpython-312.pyc,,
vllm/v1/worker/gpu/spec_decode/eagle/__pycache__/utils.cpython-312.pyc,,
-vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py,sha256=79E6LvNpC2GK13mCX-uKGhbvDokyR1DcTFyxnsns8AA,6807
+vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py,sha256=19wpXyZ5sT6MwdBNnDJYgumwpYuuzEacnO70gWsIHXM,7681
vllm/v1/worker/gpu/spec_decode/eagle/eagle3_utils.py,sha256=h2v6qPgEyrmyA_VYBrEXXyAMiPWe7qmNsHMaB2i69gc,1726
-vllm/v1/worker/gpu/spec_decode/eagle/speculator.py,sha256=LgYzo-ddGBlwgMrG_Lrv5iN0AJ1rG8wUiRrKHbmu83k,21159
+vllm/v1/worker/gpu/spec_decode/eagle/speculator.py,sha256=XJJn3xQiq-NVBY4LNFVinMWPrdNsEZy-mOuFCIoe1OM,22298
vllm/v1/worker/gpu/spec_decode/eagle/utils.py,sha256=1gmAeKdUTXhgnEf1NS33GEaNyIuBx_KQTEKZbPA8P5U,2170
vllm/v1/worker/gpu/spec_decode/rejection_sample.py,sha256=316CS-i86IEDzWt6y6bEhwQkb4VviFQfqaXDTAUuCPg,2218
vllm/v1/worker/gpu/spec_decode/utils.py,sha256=w82Z2kbCmr05EjZHIPMaDWhwMGbSN9qLFkw0vntVpOA,1915
vllm/v1/worker/gpu/states.py,sha256=ZaKsOUHd9kd7j_oC9RHkcQnRXJCLJ3dPyAdND0DQe3E,4806
vllm/v1/worker/gpu/structured_outputs.py,sha256=qykZ-Y9PFFk2cY8JZgz1l9Rhg_zXT5XHKsGyxw5cIcg,4108
-vllm/v1/worker/gpu_input_batch.py,sha256=Eb9Biwl5MM2UKzTolIaNY3DLMv1U4OpIAN6FDkJaqYM,43321
-vllm/v1/worker/gpu_model_runner.py,sha256=2wfvqi9aOVmYPeRk0p582CxdNNBNlYwEwqViq-_0LVA,277994
+vllm/v1/worker/gpu/warmup.py,sha256=mK46LlfKN0f9dVt8jPUfjCXLuZ0nd4p_1Uyu0uy-5AE,4091
+vllm/v1/worker/gpu_input_batch.py,sha256=hpOGLUJEaFj4FKhD2lMka9N5aobDZ1muy3PPT8fucFE,46692
+vllm/v1/worker/gpu_model_runner.py,sha256=xgmFWuLHIpF1pLkF8nrz5e8l7iM7AjXUK-dBYdtO_1Q,295633
vllm/v1/worker/gpu_ubatch_wrapper.py,sha256=M_sKqzRgCPDEGTP8Fpjahm_9k-gSerCHRnPUgsXXfaA,19455
-vllm/v1/worker/gpu_worker.py,sha256=Jo_zldnw2w1dlJiBcZUQa5IVgo9WbjJsl33sigOvIJ4,47469
+vllm/v1/worker/gpu_worker.py,sha256=O4Hdp4J2rS8Rw_wI1eMyTP9Um3HNPL-vlpr-yx1xIXs,40459
vllm/v1/worker/kv_connector_model_runner_mixin.py,sha256=JZrNatrj2mu_1H5BMHmV2PC3_npX3MKd68RUDFWp_Co,10927
vllm/v1/worker/lora_model_runner_mixin.py,sha256=SbtsW2hCNf1T6yF2al3rwXYXZ4gBwWVW8JdiUUqSFMY,11014
-vllm/v1/worker/mamba_utils.py,sha256=4tM43Oh8BJnHdx7Zuo_yF3kcqFP55gaUt8YdJXGXhbE,10073
+vllm/v1/worker/mamba_utils.py,sha256=BA1n7csi212an-tECS1Ui8MNkKV1CVqTK8prqQF6uNo,10489
vllm/v1/worker/tpu_input_batch.py,sha256=Cc9z4nxam3SR1QSIPEkPYsAJ1e2ins8oLL5FcppSAyU,24177
vllm/v1/worker/ubatch_utils.py,sha256=LKzttglKXQ-8k-aecrgp3wt1gYuQsCwOyiw0rTshUeQ,8550
vllm/v1/worker/ubatching.py,sha256=QDkSQcVk_rXxbHeJiuauFS7W5xpGguKkBjh3hdjeAtc,8394
-vllm/v1/worker/utils.py,sha256=oYzLhX7_-dSG1qq7EInQC8PkYlD7SpCFB5HZzOSxg1g,9446
-vllm/v1/worker/worker_base.py,sha256=sHdv9d9kNFNcZnzgPIwGzt8YJD3VKJ9JZJA2WDM_8Zw,14129
-vllm/v1/worker/workspace.py,sha256=YgWweZ2HOEV75SnSlTpIlUQonFSktC_EjGzOvXmTbH4,9102
+vllm/v1/worker/utils.py,sha256=HAUB-4Ba9FE8eJBmmimrikc8DIu9W9lBT3ipUch2MNU,23346
+vllm/v1/worker/worker_base.py,sha256=qF5yIRnHw1lm5jCb4OKQ3gE0h2TWxCUaPwVIOP-6qvo,13986
+vllm/v1/worker/workspace.py,sha256=ZDxKVepVtm9gdJK_dzHyxpkg2iz1tDA5mVkqf4zqqsE,10070
vllm/v1/worker/xpu_model_runner.py,sha256=Vuw43jij32tueGHg-2Hxiw59ZwxKsWdn2y2WipBudK4,1534
vllm/v1/worker/xpu_worker.py,sha256=XLL1AXvgP8k_PTXpZRMFYn2el156VZ068Ktx9i4ML4Y,4105
-vllm/version.py,sha256=nYdjue0uc9o6X6OSVlXN3DROsCofHt0xAlw6EDm5vOA,57
+vllm/version.py,sha256=qlCNmDctpuxf10HWjBqD2XAPyLvtWXBrLFVw274Zk7w,54
vllm/vllm_flash_attn/.gitkeep,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
+vllm/vllm_flash_attn/__init__.py,sha256=qVPVedLMtxi42E_M9z7_Z8nNtQPhunOQ6RIaKzTcjq4,689
+vllm/vllm_flash_attn/__pycache__/__init__.cpython-312.pyc,,
+vllm/vllm_flash_attn/__pycache__/flash_attn_interface.cpython-312.pyc,,
+vllm/vllm_flash_attn/flash_attn_interface.py,sha256=VclHCNwVRywtvnHOVklDhkBcm1fdnIGLvQlz8h4FWCs,20201
diff --git a/vllm-0.16.1rc0+corex.4.4.0.dist-info/REQUESTED b/vllm-0.17.0+corex.20260420090923.dist-info/REQUESTED
similarity index 100%
rename from vllm-0.16.1rc0+corex.4.4.0.dist-info/REQUESTED
rename to vllm-0.17.0+corex.20260420090923.dist-info/REQUESTED
diff --git a/vllm-0.16.1rc0+corex.4.4.0.dist-info/WHEEL b/vllm-0.17.0+corex.20260420090923.dist-info/WHEEL
similarity index 65%
rename from vllm-0.16.1rc0+corex.4.4.0.dist-info/WHEEL
rename to vllm-0.17.0+corex.20260420090923.dist-info/WHEEL
index 1ef5583..0885d05 100644
--- a/vllm-0.16.1rc0+corex.4.4.0.dist-info/WHEEL
+++ b/vllm-0.17.0+corex.20260420090923.dist-info/WHEEL
@@ -1,5 +1,5 @@
Wheel-Version: 1.0
-Generator: setuptools (82.0.0)
+Generator: setuptools (80.10.2)
Root-Is-Purelib: true
Tag: py3-none-any
diff --git a/vllm-0.17.0+corex.20260420090923.dist-info/direct_url.json b/vllm-0.17.0+corex.20260420090923.dist-info/direct_url.json
new file mode 100644
index 0000000..2795c48
--- /dev/null
+++ b/vllm-0.17.0+corex.20260420090923.dist-info/direct_url.json
@@ -0,0 +1 @@
+{"archive_info": {"hash": "sha256=844cb01bfec51cf2ec37322ff74c77b31cced2cd9253312cff724ebd06b7f740", "hashes": {"sha256": "844cb01bfec51cf2ec37322ff74c77b31cced2cd9253312cff724ebd06b7f740"}}, "url": "file:///home/poweruser/zrl/code/vllm_hub/vllm/build_pip/vllm-0.17.0%2Bcorex.20260420090923-py3-none-any.whl"}
\ No newline at end of file
diff --git a/vllm-0.16.1rc0+corex.4.4.0.dist-info/entry_points.txt b/vllm-0.17.0+corex.20260420090923.dist-info/entry_points.txt
similarity index 100%
rename from vllm-0.16.1rc0+corex.4.4.0.dist-info/entry_points.txt
rename to vllm-0.17.0+corex.20260420090923.dist-info/entry_points.txt
diff --git a/vllm-0.16.1rc0+corex.4.4.0.dist-info/licenses/LICENSE b/vllm-0.17.0+corex.20260420090923.dist-info/licenses/LICENSE
similarity index 100%
rename from vllm-0.16.1rc0+corex.4.4.0.dist-info/licenses/LICENSE
rename to vllm-0.17.0+corex.20260420090923.dist-info/licenses/LICENSE
diff --git a/vllm-0.16.1rc0+corex.4.4.0.dist-info/top_level.txt b/vllm-0.17.0+corex.20260420090923.dist-info/top_level.txt
similarity index 100%
rename from vllm-0.16.1rc0+corex.4.4.0.dist-info/top_level.txt
rename to vllm-0.17.0+corex.20260420090923.dist-info/top_level.txt
diff --git a/vllm/.gitignore b/vllm/.gitignore
deleted file mode 100644
index 3492fbe..0000000
--- a/vllm/.gitignore
+++ /dev/null
@@ -1,244 +0,0 @@
-# version file generated by setuptools-scm
-/vllm/_version.py
-
-# vllm-flash-attn built from source
-vllm/vllm_flash_attn/*
-!vllm/vllm_flash_attn/__init__.py
-!vllm/vllm_flash_attn/flash_attn_interface.py
-
-# OpenAI triton kernels copied from source
-vllm/third_party/triton_kernels/*
-
-# FlashMLA interface copied from source
-vllm/third_party/flashmla/flash_mla_interface.py
-
-# triton jit
-.triton
-
-# Byte-compiled / optimized / DLL files
-__pycache__/
-*.py[cod]
-*$py.class
-
-# C extensions
-*.so
-
-# Distribution / packaging
-.Python
-build/
-cmake-build-*/
-CMakeUserPresets.json
-develop-eggs/
-dist/
-downloads/
-eggs/
-.eggs/
-lib/
-lib64/
-parts/
-sdist/
-var/
-wheels/
-share/python-wheels/
-*.egg-info/
-.installed.cfg
-*.egg
-MANIFEST
-/.deps/
-
-# PyInstaller
-# Usually these files are written by a python script from a template
-# before PyInstaller builds the exe, so as to inject date/other infos into it.
-*.manifest
-*.spec
-
-# Installer logs
-pip-log.txt
-pip-delete-this-directory.txt
-
-# Unit test / coverage reports
-htmlcov/
-.tox/
-.nox/
-.coverage
-.coverage.*
-.cache
-nosetests.xml
-coverage.xml
-*.cover
-*.py,cover
-.hypothesis/
-.pytest_cache/
-cover/
-
-# Translations
-*.mo
-*.pot
-
-# Django stuff:
-*.log
-local_settings.py
-db.sqlite3
-db.sqlite3-journal
-
-# Flask stuff:
-instance/
-.webassets-cache
-
-# Scrapy stuff:
-.scrapy
-
-# PyBuilder
-.pybuilder/
-target/
-
-# Jupyter Notebook
-.ipynb_checkpoints
-
-# IPython
-profile_default/
-ipython_config.py
-
-# generated files
-**/generated/**
-
-# uv
-uv.lock
-
-# pyenv
-# For a library or package, you might want to ignore these files since the code is
-# intended to run in multiple environments; otherwise, check them in:
-# .python-version
-
-# pipenv
-# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
-# However, in case of collaboration, if having platform-specific dependencies or dependencies
-# having no cross-platform support, pipenv may install dependencies that don't work, or not
-# install all needed dependencies.
-#Pipfile.lock
-
-# poetry
-# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
-# This is especially recommended for binary packages to ensure reproducibility, and is more
-# commonly ignored for libraries.
-# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
-#poetry.lock
-
-# pdm
-# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
-#pdm.lock
-# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
-# in version control.
-# https://pdm.fming.dev/#use-with-ide
-.pdm.toml
-
-# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
-__pypackages__/
-
-# Celery stuff
-celerybeat-schedule
-celerybeat.pid
-
-# SageMath parsed files
-*.sage.py
-
-# Environments
-.env
-.venv
-env/
-venv/
-ENV/
-env.bak/
-venv.bak/
-
-# Spyder project settings
-.spyderproject
-.spyproject
-
-# Rope project settings
-.ropeproject
-
-# mkdocs documentation
-/site
-docs/argparse
-docs/examples/*
-!docs/examples/README.md
-
-# mypy
-.mypy_cache/
-.dmypy.json
-dmypy.json
-
-# Pyre type checker
-.pyre/
-
-# pytype static type analyzer
-.pytype/
-
-# Cython debug symbols
-cython_debug/
-
-# PyCharm
-# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
-# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
-# and can be added to the global gitignore or merged into this file. For a more nuclear
-# option (not recommended) you can uncomment the following to ignore the entire idea folder.
-.idea/
-
-# VSCode
-.vscode/
-
-# Claude
-.claude/
-
-# Codex
-.codex/
-
-# Cursor
-.cursor/
-
-# DS Store
-.DS_Store
-
-# Results
-*.csv
-
-# Python pickle files
-*.pkl
-
-# Sphinx documentation
-_build/
-
-# vim swap files
-*.swo
-*.swp
-
-# hip files generated by PyTorch
-*.hip
-*_hip*
-hip_compat.h
-
-# Benchmark dataset
-benchmarks/**/*.json
-
-# Linting
-actionlint
-shellcheck*/
-
-# Ignore moe/marlin_moe gen code
-csrc/moe/marlin_moe_wna16/kernel_*
-
-# Ignore ep_kernels_workspace folder
-ep_kernels_workspace/
-
-# Allow tracked library source folders under submodules (e.g., benchmarks/lib)
-!vllm/benchmarks/lib/
-
-# Generated gRPC protobuf files (compiled at build time from vllm_engine.proto)
-vllm/grpc/vllm_engine_pb2.py
-vllm/grpc/vllm_engine_pb2_grpc.py
-vllm/grpc/vllm_engine_pb2.pyi
-
-# Ignore generated cpu headers
-csrc/cpu/cpu_attn_dispatch_generated.h
-
diff --git a/vllm/__init__.py b/vllm/__init__.py
index 19b2cdc..968d1a1 100644
--- a/vllm/__init__.py
+++ b/vllm/__init__.py
@@ -14,8 +14,6 @@ import typing
import vllm.env_override # noqa: F401
MODULE_ATTRS = {
- "bc_linter_skip": "._bc_linter:bc_linter_skip",
- "bc_linter_include": "._bc_linter:bc_linter_include",
"AsyncEngineArgs": ".engine.arg_utils:AsyncEngineArgs",
"EngineArgs": ".engine.arg_utils:EngineArgs",
"AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine",
@@ -62,8 +60,6 @@ if typing.TYPE_CHECKING:
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.executor.ray_utils import initialize_ray_cluster
-
- from ._bc_linter import bc_linter_include, bc_linter_skip
else:
def __getattr__(name: str) -> typing.Any:
@@ -79,8 +75,6 @@ else:
__all__ = [
"__version__",
- "bc_linter_skip",
- "bc_linter_include",
"__version_tuple__",
"LLM",
"ModelRegistry",
diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py
index 8ef34bf..c8366ec 100644
--- a/vllm/_aiter_ops.py
+++ b/vllm/_aiter_ops.py
@@ -87,6 +87,10 @@ def _rocm_aiter_fused_moe_impl(
a2_scale: torch.Tensor | None = None,
num_local_tokens: torch.Tensor | None = None,
output_dtype: torch.dtype | None = None,
+ hidden_pad: int = 0,
+ intermediate_pad: int = 0,
+ bias1: torch.Tensor | None = None,
+ bias2: torch.Tensor | None = None,
) -> torch.Tensor:
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe
@@ -110,6 +114,10 @@ def _rocm_aiter_fused_moe_impl(
a2_scale,
num_local_tokens=num_local_tokens,
dtype=output_dtype,
+ hidden_pad=hidden_pad,
+ intermediate_pad=intermediate_pad,
+ bias1=bias1,
+ bias2=bias2,
)
@@ -307,6 +315,28 @@ def _rocm_aiter_grouped_topk_fake(
pass
+def _rocm_aiter_fused_topk_impl(
+ x: torch.Tensor,
+ router_logits: torch.Tensor,
+ top_k: int,
+ gate_up: bool,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ from aiter.fused_moe import fused_topk
+
+ # fused_topk returns (topk_weights, topk_indices)
+ return fused_topk(x, router_logits, top_k, gate_up)
+
+
+def _rocm_aiter_fused_topk_fake(
+ x: torch.Tensor,
+ router_logits: torch.Tensor,
+ top_k: int,
+ gate_up: bool,
+) -> None:
+ # tuple[torch.Tensor, torch.Tensor]:
+ pass
+
+
# Cache whether aiter supports FP8 MLA parameters
_AITER_MLA_SUPPORTS_FP8: bool | None = None
@@ -994,6 +1024,70 @@ class rocm_aiter_ops:
cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS
cls._TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM
+ @staticmethod
+ def get_aiter_activation_type(activation_str: str):
+ """
+ Given an activation type as a string, returns the corresponding aiter ActivationType enum.
+ Supported activation types: "no", "none", "silu", "gelu", "swiglu".
+ Returns None if the mapping fails.
+
+ Args:
+ activation_str (str): Activation type as string.
+
+ Returns:
+ Aiter ActivationType enum value, or None if not found.
+ """
+ # Import only locally, since aiter may not always be available.
+ try:
+ from aiter import ActivationType
+ except ImportError:
+ return None
+
+ if not isinstance(activation_str, str):
+ return None
+
+ name = activation_str.strip().lower()
+ mapping = {
+ "none": ActivationType.No,
+ "no": ActivationType.No,
+ "silu": ActivationType.Silu,
+ "gelu": ActivationType.Gelu,
+ "swiglu": ActivationType.Swiglu,
+ }
+ return mapping.get(name)
+
+ @staticmethod
+ def get_aiter_quant_type(quant_type_str: str):
+ """
+ Given a quantization type as a string, returns the corresponding aiter QuantType enum.
+ Supported quantization types: "no", "per_tensor", "per_token", "per_1x32", "per_1x128", "per_128x128".
+ Returns None if the mapping fails.
+
+ Args:
+ quant_type_str (str): Quantization type as string.
+
+ Returns:
+ Aiter QuantType enum value, or None if not found.
+ """
+ try:
+ from aiter import QuantType
+ except ImportError:
+ return None
+
+ if not isinstance(quant_type_str, str):
+ return None
+
+ name = quant_type_str.strip().lower()
+ mapping = {
+ "no": QuantType.No,
+ "per_tensor": QuantType.per_Tensor,
+ "per_token": QuantType.per_Token,
+ "per_1x32": QuantType.per_1x32,
+ "per_1x128": QuantType.per_1x128,
+ "per_128x128": QuantType.per_128x128,
+ }
+ return mapping.get(name)
+
@classmethod
@if_aiter_supported
def is_enabled(cls) -> bool:
@@ -1127,6 +1221,14 @@ class rocm_aiter_ops:
dispatch_key=current_platform.dispatch_key,
)
+ direct_register_custom_op(
+ op_name="rocm_aiter_fused_topk",
+ op_func=_rocm_aiter_fused_topk_impl,
+ mutates_args=[],
+ fake_impl=_rocm_aiter_fused_topk_fake,
+ dispatch_key=current_platform.dispatch_key,
+ )
+
direct_register_custom_op(
op_name="rocm_aiter_mla_decode_fwd",
op_func=_rocm_aiter_mla_decode_fwd_impl,
@@ -1360,6 +1462,10 @@ class rocm_aiter_ops:
a2_scale: torch.Tensor | None = None,
num_local_tokens: torch.Tensor | None = None,
output_dtype: torch.dtype | None = None,
+ hidden_pad: int = 0,
+ intermediate_pad: int = 0,
+ bias1: torch.Tensor | None = None,
+ bias2: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_fused_moe(
hidden_states,
@@ -1377,6 +1483,10 @@ class rocm_aiter_ops:
a2_scale,
num_local_tokens,
output_dtype,
+ hidden_pad,
+ intermediate_pad,
+ bias1,
+ bias2,
)
@staticmethod
@@ -1481,6 +1591,15 @@ class rocm_aiter_ops:
routed_scaling_factor,
)
+ @staticmethod
+ def fused_topk(
+ x: torch.Tensor,
+ router_logits: torch.Tensor,
+ top_k: int,
+ gate_up: bool,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ return torch.ops.vllm.rocm_aiter_fused_topk(x, router_logits, top_k, gate_up)
+
@staticmethod
def mla_decode_fwd(
q: torch.Tensor,
@@ -1701,6 +1820,47 @@ class rocm_aiter_ops:
return shuffle_weight(tensor, layout=layout)
+ @staticmethod
+ def shuffle_weight_a16w4(
+ tensor: "torch.Tensor",
+ nLane: int,
+ gate_up: bool,
+ ) -> "torch.Tensor":
+ """
+ Shuffles the weight tensor into (A16W4) layout for AITER kernels.
+
+ Args:
+ tensor: The input weight tensor to be shuffled.
+ layout: The block layout to use, defaults to (16, 4).
+
+ Returns:
+ torch.Tensor: The shuffled tensor.
+ """
+ from aiter.ops.shuffle import shuffle_weight_a16w4
+
+ return shuffle_weight_a16w4(tensor, nLane, gate_up)
+
+ @staticmethod
+ def shuffle_scale_a16w4(
+ tensor: "torch.Tensor",
+ num_experts: int,
+ gate_up: bool,
+ ) -> "torch.Tensor":
+ """
+ Shuffles the scale tensor into (A16W4) layout for AITER kernels.
+
+ Args:
+ tensor: The input scale tensor to be shuffled.
+ num_experts: Number of experts, needed for reshaping logic.
+ gate_up: Whether the scale is for w13 (True) or w2 (False).
+
+ Returns:
+ torch.Tensor: The shuffled scale tensor.
+ """
+ from aiter.ops.shuffle import shuffle_scale_a16w4
+
+ return shuffle_scale_a16w4(tensor, num_experts, gate_up)
+
@staticmethod
def shuffle_weights(
*tensors: torch.Tensor, layout: tuple[int, int] = (16, 16)
diff --git a/vllm/_bc_linter.py b/vllm/_bc_linter.py
deleted file mode 100644
index 2929a8b..0000000
--- a/vllm/_bc_linter.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-# vllm/_bc_linter.py
-from collections.abc import Callable
-from typing import Any, TypeVar, overload
-
-T = TypeVar("T")
-
-
-@overload
-def bc_linter_skip(obj: T) -> T: ...
-
-
-@overload
-def bc_linter_skip(*, reason: str | None = ...) -> Callable[[T], T]: ...
-
-
-def bc_linter_skip(obj: Any = None, *, reason: str | None = None):
- """
- No-op decorator to mark symbols/files for BC-linter suppression.
-
- Usage:
- @bc_linter_skip
- def legacy_api(...): ...
- """
-
- def _wrap(x: T) -> T:
- return x
-
- return _wrap if obj is None else obj
-
-
-@overload
-def bc_linter_include(obj: T) -> T: ...
-
-
-@overload
-def bc_linter_include(*, reason: str | None = ...) -> Callable[[T], T]: ...
-
-
-def bc_linter_include(obj: Any = None, *, reason: str | None = None):
- """
- Usage:
- @bc_linter_include
- def public_api(...): ...
- """
-
- def _wrap(x: T) -> T:
- return x
-
- return _wrap if obj is None else obj
-
-
-__all__ = ["bc_linter_skip", "bc_linter_include"]
diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py
index 5f60a3d..1cbebfe 100644
--- a/vllm/_custom_ops.py
+++ b/vllm/_custom_ops.py
@@ -1,11 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from typing import TYPE_CHECKING, Literal
+from typing import TYPE_CHECKING, Literal, Optional, List, Dict, Any
import torch
-
-import vllm.envs as envs
+import torch.nn.functional as F
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType
@@ -14,13 +13,12 @@ from vllm.utils.flashinfer import (
)
from vllm.utils.math_utils import cdiv
-import torch.nn.functional as F
import ixformer.inference.functions as ops
-
-from ixformer.core import config
from ixformer.distributed import _distributed as cdist
-
-logger = init_logger(__name__)
+import vllm.envs as envs
+from ixformer.core import config
+import math
+_USE_TORCH_OPS = config.IXFORMER_USE_TORCH_OPS
current_platform.import_kernels()
@@ -34,25 +32,12 @@ else:
except ImportError:
from torch.library import impl_abstract as register_fake
-
-def swiglustep_and_mul_torch(output, input, limit=7.0):
- b, n = input.shape
- d = n // 2
-
- gate = input[:, :d]
- up = input[:, d:]
-
- # 直接写入 output
- torch.mul(
- torch.clamp(F.silu(gate), max=limit),
- torch.clamp(up, -limit, limit),
- out=output
- )
-
-
+# activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
ops.silu_and_mul(x, out)
+def silu_and_mul_quant(input: torch.Tensor, out_dim: int, i8_output: torch.Tensor = None, output_scales: torch.Tensor = None, is_dynamic=True) -> None:
+ return ops.silu_and_mul_quant(input, out_dim, i8_output, output_scales, is_dynamic)
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
ops.gelu_and_mul(x, out)
@@ -60,95 +45,31 @@ def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
ops.gelu_tanh_and_mul(x, out)
+
+def swigluoai_and_mul(out: torch.Tensor, x: torch.Tensor,
+ alpha: float = 1.702, limit: float = 7.0) -> None:
+ ops.swigluoai_and_mul(x, out, alpha, limit)
-
-def fatrelu_and_mul(out: torch.Tensor, x: torch.Tensor, threshold: float = 0.0) -> None:
- raise NotImplementedError("FIX soon")
-
-
+#https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
- out.copy_(F.gelu(x, approximate="tanh"))
+ x = 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
+ out.copy_(x)
return out
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
- out.copy_(F.gelu(x, approximate="tanh"))
+ x = 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
+ out.copy_(x)
return out
def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
- out.copy_(F.gelu(x, approximate="tanh"))
+ #inplace
+ out.copy_(x)
+ out.mul_(torch.sigmoid(x * 1.702))
return out
-def swigluoai_and_mul(
- out: torch.Tensor, x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0
-) -> None:
- ops.swigluoai_and_mul(x, out, alpha, limit)
- # return
-
-
-def swigluoai_and_mul_torch(
- out: torch.Tensor, x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0
-):
- gate, up = x[..., ::2], x[..., 1::2]
- gate = gate.clamp(min=None, max=limit)
- up = up.clamp(min=-limit, max=limit)
- glu = gate * torch.sigmoid(gate * alpha)
- gated_output = (up + 1) * glu
- out.copy_(gated_output)
-
-
-def rms_norm_qk(
- output_q: torch.Tensor,
- output_k: torch.Tensor,
- input_q: torch.Tensor,
- input_k: torch.Tensor,
- weight_q: torch.Tensor,
- weight_k: torch.Tensor,
- epsilon: float,
-) -> None:
- torch.ops.ixf_ops.rms_norm_qk(
- output_q, output_k, input_q, input_k, weight_q, weight_k, epsilon
- )
-
-
-def advance_step_flashattn(
- num_seqs: int,
- num_queries: int,
- block_size: int,
- input_tokens: torch.Tensor,
- sampled_token_ids: torch.Tensor,
- input_positions: torch.Tensor,
- seq_lens: torch.Tensor,
- slot_mapping: torch.Tensor,
- block_tables: torch.Tensor,
-) -> None:
- """Advance a step on GPU for existing inputs for a multi-step runner"""
- return ops.advance_step_flashattn(
- num_seqs,
- num_queries,
- block_size,
- input_tokens,
- sampled_token_ids,
- input_positions,
- seq_lens,
- slot_mapping,
- block_tables,
- )
-
-
-def quant_kv(kv):
- amax_, _ = torch.max(torch.abs(kv), dim=-1, keepdim=True)
- f_scale = amax_.float() / 127.0
- scales = f_scale.view(kv.shape[:-1])
-
- # 量化
- kv = kv / f_scale
- kv = torch.clamp(torch.round(kv), -127, 127).to(torch.int8)
- return kv, scales
-
-
# page attention ops
def paged_attention_v1(
out: torch.Tensor,
@@ -193,7 +114,6 @@ def paged_attention_v1(
blocksparse_head_sliding_step,
)
-
def paged_attention_v2(
out: torch.Tensor,
exp_sum: torch.Tensor,
@@ -298,9 +218,7 @@ def mla_decode_kvcache_cpu(
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
) -> None:
- torch.ops._C_cpu.mla_decode_kvcache(
- out, query, kv_cache, scale, block_tables, seq_lens
- )
+ torch.ops._C.mla_decode_kvcache(out, query, kv_cache, scale, block_tables, seq_lens)
# merge attn states ops
@@ -442,65 +360,54 @@ def rotary_embedding(
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
+ rotary_dim: int=-1,
) -> None:
- # torch.ops._C.rotary_embedding(
- # positions, query, key, head_size, cos_sin_cache, is_neox
- # )
- ops.vllm_rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox)
+ ops.vllm_rotary_embedding(positions, query, key, head_size,
+ cos_sin_cache, is_neox, rotary_dim=rotary_dim)
-
-def batched_rotary_embedding(
+def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
+ key: Optional[torch.Tensor], head_size: int,
+ cos_sin_cache: torch.Tensor, is_neox: bool,
+ rot_dim: int,
+ cos_sin_cache_offsets: torch.Tensor) -> None:
+ ops.vllm_batched_rotary_embedding(positions, query, key, head_size,
+ cos_sin_cache, is_neox, rot_dim,
+ cos_sin_cache_offsets)
+def m_rotary_embedding(
positions: torch.Tensor,
query: torch.Tensor,
- key: torch.Tensor | None,
+ key: Optional[torch.Tensor],
head_size: int,
cos_sin_cache: torch.Tensor,
+ smrope_section: torch.Tensor,
is_neox: bool,
- rot_dim: int,
- cos_sin_cache_offsets: torch.Tensor,
) -> None:
- ops.vllm_batched_rotary_embedding(
- positions,
- query,
- key,
- head_size,
- cos_sin_cache,
- is_neox,
- rot_dim,
- cos_sin_cache_offsets,
- )
-
+ ops.vllm_m_rotary_embedding(positions, query, key, head_size,
+ cos_sin_cache, smrope_section, is_neox)
# layer norm ops
-def rms_norm(
- out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float
-) -> None:
- # torch.ops._C.rms_norm(out, input, weight, epsilon)
- input_contiguous = input.contiguous()
- ops.rms_norm(input_contiguous, weight, epsilon, out)
+def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
+ epsilon: float) -> None:
+ ops.rms_norm(input, weight, epsilon, out)
-# def fused_add_rms_norm(
-# input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, epsilon: float
-# ) -> None:
-# torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
-def fused_add_rms_norm(
- input: torch.Tensor,
- residual: torch.Tensor,
- weight: torch.Tensor,
+def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
+ weight: torch.Tensor, epsilon: float,
+ residual_alpha: Optional[float] = 1) -> None:
+ output, residual_output = ops.residual_rms_norm(input, weight, epsilon, residual_alpha, residual)
+ return output, residual_output
+
+def rms_norm_qk(
+ output_q: torch.Tensor,
+ output_k: torch.Tensor,
+ input_q: torch.Tensor,
+ input_k: torch.Tensor,
+ weight_q: torch.Tensor,
+ weight_k: torch.Tensor,
epsilon: float,
- residual_alpha: float | None = 1.0,
) -> None:
- # torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon)
- output, residual_output = ops.residual_rms_norm(
- input=input,
- weight=weight,
- residual=residual,
- eps=epsilon,
- residual_alpha=residual_alpha,
- )
- input[:] = output
- residual[:] = residual_output
+ ops.rms_norm_qk(
+ input_q, input_k, weight_q, weight_k, epsilon, output_q, output_k)
def fused_qk_norm_rope(
@@ -553,10 +460,7 @@ def apply_repetition_penalties_cuda(
output_mask: torch.Tensor,
repetition_penalties: torch.Tensor,
) -> None:
- # torch.ops._C.apply_repetition_penalties_(
- # logits, prompt_mask, output_mask, repetition_penalties
- # )
- apply_repetition_penalties_torch(
+ torch.ops._C.apply_repetition_penalties_(
logits, prompt_mask, output_mask, repetition_penalties
)
@@ -575,15 +479,14 @@ def apply_repetition_penalties(
output_mask: A boolean tensor indicating which tokens appear in the output.
repetition_penalties: The repetition penalties of shape (num_seqs, ).
"""
- if logits.is_cuda and logits.is_contiguous():
- apply_repetition_penalties_cuda(
- logits, prompt_mask, output_mask, repetition_penalties
- )
- else:
- apply_repetition_penalties_torch(
- logits, prompt_mask, output_mask, repetition_penalties
- )
-
+ # if logits.is_cuda and logits.is_contiguous():
+ # apply_repetition_penalties_cuda(
+ # logits, prompt_mask, output_mask, repetition_penalties
+ # )
+ # else:
+ apply_repetition_penalties_torch(
+ logits, prompt_mask, output_mask, repetition_penalties
+ )
# fused quant layer norm ops
def rms_norm_dynamic_per_token_quant(
@@ -700,73 +603,33 @@ if hasattr(torch.ops._C, "awq_dequantize"):
return torch.empty((in_c, out_c), dtype=scales.dtype, device=scales.device)
-# def awq_gemm(
-# input: torch.Tensor,
-# qweight: torch.Tensor,
-# scales: torch.Tensor,
-# qzeros: torch.Tensor,
-# split_k_iters: int,
-# ) -> torch.Tensor:
-# if envs.VLLM_USE_TRITON_AWQ:
-# from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton
-
-# return awq_gemm_triton(input, qweight, scales, qzeros, split_k_iters)
-# return torch.ops._C.awq_gemm(input, qweight, scales, qzeros, split_k_iters)
-
-
-def awq_gemm(
- input: torch.Tensor,
- qweight: torch.Tensor,
- scales: torch.Tensor,
- qzeros: torch.Tensor,
- pack_factor,
- group_size: int = 128,
-) -> torch.Tensor:
+def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor,
+ pack_factor, group_size: int = 128) -> torch.Tensor:
+ if _USE_TORCH_OPS:
+ return torch.ops.ixf_ops.wui4a16(
+ input, qweight, scales, qzeros, None, group_size, "NN"
+ )
return ops.wui4a16(input, qweight, scales, qzeros, None, group_size, "NN")
-if hasattr(torch.ops._C, "awq_gemm"):
- @register_fake("_C::awq_gemm")
- def _awq_gemm_fake(
- input: torch.Tensor,
- qweight: torch.Tensor,
- scales: torch.Tensor,
- qzeros: torch.Tensor,
- split_k_iters: torch.SymInt,
- ) -> torch.Tensor:
- num_in_feats = input.size(0)
- return torch.empty(
- (split_k_iters, num_in_feats, qweight.size(1) * 8),
- dtype=input.dtype,
- device=input.device,
- ).sum(0)
+def custom_gptq_marlin_gemm(input: torch.Tensor, qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor,
+ pack_factor, group_size: int = 128, bias = None) -> torch.Tensor:
+ if _USE_TORCH_OPS:
+ return torch.ops.ixf_ops.wui4a16(input, qweight, scales, qzeros, bias, group_size, "NN")
+ else:
+ return ops.wui4a16(input, qweight, scales, qzeros, bias, group_size, "NN")
# gptq
-def gptq_gemm(
- a: torch.Tensor,
- b_q_weight: torch.Tensor,
- b_gptq_qzeros: torch.Tensor,
- b_gptq_scales: torch.Tensor,
- b_g_idx: torch.Tensor,
- use_exllama: bool,
- use_v2_format: bool,
- bit: int,
-) -> torch.Tensor:
- # return torch.ops._C.gptq_gemm(
- # a,
- # b_q_weight,
- # b_gptq_qzeros,
- # b_gptq_scales,
- # b_g_idx,
- # use_exllama,
- # use_v2_format,
- # bit,
- # )
- return ops.gptq_gemm(
- a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_exllama, bit
- )
+def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
+ b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
+ b_g_idx: torch.Tensor, use_exllama: bool, use_v2_format: bool,
+ bit: int) -> torch.Tensor:
+ if use_v2_format:
+ raise NotImplementedError("gptq_gemm not support use_v2_format")
+ return ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros ,b_gptq_scales,
+ b_g_idx, use_exllama, bit)
if hasattr(torch.ops._C, "gptq_gemm"):
@@ -787,8 +650,8 @@ if hasattr(torch.ops._C, "gptq_gemm"):
)
-def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None:
- # torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
+def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
+ bit: int) -> None:
ops.vllm_gptq_shuffle(q_weight, q_perm, bit)
@@ -896,12 +759,10 @@ def cutlass_scaled_fp4_mm(
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
- # return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)
return False
def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool:
- # return torch.ops._C.cutlass_scaled_mm_supports_block_fp8(cuda_device_capability)
return False
@@ -912,7 +773,8 @@ def cutlass_scaled_mm(
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: torch.Tensor | None = None,
- format: str | None = "TN",
+ format: str = "TN",
+ is_w4a8_linear: bool = False
) -> torch.Tensor:
"""
`cutlass_scaled_mm` implements a fused version of
@@ -936,38 +798,41 @@ def cutlass_scaled_mm(
scale_a.shape * [1, 128] == a.shape
scale_b.shape * [128, 128] == b.shape
"""
- target_shape = (*a.shape[:-1], b.shape[1])
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
- assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype
- # a is x, b is weight
- m = a.shape[:-1]
- n = b.shape[1] * 2 if envs.VLLM_W8A8_LINEAR_USE_W4A8 else b.shape[1]
+ assert a.ndim <= 3 and b.ndim == 2
+ assert bias is None or (bias.dtype == out_dtype and bias.numel() == b.shape[1] \
+ if not is_w4a8_linear else bias.numel() == b.shape[1] * 2)
+
+ # NN format: GEMM is a @ b, requires a.shape[1] == b.shape[0].
+ # When weight K-dim was padded for 64-alignment (e.g. K=160 → pad_K=192),
+ # input must be padded to match.
+ if format == "NN" and a.ndim == 2 and b.ndim == 2 and a.shape[-1] < b.shape[0]:
+ pad_size = b.shape[0] - a.shape[-1]
+ a = F.pad(a, (0, pad_size), mode="constant", value=0)
+
+ m = a.shape[0]
+ n = b.shape[1] if not is_w4a8_linear else b.shape[1] * 2
if format == "TN":
b = b.t()
- out = torch.empty(m + (n,), dtype=out_dtype, device=a.device)
- if envs.VLLM_W8A8_LINEAR_USE_W4A8:
- assert format == "NN"
- ops.w4a8(
- a,
- b,
- scale_a,
- scale_b,
- bias=bias,
- format=0,
- output=out.view(-1, n),
- output_dtype=out_dtype,
- )
+ if a.shape[-1] != b.shape[1]:
+ padding = b.shape[1] - a.shape[-1]
+ a = torch.nn.functional.pad(a, (0, padding), mode='constant', value=0)
else:
- ops.w8a8(
- a,
- b,
- scale_a,
- scale_b,
- bias,
- format=format,
- output=out.view(-1, n),
- out_dtype=out_dtype,
- )
+ if a.shape[-1] != b.shape[0]:
+ padding = b.shape[0] - a.shape[-1]
+ a = torch.nn.functional.pad(a, (0, padding), mode='constant', value=0)
+
+ if a.ndim == 2:
+ out = torch.empty((m, n), dtype=out_dtype, device=a.device)
+ else:
+ bs = a.shape[1]
+ out = torch.empty((m, bs, n), dtype=out_dtype, device=a.device)
+
+ if is_w4a8_linear:
+ assert format=="NN", "W4A8 linear only supports NN format"
+ ops.w4a8(a, b, scale_a, scale_b, bias=bias, format=0, output=out.view(-1, n), output_dtype=out_dtype)
+ else:
+ ops.w8a8(a, b, scale_a, scale_b, bias, format=format, output=out, out_dtype=out_dtype)
return out
@@ -1000,10 +865,8 @@ def cutlass_scaled_mm_azp(
torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias)
return out.view(*target_shape)
-
def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool:
- # return torch.ops._C.cutlass_sparse_scaled_mm_supported(cuda_device_capability)
- return False
+ return torch.ops._C.cutlass_sparse_scaled_mm_supported(cuda_device_capability)
def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool:
@@ -1187,7 +1050,7 @@ def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor):
return output_tensor
-def get_cutlass_pplx_moe_mm_data(
+def get_cutlass_batched_moe_mm_data(
expert_offsets: torch.Tensor,
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
@@ -1210,7 +1073,7 @@ def get_cutlass_pplx_moe_mm_data(
multiplication in two grouped MMs used in
the fused MoE operation.
"""
- return torch.ops._C.get_cutlass_pplx_moe_mm_data(
+ return torch.ops._C.get_cutlass_batched_moe_mm_data(
expert_offsets,
problem_sizes1,
problem_sizes2,
@@ -1303,6 +1166,76 @@ def cutlass_fp4_moe_mm(
)
+def mxfp8_experts_quant(
+ input_tensor: torch.Tensor,
+ problem_sizes: torch.Tensor,
+ expert_offsets: torch.Tensor,
+ blockscale_offsets: torch.Tensor,
+ quant_output: torch.Tensor,
+ scale_factor: torch.Tensor,
+) -> None:
+ torch.ops._C.mxfp8_experts_quant(
+ input_tensor,
+ problem_sizes,
+ expert_offsets,
+ blockscale_offsets,
+ quant_output,
+ scale_factor,
+ )
+
+
+def cutlass_mxfp8_grouped_mm(
+ a_tensors: torch.Tensor,
+ b_tensors: torch.Tensor,
+ a_scales: torch.Tensor,
+ b_scales: torch.Tensor,
+ out_tensors: torch.Tensor,
+ problem_sizes: torch.Tensor,
+ expert_offsets: torch.Tensor,
+ blockscale_offsets: torch.Tensor,
+) -> None:
+ torch.ops._C.cutlass_mxfp8_grouped_mm(
+ a_tensors,
+ b_tensors,
+ a_scales,
+ b_scales,
+ out_tensors,
+ problem_sizes,
+ expert_offsets,
+ blockscale_offsets,
+ )
+
+
+if hasattr(torch.ops._C, "mxfp8_experts_quant"):
+
+ @register_fake("_C::mxfp8_experts_quant")
+ def _mxfp8_experts_quant_fake(
+ input_tensor: torch.Tensor,
+ problem_sizes: torch.Tensor,
+ expert_offsets: torch.Tensor,
+ blockscale_offsets: torch.Tensor,
+ quant_output: torch.Tensor,
+ scale_factor: torch.Tensor,
+ ) -> None:
+ return None
+
+
+if hasattr(torch.ops._C, "cutlass_mxfp8_grouped_mm"):
+
+ @register_fake("_C::cutlass_mxfp8_grouped_mm")
+ def _cutlass_mxfp8_grouped_mm_fake(
+ a_tensors: torch.Tensor,
+ b_tensors: torch.Tensor,
+ a_scales: torch.Tensor,
+ b_scales: torch.Tensor,
+ out_tensors: torch.Tensor,
+ problem_sizes: torch.Tensor,
+ expert_offsets: torch.Tensor,
+ blockscale_offsets: torch.Tensor,
+ ) -> None:
+ return None
+
+
# gptq_marlin
def gptq_marlin_repack(
b_q_weight: torch.Tensor,
@@ -2123,7 +2056,7 @@ def scaled_int8_quant(
assert symmetric == (azp is None), (
"azp must only be provided for asymmetric quantization."
)
- torch.ops._C.static_scaled_int8_quant(output, input, scale, azp)
+ ops.static_scaled_int8_quant(output, input, scale)
return output, scale, azp
# dynamic-per-token quantization.
@@ -2131,9 +2064,6 @@ def scaled_int8_quant(
(input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32
)
input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32)
- # torch.ops._C.dynamic_scaled_int8_quant(
- # output, input.contiguous(), input_scales, input_azp
- # )
ops.dynamic_scaled_int8_quant(output, input, input_scales)
return output, input_scales, input_azp
@@ -2153,7 +2083,6 @@ def ggml_mul_mat_vec_a8(
) -> torch.Tensor:
return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row)
-
def ggml_mul_mat_a8(
W: torch.Tensor,
X: torch.Tensor,
@@ -2202,7 +2131,6 @@ def ggml_moe_a8_vec(
def ggml_moe_get_block_size(quant_type: int) -> int:
return torch.ops._C.ggml_moe_get_block_size(quant_type)
-
# mamba
def selective_scan_fwd(
u: torch.Tensor,
@@ -2223,6 +2151,8 @@ def selective_scan_fwd(
block_idx_first_scheduled_token: torch.Tensor | None = None,
block_idx_last_scheduled_token: torch.Tensor | None = None,
initial_state_idx: torch.Tensor | None = None,
+ cu_chunk_seqlen: torch.Tensor | None = None,
+ last_chunk_indices: torch.Tensor | None = None,
):
torch.ops._C.selective_scan_fwd(
u,
@@ -2243,6 +2173,8 @@ def selective_scan_fwd(
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
initial_state_idx,
+ cu_chunk_seqlen,
+ last_chunk_indices,
)
@@ -2291,23 +2223,9 @@ def moe_align_block_size(
num_tokens_post_pad: torch.Tensor,
expert_map: torch.Tensor | None = None,
) -> None:
- # torch.ops._moe_C.moe_align_block_size(
- # topk_ids,
- # num_experts,
- # block_size,
- # sorted_token_ids,
- # experts_ids,
- # num_tokens_post_pad,
- # expert_map,
- # )
- ops.vllm_moe_align_block_size(
- topk_ids,
- num_experts,
- block_size,
- sorted_token_ids,
- experts_ids,
- num_tokens_post_pad,
- )
+ ops.vllm_moe_align_block_size(topk_ids, num_experts, block_size,
+ sorted_token_ids, experts_ids,
+ num_tokens_post_pad)
def batched_moe_align_block_size(
@@ -2398,6 +2316,23 @@ def moe_wna16_gemm(
)
+def router_gemm_bf16_fp32(input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
+ """bf16 x bf16 -> fp32 GEMM via cuBLAS. weight shape: (N, K)."""
+ return torch.ops._moe_C.router_gemm_bf16_fp32(input, weight)
+
+
+if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "router_gemm_bf16_fp32"):
+
+ @register_fake("_moe_C::router_gemm_bf16_fp32")
+ def router_gemm_bf16_fp32_fake(
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ ) -> torch.Tensor:
+ return torch.empty(
+ input.shape[0], weight.shape[0], dtype=torch.float32, device=input.device
+ )
+
+
def dsv3_router_gemm(
hidden_states: torch.Tensor,
router_weight: torch.Tensor,
@@ -2421,53 +2356,15 @@ def topk_softmax(
renormalize: bool = False,
e_score_correction_bias: torch.Tensor | None = None,
) -> None:
- # torch.ops._moe_C.topk_softmax(
- # topk_weights,
- # topk_ids,
- # token_expert_indices,
- # gating_output,
- # renormalize,
- # e_score_correction_bias,
- # )
- ops.vllm_moe_topk_softmax(
- topk_weights, topk_ids, token_expert_indices, gating_output
- )
-
- if renormalize:
- topk_weights[:] = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
-
-
-
-def topk_sigmoid_torch(
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- token_expert_indices: torch.Tensor,
- gating_output: torch.Tensor,
- renormalize: bool = False,
- e_score_correction_bias: torch.Tensor | None = None,
-):
- batch, num_experts = gating_output.shape
- k = topk_weights.shape[1]
-
- # Sigmoid + bias
- probs = torch.sigmoid(gating_output)
- if e_score_correction_bias is not None:
- probs = probs + e_score_correction_bias.unsqueeze(0)
-
- # Top-K
- topk_vals, topk_idx = torch.topk(probs, k=k, dim=1)
-
- # 写入结果
- topk_weights[:] = topk_vals
- topk_ids[:] = topk_idx
- token_expert_indices[:] = (torch.arange(k, device=gating_output.device).unsqueeze(0) * batch
- + torch.arange(batch, device=gating_output.device).unsqueeze(1))
-
- # renormalize
- if renormalize:
- denom = topk_weights.sum(dim=1, keepdim=True)
- denom = torch.where(denom > 0, denom, torch.ones_like(denom))
- topk_weights[:] = topk_weights / denom
+ ops.moe_fused_topk_bias(
+ gating_output=gating_output,
+ topk=topk_weights.shape[-1],
+ scoring_func="softmax",
+ routed_scaling_factor=1,
+ renormalize=renormalize,
+ e_score_correction_bias=e_score_correction_bias,
+ topk_weight=topk_weights,
+ topk_ids=topk_ids)
def topk_sigmoid(
@@ -2478,15 +2375,15 @@ def topk_sigmoid(
renormalize: bool = False,
e_score_correction_bias: torch.Tensor | None = None,
) -> None:
- # torch.ops._moe_C.topk_sigmoid(
- # topk_weights,
- # topk_ids,
- # token_expert_indices,
- # gating_output,
- # renormalize,
- # e_score_correction_bias,
- # )
- topk_sigmoid_torch(topk_weights, topk_ids, token_expert_indices, gating_output, renormalize, e_score_correction_bias)
+ ops.moe_fused_topk_bias(
+ gating_output=gating_output,
+ topk=topk_weights.shape[-1],
+ scoring_func="sigmoid",
+ routed_scaling_factor=1,
+ renormalize=renormalize,
+ e_score_correction_bias=e_score_correction_bias,
+ topk_weight=topk_weights,
+ topk_ids=topk_ids)
def grouped_topk(
@@ -2685,26 +2582,24 @@ def reshape_and_cache_flash(
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
- # torch.ops._C_cache_ops.reshape_and_cache_flash(
- # key,
- # value,
- # key_cache,
- # value_cache,
- # slot_mapping,
- # kv_cache_dtype,
- # k_scale,
- # v_scale,
- # )
- ops.reshape_and_cache_flash(
- key,
- value,
- key_cache,
- value_cache,
- slot_mapping,
- kv_cache_dtype,
- 1.0,
- 1.0,
- )
+ ops.reshape_and_cache_flash(key, value, key_cache,
+ value_cache, slot_mapping,
+ kv_cache_dtype, 1.0, 1.0)
+
+
+def reshape_and_cache_flash_mix(
+ key: torch.Tensor,
+ value: torch.Tensor,
+ key_scale: torch.Tensor,
+ key_cache: torch.Tensor,
+ value_cache: torch.Tensor,
+ key_scale_cache: torch.Tensor,
+ slot_mapping: torch.Tensor,
+ kv_cache_dtype: str,
+):
+ ops.reshape_and_cache_flash_mix(key, value, key_scale,
+ key_cache, value_cache, key_scale_cache,
+ slot_mapping, kv_cache_dtype)
def concat_and_cache_mla(
@@ -2715,37 +2610,11 @@ def concat_and_cache_mla(
kv_cache_dtype: str,
scale: torch.Tensor,
) -> None:
- # torch.ops._C_cache_ops.concat_and_cache_mla(
- # kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale
- # )
- ops.vllm_concat_and_cache_mla(
- kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale
- )
+ ops.vllm_concat_and_cache_mla(kv_c, k_pe, kv_cache,
+ slot_mapping, kv_cache_dtype,
+ scale)
-def concat_and_cache_mla_int8(
- kv_c_int8: torch.Tensor,
- kv_c_scale: torch.Tensor,
- k_pe_int8: torch.Tensor,
- k_pe_scale: torch.Tensor,
- kv_cache: torch.Tensor,
- kv_cache_scale: torch.Tensor,
- slot_mapping: torch.Tensor,
- kv_cache_dtype: str,
- scale: torch.Tensor,
-) -> None:
- ops.vllm_concat_and_cache_mla_int8(
- kv_c_int8,
- kv_c_scale,
- k_pe_int8,
- k_pe_scale,
- kv_cache,
- kv_cache_scale,
- slot_mapping,
- kv_cache_dtype,
- scale,
- )
-
def concat_and_cache_mla_rope_fused(
positions: torch.Tensor,
@@ -2799,7 +2668,6 @@ def swap_blocks(
but not both on cpu.
the block mapping tensor must on cpu.
"""
- # torch.ops._C_cache_ops.swap_blocks(src, dst, block_size_in_bytes, block_mapping)
ops.vllm_swap_blocks(src, dst, block_mapping)
@@ -2841,9 +2709,6 @@ def cp_gather_cache(
batch_size: int,
seq_starts: torch.Tensor | None = None,
) -> None:
- # torch.ops._C_cache_ops.cp_gather_cache(
- # src_cache, dst, block_table, cu_seq_lens, batch_size, seq_starts
- # )
ops.vllm_cp_gather_cache(
src_cache, dst, block_table, cu_seq_lens, batch_size, seq_starts
)
@@ -2882,19 +2747,309 @@ def indexer_k_quant_and_cache(
torch.ops._C_cache_ops.indexer_k_quant_and_cache(
k, kv_cache, slot_mapping, quant_block_size, kv_cache_dtype
)
+
+def indexer_k_cache(k: torch.Tensor, kv_cache: torch.Tensor,slot_mapping: torch.Tensor)-> None:
+ ops.indexer_k_cache(k, kv_cache, slot_mapping)
+def dsa_convert_req_index_to_global_index(
+ req_id: torch.Tensor, # int32 [num_tokens]
+ block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req]
+ token_indices: torch.Tensor, # int32 [num_tokens, num_topk_tokens]
+ block_size: int,
+ output: torch.Tensor = None,
+) -> torch.Tensor:
+ """
+ Convert request-local token indices to global KV cache indices.
-def cp_gather_indexer_k_quant_cache(
- kv_cache: torch.Tensor,
- dst_k: torch.Tensor,
- dst_scale: torch.Tensor,
- block_table: torch.Tensor,
- cu_seq_lens: torch.Tensor,
-) -> None:
- torch.ops._C_cache_ops.cp_gather_indexer_k_quant_cache(
- kv_cache, dst_k, dst_scale, block_table, cu_seq_lens
+ For each (token_id, indice_id):
+ tok = token_indices[token_id, indice_id]
+ if tok < 0 or tok // block_size >= max_num_blocks_per_req:
+ output[token_id, indice_id] = -1
+ else:
+ req = req_id[token_id]
+ block_idx = block_table[req, tok // block_size]
+ output[token_id, indice_id] = block_idx * block_size + tok % block_size
+ """
+ return ops.dsa_convert_req_index_to_global_index(
+ req_id, block_table, token_indices, block_size, output
)
+
+def ref_mqa_logits(
+ q: torch.Tensor, # [num_tokens, n_head, head_dim] - 可能已量化
+ k: torch.Tensor, # [num_blocks, block_size, head_dim] 或展开形式 - 可能已量化
+ weights: torch.Tensor, # [num_tokens, n_head, 1] - 权重
+ cu_seqlen_ks: torch.Tensor, # 序列起始位置
+ cu_seqlen_ke: torch.Tensor, # 序列结束位置
+) -> torch.Tensor:
+ """
+ 多查询注意力logits计算的PyTorch等价实现
+ """
+ M, H, D = q.shape
+ N = k.shape[0]
+ device = q.device
+ # 初始化输出logits [M, N]
+ logits = torch.full((M, N), -float('inf'), device=device, dtype=torch.float32)
+ for i in range(M):
+ seq_start = cu_seqlen_ks[i]
+ seq_end = cu_seqlen_ke[i]
+
+ if seq_start >= seq_end:
+ continue
+
+ #当前查询的Q [H, D]
+ q_i = q[i] # [H, D]
+
+ seq_k = k[seq_start:seq_end] # [seq_len, head_dim]
+
+ # 计算注意力分数 [H, seq_len]
+ attention_scores = torch.matmul(q_i, seq_k.T) # BF16计算
+ attention_scores = F.relu(attention_scores)
+
+ # 应用权重 [H, seq_len]
+ attention_scores_f32 = attention_scores.float()
+ weights_i = weights[i].unsqueeze(1) # [H, 1]
+ weighted_scores = attention_scores_f32 * weights_i # [H, seq_len]
+
+ # 汇总所有头的logits [seq_len]
+ logits_i = torch.sum(weighted_scores, dim=0) # [seq_len]
+
+ # 将结果填充到输出logits的对应位置
+ logits[i, seq_start:seq_end] = logits_i
+
+ return logits
+
+def ref_paged_mqa_logits(
+ q: torch.Tensor,
+ kv_cache: torch.Tensor,
+ weights: torch.Tensor,
+ context_lens: torch.Tensor,
+ block_tables: torch.Tensor,
+ max_model_len: int,
+ clean_logits: bool = True
+) -> torch.Tensor:
+ """使用分页KV缓存计算FP8多查询注意力logits的PyTorch实现
+
+ Args:
+ q: 查询张量 [B, next_n, H, D]
+ kv_cache: 分页KV缓存 [num_blocks, block_size, 1, D]
+ weights: 权重张量 [B * next_n, H], dtype=torch.float32
+ context_lens: 上下文长度 [B], dtype=int32
+ block_tables: 块映射表 [B, max_blocks], dtype=int32
+ schedule_metadata: 调度元数据
+ max_model_len: 最大序列长度,用于确定输出logits大小
+
+ Returns:
+ Logits张量 [B * next_n, max_model_len], dtype=torch.float32
+ """
+ def reassemble_k_from_paged_cache(
+ kv_cache: torch.Tensor,
+ block_table: torch.Tensor,
+ context_len: int,
+ head_dim: int,
+ block_size: int
+ ) -> torch.Tensor:
+ """从分页缓存中重组K值"""
+ num_blocks_needed = (context_len + block_size - 1) // block_size
+ valid_blocks = block_table[:num_blocks_needed]
+ device = kv_cache.device
+ # 初始化输出K序列 [context_len, head_dim]
+ k_sequence = torch.zeros(context_len, head_dim, device=device, dtype=kv_cache.dtype)
+ token_offset = 0
+ for block_idx in valid_blocks:
+ if block_idx < 0:
+ break
+ # 当前块中的token数量
+ tokens_in_block = min(block_size, context_len - token_offset)
+ if tokens_in_block <= 0:
+ break
+ # 从缓存块中提取K值
+ block_data = kv_cache[block_idx] # [block_size, 1, D]
+
+ # 提取K值
+ k_sequence[token_offset:token_offset + tokens_in_block] = block_data[:tokens_in_block, 0, :head_dim] # [tokens_in_block, D]
+ token_offset += tokens_in_block
+
+ return k_sequence
+
+ def compute_mqa_logits(
+ q: torch.Tensor, # [next_n, H, D]
+ k: torch.Tensor, # [context_len, D]
+ weights: torch.Tensor, # [next_n, H]
+ context_len: int,
+ max_model_len: int
+ ) -> torch.Tensor:
+ """计算多查询注意力logits"""
+ next_n, H, D = q.shape
+ device = q.device
+
+ # 初始化批次logits [next_n, max_model_len]
+ batch_logits = torch.full((next_n, max_model_len), -float('inf'),
+ device=device, dtype=torch.float32)
+
+ # 扩展K以匹配多头 [context_len, H, D]
+ k_expanded = k.unsqueeze(1).expand(-1, H, -1) # [context_len, H, D]
+
+ # 转置以便矩阵乘法
+ q_transposed = q.transpose(0, 1) # [H, next_n, D]
+ k_transposed = k_expanded.transpose(0, 1) # [H, context_len, D]
+ # 批量计算注意力分数 [H, next_n, context_len]
+ attention_scores = torch.bmm(q_transposed, k_transposed.transpose(1, 2)) # [H, next_n, context_len]
+ attention_scores = F.relu(attention_scores)
+ # 应用权重并汇总所有头 [next_n, context_len]
+ weights_expanded = weights.transpose(0, 1).unsqueeze(2) # [H, next_n, 1]
+ weighted_scores = attention_scores * weights_expanded # [H, next_n, context_len]
+ logits_per_token = weighted_scores.sum(dim=0) # [next_n, context_len]
+
+ # 填充到输出logits中
+ batch_logits[:, :context_len] = logits_per_token
+
+ return batch_logits
+ def clean_logits_tensor(
+ logits: torch.Tensor,
+ context_lens: torch.Tensor,
+ next_n: int,
+ max_model_len: int
+ ) -> torch.Tensor:
+ """清理logits张量,将超出上下文长度的位置设为负无穷"""
+ B = len(context_lens)
+
+ for batch_idx in range(B):
+ context_len = context_lens[batch_idx].item()
+ if context_len >= max_model_len:
+ continue
+
+ # 当前批次在logits中的位置
+ batch_start = batch_idx * next_n
+ batch_end = (batch_idx + 1) * next_n
+
+ # 将超出上下文长度的位置设为负无穷
+ logits[batch_start:batch_end, context_len:] = -float('inf')
+
+ return logits
+
+ B, next_n, H, D = q.shape
+ num_blocks, block_size, _, cache_stride = kv_cache.shape
+ device = q.device
+
+ # 初始化输出logits [B * next_n, max_model_len]
+ logits = torch.full((B * next_n, max_model_len), -float('inf'),
+ device=device, dtype=torch.float32)
+
+ # 处理每个批次
+ for batch_idx in range(B):
+ context_len = context_lens[batch_idx].item()
+ if context_len == 0:
+ continue
+
+ # 当前批次的查询 [next_n, H, D]
+ batch_q = q[batch_idx] # [next_n, H, D]
+
+ # 当前批次的权重 [next_n, H]
+ batch_weights_start = batch_idx * next_n
+ batch_weights_end = (batch_idx + 1) * next_n
+ batch_weights = weights[batch_weights_start:batch_weights_end] # [next_n, H]
+
+ # 从分页缓存中重组K值
+ batch_k = reassemble_k_from_paged_cache(
+ kv_cache, block_tables[batch_idx], context_len, D, block_size
+ ) # [context_len, D]
+ # 计算多查询注意力logits
+ batch_logits = compute_mqa_logits(
+ batch_q, batch_k, batch_weights, context_len, max_model_len
+ ) # [next_n, max_model_len]
+
+ # 填充到输出logits中
+ logits[batch_weights_start:batch_weights_end] = batch_logits
+
+ if clean_logits:
+ # 清理logits:将超出上下文长度的位置设为负无穷
+ logits = clean_logits_tensor(logits, context_lens, next_n, max_model_len)
+
+ return logits
+
+def sparse_prefill_fwd(
+ q: torch.Tensor,
+ kv: torch.Tensor,
+ indices: torch.Tensor,
+ sm_scale: float,
+ d_v: int = 512,
+):
+ return ops.sparse_flash_attn(q, kv, indices, sm_scale, d_v)
+
+def ref_sparse_prefill_fwd(
+ q: torch.Tensor,
+ kv: torch.Tensor,
+ indices: torch.Tensor,
+ sm_scale: float,
+ d_v: int = 512,
+):
+ """
+ 稀疏注意力预填充内核的PyTorch实现
+
+ Args:
+ - q: [s_q, h_q, d_qk], bfloat16
+ - kv: [s_kv, h_kv, d_qk], bfloat16
+ - indices: [s_q, h_kv, topk], int32. 无效索引设为-1或>=s_kv
+ - sm_scale: float
+ - d_v: 值向量的维度,只能为512
+
+ Returns:
+ - (output, max_logits, lse)
+ - output: [s_q, h_q, d_v], bfloat16
+ - max_logits: [s_q, h_q], float
+ - lse: [s_q, h_q], float, 以2为底的对数求和指数
+ """
+ def ref_masked_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ sm_scale: float,
+ ) -> torch.Tensor:
+ query = query * sm_scale
+ dtype = query.dtype
+ device = query.device
+ attn = torch.einsum("qhd,khd->hqk", query, key)
+ attn = attn.to(torch.float)
+ attn = torch.softmax(attn, dim=-1)
+ value = value.to(torch.float)
+ out = torch.einsum("hqk,khd->qhd", attn, value)
+ out = out.to(device).to(dtype)
+ return out
+ s_q, h_q, d_qk = q.shape
+ s_kv, h_kv, _ = kv.shape
+ _, _, topk = indices.shape
+
+ device = q.device
+ dtype = q.dtype
+
+ # 分离K和V
+ k = kv # [s_kv, h_kv, d_qk]
+ v = kv[:, :, :d_v] # [s_kv, h_kv, d_v]
+
+ # 初始化输出
+ output = torch.zeros(s_q, h_q, d_v, device=device, dtype=dtype)
+ # 处理每个查询位置
+ for i in range(s_q):
+ # 当前查询 [h_q, d_qk]
+ q_i = q[i].unsqueeze(0) # [1, h_q, d_qk]
+ # 获取当前查询位置的稀疏索引 [topk]
+ sparse_indices = indices[i, 0] # [topk]
+ # 过滤有效索引 (>=0 且 < s_kv)
+ valid_mask = (sparse_indices >= 0) & (sparse_indices < s_kv)
+ valid_indices = sparse_indices[valid_mask]
+ # 获取有效的K和V
+ valid_k = k[valid_indices].repeat(1, h_q, 1) # [valid_len, h_q, d_qk]
+ valid_v = v[valid_indices].repeat(1, h_q, 1) # [valid_len, h_q, d_v]
+ out = ref_masked_attention(
+ q_i,
+ valid_k,
+ valid_v,
+ sm_scale
+ )
+ out = out.view(h_q, d_v)
+ output[i].copy_(out, non_blocking=True)
+ return output
def get_device_attribute(attribute: int, device: int) -> int:
return torch.ops._C_cuda_utils.get_device_attribute(attribute, device)
@@ -2902,10 +3057,9 @@ def get_device_attribute(attribute: int, device: int) -> int:
def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
# ruff: noqa: E501
- # return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
- # device
- # )
- return 32 * 1024
+ return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute(
+ device
+ )
# custom ar
@@ -2951,7 +3105,6 @@ def register_graph_buffers(
) -> None:
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
-
def allocate_shared_buffer_and_handle(size: int) -> tuple[int, torch.Tensor]:
return torch.ops._C_custom_ar.allocate_shared_buffer_and_handle(size)
@@ -3014,7 +3167,6 @@ def get_flash_mla_metadata(
cache_seqlens, num_heads_per_head_k, num_heads_k
)
-
def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
@@ -3164,211 +3316,6 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"):
return torch.empty((M, N), dtype=out_dtype)
-# Add our new features here..
-
-
-# moe
-def invoke_fused_moe_kernel(
- A: torch.Tensor,
- B: torch.Tensor,
- C: torch.Tensor,
- A_scale: torch.Tensor | None,
- B_scale: torch.Tensor | None,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- sorted_token_ids: torch.Tensor,
- expert_ids: torch.Tensor,
- num_tokens_post_padded: torch.Tensor,
- mul_routed_weight: bool,
- top_k: int,
- config: dict[str, "any"],
- compute_type,
- use_fp8_w8a8: bool,
- use_int8_w8a16: bool,
- block_shape: list[int] | None = None,
- bias: torch.Tensor | None = None,
-) -> None:
- ops.vllm_invoke_fused_moe_kernel(
- A,
- B,
- C,
- topk_weights,
- topk_ids,
- sorted_token_ids,
- expert_ids,
- num_tokens_post_padded,
- mul_routed_weight,
- top_k,
- config["BLOCK_SIZE_M"],
- bias=bias
- )
-
-# broadcast
-class Async_helper:
- # For now, the comm and the other kernels are in the same stream, so we can remove the stream wait..
- def wait(
- self,
- ):
- return True
-
-
-def broadcast(tensor, src=0, group=None, async_op=False):
- cdist.broadcast(tensor, src, group, async_op=True)
- if async_op:
- return Async_helper()
- else:
- pass
-
-
-# w8a16
-def linear_w8a16(
- x: torch.Tensor,
- qweight: torch.Tensor,
- scales: torch.Tensor,
- group_size: int = -1,
- format: str = "TN",
-) -> torch.Tensor:
- return ops.w8a16(x, qweight, scales, format="TN", group_size=group_size)
-
-
-## lora sgmv / bgmv
-def sbgmv_expand(
- x: torch.Tensor,
- w_t_all: torch.Tensor,
- y: torch.Tensor,
- b_seq_start_loc: torch.Tensor = None,
- seq_len_tensor: torch.Tensor = None,
- lora_indices_tensor: torch.Tensor = None,
- batches: int = -1,
- max_seq_length: int = -1,
- token_nums: int = -1,
- add_input=True,
-):
- """
- x: inputs
- w_t_all: lora weight
- y: output
-
- y += x@wt_t_all
- """
- assert x.dtype in [torch.float16, torch.bfloat16, torch.float32]
- assert w_t_all.dtype in [
- torch.float16,
- torch.bfloat16,
- ]
-
- assert x.is_contiguous()
- # assert y.is_contiguous()
- if x.dtype == torch.float:
- x = x.to(w_t_all.dtype)
-
- if w_t_all.ndim == 4: # shape:(lora_num,1,size,rank)
- assert w_t_all.size(1) == 1
- w_t_all = w_t_all.squeeze(dim=1)
- else:
- assert w_t_all.ndim == 3 # shape:(lora_num,size,rank)
- assert w_t_all.is_contiguous()
-
- assert add_input == True
-
- lora_indices = lora_indices_tensor.cpu().tolist()
- lora_num = w_t_all.shape[0]
-
- ## 单一lora model, 且所有request均使用lora
- if lora_num == 1 and all(x == lora_indices[0] for x in lora_indices):
- if lora_indices[0] != -1:
- w_t = w_t_all[0]
- y += torch.matmul(x, w_t.t())
- ## 多个lora model
- else:
- ## prefill
- if batches != -1:
- for i, lora_id, start, seq_len in zip(
- range(batches), lora_indices, b_seq_start_loc, seq_len_tensor
- ):
- if lora_id != -1:
- xi = x[start : start + seq_len]
- w_t = w_t_all[lora_id]
- y[start : start + seq_len] += xi @ w_t.t()
- ## decode
- else:
- batches = x.shape[0]
- for i, lora_id in zip(range(batches), lora_indices):
- if lora_id != -1:
- xi = x[i].unsqueeze(0)
- w_t = w_t_all[lora_id]
- y[i] += (xi @ w_t.t()).squeeze(0)
-
- return y
-
-
-def sbgmv_shrink(
- x: torch.Tensor,
- w_t_all: torch.Tensor,
- y: torch.Tensor,
- b_seq_start_loc: torch.Tensor = None,
- seq_len_tensor: torch.Tensor = None,
- lora_indices_tensor: torch.Tensor = None,
- batches: int = -1,
- max_seq_length: int = -1,
- token_nums: int = -1,
- scale: float = 1.0,
-):
- """
- xx: inputs
- w_t_all: lora weight
- y: output
- scale: float
-
- y = x@w_t_all * scale
- """
- assert x.dtype == w_t_all.dtype
- assert x.dtype in [torch.float16, torch.bfloat16]
- assert x.is_contiguous()
- assert y.is_contiguous()
-
- if w_t_all.ndim == 4: # shape:(lora_num,1,size,rank)
- assert w_t_all.size(1) == 1
- w_t_all = w_t_all.squeeze(dim=1)
- else:
- assert w_t_all.ndim == 3 # shape:(lora_num,size,rank)
- assert w_t_all.is_contiguous()
-
- lora_num = w_t_all.shape[0]
- lora_indices = lora_indices_tensor.cpu().tolist()
-
- ## 单一lora model, 且所有request均使用lora
- if lora_num == 1 and all(x == lora_indices[0] for x in lora_indices):
- if lora_indices[0] != -1:
- w_t = w_t_all[0]
- y = torch.matmul(x, w_t.t()) * scale
- ## 多个lora model
- else:
- ## prefill
- if batches != -1:
- for i, lora_id, start, seq_len in zip(
- range(batches), lora_indices, b_seq_start_loc, seq_len_tensor
- ):
- if lora_id != -1:
- xi = x[start : start + seq_len]
- w_t = w_t_all[lora_id]
- y[start : start + seq_len] = (xi @ w_t.t()) * scale
- ## decode
- else:
- batches = x.shape[0]
- for i, lora_id in zip(range(batches), lora_indices):
- if lora_id != -1:
- xi = x[i].unsqueeze(0)
- w_t = w_t_all[lora_id]
- y[i] = (xi @ w_t.t()).squeeze(0) * scale
-
- return y
-
-
-def dynamic_scaled_quant_dynamic_int8(x, input_scales=None, int8_out=None, scales=None):
- return ops.dynamic_scaled_quant_smoothquant(x, input_scales, int8_out, scales)
-
-
class CPUDNNLGEMMHandler:
def __init__(self) -> None:
self.handler_tensor: torch.Tensor | None = None
@@ -3799,50 +3746,246 @@ if hasattr(torch.ops._C, "hadacore_transform"):
@register_fake("_C::hadacore_transform")
def _hadacore_transform_fake(x: torch.Tensor, inplace: bool) -> torch.Tensor:
return torch.empty_like(x) if not inplace else x
-
-
+# Add our new features here..
def gather_cache(
- src_cache: torch.Tensor, # [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
- dst: torch.Tensor, # [TOT_TOKENS, ENTRIES...]
- block_table: torch.Tensor, # [BATCH, BLOCK_INDICES]
- cu_seq_lens: torch.Tensor, # [BATCH+1]
+ src_cache: torch.Tensor, # [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
+ dst: torch.Tensor, # [TOT_TOKENS, ENTRIES...]
+ block_table: torch.Tensor, # [BATCH, BLOCK_INDICES]
+ cu_seq_lens: torch.Tensor, # [BATCH+1]
batch_size: int,
- seq_starts: torch.Tensor = None,
+ seq_starts: torch.Tensor = None
):
- ops.vllm_gather_cache(
- src_cache, dst, block_table, cu_seq_lens, batch_size, seq_starts
- )
-
-
+ ops.vllm_gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size, seq_starts)
+
def gather_cache_int8(
- src_cache: torch.Tensor, # [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
- src_cache_scale: torch.Tensor, # [NUM_BLOCKS, BLOCK_SIZE, 2]
+ src_cache: torch.Tensor, # [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
+ src_cache_scale: torch.Tensor,# [NUM_BLOCKS, BLOCK_SIZE, 2]
kv_lora_rank: int,
- dst: torch.Tensor, # [TOT_TOKENS, ENTRIES...]
- block_table: torch.Tensor, # [BATCH, BLOCK_INDICES]
- cu_seq_lens: torch.Tensor, # [BATCH+1]
+ dst: torch.Tensor, # [TOT_TOKENS, ENTRIES...]
+ block_table: torch.Tensor, # [BATCH, BLOCK_INDICES]
+ cu_seq_lens: torch.Tensor, # [BATCH+1]
batch_size: int,
- seq_starts: torch.Tensor = None,
+ seq_starts: torch.Tensor = None
):
- ops.vllm_gather_cache_int8(
- src_cache,
- src_cache_scale,
- kv_lora_rank,
- dst,
- block_table,
- cu_seq_lens,
- batch_size,
- seq_starts,
+ ops.vllm_gather_cache_int8(src_cache,src_cache_scale, kv_lora_rank, dst, block_table, cu_seq_lens, batch_size, seq_starts)
+
+def quant_kv(kv):
+ amax_, _ = torch.max(torch.abs(kv), dim=-1, keepdim=True)
+ f_scale = amax_.float() / 127.0
+ scales = f_scale.view(kv.shape[:-1])
+
+ # 量化
+ kv = kv / f_scale
+ kv = torch.clamp(torch.round(kv), -127, 127).to(torch.int8)
+ return kv, scales
+
+
+def concat_and_cache_mla_int8(
+ kv_c_int8: torch.Tensor,
+ kv_c_scale: torch.Tensor,
+ k_pe_int8: torch.Tensor,
+ k_pe_scale: torch.Tensor,
+ kv_cache: torch.Tensor,
+ kv_cache_scale: torch.Tensor,
+ slot_mapping: torch.Tensor,
+ kv_cache_dtype: str,
+ scale: torch.Tensor,
+) -> None:
+ ops.vllm_concat_and_cache_mla_int8(kv_c_int8,kv_c_scale, k_pe_int8, k_pe_scale, kv_cache, kv_cache_scale, slot_mapping, kv_cache_dtype, scale)
+def invoke_fused_moe_kernel(
+ A: torch.Tensor,
+ B: torch.Tensor,
+ C: torch.Tensor,
+ A_scale: Optional[torch.Tensor],
+ B_scale: Optional[torch.Tensor],
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ sorted_token_ids: torch.Tensor,
+ expert_ids: torch.Tensor,
+ num_tokens_post_padded: torch.Tensor,
+ mul_routed_weight: bool,
+ top_k: int,
+ config: Dict[str, Any],
+ compute_type,
+ use_fp8_w8a8: bool,
+ use_int8_w8a16: bool,
+ block_shape: Optional[List[int]] = None,
+ bias: Optional[torch.Tensor] = None,
+) -> None:
+ ops.vllm_invoke_fused_moe_kernel(
+ A,
+ B,
+ C,
+ topk_weights,
+ topk_ids,
+ sorted_token_ids,
+ expert_ids,
+ num_tokens_post_padded,
+ mul_routed_weight,
+ top_k,
+ config['BLOCK_SIZE_M'],
+ bias=bias
)
+# broadcast
+class Async_helper():
+ # For now, the comm and the other kernels are in the same stream, so we can remove the stream wait..
+ def wait(self,):
+ return True
+
+
+def broadcast(tensor, src=0, group=None, async_op=False):
+ cdist.broadcast(tensor,src,group,async_op=True)
+ if async_op:
+ return Async_helper()
+ else:
+ pass
+
+
+# w8a16
+def linear_w8a16(x: torch.Tensor, qweight: torch.Tensor, scales:torch.Tensor,
+ group_size: int = -1, format: str = "TN")-> torch.Tensor:
+ return ops.w8a16(x, qweight, scales, format="TN", group_size=group_size)
+
+
+## lora sgmv / bgmv
+def sbgmv_expand(x: torch.Tensor,
+ w_t_all: torch.Tensor,
+ y: torch.Tensor,
+ b_seq_start_loc: torch.Tensor = None,
+ seq_len_tensor: torch.Tensor = None,
+ lora_indices_tensor: torch.Tensor = None,
+ batches: int = -1,
+ max_seq_length: int = -1,
+ token_nums: int = -1,
+ add_input=True,
+ ):
+ '''
+ x: inputs
+ w_t_all: lora weight
+ y: output
+
+ y += x@wt_t_all
+ '''
+ assert x.dtype in [torch.float16, torch.bfloat16, torch.float32]
+ assert w_t_all.dtype in [
+ torch.float16,
+ torch.bfloat16,
+ ]
+
+ assert x.is_contiguous()
+ # assert y.is_contiguous()
+ if x.dtype == torch.float:
+ x = x.to(w_t_all.dtype)
+
+ if w_t_all.ndim == 4: # shape:(lora_num,1,size,rank)
+ assert w_t_all.size(1) == 1
+ w_t_all = w_t_all.squeeze(dim=1)
+ else:
+ assert w_t_all.ndim == 3 # shape:(lora_num,size,rank)
+ assert w_t_all.is_contiguous()
+
+ assert add_input == True
+
+ lora_indices = lora_indices_tensor.cpu().tolist()
+ lora_num = w_t_all.shape[0]
+
+ ## 单一lora model, 且所有request均使用lora
+ if lora_num == 1 and all(x == lora_indices[0] for x in lora_indices):
+ if lora_indices[0] != -1:
+ w_t = w_t_all[0]
+ y += torch.matmul(x, w_t.t())
+ ## 多个lora model
+ else:
+ ## prefill
+ if batches != -1:
+ for i, lora_id, start, seq_len in zip(range(batches), lora_indices, b_seq_start_loc, seq_len_tensor):
+ if lora_id != -1:
+ xi = x[start: start+seq_len]
+ w_t = w_t_all[lora_id]
+ y[start:start+seq_len] += (xi @ w_t.t())
+ ## decode
+ else:
+ batches = x.shape[0]
+ for i, lora_id in zip(range(batches), lora_indices):
+ if lora_id != -1:
+ xi = x[i].unsqueeze(0)
+ w_t = w_t_all[lora_id]
+ y[i] += (xi @ w_t.t()).squeeze(0)
+
+ return y
+
+
+def sbgmv_shrink(x: torch.Tensor,
+ w_t_all: torch.Tensor,
+ y: torch.Tensor,
+ b_seq_start_loc: torch.Tensor = None,
+ seq_len_tensor: torch.Tensor = None,
+ lora_indices_tensor: torch.Tensor = None,
+ batches: int = -1,
+ max_seq_length: int = -1,
+ token_nums: int = -1,
+ scale: float = 1.0,):
+ """
+ xx: inputs
+ w_t_all: lora weight
+ y: output
+ scale: float
+
+ y = x@w_t_all * scale
+ """
+ assert x.dtype == w_t_all.dtype
+ assert x.dtype in [torch.float16, torch.bfloat16]
+ assert x.is_contiguous()
+ assert y.is_contiguous()
+
+ if w_t_all.ndim == 4: # shape:(lora_num,1,size,rank)
+ assert w_t_all.size(1) == 1
+ w_t_all = w_t_all.squeeze(dim=1)
+ else:
+ assert w_t_all.ndim == 3 # shape:(lora_num,size,rank)
+ assert w_t_all.is_contiguous()
+
+ lora_num = w_t_all.shape[0]
+ lora_indices = lora_indices_tensor.cpu().tolist()
+
+ ## 单一lora model, 且所有request均使用lora
+ if lora_num == 1 and all(x == lora_indices[0] for x in lora_indices):
+ if lora_indices[0] != -1:
+ w_t = w_t_all[0]
+ y = torch.matmul(x, w_t.t()) * scale
+ ## 多个lora model
+ else:
+ ## prefill
+ if batches != -1:
+ for i, lora_id, start, seq_len in zip(range(batches), lora_indices, b_seq_start_loc, seq_len_tensor):
+ if lora_id != -1:
+ xi = x[start: start+seq_len]
+ w_t = w_t_all[lora_id]
+ y[start:start+seq_len] = (xi @ w_t.t())* scale
+ ## decode
+ else:
+ batches = x.shape[0]
+ for i, lora_id in zip(range(batches), lora_indices):
+ if lora_id != -1:
+ xi = x[i].unsqueeze(0)
+ w_t = w_t_all[lora_id]
+ y[i] = (xi @ w_t.t()).squeeze(0) * scale
+
+ return y
+
+def dynamic_scaled_quant_dynamic_int8(x, input_scales=None, int8_out=None, scales=None):
+ return ops.dynamic_scaled_quant_smoothquant(x, input_scales, int8_out, scales)
+
+
def rejection_greedy_sample_torch(
- output_token_ids: torch.Tensor, # [batch_size, max_spec_len + 1]
- cu_num_draft_tokens: torch.Tensor, # [batch_size] (前缀和形式)
- draft_token_ids: torch.Tensor, # [num_tokens]
- target_argmax: torch.Tensor, # [num_tokens]
- bonus_token_ids: torch.Tensor, # [batch_size]
- is_greedy: torch.Tensor = None, # [batch_size] 或 None
+ output_token_ids: torch.Tensor, # [batch_size, max_spec_len + 1]
+ cu_num_draft_tokens: torch.Tensor, # [batch_size] (前缀和形式)
+ draft_token_ids: torch.Tensor, # [num_tokens]
+ target_argmax: torch.Tensor, # [num_tokens]
+ bonus_token_ids: torch.Tensor, # [batch_size]
+ is_greedy: torch.Tensor = None, # [batch_size] 或 None
):
"""
完全等价于 rejection_greedy_sample_kernel 的 PyTorch 实现
@@ -3886,18 +4029,17 @@ def rejection_greedy_sample_torch(
return output_token_ids # 原位修改
-
def rejection_random_sample_torch(
- output_token_ids: torch.Tensor, # [batch_size, max_spec_len + 1]
- cu_num_draft_tokens: torch.Tensor, # [batch_size] (前缀和形式)
- draft_token_ids: torch.Tensor, # [num_tokens]
- draft_probs: torch.Tensor | None, # [num_tokens, vocab_size] 或 None
- target_probs: torch.Tensor, # [num_tokens, vocab_size]
- bonus_token_ids: torch.Tensor, # [batch_size]
- recovered_token_ids: torch.Tensor, # [num_tokens]
- uniform_probs: torch.Tensor, # [num_tokens] (0~1均匀分布)
- is_greedy: torch.Tensor | None, # [batch_size] 或 None
- NO_DRAFT_PROBS: bool = False, # 是否忽略draft_probs
+ output_token_ids: torch.Tensor, # [batch_size, max_spec_len + 1]
+ cu_num_draft_tokens: torch.Tensor, # [batch_size] (前缀和形式)
+ draft_token_ids: torch.Tensor, # [num_tokens]
+ draft_probs: torch.Tensor | None, # [num_tokens, vocab_size] 或 None
+ target_probs: torch.Tensor, # [num_tokens, vocab_size]
+ bonus_token_ids: torch.Tensor, # [batch_size]
+ recovered_token_ids: torch.Tensor, # [num_tokens]
+ uniform_probs: torch.Tensor, # [num_tokens] (0~1均匀分布)
+ is_greedy: torch.Tensor | None, # [batch_size] 或 None
+ NO_DRAFT_PROBS: bool = False, # 是否忽略draft_probs
):
batch_size = output_token_ids.size(0)
max_spec_len_plus_1 = output_token_ids.size(1)
@@ -3928,9 +4070,7 @@ def rejection_random_sample_torch(
if NO_DRAFT_PROBS:
draft_prob = 1.0
else:
- assert (
- draft_probs is not None
- ), "draft_probs不能为None当NO_DRAFT_PROBS=False"
+ assert draft_probs is not None, "draft_probs不能为None当NO_DRAFT_PROBS=False"
draft_prob = draft_probs[global_pos, draft_token_id]
# 获取target概率和均匀随机数
@@ -3952,287 +4092,4 @@ def rejection_random_sample_torch(
return output_token_ids
-
weak_ref_tensor = ops.weak_ref_tensor
-
-
-def indexer_k_cache(k: torch.Tensor, kv_cache: torch.Tensor,slot_mapping: torch.Tensor)-> None:
- num_tokens, head_dim = k.shape
- _, block_size, cache_stride = kv_cache.shape
- assert head_dim == cache_stride
- for i in range(num_tokens):
- block_idx = torch.div(slot_mapping[i], block_size, rounding_mode="floor")
- block_offset = slot_mapping[i] % block_size
- kv_cache[block_idx, block_offset, :] = k[i]
-
-
-def ref_mqa_logits(
- q: torch.Tensor, # [num_tokens, n_head, head_dim] - 可能已量化
- k: torch.Tensor, # [num_blocks, block_size, head_dim] 或展开形式 - 可能已量化
- weights: torch.Tensor, # [num_tokens, n_head, 1] - 权重
- cu_seqlen_ks: torch.Tensor, # 序列起始位置
- cu_seqlen_ke: torch.Tensor, # 序列结束位置
-) -> torch.Tensor:
- """
- 多查询注意力logits计算的PyTorch等价实现
- """
-
- M, H, D = q.shape
- N = k.shape[0]
- device = q.device
- # 初始化输出logits [M, N]
- logits = torch.full((M, N), -float('inf'), device=device, dtype=torch.float32)
- for i in range(M):
- seq_start = cu_seqlen_ks[i]
- seq_end = cu_seqlen_ke[i]
-
- if seq_start >= seq_end:
- continue
-
- #当前查询的Q [H, D]
- q_i = q[i] # [H, D]
-
- seq_k = k[seq_start:seq_end] # [seq_len, head_dim]
-
- # 计算注意力分数 [H, seq_len]
- attention_scores = torch.matmul(q_i, seq_k.T) # BF16计算
- attention_scores = F.relu(attention_scores)
-
- # 应用权重 [H, seq_len]
- attention_scores_f32 = attention_scores.float()
- weights_i = weights[i].unsqueeze(1) # [H, 1]
- weighted_scores = attention_scores_f32 * weights_i # [H, seq_len]
-
- # 汇总所有头的logits [seq_len]
- logits_i = torch.sum(weighted_scores, dim=0) # [seq_len]
-
- # 将结果填充到输出logits的对应位置
- logits[i, seq_start:seq_end] = logits_i
-
- return logits
-
-
-def ref_paged_mqa_logits(
- q: torch.Tensor,
- kv_cache: torch.Tensor,
- weights: torch.Tensor,
- context_lens: torch.Tensor,
- block_tables: torch.Tensor,
- max_model_len: int,
- clean_logits: bool = True
-) -> torch.Tensor:
- """使用分页KV缓存计算FP8多查询注意力logits的PyTorch实现
-
- Args:
- q: 查询张量 [B, next_n, H, D]
- kv_cache: 分页KV缓存 [num_blocks, block_size, 1, D]
- weights: 权重张量 [B * next_n, H], dtype=torch.float32
- context_lens: 上下文长度 [B], dtype=int32
- block_tables: 块映射表 [B, max_blocks], dtype=int32
- schedule_metadata: 调度元数据
- max_model_len: 最大序列长度,用于确定输出logits大小
-
- Returns:
- Logits张量 [B * next_n, max_model_len], dtype=torch.float32
- """
- def reassemble_k_from_paged_cache(
- kv_cache: torch.Tensor,
- block_table: torch.Tensor,
- context_len: int,
- head_dim: int,
- block_size: int
- ) -> torch.Tensor:
- """从分页缓存中重组K值"""
- num_blocks_needed = (context_len + block_size - 1) // block_size
- valid_blocks = block_table[:num_blocks_needed]
- device = kv_cache.device
- # 初始化输出K序列 [context_len, head_dim]
- k_sequence = torch.zeros(context_len, head_dim, device=device, dtype=kv_cache.dtype)
- token_offset = 0
- for block_idx in valid_blocks:
- if block_idx < 0:
- break
- # 当前块中的token数量
- tokens_in_block = min(block_size, context_len - token_offset)
- if tokens_in_block <= 0:
- break
- # 从缓存块中提取K值
- block_data = kv_cache[block_idx] # [block_size, 1, D]
-
- # 提取K值
- k_sequence[token_offset:token_offset + tokens_in_block] = block_data[:tokens_in_block, 0, :head_dim] # [tokens_in_block, D]
- token_offset += tokens_in_block
-
- return k_sequence
-
- def compute_mqa_logits(
- q: torch.Tensor, # [next_n, H, D]
- k: torch.Tensor, # [context_len, D]
- weights: torch.Tensor, # [next_n, H]
- context_len: int,
- max_model_len: int
- ) -> torch.Tensor:
- """计算多查询注意力logits"""
- next_n, H, D = q.shape
- device = q.device
-
- # 初始化批次logits [next_n, max_model_len]
- batch_logits = torch.full((next_n, max_model_len), -float('inf'),
- device=device, dtype=torch.float32)
-
- # 扩展K以匹配多头 [context_len, H, D]
- k_expanded = k.unsqueeze(1).expand(-1, H, -1) # [context_len, H, D]
-
- # 转置以便矩阵乘法
- q_transposed = q.transpose(0, 1) # [H, next_n, D]
- k_transposed = k_expanded.transpose(0, 1) # [H, context_len, D]
- # 批量计算注意力分数 [H, next_n, context_len]
- attention_scores = torch.bmm(q_transposed, k_transposed.transpose(1, 2)) # [H, next_n, context_len]
- attention_scores = F.relu(attention_scores)
- # 应用权重并汇总所有头 [next_n, context_len]
- weights_expanded = weights.transpose(0, 1).unsqueeze(2) # [H, next_n, 1]
- weighted_scores = attention_scores * weights_expanded # [H, next_n, context_len]
- logits_per_token = weighted_scores.sum(dim=0) # [next_n, context_len]
-
- # 填充到输出logits中
- batch_logits[:, :context_len] = logits_per_token
-
- return batch_logits
-
- def clean_logits_tensor(
- logits: torch.Tensor,
- context_lens: torch.Tensor,
- next_n: int,
- max_model_len: int
- ) -> torch.Tensor:
- """清理logits张量,将超出上下文长度的位置设为负无穷"""
- B = len(context_lens)
-
- for batch_idx in range(B):
- context_len = context_lens[batch_idx].item()
- if context_len >= max_model_len:
- continue
-
- # 当前批次在logits中的位置
- batch_start = batch_idx * next_n
- batch_end = (batch_idx + 1) * next_n
-
- # 将超出上下文长度的位置设为负无穷
- logits[batch_start:batch_end, context_len:] = -float('inf')
-
- return logits
-
- B, next_n, H, D = q.shape
- num_blocks, block_size, _, cache_stride = kv_cache.shape
- device = q.device
-
- # 初始化输出logits [B * next_n, max_model_len]
- logits = torch.full((B * next_n, max_model_len), -float('inf'),
- device=device, dtype=torch.float32)
-
- # 处理每个批次
- for batch_idx in range(B):
- context_len = context_lens[batch_idx].item()
- if context_len == 0:
- continue
-
- # 当前批次的查询 [next_n, H, D]
- batch_q = q[batch_idx] # [next_n, H, D]
-
- # 当前批次的权重 [next_n, H]
- batch_weights_start = batch_idx * next_n
- batch_weights_end = (batch_idx + 1) * next_n
- batch_weights = weights[batch_weights_start:batch_weights_end] # [next_n, H]
-
- # 从分页缓存中重组K值
- batch_k = reassemble_k_from_paged_cache(
- kv_cache, block_tables[batch_idx], context_len, D, block_size
- ) # [context_len, D]
- # 计算多查询注意力logits
- batch_logits = compute_mqa_logits(
- batch_q, batch_k, batch_weights, context_len, max_model_len
- ) # [next_n, max_model_len]
-
- # 填充到输出logits中
- logits[batch_weights_start:batch_weights_end] = batch_logits
-
- if clean_logits:
- # 清理logits:将超出上下文长度的位置设为负无穷
- logits = clean_logits_tensor(logits, context_lens, next_n, max_model_len)
-
- return logits
-
-
-def sparse_prefill_fwd(
- q: torch.Tensor,
- kv: torch.Tensor,
- indices: torch.Tensor,
- sm_scale: float,
- d_v: int = 512,
-):
- """
- 稀疏注意力预填充内核的PyTorch实现
-
- Args:
- - q: [s_q, h_q, d_qk], bfloat16
- - kv: [s_kv, h_kv, d_qk], bfloat16
- - indices: [s_q, h_kv, topk], int32. 无效索引设为-1或>=s_kv
- - sm_scale: float
- - d_v: 值向量的维度,只能为512
-
- Returns:
- - (output, max_logits, lse)
- - output: [s_q, h_q, d_v], bfloat16
- - max_logits: [s_q, h_q], float
- - lse: [s_q, h_q], float, 以2为底的对数求和指数
- """
- def ref_masked_attention(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- sm_scale: float,
- ) -> torch.Tensor:
- query = query * sm_scale
- dtype = query.dtype
- device = query.device
- attn = torch.einsum("qhd,khd->hqk", query, key)
- attn = attn.to(torch.float)
- attn = torch.softmax(attn, dim=-1)
- value = value.to(torch.float)
- out = torch.einsum("hqk,khd->qhd", attn, value)
- out = out.to(device).to(dtype)
- return out
- s_q, h_q, d_qk = q.shape
- s_kv, h_kv, _ = kv.shape
- _, _, topk = indices.shape
-
- device = q.device
- dtype = q.dtype
-
- # 分离K和V
- k = kv # [s_kv, h_kv, d_qk]
- v = kv[:, :, :d_v] # [s_kv, h_kv, d_v]
-
- # 初始化输出
- output = torch.zeros(s_q, h_q, d_v, device=device, dtype=dtype)
- # 处理每个查询位置
- for i in range(s_q):
- # 当前查询 [h_q, d_qk]
- q_i = q[i].unsqueeze(0) # [1, h_q, d_qk]
- # 获取当前查询位置的稀疏索引 [topk]
- sparse_indices = indices[i, 0] # [topk]
- # 过滤有效索引 (>=0 且 < s_kv)
- valid_mask = (sparse_indices >= 0) & (sparse_indices < s_kv)
- valid_indices = sparse_indices[valid_mask]
- # 获取有效的K和V
- valid_k = k[valid_indices].repeat(1, h_q, 1) # [valid_len, h_q, d_qk]
- valid_v = v[valid_indices].repeat(1, h_q, 1) # [valid_len, h_q, d_v]
- out = ref_masked_attention(
- q_i,
- valid_k,
- valid_v,
- sm_scale
- )
- out = out.view(h_q, d_v)
- output[i].copy_(out, non_blocking=True)
- return output
\ No newline at end of file
diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py
index a8b6b21..21ebeb9 100644
--- a/vllm/benchmarks/datasets.py
+++ b/vllm/benchmarks/datasets.py
@@ -31,6 +31,7 @@ from tempfile import NamedTemporaryFile
from typing import Any, cast
import numpy as np
+from huggingface_hub import snapshot_download
from PIL import Image
from typing_extensions import deprecated
@@ -60,6 +61,8 @@ except ImportError:
logger = logging.getLogger(__name__)
+DEFAULT_NUM_PROMPTS = 1000
+
# -----------------------------------------------------------------------------
# Data Classes
# -----------------------------------------------------------------------------
@@ -303,9 +306,11 @@ def process_image(image: Any) -> Mapping[str, Any]:
a JPEG in memory. - Encodes the JPEG data as a base64 string. - Returns
a dictionary with the image as a base64 data URL.
- 3. String input: - Treats the string as a URL or local file path. -
- Prepends "file://" if the string doesn't start with "http://" or
- "file://". - Returns a dictionary with the image URL.
+ 3. String input: - Treats the string as a URL, local file path, or base64
+ encoded data. - If string starts with "data:image/", treats as base64.
+ - If string starts with "http://", "https://", or "file://", treats as URL.
+ - Otherwise treats as local file path and prepends "file://".
+ - Returns a dictionary with the image URL or base64 data.
Raises:
ValueError: If the input is not a supported type.
@@ -325,14 +330,14 @@ def process_image(image: Any) -> Mapping[str, Any]:
if isinstance(image, str):
image_url = (
image
- if image.startswith(("http://", "https://", "file://"))
+ if image.startswith(("http://", "https://", "file://", "data:image/"))
else f"file://{image}"
)
return {"type": "image_url", "image_url": {"url": image_url}}
raise ValueError(
- f"Invalid image input {image}. Must be a PIL.Image.Image"
- " or str or dictionary with raw image bytes."
+ f"Invalid image input {image}. Must be a PIL.Image.Image, "
+ "str (URL, file path, or base64 data URL), or dictionary with raw image bytes."
)
@@ -1338,7 +1343,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
parser.add_argument(
"--num-prompts",
type=int,
- default=1000,
+ default=DEFAULT_NUM_PROMPTS,
help="Number of prompts to process.",
)
parser.add_argument(
@@ -2676,6 +2681,14 @@ class MMVUDataset(HuggingFaceDataset):
+ (" ".join(f"{k}.{v}" for k, v in x["choices"].items())),
}
+ def __init__(self, **kwargs) -> None:
+ super().__init__(**kwargs)
+
+ self._remote_path_root = (
+ f"https://huggingface.co/datasets/{self.hf_name}/resolve/main"
+ )
+ self._local_path_root = snapshot_download(self.hf_name, repo_type="dataset")
+
def sample(
self,
tokenizer: TokenizerLike,
@@ -2698,7 +2711,9 @@ class MMVUDataset(HuggingFaceDataset):
break
prompt = parser_fn(item)
- mm_content = process_video(item["video"])
+ mm_content = process_video(
+ item["video"].replace(self._remote_path_root, self._local_path_root)
+ )
prompt_len = len(tokenizer.encode(prompt))
if enable_multimodal_chat:
# Note: when chat is enabled the request prompt_len is no longer
diff --git a/vllm/benchmarks/lib/__init__.py b/vllm/benchmarks/lib/__init__.py
new file mode 100644
index 0000000..005e87a
--- /dev/null
+++ b/vllm/benchmarks/lib/__init__.py
@@ -0,0 +1,3 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Benchmark library utilities."""
diff --git a/vllm/benchmarks/lib/endpoint_request_func.py b/vllm/benchmarks/lib/endpoint_request_func.py
new file mode 100644
index 0000000..e231ccf
--- /dev/null
+++ b/vllm/benchmarks/lib/endpoint_request_func.py
@@ -0,0 +1,802 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""The request function for API endpoints."""
+
+import io
+import json
+import os
+import sys
+import time
+import traceback
+from collections.abc import Awaitable
+from dataclasses import dataclass, field
+from typing import Any, Literal, Protocol
+
+import aiohttp
+import regex as re
+from tqdm.asyncio import tqdm
+
+AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
+
+
+class StreamedResponseHandler:
+ """Handles streaming HTTP responses by accumulating chunks until complete
+ messages are available."""
+
+ def __init__(self):
+ self.buffer = ""
+
+ def add_chunk(self, chunk_bytes: bytes) -> list[str]:
+ """Add a chunk of bytes to the buffer and return any complete
+ messages."""
+ chunk_str = chunk_bytes.decode("utf-8")
+ self.buffer += chunk_str
+
+ messages = []
+
+ # Split by double newlines (SSE message separator)
+ while "\n\n" in self.buffer:
+ message, self.buffer = self.buffer.split("\n\n", 1)
+ message = message.strip()
+ if message:
+ messages.append(message)
+
+ # if self.buffer is not empty, check if it is a complete message
+ # by removing data: prefix and check if it is a valid JSON
+ if self.buffer.startswith("data: "):
+ message_content = self.buffer.removeprefix("data: ").strip()
+ if message_content == "[DONE]":
+ messages.append(self.buffer.strip())
+ self.buffer = ""
+ elif message_content:
+ try:
+ json.loads(message_content)
+ messages.append(self.buffer.strip())
+ self.buffer = ""
+ except json.JSONDecodeError:
+ # Incomplete JSON, wait for more chunks.
+ pass
+
+ return messages
+
+
+@dataclass
+class RequestFuncInput:
+ """The input for the request function."""
+
+ prompt: str | list[str]
+ api_url: str
+ prompt_len: int
+ output_len: int
+ model: str
+ model_name: str | None = None
+ logprobs: int | None = None
+ extra_headers: dict | None = None
+ extra_body: dict | None = None
+ multi_modal_content: dict | list[dict] | None = None
+ ignore_eos: bool = False
+ language: str | None = None
+ request_id: str | None = None
+
+
+@dataclass
+class RequestFuncOutput:
+ """The output of the request function including metrics."""
+
+ generated_text: str = ""
+ success: bool = False
+ latency: float = 0.0
+ output_tokens: int = 0
+ ttft: float = 0.0 # Time to first token
+ itl: list[float] = field(default_factory=list) # list of inter-token latencies
+ tpot: float = 0.0 # avg next-token latencies
+ prompt_len: int = 0
+ error: str = ""
+ start_time: float = 0.0
+ input_audio_duration: float = 0.0 # in seconds
+
+
+class RequestFunc(Protocol):
+ def __call__(
+ self,
+ request_func_input: RequestFuncInput,
+ session: aiohttp.ClientSession,
+ pbar: tqdm | None = None,
+ ) -> Awaitable[RequestFuncOutput]: ...
+
+
+def _validate_api_url(
+ api_url: str,
+ api_name: str,
+ expected_suffixes: str | set[str],
+) -> None:
+ if isinstance(expected_suffixes, str):
+ expected_suffixes = {expected_suffixes}
+
+ expected_suffixes = {*expected_suffixes, "profile"}
+
+ if not api_url.endswith(tuple(expected_suffixes)):
+ raise ValueError(f"{api_name} URL must end with one of: {expected_suffixes}.")
+
+
+def _update_payload_common(
+ payload: dict[str, Any],
+ request_func_input: RequestFuncInput,
+) -> None:
+ if request_func_input.ignore_eos:
+ payload["ignore_eos"] = request_func_input.ignore_eos
+ if request_func_input.extra_body:
+ payload.update(request_func_input.extra_body)
+
+
+def _update_headers_common(
+ headers: dict[str, Any],
+ request_func_input: RequestFuncInput,
+) -> None:
+ if request_func_input.extra_headers:
+ headers |= request_func_input.extra_headers
+ if request_func_input.request_id:
+ headers["x-request-id"] = request_func_input.request_id
+
+
+def _get_headers(content_type: str | None = None) -> dict[str, str]:
+ headers = {}
+ if content_type:
+ headers["Content-Type"] = content_type
+ api_key = os.environ.get("OPENAI_API_KEY")
+ if api_key:
+ headers["Authorization"] = f"Bearer {api_key}"
+ return headers
+
+
+async def async_request_openai_completions(
+ request_func_input: RequestFuncInput,
+ session: aiohttp.ClientSession,
+ pbar: tqdm | None = None,
+) -> RequestFuncOutput:
+ """The async request function for the OpenAI Completions API.
+
+ Args:
+ request_func_input: The input for the request function.
+ pbar: The progress bar to display the progress.
+
+ Returns:
+ The output of the request function.
+ """
+ api_url = request_func_input.api_url
+ _validate_api_url(api_url, "OpenAI Completions API", "completions")
+
+ payload = {
+ "model": request_func_input.model_name
+ if request_func_input.model_name
+ else request_func_input.model,
+ "prompt": request_func_input.prompt,
+ "repetition_penalty": 1.0,
+ "max_tokens": request_func_input.output_len,
+ "logprobs": request_func_input.logprobs,
+ "stream": True,
+ "stream_options": {
+ "include_usage": True,
+ },
+ }
+ _update_payload_common(payload, request_func_input)
+
+ headers = _get_headers()
+ _update_headers_common(headers, request_func_input)
+
+ output = RequestFuncOutput()
+ output.prompt_len = request_func_input.prompt_len
+
+ generated_text = ""
+ st = time.perf_counter()
+ output.start_time = st
+ most_recent_timestamp = st
+ try:
+ async with session.post(url=api_url, json=payload, headers=headers) as response:
+ if response.status == 200:
+ first_chunk_received = False
+ handler = StreamedResponseHandler()
+
+ async for chunk_bytes in response.content.iter_any():
+ chunk_bytes = chunk_bytes.strip()
+ if not chunk_bytes:
+ continue
+
+ messages = handler.add_chunk(chunk_bytes)
+ for message in messages:
+ # NOTE: SSE comments (often used as pings) start with
+ # a colon. These are not JSON data payload and should
+ # be skipped.
+ if message.startswith(":"):
+ continue
+
+ chunk = message.removeprefix("data: ")
+
+ if chunk != "[DONE]":
+ data = json.loads(chunk)
+
+ # NOTE: Some completion API might have a last
+ # usage summary response without a token so we
+ # want to check a token was generated
+ if choices := data.get("choices"):
+ # Note that text could be empty here
+ # e.g. for special tokens
+ text = choices[0].get("text")
+ timestamp = time.perf_counter()
+ # First token
+ if not first_chunk_received:
+ first_chunk_received = True
+ ttft = time.perf_counter() - st
+ output.ttft = ttft
+
+ # Decoding phase
+ else:
+ output.itl.append(timestamp - most_recent_timestamp)
+
+ most_recent_timestamp = timestamp
+ generated_text += text or ""
+ elif usage := data.get("usage"):
+ output.output_tokens = usage.get("completion_tokens")
+ if first_chunk_received:
+ output.success = True
+ else:
+ output.success = False
+ output.error = (
+ "Never received a valid chunk to calculate TTFT."
+ "This response will be marked as failed!"
+ )
+ output.generated_text = generated_text
+ output.latency = most_recent_timestamp - st
+ else:
+ output.error = response.reason or ""
+ output.success = False
+ except Exception:
+ output.success = False
+ exc_info = sys.exc_info()
+ output.error = "".join(traceback.format_exception(*exc_info))
+
+ if pbar:
+ pbar.update(1)
+ return output
+
+
+def _get_chat_content(
+ request_func_input: RequestFuncInput,
+ mm_position: Literal["first", "last"] = "last",
+) -> list[dict[str, Any]]:
+ text_contents = [{"type": "text", "text": request_func_input.prompt}]
+
+ mm_contents = []
+ if request_func_input.multi_modal_content:
+ mm_content = request_func_input.multi_modal_content
+ if isinstance(mm_content, list):
+ mm_contents.extend(request_func_input.multi_modal_content)
+ elif isinstance(mm_content, dict):
+ mm_contents.append(request_func_input.multi_modal_content)
+ else:
+ raise TypeError(
+ "multi_modal_content must be a dict or list[dict] for openai-chat"
+ )
+
+ if mm_position == "first":
+ return mm_contents + text_contents
+
+ return text_contents + mm_contents
+
+
+async def async_request_openai_chat_completions(
+ request_func_input: RequestFuncInput,
+ session: aiohttp.ClientSession,
+ pbar: tqdm | None = None,
+ mm_position: Literal["first", "last"] = "last",
+) -> RequestFuncOutput:
+ api_url = request_func_input.api_url
+ _validate_api_url(api_url, "OpenAI Chat Completions API", "chat/completions")
+
+ content = _get_chat_content(request_func_input, mm_position=mm_position)
+
+ payload = {
+ "model": request_func_input.model_name
+ if request_func_input.model_name
+ else request_func_input.model,
+ "messages": [
+ {"role": "user", "content": content},
+ ],
+ "max_completion_tokens": request_func_input.output_len,
+ "stream": True,
+ "stream_options": {
+ "include_usage": True,
+ },
+ }
+ _update_payload_common(payload, request_func_input)
+
+ headers = _get_headers("application/json")
+ _update_headers_common(headers, request_func_input)
+
+ output = RequestFuncOutput()
+ output.prompt_len = request_func_input.prompt_len
+
+ generated_text = ""
+ ttft = 0.0
+ st = time.perf_counter()
+ output.start_time = st
+ most_recent_timestamp = st
+ try:
+ async with session.post(url=api_url, json=payload, headers=headers) as response:
+ if response.status == 200:
+ handler = StreamedResponseHandler()
+ async for chunk_bytes in response.content.iter_any():
+ chunk_bytes = chunk_bytes.strip()
+ if not chunk_bytes:
+ continue
+
+ messages = handler.add_chunk(chunk_bytes)
+ for message in messages:
+ # NOTE: SSE comments (often used as pings) start with
+ # a colon. These are not JSON data payload and should
+ # be skipped.
+ if message.startswith(":"):
+ continue
+
+ chunk = message.removeprefix("data: ")
+
+ if chunk != "[DONE]":
+ timestamp = time.perf_counter()
+ data = json.loads(chunk)
+
+ if choices := data.get("choices"):
+ content = choices[0]["delta"].get("content")
+ # First token
+ if ttft == 0.0:
+ ttft = timestamp - st
+ output.ttft = ttft
+
+ # Decoding phase
+ else:
+ output.itl.append(timestamp - most_recent_timestamp)
+
+ generated_text += content or ""
+ elif usage := data.get("usage"):
+ output.output_tokens = usage.get("completion_tokens")
+
+ most_recent_timestamp = timestamp
+
+ output.generated_text = generated_text
+ output.success = True
+ output.latency = most_recent_timestamp - st
+ else:
+ output.error = response.reason or ""
+ output.success = False
+ except Exception:
+ output.success = False
+ exc_info = sys.exc_info()
+ output.error = "".join(traceback.format_exception(*exc_info))
+
+ if pbar:
+ pbar.update(1)
+ return output
+
+
+async def async_request_openai_audio(
+ request_func_input: RequestFuncInput,
+ session: aiohttp.ClientSession,
+ pbar: tqdm | None = None,
+) -> RequestFuncOutput:
+ # Lazy import without PlaceholderModule to avoid vllm dep.
+ import soundfile
+
+ api_url = request_func_input.api_url
+ _validate_api_url(api_url, "OpenAI Audio API", {"transcriptions", "translations"})
+
+ content = [{"type": "text", "text": request_func_input.prompt}]
+ payload = {
+ "model": request_func_input.model_name
+ if request_func_input.model_name
+ else request_func_input.model,
+ "max_completion_tokens": request_func_input.output_len,
+ "stream": True,
+ "language": "en",
+ # Flattened due to multipart/form-data
+ "stream_include_usage": True,
+ "stream_continuous_usage_stats": True,
+ }
+ _update_payload_common(payload, request_func_input)
+
+ headers = _get_headers()
+ _update_headers_common(headers, request_func_input)
+
+ # Send audio file
+ def to_bytes(y, sr):
+ buffer = io.BytesIO()
+ soundfile.write(buffer, y, sr, format="WAV")
+ buffer.seek(0)
+ return buffer
+
+ mm_audio = request_func_input.multi_modal_content
+ if not isinstance(mm_audio, dict) or "audio" not in mm_audio:
+ raise TypeError("multi_modal_content must be a dict containing 'audio'")
+ with to_bytes(*mm_audio["audio"]) as f:
+ form = aiohttp.FormData()
+ form.add_field("file", f, content_type="audio/wav")
+ for key, value in payload.items():
+ form.add_field(key, str(value))
+
+ output = RequestFuncOutput()
+ output.prompt_len = request_func_input.prompt_len
+ output.input_audio_duration = soundfile.info(f).duration
+ f.seek(0)
+
+ generated_text = ""
+ ttft = 0.0
+ st = time.perf_counter()
+ output.start_time = st
+ most_recent_timestamp = st
+ try:
+ async with session.post(
+ url=api_url, data=form, headers=headers
+ ) as response:
+ if response.status == 200:
+ handler = StreamedResponseHandler()
+
+ async for chunk_bytes in response.content.iter_any():
+ chunk_bytes = chunk_bytes.strip()
+ if not chunk_bytes:
+ continue
+
+ messages = handler.add_chunk(chunk_bytes)
+ for message in messages:
+ if type(message) is bytes:
+ message = message.decode("utf-8")
+ chunk = message.removeprefix("data: ")
+ if chunk != "[DONE]":
+ timestamp = time.perf_counter()
+ data = json.loads(chunk)
+
+ if choices := data.get("choices"):
+ content = choices[0]["delta"].get("content")
+ # First token
+ if ttft == 0.0:
+ ttft = timestamp - st
+ output.ttft = ttft
+
+ # Decoding phase
+ else:
+ output.itl.append(
+ timestamp - most_recent_timestamp
+ )
+
+ generated_text += content or ""
+ elif usage := data.get("usage"):
+ output.output_tokens = usage.get(
+ "completion_tokens"
+ )
+
+ most_recent_timestamp = timestamp
+
+ output.generated_text = generated_text
+ output.success = True
+ output.latency = most_recent_timestamp - st
+ else:
+ output.error = response.reason or ""
+ output.success = False
+ except Exception:
+ output.success = False
+ exc_info = sys.exc_info()
+ output.error = "".join(traceback.format_exception(*exc_info))
+
+ if pbar:
+ pbar.update(1)
+ return output
+
+
+async def _run_pooling_request(
+ session: aiohttp.ClientSession,
+ api_url: str,
+ payload: dict[str, Any],
+ headers: dict[str, Any],
+ pbar: tqdm | None = None,
+) -> RequestFuncOutput:
+ output = RequestFuncOutput()
+ st = time.perf_counter()
+ output.start_time = st
+ try:
+ async with session.post(url=api_url, headers=headers, json=payload) as response:
+ if response.status == 200:
+ output.ttft = output.latency = time.perf_counter() - st
+
+ if payload.get("encoding_format", "float") == "bytes":
+ metadata = json.loads(response.headers["metadata"])
+ usage = metadata.get("usage", {})
+ else:
+ data = await response.json()
+ usage = data.get("usage", {})
+
+ output.success = True
+ output.generated_text = ""
+ output.prompt_len = usage.get("prompt_tokens", 0)
+ else:
+ output.success = False
+ output.error = response.reason or ""
+ except Exception as e:
+ output.success = False
+ output.error = str(e)
+
+ if pbar:
+ pbar.update(1)
+ return output
+
+
+async def async_request_openai_embeddings(
+ request_func_input: RequestFuncInput,
+ session: aiohttp.ClientSession,
+ pbar: tqdm | None = None,
+) -> RequestFuncOutput:
+ api_url = request_func_input.api_url
+ _validate_api_url(api_url, "OpenAI Embeddings API", "embeddings")
+
+ payload = {
+ "model": request_func_input.model_name
+ if request_func_input.model_name
+ else request_func_input.model,
+ "input": request_func_input.prompt,
+ # Many embedding models have short context length,
+ # this is to avoid dropping some of the requests.
+ "truncate_prompt_tokens": -1,
+ }
+ _update_payload_common(payload, request_func_input)
+
+ headers = _get_headers("application/json")
+ _update_headers_common(headers, request_func_input)
+
+ return await _run_pooling_request(
+ session,
+ api_url,
+ payload=payload,
+ headers=headers,
+ pbar=pbar,
+ )
+
+
+async def async_request_vllm_rerank(
+ request_func_input: RequestFuncInput,
+ session: aiohttp.ClientSession,
+ pbar: tqdm | None = None,
+) -> RequestFuncOutput:
+ api_url = request_func_input.api_url
+ _validate_api_url(api_url, "vLLM score API", "rerank")
+
+ assert (
+ isinstance(request_func_input.prompt, list)
+ and len(request_func_input.prompt) > 1
+ )
+
+ payload = {
+ "model": request_func_input.model_name
+ if request_func_input.model_name
+ else request_func_input.model,
+ "query": request_func_input.prompt[0],
+ "documents": request_func_input.prompt[1:],
+ # Many reranker models have short context length,
+ # this is to avoid dropping some of the requests.
+ "truncate_prompt_tokens": -1,
+ }
+
+ headers = _get_headers("application/json")
+ _update_headers_common(headers, request_func_input)
+
+ return await _run_pooling_request(
+ session,
+ api_url,
+ payload=payload,
+ headers=headers,
+ pbar=pbar,
+ )
+
+
+async def async_request_openai_embeddings_chat(
+ request_func_input: RequestFuncInput,
+ session: aiohttp.ClientSession,
+ pbar: tqdm | None = None,
+ mm_position: Literal["first", "last"] = "last",
+) -> RequestFuncOutput:
+ api_url = request_func_input.api_url
+ _validate_api_url(api_url, "OpenAI Embeddings API", "embeddings")
+
+ content = _get_chat_content(request_func_input, mm_position=mm_position)
+
+ payload = {
+ "model": request_func_input.model_name
+ if request_func_input.model_name
+ else request_func_input.model,
+ "messages": [
+ {"role": "user", "content": content},
+ ],
+ # Many embedding models have short context length,
+ # this is to avoid dropping some of the requests.
+ "truncate_prompt_tokens": -1,
+ }
+ _update_payload_common(payload, request_func_input)
+
+ headers = _get_headers("application/json")
+ _update_headers_common(headers, request_func_input)
+
+ return await _run_pooling_request(
+ session,
+ api_url,
+ payload=payload,
+ headers=headers,
+ pbar=pbar,
+ )
+
+
+def _try_extract_request_idx(request_func_input: RequestFuncInput):
+ if request_func_input.request_id:
+ match = re.search(r"(\d+)$", request_func_input.request_id)
+ if match:
+ try:
+ return int(match.group(1))
+ except ValueError:
+ pass
+
+ return None
+
+
+def _preprocess_clip(request_func_input: RequestFuncInput):
+ if request_func_input.multi_modal_content:
+ # Image input
+ request_func_input.prompt = ""
+
+
+def _preprocess_vlm2vec(request_func_input: RequestFuncInput):
+ if request_func_input.multi_modal_content:
+ request_idx = _try_extract_request_idx(request_func_input)
+
+ # Adjust the ratio manually if needed.
+ use_image_only_prompt = request_idx is None or request_idx % 2 == 0
+
+ if use_image_only_prompt:
+ # Image input
+ request_func_input.prompt = "Represent the given image."
+ else:
+ # Text+Image input
+ request_func_input.prompt = (
+ f"Represent the given image with the following question: "
+ f"{request_func_input.prompt}"
+ )
+
+
+async def async_request_openai_embeddings_clip(
+ request_func_input: RequestFuncInput,
+ session: aiohttp.ClientSession,
+ pbar: tqdm | None = None,
+) -> RequestFuncOutput:
+ _preprocess_clip(request_func_input)
+
+ return await async_request_openai_embeddings_chat(
+ request_func_input,
+ session,
+ pbar=pbar,
+ )
+
+
+async def async_request_openai_embeddings_vlm2vec(
+ request_func_input: RequestFuncInput,
+ session: aiohttp.ClientSession,
+ pbar: tqdm | None = None,
+) -> RequestFuncOutput:
+ _preprocess_vlm2vec(request_func_input)
+
+ return await async_request_openai_embeddings_chat(
+ request_func_input,
+ session,
+ pbar=pbar,
+ mm_position="first",
+ )
+
+
+async def async_request_infinity_embeddings(
+ request_func_input: RequestFuncInput,
+ session: aiohttp.ClientSession,
+ pbar: tqdm | None = None,
+) -> RequestFuncOutput:
+ api_url = request_func_input.api_url
+ _validate_api_url(api_url, "Infinity Embeddings API", "embeddings")
+
+ payload = {
+ "model": request_func_input.model_name
+ if request_func_input.model_name
+ else request_func_input.model,
+ }
+
+ if request_func_input.prompt:
+ payload["input"] = request_func_input.prompt
+ else:
+ mm_content = request_func_input.multi_modal_content
+ assert isinstance(mm_content, dict)
+
+ mm_type = mm_content["type"]
+ payload["input"] = mm_content[mm_type]["url"]
+ payload["modality"] = mm_type.split("_", 1)[0]
+
+ _update_payload_common(payload, request_func_input)
+
+ headers = _get_headers("application/json")
+ _update_headers_common(headers, request_func_input)
+
+ return await _run_pooling_request(
+ session,
+ api_url,
+ payload=payload,
+ headers=headers,
+ pbar=pbar,
+ )
+
+
+async def async_request_infinity_embeddings_clip(
+ request_func_input: RequestFuncInput,
+ session: aiohttp.ClientSession,
+ pbar: tqdm | None = None,
+) -> RequestFuncOutput:
+ _preprocess_clip(request_func_input)
+
+ return await async_request_infinity_embeddings(
+ request_func_input,
+ session,
+ pbar=pbar,
+ )
+
+
+async def async_request_vllm_pooling(
+ request_func_input: RequestFuncInput,
+ session: aiohttp.ClientSession,
+ pbar: tqdm | None = None,
+) -> RequestFuncOutput:
+ api_url = request_func_input.api_url
+ _validate_api_url(api_url, "vLLM Pooling API", "pooling")
+
+ payload = {
+ "model": request_func_input.model_name
+ if request_func_input.model_name
+ else request_func_input.model,
+ "truncate_prompt_tokens": -1,
+ }
+
+ payload = payload | request_func_input.prompt
+
+ _update_payload_common(payload, request_func_input)
+
+ headers = _get_headers("application/json")
+ _update_headers_common(headers, request_func_input)
+
+ return await _run_pooling_request(
+ session,
+ api_url,
+ payload=payload,
+ headers=headers,
+ pbar=pbar,
+ )
+
+
+# TODO: Add more request functions for different API protocols.
+ASYNC_REQUEST_FUNCS: dict[str, RequestFunc] = {
+ "vllm": async_request_openai_completions,
+ "openai": async_request_openai_completions,
+ "openai-chat": async_request_openai_chat_completions,
+ "openai-audio": async_request_openai_audio,
+ "openai-embeddings": async_request_openai_embeddings,
+ "openai-embeddings-chat": async_request_openai_embeddings_chat,
+ "openai-embeddings-clip": async_request_openai_embeddings_clip,
+ "openai-embeddings-vlm2vec": async_request_openai_embeddings_vlm2vec,
+ # Infinity embedding server: https://github.com/michaelfeil/infinity
+ "infinity-embeddings": async_request_infinity_embeddings,
+ "infinity-embeddings-clip": async_request_infinity_embeddings_clip,
+ # (Infinity embedding server does not support vlm2vec)
+ "vllm-pooling": async_request_vllm_pooling,
+ "vllm-rerank": async_request_vllm_rerank,
+}
+
+OPENAI_COMPATIBLE_BACKENDS = [
+ k
+ for k, v in ASYNC_REQUEST_FUNCS.items()
+ if v in (async_request_openai_completions, async_request_openai_chat_completions)
+]
diff --git a/vllm/benchmarks/lib/ready_checker.py b/vllm/benchmarks/lib/ready_checker.py
new file mode 100644
index 0000000..eec4a42
--- /dev/null
+++ b/vllm/benchmarks/lib/ready_checker.py
@@ -0,0 +1,79 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Utilities for checking endpoint readiness."""
+
+import asyncio
+import time
+
+import aiohttp
+from tqdm.asyncio import tqdm
+
+from vllm.logger import init_logger
+
+from .endpoint_request_func import RequestFunc, RequestFuncInput, RequestFuncOutput
+
+logger = init_logger(__name__)
+
+
+async def wait_for_endpoint(
+ request_func: RequestFunc,
+ test_input: RequestFuncInput,
+ session: aiohttp.ClientSession,
+ timeout_seconds: int = 600,
+ retry_interval: int = 5,
+) -> RequestFuncOutput:
+ """
+ Wait for an endpoint to become available before starting benchmarks.
+
+ Args:
+ request_func: The async request function to call
+ test_input: The RequestFuncInput to test with
+ timeout_seconds: Maximum time to wait in seconds (default: 10 minutes)
+ retry_interval: Time between retries in seconds (default: 5 seconds)
+
+ Returns:
+ RequestFuncOutput: The successful response
+
+ Raises:
+ ValueError: If the endpoint doesn't become available within the timeout
+ """
+ deadline = time.perf_counter() + timeout_seconds
+ output = RequestFuncOutput(success=False)
+ print(f"Waiting for endpoint to become up in {timeout_seconds} seconds")
+
+ with tqdm(
+ total=timeout_seconds,
+ bar_format="{desc} |{bar}| {elapsed} elapsed, {remaining} remaining",
+ unit="s",
+ ) as pbar:
+ while True:
+ # update progress bar
+ remaining = deadline - time.perf_counter()
+ elapsed = timeout_seconds - remaining
+ update_amount = min(elapsed - pbar.n, timeout_seconds - pbar.n)
+ pbar.update(update_amount)
+ pbar.refresh()
+ if remaining <= 0:
+ pbar.close()
+ break
+
+ # ping the endpoint using request_func
+ try:
+ output = await request_func(
+ request_func_input=test_input, session=session
+ )
+ if output.success:
+ pbar.close()
+ return output
+ else:
+ err_last_line = str(output.error).rstrip().rsplit("\n", 1)[-1]
+ logger.warning("Endpoint is not ready. Error='%s'", err_last_line)
+ except aiohttp.ClientConnectorError:
+ pass
+
+ # retry after a delay
+ sleep_duration = min(retry_interval, remaining)
+ if sleep_duration > 0:
+ await asyncio.sleep(sleep_duration)
+
+ return output
diff --git a/vllm/benchmarks/lib/utils.py b/vllm/benchmarks/lib/utils.py
new file mode 100644
index 0000000..99a3bf9
--- /dev/null
+++ b/vllm/benchmarks/lib/utils.py
@@ -0,0 +1,131 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import argparse
+import json
+import math
+import os
+from contextlib import contextmanager
+from typing import Any
+
+
+def extract_field(
+ args: argparse.Namespace, extra_info: dict[str, Any], field_name: str
+) -> str:
+ if field_name in extra_info:
+ return extra_info[field_name]
+
+ v = args
+ # For example, args.compilation_config.mode
+ for nested_field in field_name.split("."):
+ if not hasattr(v, nested_field):
+ return ""
+ v = getattr(v, nested_field)
+ return v
+
+
+def use_compile(args: argparse.Namespace, extra_info: dict[str, Any]) -> bool:
+ """
+ Check if the benchmark is run with torch.compile
+ """
+ return not (
+ extract_field(args, extra_info, "compilation_config.mode") == "0"
+ or "eager" in getattr(args, "output_json", "")
+ or "eager" in getattr(args, "result_filename", "")
+ )
+
+
+def convert_to_pytorch_benchmark_format(
+ args: argparse.Namespace, metrics: dict[str, list], extra_info: dict[str, Any]
+) -> list:
+ """
+ Save the benchmark results in the format used by PyTorch OSS benchmark with
+ on metric per record
+ https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
+ """
+ records = []
+ if not os.environ.get("SAVE_TO_PYTORCH_BENCHMARK_FORMAT", False):
+ return records
+
+ for name, benchmark_values in metrics.items():
+ if not isinstance(benchmark_values, list):
+ raise TypeError(
+ f"benchmark_values for metric '{name}' must be a list, "
+ f"but got {type(benchmark_values).__name__}"
+ )
+
+ record = {
+ "benchmark": {
+ "name": "vLLM benchmark",
+ "extra_info": {
+ "args": vars(args),
+ "compilation_config.mode": extract_field(
+ args, extra_info, "compilation_config.mode"
+ ),
+ "optimization_level": extract_field(
+ args, extra_info, "optimization_level"
+ ),
+ # A boolean field used by vLLM benchmark HUD dashboard
+ "use_compile": use_compile(args, extra_info),
+ },
+ },
+ "model": {
+ "name": args.model,
+ },
+ "metric": {
+ "name": name,
+ "benchmark_values": benchmark_values,
+ "extra_info": extra_info,
+ },
+ }
+
+ tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size")
+ # Save tensor_parallel_size parameter if it's part of the metadata
+ if not tp and "tensor_parallel_size" in extra_info:
+ record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = (
+ extra_info["tensor_parallel_size"]
+ )
+
+ records.append(record)
+
+ return records
+
+
+class InfEncoder(json.JSONEncoder):
+ def clear_inf(self, o: Any):
+ if isinstance(o, dict):
+ return {
+ str(k)
+ if not isinstance(k, (str, int, float, bool, type(None)))
+ else k: self.clear_inf(v)
+ for k, v in o.items()
+ }
+ elif isinstance(o, list):
+ return [self.clear_inf(v) for v in o]
+ elif isinstance(o, float) and math.isinf(o):
+ return "inf"
+ return o
+
+ def iterencode(self, o: Any, *args, **kwargs) -> Any:
+ return super().iterencode(self.clear_inf(o), *args, **kwargs)
+
+
+def write_to_json(filename: str, records: list) -> None:
+ with open(filename, "w") as f:
+ json.dump(
+ records,
+ f,
+ cls=InfEncoder,
+ default=lambda o: f"<{type(o).__name__} is not JSON serializable>",
+ )
+
+
+@contextmanager
+def default_vllm_config():
+ """Set a default VllmConfig for cases that directly test CustomOps or pathways
+ that use get_current_vllm_config() outside of a full engine context.
+ """
+ from vllm.config import VllmConfig, set_current_vllm_config
+
+ with set_current_vllm_config(VllmConfig()):
+ yield
diff --git a/vllm/benchmarks/plot.py b/vllm/benchmarks/plot.py
new file mode 100644
index 0000000..3f36ede
--- /dev/null
+++ b/vllm/benchmarks/plot.py
@@ -0,0 +1,316 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Generate plots for benchmark results."""
+
+from pathlib import Path
+from typing import Any
+
+from vllm.utils.import_utils import PlaceholderModule
+
+try:
+ import plotly.express as px
+ import plotly.io as pio
+except ImportError:
+ _plotly = PlaceholderModule("plotly")
+ px = _plotly.placeholder_attr("express")
+ pio = _plotly.placeholder_attr("io")
+
+try:
+ import matplotlib.pyplot as plt
+except ImportError:
+ _matplotlib = PlaceholderModule("matplotlib")
+ plt = _matplotlib.placeholder_attr("pyplot")
+
+
+def generate_timeline_plot(
+ results: list[dict[str, Any]],
+ output_path: Path,
+ colors: list[str] | None = None,
+ itl_thresholds: list[float] | None = None,
+ labels: list[str] | None = None,
+) -> None:
+ """
+ Generate an HTML timeline plot from benchmark results.
+
+ Args:
+ results: List of per-request result dictionaries containing:
+ - start_time: Request start time (seconds)
+ - ttft: Time to first token (seconds)
+ - itl: List of inter-token latencies (seconds)
+ - latency: Total request latency (seconds)
+ - prompt_len: Number of prompt tokens
+ - output_tokens: Number of output tokens
+ output_path: Path where the HTML file will be saved
+ colors: List of colors for ITL categories (default: green, orange, red, black)
+ itl_thresholds: ITL thresholds in seconds (default: [1.0, 4.0, 6.0])
+ labels: Labels for ITL categories (default based on thresholds)
+ """
+
+ # Set defaults
+ if colors is None:
+ colors = ["#109618", "#FF7F0E", "#D62728"]
+ if itl_thresholds is None:
+ itl_thresholds = [0.025, 0.050]
+ if labels is None:
+ labels = [
+ f"ITL < {itl_thresholds[0] * 1000:.0f}ms",
+ f"{itl_thresholds[0] * 1000:.0f}ms ≤ ITL < {itl_thresholds[1] * 1000:.0f}ms", # noqa
+ f"ITL ≥ {itl_thresholds[1] * 1000:.0f}ms",
+ ]
+
+ labels_colors = {"TTFT": "#636EFA", **dict(zip(labels, colors))}
+ labels_order = ["TTFT"] + labels
+
+ timeline_data = construct_timeline_data(results, itl_thresholds, labels)
+
+ if not timeline_data:
+ print("No timeline data to plot")
+ return
+
+ # Create the plot
+ fig = px.timeline(
+ timeline_data,
+ x_start="start",
+ x_end="end",
+ y="request_id",
+ color="type",
+ color_discrete_map=labels_colors,
+ category_orders={"type": labels_order},
+ hover_data=[
+ "prompt_tokens",
+ "output_tokens",
+ "req_start_time",
+ "req_finish_time",
+ "segment_start",
+ "segment_end",
+ "duration",
+ ],
+ )
+
+ # Customize hover template to show only time without date
+ fig.update_traces(
+ hovertemplate="%{y}
"
+ "Type: %{fullData.name}
"
+ "Start: %{customdata[4]}
"
+ "End: %{customdata[5]}
"
+ "Duration: %{customdata[6]}
"
+ "Prompt Tokens: %{customdata[0]}
"
+ "Output Tokens: %{customdata[1]}
"
+ "Request Start Time: %{customdata[2]}
"
+ "Request End Time: %{customdata[3]}
"
+ ""
+ )
+
+ fig.update_yaxes(autorange="reversed")
+ fig.update_layout(
+ xaxis_title="Time",
+ yaxis_title="Request ID",
+ showlegend=True,
+ )
+
+ # Save to HTML
+ pio.write_html(fig, str(output_path))
+ print(f"Timeline plot saved to: {output_path}")
+
+
+def construct_timeline_data(
+ requests_data: list[dict[str, Any]],
+ itl_thresholds: list[float],
+ labels: list[str],
+) -> list[dict[str, Any]]:
+ """
+ Construct timeline data from request results.
+
+ Args:
+ requests_data: List of per-request result dictionaries
+ itl_thresholds: ITL thresholds in seconds
+ labels: Labels for ITL categories
+
+ Returns:
+ List of timeline segments for plotting
+ """
+
+ def tostr(sec_time: float) -> str:
+ """Convert seconds to HH:MM:SS.mmm format."""
+ h = int(sec_time // 3600)
+ assert h < 100, "time seems to last more than 100 hours"
+ m = int((sec_time % 3600) // 60)
+ s = sec_time % 60
+ return f"{h:02d}:{m:02d}:{s:06.3f}"
+
+ def itl_type(itl: float) -> str:
+ """Categorize ITL based on thresholds."""
+ if itl < itl_thresholds[0]:
+ return labels[0]
+ elif itl < itl_thresholds[1]:
+ return labels[1]
+ else:
+ return labels[2]
+
+ # Find the earliest start time to use as t0
+ t0 = None
+ for request in requests_data:
+ start_time = request.get("start_time")
+ if start_time is not None and (t0 is None or start_time < t0):
+ t0 = start_time
+
+ if t0 is None:
+ return []
+
+ timeline_data = []
+
+ for i, request in enumerate(requests_data):
+ start_time = request.get("start_time")
+ ttft = request.get("ttft")
+ itl = request.get("itl", [])
+ latency = request.get("latency")
+ prompt_len = request.get("prompt_len", 0)
+ output_tokens = request.get("output_tokens", 0)
+
+ # Skip requests without required data
+ if start_time is None or ttft is None or latency is None:
+ continue
+
+ # Normalize start time
+ start_time = start_time - t0
+ start_time_str = tostr(start_time)
+
+ # TTFT segment
+ ttft_end = start_time + ttft
+ ttft_end_str = tostr(ttft_end)
+
+ timeline_data.append(
+ {
+ "request_id": f"Req {i}",
+ "start": start_time_str,
+ "end": ttft_end_str,
+ "type": "TTFT",
+ "prompt_tokens": prompt_len,
+ "output_tokens": output_tokens,
+ "req_start_time": tostr(start_time),
+ "req_finish_time": tostr(start_time + latency),
+ "segment_start": start_time_str,
+ "segment_end": ttft_end_str,
+ "duration": f"{ttft:.3f}s",
+ }
+ )
+
+ # ITL segments
+ prev_time = ttft_end
+ prev_time_str = ttft_end_str
+
+ for itl_value in itl:
+ itl_end = prev_time + itl_value
+ itl_end_str = tostr(itl_end)
+
+ timeline_data.append(
+ {
+ "request_id": f"Req {i}",
+ "start": prev_time_str,
+ "end": itl_end_str,
+ "type": itl_type(itl_value),
+ "prompt_tokens": prompt_len,
+ "output_tokens": output_tokens,
+ "req_start_time": tostr(start_time),
+ "req_finish_time": tostr(start_time + latency),
+ "segment_start": prev_time_str,
+ "segment_end": itl_end_str,
+ "duration": f"{itl_value:.3f}s",
+ }
+ )
+
+ prev_time = itl_end
+ prev_time_str = itl_end_str
+
+ return timeline_data
+
+
+def generate_dataset_stats_plot(
+ results: list[dict[str, Any]],
+ output_path: Path,
+) -> None:
+ """
+ Generate a matplotlib figure with dataset statistics.
+
+ Creates a figure with 4 subplots:
+ - Top-left: Prompt tokens distribution (histogram)
+ - Top-right: Output tokens distribution (histogram)
+ - Bottom-left: Prompt+output tokens distribution (histogram)
+ - Bottom-right: Stacked bar chart (request_id vs tokens)
+
+ Args:
+ results: List of per-request result dictionaries containing:
+ - prompt_len: Number of prompt tokens
+ - output_tokens: Number of output tokens
+ output_path: Path where the figure will be saved
+ """
+ # Extract data
+ prompt_tokens = []
+ output_tokens = []
+ total_tokens = []
+
+ for request in results:
+ prompt_len = request.get("prompt_len", 0)
+ output_len = request.get("output_tokens", 0)
+
+ prompt_tokens.append(prompt_len)
+ output_tokens.append(output_len)
+ total_tokens.append(prompt_len + output_len)
+
+ if not prompt_tokens:
+ print("No data available for dataset statistics plot")
+ return
+
+ # Create figure with 4 subplots
+ fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 10))
+
+ # Top-left: Prompt tokens distribution
+ ax1.hist(prompt_tokens, bins=30, color="steelblue", edgecolor="black", alpha=0.7)
+ ax1.set_xlabel("Prompt Tokens")
+ ax1.set_ylabel("Frequency")
+ ax1.set_title("Prompt Tokens Distribution")
+ ax1.grid(True, alpha=0.3)
+
+ # Top-right: Output tokens distribution
+ ax2.hist(output_tokens, bins=30, color="coral", edgecolor="black", alpha=0.7)
+ ax2.set_xlabel("Output Tokens")
+ ax2.set_ylabel("Frequency")
+ ax2.set_title("Output Tokens Distribution")
+ ax2.grid(True, alpha=0.3)
+
+ # Bottom-left: Prompt+output tokens distribution
+ ax3.hist(
+ total_tokens, bins=30, color="mediumseagreen", edgecolor="black", alpha=0.7
+ )
+ ax3.set_xlabel("Total Tokens (Prompt + Output)")
+ ax3.set_ylabel("Frequency")
+ ax3.set_title("Total Tokens Distribution")
+ ax3.grid(True, alpha=0.3)
+
+ # Bottom-right: Stacked bar chart
+ request_ids = list(range(len(prompt_tokens)))
+ ax4.bar(
+ request_ids, prompt_tokens, label="Prompt Tokens", color="steelblue", alpha=0.7
+ )
+ ax4.bar(
+ request_ids,
+ output_tokens,
+ bottom=prompt_tokens,
+ label="Output Tokens",
+ color="coral",
+ alpha=0.7,
+ )
+ ax4.set_xlabel("Request ID")
+ ax4.set_ylabel("Tokens")
+ ax4.set_title("Tokens per Request (Stacked)")
+ ax4.legend()
+ ax4.grid(True, alpha=0.3, axis="y")
+
+ # Adjust layout to prevent overlap
+ plt.tight_layout()
+
+ # Save figure
+ plt.savefig(str(output_path), dpi=150, bbox_inches="tight")
+ plt.close(fig)
+
+ print(f"Dataset statistics plot saved to: {output_path}")
diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py
index 06e67f9..7c9a95e 100644
--- a/vllm/benchmarks/serve.py
+++ b/vllm/benchmarks/serve.py
@@ -34,6 +34,7 @@ from collections.abc import AsyncGenerator, Iterable
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
+from pathlib import Path
from typing import Any, Literal
import aiohttp
@@ -1183,6 +1184,49 @@ def save_to_pytorch_benchmark_format(
write_to_json(pt_file, pt_records)
+def compute_result_filename(
+ args: argparse.Namespace,
+ model_id: str,
+ label: str,
+ current_dt: str,
+) -> str | None:
+ """Compute the result filename based on benchmark configuration.
+
+ Args:
+ args: Command line arguments containing result configuration
+ model_id: The model identifier
+ label: The benchmark label
+ current_dt: Current datetime string
+
+ Returns:
+ The computed filename path or None if no result saving is requested
+ """
+ if not (args.plot_timeline or args.save_result or args.append_result):
+ return None
+
+ base_model_id = model_id.split("/")[-1]
+ max_concurrency_str = (
+ f"-concurrency{args.max_concurrency}"
+ if args.max_concurrency is not None
+ else ""
+ )
+ label = label or args.backend
+
+ if args.ramp_up_strategy is not None:
+ file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
+ else:
+ file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
+
+ if args.result_filename:
+ file_name = args.result_filename
+
+ if args.result_dir:
+ os.makedirs(args.result_dir, exist_ok=True)
+ file_name = os.path.join(args.result_dir, file_name)
+
+ return file_name
+
+
def add_cli_args(parser: argparse.ArgumentParser):
add_dataset_parser(parser)
parser.add_argument(
@@ -1277,6 +1321,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
- "slow" will always use the slow tokenizer.\n
- "mistral" will always use the tokenizer from `mistral_common`.\n
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n
+ - "qwen_vl" will always use the tokenizer from `qwen_vl`.\n
- Other custom values can be supported via plugins.""",
)
parser.add_argument("--use-beam-search", action="store_true")
@@ -1535,6 +1580,30 @@ def add_cli_args(parser: argparse.ArgumentParser):
"connecting to servers with self-signed certificates.",
)
+ parser.add_argument(
+ "--plot-timeline",
+ action="store_true",
+ help="Generate an HTML timeline plot showing request execution. "
+ "The plot will be saved alongside the results JSON file.",
+ )
+ parser.add_argument(
+ "--timeline-itl-thresholds",
+ type=float,
+ nargs=2,
+ default=[25.0, 50.0],
+ metavar=("THRESHOLD1", "THRESHOLD2"),
+ help="ITL thresholds in milliseconds for timeline plot coloring. "
+ "Specify two values to categorize inter-token latencies into three groups: "
+ "below first threshold (green), between thresholds (orange), "
+ "and above second threshold (red). Default: 25 50 (milliseconds).",
+ )
+ parser.add_argument(
+ "--plot-dataset-stats",
+ action="store_true",
+ help="Generate a matplotlib figure with dataset statistics showing "
+ "prompt tokens, output tokens, and combined token distributions.",
+ )
+
def main(args: argparse.Namespace) -> dict[str, Any]:
return asyncio.run(main_async(args))
@@ -1770,6 +1839,86 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
# Merge with benchmark result
result_json = {**result_json, **benchmark_result}
+ # Compute file_name once before using it for plots or saving results
+ file_name = compute_result_filename(args, model_id, label, current_dt)
+
+ # Generate timeline plot if requested
+ if args.plot_timeline:
+ try:
+ from vllm.benchmarks.plot import generate_timeline_plot
+
+ # Prepare per-request data for timeline
+ per_request_data = []
+ start_times = benchmark_result.get("start_times", [])
+ ttfts = benchmark_result.get("ttfts", [])
+ itls = benchmark_result.get("itls", [])
+ input_lens = benchmark_result.get("input_lens", [])
+ output_lens = benchmark_result.get("output_lens", [])
+
+ if start_times and ttfts and itls:
+ for i in range(len(start_times)):
+ # Calculate latency as ttft + sum of all itls
+ latency = ttfts[i] + sum(itls[i]) if itls[i] else ttfts[i]
+
+ per_request_data.append(
+ {
+ "start_time": start_times[i],
+ "ttft": ttfts[i],
+ "itl": itls[i],
+ "latency": latency,
+ "prompt_len": input_lens[i],
+ "output_tokens": output_lens[i],
+ }
+ )
+
+ timeline_path = Path(file_name).with_suffix(".timeline.html")
+ # Convert thresholds from milliseconds to seconds
+ itl_thresholds_sec = [t / 1000.0 for t in args.timeline_itl_thresholds]
+ generate_timeline_plot(
+ per_request_data, timeline_path, itl_thresholds=itl_thresholds_sec
+ )
+ else:
+ warnings.warn(
+ "Timeline plot requires detailed metrics. "
+ "Ensure the benchmark completed successfully.",
+ stacklevel=2,
+ )
+ except Exception as e:
+ warnings.warn(f"Failed to generate timeline plot: {e}", stacklevel=2)
+
+ # Generate dataset statistics plot if requested
+ if args.plot_dataset_stats:
+ try:
+ from vllm.benchmarks.plot import generate_dataset_stats_plot
+
+ # Prepare per-request data for dataset stats
+ per_request_data = []
+ input_lens = benchmark_result.get("input_lens", [])
+ output_lens = benchmark_result.get("output_lens", [])
+
+ if input_lens and output_lens:
+ for req_input_len, req_output_len in zip(input_lens, output_lens):
+ per_request_data.append(
+ {
+ "prompt_len": req_input_len,
+ "output_tokens": req_output_len,
+ }
+ )
+
+ stats_path = Path(file_name).with_suffix(".dataset_stats.png")
+ generate_dataset_stats_plot(per_request_data, stats_path)
+ else:
+ warnings.warn(
+ "Dataset statistics plot requires input and "
+ "output length data. Ensure the benchmark completed "
+ "successfully.",
+ stacklevel=2,
+ )
+ except Exception as e:
+ warnings.warn(
+ f"Failed to generate dataset statistics plot: {e}", stacklevel=2
+ )
+
if not args.save_detailed:
# Remove fields with too many data points
for field in [
@@ -1786,24 +1935,8 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
if field in benchmark_result:
del benchmark_result[field]
- # Save to file
+ # Save to file
if args.save_result or args.append_result:
- base_model_id = model_id.split("/")[-1]
- max_concurrency_str = (
- f"-concurrency{args.max_concurrency}"
- if args.max_concurrency is not None
- else ""
- )
- label = label or args.backend
- if args.ramp_up_strategy is not None:
- file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
- else:
- file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
- if args.result_filename:
- file_name = args.result_filename
- if args.result_dir:
- os.makedirs(args.result_dir, exist_ok=True)
- file_name = os.path.join(args.result_dir, file_name)
with open(
file_name, mode="a+" if args.append_result else "w", encoding="utf-8"
) as outfile:
diff --git a/vllm/benchmarks/sweep/cli.py b/vllm/benchmarks/sweep/cli.py
index a752000..7554910 100644
--- a/vllm/benchmarks/sweep/cli.py
+++ b/vllm/benchmarks/sweep/cli.py
@@ -10,14 +10,14 @@ from .plot_pareto import SweepPlotParetoArgs
from .plot_pareto import main as plot_pareto_main
from .serve import SweepServeArgs
from .serve import main as serve_main
-from .serve_sla import SweepServeSLAArgs
-from .serve_sla import main as serve_sla_main
+from .serve_workload import SweepServeWorkloadArgs
+from .serve_workload import main as serve_workload_main
from .startup import SweepStartupArgs
from .startup import main as startup_main
SUBCOMMANDS = (
(SweepServeArgs, serve_main),
- (SweepServeSLAArgs, serve_sla_main),
+ (SweepServeWorkloadArgs, serve_workload_main),
(SweepStartupArgs, startup_main),
(SweepPlotArgs, plot_main),
(SweepPlotParetoArgs, plot_pareto_main),
diff --git a/vllm/benchmarks/sweep/plot.py b/vllm/benchmarks/sweep/plot.py
index 53c7db3..156e18f 100644
--- a/vllm/benchmarks/sweep/plot.py
+++ b/vllm/benchmarks/sweep/plot.py
@@ -324,6 +324,11 @@ def _plot_fig(
df = filter_by.apply(df)
df = bin_by.apply(df)
+ if len(df) == 0:
+ print(f"No data to plot. Filters: {filter_by}")
+ print("[END FIGURE]")
+ return
+
# Sort by curve_by columns alphabetically for consistent legend ordering
if curve_by:
df = df.sort_values(by=curve_by)
@@ -494,7 +499,7 @@ class SweepPlotArgs:
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
- output_dir = Path(args.OUTPUT_DIR)
+ output_dir = Path(args.EXPERIMENT_DIR)
if not output_dir.exists():
raise ValueError(f"No parameter sweep results under {output_dir}")
@@ -526,11 +531,9 @@ class SweepPlotArgs:
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument(
- "OUTPUT_DIR",
+ "EXPERIMENT_DIR",
type=str,
- default="results",
- help="The directory containing the results to plot, "
- "i.e., the `--output-dir` argument to the parameter sweep script.",
+ help="The directory containing the sweep results to plot.",
)
parser.add_argument(
"--fig-dir",
@@ -570,13 +573,13 @@ class SweepPlotArgs:
parser.add_argument(
"--var-x",
type=str,
- default="request_throughput",
+ default="total_token_throughput",
help="The variable for the x-axis.",
)
parser.add_argument(
"--var-y",
type=str,
- default="p99_ttft_ms",
+ default="median_ttft_ms",
help="The variable for the y-axis",
)
parser.add_argument(
diff --git a/vllm/benchmarks/sweep/plot_pareto.py b/vllm/benchmarks/sweep/plot_pareto.py
index 3d17e47..365e87f 100644
--- a/vllm/benchmarks/sweep/plot_pareto.py
+++ b/vllm/benchmarks/sweep/plot_pareto.py
@@ -325,7 +325,7 @@ class SweepPlotParetoArgs:
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
- output_dir = Path(args.OUTPUT_DIR)
+ output_dir = Path(args.EXPERIMENT_DIR)
if not output_dir.exists():
raise ValueError(f"No parameter sweep results under {output_dir}")
@@ -342,9 +342,8 @@ class SweepPlotParetoArgs:
@classmethod
def add_cli_args(cls, parser: argparse.ArgumentParser):
parser.add_argument(
- "OUTPUT_DIR",
+ "EXPERIMENT_DIR",
type=str,
- default="results",
help="The directory containing the sweep results to plot.",
)
parser.add_argument(
diff --git a/vllm/benchmarks/sweep/serve.py b/vllm/benchmarks/sweep/serve.py
index 7420f25..f64006e 100644
--- a/vllm/benchmarks/sweep/serve.py
+++ b/vllm/benchmarks/sweep/serve.py
@@ -4,6 +4,7 @@ import argparse
import contextlib
import json
import shlex
+from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
@@ -135,17 +136,21 @@ def run_benchmark(
def _get_comb_base_path(
- output_dir: Path,
+ experiment_dir: Path,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
+ *,
+ extra_parts: tuple[str, ...] = (),
):
parts = list[str]()
if serve_comb:
parts.extend(("SERVE-", serve_comb.name))
if bench_comb:
parts.extend(("BENCH-", bench_comb.name))
+ if extra_parts:
+ parts.extend(extra_parts)
- return output_dir / sanitize_filename("-".join(parts))
+ return experiment_dir / sanitize_filename("-".join(parts))
def _get_comb_run_path(base_path: Path, run_number: int | None):
@@ -158,10 +163,10 @@ def _get_comb_run_path(base_path: Path, run_number: int | None):
def _comb_needs_server(
serve_comb: ParameterSweepItem,
bench_combs: ParameterSweep,
- output_dir: Path,
+ experiment_dir: Path,
):
for bench_comb in bench_combs:
- base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
+ base_path = _get_comb_base_path(experiment_dir, serve_comb, bench_comb)
if not _get_comb_run_path(base_path, run_number=None).exists():
return True
@@ -175,11 +180,11 @@ def server_ctx(
show_stdout: bool,
serve_comb: ParameterSweepItem,
bench_params: ParameterSweep,
- output_dir: Path,
+ experiment_dir: Path,
dry_run: bool,
server_ready_timeout: int = 300,
):
- if not _comb_needs_server(serve_comb, bench_params, output_dir):
+ if not _comb_needs_server(serve_comb, bench_params, experiment_dir):
return contextlib.nullcontext()
return run_server(
@@ -211,10 +216,10 @@ def run_comb(
*,
serve_comb: ParameterSweepItem,
bench_comb: ParameterSweepItem,
+ link_vars: list[tuple[str, str]],
base_path: Path,
num_runs: int,
dry_run: bool,
- link_vars: list[tuple[str, str]],
):
if not _comb_is_valid(serve_comb, bench_comb, link_vars):
return None
@@ -253,10 +258,10 @@ def run_combs(
server_ready_timeout: int,
serve_params: ParameterSweep,
bench_params: ParameterSweep,
- output_dir: Path,
+ link_vars: list[tuple[str, str]],
+ experiment_dir: Path,
num_runs: int,
dry_run: bool,
- link_vars: list[tuple[str, str]],
):
all_data = list[dict[str, object]]()
for serve_comb in serve_params:
@@ -266,22 +271,22 @@ def run_combs(
show_stdout=show_stdout,
serve_comb=serve_comb,
bench_params=bench_params,
- output_dir=output_dir,
+ experiment_dir=experiment_dir,
dry_run=dry_run,
server_ready_timeout=server_ready_timeout,
) as server:
for bench_comb in bench_params:
- base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb)
+ base_path = _get_comb_base_path(experiment_dir, serve_comb, bench_comb)
comb_data = run_comb(
server,
bench_cmd,
serve_comb=serve_comb,
bench_comb=bench_comb,
+ link_vars=link_vars,
base_path=base_path,
num_runs=num_runs,
dry_run=dry_run,
- link_vars=link_vars,
)
if comb_data is not None:
@@ -291,7 +296,7 @@ def run_combs(
return None
combined_df = pd.DataFrame.from_records(all_data)
- combined_df.to_csv(output_dir / "summary.csv")
+ combined_df.to_csv(experiment_dir / "summary.csv")
return combined_df
@@ -305,11 +310,12 @@ class SweepServeArgs:
server_ready_timeout: int
serve_params: ParameterSweep
bench_params: ParameterSweep
+ link_vars: list[tuple[str, str]]
output_dir: Path
+ experiment_name: str
num_runs: int
dry_run: bool
- resume: str | None
- link_vars: list[tuple[str, str]]
+ resume: bool
parser_name: ClassVar[str] = "serve"
parser_help: ClassVar[str] = "Run vLLM server benchmark under multiple settings."
@@ -336,6 +342,11 @@ class SweepServeArgs:
link_vars = cls.parse_link_vars(args.link_vars)
+ if args.experiment_name:
+ experiment_name = args.experiment_name
+ else:
+ experiment_name = datetime.now().strftime("%Y%m%d_%H%M%S")
+
num_runs = args.num_runs
if num_runs < 1:
raise ValueError("`num_runs` should be at least 1.")
@@ -347,11 +358,12 @@ class SweepServeArgs:
show_stdout=args.show_stdout,
serve_params=serve_params,
bench_params=bench_params,
+ link_vars=link_vars,
output_dir=Path(args.output_dir),
+ experiment_name=experiment_name,
num_runs=num_runs,
dry_run=args.dry_run,
resume=args.resume,
- link_vars=link_vars,
server_ready_timeout=args.server_ready_timeout,
)
@@ -388,6 +400,7 @@ class SweepServeArgs:
default=300,
help="Timeout in seconds to wait for the server to become ready.",
)
+
parser.add_argument(
"--serve-params",
type=str,
@@ -398,6 +411,16 @@ class SweepServeArgs:
"If both `serve_params` and `bench_params` are given, "
"this script will iterate over their Cartesian product.",
)
+ parser.add_argument(
+ "--link-vars",
+ type=str,
+ default="",
+ help=(
+ "Comma-separated list of linked variables between serve and bench, "
+ "e.g. max_num_seqs=max_concurrency,max_model_len=random_input_len"
+ ),
+ )
+
parser.add_argument(
"--bench-params",
type=str,
@@ -413,7 +436,15 @@ class SweepServeArgs:
"--output-dir",
type=str,
default="results",
- help="The directory to which results are written.",
+ help="The main directory to which results are written.",
+ )
+ parser.add_argument(
+ "-e",
+ "--experiment-name",
+ type=str,
+ default=None,
+ help="The name of this experiment (defaults to current timestamp). "
+ "Results will be stored under `output_dir/experiment_name`.",
)
parser.add_argument(
"--num-runs",
@@ -429,21 +460,10 @@ class SweepServeArgs:
)
parser.add_argument(
"--resume",
- type=str,
- default=None,
- help="Set this to the name of a directory under `output_dir` (which is a "
- "timestamp) to resume a previous execution of this script, i.e., only run "
- "parameter combinations for which there are still no output files.",
- )
-
- parser.add_argument(
- "--link-vars",
- type=str,
- default="",
- help=(
- "Comma-separated list of linked variables between serve and bench, "
- "e.g. max_num_seqs=max_concurrency,max_model_len=random_input_len"
- ),
+ action="store_true",
+ help="Resume a previous execution of this script, i.e., only run "
+ "parameter combinations for which there are still no output files "
+ "under `output_dir/experiment_name`.",
)
return parser
@@ -458,33 +478,52 @@ class SweepServeArgs:
pairs.append((a.strip(), b.strip()))
return pairs
+ def resolve_experiment_dir(self) -> Path:
+ experiment_dir = self.output_dir / self.experiment_name
+
+ if self.resume:
+ if not experiment_dir.exists():
+ raise ValueError(f"Cannot resume from non-existent {experiment_dir=}")
+ else:
+ if experiment_dir.exists():
+ raise ValueError(f"Cannot overwrite existing {experiment_dir=}")
+
+ return experiment_dir
+
+ @contextmanager
+ def run_ctx(self, experiment_dir: Path):
+ if self.dry_run:
+ yield
+ print(f"Experiment will be saved at: {experiment_dir}")
+ return
+
+ try:
+ yield
+ print(f"Experiment has been saved at: {experiment_dir}")
+ except BaseException as exc:
+ raise RuntimeError(
+ "The script was terminated early. Use `--resume` "
+ "to continue the script from its last checkpoint."
+ ) from exc
+
def run_main(args: SweepServeArgs):
- timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
- output_dir = args.output_dir / timestamp
+ experiment_dir = args.resolve_experiment_dir()
- if args.resume and not output_dir.exists():
- raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
-
- try:
+ with args.run_ctx(experiment_dir):
return run_combs(
serve_cmd=args.serve_cmd,
bench_cmd=args.bench_cmd,
+ link_vars=args.link_vars,
after_bench_cmd=args.after_bench_cmd,
show_stdout=args.show_stdout,
server_ready_timeout=args.server_ready_timeout,
serve_params=args.serve_params,
bench_params=args.bench_params,
- output_dir=output_dir,
+ experiment_dir=experiment_dir,
num_runs=args.num_runs,
dry_run=args.dry_run,
- link_vars=args.link_vars,
)
- except BaseException as exc:
- raise RuntimeError(
- f"The script was terminated early. Use `--resume {timestamp}` "
- f"to continue the script from its last checkpoint."
- ) from exc
def main(args: argparse.Namespace):
diff --git a/vllm/benchmarks/sweep/serve_sla.py b/vllm/benchmarks/sweep/serve_sla.py
deleted file mode 100644
index 89169ec..0000000
--- a/vllm/benchmarks/sweep/serve_sla.py
+++ /dev/null
@@ -1,305 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-import argparse
-import math
-from dataclasses import asdict, dataclass
-from datetime import datetime
-from pathlib import Path
-from typing import ClassVar, Literal, get_args
-
-import numpy as np
-from typing_extensions import assert_never
-
-from vllm.utils.import_utils import PlaceholderModule
-
-from .param_sweep import ParameterSweep, ParameterSweepItem
-from .serve import (
- SweepServeArgs,
- _get_comb_base_path,
- run_comb,
- server_ctx,
-)
-from .server import ServerProcess
-
-try:
- import pandas as pd
-except ImportError:
- pd = PlaceholderModule("pandas")
-
-
-SLAVariable = Literal["request_rate", "max_concurrency"]
-
-
-def _estimate_sla_value(run_data: dict[str, object], sla_variable: SLAVariable):
- request_throughput = float(run_data["request_throughput"]) # type: ignore
- if sla_variable == "request_rate":
- return request_throughput
- if sla_variable == "max_concurrency":
- mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
- return request_throughput * mean_latency_ms / 1000
-
- assert_never(sla_variable)
-
-
-def _estimate_sla_avg(runs: list[dict[str, object]], sla_variable: SLAVariable):
- return sum(_estimate_sla_value(run, sla_variable) for run in runs) / len(runs)
-
-
-def run_comb_sla(
- server: ServerProcess | None,
- bench_cmd: list[str],
- *,
- serve_comb: ParameterSweepItem,
- bench_comb: ParameterSweepItem,
- output_dir: Path,
- num_runs: int,
- dry_run: bool,
- link_vars: list[tuple[str, str]],
- sla_variable: SLAVariable,
- sla_value: int,
-) -> list[dict[str, object]] | None:
- bench_comb_sla = bench_comb | {sla_variable: sla_value}
-
- return run_comb(
- server,
- bench_cmd,
- serve_comb=serve_comb,
- bench_comb=bench_comb_sla,
- base_path=_get_comb_base_path(output_dir, serve_comb, bench_comb_sla),
- num_runs=num_runs,
- dry_run=dry_run,
- link_vars=link_vars,
- )
-
-
-def explore_sla(
- server: ServerProcess | None,
- bench_cmd: list[str],
- *,
- serve_comb: ParameterSweepItem,
- bench_comb: ParameterSweepItem,
- sla_variable: SLAVariable,
- sla_iters: int,
- output_dir: Path,
- num_runs: int,
- dry_run: bool,
- link_vars: list[tuple[str, str]],
-):
- print("[SLA START]")
- print(f"Serve parameters: {serve_comb.as_text() or '(None)'}")
- print(f"Bench parameters: {bench_comb.as_text() or '(None)'}")
- print(f"Number of SLA iterations: {sla_iters}")
-
- if sla_iters < 2:
- raise ValueError("`sla_iters` should be at least 2")
-
- serial_comb_data = run_comb_sla(
- server,
- bench_cmd,
- serve_comb=serve_comb,
- bench_comb=bench_comb,
- output_dir=output_dir,
- num_runs=num_runs,
- dry_run=dry_run,
- link_vars=link_vars,
- sla_variable=sla_variable,
- sla_value=1,
- )
- batch_comb_data = run_comb_sla(
- server,
- bench_cmd,
- serve_comb=serve_comb,
- bench_comb=bench_comb,
- output_dir=output_dir,
- num_runs=num_runs,
- dry_run=dry_run,
- link_vars=link_vars,
- sla_variable=sla_variable,
- sla_value=int(bench_comb.get("num_prompts", 1000)), # type: ignore
- )
-
- if serial_comb_data is None or batch_comb_data is None:
- if dry_run:
- print("Omitting intermediate SLA iterations.")
- print("[SLA END]")
-
- return
-
- serial_sla_value = math.ceil(_estimate_sla_avg(serial_comb_data, sla_variable))
- print(f"Serial inference: {sla_variable}={serial_sla_value}")
-
- batch_sla_value = math.floor(_estimate_sla_avg(batch_comb_data, sla_variable))
- print(f"Batch inference: {sla_variable}={batch_sla_value}")
-
- # Avoid duplicated runs for intermediate values if the range between
- # `serial_sla_value` and `batch_sla_value` is small
- inter_sla_values = np.linspace(serial_sla_value, batch_sla_value, sla_iters)[1:-1]
- inter_sla_values = sorted(set(map(round, inter_sla_values)))
-
- inter_combs_data: list[dict[str, object]] = []
- for inter_sla_value in inter_sla_values:
- print(f"Exploring: {sla_variable}={inter_sla_value}")
- inter_comb_data = run_comb_sla(
- server,
- bench_cmd,
- serve_comb=serve_comb,
- bench_comb=bench_comb,
- output_dir=output_dir,
- num_runs=num_runs,
- dry_run=dry_run,
- link_vars=link_vars,
- sla_variable=sla_variable,
- sla_value=inter_sla_value,
- )
- if inter_comb_data is not None:
- inter_combs_data.extend(inter_comb_data)
-
- print("[SLA END]")
-
- return serial_comb_data + inter_combs_data + batch_comb_data
-
-
-def run_slas(
- serve_cmd: list[str],
- bench_cmd: list[str],
- after_bench_cmd: list[str],
- *,
- show_stdout: bool,
- server_ready_timeout: int,
- serve_params: ParameterSweep,
- bench_params: ParameterSweep,
- sla_variable: SLAVariable,
- sla_iters: int,
- output_dir: Path,
- num_runs: int,
- dry_run: bool,
- link_vars: list[tuple[str, str]],
-):
- if any(bench_comb.has_param(sla_variable) for bench_comb in bench_params):
- raise ValueError(
- f"You should not override `{sla_variable}` in `bench_params` in SLA mode, "
- "since it is supposed to be determined automatically."
- )
-
- all_data = list[dict[str, object]]()
- for serve_comb in serve_params:
- with server_ctx(
- serve_cmd,
- after_bench_cmd,
- show_stdout=show_stdout,
- server_ready_timeout=server_ready_timeout,
- serve_comb=serve_comb,
- bench_params=bench_params,
- output_dir=output_dir,
- dry_run=dry_run,
- ) as server:
- for bench_comb in bench_params:
- comb_data = explore_sla(
- server,
- bench_cmd,
- serve_comb=serve_comb,
- bench_comb=bench_comb,
- sla_variable=sla_variable,
- sla_iters=sla_iters,
- output_dir=output_dir,
- num_runs=num_runs,
- dry_run=dry_run,
- link_vars=link_vars,
- )
-
- if comb_data is not None:
- all_data.extend(comb_data)
-
- if dry_run:
- return None
-
- combined_df = pd.DataFrame.from_records(all_data)
- combined_df.to_csv(output_dir / "summary.csv")
-
- return combined_df
-
-
-@dataclass
-class SweepServeSLAArgs(SweepServeArgs):
- sla_variable: SLAVariable
- sla_iters: int
-
- parser_name: ClassVar[str] = "serve_sla"
- parser_help: ClassVar[str] = (
- "Explore the latency-throughput space for determining SLAs."
- )
-
- @classmethod
- def from_cli_args(cls, args: argparse.Namespace):
- # NOTE: Don't use super() as `from_cli_args` calls `cls()`
- base_args = SweepServeArgs.from_cli_args(args)
-
- return cls(
- **asdict(base_args),
- sla_variable=args.sla_variable,
- sla_iters=args.sla_iters,
- )
-
- @classmethod
- def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
- parser = super().add_cli_args(parser)
-
- sla_group = parser.add_argument_group("sla options")
- sla_group.add_argument(
- "--sla-variable",
- type=str,
- choices=get_args(SLAVariable),
- default="request_rate",
- help="The variable to adjust in each iteration.",
- )
- sla_group.add_argument(
- "--sla-iters",
- type=int,
- default=10,
- help="Number of iterations used to explore the latency-throughput space. "
- "This includes the first two iterations used to interpolate the value of "
- "`sla_variable` for remaining iterations.",
- )
-
- return parser
-
-
-def run_main(args: SweepServeSLAArgs):
- timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
- output_dir = args.output_dir / timestamp
-
- if args.resume and not output_dir.exists():
- raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
-
- try:
- return run_slas(
- serve_cmd=args.serve_cmd,
- bench_cmd=args.bench_cmd,
- after_bench_cmd=args.after_bench_cmd,
- show_stdout=args.show_stdout,
- server_ready_timeout=args.server_ready_timeout,
- serve_params=args.serve_params,
- bench_params=args.bench_params,
- sla_variable=args.sla_variable,
- sla_iters=args.sla_iters,
- output_dir=output_dir,
- num_runs=args.num_runs,
- dry_run=args.dry_run,
- link_vars=args.link_vars,
- )
- except BaseException as exc:
- raise RuntimeError(
- f"The script was terminated early. Use `--resume {timestamp}` "
- f"to continue the script from its last checkpoint."
- ) from exc
-
-
-def main(args: argparse.Namespace):
- run_main(SweepServeSLAArgs.from_cli_args(args))
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser(description=SweepServeSLAArgs.parser_help)
- SweepServeSLAArgs.add_cli_args(parser)
-
- main(parser.parse_args())
diff --git a/vllm/benchmarks/sweep/serve_workload.py b/vllm/benchmarks/sweep/serve_workload.py
new file mode 100644
index 0000000..ca7ba09
--- /dev/null
+++ b/vllm/benchmarks/sweep/serve_workload.py
@@ -0,0 +1,328 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import argparse
+import math
+from dataclasses import asdict, dataclass
+from pathlib import Path
+from typing import ClassVar, Literal, get_args
+
+import numpy as np
+from typing_extensions import assert_never
+
+from vllm.benchmarks.datasets import DEFAULT_NUM_PROMPTS
+from vllm.utils.import_utils import PlaceholderModule
+
+from .param_sweep import ParameterSweep, ParameterSweepItem
+from .serve import (
+ SweepServeArgs,
+ _get_comb_base_path,
+ run_comb,
+ server_ctx,
+)
+from .server import ServerProcess
+
+try:
+ import pandas as pd
+except ImportError:
+ pd = PlaceholderModule("pandas")
+
+
+WorkloadVariable = Literal["request_rate", "max_concurrency"]
+
+
+def _estimate_workload_value(
+ run_data: dict[str, object],
+ workload_var: WorkloadVariable,
+):
+ request_throughput = float(run_data["request_throughput"]) # type: ignore
+ if workload_var == "request_rate":
+ return request_throughput
+ if workload_var == "max_concurrency":
+ mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore
+ return request_throughput * mean_latency_ms / 1000
+
+ assert_never(workload_var)
+
+
+def _estimate_workload_avg(
+ runs: list[dict[str, object]],
+ workload_var: WorkloadVariable,
+):
+ total = sum(_estimate_workload_value(run, workload_var) for run in runs)
+ return total / len(runs)
+
+
+def run_comb_workload(
+ server: ServerProcess | None,
+ bench_cmd: list[str],
+ *,
+ serve_comb: ParameterSweepItem,
+ bench_comb: ParameterSweepItem,
+ link_vars: list[tuple[str, str]],
+ experiment_dir: Path,
+ num_runs: int,
+ dry_run: bool,
+ workload_var: WorkloadVariable,
+ workload_value: int,
+) -> list[dict[str, object]] | None:
+ bench_comb_workload = bench_comb | {workload_var: workload_value}
+
+ return run_comb(
+ server,
+ bench_cmd,
+ serve_comb=serve_comb,
+ bench_comb=bench_comb_workload,
+ link_vars=link_vars,
+ base_path=_get_comb_base_path(
+ experiment_dir,
+ serve_comb,
+ bench_comb,
+ extra_parts=("WL-", f"{workload_var}={workload_value}"),
+ ),
+ num_runs=num_runs,
+ dry_run=dry_run,
+ )
+
+
+def explore_comb_workloads(
+ server: ServerProcess | None,
+ bench_cmd: list[str],
+ *,
+ serve_comb: ParameterSweepItem,
+ bench_comb: ParameterSweepItem,
+ link_vars: list[tuple[str, str]],
+ workload_var: WorkloadVariable,
+ workload_iters: int,
+ experiment_dir: Path,
+ num_runs: int,
+ dry_run: bool,
+):
+ print("[WL START]")
+ print(f"Serve parameters: {serve_comb.as_text() or '(None)'}")
+ print(f"Bench parameters: {bench_comb.as_text() or '(None)'}")
+ print(f"Number of workload iterations: {workload_iters}")
+
+ if workload_iters < 2:
+ raise ValueError("`workload_iters` should be at least 2")
+
+ dataset_size = DEFAULT_NUM_PROMPTS
+ if "num_prompts" in bench_comb:
+ dataset_size = int(bench_comb["num_prompts"]) # type: ignore
+ else:
+ for i, arg in enumerate(bench_cmd):
+ if arg == "--num-prompts" and i + 1 < len(bench_cmd):
+ dataset_size = int(bench_cmd[i + 1])
+ break
+ elif arg.startswith("--num-prompts="):
+ dataset_size = int(arg.split("=", 1)[1])
+ break
+
+ print(f"Dataset size: {dataset_size}")
+
+ serial_workload_data = run_comb_workload(
+ server,
+ bench_cmd,
+ serve_comb=serve_comb,
+ bench_comb=bench_comb | {"max_concurrency": 1},
+ link_vars=link_vars,
+ experiment_dir=experiment_dir,
+ num_runs=num_runs,
+ dry_run=dry_run,
+ workload_var=workload_var,
+ workload_value=1,
+ )
+ batch_workload_data = run_comb_workload(
+ server,
+ bench_cmd,
+ serve_comb=serve_comb,
+ bench_comb=bench_comb | {"max_concurrency": dataset_size},
+ link_vars=link_vars,
+ experiment_dir=experiment_dir,
+ num_runs=num_runs,
+ dry_run=dry_run,
+ workload_var=workload_var,
+ workload_value=dataset_size,
+ )
+
+ if serial_workload_data is None or batch_workload_data is None:
+ if dry_run:
+ print("Omitting intermediate Workload iterations.")
+ print("[WL END]")
+
+ return
+
+ serial_workload_value = math.ceil(
+ _estimate_workload_avg(serial_workload_data, workload_var)
+ )
+ print(f"Serial inference: {workload_var}={serial_workload_value}")
+
+ batch_workload_value = math.floor(
+ _estimate_workload_avg(batch_workload_data, workload_var)
+ )
+ print(f"Batch inference: {workload_var}={batch_workload_value}")
+
+ # Avoid duplicated runs for intermediate values if the range between
+ # `serial_workload_value` and `batch_workload_value` is small
+ inter_workload_values = np.linspace(
+ serial_workload_value, batch_workload_value, workload_iters
+ )[1:-1]
+ inter_workload_values = sorted(set(map(round, inter_workload_values)))
+
+ inter_workloads_data: list[dict[str, object]] = []
+ for inter_workload_value in inter_workload_values:
+ print(f"Exploring: {workload_var}={inter_workload_value}")
+ inter_workload_data = run_comb_workload(
+ server,
+ bench_cmd,
+ serve_comb=serve_comb,
+ bench_comb=bench_comb,
+ link_vars=link_vars,
+ experiment_dir=experiment_dir,
+ num_runs=num_runs,
+ dry_run=dry_run,
+ workload_var=workload_var,
+ workload_value=inter_workload_value,
+ )
+ if inter_workload_data is not None:
+ inter_workloads_data.extend(inter_workload_data)
+
+ print("[WL END]")
+
+ return serial_workload_data + inter_workloads_data + batch_workload_data
+
+
+def explore_combs_workloads(
+ serve_cmd: list[str],
+ bench_cmd: list[str],
+ after_bench_cmd: list[str],
+ *,
+ show_stdout: bool,
+ server_ready_timeout: int,
+ serve_params: ParameterSweep,
+ bench_params: ParameterSweep,
+ link_vars: list[tuple[str, str]],
+ workload_var: WorkloadVariable,
+ workload_iters: int,
+ experiment_dir: Path,
+ num_runs: int,
+ dry_run: bool,
+):
+ if any(bench_comb.has_param(workload_var) for bench_comb in bench_params):
+ raise ValueError(
+ f"You should not override `{workload_var}` in `bench_params` "
+ "since it is supposed to be explored automatically."
+ )
+
+ all_data = list[dict[str, object]]()
+ for serve_comb in serve_params:
+ with server_ctx(
+ serve_cmd,
+ after_bench_cmd,
+ show_stdout=show_stdout,
+ server_ready_timeout=server_ready_timeout,
+ serve_comb=serve_comb,
+ bench_params=bench_params,
+ experiment_dir=experiment_dir,
+ dry_run=dry_run,
+ ) as server:
+ for bench_comb in bench_params:
+ comb_data = explore_comb_workloads(
+ server,
+ bench_cmd,
+ serve_comb=serve_comb,
+ bench_comb=bench_comb,
+ link_vars=link_vars,
+ workload_var=workload_var,
+ workload_iters=workload_iters,
+ experiment_dir=experiment_dir,
+ num_runs=num_runs,
+ dry_run=dry_run,
+ )
+
+ if comb_data is not None:
+ all_data.extend(comb_data)
+
+ if dry_run:
+ return None
+
+ combined_df = pd.DataFrame.from_records(all_data)
+ combined_df.to_csv(experiment_dir / "summary.csv")
+
+ return combined_df
+
+
+@dataclass
+class SweepServeWorkloadArgs(SweepServeArgs):
+ workload_var: WorkloadVariable
+ workload_iters: int
+
+ parser_name: ClassVar[str] = "serve_workload"
+ parser_help: ClassVar[str] = (
+ "Explore the latency-throughput tradeoff for different workload levels."
+ )
+
+ @classmethod
+ def from_cli_args(cls, args: argparse.Namespace):
+ # NOTE: Don't use super() as `from_cli_args` calls `cls()`
+ base_args = SweepServeArgs.from_cli_args(args)
+
+ return cls(
+ **asdict(base_args),
+ workload_var=args.workload_var,
+ workload_iters=args.workload_iters,
+ )
+
+ @classmethod
+ def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
+ parser = super().add_cli_args(parser)
+
+ workload_group = parser.add_argument_group("workload options")
+ workload_group.add_argument(
+ "--workload-var",
+ type=str,
+ choices=get_args(WorkloadVariable),
+ default="request_rate",
+ help="The variable to adjust in each iteration.",
+ )
+ workload_group.add_argument(
+ "--workload-iters",
+ type=int,
+ default=10,
+ help="Number of workload levels to explore. "
+ "This includes the first two iterations used to interpolate the value of "
+ "`workload_var` for remaining iterations.",
+ )
+
+ return parser
+
+
+def run_main(args: SweepServeWorkloadArgs):
+ experiment_dir = args.resolve_experiment_dir()
+
+ with args.run_ctx(experiment_dir):
+ return explore_combs_workloads(
+ serve_cmd=args.serve_cmd,
+ bench_cmd=args.bench_cmd,
+ after_bench_cmd=args.after_bench_cmd,
+ show_stdout=args.show_stdout,
+ server_ready_timeout=args.server_ready_timeout,
+ serve_params=args.serve_params,
+ bench_params=args.bench_params,
+ link_vars=args.link_vars,
+ workload_var=args.workload_var,
+ workload_iters=args.workload_iters,
+ experiment_dir=experiment_dir,
+ num_runs=args.num_runs,
+ dry_run=args.dry_run,
+ )
+
+
+def main(args: argparse.Namespace):
+ run_main(SweepServeWorkloadArgs.from_cli_args(args))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description=SweepServeWorkloadArgs.parser_help)
+ SweepServeWorkloadArgs.add_cli_args(parser)
+
+ main(parser.parse_args())
diff --git a/vllm/benchmarks/sweep/startup.py b/vllm/benchmarks/sweep/startup.py
index b4d979b..6f5217e 100644
--- a/vllm/benchmarks/sweep/startup.py
+++ b/vllm/benchmarks/sweep/startup.py
@@ -4,6 +4,7 @@ import argparse
import json
import shlex
import subprocess
+from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from functools import lru_cache
@@ -111,7 +112,7 @@ def _apply_output_json(cmd: list[str], output_path: Path) -> list[str]:
def _get_comb_base_path(
- output_dir: Path,
+ experiment_dir: Path,
serve_comb: ParameterSweepItem,
startup_comb: ParameterSweepItem,
) -> Path:
@@ -120,7 +121,8 @@ def _get_comb_base_path(
parts.extend(("SERVE-", serve_comb.name))
if startup_comb:
parts.extend(("STARTUP-", startup_comb.name))
- return output_dir / sanitize_filename("-".join(parts))
+
+ return experiment_dir / sanitize_filename("-".join(parts))
def _get_comb_run_path(base_path: Path, run_number: int | None) -> Path:
@@ -225,7 +227,7 @@ def run_combs(
*,
serve_params: ParameterSweep,
startup_params: ParameterSweep,
- output_dir: Path,
+ experiment_dir: Path,
num_runs: int,
show_stdout: bool,
dry_run: bool,
@@ -233,7 +235,7 @@ def run_combs(
all_data = list[dict[str, object]]()
for serve_comb in serve_params:
for startup_comb in startup_params:
- base_path = _get_comb_base_path(output_dir, serve_comb, startup_comb)
+ base_path = _get_comb_base_path(experiment_dir, serve_comb, startup_comb)
comb_data = run_comb(
startup_cmd,
serve_comb=serve_comb,
@@ -250,7 +252,7 @@ def run_combs(
return None
combined_df = pd.DataFrame.from_records(all_data)
- combined_df.to_csv(output_dir / "summary.csv")
+ combined_df.to_csv(experiment_dir / "summary.csv")
return combined_df
@@ -260,11 +262,11 @@ class SweepStartupArgs:
serve_params: ParameterSweep
startup_params: ParameterSweep
output_dir: Path
+ experiment_name: str
num_runs: int
show_stdout: bool
dry_run: bool
- resume: str | None
- strict_params: bool
+ resume: bool
parser_name: ClassVar[str] = "startup"
parser_help: ClassVar[str] = (
@@ -286,13 +288,19 @@ class SweepStartupArgs:
startup_params = ParameterSweep.from_records([{}])
supported = _get_supported_startup_keys()
+ strict_params = args.strict_params
serve_params = _filter_params(
- serve_params, supported=supported, strict=args.strict_params
+ serve_params, supported=supported, strict=strict_params
)
startup_params = _filter_params(
- startup_params, supported=supported, strict=args.strict_params
+ startup_params, supported=supported, strict=strict_params
)
+ if args.experiment_name:
+ experiment_name = args.experiment_name
+ else:
+ experiment_name = datetime.now().strftime("%Y%m%d_%H%M%S")
+
if args.num_runs < 1:
raise ValueError("`num_runs` should be at least 1.")
@@ -301,11 +309,11 @@ class SweepStartupArgs:
serve_params=serve_params,
startup_params=startup_params,
output_dir=Path(args.output_dir),
+ experiment_name=experiment_name,
num_runs=args.num_runs,
show_stdout=args.show_stdout,
dry_run=args.dry_run,
resume=args.resume,
- strict_params=args.strict_params,
)
@classmethod
@@ -316,6 +324,7 @@ class SweepStartupArgs:
default="vllm bench startup",
help="The command used to run the startup benchmark.",
)
+
parser.add_argument(
"--serve-params",
type=str,
@@ -331,12 +340,27 @@ class SweepStartupArgs:
help="Path to JSON file containing parameter combinations "
"for the `vllm bench startup` command.",
)
+ parser.add_argument(
+ "--strict-params",
+ action="store_true",
+ help="If set, unknown parameters in sweep files raise an error "
+ "instead of being ignored.",
+ )
+
parser.add_argument(
"-o",
"--output-dir",
type=str,
default="results",
- help="The directory to which results are written.",
+ help="The main directory to which results are written.",
+ )
+ parser.add_argument(
+ "-e",
+ "--experiment-name",
+ type=str,
+ default=None,
+ help="The name of this experiment (defaults to current timestamp). "
+ "Results will be stored under `output_dir/experiment_name`.",
)
parser.add_argument(
"--num-runs",
@@ -357,43 +381,56 @@ class SweepStartupArgs:
)
parser.add_argument(
"--resume",
- type=str,
- default=None,
- help="Set this to the name of a directory under `output_dir` (which is a "
- "timestamp) to resume a previous execution of this script, i.e., only run "
- "parameter combinations for which there are still no output files.",
- )
- parser.add_argument(
- "--strict-params",
action="store_true",
- help="If set, unknown parameters in sweep files raise an error "
- "instead of being ignored.",
+ help="Resume a previous execution of this script, i.e., only run "
+ "parameter combinations for which there are still no output files "
+ "under `output_dir/experiment_name`.",
)
+
return parser
+ def resolve_experiment_dir(self) -> Path:
+ experiment_dir = self.output_dir / self.experiment_name
+
+ if self.resume:
+ if not experiment_dir.exists():
+ raise ValueError(f"Cannot resume from non-existent {experiment_dir=}")
+ else:
+ if experiment_dir.exists():
+ raise ValueError(f"Cannot overwrite existing {experiment_dir=}")
+
+ return experiment_dir
+
+ @contextmanager
+ def run_ctx(self, experiment_dir: Path):
+ if self.dry_run:
+ yield
+ print(f"Experiment will be saved at: {experiment_dir}")
+ return
+
+ try:
+ yield
+ print(f"Experiment has been saved at: {experiment_dir}")
+ except BaseException as exc:
+ raise RuntimeError(
+ "The script was terminated early. Use `--resume` "
+ "to continue the script from its last checkpoint."
+ ) from exc
+
def run_main(args: SweepStartupArgs):
- timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S")
- output_dir = args.output_dir / timestamp
+ experiment_dir = args.resolve_experiment_dir()
- if args.resume and not output_dir.exists():
- raise ValueError(f"Cannot resume from non-existent directory ({output_dir})")
-
- try:
+ with args.run_ctx(experiment_dir):
return run_combs(
startup_cmd=args.startup_cmd,
serve_params=args.serve_params,
startup_params=args.startup_params,
- output_dir=output_dir,
+ experiment_dir=experiment_dir,
num_runs=args.num_runs,
show_stdout=args.show_stdout,
dry_run=args.dry_run,
)
- except BaseException as exc:
- raise RuntimeError(
- f"The script was terminated early. Use `--resume {timestamp}` "
- f"to continue the script from its last checkpoint."
- ) from exc
def main(args: argparse.Namespace):
diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py
index b974ca0..843c042 100644
--- a/vllm/compilation/backends.py
+++ b/vllm/compilation/backends.py
@@ -282,61 +282,13 @@ class CompilerManager:
maybe_key += f"{compile_range.start}_{compile_range.end}"
maybe_key += f"_subgraph_{graph_index}"
with self.compile_context(compile_range):
- # There is a compilation time optimization here.
- #
- # If the (input metadata, graph, compiler config) are the same, then
- # we want to avoid compiling the same artifact again. If we didn't
- # do this optimization, the backend compilation (InductorAdaptor or
- # InductorStandaloneAdaptor)
- # is able to cache hit and produce an artifact faster if it was
- # already created, but it is still a duplicate artifact that
- # requires unnecessary things e.g. disk IO.
- #
- # The optimization is: If the backend compilation cache hits,
- # then do an early return from the backend compilation and look up
- # which of the previous in-memory artifacts we created to reuse.
- #
- # We implemented this by monkey-patching torch (torch does not
- # easily expose the cache_key function), but in the future torch
- # should expose the cache_key function that we can just call
- # directly before invoking backend compilation.
- cache_key = None
- orig = torch._functorch._aot_autograd.autograd_cache.autograd_cache_key
-
- def autograd_cache_key(*args, **kwargs):
- result = orig(*args, **kwargs)
- if result is None:
- return None
- nonlocal cache_key
- cache_key = result[0]
- if cache_key in self.loaded_artifacts:
- raise StopCompiling()
- return result
-
- from unittest.mock import patch
-
- with (
- # Graphs that are isometric (different node names but same
- # structure) should be treated as the same.
- torch._functorch.config.patch(autograd_cache_normalize_inputs=True),
- patch(
- "torch._functorch._aot_autograd.autograd_cache.autograd_cache_key",
- autograd_cache_key,
- ),
- ):
- try:
- compiled_graph, handle = self.compiler.compile(
- graph,
- example_inputs,
- additional_inductor_config,
- compile_range,
- maybe_key,
- )
- except StopCompiling:
- assert cache_key is not None
- return self.loaded_artifacts[cache_key]
- if cache_key is not None and compiled_graph is not None:
- self.loaded_artifacts[cache_key] = compiled_graph
+ compiled_graph, handle = self.compiler.compile(
+ graph,
+ example_inputs,
+ additional_inductor_config,
+ compile_range,
+ maybe_key,
+ )
assert compiled_graph is not None, "Failed to compile the graph"
@@ -497,7 +449,7 @@ def wrap_with_cudagraph_if_needed(
# it from the FULL cudagraph runtime mode, no matter it
# is wrapped on a full or piecewise fx graph.
return static_graph_wrapper_class(
- runnable=piecewise_backend,
+ runnable=piecewise_backend.graph.forward,
vllm_config=vllm_config,
runtime_mode=CUDAGraphMode.PIECEWISE,
cudagraph_options=CUDAGraphOptions(
@@ -780,7 +732,7 @@ class VllmBackend:
return standalone_compile_artifacts, sym_shape_indices_map, returns_tuple_map
def configure_post_pass(self) -> None:
- # self.pass_manager.configure(self.vllm_config)
+ self.pass_manager.configure(self.vllm_config)
# Post-grad custom passes are run using the post_grad_custom_post_pass
# hook. If a pass for that hook exists, add it to the pass manager.
@@ -846,7 +798,7 @@ class VllmBackend:
),
)
- def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any:
+ def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any], **kwargs) -> Any:
from .caching import (
VllmSerializableFunction,
)
@@ -988,7 +940,7 @@ class VllmBackend:
assert not self._called, "VllmBackend can only be called once"
self.graph = graph
- self.configure_post_pass()
+ # self.configure_post_pass()
if self.compilation_config.use_inductor_graph_partition:
# Let Inductor decide partitioning; avoid FX-level pre-splitting.
diff --git a/vllm/compilation/caching.py b/vllm/compilation/caching.py
index 07f9db4..3917a4f 100644
--- a/vllm/compilation/caching.py
+++ b/vllm/compilation/caching.py
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import contextlib
import hashlib
import inspect
import os
@@ -144,6 +145,18 @@ class StandaloneCompiledArtifacts:
self.loaded_submodule_store = {}
+@contextlib.contextmanager
+def patch_pytree_map_over_slice():
+ pytree._private_register_pytree_node(
+ slice, lambda x: ([x.start, x.stop, x.step], None), lambda x, c: slice(*x)
+ )
+
+ try:
+ yield
+ finally:
+ pytree._deregister_pytree_node(slice)
+
+
class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
"""
A wrapper around a compiled function by vllm. It will forward the tensor
@@ -235,7 +248,10 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
lambda inp: torch.empty_like(inp, device="meta"),
state["example_inputs"],
)
- with patch.object(GraphPickler, "reducer_override", _graph_reducer_override):
+ with (
+ patch.object(GraphPickler, "reducer_override", _graph_reducer_override),
+ patch_pytree_map_over_slice(),
+ ):
state["graph_module"] = GraphPickler.dumps(
state["graph_module"], Options(ops_filter=None)
)
@@ -261,7 +277,8 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
state = pickle.loads(data)
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
- state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
+ with patch_pytree_map_over_slice():
+ state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
state["graph_module"].recompile()
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py
index c00486a..b6fa5d1 100644
--- a/vllm/compilation/compiler_interface.py
+++ b/vllm/compilation/compiler_interface.py
@@ -184,6 +184,47 @@ def is_compile_cache_enabled(
)
+def _patch_standalone_compile_atomic_save() -> None:
+ """Backport of pytorch/pytorch#162432 for torch < 2.10.0.
+
+ Patches CompiledArtifact.save() to use write_atomic for binary format,
+ preventing corrupt cache files when multiple processes compile
+ concurrently.
+ """
+ from torch._inductor.codecache import write_atomic
+ from torch._inductor.standalone_compile import CompiledArtifact as cls
+
+ if getattr(cls.save, "_vllm_patched", False):
+ return
+
+ original_save = cls.save
+
+ def _save(
+ self: Any, *, path: str, format: Literal["binary", "unpacked"] = "binary"
+ ) -> None:
+ if format != "binary":
+ return original_save(self, path=path, format=format)
+ from torch._dynamo.utils import dynamo_timed
+ from torch._inductor.codecache import torch_key
+ from torch.utils._appending_byte_serializer import BytesWriter
+
+ with dynamo_timed("CompiledArtifact.save"):
+ assert self._artifacts is not None
+ artifact_bytes, cache_info = self._artifacts
+ assert len(cache_info.aot_autograd_artifacts) == 1, cache_info
+ key = cache_info.aot_autograd_artifacts[0]
+ assert not os.path.isdir(path)
+ writer = BytesWriter()
+ writer.write_bytes(torch_key())
+ writer.write_str(key)
+ writer.write_bytes(artifact_bytes)
+ write_atomic(path, writer.to_bytes())
+
+ _save._vllm_patched = True # type: ignore[attr-defined]
+ cls.save = _save # type: ignore[assignment]
+ logger.debug("Patched %s.save for atomic writes (torch < 2.10)", cls.__name__)
+
+
class InductorStandaloneAdaptor(CompilerInterface):
"""
The adaptor for the Inductor compiler.
@@ -197,6 +238,8 @@ class InductorStandaloneAdaptor(CompilerInterface):
name = "inductor_standalone"
def __init__(self, save_format: Literal["binary", "unpacked"]) -> None:
+ if not is_torch_equal_or_newer("2.10.0"):
+ _patch_standalone_compile_atomic_save()
self.save_format = save_format
def compute_hash(self, vllm_config: VllmConfig) -> str:
@@ -224,7 +267,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
if compiler_config is not None:
current_config.update(compiler_config)
set_inductor_config(current_config, compile_range)
- set_functorch_config()
+ # set_functorch_config()
if compile_range.is_single_size():
dynamic_shapes = "from_example_inputs"
@@ -325,6 +368,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
path=path, format=self.save_format
)
+ compilation_counter.num_compiled_artifacts_loaded += 1
from torch._inductor.compile_fx import graph_returns_tuple
returns_tuple = graph_returns_tuple(graph)
@@ -395,7 +439,7 @@ class InductorAdaptor(CompilerInterface):
current_config["fx_graph_remote_cache"] = False
set_inductor_config(current_config, compile_range)
- set_functorch_config()
+ # set_functorch_config()
# inductor can inplace modify the graph, so we need to copy it
# see https://github.com/pytorch/pytorch/issues/138980
diff --git a/vllm/compilation/counter.py b/vllm/compilation/counter.py
index 29d3045..2ed49b9 100644
--- a/vllm/compilation/counter.py
+++ b/vllm/compilation/counter.py
@@ -29,6 +29,8 @@ class CompilationCounter:
num_cache_entries_updated: int = 0
# The number of standalone_compile compiled artifacts saved
num_compiled_artifacts_saved: int = 0
+ # The number of standalone_compile compiled artifacts loaded from cache
+ num_compiled_artifacts_loaded: int = 0
# Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE
stock_torch_compile_count: int = 0
diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py
index 2014808..dfc9dc8 100644
--- a/vllm/compilation/cuda_graph.py
+++ b/vllm/compilation/cuda_graph.py
@@ -21,6 +21,7 @@ from vllm.model_executor.offloader.base import get_offloader
from vllm.platforms import current_platform
from vllm.utils.torch_utils import current_stream, weak_ref_tensors
from vllm.sequence import IntermediateTensors
+
logger = init_logger(__name__)
@@ -204,14 +205,14 @@ class CUDAGraphWrapper:
def unwrap(self) -> Callable[..., Any]:
# in case we need to access the original runnable.
return self.runnable
-
+
def weak_ref_tensors_with_intermediate(self, output):
if isinstance(output, IntermediateTensors):
intermediate_states = IntermediateTensors(
tensors={key: weak_ref_tensors(value) for key, value in output.tensors.items()})
return intermediate_states
return weak_ref_tensors(output)
-
+
def __call__(self, *args: Any, **kwargs: Any) -> Any | None:
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
@@ -298,12 +299,10 @@ class CUDAGraphWrapper:
# the last graph in piecewise cuadgraph mode, because
# the output of the last graph will not be used by
# any other cuda graph.
- # output = weak_ref_tensors(output)
output = self.weak_ref_tensors_with_intermediate(output)
# here we always use weak ref for the output
# to save memory
- # entry.output = weak_ref_tensors(output)
entry.output = self.weak_ref_tensors_with_intermediate(output)
entry.cudagraph = cudagraph
diff --git a/vllm/compilation/passes/fusion/collective_fusion.py b/vllm/compilation/passes/fusion/collective_fusion.py
index 55a5a2e..a9b64ad 100644
--- a/vllm/compilation/passes/fusion/collective_fusion.py
+++ b/vllm/compilation/passes/fusion/collective_fusion.py
@@ -53,7 +53,7 @@ class GEMMReduceScatterPattern(BasePattern):
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
mul,
mm_weight,
- "avg",
+ "sum",
scatter_dim=0,
group_name=self.tp.device_group.group_name,
)
@@ -150,7 +150,7 @@ class ScaledMMReduceScatterPattern(BasePattern):
mat2,
scale_a,
scale_b,
- "avg",
+ "sum",
scatter_dim, # orig_scatter_dim
scatter_dim, # scatter_dim_after_maybe_reshape
self.tp.device_group.group_name,
@@ -285,7 +285,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern):
mat2,
scale_a,
scale_b,
- "avg",
+ "sum",
scatter_dim, # orig_scatter_dim
scatter_dim, # scatter_dim_after_maybe_reshape
self.tp.device_group.group_name,
diff --git a/vllm/compilation/passes/fusion/rocm_aiter_fusion.py b/vllm/compilation/passes/fusion/rocm_aiter_fusion.py
index d8131ce..59c94db 100644
--- a/vllm/compilation/passes/fusion/rocm_aiter_fusion.py
+++ b/vllm/compilation/passes/fusion/rocm_aiter_fusion.py
@@ -5,7 +5,6 @@ import torch
import torch._inductor.pattern_matcher as pm
from torch import fx
from torch._inductor.pattern_matcher import PatternMatcherPass
-from torch._ops import OpOverload
import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401
from vllm._aiter_ops import rocm_aiter_ops
@@ -15,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
ScaleDesc,
+ kFp8Dynamic128Sym,
)
from vllm.platforms import current_platform
@@ -312,7 +312,9 @@ class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass):
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
self.matched_count = self.patterns.apply(graph)
- logger.debug("Replaced %s patterns", self.matched_count)
+ logger.debug(
+ "%s Replaced %s patterns", self.__class__.__name__, self.matched_count
+ )
def uuid(self) -> str:
fusion_patterns = [
@@ -332,9 +334,11 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op()
- def __init__(self, quant_op: OpOverload) -> None:
+ def __init__(self) -> None:
self.silu_and_mul_matcher = MatcherSiluAndMul()
- self.quant_op = quant_op
+ self.quant_matcher = MatcherQuantFP8(
+ quant_key=kFp8Dynamic128Sym, match_rocm_aiter=True
+ )
def get_inputs(self) -> list[torch.Tensor]:
return [
@@ -346,7 +350,7 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern):
input: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
at1 = self.silu_and_mul_matcher(input)
- at2 = self.quant_op(at1, 128)
+ at2 = self.quant_matcher(at1)
return at2[0], at2[1]
def replacement(
@@ -370,11 +374,6 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
- AITER_GROUP_FP8_QUANT_OP = rocm_aiter_ops.get_group_quant_op()
- TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default
-
- QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP]
-
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
super().__init__(config)
@@ -383,8 +382,7 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass):
pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass"
)
- for quant_op in self.QUANT_OPS:
- AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns)
+ AiterSiluMulFp8GroupQuantPattern().register(self.patterns)
self.dump_patterns(config, self.patterns)
diff --git a/vllm/compilation/passes/fusion/sequence_parallelism.py b/vllm/compilation/passes/fusion/sequence_parallelism.py
index 63de859..b7ae3dc 100644
--- a/vllm/compilation/passes/fusion/sequence_parallelism.py
+++ b/vllm/compilation/passes/fusion/sequence_parallelism.py
@@ -18,7 +18,6 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
-from vllm.platforms import current_platform
from ..inductor_pass import enable_fake_mode
from ..utility.noop_elimination import NoOpEliminationPass
@@ -215,9 +214,6 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper):
)
-FP8_DTYPE = current_platform.fp8_dtype()
-
-
class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper):
def __init__(
self,
diff --git a/vllm/compilation/passes/utility/fix_functionalization.py b/vllm/compilation/passes/utility/fix_functionalization.py
index c7df5f9..1b656d0 100644
--- a/vllm/compilation/passes/utility/fix_functionalization.py
+++ b/vllm/compilation/passes/utility/fix_functionalization.py
@@ -37,6 +37,14 @@ class FixFunctionalizationPass(VllmInductorPass):
self.nodes_to_remove: list[torch.fx.Node] = []
count = 0
+
+ rope_targets = [torch.ops._C.rotary_embedding.default]
+
+ if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"):
+ rope_targets.append(
+ torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default
+ )
+
for node in graph.nodes:
if not is_func(node, auto_functionalized):
continue # Avoid deep if-elif nesting
@@ -44,7 +52,7 @@ class FixFunctionalizationPass(VllmInductorPass):
kwargs = node.kwargs
at_target = node.args[0]
- if at_target == torch.ops._C.rotary_embedding.default:
+ if at_target in rope_targets:
query = kwargs["query"]
key = kwargs["key"]
getitem_nodes = self.getitem_users(node)
diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py
index f9eb245..91d5dc6 100644
--- a/vllm/compilation/piecewise_backend.py
+++ b/vllm/compilation/piecewise_backend.py
@@ -298,18 +298,18 @@ class PiecewiseBackend:
else list(args)
)
- with (
- torch._functorch.config.patch("bundled_autograd_cache", True),
- ):
- range_entry.runnable = self.vllm_backend.compiler_manager.compile(
- self.graph,
- args_list,
- self.vllm_backend.inductor_config,
- self.compilation_config,
- compile_range=range_entry.compile_range,
- graph_index=self.piecewise_compile_index,
- num_graphs=self.total_piecewise_compiles,
- )
+ # with (
+ # torch._functorch.config.patch("bundled_autograd_cache", True),
+ # ):
+ range_entry.runnable = self.vllm_backend.compiler_manager.compile(
+ self.graph,
+ args_list,
+ self.vllm_backend.inductor_config,
+ self.compilation_config,
+ compile_range=range_entry.compile_range,
+ graph_index=self.piecewise_compile_index,
+ num_graphs=self.total_piecewise_compiles,
+ )
range_entry.compiled = True
self.to_be_compiled_ranges.remove(range_entry.compile_range)
diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py
index 850ddae..5dff296 100644
--- a/vllm/compilation/wrapper.py
+++ b/vllm/compilation/wrapper.py
@@ -319,3 +319,52 @@ class TorchCompileWithNoGuardsWrapper:
yield
finally:
self.__class__.forward.__code__ = original
+
+
+def reset_compile_wrapper(model: torch.nn.Module) -> None:
+ """
+ Clean up compiled model and captured CUDA graphs for elastic EP.
+ """
+ if not isinstance(model, TorchCompileWithNoGuardsWrapper) and hasattr(
+ model, "model"
+ ):
+ model = model.model
+ if not isinstance(model, TorchCompileWithNoGuardsWrapper):
+ return
+ # model.do_not_compile is set by the @support_torch_compile decorator
+ if hasattr(model, "do_not_compile") and model.do_not_compile:
+ return
+ from vllm.compilation.counter import compilation_counter
+
+ # reset the compilation counter
+ compilation_counter.num_models_seen = 0
+ compilation_counter.num_graphs_seen = 0
+ compilation_counter.num_piecewise_graphs_seen = 0
+ compilation_counter.num_piecewise_capturable_graphs_seen = 0
+ compilation_counter.num_backend_compilations = 0
+ compilation_counter.num_gpu_runner_capture_triggers = 0
+ compilation_counter.num_cudagraph_captured = 0
+ compilation_counter.num_inductor_compiles = 0
+ compilation_counter.num_eager_compiles = 0
+ compilation_counter.num_cache_entries_updated = 0
+ compilation_counter.num_compiled_artifacts_saved = 0
+ compilation_counter.stock_torch_compile_count = 0
+
+ # Clear the AOT compiled function so the model is forced to
+ # recompile on the next call. Without this, decorators.py
+ # __call__ uses the stale aot_compiled_fn whose torchinductor
+ # kernels have old parameters (expert_map size for example)
+ # baked in as compile-time constants.
+ if hasattr(model, "aot_compiled_fn"):
+ model.aot_compiled_fn = None
+ if hasattr(model, "was_aot_compile_fn_loaded_from_disk"):
+ model.was_aot_compile_fn_loaded_from_disk = False
+
+ # Reset the cache_dir so VllmBackend recomputes the hash
+ # (data_parallel_size changed, so the config hash differs).
+ compilation_config = model.vllm_config.compilation_config
+ compilation_config.cache_dir = ""
+ compilation_config.local_cache_dir = ""
+
+ model.__class__.forward.__code__ = model.original_code_object()
+ TorchCompileWithNoGuardsWrapper.__init__(model)
diff --git a/vllm/config/attention.py b/vllm/config/attention.py
index 97a139c..74bb3d6 100644
--- a/vllm/config/attention.py
+++ b/vllm/config/attention.py
@@ -16,8 +16,8 @@ class AttentionConfig:
backend: AttentionBackendEnum | None = None
"""Attention backend to use. If None, will be selected automatically."""
- flash_attn_version: Literal[2, 3] | None = None
- """Force vllm to use a specific flash-attention version (2 or 3).
+ flash_attn_version: Literal[2, 3, 4] | None = None
+ """Force vllm to use a specific flash-attention version (2, 3, or 4).
Only valid when using the flash-attention backend."""
use_prefill_decode_attention: bool = False
diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py
index d22e9a9..10fc0a0 100644
--- a/vllm/config/compilation.py
+++ b/vllm/config/compilation.py
@@ -86,9 +86,16 @@ class CUDAGraphMode(enum.Enum):
def separate_routine(self) -> bool:
return isinstance(self.value, tuple)
+
+ def decode_use_graph(self) -> bool:
+ return self.decode_mode() == CUDAGraphMode.FULL
- def valid_runtime_modes(self) -> bool:
- return self in [CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]
+ @classmethod
+ def valid_runtime_modes(cls) -> frozenset["CUDAGraphMode"]:
+ return frozenset({cls.NONE, cls.PIECEWISE, cls.FULL})
+
+ def is_valid_runtime_mode(self) -> bool:
+ return self in CUDAGraphMode.valid_runtime_modes()
def __str__(self) -> str:
return self.name
@@ -385,7 +392,7 @@ class CompilationConfig:
Please use mode. Currently all levels are mapped to mode.
"""
# Top-level Compilation control
- mode: CompilationMode = Field(default=None)
+ mode: CompilationMode = Field(default=CompilationMode.NONE)
"""The compilation approach used for torch.compile-based compilation of the
model.
@@ -503,7 +510,7 @@ class CompilationConfig:
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`."""
# CudaGraph compilation
- cudagraph_mode: CUDAGraphMode = Field(default=None)
+ cudagraph_mode: CUDAGraphMode = Field(default=CUDAGraphMode.FULL_DECODE_ONLY)
"""
The mode of the cudagraph:
@@ -1003,6 +1010,7 @@ class CompilationConfig:
# https://github.com/vllm-project/vllm/issues/33267
if not self.use_inductor_graph_partition:
self.splitting_ops.append("vllm::unified_kv_cache_update")
+ self.splitting_ops.append("vllm::unified_mla_kv_cache_update")
elif len(self.splitting_ops) == 0:
if (
@@ -1045,7 +1053,7 @@ class CompilationConfig:
"are optimized for prefill and are incompatible with CUDA Graphs. "
"In order to use CUDA Graphs for decode-optimized workloads, "
"use --all2all-backend with another option, such as "
- "deepep_low_latency, pplx, or allgather_reducescatter."
+ "deepep_low_latency or allgather_reducescatter."
)
self.cudagraph_mode = CUDAGraphMode.NONE
diff --git a/vllm/config/model.py b/vllm/config/model.py
index 377a84a..3686537 100644
--- a/vllm/config/model.py
+++ b/vllm/config/model.py
@@ -50,8 +50,6 @@ from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils.import_utils import LazyLoader
from vllm.v1.attention.backends.registry import AttentionBackendEnum
-import os
-
if TYPE_CHECKING:
from transformers import PretrainedConfig
@@ -128,6 +126,7 @@ class ModelConfig:
- "slow" will always use the slow tokenizer.\n
- "mistral" will always use the tokenizer from `mistral_common`.\n
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n
+ - "qwen_vl" will always use the tokenizer from `qwen_vl`.\n
- Other custom values can be supported via plugins."""
trust_remote_code: bool = False
"""Trust remote code (e.g., from HuggingFace) when downloading the model
@@ -463,8 +462,6 @@ class ModelConfig:
self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer)
- from vllm.platforms import current_platform
-
if self.override_attention_dtype is not None and not current_platform.is_rocm():
warnings.warn(
"override-attention-dtype is set but not using ROCm platform",
@@ -473,10 +470,9 @@ class ModelConfig:
if self.enable_sleep_mode and not current_platform.is_sleep_mode_available():
raise ValueError("Sleep mode is not supported on current platform.")
-
- temp_hf_config_path = os.environ.get("CUSTOM_QUANT_CONFIG", None)
+
hf_config = get_config(
- temp_hf_config_path or self.hf_config_path or self.model,
+ self.hf_config_path or self.model,
self.trust_remote_code,
self.revision,
self.code_revision,
@@ -622,6 +618,16 @@ class ModelConfig:
self._try_verify_and_update_model_config()
self._verify_quantization()
self._verify_cuda_graph()
+ import os
+ enforce_cuda_graph = os.environ.get("VLLM_ENFORCE_CUDA_GRAPH",None)
+ if enforce_cuda_graph is not None and enforce_cuda_graph in ["1", "y", "Y"]:
+ self.enforce_eager = False
+ else:
+ self.enforce_eager = True
+ logger.warning_once(
+ "Please export VLLM_ENFORCE_CUDA_GRAPH=1 to enable cuda graph. "
+ "For now, cuda graph is not used and --enforce-eager is disabled ,"
+ "we are trying to use cuda graph as the default mode")
self._verify_bnb_config()
def get_model_arch_config(
@@ -886,6 +892,7 @@ class ModelConfig:
"modelopt",
"modelopt_fp4",
"modelopt_mxfp8",
+ "modelopt_mixed",
"petit_nvfp4",
# Ensure heavy backends are probed last to avoid unnecessary
# imports during override detection (e.g., MXFP4 imports Triton)
@@ -942,8 +949,6 @@ class ModelConfig:
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}."
)
- from vllm.platforms import current_platform
-
current_platform.verify_quantization(self.quantization)
if self.quantization in me_quant.DEPRECATED_QUANTIZATION_METHODS:
@@ -1813,8 +1818,6 @@ def _resolve_auto_dtype(
*,
is_pooling_model: bool,
):
- from vllm.platforms import current_platform
-
supported_dtypes = [
dtype
for dtype in current_platform.supported_dtypes
diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py
index cc2cfa9..6e84cf1 100644
--- a/vllm/config/parallel.py
+++ b/vllm/config/parallel.py
@@ -152,7 +152,6 @@ class ParallelConfig:
- "naive": Naive all2all implementation using broadcasts\n
- "allgather_reducescatter": All2all based on allgather and reducescatter\n
- - "pplx": Use pplx kernels\n
- "deepep_high_throughput": Use deepep high-throughput kernels\n
- "deepep_low_latency": Use deepep low-latency kernels\n
- "mori": Use mori kernels\n
@@ -166,6 +165,9 @@ class ParallelConfig:
disable_custom_all_reduce: bool = False
"""Disable the custom all-reduce kernel and fall back to NCCL."""
+ enable_elastic_ep: bool = False
+ """Enable elastic expert parallelism with stateless NCCL groups for DP/EP."""
+
enable_dbo: bool = False
"""Enable dual batch overlap for the model executor."""
ubatch_size: int = 0
@@ -245,6 +247,34 @@ class ParallelConfig:
Set to be private as it's not intended to be configured by users.
"""
+ _stateless_dp_group_port_list: list[list[int]] = Field(default_factory=list)
+ """List of open ports for stateless DP groups when enable_elastic_ep is True.
+ Set to be private as it's not intended to be configured by users.
+ It is a list of list[int], with each inner list contains a set of 3 ports
+ to be used for setting up the stateless CPU/device/TCPStore groups
+ in StatelessGroupCoordinator. The number of inner lists is equal to
+ the number of DP groups,
+ i.e., len(self._stateless_dp_group_port_list) == world_size_across_dp // dp_size,
+ and len(self._stateless_dp_group_port_list[i]) == 3 for all i.
+ """
+
+ _stateless_ep_group_port_list: list[list[int]] = Field(default_factory=list)
+ """List of open ports for stateless EP groups when enable_elastic_ep is True.
+ Set to be private as it's not intended to be configured by users.
+ len(self._stateless_ep_group_port_list) == world_size_across_dp // ep_size,
+ """
+
+ _stateless_eplb_group_port_list: list[list[int]] = Field(default_factory=list)
+ """List of open ports for stateless EPLB groups when enable_elastic_ep is True.
+ Same topology as EP but separate NCCL communicator to avoid deadlocks.
+ """
+
+ _stateless_world_group_port_list: list[list[int]] = Field(default_factory=list)
+ """List of open ports for stateless world group when enable_elastic_ep is True.
+ Set to be private as it's not intended to be configured by users.
+ len(self._stateless_world_group_port_list) == 1,
+ """
+
decode_context_parallel_size: int = 1
"""Number of decode context parallel groups, because the world size does
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
@@ -310,6 +340,13 @@ class ParallelConfig:
f"but found: {self._api_process_rank}"
)
+ if self.all2all_backend == "pplx":
+ logger.warning(
+ "The 'pplx' all2all backend has been removed. "
+ "Falling back to 'allgather_reducescatter'."
+ )
+ self.all2all_backend = "allgather_reducescatter"
+
if self.data_parallel_size_local > self.data_parallel_size:
raise ValueError(
f"data_parallel_size_local ({self.data_parallel_size_local}) "
@@ -396,7 +433,67 @@ class ParallelConfig:
return answer
- def stateless_init_dp_group(self) -> ProcessGroup:
+ def allocate_elastic_ep_ports(self) -> None:
+ """Allocate all ports for elastic EP (stateless groups + DP master).
+
+ Must be called AFTER ray.init() so that ports claimed by Ray's
+ idle worker pool are already in use and won't be returned by
+ get_open_ports_list().
+ """
+ if not self.enable_elastic_ep:
+ return
+ if self._stateless_world_group_port_list:
+ return
+
+ num_world_groups = 1
+ dp_size = self.data_parallel_size
+ ep_size = self.data_parallel_size * self.world_size_across_dp
+ num_dp_groups = max(1, self.world_size_across_dp // dp_size)
+ num_ep_groups = max(1, self.world_size_across_dp // ep_size)
+ num_eplb_groups = num_ep_groups
+ total_stateless_ports = (
+ num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups
+ ) * 3
+ num_dp_master_ports = 5
+
+ all_ports = get_open_ports_list(total_stateless_ports + num_dp_master_ports)
+
+ self._data_parallel_master_port_list = all_ports[-num_dp_master_ports:]
+ self.data_parallel_master_port = self._data_parallel_master_port_list.pop()
+ all_ports = all_ports[:-num_dp_master_ports]
+
+ self._stateless_world_group_port_list = [
+ all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3)
+ ]
+ start_idx = num_world_groups * 3
+ self._stateless_dp_group_port_list = [
+ all_ports[i : i + 3]
+ for i in range(start_idx, start_idx + num_dp_groups * 3, 3)
+ ]
+ start_idx += num_dp_groups * 3
+ self._stateless_ep_group_port_list = [
+ all_ports[i : i + 3]
+ for i in range(start_idx, start_idx + num_ep_groups * 3, 3)
+ ]
+ start_idx += num_ep_groups * 3
+ self._stateless_eplb_group_port_list = [
+ all_ports[i : i + 3]
+ for i in range(start_idx, start_idx + num_eplb_groups * 3, 3)
+ ]
+
+ def get_next_stateless_world_group_port(self) -> list[int]:
+ return self._stateless_world_group_port_list.pop()
+
+ def get_next_stateless_dp_group_port(self) -> list[int]:
+ return self._stateless_dp_group_port_list.pop()
+
+ def get_next_stateless_ep_group_port(self) -> list[int]:
+ return self._stateless_ep_group_port_list.pop()
+
+ def get_next_stateless_eplb_group_port(self) -> list[int]:
+ return self._stateless_eplb_group_port_list.pop()
+
+ def stateless_init_dp_group(self, return_store: bool = False) -> ProcessGroup:
# NOTE: In high-concurrency scenarios multiple processes
# can pick the same (currently free) port through a race
# condition when calling `get_open_port()`. When the first
@@ -420,7 +517,8 @@ class ParallelConfig:
self.get_next_dp_init_port(),
self.data_parallel_rank,
self.data_parallel_size,
- backend=current_platform.dist_backend,
+ backend="gloo",
+ return_store=return_store,
)
except DistNetworkError as e:
# We only want to retry when the root cause is EADDRINUSE.
@@ -442,7 +540,6 @@ class ParallelConfig:
# In this case, ensure the input to the experts is sequence parallel
# to avoid the excess work.
#
- # Not needed for pplx-kernels as it can handle duplicate input tokens.
@property
def use_sequence_parallel_moe(self) -> bool:
return (
@@ -556,6 +653,21 @@ class ParallelConfig:
logger.info("Using external launcher for distributed inference.")
self.world_size *= self.data_parallel_size
+ if self.enable_elastic_ep:
+ if not self.enable_eplb:
+ raise ValueError("Elastic EP is only supported with enable_eplb=True.")
+ if self.pipeline_parallel_size > 1:
+ raise ValueError(
+ "Elastic EP is not supported with pipeline parallelism "
+ f"(pipeline_parallel_size={self.pipeline_parallel_size})."
+ )
+ if self.data_parallel_external_lb or self.data_parallel_hybrid_lb:
+ raise NotImplementedError(
+ "Elastic EP is not compatible with data_parallel_external_lb "
+ "or data_parallel_hybrid_lb. Elastic EP relies on a single API "
+ "server and core client to coordinate scale up/down."
+ )
+
if self.data_parallel_size > 1 or self.data_parallel_size_local == 0:
# Data parallel was specified in the engine args.
if self.distributed_executor_backend == "external_launcher":
@@ -568,9 +680,12 @@ class ParallelConfig:
"Set data_parallel_rank to %d automatically.",
self.data_parallel_rank,
)
- if not self._data_parallel_master_port_list:
- self._data_parallel_master_port_list = get_open_ports_list(5)
- self.data_parallel_master_port = self._data_parallel_master_port_list.pop()
+ if not self.enable_elastic_ep:
+ if not self._data_parallel_master_port_list:
+ self._data_parallel_master_port_list = get_open_ports_list(5)
+ self.data_parallel_master_port = (
+ self._data_parallel_master_port_list.pop()
+ )
if not (0 <= self.data_parallel_rank < self.data_parallel_size):
raise ValueError(
@@ -597,7 +712,7 @@ class ParallelConfig:
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
logger.info("Disabling V1 multiprocessing for external launcher.")
- if self.distributed_executor_backend is None and self.world_size > 1:
+ if self.distributed_executor_backend is None and self.world_size_across_dp > 1:
# We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group.
@@ -659,6 +774,17 @@ class ParallelConfig:
"backend is mp, uni or external_launcher."
)
+ if (
+ self.all2all_backend in ("allgather_reducescatter", "naive")
+ and self.eplb_config.use_async
+ ):
+ logger.warning(
+ "Async EPLB causes hangs with the '%s' all2all backend. "
+ "Forcing synchronous EPLB.",
+ self.all2all_backend,
+ )
+ self.eplb_config.use_async = False
+
@property
def use_ray(self) -> bool:
return self.distributed_executor_backend == "ray" or (
diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py
index c2bced7..308356a 100644
--- a/vllm/config/speculative.py
+++ b/vllm/config/speculative.py
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
+import copy
from typing import TYPE_CHECKING, Any, Literal, get_args
from pydantic import Field, SkipValidation, model_validator
@@ -45,7 +46,7 @@ MTPModelTypes = Literal[
"pangu_ultra_moe_mtp",
"step3p5_mtp",
]
-EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
+EagleModelTypes = Literal["eagle", "eagle3", "extract_hidden_states", MTPModelTypes]
SpeculativeMethod = Literal[
"ngram",
"medusa",
@@ -77,12 +78,24 @@ class SpeculativeConfig:
If using `ngram` method, the related configuration `prompt_lookup_max` and
`prompt_lookup_min` should be considered."""
+ enable_multi_layers_mtp: bool = False
+ """If set to True, the MTP method will run multiple layers of MTP
+ speculator. If set to False, it will run only one layer of MTP speculator.
+ This is only effective when the method is set to `mtp`."""
draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
"""The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size."""
+ draft_pipeline_parallel_size: int | None = Field(default=None, ge=1)
+ """The degree of pipeline parallelism for the draft model.
+
+ Defaults to the target model's pipeline parallel size. Set this to 1 to
+ run the drafter locally on the last target PP stage."""
tensor_parallel_size: int | None = None
"""Users should pass "draft_tensor_parallel_size". This parameter's purpose is to
warn users when they mistakenly provide the wrong argument."""
+ pipeline_parallel_size: int | None = None
+ """Users should pass "draft_pipeline_parallel_size". This parameter's
+ purpose is to warn users when they mistakenly provide the wrong argument."""
# Draft model configuration
quantization: me_quant.QuantizationMethods | None = None
@@ -181,9 +194,22 @@ class SpeculativeConfig:
the final hidden states.
"""
factors: list[Any] = []
- # Eagle3 affects the computation graph because it returns intermediate
- # hidden states in addition to the final hidden state.
- factors.append(self.method == "eagle3")
+ # Eagle3 and extract_hidden_states affect the computation graph because
+ # they return intermediate hidden states in addition to the final hidden state.
+ uses_aux_hidden_states = self.method in ("eagle3", "extract_hidden_states")
+ factors.append(uses_aux_hidden_states)
+
+ # The specific layers used also affect the computation graph
+ if uses_aux_hidden_states and self.draft_model_config is not None:
+ layer_ids = getattr(
+ self.draft_model_config.hf_config,
+ "eagle_aux_hidden_state_layer_ids",
+ None,
+ )
+ if layer_ids is not None:
+ # Convert to tuple to make it hashable
+ factors.append(tuple(layer_ids))
+
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@@ -352,6 +378,8 @@ class SpeculativeConfig:
self.model = "ngram"
elif self.method == "suffix":
self.model = "suffix"
+ elif self.method == "extract_hidden_states":
+ self.model = "extract_hidden_states"
else:
raise ValueError(
"num_speculative_tokens was provided but without speculative model."
@@ -394,6 +422,34 @@ class SpeculativeConfig:
self.draft_parallel_config = self.target_parallel_config
elif self.method == "suffix":
self._validate_suffix_decoding()
+ elif self.method == "extract_hidden_states":
+ from vllm.transformers_utils.configs.extract_hidden_states import (
+ ExtractHiddenStatesConfig,
+ )
+
+ # ExtractHiddenStatesModel is instantiated manually in load_model()
+ # We just need to store the target model config for KV cache shape info
+ self.model = "extract_hidden_states"
+ self.prompt_lookup_max = 0
+ self.prompt_lookup_min = 0
+
+ if hasattr(self.draft_model_config, "hf_config"):
+ hf_config = self.draft_model_config.hf_config.to_dict()
+ elif (
+ isinstance(self.draft_model_config, dict)
+ and "hf_config" in self.draft_model_config
+ ):
+ hf_config = self.draft_model_config["hf_config"]
+ else:
+ hf_config = {}
+
+ self.draft_model_config = copy.copy(self.target_model_config)
+ self.draft_model_config.hf_config = ExtractHiddenStatesConfig(
+ self.draft_model_config.hf_config, **hf_config
+ )
+ self.update_arch_()
+ self.draft_parallel_config = self.target_parallel_config
+
else:
self.prompt_lookup_max = 0
self.prompt_lookup_min = 0
@@ -439,7 +495,10 @@ class SpeculativeConfig:
MTPModelTypes
):
self.method = "mtp"
- if self.num_speculative_tokens > 1:
+ if (
+ self.enable_multi_layers_mtp is False
+ and self.num_speculative_tokens > 1
+ ):
logger.warning(
"Enabling num_speculative_tokens > 1 will run "
"multiple times of forward on same MTP layer"
@@ -478,23 +537,8 @@ class SpeculativeConfig:
method=self.method,
model_type="eagle",
)
- # EAGLEConfig primarily updates architectures, so update
- # all architectures-related fields in draft_model_config
self.draft_model_config.hf_config = eagle_config
- self.draft_model_config.hf_text_config = get_hf_text_config(
- self.draft_model_config.hf_config
- )
- self.draft_model_config.model_arch_config = (
- self.draft_model_config.get_model_arch_config()
- )
- model_info, arch = (
- self.draft_model_config.registry.inspect_model_cls(
- self.draft_model_config.architectures,
- self.draft_model_config,
- )
- )
- self.draft_model_config._model_info = model_info
- self.draft_model_config._architecture = arch
+ self.update_arch_()
if self.num_speculative_tokens is not None and hasattr(
self.draft_model_config.hf_config, "num_lookahead_tokens"
@@ -510,6 +554,17 @@ class SpeculativeConfig:
if self.num_speculative_tokens is None:
# Default to max value defined in draft model config.
self.num_speculative_tokens = n_predict
+ elif (
+ self.method == "mtp"
+ and self.enable_multi_layers_mtp
+ and self.num_speculative_tokens > n_predict
+ ):
+ logger.warning_once(
+ "For multi_layer_eagle, num_speculative_tokens "
+ "is greater than the layer_num, adjusting to "
+ "layer_num"
+ )
+ self.num_speculative_tokens = n_predict
elif (
self.num_speculative_tokens > n_predict
and self.num_speculative_tokens % n_predict != 0
@@ -555,9 +610,17 @@ class SpeculativeConfig:
)
)
+ self.draft_pipeline_parallel_size = (
+ SpeculativeConfig._verify_and_get_draft_pp(
+ self.target_parallel_config,
+ self.draft_pipeline_parallel_size,
+ )
+ )
self.draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
- self.target_parallel_config, self.draft_tensor_parallel_size
+ self.target_parallel_config,
+ self.draft_tensor_parallel_size,
+ self.draft_pipeline_parallel_size,
)
)
return self
@@ -671,17 +734,61 @@ class SpeculativeConfig:
)
return speculative_draft_tensor_parallel_size
+ @staticmethod
+ def _verify_and_get_draft_pp(
+ target_parallel_config: ParallelConfig,
+ speculative_draft_pipeline_parallel_size: int | None,
+ ) -> int:
+ """
+ Verifies and adjusts the pipeline parallel size for a draft model
+ specified using speculative_draft_pipeline_parallel_size.
+ """
+ if speculative_draft_pipeline_parallel_size is None:
+ return target_parallel_config.pipeline_parallel_size
+
+ if speculative_draft_pipeline_parallel_size not in (
+ 1,
+ target_parallel_config.pipeline_parallel_size,
+ ):
+ raise ValueError(
+ f"{speculative_draft_pipeline_parallel_size=} cannot be "
+ "other value than 1 or target model "
+ f"pipeline_parallel_size="
+ f"{target_parallel_config.pipeline_parallel_size}"
+ )
+ return speculative_draft_pipeline_parallel_size
+
+ def update_arch_(self):
+ """
+ EagleConfig and ExtractHiddenStatesConfig update architectures, so update all
+ architectures-related fields in self.draft_model_config
+ """
+ self.draft_model_config.hf_text_config = get_hf_text_config(
+ self.draft_model_config.hf_config
+ )
+ self.draft_model_config.model_arch_config = (
+ self.draft_model_config.get_model_arch_config()
+ )
+ model_info, arch = self.draft_model_config.registry.inspect_model_cls(
+ self.draft_model_config.architectures,
+ self.draft_model_config,
+ )
+ self.draft_model_config._model_info = model_info
+ self.draft_model_config._architecture = arch
+
@staticmethod
def create_draft_parallel_config(
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: int,
+ speculative_draft_pipeline_parallel_size: int,
) -> ParallelConfig:
"""Create a parallel config for use by the draft worker.
- This is mostly a copy of the target parallel config, except the tp_size.
+ This is mostly a copy of the target parallel config, except the tp/pp
+ sizes used by the draft model.
"""
draft_parallel_config = ParallelConfig(
- pipeline_parallel_size=target_parallel_config.pipeline_parallel_size,
+ pipeline_parallel_size=speculative_draft_pipeline_parallel_size,
tensor_parallel_size=speculative_draft_tensor_parallel_size,
distributed_executor_backend=target_parallel_config.distributed_executor_backend,
max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers,
@@ -699,6 +806,12 @@ class SpeculativeConfig:
"'tensor_parallel_size' is not a valid argument in the "
"speculative_config. Please pass 'draft_tensor_parallel_size' instead."
)
+ if self.pipeline_parallel_size is not None:
+ raise ValueError(
+ "'pipeline_parallel_size' is not a valid argument in the "
+ "speculative_config. Please pass "
+ "'draft_pipeline_parallel_size' instead."
+ )
if self.num_speculative_tokens is None:
raise ValueError(
@@ -718,7 +831,7 @@ class SpeculativeConfig:
self.draft_parallel_config
)
- eagle3_target_supported = [
+ aux_hidden_states_supported = [
"llama",
"qwen",
"minicpm",
@@ -729,16 +842,16 @@ class SpeculativeConfig:
"nemotron_h",
]
if (
- self.method == "eagle3"
+ self.method in ("eagle3", "extract_hidden_states")
and self.target_model_config
and not any(
supported_model in self.target_model_config.hf_text_config.model_type
- for supported_model in eagle3_target_supported
+ for supported_model in aux_hidden_states_supported
)
):
raise ValueError(
- f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501
- f"Got {self.target_model_config.hf_text_config.model_type=}"
+ f"{self.method} is only supported for {aux_hidden_states_supported}"
+ f" models. Got {self.target_model_config.hf_text_config.model_type=}"
)
self.verify_equal_vocab_size_if_draft_model()
return self
@@ -782,8 +895,65 @@ class SpeculativeConfig:
def uses_draft_model(self) -> bool:
return self.method == "draft_model"
+ def uses_extract_hidden_states(self) -> bool:
+ return self.method == "extract_hidden_states"
+
+ def needs_partial_pp_draft_remap(
+ self, target_parallel_config: ParallelConfig
+ ) -> bool:
+ """Whether draft PP is smaller than target PP and needs rank remap."""
+ if self.draft_parallel_config is None:
+ return False
+ return (
+ target_parallel_config.pipeline_parallel_size
+ > self.draft_parallel_config.pipeline_parallel_size
+ )
+
+ def resolve_partial_pp_draft_rank(
+ self, target_parallel_config: ParallelConfig
+ ) -> int:
+ """Map a target rank to the local draft rank for partial-PP drafting.
+
+ Currently this only supports running the draft model with `draft_pp=1`
+ on the last target PP stage.
+ """
+ if not self.needs_partial_pp_draft_remap(target_parallel_config):
+ return target_parallel_config.rank
+
+ assert self.draft_parallel_config is not None
+ draft_pp = self.draft_parallel_config.pipeline_parallel_size
+ if draft_pp != 1:
+ raise ValueError(
+ "Partial pp drafter rank remapping only supports "
+ "draft_pipeline_parallel_size=1 when target PP is larger."
+ )
+
+ target_tp = target_parallel_config.tensor_parallel_size
+ draft_tp = self.draft_parallel_config.tensor_parallel_size
+ if draft_tp != target_tp:
+ raise ValueError(
+ "Partial pp drafter rank remapping requires "
+ "draft_tensor_parallel_size to equal target tensor_parallel_size. "
+ f"Got draft_tp={draft_tp}, target_tp={target_tp}."
+ )
+
+ target_pp = target_parallel_config.pipeline_parallel_size
+ target_rank = target_parallel_config.rank
+ target_pp_rank = target_rank // target_tp
+ target_tp_rank = target_rank % target_tp
+ if target_pp_rank != target_pp - 1:
+ raise ValueError(
+ "Partial pp drafter should only run on the last "
+ f"pipeline stage, but got pp rank {target_pp_rank} / {target_pp}"
+ )
+ return target_tp_rank
+
def __repr__(self) -> str:
method = self.method
- model = None if method in ("ngram", "suffix") else self.draft_model_config.model
+ model = (
+ None
+ if method in ("ngram", "suffix", "extract_hidden_states")
+ else self.draft_model_config.model
+ )
num_spec_tokens = self.num_speculative_tokens
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"
diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py
index 6a90a4c..c236837 100644
--- a/vllm/config/vllm.py
+++ b/vllm/config/vllm.py
@@ -126,6 +126,9 @@ def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool:
# tp-dp combination broken:
# https://github.com/vllm-project/vllm/issues/34458
and cfg.parallel_config.data_parallel_size == 1
+ # tp-pp combination broken:
+ # https://github.com/vllm-project/vllm/issues/35426
+ and cfg.parallel_config.pipeline_parallel_size == 1
)
@@ -857,7 +860,7 @@ class VllmConfig:
self.compilation_config.pass_config.fuse_gemm_comms = False
else:
# Compute SP threshold early; disable if None (model too
- # small) before +rms_norm gets forced into custom_ops.
+ # small for SP to be beneficial).
pass_config = self.compilation_config.pass_config
if pass_config.sp_min_token_num is None:
from vllm.compilation.passes.fusion.sequence_parallelism import (
@@ -880,15 +883,13 @@ class VllmConfig:
self.compilation_config.pass_config.enable_sp = False
self.compilation_config.pass_config.fuse_gemm_comms = False
- if self.compilation_config.pass_config.enable_sp:
- if "-rms_norm" in self.compilation_config.custom_ops:
- logger.warning(
- "RMS norm force disabled, sequence parallelism might break"
- )
- else:
- self.compilation_config.custom_ops.append("+rms_norm")
+ from vllm.utils.torch_utils import HAS_OPAQUE_TYPE
- if self.compilation_config.fast_moe_cold_start is None:
+ if HAS_OPAQUE_TYPE:
+ # On torch >= 2.11 the hoisted OpaqueObject approach supersedes
+ # fast_moe_cold_start, so force it off.
+ self.compilation_config.fast_moe_cold_start = False
+ elif self.compilation_config.fast_moe_cold_start is None:
# resolve default behavior: try to be as safe as possible
# this config is unsafe if any spec decoding draft model has a MOE.
# We'll conservatively turn it off if we see spec decoding.
@@ -907,9 +908,9 @@ class VllmConfig:
):
logger.warning_once(
"Pooling models do not support full cudagraphs. "
- "Overriding cudagraph_mode to PIECEWISE."
+ "Overriding cudagraph_mode to NONE."
)
- self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
+ self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
elif (
model_config.is_encoder_decoder
and self.compilation_config.cudagraph_mode
@@ -924,6 +925,33 @@ class VllmConfig:
CUDAGraphMode.FULL_DECODE_ONLY
)
+ # Check if KV connector requires PIECEWISE mode for CUDA graphs
+ if (
+ self.kv_transfer_config is not None
+ and self.kv_transfer_config.is_kv_transfer_instance
+ and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
+ ):
+ # Lazy import to avoid circular dependencies
+ from vllm.distributed.kv_transfer.kv_connector.factory import (
+ KVConnectorFactory,
+ )
+
+ connector_cls = KVConnectorFactory.get_connector_class(
+ self.kv_transfer_config
+ )
+ if connector_cls.requires_piecewise_for_cudagraph(
+ self.kv_transfer_config.kv_connector_extra_config
+ ):
+ logger.warning_once(
+ "KV connector %s requires PIECEWISE CUDA graph mode "
+ "due to layerwise async operations that cannot be "
+ "captured in CUDA graphs. "
+ "Overriding cudagraph_mode from %s to PIECEWISE.",
+ connector_cls.__name__,
+ self.compilation_config.cudagraph_mode.name,
+ )
+ self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
+
# disable cudagraph when enforce eager execution
if self.model_config is not None and self.model_config.enforce_eager:
logger.info("Cudagraph is disabled under eager mode")
@@ -1113,6 +1141,20 @@ class VllmConfig:
if not self.instance_id:
self.instance_id = random_uuid()[:5]
+
+ def is_ixserver_connector(kv_transfer_config) -> bool:
+ if kv_transfer_config is not None and hasattr(
+ kv_transfer_config, "kv_connector"
+ ):
+ connector = kv_transfer_config.kv_connector
+ if isinstance(connector, str):
+ connector_name = connector
+ else:
+ connector_name = getattr(
+ type(connector), "__name__", str(connector)
+ )
+ return "IxServer" in connector_name
+ return False
# Hybrid KV cache manager (HMA) runtime rules:
# - Explicit enable (--no-disable-kv-cache-manager): error if runtime
@@ -1154,21 +1196,29 @@ class VllmConfig:
if self.scheduler_config.disable_hybrid_kv_cache_manager is None:
# Default to disable HMA, but only if the user didn't express a preference.
if self.kv_transfer_config is not None:
+ if is_ixserver_connector(self.kv_transfer_config):
+ pass
# NOTE(Kuntai): turn HMA off for connector unless specifically enabled.
- need_disable_hybrid_kv_cache_manager = True
- logger.warning(
- "Turning off hybrid kv cache manager because "
- "`--kv-transfer-config` is set. This will reduce the "
- "performance of vLLM on LLMs with sliding window attention "
- "or Mamba attention. If you are a developer of kv connector"
- ", please consider supporting hybrid kv cache manager for "
- "your connector by making sure your connector is a subclass"
- " of `SupportsHMA` defined in kv_connector/v1/base.py and"
- " use --no-disable-hybrid-kv-cache-manager to start vLLM."
+ else:
+ need_disable_hybrid_kv_cache_manager = True
+ logger.warning(
+ "Turning off hybrid kv cache manager because "
+ "`--kv-transfer-config` is set. This will reduce the "
+ "performance of vLLM on LLMs with sliding window attention "
+ "or Mamba attention. If you are a developer of kv connector"
+ ", please consider supporting hybrid kv cache manager for "
+ "your connector by making sure your connector is a subclass"
+ " of `SupportsHMA` defined in kv_connector/v1/base.py and"
+ " use --no-disable-hybrid-kv-cache-manager to start vLLM."
+ )
+ self.scheduler_config.disable_hybrid_kv_cache_manager = (
+ need_disable_hybrid_kv_cache_manager
+ )
+
+ else:
+ self.scheduler_config.disable_hybrid_kv_cache_manager = (
+ need_disable_hybrid_kv_cache_manager
)
- self.scheduler_config.disable_hybrid_kv_cache_manager = (
- need_disable_hybrid_kv_cache_manager
- )
elif (
self.scheduler_config.disable_hybrid_kv_cache_manager is False
and need_disable_hybrid_kv_cache_manager
@@ -1466,22 +1516,22 @@ class VllmConfig:
if compile_range_end is not None:
computed_compile_ranges_split_points.append(compile_range_end)
- # # Add the compile ranges for flashinfer
- # if compilation_config.pass_config.fuse_allreduce_rms:
- # tp_size = self.parallel_config.tensor_parallel_size
- # max_size = compilation_config.pass_config.flashinfer_max_size(tp_size)
- # if max_size is not None:
- # max_token_num = max_size // (
- # self.model_config.get_hidden_size()
- # * self.model_config.dtype.itemsize
- # )
- # if compile_range_end is not None and max_token_num < compile_range_end:
- # computed_compile_ranges_split_points.append(max_token_num)
- # else:
- # logger.debug(
- # "Max num batched tokens below allreduce-rms fusion threshold, "
- # "allreduce-rms fusion will be enabled for all num_tokens."
- # )
+ # Add the compile ranges for flashinfer
+ if compilation_config.pass_config.fuse_allreduce_rms:
+ tp_size = self.parallel_config.tensor_parallel_size
+ max_size = compilation_config.pass_config.flashinfer_max_size(tp_size)
+ if max_size is not None:
+ max_token_num = max_size // (
+ self.model_config.get_hidden_size()
+ * self.model_config.dtype.itemsize
+ )
+ if compile_range_end is not None and max_token_num < compile_range_end:
+ computed_compile_ranges_split_points.append(max_token_num)
+ else:
+ logger.debug(
+ "Max num batched tokens below allreduce-rms fusion threshold, "
+ "allreduce-rms fusion will be enabled for all num_tokens."
+ )
# Add the compile ranges for sequence parallelism
if compilation_config.pass_config.enable_sp:
@@ -1618,6 +1668,7 @@ class VllmConfig:
f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, " # noqa
f"data_parallel_size={self.parallel_config.data_parallel_size}, " # noqa
f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa
+ f"quantization={self.model_config.quantization}, "
f"enforce_eager={self.model_config.enforce_eager}, "
f"enable_return_routed_experts={self.model_config.enable_return_routed_experts}, " # noqa
f"kv_cache_dtype={self.cache_config.cache_dtype}, "
diff --git a/vllm/config/weight_transfer.py b/vllm/config/weight_transfer.py
index 855b0d9..1da1f96 100644
--- a/vllm/config/weight_transfer.py
+++ b/vllm/config/weight_transfer.py
@@ -9,5 +9,5 @@ from vllm.config.utils import config
class WeightTransferConfig:
"""Configuration for weight transfer during RL training."""
- backend: Literal["nccl"] = "nccl"
+ backend: Literal["nccl", "ipc"] = "nccl"
"""The backend to use for weight transfer."""
diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py
index 2f97288..554a34b 100644
--- a/vllm/device_allocator/cumem.py
+++ b/vllm/device_allocator/cumem.py
@@ -11,7 +11,7 @@
import dataclasses
import gc
import os
-from collections.abc import Callable
+from collections.abc import Callable, Iterator
from contextlib import contextmanager
from typing import Any
@@ -25,6 +25,7 @@ logger = init_logger(__name__)
cumem_available = False
+libcudart: Any = None
try:
from vllm.cumem_allocator import (
init_module,
@@ -41,9 +42,7 @@ except ModuleNotFoundError:
init_module = None
python_create_and_map = None
python_unmap_and_release = None
- CudaRTLibrary = None
lib_name = None
- libcudart = None
# py_device, py_alignedSize, py_d_mem, py_p_memHandle
HandleType = tuple[int, int, int, int]
@@ -65,7 +64,8 @@ def unmap_and_release(allocation_handle: HandleType) -> None:
def get_pluggable_allocator(
- python_malloc_fn: Callable[[int], int], python_free_func: Callable[[int, int], None]
+ python_malloc_fn: Callable[[HandleType], None],
+ python_free_func: Callable[[int], HandleType],
) -> torch.cuda.memory.CUDAPluggableAllocator:
init_module(python_malloc_fn, python_free_func)
new_alloc = torch.cuda.memory.CUDAPluggableAllocator(
@@ -76,8 +76,11 @@ def get_pluggable_allocator(
@contextmanager
def use_memory_pool_with_allocator(
- python_malloc_fn: Callable[[int], int], python_free_func: Callable[[int, int], None]
-) -> None:
+ python_malloc_fn: Callable[[HandleType], None],
+ python_free_func: Callable[[int], HandleType],
+) -> Iterator[
+ tuple[torch.cuda.memory.MemPool, torch.cuda.memory.CUDAPluggableAllocator]
+]:
new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator)
with torch.cuda.memory.use_mem_pool(mem_pool):
@@ -109,7 +112,7 @@ class CuMemAllocator:
not work as expected.
"""
- instance: "CuMemAllocator" = None
+ instance: "CuMemAllocator | None" = None
default_tag: str = "default"
@staticmethod
diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py
index 678cd45..7e1c9e6 100644
--- a/vllm/distributed/device_communicators/all2all.py
+++ b/vllm/distributed/device_communicators/all2all.py
@@ -3,14 +3,13 @@
from typing import Any
import torch
-import torch.distributed as dist
import vllm.envs as envs
from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils.flashinfer import has_flashinfer_all2all
-from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
+from vllm.utils.import_utils import has_deep_ep, has_mori
from .base_device_communicator import All2AllManagerBase, Cache
@@ -32,8 +31,8 @@ class NaiveAll2AllManager(All2AllManagerBase):
debugging.
"""
- def __init__(self, cpu_group):
- super().__init__(cpu_group)
+ def __init__(self, cpu_group, tcp_store_group=None):
+ super().__init__(cpu_group, tcp_store_group)
def naive_multicast(
self,
@@ -139,8 +138,8 @@ class AgRsAll2AllManager(All2AllManagerBase):
all-gather (dispatch) and reduce-scatter (combine).
"""
- def __init__(self, cpu_group):
- super().__init__(cpu_group)
+ def __init__(self, cpu_group, tcp_store_group=None):
+ super().__init__(cpu_group, tcp_store_group)
def dispatch_router_logits(
self,
@@ -235,107 +234,17 @@ class AgRsAll2AllManager(All2AllManagerBase):
pass
-class PPLXAll2AllManager(All2AllManagerBase):
- """
- All2All communication based on PPLX kernels.
- """
-
- def __init__(self, cpu_group):
- assert has_pplx(), (
- "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
- " to install pplx_kernels."
- )
- super().__init__(cpu_group)
-
- if self.internode:
- # inter-node communication needs nvshmem,
- # intra-node communication uses p2p mapping directly
- from pplx_kernels.nvshmem import ( # type: ignore[import-not-found]
- nvshmem_alloc_empty_unique_id,
- nvshmem_get_unique_id,
- nvshmem_init,
- )
-
- logger.debug(
- "Initialize NVSHMEM for pplx_kernels: rank=%d, world size=%d",
- self.rank,
- self.world_size,
- )
- uid = (
- nvshmem_get_unique_id()
- if self.rank == 0
- else nvshmem_alloc_empty_unique_id()
- )
- dist.broadcast(
- uid,
- src=dist.get_process_group_ranks(self.cpu_group)[0],
- group=self.cpu_group,
- )
- logger.debug("PPLX NVSHMEM UID = %s", uid)
- nvshmem_init(uid, self.rank, self.world_size)
-
- self.handle_cache = Cache()
-
- def get_handle(self, kwargs):
- import pplx_kernels as pplx # type: ignore[import-not-found]
-
- return self.handle_cache.get_or_create(
- kwargs,
- pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode,
- )
-
- def dispatch_router_logits(
- self,
- hidden_states: torch.Tensor,
- router_logits: torch.Tensor,
- is_sequence_parallel: bool = False,
- extra_tensors: list[torch.Tensor] | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- raise NotImplementedError
-
- def dispatch(
- self,
- hidden_states: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- is_sequence_parallel: bool = False,
- extra_tensors: list[torch.Tensor] | None = None,
- ) -> (
- tuple[torch.Tensor, torch.Tensor, torch.Tensor]
- | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
- ):
- raise NotImplementedError
-
- def combine(
- self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
- ) -> torch.Tensor:
- raise NotImplementedError
-
- def destroy(self):
- with self.handle_cache._lock:
- for _, handle in self.handle_cache._cache.items():
- handle.destroy()
-
- if self.internode:
- from pplx_kernels.nvshmem import (
- nvshmem_finalize, # type: ignore[import-not-found]
- )
-
- logger.debug("PPLX NVSHMEM finalize")
- nvshmem_finalize()
-
-
class DeepEPAll2AllManagerBase(All2AllManagerBase):
"""
All2All communication based on DeepEP High-Throughput kernels.
"""
- def __init__(self, cpu_group):
+ def __init__(self, cpu_group, tcp_store_group=None):
assert has_deep_ep(), (
"DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md"
" to install DeepEP kernels."
) # noqa
- super().__init__(cpu_group)
+ super().__init__(cpu_group, tcp_store_group)
self.handle_cache = Cache()
# This is the DeepEP default. Stick to it till we can establish
@@ -373,7 +282,10 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
raise NotImplementedError
def destroy(self):
- pass
+ with self.handle_cache._lock:
+ for _, handle in self.handle_cache._cache.items():
+ handle.destroy()
+ self.handle_cache._cache.clear()
class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
@@ -381,8 +293,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
All2All communication based on DeepEP High-Throughput kernels.
"""
- def __init__(self, cpu_group):
- super().__init__(cpu_group)
+ def __init__(self, cpu_group, tcp_store_group=None):
+ super().__init__(cpu_group, tcp_store_group)
def _make_all2all_kwargs(self) -> dict[Any, Any]:
# Defaults for internode and intranode are taken from DeepEP tests.
@@ -405,6 +317,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank,
+ explicitly_destroy=True,
)
def get_handle(self, kwargs):
@@ -438,8 +351,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
All2All communication based on DeepEP Low-Latency kernels.
"""
- def __init__(self, cpu_group):
- super().__init__(cpu_group)
+ def __init__(self, cpu_group, tcp_store_group=None):
+ super().__init__(cpu_group, tcp_store_group)
def _make_all2all_kwargs(
self,
@@ -476,8 +389,9 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_qps_per_rank,
- allow_nvlink_for_low_latency_mode=True,
- allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
+ # allow_nvlink_for_low_latency_mode=True,
+ # allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
+ explicitly_destroy=True,
)
def get_handle(self, kwargs):
@@ -509,11 +423,11 @@ class FlashInferAllToAllManager(All2AllManagerBase):
rank: int
world_size: int
- def __init__(self, cpu_group):
+ def __init__(self, cpu_group, tcp_store_group=None):
assert has_flashinfer_all2all(), (
"flashinfer all2all module not found. Please install/check flashinfer"
) # noqa
- super().__init__(cpu_group)
+ super().__init__(cpu_group, tcp_store_group)
logger.debug(
"Initialize for flashinfer All2All rank=%d, world size=%d",
self.rank,
diff --git a/vllm/distributed/device_communicators/all_reduce_utils.py b/vllm/distributed/device_communicators/all_reduce_utils.py
index ff2d743..3c347ef 100644
--- a/vllm/distributed/device_communicators/all_reduce_utils.py
+++ b/vllm/distributed/device_communicators/all_reduce_utils.py
@@ -27,6 +27,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
logger = init_logger(__name__)
+KiB = 1024
MiB = 1024 * 1024
# Max size for each world size in case symmetric memory is available
# For different SM architectures
@@ -60,17 +61,44 @@ SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
},
}
+# NCCL symmetric memory allreduce configuration based on H100 and GB200 benchmarks.
+# PyNCCL-symm outperforms custom_AR for small and large tensor sizes,
+# while custom_AR wins for mid-range sizes.
+#
+# Benchmark results (8 GPUs):
+# 2K - 16K: PyNCCL-symm wins (1.35x - 1.48x faster)
+# 32K - 64K: custom_AR wins
+# 128K - 1G: PyNCCL-symm wins (1.12x - 6.14x faster)
+#
+# Benchmark results (4 GPUs):
+# 2K - 16K: PyNCCL-symm wins (1.21x - 1.30x faster)
+# 32K - 256K: custom_AR wins (1.07x - 1.35x faster)
+# 512K - 1G: PyNCCL-symm wins (1.10x - 2.32x faster)
+#
+# The config defines ranges where custom_AR is preferred (symm_mem disabled).
NCCL_SYMM_MEM_ALL_REDUCE_CONFIG: dict[str, Any] = {
"min_world_size": 4,
- "thresholds": {
- 4: 2 * MiB, # 2 MB
- 8: 1 * MiB, # 1 MB
+ # Ranges where custom_AR outperforms NCCL symm_mem: (lower_bound, upper_bound)
+ # NCCL symm_mem will NOT be used for sizes in range: lower < size < upper
+ "custom_ar_preferred_ranges": {
+ 4: (16 * KiB, 512 * KiB), # custom_AR wins for 32K-256K
+ 8: (16 * KiB, 128 * KiB), # custom_AR wins for 32K-64K
},
"always_use_above_world_size": 8, # Always use symm mem for world_size > 8
}
def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor) -> bool:
+ """
+ Determine if NCCL symmetric memory allreduce should be used.
+
+ Based on H100 and GB200 benchmarks, NCCL symm_mem is preferred for:
+ - Small tensors (≤16K): Lower latency than custom_AR
+ - Large tensors (≥128K for 8 GPUs, ≥512K for 4 GPUs): Better bandwidth
+
+ Custom_AR is preferred for mid-range sizes where its P2P approach
+ has lower overhead than the symm_mem copy-in/copy-out pattern.
+ """
from vllm.distributed.device_communicators.pynccl_allocator import (
is_symmetric_memory_enabled,
)
@@ -80,11 +108,20 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor)
if not is_symmetric_memory_enabled():
return False
+
if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]:
return False
- threshold = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["thresholds"].get(world_size)
- if threshold is not None and input_tensor.nbytes >= threshold:
- return True
+
+ tensor_size = input_tensor.nbytes
+ custom_ar_range = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["custom_ar_preferred_ranges"].get(
+ world_size
+ )
+
+ if custom_ar_range is not None:
+ lower_bound, upper_bound = custom_ar_range
+ # Use symm_mem for small sizes (≤ lower_bound) and large sizes (≥ upper_bound)
+ # Use custom_AR (not symm_mem) for mid-range sizes
+ return tensor_size <= lower_bound or tensor_size >= upper_bound
return world_size > NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"]
diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py
index fb055ba..52c6954 100644
--- a/vllm/distributed/device_communicators/base_device_communicator.py
+++ b/vllm/distributed/device_communicators/base_device_communicator.py
@@ -30,8 +30,9 @@ class All2AllManagerBase:
rank: int
world_size: int
- def __init__(self, cpu_group):
+ def __init__(self, cpu_group, tcp_store_group=None):
self.cpu_group = cpu_group
+ self.tcp_store_group = tcp_store_group
# compute some common properties
from vllm.distributed.parallel_state import (
@@ -48,12 +49,17 @@ class All2AllManagerBase:
# when we create this object
self.dp_rank = self.dp_group.rank_in_group
self.dp_world_size = self.dp_group.world_size
- self.rank = dist.get_rank(cpu_group)
- self.world_size = dist.get_world_size(cpu_group)
+ self.rank = cpu_group.rank()
+ self.world_size = cpu_group.size()
# all2all communication often has separate implementations for
# intra-node and inter-node communication
- self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
+ if tcp_store_group is None:
+ self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0))
+ else:
+ self.internode = not all(
+ in_the_same_node_as(tcp_store_group, source_rank=0)
+ )
def get_handle(self, kwargs):
# get a handle for the all2all communication,
@@ -122,17 +128,36 @@ class DeviceCommunicatorBase:
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = "",
+ global_ranks: list[int] | None = None,
+ global_world_size: int | None = None,
):
self.device = device or torch.device("cpu")
self.cpu_group = cpu_group
self.device_group = device_group
self.unique_name = unique_name
- self.rank = dist.get_rank(cpu_group)
- self.world_size = dist.get_world_size(cpu_group)
- self.ranks = dist.get_process_group_ranks(cpu_group)
- self.global_rank = dist.get_rank()
- self.global_world_size = dist.get_world_size()
- self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
+
+ # Check if this is a stateless process group
+ from torch.distributed.distributed_c10d import _world
+
+ is_stateless = _world.pg_map.get(cpu_group, None) is None
+
+ if is_stateless:
+ # For stateless groups, we can't use torch.distributed methods
+ self.rank = cpu_group.rank()
+ self.world_size = cpu_group.size()
+ assert global_ranks is not None
+ assert global_world_size is not None
+ self.ranks = global_ranks
+ self.global_rank = self.ranks[self.rank]
+ self.global_world_size = global_world_size
+ self.rank_in_group = self.rank
+ else:
+ self.rank = dist.get_rank(cpu_group)
+ self.world_size = dist.get_world_size(cpu_group)
+ self.ranks = dist.get_process_group_ranks(cpu_group)
+ self.global_rank = dist.get_rank()
+ self.global_world_size = dist.get_world_size()
+ self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank)
use_ep = False
all2all_backend = None
@@ -146,7 +171,7 @@ class DeviceCommunicatorBase:
use_ep = config.parallel_config.data_parallel_size > 1
all2all_backend = config.parallel_config.all2all_backend
- self.is_ep_communicator = "ep" in unique_name
+ self.is_ep_communicator = unique_name.split(":")[0] == "ep"
self.use_all2all = self.is_ep_communicator and use_ep
self.all2all_backend = all2all_backend
self.all2all_manager: All2AllManagerBase | None = None
@@ -175,9 +200,7 @@ class DeviceCommunicatorBase:
group=self.device_group,
async_op=True)
else:
- torch.distributed.all_gather_into_tensor(output_tensor,
- input_,
- group=self.device_group)
+ dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group)
# Reshape
output_tensor = output_tensor.reshape((self.world_size,) + input_size)
output_tensor = output_tensor.movedim(0, dim)
@@ -263,10 +286,9 @@ class DeviceCommunicatorBase:
group=self.device_group,
async_op=True)
else:
- torch.distributed.gather(input_,
- gather_list,
- dst=self.ranks[dst],
- group=self.device_group)
+ torch.distributed.gather(
+ input_, gather_list, dst=self.ranks[dst], group=self.device_group
+ )
if self.rank_in_group == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
@@ -292,6 +314,13 @@ class DeviceCommunicatorBase:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
+ def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
+ """Broadcast a tensor from source rank to all ranks."""
+ if self.world_size == 1:
+ return tensor
+ torch.distributed.broadcast(tensor, self.ranks[src], self.device_group)
+ return tensor
+
def destroy(self):
pass
@@ -360,3 +389,6 @@ class DeviceCommunicatorBase:
This is a no-op in the base class.
"""
return hidden_states
+
+ def batch_isend_irecv(self, p2p_ops: list):
+ raise NotImplementedError
diff --git a/vllm/distributed/device_communicators/cpu_communicator.py b/vllm/distributed/device_communicators/cpu_communicator.py
index 23be8fc..2bce5fa 100644
--- a/vllm/distributed/device_communicators/cpu_communicator.py
+++ b/vllm/distributed/device_communicators/cpu_communicator.py
@@ -35,8 +35,15 @@ class CpuCommunicator(DeviceCommunicatorBase):
)
and hasattr(torch.ops._C, "init_shm_manager")
and (unique_name.startswith("tp") or unique_name.startswith("pp"))
+ and self._all_group_ranks_share_shm_group_name()
):
self.dist_module = _CPUSHMDistributed(self)
+ elif unique_name.startswith("tp") or unique_name.startswith("pp"):
+ logger.info(
+ "CPU SHM communicator disabled for group %s: ranks do not share "
+ "the same SHM group name, falling back to torch.distributed.",
+ unique_name,
+ )
if self.use_all2all:
if self.all2all_backend != "naive": # type: ignore[has-type]
@@ -52,6 +59,20 @@ class CpuCommunicator(DeviceCommunicatorBase):
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
logger.info("Using naive all2all manager.")
+ def _all_group_ranks_share_shm_group_name(self) -> bool:
+ """
+ CPUSHM requires all ranks in this group to agree on one SHM group name.
+ This is a lightweight consistency check for VLLM_DIST_IDENT/name inputs.
+ """
+ local_name = _CPUSHMDistributed.make_group_name(self)
+ names: list[str] = [""] * self.world_size
+ torch.distributed.all_gather_object(
+ names,
+ local_name,
+ group=self.device_group,
+ )
+ return len(set(names)) == 1
+
def all_reduce(self, input_):
self.dist_module.all_reduce(input_, group=self.device_group)
return input_
@@ -193,16 +214,20 @@ class CpuCommunicator(DeviceCommunicatorBase):
class _CPUSHMDistributed:
def __init__(self, communicator: CpuCommunicator):
+ self.communicator = communicator
+
+ self.group_name = self.make_group_name(communicator)
+
+ self.handle = self._init_cpu_shm()
+
+ @staticmethod
+ def make_group_name(communicator: CpuCommunicator) -> str:
instance_identifier = os.environ["VLLM_DIST_IDENT"]
unique_name = communicator.unique_name
instance_identifier = f"{instance_identifier}-{unique_name}"
- self.communicator = communicator
-
- group_ranks = [str(rank) for rank in self.communicator.ranks]
+ group_ranks = [str(rank) for rank in communicator.ranks]
shm_group_identifier = f"[{'-'.join(group_ranks)}]"
- self.group_name = f"{instance_identifier}-{shm_group_identifier}-cpushm"
-
- self.handle = self._init_cpu_shm()
+ return f"{instance_identifier}-{shm_group_identifier}-cpushm"
def _init_cpu_shm(self) -> int:
thread_num_tensor = torch.tensor(
diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py
index 6302e50..519a37f 100644
--- a/vllm/distributed/device_communicators/cuda_communicator.py
+++ b/vllm/distributed/device_communicators/cuda_communicator.py
@@ -16,6 +16,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import (
from vllm.logger import init_logger
from vllm.platforms import current_platform
+from ..utils import StatelessProcessGroup
from .base_device_communicator import DeviceCommunicatorBase
import ixformer.distributed as ixfd
import os
@@ -29,8 +30,18 @@ class CudaCommunicator(DeviceCommunicatorBase):
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = "",
+ global_ranks: list[int] | None = None,
+ global_world_size: int | None = None,
+ tcp_store_group: StatelessProcessGroup | None = None,
):
- super().__init__(cpu_group, device, device_group, unique_name)
+ super().__init__(
+ cpu_group,
+ device,
+ device_group,
+ unique_name,
+ global_ranks,
+ global_world_size,
+ )
if "tp" not in unique_name:
# custom allreduce or torch symm mem can be used only by tp
use_custom_allreduce = False
@@ -46,8 +57,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.use_custom_allreduce = use_custom_allreduce
self.use_torch_symm_mem = use_torch_symm_mem
self.use_flashinfer_allreduce = use_flashinfer_allreduce
-
+
self.use_vllm_comm = os.environ.get("VLLM_FORCE_NCCL_COMM",None) not in ["1", "Y", "y"]
+
# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce,
@@ -64,7 +76,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.pynccl_comm: PyNcclCommunicator | None = None
if self.world_size > 1:
self.pynccl_comm = PyNcclCommunicator(
- group=self.cpu_group,
+ group=self.cpu_group if tcp_store_group is None else tcp_store_group,
device=self.device,
)
if is_symmetric_memory_enabled():
@@ -109,23 +121,27 @@ class CudaCommunicator(DeviceCommunicatorBase):
if self.all2all_backend == "naive":
from .all2all import NaiveAll2AllManager
- self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
+ self.all2all_manager = NaiveAll2AllManager(
+ self.cpu_group, tcp_store_group
+ )
elif self.all2all_backend == "allgather_reducescatter":
from .all2all import AgRsAll2AllManager
- self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
- elif self.all2all_backend == "pplx":
- from .all2all import PPLXAll2AllManager
-
- self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
+ self.all2all_manager = AgRsAll2AllManager(
+ self.cpu_group, tcp_store_group
+ )
elif self.all2all_backend == "deepep_high_throughput":
from .all2all import DeepEPHTAll2AllManager
- self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
+ self.all2all_manager = DeepEPHTAll2AllManager(
+ self.cpu_group, tcp_store_group
+ )
elif self.all2all_backend == "deepep_low_latency":
from .all2all import DeepEPLLAll2AllManager
- self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
+ self.all2all_manager = DeepEPLLAll2AllManager(
+ self.cpu_group, tcp_store_group
+ )
elif self.all2all_backend == "mori":
from .all2all import MoriAll2AllManager
@@ -133,7 +149,9 @@ class CudaCommunicator(DeviceCommunicatorBase):
elif self.all2all_backend == "flashinfer_all2allv":
from .all2all import FlashInferAllToAllManager
- self.all2all_manager = FlashInferAllToAllManager(self.cpu_group)
+ self.all2all_manager = FlashInferAllToAllManager(
+ self.cpu_group, tcp_store_group
+ )
else:
raise ValueError(f"Unknown all2all backend: {self.all2all_backend}")
@@ -188,27 +206,12 @@ class CudaCommunicator(DeviceCommunicatorBase):
return out
if self.world_size == 1:
return input_
+
if self.use_vllm_comm:
- # torch.ops.ixf_ops.vllm_all_reduce(input_, async_op=True)
ixfd.all_reduce(input_, group=self.device_group, async_op=True)
else:
torch.distributed.all_reduce(input_, group=self.device_group)
return input_
- pynccl_comm = self.pynccl_comm
- if pynccl_comm is None or pynccl_comm.disabled:
- out = input_.clone()
- torch.distributed.all_reduce(out, group=self.device_group)
- return out
- assert pynccl_comm is not None
- out = pynccl_comm.all_reduce(input_)
- if out is None:
- # fall back to the default all-reduce using PyTorch.
- # this usually happens during testing.
- # when we run the model, allreduce only happens for the TP
- # group, where we always have either custom allreduce or pynccl.
- out = input_.clone()
- torch.distributed.all_reduce(out, group=self.device_group)
- return out
def reduce_scatter(self, input_: torch.Tensor, dim: int = -1):
world_size = self.world_size
@@ -230,10 +233,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
output_shape, dtype=input_tensor.dtype, device=input_tensor.device
)
- # pynccl_comm.reduce_scatter(output, input_tensor)
- torch.distributed.reduce_scatter_tensor(output,
- input_tensor,
- group=self.device_group)
+ # Perform reduce-scatter operation
+ ixfd.reduce_scatter_tensor(output,input_tensor,group=self.device_group, async_op=True)
# Reshape before returning
return output.movedim(0, dim).contiguous()
@@ -278,12 +279,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
"""NOTE: `dst` is the local rank of the destination rank."""
if dst is None:
dst = (self.rank_in_group + 1) % self.world_size
-
- pynccl_comm = self.pynccl_comm
- # if pynccl_comm is not None and not pynccl_comm.disabled:
- # pynccl_comm.send(tensor, dst)
- # else:
- # torch.distributed.send(tensor, self.ranks[dst], self.device_group)
if self.use_vllm_comm:
ixfd.send(tensor, self.ranks[dst], self.device_group)
else:
@@ -298,17 +293,24 @@ class CudaCommunicator(DeviceCommunicatorBase):
src = (self.rank_in_group - 1) % self.world_size
tensor = torch.empty(size, dtype=dtype, device=self.device)
- # pynccl_comm = self.pynccl_comm
- # if pynccl_comm is not None and not pynccl_comm.disabled:
- # pynccl_comm.recv(tensor, src)
- # else:
- # torch.distributed.recv(tensor, self.ranks[src], self.device_group)
if self.use_vllm_comm:
ixfd.recv(tensor, self.ranks[src], self.device_group)
else:
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
return tensor
+ def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor:
+ """Broadcast a tensor from source rank to all ranks."""
+ if self.world_size == 1:
+ return tensor
+
+ pynccl_comm = self.pynccl_comm
+ if pynccl_comm is not None and not pynccl_comm.disabled:
+ pynccl_comm.broadcast(tensor, src)
+ return tensor
+ else:
+ raise ValueError("No PyNCCL communicator found")
+
def destroy(self):
if self.pynccl_comm is not None:
self.pynccl_comm = None
@@ -319,7 +321,7 @@ class CudaCommunicator(DeviceCommunicatorBase):
self.fi_ar_comm = None
if self.all2all_manager is not None:
self.all2all_manager.destroy()
- self.all2all_manager = None
+ self.all2all_manager = None # type: ignore[assignment]
def all_gatherv(
self,
@@ -372,7 +374,6 @@ class CudaCommunicator(DeviceCommunicatorBase):
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
- extra_residual:torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
@@ -409,16 +410,13 @@ class CudaCommunicator(DeviceCommunicatorBase):
This is a no-op in the base class.
"""
assert self.all2all_manager is not None
- # return self.all2all_manager.dispatch(
- # hidden_states,
- # topk_weights,
- # topk_ids,
- # is_sequence_parallel,
- # extra_tensors=extra_tensors,
- # )
- hidden_states, extra_residual, router_logits = self.all2all_manager.dispatch(
- hidden_states, extra_residual, router_logits)
- return hidden_states, extra_residual, router_logits
+ return self.all2all_manager.dispatch(
+ hidden_states,
+ topk_weights,
+ topk_ids,
+ is_sequence_parallel,
+ extra_tensors=extra_tensors,
+ )
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
@@ -432,3 +430,10 @@ class CudaCommunicator(DeviceCommunicatorBase):
hidden_states,
is_sequence_parallel,
)
+
+ def batch_isend_irecv(self, p2p_ops: list):
+ pynccl_comm = self.pynccl_comm
+ if pynccl_comm is not None and not pynccl_comm.disabled:
+ pynccl_comm.batch_isend_irecv(p2p_ops)
+ else:
+ raise ValueError("No PyNCCL communicator found")
diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py
index 2fc35e8..44dc113 100644
--- a/vllm/distributed/device_communicators/pynccl.py
+++ b/vllm/distributed/device_communicators/pynccl.py
@@ -312,10 +312,19 @@ class PyNcclCommunicator:
)
if stream is None:
stream = current_stream()
+ if tensor.dtype in [
+ torch.float8_e5m2,
+ torch.float8_e4m3fn,
+ torch.float8_e4m3fnuz,
+ torch.float8_e5m2fnuz,
+ ]:
+ nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
+ else:
+ nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
self.nccl.ncclSend(
buffer_type(tensor.data_ptr()),
tensor.numel(),
- ncclDataTypeEnum.from_torch(tensor.dtype),
+ nccl_dtype,
dst,
self.comm,
cudaStream_t(stream.cuda_stream),
@@ -330,10 +339,19 @@ class PyNcclCommunicator:
)
if stream is None:
stream = current_stream()
+ if tensor.dtype in [
+ torch.float8_e5m2,
+ torch.float8_e4m3fn,
+ torch.float8_e4m3fnuz,
+ torch.float8_e5m2fnuz,
+ ]:
+ nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8)
+ else:
+ nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype)
self.nccl.ncclRecv(
buffer_type(tensor.data_ptr()),
tensor.numel(),
- ncclDataTypeEnum.from_torch(tensor.dtype),
+ nccl_dtype,
src,
self.comm,
cudaStream_t(stream.cuda_stream),
@@ -384,3 +402,17 @@ class PyNcclCommunicator:
def deregister_comm_window(self, window):
return self.nccl.ncclCommWindowDeregister(self.comm, window)
+
+ def batch_isend_irecv(self, p2p_ops: list, stream=None):
+ if self.disabled:
+ return
+ if stream is None:
+ stream = current_stream()
+ self.group_start()
+ for op in p2p_ops:
+ if op.op is torch.distributed.isend:
+ self.send(op.tensor, op.group_peer, stream)
+ elif op.op is torch.distributed.irecv:
+ self.recv(op.tensor, op.group_peer, stream)
+
+ self.group_end()
diff --git a/vllm/distributed/elastic_ep/__init__.py b/vllm/distributed/elastic_ep/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/vllm/distributed/elastic_ep/elastic_execute.py b/vllm/distributed/elastic_ep/elastic_execute.py
new file mode 100644
index 0000000..22d5706
--- /dev/null
+++ b/vllm/distributed/elastic_ep/elastic_execute.py
@@ -0,0 +1,529 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import copy
+import gc
+import weakref
+from collections.abc import Iterable, Sequence
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.distributed import P2POp
+
+from vllm.compilation.counter import compilation_counter
+from vllm.compilation.cuda_graph import CUDAGraphWrapper
+from vllm.compilation.wrapper import reset_compile_wrapper
+from vllm.config import (
+ CompilationMode,
+ set_current_vllm_config,
+)
+from vllm.distributed import (
+ get_dp_group,
+ get_ep_group,
+ get_pcp_group,
+ get_tp_group,
+)
+from vllm.distributed.elastic_ep.standby_state import (
+ create_standby_groups,
+ get_standby_dp_group,
+ get_standby_ep_group,
+ pop_standby_groups,
+)
+from vllm.distributed.parallel_state import (
+ _replace_active_groups,
+ prepare_communication_buffer_for_model,
+)
+from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
+from vllm.logger import init_logger
+from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig
+from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
+from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
+from vllm.v1.worker.workspace import lock_workspace, unlock_workspace
+
+logger = init_logger(__name__)
+
+
+def batch_transfer_weights(
+ model: nn.Module,
+ is_sender: bool,
+ peer_rank: int,
+ dp_group: StatelessGroupCoordinator,
+ expert_weights: Sequence[Iterable[torch.Tensor]],
+) -> None:
+ device_comm = dp_group.device_communicator
+ if device_comm is None:
+ raise ValueError("No device communicator found")
+
+ expert_weights_set = set()
+ for weight_group in expert_weights:
+ for weight in weight_group:
+ expert_weights_set.add(weight.data_ptr())
+
+ state_dict = model.state_dict()
+ all_params = []
+
+ for name, param in state_dict.items():
+ if name.endswith("expert_map"):
+ continue
+ if param.data_ptr() not in expert_weights_set:
+ all_params.append(param.data)
+
+ assert len(all_params) > 0
+ p2p_ops = []
+ for param in all_params:
+ op = object.__new__(P2POp)
+ if is_sender:
+ op.op = torch.distributed.isend
+ op.tensor = param
+ else:
+ op.op = torch.distributed.irecv
+ op.tensor = param
+ op.group_peer = peer_rank
+ p2p_ops.append(op)
+ device_comm.batch_isend_irecv(p2p_ops)
+
+
+def broadcast_expert_mapping(
+ physical_to_logical: torch.Tensor | None,
+ num_local_physical_experts: int | None,
+ num_logical_experts: int | None,
+ dp_group: StatelessGroupCoordinator,
+ device: torch.device,
+ src_rank: int = 0,
+) -> tuple[torch.Tensor, int, int]:
+ if dp_group.rank_in_group == src_rank:
+ assert physical_to_logical is not None
+ assert num_local_physical_experts is not None
+ assert num_logical_experts is not None
+ assert physical_to_logical.dtype == torch.int64
+ shape_tensor = torch.tensor(
+ list(physical_to_logical.shape), dtype=torch.int64, device="cpu"
+ )
+ metadata_tensor = torch.tensor(
+ [num_local_physical_experts, num_logical_experts],
+ dtype=torch.int64,
+ device="cpu",
+ )
+ else:
+ shape_tensor = torch.empty(2, dtype=torch.int64, device="cpu")
+ metadata_tensor = torch.empty(2, dtype=torch.int64, device="cpu")
+
+ shape_tensor = dp_group.tcp_store_group.broadcast(shape_tensor, src_rank)
+ metadata_tensor = dp_group.tcp_store_group.broadcast(metadata_tensor, src_rank)
+
+ if dp_group.rank_in_group != src_rank:
+ assert device is not None
+ physical_to_logical = torch.empty(
+ tuple(shape_tensor.tolist()),
+ dtype=torch.int64,
+ device=device,
+ )
+
+ assert physical_to_logical is not None
+ physical_to_logical = dp_group.broadcast(physical_to_logical, src_rank)
+ num_local_physical_experts = int(metadata_tensor[0].item())
+ num_logical_experts = int(metadata_tensor[1].item())
+
+ return physical_to_logical, num_local_physical_experts, num_logical_experts
+
+
+class ElasticEPScalingExecutor:
+ def __init__(self, worker):
+ self.worker_ref = weakref.ref(worker)
+ self.reconfig_request = None
+
+ @property
+ def worker(self):
+ worker = self.worker_ref()
+ if worker is None:
+ raise RuntimeError("Worker has been garbage collected")
+ return worker
+
+ def execute(self, execute_method: str, *args, **kwargs):
+ method = getattr(self, execute_method, None)
+ if method is None:
+ raise ValueError(f"Unknown execute method: {execute_method}")
+ return method(*args, **kwargs)
+
+ def create_standby_groups(
+ self, reconfig_request: ReconfigureDistributedRequest
+ ) -> None:
+ self.reconfig_request = reconfig_request
+ new_dp_size = reconfig_request.new_data_parallel_size
+ world_size = self.worker.vllm_config.parallel_config.world_size
+ new_world_size_across_dp = world_size * new_dp_size
+ updated_config = copy.copy(self.worker.vllm_config)
+ updated_config.parallel_config = copy.deepcopy(
+ self.worker.vllm_config.parallel_config
+ )
+ updated_config.parallel_config.data_parallel_size = new_dp_size
+ with set_current_vllm_config(updated_config):
+ create_standby_groups(
+ new_dp_size=new_dp_size,
+ new_world_size_across_dp=new_world_size_across_dp,
+ master_ip=reconfig_request.new_data_parallel_master_ip,
+ world_group_ports=reconfig_request.new_stateless_world_group_port_list,
+ dp_group_ports=reconfig_request.new_stateless_dp_group_port_list,
+ ep_group_ports=reconfig_request.new_stateless_ep_group_port_list,
+ eplb_group_ports=reconfig_request.new_stateless_eplb_group_port_list,
+ )
+ self.worker.model_runner.eep_eplb_suppressed = True
+ standby_ep_group = get_standby_ep_group()
+ assert standby_ep_group is not None
+ if standby_ep_group.rank == 0:
+ logger.info("[Elastic EP] EPLB disabled during elastic scaling transition")
+
+ def transfer_weights(self, old_dp_size: int, new_dp_size: int) -> None:
+ standby_dp_group = get_standby_dp_group()
+ assert standby_dp_group is not None
+ # Broadcast old_dp_size to all workers in standby group
+ if standby_dp_group.rank_in_group < old_dp_size:
+ old_dp_size_tensor = torch.tensor(
+ [old_dp_size], dtype=torch.int64, device="cpu"
+ )
+ else:
+ old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu")
+ old_dp_size_tensor = standby_dp_group.tcp_store_group.broadcast(
+ old_dp_size_tensor, 0
+ )
+
+ num_new_workers = new_dp_size - old_dp_size
+ dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank
+
+ # Sender-receiver pairing: the first new_workers % old_dp_size
+ # senders get (k+1) contiguous receivers, the rest get k
+ # receivers.
+ num_dst_per_sender = num_new_workers // old_dp_size
+ remainder = num_new_workers % old_dp_size
+
+ if dp_rank < remainder:
+ recv_begin = dp_rank * (num_dst_per_sender + 1)
+ recv_end = recv_begin + num_dst_per_sender + 1
+ else:
+ recv_begin = (
+ remainder * (num_dst_per_sender + 1)
+ + (dp_rank - remainder) * num_dst_per_sender
+ )
+ recv_end = recv_begin + num_dst_per_sender
+
+ ranks_to_send = list(range(old_dp_size + recv_begin, old_dp_size + recv_end))
+
+ model = self.worker.model_runner.get_model()
+ for new_worker_rank in sorted(ranks_to_send):
+ batch_transfer_weights(
+ model=model,
+ is_sender=True,
+ peer_rank=new_worker_rank,
+ dp_group=standby_dp_group,
+ expert_weights=model.expert_weights,
+ )
+ torch.cuda.synchronize()
+
+ def broadcast_expert_mapping(self) -> None:
+ standby_dp_group = get_standby_dp_group()
+ assert standby_dp_group is not None
+ model_config = self.worker.model_runner.model_config
+ eplb_state = self.worker.model_runner.eplb_state
+ assert eplb_state is not None
+ eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
+ physical_to_logical = eplb_model_state.physical_to_logical_map
+ num_physical_experts = physical_to_logical.shape[1]
+ num_local_physical_experts = num_physical_experts // get_ep_group().world_size
+ num_logical_experts = eplb_model_state.logical_replica_count.shape[1]
+ broadcast_expert_mapping(
+ physical_to_logical=physical_to_logical,
+ num_local_physical_experts=num_local_physical_experts,
+ num_logical_experts=num_logical_experts,
+ dp_group=standby_dp_group,
+ src_rank=0,
+ device=self.worker.device,
+ )
+
+ def switch_and_remove(self) -> None:
+ _replace_active_groups(world=None, dp=None, ep=None, eplb=None, node_count=None)
+
+ def switch_and_prepare(self) -> None:
+ old_dp_size = get_dp_group().world_size
+ old_ep_size = get_ep_group().world_size
+
+ _replace_active_groups(**pop_standby_groups())
+
+ parallel_config = self.worker.vllm_config.parallel_config
+ reconfig_request = self.reconfig_request
+ assert reconfig_request is not None
+ new_dp_size = reconfig_request.new_data_parallel_size
+ new_ep_size = get_ep_group().world_size
+
+ parallel_config.data_parallel_size = new_dp_size
+ if (
+ reconfig_request.new_data_parallel_rank
+ != ReconfigureRankType.KEEP_CURRENT_RANK
+ ):
+ parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
+ if (
+ reconfig_request.new_data_parallel_rank_local
+ != ReconfigureRankType.KEEP_CURRENT_RANK
+ ):
+ parallel_config.data_parallel_rank_local = (
+ reconfig_request.new_data_parallel_rank_local
+ )
+ parallel_config.data_parallel_master_ip = (
+ reconfig_request.new_data_parallel_master_ip
+ )
+ parallel_config.data_parallel_master_port = (
+ reconfig_request.new_data_parallel_master_port
+ )
+
+ # Reconfigure MoE modules with new EP size
+ moe_modules = [
+ module
+ for module in self.worker.model_runner.model.modules()
+ if (
+ module.__class__.__name__ == "FusedMoE"
+ or module.__class__.__name__ == "SharedFusedMoE"
+ )
+ ]
+ num_local_experts = moe_modules[0].moe_config.num_local_experts
+ assert all(
+ module.moe_config.num_local_experts == num_local_experts
+ for module in moe_modules
+ ), "All MoE modules must have the same number of experts"
+ for module in moe_modules:
+ module.moe_config.num_experts = num_local_experts * new_ep_size
+ module.global_num_experts = module.moe_config.num_experts
+ tp_size = get_tp_group().world_size
+ is_sequence_parallel = parallel_config.use_sequence_parallel_moe
+ sp_size = tp_size if is_sequence_parallel else 1
+ module.moe_parallel_config = FusedMoEParallelConfig.make(
+ tp_size_=tp_size,
+ pcp_size_=get_pcp_group().world_size,
+ dp_size_=get_dp_group().world_size,
+ sp_size_=sp_size,
+ vllm_parallel_config=parallel_config,
+ )
+ module.moe_config.moe_parallel_config = module.moe_parallel_config
+
+ # Update EPLB state
+ eplb_state = self.worker.model_runner.eplb_state
+ assert eplb_state is not None
+ model_config = self.worker.model_runner.model_config
+ eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
+
+ num_physical_experts = num_local_experts * new_ep_size
+ num_logical_experts = eplb_model_state.logical_replica_count.shape[1]
+ parallel_config.eplb_config.num_redundant_experts = (
+ num_physical_experts - num_logical_experts
+ )
+ old_physical_to_logical = eplb_model_state.physical_to_logical_map
+ num_moe_layers = old_physical_to_logical.shape[0]
+ num_local_experts = eplb_model_state.expert_load_pass.shape[1] // old_ep_size
+ if new_dp_size > old_dp_size:
+ expanded_physical_to_logical = torch.full(
+ (num_moe_layers, num_local_experts * new_ep_size),
+ -1,
+ dtype=old_physical_to_logical.dtype,
+ device=old_physical_to_logical.device,
+ )
+ expanded_physical_to_logical[:, : num_local_experts * old_ep_size] = (
+ old_physical_to_logical
+ )
+ eplb_model_state.physical_to_logical_map = expanded_physical_to_logical
+
+ old_num_physical_experts = eplb_model_state.expert_load_pass.shape[1]
+ pad_size = num_physical_experts - old_num_physical_experts
+ if new_dp_size > old_dp_size:
+ assert pad_size > 0
+ expanded_expert_load_pass = F.pad(
+ eplb_model_state.expert_load_pass, (0, pad_size), value=0
+ )
+ expanded_expert_load_window = F.pad(
+ eplb_model_state.expert_load_window, (0, pad_size), value=0
+ )
+ eplb_model_state.expert_load_pass = expanded_expert_load_pass
+ eplb_model_state.expert_load_window = expanded_expert_load_window
+ eplb_state.num_valid_physical_experts = old_num_physical_experts
+ else:
+ assert pad_size < 0
+ eplb_model_state.expert_load_pass = eplb_model_state.expert_load_pass[
+ :, :num_physical_experts
+ ]
+ eplb_model_state.expert_load_window = eplb_model_state.expert_load_window[
+ :, :, :num_physical_experts
+ ]
+ eplb_state.num_valid_physical_experts = num_physical_experts
+
+ model = self.worker.model_runner.get_model()
+ model.expert_weights = []
+ with set_current_vllm_config(self.worker.vllm_config):
+ model.set_eplb_state(
+ eplb_model_state.expert_load_pass,
+ eplb_model_state.logical_to_physical_map,
+ eplb_model_state.logical_replica_count,
+ )
+ model.update_physical_experts_metadata(
+ num_physical_experts=num_physical_experts,
+ num_local_physical_experts=num_local_experts,
+ )
+ # Force re-creation of the modular kernel (and all2all manager)
+ # for the new EP size by resetting quant_method to base
+ for module in moe_modules:
+ if hasattr(module.quant_method, "old_quant_method"):
+ module.quant_method = module.quant_method.old_quant_method
+ module.runner = module._init_runner()
+ prepare_communication_buffer_for_model(self.worker.model_runner.model)
+ if (
+ self.worker.vllm_config.compilation_config.mode
+ == CompilationMode.STOCK_TORCH_COMPILE
+ ):
+ # NOTE(yongji): when using stock torch.compile,
+ # torch.compile is triggered during GPUModelRunner's load_model()
+ # TODO(yongji):check do we need to re-trigger torch.compile here?
+ # any changes to the tensor shapes in execution should already
+ # be handled internally by torch.compile.
+ backend = self.worker.vllm_config.compilation_config.init_backend(
+ self.worker.vllm_config
+ )
+ compilation_counter.stock_torch_compile_count += 1
+ self.worker.model_runner.model.compile(fullgraph=True, backend=backend)
+
+ # release all previously captured CUDA graphs
+ if isinstance(self.worker.model_runner.model, CUDAGraphWrapper):
+ wrapper = self.worker.model_runner.model
+ wrapper.concrete_cudagraph_entries = {}
+ elif isinstance(self.worker.model_runner.model, UBatchWrapper):
+ raise RuntimeError("DBO is not yet supported in elastic EP")
+
+ multi_block_table = self.worker.model_runner.input_batch.block_table
+ saved_block_tables: list[tuple[torch.Tensor, torch.Tensor]] = []
+ for bt in multi_block_table.block_tables:
+ saved_block_tables.append(
+ (bt.block_table.gpu.clone(), bt.block_table.cpu.clone())
+ )
+ multi_block_table.clear()
+
+ # reset the compile wrapper
+ torch.compiler.reset()
+ with set_current_vllm_config(self.worker.vllm_config):
+ reset_compile_wrapper(self.worker.model_runner.get_model())
+
+ gc.collect()
+ torch.cuda.synchronize()
+ torch.cuda.empty_cache()
+ unlock_workspace()
+ self.worker.compile_or_warm_up_model()
+ lock_workspace()
+
+ for bt, (saved_gpu, saved_cpu) in zip(
+ multi_block_table.block_tables, saved_block_tables
+ ):
+ bt.block_table.gpu.copy_(saved_gpu)
+ bt.block_table.cpu.copy_(saved_cpu)
+
+ def perform_eplb_reshuffle(self, new_dp_size: int | None = None) -> None:
+ if get_ep_group().rank == 0:
+ logger.info("[Elastic EP] Starting expert resharding...")
+
+ eplb_state = self.worker.model_runner.eplb_state
+ assert eplb_state is not None
+
+ model_config = self.worker.model_runner.model_config
+ eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
+ is_async_enabled = eplb_state.is_async
+ eplb_state.is_async = False
+ if new_dp_size is None:
+ eplb_state.rearrange()
+ else:
+ # scale down
+ parallel_config = self.worker.vllm_config.parallel_config
+ tp_size = parallel_config.tensor_parallel_size
+ old_ep_size = parallel_config.data_parallel_size * tp_size
+ new_ep_size = new_dp_size * tp_size
+
+ rank_mapping = {
+ old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1
+ for old_ep_rank in range(old_ep_size)
+ }
+
+ eplb_state.rearrange(rank_mapping=rank_mapping)
+ # NOTE(yongji): check whether we need to synchronize here
+ torch.cuda.synchronize()
+ # reset expert_rearrangement_step to ensure all ranks are synchronized
+ eplb_state.expert_rearrangement_step = 0
+ eplb_state.num_valid_physical_experts = (
+ eplb_model_state.physical_to_logical_map.shape[1]
+ )
+ eplb_state.is_async = is_async_enabled
+ self.worker.model_runner.eep_eplb_suppressed = False
+ if get_ep_group().rank == 0:
+ logger.info("[Elastic EP] Expert resharding completed")
+
+ def receive_weights(self) -> None:
+ dp_group = get_dp_group()
+ assert isinstance(dp_group, StatelessGroupCoordinator)
+ new_dp_size = dp_group.world_size
+ dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank
+
+ # Receive old_dp_size broadcasted during transfer_weights
+ old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu")
+ old_dp_size_tensor = dp_group.tcp_store_group.broadcast(old_dp_size_tensor, 0)
+ old_dp_size = int(old_dp_size_tensor[0].item())
+
+ # Calculate which existing worker will send to this new worker
+ num_new_workers = new_dp_size - old_dp_size
+ new_worker_idx = dp_rank - old_dp_size
+ num_dst_per_sender = num_new_workers // old_dp_size
+ remainder = num_new_workers % old_dp_size
+
+ if new_worker_idx < remainder * (num_dst_per_sender + 1):
+ sender_rank = new_worker_idx // (num_dst_per_sender + 1)
+ else:
+ sender_rank = (
+ remainder
+ + (new_worker_idx - remainder * (num_dst_per_sender + 1))
+ // num_dst_per_sender
+ )
+
+ model = self.worker.model_runner.get_model()
+ batch_transfer_weights(
+ model=model,
+ is_sender=False,
+ peer_rank=sender_rank,
+ dp_group=dp_group,
+ expert_weights=model.expert_weights,
+ )
+ torch.cuda.synchronize()
+
+ def receive_expert_mapping(self) -> tuple[torch.Tensor, int, int]:
+ dp_group = get_dp_group()
+ assert isinstance(dp_group, StatelessGroupCoordinator)
+ physical_to_logical, num_local_physical_experts, num_logical_experts = (
+ broadcast_expert_mapping(
+ physical_to_logical=None,
+ num_local_physical_experts=None,
+ num_logical_experts=None,
+ dp_group=dp_group,
+ src_rank=0,
+ device=self.worker.device,
+ )
+ )
+ num_moe_layers = physical_to_logical.shape[0]
+ new_dp_size = get_dp_group().world_size
+ tp_size = self.worker.vllm_config.parallel_config.tensor_parallel_size
+ new_ep_size = new_dp_size * tp_size
+ expanded_physical_to_logical = torch.full(
+ (num_moe_layers, num_local_physical_experts * new_ep_size),
+ -1,
+ dtype=physical_to_logical.dtype,
+ device=physical_to_logical.device,
+ )
+ old_num_physical_experts = physical_to_logical.shape[1]
+ expanded_physical_to_logical[:, :old_num_physical_experts] = physical_to_logical
+ return (
+ expanded_physical_to_logical,
+ num_logical_experts,
+ old_num_physical_experts,
+ )
+
+ def prepare_new_worker(self) -> None:
+ with set_current_vllm_config(self.worker.vllm_config):
+ prepare_communication_buffer_for_model(self.worker.model_runner.get_model())
diff --git a/vllm/distributed/elastic_ep/elastic_state.py b/vllm/distributed/elastic_ep/elastic_state.py
new file mode 100644
index 0000000..4845a16
--- /dev/null
+++ b/vllm/distributed/elastic_ep/elastic_state.py
@@ -0,0 +1,563 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import enum
+import time
+import weakref
+from datetime import timedelta
+from typing import TYPE_CHECKING, Literal
+
+import torch.distributed
+
+from vllm.config import ParallelConfig
+from vllm.distributed import (
+ sched_yield,
+ stateless_destroy_torch_distributed_process_group,
+)
+from vllm.logger import init_logger
+from vllm.v1.engine import (
+ EEPNotificationType,
+ ReconfigureDistributedRequest,
+ ReconfigureRankType,
+)
+from vllm.v1.engine.core import DPEngineCoreProc
+
+if TYPE_CHECKING:
+ from vllm.config import VllmConfig
+ from vllm.v1.executor.abstract import Executor
+
+logger = init_logger(__name__)
+
+WorkerType = Literal["existing", "new", "removing"]
+
+
+class ScaleUpExistingEngineState(enum.IntEnum):
+ WAIT_NEW_CORE_ENGINES_INIT = 0
+ CREATE_STANDBY_GROUPS = 1
+ TRANSFER_EXPERT_MAPPING = 2
+ WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT = 3
+ TRANSFER_WEIGHTS = 4
+ SYNC_KV_CACHE_MEMORY_SIZE = 5
+ SWITCH_AND_PREPARE = 6
+ EPLB_RESHUFFLE = 7
+ COMPLETE = 8
+
+
+class ScaleUpNewEngineState(enum.IntEnum):
+ PREPARE = 0
+ EPLB_RESHUFFLE = 1
+ COMPLETE = 2
+
+
+class ScaleDownRemainingEngineState(enum.IntEnum):
+ PREPARE = 0
+ EPLB_RESHUFFLE = 1
+ SWITCH_AND_PREPARE = 2
+ COMPLETE = 3
+
+
+class ScaleDownRemovingEngineState(enum.IntEnum):
+ PREPARE = 0
+ EPLB_RESHUFFLE = 1
+ COMPLETE = 2
+
+
+class _BarrierTimeoutError(RuntimeError):
+ """
+ Exception raised for timeout
+ in the first stage of our two-staged
+ TCPStore based barrier to synchronize the
+ execution of all engines in the DP group.
+ """
+
+
+class ElasticEPScalingState:
+ def __init__(
+ self,
+ model_executor: "Executor",
+ engine_core: "DPEngineCoreProc",
+ vllm_config: "VllmConfig",
+ new_parallel_config: ParallelConfig,
+ worker_type: WorkerType,
+ scale_type: Literal["scale_up", "scale_down"],
+ reconfig_request: ReconfigureDistributedRequest | None = None,
+ ):
+ self.model_executor_ref = weakref.ref(model_executor)
+ self.engine_core_ref = weakref.ref(engine_core)
+ self.vllm_config = vllm_config
+ self.old_dp_group = self.engine_core.dp_group if worker_type != "new" else None
+ self.old_dp_store = self.engine_core.dp_store if worker_type != "new" else None
+ self.new_parallel_config: ParallelConfig = new_parallel_config
+ self.new_dp_group: torch.distributed.ProcessGroup | None = (
+ self.engine_core.dp_group if worker_type == "new" else None
+ )
+ self.new_dp_store = self.engine_core.dp_store if worker_type == "new" else None
+ self.worker_type = worker_type
+ self.scale_type = scale_type
+ self.reconfig_request = reconfig_request
+
+ if scale_type == "scale_up":
+ self.state = (
+ ScaleUpNewEngineState.PREPARE
+ if worker_type == "new"
+ else ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT
+ )
+ else:
+ self.state = (
+ ScaleDownRemovingEngineState.PREPARE
+ if worker_type == "removing"
+ else ScaleDownRemainingEngineState.PREPARE
+ )
+
+ @property
+ def model_executor(self) -> "Executor":
+ model_executor = self.model_executor_ref()
+ if model_executor is None:
+ raise RuntimeError("Model executor has been garbage collected")
+ return model_executor
+
+ @property
+ def engine_core(self) -> "DPEngineCoreProc":
+ engine_core = self.engine_core_ref()
+ if engine_core is None:
+ raise RuntimeError("Engine core has been garbage collected")
+ return engine_core
+
+ def progress(self) -> bool:
+ if self.scale_type == "scale_up":
+ return (
+ self._progress_new_engine()
+ if self.worker_type == "new"
+ else self._progress_existing_engine()
+ )
+ return (
+ self._progress_removing_engine()
+ if self.worker_type == "removing"
+ else self._progress_remaining_engine()
+ )
+
+ def _execute_tcp_store_barrier(
+ self, dp_store, group_rank, group_size, barrier_id, timeout=None
+ ):
+ arrival_key = f"arrival_{barrier_id}_{group_rank}"
+ dp_store.set(arrival_key, b"1")
+
+ start_time = time.time()
+ processes_arrived: set[int] = set()
+
+ while len(processes_arrived) < group_size:
+ if (
+ timeout is not None
+ and time.time() - start_time > timeout.total_seconds()
+ ):
+ raise _BarrierTimeoutError(
+ f"Barrier timed out after {timeout.total_seconds()} seconds"
+ )
+
+ for i in range(group_size):
+ if i in processes_arrived:
+ continue
+
+ key = f"arrival_{barrier_id}_{i}"
+ present = dp_store.check([key])
+ if present:
+ processes_arrived.add(i)
+
+ if len(processes_arrived) < group_size:
+ sched_yield()
+
+ def _staged_barrier(self, use_new_group: bool, barrier_name: str) -> bool:
+ """
+ Execute a two-staged barrier to synchronize all engines in the DP group.
+
+ Some DP EngineCores may receive the reconfiguration notifications
+ later than others, and already proceed to engine step (model forward)
+ in the busy loop.
+ In this case, EngineCores that already proceed to reconfiguration
+ should skip reconfiguration and execute model forward for one more
+ step, so in the next step, all EngineCores will be synchronized.
+ We use a two-staged barrier to achieve this. The first time each
+ EngineCore executes the barrier, if a timeout is reached before the
+ barrier completes, that means some EngineCores have already entered
+ engine step. The EngineCores that timed out will then proceed to
+ engine step, and will synchronize with the other EngineCores in the
+ next step with a barrier without timeout.
+ """
+ dp_store = self.new_dp_store if use_new_group else self.old_dp_store
+ dp_group = self.new_dp_group if use_new_group else self.old_dp_group
+ assert dp_group is not None
+
+ group_rank = dp_group.rank()
+ group_size = dp_group.size()
+ barrier_id = f"eep_barrier_{barrier_name}"
+ sync_key = f"{barrier_id}_sync"
+
+ # TODO(yongji): figure out appropriate timeout for the barrier
+ timeout = None if dp_store.check([sync_key]) else timedelta(seconds=5)
+
+ try:
+ self._execute_tcp_store_barrier(
+ dp_store, group_rank, group_size, barrier_id, timeout=timeout
+ )
+ torch.distributed.barrier(dp_group)
+ if group_rank == 0:
+ dp_store.delete_key(sync_key)
+ for i in range(group_size):
+ dp_store.delete_key(f"arrival_{barrier_id}_{i}")
+ return True
+ except _BarrierTimeoutError as e:
+ if timeout is None:
+ raise RuntimeError("Unexpected timeout encountered") from e
+ dp_store.compare_set(sync_key, "", b"1")
+ return False
+
+ def _progress_existing_engine(self) -> bool:
+ state = self.state
+
+ if state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT:
+ return False
+
+ elif state == ScaleUpExistingEngineState.CREATE_STANDBY_GROUPS:
+ # NOTE(yongji): wait for all existing workers to receive the request
+ if (
+ int(self.old_dp_store.get("eep_barrier_engine_count"))
+ < self.old_dp_group.size()
+ ):
+ return False
+ if not self._staged_barrier(
+ use_new_group=False, barrier_name="create_standby_groups"
+ ):
+ return False
+ if self.old_dp_group.rank() == 0:
+ self.old_dp_store.delete_key("eep_barrier_engine_count")
+ self._create_standby_groups()
+ self.state = ScaleUpExistingEngineState.TRANSFER_EXPERT_MAPPING
+ return True
+
+ elif state == ScaleUpExistingEngineState.TRANSFER_EXPERT_MAPPING:
+ self._transfer_expert_mapping()
+ self.state = ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT
+ return True
+
+ elif state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT:
+ return False
+
+ elif state == ScaleUpExistingEngineState.TRANSFER_WEIGHTS:
+ if (
+ int(self.old_dp_store.get("eep_barrier_engine_count"))
+ < self.old_dp_group.size()
+ ):
+ return False
+ if not self._staged_barrier(
+ use_new_group=False, barrier_name="transfer_weights"
+ ):
+ return False
+ if self.old_dp_group.rank() == 0:
+ self.old_dp_store.delete_key("eep_barrier_engine_count")
+ self._transfer_weights()
+ self.state = ScaleUpExistingEngineState.SYNC_KV_CACHE_MEMORY_SIZE
+ return True
+
+ elif state == ScaleUpExistingEngineState.SYNC_KV_CACHE_MEMORY_SIZE:
+ self._sync_kv_cache_memory_size()
+ self.state = ScaleUpExistingEngineState.SWITCH_AND_PREPARE
+ return True
+
+ elif state == ScaleUpExistingEngineState.SWITCH_AND_PREPARE:
+ self._switch_and_prepare()
+ self.state = ScaleUpExistingEngineState.EPLB_RESHUFFLE
+ self.new_dp_store.add("eep_barrier_engine_count", 1)
+ return True
+
+ elif state == ScaleUpExistingEngineState.EPLB_RESHUFFLE:
+ assert self.new_dp_group is not None
+ if (
+ int(self.new_dp_store.get("eep_barrier_engine_count"))
+ < self.new_dp_group.size()
+ ):
+ return False
+ if not self._staged_barrier(
+ use_new_group=True, barrier_name="eplb_reshuffle"
+ ):
+ return False
+ if self.new_dp_group.rank() == 0:
+ self.new_dp_store.delete_key("eep_barrier_engine_count")
+ self._eplb_reshuffle()
+ self.state = ScaleUpExistingEngineState.COMPLETE
+ self._update_parallel_config()
+ return True
+
+ else:
+ assert self.state == ScaleUpExistingEngineState.COMPLETE
+ return True
+
+ def _progress_new_engine(self) -> bool:
+ state = self.state
+ assert self.new_dp_group is not None
+
+ if state == ScaleUpNewEngineState.PREPARE:
+ tensor = torch.tensor([0, 0, 0], dtype=torch.int32, device="cpu")
+ torch.distributed.all_reduce(
+ tensor,
+ op=torch.distributed.ReduceOp.MAX,
+ group=self.new_dp_group,
+ )
+ data = tensor.tolist()
+ self.engine_core.engines_running = bool(data[0])
+ self.engine_core.current_wave = int(data[1])
+ self.engine_core.step_counter = int(data[2])
+ self.state = ScaleUpNewEngineState.EPLB_RESHUFFLE
+ self.new_dp_store.add("eep_barrier_engine_count", 1)
+ return True
+
+ elif state == ScaleUpNewEngineState.EPLB_RESHUFFLE:
+ if (
+ int(self.new_dp_store.get("eep_barrier_engine_count"))
+ < self.new_dp_group.size()
+ ):
+ return False
+ if not self._staged_barrier(
+ use_new_group=True, barrier_name="eplb_reshuffle"
+ ):
+ return False
+ assert self.new_dp_group.rank() > 0
+ self._eplb_reshuffle()
+ self.state = ScaleUpNewEngineState.COMPLETE
+ return True
+
+ else:
+ assert self.state == ScaleUpNewEngineState.COMPLETE
+ return True
+
+ def _progress_remaining_engine(self) -> bool:
+ state = self.state
+
+ if state == ScaleDownRemainingEngineState.PREPARE:
+ self.state = ScaleDownRemainingEngineState.EPLB_RESHUFFLE
+ self.old_dp_store.add("eep_barrier_engine_count", 1)
+ return True
+
+ elif state == ScaleDownRemainingEngineState.EPLB_RESHUFFLE:
+ if (
+ int(self.old_dp_store.get("eep_barrier_engine_count"))
+ < self.old_dp_group.size()
+ ):
+ return False
+ if not self._staged_barrier(
+ use_new_group=False, barrier_name="eplb_reshuffle"
+ ):
+ return False
+ if self.old_dp_group.rank() == 0:
+ self.old_dp_store.delete_key("eep_barrier_engine_count")
+ self._eplb_reshuffle_before_scale_down()
+ self.state = ScaleDownRemainingEngineState.SWITCH_AND_PREPARE
+ # NOTE(yongji): currently, after EPLB reshuffle
+ # that redistributes experts to remaining workers, workers
+ # to be removed will immediately initiate shutdown;
+ # existing workers can no longer execute forward steps using
+ # the old setup. In the future, we may keep
+ # the removing workers alive a bit longer,
+ # e.g., to drain in-batch requests.
+ self._create_standby_groups()
+ self._switch_and_prepare()
+ self._update_parallel_config()
+ self.state = ScaleDownRemainingEngineState.COMPLETE
+ return True
+
+ else:
+ assert self.state == ScaleDownRemainingEngineState.COMPLETE
+ return True
+
+ def _progress_removing_engine(self) -> bool:
+ state = self.state
+
+ if state == ScaleDownRemovingEngineState.PREPARE:
+ self.state = ScaleDownRemovingEngineState.EPLB_RESHUFFLE
+ self.old_dp_store.add("eep_barrier_engine_count", 1)
+ return True
+
+ if state == ScaleDownRemovingEngineState.EPLB_RESHUFFLE:
+ if (
+ int(self.old_dp_store.get("eep_barrier_engine_count"))
+ < self.old_dp_group.size()
+ ):
+ return False
+ if not self._staged_barrier(
+ use_new_group=False, barrier_name="eplb_reshuffle"
+ ):
+ return False
+ assert self.old_dp_group.rank() > 0
+ self._eplb_reshuffle_before_scale_down()
+ self._switch_and_remove()
+ self.state = ScaleDownRemovingEngineState.COMPLETE
+ self.engine_core._eep_send_engine_core_notification(
+ EEPNotificationType.SHUTDOWN_COMPLETE
+ )
+ self.engine_core.shutdown()
+ return True
+
+ else:
+ assert self.state == ScaleDownRemovingEngineState.COMPLETE
+ return True
+
+ def handle_notification(self, notification_type: EEPNotificationType):
+ assert self.worker_type != "new"
+ if (
+ notification_type == EEPNotificationType.NEW_CORE_ENGINES_INIT_READY
+ and self.state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT
+ ):
+ self.old_dp_store.add("eep_barrier_engine_count", 1)
+ self.state = ScaleUpExistingEngineState.CREATE_STANDBY_GROUPS
+ elif (
+ notification_type == EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY
+ and self.state
+ == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT
+ ):
+ self.old_dp_store.add("eep_barrier_engine_count", 1)
+ self.state = ScaleUpExistingEngineState.TRANSFER_WEIGHTS
+
+ def is_complete(self) -> bool:
+ if self.scale_type == "scale_up":
+ return (
+ self.state == ScaleUpNewEngineState.COMPLETE
+ if self.worker_type == "new"
+ else self.state == ScaleUpExistingEngineState.COMPLETE
+ )
+ return (
+ self.state == ScaleDownRemovingEngineState.COMPLETE
+ if self.worker_type == "removing"
+ else self.state == ScaleDownRemainingEngineState.COMPLETE
+ )
+
+ def _create_standby_groups(self):
+ self.new_dp_group, self.new_dp_store = (
+ self.new_parallel_config.stateless_init_dp_group(return_store=True)
+ )
+ self.model_executor.collective_rpc(
+ "elastic_ep_execute", args=("create_standby_groups", self.reconfig_request)
+ )
+ if self.old_dp_group.rank() == 0:
+ logger.info("[Elastic EP] Created standby communication groups")
+
+ def _transfer_weights(self):
+ assert self.reconfig_request is not None
+ old_dp_size = self.old_dp_group.size()
+ new_dp_size = self.reconfig_request.new_data_parallel_size
+
+ self.model_executor.collective_rpc(
+ "elastic_ep_execute", args=("transfer_weights", old_dp_size, new_dp_size)
+ )
+ if self.old_dp_group.rank() == 0:
+ logger.info("[Elastic EP] Transferred weights to new workers")
+
+ def _transfer_expert_mapping(self):
+ self.model_executor.collective_rpc(
+ "elastic_ep_execute", args=("broadcast_expert_mapping",)
+ )
+ if self.old_dp_group.rank() == 0:
+ logger.info("[Elastic EP] Broadcasted expert mapping to new workers")
+
+ def _sync_kv_cache_memory_size(self):
+ assert self.engine_core.available_gpu_memory_for_kv_cache > 0
+ assert self.new_dp_group is not None
+ ParallelConfig.sync_kv_cache_memory_size(
+ self.new_dp_group,
+ self.engine_core.available_gpu_memory_for_kv_cache,
+ )
+ if self.old_dp_group.rank() == 0:
+ logger.info("[Elastic EP] Synced KV cache memory size to new workers")
+
+ def _switch_and_prepare(self):
+ self.model_executor.collective_rpc(
+ "elastic_ep_execute", args=("switch_and_prepare",)
+ )
+ old_dp_group = self.old_dp_group
+ stateless_destroy_torch_distributed_process_group(old_dp_group)
+ assert self.new_dp_group is not None
+ new_dp_group = self.new_dp_group
+ self.engine_core.dp_group = new_dp_group
+ self.engine_core.dp_rank = new_dp_group.rank()
+ self.engine_core.dp_store = self.new_dp_store
+ engines_running = int(self.engine_core.engines_running)
+ current_wave = self.engine_core.current_wave
+ step_counter = self.engine_core.step_counter
+ tensor = torch.tensor(
+ [engines_running, current_wave, step_counter],
+ dtype=torch.int32,
+ device="cpu",
+ )
+ torch.distributed.all_reduce(
+ tensor, op=torch.distributed.ReduceOp.MAX, group=new_dp_group
+ )
+ data = tensor.tolist()
+ self.engine_core.engines_running = bool(data[0])
+ self.engine_core.current_wave = int(data[1])
+ self.engine_core.step_counter = int(data[2])
+ if new_dp_group.rank() == 0:
+ self.engine_core._eep_send_engine_core_notification(
+ EEPNotificationType.RECONFIGURE_FINISHED
+ )
+ logger.info("[Elastic EP] Switched to new setup")
+
+ def _eplb_reshuffle(self):
+ self.model_executor.collective_rpc(
+ "elastic_ep_execute", args=("perform_eplb_reshuffle",)
+ )
+ assert self.new_dp_group is not None
+ if self.new_dp_group.rank() == 0:
+ logger.info("[Elastic EP] EPLB reshuffle completed")
+
+ def _eplb_reshuffle_before_scale_down(self):
+ assert self.reconfig_request is not None
+ self.model_executor.collective_rpc(
+ "elastic_ep_execute",
+ args=(
+ "perform_eplb_reshuffle",
+ self.reconfig_request.new_data_parallel_size,
+ ),
+ )
+ if self.old_dp_group.rank() == 0:
+ logger.info("[Elastic EP] EPLB reshuffle completed")
+
+ def _switch_and_remove(self):
+ self.model_executor.collective_rpc(
+ "elastic_ep_execute", args=("switch_and_remove",)
+ )
+
+ def _update_parallel_config(self):
+ assert self.reconfig_request is not None
+ reconfig_request = self.reconfig_request
+ parallel_config = self.vllm_config.parallel_config
+ parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
+ if (
+ reconfig_request.new_data_parallel_rank
+ != ReconfigureRankType.KEEP_CURRENT_RANK
+ ):
+ parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
+ if (
+ reconfig_request.new_data_parallel_rank_local
+ != ReconfigureRankType.KEEP_CURRENT_RANK
+ ):
+ parallel_config.data_parallel_rank_local = (
+ reconfig_request.new_data_parallel_rank_local
+ )
+ parallel_config.data_parallel_master_ip = (
+ reconfig_request.new_data_parallel_master_ip
+ )
+ parallel_config.data_parallel_master_port = (
+ reconfig_request.new_data_parallel_master_port
+ )
+ parallel_config._data_parallel_master_port_list = (
+ reconfig_request.new_data_parallel_master_port_list
+ )
+ parallel_config._stateless_world_group_port_list = (
+ reconfig_request.new_stateless_world_group_port_list
+ )
+ parallel_config._stateless_dp_group_port_list = (
+ reconfig_request.new_stateless_dp_group_port_list
+ )
+ parallel_config._stateless_ep_group_port_list = (
+ reconfig_request.new_stateless_ep_group_port_list
+ )
+ parallel_config._stateless_eplb_group_port_list = (
+ reconfig_request.new_stateless_eplb_group_port_list
+ )
diff --git a/vllm/distributed/elastic_ep/standby_state.py b/vllm/distributed/elastic_ep/standby_state.py
new file mode 100644
index 0000000..d11e0b5
--- /dev/null
+++ b/vllm/distributed/elastic_ep/standby_state.py
@@ -0,0 +1,117 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import torch
+
+from vllm.distributed.parallel_state import (
+ _init_stateless_group,
+ _node_count,
+ get_pp_group,
+ get_tp_group,
+ get_world_group,
+)
+from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
+
+_STANDBY_WORLD: StatelessGroupCoordinator | None = None
+_STANDBY_WORLD_NODE_COUNT: int | None = None
+_STANDBY_DP: StatelessGroupCoordinator | None = None
+_STANDBY_EP: StatelessGroupCoordinator | None = None
+_STANDBY_EPLB: StatelessGroupCoordinator | None = None
+
+
+def get_standby_dp_group() -> StatelessGroupCoordinator | None:
+ return _STANDBY_DP
+
+
+def get_standby_ep_group() -> StatelessGroupCoordinator | None:
+ return _STANDBY_EP
+
+
+def get_standby_eplb_group() -> StatelessGroupCoordinator | None:
+ return _STANDBY_EPLB
+
+
+def get_standby_world_group() -> StatelessGroupCoordinator | None:
+ return _STANDBY_WORLD
+
+
+def create_standby_groups(
+ new_dp_size: int,
+ new_world_size_across_dp: int,
+ master_ip: str,
+ world_group_ports: list[list[int]],
+ dp_group_ports: list[list[int]],
+ ep_group_ports: list[list[int]],
+ eplb_group_ports: list[list[int]] | None = None,
+ backend: str | None = None,
+) -> None:
+ global \
+ _STANDBY_WORLD, \
+ _STANDBY_WORLD_NODE_COUNT, \
+ _STANDBY_DP, \
+ _STANDBY_EP, \
+ _STANDBY_EPLB
+
+ assert new_world_size_across_dp == torch.distributed.get_world_size() * new_dp_size
+ world_group = get_world_group()
+ assert isinstance(world_group, StatelessGroupCoordinator)
+ backend = backend or world_group.backend
+
+ standby_world_ranks = [list(range(new_world_size_across_dp))]
+ _STANDBY_WORLD = _init_stateless_group(
+ standby_world_ranks,
+ "world",
+ world_group_ports,
+ master_ip,
+ backend,
+ use_device_communicator=False,
+ )
+ _STANDBY_WORLD_NODE_COUNT = _node_count(_STANDBY_WORLD.tcp_store_group)
+
+ tp_size = get_tp_group().world_size
+ pp_size = get_pp_group().world_size
+
+ all_ranks = torch.arange(new_world_size_across_dp).reshape(
+ -1, new_dp_size, pp_size, tp_size
+ )
+ standby_dp_ranks = all_ranks.transpose(1, 3).reshape(-1, new_dp_size).unbind(0)
+ standby_dp_ranks = [x.tolist() for x in standby_dp_ranks]
+ _STANDBY_DP = _init_stateless_group(
+ standby_dp_ranks, "dp", dp_group_ports, master_ip, backend
+ )
+
+ standby_ep_ranks = (
+ all_ranks.transpose(1, 2).reshape(-1, new_dp_size * tp_size).unbind(0)
+ )
+ standby_ep_ranks = [x.tolist() for x in standby_ep_ranks]
+ _STANDBY_EP = _init_stateless_group(
+ standby_ep_ranks, "ep", ep_group_ports, master_ip, backend
+ )
+
+ if eplb_group_ports is not None:
+ _STANDBY_EPLB = _init_stateless_group(
+ standby_ep_ranks, "eplb", eplb_group_ports, master_ip, backend
+ )
+
+
+def pop_standby_groups() -> dict:
+ """Return all standby groups and clear the standby state."""
+ global \
+ _STANDBY_WORLD, \
+ _STANDBY_WORLD_NODE_COUNT, \
+ _STANDBY_DP, \
+ _STANDBY_EP, \
+ _STANDBY_EPLB
+
+ result = dict(
+ world=_STANDBY_WORLD,
+ dp=_STANDBY_DP,
+ ep=_STANDBY_EP,
+ eplb=_STANDBY_EPLB,
+ node_count=_STANDBY_WORLD_NODE_COUNT,
+ )
+ _STANDBY_WORLD = None
+ _STANDBY_WORLD_NODE_COUNT = None
+ _STANDBY_DP = None
+ _STANDBY_EP = None
+ _STANDBY_EPLB = None
+ return result
diff --git a/vllm/distributed/eplb/async_worker.py b/vllm/distributed/eplb/async_worker.py
index b81c7fa..5dd862f 100644
--- a/vllm/distributed/eplb/async_worker.py
+++ b/vllm/distributed/eplb/async_worker.py
@@ -24,7 +24,6 @@ logger = init_logger(__name__)
def start_async_worker(
state: "EplbState",
- rank_mapping: dict[int, int] | None = None,
is_profile: bool = False,
) -> threading.Thread:
eplb_group = get_eplb_group().device_group
@@ -45,7 +44,6 @@ def start_async_worker(
eplb_group=eplb_group,
cuda_stream=cuda_stream,
is_profile=is_profile,
- rank_mapping=rank_mapping,
)
)
except Exception as exc: # pragma: no cover - diagnostic path
@@ -107,7 +105,6 @@ async def transfer_run_periodically(
eplb_group: ProcessGroup,
cuda_stream: torch.cuda.Stream,
is_profile: bool = False,
- rank_mapping: dict[int, int] | None = None,
) -> None:
while True:
await asyncio.to_thread(state.rearrange_event.wait)
@@ -176,7 +173,6 @@ async def transfer_run_periodically(
ep_group=eplb_group,
is_profile=is_profile,
cuda_stream=cuda_stream,
- rank_mapping=rank_mapping,
)
event = torch.cuda.Event(blocking=False)
cuda_stream.record_event(event)
diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py
index 7c3701b..b417c2b 100644
--- a/vllm/distributed/eplb/eplb_state.py
+++ b/vllm/distributed/eplb/eplb_state.py
@@ -40,6 +40,7 @@ from vllm.distributed.parallel_state import (
get_node_count,
in_the_same_node_as,
)
+from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
from vllm.distributed.utils import StatelessProcessGroup
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MixtureOfExperts
@@ -159,7 +160,7 @@ class EplbModelState:
NOTE: The expert_load_view now records load for all physical experts
rather than just local experts. This ensures consistent load statistics
- across different dispatch methods (naive all-to-all, DeepEP, pplx-kernels).
+ across different dispatch methods (naive all-to-all, DeepEP).
The recorded load will be multiplied by dp_size when using naive all-to-all
due to each DP rank contributing the same token set to the calculation.
See:
@@ -302,6 +303,14 @@ class EplbState:
"""
CUDA device index for the async EPLB worker thread.
"""
+ self.num_valid_physical_experts: int = 0
+ """
+ Number of valid physical experts.
+ This is the number of physical experts that are
+ actually mapped to logical experts. In elastic EP,
+ newly started EP ranks may not have physical experts
+ mapped yet.
+ """
if self.device.type == "cuda":
self.cuda_device_index = self.device.index
if self.cuda_device_index is None and torch.cuda.is_available():
@@ -367,9 +376,6 @@ class EplbState:
self,
model: MixtureOfExperts,
model_config: ModelConfig,
- global_expert_load: torch.Tensor | None = None,
- old_global_expert_indices: torch.Tensor | None = None,
- rank_mapping: dict[int, int] | None = None,
):
"""
Build the initial EPLB state.
@@ -462,75 +468,15 @@ class EplbState:
)
self.expert_rearrangement_step_interval = eplb_step_interval
- # Set the policy based on the selected eplb algorithm type.
policy_type = self.parallel_config.eplb_config.policy
self.policy = EPLB_POLICIES[policy_type]
logger.debug("Selected EPLB policy: %s", policy_type)
- if global_expert_load is not None:
- ep_group = get_ep_group().device_group
- assert global_expert_load.shape == (
- model.num_moe_layers,
- model.num_logical_experts,
- )
- assert global_expert_load.dtype == torch.int64
- num_replicas = model.num_physical_experts
- num_groups = model.num_expert_groups
- num_nodes = get_node_count()
- num_gpus = ep_group.size()
-
- if num_gpus % num_nodes != 0:
- num_nodes = 1
- logger.warning_once(
- f"num_gpus % num_nodes != 0, "
- "not using hierarchical rearrangement algorithm.\n"
- f"{num_gpus=}, {num_nodes=}"
- )
-
- # Get new expert mappings
- (
- new_physical_to_logical_map,
- new_logical_to_physical_map,
- new_logical_replica_count,
- ) = self.policy.rebalance_experts(
- global_expert_load,
- num_replicas,
- num_groups,
- num_nodes,
- num_gpus,
- )
-
- max_physical_slots = new_logical_to_physical_map.shape[-1]
- assert max_physical_slots <= logical_to_physical_map.shape[-1]
- new_logical_to_physical_map = torch.nn.functional.pad(
- new_logical_to_physical_map,
- (0, logical_to_physical_map.shape[-1] - max_physical_slots),
- value=-1,
- )
- physical_to_logical_map = new_physical_to_logical_map.to(self.device)
- logical_to_physical_map.copy_(new_logical_to_physical_map)
- logical_replica_count.copy_(new_logical_replica_count)
- else:
- new_physical_to_logical_map = None
-
- new_logical_to_physical_map = None
-
- new_logical_replica_count = None
model.set_eplb_state(
expert_load_pass,
logical_to_physical_map,
logical_replica_count,
)
- if global_expert_load is not None:
- rearrange_expert_weights_inplace(
- old_global_expert_indices,
- new_physical_to_logical_map,
- model.expert_weights,
- ep_group,
- False,
- rank_mapping,
- )
- self.expert_rearrangement_step = 0
expert_buffer = [torch.empty_like(w) for w in model.expert_weights[0]]
@@ -561,11 +507,12 @@ class EplbState:
recv_dst_rows=np.array([]),
),
cuda_device_index=self.cuda_device_index,
- new_physical_to_logical_map=new_physical_to_logical_map,
- new_logical_to_physical_map=new_logical_to_physical_map,
- new_logical_replica_count=new_logical_replica_count,
+ new_physical_to_logical_map=None,
+ new_logical_to_physical_map=None,
+ new_logical_replica_count=None,
)
self.model_states[model_config.compute_hash()] = model_state
+ self.num_valid_physical_experts = model.num_physical_experts
def step(
self,
@@ -696,8 +643,6 @@ class EplbState:
def rearrange(
self,
is_profile: bool = False,
- execute_shuffle: bool = True,
- global_expert_loads: list[torch.Tensor] | None = None,
rank_mapping: dict[int, int] | None = None,
) -> torch.Tensor | None:
"""
@@ -707,12 +652,6 @@ class EplbState:
is_profile (bool): If `True`, perform a dummy rearrangement.
This is used in `profile_run` to reserve enough memory,
no memory movement will be performed. Default is False.
- execute_shuffle (bool): If `True`, execute the shuffle
- in elastic expert parallel (EEP). Default is True.
- global_expert_loads (list[torch.Tensor] | None): The global expert
- loads when scaling is done in EEP.
- List of expert loads for the main and drafter
- (when spec decode is used) models.
rank_mapping (dict[int, int] | None): The rank mapping
when scaling is done in EEP.
"""
@@ -734,67 +673,34 @@ class EplbState:
"(profile)" if is_profile else "",
)
- if global_expert_loads is None:
- # Map the physical expert load to global logical experts
- global_expert_load_windows = []
- if not execute_shuffle:
- num_models = torch.tensor(
- [len(self.model_states)], dtype=torch.int32, device="cpu"
- )
- torch.distributed.broadcast(
- num_models, group=get_ep_group().cpu_group, group_src=0
- )
-
- for eplb_model_state in self.model_states.values():
- logical_expert_load_window = torch.zeros(
- self.expert_load_window_size,
- eplb_model_state.model.num_moe_layers,
- eplb_model_state.model.num_logical_experts,
- dtype=eplb_model_state.expert_load_window.dtype,
- device=eplb_model_state.expert_load_window.device,
- )
- logical_expert_load_window.scatter_add_(
- dim=-1,
- index=eplb_model_state.physical_to_logical_map.unsqueeze(0)
- .expand_as(eplb_model_state.expert_load_window)
- .long(),
- src=eplb_model_state.expert_load_window,
- )
-
- if not execute_shuffle:
- metadata = torch.tensor(
- [
- eplb_model_state.model.num_moe_layers,
- eplb_model_state.model.num_logical_experts,
- eplb_model_state.physical_to_logical_map.shape[1],
- ],
- dtype=torch.int32,
- device="cpu",
- )
- torch.distributed.broadcast(
- metadata, group=get_ep_group().cpu_group, group_src=0
- )
-
- global_expert_load_window = logical_expert_load_window.sum(dim=0)
- global_expert_load_windows.append(global_expert_load_window)
- # Perform all-reduce to get the expert load across all ranks for each model
- global_expert_load_windows = self._allreduce_list(
- global_expert_load_windows
+ # Map the physical expert load to global logical experts
+ global_expert_load_windows = []
+ for eplb_model_state in self.model_states.values():
+ expert_load_window = eplb_model_state.expert_load_window[
+ :, :, : self.num_valid_physical_experts
+ ]
+ logical_expert_load_window = torch.zeros(
+ self.expert_load_window_size,
+ eplb_model_state.model.num_moe_layers,
+ eplb_model_state.model.num_logical_experts,
+ dtype=eplb_model_state.expert_load_window.dtype,
+ device=eplb_model_state.expert_load_window.device,
)
- if not execute_shuffle:
- for eplb_model_state, global_expert_load_window in zip(
- self.model_states.values(), global_expert_load_windows
- ):
- # (num_moe_layers, old_num_physical_experts)
- old_global_expert_indices = eplb_model_state.physical_to_logical_map
- torch.distributed.broadcast(
- old_global_expert_indices, group=ep_group, group_src=0
- )
- if not execute_shuffle:
- return global_expert_load_windows
- else:
- assert execute_shuffle
- global_expert_load_windows = global_expert_loads
+ logical_expert_load_window.scatter_add_(
+ dim=-1,
+ index=eplb_model_state.physical_to_logical_map[
+ :, : self.num_valid_physical_experts
+ ]
+ .unsqueeze(0)
+ .expand_as(expert_load_window)
+ .long(),
+ src=expert_load_window,
+ )
+
+ global_expert_load_window = logical_expert_load_window.sum(dim=0)
+ global_expert_load_windows.append(global_expert_load_window)
+ # Perform all-reduce to get the expert load across all ranks for each model
+ global_expert_load_windows = self._allreduce_list(global_expert_load_windows)
# TODO(bowen): Treat differently for prefill and decode nodes
eplb_model_state = next(iter(self.model_states.values()))
@@ -806,8 +712,10 @@ class EplbState:
# NOTE(yongji): scale down, we need to rebalance the experts on
# remaining GPUs, transfer the experts while we haven't shutdown
# the GPUs to be released.
- cpu_group = get_ep_group().cpu_group
- num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping)
+ coordinator = get_ep_group()
+ assert isinstance(coordinator, StatelessGroupCoordinator)
+ tcp_store_group = coordinator.tcp_store_group
+ num_nodes = _node_count_with_rank_mapping(tcp_store_group, rank_mapping)
num_gpus = sum(new_rank != -1 for new_rank in rank_mapping.values())
num_replicas = (
num_replicas // ep_group.size() * num_gpus
@@ -933,7 +841,6 @@ class EplbState:
if self.async_worker is None:
self.async_worker = start_async_worker(
self,
- rank_mapping=rank_mapping,
is_profile=is_profile,
)
@@ -1089,83 +996,6 @@ class EplbState:
model_state.new_logical_to_physical_map = None
model_state.new_logical_replica_count = None
- @staticmethod
- def recv_state() -> tuple[list[torch.Tensor], list[torch.Tensor]]:
- """
- Receive the expert load and old placement from the master rank.
- """
- ep_group = get_ep_group()
- num_models = torch.empty(1, dtype=torch.int32, device="cpu")
- torch.distributed.broadcast(num_models, group=ep_group.cpu_group, group_src=0)
- num_models = num_models.item()
- global_expert_loads = []
- old_global_expert_indices_per_model = []
- for _ in range(num_models):
- metadata = torch.empty(3, dtype=torch.int32, device="cpu")
- torch.distributed.broadcast(metadata, group=ep_group.cpu_group, group_src=0)
- num_moe_layers, num_logical_experts, num_old_physical_experts = (
- metadata.tolist()
- )
- global_expert_load = torch.zeros(
- (num_moe_layers, num_logical_experts),
- dtype=torch.int64,
- device=ep_group.device,
- )
- all_reduce(global_expert_load, group=ep_group.device_group)
- old_global_expert_indices = torch.empty(
- (num_moe_layers, num_old_physical_experts),
- dtype=torch.int64,
- device=ep_group.device,
- )
- torch.distributed.broadcast(
- old_global_expert_indices,
- group=ep_group.device_group,
- group_src=0,
- )
- global_expert_loads.append(global_expert_load)
- old_global_expert_indices_per_model.append(old_global_expert_indices)
- return global_expert_loads, old_global_expert_indices_per_model
-
- @classmethod
- def get_eep_state(
- cls, parallel_config: ParallelConfig
- ) -> tuple[
- list[torch.Tensor] | None,
- list[torch.Tensor] | None,
- dict[int, int] | None,
- ]:
- num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu")
- torch.distributed.broadcast(
- num_local_physical_experts,
- group=get_ep_group().cpu_group,
- group_src=0,
- )
- num_local_physical_experts = int(num_local_physical_experts.item())
- new_ep_size = get_ep_group().world_size
- global_expert_loads, old_global_expert_indices_per_model = (
- EplbState.recv_state()
- )
-
- # EP configuration for all models has to be the same so as eplb config
- num_logical_experts = global_expert_loads[0].shape[1]
- parallel_config.eplb_config.num_redundant_experts = (
- num_local_physical_experts * new_ep_size - num_logical_experts
- )
- assert (
- old_global_expert_indices_per_model[0].shape[1] % num_local_physical_experts
- == 0
- )
- old_ep_size = (
- old_global_expert_indices_per_model[0].shape[1]
- // num_local_physical_experts
- )
- rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)}
- return (
- global_expert_loads,
- old_global_expert_indices_per_model,
- rank_mapping,
- )
-
def _allreduce_list(self, tensor_list: list[torch.Tensor]) -> list[torch.Tensor]:
"""
All-reduce a list of tensors.
@@ -1203,6 +1033,60 @@ class EplbState:
load_pass_list.append(eplb_model_state.expert_load_pass.clone())
return self._allreduce_list(load_pass_list)
+ @classmethod
+ def from_mapping(
+ cls,
+ model: MixtureOfExperts,
+ model_config: ModelConfig,
+ device: torch.device,
+ parallel_config: ParallelConfig,
+ expanded_physical_to_logical: torch.Tensor,
+ num_valid_physical_experts: int,
+ ) -> "EplbState":
+ eplb_state = cls(
+ parallel_config=parallel_config,
+ device=device,
+ )
+ eplb_state.add_model(
+ model=model,
+ model_config=model_config,
+ )
+ eplb_state.num_valid_physical_experts = num_valid_physical_experts
+ num_moe_layers = expanded_physical_to_logical.shape[0]
+ num_physical_experts = expanded_physical_to_logical.shape[1]
+ eplb_model_state = eplb_state.model_states[model_config.compute_hash()]
+ eplb_model_state.physical_to_logical_map.copy_(expanded_physical_to_logical)
+
+ logical_to_physical_map = torch.full(
+ (
+ num_moe_layers,
+ model.num_logical_experts,
+ eplb_model_state.logical_to_physical_map.shape[2],
+ ),
+ -1,
+ dtype=torch.int64,
+ )
+ logical_replica_count = torch.zeros(
+ (num_moe_layers, model.num_logical_experts),
+ dtype=torch.int64,
+ )
+ expanded_physical_to_logical_numpy = expanded_physical_to_logical.cpu().numpy()
+ for layer_idx in range(num_moe_layers):
+ for phys_idx in range(num_physical_experts):
+ logical_idx = expanded_physical_to_logical_numpy[layer_idx, phys_idx]
+ if logical_idx >= 0:
+ replica_idx = logical_replica_count[layer_idx, logical_idx]
+ logical_to_physical_map[layer_idx, logical_idx, replica_idx] = (
+ phys_idx
+ )
+ logical_replica_count[layer_idx, logical_idx] += 1
+
+ logical_to_physical_map = logical_to_physical_map.to(device)
+ logical_replica_count = logical_replica_count.to(device)
+ eplb_model_state.logical_to_physical_map.copy_(logical_to_physical_map)
+ eplb_model_state.logical_replica_count.copy_(logical_replica_count)
+ return eplb_state
+
@dataclass
class EplbLayerState:
diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py
index 1be1e24..777f9c5 100644
--- a/vllm/distributed/eplb/rebalance_execute.py
+++ b/vllm/distributed/eplb/rebalance_execute.py
@@ -19,6 +19,8 @@ from torch.distributed import (
get_global_rank,
)
+from vllm.distributed.parallel_state import get_ep_group
+from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
from vllm.logger import init_logger
logger = init_logger(__name__)
@@ -249,10 +251,18 @@ def move_to_buffer(
b[dst].copy_(w[src_local], non_blocking=True)
p2p_ops: list[P2POp] = []
+ if isinstance(get_ep_group(), StatelessGroupCoordinator):
+ ep_group = get_ep_group()
+ is_stateless = True
+ else:
+ is_stateless = False
- # Pre-compute global ranks mapping
+ # Pre-compute global ranks mapping (only needed for non-stateless groups)
ep_size = ep_group.size()
- rank_to_global = {rank: get_global_rank(ep_group, rank) for rank in range(ep_size)}
+ if not is_stateless:
+ rank_to_global = {
+ rank: get_global_rank(ep_group, rank) for rank in range(ep_size)
+ }
# 2. Post sends
if send_count > 0:
@@ -284,15 +294,23 @@ def move_to_buffer(
if recver_pos < len(ranks_to_recv):
recv_ranks.append(ranks_to_recv[recver_pos])
for dst in recv_ranks:
- dst_global = rank_to_global[dst]
- p2p_ops += [
- P2POp(
- torch.distributed.isend,
- w[src],
- dst_global,
- )
- for w in expert_weights
- ]
+ if is_stateless:
+ for w in expert_weights:
+ op = object.__new__(P2POp)
+ op.op = torch.distributed.isend
+ op.tensor = w[src]
+ op.group_peer = dst
+ p2p_ops.append(op)
+ else:
+ dst_global = rank_to_global[dst]
+ p2p_ops += [
+ P2POp(
+ torch.distributed.isend,
+ w[src],
+ dst_global,
+ )
+ for w in expert_weights
+ ]
# 3. Post recvs
if recv_count > 0:
@@ -321,26 +339,40 @@ def move_to_buffer(
src = ranks_to_send[recver_pos // num_dst_per_sender]
else:
src = ranks_to_send[recver_pos - remainder_start]
- src_global = rank_to_global[src]
- p2p_ops += [
- P2POp(
- torch.distributed.irecv,
- b[dst],
- src_global,
- )
- for b in expert_weights_buffers
- ]
+ if is_stateless:
+ for b in expert_weights_buffers:
+ op = object.__new__(P2POp)
+ op.op = torch.distributed.irecv
+ op.tensor = b[dst]
+ op.group_peer = src
+ p2p_ops.append(op)
+ else:
+ src_global = rank_to_global[src]
+ p2p_ops += [
+ P2POp(
+ torch.distributed.irecv,
+ b[dst],
+ src_global,
+ )
+ for b in expert_weights_buffers
+ ]
# 4. Execute the P2P operations. The real communication happens here.
if p2p_ops and cuda_stream is not None:
with torch.cuda.stream(cuda_stream):
+ if is_stateless:
+ ep_group.device_communicator.batch_isend_irecv(p2p_ops)
+ else:
+ reqs = batch_isend_irecv(p2p_ops)
+ for req in reqs:
+ req.wait()
+ elif p2p_ops:
+ if is_stateless:
+ ep_group.device_communicator.batch_isend_irecv(p2p_ops)
+ else:
reqs = batch_isend_irecv(p2p_ops)
for req in reqs:
req.wait()
- elif p2p_ops:
- reqs = batch_isend_irecv(p2p_ops)
- for req in reqs:
- req.wait()
# wait for the communication to finish
return (
is_unchanged,
diff --git a/vllm/distributed/kv_events.py b/vllm/distributed/kv_events.py
index 096ed44..21ec7a3 100644
--- a/vllm/distributed/kv_events.py
+++ b/vllm/distributed/kv_events.py
@@ -209,6 +209,10 @@ class KVConnectorKVEvents(ABC):
def clear_events(self) -> None:
raise NotImplementedError
+ def merge(self, other: "KVConnectorKVEvents") -> "KVConnectorKVEvents":
+ self.add_events(other.get_all_events())
+ return self
+
class EventPublisher(ABC):
"""Lightweight publisher for EventBatch batches with data parallelism
diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py
index 1ceac39..d5a40fc 100644
--- a/vllm/distributed/kv_transfer/kv_connector/factory.py
+++ b/vllm/distributed/kv_transfer/kv_connector/factory.py
@@ -149,6 +149,12 @@ KVConnectorFactory.register_connector(
"ExampleConnector",
)
+KVConnectorFactory.register_connector(
+ "ExampleHiddenStatesConnector",
+ "vllm.distributed.kv_transfer.kv_connector.v1.example_hidden_states_connector",
+ "ExampleHiddenStatesConnector",
+)
+
KVConnectorFactory.register_connector(
"P2pNcclConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_connector",
diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py
index f9367da..1b012d4 100644
--- a/vllm/distributed/kv_transfer/kv_connector/utils.py
+++ b/vllm/distributed/kv_transfer/kv_connector/utils.py
@@ -413,7 +413,20 @@ class TpKVTopology:
f"by local tensor parallel size {self.tp_size}."
)
# P TP > D TP case, return the ratio as negative
- return -remote_tp_size // self.tp_size
+ return remote_tp_size // self.tp_size
+
+ def pp_ratio(
+ self,
+ remote_pp_size: int,
+ ) -> int:
+ """
+ Calculate the pipeline parallel ratio between local and remote PP.
+ """
+ assert self.pp_size % remote_pp_size == 0 or remote_pp_size % self.pp_size == 0, (
+ f"Local pipline parallel size {self.tp_size} is not divisible "
+ f"by remote pipline parallel size {remote_pp_size} or vice versa."
+ )
+ return self.pp_size // remote_pp_size if self.pp_size % remote_pp_size == 0 else remote_pp_size // self.pp_size
def block_size_ratio(
self,
@@ -457,6 +470,7 @@ class TpKVTopology:
def get_target_remote_ranks(
self,
remote_tp_size: int,
+ remote_pp_size: int
) -> list[int]:
"""
Get the remote TP rank (on P) that the current local TP rank
@@ -464,19 +478,36 @@ class TpKVTopology:
read from multiple remote ranks.
"""
tp_ratio = self.tp_ratio(remote_tp_size)
- if tp_ratio > 0:
- return [self.tp_rank // tp_ratio]
+ pp_ratio = self.pp_ratio(remote_pp_size)
+ target_pp_rank_list = []
+ target_tp_rank_list = []
+ if self.pp_size < remote_pp_size:
+ for i in range(pp_ratio):
+ target_pp_rank_list.append(self.pp_rank * pp_ratio + i)
+ else:
+ target_pp_rank_list.append(self.pp_rank // pp_ratio)
- # P TP > D TP case, D reads from |tp_ratio| remote workers.
- tp_ratio = -tp_ratio
- return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)]
+ if self.tp_size < remote_tp_size:
+ for i in range(tp_ratio):
+ target_tp_rank_list.append(self.tp_rank * tp_ratio + i)
+ else:
+ target_tp_rank_list.append(self.tp_rank // tp_ratio)
+
+ target_rank_list = []
+ for pp_rank in target_pp_rank_list:
+ for tp_rank in target_tp_rank_list:
+ target_rank = pp_rank * remote_tp_size + tp_rank
+ target_rank_list.append((target_rank, pp_rank, tp_rank))
+
+ return target_rank_list
def get_target_remote_ranks_from_engine_id(
self,
remote_engine_id: EngineId,
) -> list[int]:
remote_tp_size = self.remote_tp_size[remote_engine_id]
- return self.get_target_remote_ranks(remote_tp_size)
+ remote_pp_size = self.remote_pp_size[remote_engine_id]
+ return self.get_target_remote_ranks(remote_tp_size, remote_pp_size)
def get_current_attn_backend(vllm_config: VllmConfig):
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py
index a0e03b0..c096827 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py
@@ -543,6 +543,28 @@ class KVConnectorBase_V1(ABC):
)
return None
+ @classmethod
+ def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool:
+ """
+ Check if this connector requires PIECEWISE CUDA graph mode.
+
+ Connectors that use asynchronous layer-by-layer operations
+ (wait_for_layer_load/save_kv_layer) should override this method
+ to return True when those operations are enabled. These operations
+ cannot be captured in CUDA graphs and will be skipped during replay,
+ causing data races. PIECEWISE mode allows Python code to execute
+ between graph pieces, ensuring proper synchronization.
+
+ Args:
+ extra_config: The kv_connector_extra_config dict from
+ KVTransferConfig.
+
+ Returns:
+ True if this connector requires PIECEWISE CUDA graph mode,
+ False otherwise.
+ """
+ return False
+
def get_finished_count(self) -> int | None:
"""
Get the count of requests expected to complete send/receive operations
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py
index d4a99cf..14feafc 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py
@@ -17,6 +17,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadata
from vllm.utils.hashing import safe_hash
from vllm.v1.attention.backend import AttentionMetadata
+from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
@@ -118,12 +119,12 @@ class ExampleConnector(KVConnectorBase_V1):
The number of elements in kv_caches and layer_names should be
the same.
"""
- attn_metadata = forward_context.attn_metadata
def inject_kv_into_layer(
dst_kv_cache_layer: torch.Tensor,
src_kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
+ attn_metadata: AttentionMetadata,
) -> None:
"""Inject the KV cache into the layer.
@@ -145,6 +146,10 @@ class ExampleConnector(KVConnectorBase_V1):
num_pages * page_size, -1
)
dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache
+ elif isinstance(attn_metadata, TritonAttentionMetadata):
+ block_idxs = slot_mapping // self._block_size
+ offsets = slot_mapping % self._block_size
+ dst_kv_cache_layer[block_idxs, :, offsets] = src_kv_cache
else:
num_pages = dst_kv_cache_layer_shape[1]
page_size = dst_kv_cache_layer_shape[2]
@@ -186,7 +191,13 @@ class ExampleConnector(KVConnectorBase_V1):
layer_name, request.token_ids, request.mm_hashes
)
kv_cache = safetensors.torch.load_file(filename)["kv_cache"].cuda()
- inject_kv_into_layer(kv_cache_layer, kv_cache, request.slot_mapping)
+ if isinstance(attn_metadata, dict):
+ inject_kv_into_layer(
+ kv_cache_layer,
+ kv_cache,
+ request.slot_mapping,
+ attn_metadata[layer_name],
+ )
def wait_for_layer_load(self, layer_name: str) -> None:
"""Blocking until the KV for a specific layer is loaded into vLLM's
@@ -229,6 +240,10 @@ class ExampleConnector(KVConnectorBase_V1):
if isinstance(attn_metadata, MLACommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
+ elif isinstance(attn_metadata, TritonAttentionMetadata):
+ block_idxs = slot_mapping // self._block_size
+ offsets = slot_mapping % self._block_size
+ return layer[block_idxs, :, offsets]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...]
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py
new file mode 100644
index 0000000..945f8d9
--- /dev/null
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py
@@ -0,0 +1,354 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import os
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, Any, Optional
+
+import safetensors
+import torch
+
+from vllm.config import VllmConfig, get_layers_from_vllm_config
+from vllm.distributed.kv_transfer.kv_connector.v1.base import (
+ KVConnectorBase_V1,
+ KVConnectorMetadata,
+ KVConnectorRole,
+)
+from vllm.logger import init_logger
+from vllm.v1.attention.backend import AttentionMetadata
+from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
+
+if TYPE_CHECKING:
+ from vllm.v1.core.kv_cache_manager import KVCacheBlocks
+ from vllm.v1.kv_cache_interface import KVCacheConfig
+ from vllm.v1.request import Request
+
+logger = init_logger(__name__)
+
+
+def extract_from_kv_cache(
+ kv_cache: torch.Tensor,
+ slot_mapping: torch.Tensor,
+ num_tokens: int,
+) -> torch.Tensor:
+ """Extract data from KV cache
+ Assume the shape of the kv_cache is (num_pages, page_size, num_heads, head_size)
+ """
+
+ padded_kv = kv_cache.flatten(0, 1)[slot_mapping]
+ # shape: [len(slot_mapping), num_heads, head_size]
+ return padded_kv[:num_tokens] # shape: [num_tokens, num_heads, head_size]
+
+
+@dataclass
+class ReqMeta:
+ # Request ID
+ req_id: str
+ # Request filename
+ filename: str
+ # Request tokens
+ token_ids: torch.Tensor
+ # Slot mappings, should have the same length as token_ids
+ slot_mapping: torch.Tensor
+ # Whether this request is a new request or partially computed already
+ new_req: bool
+
+ @staticmethod
+ def make_meta(
+ req_id: str,
+ filename: str,
+ token_ids: list[int],
+ block_ids: list[int],
+ block_size: int,
+ new_req: bool,
+ ) -> "ReqMeta":
+ token_ids_tensor = torch.tensor(token_ids)
+ block_ids_tensor = torch.tensor(block_ids)
+ num_blocks = block_ids_tensor.shape[0]
+ block_offsets = torch.arange(0, block_size)
+ slot_mapping = (
+ block_offsets.reshape((1, block_size))
+ + block_ids_tensor.reshape((num_blocks, 1)) * block_size
+ )
+ slot_mapping = slot_mapping.flatten()
+ return ReqMeta(
+ req_id=req_id,
+ filename=filename,
+ token_ids=token_ids_tensor,
+ slot_mapping=slot_mapping,
+ new_req=new_req,
+ )
+
+
+@dataclass
+class ExampleHiddenStatesConnectorMetadata(KVConnectorMetadata):
+ requests: list[ReqMeta] = field(default_factory=list)
+
+ def add_request(
+ self,
+ req_id: str,
+ filename: str,
+ token_ids: list[int],
+ block_ids: list[int],
+ block_size: int,
+ new_req: bool = True,
+ ) -> None:
+ self.requests.append(
+ ReqMeta.make_meta(
+ req_id, filename, token_ids, block_ids, block_size, new_req
+ )
+ )
+
+
+class ExampleHiddenStatesConnector(KVConnectorBase_V1):
+ """
+ Simple debug implementation of a HiddenStatesConnector.
+
+ Simply extracts the hidden states from the kv cache and stores them to disk.
+ Must be used in conjunction with the `extract_hidden_states` spec decoding method.
+ """
+
+ @property
+ def prefer_cross_layer_blocks(self) -> bool:
+ """
+ Indicates whether this connector prefers KV blocks that hold KV data for all
+ layers, which can speed up KV data transfers. Defaults to False.
+ """
+ # Must be False so that drafter kv cache isn't merged with verifier's
+ return False
+
+ def __init__(
+ self,
+ vllm_config: "VllmConfig",
+ role: KVConnectorRole,
+ kv_cache_config: Optional["KVCacheConfig"] = None,
+ ):
+ super().__init__(
+ vllm_config=vllm_config,
+ role=role,
+ kv_cache_config=kv_cache_config,
+ )
+ self._block_size = vllm_config.cache_config.block_size
+ self._storage_path = self._kv_transfer_config.get_from_extra_config(
+ "shared_storage_path", "/tmp"
+ )
+ self.cache_layers: list[str] = [] # set by self.register_kv_caches
+ logger.info(self._kv_transfer_config)
+ logger.info("Shared storage path is %s", self._storage_path)
+
+ assert self._vllm_config.speculative_config is not None, (
+ "ExampleHiddenStatesConnector only works when using "
+ "'extract_hidden_states' speculative method"
+ )
+ spec_config = self._vllm_config.speculative_config.draft_model_config.hf_config
+ self.num_hidden_states = len(
+ getattr(spec_config, "eagle_aux_hidden_state_layer_ids", [])
+ )
+
+ self._request_filenames: dict[str, str] = {}
+ self._active_requests: dict[str, NewRequestData] = {}
+ self._req_blocks: dict[str, list[int]] = {}
+
+ # ==============================
+ # Worker-side methods
+ # ==============================
+ def start_load_kv(self, *args, **kwargs: Any) -> None:
+ pass # Empty implementation of abstract method
+
+ def wait_for_layer_load(self, layer_name: str) -> None:
+ pass # Empty implementation of abstract method
+
+ def wait_for_save(self):
+ pass # Empty implementation of abstract method
+
+ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
+ from vllm.model_executor.models.extract_hidden_states import (
+ CacheOnlyAttentionLayer,
+ )
+
+ # Filter layers to only include CacheOnlyAttentionLayers
+ layers = get_layers_from_vllm_config(
+ self._vllm_config, CacheOnlyAttentionLayer, list(kv_caches.keys())
+ )
+ self.cache_layers = list(layers.keys())
+ assert len(self.cache_layers) == 1, (
+ f"Expected 1 CacheOnlyAttentionLayer, got {len(self.cache_layers)}"
+ )
+
+ def save_kv_layer(
+ self,
+ layer_name: str,
+ kv_layer: torch.Tensor,
+ attn_metadata: AttentionMetadata,
+ **kwargs: Any,
+ ) -> None:
+ """Start saving the KV cache of the layer from vLLM's paged buffer
+ to the connector.
+
+ Args:
+ layer_name (str): the name of the layer.
+ kv_layer (torch.Tensor): the paged KV buffer of the current
+ layer in vLLM.
+ attn_metadata (AttentionMetadata): the attention metadata.
+ **kwargs: additional arguments for the save operation.
+ """
+ if layer_name not in self.cache_layers:
+ return
+
+ from vllm.model_executor.models.extract_hidden_states import (
+ CacheOnlyAttentionMetadata,
+ )
+
+ assert isinstance(attn_metadata, CacheOnlyAttentionMetadata), (
+ "ExampleHiddenStatesConnector only supports CacheOnlyAttentionBackend"
+ )
+
+ connector_metadata = self._get_connector_metadata()
+ assert isinstance(connector_metadata, ExampleHiddenStatesConnectorMetadata)
+
+ os.makedirs(self._storage_path, exist_ok=True)
+ for request in connector_metadata.requests:
+ hidden_states = extract_from_kv_cache(
+ kv_layer, request.slot_mapping, request.token_ids.shape[0]
+ )
+ tensors = {
+ "hidden_states": hidden_states.detach().cpu(),
+ "token_ids": request.token_ids.detach().cpu(),
+ }
+ safetensors.torch.save_file(tensors, request.filename)
+
+ # ==============================
+ # Scheduler-side methods
+ # ==============================
+
+ def get_num_new_matched_tokens(
+ self,
+ request: "Request",
+ num_computed_tokens: int,
+ ) -> tuple[int | None, bool]:
+ """
+ Get number of new tokens that can be loaded from the
+ external KV cache beyond the num_computed_tokens.
+
+ Args:
+ request (Request): the request object.
+ num_computed_tokens (int): the number of locally
+ computed tokens for this request
+
+ Returns:
+ the number of tokens that can be loaded from the
+ external KV cache beyond what is already computed.
+ """
+ # This connector is store-only, so we don't need to load any tokens
+ return 0, False
+
+ def update_state_after_alloc(
+ self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
+ ):
+ # Usually used to handle allocation of new blocks for requests that are loading
+ # tokens from connector's external kv cache. We never load from external cache
+ # so this is a no-op.
+ assert num_external_tokens == 0, "This connector is store-only"
+
+ def build_connector_meta(
+ self,
+ scheduler_output: SchedulerOutput,
+ ) -> KVConnectorMetadata:
+ """Build the connector metadata for this step.
+
+ This function should NOT modify any fields in the scheduler_output.
+ Also, calling this function will reset the state of the connector.
+
+ Args:
+ scheduler_output (SchedulerOutput): the scheduler output object.
+ """
+ meta = ExampleHiddenStatesConnectorMetadata()
+ for new_req in scheduler_output.scheduled_new_reqs:
+ token_ids = new_req.prompt_token_ids or []
+ filename = os.path.join(self._storage_path, f"{new_req.req_id}.safetensors")
+ meta.add_request(
+ new_req.req_id,
+ filename=filename,
+ token_ids=token_ids,
+ block_ids=new_req.block_ids[0],
+ block_size=self._block_size,
+ )
+ self._request_filenames[new_req.req_id] = filename
+ self._active_requests[new_req.req_id] = new_req
+ self._req_blocks[new_req.req_id] = list(new_req.block_ids[0])
+
+ cached_reqs = scheduler_output.scheduled_cached_reqs
+ for i, req_id in enumerate(cached_reqs.req_ids):
+ if req_id not in self._active_requests:
+ continue
+
+ new_block_ids = cached_reqs.new_block_ids[i]
+
+ cached_req = self._active_requests[req_id]
+ req_block_ids = self._req_blocks[req_id]
+
+ assert new_block_ids is not None
+ block_ids = new_block_ids[0]
+
+ req_block_ids.extend(block_ids)
+ filename = os.path.join(self._storage_path, f"{req_id}.safetensors")
+
+ meta.add_request(
+ req_id=req_id,
+ filename=filename,
+ token_ids=cached_req.prompt_token_ids or [],
+ block_ids=req_block_ids,
+ block_size=self._block_size,
+ new_req=False,
+ )
+
+ return meta
+
+ def request_finished(
+ self,
+ request: "Request",
+ block_ids: list[int],
+ ) -> tuple[bool, dict[str, Any] | None]:
+ """
+ Called exactly once when a request has finished, before its blocks are
+ freed.
+
+ The connector may assumes responsibility for freeing the blocks
+ asynchronously by returning True.
+
+ Returns:
+ True if the request is being saved/sent asynchronously and blocks
+ should not be freed until the request_id is returned from
+ get_finished().
+ Optional KVTransferParams to be included in the request outputs
+ returned by the engine.
+ """
+ req_id = request.request_id
+ req_filename = self._request_filenames.pop(req_id, None)
+ _ = self._active_requests.pop(req_id, None)
+ _ = self._req_blocks.pop(req_id, None)
+
+ return False, {"hidden_states_path": req_filename}
+
+ @classmethod
+ def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None:
+ """
+ Get the required KV cache layout for this connector.
+ Args:
+ vllm_config (VllmConfig): the vllm config.
+
+ Returns:
+ str: the required KV cache layout. e.g. HND, or NHD.
+ None if the connector does not require a specific layout.
+ """
+
+ if cls is KVConnectorBase_V1:
+ raise TypeError(
+ "get_required_kvcache_layout should not be called "
+ "on the abstract base class"
+ )
+ # NHD means we have (num_tokens, num_heads)
+ # HND means we have (num_heads, num_tokens)
+ # For now, we only support NHD layout since this keeps the
+ # hidden states for each token together in memory.
+ # HND is primarily used when sharding heads across devices.
+ return "NHD"
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
index 376215e..64aee2b 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
@@ -70,6 +70,16 @@ class LMCacheKVEvents(KVConnectorKVEvents):
class LMCacheConnectorV1(KVConnectorBase_V1):
+ @classmethod
+ def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool:
+ """
+ LMCache requires PIECEWISE CUDA graph mode when layerwise
+ operations are enabled. The wait_for_layer_load and save_kv_layer
+ methods perform actual async synchronization that cannot be
+ captured in CUDA graphs.
+ """
+ return extra_config.get("use_layerwise", False)
+
def __init__(
self,
vllm_config: "VllmConfig",
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py
index f105d34..744d763 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py
@@ -173,6 +173,29 @@ class MooncakeConnector(KVConnectorBase_V1):
self.connector_scheduler = None
self.connector_worker = MooncakeConnectorWorker(vllm_config, self.engine_id)
+
+ ############################################################
+ # Class Methods
+ ############################################################
+ @classmethod
+ def get_required_kvcache_layout(cls, vllm_config: VllmConfig):
+ if vllm_config.model_config is None:
+ logger.warning_once(
+ "Unable to detect current VLLM config. "
+ "Fallback to default kv cache layout."
+ )
+ return None
+ use_mla = vllm_config.model_config.use_mla
+ if use_mla:
+ # return None when we have mla
+ # as the layout should not matter in that case,
+ # which fallback to the default behavior.
+ return None
+ logger.info_once(
+ "MooncakeConnector setting KV cache layout to HND for better xfer performance."
+ )
+ return "HND"
+
############################################################
# Scheduler Side Methods
############################################################
@@ -941,7 +964,13 @@ class MooncakeConnectorWorker:
assert tensor_size_bytes == curr_tensor_size_bytes, (
"All kv cache tensors must have the same size"
)
- kernel_block_size = cache.shape[-2 if self.use_mla else -3]
+
+ cache_layout = get_kv_cache_layout()
+ if cache_layout == "HND":
+ kernel_block_size = cache.shape[-2]
+ else:
+ kernel_block_size = cache.shape[-3]
+
assert self.block_size == kernel_block_size
kv_data_ptrs.append(base_addr)
kv_data_lens.append(tensor_size_bytes)
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
index 3f0c983..7052886 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py
@@ -112,6 +112,21 @@ class MultiConnector(KVConnectorBase_V1):
- Save to all connectors.
"""
+ @classmethod
+ def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool:
+ """
+ MultiConnector requires PIECEWISE CUDA graph mode if any of its
+ child connectors require it.
+ """
+ connectors_config = extra_config.get("connectors", [])
+ for conn_config in connectors_config:
+ temp_ktc = KVTransferConfig(**conn_config)
+ connector_cls = KVConnectorFactory.get_connector_class(temp_ktc)
+ child_extra_config = conn_config.get("kv_connector_extra_config", {})
+ if connector_cls.requires_piecewise_for_cudagraph(child_extra_config):
+ return True
+ return False
+
def __init__(
self,
vllm_config: "VllmConfig",
diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
index b3f2ae7..4cafe12 100644
--- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
@@ -30,7 +30,6 @@ from vllm.distributed.kv_transfer.kv_connector.utils import (
kv_postprocess_blksize_and_layout_on_receive,
kv_postprocess_blksize_on_receive,
kv_postprocess_layout_on_receive,
- yield_req_data,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
CopyBlocksOp,
@@ -49,6 +48,7 @@ from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group,
+ get_pp_group,
)
from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
@@ -135,7 +135,10 @@ _NIXL_SUPPORTED_DEVICE = {
"cpu",
),
"tpu": ("cpu",),
- "xpu": ("cpu",),
+ "xpu": (
+ "cpu",
+ "xpu",
+ ),
"cpu": ("cpu",),
}
# support for oot platform by providing mapping in current_platform
@@ -150,8 +153,12 @@ class NixlAgentMetadata:
device_id: int
num_blocks: int
block_lens: list[int]
+ attn_backend_name: str
kv_cache_layout: str
block_size: int
+ pp_split: str
+ pp_rank: int
+ tp_rank: int
@dataclass
@@ -245,10 +252,26 @@ class RemoteMeta:
@dataclass
class ReqMeta:
local_block_ids: list[int]
- # To be used when logical block size does not match the kernel block size
local_physical_block_ids: list[int]
+ remote_block_ids: list[int]
+ remote_host: str
+ remote_port: int
+ remote_engine_id: str
tp_size: int
- remote: RemoteMeta | None = None
+ remote_pp_size: int = 1
+ remote_tp_size: int = 1
+ remote_pp_rank: int = 0
+ remote_tp_rank: int = 0
+
+ @property
+ def remote(self) -> RemoteMeta:
+ return RemoteMeta(
+ block_ids=self.remote_block_ids,
+ host=self.remote_host,
+ port=self.remote_port,
+ engine_id=self.remote_engine_id,
+ request_id="",
+ )
class NixlConnectorMetadata(KVConnectorMetadata):
@@ -259,43 +282,32 @@ class NixlConnectorMetadata(KVConnectorMetadata):
self.reqs_in_batch: set[ReqId] = set()
self.reqs_not_processed: set[ReqId] = set()
- def _add_new_req(
+ def add_new_req(
self,
+ request_id: ReqId,
local_block_ids: list[int],
kv_transfer_params: dict[str, Any],
- ) -> ReqMeta:
- return ReqMeta(
+ load_remote_cache: bool = True,
+ save_to_host: bool = False,
+ ):
+ assert load_remote_cache ^ save_to_host
+ _req = ReqMeta(
local_block_ids=local_block_ids,
local_physical_block_ids=local_block_ids,
- # P workers don't need to receive tp_size from proxy here.
+ remote_block_ids=kv_transfer_params["remote_block_ids"],
+ remote_engine_id=kv_transfer_params["remote_engine_id"],
+ remote_host=kv_transfer_params["remote_host"],
+ remote_port=kv_transfer_params["remote_port"],
tp_size=kv_transfer_params.get("tp_size", 1),
+ remote_tp_size=kv_transfer_params.get("tp_size", 1),
+ remote_pp_size=kv_transfer_params.get("pp_size", 1),
+ remote_tp_rank=0,
+ remote_pp_rank=0,
)
-
- def add_new_req_to_save(
- self,
- request_id: ReqId,
- local_block_ids: list[int],
- kv_transfer_params: dict[str, Any],
- ):
- self.reqs_to_save[request_id] = self._add_new_req(
- local_block_ids, kv_transfer_params
- )
-
- def add_new_req_to_recv(
- self,
- request_id: ReqId,
- local_block_ids: list[int],
- kv_transfer_params: dict[str, Any],
- ):
- req = self._add_new_req(local_block_ids, kv_transfer_params)
- req.remote = RemoteMeta(
- block_ids=kv_transfer_params["remote_block_ids"],
- engine_id=kv_transfer_params["remote_engine_id"],
- request_id=kv_transfer_params["remote_request_id"],
- host=kv_transfer_params["remote_host"],
- port=kv_transfer_params["remote_port"],
- )
- self.reqs_to_recv[request_id] = req
+ if save_to_host:
+ self.reqs_to_save[request_id] = _req
+ if load_remote_cache:
+ self.reqs_to_recv[request_id] = _req
class NixlConnector(KVConnectorBase_V1):
@@ -543,7 +555,7 @@ class NixlConnectorScheduler:
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {}
- self._reqs_need_save: dict[ReqId, Request] = {}
+ self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {}
# Reqs to send and their expiration time
self._reqs_need_send: dict[ReqId, float] = {}
self._reqs_in_batch: set[ReqId] = set()
@@ -568,17 +580,17 @@ class NixlConnectorScheduler:
"""
encoded_data: dict[int, bytes] = {}
encoder = msgspec.msgpack.Encoder()
- for tp_rank, rank_metadata in metadata.items():
+ for global_rank, rank_metadata in metadata.items():
if not isinstance(rank_metadata, NixlHandshakePayload):
raise ValueError(
"NixlConnectorScheduler expects NixlHandshakePayload for "
"handshake metadata."
)
- encoded_data[tp_rank] = encoder.encode(rank_metadata)
+ encoded_data[global_rank] = encoder.encode(rank_metadata)
logger.debug(
- "Tp rank %d: encoded NixlHandshakePayload size: %s bytes",
- tp_rank,
- str(len(encoded_data[tp_rank])),
+ "Global rank %d: encoded NixlHandshakePayload size: %s bytes",
+ global_rank,
+ str(len(encoded_data[global_rank])),
)
self._encoded_xfer_handshake_metadata = encoded_data
@@ -625,14 +637,14 @@ class NixlConnectorScheduler:
break
continue
# Decode the message which contains (GET_META_MSG, rank)
- msg, target_tp_rank = msgspec.msgpack.decode(msg)
+ msg, target_rank = msgspec.msgpack.decode(msg)
logger.debug(
- "Received message for tp rank %s",
- target_tp_rank,
+ "Received message for rank %s",
+ target_rank,
)
if msg != GET_META_MSG:
logger.warning("Connection listener got unexpected message %s", msg)
- sock.send_multipart((identity, b"", encoded_data[target_tp_rank]))
+ sock.send_multipart((identity, b"", encoded_data[target_rank]))
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
@@ -689,7 +701,9 @@ class NixlConnectorScheduler:
if self.use_host_buffer and params.get("do_remote_decode"):
# NOTE: when accelerator is not directly supported by Nixl,
# prefilled blocks need to be saved to host memory before transfer.
- self._reqs_need_save[request.request_id] = request
+ block_ids = blocks.get_block_ids()[0]
+ if block_ids:
+ self._reqs_need_save[request.request_id] = (request, block_ids)
elif params.get("do_remote_prefill"):
if params.get("remote_block_ids"):
if all(
@@ -735,38 +749,23 @@ class NixlConnectorScheduler:
# Loop through scheduled reqs and convert to ReqMeta.
for req_id, (req, block_ids) in self._reqs_need_recv.items():
assert req.kv_transfer_params is not None
- meta.add_new_req_to_recv(
+ meta.add_new_req(
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
+ load_remote_cache=True,
+ save_to_host=False,
)
- # NOTE: For the prefill side, there might be a chance that an early added
- # request is a chunked prefill, so we need to check if new blocks are added
- for req_id, new_block_id_groups, _ in yield_req_data(scheduler_output):
- req_to_save = self._reqs_need_save.get(req_id)
- if req_to_save is None or new_block_id_groups is None:
- continue
- req = req_to_save
-
+ for req_id, (req, block_ids) in self._reqs_need_save.items():
assert req.kv_transfer_params is not None
- meta.add_new_req_to_save(
+ meta.add_new_req(
request_id=req_id,
- local_block_ids=new_block_id_groups[0],
+ local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
+ load_remote_cache=False,
+ save_to_host=True,
)
- assert scheduler_output.num_scheduled_tokens is not None
- num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
- is_partial = (
- req.num_computed_tokens + num_scheduled_tokens
- ) < req.num_prompt_tokens
- if not is_partial:
- # For non-partial prefills, once new req_meta is scheduled, it
- # can be removed from _reqs_need_save.
- # For partial prefill case, we will retain the request in
- # _reqs_need_save until all blocks are scheduled with req_meta.
- # Therefore, only pop if `not is_partial`.
- self._reqs_need_save.pop(req_id)
meta.reqs_to_send = self._reqs_need_send
meta.reqs_in_batch = self._reqs_in_batch
@@ -774,6 +773,7 @@ class NixlConnectorScheduler:
# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
+ self._reqs_need_save.clear()
self._reqs_in_batch = set()
self._reqs_not_processed = set()
self._reqs_need_send = {}
@@ -819,8 +819,6 @@ class NixlConnectorScheduler:
# Also include the case of a P/D Prefill request with immediate
# block free (eg abort). Stop tracking this request.
self._reqs_not_processed.add(request.request_id)
- # Clear _reqs_need_save if a request is aborted as partial prefill.
- self._reqs_need_save.pop(request.request_id, None)
return False, None
# TODO: check whether block_ids actually ever be 0. If not we could
@@ -848,6 +846,7 @@ class NixlConnectorScheduler:
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
tp_size=self.vllm_config.parallel_config.tensor_parallel_size,
+ pp_size=self.vllm_config.parallel_config.pipeline_parallel_size,
)
@@ -901,6 +900,12 @@ class NixlConnectorWorker:
# Metadata.
self.engine_id: EngineId = engine_id
self.tp_rank = get_tensor_model_parallel_rank()
+ self.pp_rank = get_pp_group().rank_in_group
+ self.tp_world_size = get_tp_group().world_size
+ self.pp_world_size = get_pp_group().world_size
+ self.dp_rank = vllm_config.parallel_config.data_parallel_rank
+ self.dp_world_size = vllm_config.parallel_config.data_parallel_size
+ self.global_rank = self.pp_rank * self.tp_world_size + self.tp_rank
self.world_size = get_tensor_model_parallel_world_size()
self.tp_group = get_tp_group()
self.num_blocks = 0
@@ -945,7 +950,7 @@ class NixlConnectorWorker:
# type based on kv_buffer_device
nixl_memory_type = current_platform.get_nixl_memory_type()
if nixl_memory_type is None:
- if self.kv_buffer_device == "cuda":
+ if self.kv_buffer_device in ["cuda", "xpu"]:
nixl_memory_type = "VRAM"
elif self.kv_buffer_device == "cpu":
nixl_memory_type = "DRAM"
@@ -960,10 +965,9 @@ class NixlConnectorWorker:
self.copy_blocks: CopyBlocksOp | None = None
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
+ # rank will still only pull from a single remote TP worker.
+ self.kv_caches_base_addr: dict[EngineId, list[int]] = {}
self.device_id: int = 0
- # Current rank may pull from multiple remote TP workers.
- # EngineId, dict[int, list[int]] -> engine_id, tp_rank, base_addr_for_layer
- self.kv_caches_base_addr = defaultdict[EngineId, dict[int, list[int]]](dict)
# Number of NIXL regions. Currently one region per cache
# (so 1 per layer for MLA, otherwise 2 per layer)
@@ -971,12 +975,12 @@ class NixlConnectorWorker:
self.num_layers = 0
# nixl_prepped_dlist_handle.
- self.src_xfer_handles_by_block_size: dict[int, int] = {}
- # Populated dynamically during handshake based on remote configuration.
- # Keep track of regions at different tp_ratio values. tp_ratio->handles
- self.src_xfer_handles_by_tp_ratio: dict[int, list[int]] = {}
- # Map of engine_id -> {tp_rank: nixl_prepped_dlist_handle (int)}.
- self.dst_xfer_side_handles = defaultdict[EngineId, dict[int, int]](dict)
+ self.src_xfer_side_handle: int = 0
+ # Map of engine_id -> nixl_prepped_dlist_handle (int)].
+ self.dst_xfer_side_handles: dict[EngineId, dict[int, int]] = {}
+ self.src_xfer_side_handles: dict[EngineId, dict[int, int]] = {}
+ # Map of (offset, len) -> nixl_prepped_dlist_handle (int).
+ self.src_registered_xfer_side_headles: dict[tuple[int], int] = {}
# Map of engine_id -> num_blocks. All ranks in the same deployment will
# have the same number of blocks.
@@ -986,7 +990,7 @@ class NixlConnectorWorker:
# In progress transfers.
# [req_id -> list[handle]]
self._recving_metadata: dict[ReqId, ReqMeta] = {}
- self._recving_transfers = defaultdict[ReqId, list[TransferHandle]](list)
+ self._recving_transfers = defaultdict[ReqId, list[tuple[int, float]]](list)
# Track the expiration time of requests that are waiting to be sent.
self._reqs_to_send: dict[ReqId, float] = {}
# Set of requests that have been part of a batch, regardless of status.
@@ -1034,7 +1038,8 @@ class NixlConnectorWorker:
self.compat_hash: str | None = None
self.kv_topo: TpKVTopology | None = None
- self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
+ self._tp_size: dict[EngineId, int] = {self.engine_id: self.tp_world_size}
+ self._pp_size: dict[EngineId, int] = {self.engine_id: self.pp_world_size}
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}
# With heterogeneous TP, P must wait for all assigned D TP workers to
# finish reading before safely freeing the blocks.
@@ -1047,35 +1052,50 @@ class NixlConnectorWorker:
"enforce_handshake_compat", True
)
+ pp_split = envs.VLLM_PP_LAYER_PARTITION
+ if pp_split is None and self.pp_world_size > 1:
+ raise RuntimeError(
+ "VLLM_PP_LAYER_PARTITION must be set when using nixl PD with pp > 1."
+ )
+ self.pp_split = pp_split or "0"
+ if self.pp_world_size > 1:
+ assert len(self.pp_split.split(",")) == self.pp_world_size, (
+ "VLLM_PP_LAYER_PARTITION split count must equal pp_size"
+ )
+ self.local_start_index: dict[EngineId, dict[int, tuple]] = {}
+ self.local_end_index: dict[EngineId, dict[int, tuple]] = {}
+ self.remote_nums: dict[EngineId, dict[int, int]] = {}
+ self.is_kv = 1 if self.use_mla else 2
+ self.kv_element_size = None
+ self.ds_v32 = hasattr(self.model_config.hf_config, "index_topk")
+
def _nixl_handshake(
self,
host: str,
port: int,
remote_tp_size: int,
+ remote_pp_size: int,
expected_engine_id: str,
- ) -> dict[int, str]:
+ ) -> dict[int, tuple[str, bool]]:
"""Do a NIXL handshake with a remote instance."""
- # When target instance TP > local TP, we need to perform multiple
- # handshakes. Do it in a single background job for simplicity.
- # Regardless, only handshake with the remote TP rank(s) that current
- # local rank will read from. Note that With homogeneous TP,
- # this happens to be the same single rank_i.
assert self.kv_topo is not None
- p_remote_ranks = self.kv_topo.get_target_remote_ranks(remote_tp_size)
- remote_rank_to_agent_name = {}
+ p_remote_rank_list = self.kv_topo.get_target_remote_ranks(
+ remote_tp_size, remote_pp_size
+ )
+ remote_agent_dict: dict[int, tuple[str, bool]] = {}
path = make_zmq_path("tcp", host, port)
with zmq_ctx(zmq.REQ, path) as sock:
- for remote_rank in p_remote_ranks:
+ for p_remote_rank, remote_pp_rank, remote_tp_rank in p_remote_rank_list:
logger.debug(
- "Querying metadata on path: %s at remote tp rank %s",
- path,
- remote_rank,
+ "Querying metadata on path: %s at remote rank %s "
+ "(pp_rank=%s, tp_rank=%s)",
+ path, p_remote_rank, remote_pp_rank, remote_tp_rank,
)
start_time = time.perf_counter()
# Send query for the request.
- msg = msgspec.msgpack.encode((GET_META_MSG, remote_rank))
+ msg = msgspec.msgpack.encode((GET_META_MSG, p_remote_rank))
# Set receive timeout to 5 seconds to avoid hanging on dead server
sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds
sock.send(msg)
@@ -1108,10 +1128,7 @@ class NixlConnectorWorker:
f"Local: {self.compat_hash}, "
f"Remote: {handshake_payload.compatibility_hash}. "
f"Prefill and decode instances have incompatible "
- f"configurations. This may be due to: different vLLM versions,"
- f" models, dtypes, KV cache layouts, attention backends, etc. "
- f"Both instances must use identical configurations."
- f"Disable this check using "
+ f"configurations. Disable this check using "
f'--kv-transfer-config \'{{"kv_connector_extra_config": '
f'{{"enforce_handshake_compat": false}}}}\''
)
@@ -1140,18 +1157,22 @@ class NixlConnectorWorker:
f"Expected {expected_engine_id},"
f"received {metadata.engine_id}."
)
- setup_agent_time = time.perf_counter()
- # Register Remote agent.
- remote_agent_name = self.add_remote_agent(
- metadata, remote_rank, remote_tp_size
+ assert metadata.block_size <= self.block_size, (
+ "nP > nD is not supported yet."
)
+ remote_agent_name, read = self.add_remote_agent(
+ metadata, p_remote_rank, remote_tp_rank,
+ remote_tp_size, remote_pp_rank, remote_pp_size,
+ )
+
+ setup_agent_time = time.perf_counter()
logger.debug(
"NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time,
)
- remote_rank_to_agent_name[remote_rank] = remote_agent_name
- return remote_rank_to_agent_name
+ remote_agent_dict[p_remote_rank] = (remote_agent_name, read)
+ return remote_agent_dict
def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None:
"""
@@ -1228,15 +1249,14 @@ class NixlConnectorWorker:
# Try to get metadata from in progress transfers when not provided
meta = self._recving_metadata.get(req_id)
- if meta and meta.remote:
+ if meta:
context.update(
{
- "remote_engine_id": meta.remote.engine_id,
- "remote_request_id": meta.remote.request_id,
- "remote_host": meta.remote.host,
- "remote_port": meta.remote.port,
+ "remote_engine_id": meta.remote_engine_id,
+ "remote_host": meta.remote_host,
+ "remote_port": meta.remote_port,
"num_local_blocks": len(meta.local_block_ids),
- "num_remote_blocks": len(meta.remote.block_ids),
+ "num_remote_blocks": len(meta.remote_block_ids),
"local_block_ids_sample": meta.local_block_ids[:10],
}
)
@@ -1259,12 +1279,12 @@ class NixlConnectorWorker:
# Do NIXL handshake in background and add to _ready_requests when done.
fut = self._handshake_futures.get(remote_engine_id)
if fut is None:
- assert meta.remote is not None
fut = self._handshake_initiation_executor.submit(
self._nixl_handshake,
- meta.remote.host,
- meta.remote.port,
- meta.tp_size,
+ meta.remote_host,
+ meta.remote_port,
+ meta.remote_tp_size,
+ meta.remote_pp_size,
remote_engine_id,
)
self._handshake_futures[remote_engine_id] = fut
@@ -1317,6 +1337,48 @@ class NixlConnectorWorker:
attn_backend=self.attn_backend,
tensor_shape=next(iter(kv_caches.values())).shape,
)
+ original_get_target = self.kv_topo.get_target_remote_ranks
+
+ def _pp_aware_get_target_remote_ranks(
+ remote_tp_size, remote_pp_size=1
+ ):
+ if remote_pp_size == 1 and self.pp_world_size == 1:
+ tp_ranks = original_get_target(remote_tp_size)
+ return [(r, 0, r) for r in tp_ranks]
+
+ tp_ratio = self.kv_topo.tp_ratio(remote_tp_size)
+ pp_ratio = (
+ self.pp_world_size // remote_pp_size
+ if self.pp_world_size >= remote_pp_size
+ else remote_pp_size // self.pp_world_size
+ )
+ target_pp_rank_list = []
+ target_tp_rank_list = []
+ if self.pp_world_size < remote_pp_size:
+ for i in range(pp_ratio):
+ target_pp_rank_list.append(
+ self.pp_rank * pp_ratio + i)
+ else:
+ target_pp_rank_list.append(
+ self.pp_rank // pp_ratio)
+
+ if self.tp_world_size < remote_tp_size:
+ for i in range(tp_ratio):
+ target_tp_rank_list.append(
+ self.tp_rank * tp_ratio + i)
+ else:
+ target_tp_rank_list.append(
+ self.tp_rank // tp_ratio)
+
+ result = []
+ for pp_rank in target_pp_rank_list:
+ for tp_rank in target_tp_rank_list:
+ global_rank = pp_rank * remote_tp_size + tp_rank
+ result.append((global_rank, pp_rank, tp_rank))
+ return result
+
+ self.kv_topo.get_target_remote_ranks = (
+ _pp_aware_get_target_remote_ranks)
self.compat_hash = compute_nixl_compatibility_hash(
self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks
)
@@ -1422,7 +1484,7 @@ class NixlConnectorWorker:
assert len(self.block_len_per_layer) == len(seen_base_addresses)
assert self.num_blocks != 0
- self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses
+ self.kv_caches_base_addr[self.engine_id] = seen_base_addresses
self.num_regions = len(caches_data)
self.num_layers = len(xfer_buffers.keys())
@@ -1450,8 +1512,9 @@ class NixlConnectorWorker:
# Register local/src descr for NIXL xfer.
self.seen_base_addresses = seen_base_addresses
- self.src_xfer_handles_by_block_size[self.block_size], self.src_blocks_data = (
- self.register_local_xfer_handler(self.block_size)
+ first_kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=0)
+ self.src_registered_xfer_side_headles[(0, first_kv_block_len)] = (
+ self.register_local_xfer_handler(self.block_size)[0]
)
# TODO(mgoin): Hybrid memory allocator is currently disabled for
@@ -1481,13 +1544,17 @@ class NixlConnectorWorker:
engine_id=self.engine_id,
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
device_id=self.device_id,
- kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id][self.tp_rank],
+ kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
num_blocks=self.num_blocks,
block_lens=self.block_len_per_layer,
+ attn_backend_name=self.backend_name,
kv_cache_layout=self.kv_cache_layout
if not self.use_host_buffer
else self.host_buffer_kv_cache_layout,
block_size=self.block_size,
+ pp_split=self.pp_split,
+ pp_rank=self.pp_rank,
+ tp_rank=self.tp_rank,
)
# Wrap metadata in payload with hash for defensive decoding
assert self.compat_hash is not None
@@ -1555,63 +1622,33 @@ class NixlConnectorWorker:
def add_remote_agent(
self,
nixl_agent_meta: NixlAgentMetadata,
+ p_remote_rank: int = 0,
remote_tp_rank: int = 0,
remote_tp_size: int = 1,
- ) -> str:
+ remote_pp_rank: int = 0,
+ remote_pp_size: int = 1,
+ ) -> tuple[str, bool]:
"""
Add the remote NIXL agent and prepare the descriptors for reading cache
- blocks from remote.
-
- In particular, handle both homogeneous and heterogeneous TP. The former
- requires local rank_i to read from remote rank_i.
- The latter, in the case of D.world_size < P.world_size, requires that a
- local (D) TP worker reads from multiple remote (P) TP workers.
- Conversely, assuming D.world_size > P.world_size, two or more local TP
- workers will read from a single remote TP worker.
-
- Here's an example for the last case described above (non-MLA):
-
- rank_offset p_remote_tp_rank
- (kv split no)
- --------------------------------
- 0 0 Worker0 ---- 1st half of KV ----> Worker0 [ KV Cache ]
- /
- 1 0 Worker1 ---- 2nd half of KV -----/
-
- 0 1 Worker2 ---- 1st half of KV ----> Worker1 [ KV Cache ]
- /
- 1 1 Worker3 ---- 2nd half of KV -----/
-
-
- Decoder TP workers Prefix TP workers
- (world_size=4) (world_size=2)
- tp_ratio = 4 // 2 = 2
-
- Considering the KV Caches, if P-Worker_i has cache size [2, num_blocksP, kv_heads, block_size, head_dim]
- then D-Worker_j has [2, num_blocksD, kv_heads//tp_ratio, block_size, head_dim]. Mind the "HND" layout format.
- Assuming num_blocksD >= num_blocksP, D-Worker0 reads from P-Worker0 by preparing the kv_heads//tp_ratio
- first heads from all the slots of all the blocks. D-Worker1 will do the same, but reading the second split
- along the kv_heads dimension, and so forth until "tp_ratio" D TP workers have pulled from P-Worker0.
-
- Note that the above will also hold true for the homogeneous TP case, where tp_ratio evaluates to 1.
-
- Regarding MLA case, the cache is replicated across TP workers so the rank_offset will just always be 0
- so that the whole cache is shared by "tp_ratio" D TP workers.
- """ # noqa: E501
+ blocks from remote, with PP support.
+ """
engine_id = nixl_agent_meta.engine_id
- # TODO re-evaluate refreshing for scaling/recovery
- if remote_tp_rank in self._remote_agents.get(engine_id, {}):
+ if p_remote_rank in self._remote_agents.get(engine_id, {}):
logger.debug(
- "Remote agent with engine_id %s and rank"
+ "Remote agent with engine_id %s and rank "
"%s already exchanged metadata, skip handshake.",
- engine_id,
- remote_tp_rank,
+ engine_id, p_remote_rank,
)
- return self._remote_agents[engine_id][remote_tp_rank]
+ return self._remote_agents[engine_id][p_remote_rank]
+
+ if (self.tp_world_size == remote_tp_size
+ and self.pp_world_size == remote_pp_size):
+ assert self.tp_rank == remote_tp_rank
+ assert self.pp_rank == remote_pp_rank
- ### Register remote agent metadata
if engine_id not in self._tp_size:
self._tp_size[engine_id] = remote_tp_size
+ self._pp_size[engine_id] = remote_pp_size
if engine_id not in self._block_size:
self._block_size[engine_id] = nixl_agent_meta.block_size
@@ -1619,139 +1656,275 @@ class NixlConnectorWorker:
nixl_agent_meta.agent_metadata
)
- # Create dst descs and xfer side handles. TP workers have same #blocks
- # so we only register once per engine_id.
- # Example:
- # block_size_ratio > 1:
- # remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|
- # local origin:| 0| 1| 8| 12|
- # local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15|
+ remote_replicates_kv_cache = self.kv_topo.replicates_kv_cache(engine_id)
+ local_replicates_kv_cache = self.kv_topo.replicates_kv_cache(self.engine_id)
+
assert self.kv_topo is not None
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(engine_id)
+ assert block_size_ratio == 1, "block_size_ratio != 1 not supported with PP"
if engine_id not in self.dst_num_blocks:
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
- # Keep track of remote agent kv caches base addresses.
- self.kv_caches_base_addr[engine_id][remote_tp_rank] = (
- nixl_agent_meta.kv_caches_base_addr
+ self._validate_remote_agent_handshake(
+ nixl_agent_meta, remote_tp_size, remote_pp_size
)
- self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size)
- # This is 1 when P and D `--tensor-parallel-size` match. Otherwise,
- # this is the ratio between the two sizes.
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(engine_id)
+ total_num_kv_heads = self.model_config.get_total_num_kv_heads()
- # Handle tp_size>num_kv_heads: replicate KV cache.
- indexes_into_remote = (
- not self.kv_topo.replicates_kv_cache(engine_id) and tp_ratio > 0
- )
-
- logger.debug(
- "Registering remote agent (%s, rank %s) memory regions with tp_ratio %s",
- engine_id,
- remote_tp_rank,
- tp_ratio,
- )
-
- ### (Optional) Register local agent memory regions. MLA is not split.
- if (
- tp_ratio < 0
- and not self.use_mla
- and tp_ratio not in self.src_xfer_handles_by_tp_ratio
- ):
- # Remote tp_size > local tp_size: read from multiple remote ranks.
- # Logically "split" own regions into |tp_ratio| chunks. Mind that
- # we only do this once per remote tp_size (replica-friendly).
- self.src_xfer_handles_by_tp_ratio[tp_ratio] = []
- for i in range(-tp_ratio):
- blocks_data = []
- for memory_region in self.src_blocks_data:
- addr, local_block_len, own_tp_rank = memory_region
- # Computing block len layer by layer allows for different
- # block sizes to be used.
- remote_block_len = local_block_len // (-tp_ratio)
- addr = addr + i * remote_block_len
- blocks_data.append((addr, remote_block_len, own_tp_rank))
- descs = self.nixl_wrapper.get_xfer_descs(
- blocks_data, self.nixl_memory_type
- )
- handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs)
- self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle)
-
- ### Register remote agent memory regions
blocks_data = []
- # With homogeneous TP, D pulls the whole kv cache from corresponding
- # rank. With heterogeneous TP, prepare the descriptors by splitting the
- # P KV cache along kv_head dim, of D worker's kv_head size (D>P).
- # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
+ if engine_id not in self.local_start_index:
+ self.local_start_index[engine_id] = {}
+ self.local_end_index[engine_id] = {}
+ self.remote_nums[engine_id] = {}
+ self.local_start_index[engine_id][p_remote_rank] = 0
+ self.local_end_index[engine_id][p_remote_rank] = self.num_regions
+ self.remote_nums[engine_id][p_remote_rank] = self.num_regions
+ remote_start = 0
+ remote_end = self.num_regions
- # Register all remote blocks, but only the corresponding kv heads.
- for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr):
- # Read our whole local region size from remote.
- local_block_len = self.get_backend_aware_kv_block_len(layer_idx=i)
- remote_kv_block_len = local_block_len // block_size_ratio
- if block_size_ratio > 1:
- # using remote kv_block_len as transfer unit
- local_block_len = remote_kv_block_len
-
- if tp_ratio < 0 and not self.use_mla:
- # Remote tp is bigger: read a chunk of local region from remote
- local_block_len = local_block_len // (-tp_ratio)
- rank_offset = (
- self.tp_rank % tp_ratio * remote_kv_block_len
- if indexes_into_remote
- else 0
+ if self.pp_world_size != remote_pp_size:
+ assert self.pp_split != "" or nixl_agent_meta.pp_split != "", (
+ "pp_split must be set when pp_size > 1"
)
+ assert (self.pp_world_size % remote_pp_size == 0
+ or remote_pp_size % self.pp_world_size == 0)
+
+ local_pp_layer_split = [int(x) for x in self.pp_split.split(",")]
+ local_pp_layer_index_cum = [0]
+ for x in local_pp_layer_split:
+ local_pp_layer_index_cum.append(
+ local_pp_layer_index_cum[-1] + x)
+
+ remote_pp_layer_split = [
+ int(x) for x in nixl_agent_meta.pp_split.split(",")
+ ]
+ remote_pp_layer_index_cum = [0]
+ for x in remote_pp_layer_split:
+ remote_pp_layer_index_cum.append(
+ remote_pp_layer_index_cum[-1] + x)
+
+ if self.pp_world_size < remote_pp_size:
+ ratio = remote_pp_size // self.pp_world_size
+ assert remote_pp_rank == int(nixl_agent_meta.pp_rank)
+ assert (local_pp_layer_index_cum[self.pp_rank + 1]
+ >= remote_pp_layer_index_cum[remote_pp_rank + 1])
+
+ if (remote_pp_rank + 1) % ratio == 0:
+ assert (local_pp_layer_index_cum[self.pp_rank + 1]
+ == remote_pp_layer_index_cum[remote_pp_rank + 1])
+
+ lst = (0 if remote_pp_rank == 0
+ else (remote_pp_layer_index_cum[remote_pp_rank]
+ - local_pp_layer_index_cum[
+ remote_pp_rank // ratio]))
+ lst = lst * self.is_kv
+ led = (remote_pp_layer_split[remote_pp_rank]
+ * self.is_kv + lst)
+ self.local_start_index[engine_id][p_remote_rank] = lst
+ self.local_end_index[engine_id][p_remote_rank] = led
+ if self.ds_v32:
+ local_pp_layers = local_pp_layer_split[self.pp_rank]
+ self.local_start_index[engine_id][p_remote_rank] = (
+ lst, lst + local_pp_layers)
+ self.local_end_index[engine_id][p_remote_rank] = (
+ led, led + local_pp_layers)
+ remote_start = 0
+ remote_end = (remote_pp_layer_split[remote_pp_rank]
+ * self.is_kv)
+ self.remote_nums[engine_id][p_remote_rank] = (
+ remote_end if not self.ds_v32 else remote_end * 2)
+ else:
+ ratio = self.pp_world_size // remote_pp_size
+ assert remote_pp_rank == int(nixl_agent_meta.pp_rank)
+ assert self.pp_rank // ratio == remote_pp_rank
+ assert (local_pp_layer_index_cum[self.pp_rank + 1]
+ <= remote_pp_layer_index_cum[remote_pp_rank + 1])
+
+ if (self.pp_rank + 1) % ratio == 0:
+ assert (local_pp_layer_index_cum[self.pp_rank + 1]
+ == remote_pp_layer_index_cum[
+ remote_pp_rank + 1])
+
+ remote_start = (
+ 0 if self.pp_rank == 0
+ else (local_pp_layer_index_cum[self.pp_rank]
+ - remote_pp_layer_index_cum[remote_pp_rank]))
+ remote_start = remote_start * self.is_kv
+ remote_end = (local_pp_layer_split[self.pp_rank]
+ * self.is_kv + remote_start)
+ self.remote_nums[engine_id][p_remote_rank] = (
+ remote_end - remote_start if not self.ds_v32
+ else (remote_end - remote_start) * 2)
+ else:
+ if self.pp_world_size > 1:
+ assert self.pp_split == nixl_agent_meta.pp_split, (
+ f"pp_split must match: {self.pp_split} != "
+ f"{nixl_agent_meta.pp_split}")
+
+ kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=0)
+ read = True
+ local_rank_offset = 0
+ local_block_len = kv_block_len
+ remote_rank_offset = 0
+ default_local_xfer_dlist_key = (0, kv_block_len)
+ current_local_xfer_dlist = self.src_registered_xfer_side_headles[
+ default_local_xfer_dlist_key]
+ remote_block_len = nixl_agent_meta.block_lens[0]
+ block_len_ratio = 1
+
+ if self.use_mla or (local_replicates_kv_cache
+ and remote_replicates_kv_cache):
+ assert self.block_len_per_layer[0] == nixl_agent_meta.block_lens[0], (
+ "KV cache sizes must match between P and D when replicated"
+ )
+ if self.tp_world_size < remote_tp_size:
+ read = bool(remote_tp_rank % tp_ratio == 0)
+ else:
+ if self.tp_world_size >= remote_tp_size:
+ assert not remote_replicates_kv_cache
+ if not local_replicates_kv_cache:
+ remote_rank_offset = (self.tp_rank % tp_ratio
+ * local_block_len)
+ assert remote_block_len == (
+ self.block_len_per_layer[0] * tp_ratio)
+ else:
+ local_repeat_times = (self._tp_size[self.engine_id]
+ // total_num_kv_heads)
+ remote_rank_offset = (
+ (self.tp_rank // local_repeat_times)
+ % (tp_ratio // local_repeat_times)
+ * local_block_len)
+ assert remote_block_len == (
+ self.block_len_per_layer[0]
+ * (tp_ratio // local_repeat_times))
+ block_len_ratio = max(remote_block_len // local_block_len, 1)
+ else:
+ assert not local_replicates_kv_cache
+ if not remote_replicates_kv_cache:
+ assert remote_block_len == local_block_len // tp_ratio
+ local_rank_offset = (remote_block_len
+ * (remote_tp_rank % tp_ratio))
+ assert remote_block_len == (
+ self.block_len_per_layer[0] // tp_ratio)
+ else:
+ remote_repeat_times = (remote_tp_size
+ // total_num_kv_heads)
+ assert remote_block_len == (
+ local_block_len
+ // (tp_ratio // remote_repeat_times))
+ read = bool(
+ remote_tp_rank % remote_repeat_times == 0)
+ local_rank_offset = (
+ remote_block_len
+ * ((remote_tp_rank // remote_repeat_times)
+ % (tp_ratio // remote_repeat_times)))
+ assert remote_block_len == (
+ self.block_len_per_layer[0]
+ // (tp_ratio // remote_repeat_times))
+ local_block_len = remote_block_len
+
+ if not read:
+ local_rank_offset = 0
+
+ def register_new_local_xfer_dlist(offset, block_len):
+ bdata = []
+ for base_addr in self.kv_caches_base_addr[self.engine_id]:
+ for block_id in range(self.num_blocks):
+ boffset = (block_id
+ * self.get_backend_aware_kv_block_len(0))
+ addr = base_addr + boffset + offset
+ bdata.append((addr, block_len, self.global_rank))
+ descs = self.nixl_wrapper.get_xfer_descs(
+ bdata, self.nixl_memory_type)
+ return self.nixl_wrapper.prep_xfer_dlist(
+ "NIXL_INIT_AGENT", descs)
+
+ if (local_rank_offset, local_block_len) != default_local_xfer_dlist_key:
+ key = (local_rank_offset, local_block_len)
+ if key in self.src_registered_xfer_side_headles:
+ current_local_xfer_dlist = (
+ self.src_registered_xfer_side_headles[key])
+ else:
+ current_local_xfer_dlist = register_new_local_xfer_dlist(
+ local_rank_offset, local_block_len)
+ self.src_registered_xfer_side_headles[key] = (
+ current_local_xfer_dlist)
+
+ self.kv_caches_base_addr[engine_id] = (
+ nixl_agent_meta.kv_caches_base_addr[remote_start:remote_end])
+
+ for i, base_addr in enumerate(
+ nixl_agent_meta.kv_caches_base_addr):
+ if i < remote_start or i >= remote_end:
+ if self.ds_v32:
+ half = len(nixl_agent_meta.kv_caches_base_addr) // 2
+ if (i < remote_start + half
+ or i >= remote_end + half):
+ continue
+ else:
+ continue
for block_id in range(nixl_agent_meta.num_blocks):
- block_offset = block_id * nixl_agent_meta.block_lens[i]
- # For each block, grab the heads chunk belonging to rank_i
- # of size remote_nheads // tp_ratio, which correspond to
- # self.block_len == remote_block_len//tp_ratio bytes.
- addr = base_addr + block_offset + rank_offset
- # (addr, len, device id)
- blocks_data.append((addr, local_block_len, nixl_agent_meta.device_id))
+ cur_layer_block_len = nixl_agent_meta.block_lens[i]
+ block_offset = block_id * cur_layer_block_len
+ addr = base_addr + block_offset + remote_rank_offset
+ blocks_data.append((
+ addr,
+ cur_layer_block_len // block_len_ratio,
+ nixl_agent_meta.device_id,
+ ))
if self.kv_topo.is_kv_layout_blocks_first:
- # With FlashInfer index V separately to allow head splitting.
- for block_id in range(nixl_agent_meta.num_blocks):
- block_offset = block_id * nixl_agent_meta.block_lens[i]
- addr = base_addr + block_offset + rank_offset
- v_addr = addr + nixl_agent_meta.block_lens[i] // 2
- blocks_data.append(
- (v_addr, local_block_len, nixl_agent_meta.device_id)
- )
+ raise NotImplementedError(
+ "FlashInfer not supported with PP yet")
logger.debug(
- "Created %s blocks for dst engine %s with remote rank %s and local rank %s",
- len(blocks_data),
- engine_id,
- remote_tp_rank,
- self.tp_rank,
+ "Created %s blocks for dst engine %s with remote rank %s "
+ "and local rank %s",
+ len(blocks_data), engine_id, p_remote_rank, self.global_rank,
)
- # Register with NIXL.
- descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type)
- self.dst_xfer_side_handles[engine_id][remote_tp_rank] = (
- self.nixl_wrapper.prep_xfer_dlist(remote_agent_name, descs)
+ descs = self.nixl_wrapper.get_xfer_descs(
+ blocks_data, self.nixl_memory_type)
+ if engine_id not in self.dst_xfer_side_handles:
+ self.dst_xfer_side_handles[engine_id] = {
+ p_remote_rank: self.nixl_wrapper.prep_xfer_dlist(
+ remote_agent_name, descs)}
+ self.src_xfer_side_handles[engine_id] = {
+ p_remote_rank: current_local_xfer_dlist}
+ else:
+ self.dst_xfer_side_handles[engine_id][p_remote_rank] = (
+ self.nixl_wrapper.prep_xfer_dlist(remote_agent_name, descs))
+ self.src_xfer_side_handles[engine_id][p_remote_rank] = (
+ current_local_xfer_dlist)
+
+ logger.info(
+ "create remote info: "
+ "local: tp_rank %s, tp_size %s, pp_rank %s, pp_size %s, "
+ "remote: tp_rank %s, tp_size %s, pp_rank %s, pp_size %s, "
+ "local_rank_offset: %s, local_block_len: %s, "
+ "local start: %s, local end: %s, "
+ "remote_rank_offset %s, remote_block_len %s, "
+ "remote start: %s, remote end: %s, read: %s",
+ self.tp_rank, self.tp_world_size, self.pp_rank,
+ self.pp_world_size,
+ remote_tp_rank, remote_tp_size, remote_pp_rank,
+ remote_pp_size,
+ local_rank_offset, local_block_len,
+ self.local_start_index[engine_id][p_remote_rank],
+ self.local_end_index[engine_id][p_remote_rank],
+ remote_rank_offset, remote_block_len,
+ remote_start, remote_end, read,
)
- if block_size_ratio > 1:
- # when prefill with smaller block_size, we need to init a
- # new handler with same block_len to match
- self.src_xfer_handles_by_block_size[nixl_agent_meta.block_size] = (
- self.register_local_xfer_handler(nixl_agent_meta.block_size)[0]
- )
-
- return remote_agent_name
+ return (remote_agent_name, read)
def _validate_remote_agent_handshake(
- self, nixl_agent_meta: NixlAgentMetadata, remote_tp_size: int
+ self, nixl_agent_meta: NixlAgentMetadata,
+ remote_tp_size: int, remote_pp_size: int
):
- """
- Validate the remote agent handshake metadata ensuring the
- invariants hold true.
- """
+ """Validate the remote agent handshake metadata."""
remote_engine_id = nixl_agent_meta.engine_id
assert self._tp_size[remote_engine_id] == remote_tp_size
@@ -1761,8 +1934,6 @@ class NixlConnectorWorker:
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
remote_engine_id
)
- # Num kv_heads > tp_size and P TP > D TP case, not supported
- assert not (tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id))
kv_cache_layout = (
self.kv_cache_layout
@@ -1786,46 +1957,51 @@ class NixlConnectorWorker:
"setting 'enable_permute_local_kv'=True in --kv-transfer-config."
)
- # Block len can only vary across layers when using MLA.
remote_block_len = nixl_agent_meta.block_lens[0]
if self.use_mla or self.kv_topo.is_kv_replicated(remote_engine_id):
- # With replicated KV cache, only the number of blocks can differ.
- for i in range(len(self.block_len_per_layer)):
- assert (
- self.block_len_per_layer[i] // block_size_ratio
- == nixl_agent_meta.block_lens[i]
- ), "KV cache sizes must match between P and D when replicated"
+ local_pp_split = [int(i) for i in self.pp_split.split(",")]
+ local_pp_layer_start = sum(local_pp_split[:self.pp_rank])
+ local_pp_layers = local_pp_split[self.pp_rank]
+ local_pp_layer_end = local_pp_layer_start + local_pp_layers
+
+ remote_pp_split_vals = [
+ int(i) for i in nixl_agent_meta.pp_split.split(",")]
+ rpp_rank = nixl_agent_meta.pp_rank
+ remote_pp_layer_start = sum(remote_pp_split_vals[:rpp_rank])
+ remote_pp_layers = remote_pp_split_vals[rpp_rank]
+ remote_pp_layer_end = remote_pp_layer_start + remote_pp_layers
+
+ for i in range(local_pp_layer_start, local_pp_layer_end):
+ if i >= remote_pp_layer_start and i < remote_pp_layer_end:
+ assert (
+ self.block_len_per_layer[i - local_pp_layer_start]
+ // block_size_ratio
+ == nixl_agent_meta.block_lens[
+ i - remote_pp_layer_start]
+ ), "KV cache sizes must match between P and D"
+ if self.ds_v32:
+ assert (
+ self.block_len_per_layer[
+ local_pp_layers + i - local_pp_layer_start]
+ == nixl_agent_meta.block_lens[
+ remote_pp_layers + i - remote_pp_layer_start]
+ ), "indexer_key_cache sizes must match"
else:
- # When MLA is not used, this is a list of the same block length
for block_len in nixl_agent_meta.block_lens:
assert block_len == remote_block_len, (
"All remote layers must have the same block size"
)
+ assert (
+ remote_block_len
+ == (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio
+ ), (
+ "Remote P worker KV layer cache shape mismatch."
+ )
- if tp_ratio > 0:
- # Remote tp is smaller: remote block_len size is bigger
- assert (
- remote_block_len
- == (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio
- ), (
- "Remote P worker KV layer cache must be of shape [2, N, "
- "local_kv_heads*tp_ratio, page_size, head_dim] and same dtype."
- ) # noqa: E501
- else:
- assert block_size_ratio == 1, (
- "Different local/remote block sizes are not supported when"
- " P TP > D TP."
- )
- # Remote tp is bigger: remote block_len size is smaller
- assert remote_block_len == self.block_len_per_layer[0] // (-tp_ratio), (
- "Remote P worker KV layer cache must be of shape [2, N, "
- "local_kv_heads/tp_ratio, page_size, head_dim] and same dtype."
- ) # noqa: E501
-
- # TP workers that handhshake with same remote have same #blocks.
assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks
- # Same number of regions/~layers.
- assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer)
+ if self.pp_world_size == remote_pp_size:
+ assert (len(nixl_agent_meta.kv_caches_base_addr)
+ == len(self.block_len_per_layer))
def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta):
"""copy recved kv from host buffer to device."""
@@ -1959,13 +2135,12 @@ class NixlConnectorWorker:
# clean up metadata for completed requests
meta = self._recving_metadata.pop(req_id, None)
assert meta is not None, f"{req_id} not found in recving_metadata list"
- assert meta.remote is not None
if self.use_host_buffer:
self.sync_recved_kv_to_device(req_id, meta)
# post processing for heteroblocksize
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
- meta.remote.engine_id
+ meta.remote_engine_id
)
if not self.use_mla and (
block_size_ratio > 1 or self.enable_permute_local_kv
@@ -2004,14 +2179,28 @@ class NixlConnectorWorker:
def _get_new_notifs(self) -> set[str]:
"""
Get req_ids which got a remote xfer message. When multiple consumers
- are reading from the same producer (heterogeneous TP scenario), wait
- for all consumers to be done pulling.
+ are reading from the same producer (heterogeneous TP/PP scenario),
+ wait for all consumers to be done pulling.
"""
assert self.kv_topo is not None
notified_req_ids: set[str] = set()
for notifs in self.nixl_wrapper.get_new_notifs().values():
for notif in notifs:
- req_id, tp_size = notif.decode("utf-8").rsplit(":", 1)
+ parts = notif.decode("utf-8").rsplit(":", 2)
+ if len(parts) == 3:
+ req_id, tp_ratio_str, pp_ratio_str = parts
+ consumers_per_producer = (
+ int(tp_ratio_str) * int(pp_ratio_str))
+ elif len(parts) == 2:
+ req_id, tp_size_str = parts
+ n_consumers = int(tp_size_str)
+ tp_ratio = self.kv_topo.tp_ratio(n_consumers)
+ consumers_per_producer = (
+ -tp_ratio if n_consumers > self.world_size else 1)
+ else:
+ logger.warning("Unexpected notif format: %s", notif)
+ continue
+
if (
req_id not in self._reqs_to_send
and req_id not in self._reqs_to_process
@@ -2024,18 +2213,7 @@ class NixlConnectorWorker:
)
continue
- # NOTE: `tp_ratio` is the opposite when swapping local<>remote
- n_consumers = int(tp_size)
- tp_ratio = self.kv_topo.tp_ratio(n_consumers)
-
- # Number of reads *per producer* to wait for.
- # When remote D TP > local P TP we expect `tp_ratio` reads.
- consumers_per_producer = (
- -tp_ratio if n_consumers > self.world_size else 1
- )
-
self.consumer_notification_counts_by_req[req_id] += 1
- # Wait all consumers (D) to be done reading before freeing.
if (
self.consumer_notification_counts_by_req[req_id]
== consumers_per_producer
@@ -2046,27 +2224,28 @@ class NixlConnectorWorker:
self._reqs_to_send.pop(req_id, None)
return notified_req_ids
- def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]:
+ def _pop_done_transfers(
+ self, transfers: dict[str, list[tuple[int, float]]]
+ ) -> set[str]:
"""
Pop completed xfers by checking for DONE state.
Args:
- transfers: dict of req_id -> list[running_xfer]
+ transfers: dict of req_id -> list[(handle, start_time)]
Returns:
set of req_ids that have all done xfers
"""
done_req_ids: set[str] = set()
- for req_id, handles in list(transfers.items()):
+ for req_id, handle_list in list(transfers.items()):
in_progress = []
- for handle in handles:
+ for handle, start_time in handle_list:
try:
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
if xfer_state == "DONE":
- # Get telemetry from NIXL
res = self.nixl_wrapper.get_xfer_telemetry(handle)
self.xfer_stats.record_transfer(res)
self.nixl_wrapper.release_xfer_handle(handle)
elif xfer_state == "PROC":
- in_progress.append(handle)
+ in_progress.append((handle, start_time))
continue
else:
self._log_failure(
@@ -2167,63 +2346,16 @@ class NixlConnectorWorker:
self._reqs_to_send[req_id] = expiration_time
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
- assert meta.remote is not None and self.kv_topo is not None
- remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id(
- meta.remote.engine_id
+ logger.debug(
+ "Remote agent %s available, calling _read_blocks for req %s",
+ meta.remote_engine_id, req_id,
+ )
+ self._read_blocks(
+ request_id=req_id,
+ dst_engine_id=meta.remote_engine_id,
+ local_block_ids=meta.local_physical_block_ids,
+ remote_block_ids=meta.remote_block_ids,
)
- tp_ratio = self.kv_topo.tp_ratio_from_engine_id(meta.remote.engine_id)
- # D may have to perform multiple reads from different remote ranks.
- for i, remote_rank in enumerate(remote_ranks):
- if self.use_mla and tp_ratio < 0 and i > 0:
- # MLA opt: when P TP > D TP, only a single read is executed for
- # the first remote rank (cache is duplicated)..
- break
-
- remote_block_size = self.kv_topo.remote_block_size[meta.remote.engine_id]
- logger.debug(
- "Remote agent %s available, calling _read_blocks"
- " on remote rank %s with remote block size %s for req %s",
- meta.remote.engine_id,
- remote_rank,
- remote_block_size,
- req_id,
- )
- # Get side handles.
- if tp_ratio < 0 and not self.use_mla:
- assert remote_block_size == self.block_size
- # Remote tp_size > local tp_size: we must perform multiple
- # reads. Get the memory chunk onto which we will write to.
- local_xfer_side_handle = self.src_xfer_handles_by_tp_ratio[tp_ratio][i]
- else:
- # Single read from remote, we write to the whole memory region.
- # Also handle remote block size different from local block size.
- local_xfer_side_handle = self.src_xfer_handles_by_block_size[
- remote_block_size
- ]
-
- # Destination handle: remote_engine_id -> remote_rank -> handle.
- remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote.engine_id][
- remote_rank
- ]
- self._read_blocks(
- request_id=req_id,
- dst_engine_id=meta.remote.engine_id,
- remote_request_id=meta.remote.request_id,
- local_block_ids=meta.local_physical_block_ids,
- remote_block_ids=meta.remote.block_ids,
- remote_rank=remote_rank,
- local_xfer_side_handle=local_xfer_side_handle,
- remote_xfer_side_handle=remote_xfer_side_handle,
- )
-
- if self.use_mla and tp_ratio < 0:
- # ..but we still need to notify the other remote ranks that we
- # have the blocks we need so they can update the request state.
- notif_id = f"{req_id}:{self.world_size}".encode()
- remote_agents = self._remote_agents[meta.remote.engine_id]
- for rank_to_notify, agent in remote_agents.items():
- if rank_to_notify != remote_rank:
- self.nixl_wrapper.send_notif(agent, notif_msg=notif_id)
def _read_blocks(
self,
@@ -2231,165 +2363,147 @@ class NixlConnectorWorker:
remote_block_ids: list[int],
dst_engine_id: str,
request_id: str,
- remote_request_id: str,
- remote_rank: int,
- local_xfer_side_handle: int,
- remote_xfer_side_handle: int,
):
"""
- Post a READ point-to-point xfer request from a single local worker to
- a single remote worker.
+ Post READ xfer requests from local worker to remote workers.
+ With PP, may read from multiple remote PP ranks.
"""
assert self.kv_topo is not None
- block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id)
+ block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
+ dst_engine_id)
if block_size_ratio > 1:
local_block_ids = self.get_mapped_blocks(
np.asarray(local_block_ids), block_size_ratio
)
if len(local_block_ids) > len(remote_block_ids):
- # NOTE:
- # get_mapped_blocks will always expand block_ids for n times.
- # ex:
- # prefill block_ids with block_size as 4:
- # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
- # Local decode block_ids with block_size as 16: [1, 2, 3]
- # expland ecode block_ids with get_mapped_blocks from [1, 2, 3] to
- # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
- # Then we clip local to align with prefill
- # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] to
- # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
- local_block_ids = local_block_ids[: len(remote_block_ids)]
- # NOTE(rob): having the staging blocks be on the READER side is
- # not going to work well (since we will have to call rearrange tensors).
- # after we detect the txn is complete (which means we cannot make the
- # read trxn async easily). If we want to make "READ" happen cleanly,
- # then we will need to have the staging blocks on the remote side.
+ local_block_ids = local_block_ids[:len(remote_block_ids)]
- # NOTE(rob): according to nvidia the staging blocks are used to
- # saturate IB with heterogeneous TP sizes. We should remove the staging
- # blocks until we are ready.
-
- # Number of D TP workers that will read from dst P. Propagate info
- # on notification so that dst worker can wait before freeing blocks.
- notif_id = f"{remote_request_id}:{self.world_size}".encode()
-
- # Full prefix cache hit: do not need to read remote blocks,
- # just notify P worker that we have the blocks we need.
- num_local_blocks = len(local_block_ids)
- if num_local_blocks == 0:
- agent_name = self._remote_agents[dst_engine_id][remote_rank]
- try:
- self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
- except Exception as e:
- self._log_failure(
- failure_type="notification_failed",
- msg="P worker blocks will be freed after timeout. "
- "This may indicate network issues.",
- req_id=request_id,
- error=e,
- dst_engine_id=dst_engine_id,
- remote_rank=remote_rank,
- remote_agent_name=agent_name,
- )
- self.xfer_stats.record_failed_notification()
- return
-
- # Partial prefix cache hit: just read uncomputed blocks.
- num_remote_blocks = len(remote_block_ids)
- assert num_local_blocks <= num_remote_blocks
- if num_local_blocks < num_remote_blocks:
- remote_block_ids = remote_block_ids[-num_local_blocks:]
-
- # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
- # corresponding rank. With heterogeneous TP, fixing D>P, the D tp
- # workers will issue xfers to parts of the P worker remote kv caches.
-
- # Get descs ids.
- local_block_descs_ids: np.ndarray
- remote_block_descs_ids: np.ndarray
-
- if not self.block_window_per_layer:
- # Default case: assume global attention
- remote_block_descs_ids = self._get_block_descs_ids(
- dst_engine_id,
- remote_block_ids,
- )
- local_block_descs_ids = self._get_block_descs_ids(
- self.engine_id,
- local_block_ids,
- block_size_ratio=block_size_ratio,
- )
+ if self._tp_size[self.engine_id] > self._tp_size[dst_engine_id]:
+ tp_ratio = (self._tp_size[self.engine_id]
+ // self._tp_size[dst_engine_id])
else:
- # TODO(mgoin): remove this once we have hybrid memory allocator
- # Optimization for models with local attention (Llama 4)
- local_descs_list = []
- remote_descs_list = []
- for layer_idx, block_window in enumerate(self.block_window_per_layer):
- # For each layer:
- if block_window is None:
- # If not chunked, we just use the
- # full block lists (global attention)
- layer_local_block_ids = local_block_ids
- layer_remote_block_ids = remote_block_ids
+ tp_ratio = 1
+
+ if self._pp_size[self.engine_id] > self._pp_size[dst_engine_id]:
+ pp_ratio = (self._pp_size[self.engine_id]
+ // self._pp_size[dst_engine_id])
+ else:
+ pp_ratio = 1
+
+ notif_id = f"{request_id}:{tp_ratio}:{pp_ratio}".encode()
+
+ remote_infos = self._remote_agents[dst_engine_id]
+ local_side_handles = self.src_xfer_side_handles[dst_engine_id]
+ remote_side_handles = self.dst_xfer_side_handles[dst_engine_id]
+
+ for remote_index, (remote_agent_name, read) in remote_infos.items():
+ num_local_blocks = len(local_block_ids)
+ if num_local_blocks == 0:
+ try:
+ self.nixl_wrapper.send_notif(
+ remote_agent_name, notif_msg=notif_id)
+ except Exception as e:
+ self._log_failure(
+ failure_type="notification_failed",
+ msg="P worker blocks will be freed after timeout.",
+ req_id=request_id,
+ error=e,
+ dst_engine_id=dst_engine_id,
+ )
+ self.xfer_stats.record_failed_notification()
+ continue
+
+ if read:
+ num_remote_blocks = len(remote_block_ids)
+ assert num_local_blocks <= num_remote_blocks
+ if num_local_blocks < num_remote_blocks:
+ remote_block_ids_used = remote_block_ids[
+ -num_local_blocks:]
else:
- # If chunked, get the last block_window blocks
- layer_local_block_ids = local_block_ids[-block_window:]
- layer_remote_block_ids = remote_block_ids[-block_window:]
+ remote_block_ids_used = remote_block_ids
- # Get descs ids for the layer.
- layer_local_desc_ids = self._get_block_descs_ids(
- self.engine_id,
- layer_local_block_ids,
- layer_idx,
- block_size_ratio=block_size_ratio,
- )
- layer_remote_desc_ids = self._get_block_descs_ids(
- dst_engine_id,
- layer_remote_block_ids,
- layer_idx,
- )
+ local_xfer_side_handle = local_side_handles[remote_index]
+ remote_xfer_side_handle = remote_side_handles[remote_index]
- local_descs_list.append(layer_local_desc_ids)
- remote_descs_list.append(layer_remote_desc_ids)
+ local_block_descs_ids: np.ndarray
+ remote_block_descs_ids: np.ndarray
- local_block_descs_ids = np.concatenate(local_descs_list)
- remote_block_descs_ids = np.concatenate(remote_descs_list)
+ if not self.block_window_per_layer:
+ remote_block_descs_ids = self._get_block_descs_ids(
+ dst_engine_id, remote_block_ids_used,
+ is_local=False,
+ remote_engine_id=dst_engine_id,
+ remote_index=remote_index,
+ )
+ local_block_descs_ids = self._get_block_descs_ids(
+ self.engine_id, local_block_ids,
+ is_local=True,
+ remote_engine_id=dst_engine_id,
+ remote_index=remote_index,
+ block_size_ratio=block_size_ratio,
+ )
+ else:
+ local_descs_list = []
+ remote_descs_list = []
+ for layer_idx, block_window in enumerate(
+ self.block_window_per_layer):
+ if block_window is None:
+ layer_local_block_ids = local_block_ids
+ layer_remote_block_ids = remote_block_ids_used
+ else:
+ layer_local_block_ids = local_block_ids[
+ -block_window:]
+ layer_remote_block_ids = remote_block_ids_used[
+ -block_window:]
- assert len(local_block_descs_ids) == len(remote_block_descs_ids)
+ layer_local_desc_ids = self._get_block_descs_ids(
+ self.engine_id, layer_local_block_ids, layer_idx,
+ block_size_ratio=block_size_ratio,
+ )
+ layer_remote_desc_ids = self._get_block_descs_ids(
+ dst_engine_id, layer_remote_block_ids, layer_idx,
+ )
+ local_descs_list.append(layer_local_desc_ids)
+ remote_descs_list.append(layer_remote_desc_ids)
- # Prepare transfer with Nixl.
- handle = None
- try:
- handle = self.nixl_wrapper.make_prepped_xfer(
- "READ",
- local_xfer_side_handle,
- local_block_descs_ids,
- remote_xfer_side_handle,
- remote_block_descs_ids,
- notif_msg=notif_id,
- )
+ local_block_descs_ids = np.concatenate(local_descs_list)
+ remote_block_descs_ids = np.concatenate(
+ remote_descs_list)
- # Begin async xfer.
- self.nixl_wrapper.transfer(handle)
+ assert len(local_block_descs_ids) == len(
+ remote_block_descs_ids)
- # Use handle to check completion in future step().
- self._recving_transfers[request_id].append(handle)
- except Exception as e:
- # mark all (logical) blocks for this request as invalid
- self._log_failure(
- failure_type="transfer_setup_failed",
- req_id=request_id,
- msg="Marking blocks as invalid",
- error=e,
- dst_engine_id=dst_engine_id,
- remote_rank=remote_rank,
- )
- if meta := self._recving_metadata.get(request_id):
- self._invalid_block_ids.update(meta.local_block_ids)
- self.xfer_stats.record_failed_transfer()
- if handle is not None:
- self.nixl_wrapper.release_xfer_handle(handle)
- self._failed_recv_reqs.add(request_id)
+ handle = None
+ try:
+ handle = self.nixl_wrapper.make_prepped_xfer(
+ "READ",
+ local_xfer_side_handle,
+ local_block_descs_ids,
+ remote_xfer_side_handle,
+ remote_block_descs_ids,
+ notif_msg=notif_id,
+ )
+ self.nixl_wrapper.transfer(handle)
+
+ self._recving_transfers[request_id].append(
+ (handle, time.perf_counter()))
+ except Exception as e:
+ self._log_failure(
+ failure_type="transfer_setup_failed",
+ req_id=request_id,
+ msg="Marking blocks as invalid",
+ error=e,
+ dst_engine_id=dst_engine_id,
+ )
+ if meta := self._recving_metadata.get(request_id):
+ self._invalid_block_ids.update(meta.local_block_ids)
+ self.xfer_stats.record_failed_transfer()
+ if handle is not None:
+ self.nixl_wrapper.release_xfer_handle(handle)
+ self._failed_recv_reqs.add(request_id)
+ else:
+ self.nixl_wrapper.send_notif(
+ remote_agent_name, notif_msg=notif_id)
def get_mapped_blocks(self, block_ids, block_size_ratio):
"""
@@ -2415,25 +2529,52 @@ class NixlConnectorWorker:
engine_id: str,
block_ids: list[int],
layer_idx: int | None = None,
+ is_local: bool = True,
+ remote_engine_id: str | None = None,
+ remote_index: int | None = None,
block_size_ratio: float | None = None,
) -> np.ndarray:
"""
Get the descs ids for a set of block ids.
- If layer_idx is provided, we use the region_ids for the given layer.
- Otherwise, we use all regions.
+ With PP support, uses local_start_index/local_end_index for region
+ selection when reading from specific remote PP ranks.
"""
if layer_idx is None:
- region_ids = np.arange(self.num_regions)
+ if is_local:
+ if (remote_engine_id is not None
+ and remote_index is not None
+ and remote_engine_id in self.local_start_index
+ and remote_index in self.local_start_index[
+ remote_engine_id]):
+ si = self.local_start_index[remote_engine_id][
+ remote_index]
+ ei = self.local_end_index[remote_engine_id][
+ remote_index]
+ if isinstance(si, tuple):
+ region_ids = np.arange(0)
+ for start, end in zip(si, ei):
+ region_ids = np.concatenate(
+ [region_ids, np.arange(start, end)])
+ else:
+ region_ids = np.arange(si, ei)
+ else:
+ region_ids = np.arange(self.num_regions)
+ else:
+ if (remote_engine_id is not None
+ and remote_index is not None
+ and remote_engine_id in self.remote_nums
+ and remote_index in self.remote_nums[
+ remote_engine_id]):
+ region_ids = np.arange(
+ self.remote_nums[remote_engine_id][remote_index])
+ else:
+ region_ids = np.arange(self.num_regions)
else:
assert layer_idx < self.num_layers
if self.num_layers < self.num_regions:
- # If we have more regions than layers, we assume that
- # the regions are organized as [K0, V0, K1, V1, ...]
- # and we select K_i and V_i
assert 2 * self.num_layers == self.num_regions
region_ids = np.arange(2 * layer_idx, 2 * layer_idx + 2)
else:
- # Otherwise, we assume we have MLA and select i-th layer
assert self.num_layers == self.num_regions
region_ids = np.arange(layer_idx, layer_idx + 1)
@@ -2441,10 +2582,9 @@ class NixlConnectorWorker:
if block_size_ratio is not None:
num_blocks = int(num_blocks * block_size_ratio)
- # Compute the desc ids for each block.
region_ids = region_ids[:, None]
- block_ids = np.array(block_ids)[None, :]
- descs_ids = region_ids * num_blocks + block_ids
+ block_ids_arr = np.array(block_ids)[None, :]
+ descs_ids = region_ids * num_blocks + block_ids_arr
return descs_ids.flatten()
def _logical_to_kernel_block_ids(self, block_ids: list[int]) -> list[int]:
@@ -2506,21 +2646,24 @@ class NixlConnectorWorker:
def shutdown(self):
"""Shutdown the connector worker."""
+ if not hasattr(self, "_handshake_initiation_executor"):
+ # error happens during init, no need to shutdown
+ return
self._handshake_initiation_executor.shutdown(wait=False)
- for handles in self._recving_transfers.values():
- for handle in handles:
+ for handle_list in self._recving_transfers.values():
+ for handle, _ in handle_list:
self.nixl_wrapper.release_xfer_handle(handle)
self._recving_transfers.clear()
- for handle in self.src_xfer_handles_by_block_size.values():
+ for handle in self.src_registered_xfer_side_headles.values():
self.nixl_wrapper.release_dlist_handle(handle)
- self.src_xfer_handles_by_block_size.clear()
- for handles in self.src_xfer_handles_by_tp_ratio.values():
- for handle in handles:
+ self.src_registered_xfer_side_headles.clear()
+ for src_handles in self.src_xfer_side_handles.values():
+ for handle in src_handles.values():
+ self.nixl_wrapper.release_dlist_handle(handle)
+ self.src_xfer_side_handles.clear()
+ for dst_handles in self.dst_xfer_side_handles.values():
+ for handle in dst_handles.values():
self.nixl_wrapper.release_dlist_handle(handle)
- self.src_xfer_handles_by_tp_ratio.clear()
- for dst_xfer_side_handles in self.dst_xfer_side_handles.values():
- for dst_xfer_side_handle in dst_xfer_side_handles.values():
- self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle)
self.dst_xfer_side_handles.clear()
for remote_agents in self._remote_agents.values():
for agent_name in remote_agents.values():
diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py
index 61646dd..ee170d4 100644
--- a/vllm/distributed/parallel_state.py
+++ b/vllm/distributed/parallel_state.py
@@ -25,6 +25,7 @@ If you only need to use the distributed environment without model/pipeline
import contextlib
import gc
+import os
import pickle
import weakref
from collections import namedtuple
@@ -33,7 +34,7 @@ from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from datetime import timedelta
from multiprocessing import shared_memory
-from typing import Any, Protocol
+from typing import TYPE_CHECKING, Any, Protocol
from unittest.mock import patch
import torch
@@ -54,6 +55,10 @@ from vllm.utils.system_utils import suppress_stdout
from vllm.utils.torch_utils import (
direct_register_custom_op,
)
+
+if TYPE_CHECKING:
+ from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
+
import ixformer.distributed as ixfd
import vllm._custom_ops as ops
@@ -327,6 +332,8 @@ class GroupCoordinator:
self.rank = torch.distributed.get_rank()
self.local_rank = local_rank
+ use_vllm_comm = os.environ.get("VLLM_FORCE_NCCL_COMM", None) not in {"1", "Y", "y"}
+
self_device_group = None
self_cpu_group = None
@@ -339,7 +346,7 @@ class GroupCoordinator:
with suppress_stdout():
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
if self.rank in ranks:
- self.ixfd_group = ixfd.init_comm_with_store(device_group)
+ self.ixfd_group = ixfd.init_comm_with_store(device_group) if use_vllm_comm else None
self.ranks = ranks
self.world_size = len(ranks)
self.rank_in_group = ranks.index(self.rank)
@@ -372,8 +379,7 @@ class GroupCoordinator:
self.device_communicator = device_comm_cls(
cpu_group=self.cpu_group,
device=self.device,
- # device_group=self.device_group,
- device_group=self.ixfd_group if envs.VLLM_FORCE_NCCL_COMM else self.device_group,
+ device_group=self.ixfd_group if use_vllm_comm else self.device_group,
unique_name=self.unique_name,
)
@@ -385,11 +391,6 @@ class GroupCoordinator:
self.cpu_group, 1 << 22, 6
)
- from vllm.platforms import current_platform
-
- # self.use_custom_op_call = (
- # current_platform.is_cuda_alike() or current_platform.is_tpu()
- # )
self.use_custom_op_call = False
self.use_cpu_custom_send_recv = current_platform.is_cpu() and hasattr(
@@ -468,14 +469,12 @@ class GroupCoordinator:
# only cuda uses this function,
# so we don't abstract it into the base class
maybe_ca_context = nullcontext()
- # from vllm.distributed.device_communicators.cuda_communicator import (
- # CudaCommunicator,
- # )
- from vllm.distributed.device_communicators.base_device_communicator import DeviceCommunicatorBase
+ from vllm.distributed.device_communicators.cuda_communicator import (
+ CudaCommunicator,
+ )
if self.device_communicator is not None:
- # assert isinstance(self.device_communicator, CudaCommunicator)
- assert isinstance(self.device_communicator, DeviceCommunicatorBase)
+ assert isinstance(self.device_communicator, CudaCommunicator)
ca_comm = self.device_communicator.ca_comm
if ca_comm is not None:
maybe_ca_context = ca_comm.capture() # type: ignore
@@ -608,9 +607,9 @@ class GroupCoordinator:
src=self.ranks[src],
group=self.device_group)
else:
- torch.distributed.broadcast(input_,
- src=self.ranks[src],
- group=self.device_group)
+ torch.distributed.broadcast(
+ input_, src=self.ranks[src], group=self.device_group
+ )
return input_
def broadcast_object(self, obj: Any | None = None, src: int = 0):
@@ -764,10 +763,9 @@ class GroupCoordinator:
group=group,
async_op=True)
else:
- handle = torch.distributed.broadcast(tensor,
- src=self.ranks[src],
- group=group,
- async_op=True)
+ handle = torch.distributed.broadcast(
+ tensor, src=self.ranks[src], group=group, async_op=True
+ )
async_handles.append(handle)
for async_handle in async_handles:
async_handle.wait()
@@ -802,10 +800,8 @@ class GroupCoordinator:
async_op=True)
else:
handle = torch.distributed.broadcast(
- tensor,
- src=self.ranks[src],
- group=group,
- async_op=True)
+ tensor, src=self.ranks[src], group=group, async_op=True
+ )
async_handles.append(handle)
tensor_dict[key] = tensor
else:
@@ -876,6 +872,10 @@ class GroupCoordinator:
if self.world_size <= 1:
return []
+ if dst is None:
+ dst = (self.rank_in_group + 1) % self.world_size
+ assert dst < self.world_size, f"Invalid dst rank ({dst})"
+
if self.use_cpu_custom_send_recv:
if self.device_communicator is None:
raise ValueError("No device communicator found")
@@ -893,10 +893,6 @@ class GroupCoordinator:
group = self.device_group
metadata_group = self.cpu_group
- if dst is None:
- dst = (self.rank_in_group + 1) % self.world_size
- assert dst < self.world_size, f"Invalid dst rank ({dst})"
-
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
self.send_object(metadata_list, dst=dst)
@@ -917,6 +913,7 @@ class GroupCoordinator:
handle = torch.distributed.isend(
tensor, dst=self.ranks[dst], group=comm_group
)
+
if tensor.is_cuda:
tensor.record_stream(torch.cuda.current_stream(tensor.device))
handles.append(handle)
@@ -973,6 +970,11 @@ class GroupCoordinator:
]:
if not torch.distributed.is_initialized() or self.world_size == 1:
return None, [], []
+
+ if src is None:
+ src = (self.rank_in_group - 1) % self.world_size
+ assert src < self.world_size, f"Invalid src rank ({src})"
+
if self.use_cpu_custom_send_recv:
if self.device_communicator is None:
raise ValueError("No device communicator found")
@@ -990,10 +992,6 @@ class GroupCoordinator:
group = self.device_group
metadata_group = self.cpu_group
- if src is None:
- src = (self.rank_in_group - 1) % self.world_size
- assert src < self.world_size, f"Invalid src rank ({src})"
-
recv_metadata_list = self.recv_object(src=src)
tensor_dict: dict[str, Any] = {}
handles: list[Handle] = []
@@ -1072,14 +1070,13 @@ class GroupCoordinator:
return self.device_communicator.recv(size, dtype, src)
def destroy(self):
- if hasattr(self, "device_group"):
- # torch.distributed.destroy_process_group(self.device_group)
+ if self.device_group is not None:
if self.device_communicator and self.device_communicator.use_vllm_comm:
ixfd.destroy_process_group(self.device_group)
else:
torch.distributed.destroy_process_group(self.device_group)
- del self.device_group
- if hasattr(self, "cpu_group"):
+ self.device_group = None
+ if self.cpu_group is not None:
torch.distributed.destroy_process_group(self.cpu_group)
del self.cpu_group
if self.device_communicator is not None:
@@ -1094,7 +1091,6 @@ class GroupCoordinator:
def dispatch_router_logits(
self,
hidden_states: torch.Tensor,
- extra_residual:torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
@@ -1105,13 +1101,12 @@ class GroupCoordinator:
if self.device_communicator is not None:
return self.device_communicator.dispatch_router_logits(
hidden_states,
- extra_residual,
router_logits,
is_sequence_parallel,
extra_tensors,
)
else:
- return hidden_states, extra_residual, router_logits
+ return hidden_states, router_logits
def dispatch(
self,
@@ -1189,6 +1184,55 @@ def init_model_parallel_group(
)
+def _init_stateless_group(
+ group_ranks: list[list[int]],
+ group_name: str,
+ group_ports: list[list[int]],
+ host: str,
+ backend: str,
+ use_device_communicator: bool = True,
+) -> "StatelessGroupCoordinator":
+ """Create a StatelessGroupCoordinator with the given parameters."""
+ from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
+
+ world = get_world_group()
+ return StatelessGroupCoordinator(
+ group_ranks=group_ranks,
+ local_rank=world.local_rank,
+ torch_distributed_backend=backend,
+ use_device_communicator=use_device_communicator,
+ group_name=group_name,
+ host=host,
+ group_ports=group_ports,
+ global_rank=world.rank,
+ global_world_size=world.world_size,
+ )
+
+
+def _replace_active_groups(
+ *,
+ world: GroupCoordinator | None,
+ dp: GroupCoordinator | None,
+ ep: GroupCoordinator | None,
+ eplb: GroupCoordinator | None,
+ node_count: int | None,
+) -> None:
+ """Destroy the current DP/EP/WORLD/EPLB groups and replace them.
+
+ Destruction is collective — all ranks in the old groups must call this
+ function together. Pass all-``None`` to tear down without replacement.
+ """
+ global _WORLD, _DP, _EP, _EPLB, _NODE_COUNT
+ for group in (_DP, _EP, _WORLD, _EPLB):
+ if group is not None:
+ group.destroy()
+ _WORLD = world
+ _DP = dp
+ _EP = ep
+ _EPLB = eplb
+ _NODE_COUNT = node_count
+
+
_TP: GroupCoordinator | None = None
@@ -1286,6 +1330,39 @@ def set_custom_all_reduce(enable: bool):
_ENABLE_CUSTOM_ALL_REDUCE = enable
+def _init_elastic_ep_world(
+ config, local_rank: int, backend: str, rank: int, world_size: int
+) -> None:
+ from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator
+
+ global _WORLD, _NODE_COUNT
+ assert _WORLD is None, "world group already initialized"
+ parallel_config = config.parallel_config
+ global_rank = parallel_config.data_parallel_rank * world_size + rank
+ global_world_size = parallel_config.world_size_across_dp
+ all_ranks = list(range(global_world_size))
+ group_ranks = [all_ranks[i : i + 1] for i in range(global_world_size)]
+ if global_rank in all_ranks:
+ group_ranks = [all_ranks]
+ group_ports = [parallel_config.get_next_stateless_world_group_port()]
+ world = StatelessGroupCoordinator(
+ group_ranks=group_ranks,
+ local_rank=local_rank,
+ torch_distributed_backend=backend,
+ use_device_communicator=False,
+ group_name="world",
+ host=parallel_config.data_parallel_master_ip,
+ group_ports=group_ports,
+ global_rank=global_rank,
+ global_world_size=global_world_size,
+ )
+ assert parallel_config.nnodes_within_dp == 1, (
+ "Elastic EP is not supported with multi-node TP/PP"
+ )
+ _NODE_COUNT = _node_count(world.tcp_store_group)
+ _WORLD = world
+
+
def init_distributed_environment(
world_size: int = -1,
rank: int = -1,
@@ -1305,6 +1382,7 @@ def init_distributed_environment(
from vllm.config import get_current_vllm_config_or_none
config = get_current_vllm_config_or_none()
+ enable_elastic_ep = config is not None and config.parallel_config.enable_elastic_ep
if (
config is not None
and config.parallel_config.distributed_executor_backend != "external_launcher"
@@ -1312,6 +1390,7 @@ def init_distributed_environment(
config.parallel_config.nnodes > 1
or config.parallel_config.data_parallel_size > 1
)
+ and not enable_elastic_ep
):
parallel_config = config.parallel_config
# adjust to take into account data parallelism
@@ -1365,6 +1444,18 @@ def init_distributed_environment(
rank=rank,
timeout=timeout,
)
+ if enable_elastic_ep:
+ tp_pp_cpu_group = torch.distributed.new_group(
+ backend="gloo", timeout=timeout
+ )
+ if _node_count(tp_pp_cpu_group) > 1:
+ # NOTE(yongji): StatelessGroupCoordinator uses data_parallel_master_ip
+ # to initialize all DP/EP groups, hence all ranks within TP/PP group
+ # must reside on the same node
+ raise RuntimeError(
+ "Elastic EP is not yet supported with multi-node TP/PP"
+ )
+
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816
@@ -1373,6 +1464,9 @@ def init_distributed_environment(
# setting, where we can use rank as local rank
local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
global _WORLD, _NODE_COUNT, _INNER_DP_WORLD
+ if enable_elastic_ep:
+ _init_elastic_ep_world(config, local_rank, backend, rank, world_size)
+ return
if _WORLD is None:
ranks = list(range(torch.distributed.get_world_size()))
_WORLD = init_world_group(ranks, local_rank, backend)
@@ -1436,16 +1530,33 @@ def initialize_model_parallel(
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
- world_size: int = torch.distributed.get_world_size()
- rank = torch.distributed.get_rank()
- backend = backend or torch.distributed.get_backend(get_world_group().device_group)
- data_parallel_size = 1
- from vllm.config import get_current_vllm_config_or_none
+ from vllm.config import get_current_vllm_config
- config = get_current_vllm_config_or_none()
- if config is not None:
- data_parallel_size = config.parallel_config.data_parallel_size
+ config = get_current_vllm_config()
+ data_parallel_size = config.parallel_config.data_parallel_size
+ enable_elastic_ep = config.parallel_config.enable_elastic_ep
+ if enable_elastic_ep:
+ # Use stateless world group for global information
+ world_size = get_world_group().world_size
+ rank = get_world_group().rank
+ backend = backend or "nccl"
+ tp_pp_pcp_size = (
+ tensor_model_parallel_size
+ * pipeline_model_parallel_size
+ * prefill_context_model_parallel_size
+ )
+ local_all_ranks = torch.arange(tp_pp_pcp_size).reshape(
+ pipeline_model_parallel_size,
+ prefill_context_model_parallel_size,
+ tensor_model_parallel_size,
+ )
+ else:
+ world_size = torch.distributed.get_world_size()
+ rank = torch.distributed.get_rank()
+ backend = backend or torch.distributed.get_backend(
+ get_world_group().device_group
+ )
# the layout order is: ExternalDP x DP x PP x TP
# ExternalDP is the data parallel group that is not part of the model,
@@ -1469,7 +1580,9 @@ def initialize_model_parallel(
assert _TP is None, "tensor model parallel group is already initialized"
group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
-
+ if enable_elastic_ep:
+ group_ranks = local_all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
+ group_ranks = [x.tolist() for x in group_ranks]
# message queue broadcaster is only used in tensor model parallel group
_TP = init_model_parallel_group(
group_ranks,
@@ -1488,6 +1601,11 @@ def initialize_model_parallel(
# TP group into tp_size//dcp_size DCP groups.
group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
+ if enable_elastic_ep:
+ group_ranks = local_all_ranks.reshape(
+ -1, decode_context_model_parallel_size
+ ).unbind(0)
+ group_ranks = [x.tolist() for x in group_ranks]
_DCP = init_model_parallel_group(
group_ranks,
get_world_group().local_rank,
@@ -1504,6 +1622,13 @@ def initialize_model_parallel(
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
+ if enable_elastic_ep:
+ group_ranks = (
+ local_all_ranks.transpose(1, 2)
+ .reshape(-1, prefill_context_model_parallel_size)
+ .unbind(0)
+ )
+ group_ranks = [x.tolist() for x in group_ranks]
_PCP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="pcp"
)
@@ -1515,6 +1640,13 @@ def initialize_model_parallel(
all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
+ if enable_elastic_ep:
+ group_ranks = (
+ local_all_ranks.transpose(0, 2)
+ .reshape(-1, pipeline_model_parallel_size)
+ .unbind(0)
+ )
+ group_ranks = [x.tolist() for x in group_ranks]
_PP = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="pp"
)
@@ -1523,14 +1655,27 @@ def initialize_model_parallel(
assert _DP is None, "data parallel group is already initialized"
group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
- _DP = init_model_parallel_group(
- group_ranks, get_world_group().local_rank, backend, group_name="dp"
- )
+ if enable_elastic_ep:
+ parallel_config = config.parallel_config
+ dp_ports = [
+ parallel_config.get_next_stateless_dp_group_port() for _ in group_ranks
+ ]
+ _DP = _init_stateless_group(
+ group_ranks,
+ "dp",
+ dp_ports,
+ parallel_config.data_parallel_master_ip,
+ backend,
+ )
+ else:
+ _DP = init_model_parallel_group(
+ group_ranks, get_world_group().local_rank, backend, group_name="dp"
+ )
global _EP
assert _EP is None, "expert parallel group is already initialized"
# Don't create EP group for dense models.
- if config is None or config.model_config is None or config.model_config.is_moe:
+ if config.model_config is None or config.model_config.is_moe:
group_ranks = (
all_ranks.transpose(1, 2)
.reshape(
@@ -1542,9 +1687,22 @@ def initialize_model_parallel(
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
- _EP = init_model_parallel_group(
- group_ranks, get_world_group().local_rank, backend, group_name="ep"
- )
+ if enable_elastic_ep:
+ parallel_config = config.parallel_config
+ ep_ports = [
+ parallel_config.get_next_stateless_ep_group_port() for _ in group_ranks
+ ]
+ _EP = _init_stateless_group(
+ group_ranks,
+ "ep",
+ ep_ports,
+ parallel_config.data_parallel_master_ip,
+ backend,
+ )
+ else:
+ _EP = init_model_parallel_group(
+ group_ranks, get_world_group().local_rank, backend, group_name="ep"
+ )
# Create EPLB group with the same ranks as EP if EPLB is enabled.
# This is a separate process group to isolate EPLB communications
@@ -1557,10 +1715,25 @@ def initialize_model_parallel(
and config.parallel_config is not None
and config.parallel_config.enable_eplb
):
- # Reuse the same group_ranks from EP
- _EPLB = init_model_parallel_group(
- group_ranks, get_world_group().local_rank, backend, group_name="eplb"
- )
+ if enable_elastic_ep:
+ eplb_ports = [
+ parallel_config.get_next_stateless_eplb_group_port()
+ for _ in group_ranks
+ ]
+ _EPLB = _init_stateless_group(
+ group_ranks,
+ "eplb",
+ eplb_ports,
+ parallel_config.data_parallel_master_ip,
+ backend,
+ )
+ else:
+ _EPLB = init_model_parallel_group(
+ group_ranks,
+ get_world_group().local_rank,
+ backend,
+ group_name="eplb",
+ )
# If no EP group needed, _EP remains None
# If no EPLB group needed, _EPLB remains None
@@ -1590,7 +1763,11 @@ def ensure_model_parallel_initialized(
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
values if the model parallel groups are initialized.
"""
- backend = backend or torch.distributed.get_backend(get_world_group().device_group)
+ world_group = get_world_group()
+ if hasattr(world_group, "backend"):
+ backend = backend or world_group.backend
+ else:
+ backend = backend or torch.distributed.get_backend(world_group.device_group)
if not model_parallel_is_initialized():
initialize_model_parallel(
tensor_model_parallel_size,
diff --git a/vllm/distributed/stateless_coordinator.py b/vllm/distributed/stateless_coordinator.py
new file mode 100644
index 0000000..f2126fd
--- /dev/null
+++ b/vllm/distributed/stateless_coordinator.py
@@ -0,0 +1,322 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from typing import Any, Optional
+
+import torch
+from torch.distributed import Backend, ProcessGroup
+
+from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
+from vllm.distributed.parallel_state import (
+ GroupCoordinator,
+ TensorMetadata,
+ _get_unique_name,
+ _register_group,
+ _split_tensor_dict,
+)
+from vllm.distributed.utils import (
+ StatelessProcessGroup,
+ stateless_destroy_torch_distributed_process_group,
+ stateless_init_torch_distributed_process_group,
+)
+from vllm.logger import init_logger
+from vllm.utils.import_utils import resolve_obj_by_qualname
+
+logger = init_logger(__name__)
+
+
+class StatelessGroupCoordinator(GroupCoordinator):
+ """
+ A stateless version of the GroupCoordinator class in parallel_state,
+ It will create CPU, device and TCPStore based communication groups
+ that are independent of PyTorch's WORLD group. Hence,
+ communication groups with a different set of participants GPUs
+ can be created without destroying the existing ones.
+ """
+
+ def __init__(
+ self,
+ group_ranks: list[list[int]],
+ local_rank: int,
+ torch_distributed_backend: str | Backend,
+ use_device_communicator: bool,
+ use_message_queue_broadcaster: bool = False,
+ group_name: str | None = None,
+ host: str = "127.0.0.1",
+ group_ports: list[list[int]] | None = None,
+ global_rank: int = 0,
+ global_world_size: int = 1,
+ ):
+ group_name = group_name or "anonymous"
+ self.unique_name = _get_unique_name(group_name)
+ _register_group(self)
+
+ self.rank = global_rank
+ self.local_rank = local_rank
+
+ self_device_group = None
+ self_cpu_group = None
+ self_tcp_store_group = None
+
+ from vllm.platforms import current_platform
+
+ backend = str(torch_distributed_backend)
+ self.backend = backend
+ assert group_ports is not None, "group_ports is not provided"
+ for idx, ranks in enumerate(group_ranks):
+ if self.rank in ranks:
+ self.ranks = ranks
+ self.world_size = len(ranks)
+ self.rank_in_group = ranks.index(self.rank)
+
+ ports = group_ports[idx]
+ device_port = ports[0]
+ cpu_port = ports[1]
+ tcp_store_port = ports[2]
+
+ device_group = stateless_init_torch_distributed_process_group(
+ host=host,
+ port=device_port,
+ rank=self.rank_in_group,
+ world_size=self.world_size,
+ backend=backend,
+ group_name=f"{self.unique_name}_device",
+ )
+ cpu_group = stateless_init_torch_distributed_process_group(
+ host=host,
+ port=cpu_port,
+ rank=self.rank_in_group,
+ world_size=self.world_size,
+ backend="gloo",
+ group_name=f"{self.unique_name}_cpu",
+ )
+ tcp_store_group = StatelessProcessGroup.create(
+ host=host,
+ port=tcp_store_port,
+ rank=self.rank_in_group,
+ world_size=self.world_size,
+ )
+
+ self_device_group = device_group
+ self_cpu_group = cpu_group
+ self_tcp_store_group = tcp_store_group
+
+ assert self_cpu_group is not None
+ assert self_device_group is not None
+ assert self_tcp_store_group is not None
+
+ self.cpu_group = self_cpu_group
+ self.device_group = self_device_group
+ self.tcp_store_group = self_tcp_store_group
+
+ if current_platform.is_cuda_alike():
+ self.device = torch.device(f"cuda:{local_rank}")
+ elif current_platform.is_xpu():
+ self.device = torch.device(f"xpu:{local_rank}")
+ elif current_platform.is_out_of_tree():
+ self.device = torch.device(f"{current_platform.device_name}:{local_rank}")
+ else:
+ self.device = torch.device("cpu")
+
+ self.use_device_communicator = use_device_communicator
+ self.device_communicator = None
+ if use_device_communicator and self.world_size > 1:
+ device_comm_cls = resolve_obj_by_qualname(
+ current_platform.get_device_communicator_cls()
+ )
+ assert device_comm_cls == CudaCommunicator
+ self.device_communicator = CudaCommunicator(
+ cpu_group=self.cpu_group,
+ device=self.device,
+ device_group=self.device_group,
+ unique_name=self.unique_name,
+ global_ranks=self.ranks,
+ global_world_size=global_world_size,
+ tcp_store_group=self.tcp_store_group,
+ )
+
+ self.mq_broadcaster = None
+
+ self.use_custom_op_call = (
+ current_platform.is_cuda_alike() or current_platform.is_tpu()
+ )
+ self.use_cpu_custom_send_recv = False
+
+ def destroy(self):
+ if self.device_communicator:
+ self.device_communicator.destroy()
+ if self.device_group:
+ stateless_destroy_torch_distributed_process_group(self.device_group)
+ if self.cpu_group:
+ stateless_destroy_torch_distributed_process_group(self.cpu_group)
+
+ def size(self) -> int:
+ """Return the world size of this group."""
+ return self.world_size
+
+ def broadcast(self, input_: torch.Tensor, src: int = 0):
+ if self.world_size == 1:
+ return input_
+
+ if self.device_communicator and input_.is_cuda:
+ return self.device_communicator.broadcast(input_, src)
+ else:
+ return self.tcp_store_group.broadcast(input_, src)
+
+ def broadcast_object(self, obj=None, src: int = 0):
+ if self.world_size == 1:
+ return obj
+ return self.tcp_store_group.broadcast_obj(obj, src)
+
+ def broadcast_object_list(
+ self, obj_list: list[Any], src: int = 0, group: ProcessGroup | None = None
+ ):
+ assert src < self.world_size
+
+ if self.world_size == 1:
+ return obj_list
+
+ if self.rank_in_group == src:
+ for obj in obj_list:
+ self.tcp_store_group.broadcast_obj(obj, src)
+ else:
+ for i in range(len(obj_list)):
+ obj_list[i] = self.tcp_store_group.broadcast_obj(None, src)
+
+ return obj_list
+
+ def broadcast_tensor_dict(
+ self,
+ tensor_dict: dict[str, torch.Tensor | Any] | None = None,
+ src: int = 0,
+ group: ProcessGroup | None = None,
+ metadata_group: ProcessGroup | None = None,
+ ) -> dict[str, torch.Tensor | Any] | None:
+ if self.world_size == 1:
+ return tensor_dict
+
+ if self.rank_in_group == src:
+ assert isinstance(tensor_dict, dict), (
+ f"Expecting a dictionary, got {type(tensor_dict)}"
+ )
+ metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
+ else:
+ metadata_list = None
+ tensor_list = []
+
+ recv_metadata_list: list[tuple[str, Any]] = self.tcp_store_group.broadcast_obj(
+ metadata_list, src
+ )
+
+ if self.rank_in_group != src:
+ tensor_dict = {}
+ for key, value in recv_metadata_list:
+ if isinstance(value, TensorMetadata):
+ tensor = torch.empty(
+ value.size, dtype=value.dtype, device=value.device
+ )
+ tensor_list.append(tensor)
+ tensor_dict[key] = tensor
+ else:
+ tensor_dict[key] = value
+
+ for tensor in tensor_list:
+ if tensor.numel() == 0:
+ continue
+ if self.device_communicator and tensor.is_cuda:
+ tensor.copy_(self.device_communicator.broadcast(tensor, src))
+ else:
+ tensor.copy_(self.tcp_store_group.broadcast(tensor, src))
+
+ return tensor_dict
+
+ def send_object(self, obj, dst: int) -> None:
+ assert dst < self.world_size
+ assert dst != self.rank_in_group
+ self.tcp_store_group.send_obj(obj, dst)
+
+ def recv_object(self, src: int):
+ assert src < self.world_size
+ assert src != self.rank_in_group
+ return self.tcp_store_group.recv_obj(src)
+
+ def send_tensor_dict(
+ self,
+ tensor_dict: dict[str, torch.Tensor | Any],
+ dst: int | None = None,
+ all_gather_group: Optional["GroupCoordinator"] = None,
+ all_gather_tensors: dict[str, bool] | None = None,
+ ) -> dict[str, torch.Tensor | Any] | None:
+ if self.world_size == 1:
+ return tensor_dict
+
+ if dst is None:
+ dst = (self.rank_in_group + 1) % self.world_size
+ assert dst < self.world_size
+
+ metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
+ self.tcp_store_group.send_obj(metadata_list, dst)
+
+ for tensor in tensor_list:
+ if tensor.numel() == 0:
+ continue
+ if self.device_communicator and tensor.is_cuda:
+ self.device_communicator.send(tensor, dst)
+ else:
+ self.tcp_store_group.send(tensor, dst)
+
+ return None
+
+ def recv_tensor_dict(
+ self,
+ src: int | None = None,
+ all_gather_group: Optional["GroupCoordinator"] = None,
+ all_gather_tensors: dict[str, bool] | None = None,
+ ) -> dict[str, torch.Tensor | Any] | None:
+ if self.world_size == 1:
+ return None
+
+ if src is None:
+ src = (self.rank_in_group - 1) % self.world_size
+ assert src < self.world_size
+
+ recv_metadata_list = self.tcp_store_group.recv_obj(src)
+ tensor_dict = {}
+ for key, value in recv_metadata_list:
+ if isinstance(value, TensorMetadata):
+ tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)
+ if tensor.numel() > 0:
+ if self.device_communicator and tensor.is_cuda:
+ tensor = self.device_communicator.recv(
+ tensor.size(), tensor.dtype, src
+ )
+ else:
+ tensor = self.tcp_store_group.recv(tensor, src)
+ tensor_dict[key] = tensor
+ else:
+ tensor_dict[key] = value
+ return tensor_dict
+
+ def barrier(self):
+ self.tcp_store_group.barrier()
+
+ def gather(
+ self, input_: torch.Tensor, dst: int = 0, dim: int = -1
+ ) -> torch.Tensor | None:
+ if self.world_size == 1:
+ return input_
+
+ if self.device_communicator is None:
+ raise ValueError("No device communicator found")
+
+ if self.rank_in_group == dst:
+ gathered_list = [torch.empty_like(input_) for _ in range(self.world_size)]
+ gathered_list[self.rank_in_group] = input_
+ for src_rank in range(self.world_size):
+ if src_rank != self.rank_in_group:
+ gathered_list[src_rank] = self.device_communicator.recv(
+ input_.size(), input_.dtype, src_rank
+ )
+ return torch.cat(gathered_list, dim=dim)
+ else:
+ self.device_communicator.send(input_, dst)
+ return None
diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py
index 1737525..102f2f7 100644
--- a/vllm/distributed/utils.py
+++ b/vllm/distributed/utils.py
@@ -18,7 +18,7 @@ from datetime import timedelta
from typing import Any
import torch
-from torch.distributed import ProcessGroup, TCPStore
+from torch.distributed import ProcessGroup, Store, TCPStore
from torch.distributed.distributed_c10d import (
Backend,
PrefixStore,
@@ -228,6 +228,55 @@ class StatelessProcessGroup:
gathered_objs.append(recv_obj)
return gathered_objs
+ def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
+ """Broadcast a tensor from source rank to all other ranks."""
+ if self.rank == src:
+ tensor_bytes = pickle.dumps(tensor)
+ self.expire_data()
+ key = f"broadcast_tensor/{src}/{self.broadcast_send_counter}"
+ self.store.set(key, tensor_bytes)
+ self.broadcast_send_counter += 1
+ self.entries.append((key, time.time()))
+ return tensor
+ else:
+ key = f"broadcast_tensor/{src}/{self.broadcast_recv_src_counter[src]}"
+ tensor = pickle.loads(self.store.get(key))
+ self.broadcast_recv_src_counter[src] += 1
+ return tensor
+
+ def send(self, tensor: torch.Tensor, dst: int):
+ """Send a tensor to a destination rank."""
+ self.expire_data()
+ key = f"send_tensor/{dst}/{self.send_dst_counter[dst]}"
+ self.store.set(key, pickle.dumps(tensor))
+ self.send_dst_counter[dst] += 1
+ self.entries.append((key, time.time()))
+
+ def recv(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
+ """Receive a tensor from a source rank."""
+ key = f"send_tensor/{self.rank}/{self.recv_src_counter[src]}"
+ received = pickle.loads(self.store.get(key))
+ self.recv_src_counter[src] += 1
+ tensor.copy_(received)
+ return tensor
+
+ def all_reduce(
+ self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM
+ ) -> torch.Tensor:
+ """All-reduce a tensor across all ranks."""
+ tensors = self.all_gather_obj(tensor)
+ result = tensors[0].clone()
+ for t in tensors[1:]:
+ if op == torch.distributed.ReduceOp.SUM:
+ result.add_(t)
+ elif op == torch.distributed.ReduceOp.PRODUCT:
+ result.mul_(t)
+ elif op == torch.distributed.ReduceOp.MAX:
+ result = torch.maximum(result, t)
+ elif op == torch.distributed.ReduceOp.MIN:
+ result = torch.minimum(result, t)
+ return result
+
def barrier(self, timeout: float = 30.0):
"""A robust barrier to synchronize all ranks.
@@ -448,8 +497,14 @@ def init_gloo_process_group(
def stateless_init_torch_distributed_process_group(
- host: str, port: int, rank: int, world_size: int, backend: str
-) -> ProcessGroup:
+ host: str,
+ port: int,
+ rank: int,
+ world_size: int,
+ backend: str,
+ group_name: str | None = None,
+ return_store: bool = False,
+) -> ProcessGroup | tuple[ProcessGroup, Store]:
"""
A replacement for `torch.distributed.init_process_group` that does not
pollute the global state. The created ProcessGroup object can be used for
@@ -496,25 +551,35 @@ def stateless_init_torch_distributed_process_group(
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store = PrefixStore(init_method, store)
- try:
+
+ if backend == "gloo":
+ pg = init_gloo_process_group(
+ prefix_store=prefix_store,
+ group_rank=group_rank,
+ group_size=group_size,
+ timeout=timeout,
+ )
+ else:
from vllm.platforms import current_platform
- return current_platform.stateless_init_device_torch_dist_pg(
+ pg = current_platform.stateless_init_device_torch_dist_pg(
backend=backend,
prefix_store=prefix_store,
group_rank=group_rank,
group_size=group_size,
timeout=timeout,
)
- except NotImplementedError:
- # If platform doesn't implement stateless_init_device_torch_dist_pg, it
- # will raise a NotImplementedError. In this case, we fall back to gloo.
- return init_gloo_process_group(
- prefix_store=prefix_store,
- group_rank=group_rank,
- group_size=group_size,
- timeout=timeout,
- )
+
+ if group_name is not None:
+ from torch._C._distributed_c10d import _register_process_group
+
+ pg._set_group_name(group_name)
+ _register_process_group(group_name, pg)
+
+ if return_store:
+ return pg, store
+ else:
+ return pg
def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None:
diff --git a/vllm/distributed/weight_transfer/base.py b/vllm/distributed/weight_transfer/base.py
index b87f190..788dcef 100644
--- a/vllm/distributed/weight_transfer/base.py
+++ b/vllm/distributed/weight_transfer/base.py
@@ -3,7 +3,7 @@
"""Base class for weight transfer engines."""
from abc import ABC, abstractmethod
-from collections.abc import Callable
+from collections.abc import Callable, Iterator
from dataclasses import KW_ONLY, dataclass, field
from typing import Any, Generic, TypeVar
@@ -156,3 +156,30 @@ class WeightTransferEngine(ABC, Generic[TInitInfo, TUpdateInfo]):
This should be called when the worker is shutting down.
"""
raise NotImplementedError
+
+ @staticmethod
+ @abstractmethod
+ def trainer_send_weights(
+ iterator: Iterator[tuple[str, torch.Tensor]],
+ trainer_args: dict[str, Any] | Any,
+ ) -> None:
+ """
+ Send weights from trainer to inference workers.
+
+ This is a static method that can be called from the trainer process
+ to send weights to all inference workers.
+
+ Args:
+ iterator: Iterator of model parameters. Returns (name, tensor) tuples.
+ The tensors should be on the appropriate device for the backend.
+ trainer_args: Dictionary containing backend-specific arguments needed
+ to send weights. The structure depends on the backend:
+ - NCCL: Contains 'group', 'src', 'packed', etc.
+ - IPC: Contains 'mode' ('http' or 'ray'),
+ 'llm_handle' (for Ray), 'url' (for HTTP), etc.
+
+ Example:
+ >>> param_iter = ((n, p) for n, p in model.named_parameters())
+ >>> engine.trainer_send_weights(param_iter, trainer_args)
+ """
+ raise NotImplementedError
diff --git a/vllm/distributed/weight_transfer/factory.py b/vllm/distributed/weight_transfer/factory.py
index 7235e30..f8e9c86 100644
--- a/vllm/distributed/weight_transfer/factory.py
+++ b/vllm/distributed/weight_transfer/factory.py
@@ -114,3 +114,9 @@ WeightTransferEngineFactory.register_engine(
"vllm.distributed.weight_transfer.nccl_engine",
"NCCLWeightTransferEngine",
)
+
+WeightTransferEngineFactory.register_engine(
+ "ipc",
+ "vllm.distributed.weight_transfer.ipc_engine",
+ "IPCWeightTransferEngine",
+)
diff --git a/vllm/distributed/weight_transfer/ipc_engine.py b/vllm/distributed/weight_transfer/ipc_engine.py
new file mode 100644
index 0000000..2edbec6
--- /dev/null
+++ b/vllm/distributed/weight_transfer/ipc_engine.py
@@ -0,0 +1,291 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""IPC-based weight transfer engine using CUDA IPC for communication."""
+
+import base64
+import pickle
+from collections.abc import Callable, Iterator
+from dataclasses import asdict, dataclass
+from typing import Any
+
+import requests
+import torch
+from torch.multiprocessing.reductions import reduce_tensor
+
+from vllm.config.parallel import ParallelConfig
+from vllm.config.weight_transfer import WeightTransferConfig
+from vllm.distributed.weight_transfer.base import (
+ WeightTransferEngine,
+ WeightTransferInitInfo,
+ WeightTransferUpdateInfo,
+)
+
+
+@dataclass
+class IPCTrainerSendWeightsArgs:
+ """Arguments for IPC trainer_send_weights method."""
+
+ mode: str
+ """Transport mode: 'http' or 'ray'."""
+ llm_handle: Any = None
+ """Ray ObjectRef to LLM handle (required for 'ray' mode)."""
+ url: str | None = None
+ """Base URL for HTTP endpoint (required for 'http' mode)."""
+
+ def __post_init__(self):
+ """Validate that required arguments are provided for the selected mode."""
+ if self.mode == "ray" and self.llm_handle is None:
+ raise ValueError("llm_handle is required for 'ray' mode")
+ if self.mode == "http" and self.url is None:
+ raise ValueError("url is required for 'http' mode")
+ if self.mode not in ("ray", "http"):
+ raise ValueError(f"mode must be 'ray' or 'http', got {self.mode}")
+
+
+@dataclass
+class IPCWeightTransferInitInfo(WeightTransferInitInfo):
+ """Initialization info for IPC weight transfer backend. No init needed for IPC."""
+
+ pass
+
+
+@dataclass
+class IPCWeightTransferUpdateInfo(WeightTransferUpdateInfo):
+ """Update info for IPC weight transfer backend.
+
+ Accepts IPC handles either directly via ``ipc_handles`` (Ray transport)
+ or as a base64-encoded pickle via ``ipc_handles_pickled`` (HTTP transport).
+ Exactly one of the two must be provided; if ``ipc_handles_pickled`` is set
+ it is unpickled into ``ipc_handles`` during ``__post_init__``.
+ """
+
+ names: list[str]
+ dtype_names: list[str]
+ shapes: list[list[int]]
+ ipc_handles: list[dict[str, tuple[Callable, tuple]]] | None = None
+ """IPC handles mapping physical GPU UUID to (func, args) tuple.
+ Each handle is a dictionary mapping GPU UUID strings to IPC handle tuples."""
+ ipc_handles_pickled: str | None = None
+ """Base64-encoded pickled IPC handles, used for HTTP transport."""
+
+ def __post_init__(self):
+ if self.ipc_handles_pickled is not None:
+ if self.ipc_handles is not None:
+ raise ValueError(
+ "Cannot specify both `ipc_handles` and `ipc_handles_pickled`"
+ )
+ self.ipc_handles = pickle.loads(base64.b64decode(self.ipc_handles_pickled))
+ self.ipc_handles_pickled = None
+
+ if self.ipc_handles is None:
+ raise ValueError(
+ "Either `ipc_handles` or `ipc_handles_pickled` must be provided"
+ )
+
+ num_params = len(self.names)
+ if len(self.dtype_names) != num_params:
+ raise ValueError(
+ f"`dtype_names` should be of the same size as `names`: "
+ f"got {len(self.dtype_names)} and {len(self.names)}"
+ )
+ if len(self.shapes) != num_params:
+ raise ValueError(
+ f"`shapes` should be of the same size as `names`: "
+ f"got {len(self.shapes)} and {len(self.names)}"
+ )
+ if len(self.ipc_handles) != num_params:
+ raise ValueError(
+ f"`ipc_handles` should be of the same size as `names`: "
+ f"got {len(self.ipc_handles)} and {len(self.names)}"
+ )
+
+
+class IPCWeightTransferEngine(
+ WeightTransferEngine[IPCWeightTransferInitInfo, IPCWeightTransferUpdateInfo]
+):
+ """
+ Weight transfer engine using CUDA IPC for communication between trainer and workers.
+
+ This implementation uses CUDA IPC to transfer weights from the trainer (rank 0)
+ to all inference workers in a process group. IPC handles are used to share
+ memory between processes on the same node.
+ """
+
+ # Define backend-specific dataclass types
+ init_info_cls = IPCWeightTransferInitInfo
+ update_info_cls = IPCWeightTransferUpdateInfo
+
+ def __init__(
+ self, config: WeightTransferConfig, parallel_config: ParallelConfig
+ ) -> None:
+ """
+ Initialize the IPC weight transfer engine.
+
+ Args:
+ config: The configuration for the weight transfer engine
+ parallel_config: The configuration for the parallel setup
+ """
+ super().__init__(config, parallel_config)
+
+ def init_transfer_engine(self, init_info: IPCWeightTransferInitInfo) -> None:
+ """
+ Initialize the weight transfer mechanism.
+ This is called once at the beginning of training.
+ No initialization needed for IPC backend.
+
+ Args:
+ init_info: IPC initialization info (empty)
+ """
+ pass
+
+ def receive_weights(
+ self,
+ update_info: IPCWeightTransferUpdateInfo,
+ load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
+ ) -> None:
+ """
+ Receive weights from the trainer via CUDA IPC handles.
+
+ Args:
+ update_info: IPC update info containing parameter names, dtypes, shapes,
+ and IPC handles. Each IPC handle is a mapping between physical
+ GPU UUID and the IPC handle tuple (func, args).
+ load_weights: Callable that loads weights into the model. Called
+ incrementally for each weight to avoid OOM.
+ """
+ assert update_info.ipc_handles is not None
+ weights = []
+ for name, _dtype_name, _shape, ipc_handle in zip(
+ update_info.names,
+ update_info.dtype_names,
+ update_info.shapes,
+ update_info.ipc_handles,
+ ):
+ device_index = torch.cuda.current_device()
+ props = torch.cuda.get_device_properties(device_index)
+ physical_gpu_id = str(props.uuid)
+
+ if physical_gpu_id not in ipc_handle:
+ raise ValueError(
+ f"IPC handle not found for GPU UUID {physical_gpu_id}. "
+ f"Available UUIDs: {list(ipc_handle.keys())}"
+ )
+
+ handle = ipc_handle[physical_gpu_id]
+
+ func, args = handle
+ list_args = list(args) # type: ignore
+ # Index 6 is the device_index parameter in torch's
+ # IPC handle tuple (rebuild_cuda_tensor). Update it
+ # to the current device since the logical index can
+ # differ between sender and receiver.
+ list_args[6] = device_index
+ weight = func(*list_args) # type: ignore
+ weights.append((name, weight))
+
+ load_weights(weights)
+
+ def shutdown(self) -> None:
+ """
+ Shutdown the weight transfer engine.
+ """
+ pass
+
+ @staticmethod
+ def trainer_send_weights(
+ iterator: Iterator[tuple[str, torch.Tensor]],
+ trainer_args: dict[str, Any] | IPCTrainerSendWeightsArgs,
+ ) -> None:
+ """
+ Send weights from trainer to inference workers via CUDA IPC.
+
+ Supports two modes:
+ - 'ray': Sends weights via Ray RPC to a Ray-based LLM handle
+ - 'http': Sends weights via HTTP POST to a vLLM HTTP server
+
+ Args:
+ iterator: Iterator of model parameters. Returns (name, tensor) tuples.
+ Tensors should be on the same GPU as the inference workers.
+ trainer_args: Dictionary containing IPC-specific arguments.
+ Should contain keys from IPCTrainerSendWeightsArgs:
+ - mode: 'ray' or 'http'
+ - llm_handle: Ray ObjectRef (for 'ray' mode)
+ - url: Base URL string (for 'http' mode)
+
+ Example (Ray mode):
+ >>> from vllm.distributed.weight_transfer.ipc_engine import (
+ ... IPCWeightTransferEngine,
+ ... IPCTrainerSendWeightsArgs,
+ ... )
+ >>> param_iter = ((n, p) for n, p in model.named_parameters())
+ >>> args = IPCTrainerSendWeightsArgs(mode="ray", llm_handle=llm_handle)
+ >>> IPCWeightTransferEngine.trainer_send_weights(param_iter, asdict(args))
+
+ Example (HTTP mode):
+ >>> args = IPCTrainerSendWeightsArgs(
+ ... mode="http", url="http://localhost:8000"
+ ... )
+ >>> IPCWeightTransferEngine.trainer_send_weights(param_iter, asdict(args))
+ """
+ # Parse trainer args - accept either dict or dataclass instance
+ if isinstance(trainer_args, dict):
+ args = IPCTrainerSendWeightsArgs(**trainer_args)
+ else:
+ args = trainer_args
+
+ # Get physical GPU UUID
+ device_index = torch.cuda.current_device()
+ props = torch.cuda.get_device_properties(device_index)
+ gpu_uuid = str(props.uuid)
+
+ # Collect weight metadata and create IPC handles
+ names = []
+ dtype_names = []
+ shapes = []
+ ipc_handles = []
+
+ for name, tensor in iterator:
+ names.append(name)
+ dtype_names.append(str(tensor.dtype).split(".")[-1])
+ shapes.append(list(tensor.shape))
+
+ # Create IPC handle for this weight tensor
+ # The tensor must remain in memory for IPC to work
+ weight = tensor.detach().contiguous()
+ ipc_handle = reduce_tensor(weight)
+ ipc_handles.append({gpu_uuid: ipc_handle})
+
+ # Send weights based on mode
+ if args.mode == "ray":
+ # Ray mode: send via Ray RPC
+ import ray
+
+ update_info = asdict(
+ IPCWeightTransferUpdateInfo(
+ names=names,
+ dtype_names=dtype_names,
+ shapes=shapes,
+ ipc_handles=ipc_handles,
+ )
+ )
+ ray.get(
+ args.llm_handle.update_weights.remote(dict(update_info=update_info))
+ )
+ elif args.mode == "http":
+ # HTTP mode: send via HTTP POST with pickled handles
+ # Pickle and base64 encode IPC handles for HTTP transmission
+ pickled_handles = base64.b64encode(pickle.dumps(ipc_handles)).decode(
+ "utf-8"
+ )
+
+ url = f"{args.url}/update_weights"
+ payload = {
+ "update_info": {
+ "names": names,
+ "dtype_names": dtype_names,
+ "shapes": shapes,
+ "ipc_handles_pickled": pickled_handles,
+ }
+ }
+ response = requests.post(url, json=payload, timeout=300)
+ response.raise_for_status()
diff --git a/vllm/distributed/weight_transfer/nccl_engine.py b/vllm/distributed/weight_transfer/nccl_engine.py
index 5c90198..e8a1091 100644
--- a/vllm/distributed/weight_transfer/nccl_engine.py
+++ b/vllm/distributed/weight_transfer/nccl_engine.py
@@ -35,6 +35,32 @@ class NCCLWeightTransferInitInfo(WeightTransferInitInfo):
world_size: int
+@dataclass
+class NCCLTrainerSendWeightsArgs:
+ """Arguments for NCCL trainer_send_weights method."""
+
+ group: Any
+ """Process group (PyNcclCommunicator) for NCCL communication."""
+ src: int = 0
+ """Source rank (default 0, trainer is typically rank 0)."""
+ post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor] | None = None
+ """Optional function to apply to each (name, tensor) pair before broadcasting.
+ If None, extracts just the tensor."""
+ packed: bool = False
+ """Whether to use packed tensor broadcasting for efficiency.
+ When True, multiple tensors are batched together before broadcasting
+ to reduce NCCL communication overhead."""
+ stream: torch.cuda.Stream | None = None
+ """CUDA stream to use for broadcasting if packed is False.
+ If packed is True, new streams will be created for each buffer."""
+ packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES
+ """Size in bytes for each packed tensor buffer.
+ Must match the value used in NCCLWeightTransferUpdateInfo."""
+ packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS
+ """Number of buffers for double/triple buffering during packed transfer.
+ Must match the value used in NCCLWeightTransferUpdateInfo."""
+
+
@dataclass
class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo):
"""Update info for NCCL weight transfer backend."""
@@ -47,7 +73,7 @@ class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo):
When True, multiple tensors are batched together before broadcasting
to reduce NCCL communication overhead."""
packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES
- """Size in bytes for each packed tensor buffer. Default is 1GB.
+ """Size in bytes for each packed tensor buffer.
Both producer and consumer must use the same value."""
packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS
"""Number of buffers for double/triple buffering during packed transfer.
@@ -186,47 +212,38 @@ class NCCLWeightTransferEngine(
@staticmethod
def trainer_send_weights(
iterator: Iterator[tuple[str, torch.Tensor]],
- group: Any,
- src: int = 0,
- post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor]
- | None = None,
- packed: bool = False,
- stream: torch.cuda.Stream | None = None,
- packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES,
- packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS,
+ trainer_args: dict[str, Any] | NCCLTrainerSendWeightsArgs,
) -> None:
"""Broadcast weights from trainer to vLLM workers.
Args:
iterator: Iterator of model parameters. Returns (name, tensor) tuples
- group: Process group (PyNcclCommunicator)
- src: Source rank (default 0, trainer is typically rank 0)
- post_iter_func: Optional function to apply to each (name, tensor) pair
- before broadcasting. If None, extracts just the tensor.
- packed: Whether to use packed tensor broadcasting for efficiency.
- When True, multiple tensors are batched together before
- broadcasting to reduce NCCL communication overhead.
- stream: CUDA stream to use for broadcasting if packed is False.
- If packed is True, new streams will be created for each buffer.
- packed_buffer_size_bytes: Size in bytes for each packed tensor buffer.
- Must match the value used in NCCLWeightTransferUpdateInfo.
- packed_num_buffers: Number of buffers for double/triple buffering.
- Must match the value used in NCCLWeightTransferUpdateInfo.
+ trainer_args: Dictionary or NCCLTrainerSendWeightsArgs instance containing
+ NCCL-specific arguments. If a dict, should contain keys from
+ NCCLTrainerSendWeightsArgs.
Example:
>>> from vllm.distributed.weight_transfer.nccl_engine import (
... NCCLWeightTransferEngine,
+ ... NCCLTrainerSendWeightsArgs,
... )
>>> param_iter = ((n, p) for n, p in model.named_parameters())
- >>> NCCLWeightTransferEngine.trainer_send_weights(
- ... param_iter, group, packed=True
- ... )
+ >>> args = NCCLTrainerSendWeightsArgs(group=group, packed=True)
+ >>> NCCLWeightTransferEngine.trainer_send_weights(param_iter, args)
"""
- if post_iter_func is None:
+ # Parse trainer args - accept either dict or dataclass instance
+ if isinstance(trainer_args, dict):
+ args = NCCLTrainerSendWeightsArgs(**trainer_args)
+ else:
+ args = trainer_args
+
+ if args.post_iter_func is None:
# Default: extract just the tensor from (name, tensor) tuple
post_iter_func = lambda x: x[1]
+ else:
+ post_iter_func = args.post_iter_func
- if packed:
+ if args.packed:
# Use packed tensor broadcasting for efficiency
from vllm.distributed.weight_transfer.packed_tensor import (
packed_broadcast_producer,
@@ -234,18 +251,20 @@ class NCCLWeightTransferEngine(
packed_broadcast_producer(
iterator=iterator,
- group=group,
- src=src,
+ group=args.group,
+ src=args.src,
post_iter_func=post_iter_func,
- buffer_size_bytes=packed_buffer_size_bytes,
- num_buffers=packed_num_buffers,
+ buffer_size_bytes=args.packed_buffer_size_bytes,
+ num_buffers=args.packed_num_buffers,
)
else:
# Use simple one-by-one broadcasting
for item in iterator:
tensor = post_iter_func(item)
- group.broadcast(
- tensor, src=src, stream=stream or torch.cuda.current_stream()
+ args.group.broadcast(
+ tensor,
+ src=args.src,
+ stream=args.stream or torch.cuda.current_stream(),
)
@staticmethod
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index 8c48661..c4d3c03 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -419,6 +419,7 @@ class EngineArgs:
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
moe_backend: MoEBackend = KernelConfig.moe_backend
all2all_backend: All2AllBackend = ParallelConfig.all2all_backend
+ enable_elastic_ep: bool = ParallelConfig.enable_elastic_ep
enable_dbo: bool = ParallelConfig.enable_dbo
ubatch_size: int = ParallelConfig.ubatch_size
dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold
@@ -896,6 +897,9 @@ class EngineArgs:
"--ubatch-size",
**parallel_kwargs["ubatch_size"],
)
+ parallel_group.add_argument(
+ "--enable-elastic-ep", **parallel_kwargs["enable_elastic_ep"]
+ )
parallel_group.add_argument(
"--dbo-decode-token-threshold",
**parallel_kwargs["dbo_decode_token_threshold"],
@@ -1321,6 +1325,7 @@ class EngineArgs:
"launched vLLM.",
self.seed,
)
+
return ModelConfig(
model=self.model,
model_weights=self.model_weights,
@@ -1697,6 +1702,7 @@ class EngineArgs:
is_moe_model=model_config.is_moe,
enable_expert_parallel=self.enable_expert_parallel,
all2all_backend=self.all2all_backend,
+ enable_elastic_ep=self.enable_elastic_ep,
enable_dbo=self.enable_dbo,
ubatch_size=self.ubatch_size,
dbo_decode_token_threshold=self.dbo_decode_token_threshold,
@@ -1905,6 +1911,7 @@ class EngineArgs:
performance_mode=self.performance_mode,
weight_transfer_config=self.weight_transfer_config,
)
+
return config
def _check_feature_supported(self):
@@ -2074,20 +2081,19 @@ class EngineArgs:
)
# Disable chunked prefill and prefix caching for:
- # POWER (ppc64le)/RISCV CPUs in V1
+ # RISCV CPUs in V1
if current_platform.is_cpu() and current_platform.get_cpu_architecture() in (
- CpuArchEnum.POWERPC,
CpuArchEnum.RISCV,
):
logger.info(
- "Chunked prefill is not supported for POWER, "
- "and RISC-V CPUs; "
+ "Chunked prefill is not supported for"
+ "RISC-V CPUs; "
"disabling it for V1 backend."
)
self.enable_chunked_prefill = False
logger.info(
- "Prefix caching is not supported for POWER, "
- "and RISC-V CPUs; "
+ "Prefix caching is not supported for "
+ "RISC-V CPUs; "
"disabling it for V1 backend."
)
self.enable_prefix_caching = False
@@ -2181,14 +2187,10 @@ class AsyncEngineArgs(EngineArgs):
"--enable-log-requests",
action=argparse.BooleanOptionalAction,
default=AsyncEngineArgs.enable_log_requests,
- help="Enable logging requests.",
- )
- parser.add_argument(
- "--disable-log-requests",
- action=argparse.BooleanOptionalAction,
- default=not AsyncEngineArgs.enable_log_requests,
- help="[DEPRECATED] Disable logging requests.",
- deprecated=True,
+ help="Enable logging request information, dependant on log level:\n"
+ "- INFO: Request ID, parameters and LoRA request.\n"
+ "- DEBUG: Prompt inputs (e.g: text, token IDs).\n"
+ "You can set the minimum log level via `VLLM_LOGGING_LEVEL`.",
)
current_platform.pre_register_and_update(parser)
return parser
diff --git a/vllm/entrypoints/anthropic/api_router.py b/vllm/entrypoints/anthropic/api_router.py
index 1494dd7..2b65fff 100644
--- a/vllm/entrypoints/anthropic/api_router.py
+++ b/vllm/entrypoints/anthropic/api_router.py
@@ -8,6 +8,8 @@ from fastapi import APIRouter, Depends, FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.anthropic.protocol import (
+ AnthropicCountTokensRequest,
+ AnthropicCountTokensResponse,
AnthropicError,
AnthropicErrorResponse,
AnthropicMessagesRequest,
@@ -31,6 +33,18 @@ def messages(request: Request) -> AnthropicServingMessages:
return request.app.state.anthropic_serving_messages
+def translate_error_response(response: ErrorResponse) -> JSONResponse:
+ anthropic_error = AnthropicErrorResponse(
+ error=AnthropicError(
+ type=response.error.type,
+ message=response.error.message,
+ )
+ )
+ return JSONResponse(
+ status_code=response.error.code, content=anthropic_error.model_dump()
+ )
+
+
@router.post(
"/v1/messages",
dependencies=[Depends(validate_json_request)],
@@ -44,17 +58,6 @@ def messages(request: Request) -> AnthropicServingMessages:
@with_cancellation
@load_aware_call
async def create_messages(request: AnthropicMessagesRequest, raw_request: Request):
- def translate_error_response(response: ErrorResponse) -> JSONResponse:
- anthropic_error = AnthropicErrorResponse(
- error=AnthropicError(
- type=response.error.type,
- message=response.error.message,
- )
- )
- return JSONResponse(
- status_code=response.error.code, content=anthropic_error.model_dump()
- )
-
handler = messages(raw_request)
if handler is None:
base_server = raw_request.app.state.openai_serving_tokenization
@@ -88,5 +91,46 @@ async def create_messages(request: AnthropicMessagesRequest, raw_request: Reques
return StreamingResponse(content=generator, media_type="text/event-stream")
+@router.post(
+ "/v1/messages/count_tokens",
+ dependencies=[Depends(validate_json_request)],
+ responses={
+ HTTPStatus.OK.value: {"model": AnthropicCountTokensResponse},
+ HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse},
+ HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse},
+ HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse},
+ },
+)
+@load_aware_call
+@with_cancellation
+async def count_tokens(request: AnthropicCountTokensRequest, raw_request: Request):
+ handler = messages(raw_request)
+ if handler is None:
+ base_server = raw_request.app.state.openai_serving_tokenization
+ error = base_server.create_error_response(
+ message="The model does not support Messages API"
+ )
+ return translate_error_response(error)
+
+ try:
+ response = await handler.count_tokens(request, raw_request)
+ except Exception as e:
+ logger.exception("Error in count_tokens: %s", e)
+ return JSONResponse(
+ status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
+ content=AnthropicErrorResponse(
+ error=AnthropicError(
+ type="internal_error",
+ message=str(e),
+ )
+ ).model_dump(),
+ )
+
+ if isinstance(response, ErrorResponse):
+ return translate_error_response(response)
+
+ return JSONResponse(content=response.model_dump(exclude_none=True))
+
+
def attach_router(app: FastAPI):
app.include_router(router)
diff --git a/vllm/entrypoints/anthropic/protocol.py b/vllm/entrypoints/anthropic/protocol.py
index af9430e..c541db5 100644
--- a/vllm/entrypoints/anthropic/protocol.py
+++ b/vllm/entrypoints/anthropic/protocol.py
@@ -34,7 +34,7 @@ class AnthropicUsage(BaseModel):
class AnthropicContentBlock(BaseModel):
"""Content block in message"""
- type: Literal["text", "image", "tool_use", "tool_result"]
+ type: Literal["text", "image", "tool_use", "tool_result", "thinking"]
text: str | None = None
# For image content
source: dict[str, Any] | None = None
@@ -45,6 +45,9 @@ class AnthropicContentBlock(BaseModel):
input: dict[str, Any] | None = None
content: str | list[dict[str, Any]] | None = None
is_error: bool | None = None
+ # For thinking content
+ thinking: str | None = None
+ signature: str | None = None
class AnthropicMessage(BaseModel):
@@ -74,7 +77,7 @@ class AnthropicTool(BaseModel):
class AnthropicToolChoice(BaseModel):
"""Tool Choice definition"""
- type: Literal["auto", "any", "tool"]
+ type: Literal["auto", "any", "tool", "none"]
name: str | None = None
@model_validator(mode="after")
@@ -118,9 +121,14 @@ class AnthropicMessagesRequest(BaseModel):
class AnthropicDelta(BaseModel):
"""Delta for streaming responses"""
- type: Literal["text_delta", "input_json_delta"] | None = None
+ type: (
+ Literal["text_delta", "input_json_delta", "thinking_delta", "signature_delta"]
+ | None
+ ) = None
text: str | None = None
+ thinking: str | None = None
partial_json: str | None = None
+ signature: str | None = None
# Message delta
stop_reason: (
@@ -167,3 +175,33 @@ class AnthropicMessagesResponse(BaseModel):
def model_post_init(self, __context):
if not self.id:
self.id = f"msg_{int(time.time() * 1000)}"
+
+
+class AnthropicContextManagement(BaseModel):
+ """Context management information for token counting."""
+
+ original_input_tokens: int
+
+
+class AnthropicCountTokensRequest(BaseModel):
+ """Anthropic messages.count_tokens request"""
+
+ model: str
+ messages: list[AnthropicMessage]
+ system: str | list[AnthropicContentBlock] | None = None
+ tool_choice: AnthropicToolChoice | None = None
+ tools: list[AnthropicTool] | None = None
+
+ @field_validator("model")
+ @classmethod
+ def validate_model(cls, v):
+ if not v:
+ raise ValueError("Model is required")
+ return v
+
+
+class AnthropicCountTokensResponse(BaseModel):
+ """Anthropic messages.count_tokens response"""
+
+ input_tokens: int
+ context_management: AnthropicContextManagement | None = None
diff --git a/vllm/entrypoints/anthropic/serving.py b/vllm/entrypoints/anthropic/serving.py
index dc03731..85232e9 100644
--- a/vllm/entrypoints/anthropic/serving.py
+++ b/vllm/entrypoints/anthropic/serving.py
@@ -8,6 +8,7 @@
import json
import logging
import time
+import uuid
from collections.abc import AsyncGenerator
from typing import Any
@@ -16,6 +17,9 @@ from fastapi import Request
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.anthropic.protocol import (
AnthropicContentBlock,
+ AnthropicContextManagement,
+ AnthropicCountTokensRequest,
+ AnthropicCountTokensResponse,
AnthropicDelta,
AnthropicError,
AnthropicMessagesRequest,
@@ -85,94 +89,225 @@ class AnthropicServingMessages(OpenAIServingChat):
"tool_calls": "tool_use",
}
+ @staticmethod
+ def _convert_image_source_to_url(source: dict[str, Any]) -> str:
+ """Convert an Anthropic image source to an OpenAI-compatible URL.
+
+ Anthropic supports two image source types:
+ - base64: {"type": "base64", "media_type": "image/jpeg", "data": "..."}
+ - url: {"type": "url", "url": "https://..."}
+
+ For base64 sources, this constructs a proper data URI that
+ downstream processors (e.g. vLLM's media connector) can handle.
+ """
+ source_type = source.get("type")
+ if source_type == "url":
+ return source.get("url", "")
+ # Default to base64 processing if type is "base64"
+ # or missing, ensuring a proper data URI is always
+ # constructed for non-URL sources.
+ media_type = source.get("media_type", "image/jpeg")
+ data = source.get("data", "")
+ return f"data:{media_type};base64,{data}"
+
+ @classmethod
def _convert_anthropic_to_openai_request(
- self, anthropic_request: AnthropicMessagesRequest
+ cls, anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest
) -> ChatCompletionRequest:
"""Convert Anthropic message format to OpenAI format"""
- openai_messages = []
+ openai_messages: list[dict[str, Any]] = []
- # Add system message if provided
- if anthropic_request.system:
- if isinstance(anthropic_request.system, str):
- openai_messages.append(
- {"role": "system", "content": anthropic_request.system}
- )
- else:
- system_prompt = ""
- for block in anthropic_request.system:
- if block.type == "text" and block.text:
- system_prompt += block.text
- openai_messages.append({"role": "system", "content": system_prompt})
+ cls._convert_system_message(anthropic_request, openai_messages)
+ cls._convert_messages(anthropic_request.messages, openai_messages)
+ req = cls._build_base_request(anthropic_request, openai_messages)
+ cls._handle_streaming_options(req, anthropic_request)
+ cls._convert_tool_choice(anthropic_request, req)
+ cls._convert_tools(anthropic_request, req)
+ return req
- for msg in anthropic_request.messages:
+ @classmethod
+ def _convert_system_message(
+ cls,
+ anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
+ openai_messages: list[dict[str, Any]],
+ ) -> None:
+ """Convert Anthropic system message to OpenAI format"""
+ if not anthropic_request.system:
+ return
+
+ if isinstance(anthropic_request.system, str):
+ openai_messages.append(
+ {"role": "system", "content": anthropic_request.system}
+ )
+ else:
+ system_prompt = ""
+ for block in anthropic_request.system:
+ if block.type == "text" and block.text:
+ system_prompt += block.text
+ openai_messages.append({"role": "system", "content": system_prompt})
+
+ @classmethod
+ def _convert_messages(
+ cls, messages: list, openai_messages: list[dict[str, Any]]
+ ) -> None:
+ """Convert Anthropic messages to OpenAI format"""
+ for msg in messages:
openai_msg: dict[str, Any] = {"role": msg.role} # type: ignore
+
if isinstance(msg.content, str):
openai_msg["content"] = msg.content
else:
- # Handle complex content blocks
- content_parts: list[dict[str, Any]] = []
- tool_calls: list[dict[str, Any]] = []
-
- for block in msg.content:
- if block.type == "text" and block.text:
- content_parts.append({"type": "text", "text": block.text})
- elif block.type == "image" and block.source:
- content_parts.append(
- {
- "type": "image_url",
- "image_url": {"url": block.source.get("data", "")},
- }
- )
- elif block.type == "tool_use":
- # Convert tool use to function call format
- tool_call = {
- "id": block.id or f"call_{int(time.time())}",
- "type": "function",
- "function": {
- "name": block.name or "",
- "arguments": json.dumps(block.input or {}),
- },
- }
- tool_calls.append(tool_call)
- elif block.type == "tool_result":
- if msg.role == "user":
- openai_messages.append(
- {
- "role": "tool",
- "tool_call_id": block.tool_use_id or "",
- "content": str(block.content)
- if block.content
- else "",
- }
- )
- else:
- # Assistant tool result becomes regular text
- tool_result_text = (
- str(block.content) if block.content else ""
- )
- content_parts.append(
- {
- "type": "text",
- "text": f"Tool result: {tool_result_text}",
- }
- )
-
- # Add tool calls to the message if any
- if tool_calls:
- openai_msg["tool_calls"] = tool_calls # type: ignore
-
- # Add content parts if any
- if content_parts:
- if len(content_parts) == 1 and content_parts[0]["type"] == "text":
- openai_msg["content"] = content_parts[0]["text"]
- else:
- openai_msg["content"] = content_parts # type: ignore
- elif not tool_calls:
- continue
+ cls._convert_message_content(msg, openai_msg, openai_messages)
openai_messages.append(openai_msg)
- req = ChatCompletionRequest(
+ @classmethod
+ def _convert_message_content(
+ cls,
+ msg,
+ openai_msg: dict[str, Any],
+ openai_messages: list[dict[str, Any]],
+ ) -> None:
+ """Convert complex message content blocks"""
+ content_parts: list[dict[str, Any]] = []
+ tool_calls: list[dict[str, Any]] = []
+ reasoning_parts: list[str] = []
+
+ for block in msg.content:
+ cls._convert_block(
+ block,
+ msg.role,
+ content_parts,
+ tool_calls,
+ reasoning_parts,
+ openai_messages,
+ )
+
+ if reasoning_parts:
+ openai_msg["reasoning"] = "".join(reasoning_parts)
+
+ if tool_calls:
+ openai_msg["tool_calls"] = tool_calls # type: ignore
+
+ if content_parts:
+ if len(content_parts) == 1 and content_parts[0]["type"] == "text":
+ openai_msg["content"] = content_parts[0]["text"]
+ else:
+ openai_msg["content"] = content_parts # type: ignore
+ elif not tool_calls and not reasoning_parts:
+ return
+
+ @classmethod
+ def _convert_block(
+ cls,
+ block,
+ role: str,
+ content_parts: list[dict[str, Any]],
+ tool_calls: list[dict[str, Any]],
+ reasoning_parts: list[str],
+ openai_messages: list[dict[str, Any]],
+ ) -> None:
+ """Convert individual content block"""
+ if block.type == "text" and block.text:
+ content_parts.append({"type": "text", "text": block.text})
+ elif block.type == "image" and block.source:
+ image_url = cls._convert_image_source_to_url(block.source)
+ content_parts.append({"type": "image_url", "image_url": {"url": image_url}})
+ elif block.type == "thinking" and block.thinking is not None:
+ reasoning_parts.append(block.thinking)
+ elif block.type == "tool_use":
+ cls._convert_tool_use_block(block, tool_calls)
+ elif block.type == "tool_result":
+ cls._convert_tool_result_block(block, role, openai_messages, content_parts)
+
+ @classmethod
+ def _convert_tool_use_block(cls, block, tool_calls: list[dict[str, Any]]) -> None:
+ """Convert tool_use block to OpenAI function call format"""
+ tool_call = {
+ "id": block.id or f"call_{int(time.time())}",
+ "type": "function",
+ "function": {
+ "name": block.name or "",
+ "arguments": json.dumps(block.input or {}),
+ },
+ }
+ tool_calls.append(tool_call)
+
+ @classmethod
+ def _convert_tool_result_block(
+ cls,
+ block,
+ role: str,
+ openai_messages: list[dict[str, Any]],
+ content_parts: list[dict[str, Any]],
+ ) -> None:
+ """Convert tool_result block to OpenAI format"""
+ if role == "user":
+ cls._convert_user_tool_result(block, openai_messages)
+ else:
+ tool_result_text = str(block.content) if block.content else ""
+ content_parts.append(
+ {"type": "text", "text": f"Tool result: {tool_result_text}"}
+ )
+
+ @classmethod
+ def _convert_user_tool_result(
+ cls, block, openai_messages: list[dict[str, Any]]
+ ) -> None:
+ """Convert user tool_result with text and image support"""
+ tool_text = ""
+ tool_image_urls: list[str] = []
+
+ if isinstance(block.content, str):
+ tool_text = block.content
+ elif isinstance(block.content, list):
+ text_parts: list[str] = []
+ for item in block.content:
+ if not isinstance(item, dict):
+ continue
+ item_type = item.get("type")
+ if item_type == "text":
+ text_parts.append(item.get("text", ""))
+ elif item_type == "image":
+ source = item.get("source", {})
+ url = cls._convert_image_source_to_url(source)
+ if url:
+ tool_image_urls.append(url)
+ tool_text = "\n".join(text_parts)
+
+ openai_messages.append(
+ {
+ "role": "tool",
+ "tool_call_id": block.tool_use_id or "",
+ "content": tool_text or "",
+ }
+ )
+
+ if tool_image_urls:
+ openai_messages.append(
+ {
+ "role": "user",
+ "content": [ # type: ignore[dict-item]
+ {"type": "image_url", "image_url": {"url": img}}
+ for img in tool_image_urls
+ ],
+ }
+ )
+
+ @classmethod
+ def _build_base_request(
+ cls,
+ anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
+ openai_messages: list[dict[str, Any]],
+ ) -> ChatCompletionRequest:
+ """Build base ChatCompletionRequest"""
+ if isinstance(anthropic_request, AnthropicCountTokensRequest):
+ return ChatCompletionRequest(
+ model=anthropic_request.model,
+ messages=openai_messages,
+ )
+
+ return ChatCompletionRequest(
model=anthropic_request.model,
messages=openai_messages,
max_tokens=anthropic_request.max_tokens,
@@ -183,19 +318,40 @@ class AnthropicServingMessages(OpenAIServingChat):
top_k=anthropic_request.top_k,
)
+ @classmethod
+ def _handle_streaming_options(
+ cls,
+ req: ChatCompletionRequest,
+ anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
+ ) -> None:
+ """Handle streaming configuration"""
+ if isinstance(anthropic_request, AnthropicCountTokensRequest):
+ return
if anthropic_request.stream:
req.stream = anthropic_request.stream
- req.stream_options = StreamOptions.validate(
+ req.stream_options = StreamOptions.model_validate(
{"include_usage": True, "continuous_usage_stats": True}
)
+ @classmethod
+ def _convert_tool_choice(
+ cls,
+ anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
+ req: ChatCompletionRequest,
+ ) -> None:
+ """Convert Anthropic tool_choice to OpenAI format"""
if anthropic_request.tool_choice is None:
req.tool_choice = None
- elif anthropic_request.tool_choice.type == "auto":
+ return
+
+ tool_choice_type = anthropic_request.tool_choice.type
+ if tool_choice_type == "auto":
req.tool_choice = "auto"
- elif anthropic_request.tool_choice.type == "any":
+ elif tool_choice_type == "any":
req.tool_choice = "required"
- elif anthropic_request.tool_choice.type == "tool":
+ elif tool_choice_type == "none":
+ req.tool_choice = "none"
+ elif tool_choice_type == "tool":
req.tool_choice = ChatCompletionNamedToolChoiceParam.model_validate(
{
"type": "function",
@@ -203,9 +359,17 @@ class AnthropicServingMessages(OpenAIServingChat):
}
)
- tools = []
+ @classmethod
+ def _convert_tools(
+ cls,
+ anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest,
+ req: ChatCompletionRequest,
+ ) -> None:
+ """Convert Anthropic tools to OpenAI format"""
if anthropic_request.tools is None:
- return req
+ return
+
+ tools = []
for tool in anthropic_request.tools:
tools.append(
ChatCompletionToolsParam.model_validate(
@@ -219,10 +383,10 @@ class AnthropicServingMessages(OpenAIServingChat):
}
)
)
+
if req.tool_choice is None:
req.tool_choice = "auto"
req.tools = tools
- return req
async def create_messages(
self,
@@ -263,23 +427,32 @@ class AnthropicServingMessages(OpenAIServingChat):
output_tokens=generator.usage.completion_tokens,
),
)
- if generator.choices[0].finish_reason == "stop":
+ choice = generator.choices[0]
+ if choice.finish_reason == "stop":
result.stop_reason = "end_turn"
- elif generator.choices[0].finish_reason == "length":
+ elif choice.finish_reason == "length":
result.stop_reason = "max_tokens"
- elif generator.choices[0].finish_reason == "tool_calls":
+ elif choice.finish_reason == "tool_calls":
result.stop_reason = "tool_use"
- content: list[AnthropicContentBlock] = [
- AnthropicContentBlock(
- type="text",
- text=generator.choices[0].message.content
- if generator.choices[0].message.content
- else "",
+ content: list[AnthropicContentBlock] = []
+ if choice.message.reasoning:
+ content.append(
+ AnthropicContentBlock(
+ type="thinking",
+ thinking=choice.message.reasoning,
+ signature=uuid.uuid4().hex,
+ )
+ )
+ if choice.message.content:
+ content.append(
+ AnthropicContentBlock(
+ type="text",
+ text=choice.message.content,
+ )
)
- ]
- for tool_call in generator.choices[0].message.tool_calls:
+ for tool_call in choice.message.tool_calls:
anthropic_tool_call = AnthropicContentBlock(
type="tool_use",
id=tool_call.id,
@@ -297,10 +470,85 @@ class AnthropicServingMessages(OpenAIServingChat):
generator: AsyncGenerator[str, None],
) -> AsyncGenerator[str, None]:
try:
+
+ class _ActiveBlockState:
+ def __init__(self) -> None:
+ self.content_block_index = 0
+ self.block_type: str | None = None
+ self.block_index: int | None = None
+ self.block_signature: str | None = None
+ self.signature_emitted: bool = False
+ self.tool_use_id: str | None = None
+
+ def reset(self) -> None:
+ self.block_type = None
+ self.block_index = None
+ self.block_signature = None
+ self.signature_emitted = False
+ self.tool_use_id = None
+
+ def start(self, block: AnthropicContentBlock) -> None:
+ self.block_type = block.type
+ self.block_index = self.content_block_index
+ if block.type == "thinking":
+ self.block_signature = uuid.uuid4().hex
+ self.signature_emitted = False
+ self.tool_use_id = None
+ elif block.type == "tool_use":
+ self.block_signature = None
+ self.signature_emitted = True
+ self.tool_use_id = block.id
+ else:
+ self.block_signature = None
+ self.signature_emitted = True
+ self.tool_use_id = None
+
first_item = True
finish_reason = None
- content_block_index = 0
- content_block_started = False
+ state = _ActiveBlockState()
+ # Map from tool call index to tool_use_id
+ tool_index_to_id: dict[int, str] = {}
+
+ def stop_active_block():
+ events: list[str] = []
+ if state.block_type is None:
+ return events
+ if (
+ state.block_type == "thinking"
+ and state.block_signature is not None
+ and not state.signature_emitted
+ ):
+ chunk = AnthropicStreamEvent(
+ index=state.block_index,
+ type="content_block_delta",
+ delta=AnthropicDelta(
+ type="signature_delta",
+ signature=state.block_signature,
+ ),
+ )
+ data = chunk.model_dump_json(exclude_unset=True)
+ events.append(wrap_data_with_event(data, "content_block_delta"))
+ state.signature_emitted = True
+ stop_chunk = AnthropicStreamEvent(
+ index=state.block_index,
+ type="content_block_stop",
+ )
+ data = stop_chunk.model_dump_json(exclude_unset=True)
+ events.append(wrap_data_with_event(data, "content_block_stop"))
+ state.reset()
+ state.content_block_index += 1
+ return events
+
+ def start_block(block: AnthropicContentBlock):
+ chunk = AnthropicStreamEvent(
+ index=state.content_block_index,
+ type="content_block_start",
+ content_block=block,
+ )
+ data = chunk.model_dump_json(exclude_unset=True)
+ event = wrap_data_with_event(data, "content_block_start")
+ state.start(block)
+ return event
async for item in generator:
if item.startswith("data:"):
@@ -326,6 +574,8 @@ class AnthropicServingMessages(OpenAIServingChat):
id=origin_chunk.id,
content=[],
model=origin_chunk.model,
+ stop_reason=None,
+ stop_sequence=None,
usage=AnthropicUsage(
input_tokens=origin_chunk.usage.prompt_tokens
if origin_chunk.usage
@@ -341,13 +591,8 @@ class AnthropicServingMessages(OpenAIServingChat):
# last chunk including usage info
if len(origin_chunk.choices) == 0:
- if content_block_started:
- stop_chunk = AnthropicStreamEvent(
- index=content_block_index,
- type="content_block_stop",
- )
- data = stop_chunk.model_dump_json(exclude_unset=True)
- yield wrap_data_with_event(data, "content_block_stop")
+ for event in stop_active_block():
+ yield event
stop_reason = self.stop_reason_map.get(
finish_reason or "stop"
)
@@ -369,96 +614,139 @@ class AnthropicServingMessages(OpenAIServingChat):
if origin_chunk.choices[0].finish_reason is not None:
finish_reason = origin_chunk.choices[0].finish_reason
- continue
-
- # content
- if origin_chunk.choices[0].delta.content is not None:
- if not content_block_started:
- chunk = AnthropicStreamEvent(
- index=content_block_index,
- type="content_block_start",
- content_block=AnthropicContentBlock(
- type="text", text=""
- ),
- )
- data = chunk.model_dump_json(exclude_unset=True)
- yield wrap_data_with_event(data, "content_block_start")
- content_block_started = True
-
- if origin_chunk.choices[0].delta.content == "":
- continue
- chunk = AnthropicStreamEvent(
- index=content_block_index,
- type="content_block_delta",
- delta=AnthropicDelta(
- type="text_delta",
- text=origin_chunk.choices[0].delta.content,
- ),
- )
- data = chunk.model_dump_json(exclude_unset=True)
- yield wrap_data_with_event(data, "content_block_delta")
- continue
-
- # tool calls
- elif len(origin_chunk.choices[0].delta.tool_calls) > 0:
- tool_call = origin_chunk.choices[0].delta.tool_calls[0]
- if tool_call.id is not None:
- if content_block_started:
- stop_chunk = AnthropicStreamEvent(
- index=content_block_index,
- type="content_block_stop",
- )
- data = stop_chunk.model_dump_json(
- exclude_unset=True
- )
- yield wrap_data_with_event(
- data, "content_block_stop"
- )
- content_block_started = False
- content_block_index += 1
-
- chunk = AnthropicStreamEvent(
- index=content_block_index,
- type="content_block_start",
- content_block=AnthropicContentBlock(
- type="tool_use",
- id=tool_call.id,
- name=tool_call.function.name
- if tool_call.function
- else None,
- input={},
- ),
- )
- data = chunk.model_dump_json(exclude_unset=True)
- yield wrap_data_with_event(data, "content_block_start")
- content_block_started = True
- if tool_call.function and tool_call.function.arguments:
- chunk = AnthropicStreamEvent(
- index=content_block_index,
- type="content_block_delta",
- delta=AnthropicDelta(
- type="input_json_delta",
- partial_json=tool_call.function.arguments,
- ),
- )
- data = chunk.model_dump_json(exclude_unset=True)
- yield wrap_data_with_event(
- data, "content_block_delta"
- )
+ # continue
+ # thinking / text content
+ reasoning_delta = origin_chunk.choices[0].delta.reasoning
+ if reasoning_delta is not None:
+ if reasoning_delta == "":
+ pass
else:
+ if state.block_type != "thinking":
+ for event in stop_active_block():
+ yield event
+ start_event = start_block(
+ AnthropicContentBlock(
+ type="thinking", thinking=""
+ )
+ )
+ yield start_event
chunk = AnthropicStreamEvent(
- index=content_block_index,
+ index=(
+ state.block_index
+ if state.block_index is not None
+ else state.content_block_index
+ ),
type="content_block_delta",
delta=AnthropicDelta(
- type="input_json_delta",
- partial_json=tool_call.function.arguments
- if tool_call.function
- else None,
+ type="thinking_delta",
+ thinking=reasoning_delta,
),
)
data = chunk.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "content_block_delta")
+
+ if origin_chunk.choices[0].delta.content is not None:
+ if origin_chunk.choices[0].delta.content == "":
+ pass
+ else:
+ if state.block_type != "text":
+ for event in stop_active_block():
+ yield event
+ start_event = start_block(
+ AnthropicContentBlock(type="text", text="")
+ )
+ yield start_event
+ chunk = AnthropicStreamEvent(
+ index=(
+ state.block_index
+ if state.block_index is not None
+ else state.content_block_index
+ ),
+ type="content_block_delta",
+ delta=AnthropicDelta(
+ type="text_delta",
+ text=origin_chunk.choices[0].delta.content,
+ ),
+ )
+ data = chunk.model_dump_json(exclude_unset=True)
+ yield wrap_data_with_event(data, "content_block_delta")
+
+ # tool calls - process all tool calls in the delta
+ if len(origin_chunk.choices[0].delta.tool_calls) > 0:
+ for tool_call in origin_chunk.choices[0].delta.tool_calls:
+ if tool_call.id is not None:
+ # Update mapping for incremental updates
+ tool_index_to_id[tool_call.index] = tool_call.id
+ # Only create new block if different tool call
+ # AND has a name
+ tool_name = (
+ tool_call.function.name
+ if tool_call.function
+ else None
+ )
+ if (
+ state.tool_use_id != tool_call.id
+ and tool_name is not None
+ ):
+ for event in stop_active_block():
+ yield event
+ start_event = start_block(
+ AnthropicContentBlock(
+ type="tool_use",
+ id=tool_call.id,
+ name=tool_name,
+ input={},
+ )
+ )
+ yield start_event
+ # Handle initial arguments if present
+ if (
+ tool_call.function
+ and tool_call.function.arguments
+ and state.tool_use_id == tool_call.id
+ ):
+ chunk = AnthropicStreamEvent(
+ index=(
+ state.block_index
+ if state.block_index is not None
+ else state.content_block_index
+ ),
+ type="content_block_delta",
+ delta=AnthropicDelta(
+ type="input_json_delta",
+ partial_json=tool_call.function.arguments,
+ ),
+ )
+ data = chunk.model_dump_json(exclude_unset=True)
+ yield wrap_data_with_event(
+ data, "content_block_delta"
+ )
+ else:
+ # Incremental update - use index to find tool_use_id
+ tool_use_id = tool_index_to_id.get(tool_call.index)
+ if (
+ tool_use_id is not None
+ and tool_call.function
+ and tool_call.function.arguments
+ and state.tool_use_id == tool_use_id
+ ):
+ chunk = AnthropicStreamEvent(
+ index=(
+ state.block_index
+ if state.block_index is not None
+ else state.content_block_index
+ ),
+ type="content_block_delta",
+ delta=AnthropicDelta(
+ type="input_json_delta",
+ partial_json=tool_call.function.arguments,
+ ),
+ )
+ data = chunk.model_dump_json(exclude_unset=True)
+ yield wrap_data_with_event(
+ data, "content_block_delta"
+ )
continue
else:
error_response = AnthropicStreamEvent(
@@ -481,3 +769,31 @@ class AnthropicServingMessages(OpenAIServingChat):
data = error_response.model_dump_json(exclude_unset=True)
yield wrap_data_with_event(data, "error")
yield "data: [DONE]\n\n"
+
+ async def count_tokens(
+ self,
+ request: AnthropicCountTokensRequest,
+ raw_request: Request | None = None,
+ ) -> AnthropicCountTokensResponse | ErrorResponse:
+ """Implements Anthropic's messages.count_tokens endpoint."""
+ chat_req = self._convert_anthropic_to_openai_request(request)
+ result = await self.render_chat_request(chat_req)
+ if isinstance(result, ErrorResponse):
+ return result
+
+ _, engine_prompts = result
+
+ input_tokens = sum( # type: ignore
+ len(prompt["prompt_token_ids"]) # type: ignore[typeddict-item, misc]
+ for prompt in engine_prompts
+ if "prompt_token_ids" in prompt
+ )
+
+ response = AnthropicCountTokensResponse(
+ input_tokens=input_tokens,
+ context_management=AnthropicContextManagement(
+ original_input_tokens=input_tokens
+ ),
+ )
+
+ return response
diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py
index c48d7be..1d10aa6 100644
--- a/vllm/entrypoints/chat_utils.py
+++ b/vllm/entrypoints/chat_utils.py
@@ -7,6 +7,7 @@ import warnings
from abc import ABC, abstractmethod
from collections import Counter, defaultdict
from collections.abc import Awaitable, Callable, Iterable
+from dataclasses import dataclass
from functools import cached_property, lru_cache, partial
from itertools import accumulate
from pathlib import Path
@@ -1024,6 +1025,13 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser):
self._add_placeholder("video", placeholder)
+@dataclass
+class ChatTemplateConfig:
+ chat_template: str | None = None
+ chat_template_content_format: ChatTemplateContentFormatOption = "auto"
+ trust_request_chat_template: bool = False
+
+
def validate_chat_template(chat_template: Path | str | None):
"""Raises if the provided chat template appears invalid."""
if chat_template is None:
diff --git a/vllm/entrypoints/cli/__init__.py b/vllm/entrypoints/cli/__init__.py
index ff5a928..704d94d 100644
--- a/vllm/entrypoints/cli/__init__.py
+++ b/vllm/entrypoints/cli/__init__.py
@@ -1,12 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from vllm.entrypoints.cli.benchmark.latency import BenchmarkLatencySubcommand
+from vllm.entrypoints.cli.benchmark.mm_processor import (
+ BenchmarkMMProcessorSubcommand,
+)
+from vllm.entrypoints.cli.benchmark.serve import BenchmarkServingSubcommand
+from vllm.entrypoints.cli.benchmark.startup import BenchmarkStartupSubcommand
+from vllm.entrypoints.cli.benchmark.sweep import BenchmarkSweepSubcommand
+from vllm.entrypoints.cli.benchmark.throughput import BenchmarkThroughputSubcommand
-# Keep this package init import-free.
-#
-# The `vllm` console script imports `vllm.entrypoints.cli.main`, which causes
-# Python to import this package before loading the `main` submodule.
-# Eagerly importing benchmark subcommands here makes every `vllm serve ...`
-# startup depend on optional benchmark-only modules.
-#
-# Benchmark subcommands are loaded on demand in
-# `vllm.entrypoints.cli.benchmark.main`.
+__all__: list[str] = [
+ "BenchmarkLatencySubcommand",
+ "BenchmarkMMProcessorSubcommand",
+ "BenchmarkServingSubcommand",
+ "BenchmarkStartupSubcommand",
+ "BenchmarkSweepSubcommand",
+ "BenchmarkThroughputSubcommand",
+]
diff --git a/vllm/entrypoints/cli/benchmark/main.py b/vllm/entrypoints/cli/benchmark/main.py
index ae490ba..48f34fc 100644
--- a/vllm/entrypoints/cli/benchmark/main.py
+++ b/vllm/entrypoints/cli/benchmark/main.py
@@ -2,8 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
-import importlib
-import logging
import typing
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
@@ -15,30 +13,6 @@ if typing.TYPE_CHECKING:
else:
FlexibleArgumentParser = argparse.ArgumentParser
-logger = logging.getLogger(__name__)
-
-
-def _load_benchmark_subcommands() -> None:
- modules = [
- "vllm.entrypoints.cli.benchmark.latency",
- "vllm.entrypoints.cli.benchmark.mm_processor",
- "vllm.entrypoints.cli.benchmark.serve",
- "vllm.entrypoints.cli.benchmark.startup",
- "vllm.entrypoints.cli.benchmark.sweep",
- "vllm.entrypoints.cli.benchmark.throughput",
- ]
-
- for module_name in modules:
- try:
- importlib.import_module(module_name)
- except ModuleNotFoundError as e:
- logger.warning(
- "Skipping benchmark subcommand module %s because an optional "
- "dependency could not be imported: %r",
- module_name,
- e,
- )
-
class BenchmarkSubcommand(CLISubcommand):
"""The `bench` subcommand for the vLLM CLI."""
@@ -64,8 +38,6 @@ class BenchmarkSubcommand(CLISubcommand):
)
bench_subparsers = bench_parser.add_subparsers(required=True, dest="bench_type")
- _load_benchmark_subcommands()
-
for cmd_cls in BenchmarkSubcommandBase.__subclasses__():
cmd_subparser = bench_subparsers.add_parser(
cmd_cls.name,
diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py
index c12cc7f..944fb88 100644
--- a/vllm/entrypoints/cli/serve.py
+++ b/vllm/entrypoints/cli/serve.py
@@ -220,6 +220,12 @@ def run_multi_api_server(args: argparse.Namespace):
num_api_servers: int = args.api_server_count
assert num_api_servers > 0
+ if num_api_servers > 1 and getattr(args, "use_gpu_for_pooling_score", False):
+ # TODO(wentao): remove this once well tested
+ raise ValueError(
+ "--use-gpu-for-pooling-score cannot be used with api_server_count > 1 now"
+ )
+
if num_api_servers > 1:
setup_multiprocess_prometheus()
@@ -246,8 +252,12 @@ def run_multi_api_server(args: argparse.Namespace):
api_server_manager: APIServerProcessManager | None = None
+ from vllm.v1.engine.utils import get_engine_zmq_addresses
+
+ addresses = get_engine_zmq_addresses(vllm_config, num_api_servers)
+
with launch_core_engines(
- vllm_config, executor_class, log_stats, num_api_servers
+ vllm_config, executor_class, log_stats, addresses, num_api_servers
) as (local_engine_manager, coordinator, addresses):
# Construct common args for the APIServerProcessManager up-front.
api_server_manager_kwargs = dict(
diff --git a/vllm/entrypoints/grpc_server.py b/vllm/entrypoints/grpc_server.py
index 1fc3354..ec8f480 100644
--- a/vllm/entrypoints/grpc_server.py
+++ b/vllm/entrypoints/grpc_server.py
@@ -101,11 +101,15 @@ class VllmEngineServicer(vllm_engine_pb2_grpc.VllmEngineServicer):
sampling_params = self._sampling_params_from_proto(
request.sampling_params, stream=request.stream
)
+ tokenization_kwargs = self._tokenization_kwargs_from_proto(
+ request.sampling_params
+ )
async for output in self.async_llm.generate(
prompt=prompt,
sampling_params=sampling_params,
request_id=request_id,
+ tokenization_kwargs=tokenization_kwargs,
):
# Convert vLLM output to protobuf
# For streaming, always send chunks
@@ -308,9 +312,6 @@ class VllmEngineServicer(vllm_engine_pb2_grpc.VllmEngineServicer):
seed=params.seed if params.HasField("seed") else None,
include_stop_str_in_output=params.include_stop_str_in_output,
logit_bias=dict(params.logit_bias) if params.logit_bias else None,
- truncate_prompt_tokens=params.truncate_prompt_tokens
- if params.HasField("truncate_prompt_tokens")
- else None,
structured_outputs=structured_outputs,
# detokenize must be True if stop strings are used
detokenize=bool(stop),
@@ -319,6 +320,14 @@ class VllmEngineServicer(vllm_engine_pb2_grpc.VllmEngineServicer):
else RequestOutputKind.FINAL_ONLY,
)
+ @staticmethod
+ def _tokenization_kwargs_from_proto(
+ params: vllm_engine_pb2.SamplingParams,
+ ) -> dict[str, int] | None:
+ if params.HasField("truncate_prompt_tokens"):
+ return {"truncate_prompt_tokens": params.truncate_prompt_tokens}
+ return None
+
@staticmethod
def _chunk_response(output: RequestOutput) -> vllm_engine_pb2.GenerateResponse:
"""
diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py
index ee78d4d..d5a51a6 100644
--- a/vllm/entrypoints/llm.py
+++ b/vllm/entrypoints/llm.py
@@ -2,8 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
-import warnings
from collections.abc import Callable, Iterable, Sequence
+from pathlib import Path
from typing import TYPE_CHECKING, Any
import cloudpickle
@@ -41,8 +41,11 @@ from vllm.distributed.weight_transfer.base import (
from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
+ ChatTemplateConfig,
ChatTemplateContentFormatOption,
+ load_chat_template,
)
+from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors
from vllm.entrypoints.pooling.score.utils import (
ScoreData,
ScoreMultiModalParam,
@@ -146,6 +149,7 @@ class LLM:
a tag name, or a commit id.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id.
+ chat_template: The chat template to apply.
seed: The seed to initialize the random number generator for sampling.
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
reserve for the model weights, activations, and KV cache. Higher
@@ -233,6 +237,7 @@ class LLM:
quantization: QuantizationMethods | None = None,
revision: str | None = None,
tokenizer_revision: str | None = None,
+ chat_template: Path | str | None = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: float = 4,
@@ -385,9 +390,16 @@ class LLM:
self.model_config = self.llm_engine.model_config
self.renderer = self.llm_engine.renderer
+ self.chat_template = load_chat_template(chat_template)
self.io_processor = self.llm_engine.io_processor
self.input_processor = self.llm_engine.input_processor
-
+ self.chat_template_config = ChatTemplateConfig(chat_template=self.chat_template)
+ self.init_pooling_io_processors = init_pooling_io_processors(
+ supported_tasks=supported_tasks,
+ model_config=self.model_config,
+ renderer=self.renderer,
+ chat_template_config=self.chat_template_config,
+ )
# Cache for __repr__ to avoid repeated collective_rpc calls
self._cached_repr: str | None = None
@@ -1030,7 +1042,6 @@ class LLM:
prompts: PromptType | Sequence[PromptType] | DataPrompt,
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
*,
- truncate_prompt_tokens: int | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: list[LoRARequest] | LoRARequest | None = None,
pooling_task: PoolingTask | None = None,
@@ -1088,21 +1099,7 @@ class LLM:
"pooling model."
)
- if truncate_prompt_tokens is not None:
- warnings.warn(
- "The `truncate_prompt_tokens` parameter in `LLM.encode()` "
- "is deprecated and will be removed in v0.16. "
- "Please pass it via `tokenization_kwargs` instead.",
- DeprecationWarning,
- stacklevel=2,
- )
-
- tokenization_kwargs = merge_kwargs(
- tokenization_kwargs,
- dict(truncate_prompt_tokens=truncate_prompt_tokens),
- )
-
- if use_io_processor := (isinstance(prompts, dict) and "data" in prompts):
+ if isinstance(prompts, dict) and "data" in prompts:
if self.io_processor is None:
raise ValueError(
"No IOProcessor plugin installed. Please refer "
@@ -1136,6 +1133,31 @@ class LLM:
for p in params_seq:
if p.task is None:
p.task = "plugin"
+
+ outputs = self._run_completion(
+ prompts=prompts_seq,
+ params=params_seq,
+ output_type=PoolingRequestOutput,
+ use_tqdm=use_tqdm,
+ lora_request=lora_request,
+ tokenization_kwargs=tokenization_kwargs,
+ )
+
+ # get the post-processed model outputs
+ assert self.io_processor is not None
+ processed_outputs = self.io_processor.post_process(outputs)
+
+ return [
+ PoolingRequestOutput[Any](
+ request_id="",
+ outputs=processed_outputs,
+ num_cached_tokens=getattr(
+ processed_outputs, "num_cached_tokens", 0
+ ),
+ prompt_token_ids=[],
+ finished=True,
+ )
+ ]
else:
if pooling_params is None:
# Use default pooling params.
@@ -1153,39 +1175,42 @@ class LLM:
)
raise ValueError(msg)
- outputs = self._run_completion(
- prompts=prompts_seq,
- params=params_seq,
- output_type=PoolingRequestOutput,
- use_tqdm=use_tqdm,
- lora_request=lora_request,
- tokenization_kwargs=tokenization_kwargs,
- )
-
- if use_io_processor:
- # get the post-processed model outputs
- assert self.io_processor is not None
- processed_outputs = self.io_processor.post_process(outputs)
-
- return [
- PoolingRequestOutput[Any](
- request_id="",
- outputs=processed_outputs,
- num_cached_tokens=getattr(
- processed_outputs, "num_cached_tokens", 0
- ),
- prompt_token_ids=[],
- finished=True,
+ if pooling_task in self.init_pooling_io_processors:
+ io_processor = self.init_pooling_io_processors[pooling_task]
+ processor_inputs = io_processor.pre_process_offline(
+ prompts_seq, tokenization_kwargs
)
- ]
+ seq_lora_requests = self._lora_request_to_seq(
+ lora_request, len(prompts_seq)
+ )
+ seq_priority = self._priority_to_seq(None, len(prompts))
+ self._render_and_add_requests(
+ prompts=processor_inputs,
+ params=params_seq,
+ lora_requests=seq_lora_requests,
+ priorities=seq_priority,
+ )
+
+ outputs = self._run_engine(
+ use_tqdm=use_tqdm, output_type=PoolingRequestOutput
+ )
+ outputs = io_processor.post_process(outputs)
+ else:
+ outputs = self._run_completion(
+ prompts=prompts_seq,
+ params=params_seq,
+ output_type=PoolingRequestOutput,
+ use_tqdm=use_tqdm,
+ lora_request=lora_request,
+ tokenization_kwargs=tokenization_kwargs,
+ )
return outputs
def embed(
self,
prompts: PromptType | Sequence[PromptType],
*,
- truncate_prompt_tokens: int | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
lora_request: list[LoRARequest] | LoRARequest | None = None,
@@ -1221,12 +1246,6 @@ class LLM:
"Try converting the model using `--convert embed`."
)
- if truncate_prompt_tokens is not None:
- tokenization_kwargs = merge_kwargs(
- tokenization_kwargs,
- dict(truncate_prompt_tokens=truncate_prompt_tokens),
- )
-
items = self.encode(
prompts,
use_tqdm=use_tqdm,
@@ -1294,7 +1313,6 @@ class LLM:
/,
*,
pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
- truncate_prompt_tokens: int | None = None,
use_tqdm: bool | Callable[..., tqdm] = True,
lora_request: list[LoRARequest] | LoRARequest | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
@@ -1319,13 +1337,11 @@ class LLM:
A list of `PoolingRequestOutput` objects containing the
pooled hidden states in the same order as the input prompts.
"""
-
return self.encode(
prompts,
use_tqdm=use_tqdm,
lora_request=lora_request,
pooling_params=pooling_params,
- truncate_prompt_tokens=truncate_prompt_tokens,
pooling_task="token_classify",
tokenization_kwargs=tokenization_kwargs,
)
@@ -1771,23 +1787,15 @@ class LLM:
seq_prompts = prompt_to_seq(prompts)
seq_params = self._params_to_seq(params, len(seq_prompts))
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
- seq_tok_kwargs = [
- merge_kwargs(
- tokenization_kwargs,
- dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
- )
- for param in seq_params
- ]
seq_priority = self._priority_to_seq(priority, len(prompts))
return self._render_and_add_requests(
prompts=(
- self._preprocess_cmpl_one(prompt, tok_kwargs)
- for prompt, tok_kwargs in zip(
- maybe_tqdm(
- seq_prompts, use_tqdm=use_tqdm, desc="Rendering prompts"
- ),
- seq_tok_kwargs,
+ self._preprocess_cmpl_one(prompt, tokenization_kwargs)
+ for prompt in maybe_tqdm(
+ seq_prompts,
+ use_tqdm=use_tqdm,
+ desc="Rendering prompts",
)
),
params=seq_params,
@@ -1841,13 +1849,6 @@ class LLM:
seq_convs = conversation_to_seq(messages)
seq_params = self._params_to_seq(params, len(seq_convs))
seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_convs))
- seq_tok_kwargs = [
- merge_kwargs(
- tokenization_kwargs,
- dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
- )
- for param in seq_params
- ]
return self._render_and_run_requests(
prompts=(
@@ -1859,16 +1860,13 @@ class LLM:
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
- tokenization_kwargs=tok_kwargs,
+ tokenization_kwargs=tokenization_kwargs,
mm_processor_kwargs=mm_processor_kwargs,
)
- for conversation, tok_kwargs in zip(
- maybe_tqdm(
- seq_convs,
- use_tqdm=use_tqdm,
- desc="Rendering conversations",
- ),
- seq_tok_kwargs,
+ for conversation in maybe_tqdm(
+ seq_convs,
+ use_tqdm=use_tqdm,
+ desc="Rendering conversations",
)
),
params=seq_params,
diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py
index c9e8093..c2a77fb 100644
--- a/vllm/entrypoints/logger.py
+++ b/vllm/entrypoints/logger.py
@@ -18,6 +18,20 @@ class RequestLogger:
def __init__(self, *, max_log_len: int | None) -> None:
self.max_log_len = max_log_len
+ if not logger.isEnabledFor(logging.INFO):
+ logger.warning_once(
+ "`--enable-log-requests` is set but "
+ "the minimum log level is higher than INFO. "
+ "No request information will be logged."
+ )
+ elif not logger.isEnabledFor(logging.DEBUG):
+ logger.info_once(
+ "`--enable-log-requests` is set but "
+ "the minimum log level is higher than DEBUG. "
+ "Only limited information will be logged to minimize overhead. "
+ "To view more details, set `VLLM_LOGGING_LEVEL=DEBUG`."
+ )
+
def log_inputs(
self,
request_id: str,
diff --git a/vllm/entrypoints/openai/chat_completion/protocol.py b/vllm/entrypoints/openai/chat_completion/protocol.py
index 1bf0de5..0abe85a 100644
--- a/vllm/entrypoints/openai/chat_completion/protocol.py
+++ b/vllm/entrypoints/openai/chat_completion/protocol.py
@@ -38,6 +38,7 @@ from vllm.logprobs import Logprob
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
from vllm.sampling_params import (
BeamSearchParams,
+ RepetitionDetectionParams,
RequestOutputKind,
SamplingParams,
StructuredOutputsParams,
@@ -336,6 +337,16 @@ class ChatCompletionRequest(OpenAIBaseModel):
),
)
+ repetition_detection: RepetitionDetectionParams | None = Field(
+ default=None,
+ description="Parameters for detecting repetitive N-gram patterns "
+ "in output tokens. If such repetition is detected, generation will "
+ "be ended early. LLMs can sometimes generate repetitive, unhelpful "
+ "token patterns, stopping only when they hit the maximum output length "
+ "(e.g. 'abcdabcdabcd...' or '\emoji \emoji \emoji ...'). This feature "
+ "can detect such behavior and terminate early, saving time and tokens.",
+ )
+
# --8<-- [end:chat-completion-extra-params]
def build_chat_params(
@@ -490,7 +501,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
- truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA
if self.stream
else RequestOutputKind.FINAL_ONLY,
@@ -500,8 +510,37 @@ class ChatCompletionRequest(OpenAIBaseModel):
allowed_token_ids=self.allowed_token_ids,
extra_args=extra_args or None,
skip_clone=True, # Created fresh per request, safe to skip clone
+ repetition_detection=self.repetition_detection,
)
+ @model_validator(mode="before")
+ @classmethod
+ def validate_response_format(cls, data):
+ response_format = data.get("response_format")
+ if response_format is None:
+ return data
+
+ rf_type = (
+ response_format.get("type")
+ if isinstance(response_format, dict)
+ else getattr(response_format, "type", None)
+ )
+
+ if rf_type == "json_schema":
+ json_schema = (
+ response_format.get("json_schema")
+ if isinstance(response_format, dict)
+ else getattr(response_format, "json_schema", None)
+ )
+ if json_schema is None:
+ raise VLLMValidationError(
+ "When response_format type is 'json_schema', the "
+ "'json_schema' field must be provided.",
+ parameter="response_format",
+ )
+
+ return data
+
@model_validator(mode="before")
@classmethod
def validate_stream_options(cls, data):
diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py
index 39f8635..06b16cd 100644
--- a/vllm/entrypoints/openai/chat_completion/serving.py
+++ b/vllm/entrypoints/openai/chat_completion/serving.py
@@ -1249,13 +1249,23 @@ class OpenAIServingChat(OpenAIServing):
)
# get the expected call based on partial JSON
- # parsing which "autocompletes" the JSON
- expected_call = json.dumps(
- tool_parser.prev_tool_call_arr[index].get(
- "arguments", {}
- ),
- ensure_ascii=False,
+ # parsing which "autocompletes" the JSON.
+ # Tool parsers (e.g. Qwen3Coder) store
+ # arguments as a JSON string in
+ # prev_tool_call_arr. Calling json.dumps()
+ # on an already-serialized string would
+ # double-serialize it (e.g. '{"k":1}' becomes
+ # '"{\\"k\\":1}"'), which then causes the
+ # replace() below to fail and append the
+ # entire double-serialized string as a
+ # spurious final delta.
+ args = tool_parser.prev_tool_call_arr[index].get(
+ "arguments", {}
)
+ if isinstance(args, str):
+ expected_call = args
+ else:
+ expected_call = json.dumps(args, ensure_ascii=False)
# get what we've streamed so far for arguments
# for the current tool
diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py
index eac581e..d3a66c1 100644
--- a/vllm/entrypoints/openai/cli_args.py
+++ b/vllm/entrypoints/openai/cli_args.py
@@ -143,7 +143,8 @@ class BaseFrontendArgs:
templates and other tokenizer configuration."""
enable_log_outputs: bool = False
"""If set to True, log model outputs (generations).
- Requires --enable-log-requests."""
+ Requires `--enable-log-requests`. As with `--enable-log-requests`,
+ information is only logged at INFO level at maximum."""
enable_log_deltas: bool = True
"""If set to False, output deltas will not be logged. Relevant only if
--enable-log-outputs is set.
@@ -277,6 +278,10 @@ class FrontendArgs(BaseFrontendArgs):
Enable offline FastAPI documentation for air-gapped environments.
Uses vendored static assets bundled with vLLM.
"""
+ use_gpu_for_pooling_score: bool = False
+ """If set, run pooling score MaxSim on GPU in the API server process.
+ Can significantly improve late-interaction scoring performance.
+ https://github.com/vllm-project/vllm/pull/35330"""
@classmethod
def _customize_cli_kwargs(
diff --git a/vllm/entrypoints/openai/completion/protocol.py b/vllm/entrypoints/openai/completion/protocol.py
index 226dd6c..af13204 100644
--- a/vllm/entrypoints/openai/completion/protocol.py
+++ b/vllm/entrypoints/openai/completion/protocol.py
@@ -26,6 +26,7 @@ from vllm.logprobs import Logprob
from vllm.renderers import TokenizeParams
from vllm.sampling_params import (
BeamSearchParams,
+ RepetitionDetectionParams,
RequestOutputKind,
SamplingParams,
StructuredOutputsParams,
@@ -166,6 +167,16 @@ class CompletionRequest(OpenAIBaseModel):
),
)
+ repetition_detection: RepetitionDetectionParams | None = Field(
+ default=None,
+ description="Parameters for detecting repetitive N-gram patterns "
+ "in output tokens. If such repetition is detected, generation will "
+ "be ended early. LLMs can sometimes generate repetitive, unhelpful "
+ "token patterns, stopping only when they hit the maximum output length "
+ "(e.g. 'abcdabcdabcd...' or '\emoji \emoji \emoji ...'). This feature "
+ "can detect such behavior and terminate early, saving time and tokens.",
+ )
+
# --8<-- [end:completion-extra-params]
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
@@ -259,7 +270,7 @@ class CompletionRequest(OpenAIBaseModel):
structured_outputs_kwargs["json"] = json_schema.json_schema
elif response_format.type == "structural_tag":
structural_tag = response_format
- assert structural_tag is not None and isinstance(
+ assert isinstance(
structural_tag,
(
LegacyStructuralTagResponseFormat,
@@ -302,7 +313,6 @@ class CompletionRequest(OpenAIBaseModel):
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
- truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA
if self.stream
else RequestOutputKind.FINAL_ONLY,
@@ -311,8 +321,37 @@ class CompletionRequest(OpenAIBaseModel):
allowed_token_ids=self.allowed_token_ids,
extra_args=extra_args or None,
skip_clone=True, # Created fresh per request, safe to skip clone
+ repetition_detection=self.repetition_detection,
)
+ @model_validator(mode="before")
+ @classmethod
+ def validate_response_format(cls, data):
+ response_format = data.get("response_format")
+ if response_format is None:
+ return data
+
+ rf_type = (
+ response_format.get("type")
+ if isinstance(response_format, dict)
+ else getattr(response_format, "type", None)
+ )
+
+ if rf_type == "json_schema":
+ json_schema = (
+ response_format.get("json_schema")
+ if isinstance(response_format, dict)
+ else getattr(response_format, "json_schema", None)
+ )
+ if json_schema is None:
+ raise VLLMValidationError(
+ "When response_format type is 'json_schema', the "
+ "'json_schema' field must be provided.",
+ parameter="response_format",
+ )
+
+ return data
+
@model_validator(mode="before")
@classmethod
def check_structured_outputs_count(cls, data):
diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py
index 3e376ba..e864f56 100644
--- a/vllm/entrypoints/openai/engine/serving.py
+++ b/vllm/entrypoints/openai/engine/serving.py
@@ -62,11 +62,6 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
TranscriptionResponse,
TranslationRequest,
)
-from vllm.entrypoints.pooling.classify.protocol import (
- ClassificationChatRequest,
- ClassificationCompletionRequest,
- ClassificationResponse,
-)
from vllm.entrypoints.pooling.embed.protocol import (
EmbeddingBytesResponse,
EmbeddingChatRequest,
@@ -161,7 +156,6 @@ CompletionLikeRequest: TypeAlias = (
| TokenizeCompletionRequest
| DetokenizeRequest
| EmbeddingCompletionRequest
- | ClassificationCompletionRequest
| RerankRequest
| ScoreRequest
| PoolingCompletionRequest
@@ -171,7 +165,6 @@ ChatLikeRequest: TypeAlias = (
ChatCompletionRequest
| TokenizeChatRequest
| EmbeddingChatRequest
- | ClassificationChatRequest
| PoolingChatRequest
)
@@ -194,12 +187,10 @@ AnyResponse: TypeAlias = (
| TranscriptionResponse
| TokenizeResponse
| PoolingResponse
- | ClassificationResponse
| ScoreResponse
| GenerateResponse
)
-
RequestT = TypeVar("RequestT", bound=AnyRequest)
@@ -223,8 +214,8 @@ class ServeContext(Generic[RequestT]):
class OpenAIServing:
request_id_prefix: ClassVar[str] = """
- A short string prepended to every request’s ID (e.g. "embd", "classify")
- so you can easily tell “this ID came from Embedding vs Classification.”
+ A short string prepended to every request’s ID (e.g. "embd")
+ so you can easily tell “this ID came from Embedding.”
"""
def __init__(
@@ -456,7 +447,7 @@ class OpenAIServing:
) -> ErrorResponse | None:
"""
Default preprocessing hook. Subclasses may override
- to prepare `ctx` (classification, embedding, etc.).
+ to prepare `ctx` (embedding, etc.).
"""
return None
@@ -817,7 +808,7 @@ class OpenAIServing:
token_num = len(input_ids)
max_model_len = self.model_config.max_model_len
- # Note: EmbeddingRequest, ClassificationRequest,
+ # Note: EmbeddingRequest,
# and ScoreRequest doesn't have max_tokens
if isinstance(
request,
@@ -828,8 +819,6 @@ class OpenAIServing:
ScoreTextRequest,
ScoreQueriesDocumentsRequest,
RerankRequest,
- ClassificationCompletionRequest,
- ClassificationChatRequest,
),
):
# Note: input length can be up to the entire model context length
@@ -839,8 +828,6 @@ class OpenAIServing:
ScoreDataRequest: "score",
ScoreTextRequest: "score",
ScoreQueriesDocumentsRequest: "score",
- ClassificationCompletionRequest: "classification",
- ClassificationChatRequest: "classification",
}
operation = operations.get(type(request), "embedding generation")
raise VLLMValidationError(
diff --git a/vllm/entrypoints/openai/responses/protocol.py b/vllm/entrypoints/openai/responses/protocol.py
index b0ffd03..1ec88cc 100644
--- a/vllm/entrypoints/openai/responses/protocol.py
+++ b/vllm/entrypoints/openai/responses/protocol.py
@@ -328,8 +328,9 @@ class ResponsesRequest(OpenAIBaseModel):
# Also check text.format for OpenAI-style json_schema
if self.text is not None and self.text.format is not None:
if structured_outputs is not None:
- raise ValueError(
- "Cannot specify both structured_outputs and text.format"
+ raise VLLMValidationError(
+ "Cannot specify both structured_outputs and text.format",
+ parameter="structured_outputs",
)
response_format = self.text.format
if (
@@ -378,14 +379,19 @@ class ResponsesRequest(OpenAIBaseModel):
)
@model_validator(mode="before")
+ @classmethod
def validate_background(cls, data):
if not data.get("background"):
return data
if not data.get("store", True):
- raise ValueError("background can only be used when `store` is true")
+ raise VLLMValidationError(
+ "background can only be used when `store` is true",
+ parameter="background",
+ )
return data
@model_validator(mode="before")
+ @classmethod
def validate_prompt(cls, data):
if data.get("prompt") is not None:
raise VLLMValidationError(
@@ -394,16 +400,19 @@ class ResponsesRequest(OpenAIBaseModel):
return data
@model_validator(mode="before")
+ @classmethod
def check_cache_salt_support(cls, data):
if data.get("cache_salt") is not None and (
not isinstance(data["cache_salt"], str) or not data["cache_salt"]
):
- raise ValueError(
- "Parameter 'cache_salt' must be a non-empty string if provided."
+ raise VLLMValidationError(
+ "Parameter 'cache_salt' must be a non-empty string if provided.",
+ parameter="cache_salt",
)
return data
@model_validator(mode="before")
+ @classmethod
def function_call_parsing(cls, data):
"""Parse function_call dictionaries into ResponseFunctionToolCall objects.
This ensures Pydantic can properly resolve union types in the input field.
diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py
index b9d526e..3cfb6ff 100644
--- a/vllm/entrypoints/openai/responses/serving.py
+++ b/vllm/entrypoints/openai/responses/serving.py
@@ -85,6 +85,8 @@ from vllm.entrypoints.openai.responses.protocol import (
ResponseCreatedEvent,
ResponseInProgressEvent,
ResponseInputOutputMessage,
+ ResponseReasoningPartAddedEvent,
+ ResponseReasoningPartDoneEvent,
ResponsesRequest,
ResponsesResponse,
ResponseUsage,
@@ -1339,6 +1341,19 @@ class OpenAIServingResponses(OpenAIServing):
),
)
)
+ yield _increment_sequence_number_and_return(
+ ResponseReasoningPartAddedEvent(
+ type="response.reasoning_part.added",
+ sequence_number=-1,
+ output_index=current_output_index,
+ item_id=current_item_id,
+ content_index=current_content_index,
+ part=ResponseReasoningTextContent(
+ text="",
+ type="reasoning_text",
+ ),
+ )
+ )
else:
yield _increment_sequence_number_and_return(
ResponseOutputItemAddedEvent(
@@ -1354,22 +1369,21 @@ class OpenAIServingResponses(OpenAIServing):
),
)
)
- yield _increment_sequence_number_and_return(
- ResponseContentPartAddedEvent(
- type="response.content_part.added",
- sequence_number=-1,
- output_index=current_output_index,
- item_id=current_item_id,
- content_index=current_content_index,
- part=ResponseOutputText(
- type="output_text",
- text="",
- annotations=[],
- logprobs=[],
- ),
+ yield _increment_sequence_number_and_return(
+ ResponseContentPartAddedEvent(
+ type="response.content_part.added",
+ sequence_number=-1,
+ output_index=current_output_index,
+ item_id=current_item_id,
+ content_index=current_content_index,
+ part=ResponseOutputText(
+ type="output_text",
+ text="",
+ annotations=[],
+ logprobs=[],
+ ),
+ )
)
- )
- current_content_index += 1
first_delta_sent = True
# todo(kebe7jun) tool call support
@@ -1397,6 +1411,19 @@ class OpenAIServingResponses(OpenAIServing):
text=reason_content,
)
)
+ yield _increment_sequence_number_and_return(
+ ResponseReasoningPartDoneEvent(
+ type="response.reasoning_part.done",
+ sequence_number=-1,
+ item_id=current_item_id,
+ output_index=current_output_index,
+ content_index=current_content_index,
+ part=ResponseReasoningTextContent(
+ text=reason_content,
+ type="reasoning_text",
+ ),
+ )
+ )
current_content_index = 0
reasoning_item = ResponseReasoningItem(
type="reasoning",
@@ -1418,6 +1445,8 @@ class OpenAIServingResponses(OpenAIServing):
item=reasoning_item,
)
)
+ current_output_index += 1
+ current_item_id = str(uuid.uuid4())
yield _increment_sequence_number_and_return(
ResponseOutputItemAddedEvent(
type="response.output_item.added",
@@ -1432,8 +1461,6 @@ class OpenAIServingResponses(OpenAIServing):
),
)
)
- current_output_index += 1
- current_item_id = str(uuid.uuid4())
yield _increment_sequence_number_and_return(
ResponseContentPartAddedEvent(
type="response.content_part.added",
@@ -1449,7 +1476,6 @@ class OpenAIServingResponses(OpenAIServing):
),
)
)
- current_content_index += 1
# reset previous delta messages
previous_delta_messages = []
@@ -1485,7 +1511,6 @@ class OpenAIServingResponses(OpenAIServing):
),
)
)
- current_content_index += 1
previous_delta_messages.append(delta_message)
if previous_delta_messages:
@@ -1505,7 +1530,19 @@ class OpenAIServingResponses(OpenAIServing):
text=reason_content,
)
)
- current_content_index += 1
+ yield _increment_sequence_number_and_return(
+ ResponseReasoningPartDoneEvent(
+ type="response.reasoning_part.done",
+ sequence_number=-1,
+ item_id=current_item_id,
+ output_index=current_output_index,
+ content_index=current_content_index,
+ part=ResponseReasoningTextContent(
+ text=reason_content,
+ type="reasoning_text",
+ ),
+ )
+ )
reasoning_item = ResponseReasoningItem(
type="reasoning",
content=[
@@ -1543,7 +1580,6 @@ class OpenAIServingResponses(OpenAIServing):
item_id=current_item_id,
)
)
- current_content_index += 1
part = ResponseOutputText(
text=final_content,
type="output_text",
@@ -1559,7 +1595,6 @@ class OpenAIServingResponses(OpenAIServing):
part=part,
)
)
- current_content_index += 1
item = ResponseOutputMessage(
type="message",
role="assistant",
diff --git a/vllm/entrypoints/openai/speech_to_text/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text/speech_to_text.py
index 966e6d4..1c56f09 100644
--- a/vllm/entrypoints/openai/speech_to_text/speech_to_text.py
+++ b/vllm/entrypoints/openai/speech_to_text/speech_to_text.py
@@ -11,6 +11,7 @@ from typing import Final, Literal, TypeAlias, TypeVar, cast
import numpy as np
from fastapi import Request
+from soundfile import LibsndfileError
from transformers import PreTrainedTokenizerBase
import vllm.envs as envs
@@ -57,6 +58,14 @@ try:
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
+# Public libsndfile error codes exposed via `soundfile.LibsndfileError.code`, soundfile
+# being librosa's main backend. Used to validate if an audio loading error is due to a
+# server error vs a client error (invalid audio file).
+# 1 = unrecognised format (file is not a supported audio container)
+# 3 = malformed file (corrupt or structurally invalid audio)
+# 4 = unsupported encoding (codec not supported by this libsndfile build)
+_BAD_SF_CODES = {1, 3, 4}
+
SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse
SpeechToTextResponseVerbose: TypeAlias = (
TranscriptionResponseVerbose | TranslationResponseVerbose
@@ -315,9 +324,15 @@ class OpenAISpeechToText(OpenAIServing):
)
with io.BytesIO(audio_data) as bytes_:
- # NOTE resample to model SR here for efficiency. This is also a
- # pre-requisite for chunking, as it assumes Whisper SR.
- y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
+ try:
+ # NOTE resample to model SR here for efficiency. This is also a
+ # pre-requisite for chunking, as it assumes Whisper SR.
+ y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
+ except LibsndfileError as exc:
+ # Distinguish client errors (invalid audio) from server errors
+ if exc.code in _BAD_SF_CODES:
+ raise ValueError("Invalid or unsupported audio file.") from exc
+ raise
duration = librosa.get_duration(y=y, sr=sr)
do_split_audio = (
diff --git a/vllm/entrypoints/openai/translations/__init__.py b/vllm/entrypoints/openai/translations/__init__.py
deleted file mode 100644
index cf210d5..0000000
--- a/vllm/entrypoints/openai/translations/__init__.py
+++ /dev/null
@@ -1,12 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import warnings
-
-warnings.warn(
- "The 'vllm.entrypoints.openai.translations' module has been renamed to "
- "'vllm.entrypoints.openai.speech_to_text'. Please update your imports. "
- "This backward-compatible alias will be removed in version 0.17+.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/vllm/entrypoints/openai/translations/api_router.py b/vllm/entrypoints/openai/translations/api_router.py
deleted file mode 100644
index 4a43bf8..0000000
--- a/vllm/entrypoints/openai/translations/api_router.py
+++ /dev/null
@@ -1,14 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import warnings
-
-warnings.warn(
- "'vllm.entrypoints.openai.translations.api_router' has been moved to "
- "'vllm.entrypoints.openai.speech_to_text.api_router'. Please update your "
- "imports. This backward-compatible alias will be removed in version 0.17+.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-from vllm.entrypoints.openai.speech_to_text.api_router import * # noqa: F401,F403,E402
diff --git a/vllm/entrypoints/openai/translations/protocol.py b/vllm/entrypoints/openai/translations/protocol.py
deleted file mode 100644
index c8ec156..0000000
--- a/vllm/entrypoints/openai/translations/protocol.py
+++ /dev/null
@@ -1,14 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import warnings
-
-warnings.warn(
- "'vllm.entrypoints.openai.translations.protocol' has been moved to "
- "'vllm.entrypoints.openai.speech_to_text.protocol'. Please update your "
- "imports. This backward-compatible alias will be removed in version 0.17+.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-from vllm.entrypoints.openai.speech_to_text.protocol import * # noqa: F401,F403,E402
diff --git a/vllm/entrypoints/openai/translations/serving.py b/vllm/entrypoints/openai/translations/serving.py
deleted file mode 100644
index 1749d61..0000000
--- a/vllm/entrypoints/openai/translations/serving.py
+++ /dev/null
@@ -1,14 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import warnings
-
-warnings.warn(
- "'vllm.entrypoints.openai.translations.serving' has been moved to "
- "'vllm.entrypoints.openai.speech_to_text.serving'. Please update your "
- "imports. This backward-compatible alias will be removed in version 0.17+.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-from vllm.entrypoints.openai.speech_to_text.serving import * # noqa: F401,F403,E402
diff --git a/vllm/entrypoints/openai/translations/speech_to_text.py b/vllm/entrypoints/openai/translations/speech_to_text.py
deleted file mode 100644
index eb26c6a..0000000
--- a/vllm/entrypoints/openai/translations/speech_to_text.py
+++ /dev/null
@@ -1,15 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import warnings
-
-warnings.warn(
- "'vllm.entrypoints.openai.translations.speech_to_text' has been moved to "
- "'vllm.entrypoints.openai.speech_to_text.speech_to_text'. Please update "
- "your imports. This backward-compatible alias will be removed in version "
- "0.17+.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-from vllm.entrypoints.openai.speech_to_text.speech_to_text import * # noqa: F401,F403,E402
diff --git a/vllm/entrypoints/pooling/__init__.py b/vllm/entrypoints/pooling/__init__.py
index 1108be1..3ba131d 100644
--- a/vllm/entrypoints/pooling/__init__.py
+++ b/vllm/entrypoints/pooling/__init__.py
@@ -115,6 +115,7 @@ def init_pooling_state(
request_logger=request_logger,
score_template=resolved_chat_template,
log_error_stack=args.log_error_stack,
+ use_gpu_for_pooling_score=getattr(args, "use_gpu_for_pooling_score", False),
)
if any(t in supported_tasks for t in ("embed", "score", "token_embed"))
else None
diff --git a/vllm/entrypoints/pooling/base/io_processor.py b/vllm/entrypoints/pooling/base/io_processor.py
new file mode 100644
index 0000000..254c3d6
--- /dev/null
+++ b/vllm/entrypoints/pooling/base/io_processor.py
@@ -0,0 +1,189 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from collections.abc import Callable, Sequence
+from concurrent.futures import ThreadPoolExecutor
+from typing import Any, Final
+
+from vllm import PoolingRequestOutput, PromptType
+from vllm.config import ModelConfig
+from vllm.entrypoints.chat_utils import (
+ ChatCompletionMessageParam,
+ ChatTemplateConfig,
+ ChatTemplateContentFormatOption,
+ ConversationMessage,
+)
+from vllm.entrypoints.openai.engine.serving import RendererChatRequest, RendererRequest
+from vllm.inputs import ProcessorInputs, SingletonPrompt
+from vllm.renderers import BaseRenderer, merge_kwargs
+from vllm.renderers.inputs import TokPrompt
+from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq
+from vllm.tokenizers import TokenizerLike
+from vllm.tool_parsers import ToolParser
+from vllm.utils.mistral import is_mistral_tokenizer
+
+
+class PoolingIOProcessor:
+ def __init__(
+ self,
+ model_config: ModelConfig,
+ renderer: BaseRenderer,
+ chat_template_config: ChatTemplateConfig,
+ ):
+ self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
+
+ self.model_config = model_config
+ self.renderer = renderer
+
+ self.chat_template = chat_template_config.chat_template
+ self.chat_template_content_format: Final = (
+ chat_template_config.chat_template_content_format
+ )
+ self.trust_request_chat_template = (
+ chat_template_config.trust_request_chat_template
+ )
+
+ def pre_process_online(self, *args, **kwargs):
+ raise NotImplementedError
+
+ async def pre_process_online_async(self, *args, **kwargs):
+ return self.pre_process_online(*args, **kwargs)
+
+ def pre_process_offline(self, *args, **kwargs):
+ raise NotImplementedError
+
+ async def pre_process_offline_async(self, *args, **kwargs):
+ return self.pre_process_offline(*args, **kwargs)
+
+ def post_process(
+ self, outputs: list[PoolingRequestOutput]
+ ) -> list[PoolingRequestOutput]:
+ return outputs
+
+ async def post_process_async(
+ self, outputs: list[PoolingRequestOutput]
+ ) -> list[PoolingRequestOutput]:
+ return self.post_process(outputs)
+
+ def create_pooling_params(self, request):
+ return request.to_pooling_params()
+
+ def _preprocess_completion_online(
+ self,
+ request: RendererRequest,
+ prompt_input: str | list[str] | list[int] | list[list[int]] | None,
+ prompt_embeds: bytes | list[bytes] | None,
+ ) -> list[TokPrompt]:
+ renderer = self.renderer
+ model_config = self.model_config
+
+ prompts = list[SingletonPrompt | bytes]()
+ if prompt_embeds is not None: # embeds take higher priority
+ prompts.extend(prompt_to_seq(prompt_embeds))
+ if prompt_input is not None:
+ prompts.extend(prompt_to_seq(prompt_input))
+
+ parsed_prompts = [
+ (
+ prompt
+ if isinstance(prompt, bytes)
+ else parse_model_prompt(model_config, prompt)
+ )
+ for prompt in prompts
+ ]
+ tok_params = request.build_tok_params(model_config)
+
+ return renderer.render_cmpl(
+ parsed_prompts,
+ tok_params,
+ prompt_extras={
+ k: v
+ for k in ("mm_processor_kwargs", "cache_salt")
+ if (v := getattr(request, k, None)) is not None
+ },
+ )
+
+ def _preprocess_chat_online(
+ self,
+ request: RendererChatRequest,
+ messages: list[ChatCompletionMessageParam],
+ default_template: str | None,
+ default_template_content_format: ChatTemplateContentFormatOption,
+ default_template_kwargs: dict[str, Any] | None,
+ tool_dicts: list[dict[str, Any]] | None = None,
+ tool_parser: Callable[[TokenizerLike], ToolParser] | None = None,
+ ) -> tuple[list[ConversationMessage], list[TokPrompt]]:
+ renderer = self.renderer
+
+ default_template_kwargs = merge_kwargs(
+ default_template_kwargs,
+ dict(
+ tools=tool_dicts,
+ tokenize=is_mistral_tokenizer(renderer.tokenizer),
+ ),
+ )
+
+ tok_params = request.build_tok_params(self.model_config)
+ chat_params = request.build_chat_params(
+ default_template, default_template_content_format
+ ).with_defaults(default_template_kwargs)
+
+ (conversation,), (engine_prompt,) = renderer.render_chat(
+ [messages],
+ chat_params,
+ tok_params,
+ prompt_extras={
+ k: v
+ for k in ("mm_processor_kwargs", "cache_salt")
+ if (v := getattr(request, k, None)) is not None
+ },
+ )
+
+ return conversation, [engine_prompt]
+
+ def _preprocess_completion_offline(
+ self,
+ prompts: PromptType | Sequence[PromptType],
+ tokenization_kwargs: dict[str, Any] | None = None,
+ ) -> Sequence[ProcessorInputs]:
+ renderer = self.renderer
+ model_config = self.model_config
+
+ prompts = prompt_to_seq(prompts)
+
+ parsed_prompts = [
+ (
+ prompt
+ if isinstance(prompt, bytes)
+ else parse_model_prompt(model_config, prompt)
+ )
+ for prompt in prompts
+ ]
+ tok_params = renderer.default_cmpl_tok_params.with_kwargs(
+ **(tokenization_kwargs or {})
+ )
+
+ return renderer.render_cmpl(
+ parsed_prompts,
+ tok_params,
+ )
+
+ def _validate_chat_template(
+ self,
+ request_chat_template: str | None,
+ chat_template_kwargs: dict[str, Any] | None,
+ trust_request_chat_template: bool,
+ ):
+ if not trust_request_chat_template and (
+ request_chat_template is not None
+ or (
+ chat_template_kwargs
+ and chat_template_kwargs.get("chat_template") is not None
+ )
+ ):
+ raise ValueError(
+ "Chat template is passed with request, but "
+ "--trust-request-chat-template is not set. "
+ "Refused request with untrusted chat template."
+ )
+ return None
diff --git a/vllm/entrypoints/pooling/base/protocol.py b/vllm/entrypoints/pooling/base/protocol.py
index 86dc12c..5394510 100644
--- a/vllm/entrypoints/pooling/base/protocol.py
+++ b/vllm/entrypoints/pooling/base/protocol.py
@@ -190,10 +190,6 @@ class EmbedRequestMixin(EncodingRequestMixin):
description="Whether to use activation for the pooler outputs. "
"`None` uses the pooler's default, which is `True` in most cases.",
)
- normalize: bool | None = Field(
- default=None,
- description="Deprecated; please pass `use_activation` instead",
- )
# --8<-- [end:embed-extra-params]
diff --git a/vllm/entrypoints/pooling/base/serving.py b/vllm/entrypoints/pooling/base/serving.py
new file mode 100644
index 0000000..813282d
--- /dev/null
+++ b/vllm/entrypoints/pooling/base/serving.py
@@ -0,0 +1,378 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import time
+from collections.abc import AsyncGenerator, Mapping
+from dataclasses import dataclass, field
+from http import HTTPStatus
+from typing import ClassVar, Generic, TypeVar
+
+from fastapi import Request
+from pydantic import ConfigDict
+from starlette.datastructures import Headers
+from starlette.responses import JSONResponse
+
+from vllm import (
+ PoolingParams,
+ PoolingRequestOutput,
+ PromptType,
+ SamplingParams,
+ envs,
+)
+from vllm.config import ModelConfig
+from vllm.engine.protocol import EngineClient
+from vllm.entrypoints.chat_utils import (
+ ChatTemplateConfig,
+ ChatTemplateContentFormatOption,
+)
+from vllm.entrypoints.logger import RequestLogger
+from vllm.entrypoints.openai.engine.protocol import ErrorResponse
+from vllm.entrypoints.openai.models.serving import OpenAIServingModels
+from vllm.entrypoints.pooling.typing import AnyPoolingRequest, AnyPoolingResponse
+from vllm.inputs import ProcessorInputs
+from vllm.lora.request import LoRARequest
+from vllm.renderers import BaseRenderer
+from vllm.renderers.inputs.preprocess import extract_prompt_components
+from vllm.sampling_params import BeamSearchParams
+from vllm.tracing import (
+ contains_trace_headers,
+ extract_trace_headers,
+ log_tracing_disabled_warning,
+)
+from vllm.utils import random_uuid
+from vllm.utils.async_utils import merge_async_iterators
+
+from ...utils import create_error_response
+from .io_processor import PoolingIOProcessor
+
+PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest)
+
+
+@dataclass(kw_only=True)
+class PoolingServeContext(Generic[PoolingRequestT]):
+ request: PoolingRequestT
+ raw_request: Request | None = None
+ model_name: str
+ request_id: str
+ created_time: int = field(default_factory=lambda: int(time.time()))
+ lora_request: LoRARequest | None = None
+ engine_prompts: list[ProcessorInputs] | None = None
+
+ result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = (
+ None
+ )
+ final_res_batch: list[PoolingRequestOutput] = field(default_factory=list)
+
+ model_config = ConfigDict(arbitrary_types_allowed=True)
+
+
+class PoolingServing:
+ request_id_prefix: ClassVar[str]
+
+ def __init__(
+ self,
+ engine_client: EngineClient,
+ models: OpenAIServingModels,
+ *,
+ request_logger: RequestLogger | None,
+ chat_template: str | None = None,
+ chat_template_content_format: ChatTemplateContentFormatOption = "auto",
+ trust_request_chat_template: bool = False,
+ return_tokens_as_token_ids: bool = False,
+ log_error_stack: bool = False,
+ ):
+ super().__init__()
+ self.engine_client = engine_client
+ self.models = models
+ self.model_config = models.model_config
+ self.max_model_len = self.model_config.max_model_len
+ self.request_logger = request_logger
+ self.return_tokens_as_token_ids = return_tokens_as_token_ids
+ self.log_error_stack = log_error_stack
+ self.chat_template_config = ChatTemplateConfig(
+ chat_template=chat_template,
+ chat_template_content_format=chat_template_content_format,
+ trust_request_chat_template=trust_request_chat_template,
+ )
+ self.io_processor = self.init_io_processor(
+ model_config=models.model_config,
+ renderer=models.renderer,
+ chat_template_config=self.chat_template_config,
+ )
+
+ def init_io_processor(
+ self,
+ model_config: ModelConfig,
+ renderer: BaseRenderer,
+ chat_template_config: ChatTemplateConfig,
+ ) -> PoolingIOProcessor:
+ raise NotImplementedError
+
+ async def __call__(
+ self,
+ request: AnyPoolingRequest,
+ raw_request: Request,
+ ) -> JSONResponse:
+ try:
+ model_name = self.models.model_name()
+ request_id = (
+ f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
+ )
+
+ await self._check_model(request)
+
+ ctx = PoolingServeContext(
+ request=request,
+ raw_request=raw_request,
+ model_name=model_name,
+ request_id=request_id,
+ )
+
+ self._validate_request(ctx)
+ self._maybe_get_adapters(ctx)
+ await self._preprocess(ctx)
+ await self._prepare_generators(ctx)
+ await self._collect_batch(ctx)
+ response = await self._build_response(ctx)
+ return JSONResponse(content=response.model_dump())
+ except Exception as e:
+ error_response = create_error_response(e)
+ return JSONResponse(
+ content=error_response.model_dump(),
+ status_code=error_response.error.code,
+ )
+
+ async def _preprocess(
+ self,
+ ctx: PoolingServeContext,
+ ):
+ ctx.engine_prompts = await self.io_processor.pre_process_online_async(
+ ctx.request
+ )
+
+ async def _prepare_generators(
+ self,
+ ctx: PoolingServeContext,
+ ):
+ if ctx.engine_prompts is None:
+ raise ValueError("Engine prompts not available")
+
+ generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
+
+ trace_headers = (
+ None
+ if ctx.raw_request is None
+ else await self._get_trace_headers(ctx.raw_request.headers)
+ )
+
+ pooling_params = self.io_processor.create_pooling_params(ctx.request)
+
+ for i, engine_prompt in enumerate(ctx.engine_prompts):
+ request_id_item = f"{ctx.request_id}-{i}"
+
+ self._log_inputs(
+ request_id_item,
+ engine_prompt,
+ params=pooling_params,
+ lora_request=ctx.lora_request,
+ )
+
+ generator = self.engine_client.encode(
+ engine_prompt,
+ pooling_params,
+ request_id_item,
+ lora_request=ctx.lora_request,
+ trace_headers=trace_headers,
+ priority=getattr(ctx.request, "priority", 0),
+ )
+
+ generators.append(generator)
+
+ ctx.result_generator = merge_async_iterators(*generators)
+
+ async def _collect_batch(
+ self,
+ ctx: PoolingServeContext,
+ ):
+ if ctx.engine_prompts is None:
+ raise ValueError("Engine prompts not available")
+
+ if ctx.result_generator is None:
+ raise ValueError("Result generator not available")
+
+ num_prompts = len(ctx.engine_prompts)
+ final_res_batch: list[PoolingRequestOutput | None]
+ final_res_batch = [None] * num_prompts
+
+ async for i, res in ctx.result_generator:
+ final_res_batch[i] = res
+
+ if None in final_res_batch:
+ raise ValueError("Failed to generate results for all prompts")
+
+ ctx.final_res_batch = [res for res in final_res_batch if res is not None]
+
+ async def _build_response(
+ self,
+ ctx: PoolingServeContext,
+ ) -> AnyPoolingResponse:
+ raise NotImplementedError
+
+ @staticmethod
+ def _base_request_id(
+ raw_request: Request | None, default: str | None = None
+ ) -> str | None:
+ """Pulls the request id to use from a header, if provided"""
+ if raw_request is not None and (
+ (req_id := raw_request.headers.get("X-Request-Id")) is not None
+ ):
+ return req_id
+
+ return random_uuid() if default is None else default
+
+ def _is_model_supported(self, model_name: str | None) -> bool:
+ if not model_name:
+ return True
+ return self.models.is_base_model(model_name)
+
+ async def _check_model(
+ self,
+ request: AnyPoolingRequest,
+ ) -> ErrorResponse | None:
+ if self._is_model_supported(request.model):
+ return None
+ if request.model in self.models.lora_requests:
+ return None
+ if (
+ envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING
+ and request.model
+ and (load_result := await self.models.resolve_lora(request.model))
+ ):
+ if isinstance(load_result, LoRARequest):
+ return None
+ if (
+ isinstance(load_result, ErrorResponse)
+ and load_result.error.code == HTTPStatus.BAD_REQUEST.value
+ ):
+ raise ValueError(load_result.error.message)
+ return None
+
+ def _validate_request(self, ctx: PoolingServeContext) -> None:
+ truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None)
+
+ if (
+ truncate_prompt_tokens is not None
+ and truncate_prompt_tokens > self.max_model_len
+ ):
+ raise ValueError(
+ "truncate_prompt_tokens value is "
+ "greater than max_model_len."
+ " Please, select a smaller truncation size."
+ )
+ return None
+
+ async def _get_trace_headers(
+ self,
+ headers: Headers,
+ ) -> Mapping[str, str] | None:
+ is_tracing_enabled = await self.engine_client.is_tracing_enabled()
+
+ if is_tracing_enabled:
+ return extract_trace_headers(headers)
+
+ if contains_trace_headers(headers):
+ log_tracing_disabled_warning()
+
+ return None
+
+ def _maybe_get_adapters(
+ self,
+ ctx: PoolingServeContext,
+ supports_default_mm_loras: bool = False,
+ ):
+ request = ctx.request
+ if request.model in self.models.lora_requests:
+ ctx.lora_request = self.models.lora_requests[request.model]
+
+ # Currently only support default modality specific loras
+ # if we have exactly one lora matched on the request.
+ if supports_default_mm_loras:
+ default_mm_lora = self._get_active_default_mm_loras(request)
+ if default_mm_lora is not None:
+ ctx.lora_request = default_mm_lora
+
+ if self._is_model_supported(request.model):
+ return None
+
+ # if _check_model has been called earlier, this will be unreachable
+ raise ValueError(f"The model `{request.model}` does not exist.")
+
+ def _get_active_default_mm_loras(
+ self, request: AnyPoolingRequest
+ ) -> LoRARequest | None:
+ """Determine if there are any active default multimodal loras."""
+ # TODO: Currently this is only enabled for chat completions
+ # to be better aligned with only being enabled for .generate
+ # when run offline. It would be nice to support additional
+ # tasks types in the future.
+ message_types = self._get_message_types(request)
+ default_mm_loras = set()
+
+ for lora in self.models.lora_requests.values():
+ # Best effort match for default multimodal lora adapters;
+ # There is probably a better way to do this, but currently
+ # this matches against the set of 'types' in any content lists
+ # up until '_', e.g., to match audio_url -> audio
+ if lora.lora_name in message_types:
+ default_mm_loras.add(lora)
+
+ # Currently only support default modality specific loras if
+ # we have exactly one lora matched on the request.
+ if len(default_mm_loras) == 1:
+ return default_mm_loras.pop()
+ return None
+
+ def _get_message_types(self, request: AnyPoolingRequest) -> set[str]:
+ """Retrieve the set of types from message content dicts up
+ until `_`; we use this to match potential multimodal data
+ with default per modality loras.
+ """
+ message_types: set[str] = set()
+
+ if not hasattr(request, "messages"):
+ return message_types
+
+ messages = request.messages
+ if messages is None or isinstance(messages, (str, bytes)):
+ return message_types
+
+ for message in messages:
+ if (
+ isinstance(message, dict)
+ and "content" in message
+ and isinstance(message["content"], list)
+ ):
+ for content_dict in message["content"]:
+ if "type" in content_dict:
+ message_types.add(content_dict["type"].split("_")[0])
+ return message_types
+
+ def _log_inputs(
+ self,
+ request_id: str,
+ inputs: PromptType | ProcessorInputs,
+ params: SamplingParams | PoolingParams | BeamSearchParams | None,
+ lora_request: LoRARequest | None,
+ ) -> None:
+ if self.request_logger is None:
+ return
+
+ components = extract_prompt_components(self.model_config, inputs)
+
+ self.request_logger.log_inputs(
+ request_id,
+ components.text,
+ components.token_ids,
+ components.embeds,
+ params=params,
+ lora_request=lora_request,
+ )
diff --git a/vllm/entrypoints/pooling/classify/api_router.py b/vllm/entrypoints/pooling/classify/api_router.py
index 8a1513e..0e99a86 100644
--- a/vllm/entrypoints/pooling/classify/api_router.py
+++ b/vllm/entrypoints/pooling/classify/api_router.py
@@ -3,16 +3,17 @@
from fastapi import APIRouter, Depends, Request
from starlette.responses import JSONResponse
-from typing_extensions import assert_never
-from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.classify.protocol import (
ClassificationRequest,
- ClassificationResponse,
)
from vllm.entrypoints.pooling.classify.serving import ServingClassification
-from vllm.entrypoints.utils import load_aware_call, with_cancellation
+from vllm.entrypoints.utils import (
+ create_error_response,
+ load_aware_call,
+ with_cancellation,
+)
router = APIRouter()
@@ -24,25 +25,17 @@ def classify(request: Request) -> ServingClassification | None:
@router.post("/classify", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
-async def create_classify(request: ClassificationRequest, raw_request: Request):
+async def create_classify(
+ request: ClassificationRequest, raw_request: Request
+) -> JSONResponse:
handler = classify(raw_request)
if handler is None:
- base_server = raw_request.app.state.openai_serving_tokenization
- return base_server.create_error_response(
+ error_response = create_error_response(
message="The model does not support Classification API"
)
-
- try:
- generator = await handler.create_classify(request, raw_request)
- except Exception as e:
- generator = handler.create_error_response(e)
-
- if isinstance(generator, ErrorResponse):
return JSONResponse(
- content=generator.model_dump(), status_code=generator.error.code
+ content=error_response.model_dump(),
+ status_code=error_response.error.code,
)
- elif isinstance(generator, ClassificationResponse):
- return JSONResponse(content=generator.model_dump())
-
- assert_never(generator)
+ return await handler(request, raw_request)
diff --git a/vllm/entrypoints/pooling/classify/io_processor.py b/vllm/entrypoints/pooling/classify/io_processor.py
new file mode 100644
index 0000000..90d5b0e
--- /dev/null
+++ b/vllm/entrypoints/pooling/classify/io_processor.py
@@ -0,0 +1,50 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from collections.abc import Sequence
+from typing import Any
+
+from vllm import PromptType
+from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
+from vllm.entrypoints.pooling.classify.protocol import (
+ ClassificationChatRequest,
+ ClassificationCompletionRequest,
+)
+from vllm.inputs import ProcessorInputs
+from vllm.renderers.inputs import TokPrompt
+
+
+class ClassifyIOProcessor(PoolingIOProcessor):
+ def pre_process_online(
+ self, request: ClassificationCompletionRequest | ClassificationChatRequest
+ ) -> list[TokPrompt] | None:
+ if isinstance(request, ClassificationChatRequest):
+ self._validate_chat_template(
+ request_chat_template=request.chat_template,
+ chat_template_kwargs=request.chat_template_kwargs,
+ trust_request_chat_template=self.trust_request_chat_template,
+ )
+ _, engine_prompts = self._preprocess_chat_online(
+ request,
+ request.messages,
+ default_template=self.chat_template,
+ default_template_content_format=self.chat_template_content_format,
+ default_template_kwargs=None,
+ )
+ elif isinstance(request, ClassificationCompletionRequest):
+ engine_prompts = self._preprocess_completion_online(
+ request,
+ prompt_input=request.input,
+ prompt_embeds=None,
+ )
+ else:
+ raise ValueError("Invalid classification request type")
+ return engine_prompts
+
+ def pre_process_offline(
+ self,
+ prompts: PromptType | Sequence[PromptType],
+ tokenization_kwargs: dict[str, Any] | None = None,
+ ) -> Sequence[ProcessorInputs]:
+ return self._preprocess_completion_offline(
+ prompts=prompts, tokenization_kwargs=tokenization_kwargs
+ )
diff --git a/vllm/entrypoints/pooling/classify/protocol.py b/vllm/entrypoints/pooling/classify/protocol.py
index 3c4bbd8..bfc38eb 100644
--- a/vllm/entrypoints/pooling/classify/protocol.py
+++ b/vllm/entrypoints/pooling/classify/protocol.py
@@ -40,7 +40,6 @@ class ClassificationCompletionRequest(
def to_pooling_params(self):
return PoolingParams(
task="classify",
- truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)
@@ -63,7 +62,6 @@ class ClassificationChatRequest(
def to_pooling_params(self):
return PoolingParams(
task="classify",
- truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)
diff --git a/vllm/entrypoints/pooling/classify/serving.py b/vllm/entrypoints/pooling/classify/serving.py
index 8cdbbde..efd4be7 100644
--- a/vllm/entrypoints/pooling/classify/serving.py
+++ b/vllm/entrypoints/pooling/classify/serving.py
@@ -1,116 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from typing import Final, TypeAlias
+from typing import TypeAlias
-import jinja2
import numpy as np
-from fastapi import Request
-from vllm.engine.protocol import EngineClient
-from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
-from vllm.entrypoints.logger import RequestLogger
-from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo
-from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext
-from vllm.entrypoints.openai.models.serving import OpenAIServingModels
-from vllm.entrypoints.pooling.classify.protocol import (
- ClassificationChatRequest,
- ClassificationCompletionRequest,
+from vllm import ClassificationOutput
+from vllm.config import ModelConfig
+from vllm.entrypoints.chat_utils import ChatTemplateConfig
+from vllm.entrypoints.openai.engine.protocol import UsageInfo
+from vllm.entrypoints.pooling.base.serving import PoolingServeContext, PoolingServing
+from vllm.logger import init_logger
+from vllm.renderers import BaseRenderer
+
+from .io_processor import ClassifyIOProcessor
+from .protocol import (
ClassificationData,
ClassificationRequest,
ClassificationResponse,
)
-from vllm.logger import init_logger
-from vllm.outputs import ClassificationOutput
logger = init_logger(__name__)
-ClassificationServeContext: TypeAlias = ServeContext[ClassificationRequest]
+ClassificationServeContext: TypeAlias = PoolingServeContext[ClassificationRequest]
-class ServingClassification(OpenAIServing):
+class ServingClassification(PoolingServing):
request_id_prefix = "classify"
- def __init__(
+ def init_io_processor(
self,
- engine_client: EngineClient,
- models: OpenAIServingModels,
- *,
- request_logger: RequestLogger | None,
- chat_template: str | None = None,
- chat_template_content_format: ChatTemplateContentFormatOption = "auto",
- trust_request_chat_template: bool = False,
- log_error_stack: bool = False,
- ) -> None:
- super().__init__(
- engine_client=engine_client,
- models=models,
- request_logger=request_logger,
- log_error_stack=log_error_stack,
+ model_config: ModelConfig,
+ renderer: BaseRenderer,
+ chat_template_config: ChatTemplateConfig,
+ ) -> ClassifyIOProcessor:
+ return ClassifyIOProcessor(
+ model_config=model_config,
+ renderer=renderer,
+ chat_template_config=chat_template_config,
)
- self.chat_template = chat_template
- self.chat_template_content_format: Final = chat_template_content_format
- self.trust_request_chat_template = trust_request_chat_template
-
- async def _preprocess(
+ async def _build_response(
self,
ctx: ClassificationServeContext,
- ) -> ErrorResponse | None:
- """
- Process classification inputs: tokenize text, resolve adapters,
- and prepare model-specific inputs.
- """
- try:
- ctx.lora_request = self._maybe_get_adapters(ctx.request)
+ ) -> ClassificationResponse:
+ final_res_batch_checked = await self.io_processor.post_process_async(
+ ctx.final_res_batch
+ )
- if isinstance(ctx.request, ClassificationChatRequest):
- error_check_ret = self._validate_chat_template(
- request_chat_template=ctx.request.chat_template,
- chat_template_kwargs=ctx.request.chat_template_kwargs,
- trust_request_chat_template=self.trust_request_chat_template,
- )
- if error_check_ret:
- return error_check_ret
-
- _, ctx.engine_prompts = await self._preprocess_chat(
- ctx.request,
- ctx.request.messages,
- default_template=self.chat_template,
- default_template_content_format=self.chat_template_content_format,
- default_template_kwargs=None,
- )
- elif isinstance(ctx.request, ClassificationCompletionRequest):
- ctx.engine_prompts = await self._preprocess_completion(
- ctx.request,
- prompt_input=ctx.request.input,
- prompt_embeds=None,
- )
- else:
- return self.create_error_response("Invalid classification request type")
-
- return None
-
- except (ValueError, TypeError, jinja2.TemplateError) as e:
- logger.exception("Error in preprocessing prompt inputs")
- return self.create_error_response(str(e))
-
- def _build_response(
- self,
- ctx: ClassificationServeContext,
- ) -> ClassificationResponse | ErrorResponse:
- """
- Convert model outputs to a formatted classification response
- with probabilities and labels.
- """
id2label = getattr(self.model_config.hf_config, "id2label", {})
-
- items: list[ClassificationData] = []
num_prompt_tokens = 0
-
- final_res_batch_checked = ctx.final_res_batch
-
+ items: list[ClassificationData] = []
for idx, final_res in enumerate(final_res_batch_checked):
classify_res = ClassificationOutput.from_base(final_res.outputs)
@@ -141,20 +82,3 @@ class ServingClassification(OpenAIServing):
data=items,
usage=usage,
)
-
- async def create_classify(
- self,
- request: ClassificationRequest,
- raw_request: Request,
- ) -> ClassificationResponse | ErrorResponse:
- model_name = self.models.model_name()
- request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}"
-
- ctx = ClassificationServeContext(
- request=request,
- raw_request=raw_request,
- model_name=model_name,
- request_id=request_id,
- )
-
- return await self.handle(ctx) # type: ignore[return-value]
diff --git a/vllm/entrypoints/pooling/embed/protocol.py b/vllm/entrypoints/pooling/embed/protocol.py
index 4f83105..4b47c65 100644
--- a/vllm/entrypoints/pooling/embed/protocol.py
+++ b/vllm/entrypoints/pooling/embed/protocol.py
@@ -14,12 +14,9 @@ from vllm.entrypoints.pooling.base.protocol import (
EmbedRequestMixin,
PoolingBasicRequestMixin,
)
-from vllm.logger import init_logger
from vllm.renderers import TokenizeParams
from vllm.utils import random_uuid
-logger = init_logger(__name__)
-
def _get_max_total_output_tokens(
model_config: ModelConfig,
@@ -60,18 +57,10 @@ class EmbeddingCompletionRequest(
)
def to_pooling_params(self):
- if self.normalize is not None:
- logger.warning_once(
- "`normalize` is deprecated and will be removed in v0.17. "
- "Please pass `use_activation` instead."
- )
- self.use_activation = self.normalize
-
return PoolingParams(
task="embed",
dimensions=self.dimensions,
use_activation=self.use_activation,
- truncate_prompt_tokens=self.truncate_prompt_tokens,
)
@@ -97,18 +86,10 @@ class EmbeddingChatRequest(
)
def to_pooling_params(self):
- if self.normalize is not None:
- logger.warning_once(
- "`normalize` is deprecated and will be removed in v0.17. "
- "Please pass `use_activation` instead."
- )
- self.use_activation = self.normalize
-
return PoolingParams(
task="embed",
dimensions=self.dimensions,
use_activation=self.use_activation,
- truncate_prompt_tokens=self.truncate_prompt_tokens,
)
diff --git a/vllm/entrypoints/pooling/io_processor_factories.py b/vllm/entrypoints/pooling/io_processor_factories.py
new file mode 100644
index 0000000..9747676
--- /dev/null
+++ b/vllm/entrypoints/pooling/io_processor_factories.py
@@ -0,0 +1,31 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+
+from vllm.config import ModelConfig
+from vllm.entrypoints.chat_utils import ChatTemplateConfig
+from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor
+from vllm.renderers import BaseRenderer
+from vllm.tasks import SupportedTask
+
+
+def init_pooling_io_processors(
+ supported_tasks: tuple[SupportedTask, ...],
+ model_config: ModelConfig,
+ renderer: BaseRenderer,
+ chat_template_config: ChatTemplateConfig,
+) -> dict[str, PoolingIOProcessor]:
+ pooling_io_processors: dict[str, PoolingIOProcessor] = {}
+
+ if "classify" in supported_tasks:
+ from vllm.entrypoints.pooling.classify.io_processor import (
+ ClassifyIOProcessor,
+ )
+
+ pooling_io_processors["classify"] = ClassifyIOProcessor(
+ model_config=model_config,
+ renderer=renderer,
+ chat_template_config=chat_template_config,
+ )
+
+ return pooling_io_processors
diff --git a/vllm/entrypoints/pooling/pooling/protocol.py b/vllm/entrypoints/pooling/pooling/protocol.py
index a8c1c59..b99f989 100644
--- a/vllm/entrypoints/pooling/pooling/protocol.py
+++ b/vllm/entrypoints/pooling/pooling/protocol.py
@@ -16,13 +16,10 @@ from vllm.entrypoints.pooling.base.protocol import (
EncodingRequestMixin,
PoolingBasicRequestMixin,
)
-from vllm.logger import init_logger
from vllm.renderers import TokenizeParams
from vllm.tasks import PoolingTask
from vllm.utils import random_uuid
-logger = init_logger(__name__)
-
class PoolingCompletionRequest(
PoolingBasicRequestMixin,
@@ -45,16 +42,8 @@ class PoolingCompletionRequest(
)
def to_pooling_params(self):
- if self.normalize is not None:
- logger.warning_once(
- "`normalize` is deprecated and will be removed in v0.17. "
- "Please pass `use_activation` instead."
- )
- self.use_activation = self.normalize
-
return PoolingParams(
task=self.task,
- truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
dimensions=self.dimensions,
)
@@ -78,16 +67,8 @@ class PoolingChatRequest(
)
def to_pooling_params(self):
- if self.normalize is not None:
- logger.warning_once(
- "`normalize` is deprecated and will be removed in v0.17. "
- "Please pass `use_activation` instead."
- )
- self.use_activation = self.normalize
-
return PoolingParams(
task=self.task,
- truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
dimensions=self.dimensions,
)
diff --git a/vllm/entrypoints/pooling/score/protocol.py b/vllm/entrypoints/pooling/score/protocol.py
index a85ed5d..643eeed 100644
--- a/vllm/entrypoints/pooling/score/protocol.py
+++ b/vllm/entrypoints/pooling/score/protocol.py
@@ -37,7 +37,6 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin):
def to_pooling_params(self, task: PoolingTask = "score"):
return PoolingParams(
task=task,
- truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)
@@ -113,7 +112,6 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin):
def to_pooling_params(self, task: PoolingTask = "score"):
return PoolingParams(
task=task,
- truncate_prompt_tokens=self.truncate_prompt_tokens,
use_activation=self.use_activation,
)
diff --git a/vllm/entrypoints/pooling/score/serving.py b/vllm/entrypoints/pooling/score/serving.py
index 3fe18ca..60d6db6 100644
--- a/vllm/entrypoints/pooling/score/serving.py
+++ b/vllm/entrypoints/pooling/score/serving.py
@@ -31,7 +31,7 @@ from vllm.entrypoints.pooling.score.utils import (
ScoreInputs,
_cosine_similarity,
compress_token_type_ids,
- compute_maxsim_score,
+ compute_maxsim_scores,
get_score_prompt,
parse_score_data_single,
validate_score_input,
@@ -56,6 +56,7 @@ class ServingScores(OpenAIServing):
request_logger: RequestLogger | None,
score_template: str | None = None,
log_error_stack: bool = False,
+ use_gpu_for_pooling_score: bool = False,
) -> None:
super().__init__(
engine_client=engine_client,
@@ -64,6 +65,7 @@ class ServingScores(OpenAIServing):
log_error_stack=log_error_stack,
)
self.score_template = score_template
+ self.use_gpu_for_pooling_score = use_gpu_for_pooling_score
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
@@ -311,19 +313,18 @@ class ServingScores(OpenAIServing):
# Compute MaxSim scores
from vllm.outputs import PoolingOutput
+ maxsim_scores = compute_maxsim_scores(
+ [emb.outputs.data for emb in emb_data_1],
+ [emb.outputs.data for emb in emb_data_2],
+ use_gpu_for_pooling_score=self.use_gpu_for_pooling_score,
+ )
+
scores: list[PoolingRequestOutput] = []
padding: list[int] = []
if (pad_token_id := tokenizer.pad_token_id) is not None:
padding = [pad_token_id]
- for emb_1, emb_2 in zip(emb_data_1, emb_data_2):
- # emb_1.outputs.data: [query_len, dim]
- # emb_2.outputs.data: [doc_len, dim]
- q_emb = emb_1.outputs.data
- d_emb = emb_2.outputs.data
-
- maxsim_score = compute_maxsim_score(q_emb, d_emb)
-
+ for emb_1, emb_2, maxsim_score in zip(emb_data_1, emb_data_2, maxsim_scores):
tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
scores.append(
diff --git a/vllm/entrypoints/pooling/score/utils.py b/vllm/entrypoints/pooling/score/utils.py
index 60e71ff..65611dc 100644
--- a/vllm/entrypoints/pooling/score/utils.py
+++ b/vllm/entrypoints/pooling/score/utils.py
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from collections.abc import Iterable
+from collections.abc import Iterable, Sequence
from typing import Any, TypeAlias, cast
import torch
@@ -25,6 +25,7 @@ from vllm.inputs.data import PromptType, TextPrompt
from vllm.model_executor.models.interfaces import supports_score_template
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict
from vllm.outputs import PoolingRequestOutput
+from vllm.platforms import current_platform
from vllm.renderers.hf import safe_apply_chat_template
from vllm.tokenizers import TokenizerLike
@@ -53,6 +54,91 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens
return token_scores.amax(dim=-1).sum()
+def _should_use_gpu_for_maxsim(use_gpu_for_pooling_score: bool) -> bool:
+ return use_gpu_for_pooling_score and not current_platform.is_cpu()
+
+
+def compute_maxsim_scores(
+ q_embs: Sequence[torch.Tensor],
+ d_embs: Sequence[torch.Tensor],
+ max_batch_size: int = 16,
+ max_score_matrix_elements: int = 16_000_000,
+ use_gpu_for_pooling_score: bool = False,
+) -> list[torch.Tensor]:
+ """Compute ColBERT MaxSim scores in padded mini-batches."""
+ if len(q_embs) != len(d_embs):
+ raise ValueError("q_embs and d_embs must have the same length")
+
+ num_pairs = len(q_embs)
+ if num_pairs == 0:
+ return []
+
+ for q_emb, d_emb in zip(q_embs, d_embs):
+ if q_emb.ndim != 2 or d_emb.ndim != 2:
+ raise ValueError("Each embedding tensor must be 2-D")
+ if q_emb.shape[1] != d_emb.shape[1]:
+ raise ValueError("Query and document embeddings must have same dim")
+
+ compute_device = torch.device(
+ current_platform.device_type
+ if _should_use_gpu_for_maxsim(use_gpu_for_pooling_score)
+ else "cpu"
+ )
+ scores: list[torch.Tensor] = []
+ start = 0
+ while start < num_pairs:
+ end = min(start + max_batch_size, num_pairs)
+ max_q = max(int(x.shape[0]) for x in q_embs[start:end])
+ max_d = max(int(x.shape[0]) for x in d_embs[start:end])
+
+ # keep score matrix bounded to avoid oversized allocations.
+ while (
+ end - start > 1
+ and (end - start) * max_q * max_d > max_score_matrix_elements
+ ):
+ end -= 1
+ max_q = max(int(x.shape[0]) for x in q_embs[start:end])
+ max_d = max(int(x.shape[0]) for x in d_embs[start:end])
+
+ batch_q = q_embs[start:end]
+ batch_d = d_embs[start:end]
+ batch_size = end - start
+ dim = int(batch_q[0].shape[1])
+ dtype = batch_q[0].dtype
+
+ q_batch = torch.zeros(
+ (batch_size, max_q, dim), dtype=dtype, device=compute_device
+ )
+ d_batch = torch.zeros(
+ (batch_size, max_d, dim), dtype=dtype, device=compute_device
+ )
+ q_mask = torch.zeros(
+ (batch_size, max_q), dtype=torch.bool, device=compute_device
+ )
+ d_mask = torch.zeros(
+ (batch_size, max_d), dtype=torch.bool, device=compute_device
+ )
+
+ # copy to padded tensors
+ for i, (q_emb, d_emb) in enumerate(zip(batch_q, batch_d)):
+ q_len = int(q_emb.shape[0])
+ d_len = int(d_emb.shape[0])
+ q_batch[i, :q_len] = q_emb.to(device=compute_device, dtype=dtype)
+ d_batch[i, :d_len] = d_emb.to(device=compute_device, dtype=dtype)
+ q_mask[i, :q_len] = True
+ d_mask[i, :d_len] = True
+
+ token_scores = torch.bmm(q_batch, d_batch.transpose(1, 2))
+ token_scores.masked_fill_(~d_mask.unsqueeze(1), float("-inf"))
+ max_per_query = token_scores.amax(dim=-1)
+ max_per_query.masked_fill_(~q_mask, 0)
+ batch_scores = max_per_query.sum(dim=-1).to("cpu")
+ scores.extend(batch_scores.unbind(0))
+ start = end
+
+ return [cast(torch.Tensor, score) for score in scores]
+
+
class ScoreMultiModalParam(TypedDict, total=False):
"""
A specialized parameter type for scoring multimodal content
diff --git a/vllm/entrypoints/pooling/typing.py b/vllm/entrypoints/pooling/typing.py
new file mode 100644
index 0000000..87d6487
--- /dev/null
+++ b/vllm/entrypoints/pooling/typing.py
@@ -0,0 +1,51 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from typing import TypeAlias
+
+from vllm.entrypoints.pooling.classify.protocol import (
+ ClassificationChatRequest,
+ ClassificationCompletionRequest,
+ ClassificationResponse,
+)
+from vllm.entrypoints.pooling.embed.protocol import (
+ EmbeddingBytesResponse,
+ EmbeddingChatRequest,
+ EmbeddingCompletionRequest,
+ EmbeddingResponse,
+)
+from vllm.entrypoints.pooling.pooling.protocol import (
+ IOProcessorRequest,
+ PoolingChatRequest,
+ PoolingCompletionRequest,
+ PoolingResponse,
+)
+from vllm.entrypoints.pooling.score.protocol import (
+ RerankRequest,
+ ScoreRequest,
+ ScoreResponse,
+)
+
+PoolingCompletionLikeRequest: TypeAlias = (
+ EmbeddingCompletionRequest
+ | ClassificationCompletionRequest
+ | RerankRequest
+ | ScoreRequest
+ | PoolingCompletionRequest
+)
+
+PoolingChatLikeRequest: TypeAlias = (
+ EmbeddingChatRequest | ClassificationChatRequest | PoolingChatRequest
+)
+
+AnyPoolingRequest: TypeAlias = (
+ PoolingCompletionLikeRequest | PoolingChatLikeRequest | IOProcessorRequest
+)
+
+AnyPoolingResponse: TypeAlias = (
+ ClassificationResponse
+ | EmbeddingResponse
+ | EmbeddingBytesResponse
+ | PoolingResponse
+ | ScoreResponse
+)
diff --git a/vllm/entrypoints/sagemaker/api_router.py b/vllm/entrypoints/sagemaker/api_router.py
index 1138225..32faaa0 100644
--- a/vllm/entrypoints/sagemaker/api_router.py
+++ b/vllm/entrypoints/sagemaker/api_router.py
@@ -13,6 +13,7 @@ from fastapi.responses import JSONResponse, Response
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.utils import validate_json_request
+from vllm.entrypoints.pooling.base.serving import PoolingServing
from vllm.entrypoints.serve.instrumentator.basic import base
from vllm.entrypoints.serve.instrumentator.health import health
from vllm.tasks import POOLING_TASKS, SupportedTask
@@ -20,7 +21,7 @@ from vllm.tasks import POOLING_TASKS, SupportedTask
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# (requires typing_extensions >= 4.13)
RequestType = Any
-GetHandlerFn = Callable[[Request], OpenAIServing | None]
+GetHandlerFn = Callable[[Request], OpenAIServing | PoolingServing | None]
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]
diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py
index 34df85f..6390a72 100644
--- a/vllm/entrypoints/utils.py
+++ b/vllm/entrypoints/utils.py
@@ -5,7 +5,10 @@ import asyncio
import dataclasses
import functools
import os
+import sys
+import traceback
from argparse import Namespace
+from http import HTTPStatus
from logging import Logger
from string import Template
from typing import TYPE_CHECKING
@@ -17,17 +20,23 @@ from starlette.background import BackgroundTask, BackgroundTasks
from vllm import envs
from vllm.engine.arg_utils import EngineArgs
+from vllm.exceptions import VLLMValidationError
from vllm.logger import current_formatter_type, init_logger
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser
if TYPE_CHECKING:
- from vllm.entrypoints.openai.engine.protocol import StreamOptions
+ from vllm.entrypoints.openai.engine.protocol import (
+ ErrorInfo,
+ ErrorResponse,
+ StreamOptions,
+ )
from vllm.entrypoints.openai.models.protocol import LoRAModulePath
else:
- StreamOptions = object
+ ErrorResponse = object
+ ErrorInfo = object
LoRAModulePath = object
-
+ StreamOptions = object
logger = init_logger(__name__)
@@ -291,3 +300,59 @@ def log_version_and_model(lgr: Logger, version: str, model_name: str) -> None:
message = logo_template.substitute(colors)
lgr.info(message, version, model_name)
+
+
+def create_error_response(
+ message: str | Exception,
+ err_type: str = "BadRequestError",
+ status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
+ param: str | None = None,
+ log_error_stack: bool = False,
+) -> "ErrorResponse":
+ exc: Exception | None = None
+
+ from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse
+
+ if isinstance(message, Exception):
+ exc = message
+
+ if isinstance(exc, VLLMValidationError):
+ err_type = "BadRequestError"
+ status_code = HTTPStatus.BAD_REQUEST
+ param = exc.parameter
+ elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)):
+ # Common validation errors from user input
+ err_type = "BadRequestError"
+ status_code = HTTPStatus.BAD_REQUEST
+ param = None
+ elif isinstance(exc, NotImplementedError):
+ err_type = "NotImplementedError"
+ status_code = HTTPStatus.NOT_IMPLEMENTED
+ param = None
+ elif exc.__class__.__name__ == "TemplateError":
+ # jinja2.TemplateError (avoid importing jinja2)
+ err_type = "BadRequestError"
+ status_code = HTTPStatus.BAD_REQUEST
+ param = None
+ else:
+ err_type = "InternalServerError"
+ status_code = HTTPStatus.INTERNAL_SERVER_ERROR
+ param = None
+
+ message = str(exc)
+
+ if log_error_stack:
+ exc_type, _, _ = sys.exc_info()
+ if exc_type is not None:
+ traceback.print_exc()
+ else:
+ traceback.print_stack()
+
+ return ErrorResponse(
+ error=ErrorInfo(
+ message=sanitize_message(message),
+ type=err_type,
+ code=status_code.value,
+ param=param,
+ )
+ )
diff --git a/vllm/env_override.py b/vllm/env_override.py
index 181d000..2799221 100644
--- a/vllm/env_override.py
+++ b/vllm/env_override.py
@@ -482,3 +482,44 @@ if is_torch_equal("2.9.0"):
PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched
GraphLowering._update_scheduler = _update_scheduler_patched
+
+# ===================================================
+# torch 2.11 Inductor constrain_to_fx_strides monkeypatch
+# ===================================================
+# Patch the inductor's `constrain_to_fx_strides` to handle opaque
+# (non-tensor) arguments. The original calls `.stride()` on every FX
+# arg's meta value, which crashes on FakeScriptObject (the compile-time
+# proxy for hoisted opaque types). The patched version skips args
+# whose meta value is not a torch.Tensor.
+# Upstream issue: https://github.com/pytorch/pytorch/issues/175973
+
+from vllm.utils.torch_utils import is_torch_equal_or_newer
+
+if is_torch_equal_or_newer("2.11.0.dev"):
+ import torch._inductor.ir as _ir
+ import torch._inductor.lowering as _lowering
+ from torch._inductor.virtualized import V as _V
+
+ _orig_constrain = _lowering.constrain_to_fx_strides
+
+ def _patched_constrain_to_fx_strides(fx_node, *args, **kwargs):
+ def apply_constraint(arg, fx_arg):
+ if isinstance(arg, _ir.IRNode):
+ meta_val = fx_arg.meta.get("val")
+ if isinstance(meta_val, torch.Tensor):
+ stride_order = _ir.get_stride_order(
+ meta_val.stride(), _V.graph.sizevars.shape_env
+ )
+ return _ir.ExternKernel.require_stride_order(arg, stride_order)
+ return arg
+ if isinstance(arg, dict):
+ return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg}
+ return arg
+
+ args = tuple(
+ apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)
+ )
+ kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
+ return args, kwargs
+
+ _lowering.constrain_to_fx_strides = _patched_constrain_to_fx_strides
diff --git a/vllm/envs.py b/vllm/envs.py
index be3a568..9f1f5d8 100644
--- a/vllm/envs.py
+++ b/vllm/envs.py
@@ -26,6 +26,9 @@ if TYPE_CHECKING:
VLLM_ENGINE_READY_TIMEOUT_S: int = 600
VLLM_API_KEY: str | None = None
VLLM_DEBUG_LOG_API_SERVER_RESPONSE: bool = False
+ VLLM_ENABLE_PP_MIX_ILU_SCHEDULING: bool = False
+ VLLM_ENABLE_PP_ILU_OPT: bool = False
+ VLLM_PP_ILU_OPT_BATCH_QUEUE_SIZE: int = 0
S3_ACCESS_KEY_ID: str | None = None
S3_SECRET_ACCESS_KEY: str | None = None
S3_ENDPOINT_URL: str | None = None
@@ -35,7 +38,7 @@ if TYPE_CHECKING:
VLLM_USAGE_STATS_SERVER: str = "https://stats.vllm.ai"
VLLM_NO_USAGE_STATS: bool = False
VLLM_DO_NOT_TRACK: bool = False
- VLLM_USAGE_SOURCE: str = ""
+ VLLM_USAGE_SOURCE: str = "production"
VLLM_CONFIGURE_LOGGING: bool = True
VLLM_LOGGING_LEVEL: str = "INFO"
VLLM_LOGGING_PREFIX: str = ""
@@ -48,7 +51,7 @@ if TYPE_CHECKING:
VLLM_USE_FLASHINFER_SAMPLER: bool | None = None
VLLM_PP_LAYER_PARTITION: str | None = None
VLLM_CPU_KVCACHE_SPACE: int | None = 0
- VLLM_CPU_OMP_THREADS_BIND: str = ""
+ VLLM_CPU_OMP_THREADS_BIND: str = "auto"
VLLM_CPU_NUM_OF_RESERVED_CPU: int | None = None
VLLM_CPU_SGL_KERNEL: bool = False
VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache")
@@ -89,13 +92,14 @@ if TYPE_CHECKING:
VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None
VLLM_LORA_RESOLVER_HF_REPO_LIST: str | None = None
VLLM_USE_AOT_COMPILE: bool = False
- VLLM_USE_BYTECODE_HOOK: bool = False
+ VLLM_USE_BYTECODE_HOOK: bool = True
VLLM_FORCE_AOT_LOAD: bool = False
VLLM_USE_MEGA_AOT_ARTIFACT: bool = False
VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_DISABLED_KERNELS: list[str] = []
+ VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE: bool = True
VLLM_DISABLE_PYNCCL: bool = False
VLLM_USE_OINK_OPS: bool = False
VLLM_ROCM_USE_AITER: bool = False
@@ -106,7 +110,7 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_AITER_MHA: bool = True
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
- VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = True
+ VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_AITER_FP4BMM: bool = True
VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False
@@ -168,7 +172,7 @@ if TYPE_CHECKING:
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "masked_gemm"] = (
"latency"
)
- VLLM_FLASHINFER_ALLREDUCE_BACKEND: Literal["auto", "trtllm", "mnnvl"] = "auto"
+ VLLM_FLASHINFER_ALLREDUCE_BACKEND: Literal["auto", "trtllm", "mnnvl"] = "trtllm"
VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024
VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
@@ -231,7 +235,7 @@ if TYPE_CHECKING:
VLLM_USE_FBGEMM: bool = False
VLLM_GC_DEBUG: str = ""
VLLM_DEBUG_WORKSPACE: bool = False
- VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False
+ VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = True
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
VLLM_USE_V2_MODEL_RUNNER: bool = False
@@ -243,17 +247,33 @@ if TYPE_CHECKING:
VLLM_LORA_DISABLE_PDL: bool = False
VLLM_ENABLE_CUDA_COMPATIBILITY: bool = False
VLLM_CUDA_COMPATIBILITY_PATH: str | None = None
+ VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: bool = False
+ VLLM_ELASTIC_EP_DRAIN_REQUESTS: bool = False
+ VLLM_FLAT_LOGPROBS: bool = False
# optional envs we add.
VLLM_W8A8_MOE_USE_W4A8: bool = False
+ VLLM_WNA16_MOE_USE_W4A8: bool = False
+ VLLM_W8A8_FORMAT: str = "TN"
VLLM_W4A8_FORMAT: str = "TN"
VLLM_W4A8_VERSION: int = 2
VLLM_MIX_QUANTIZATION_TYPE: str = ""
VLLM_MLA_CUSTOMIZE: bool = True
VLLM_USE_INT8_MLA: bool = False
- VLLM_W8A8_LINEAR_USE_W4A8: bool = False
- VLLM_FORCE_NCCL_COMM: bool =False
- VLLM_KV_DISABLE_CROSS_GROUP_SHARE: bool = False
+ # support Iluvatar IxServer
+ VLLM_ATTN_OPT_LEVEL: int = 0
+ VLLM_MOE_OPT_LEVEL: int = 0
+ VLLM_LINEAR_OPT_LEVEL: int = 0
+ VLLM_OPT_EXCLUDE_LAYERS: str = ""
+ VLLM_LINEAR_SPECIFIED_LAYERS: str = ""
+ VLLM_LINEAR_SPECIFIED_KEYS: str = ""
+ VLLM_LINEAR_SPECIFIED_OPT_LEVEL: int = 0
+
+ VLLM_USE_LORA_FUSION: bool = False
+
+ VLLM_USE_SILU_QUANT_FUSION: bool = False
+ # static quant for attention
+ VLLM_ATTN_STATIC_QUANT_SCALE_FILE_PATH: str = ""
def get_default_cache_root():
return os.getenv(
@@ -478,6 +498,8 @@ def get_env_or_set_default(
logger = logging.getLogger(__name__)
+IGNORED_UNKNOWN_VARS = {"VLLM_ENFORCE_CUDA_GRAPH"}
+
environment_variables: dict[str, Callable[[], Any]] = {
# ================== Installation Time Env Vars ==================
# Target device of vLLM, supporting [cuda (by default),
@@ -648,6 +670,22 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False"
).lower()
== "true",
+ # When set to 1, scheduler.schedule() delegates to schedule_opt()
+ # for PP mix ILU scheduling.
+ "VLLM_ENABLE_PP_MIX_ILU_SCHEDULING": lambda: os.environ.get(
+ "VLLM_ENABLE_PP_MIX_ILU_SCHEDULING", "0"
+ ) == "1",
+ # When set to 1, use step_with_batch_queue_ilu_opt (async batch queue with
+ # background thread). Batch queue size can be controlled via
+ # VLLM_PP_ILU_OPT_BATCH_QUEUE_SIZE.
+ "VLLM_ENABLE_PP_ILU_OPT": lambda: os.environ.get(
+ "VLLM_ENABLE_PP_ILU_OPT", "0"
+ ) == "1",
+ # Batch queue size used when VLLM_ENABLE_PP_ILU_OPT=1.
+ # If <= 0, EngineCore falls back to base_batch_queue_size * 2.
+ "VLLM_PP_ILU_OPT_BATCH_QUEUE_SIZE": lambda: int(
+ os.environ.get("VLLM_PP_ILU_OPT_BATCH_QUEUE_SIZE", "0")
+ ),
# S3 access information, used for tensorizer to load model from S3
"S3_ACCESS_KEY_ID": lambda: os.environ.get("S3_ACCESS_KEY_ID", None),
"S3_SECRET_ACCESS_KEY": lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None),
@@ -862,7 +900,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
),
# Time in ms for the zmq client to wait for a response from the backend
# server for simple data operations
- "VLLM_RPC_TIMEOUT": lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "10000")),
+ "VLLM_RPC_TIMEOUT": lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "1000000")),
# Timeout in seconds for keeping HTTP connections alive in API server
"VLLM_HTTP_TIMEOUT_KEEP_ALIVE": lambda: int(
os.environ.get("VLLM_HTTP_TIMEOUT_KEEP_ALIVE", "5")
@@ -907,6 +945,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_DISABLED_KERNELS": lambda: []
if "VLLM_DISABLED_KERNELS" not in os.environ
else os.environ["VLLM_DISABLED_KERNELS"].split(","),
+ "VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE": lambda: bool(
+ int(os.getenv("VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE", "1"))
+ ),
# Disable pynccl (using torch.distributed instead)
"VLLM_DISABLE_PYNCCL": lambda: (
os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1")
@@ -957,9 +998,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
os.getenv("VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", "False").lower() in ("true", "1")
),
# Whether to use aiter rope.
- # By default is enabled.
+ # By default is disabled.
"VLLM_ROCM_USE_AITER_TRITON_ROPE": lambda: (
- os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "True").lower() in ("true", "1")
+ os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "False").lower() in ("true", "1")
),
# Whether to use aiter triton fp8 bmm kernel
# By default is enabled.
@@ -1305,9 +1346,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Flashinfer fused allreduce backend.
# "auto" will default to "mnnvl", which performs mostly same/better than "trtllm".
# But "mnnvl" backend does not support fuse with quantization.
+ # TODO: Default is "trtllm" right now because "mnnvl" has issues with cudagraph:
+ # https://github.com/vllm-project/vllm/issues/35772
+ # Should switch back to "auto" if the issue is resolved.
"VLLM_FLASHINFER_ALLREDUCE_BACKEND": env_with_choices(
"VLLM_FLASHINFER_ALLREDUCE_BACKEND",
- "auto",
+ "trtllm",
["auto", "trtllm", "mnnvl"],
),
# Control the workspace buffer size for the FlashInfer backend.
@@ -1478,41 +1522,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
),
# Allows vllm to find tuned config under customized folder
"VLLM_TUNED_CONFIG_FOLDER": lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None),
- # vLLM do not support W4A8 and W8A8, we add it. For MOE, we default use W8A8, If set to true, we use W4A8.
-
- "VLLM_W8A8_LINEAR_USE_W4A8":
- lambda: os.environ.get("VLLM_W8A8_LINEAR_USE_W4A8", "0").lower() in
- ("1", "true"),
-
- "VLLM_W8A8_MOE_USE_W4A8":
- lambda: os.environ.get("VLLM_W8A8_MOE_USE_W4A8", "0").lower() in
- ("1", "true"),
-
- "VLLM_FORCE_NCCL_COMM":
- lambda: os.environ.get("VLLM_FORCE_NCCL_COMM", "0").lower() in
- ("1", "true"),
- # If set to true, we use int8 mla attention for decode stage.
-
- "VLLM_USE_INT8_MLA":
- lambda: os.environ.get("VLLM_USE_INT8_MLA", "0").lower() in
- ("1", "true"),
-
- # For W4A8 MOE, we default use TN gemm format, choices: [TN, NN].
- "VLLM_W4A8_FORMAT":
- lambda: os.environ.get("VLLM_W4A8_FORMAT", "TN").upper(),
-
- "VLLM_W4A8_VERSION":
- # For W4A8 MOE, we default use version 2, choices: [1, 2].
- lambda: int(os.environ.get("VLLM_W4A8_VERSION", "2")),
-
- # temp param to support compressed-tensor's multi-quantization
- "VLLM_MIX_QUANTIZATION_TYPE":
- lambda: os.environ.get("VLLM_MIX_QUANTIZATION_TYPE", "").upper(),
-
- # Use Customize mlp impl for faster speed and less gpu memory usage.
- "VLLM_MLA_CUSTOMIZE":
- lambda: os.environ.get("VLLM_MLA_CUSTOMIZE", "1").lower() in
- ("1", "true"),
# Valid values are container,code_interpreter,web_search_preview
# ex VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter
# If the server_label of your mcp tool is not in this list it will
@@ -1607,7 +1616,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_DEBUG_WORKSPACE": lambda: bool(int(os.getenv("VLLM_DEBUG_WORKSPACE", "0"))),
# Disables parallel execution of shared_experts via separate cuda stream
"VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool(
- int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "0"))
+ int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "1"))
),
# Limits when we run shared_experts in a separate stream.
# We found out that for large batch sizes, the separate stream
@@ -1662,10 +1671,93 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_CUDA_COMPATIBILITY_PATH": lambda: os.environ.get(
"VLLM_CUDA_COMPATIBILITY_PATH", None
),
- # control kv cache share cross group
- "VLLM_KV_DISABLE_CROSS_GROUP_SHARE":
- lambda: os.environ.get("VLLM_KV_DISABLE_CROSS_GROUP_SHARE", "0").lower() in
- ("1", "true")
+ # Whether it is a scale up launch engine for elastic EP,
+ # Should only be set by EngineCoreClient.
+ "VLLM_ELASTIC_EP_SCALE_UP_LAUNCH": lambda: bool(
+ int(os.getenv("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH", "0"))
+ ),
+ # Whether to wait for all requests to drain before sending the
+ # scaling command in elastic EP.
+ "VLLM_ELASTIC_EP_DRAIN_REQUESTS": lambda: bool(
+ int(os.getenv("VLLM_ELASTIC_EP_DRAIN_REQUESTS", "0"))
+ ),
+ # Flag to enable FlatLogprobs whose GC overhead is significantly smaller than
+ # the original list[dict[int, Logprob]] approach.
+ # After enabled, PromptLogprobs and SampleLogprobs would populated as
+ # FlatLogprobs.
+ "VLLM_FLAT_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLAT_LOGPROBS", "0"))),
+
+ # vLLM do not support W4A8 and W8A8, we add it. For MOE, we default use W8A8, If set to true, we use W4A8.
+
+ "VLLM_W8A8_MOE_USE_W4A8":
+ lambda: os.environ.get("VLLM_W8A8_MOE_USE_W4A8", "0").lower() in
+ ("1", "true"),
+
+ "VLLM_WNA16_MOE_USE_W4A8":
+ lambda: os.environ.get("VLLM_WNA16_MOE_USE_W4A8", "0").lower() in
+ ("1", "true"),
+
+ # If set to true, we use int8 mla attention for decode stage.
+
+ "VLLM_USE_INT8_MLA":
+ lambda: os.environ.get("VLLM_USE_INT8_MLA", "0").lower() in
+ ("1", "true"),
+
+ # For attn, 0 for f16qkv, 1 for i8qkv, 2 for i8qkf16v
+ "VLLM_ATTN_OPT_LEVEL":
+ lambda: int(os.environ.get("VLLM_ATTN_OPT_LEVEL", "0")),
+
+ # For W8A8 MOE, we default use TN gemm format, choices: [TN, NN].However, GEMV(Cuda Graph) only supports NN.
+ "VLLM_W8A8_FORMAT":
+ lambda: os.environ.get("VLLM_W8A8_FORMAT", "TN").upper(),
+
+ # For W4A8 MOE, we default use TN gemm format, choices: [TN, NN].
+ "VLLM_W4A8_FORMAT":
+ lambda: os.environ.get("VLLM_W4A8_FORMAT", "TN").upper(),
+
+ "VLLM_W4A8_VERSION":
+ # For W4A8 MOE, we default use version 2, choices: [1, 2].
+ lambda: int(os.environ.get("VLLM_W4A8_VERSION", "2")),
+
+ # temp param to support compressed-tensor's multi-quantization
+ "VLLM_MIX_QUANTIZATION_TYPE":
+ lambda: os.environ.get("VLLM_MIX_QUANTIZATION_TYPE", "").upper(),
+
+ # Use Customize mlp impl for faster speed and less gpu memory usage.
+ "VLLM_MLA_CUSTOMIZE":
+ lambda: os.environ.get("VLLM_MLA_CUSTOMIZE", "1").lower() in
+ ("1", "true"),
+
+ # support Iluvatar IxServer
+ # Does vLLM support Iluvatar IxServer which is a distributed inference framework.
+ "VLLM_MOE_OPT_LEVEL":
+ lambda: int(os.getenv("VLLM_MOE_OPT_LEVEL", "0")),
+
+ "VLLM_LINEAR_OPT_LEVEL":
+ lambda: int(os.getenv("VLLM_LINEAR_OPT_LEVEL", "0")),
+
+ "VLLM_OPT_EXCLUDE_LAYERS":
+ lambda: os.environ.get("VLLM_OPT_EXCLUDE_LAYERS", "").upper(),
+
+ "VLLM_LINEAR_SPECIFIED_LAYERS":
+ lambda: os.environ.get("VLLM_LINEAR_SPECIFIED_LAYERS", "").upper(),
+
+ "VLLM_LINEAR_SPECIFIED_KEYS":
+ lambda: os.environ.get("VLLM_LINEAR_SPECIFIED_KEYS", "").lower(),
+
+ "VLLM_LINEAR_SPECIFIED_OPT_LEVEL":
+ lambda: int(os.getenv("VLLM_LINEAR_SPECIFIED_OPT_LEVEL", "0")),
+
+ "VLLM_USE_LORA_FUSION":
+ lambda: os.environ.get("VLLM_USE_LORA_FUSION", "0").lower() in
+ ("1", "true"),
+
+ "VLLM_USE_SILU_QUANT_FUSION":
+ lambda: os.environ.get("VLLM_USE_SILU_QUANT_FUSION", "0").lower() in
+ ("1", "true"),
+
+ "VLLM_ATTN_STATIC_QUANT_SCALE_FILE_PATH":
+ lambda: os.environ.get("VLLM_ATTN_STATIC_QUANT_SCALE_FILE_PATH", ""),
}
@@ -1738,6 +1830,8 @@ def is_set(name: str):
def validate_environ(hard_fail: bool) -> None:
for env in os.environ:
if env.startswith("VLLM_") and env not in environment_variables:
+ if env in IGNORED_UNKNOWN_VARS:
+ continue
if hard_fail:
raise ValueError(f"Unknown vLLM environment variable detected: {env}")
else:
@@ -1801,10 +1895,7 @@ def compile_factors() -> dict[str, object]:
"VLLM_ENABLE_V1_MULTIPROCESSING",
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE",
"VLLM_CPU_KVCACHE_SPACE",
- "VLLM_CPU_OMP_THREADS_BIND",
- "VLLM_CPU_NUM_OF_RESERVED_CPU",
"VLLM_CPU_MOE_PREPACK",
- "VLLM_CPU_SGL_KERNEL",
"VLLM_TEST_FORCE_LOAD_FORMAT",
"VLLM_ENABLE_CUDA_COMPATIBILITY",
"VLLM_CUDA_COMPATIBILITY_PATH",
diff --git a/vllm/forward_context.py b/vllm/forward_context.py
index a0753b1..15e3263 100644
--- a/vllm/forward_context.py
+++ b/vllm/forward_context.py
@@ -241,7 +241,7 @@ class ForwardContext:
additional_kwargs: dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
- assert self.cudagraph_runtime_mode.valid_runtime_modes(), (
+ assert self.cudagraph_runtime_mode.is_valid_runtime_mode(), (
f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
)
@@ -347,7 +347,6 @@ def set_forward_context(
num_tokens_unpadded=num_tokens,
parallel_config=vllm_config.parallel_config,
allow_microbatching=False,
- allow_dp_padding=False,
)
assert num_tokens_across_dp is not None
dp_metadata = DPMetadata.make(
diff --git a/vllm/kernels/helion/register.py b/vllm/kernels/helion/register.py
index 3114631..cd0ef83 100644
--- a/vllm/kernels/helion/register.py
+++ b/vllm/kernels/helion/register.py
@@ -31,8 +31,8 @@ by key matches the config returned by the autotuner.
Key Classes
-----------
-- HelionKernelWrapper: Wraps raw kernel + config_picker, creates configured ops
-- ConfiguredHelionKernel: Platform-specific kernel registered as PyTorch custom op
+- HelionKernelWrapper: Wraps raw kernel + config_picker, creates configured kernels
+- ConfiguredHelionKernel: Platform-specific kernel with pre-tuned configs
- PresetConfigSearch: Custom autotuner that returns pre-tuned configs
"""
@@ -53,10 +53,27 @@ if not has_helion():
)
import helion
+from helion._compat import requires_torch_version
from helion.autotuner.base_search import BaseAutotuner
from helion.runtime.config import Config
from helion.runtime.settings import default_autotuner_fn
+# TODO(gmagogsfm): Remove CustomOp fallback path (_get_or_register_custom_op,
+# vllm_helion_lib, direct_register_custom_op) once vLLM requires PyTorch >= 2.11.
+_HOP_AVAILABLE = requires_torch_version("2.11")
+
+if _HOP_AVAILABLE:
+ import torch.utils._pytree as pytree
+ from helion._compiler._dynamo.higher_order_ops import (
+ helion_kernel_side_table,
+ helion_kernel_wrapper_mutation,
+ )
+ from helion._compiler._dynamo.variables import infer_output_spec
+ from torch.fx.experimental.proxy_tensor import (
+ disable_proxy_modes_tracing,
+ get_proxy_mode,
+ )
+
logger = init_logger(__name__)
vllm_helion_lib = Library("vllm_helion", "FRAGMENT") # noqa
@@ -233,7 +250,7 @@ class ConfiguredHelionKernel:
class HelionKernelWrapper:
- """Wrapper for Helion kernels that creates config-specific PyTorch custom ops."""
+ """Wrapper for Helion kernels with pre-tuned config selection and HOP support."""
def __init__(
self,
@@ -252,11 +269,86 @@ class HelionKernelWrapper:
self._config_picker: (
Callable[[tuple[Any, ...], list[str]], str | None] | None
) = None
+ self._configured_kernel: ConfiguredHelionKernel | None = None
self._input_generator: Callable[[], dict[str, tuple[Any, ...]]] | None = None
def __call__(self, *args, **kwargs):
- configured_op = self.get_configured_op()
- return configured_op(*args, **kwargs)
+ # CustomOp fallback: register as torch custom op for torch.compile
+ # compatibility on older PyTorch lacking HOP/EffectType support
+ if not _HOP_AVAILABLE:
+ custom_op = self._get_or_register_custom_op()
+ return custom_op(*args, **kwargs)
+ # HOP tracing: record HigherOrderOp in the FX graph
+ if get_proxy_mode() is not None:
+ return self._call_via_hop(args, kwargs)
+ # Eager: run the configured kernel directly
+ return self.get_configured_op()(*args, **kwargs)
+
+ def _call_via_hop(
+ self,
+ args: tuple[Any, ...],
+ kwargs: dict[str, Any],
+ ) -> Any:
+ kernel = self.get_configured_op()._decorated_kernel
+ kernel_idx = helion_kernel_side_table.add_kernel(kernel)
+
+ constant_args, tensor_args = self._partition_args(kernel, args, kwargs)
+
+ all_named = {**constant_args, **tensor_args}
+ full_args = tuple(
+ all_named.get(n, p.default)
+ for n, p in kernel.signature.parameters.items() # type: ignore[attr-defined]
+ if n in all_named or p.default is not p.empty
+ )
+
+ with disable_proxy_modes_tracing():
+ output_spec = infer_output_spec(kernel, full_args)
+
+ hop_result = helion_kernel_wrapper_mutation(
+ kernel_idx=kernel_idx,
+ constant_args=constant_args,
+ tensor_args=tensor_args,
+ output_spec=output_spec,
+ )
+
+ tree_spec_str = output_spec.get("tree_spec_str")
+ if tree_spec_str is None:
+ return None
+ tree_spec = pytree.treespec_loads(tree_spec_str)
+
+ hop_iter = iter(hop_result)
+ reconstructed = []
+ for spec in output_spec["leaf_specs"]:
+ is_constant_scalar = spec["type"] == "scalar" and not isinstance(
+ spec.get("scalar_value"), torch.SymInt
+ )
+ if is_constant_scalar:
+ reconstructed.append(spec["scalar_value"])
+ else:
+ reconstructed.append(next(hop_iter))
+ return pytree.tree_unflatten(reconstructed, tree_spec)
+
+ @staticmethod
+ def _partition_args(
+ kernel: Any,
+ args: tuple[Any, ...],
+ kwargs: dict[str, Any],
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
+ constant_args: dict[str, Any] = {}
+ tensor_args: dict[str, Any] = {}
+ params = list(kernel.signature.parameters.keys())
+ for i, val in enumerate(args):
+ name = params[i]
+ if isinstance(val, torch.Tensor):
+ tensor_args[name] = val
+ else:
+ constant_args[name] = val
+ for name, val in kwargs.items():
+ if isinstance(val, torch.Tensor):
+ tensor_args[name] = val
+ else:
+ constant_args[name] = val
+ return constant_args, tensor_args
def register_config_picker(
self, picker_func: Callable[[tuple[Any, ...], list[str]], str | None]
@@ -309,29 +401,32 @@ class HelionKernelWrapper:
)
return autotune_kernel.autotune(inputs)
- def get_configured_op(self) -> Any:
+ def get_configured_op(self) -> ConfiguredHelionKernel:
assert self._config_picker is not None, (
f"No config picker registered for kernel '{self.op_name}'. "
f"Use @{self.op_name}.register_config_picker to register one."
)
+ if self._configured_kernel is None:
+ self._configured_kernel = ConfiguredHelionKernel(
+ op_name=self.op_name,
+ config_picker=self._config_picker,
+ raw_kernel_func=self.raw_kernel_func,
+ helion_settings=self.helion_settings,
+ )
+
+ return self._configured_kernel
+
+ def _get_or_register_custom_op(self) -> Any:
if hasattr(torch.ops.vllm_helion, self.op_name):
- logger.debug("Op vllm_helion::%s already registered", self.op_name)
return getattr(torch.ops.vllm_helion, self.op_name)
- configured_kernel = ConfiguredHelionKernel(
- op_name=self.op_name,
- config_picker=self._config_picker,
- raw_kernel_func=self.raw_kernel_func,
- helion_settings=self.helion_settings,
- )
+ configured_kernel = self.get_configured_op()
logger.info("Registering op: vllm_helion::%s", self.op_name)
direct_register_custom_op(
op_name=self.op_name,
- op_func=configured_kernel._decorated_kernel, # Register decorated kernel
- # TODO(gmagogsfm): Implement automatic mutation/aliasing detection
- # for Helion kernels.
+ op_func=configured_kernel._decorated_kernel,
mutates_args=None,
fake_impl=self._fake_impl,
target_lib=vllm_helion_lib,
diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py
index e08dcc8..eff05b5 100644
--- a/vllm/lora/layers/fused_moe.py
+++ b/vllm/lora/layers/fused_moe.py
@@ -32,10 +32,10 @@ from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
UnfusedOAITritonExperts,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
- FusedMoEModularKernel,
+ FusedMoEKernel,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
- MoEPrepareAndFinalizeNoEP,
+ MoEPrepareAndFinalizeNoDPEPModular,
)
from .utils import _get_lora_device, try_get_optimal_moe_lora_config
@@ -83,7 +83,11 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
):
if envs.VLLM_TUNED_CONFIG_FOLDER:
hidden_size = layer.hidden_size
- intermediate_size = layer.intermediate_size_per_partition
+ intermediate_size = (
+ self.w2_lora_a_stacked[0].shape[-1]
+ if op_prefix == "w2"
+ else self.w13_lora_b_stacked[0].shape[-2]
+ )
shrink_config = get_lora_op_configs(
op_type=f"fused_moe_lora_{op_prefix}_shrink",
max_loras=num_loras,
@@ -132,7 +136,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
if getattr(self.base_layer.quant_method, "supports_internal_mk", False):
# Use the existing modular kernel from the quant method
- m_fused_moe_fn = self.base_layer.quant_method.moe_mk
+ m_fused_moe_fn = self.base_layer.quant_method.moe_kernel
# Don't let the kernel own shared experts so the runner can
# overlap them with routed experts via a separate CUDA stream.
m_fused_moe_fn.shared_experts = None
@@ -140,8 +144,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
# Create a new modular kernel via select_gemm_impl.
# Don't pass shared_experts to the kernel so the runner can
# overlap them with routed experts via a separate CUDA stream.
- prepare_finalize = MoEPrepareAndFinalizeNoEP()
- m_fused_moe_fn = FusedMoEModularKernel(
+ prepare_finalize = MoEPrepareAndFinalizeNoDPEPModular()
+ m_fused_moe_fn = FusedMoEKernel(
prepare_finalize,
self.base_layer.quant_method.select_gemm_impl(
prepare_finalize, self.base_layer
@@ -150,10 +154,11 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
if quant_config.use_mxfp4_w4a16:
assert isinstance(
- m_fused_moe_fn.fused_experts, (MarlinExperts, UnfusedOAITritonExperts)
+ m_fused_moe_fn.impl.fused_experts,
+ (MarlinExperts, UnfusedOAITritonExperts),
)
else:
- assert isinstance(m_fused_moe_fn.fused_experts, TritonExperts)
+ assert isinstance(m_fused_moe_fn.impl.fused_experts, TritonExperts)
def fwd_decorator(layer, func):
def wrapper(*args, **kwargs):
@@ -333,9 +338,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA):
return wrapper
- fused_experts = m_fused_moe_fn.fused_experts
+ fused_experts = m_fused_moe_fn.impl.fused_experts
- m_fused_moe_fn.forward = fwd_decorator(self.base_layer, m_fused_moe_fn.forward)
+ m_fused_moe_fn.apply = fwd_decorator(self.base_layer, m_fused_moe_fn.apply)
fused_experts.activation = act_decorator(
self.base_layer, fused_experts.activation
)
diff --git a/vllm/lora/layers/logits_processor.py b/vllm/lora/layers/logits_processor.py
index 217c46f..237a61e 100644
--- a/vllm/lora/layers/logits_processor.py
+++ b/vllm/lora/layers/logits_processor.py
@@ -88,10 +88,8 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
model_config: PretrainedConfig | None = None,
) -> None:
# TODO: Verify if this condition can be further relaxed
- if self.base_layer.vocab_size <= 32000 or self.base_layer.vocab_size > 258048:
- raise ValueError(
- "When using LoRA, vocab size must be > 32000 and <= 258048"
- )
+ if self.base_layer.vocab_size > 258048:
+ raise ValueError("When using LoRA, vocab size must be <= 258048")
self.lora_a_stacked = torch.zeros(
(
max_loras,
diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py
index c9c85c1..7fc49d8 100644
--- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py
+++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py
@@ -8,9 +8,10 @@ from vllm.distributed import (
tensor_model_parallel_all_reduce,
)
from vllm.triton_utils import tl, triton
+from vllm.triton_utils.allocation import set_triton_allocator
from vllm.utils.torch_utils import direct_register_custom_op
-from .utils import supports_pdl
+from .utils import supports_pdl, supports_tma
@triton.jit
@@ -70,6 +71,37 @@ def _get_token_offs(
)
+@triton.jit
+def _get_c_ptrs(
+ cur_c_ptr,
+ lora_id,
+ pid_m,
+ offs,
+ offs_token,
+ offs_cn,
+ stride_cm,
+ stride_cn,
+ EM: tl.constexpr,
+ BLOCK_SIZE_M: tl.constexpr,
+ sort_c: tl.constexpr,
+):
+ # When sort_c is true, store the output in c_ptr using token order defined
+ # in sorted_token_ids_ptr; otherwise, use the original token order from the prompt
+ if sort_c:
+ offs_token_id = pid_m * BLOCK_SIZE_M + offs
+ c_ptrs = (
+ cur_c_ptr
+ + lora_id * EM * stride_cm
+ + stride_cm * offs_token_id[:, None]
+ + stride_cn * offs_cn[None, :]
+ )
+ else:
+ c_ptrs = (
+ cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
+ )
+ return c_ptrs
+
+
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
@@ -95,7 +127,7 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):
def _adjust_kernel_inputs(
- num_active_loras: int,
+ num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
sorted_token_ids: torch.Tensor | None,
expert_ids: torch.Tensor,
):
@@ -109,7 +141,7 @@ def _adjust_kernel_inputs(
else:
stride_tl = sorted_token_ids.stride(0)
stride_el = expert_ids.stride(0)
- grid_lora_dim = num_active_loras
+ grid_lora_dim = num_active_loras.item()
return grid_lora_dim, stride_tl, stride_el
@@ -125,7 +157,9 @@ def _adjust_kernel_inputs(
)
def _fused_moe_lora_kernel(
a_ptr,
+ a_desc,
b_ptr,
+ b_desc,
c_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
@@ -177,6 +211,18 @@ def _fused_moe_lora_kernel(
USE_GDC: tl.constexpr,
launch_pdl: tl.constexpr,
IS_PRIMARY: tl.constexpr,
+ USE_TMA: tl.constexpr,
+ # sort_c determines whether tokens are stored in C in the order determined
+ # by sorted_token_ids to enable later TMA loads from this tensor.
+ #
+ # When USE_TMA is enabled, the parameter combinations are:
+ # a_desc | b_desc | sort_c | Use Case
+ # --------|---------|--------|-----------------------------
+ # yes | yes | False | expand kernel (num_slices=1)
+ # no | yes | True | shrink kernel (num_slices=1)
+ # yes | no | False | expand kernel (num_slices>1)
+ # no | no | True | shrink kernel (num_slices>1)
+ sort_c: tl.constexpr,
):
pid = tl.program_id(axis=0)
slice_id = tl.program_id(axis=1)
@@ -250,58 +296,90 @@ def _fused_moe_lora_kernel(
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty))
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
- # remove modulo wrap-around
- offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int32)
offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
token_mask = offs_token < num_valid_tokens
- # get a_ptrs,b_ptrs
- a_ptrs = cur_a_ptr + (
- offs_token[:, None] // token_mapping_factor * stride_am
- + offs_k[None, :] * stride_ak
- )
+ if USE_TMA and a_desc is not None:
+ # Expand path - with TMA enabled, load from A using TMA descriptor
+ offs_am = (
+ slice_id * max_loras * EM
+ + lora_id * EM
+ + pid_m * BLOCK_SIZE_M // token_mapping_factor
+ )
+ offs_ak = pid_sk * BLOCK_SIZE_K
+ else:
+ # Shrink path - load hidden states based on order defined in
+ # 'sorted_token_ids_ptr' then store them in c_ptr in this same sorted order
+ tl.static_assert(a_desc is None, "a_desc must be none")
+ a_ptrs = cur_a_ptr + (
+ offs_token[:, None] // token_mapping_factor * stride_am
+ + offs_k[None, :] * stride_ak
+ )
- b_ptrs = (
- cur_b_ptr
- + lora_id * stride_bl
- + expert_id * stride_be
- + offs_k[:, None] * stride_bk
- + offs_bn[None, :] * stride_bn
- )
+ if USE_TMA:
+ offs_bn = pid_n * BLOCK_SIZE_N
+ offs_bk = pid_sk * BLOCK_SIZE_K
+ if b_desc is None:
+ # Note(@gnovack) - Allocation of TMA descriptors on-device
+ # can cause conflicts when running in parallel via PDL
+ if USE_GDC and not IS_PRIMARY:
+ tl.extra.cuda.gdc_wait()
+
+ b_desc = tl.make_tensor_descriptor(
+ cur_b_ptr,
+ shape=[max_loras, num_experts, N, K],
+ strides=[stride_bl, stride_be, stride_bn, stride_bk],
+ block_shape=[1, 1, BLOCK_SIZE_N, BLOCK_SIZE_K],
+ )
+ else:
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int32)
+ b_ptrs = (
+ cur_b_ptr
+ + lora_id * stride_bl
+ + expert_id * stride_be
+ + offs_k[:, None] * stride_bk
+ + offs_bn[None, :] * stride_bn
+ )
if USE_GDC and IS_PRIMARY:
# GDC launch dependents hints the runtime system to launch dependent kernels.
tl.extra.cuda.gdc_launch_dependents()
- # accumulator
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
if USE_GDC and not IS_PRIMARY:
tl.extra.cuda.gdc_wait()
for k in range(0, grid_k):
- k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K)
- # GDC wait waits for ALL programs in the prior kernel to complete
- # before continuing.
+ cur_k_offset = k * (BLOCK_SIZE_K * SPLIT_K)
+ k_remaining = K - cur_k_offset
# pre-fetch lora weight
- # add (offs_bn < N) mask; optional .ca for B
- b_mask = (offs_k[:, None] < k_remaining) & (offs_bn[None, :] < N)
- if USE_B_L2_CACHE:
- b = tl.load(b_ptrs, mask=b_mask, other=0.0, cache_modifier=".ca")
+ if b_desc is not None:
+ b = (
+ b_desc.load([lora_id, expert_id, offs_bn, offs_bk + cur_k_offset])
+ .reshape(BLOCK_SIZE_N, BLOCK_SIZE_K)
+ .T
+ )
else:
- b = tl.load(b_ptrs, mask=b_mask, other=0.0)
+ # add (offs_bn < N) mask; optional .ca for B
+ b_mask = (offs_k[:, None] < k_remaining) & (offs_bn[None, :] < N)
+ if USE_B_L2_CACHE:
+ b = tl.load(b_ptrs, mask=b_mask, other=0.0, cache_modifier=".ca")
+ else:
+ b = tl.load(b_ptrs, mask=b_mask, other=0.0)
+ b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
+
+ if a_desc is not None:
+ a = a_desc.load([offs_am, offs_ak + cur_k_offset])
+ else:
+ a = tl.load(
+ a_ptrs,
+ mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
+ other=0.0,
+ )
+ a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
- if USE_GDC and not IS_PRIMARY:
- tl.extra.cuda.gdc_wait()
- a = tl.load(
- a_ptrs,
- mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
- other=0.0,
- )
accumulator += tl.dot(a, b)
- # Advance the ptrs to the next K block.
- a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
- b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0)
@@ -309,7 +387,19 @@ def _fused_moe_lora_kernel(
accumulator = accumulator.to(c_ptr.dtype.element_ty)
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
- c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
+ c_ptrs = _get_c_ptrs(
+ cur_c_ptr,
+ lora_id,
+ pid_m,
+ offs,
+ offs_token,
+ offs_cn,
+ stride_cm,
+ stride_cn,
+ EM,
+ BLOCK_SIZE_M,
+ sort_c,
+ )
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
if SPLIT_K == 1:
@@ -354,9 +444,10 @@ def _fused_moe_lora_shrink(
num_warps: int,
num_stages: int,
split_k: int,
- num_active_loras: int,
+ num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
mul_routed_weight: bool = False,
use_gdc: bool = False,
+ use_tma: bool = False,
) -> None:
w1_lora_a_stacked = lora_a_stacked[0]
shrink_config = {
@@ -369,6 +460,7 @@ def _fused_moe_lora_shrink(
"SPLIT_K": split_k,
"USE_GDC": use_gdc,
"launch_pdl": use_gdc, # triton kernel metadata
+ "USE_TMA": use_tma,
}
b_ptr = _get_ptr(lora_a_stacked, device)
@@ -383,9 +475,20 @@ def _fused_moe_lora_shrink(
len(lora_a_stacked),
grid_lora_dim,
)
+
+ a_desc = None
+ b_desc = None
+ if use_tma and num_slices == 1:
+ b_desc = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor(
+ lora_a_stacked[0],
+ [1, 1, shrink_config["BLOCK_SIZE_N"], shrink_config["BLOCK_SIZE_K"]],
+ )
+
_fused_moe_lora_kernel[grid](
qcurr_hidden_states,
+ a_desc,
b_ptr,
+ b_desc,
a_intermediate_cache1,
topk_weights,
sorted_token_ids,
@@ -407,8 +510,8 @@ def _fused_moe_lora_shrink(
w1_lora_a_stacked.stride(1),
w1_lora_a_stacked.stride(3),
w1_lora_a_stacked.stride(2),
- a_intermediate_cache1.stride(2),
- a_intermediate_cache1.stride(3),
+ a_intermediate_cache1.stride(-2),
+ a_intermediate_cache1.stride(-1),
stride_tl,
stride_el,
slice_a_size=qcurr_hidden_states.numel(),
@@ -419,7 +522,8 @@ def _fused_moe_lora_shrink(
naive_block_assignment=sorted_token_ids is None,
MUL_ROUTED_WEIGHT=False,
ADD_INPUTS=False,
- USE_B_L2_CACHE=True, # new
+ USE_B_L2_CACHE=True,
+ sort_c=use_tma and sorted_token_ids is not None,
IS_PRIMARY=True,
**shrink_config,
)
@@ -458,10 +562,11 @@ def _fused_moe_lora_expand(
num_warps: int,
num_stages: int,
split_k: int,
- num_active_loras: int,
+ num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
mul_routed_weight: bool = False,
offset: int = 0,
use_gdc: bool = False,
+ use_tma: bool = False,
) -> None:
b_ptr = _get_ptr(lora_b_stacked, device)
K = max_lora_rank
@@ -470,7 +575,7 @@ def _fused_moe_lora_expand(
w1_lora_b_stacked = lora_b_stacked[0]
a_intermediate_cache1 = a_intermediate_cache1.view(
- -1, a_intermediate_cache1.shape[3]
+ -1, a_intermediate_cache1.shape[-1]
)
expand_config = {
@@ -483,6 +588,7 @@ def _fused_moe_lora_expand(
"SPLIT_K": 1, # Set split_k = 1 for expand calls
"USE_GDC": use_gdc,
"launch_pdl": use_gdc, # triton kernel metadata
+ "USE_TMA": use_tma,
}
grid_lora_dim, stride_tl, stride_el = _adjust_kernel_inputs(
@@ -498,10 +604,27 @@ def _fused_moe_lora_expand(
# Fast path: directly accumulate into the corresponding slice interval of output.
out_view = output[:, :, offset : offset + num_slices * N]
slice_c_size = N * out_view.stride(2)
+ a_desc = None
+ b_desc = None
+ if use_tma:
+ if sorted_token_ids is not None:
+ a_desc = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor(
+ a_intermediate_cache1,
+ [expand_config["BLOCK_SIZE_M"], expand_config["BLOCK_SIZE_K"]],
+ )
+ if num_slices == 1:
+ b_desc = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor(
+ lora_b_stacked[0],
+ [1, 1, expand_config["BLOCK_SIZE_N"], expand_config["BLOCK_SIZE_K"]],
+ )
+ else:
+ b_desc = None
_fused_moe_lora_kernel[grid](
a_intermediate_cache1,
+ a_desc,
b_ptr,
+ b_desc,
out_view,
topk_weights,
sorted_token_ids,
@@ -535,7 +658,8 @@ def _fused_moe_lora_expand(
naive_block_assignment=sorted_token_ids is None,
MUL_ROUTED_WEIGHT=mul_routed_weight,
ADD_INPUTS=True,
- USE_B_L2_CACHE=True, # new
+ USE_B_L2_CACHE=True,
+ sort_c=False,
IS_PRIMARY=False,
**expand_config,
)
@@ -559,7 +683,7 @@ def _fused_moe_lora(
max_lora_rank: int,
top_k_num: int,
lora_ids: torch.Tensor,
- num_active_loras: int,
+ num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
adapter_enabled: torch.Tensor,
shrink_block_size_m: int,
shrink_block_size_n: int,
@@ -616,8 +740,34 @@ def _fused_moe_lora(
else num_tokens * shrink_block_size_m
)
+ # TMA is not currently compatiple with fully_sharded due to the non-determinism
+ # of token id sorting across ranks.
+ use_tma = supports_tma(device) and not fully_sharded
+
+ intermediate_cache_shape = (
+ num_slices,
+ M,
+ top_k_num,
+ max_lora_rank,
+ )
+ if use_tma:
+ if num_slices > 1:
+ # if num_slices > 1, we construct TMA descriptors for LoRA
+ # weights within the kernel, which requires us to first set an allocator
+ set_triton_allocator(device)
+
+ # When storing intermediate data in sorted order for TMA, we
+ # need an extra 'num_active_loras' dim in the cache to avoid conflicts
+ if sorted_token_ids is not None:
+ intermediate_cache_shape = (
+ num_slices,
+ sorted_token_ids.shape[0],
+ EM,
+ max_lora_rank,
+ )
+
a_intermediate_cache1 = torch.zeros(
- (num_slices, M, top_k_num, max_lora_rank),
+ intermediate_cache_shape,
dtype=output.dtype,
device=device,
)
@@ -654,6 +804,7 @@ def _fused_moe_lora(
num_active_loras,
mul_routed_weight,
use_gdc=use_gdc,
+ use_tma=use_tma,
)
if fully_sharded:
@@ -703,6 +854,7 @@ def _fused_moe_lora(
mul_routed_weight,
offset,
use_gdc=use_gdc,
+ use_tma=use_tma,
)
@@ -719,7 +871,7 @@ def _fused_moe_lora_fake(
max_lora_rank: int,
top_k_num: int,
lora_ids: torch.Tensor,
- num_active_loras: int,
+ num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
adapter_enabled: torch.Tensor,
shrink_block_size_m: int,
shrink_block_size_n: int,
@@ -769,9 +921,10 @@ def _fused_moe_lora_shrink_fake(
num_warps: int,
num_stages: int,
split_k: int,
- num_active_loras: int,
+ num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
mul_routed_weight: bool = False,
use_gdc: bool = False,
+ use_tma: bool = False,
) -> None:
return
@@ -805,10 +958,11 @@ def _fused_moe_lora_expand_fake(
num_warps: int,
num_stages: int,
split_k: int,
- num_active_loras: int,
+ num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
mul_routed_weight: bool = False,
offset: int = 0,
use_gdc: bool = False,
+ use_tma: bool = False,
) -> None:
return
diff --git a/vllm/lora/ops/triton_ops/lora_expand_op.py b/vllm/lora/ops/triton_ops/lora_expand_op.py
index 1557d37..343e0c8 100644
--- a/vllm/lora/ops/triton_ops/lora_expand_op.py
+++ b/vllm/lora/ops/triton_ops/lora_expand_op.py
@@ -138,7 +138,7 @@ def _lora_expand(
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
- num_active_loras: int, # number of active LoRAs (unused here, for API compat)
+ num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
@@ -235,7 +235,7 @@ def _lora_expand(
grid = (
triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
NUM_SLICES,
- num_active_loras,
+ num_active_loras.item(),
)
# We disable PDL temporarily because LoRA kernels are not launching back-to-back,
# making PDL invalid and affecting the kernel performance.
@@ -289,7 +289,7 @@ def _lora_expand_fake(
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
- num_active_loras: int,
+ num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
diff --git a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py
index 1fec1d5..dd7c2c7 100644
--- a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py
+++ b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py
@@ -29,9 +29,16 @@ class LoRAKernelMeta:
# to early exit from inside the lora_expand / lora_shrink torch operation.
no_lora_flag_cpu: torch.Tensor
- # Number of active LoRAs (unique non-(-1) values in token_lora_mapping)
- # Stored as a Python int to avoid GPU->CPU sync during forward pass
- num_active_loras: int = 0
+ # Number of active LoRAs (unique non-(-1) values in token_lora_mapping).
+ # Stored as a CPU tensor (not a Python int) so that torch.compile treats
+ # it as a dynamic value rather than baking it as a constant at trace time.
+ # This follows the same pattern as no_lora_flag_cpu above.
+ num_active_loras_cpu: torch.Tensor
+
+ # Default num_active_loras value (max_loras + 1) as a CPU tensor,
+ # used when specialize_active_lora is False to avoid allocating a
+ # new tensor on every meta_args() call.
+ default_num_active_loras_cpu: torch.Tensor
# Captured LoRA counts for cudagraph specialization (sorted list).
# When specialize_active_lora is enabled, num_active_loras is rounded up
@@ -73,6 +80,11 @@ class LoRAKernelMeta:
no_lora_flag_cpu = torch.tensor([False], dtype=torch.bool, device="cpu")
+ num_active_loras_cpu = torch.tensor([0], dtype=torch.int32, device="cpu")
+ default_num_active_loras_cpu = torch.tensor(
+ [max_loras + 1], dtype=torch.int32, device="cpu"
+ )
+
return LoRAKernelMeta(
token_lora_mapping=token_lora_mapping,
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
@@ -80,6 +92,8 @@ class LoRAKernelMeta:
num_tokens_per_lora=num_tokens_per_lora,
lora_token_start_loc=lora_token_start_loc,
no_lora_flag_cpu=no_lora_flag_cpu,
+ num_active_loras_cpu=num_active_loras_cpu,
+ default_num_active_loras_cpu=default_num_active_loras_cpu,
captured_lora_counts=sorted(captured_lora_counts)
if captured_lora_counts
else [],
@@ -90,8 +104,7 @@ class LoRAKernelMeta:
self.num_tokens_per_lora.fill_(0)
self.lora_token_start_loc.fill_(0)
self.no_lora_flag_cpu.fill_(False)
- self.num_active_loras = 0
- self.captured_lora_counts = []
+ self.num_active_loras_cpu.fill_(0)
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
"""
@@ -137,14 +150,16 @@ class LoRAKernelMeta:
num_tokens_per_lora, non_blocking=True
)
- self.num_active_loras = lora_ids.size(0)
+ num_active_loras = lora_ids.size(0)
# Round up num_active_loras to match cudagraph capture keys.
# This ensures the kernel grid dimension matches the captured graph.
- if self.captured_lora_counts and self.num_active_loras > 0:
- idx = bisect.bisect_left(self.captured_lora_counts, self.num_active_loras)
+ if self.captured_lora_counts and num_active_loras > 0:
+ idx = bisect.bisect_left(self.captured_lora_counts, num_active_loras)
if idx < len(self.captured_lora_counts):
- self.num_active_loras = self.captured_lora_counts[idx]
+ num_active_loras = self.captured_lora_counts[idx]
+
+ self.num_active_loras_cpu[0] = num_active_loras
# lora_token_start_loc
lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0)
@@ -163,7 +178,7 @@ class LoRAKernelMeta:
torch.Tensor,
torch.Tensor,
torch.Tensor,
- int,
+ torch.Tensor,
]:
"""
This function returns the kernel metadata required for the current
@@ -175,7 +190,10 @@ class LoRAKernelMeta:
token_nums (int): Number of input tokens in the current forward
pass of the kernel.
"""
- max_loras = self.active_lora_ids.size(0) - 1
+ if specialize_active_lora:
+ num_active_loras = self.num_active_loras_cpu
+ else:
+ num_active_loras = self.default_num_active_loras_cpu
return (
self.token_lora_mapping[:token_nums],
self.token_indices_sorted_by_lora_ids[:token_nums],
@@ -183,5 +201,5 @@ class LoRAKernelMeta:
self.lora_token_start_loc,
self.active_lora_ids,
self.no_lora_flag_cpu,
- self.num_active_loras if specialize_active_lora else max_loras + 1,
+ num_active_loras,
)
diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py
index 8dbd988..ea850ba 100644
--- a/vllm/lora/ops/triton_ops/lora_shrink_op.py
+++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py
@@ -134,7 +134,7 @@ def _lora_shrink(
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
- num_active_loras: int, # number of active LoRAs (unused here, for API compat)
+ num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
scaling: float,
) -> None:
"""
@@ -157,6 +157,9 @@ def _lora_shrink(
lora_ids (torch.Tensor): LoRA ids to process.
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA.
+ num_active_loras (torch.Tensor): A CPU tensor of size 1, containing the
+ number of active LoRAs. Stored as a tensor (not int) so
+ torch.compile treats it as dynamic rather than a constant.
scaling (float): Scaling factor.
"""
@@ -215,7 +218,7 @@ def _lora_shrink(
grid = (
SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
NUM_SLICES,
- num_active_loras,
+ num_active_loras.item(),
)
# We disable PDL temporarily because LoRA kernels are not launching back-to-back,
# making PDL invalid and affecting the kernel performance.
@@ -267,7 +270,7 @@ def _lora_shrink_fake(
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
- num_active_loras: int,
+ num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
scaling: float,
) -> None:
return
diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py
index c7ac591..a863b97 100644
--- a/vllm/lora/ops/triton_ops/utils.py
+++ b/vllm/lora/ops/triton_ops/utils.py
@@ -316,3 +316,9 @@ def supports_pdl(device: torch.device | None = None) -> bool:
and current_platform.has_device_capability(90)
and not envs.VLLM_LORA_DISABLE_PDL
)
+
+
+@lru_cache
+def supports_tma(device: torch.device | None = None) -> bool:
+ # TMA requires compute capability SM90 or above
+ return current_platform.is_cuda() and current_platform.has_device_capability(90)
diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py
index b75d297..2964a0d 100644
--- a/vllm/lora/punica_wrapper/punica_gpu.py
+++ b/vllm/lora/punica_wrapper/punica_gpu.py
@@ -233,6 +233,17 @@ class PunicaWrapperGPU(PunicaWrapperBase):
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
+ import vllm.envs as env
+ if env.VLLM_USE_LORA_FUSION:
+ import ixformer.inference.functions as ops
+
+ num_token, m = x.size(0), x.size(-1)
+ k, n = lora_b_stacked[0].size(-1), y.size(-1)
+ if len(lora_a_stacked) == 1 and ops.lora_gemv_optim_condition(num_token, m, k, n):
+ ops.add_lora_linear(y, x, lora_a_stacked, lora_b_stacked,
+ lora_bias_stacked = None, scale = 1.0, output_slices = (1,))
+ return
+
assert buffer is None, (
"To minimize overhead, the buffer should be created by "
".add_lora_linear() instead of being passed in."
@@ -351,6 +362,8 @@ class PunicaWrapperGPU(PunicaWrapperBase):
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
+ if topk_ids.numel() < num_experts:
+ max_num_tokens_padded = topk_ids.numel() * block_size
sorted_ids = torch.empty(
(max_loras * max_num_tokens_padded,),
dtype=torch.int32,
diff --git a/vllm/model_executor/kernels/linear/__init__.py b/vllm/model_executor/kernels/linear/__init__.py
index 1b4b7dc..cbdc53d 100644
--- a/vllm/model_executor/kernels/linear/__init__.py
+++ b/vllm/model_executor/kernels/linear/__init__.py
@@ -127,10 +127,10 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
# in priority/performance order (when available)
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
PlatformEnum.CUDA: [
+ MarlinLinearKernel,
CutlassW4A8LinearKernel,
MacheteLinearKernel,
AllSparkLinearKernel,
- MarlinLinearKernel,
ConchLinearKernel,
ExllamaLinearKernel,
],
diff --git a/vllm/model_executor/kernels/linear/mixed_precision/machete.py b/vllm/model_executor/kernels/linear/mixed_precision/machete.py
index 7953ed5..b756c8a 100644
--- a/vllm/model_executor/kernels/linear/mixed_precision/machete.py
+++ b/vllm/model_executor/kernels/linear/mixed_precision/machete.py
@@ -69,7 +69,6 @@ class MacheteLinearKernel(MPLinearKernel):
# `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module):
c = self.config
-
if c.has_g_idx:
assert self.w_gidx_name is not None
perm = torch.argsort(getattr(layer, self.w_gidx_name)).to(torch.int)
@@ -86,19 +85,17 @@ class MacheteLinearKernel(MPLinearKernel):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
if c.has_g_idx:
- x_unpacked = unpack_quantized_values_into_int32(
- x.data, c.weight_type, packed_dim=0
- )
+ x_unpacked = unpack_quantized_values_into_int32(x.data,
+ c.weight_type,
+ packed_dim=0)
x_perm = x_unpacked[perm, :]
- x.data = pack_quantized_values_into_int32(
- x_perm, c.weight_type, packed_dim=0
- )
- x.data = ops.machete_prepack_B(
- x.data.t().contiguous().t(),
- a_type=c.act_type,
- b_type=c.weight_type,
- group_scales_type=c.act_type,
- )
+ x.data = pack_quantized_values_into_int32(x_perm,
+ c.weight_type,
+ packed_dim=0)
+ x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
+ a_type=c.act_type,
+ b_type=c.weight_type,
+ group_scales_type=c.act_type)
return x
def transform_w_s(x):
@@ -144,16 +141,14 @@ class MacheteLinearKernel(MPLinearKernel):
else:
w_zp = None
- output = ops.machete_mm(
- a=x_2d,
- b_q=w_q,
- b_type=c.weight_type,
- b_group_zeros=w_zp,
- b_group_scales=w_s,
- b_group_size=c.group_size,
- )
+ output = ops.machete_mm(a=x_2d,
+ b_q=w_q,
+ b_type=c.weight_type,
+ b_group_zeros=w_zp,
+ b_group_scales=w_s,
+ b_group_size=c.group_size)
if bias is not None:
output.add_(bias) # In-place add
- return output.reshape(out_shape)
+ return output.reshape(out_shape)
\ No newline at end of file
diff --git a/vllm/model_executor/kernels/linear/mixed_precision/marlin.py b/vllm/model_executor/kernels/linear/mixed_precision/marlin.py
index eb14f9e..b955324 100644
--- a/vllm/model_executor/kernels/linear/mixed_precision/marlin.py
+++ b/vllm/model_executor/kernels/linear/mixed_precision/marlin.py
@@ -23,9 +23,95 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ pack_quantized_values_into_int32, unpack_quantized_values_into_int32)
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
+from vllm.scalar_type import ScalarType, scalar_types
+import ixformer.inference.functions as ixf_ops
+from vllm.model_executor.layers.quantization.utils import replace_parameter
+from vllm.logger import init_logger
+logger = init_logger(__name__)
+
+
+def unpack_rows(packed_w: torch.Tensor, num_bits: int) -> torch.Tensor:
+ """
+ Efficient vectorized unpacking.
+ Converts [K // pack_factor, N] int32 tensor → [K, N] int8 tensor.
+
+ Args:
+ packed_w: torch.int32 tensor of shape [K // pack_factor, N].
+ num_bits: Number of bits per packed element (e.g., 4).
+
+ Returns:
+ unpacked: torch.int8 tensor of shape [K, N].
+ """
+ pack_factor = 32 // num_bits
+ k_packed, n = packed_w.shape
+ k = k_packed * pack_factor
+
+ mask = (1 << num_bits) - 1
+
+ # [pack_factor, 1, 1]
+ shifts = (num_bits * torch.arange(pack_factor, device=packed_w.device)).view(-1, 1, 1)
+
+ # [pack_factor, k_packed, n]
+ packed_expanded = packed_w.unsqueeze(0)
+
+ # Extract each group of num_bits using bitwise ops
+ unpacked_groups = ((packed_expanded >> shifts) & mask).to(torch.int8)
+ # [pack_factor, k_packed, n] → [k, n]
+ unpacked = unpacked_groups.permute(1, 0, 2).reshape(k, n)
+
+ return unpacked
+
+
+def pack_cols(x: torch.Tensor, pack_num: int = 8, order_map=None) -> torch.Tensor:
+ """
+ Efficient vectorized version: pack int4 values (0–15) into int32.
+ Each int32 element contains `pack_num` 4-bit values.
+
+ Args:
+ x: Tensor of shape [rows, cols * pack_num], dtype=int32.
+ Represents unpacked int4 values.
+ pack_num: Number of 4-bit elements to pack into each int32.
+ order_map: Index mapping defining the order of 4-bit packing,
+ must match the unpack order used in `unpack_tensor`.
+
+ Returns:
+ Tensor of shape [rows, cols], dtype=int32 — packed result.
+ """
+ # Default sequential order if none provided
+ if order_map is None:
+ order_map = list(range(pack_num))
+ order_map = torch.tensor(order_map, device=x.device)
+
+ # Number of bits per packed element (e.g., 32 / 8 = 4 bits)
+ unit = 32 // pack_num
+ rows, cols_pack = x.shape
+ assert cols_pack % pack_num == 0, "Number of columns must be a multiple of pack_num"
+ cols = cols_pack // pack_num
+
+ # Reshape input into groups of `pack_num` int4 values
+ # Shape: [rows, cols, pack_num]
+ x_reshape = x.view(rows, cols, pack_num)
+
+ # Reorder elements according to order_map
+ # order_map is broadcasted to match shape [rows, cols, pack_num]
+ x_reorder = torch.gather(x_reshape, 2, order_map.view(1, 1, -1).expand(rows, cols, -1))
+
+ # Keep only the lower 4 bits of each value
+ x_reorder = x_reorder & 0xF
+
+ # Compute bit shifts for each position (e.g., [0, 4, 8, 12, 16, 20, 24, 28])
+ shifts = (unit * torch.arange(pack_num, device=x.device)).view(1, 1, -1)
+
+ # Shift and combine (bitwise OR) along the last dimension
+ # Using sum() is safe since bits don't overlap between 4-bit slots
+ res = (x_reorder << shifts).sum(dim=-1).to(torch.int32)
+
+ return res
class MarlinLinearKernel(MPLinearKernel):
@classmethod
@@ -79,96 +165,133 @@ class MarlinLinearKernel(MPLinearKernel):
getattr(layer, self.w_s_name).data = (
getattr(layer, self.w_s_name).data * 512
)
+ assert (c.weight_type.size_bits == 4) , f"MarlinLinearKernel now only support uint4, uint4b8, \
+ now quant weight_type {c.weight_typ}"
+
+ # device = getattr(layer, self.w_q_name).device
+
- row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
- self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
+ # row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
+ # self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
# Allocate marlin workspace.
- self.workspace = marlin_make_workspace_new(device)
+ # self.workspace = marlin_make_workspace_new(device)
# Default names since marlin requires empty parameters for these,
# TODO: remove this requirement from marlin (allow optional tensors)
- if self.w_gidx_name is None:
- self.w_gidx_name = "g_idx"
- if self.w_zp_name is None:
- self.w_zp_name = "w_zp"
+ # if self.w_gidx_name is None:
+ # self.w_gidx_name = "g_idx"
+ # if self.w_zp_name is None:
+ # self.w_zp_name = "w_zp"
+ if c.has_g_idx:
+ assert self.w_gidx_name is not None
+ perm = torch.argsort(getattr(layer, self.w_gidx_name)).to(torch.int)
+
+ self.act_perm = lambda x: x[:, perm]
def transform_w_q(x):
- assert isinstance(x, BasevLLMParameter)
- permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
- x.data = ops.gptq_marlin_repack(
- x.data.contiguous(),
- perm=layer.g_idx_sort_indices,
- size_k=c.partition_weight_shape[0],
- size_n=c.partition_weight_shape[1],
- num_bits=c.weight_type.size_bits,
- is_a_8bit=is_a_8bit,
- )
+ # assert isinstance(x, BasevLLMParameter)
+ # permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
+ # x.data = ops.gptq_marlin_repack(
+ # x.data.contiguous(),
+ # perm=layer.g_idx_sort_indices,
+ # size_k=c.partition_weight_shape[0],
+ # size_n=c.partition_weight_shape[1],
+ # num_bits=c.weight_type.size_bits,
+ # is_a_8bit=is_a_8bit,
+ # )
+ assert x.data.ndim == 2
+ if x._packed_dim == 1: #CompressedTensorsWNA16
+ #[oc, ic // 8] - > [oc, ic]
+ x_unpacked = unpack_quantized_values_into_int32(x.data,
+ c.weight_type,
+ packed_dim=1)
+ if c.has_g_idx:
+ x_unpacked = x_unpacked[:,perm]
+ #[oc, ic] -> [ic, oc]
+ x_unpacked = x_unpacked.t().contiguous()
+
+ elif x._packed_dim == 0: #GPTQMarlinLinearMethod
+
+ #[ic // 8, oc] -> [ic , oc]
+ x_unpacked = unpack_rows(x.data,c.weight_type.size_bits)
+ if c.has_g_idx:
+ x_unpacked = x_unpacked[perm:]
+ raise NotImplementedError(f"GPTQMarlinLinearMethod has_g_idx not test, \
+ Please check whether the model's inference results are correct, and annotate/modify the statement. ")
+ else:
+ raise NotImplementedError(f"transform_w_q pack_dim {x._packed_dim} not implement")
+
+ #[ic, oc]-> [ic, oc//8]
+ x_packed = pack_cols(x_unpacked, order_map=[0, 2, 4, 6, 1, 3, 5, 7])
+ x.data = x_packed.contiguous()
+ x._packed_dim = 1
return x
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
- x.data = marlin_permute_scales(
- x.data.contiguous(),
- size_k=c.partition_weight_shape[0],
- size_n=c.partition_weight_shape[1],
- group_size=c.group_size,
- is_a_8bit=is_a_8bit,
- )
+ x.data = x.data.contiguous()
+ return x.to(dtype=c.act_type)
- if c.group_size == -1:
- num_groups = 1
- else:
- num_groups = c.partition_weight_shape[0] // c.group_size
-
- if c.act_type == torch.int8 and num_groups > 1:
- x.data, input_global_scale = marlin_act_int8_process_scales(x.data)
- layer.register_parameter(
- "input_global_scale",
- torch.nn.Parameter(input_global_scale, requires_grad=False),
- )
- else:
- layer.input_global_scale = None
+ # if c.has_g_idx:
+ # g_idx, g_idx_sort_indices = marlin_sort_g_idx(
+ # getattr(layer, self.w_gidx_name)
+ # )
+ # self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
+ # layer.g_idx_sort_indices = g_idx_sort_indices
+ # else:
+ # setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
+ # layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
+ def transform_w_zp(x):
+ grouped_k = (c.partition_weight_shape[0] //
+ c.group_size if c.group_size != -1 else 1)
+ x_unpacked = unpack_cols(x.clone().t(), c.weight_type.size_bits, grouped_k, c.partition_weight_shape[1])
+ x_packed = pack_cols(x_unpacked, order_map=[0, 2, 4, 6, 1, 3, 5, 7])
+ x.data = x_packed.contiguous()
return x
-
- if c.has_g_idx:
- g_idx, g_idx_sort_indices = marlin_sort_g_idx(
- getattr(layer, self.w_gidx_name)
- )
- self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
- layer.g_idx_sort_indices = g_idx_sort_indices
- else:
- setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
- layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
+
if c.zero_points:
- grouped_k = (
- c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1
- )
- self._transform_param(
- layer,
- self.w_zp_name,
- lambda x: marlin_zero_points(
- unpack_cols(
- x.t(),
- c.weight_type.size_bits,
- grouped_k,
- c.partition_weight_shape[1],
- ),
- size_k=grouped_k,
- size_n=c.partition_weight_shape[1],
- num_bits=c.weight_type.size_bits,
- is_a_8bit=is_a_8bit,
- ),
- )
+ # grouped_k = (
+ # c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1
+ # )
+ # self._transform_param(
+ # layer,
+ # self.w_zp_name,
+ # lambda x: marlin_zero_points(
+ # unpack_cols(
+ # x.t(),
+ # c.weight_type.size_bits,
+ # grouped_k,
+ # c.partition_weight_shape[1],
+ # ),
+ # size_k=grouped_k,
+ # size_n=c.partition_weight_shape[1],
+ # num_bits=c.weight_type.size_bits,
+ # ),
+ # )
+ self._transform_param(layer, self.w_zp_name, transform_w_zp)
else:
- setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
+ # setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
+ #weight_type = uint4b8, using c.weight_type.bias as zero point,according quant method.
+ #[ic, oc]-> [ic, oc//8]
+ w_zp = torch.full_like(getattr(layer, self.w_s_name), c.weight_type.bias, dtype=torch.int32)
+ w_zp_pack = pack_cols(w_zp, order_map=[0, 2, 4, 6, 1, 3, 5, 7]).contiguous()
+ weight_zero_point = torch.nn.Parameter(
+ w_zp_pack,
+ requires_grad=False)
+
+ if hasattr(layer, self.w_zp_name):
+ replace_parameter(layer, self.w_zp_name, weight_zero_point) #GPTQMarlinLinearMethod
+ else:
+ layer.register_parameter("weight_zero_point", weight_zero_point) #CompressedTensorsWNA16
+
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
- if hasattr(layer, "bias") and layer.bias is not None:
- layer.bias.data = marlin_permute_bias(layer.bias)
+ # if hasattr(layer, "bias") and layer.bias is not None:
+ # layer.bias.data = marlin_permute_bias(layer.bias)
def apply_weights(
self,
@@ -179,22 +302,39 @@ class MarlinLinearKernel(MPLinearKernel):
c = self.config
w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)
- # `process_weights_after_loading` will ensure w_zp and w_gidx are not
- # None for marlin
+ pack_factor = 32 // c.weight_type.size_bits
+
+ out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
+ x_2d = x.reshape(-1, x.shape[-1])
+
+ if c.has_g_idx:
+ x_2d = self.act_perm(x_2d)
+
+ out = ops.custom_gptq_marlin_gemm(input = x_2d,
+ qweight = w_q,
+ scales = w_s,
+ qzeros = w_zp,
+ pack_factor = pack_factor,
+ group_size = c.group_size,
+ bias = bias)
+ out = out.reshape(out_shape)
+ # if bias is not None:
+ # out.add_(bias)
+ return out
+
- return apply_gptq_marlin_linear(
- input=x,
- weight=w_q,
- weight_scale=w_s,
- weight_zp=w_zp, # type: ignore
- g_idx=w_gidx, # type: ignore
- g_idx_sort_indices=layer.g_idx_sort_indices,
- workspace=self.workspace,
- wtype=c.weight_type,
- input_size_per_partition=c.partition_weight_shape[0],
- output_size_per_partition=c.partition_weight_shape[1],
- is_k_full=self.is_k_full,
- input_global_scale=getattr(layer, "input_global_scale", None),
- bias=bias,
- input_dtype=c.act_type,
- )
+ # # `process_weights_after_loading` will ensure w_zp and w_gidx are not
+ # # None for marlin
+ # return apply_gptq_marlin_linear(
+ # input=x,
+ # weight=w_q,
+ # weight_scale=w_s,
+ # weight_zp=w_zp, # type: ignore
+ # g_idx=w_gidx, # type: ignore
+ # g_idx_sort_indices=layer.g_idx_sort_indices,
+ # workspace=self.workspace,
+ # wtype=c.weight_type,
+ # input_size_per_partition=c.partition_weight_shape[0],
+ # output_size_per_partition=c.partition_weight_shape[1],
+ # is_k_full=self.is_k_full,
+ # bias=bias)
diff --git a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py
index e4b14ea..46dd4c6 100644
--- a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py
+++ b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py
@@ -18,7 +18,6 @@ from .ScaledMMLinearKernel import (
Int8ScaledMMLinearLayerConfig,
)
-import vllm.envs as envs
class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
@classmethod
@@ -38,13 +37,28 @@ class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
config = self.config
# WEIGHT
# Cutlass kernels need transposed weight.
- weight = getattr(layer, w_q_name)
- replace_parameter(
- layer,
- w_q_name,
- # torch.nn.Parameter(weight.t().data, requires_grad=False),
- torch.nn.Parameter(weight.data if envs.VLLM_W8A8_LINEAR_USE_W4A8 else weight.t().data, requires_grad=False),
- )
+ weight = getattr(layer, w_q_name)
+ if layer.scheme.is_w4a8_linear:
+ self.format = "NN"
+ replace_parameter(layer, w_q_name, torch.nn.Parameter(weight.data.contiguous(), requires_grad=False))
+ else:
+ self.format = "TN" #默认weight都是按T排布
+ m, k = weight.shape
+ if(m % 64 == 0 and k % 64 == 0):
+ self.format= "NN"
+ replace_parameter(
+ layer, w_q_name,
+ torch.nn.Parameter(weight.t().data.contiguous(), requires_grad=False))#原始排布是T[m,k] 处理完后是N[k, m]
+ else:
+ if k % 64 != 0:
+ pad_k = (k // 64 + 1) * 64
+ weight_pad = torch.empty((m, pad_k), dtype=weight.dtype, device=weight.device)
+ _weight = weight_pad[:, :k]
+ _weight.copy_(weight)
+ weight = _weight
+ replace_parameter(
+ layer, w_q_name,
+ torch.nn.Parameter(weight.t(), requires_grad=False))
# WEIGHT SCALE
# Cutlass kernels support only per-tensor and per-channel.
@@ -114,6 +128,7 @@ class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
+ is_w4a8_linear: bool = False,
) -> torch.Tensor:
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
@@ -121,9 +136,15 @@ class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.
symmetric = azp_adj is None
- x_q, x_s, x_zp = ops.scaled_int8_quant(
- x.contiguous(), i_s, i_zp, symmetric=symmetric
- )
+ if isinstance(x, tuple):
+ x_q, x_s, out_dtype = x
+ x_zp = None
+ else:
+ out_dtype = x.dtype
+ x_q, x_s, x_zp = ops.scaled_int8_quant(x.contiguous(),
+ i_s,
+ i_zp,
+ symmetric=symmetric)
if x_zp is not None:
# Currently, static is always per-tensor and dynamic is per-token
@@ -134,14 +155,21 @@ class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
w_q,
scale_a=x_s,
scale_b=w_s,
- out_dtype=x.dtype,
+ out_dtype=out_dtype,
azp_adj=azp_adj,
azp=azp,
bias=bias,
)
+ if self.format == "NN" and x_q.shape[-1] != w_q.shape[0]:
+ padding = w_q.shape[0] - x_q.shape[-1]
+ x_align = torch.nn.functional.pad(x_q, (0, padding), mode='constant', value=0)
+ elif self.format == "TN" and x_q.shape[-1] != w_q.shape[-1]:
+ padding = w_q.shape[-1] - x_q.shape[-1]
+ x_align = torch.nn.functional.pad(x_q, (0, padding), mode='constant', value=0)
+ else:
+ x_align = x_q
return ops.cutlass_scaled_mm(
- # x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
- x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias, format="NN" if envs.VLLM_W8A8_LINEAR_USE_W4A8 else "TN"
+ x_align, w_q, scale_a=x_s, scale_b=w_s, out_dtype=out_dtype, bias=bias, format=self.format, is_w4a8_linear=is_w4a8_linear
)
diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py
index 11cdc73..aee1621 100644
--- a/vllm/model_executor/layers/activation.py
+++ b/vllm/model_executor/layers/activation.py
@@ -8,6 +8,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
+from vllm import envs
from vllm.distributed import (
divide,
get_tensor_model_parallel_rank,
@@ -130,13 +131,12 @@ class SiluAndMul(CustomOp):
def __init__(self, *, compile_native: bool = True):
super().__init__(compile_native=compile_native)
- if current_platform.is_cuda_alike():
+ if current_platform.is_cuda_alike() or current_platform.is_xpu():
from vllm import _custom_ops as ops
- self.op = ops.silu_and_mul
- elif current_platform.is_xpu():
- from vllm._ipex_ops import ipex_ops
-
- self.op = ipex_ops.silu_and_mul
+ if envs.VLLM_USE_SILU_QUANT_FUSION:
+ self.op = ops.silu_and_mul_quant
+ else:
+ self.op = ops.silu_and_mul
elif current_platform.is_cpu():
self._forward_method = self.forward_native
@@ -146,11 +146,15 @@ class SiluAndMul(CustomOp):
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
- def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
+ def forward_cuda(self, x: torch.Tensor, out_dim: int = 0) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
- out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
- self.op(out, x)
+ if envs.VLLM_USE_SILU_QUANT_FUSION:
+ quant_out, out_scales = self.op(x, out_dim)
+ out = (quant_out, out_scales, x.dtype)
+ else:
+ out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
+ self.op(out, x)
return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
@@ -174,7 +178,6 @@ class MulAndSilu(CustomOp):
def __init__(self):
super().__init__()
if current_platform.is_cuda_alike() or current_platform.is_xpu():
- # self.op = torch.ops._C.mul_and_silu
from vllm import _custom_ops as ops
self.op = ops.mul_and_silu
elif current_platform.is_cpu():
@@ -397,7 +400,6 @@ class NewGELU(CustomOp):
or current_platform.is_cpu()
or current_platform.is_xpu()
):
- # self.op = torch.ops._C.gelu_new
from vllm import _custom_ops as ops
self.op = ops.gelu_new
@@ -427,7 +429,8 @@ class FastGELU(CustomOp):
or current_platform.is_cpu()
or current_platform.is_xpu()
):
- self.op = torch.ops._C.gelu_fast
+ from vllm import _custom_ops as ops
+ self.op = ops.gelu_fast
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
@@ -455,7 +458,6 @@ class QuickGELU(CustomOp):
or current_platform.is_cpu()
or current_platform.is_xpu()
):
- # self.op = torch.ops._C.gelu_quick
from vllm import _custom_ops as ops
self.op = ops.gelu_quick
diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py
index 38f1099..2948d46 100644
--- a/vllm/model_executor/layers/attention/attention.py
+++ b/vllm/model_executor/layers/attention/attention.py
@@ -12,7 +12,7 @@ from vllm.config.vllm import VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.kv_transfer_utils import (
- maybe_transfer_kv_layer,
+ maybe_transfer_kv_layer
)
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
@@ -40,6 +40,9 @@ from vllm.v1.kv_cache_interface import (
KVCacheSpec,
SlidingWindowSpec,
)
+from .extra_cache import StaticQuantManager
+from ixformer.core import config
+_USE_TORCH_OPS = config.IXFORMER_USE_TORCH_OPS
if TYPE_CHECKING:
from vllm.model_executor.layers.attention import MLAAttention
@@ -202,6 +205,7 @@ class Attention(nn.Module, AttentionLayerBase):
kv_sharing_target_layer_name: str | None = None,
attn_backend: type[AttentionBackend] | None = None,
head_size_v: int | None = None,
+ extra_cache_para: dict = None,
**extra_impl_args,
) -> None:
"""
@@ -258,6 +262,7 @@ class Attention(nn.Module, AttentionLayerBase):
self.num_heads = num_heads
self.head_size = head_size
+ self.hidden_size = head_size * num_heads
self.head_size_v = self.head_size if head_size_v is None else head_size_v
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
@@ -326,6 +331,15 @@ class Attention(nn.Module, AttentionLayerBase):
kv_sharing_target_layer_name,
**extra_impl_args,
)
+ if extra_cache_para is not None:
+ self.quant_manager = StaticQuantManager(
+ layer_id=extra_cache_para.get("layer_id", None),
+ shape=(self.num_kv_heads, self.head_size_v),
+ dtype=torch.float32,
+ total_layer_num=extra_cache_para.get("total_layer_num", None)
+ )
+ else:
+ self.quant_manager = None
self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
self.dtype = dtype
@@ -333,7 +347,10 @@ class Attention(nn.Module, AttentionLayerBase):
# torch.compile works by registering the attention as one giant
# opaque custom op. For other platforms, we directly call them
# and let torch.compile handle them.
- self.use_direct_call = not current_platform.opaque_attention_op()
+ if _USE_TORCH_OPS:
+ self.use_direct_call = False
+ else:
+ self.use_direct_call = True
self.use_output = self.attn_backend.accept_output_buffer
compilation_config = vllm_config.compilation_config
@@ -349,14 +366,26 @@ class Attention(nn.Module, AttentionLayerBase):
compilation_config.static_forward_context,
)
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
-
- # use a placeholder kv cache tensor during init, which will be replaced
- # by bind_kv_cache
- # this variable will not be accessed if use_direct_call is True
self.kv_cache = [
torch.tensor([])
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
]
+ self.is_i8qi8ki8v = envs.VLLM_ATTN_OPT_LEVEL == 1
+ self.is_i8qi8kf16v = envs.VLLM_ATTN_OPT_LEVEL == 2
+ if self.is_i8qi8kf16v:
+ self.kv_cache_scale = [
+ torch.tensor([]) for _ in range(get_current_vllm_config(
+ ).parallel_config.pipeline_parallel_size)
+ ]
+ elif self.is_i8qi8ki8v:
+ self.kv_cache_scale = [
+ [torch.tensor([]), torch.tensor([])] for _ in range(get_current_vllm_config(
+ ).parallel_config.pipeline_parallel_size)
+ ]
+
+ # use a placeholder kv cache tensor during init, which will be replaced
+ # by bind_kv_cache
+ # this variable will not be accessed if use_direct_call is True
# Initialize KV cache quantization attributes
_init_kv_cache_quant(self, quant_config, prefix)
@@ -396,6 +425,7 @@ class Attention(nn.Module, AttentionLayerBase):
context using
`vllm.forward_context.get_forward_context().attn_metadata`.
"""
+ optional_args = {}
if self.calculate_kv_scales:
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
output_dtype = query.dtype
@@ -412,15 +442,8 @@ class Attention(nn.Module, AttentionLayerBase):
query, _ = self.query_quant(query, self._q_scale)
if self.use_output:
- if output_shape is None:
- # Handle both 2D [num_tokens, hidden] and
- # 3D [num_tokens, heads, head_dim] query
- num_tokens = query.shape[0]
- output_shape = torch.Size(
- (num_tokens, self.num_heads * self.head_size_v)
- )
+ output_shape = output_shape if output_shape is not None else query.shape
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
- hidden_size = output_shape[-1]
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
# CPU overheads from the non-CUDA-graph regions.
@@ -430,46 +453,50 @@ class Attention(nn.Module, AttentionLayerBase):
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size_v)
- kv_cache_dummy_dep = None
if self.use_direct_call:
- # Skip this if sharing KV cache with an earlier attention layer.
- if (
- not self.attn_backend.forward_includes_kv_cache_update
- and self.kv_sharing_target_layer_name is None
- and key is not None
- and value is not None
- ):
- kv_cache_dummy_dep = unified_kv_cache_update(
- key, value, self.layer_name
+ def direct_forward(layer_name: str, output: torch.Tensor):
+ forward_context: ForwardContext = get_forward_context()
+ attn_metadata = forward_context.attn_metadata
+ if isinstance(attn_metadata, dict):
+ attn_metadata = attn_metadata[layer_name]
+ self_kv_cache = self.kv_cache[forward_context.virtual_engine]
+ # Skip this if sharing KV cache with an earlier attention layer.
+ if self.is_i8qi8ki8v or self.is_i8qi8kf16v:
+ optional_args["kv_cache_scale"] = self.kv_cache_scale[forward_context.virtual_engine]
+ output = self.impl.forward(
+ self,
+ query,
+ key,
+ value,
+ self_kv_cache,
+ attn_metadata,
+ output=output,
+ **optional_args
)
- unified_attention_with_output(
- query,
- key,
- value,
- output,
- self.layer_name,
- kv_cache_dummy_dep=kv_cache_dummy_dep,
- )
+ return output
+ return maybe_transfer_kv_layer(direct_forward)(self.layer_name, output)
else:
- # Skip this if sharing KV cache with an earlier attention layer.
- if (
- not self.attn_backend.forward_includes_kv_cache_update
- and self.kv_sharing_target_layer_name is None
- and key is not None
- and value is not None
- ):
- kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
- key, value, self.layer_name
- )
+ if self.is_i8qi8ki8v:
+ forward_context: ForwardContext = get_forward_context()
+ kv_cache_scale = self.kv_cache_scale[forward_context.virtual_engine][0]
+ v_cache_scale = self.kv_cache_scale[forward_context.virtual_engine][1]
+ elif self.is_i8qi8kf16v:
+ forward_context: ForwardContext = get_forward_context()
+ kv_cache_scale = self.kv_cache_scale[forward_context.virtual_engine]
+ v_cache_scale = None
+ else:
+ kv_cache_scale = None
+ v_cache_scale = None
torch.ops.vllm.unified_attention_with_output(
query,
key,
value,
output,
self.layer_name,
- kv_cache_dummy_dep=kv_cache_dummy_dep,
+ kv_cache_scale,
+ v_cache_scale
)
- return output.view(-1, hidden_size)
+ return output.view(-1, self.hidden_size)
else:
assert self.attn_backend.forward_includes_kv_cache_update, (
"Split KV cache update not supported when output tensor not provided."
@@ -521,6 +548,7 @@ class Attention(nn.Module, AttentionLayerBase):
block_size = vllm_config.cache_config.block_size
# Should not be called for enc-dec or encoder-only attention.
assert self.attn_type == AttentionType.DECODER
+ # TODO : kernel unsupport kvcache for sliding_window, use FullAttentionSpec replace
if self.sliding_window is not None:
assert not vllm_config.model_config.use_mla, (
"MLA is not supported for slidingwindow"
@@ -689,6 +717,8 @@ def unified_attention_with_output(
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
+ kv_cache_scale: torch.Tensor | None = None,
+ v_cache_scale: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
kv_cache_dummy_dep: torch.Tensor | None = None,
@@ -696,9 +726,7 @@ def unified_attention_with_output(
# kv_cache_dummy_dep is not used but accepting it creates a data dependency
# that ensures torch.compile preserves ordering between KV cache update and
# attention forward.
- del kv_cache_dummy_dep
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
-
self.impl.forward(
self,
query,
@@ -707,6 +735,7 @@ def unified_attention_with_output(
kv_cache,
attn_metadata,
output=output,
+ kv_cache_scale = [kv_cache_scale, v_cache_scale] if envs.VLLM_ATTN_OPT_LEVEL==1 else kv_cache_scale,
output_scale=output_scale,
output_block_scale=output_block_scale,
)
@@ -718,6 +747,8 @@ def unified_attention_with_output_fake(
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
+ kv_cache_scale: torch.Tensor | None = None,
+ v_cache_scale: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
kv_cache_dummy_dep: torch.Tensor | None = None,
diff --git a/vllm/model_executor/layers/attention/extra_cache.py b/vllm/model_executor/layers/attention/extra_cache.py
new file mode 100644
index 0000000..36d2c1f
--- /dev/null
+++ b/vllm/model_executor/layers/attention/extra_cache.py
@@ -0,0 +1,131 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import os
+
+import torch
+from filelock import FileLock
+
+import vllm.envs as envs
+from vllm.distributed import (
+ get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+)
+from vllm.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+class StaticQuantManager:
+ def __init__(
+ self,
+ layer_id: int,
+ shape: tuple,
+ dtype: torch.dtype,
+ total_layer_num: int,
+ device: str = None,
+ tp_size: int = None,
+ tp_rank: int = None,
+ file_save_path: str = None,
+ save_step: int = 100,
+ info_step: int = 100,
+ ):
+ # update parament
+ if tp_size is None:
+ tp_size = get_tensor_model_parallel_world_size()
+ if tp_rank is None:
+ tp_rank = get_tensor_model_parallel_rank()
+ if file_save_path is None:
+ file_save_path = envs.VLLM_ATTN_STATIC_QUANT_SCALE_FILE_PATH
+ if device is None:
+ device = "cuda"
+
+ # check parament
+ if file_save_path in [None, ""]:
+ self.disable = True
+ return
+
+ para_dir = os.path.dirname(file_save_path)
+ assert os.path.exists(para_dir), (
+ f"StaticQuantManager workdir {para_dir} not exist!"
+ )
+ self.disable = os.path.exists(file_save_path)
+ if self.disable:
+ return
+
+ assert layer_id is not None
+ assert total_layer_num is not None
+
+ world_rank = torch.distributed.get_rank()
+ work_dir = os.path.join(para_dir, "StaticQuantManagerWorkdir")
+ self.operator = world_rank == 0 and layer_id == 0
+ if not os.path.exists(work_dir):
+ if self.operator:
+ logger.debug(f"StaticQuantManager Creat {work_dir}!")
+ os.mkdir(work_dir)
+ self.file_save_path = file_save_path
+ self.work_dir = work_dir
+
+ self.tp_size = tp_size
+ self.tp_rank = tp_rank
+ self.world_rank = world_rank
+ self.layer_id = layer_id
+ self.total_layer_num = total_layer_num
+ self.save_step = save_step
+ self.info_step = info_step
+ self.update_count = 0
+ self.save_flag = False
+ self.scales = torch.zeros(shape, dtype=dtype, device=device)
+ logger.debug(
+ f"StaticQuantManager info: world_rank:{self.world_rank} tp_rank:{self.tp_rank} layer_id:{self.layer_id} scale shape:{shape} self.scales:{self.scales.device}"
+ )
+
+ def check_enable(self):
+ return not self.disable
+
+ def update_data(self, data):
+ if self.disable:
+ return
+
+ self.scales = torch.max(data, self.scales)
+
+ # save file
+ self.update_count += 1
+ if self.update_count % self.info_step == 0 and self.operator:
+ logger.info(f"StaticQuantManager run update_data {self.update_count} step")
+
+ if self.update_count % self.save_step == 0:
+ # step1: save to disk
+ save_file_path = os.path.join(
+ self.work_dir, f"{self.layer_id}_{self.tp_rank}.pt"
+ )
+ lock_file_path = os.path.join(
+ self.work_dir, f"{self.layer_id}_{self.tp_rank}.lock"
+ )
+ lock = FileLock(lock_file_path)
+ cpu_data = self.scales.cpu()
+ with lock:
+ torch.save(cpu_data, save_file_path)
+
+ # step2: merge and save
+ if self.save_flag and self.operator:
+ save_dict = {}
+ for idx in range(self.total_layer_num):
+ tp_datas = []
+ for tp_rank in range(self.tp_size):
+ load_file = os.path.join(self.work_dir, f"{idx}_{tp_rank}.pt")
+ lock_file_path = os.path.join(
+ self.work_dir, f"{idx}_{tp_rank}.lock"
+ )
+ lock = FileLock(lock_file_path)
+ with lock:
+ cur_data = torch.load(load_file)
+ tp_datas.append(cur_data)
+
+ layer_data = torch.concat(tp_datas)
+ save_dict[f"layer_{idx}"] = layer_data
+
+ torch.save(save_dict, self.file_save_path)
+ logger.info(
+ f"StaticQuantManager save to {self.file_save_path} with {self.update_count} step"
+ )
+ self.save_flag = True
diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py
index f9cfcd2..2b74e56 100644
--- a/vllm/model_executor/layers/attention/mla_attention.py
+++ b/vllm/model_executor/layers/attention/mla_attention.py
@@ -191,7 +191,7 @@ import functools
from abc import abstractmethod
from dataclasses import dataclass, field
from enum import Enum
-from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar, cast
+from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar, cast, Any
if TYPE_CHECKING:
from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
@@ -221,8 +221,6 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
- UnquantizedLinearMethod,
- LinearBase,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
@@ -230,6 +228,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
get_and_maybe_dequant_weights,
)
+from vllm.model_executor.layers.linear import (
+ ColumnParallelLinear,
+ LinearBase,
+ UnquantizedLinearMethod,
+)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer, has_nvidia_artifactory
from vllm.utils.math_utils import cdiv, round_down
@@ -245,14 +248,16 @@ from vllm.v1.attention.backend import (
AttentionType,
CommonAttentionMetadata,
MLAAttentionImpl,
+ SparseMLAAttentionImpl,
)
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
from vllm.v1.attention.backends.utils import (
get_dcp_local_seq_lens,
get_per_layer_parameters,
infer_global_hyperparameters,
- split_decodes_and_prefills,
+ split_decodes_and_prefills
)
+from vllm.v1.attention.backend import AttentionCGSupport
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
from vllm.v1.attention.selector import get_attn_backend
@@ -262,14 +267,15 @@ from vllm.v1.kv_cache_interface import (
MLAAttentionSpec,
)
-from vllm.v1.attention.backend import AttentionCGSupport
-
logger = init_logger(__name__)
class MLAAttention(nn.Module, AttentionLayerBase):
"""Multi-Head Latent Attention layer.
+ NOTE: Please read the comment at the top of the file before trying to
+ understand this class
+
This class takes query, and compressed key/value tensors as input.
The class does the following:
@@ -293,6 +299,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
prefix: str = "",
use_sparse: bool = False,
indexer: object | None = None,
+ rotary_emb: object | None = None,
**extra_impl_args,
):
super().__init__()
@@ -303,8 +310,20 @@ class MLAAttention(nn.Module, AttentionLayerBase):
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
+ self.kv_b_proj = kv_b_proj
self.head_size = kv_lora_rank + qk_rope_head_dim
self.layer_name = prefix
+ self.indexer = indexer
+ self.rotary_emb = rotary_emb
+ speculative_config = get_current_vllm_config().speculative_config
+ self.use_spec_decode = (
+ speculative_config is not None
+ and speculative_config.num_speculative_tokens is not None
+ and speculative_config.num_speculative_tokens > 0
+ )
+
+ self.num_kv_heads = 1
+ self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
@@ -329,6 +348,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
block_size,
use_mla=True,
use_sparse=use_sparse,
+ num_heads=self.num_heads,
)
if (
@@ -370,8 +390,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
indexer=indexer,
**extra_impl_args,
)
-
- # self.use_direct_call = not current_platform.opaque_attention_op()
+ self.q_pad_num_heads = getattr(self.impl, "q_pad_num_heads", None)
self.use_direct_call = True
compilation_config = get_current_vllm_config().compilation_config
@@ -385,6 +404,12 @@ class MLAAttention(nn.Module, AttentionLayerBase):
get_current_vllm_config().parallel_config.pipeline_parallel_size
)
]
+ if envs.VLLM_USE_INT8_MLA:
+ self.kv_cache_scale = [
+ torch.tensor([]) for _ in range(get_current_vllm_config(
+ ).parallel_config.pipeline_parallel_size)
+ ]
+ self.is_int8_mla = envs.VLLM_USE_INT8_MLA
self.use_sparse = use_sparse
@@ -393,85 +418,76 @@ class MLAAttention(nn.Module, AttentionLayerBase):
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
- def forward_corex(
- self,
- q: torch.Tensor,
- kv_c_normed: torch.Tensor,
- k_pe: torch.Tensor,
- output_shape: torch.Size | None = None,
- ) -> torch.Tensor:
- if self.calculate_kv_scales:
- torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)
+ self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
- optional_args = {}
- output_shape = output_shape if output_shape is not None else q.shape
- output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
+ # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
+ self.is_aiter_triton_fp4_bmm_enabled = (
+ rocm_aiter_ops.is_fp4bmm_enabled()
+ and self.kv_b_proj.weight.dtype == torch.bfloat16
+ )
- if self.use_direct_call:
- forward_context: ForwardContext = get_forward_context()
- attn_metadata = forward_context.attn_metadata
- if isinstance(attn_metadata, dict):
- attn_metadata = attn_metadata[self.layer_name]
- self_kv_cache = self.kv_cache[forward_context.virtual_engine]
- if self.attn_backend.accept_output_buffer:
- output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
- output = self.impl.forward(
- self,
- q,
- kv_c_normed,
- k_pe,
- self_kv_cache,
- attn_metadata,
- output=output,
- **optional_args,
- )
- else:
- output = self.impl.forward(
- self,
- q,
- kv_c_normed,
- k_pe,
- self_kv_cache,
- attn_metadata,
- **optional_args,
- )
- return output
+ # Attributes for forward_impl method
+ self.chunked_prefill_workspace_size = (
+ MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
+ get_current_vllm_config()
+ )
+ )
+ self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
+ static=True,
+ group_shape=GroupShape.PER_TENSOR,
+ compile_native=True,
+ )
def forward(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
+ position: torch.Tensor,
output_shape: torch.Size | None = None,
) -> torch.Tensor:
- return self.forward_corex(q, kv_c_normed, k_pe, output_shape)
+ optional_args = {}
if self.calculate_kv_scales:
torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name)
if self.use_direct_call:
- forward_context: ForwardContext = get_forward_context()
- attn_metadata = forward_context.attn_metadata
- if isinstance(attn_metadata, dict):
- attn_metadata = attn_metadata[self.layer_name]
- self_kv_cache = self.kv_cache[forward_context.virtual_engine]
-
- if self.attn_backend.accept_output_buffer:
- output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
- self.impl.forward(
- self,
- q,
- kv_c_normed,
- k_pe,
- self_kv_cache,
- attn_metadata,
- output=output,
- )
- return output
- else:
- return self.impl.forward(
- self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata
- )
+ def direct_forward(layer_name: str, output_shape: torch.Size | None):
+ forward_context: ForwardContext = get_forward_context()
+ attn_metadata = forward_context.attn_metadata
+ if isinstance(attn_metadata, dict):
+ attn_metadata = attn_metadata[layer_name]
+ self_kv_cache = self.kv_cache[forward_context.virtual_engine]
+ if self.is_int8_mla:
+ optional_args["kv_cache_scale"] = self.kv_cache_scale[forward_context.virtual_engine]
+ if self.attn_backend.accept_output_buffer:
+ output_shape = (output_shape if output_shape is not None else q.shape)
+ output = torch.zeros(output_shape,
+ dtype=q.dtype,
+ device=q.device)
+ output = self.forward_impl(
+ q,
+ kv_c_normed,
+ k_pe,
+ self_kv_cache,
+ attn_metadata,
+ output=output,
+ positions=position,
+ **optional_args
+ )
+ return output
+ else:
+ return self.forward_impl(
+ q, kv_c_normed, k_pe, self_kv_cache, attn_metadata, position=position
+ )
+ return maybe_transfer_kv_layer(direct_forward)(self.layer_name, output_shape)
else:
+ kv_cache_dummy_dep = torch.ops.vllm.unified_mla_kv_cache_update(
+ kv_c_normed,
+ k_pe,
+ self.layer_name,
+ self.kv_cache_dtype,
+ self._k_scale,
+ )
if self.attn_backend.accept_output_buffer:
output = torch.empty(output_shape, dtype=q.dtype, device=q.device)
torch.ops.vllm.unified_mla_attention_with_output(
@@ -480,6 +496,7 @@ class MLAAttention(nn.Module, AttentionLayerBase):
k_pe,
output,
self.layer_name,
+ kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return output
else:
@@ -488,9 +505,210 @@ class MLAAttention(nn.Module, AttentionLayerBase):
kv_c_normed,
k_pe,
self.layer_name,
+ kv_cache_dummy_dep=kv_cache_dummy_dep,
)
+ def forward_impl(
+ self,
+ q: torch.Tensor,
+ k_c_normed: torch.Tensor, # key in unified attn
+ k_pe: torch.Tensor, # value in unified attn
+ kv_cache: torch.Tensor,
+ attn_metadata: "MLACommonMetadata",
+ output: torch.Tensor | None = None,
+ positions: torch.Tensor | None = None,
+ kv_cache_scale: torch.Tensor | None = None,
+ output_scale: torch.Tensor | None = None,
+ output_block_scale: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ assert output is not None, "Output tensor must be provided."
+ assert positions is not None, "positions tensor must be provided."
+
+ if output_scale is not None or output_block_scale is not None:
+ raise NotImplementedError(
+ "fused output quantization is not yet supported for MLA"
+ )
+
+ if attn_metadata is None:
+ # During the profile run try to simulate to worse case output size
+ # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
+ # since this can be large
+ _ = torch.empty(
+ (
+ self.chunked_prefill_workspace_size,
+ self.num_heads,
+ self.qk_nope_head_dim + self.v_head_dim,
+ ),
+ device=k_c_normed.device,
+ dtype=k_c_normed.dtype,
+ )
+
+ # The zero fill is required when used with DP + EP
+ # to ensure all ranks within a DP group compute the
+ # same expert outputs.
+ output = torch.empty(output.shape[0], self.v_head_dim * self.num_heads, device=q.device,
+ dtype=q.dtype)
+ return output
+
+ if self.impl.dcp_world_size == -1:
+ self.impl.dcp_world_size = get_dcp_group().world_size
+
+ fp8_attention = self.kv_cache_dtype.startswith("fp8")
+
+ is_sparse_impl = isinstance(self.impl, SparseMLAAttentionImpl)
+
+ if not is_sparse_impl:
+ assert (
+ attn_metadata.num_decodes is not None
+ and attn_metadata.num_prefills is not None
+ and attn_metadata.num_decode_tokens is not None
+ )
+
+ has_decode = attn_metadata.num_decodes > 0
+ has_prefill = attn_metadata.num_prefills > 0
+ num_decode_tokens = attn_metadata.num_decode_tokens
+
+ decode_q = q[:num_decode_tokens]
+ k_pe = k_pe.unsqueeze(1)
+
+ prefill_q = q[num_decode_tokens:]
+ prefill_k_pe = k_pe[num_decode_tokens:]
+ prefill_k_c_normed = k_c_normed[num_decode_tokens:]
+ prefill_k = torch.empty_like(prefill_q)
+
+ # write the latent and rope to kv cache
+ write_kv_cache = (None, None)
+ if kv_cache.numel() > 0:
+ if has_decode:
+ decode_q_pe, decode_k_pe = self.rotary_emb(positions[:num_decode_tokens], decode_q[..., self.qk_nope_head_dim:], k_pe[:num_decode_tokens],)
+ if envs.VLLM_USE_INT8_MLA:
+ k_c_normed_int8, k_c_normed_scale,_ = ops.scaled_int8_quant(k_c_normed[:num_decode_tokens])
+ decode_k_pe_int8, decode_k_pe_scale,_ = ops.scaled_int8_quant(decode_k_pe.contiguous())
+ ops.concat_and_cache_mla_int8(
+ kv_c_int8 = k_c_normed_int8,
+ kv_c_scale = k_c_normed_scale[...,0],
+ k_pe_int8 = decode_k_pe_int8,
+ k_pe_scale = decode_k_pe_scale[...,0].view(-1,decode_k_pe_int8.shape[-2]),
+ kv_cache = kv_cache,
+ kv_cache_scale = kv_cache_scale,
+ slot_mapping = attn_metadata.slot_mapping.flatten()[:num_decode_tokens],
+ kv_cache_dtype=self.kv_cache_dtype,
+ scale=self._k_scale,
+ )
+ else:
+ if self.impl.dcp_world_size > 1 or self.use_spec_decode:
+ ops.concat_and_cache_mla(
+ k_c_normed[:num_decode_tokens],
+ decode_k_pe,
+ kv_cache,
+ attn_metadata.slot_mapping.flatten()[:num_decode_tokens],
+ kv_cache_dtype=self.kv_cache_dtype,
+ scale=self._k_scale,
+ )
+ else:
+ write_kv_cache = (k_c_normed[:num_decode_tokens], decode_k_pe)
+ if has_prefill:
+ ixf_ops.mla_rope(positions[num_decode_tokens:], prefill_q[..., self.qk_nope_head_dim:], prefill_k_pe.squeeze(1), prefill_k[...,self.qk_nope_head_dim:], self.rotary_emb.cos_sin_cache)
+ if envs.VLLM_USE_INT8_MLA:
+ prefill_k_c_normed_int8, prefill_k_c_normed_scale,_ = ops.scaled_int8_quant(prefill_k_c_normed)
+ prefill_k_pe_int8, prefill_k_pe_scale,_ = ops.scaled_int8_quant(prefill_k[...,self.qk_nope_head_dim:].contiguous())
+ ops.concat_and_cache_mla_int8(
+ prefill_k_c_normed_int8,
+ prefill_k_c_normed_scale[...,0],
+ prefill_k_pe_int8,
+ prefill_k_pe_scale[...,0].view(-1,prefill_k_pe_int8.shape[-2]),
+ kv_cache,
+ kv_cache_scale,
+ attn_metadata.slot_mapping.flatten()[num_decode_tokens:],
+ self.kv_cache_dtype,
+ self._k_scale,
+ )
+ else:
+ ops.concat_and_cache_mla(
+ prefill_k_c_normed,
+ prefill_k[...,self.qk_nope_head_dim:],
+ kv_cache,
+ attn_metadata.slot_mapping.flatten()[num_decode_tokens:],
+ kv_cache_dtype=self.kv_cache_dtype,
+ scale=self._k_scale,
+ )
+ output = torch.empty(output.shape[0],
+ self.num_heads, self.v_head_dim,
+ device=q.device,
+ dtype=q.dtype)
+
+ if fp8_attention and self.kv_cache_dtype != "fp8_ds_mla":
+ kv_cache = kv_cache.view(current_platform.fp8_dtype())
+
+ # Sparse MLA impls only support forward_mqa (decode-style attention)
+
+ if has_prefill:
+ output[num_decode_tokens:] = self.impl.forward_mha(
+ prefill_q,
+ prefill_k_c_normed,
+ prefill_k,
+ kv_cache,
+ attn_metadata,
+ kv_c_and_k_pe_cache_scale=kv_cache_scale
+ )
+
+ if has_decode:
+ # For sparse impl, we always use forward_mqa for all tokens
+ # For non-sparse impl, we only use forward_mqa for decode tokens
+ assert attn_metadata.decode is not None
+ mqa_q = decode_q
+
+ attn_out, lse = self.impl.forward_mqa(
+ mqa_q[..., :self.qk_nope_head_dim], decode_q_pe, kv_cache, attn_metadata,
+ *write_kv_cache, kv_c_and_k_pe_cache_scale=kv_cache_scale)
+ if self.impl.dcp_world_size > 1:
+ assert lse is not None
+ attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
+ output[:num_decode_tokens] = self.impl._v_up_proj(attn_out)
+ else:
+ assert lse is None
+ output[:num_decode_tokens] = attn_out
+ return output.view(output.shape[0], self.v_head_dim * self.num_heads)
+ else:
+ num_actual_toks = attn_metadata.num_actual_tokens
+ # Inputs and outputs may be padded for CUDA graphs
+ k_pe = k_pe.unsqueeze(1)
+ q = q[:num_actual_toks, ...]
+ k_c_normed = k_c_normed[:num_actual_toks, ...]
+ k_pe = k_pe[:num_actual_toks, ...]
+ positions = positions[:num_actual_toks, ...]
+ num_mqa_tokens = q.size(0)
+ if num_mqa_tokens > 0:
+ q = q[:num_mqa_tokens]
+ q_nope, q_pe = q.split(
+ [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
+ )
+ q_pe, k_pe = self.rotary_emb(positions[:num_actual_toks], q_pe, k_pe)
+
+ q_nope = self.impl._k_up_proj(q_nope)
+ q_nope = q_nope.view(-1, self.num_heads, self.kv_lora_rank)
+ q = torch.cat([q_nope, q_pe], dim=-1)
+ if kv_cache.numel() > 0:
+ ops.concat_and_cache_mla(
+ k_c_normed,
+ k_pe,
+ kv_cache,
+ attn_metadata.slot_mapping.flatten(),
+ kv_cache_dtype=self.kv_cache_dtype,
+ scale=self._k_scale,
+ )
+ attn_out= self.impl.forward_mqa(q, kv_cache, attn_metadata, self)
+ output = torch.empty(output.shape[0],
+ self.num_heads, self.v_head_dim,
+ device=q.device,
+ dtype=q.dtype)
+ output[:num_actual_toks] = self.impl._v_up_proj(attn_out)
+
+ return output.view(output.shape[0], self.v_head_dim * self.num_heads)
+
+
def process_weights_after_loading(self, act_dtype: torch.dtype):
+
if hasattr(self.impl, "process_weights_after_loading"):
self.impl.process_weights_after_loading(act_dtype)
@@ -543,15 +761,21 @@ class MLAAttention(nn.Module, AttentionLayerBase):
)
+
@maybe_transfer_kv_layer
def unified_mla_attention(
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
+ kv_cache_dummy_dep: torch.Tensor | None = None,
) -> torch.Tensor:
- attn_metadata, self, kv_cache = get_attention_context(layer_name)
- output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata)
+ # kv_cache_dummy_dep is not used but accepting it creates a data dependency
+ # that ensures torch.compile preserves ordering between KV cache update and
+ # attention forward.
+ del kv_cache_dummy_dep
+ attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name)
+ output = layer.forward_impl(q, kv_c_normed, k_pe, kv_cache, attn_metadata)
return output
@@ -561,6 +785,7 @@ def unified_mla_attention_fake(
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
layer_name: str,
+ kv_cache_dummy_dep: torch.Tensor | None = None,
) -> torch.Tensor:
return torch.empty_like(q).contiguous()
@@ -574,6 +799,56 @@ direct_register_custom_op(
)
+def unified_mla_kv_cache_update(
+ kv_c_normed: torch.Tensor,
+ k_pe: torch.Tensor,
+ layer_name: str,
+ kv_cache_dtype: str,
+ k_scale: torch.Tensor,
+) -> torch.Tensor:
+ """
+ Returns a dummy that is passed to unified_attention to signal a side effect and
+ the data dependency between them to ensure torch.compile preserves ordering.
+ """
+ forward_context = get_forward_context()
+ attn_layer = forward_context.no_compile_layers[layer_name]
+ kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
+
+ slot_mapping = forward_context.slot_mapping
+ assert isinstance(slot_mapping, dict), (
+ f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
+ )
+ layer_slot_mapping = slot_mapping.get(layer_name)
+ if layer_slot_mapping is not None:
+ attn_layer.impl.do_kv_cache_update(
+ kv_c_normed,
+ k_pe,
+ kv_cache,
+ layer_slot_mapping,
+ kv_cache_dtype,
+ k_scale,
+ )
+
+ return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype)
+
+
+def unified_mla_kv_cache_update_fake(
+ kv_c_normed: torch.Tensor,
+ k_pe: torch.Tensor,
+ layer_name: str,
+ kv_cache_dtype: str,
+ k_scale: torch.Tensor,
+) -> torch.Tensor:
+ return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype)
+
+
+direct_register_custom_op(
+ op_name="unified_mla_kv_cache_update",
+ op_func=unified_mla_kv_cache_update,
+ fake_impl=unified_mla_kv_cache_update_fake,
+)
+
+
@maybe_transfer_kv_layer
def unified_mla_attention_with_output(
q: torch.Tensor,
@@ -583,10 +858,14 @@ def unified_mla_attention_with_output(
layer_name: str,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
+ kv_cache_dummy_dep: torch.Tensor | None = None,
) -> None:
- attn_metadata, self, kv_cache = get_attention_context(layer_name)
- self.impl.forward(
- self,
+ # kv_cache_dummy_dep is not used but accepting it creates a data dependency
+ # that ensures torch.compile preserves ordering between KV cache update and
+ # attention forward.
+ del kv_cache_dummy_dep
+ attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name)
+ layer.forward_impl(
q,
kv_c_normed,
k_pe,
@@ -606,6 +885,7 @@ def unified_mla_attention_with_output_fake(
layer_name: str,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
+ kv_cache_dummy_dep: torch.Tensor | None = None,
) -> None:
return
@@ -637,20 +917,23 @@ class QueryLenSupport(Enum):
try:
- # from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
- # flash_attn_varlen_func,
- # )
- from ixformer.contrib.vllm_flash_attn import (
- flash_attn_varlen_func,
- merge_attn_states,
- )
-
+ from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func,merge_attn_states
is_vllm_fa = True
except ImportError:
- # For rocm use upstream flash attention
- if current_platform.is_rocm():
- from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
is_vllm_fa = False
+ flash_attn_varlen_func = None # type: ignore[assignment]
+ # On ROCm, vllm_flash_attn is not available, try upstream flash_attn instead.
+ # On CUDA, vllm_flash_attn should always be available (built with vLLM),
+ # so we don't attempt the fallback there.
+ if current_platform.is_rocm():
+ try:
+ from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
+ except ImportError:
+ logger.debug(
+ "flash_attn not available on ROCm; "
+ "MLA models using TRITON_MLA will require flash_attn. "
+ "AITER_MLA backends use aiter kernels instead."
+ )
def dynamic_per_batched_tensor_quant(
@@ -667,7 +950,10 @@ def dynamic_per_batched_tensor_quant(
logger = init_logger(__name__)
-@CustomOp.register("mla_decode_concat_quant_fp8")
+@CustomOp.register(
+ "mla_decode_concat_quant_fp8",
+ dynamic_arg_dims={"decode_ql_nope": 0, "decode_q_pe": 0},
+)
class _DecodeConcatQuantFP8(QuantFP8):
"""
QuantFP8 variant that concatenates decode_ql_nope and decode_q_pe before
@@ -704,7 +990,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import scaled_deq
import ixformer.inference.functions as ixf_ops
import numpy as np
-
class MLACommonBackend(AttentionBackend):
accept_output_buffer: bool = True
@@ -775,6 +1060,7 @@ class MLACommonPrefillMetadata:
query_seq_lens: torch.Tensor | None = None
workspace_buffer: torch.Tensor | None = None
q_data_type: torch.dtype | None = None
+ output_dtype: torch.dtype | None = None
@dataclass
@@ -870,6 +1156,7 @@ def is_deepseek_r1_mla_compatible(vllm_config: VllmConfig) -> bool:
return qk_nope_head_dim == 128 and qk_rope_head_dim == 64 and v_head_dim == 128
+@functools.cache
def use_flashinfer_prefill() -> bool:
# For blackwell default to flashinfer prefill if it's available since
# it is faster than FA2.
@@ -887,6 +1174,7 @@ def use_flashinfer_prefill() -> bool:
return is_deepseek_r1_mla_compatible(vllm_config)
+@functools.cache
def use_cudnn_prefill() -> bool:
from vllm.config import get_current_vllm_config
@@ -899,6 +1187,7 @@ def use_cudnn_prefill() -> bool:
)
+@functools.cache
def use_trtllm_ragged_deepseek_prefill() -> bool:
"""Check if TRT-LLM ragged DeepSeek prefill should be used."""
from vllm.config import get_current_vllm_config
@@ -935,6 +1224,27 @@ def get_mla_dims(model_config: ModelConfig) -> MLADims:
)
+@functools.cache
+def backend_supports_prefill_query_quantization() -> bool:
+ """Check if the selected MLA backend supports prefill query quantization.
+
+ Currently supported backends:
+ - FlashInfer prefill
+ - TRT-LLM ragged DeepSeek prefill
+
+ Not supported:
+ - cuDNN Prefill
+ - FlashAttention
+ - Non-GB200 devices (FP8 prefill requires device capability 100)
+ """
+ # FP8 prefill query quantization requires GB200 (device capability 100)
+ # for the necessary FP8 kernels at the moment.
+ if not current_platform.is_device_capability_family(100):
+ return False
+
+ return use_flashinfer_prefill() or use_trtllm_ragged_deepseek_prefill()
+
+
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
"""
NOTE: Please read the comment at the top of the file before trying to
@@ -947,7 +1257,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# - VARLEN: Supports variable-length queries (spec decode with mixed lengths)
# If set to UNIFORM or VARLEN, this will increase `reorder_batch_threshold` when
# speculative decoding is enabled.
- query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.SINGLE_ONLY
+ query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
# The threshold for reordering the batch into decode and prefill requests.
# If > 1, the batch will be reordered such that requests with
@@ -956,7 +1266,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# when speculative decoding is enabled.
reorder_batch_threshold: int = 1
_cudagraph_support: ClassVar[AttentionCGSupport] = \
- AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
+ AttentionCGSupport.UNIFORM_BATCH
@staticmethod
def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int:
@@ -989,6 +1299,40 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
return chunked_prefill_workspace_size
+ @staticmethod
+ def determine_prefill_query_data_type(
+ vllm_config: VllmConfig,
+ model_dtype: torch.dtype,
+ ) -> torch.dtype:
+ """
+ Determine the query data type for prefill queries.
+ Return FP8 dtype if cache is FP8 and prefill query quantization
+ is enabled, else model dtype.
+ """
+ use_fp8 = (
+ vllm_config.cache_config.cache_dtype.startswith("fp8")
+ and vllm_config.attention_config.use_prefill_query_quantization
+ and backend_supports_prefill_query_quantization()
+ )
+
+ if use_fp8:
+ fp8_dtype = current_platform.fp8_dtype()
+ logger.info_once(
+ "FP8 prefill attention enabled: query data type is FP8", scope="local"
+ )
+ return fp8_dtype
+ elif vllm_config.attention_config.use_prefill_query_quantization:
+ logger.info_once(
+ "Unable to perform FP8 prefill attention when"
+ " use_prefill_query_quantization is enabled. Please"
+ " ensure that --kv-cache-dtype is set to fp8 and your prefill"
+ " backend is compatible with FP8 attention.",
+ scope="local",
+ )
+ return model_dtype
+
+ return model_dtype
+
def __init__(
self,
kv_cache_spec: AttentionSpec,
@@ -1006,12 +1350,24 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self.model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self.compilation_config = vllm_config.compilation_config
+ self.decode_use_graph = (
+ vllm_config.compilation_config.cudagraph_mode.decode_use_graph()
+ )
+ self.use_full_cuda_graph = (
+ vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs()
+ )
self.vllm_config = vllm_config
self.device = device
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
self.mla_dims = get_mla_dims(self.model_config)
self.aot_schedule = current_platform.is_cuda()
+
+ self.kv_cache_spec = kv_cache_spec
+ self.q_data_type = self.determine_prefill_query_data_type(
+ vllm_config, self.model_config.dtype
+ )
+
try:
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
@@ -1052,7 +1408,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self.chunked_prefill_workspace_size,
self.model_config.get_head_size(),
),
- dtype=self.model_config.dtype,
+ dtype=self.q_data_type,
device=device,
)
@@ -1106,6 +1462,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
f"reorder_batch_threshold must be 1 when query_len_support is "
f"SINGLE_ONLY, got {self.reorder_batch_threshold}"
)
+
+ # Spec-decode expands block_table / seq_lens via index_select; CUDA graph
+ # replay expects stable tensor storage. Copy expanded rows into these
+ # buffers when decode or full CUDA graph is enabled.
+ self._spec_decode_expand_block_table_buf: torch.Tensor | None = None
+ self._spec_decode_expand_seq_lens_buf: torch.Tensor | None = None
def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
qo_indptr = prefill.query_start_loc
@@ -1162,7 +1524,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
- q_data_type=self.model_config.dtype,
+ q_data_type=self.q_data_type,
+ o_data_type=prefill.output_dtype,
)
# Prepare context prefills
@@ -1181,7 +1544,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
sm_scale=self._global_hyperparameters.sm_scale,
window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
- q_data_type=self.model_config.dtype,
+ q_data_type=self.q_data_type,
+ o_data_type=prefill.output_dtype,
)
prefill.prefill_main = self._fi_prefill_main
@@ -1195,16 +1559,17 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int,
- max_decode_seq_len: int,
- use_cuda_graph: bool,
dcp_tot_seq_lens_device: torch.Tensor | None,
+ max_decode_seq_len: int,
+ use_cuda_graph: bool
) -> MLACommonDecodeMetadata:
return MLACommonDecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
+ dcp_tot_seq_lens=dcp_tot_seq_lens_device,
max_decode_seq_len=max_decode_seq_len,
use_cuda_graph=use_cuda_graph,
- dcp_tot_seq_lens=dcp_tot_seq_lens_device,
+
)
def build_for_cudagraph_capture(
@@ -1246,7 +1611,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
seq_lens = common_attn_metadata.seq_lens
dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens
-
+ seq_lens_np = common_attn_metadata.seq_lens_np
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata,
@@ -1440,6 +1805,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
query_start_loc=prefill_query_start_loc,
max_query_len=max_query_len,
chunked_context=chunked_context_metadata,
+ output_dtype=self.model_config.dtype,
+ q_data_type=self.q_data_type,
)
if self._use_cudnn_prefill:
@@ -1471,17 +1838,135 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
max_seq_len = (
(max_seq_len + num_partitions - 1) // num_partitions
) * self.cp_kv_cache_interleave_size
+
+ decode_block_table = block_table_tensor[:num_decodes, ...]
+ decode_seq_lens_np = seq_lens_np[:num_decodes]
+ decode_seq_lens_device = seq_lens[:num_decodes]
+ decode_query_start_loc_cpu = query_start_loc_cpu[: num_decodes + 1]
+ decode_query_start_loc = query_start_loc[: num_decodes + 1]
+ max_decode_seq_len = np.max(decode_seq_lens_np).item()
+
+ speculative_config = self.vllm_config.speculative_config
+ use_spec_decode = (
+ speculative_config is not None
+ and speculative_config.num_speculative_tokens is not None
+ and speculative_config.num_speculative_tokens > 0
+ )
+ # For non-DCP speculative decode, expand decode metadata from
+ # request-level rows to token-level rows so block_table/seq_lens
+ # align with num_decode_tokens.
+ if (
+ use_spec_decode
+ and self.dcp_world_size == 1
+ and num_decode_tokens > num_decodes
+ ):
+ decode_lens_cpu = (
+ decode_query_start_loc_cpu[1:] - decode_query_start_loc_cpu[:-1]
+ )
+ decode_lens_device_long = decode_lens_cpu.to(
+ device=decode_block_table.device, dtype=torch.long
+ )
+ req_ids = torch.repeat_interleave(
+ torch.arange(
+ num_decodes, device=decode_block_table.device, dtype=torch.long
+ ),
+ decode_lens_device_long,
+ )
+ if req_ids.numel() != num_decode_tokens:
+ pass
+ else:
+ expanded_block_table = decode_block_table.index_select(
+ 0, req_ids
+ )
+
+ decode_query_start_loc_long = decode_query_start_loc.to(
+ device=expanded_block_table.device, dtype=torch.long
+ )
+ token_starts = torch.repeat_interleave(
+ decode_query_start_loc_long[:-1], decode_lens_device_long
+ )
+ token_pos_in_req = (
+ torch.arange(
+ num_decode_tokens,
+ device=expanded_block_table.device,
+ dtype=torch.long,
+ )
+ - token_starts
+ )
+
+ decode_seq_lens_device_long = decode_seq_lens_device.to(
+ device=decode_block_table.device, dtype=torch.long
+ )
+ context_lens_long = (
+ decode_seq_lens_device_long - decode_lens_device_long
+ )
+ decode_seq_lens_device_long = (
+ context_lens_long.index_select(0, req_ids)
+ + token_pos_in_req
+ + 1
+ )
+
+ decode_seq_lens_device = decode_seq_lens_device_long.to(
+ dtype=seq_lens.dtype
+ )
+
+ if self.decode_use_graph or self.use_full_cuda_graph:
+ n = expanded_block_table.shape[0]
+ ncols = expanded_block_table.shape[1]
+ max_rows = (
+ self.vllm_config.scheduler_config.max_num_batched_tokens
+ )
+ assert n <= max_rows, (
+ f"spec-decode expanded decode rows {n} > "
+ f"max_num_batched_tokens {max_rows}"
+ )
+ if (
+ self._spec_decode_expand_block_table_buf is None
+ or self._spec_decode_expand_block_table_buf.shape[0]
+ < max_rows
+ or self._spec_decode_expand_block_table_buf.shape[1]
+ != ncols
+ ):
+ self._spec_decode_expand_block_table_buf = torch.zeros(
+ max_rows,
+ ncols,
+ dtype=block_table_tensor.dtype,
+ device=self.device,
+ )
+ self._spec_decode_expand_seq_lens_buf = torch.zeros(
+ max_rows,
+ dtype=seq_lens.dtype,
+ device=self.device,
+ )
+ self._spec_decode_expand_block_table_buf[:n].copy_(
+ expanded_block_table
+ )
+ self._spec_decode_expand_seq_lens_buf[:n].copy_(
+ decode_seq_lens_device
+ )
+ decode_block_table = self._spec_decode_expand_block_table_buf[
+ :n
+ ]
+ decode_seq_lens_device = (
+ self._spec_decode_expand_seq_lens_buf[:n]
+ )
+ else:
+ decode_block_table = expanded_block_table
+ decode_seq_lens_cpu = decode_seq_lens_device.to(
+ device="cpu", dtype=seq_lens.dtype
+ )
+ max_decode_seq_len = torch.max(decode_seq_lens_cpu).item()
decode_metadata = self._build_decode(
- block_table_tensor=block_table_tensor[:num_decodes, ...],
- seq_lens_device=seq_lens[:num_decodes],
+ block_table_tensor=decode_block_table,
+ seq_lens_device=decode_seq_lens_device,
max_seq_len=max_seq_len,
- query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1],
- query_start_loc_device=query_start_loc[: num_decodes + 1],
+ query_start_loc_cpu=decode_query_start_loc_cpu,
+ query_start_loc_device=decode_query_start_loc,
num_decode_tokens=num_decode_tokens,
- max_decode_seq_len=torch.max(seq_lens[:num_decodes]).item(),
- use_cuda_graph=False,
dcp_tot_seq_lens_device=dcp_tot_seq_lens_device,
+ max_decode_seq_len=max_decode_seq_len,
+ use_cuda_graph=num_prefills == 0 and self.decode_use_graph,
)
attn_metadata = self.metadata_cls(
@@ -1580,9 +2065,7 @@ def reorg_kvcache(
return reorganized_kv_c_normed, reorganized_k_pe
-# TODO(Lucas): rename MLACommonBaseImpl -> MLACommonImpl,
-# and MLACommonImpl -> MLACommonDenseImpl or somthing like that
-class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
+class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
@@ -1608,9 +2091,8 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
qk_head_dim: int,
v_head_dim: int,
kv_b_proj: ColumnParallelLinear,
- indexer=None,
+ indexer: object | None = None,
q_pad_num_heads: int | None = None,
- rotary_emb = None,
) -> None:
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported for MLA")
@@ -1630,23 +2112,321 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
self.kv_b_proj = kv_b_proj
self.indexer = indexer
self.q_pad_num_heads = q_pad_num_heads
- self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
- self.rotary_emb = rotary_emb
+ self.supports_quant_query_input = True
- # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
- self.is_aiter_triton_fp4_bmm_enabled = (
- rocm_aiter_ops.is_fp4bmm_enabled()
- and self.kv_b_proj.weight.dtype == torch.bfloat16
+ # Use flashinfer's optimized concat_mla_k kernel when available.
+ # The kernel is optimized for DeepSeek V3 dimensions:
+ # num_heads=128, nope_dim=128, rope_dim=64
+ self._use_flashinfer_concat_mla_k = (
+ has_flashinfer()
+ and (self.num_heads == 128)
+ and (self.qk_nope_head_dim == 128)
+ and (self.qk_rope_head_dim == 64)
)
- def process_weights_after_loading(self, act_dtype: torch.dtype):
- # we currently do not have quantized bmm's which are needed for
- # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
- # the bmm's in 16-bit, the extra memory overhead of this is fairly low
- # kv_b_proj_weight = get_and_maybe_dequant_weights(
- # self.kv_b_proj, out_dtype=act_dtype
- # ).T
+ if use_trtllm_ragged_deepseek_prefill():
+ logger.info_once(
+ "Using TRT-LLM ragged DeepSeek prefill for MLA", scope="local"
+ )
+ self._run_prefill_context_chunk = (
+ self._run_prefill_context_chunk_trtllm_ragged
+ )
+ self._run_prefill_new_tokens = self._run_prefill_new_tokens_trtllm_ragged
+ self._pad_v = False
+ elif use_flashinfer_prefill():
+ logger.info_once("Using FlashInfer prefill for MLA", scope="local")
+ self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi
+ self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi
+ self._pad_v = False
+ elif use_cudnn_prefill():
+ logger.info_once("Using CUDNN prefill for MLA", scope="local")
+ self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn
+ self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn
+ self._pad_v = False
+ else: # Use FlashAttention
+ if flash_attn_varlen_func is None:
+ raise RuntimeError(
+ "MLA attention requires FlashAttention but it is not "
+ "available. Please install flash_attn or use "
+ "--attention-backend ROCM_AITER_MLA."
+ )
+ logger.info_once("Using FlashAttention prefill for MLA", scope="local")
+ self.positions = None
+ self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa
+ self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa
+ # Handle the differences between the flash_attn_varlen from
+ # flash_attn and the one from vllm_flash_attn. The former is used on
+ # RoCM and the latter has an additional parameter to control
+ # FA2 vs FA3
+ self.flash_attn_varlen_func = flash_attn_varlen_func
+ self.vllm_flash_attn_version = get_flash_attn_version()
+ # if self.vllm_flash_attn_version is not None:
+ # self.flash_attn_varlen_func = functools.partial(
+ # flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version
+ # )
+
+ # For MLA the v head dim is smaller than qk head dim so we pad out
+ # v with 0s to match the qk head dim for attention backends that do
+ # not support different headdims
+ # We don't need to pad V if we are on a hopper system with FA3
+ # self._pad_v = self.vllm_flash_attn_version is None or not (
+ # self.vllm_flash_attn_version == 3
+ # and current_platform.get_device_capability()[0] == 9
+ # )
+ self._pad_v = False
+
+ self.dcp_world_size: int = -1
+
+ self.cp_kv_cache_interleave_size: int = (
+ get_current_vllm_config().parallel_config.cp_kv_cache_interleave_size
+ )
+
+ def _flash_attn_varlen_diff_headdims(
+ self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
+ ):
+ maybe_padded_v = v
+ if self._pad_v:
+ maybe_padded_v = torch.nn.functional.pad(
+ v, [0, q.shape[-1] - v.shape[-1]], value=0
+ )
+
+ if is_vllm_fa:
+ kwargs["return_softmax_lse"] = return_softmax_lse
+ else:
+ # ROCm leverages the upstream flash_attn, which takes a parameter
+ # called "return_attn_probs" instead of return_softmax_lse
+ kwargs["return_attn_probs"] = return_softmax_lse
+ if vllm_is_batch_invariant():
+ kwargs["num_splits"] = 1
+
+ attn_out = self.flash_attn_varlen_func(
+ q=q,
+ k=k,
+ v=maybe_padded_v,
+ softmax_scale=softmax_scale,
+ **kwargs,
+ )
+
+ # Unpack the output if there is multiple results
+ lse = None
+ if isinstance(attn_out, tuple):
+ attn_out, lse = attn_out[0], attn_out[1]
+
+ # Remain consistent with old `flash_attn_varlen_func` where there
+ # is only one output tensor if `return_softmax_lse` is False.
+ if return_softmax_lse:
+ return attn_out, lse
+ return attn_out
+
+ def _run_prefill_new_tokens_fa(
+ self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse, out
+ ):
+ return self._flash_attn_varlen_diff_headdims(
+ q=q,
+ k=k,
+ v=v,
+ cu_seqlens_q=prefill.query_start_loc,
+ cu_seqlens_k=prefill.query_start_loc,
+ max_seqlen_q=prefill.max_query_len,
+ max_seqlen_k=prefill.max_query_len,
+ softmax_scale=self.scale,
+ causal=True,
+ return_softmax_lse=return_softmax_lse,
+ out=out,
+ )
+
+ def _run_prefill_new_tokens_fi(
+ self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
+ ):
+ assert isinstance(prefill, FlashInferPrefillMetadata)
+ assert prefill.prefill_main is not None
+
+ ret = prefill.prefill_main.run(
+ q=q,
+ k=k,
+ v=v,
+ return_lse=return_softmax_lse,
+ )
+
+ if isinstance(ret, tuple):
+ return ret[0], ret[1].transpose(0, 1).contiguous()
+ return ret
+
+ def _run_prefill_new_tokens_cudnn(
+ self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
+ ):
+ assert isinstance(prefill, CudnnPrefillMetadata)
+ assert prefill.query_seq_lens is not None
+ from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache
+
+ output, lse = cudnn_batch_prefill_with_kv_cache(
+ q=q,
+ k_cache=k,
+ v_cache=v,
+ scale=self.scale,
+ workspace_buffer=prefill.cudnn_workspace,
+ max_token_per_sequence=prefill.max_query_len,
+ max_sequence_kv=prefill.max_query_len,
+ actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1),
+ actual_seq_lens_kv=prefill.query_seq_lens.view(-1, 1, 1, 1),
+ causal=True,
+ # Do not support False for now
+ return_lse=True,
+ # Indicates actual_seq_lens are on GPU or CPU.
+ is_cuda_graph_compatible=True,
+ )
+ if return_softmax_lse:
+ return output, lse
+ return output
+
+ def _run_prefill_context_chunk_fa(
+ self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v, out
+ ):
+ assert prefill.chunked_context is not None
+ return self._flash_attn_varlen_diff_headdims(
+ q=q,
+ k=k,
+ v=v,
+ cu_seqlens_q=prefill.query_start_loc,
+ cu_seqlens_k=prefill.chunked_context.cu_seq_lens[chunk_idx],
+ max_seqlen_q=prefill.max_query_len,
+ max_seqlen_k=prefill.chunked_context.max_seq_lens[chunk_idx],
+ softmax_scale=self.scale,
+ causal=False, # Context is unmasked
+ return_softmax_lse=True,
+ out=out,
+ )
+
+ def _run_prefill_context_chunk_fi(
+ self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
+ ):
+ assert isinstance(prefill, FlashInferPrefillMetadata)
+
+ attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
+ q=q,
+ k=k,
+ v=v,
+ return_lse=True,
+ )
+
+ # Convert from (q_len, num_heads) to (num_heads, q_len)
+ return attn_out, lse.transpose(0, 1).contiguous()
+
+ def _run_prefill_context_chunk_cudnn(
+ self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
+ ):
+ assert isinstance(prefill, CudnnPrefillMetadata)
+ assert prefill.chunked_context is not None
+ assert prefill.chunked_context.seq_lens[chunk_idx] is not None
+ assert prefill.query_seq_lens is not None
+ from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache
+
+ return cudnn_batch_prefill_with_kv_cache(
+ q=q,
+ k_cache=k,
+ v_cache=v,
+ scale=self.scale,
+ workspace_buffer=prefill.cudnn_workspace,
+ max_token_per_sequence=prefill.max_query_len,
+ max_sequence_kv=prefill.chunked_context.max_seq_lens[chunk_idx],
+ actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1),
+ actual_seq_lens_kv=prefill.chunked_context.seq_lens[chunk_idx].view(
+ -1, 1, 1, 1
+ ),
+ causal=False,
+ return_lse=True,
+ # Indicates actual_seq_lens are on GPU or CPU.
+ is_cuda_graph_compatible=True,
+ )
+
+ def _run_prefill_new_tokens_trtllm_ragged(
+ self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
+ ):
+ """TRT-LLM ragged attention for new tokens (causal)."""
+ from flashinfer.prefill import trtllm_ragged_attention_deepseek
+
+ assert prefill.query_seq_lens is not None
+ assert prefill.workspace_buffer is not None
+ # allocate BF16 / FP16 output tensor for TRT-LLM ragged attention
+ out = torch.empty(
+ q.shape[0],
+ q.shape[1],
+ v.shape[2],
+ device=q.device,
+ dtype=prefill.output_dtype,
+ )
+
+ ret = trtllm_ragged_attention_deepseek(
+ query=q,
+ key=k,
+ value=v,
+ workspace_buffer=prefill.workspace_buffer,
+ seq_lens=prefill.query_seq_lens,
+ max_q_len=prefill.max_query_len,
+ max_kv_len=prefill.max_query_len,
+ bmm1_scale=self.scale,
+ bmm2_scale=1.0,
+ o_sf_scale=1.0,
+ batch_size=prefill.query_seq_lens.shape[0],
+ window_left=-1,
+ cum_seq_lens_q=prefill.query_start_loc,
+ cum_seq_lens_kv=prefill.query_start_loc,
+ enable_pdl=False,
+ is_causal=True,
+ return_lse=return_softmax_lse,
+ out=out,
+ )
+
+ if isinstance(ret, tuple):
+ # Convert from (q_len, num_heads) to (num_heads, q_len)
+ return ret[0], ret[1].transpose(0, 1).contiguous()
+ return ret
+
+ def _run_prefill_context_chunk_trtllm_ragged(
+ self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
+ ):
+ """TRT-LLM ragged attention for context chunks (non-causal)."""
+ from flashinfer.prefill import trtllm_ragged_attention_deepseek
+
+ assert prefill.chunked_context is not None
+ assert prefill.chunked_context.seq_lens[chunk_idx] is not None
+ assert prefill.workspace_buffer is not None
+
+ out = torch.zeros(
+ q.shape[0],
+ q.shape[1],
+ v.shape[2],
+ device=q.device,
+ dtype=prefill.output_dtype,
+ )
+ prefill.workspace_buffer.fill_(0)
+
+ attn_out, lse = trtllm_ragged_attention_deepseek(
+ query=q,
+ key=k,
+ value=v,
+ workspace_buffer=prefill.workspace_buffer,
+ seq_lens=prefill.chunked_context.seq_lens[chunk_idx],
+ max_q_len=prefill.max_query_len,
+ max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx],
+ bmm1_scale=self.scale,
+ bmm2_scale=1.0,
+ o_sf_scale=1.0,
+ batch_size=prefill.chunked_context.seq_lens[chunk_idx].shape[0],
+ window_left=-1,
+ cum_seq_lens_q=prefill.query_start_loc,
+ cum_seq_lens_kv=prefill.chunked_context.cu_seq_lens[chunk_idx],
+ enable_pdl=False,
+ is_causal=False,
+ return_lse=True,
+ out=out,
+ )
+
+ # Convert from (q_len, num_heads) to (num_heads, q_len)
+ return attn_out, lse.transpose(0, 1).contiguous()
+
+ def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
for attr in WEIGHT_NAMES:
@@ -1657,15 +2437,17 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
)
def get_and_maybe_dequant_weights(layer: LinearBase):
- if not isinstance(layer.quant_method, UnquantizedLinearMethod):
+ if layer.quant_method is not None and not isinstance(
+ layer.quant_method, UnquantizedLinearMethod
+ ):
# we already prepare a dequant weight for GGUF, skip it.
if layer.quant_method.__class__.__name__ == "GGUFLinearMethod":
weight = layer.weight.clone()
del layer.weight
return weight
if layer.quant_method.__class__.__name__ == "AWQMarlinLinearMethod":
- from ixformer.inference.functions import ref_wui4a16
- return ref_wui4a16(None, layer.qweight, layer.scales, layer.qzeros, None, layer.quant_method.quant_config.group_size, only_return_weight=True)
+ from ixformer.inference.functions import wui4a16
+ return wui4a16(None, layer.qweight, layer.scales, layer.qzeros, None, layer.quant_method.quant_config.group_size, only_return_weight=True)
# for W8A8, we directly dequantize it here to avoiding quantization errors
if hasattr(layer, "scheme") and layer.scheme.__class__.__name__ == "CompressedTensorsW8A8Int8" and not layer.scheme.is_static_input_scheme:
quant_weight = layer.weight.T # output, input
@@ -1682,27 +2464,12 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
# standardize to (output, input)
return dequant_weights.T
return layer.weight
-
- # assert kv_b_proj_weight.shape == (
- # self.kv_lora_rank,
- # self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
- # ), (
- # f"{kv_b_proj_weight.shape=}, "
- # f"{self.kv_lora_rank=}, "
- # f"{self.num_heads=}, "
- # f"{self.qk_nope_head_dim=}, "
- # f"{self.v_head_dim=}"
- # )
- # kv_b_proj_weight = kv_b_proj_weight.view(
- # self.kv_lora_rank,
- # self.num_heads,
- # self.qk_nope_head_dim + self.v_head_dim,
- # )
-
- # W_UK, W_UV = kv_b_proj_weight.split(
- # [self.qk_nope_head_dim, self.v_head_dim], dim=-1
- # )
+
back_to_vllm = False
+ weight_dtype = get_layer_weight(self.kv_b_proj).dtype
+
+ # when use customize forward, we do not care the specific data types and shape, just reset the
+ # _v_up_proj_and_o_proj and _q_proj_and_k_up_proj funs to get the correct results.
if envs.VLLM_MLA_CUSTOMIZE:
layer = self.kv_b_proj
quant_method = layer.quant_method.__class__.__name__
@@ -1866,19 +2633,18 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
back_to_vllm = True
if back_to_vllm:
- # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
+ # we currently do not have quantized bmm's which are needed for
+ # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
- self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
- ), (
- f"{kv_b_proj_weight.shape=}, "
- f"{self.kv_lora_rank=}, "
- f"{self.num_heads=}, "
- f"{self.qk_nope_head_dim=}, "
- f"{self.v_head_dim=}"
- )
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
+ f"{kv_b_proj_weight.shape=}, "
+ f"{self.kv_lora_rank=}, "
+ f"{self.num_heads=}, "
+ f"{self.qk_nope_head_dim=}, "
+ f"{self.v_head_dim=}")
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
@@ -1886,478 +2652,59 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
)
W_UK, W_UV = kv_b_proj_weight.split(
- [self.qk_nope_head_dim, self.v_head_dim], dim=-1
- )
+ [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
- if self.is_aiter_triton_fp8_bmm_enabled:
- W_K = W_UK.transpose(0, 1) # 16 512 128
- W_V = W_UV.permute(1, 2, 0) # 16 128 512
- self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
- W_K, dtype=current_platform.fp8_dtype()
- )
- self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
- W_V, dtype=current_platform.fp8_dtype()
- )
-
- # The kernel operates on non-padded inputs. Hence, pre-compiling
- # triton kernel to avoid runtime compilation for unseen batch sizes
- # Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
- # On DS-R1, this step adds roughly 50s to the model loading time.
- max_batch_size = 1024 # [ToDo] Find the optimal upper limit
- pre_compilation_list = list(range(1, max_batch_size + 1))
- if is_global_first_rank():
- pre_compilation_list = tqdm(
- pre_compilation_list,
- desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
- total=max_batch_size,
- )
-
- for m in pre_compilation_list:
- x = torch.empty(
- (self.W_K.shape[0], m, self.W_K.shape[2]),
- dtype=torch.bfloat16,
- device=self.W_K.device,
- )
- rocm_aiter_ops.triton_fp8_bmm(
- x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
- )
-
- x = torch.empty(
- (self.W_V.shape[0], m, self.W_V.shape[2]),
- dtype=torch.bfloat16,
- device=self.W_V.device,
- )
- rocm_aiter_ops.triton_fp8_bmm(
- x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
- )
- else:
- # # Convert from (L, N, V) to (N, L, V)
- # self.W_UV = W_UV.transpose(0, 1)
- # # Convert from (L, N, P) to (N, P, L)
- # self.W_UK_T = W_UK.permute(1, 2, 0)
- self.W_UV = W_UV
- self.W_UK = W_UK
-
- # # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported
+ # # Convert from (L, N, V) to (N, L, V)
+ # self.W_UV = W_UV.transpose(0, 1)
+ # # Convert from (L, N, P) to (N, P, L)
+ # self.W_UK_T = W_UK.permute(1, 2, 0)
+ self.W_UV = W_UV
+ self.W_UK = W_UK
+
+ def _v_up_proj(self, x: torch.Tensor):
+ # Convert from (B, N, L) to (N, B, L)
+ # x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
+ # out = out.view(-1, self.num_heads, self.v_head_dim)
# if self.is_aiter_triton_fp4_bmm_enabled:
- # from vllm.model_executor.layers.quantization.quark.utils import (
- # quark_quantize_weight_to_mxfp4,
- # )
-
- # self.W_K, self.W_K_scale = quark_quantize_weight_to_mxfp4(W_UK)
- # # Convert from (L, N, P) to (N, L, P)
- # self.W_K = self.W_K.transpose(0, 1)
- # self.W_K_scale = self.W_K_scale.transpose(0, 1)
-
- # self.W_V, self.W_V_scale = quark_quantize_weight_to_mxfp4(
- # W_UV.permute(1, 2, 0)
+ # out = rocm_aiter_ops.batched_gemm_a16wfp4(
+ # x,
+ # self.W_V,
+ # self.W_V_scale,
+ # out,
+ # transpose_bm=True,
+ # prequant=True,
+ # y_scale=None,
# )
+ # x = out.view(-1, self.num_heads * self.v_head_dim)
# elif self.is_aiter_triton_fp8_bmm_enabled:
- # W_K = W_UK.transpose(0, 1) # 16 512 128
- # W_V = W_UV.permute(1, 2, 0) # 16 128 512
- # self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
- # W_K, dtype=current_platform.fp8_dtype()
+ # # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
+ # x = rocm_aiter_ops.triton_fp8_bmm(
+ # x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out
# )
- # self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
- # W_V, dtype=current_platform.fp8_dtype()
- # )
-
- # # The kernel operates on non-padded inputs. Hence, pre-compiling
- # # triton kernel to avoid runtime compilation for unseen batch sizes
- # # Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
- # # On DS-R1, this step adds roughly 50s to the model loading time.
- # max_batch_size = 1024 # [ToDo] Find the optimal upper limit
- # pre_compilation_list = list(range(1, max_batch_size + 1))
- # if is_global_first_rank():
- # pre_compilation_list = tqdm(
- # pre_compilation_list,
- # desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
- # total=max_batch_size,
- # )
-
- # for m in pre_compilation_list:
- # x = torch.empty(
- # (self.W_K.shape[0], m, self.W_K.shape[2]),
- # dtype=torch.bfloat16,
- # device=self.W_K.device,
- # )
- # rocm_aiter_ops.triton_fp8_bmm(
- # x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
- # )
-
- # x = torch.empty(
- # (self.W_V.shape[0], m, self.W_V.shape[2]),
- # dtype=torch.bfloat16,
- # device=self.W_V.device,
- # )
- # rocm_aiter_ops.triton_fp8_bmm(
- # x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
- # )
# else:
- # # Convert from (L, N, V) to (N, L, V)
- # self.W_UV = W_UV.transpose(0, 1)
- # # Convert from (L, N, P) to (N, P, L)
- # self.W_UK_T = W_UK.permute(1, 2, 0)
+ # # Convert from (B, N * V) to (N, B, V)
+ # out = out.transpose(0, 1)
- def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
- # Convert from (B, N, L) to (N, B, L)
- x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
- out = out.view(-1, self.num_heads, self.v_head_dim)
- if self.is_aiter_triton_fp4_bmm_enabled:
- out = rocm_aiter_ops.batched_gemm_a16wfp4(
- x,
- self.W_V,
- self.W_V_scale,
- out,
- transpose_bm=True,
- prequant=True,
- y_scale=None,
- )
- x = out.view(-1, self.num_heads * self.v_head_dim)
- elif self.is_aiter_triton_fp8_bmm_enabled:
- # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
- x = rocm_aiter_ops.triton_fp8_bmm(
- x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out
- )
- else:
- # Convert from (B, N * V) to (N, B, V)
- out = out.transpose(0, 1)
+ # # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
+ # torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot"
- # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
- torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot"
+ # # Convert from (N, B, V) to (B, N * V)
+ # out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
- # Convert from (N, B, V) to (B, N * V)
- out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
+ # # Adjust output buffer shape back to the original (B, N * V)
+ # N, B, V = out.shape
+ # out.resize_((B, N * V))
+ # out.copy_(out_new) # Copy result
+ return torch.einsum("bnl,lnv->bnv", x, self.W_UV)
+ def _k_up_proj(self, q_nope):
+ # # Convert from (B, N, P) to (N, B, P)
+ # q_nope = q_nope.transpose(0, 1)
+ # # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
+ # ql_nope = torch.bmm(q_nope, self.W_UK_T)
+ # # Convert from (N, B, L) to (B, N, L)
+ # return ql_nope.transpose(0, 1), q_pe
+ return torch.einsum("bnp,lnp->bnl", q_nope, self.W_UK).view(-1, self.num_heads, self.kv_lora_rank)
- # Adjust output buffer shape back to the original (B, N * V)
- N, B, V = out.shape
- out.resize_((B, N * V))
- out.copy_(out_new) # Copy result
-
-
-class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
- """
- NOTE: Please read the comment at the top of the file before trying to
- understand this class
- """
-
- def __init__(self, *args, **kwargs) -> None:
- super().__init__(*args, **kwargs)
-
- if use_trtllm_ragged_deepseek_prefill():
- logger.info_once(
- "Using TRT-LLM ragged DeepSeek prefill for MLA", scope="local"
- )
- self._run_prefill_context_chunk = (
- self._run_prefill_context_chunk_trtllm_ragged
- )
- self._run_prefill_new_tokens = self._run_prefill_new_tokens_trtllm_ragged
- self._pad_v = False
- elif use_flashinfer_prefill():
- logger.info_once("Using FlashInfer prefill for MLA", scope="local")
- self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi
- self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi
- self._pad_v = False
- elif use_cudnn_prefill():
- logger.info_once("Using CUDNN prefill for MLA", scope="local")
- self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn
- self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn
- self._pad_v = False
- else: # Use FlashAttention
- logger.info_once("Using FlashAttention prefill for MLA", scope="local")
- self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa
- self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa
-
- # Handle the differences between the flash_attn_varlen from
- # flash_attn and the one from vllm_flash_attn. The former is used on
- # RoCM and the latter has an additional parameter to control
- # FA2 vs FA3
- self.flash_attn_varlen_func = flash_attn_varlen_func
- self.vllm_flash_attn_version = get_flash_attn_version()
- # if self.vllm_flash_attn_version is not None:
- # self.flash_attn_varlen_func = functools.partial(
- # flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version
- # )
-
- # For MLA the v head dim is smaller than qk head dim so we pad out
- # v with 0s to match the qk head dim for attention backends that do
- # not support different headdims
- # We don't need to pad V if we are on a hopper system with FA3
- # device_capability = current_platform.get_device_capability()
- # self._pad_v = self.vllm_flash_attn_version is None or not (
- # self.vllm_flash_attn_version == 3
- # and device_capability is not None
- # and device_capability[0] == 9
- # )
- self._pad_v = False
-
- self.dcp_world_size: int = -1
-
- self.chunked_prefill_workspace_size = (
- MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
- get_current_vllm_config()
- )
- )
- self.cp_kv_cache_interleave_size: int = (
- get_current_vllm_config().parallel_config.cp_kv_cache_interleave_size
- )
- self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
- static=True,
- group_shape=GroupShape.PER_TENSOR,
- compile_native=True,
- )
-
- def _flash_attn_varlen_diff_headdims(
- self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
- ):
- maybe_padded_v = v
- if self._pad_v:
- maybe_padded_v = torch.nn.functional.pad(
- v, [0, q.shape[-1] - v.shape[-1]], value=0
- )
-
- if is_vllm_fa:
- kwargs["return_softmax_lse"] = return_softmax_lse
- else:
- # ROCm leverages the upstream flash_attn, which takes a parameter
- # called "return_attn_probs" instead of return_softmax_lse
- kwargs["return_attn_probs"] = return_softmax_lse
- if vllm_is_batch_invariant():
- kwargs["num_splits"] = 1
-
- attn_out = self.flash_attn_varlen_func(
- q=q,
- k=k,
- v=maybe_padded_v,
- softmax_scale=softmax_scale,
- **kwargs,
- )
-
- # Unpack the output if there is multiple results
- lse = None
- if isinstance(attn_out, tuple):
- attn_out, lse = attn_out[0], attn_out[1]
-
- # Remain consistent with old `flash_attn_varlen_func` where there
- # is only one output tensor if `return_softmax_lse` is False.
- if return_softmax_lse:
- return attn_out, lse
- return attn_out
-
- def _run_prefill_new_tokens_fa(
- self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse, out
- ):
- return self._flash_attn_varlen_diff_headdims(
- q=q,
- k=k,
- v=v,
- cu_seqlens_q=prefill.query_start_loc,
- cu_seqlens_k=prefill.query_start_loc,
- max_seqlen_q=prefill.max_query_len,
- max_seqlen_k=prefill.max_query_len,
- softmax_scale=self.scale,
- causal=True,
- return_softmax_lse=return_softmax_lse,
- out=out
- )
-
- def _run_prefill_new_tokens_fi(
- self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
- ):
- assert isinstance(prefill, FlashInferPrefillMetadata)
- assert prefill.prefill_main is not None
-
- ret = prefill.prefill_main.run(
- q=q,
- k=k,
- v=v,
- return_lse=return_softmax_lse,
- )
-
- if isinstance(ret, tuple):
- return ret[0], ret[1].transpose(0, 1).contiguous()
- return ret
-
- def _run_prefill_new_tokens_cudnn(
- self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
- ):
- assert isinstance(prefill, CudnnPrefillMetadata)
- assert prefill.query_seq_lens is not None
- from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache
-
- output, lse = cudnn_batch_prefill_with_kv_cache(
- q=q,
- k_cache=k,
- v_cache=v,
- scale=self.scale,
- workspace_buffer=prefill.cudnn_workspace,
- max_token_per_sequence=prefill.max_query_len,
- max_sequence_kv=prefill.max_query_len,
- actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1),
- actual_seq_lens_kv=prefill.query_seq_lens.view(-1, 1, 1, 1),
- causal=True,
- # Do not support False for now
- return_lse=True,
- # Indicates actual_seq_lens are on GPU or CPU.
- is_cuda_graph_compatible=True,
- )
- if return_softmax_lse:
- return output, lse
- return output
-
- def _run_prefill_context_chunk_fa(
- self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v, out
- ):
- assert prefill.chunked_context is not None
- return self._flash_attn_varlen_diff_headdims(
- q=q,
- k=k,
- v=v,
- cu_seqlens_q=prefill.query_start_loc,
- cu_seqlens_k=prefill.chunked_context.cu_seq_lens[chunk_idx],
- max_seqlen_q=prefill.max_query_len,
- max_seqlen_k=prefill.chunked_context.max_seq_lens[chunk_idx],
- softmax_scale=self.scale,
- causal=False, # Context is unmasked
- return_softmax_lse=True,
- out=out
- )
-
- def _run_prefill_context_chunk_fi(
- self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
- ):
- assert isinstance(prefill, FlashInferPrefillMetadata)
-
- attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
- q=q,
- k=k,
- v=v,
- return_lse=True,
- )
-
- # Convert from (q_len, num_heads) to (num_heads, q_len)
- return attn_out, lse.transpose(0, 1).contiguous()
-
- def _run_prefill_context_chunk_cudnn(
- self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
- ):
- assert isinstance(prefill, CudnnPrefillMetadata)
- assert prefill.chunked_context is not None
- assert prefill.chunked_context.seq_lens[chunk_idx] is not None
- assert prefill.query_seq_lens is not None
- from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache
-
- return cudnn_batch_prefill_with_kv_cache(
- q=q,
- k_cache=k,
- v_cache=v,
- scale=self.scale,
- workspace_buffer=prefill.cudnn_workspace,
- max_token_per_sequence=prefill.max_query_len,
- max_sequence_kv=prefill.chunked_context.max_seq_lens[chunk_idx],
- actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1),
- actual_seq_lens_kv=prefill.chunked_context.seq_lens[chunk_idx].view(
- -1, 1, 1, 1
- ),
- causal=False,
- return_lse=True,
- # Indicates actual_seq_lens are on GPU or CPU.
- is_cuda_graph_compatible=True,
- )
-
- def _v_up_proj(self, x):
- # Convert from (B, N, L) to (N, B, L)
- x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
- if is_rocm_aiter_fp8bmm_enabled():
- # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
- x = aiter_triton_fp8_bmm(
- x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
- )
- # Convert from (B, N, V) to (B, N * V)
- x = x.reshape(-1, self.num_heads * self.v_head_dim)
- else:
- # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
- x = torch.bmm(x, self.W_UV)
- # Convert from (N, B, V) to (B, N * V)
- x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
- return x
-
- def _run_prefill_new_tokens_trtllm_ragged(
- self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
- ):
- """TRT-LLM ragged attention for new tokens (causal)."""
- from flashinfer.prefill import trtllm_ragged_attention_deepseek
-
- assert prefill.query_seq_lens is not None
- assert prefill.workspace_buffer is not None
-
- ret = trtllm_ragged_attention_deepseek(
- query=q,
- key=k,
- value=v,
- workspace_buffer=prefill.workspace_buffer,
- seq_lens=prefill.query_seq_lens,
- max_q_len=prefill.max_query_len,
- max_kv_len=prefill.max_query_len,
- bmm1_scale=self.scale,
- bmm2_scale=1.0,
- o_sf_scale=1.0,
- batch_size=prefill.query_seq_lens.shape[0],
- window_left=-1,
- cum_seq_lens_q=prefill.query_start_loc,
- cum_seq_lens_kv=prefill.query_start_loc,
- enable_pdl=False,
- is_causal=True,
- return_lse=return_softmax_lse,
- )
-
- if isinstance(ret, tuple):
- # Convert from (q_len, num_heads) to (num_heads, q_len)
- return ret[0], ret[1].transpose(0, 1).contiguous()
- return ret
-
- def _run_prefill_context_chunk_trtllm_ragged(
- self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
- ):
- """TRT-LLM ragged attention for context chunks (non-causal)."""
- from flashinfer.prefill import trtllm_ragged_attention_deepseek
-
- assert prefill.chunked_context is not None
- assert prefill.chunked_context.seq_lens[chunk_idx] is not None
- assert prefill.workspace_buffer is not None
-
- out = torch.zeros(
- q.shape[0],
- q.shape[1],
- v.shape[2],
- device=q.device,
- dtype=q.dtype,
- )
- prefill.workspace_buffer.fill_(0)
-
- attn_out, lse = trtllm_ragged_attention_deepseek(
- query=q,
- key=k,
- value=v,
- workspace_buffer=prefill.workspace_buffer,
- seq_lens=prefill.chunked_context.seq_lens[chunk_idx],
- max_q_len=prefill.max_query_len,
- max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx],
- bmm1_scale=self.scale,
- bmm2_scale=1.0,
- o_sf_scale=1.0,
- batch_size=prefill.chunked_context.seq_lens[chunk_idx].shape[0],
- window_left=-1,
- cum_seq_lens_q=prefill.query_start_loc,
- cum_seq_lens_kv=prefill.chunked_context.cu_seq_lens[chunk_idx],
- enable_pdl=False,
- is_causal=False,
- return_lse=True,
- out=out,
- )
-
- # Convert from (q_len, num_heads) to (num_heads, q_len)
- return attn_out, lse.transpose(0, 1).contiguous()
def _concat_k_nope_k_pe(
self, k_nope: torch.Tensor, k_pe: torch.Tensor
@@ -2381,9 +2728,13 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
dtype=k_nope.dtype,
device=k_nope.device,
)
- # Direct copies with efficient broadcasting
- k[..., : k_nope.shape[-1]] = k_nope
- k[..., k_nope.shape[-1] :] = k_pe
+
+ if self._use_flashinfer_concat_mla_k:
+ torch.ops.vllm.flashinfer_concat_mla_k(k, k_nope, k_pe)
+ else:
+ # Fallback: Direct copies with efficient broadcasting
+ k[..., : k_nope.shape[-1]] = k_nope
+ k[..., k_nope.shape[-1] :] = k_pe
return k
def _compute_prefill_context(
@@ -2392,34 +2743,27 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
kv_c_and_k_pe_cache: torch.Tensor,
kv_c_and_k_pe_cache_scale: torch.Tensor,
attn_metadata: MLACommonMetadata,
- k_scale: torch.Tensor | None = None,
):
assert attn_metadata.prefill is not None
prefill_metadata = attn_metadata.prefill
assert prefill_metadata.chunked_context is not None
+ use_fp8_prefill = prefill_metadata.q_data_type == current_platform.fp8_dtype()
+
output = None
iters = len(prefill_metadata.chunked_context.seq_tot)
workspace = prefill_metadata.chunked_context.workspace
+
+ if use_fp8_prefill:
+ q = q.to(prefill_metadata.q_data_type)
+
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
- # ops.gather_and_maybe_dequant_cache(
- # src_cache=kv_c_and_k_pe_cache,
- # dst=workspace,
- # block_table=prefill_metadata.block_table,
- # cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
- # token_to_seq=prefill_metadata.chunked_context.token_to_seq[i],
- # num_tokens=prefill_metadata.chunked_context.chunk_total_token[i],
- # kv_cache_dtype=self.kv_cache_dtype,
- # scale=k_scale,
- # seq_starts=prefill_metadata.chunked_context.starts[i],
- # )
-
if envs.VLLM_USE_INT8_MLA:
ops.gather_cache_int8(
src_cache=kv_c_and_k_pe_cache,
src_cache_scale=kv_c_and_k_pe_cache_scale,
- kv_lora_rank=self.kv_lora_rank,
+ kv_lora_rank = self.kv_lora_rank,
dst=workspace,
block_table=prefill_metadata.block_table,
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
@@ -2439,20 +2783,20 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
kv_c_normed = workspace[:toks][..., : self.kv_lora_rank].contiguous()
k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1)
+ k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1)
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
)
+
+ # To Do: Use epilogue of kv_b_proj to generate fp8 kv_nope.
+ if use_fp8_prefill:
+ kv_nope = kv_nope.to(prefill_metadata.q_data_type)
+ k_pe = k_pe.to(prefill_metadata.q_data_type)
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
- # k = self._concat_k_nope_k_pe(k_nope, k_pe)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
- attn_output = torch.empty(
- q.shape[0],
- self.num_heads,
- self.v_head_dim,
- dtype=q.dtype,
- device=q.device,
- )
+
+ attn_output = torch.empty(q.shape[0], self.num_heads, self.v_head_dim, dtype=q.dtype, device=q.device)
attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
prefill=prefill_metadata,
@@ -2467,24 +2811,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
output = attn_output
output_lse = attn_softmax_lse
else:
- # output_tmp = torch.empty_like(output)
- # output_lse_tmp = torch.empty_like(output_lse)
- # merge_attn_states(
- # output=output_tmp,
- # output_lse=output_lse_tmp,
- # prefix_output=output,
- # prefix_lse=output_lse,
- # suffix_output=attn_output,
- # suffix_lse=attn_softmax_lse,
- # )
- # output = output_tmp
- # output_lse = output_lse_tmp
- merge_attn_states(
+ output,output_lse = merge_attn_states(
prefix_output=output,
prefix_lse=output_lse,
suffix_output=attn_output,
suffix_lse=attn_softmax_lse,
- return_lse=True
+ return_lse=True,
)
return output, output_lse
@@ -2564,8 +2896,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
)
k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
- # k = self._concat_k_nope_k_pe(k_nope, k_pe)
- k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
+ k = self._concat_k_nope_k_pe(k_nope, k_pe)
+ attn_output = torch.empty(q.shape[0], self.num_heads, self.v_head_dim, dtype=q.dtype, device=q.device)
attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
prefill=prefill_metadata,
@@ -2573,67 +2905,54 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
q=q,
k=k,
v=v,
+ out=attn_output,
)
if output is None:
output = attn_output
output_lse = attn_softmax_lse
else:
- output_tmp = torch.empty_like(output)
- output_lse_tmp = torch.empty_like(output_lse)
merge_attn_states(
- output=output_tmp,
- output_lse=output_lse_tmp,
prefix_output=output,
prefix_lse=output_lse,
suffix_output=attn_output,
suffix_lse=attn_softmax_lse,
+ return_lse=True,
)
- output = output_tmp
- output_lse = output_lse_tmp
-
return output, output_lse
def forward_mha(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
- # k_pe: torch.Tensor,
k: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
- # k_scale: torch.Tensor,
- # output: torch.Tensor,
- kv_c_and_k_pe_cache_scale: torch.Tensor,
- output: torch.Tensor | None = None,
- ) -> None:
+ kv_c_and_k_pe_cache_scale: torch.Tensor |None = None,
+ ) -> torch.Tensor:
# TODO (zyongye): Prefill function here
assert attn_metadata.prefill is not None
assert self.dcp_world_size != -1
- has_context = attn_metadata.prefill.chunked_context is not None
+ prefill_metadata = attn_metadata.prefill
+ use_fp8_prefill = prefill_metadata.q_data_type == current_platform.fp8_dtype()
+
+ # Convert q to FP8 if FP8 prefill attention is enabled
+ if use_fp8_prefill:
+ q = q.to(prefill_metadata.q_data_type)
+
+ has_context = prefill_metadata.chunked_context is not None
+
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
)
- # k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
-
- # k = self._concat_k_nope_k_pe(k_nope, k_pe)
-
- # output_prefill = self._run_prefill_new_tokens(
- # prefill=attn_metadata.prefill,
- # q=q,
- # k=k,
- # v=v,
- # return_softmax_lse=has_context,
- # )
k_nope, v_nope = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
+
v = v_nope
- k[..., : self.qk_nope_head_dim] = k_nope
-
- attn_output = torch.empty(
- q.shape[0], self.num_heads, self.v_head_dim, dtype=q.dtype, device=q.device
- )
-
+ k[...,:self.qk_nope_head_dim] = k_nope
+
+ attn_output = torch.empty(q.shape[0], self.num_heads, self.v_head_dim, dtype=q.dtype, device=q.device)
+
output = self._run_prefill_new_tokens(
prefill=attn_metadata.prefill,
q=q,
@@ -2644,7 +2963,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
if has_context:
- # suffix_output, suffix_lse = output_prefill
suffix_output, suffix_lse = output
if self.dcp_world_size > 1:
context_output, context_lse = (
@@ -2657,26 +2975,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
)
else:
- # context_output, context_lse = self._compute_prefill_context(
- # q, kv_c_and_k_pe_cache, attn_metadata, k_scale
- # )
- context_output, context_lse = self._compute_prefill_context(
- q, kv_c_and_k_pe_cache, kv_c_and_k_pe_cache_scale, attn_metadata
- )
+ context_output, context_lse = self._compute_prefill_context( \
+ q, kv_c_and_k_pe_cache, kv_c_and_k_pe_cache_scale, attn_metadata)
- # unpad if necessary
- if self._pad_v:
- context_output = context_output[..., : v.shape[-1]]
- suffix_output = suffix_output[..., : v.shape[-1]]
-
- # output = output.view(-1, self.num_heads, self.v_head_dim)
- # merge_attn_states(
- # output=output,
- # prefix_output=context_output,
- # prefix_lse=context_lse,
- # suffix_output=suffix_output,
- # suffix_lse=suffix_lse,
- # )
output = torch.empty_like(suffix_output)
output = merge_attn_states(
output=output,
@@ -2685,322 +2986,22 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
suffix_output=suffix_output,
suffix_lse=suffix_lse,
)
- # else:
- # output_prefill = output_prefill[..., : v.shape[-1]].flatten(start_dim=-2)
- # output.copy_(output_prefill)
+
+ # unpad if necessary
+ if self._pad_v:
+ output = output[..., : v.shape[-1]]
+
return output
@abstractmethod
def forward_mqa(
self,
- q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
+ ql_nope: torch.Tensor,
+ q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: M,
- layer: AttentionLayer,
+ k_c_normed: torch.Tensor | None,
+ k_pe: torch.Tensor | None,
+ kv_c_and_k_pe_cache_scale: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
raise NotImplementedError
-
- def forward_prepare(
- self,
- positions: torch.Tensor,
- ) -> None:
- self.positions = positions
-
- def forward(
- self,
- layer: AttentionLayer,
- q: torch.Tensor,
- k_c_normed: torch.Tensor, # key in unified attn
- k_pe: torch.Tensor, # value in unified attn
- kv_cache: torch.Tensor,
- attn_metadata: M,
- output: torch.Tensor | None = None,
- kv_cache_scale: torch.Tensor | None = None,
- output_scale: torch.Tensor | None = None,
- output_block_scale: torch.Tensor | None = None,
- ) -> torch.Tensor:
- assert output is not None, "Output tensor must be provided."
-
- if output_scale is not None or output_block_scale is not None:
- raise NotImplementedError(
- "fused output quantization is not yet supported for MLACommonImpl"
- )
-
- if attn_metadata is None:
- # During the profile run try to simulate to worse case output size
- # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
- # since this can be large
- _ = torch.empty(
- (
- self.chunked_prefill_workspace_size,
- self.num_heads,
- self.qk_nope_head_dim + self.v_head_dim,
- ),
- device=k_c_normed.device,
- dtype=k_c_normed.dtype,
- )
-
- # The zero fill is required when used with DP + EP
- # to ensure all ranks within a DP group compute the
- # same expert outputs.
- # return output.fill_(0)
- output = torch.empty(
- output.shape[0],
- self.v_head_dim * self.num_heads,
- device=q.device,
- dtype=q.dtype,
- )
- return output
-
- if self.dcp_world_size == -1:
- self.dcp_world_size = get_dcp_group().world_size
-
- fp8_attention = self.kv_cache_dtype.startswith("fp8")
-
- # num_actual_toks = attn_metadata.num_actual_tokens
-
- # Inputs and outputs may be padded for CUDA graphs
- # output_padded = output
- # output = output[:num_actual_toks, ...]
- # q = q[:num_actual_toks, ...]
- # k_c_normed = k_c_normed[:num_actual_toks, ...]
- # k_pe = k_pe[:num_actual_toks, ...]
-
- assert (
- attn_metadata.num_decodes is not None
- and attn_metadata.num_prefills is not None
- and attn_metadata.num_decode_tokens is not None
- )
-
- has_decode = attn_metadata.num_decodes > 0
- has_prefill = attn_metadata.num_prefills > 0
- num_decode_tokens = attn_metadata.num_decode_tokens
-
- decode_q = q[:num_decode_tokens]
-
- k_pe = k_pe.unsqueeze(1)
- prefill_q = q[num_decode_tokens:]
- prefill_k_pe = k_pe[num_decode_tokens:]
- prefill_k_c_normed = k_c_normed[num_decode_tokens:]
- prefill_k = torch.empty_like(prefill_q)
-
- # write the latent and rope to kv cache
- write_kv_cache = (None, None)
- if kv_cache.numel() > 0:
- # ops.concat_and_cache_mla(
- # k_c_normed,
- # k_pe.squeeze(1),
- # kv_cache,
- # attn_metadata.slot_mapping.flatten(),
- # kv_cache_dtype=self.kv_cache_dtype,
- # scale=layer._k_scale,
- # )
- if has_decode:
- decode_q_pe, decode_k_pe = self.rotary_emb(
- self.positions[:num_decode_tokens],
- decode_q[..., self.qk_nope_head_dim :],
- k_pe[:num_decode_tokens],
- )
- if envs.VLLM_USE_INT8_MLA:
- write_kv_cache = (None, None)
- k_c_normed_int8, k_c_normed_scale, _ = ops.scaled_int8_quant(
- k_c_normed[:num_decode_tokens]
- )
- decode_k_pe_int8, decode_k_pe_scale, _ = ops.scaled_int8_quant(
- decode_k_pe
- )
- ops.concat_and_cache_mla_int8(
- kv_c_int8=k_c_normed_int8,
- kv_c_scale=k_c_normed_scale[..., 0],
- k_pe_int8=decode_k_pe_int8,
- k_pe_scale=decode_k_pe_scale[..., 0].view(
- -1, decode_k_pe_int8.shape[-2]
- ),
- kv_cache=kv_cache,
- kv_cache_scale=kv_cache_scale,
- slot_mapping=attn_metadata.slot_mapping.flatten()[
- :num_decode_tokens
- ],
- kv_cache_dtype=self.kv_cache_dtype,
- scale=layer._k_scale,
- )
- else:
- # write_kv_cache = (k_c_normed[:num_decode_tokens], decode_k_pe)
- if self.dcp_world_size > 1:
- ops.concat_and_cache_mla(
- k_c_normed[:num_decode_tokens],
- decode_k_pe,
- kv_cache,
- attn_metadata.slot_mapping.flatten()[:num_decode_tokens],
- kv_cache_dtype=self.kv_cache_dtype,
- scale=layer._k_scale,
- )
- else:
- write_kv_cache = (k_c_normed[:num_decode_tokens], decode_k_pe)
- if has_prefill:
- ixf_ops.mla_rope(
- self.positions[num_decode_tokens:],
- prefill_q[..., self.qk_nope_head_dim :],
- prefill_k_pe.squeeze(1),
- prefill_k[..., self.qk_nope_head_dim :],
- self.rotary_emb.cos_sin_cache,
- )
- if envs.VLLM_USE_INT8_MLA:
- prefill_k_c_normed_int8, prefill_k_c_normed_scale, _ = (
- ops.scaled_int8_quant(prefill_k_c_normed)
- )
- prefill_k_pe_int8, prefill_k_pe_scale, _ = ops.scaled_int8_quant(
- prefill_k[..., self.qk_nope_head_dim :].contiguous()
- )
-
- ops.concat_and_cache_mla_int8(
- prefill_k_c_normed_int8,
- prefill_k_c_normed_scale[..., 0],
- prefill_k_pe_int8,
- prefill_k_pe_scale[..., 0].view(
- -1, prefill_k_pe_int8.shape[-2]
- ),
- kv_cache,
- kv_cache_scale,
- attn_metadata.slot_mapping.flatten()[num_decode_tokens:],
- kv_cache_dtype=self.kv_cache_dtype,
- scale=layer._k_scale,
- )
- else:
- ops.concat_and_cache_mla(
- prefill_k_c_normed,
- prefill_k[..., self.qk_nope_head_dim :],
- kv_cache,
- attn_metadata.slot_mapping.flatten()[num_decode_tokens:],
- kv_cache_dtype=self.kv_cache_dtype,
- scale=layer._k_scale,
- )
- output = torch.empty(
- output.shape[0],
- self.num_heads,
- self.v_head_dim,
- device=q.device,
- dtype=q.dtype,
- )
-
- if fp8_attention:
- kv_cache = kv_cache.view(current_platform.fp8_dtype())
-
- if has_prefill:
- # self.forward_mha(
- # prefill_q,
- # prefill_k_c_normed,
- # prefill_k_pe,
- # kv_cache,
- # attn_metadata,
- # layer._k_scale,
- # output=output[num_decode_tokens:],
- # )
- output[num_decode_tokens:] = self.forward_mha(
- prefill_q,
- prefill_k_c_normed,
- prefill_k,
- kv_cache,
- attn_metadata,
- kv_c_and_k_pe_cache_scale=kv_cache_scale,
- )
-
- if has_decode:
- assert attn_metadata.decode is not None
-
- # decode_q_nope, decode_q_pe = decode_q.split(
- # [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
- # )
-
- # # Convert from (B, N, P) to (N, B, P)
- # decode_q_nope = decode_q_nope.transpose(0, 1)
-
- # if self.q_pad_num_heads is not None:
- # B, N, L = decode_q_pe.shape
- # decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L))
- # decode_pe_padded.resize_((B, N, L))
- # decode_pe_padded.copy_(decode_q_pe)
- # decode_q_pe = decode_pe_padded
-
- # if self.is_aiter_triton_fp4_bmm_enabled:
- # from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4
-
- # decode_ql_nope = batched_gemm_a16wfp4(
- # decode_q_nope,
- # self.W_K,
- # self.W_K_scale,
- # transpose_bm=True,
- # prequant=True,
- # y_scale=layer._q_scale if fp8_attention else None,
- # )
- # elif self.is_aiter_triton_fp8_bmm_enabled:
- # # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
- # decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm(
- # decode_q_nope,
- # self.W_K,
- # self.W_K_scale,
- # group_size=128,
- # transpose_bm=True,
- # )
- # else:
- # # Pads the head_dim if necessary (for the underlying kernel)
- # N, B, P = decode_q_nope.shape
- # _, _, L = self.W_UK_T.shape
-
- # if self.q_pad_num_heads is not None:
- # decode_ql_nope = decode_q_nope.new_empty(
- # (self.q_pad_num_heads, B, L)
- # )
- # decode_ql_nope.resize_((N, B, L))
- # else:
- # decode_ql_nope = decode_q_nope.new_empty((N, B, L))
-
- # # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
- # torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope)
-
- # # Convert from (N, B, L) to (B, N, L)
- # decode_ql_nope = decode_ql_nope.transpose(0, 1)
-
- # if fp8_attention:
- # assert decode_ql_nope.shape[0] == decode_q_pe.shape[0]
- # assert decode_ql_nope.shape[1] == decode_q_pe.shape[1]
- # decode_q = self._decode_concat_quant_fp8_op(
- # decode_ql_nope, decode_q_pe, layer._q_scale
- # )
- # else:
- # decode_q = (decode_ql_nope, decode_q_pe)
- # if self.dcp_world_size > 1:
- # assert not fp8_attention, "DCP not support fp8 kvcache now."
- # # concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)
- # decode_q = torch.cat(decode_q, dim=-1)
- # # decode_q do allgather in head dim.
- # decode_q = get_dcp_group().all_gather(decode_q, dim=1)
-
- # call decode attn
- # decode_q, kv_cache, attn_metadata, layer
-
- attn_out, lse = self.forward_mqa(
- decode_q[..., :self.qk_nope_head_dim], decode_q_pe, kv_cache, attn_metadata,
- *write_kv_cache, kv_c_and_k_pe_cache_scale=kv_cache_scale
- )
-
- # correct dcp attn_out with lse.
- if self.dcp_world_size > 1:
- assert lse is not None
- attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
- output[:num_decode_tokens] = self._v_up_proj(attn_out)
- else:
- assert lse is None
- output[:num_decode_tokens] = attn_out
-
- return output.view(output.shape[0], self.v_head_dim * self.num_heads)
- # attn_out = cp_lse_ag_out_rs(
- # attn_out,
- # lse,
- # get_dcp_group(),
- # is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False),
- # )
-
- # # v_up projection
- # self._v_up_proj(attn_out, out=output[:num_decode_tokens])
- # return output_padded
diff --git a/vllm/model_executor/layers/attention/mm_encoder_attention.py b/vllm/model_executor/layers/attention/mm_encoder_attention.py
index e59806a..64370a6 100644
--- a/vllm/model_executor/layers/attention/mm_encoder_attention.py
+++ b/vllm/model_executor/layers/attention/mm_encoder_attention.py
@@ -2,21 +2,94 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import numpy as np
import torch
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.models.vision import get_vit_attn_backend
+from vllm.utils.math_utils import round_up
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
+ vit_flashinfer_wrapper,
vit_torch_sdpa_wrapper,
vit_triton_attn_wrapper,
)
+import ixformer.contrib.vllm_flash_attn as ops
logger = init_logger(__name__)
+# Batch buckets for cuDNN graph caching.
+# Graphs use batch size and max sequence length as cache key.
+# This avoids creating a new graph for each unique set of
+# batch size and max sequence length at runtime.
+# From the cuDNN team's performance measurements, there
+# is no significant kernel performance difference between padding
+# to a smaller batch size/seq length and padding to larger
+# ones. The bucketing here is solely used to avoid memory
+# operation overhead, which won't be needed if we have CUDA
+# graph support in the future.
+# TODO: Remove buckets after issue #34763
+# (cuda graph support) is addressed.
+FLASHINFER_BATCH_BUCKETS = [8, 16, 32, 64]
+FLASHINFER_MAX_SEQLEN_BUCKETS = [
+ 1 * 1024,
+ 2 * 1024,
+ 4 * 1024,
+ 8 * 1024,
+ 16 * 1024,
+ 32 * 1024,
+ 64 * 1024,
+ 128 * 1024,
+]
+
+# Workspace buffer for FlashInfer CuDNN backend
+FLASHINFER_CUDNN_WORKSPACE_SIZE_BYTES = 128 * 1024 * 1024
+_flashinfer_workspace_buffer: torch.Tensor | None = None
+
+
+def _get_flashinfer_workspace_buffer() -> torch.Tensor:
+ global _flashinfer_workspace_buffer
+ if _flashinfer_workspace_buffer is None:
+ _flashinfer_workspace_buffer = torch.zeros(
+ FLASHINFER_CUDNN_WORKSPACE_SIZE_BYTES,
+ dtype=torch.uint8,
+ device="cuda",
+ )
+ return _flashinfer_workspace_buffer
+
+
+def add_padding_to_seqlens(
+ seq: np.ndarray,
+ batch_size: int,
+ padding_value: int,
+) -> np.ndarray:
+ batch_size_padded = next(
+ (b for b in FLASHINFER_BATCH_BUCKETS if b >= batch_size),
+ round_up(batch_size, FLASHINFER_BATCH_BUCKETS[0]),
+ )
+ if batch_size_padded == batch_size:
+ return seq
+ return np.concatenate(
+ [
+ seq,
+ np.full((batch_size_padded - batch_size,), padding_value, dtype=seq.dtype),
+ ]
+ )
+
+
+def bucket_flashinfer_max_seqlen(
+ real_max_seqlen: int,
+) -> int:
+ if real_max_seqlen <= 0:
+ return FLASHINFER_MAX_SEQLEN_BUCKETS[0]
+ return next(
+ (s for s in FLASHINFER_MAX_SEQLEN_BUCKETS if s >= real_max_seqlen),
+ round_up(real_max_seqlen, FLASHINFER_MAX_SEQLEN_BUCKETS[-1]),
+ )
+
# --8<-- [start:mm_encoder_attn]
@CustomOp.register("mm_encoder_attn")
@@ -24,6 +97,67 @@ class MMEncoderAttention(CustomOp):
"""Multi-headed attention without any cache, used for multimodal encoder."""
# --8<-- [end:mm_encoder_attn]
+ @classmethod
+ def compute_max_seqlen(
+ cls,
+ attn_backend: AttentionBackendEnum,
+ cu_seqlens: np.ndarray,
+ ) -> int:
+ max_seqlen = 0
+ if (
+ attn_backend
+ in (
+ AttentionBackendEnum.FLASH_ATTN,
+ AttentionBackendEnum.ROCM_AITER_FA,
+ AttentionBackendEnum.TRITON_ATTN,
+ AttentionBackendEnum.FLASHINFER,
+ )
+ and len(cu_seqlens) >= 2
+ ):
+ max_seqlen = int((cu_seqlens[1:] - cu_seqlens[:-1]).max())
+ if attn_backend == AttentionBackendEnum.FLASHINFER:
+ max_seqlen = bucket_flashinfer_max_seqlen(max_seqlen)
+ return max_seqlen
+
+ @classmethod
+ def maybe_compute_sequence_lengths(
+ cls,
+ attn_backend: AttentionBackendEnum,
+ cu_seqlens: np.ndarray,
+ ) -> np.ndarray | None:
+ if attn_backend != AttentionBackendEnum.FLASHINFER:
+ return None
+ sequence_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
+ sequence_lengths = add_padding_to_seqlens(
+ sequence_lengths, len(sequence_lengths), 0
+ )
+ return sequence_lengths
+
+ @classmethod
+ def maybe_recompute_cu_seqlens(
+ cls,
+ attn_backend: AttentionBackendEnum,
+ cu_seqlens: np.ndarray,
+ hidden_size: int,
+ tp_size: int,
+ ) -> np.ndarray:
+ if attn_backend != AttentionBackendEnum.FLASHINFER:
+ return cu_seqlens
+
+ batch_size = len(cu_seqlens) - 1
+ scale = hidden_size // tp_size
+ cu_seqlens = cu_seqlens * scale
+
+ cu_seqlens_qko = cu_seqlens
+ cu_seqlens_v = cu_seqlens * 3
+
+ cu_seqlens_qko = add_padding_to_seqlens(
+ cu_seqlens_qko, batch_size, cu_seqlens_qko[-1]
+ )
+ cu_seqlens_v = add_padding_to_seqlens(
+ cu_seqlens_v, batch_size, cu_seqlens_v[-1]
+ )
+ return np.concatenate([cu_seqlens_qko, cu_seqlens_v])
def __init__(
self,
@@ -46,10 +180,9 @@ class MMEncoderAttention(CustomOp):
self.num_heads = num_heads
self.head_size = head_size
- self.scale = scale
+ self.scale = 1.0 / (head_size**0.5) if scale is None else scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.layer_name = prefix
-
assert self.num_heads % self.num_kv_heads == 0, (
f"num_heads ({self.num_heads}) is not "
f"divisible by num_kv_heads ({self.num_kv_heads})"
@@ -72,9 +205,14 @@ class MMEncoderAttention(CustomOp):
}
self._fa_version = (
- get_flash_attn_version() if self.is_flash_attn_backend else None
+ get_flash_attn_version(head_size=head_size)
+ if self.is_flash_attn_backend
+ else None
)
+ if self.attn_backend == AttentionBackendEnum.FLASHINFER:
+ _get_flashinfer_workspace_buffer()
+
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
@classmethod
@@ -148,23 +286,27 @@ class MMEncoderAttention(CustomOp):
bsz, q_len = query.size()[:2]
kv_len = key.size(1)
is_reshaped = query.dim() != 4
+ query = query.view(bsz * q_len, self.num_heads, self.head_size)
+ key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size)
+ value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size)
- query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)
-
- output = vit_flash_attn_wrapper(
- q=query,
- k=key,
- v=value,
- batch_size=bsz,
- is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
- fa_version=self._fa_version,
- scale=self.scale,
- cu_seqlens=cu_seqlens,
- max_seqlen=max_seqlen,
+ cu_q = torch.tensor([0,] + [q_len for _ in range(bsz)], device=query.device, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32)
+ cu_kv = torch.tensor([0,] + [kv_len for _ in range(bsz)], device=query.device, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32)
+ out = ops.flash_attn_varlen_func(
+ query,
+ key,
+ value,
+ cu_q,
+ cu_kv,
+ q_len,
+ kv_len,
+ softmax_scale=self.scale,
+ causal=False,
)
+ out = out.view(bsz, q_len, self.num_heads, self.head_size)
if is_reshaped:
- output = output.reshape(bsz, q_len, -1)
- return output
+ out = out.reshape(bsz, q_len, -1)
+ return out
def _forward_triton(
self,
@@ -201,6 +343,27 @@ class MMEncoderAttention(CustomOp):
output = output.reshape(bsz, q_len, -1)
return output
+ def _forward_flashinfer(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ cu_seqlens: torch.Tensor | None = None,
+ max_seqlen: torch.Tensor | None = None,
+ sequence_lengths: torch.Tensor
+ | None = None, # Only used for FlashInfer CuDNN backend
+ ) -> torch.Tensor:
+ return vit_flashinfer_wrapper(
+ q=query,
+ k=key,
+ v=value,
+ scale=self.scale,
+ workspace_buffer=_get_flashinfer_workspace_buffer(),
+ cu_seqlens=cu_seqlens,
+ max_seqlen=max_seqlen,
+ sequence_lengths=sequence_lengths,
+ )
+
def forward_native(
self,
query: torch.Tensor,
@@ -208,6 +371,8 @@ class MMEncoderAttention(CustomOp):
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
+ sequence_lengths: torch.Tensor
+ | None = None, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
return self._forward_sdpa(query, key, value, cu_seqlens)
@@ -218,11 +383,17 @@ class MMEncoderAttention(CustomOp):
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
+ sequence_lengths: torch.Tensor
+ | None = None, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
if self.is_flash_attn_backend:
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
elif self.attn_backend == AttentionBackendEnum.TRITON_ATTN:
return self._forward_triton(query, key, value, cu_seqlens, max_seqlen)
+ elif self.attn_backend == AttentionBackendEnum.FLASHINFER:
+ return self._forward_flashinfer(
+ query, key, value, cu_seqlens, max_seqlen, sequence_lengths
+ )
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
return self._forward_sdpa(query, key, value, cu_seqlens)
else:
@@ -238,6 +409,8 @@ class MMEncoderAttention(CustomOp):
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
+ sequence_lengths: torch.Tensor
+ | None = None, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
return self._forward_sdpa(query, key, value, cu_seqlens)
@@ -248,6 +421,8 @@ class MMEncoderAttention(CustomOp):
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
+ sequence_lengths: torch.Tensor
+ | None = None, # Only used for FlashInfer CuDNN backend
) -> torch.Tensor:
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
diff --git a/vllm/model_executor/layers/fla/ops/__init__.py b/vllm/model_executor/layers/fla/ops/__init__.py
index c19cc14..e52387a 100644
--- a/vllm/model_executor/layers/fla/ops/__init__.py
+++ b/vllm/model_executor/layers/fla/ops/__init__.py
@@ -7,11 +7,17 @@
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from .chunk import chunk_gated_delta_rule
-from .fused_recurrent import fused_recurrent_gated_delta_rule
+from .fused_recurrent import (
+ fused_recurrent_gated_delta_rule,
+ fused_recurrent_gated_delta_rule_packed_decode,
+)
+from .fused_sigmoid_gating import fused_sigmoid_gating_delta_rule_update
from .layernorm_guard import RMSNormGated
__all__ = [
"RMSNormGated",
"chunk_gated_delta_rule",
"fused_recurrent_gated_delta_rule",
+ "fused_recurrent_gated_delta_rule_packed_decode",
+ "fused_sigmoid_gating_delta_rule_update",
]
diff --git a/vllm/model_executor/layers/fla/ops/chunk.py b/vllm/model_executor/layers/fla/ops/chunk.py
index 40f8c3c..9261885 100644
--- a/vllm/model_executor/layers/fla/ops/chunk.py
+++ b/vllm/model_executor/layers/fla/ops/chunk.py
@@ -30,7 +30,7 @@ def chunk_gated_delta_rule_fwd(
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
- cu_seqlens: torch.LongTensor | None = None,
+ cu_seqlens: torch.Tensor | None = None,
):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
# obtain WY representation. u is actually the new v.
@@ -84,7 +84,7 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
- cu_seqlens: torch.LongTensor | None = None,
+ cu_seqlens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
):
if use_qk_l2norm_in_kernel:
@@ -117,7 +117,7 @@ def chunk_gated_delta_rule(
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
- cu_seqlens: torch.LongTensor | None = None,
+ cu_seqlens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
):
r"""
@@ -141,7 +141,7 @@ def chunk_gated_delta_rule(
Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, V, K]`. Default: `False`.
- cu_seqlens (torch.LongTensor):
+ cu_seqlens (torch.Tensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
Returns:
@@ -171,7 +171,7 @@ def chunk_gated_delta_rule(
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
- >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.int32)
>>> o_var, ht_var = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
diff --git a/vllm/model_executor/layers/fla/ops/chunk_delta_h.py b/vllm/model_executor/layers/fla/ops/chunk_delta_h.py
index 98a3d61..ce60ca4 100644
--- a/vllm/model_executor/layers/fla/ops/chunk_delta_h.py
+++ b/vllm/model_executor/layers/fla/ops/chunk_delta_h.py
@@ -288,7 +288,7 @@ def chunk_gated_delta_rule_fwd_h(
output_final_state: bool = False,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
save_new_value: bool = True,
- cu_seqlens: torch.LongTensor | None = None,
+ cu_seqlens: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# This kernel is slightly different from fla to support Q/K with different head numbers.
# In fla, Q/K always have the same head number, so Hg is always equal to H.
diff --git a/vllm/model_executor/layers/fla/ops/chunk_o.py b/vllm/model_executor/layers/fla/ops/chunk_o.py
index 16890af..1307812 100644
--- a/vllm/model_executor/layers/fla/ops/chunk_o.py
+++ b/vllm/model_executor/layers/fla/ops/chunk_o.py
@@ -89,7 +89,7 @@ def chunk_fwd_kernel_o(
b_o = tl.zeros([BT, BV], dtype=tl.float32)
b_A = tl.zeros([BT, BT], dtype=tl.float32)
-
+
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(
q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
@@ -145,7 +145,7 @@ def chunk_fwd_o(
h: torch.Tensor,
g: torch.Tensor | None = None, # cumsum of log decay
scale: float | None = None,
- cu_seqlens: torch.LongTensor | None = None,
+ cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64,
) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1]
diff --git a/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py b/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py
index 7724fa5..31bd489 100644
--- a/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py
+++ b/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py
@@ -102,7 +102,7 @@ def chunk_scaled_dot_kkt_fwd(
k: torch.Tensor,
g: torch.Tensor | None = None,
beta: torch.Tensor | None = None,
- cu_seqlens: torch.LongTensor | None = None,
+ cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
@@ -116,7 +116,7 @@ def chunk_scaled_dot_kkt_fwd(
The beta tensor of shape `[B, T, H]`.
g (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`.
- cu_seqlens (torch.LongTensor):
+ cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):
diff --git a/vllm/model_executor/layers/fla/ops/fused_recurrent.py b/vllm/model_executor/layers/fla/ops/fused_recurrent.py
index 67d77e8..920efa4 100644
--- a/vllm/model_executor/layers/fla/ops/fused_recurrent.py
+++ b/vllm/model_executor/layers/fla/ops/fused_recurrent.py
@@ -106,12 +106,12 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
else:
i_t = 0
- # Load state index and check for PAD_SLOT_ID (-1)
+ # Load state index and check for invalid entries
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
tl.int64
)
- # Skip if state index is invalid (PAD_SLOT_ID = -1)
- if state_idx < 0:
+ # Skip if state index is invalid (NULL_BLOCK_ID=0)
+ if state_idx <= 0:
return
p_h0 = h0 + state_idx * stride_init_state_token
else:
@@ -150,12 +150,12 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
# keep the states for multi-query tokens
if INPLACE_FINAL_STATE:
- # Load state index and check for PAD_SLOT_ID (-1)
+ # Load state index and check for invalid entries
final_state_idx = tl.load(
ssm_state_indices + i_n * stride_indices_seq + i_t
).to(tl.int64)
- # Only store if state index is valid (not PAD_SLOT_ID)
- if final_state_idx >= 0:
+ # Only store if state index is valid (not NULL_BLOCK_ID=0)
+ if final_state_idx > 0:
p_ht = ht + final_state_idx * stride_final_state_token
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
@@ -184,7 +184,7 @@ def fused_recurrent_gated_delta_rule_fwd(
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
- cu_seqlens: torch.LongTensor | None = None,
+ cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
@@ -252,6 +252,232 @@ def fused_recurrent_gated_delta_rule_fwd(
return o, final_state
+@triton.jit
+def fused_recurrent_gated_delta_rule_packed_decode_kernel(
+ mixed_qkv,
+ a,
+ b,
+ A_log,
+ dt_bias,
+ o,
+ h0,
+ ht,
+ ssm_state_indices,
+ scale,
+ stride_mixed_qkv_tok: tl.constexpr,
+ stride_a_tok: tl.constexpr,
+ stride_b_tok: tl.constexpr,
+ stride_init_state_token: tl.constexpr,
+ stride_final_state_token: tl.constexpr,
+ stride_indices_seq: tl.constexpr,
+ H: tl.constexpr,
+ HV: tl.constexpr,
+ K: tl.constexpr,
+ V: tl.constexpr,
+ BK: tl.constexpr,
+ BV: tl.constexpr,
+ SOFTPLUS_THRESHOLD: tl.constexpr,
+ USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
+):
+ i_v, i_nh = tl.program_id(0), tl.program_id(1)
+ i_n, i_hv = i_nh // HV, i_nh % HV
+ i_h = i_hv // (HV // H)
+
+ o_k = tl.arange(0, BK)
+ o_v = i_v * BV + tl.arange(0, BV)
+ mask_k = o_k < K
+ mask_v = o_v < V
+ mask_h = mask_v[:, None] & mask_k[None, :]
+
+ state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq).to(tl.int64)
+ p_o = o + (i_n * HV + i_hv) * V + o_v
+
+ # Skip if state index is invalid (NULL_BLOCK_ID=0)
+ if state_idx <= 0:
+ zero = tl.zeros([BV], dtype=tl.float32).to(p_o.dtype.element_ty)
+ tl.store(p_o, zero, mask=mask_v)
+ return
+
+ p_h0 = h0 + state_idx * stride_init_state_token
+ p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
+ b_h = tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
+
+ p_mixed = mixed_qkv + i_n * stride_mixed_qkv_tok
+ q_off = i_h * K + o_k
+ k_off = (H * K) + i_h * K + o_k
+ v_off = (2 * H * K) + i_hv * V + o_v
+ b_q = tl.load(p_mixed + q_off, mask=mask_k, other=0).to(tl.float32)
+ b_k = tl.load(p_mixed + k_off, mask=mask_k, other=0).to(tl.float32)
+ b_v = tl.load(p_mixed + v_off, mask=mask_v, other=0).to(tl.float32)
+
+ if USE_QK_L2NORM_IN_KERNEL:
+ b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
+ b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
+ b_q = b_q * scale
+
+ a_val = tl.load(a + i_n * stride_a_tok + i_hv).to(tl.float32)
+ b_val = tl.load(b + i_n * stride_b_tok + i_hv).to(tl.float32)
+ A_log_val = tl.load(A_log + i_hv).to(tl.float32)
+ dt_bias_val = tl.load(dt_bias + i_hv).to(tl.float32)
+ x = a_val + dt_bias_val
+ softplus_x = tl.where(x <= SOFTPLUS_THRESHOLD, tl.log(1.0 + tl.exp(x)), x)
+ g_val = -tl.exp(A_log_val) * softplus_x
+ beta_val = tl.sigmoid(b_val).to(b.dtype.element_ty).to(tl.float32)
+
+ b_h *= exp(g_val)
+ b_v -= tl.sum(b_h * b_k[None, :], 1)
+ b_v *= beta_val
+ b_h += b_v[:, None] * b_k[None, :]
+ b_o = tl.sum(b_h * b_q[None, :], 1)
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
+
+ p_ht = ht + state_idx * stride_final_state_token
+ p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
+
+
+def fused_recurrent_gated_delta_rule_packed_decode(
+ mixed_qkv: torch.Tensor,
+ a: torch.Tensor,
+ b: torch.Tensor,
+ A_log: torch.Tensor,
+ dt_bias: torch.Tensor,
+ scale: float,
+ initial_state: torch.Tensor,
+ out: torch.Tensor,
+ ssm_state_indices: torch.Tensor,
+ use_qk_l2norm_in_kernel: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ if mixed_qkv.ndim != 2:
+ raise ValueError(
+ f"`mixed_qkv` must be a 2D tensor (got ndim={mixed_qkv.ndim})."
+ )
+ if mixed_qkv.stride(-1) != 1:
+ raise ValueError("`mixed_qkv` must be contiguous in the last dim.")
+ if a.ndim != 2 or b.ndim != 2:
+ raise ValueError(
+ f"`a` and `b` must be 2D tensors (got a.ndim={a.ndim}, b.ndim={b.ndim})."
+ )
+ if a.stride(-1) != 1 or b.stride(-1) != 1:
+ raise ValueError("`a`/`b` must be contiguous in the last dim.")
+ if A_log.ndim != 1 or dt_bias.ndim != 1:
+ raise ValueError("`A_log`/`dt_bias` must be 1D tensors.")
+ if A_log.stride(0) != 1 or dt_bias.stride(0) != 1:
+ raise ValueError("`A_log`/`dt_bias` must be contiguous.")
+ if ssm_state_indices.ndim != 1:
+ raise ValueError(
+ f"`ssm_state_indices` must be 1D for packed decode (got ndim={ssm_state_indices.ndim})."
+ )
+ if not out.is_contiguous():
+ raise ValueError("`out` must be contiguous.")
+
+ dev = mixed_qkv.device
+ if (
+ a.device != dev
+ or b.device != dev
+ or A_log.device != dev
+ or dt_bias.device != dev
+ or initial_state.device != dev
+ or out.device != dev
+ or ssm_state_indices.device != dev
+ ):
+ raise ValueError("All inputs must be on the same device.")
+
+ B = mixed_qkv.shape[0]
+ if a.shape[0] != B or b.shape[0] != B:
+ raise ValueError(
+ "Mismatched batch sizes: "
+ f"mixed_qkv.shape[0]={B}, a.shape[0]={a.shape[0]}, b.shape[0]={b.shape[0]}."
+ )
+ if ssm_state_indices.shape[0] != B:
+ raise ValueError(
+ f"`ssm_state_indices` must have shape [B] (got {tuple(ssm_state_indices.shape)}; expected ({B},))."
+ )
+
+ if initial_state.ndim != 4:
+ raise ValueError(
+ f"`initial_state` must be a 4D tensor (got ndim={initial_state.ndim})."
+ )
+ if initial_state.stride(-1) != 1:
+ raise ValueError("`initial_state` must be contiguous in the last dim.")
+ HV, V, K = initial_state.shape[-3:]
+ if a.shape[1] != HV or b.shape[1] != HV:
+ raise ValueError(
+ f"`a`/`b` must have shape [B, HV] with HV={HV} (got a.shape={tuple(a.shape)}, b.shape={tuple(b.shape)})."
+ )
+ if A_log.numel() != HV or dt_bias.numel() != HV:
+ raise ValueError(
+ f"`A_log` and `dt_bias` must have {HV} elements (got A_log.numel()={A_log.numel()}, dt_bias.numel()={dt_bias.numel()})."
+ )
+ if out.shape != (B, 1, HV, V):
+ raise ValueError(
+ f"`out` must have shape {(B, 1, HV, V)} (got out.shape={tuple(out.shape)})."
+ )
+
+ qkv_dim = mixed_qkv.shape[1]
+ qk_dim = qkv_dim - HV * V
+ if qk_dim <= 0 or qk_dim % 2 != 0:
+ raise ValueError(
+ f"Invalid packed `mixed_qkv` last dim={qkv_dim} for HV={HV}, V={V}."
+ )
+ q_dim = qk_dim // 2
+ if q_dim % K != 0:
+ raise ValueError(f"Invalid packed Q size {q_dim}: must be divisible by K={K}.")
+ H = q_dim // K
+ if H <= 0 or HV % H != 0:
+ raise ValueError(
+ f"Invalid head config inferred from mixed_qkv: H={H}, HV={HV}."
+ )
+
+ BK = triton.next_power_of_2(K)
+ if triton.cdiv(K, BK) != 1:
+ raise ValueError(
+ f"Packed decode kernel only supports NK=1 (got K={K}, BK={BK})."
+ )
+ BV = min(triton.next_power_of_2(V), 32)
+ num_stages = 3
+ num_warps = 1
+
+ stride_mixed_qkv_tok = mixed_qkv.stride(0)
+ stride_a_tok = a.stride(0)
+ stride_b_tok = b.stride(0)
+ stride_init_state_token = initial_state.stride(0)
+ stride_final_state_token = initial_state.stride(0)
+ stride_indices_seq = ssm_state_indices.stride(0)
+
+ NV = triton.cdiv(V, BV)
+ grid = (NV, B * HV)
+ fused_recurrent_gated_delta_rule_packed_decode_kernel[grid](
+ mixed_qkv=mixed_qkv,
+ a=a,
+ b=b,
+ A_log=A_log,
+ dt_bias=dt_bias,
+ o=out,
+ h0=initial_state,
+ ht=initial_state,
+ ssm_state_indices=ssm_state_indices,
+ scale=scale,
+ stride_mixed_qkv_tok=stride_mixed_qkv_tok,
+ stride_a_tok=stride_a_tok,
+ stride_b_tok=stride_b_tok,
+ stride_init_state_token=stride_init_state_token,
+ stride_final_state_token=stride_final_state_token,
+ stride_indices_seq=stride_indices_seq,
+ H=H,
+ HV=HV,
+ K=K,
+ V=V,
+ BK=BK,
+ BV=BV,
+ SOFTPLUS_THRESHOLD=20.0,
+ USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+ return out, initial_state
+
+
class FusedRecurrentFunction(torch.autograd.Function):
@staticmethod
def forward(
@@ -264,7 +490,7 @@ class FusedRecurrentFunction(torch.autograd.Function):
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
- cu_seqlens: torch.LongTensor | None = None,
+ cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
@@ -296,7 +522,7 @@ def fused_recurrent_gated_delta_rule(
scale: float = None,
initial_state: torch.Tensor = None,
inplace_final_state: bool = True,
- cu_seqlens: torch.LongTensor | None = None,
+ cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
@@ -324,7 +550,7 @@ def fused_recurrent_gated_delta_rule(
inplace_final_state: bool:
Whether to store the final state in-place to save memory.
Default: `True`.
- cu_seqlens (torch.LongTensor):
+ cu_seqlens (torch.Tensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
ssm_state_indices (Optional[torch.Tensor]):
@@ -358,7 +584,7 @@ def fused_recurrent_gated_delta_rule(
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
- >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.int32)
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
q, k, v, g, beta,
initial_state=h0,
diff --git a/vllm/model_executor/layers/fla/ops/fused_sigmoid_gating.py b/vllm/model_executor/layers/fla/ops/fused_sigmoid_gating.py
new file mode 100644
index 0000000..7e0c7e0
--- /dev/null
+++ b/vllm/model_executor/layers/fla/ops/fused_sigmoid_gating.py
@@ -0,0 +1,279 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
+#
+# This file contains code copied from the flash-linear-attention project.
+# The original source code was licensed under the MIT license and included
+# the following copyright notice:
+# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
+
+import torch
+
+from vllm.triton_utils import tl, triton
+
+
+@triton.heuristics(
+ {
+ "USE_INITIAL_STATE": lambda args: args["h0"] is not None,
+ "IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
+ "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None,
+ "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None,
+ }
+)
+@triton.jit(do_not_specialize=["N", "T"])
+def fused_sigmoid_gating_delta_rule_update_kernel(
+ A_log,
+ a,
+ b,
+ dt_bias,
+ beta,
+ threshold,
+ q,
+ k,
+ v,
+ o,
+ h0,
+ ht,
+ cu_seqlens,
+ ssm_state_indices,
+ num_accepted_tokens,
+ scale,
+ N: tl.int64, # num of sequences
+ T: tl.int64, # num of tokens
+ B: tl.constexpr,
+ H: tl.constexpr,
+ HV: tl.constexpr,
+ K: tl.constexpr,
+ V: tl.constexpr,
+ BK: tl.constexpr,
+ BV: tl.constexpr,
+ stride_init_state_token: tl.constexpr,
+ stride_final_state_token: tl.constexpr,
+ stride_indices_seq: tl.constexpr,
+ stride_indices_tok: tl.constexpr,
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
+ INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
+ USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
+ IS_VARLEN: tl.constexpr,
+ IS_CONTINUOUS_BATCHING: tl.constexpr,
+ IS_SPEC_DECODING: tl.constexpr,
+ IS_KDA: tl.constexpr,
+):
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
+ i_n, i_hv = i_nh // HV, i_nh % HV
+ i_h = i_hv // (HV // H)
+ if IS_VARLEN:
+ bos, eos = (
+ tl.load(cu_seqlens + i_n).to(tl.int64),
+ tl.load(cu_seqlens + i_n + 1).to(tl.int64),
+ )
+ all = T
+ T = eos - bos
+ else:
+ bos, eos = i_n * T, i_n * T + T
+ all = B * T
+
+ if T == 0:
+ # no tokens to process for this sequence
+ return
+
+ o_k = i_k * BK + tl.arange(0, BK)
+ o_v = i_v * BV + tl.arange(0, BV)
+
+ p_q = q + (bos * H + i_h) * K + o_k
+ p_k = k + (bos * H + i_h) * K + o_k
+ p_v = v + (bos * HV + i_hv) * V + o_v
+
+ p_A_log = A_log + i_hv
+ if not IS_KDA:
+ p_a = a + bos * HV + i_hv
+ p_dt_bias = dt_bias + i_hv
+ else:
+ p_a = a + (bos * HV + i_hv) * K + o_k
+ p_dt_bias = dt_bias + i_hv * K + o_k
+
+ p_b = b + bos * HV + i_hv
+ p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
+
+ mask_k = o_k < K
+ mask_v = o_v < V
+ mask_h = mask_v[:, None] & mask_k[None, :]
+
+ b_h = tl.zeros([BV, BK], dtype=tl.float32)
+ if USE_INITIAL_STATE:
+ if IS_CONTINUOUS_BATCHING:
+ if IS_SPEC_DECODING:
+ i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
+ else:
+ i_t = 0
+ # Load state index and check for invalid entries
+ state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
+ tl.int64
+ )
+ # Skip if state index is invalid (NULL_BLOCK_ID=0)
+ if state_idx <= 0:
+ return
+ p_h0 = h0 + state_idx * stride_init_state_token
+ else:
+ p_h0 = h0 + bos * HV * V * K
+ p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
+
+ for i_t in range(0, T):
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
+ b_b = tl.load(p_b).to(tl.float32)
+
+ # If the model is loaded in fp16, without the .float() here, A might be -inf
+ x = tl.load(p_a).to(tl.float32) + tl.load(p_dt_bias).to(tl.float32)
+ softplus_x = tl.where(
+ beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x
+ )
+ b_g = -tl.exp(tl.load(p_A_log).to(tl.float32)) * softplus_x
+
+ # compute beta_output = sigmoid(b)
+ b_beta = tl.sigmoid(b_b.to(tl.float32))
+
+ if USE_QK_L2NORM_IN_KERNEL:
+ b_q = b_q * (tl.rsqrt(tl.sum(b_q * b_q) + 1e-6))
+ b_k = b_k * (tl.rsqrt(tl.sum(b_k * b_k) + 1e-6))
+ b_q = b_q * scale
+ # [BV, BK]
+ if not IS_KDA:
+ b_h *= tl.exp(b_g)
+ else:
+ b_h *= tl.exp(b_g[None, :])
+ # [BV]
+ b_v -= tl.sum(b_h * b_k[None, :], 1)
+ b_v *= b_beta
+ # [BV, BK]
+ b_h += b_v[:, None] * b_k[None, :]
+ # [BV]
+ b_o = tl.sum(b_h * b_q[None, :], 1)
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
+
+ # keep the states for multi-query tokens
+ if INPLACE_FINAL_STATE:
+ # Load state index and check for invalid entries
+ final_state_idx = tl.load(
+ ssm_state_indices + i_n * stride_indices_seq + i_t
+ ).to(tl.int64)
+ # Only store if state index is valid (not NULL_BLOCK_ID=0)
+ if final_state_idx > 0:
+ p_ht = ht + final_state_idx * stride_final_state_token
+ p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
+ else:
+ p_ht = ht + (bos + i_t) * stride_final_state_token
+ p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
+
+ # Update pointers for next timestep
+ p_q += H * K
+ p_k += H * K
+ p_o += HV * V
+ p_v += HV * V
+ p_b += HV
+ p_a += HV
+
+
+def fused_sigmoid_gating_delta_rule_update(
+ A_log: torch.Tensor,
+ a: torch.Tensor,
+ b: torch.Tensor,
+ dt_bias: torch.Tensor,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ beta: float = 1.0,
+ threshold: float = 20.0,
+ scale: float = None,
+ initial_state: torch.Tensor = None,
+ inplace_final_state: bool = True,
+ cu_seqlens: torch.Tensor | None = None,
+ ssm_state_indices: torch.Tensor | None = None,
+ num_accepted_tokens: torch.Tensor | None = None,
+ use_qk_l2norm_in_kernel: bool = False,
+ is_kda: bool = False,
+):
+ """
+ Fused triton implementation of sigmoid gating delta rule update.
+ This function uses a single fused kernel that combines both sigmoid gating
+ computation and the recurrent delta rule update for better performance.
+ """
+ B, T, H, K, V = *k.shape, v.shape[-1]
+ HV = v.shape[2]
+ N = B if cu_seqlens is None else len(cu_seqlens) - 1
+ BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
+ assert NK == 1, "NK > 1 is not supported yet"
+ num_stages = 3
+ num_warps = 4
+
+ if cu_seqlens is not None and q.shape[0] != 1:
+ raise ValueError(
+ f"The batch size is expected to be 1 rather than {q.shape[0]}"
+ f" when using `cu_seqlens`. Please flatten variable-length"
+ f" inputs before processing."
+ )
+ if scale is None:
+ scale = k.shape[-1] ** -0.5
+ else:
+ assert scale > 0, "scale must be positive"
+
+ o = q.new_empty(NK, *v.shape)
+ if inplace_final_state:
+ final_state = initial_state
+ else:
+ final_state = q.new_empty(T, HV, V, K, dtype=initial_state.dtype)
+
+ stride_init_state_token = initial_state.stride(0)
+ stride_final_state_token = final_state.stride(0)
+
+ if ssm_state_indices is None:
+ stride_indices_seq, stride_indices_tok = 1, 1
+ elif ssm_state_indices.ndim == 1:
+ stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
+ else:
+ stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
+
+ grid = (NK, NV, N * HV)
+ fused_sigmoid_gating_delta_rule_update_kernel[grid](
+ A_log=A_log,
+ a=a.contiguous(),
+ b=b.contiguous(),
+ dt_bias=dt_bias,
+ beta=beta,
+ threshold=threshold,
+ q=q.contiguous(),
+ k=k.contiguous(),
+ v=v.contiguous(),
+ o=o,
+ h0=initial_state,
+ ht=final_state,
+ cu_seqlens=cu_seqlens,
+ ssm_state_indices=ssm_state_indices,
+ num_accepted_tokens=num_accepted_tokens,
+ scale=scale,
+ N=N,
+ T=T,
+ B=B,
+ H=H,
+ HV=HV,
+ K=K,
+ V=V,
+ BK=BK,
+ BV=BV,
+ stride_init_state_token=stride_init_state_token,
+ stride_final_state_token=stride_final_state_token,
+ stride_indices_seq=stride_indices_seq,
+ stride_indices_tok=stride_indices_tok,
+ INPLACE_FINAL_STATE=inplace_final_state,
+ USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
+ IS_KDA=is_kda,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+ o = o.squeeze(0)
+ return o, final_state
diff --git a/vllm/model_executor/layers/fla/ops/index.py b/vllm/model_executor/layers/fla/ops/index.py
index f023e13..810d32c 100644
--- a/vllm/model_executor/layers/fla/ops/index.py
+++ b/vllm/model_executor/layers/fla/ops/index.py
@@ -15,14 +15,12 @@ from .utils import tensor_cache
@tensor_cache
-def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
+def prepare_lens(cu_seqlens: torch.Tensor) -> torch.Tensor:
return cu_seqlens[1:] - cu_seqlens[:-1]
@tensor_cache
-def prepare_chunk_indices(
- cu_seqlens: torch.LongTensor, chunk_size: int
-) -> torch.LongTensor:
+def prepare_chunk_indices(cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor:
indices = torch.cat(
[
torch.arange(n)
@@ -33,9 +31,7 @@ def prepare_chunk_indices(
@tensor_cache
-def prepare_chunk_offsets(
- cu_seqlens: torch.LongTensor, chunk_size: int
-) -> torch.LongTensor:
+def prepare_chunk_offsets(cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor:
return torch.cat(
[cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]
).cumsum(-1)
diff --git a/vllm/model_executor/layers/fla/ops/kda.py b/vllm/model_executor/layers/fla/ops/kda.py
index 7145933..e686471 100644
--- a/vllm/model_executor/layers/fla/ops/kda.py
+++ b/vllm/model_executor/layers/fla/ops/kda.py
@@ -37,7 +37,7 @@ def fused_recurrent_kda_fwd(
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
- cu_seqlens: torch.LongTensor | None = None,
+ cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
@@ -115,7 +115,7 @@ def fused_recurrent_kda(
initial_state: torch.Tensor = None,
inplace_final_state: bool = True,
use_qk_l2norm_in_kernel: bool = True,
- cu_seqlens: torch.LongTensor | None = None,
+ cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.LongTensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -692,7 +692,7 @@ def chunk_kda_scaled_dot_kkt_fwd(
gk: torch.Tensor | None = None,
beta: torch.Tensor | None = None,
scale: float | None = None,
- cu_seqlens: torch.LongTensor | None = None,
+ cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -706,7 +706,7 @@ def chunk_kda_scaled_dot_kkt_fwd(
The beta tensor of shape `[B, T, H]`.
gk (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
- cu_seqlens (torch.LongTensor):
+ cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):
@@ -936,7 +936,7 @@ def recompute_w_u_fwd(
A: torch.Tensor,
q: torch.Tensor | None = None,
gk: torch.Tensor | None = None,
- cu_seqlens: torch.LongTensor | None = None,
+ cu_seqlens: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = A.shape[-1]
@@ -1104,7 +1104,7 @@ def chunk_gla_fwd_o_gk(
h: torch.Tensor,
o: torch.Tensor,
scale: float,
- cu_seqlens: torch.LongTensor | None = None,
+ cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64,
):
B, T, H, K, V = *q.shape, v.shape[-1]
@@ -1148,7 +1148,7 @@ def chunk_kda_fwd(
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
- cu_seqlens: torch.LongTensor | None = None,
+ cu_seqlens: torch.Tensor | None = None,
):
chunk_size = 64
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
@@ -1208,7 +1208,7 @@ def chunk_kda(
initial_state: torch.Tensor = None,
output_final_state: bool = False,
use_qk_l2norm_in_kernel: bool = False,
- cu_seqlens: torch.LongTensor | None = None,
+ cu_seqlens: torch.Tensor | None = None,
**kwargs,
):
if scale is None:
diff --git a/vllm/model_executor/layers/fla/ops/layernorm_guard.py b/vllm/model_executor/layers/fla/ops/layernorm_guard.py
index 74c08e0..3abfbff 100644
--- a/vllm/model_executor/layers/fla/ops/layernorm_guard.py
+++ b/vllm/model_executor/layers/fla/ops/layernorm_guard.py
@@ -84,6 +84,7 @@ def layer_norm_fwd_kernel(
HAS_Z: tl.constexpr,
NORM_BEFORE_GATE: tl.constexpr,
IS_RMS_NORM: tl.constexpr,
+ ACTIVATION: tl.constexpr,
):
# Map the program id to the starting row of X and Y it should compute.
row_start = tl.program_id(0) * ROWS_PER_BLOCK
@@ -112,7 +113,10 @@ def layer_norm_fwd_kernel(
if HAS_Z and not NORM_BEFORE_GATE:
Z_base = Z + rows[:, None] * stride_z_row + col_offsets
z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)
- x *= z * tl.sigmoid(z)
+ if ACTIVATION == "swish" or ACTIVATION == "silu":
+ x *= z * tl.sigmoid(z)
+ elif ACTIVATION == "sigmoid":
+ x *= tl.sigmoid(z)
# Compute mean and variance per row (reduce along axis 1)
if not IS_RMS_NORM:
@@ -155,7 +159,10 @@ def layer_norm_fwd_kernel(
if HAS_Z and NORM_BEFORE_GATE:
Z_base = Z + rows[:, None] * stride_z_row + col_offsets
z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)
- y *= z * tl.sigmoid(z)
+ if ACTIVATION == "swish" or ACTIVATION == "silu":
+ y *= z * tl.sigmoid(z)
+ elif ACTIVATION == "sigmoid":
+ y *= tl.sigmoid(z)
# Write output
tl.store(Y_base, y, mask=mask)
@@ -178,6 +185,7 @@ def layer_norm_fwd(
group_size: int = None,
norm_before_gate: bool = True,
is_rms_norm: bool = False,
+ activation: str = "swish",
):
M, N = x.shape
if group_size is None:
@@ -232,9 +240,12 @@ def layer_norm_fwd(
eps,
BLOCK_N=BLOCK_N,
ROWS_PER_BLOCK=rows_per_block,
+ HAS_BIAS=bias is not None,
+ HAS_Z=z is not None,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
num_warps=num_warps,
+ ACTIVATION=activation,
)
return out, mean, rstd
@@ -252,6 +263,7 @@ class LayerNormFn(torch.autograd.Function):
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
+ activation: str = "swish",
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
@@ -277,6 +289,7 @@ class LayerNormFn(torch.autograd.Function):
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm,
+ activation=activation,
)
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
ctx.x_shape_og = x_shape_og
@@ -284,6 +297,7 @@ class LayerNormFn(torch.autograd.Function):
ctx.group_size = group_size
ctx.norm_before_gate = norm_before_gate
ctx.is_rms_norm = is_rms_norm
+ ctx.activation = activation
return y.reshape(x_shape_og)
@@ -296,17 +310,25 @@ def layernorm_fn(
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
+ activation: str = "swish",
):
return LayerNormFn.apply(
- x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm
+ x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm, activation
)
def rmsnorm_fn(
- x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
+ x,
+ weight,
+ bias,
+ z=None,
+ eps=1e-6,
+ group_size=None,
+ norm_before_gate=True,
+ activation: str = "swish",
):
return LayerNormFn.apply(
- x, weight, bias, z, eps, group_size, norm_before_gate, True
+ x, weight, bias, z, eps, group_size, norm_before_gate, True, activation
)
@@ -359,6 +381,7 @@ class RMSNormGated(nn.Module):
norm_before_gate: bool = False,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
+ activation: str = "swish",
):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
@@ -366,6 +389,7 @@ class RMSNormGated(nn.Module):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
+ self.activation = activation
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter("bias", None)
self.group_size = group_size
@@ -385,4 +409,5 @@ class RMSNormGated(nn.Module):
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate,
+ activation=self.activation,
)
diff --git a/vllm/model_executor/layers/fla/ops/wy_fast.py b/vllm/model_executor/layers/fla/ops/wy_fast.py
index a66ec1d..6baa08a 100644
--- a/vllm/model_executor/layers/fla/ops/wy_fast.py
+++ b/vllm/model_executor/layers/fla/ops/wy_fast.py
@@ -122,7 +122,7 @@ def recompute_w_u_fwd(
beta: torch.Tensor,
g_cumsum: torch.Tensor,
A: torch.Tensor,
- cu_seqlens: torch.LongTensor | None,
+ cu_seqlens: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2]
diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py
index c6cb31b..d26ac6e 100644
--- a/vllm/model_executor/layers/fused_moe/__init__.py
+++ b/vllm/model_executor/layers/fused_moe/__init__.py
@@ -22,12 +22,13 @@ from vllm.model_executor.layers.fused_moe.layer import (
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat,
- FusedMoEPermuteExpertsUnpermute,
- FusedMoEPrepareAndFinalize,
+ FusedMoEExpertsModular,
+ FusedMoEPrepareAndFinalizeModular,
)
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
+from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
@@ -61,9 +62,10 @@ __all__ = [
"MoEActivation",
"UnquantizedFusedMoEMethod",
"FusedMoeWeightScaleSupported",
- "FusedMoEPermuteExpertsUnpermute",
+ "FusedMoEExpertsModular",
"FusedMoEActivationFormat",
- "FusedMoEPrepareAndFinalize",
+ "FusedMoEPrepareAndFinalizeModular",
+ "GateLinear",
"RoutingMethodType",
"SharedFusedMoE",
"ZeroExpertFusedMoE",
@@ -137,4 +139,4 @@ else:
raise NotImplementedError(f"{method} is not implemented as lack of triton.")
fused_topk = lambda *args, **kwargs: _raise_exception("fused_topk")
- fused_experts = lambda *args, **kwargs: _raise_exception("fused_experts")
+ fused_experts = lambda *args, **kwargs: _raise_exception("fused_experts")
\ No newline at end of file
diff --git a/vllm/model_executor/layers/fused_moe/activation.py b/vllm/model_executor/layers/fused_moe/activation.py
index c11f16d..021b066 100644
--- a/vllm/model_executor/layers/fused_moe/activation.py
+++ b/vllm/model_executor/layers/fused_moe/activation.py
@@ -6,8 +6,7 @@ from enum import Enum
import torch
import torch.nn.functional as F
-
-from vllm._custom_ops import silu_and_mul, gelu_and_mul, swigluoai_and_mul
+from vllm import _custom_ops as ops
class MoEActivation(Enum):
@@ -114,14 +113,11 @@ def apply_moe_activation(
# Activations with gated multiplication (gate × activation(up))
if activation == MoEActivation.SILU:
- # torch.ops._C.silu_and_mul(output, input)
- silu_and_mul(output, input)
+ ops.silu_and_mul(output, input)
elif activation == MoEActivation.GELU:
- # torch.ops._C.gelu_and_mul(output, input)
- gelu_and_mul(output, input)
+ ops.gelu_and_mul(output, input)
elif activation == MoEActivation.SWIGLUOAI:
- # torch.ops._C.swigluoai_and_mul(output, input)
- swigluoai_and_mul(output, input)
+ ops.swigluoai_and_mul(output, input)
elif activation == MoEActivation.SWIGLUSTEP:
from vllm.model_executor.layers.activation import swiglustep_and_mul_triton
diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py
index bf8ec2d..47ca95e 100644
--- a/vllm/model_executor/layers/fused_moe/all2all_utils.py
+++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from typing import Any
import torch
@@ -20,20 +21,15 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
- MoEPrepareAndFinalizeNaiveEP,
- MoEPrepareAndFinalizeNoEP,
+ make_moe_prepare_and_finalize_naive_dp_ep,
+ make_moe_prepare_and_finalize_no_dp_ep,
)
from vllm.platforms import current_platform
-from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
+from vllm.utils.import_utils import has_deep_ep, has_mori
logger = init_logger(__name__)
if current_platform.is_cuda_alike():
- if has_pplx():
- from .pplx_prepare_finalize import (
- PplxPrepareAndFinalize,
- pplx_hidden_dim_scale_bytes,
- )
if has_deep_ep():
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
from .deepep_ll_prepare_finalize import (
@@ -81,6 +77,7 @@ def maybe_make_prepare_finalize(
quant_config: FusedMoEQuantConfig | None,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
allow_new_interface: bool = False,
+ use_monolithic: bool = False,
) -> FusedMoEPrepareAndFinalize | None:
# NOTE(rob): we are migrating each quant_method to hold the MK
# in all cases. The allow_new_interface=False flag allow us to fall
@@ -106,65 +103,25 @@ def maybe_make_prepare_finalize(
"Detected DP deployment with no --enable-expert-parallel. "
"Falling back to AllGather+ReduceScatter dispatch/combine."
)
- return MoEPrepareAndFinalizeNaiveEP(
+ return make_moe_prepare_and_finalize_naive_dp_ep(
is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel,
num_dispatchers=(
get_ep_group().device_communicator.all2all_manager.world_size
),
+ use_monolithic=use_monolithic,
)
else:
- return MoEPrepareAndFinalizeNoEP()
+ return make_moe_prepare_and_finalize_no_dp_ep(use_monolithic)
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
prepare_finalize: FusedMoEPrepareAndFinalize | None = None
- if moe.use_pplx_kernels:
- assert quant_config is not None
-
- hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
- moe.max_num_tokens,
- moe.hidden_dim,
- moe.in_dtype,
- quant_config.quant_dtype,
- per_act_token_quant=quant_config.per_act_token_quant,
- block_shape=quant_config.block_shape,
- )
-
- all_to_all_args = dict(
- max_num_tokens=moe.max_num_tokens,
- num_experts=moe.num_experts,
- experts_per_token=moe.experts_per_token, # topk
- rank=all2all_manager.rank,
- world_size=all2all_manager.world_size,
- # dp_size actually means tp_size, bug in pplx kernels
- dp_size=all2all_manager.tp_group.world_size,
- hidden_dim=moe.hidden_dim,
- hidden_dim_bytes=hidden_dim_bytes,
- hidden_dim_scale_bytes=hidden_scale_bytes,
- )
-
- num_dispatchers = (
- all2all_manager.world_size // all2all_manager.tp_group.world_size
- )
-
- # Intranode pplx a2a takes a group name while internode does not.
- if not all2all_manager.internode:
- all_to_all_args["group_name"] = all2all_manager.cpu_group.group_name
-
- handle = all2all_manager.get_handle(all_to_all_args)
-
- prepare_finalize = PplxPrepareAndFinalize(
- handle,
- max_num_tokens=moe.max_num_tokens,
- num_local_experts=moe.num_local_experts,
- num_dispatchers=num_dispatchers,
- )
- elif moe.use_deepep_ht_kernels:
+ if moe.use_deepep_ht_kernels:
assert moe.dp_size == all2all_manager.dp_world_size
- all_to_all_args = dict()
+ all_to_all_args: dict[str, Any] = dict()
handle = all2all_manager.get_handle(all_to_all_args)
prepare_finalize = DeepEPHTPrepareAndFinalize(
handle,
@@ -246,8 +203,9 @@ def maybe_make_prepare_finalize(
)
elif moe.use_naive_all2all_kernels and allow_new_interface:
- prepare_finalize = MoEPrepareAndFinalizeNaiveEP(
- is_sequence_parallel=(moe.moe_parallel_config.is_sequence_parallel),
+ prepare_finalize = make_moe_prepare_and_finalize_naive_dp_ep(
+ use_monolithic=use_monolithic,
+ is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel,
num_dispatchers=all2all_manager.world_size,
)
diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
index 405965c..5397125 100644
--- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
@@ -261,7 +261,7 @@ def persistent_masked_m_silu_mul_quant(
return y_q, y_s
-class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
+class BatchedDeepGemmExperts(mk.FusedMoEExpertsModular):
def __init__(
self,
moe_config: FusedMoEConfig,
diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py
index 87e1e24..e0ed913 100644
--- a/vllm/model_executor/layers/fused_moe/config.py
+++ b/vllm/model_executor/layers/fused_moe/config.py
@@ -228,6 +228,7 @@ class FusedMoEQuantConfig:
_a2: FusedMoEQuantDesc
_w1: FusedMoEQuantDesc
_w2: FusedMoEQuantDesc
+ is_nvfp4_scale_swizzled: bool = True
def __post_init__(self):
assert not self.per_act_token_quant or self.block_shape is None, (
@@ -475,6 +476,7 @@ class FusedMoEQuantConfig:
w1_zp: torch.Tensor | None = None,
w2_zp: torch.Tensor | None = None,
weight_dtype: torch.dtype | str | None = None,
+ is_nvfp4_scale_swizzled: bool = True,
) -> "FusedMoEQuantConfig":
"""
General builder function for a FusedMoEQuantConfig.
@@ -504,6 +506,7 @@ class FusedMoEQuantConfig:
- w2_bias: Optional biases for w1 (GPT OSS Triton).
- w1_zp: Optional w1 zero points for int4/int8 quantization.
- w2_zp: Optional w2 zero points for int4/int8 quantization.
+ - is_nvfp4_scale_swizzled: Whether to swizzle the nvfp4 scale swizzling.
"""
assert not isinstance(quant_dtype, str) or quant_dtype in {
"nvfp4",
@@ -536,6 +539,7 @@ class FusedMoEQuantConfig:
_w2=FusedMoEQuantDesc(
weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias
),
+ is_nvfp4_scale_swizzled=is_nvfp4_scale_swizzled,
)
assert quant_config.per_act_token_quant == per_act_token_quant
assert quant_config.per_out_ch_quant == per_out_ch_quant
@@ -737,6 +741,7 @@ def nvfp4_moe_quant_config(
w2_scale: torch.Tensor,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
+ is_nvfp4_scale_swizzled: bool = True,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for mxfp4 activations and nvp4 weights.
@@ -754,6 +759,7 @@ def nvfp4_moe_quant_config(
per_act_token_quant=False,
per_out_ch_quant=False,
block_shape=None,
+ is_nvfp4_scale_swizzled=is_nvfp4_scale_swizzled,
)
@@ -939,10 +945,6 @@ class FusedMoEParallelConfig:
def use_all2all_kernels(self):
return self.dp_size > 1 and self.use_ep
- @property
- def use_pplx_kernels(self):
- return self.use_all2all_kernels and self.all2all_backend == "pplx"
-
@property
def use_deepep_ht_kernels(self):
return (
@@ -962,7 +964,7 @@ class FusedMoEParallelConfig:
@property
def use_batched_activation_format(self):
- return self.use_deepep_ll_kernels or self.use_pplx_kernels
+ return self.use_deepep_ll_kernels
@property
def use_naive_all2all_kernels(self):
@@ -1221,10 +1223,6 @@ class FusedMoEConfig:
def use_ep(self):
return self.moe_parallel_config.use_ep
- @property
- def use_pplx_kernels(self):
- return self.moe_parallel_config.use_pplx_kernels
-
@property
def use_deepep_ht_kernels(self):
return self.moe_parallel_config.use_deepep_ht_kernels
diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
new file mode 100644
index 0000000..620fe93
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H200,dtype=fp8_w8a8.json
@@ -0,0 +1,147 @@
+{
+ "triton_version": "3.6.0",
+ "1": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "2": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "4": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "8": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "16": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "24": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 5
+ },
+ "32": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "48": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "64": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "96": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "128": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "256": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "512": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "1024": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "1536": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "2048": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "3072": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "4096": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 128,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 2
+ }
+}
diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H200.json
new file mode 100644
index 0000000..fc7dda8
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H200.json
@@ -0,0 +1,147 @@
+{
+ "triton_version": "3.6.0",
+ "1": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "2": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "4": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "8": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "16": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "24": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 4,
+ "num_stages": 4
+ },
+ "32": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "48": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "64": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "96": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "128": {
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "256": {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 4,
+ "num_stages": 3
+ },
+ "512": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "1024": {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 1,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "1536": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 16,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "2048": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "3072": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 3
+ },
+ "4096": {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 8,
+ "num_stages": 3
+ }
+}
diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py
index ae9430d..64848bf 100644
--- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py
+++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py
@@ -21,7 +21,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
moe_unpermute,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
- MoEPrepareAndFinalizeNoEP,
+ MoEPrepareAndFinalizeNoDPEPModular,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
@@ -166,7 +166,7 @@ def run_cutlass_moe_fp8(
problem_sizes1 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
problem_sizes2 = torch.empty((local_E, 3), dtype=torch.int32, device=device)
- ops.get_cutlass_pplx_moe_mm_data(
+ ops.get_cutlass_batched_moe_mm_data(
expert_offsets,
problem_sizes1,
problem_sizes2,
@@ -262,7 +262,7 @@ def run_cutlass_moe_fp8(
)
-class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
+class CutlassExpertsFp8Base(mk.FusedMoEExpertsModular):
def __init__(
self,
moe_config: FusedMoEConfig,
@@ -661,7 +661,7 @@ def run_cutlass_moe_fp4(
return
-class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
+class CutlassExpertsFp4(mk.FusedMoEExpertsModular):
"""CUTLASS FP4 fused MoE expert implementation."""
@property
@@ -928,7 +928,7 @@ def run_cutlass_moe_w4a8_fp8(
)
-class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute):
+class CutlassExpertsW4A8Fp8(mk.FusedMoEExpertsModular):
def __init__(
self,
out_dtype: torch.dtype | None,
@@ -1170,8 +1170,8 @@ def cutlass_moe_w4a8_fp8(
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0)
- fn = mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(),
+ fn = mk.FusedMoEKernel(
+ MoEPrepareAndFinalizeNoDPEPModular(),
CutlassExpertsW4A8Fp8(
out_dtype=a.dtype,
a_strides1=a_strides1,
@@ -1186,10 +1186,9 @@ def cutlass_moe_w4a8_fp8(
quant_config=quant_config,
group_size=group_size,
),
- inplace=False,
)
- return fn(
+ return fn.apply(
a,
w1_q,
w2_q,
diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
index 69ca7c9..8af439a 100644
--- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
@@ -113,7 +113,7 @@ def _valid_deep_gemm(
return True
-class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
+class DeepGemmExperts(mk.FusedMoEExpertsModular):
"""DeepGemm-based fused MoE expert implementation."""
def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig):
diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
index 514aa20..1f9e79c 100644
--- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
+++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py
@@ -25,7 +25,7 @@ from vllm.v1.worker.ubatching import (
)
-class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
+class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
"""
Prepare/Finalize using DeepEP High-Throughput kernels.
"""
@@ -123,7 +123,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
is_token_in_rank,
event,
) = self.buffer.get_dispatch_layout(
- topk_idx=rank_topk_ids,
+ topk_idx=rank_topk_ids.long(),
num_experts=num_experts,
previous_event=previous_event,
async_finish=False,
@@ -148,7 +148,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank,
num_tokens_per_expert=dispatch_expert_num_tokens,
- topk_idx=rank_topk_ids,
+ topk_idx=rank_topk_ids.long(),
topk_weights=rank_topk_weights,
# expert_alignment rounds the number of tokens per expert
# to this value.
@@ -169,7 +169,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
event,
has_scales,
token_data,
- expert_topk_ids,
+ expert_topk_ids.int(),
num_experts,
expert_num_tokens_per_expert_list,
expert_topk_weights,
@@ -239,6 +239,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=False,
block_shape=quant_config.block_shape,
+ is_fp4_scale_swizzled=quant_config.is_nvfp4_scale_swizzled,
)
return (
diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
index a4cee76..d95e481 100644
--- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
+++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
@@ -49,7 +49,7 @@ def dequant_fp8(
return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.size())
-class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
+class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
"""
Prepare/Finalize using DeepEP low-latency kernels.
"""
@@ -119,7 +119,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# time. This setting is handled by post_init_setup.
self.use_ue8m0_dispatch = False
- def post_init_setup(self, fused_experts: mk.FusedMoEPermuteExpertsUnpermute):
+ def post_init_setup(self, fused_experts: mk.FusedMoEExperts):
if not fused_experts.supports_packed_ue8m0_act_scales():
# Early exit.
return
@@ -297,12 +297,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
dispatch_topk_ids = self._map_global_to_physical_ids(topk_ids)
expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch(
a1,
- dispatch_topk_ids,
+ dispatch_topk_ids.long(),
self.max_tokens_per_rank,
num_experts,
use_fp8=self.use_fp8_dispatch,
- round_scale=self.use_ue8m0_dispatch,
- use_ue8m0=self.use_ue8m0_dispatch,
+ # round_scale=self.use_ue8m0_dispatch,
+ # use_ue8m0=self.use_ue8m0_dispatch,
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
**(
dict(x_global_scale=qc_a1_gscale_or_scale)
@@ -398,7 +398,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
dbo_maybe_run_recv_hook()
_, _, recv_hook = self.buffer.low_latency_combine(
fused_expert_output,
- combine_topk_ids,
+ combine_topk_ids.long(),
combine_topk_weights,
handle,
async_finish=False,
diff --git a/vllm/model_executor/layers/fused_moe/experts/__init__.py b/vllm/model_executor/layers/fused_moe/experts/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py
new file mode 100644
index 0000000..febb3b2
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py
@@ -0,0 +1,335 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import torch
+
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from vllm.model_executor.layers.fused_moe.activation import MoEActivation
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEParallelConfig,
+ FusedMoEQuantConfig,
+ RoutingMethodType,
+)
+from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
+ activation_to_flashinfer_int,
+)
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+ kFp8Dynamic128Sym,
+ kFp8Static128BlockSym,
+ kFp8StaticTensorSym,
+)
+from vllm.platforms import current_platform
+
+
+class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic):
+ """
+ Fp8 TRTLLM-Gen MoE kernels. Supports monolithic interface.
+ """
+
+ def __init__(
+ self,
+ moe_config: FusedMoEConfig,
+ quant_config: FusedMoEQuantConfig,
+ ):
+ super().__init__(moe_config, quant_config)
+
+ if moe_config.moe_parallel_config.use_ep and quant_config.is_per_tensor:
+ raise NotImplementedError(
+ "EP parallelism is not supported with TRTLLM"
+ "per-tensor FP8 quantization."
+ )
+
+ self.routing_method_type = moe_config.routing_method
+ self.topk = moe_config.experts_per_token
+ self.intermediate_size_per_partition = (
+ moe_config.intermediate_size_per_partition
+ )
+ self.hidden_dim = moe_config.hidden_dim
+ self.local_num_experts = moe_config.num_local_experts
+ self.ep_rank = moe_config.moe_parallel_config.ep_rank
+
+ # Make additional scales for per-tensor interface.
+ if self.quant_config.is_per_tensor:
+ w1_scale = self.quant_config.w1_scale
+ assert w1_scale is not None
+ a1_scale = self.quant_config.a1_scale
+ assert a1_scale is not None
+ w2_scale = self.quant_config.w2_scale
+ assert w2_scale is not None
+ a2_scale = self.quant_config.a2_scale
+ assert a2_scale is not None
+
+ self._g1_alphas = (w1_scale * a1_scale).squeeze()
+ self._g2_alphas = (w2_scale * a2_scale).squeeze()
+ self._g1_scale_c = (
+ self._g1_alphas / self.quant_config.a2_scale
+ if moe_config.is_act_and_mul
+ else torch.ones_like(self._g1_alphas) / self.quant_config.a2_scale
+ )
+
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
+
+ @staticmethod
+ def _supports_current_device() -> bool:
+ """Supports only Blackwell-family GPUs."""
+ p = current_platform
+ # Add check flashinfer trtllm is available
+ return p.is_cuda() and p.is_device_capability_family(100)
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ """Does not support non-gated MoE (i.e. Nanotron-3-Nano)."""
+ return True
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ """Supports Fp8 per-tensor and Fp8 block."""
+ SUPPORTED_W_A = [
+ (kFp8Static128BlockSym, kFp8Dynamic128Sym),
+ (kFp8StaticTensorSym, kFp8StaticTensorSym),
+ ]
+ return (weight_key, activation_key) in SUPPORTED_W_A
+
+ @staticmethod
+ def _supports_activation(activation: MoEActivation) -> bool:
+ """Supports only SiLU and RELU^2 non-gated activation."""
+ return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
+
+ @staticmethod
+ def _supports_routing_method(
+ routing_method: RoutingMethodType,
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ """Monolithic kernels need to express router support."""
+ # NOTE(dbari): TopK routing could also be enabled, but need to validate models
+ # NOTE(dbari): Default is not implemented and should not be enabled until it is
+ if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
+ # NOTE(rob): potentially allow others here. This is a conservative list.
+ return routing_method in [
+ RoutingMethodType.DeepSeekV3,
+ RoutingMethodType.Renormalize,
+ RoutingMethodType.RenormalizeNaive,
+ ]
+ elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
+ # NOTE(dbari): as above, potentially allow others here.
+ return routing_method in [
+ RoutingMethodType.DeepSeekV3,
+ RoutingMethodType.Llama4,
+ RoutingMethodType.Renormalize,
+ RoutingMethodType.RenormalizeNaive,
+ ]
+ else:
+ raise ValueError("Unsupported quantization scheme.")
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ """Monolithic kernel so only use with naive DP/EP and TP."""
+ return (
+ not moe_parallel_config.use_all2all_kernels
+ or moe_parallel_config.use_naive_all2all_kernels
+ ) and not moe_parallel_config.enable_eplb
+
+ @staticmethod
+ def _supports_router_logits_dtype(
+ router_logits_dtype: torch.dtype | None,
+ routing_method: RoutingMethodType,
+ ) -> bool:
+ """
+ The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
+ Only DeepSeekV3 routing supports float32 router_logits (which is converted
+ internally in the kernel).
+ """
+ if router_logits_dtype == torch.float32:
+ # Only DeepSeekV3 routing handles float32 logits
+ # https://github.com/flashinfer-ai/flashinfer/issues/2469
+ return routing_method == RoutingMethodType.DeepSeekV3
+ return True
+
+ def supports_chunking(self) -> bool:
+ return False
+
+ def supports_expert_map(self) -> bool:
+ return False
+
+ def _apply_per_block(
+ self,
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ router_logits: torch.Tensor,
+ activation: MoEActivation,
+ global_num_experts: int,
+ expert_map: torch.Tensor | None,
+ a1q_scale: torch.Tensor | None,
+ apply_router_weight_on_input: bool,
+ # grouped topk + fused topk bias parameters
+ num_expert_group: int | None = None,
+ e_score_correction_bias: torch.Tensor | None = None,
+ routed_scaling_factor: float | None = None,
+ topk_group: int | None = None,
+ ) -> torch.Tensor:
+ # Delay import for non-CUDA.
+ import flashinfer
+
+ assert not apply_router_weight_on_input
+ assert activation == MoEActivation.SILU
+
+ if e_score_correction_bias is not None:
+ e_score_correction_bias = e_score_correction_bias.to(hidden_states.dtype)
+
+ if self.routing_method_type == RoutingMethodType.DeepSeekV3:
+ router_logits = router_logits.to(torch.float32)
+
+ assert self.topk <= global_num_experts
+ assert self.topk <= 10
+ assert global_num_experts % 4 == 0
+ assert self.quant_config.block_shape == [128, 128]
+ # Routing kernel expects #experts <= #threads 512
+ assert global_num_experts <= 512
+
+ # Kernel requires transposed hidden state scales
+ # TODO: fuse into the quant kernel.
+ assert a1q_scale is not None
+ a1q_scale_t = a1q_scale.t().contiguous()
+
+ return flashinfer.fused_moe.trtllm_fp8_block_scale_moe(
+ routing_logits=router_logits,
+ routing_bias=e_score_correction_bias,
+ hidden_states=hidden_states,
+ hidden_states_scale=a1q_scale_t,
+ gemm1_weights=w1,
+ gemm1_weights_scale=self.quant_config.w1_scale,
+ gemm2_weights=w2,
+ gemm2_weights_scale=self.quant_config.w2_scale,
+ num_experts=global_num_experts,
+ top_k=self.topk,
+ n_group=(num_expert_group or 0),
+ topk_group=(topk_group or 0),
+ intermediate_size=self.intermediate_size_per_partition,
+ local_expert_offset=self.ep_rank * self.local_num_experts,
+ local_num_experts=self.local_num_experts,
+ routed_scaling_factor=routed_scaling_factor,
+ routing_method_type=self.routing_method_type,
+ use_shuffled_weight=False,
+ )
+
+ def _apply_per_tensor(
+ self,
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ router_logits: torch.Tensor,
+ activation: MoEActivation,
+ global_num_experts: int,
+ expert_map: torch.Tensor | None,
+ a1q_scale: torch.Tensor | None,
+ apply_router_weight_on_input: bool,
+ # grouped topk + fused topk bias parameters
+ num_expert_group: int | None = None,
+ e_score_correction_bias: torch.Tensor | None = None,
+ routed_scaling_factor: float | None = None,
+ topk_group: int | None = None,
+ ) -> torch.Tensor:
+ # Delay import for non-CUDA.
+ import flashinfer
+ from flashinfer.fused_moe.core import ActivationType
+
+ # Confirm supported activation function.
+ assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
+
+ activation_type = ActivationType(activation_to_flashinfer_int(activation))
+
+ # Confirm Llama-4 routing is proper.
+ if self.routing_method_type == RoutingMethodType.Llama4:
+ assert apply_router_weight_on_input
+ else:
+ assert not apply_router_weight_on_input
+
+ # The DeepSeekV3 routing method requires float32 router logits.
+ if self.routing_method_type == RoutingMethodType.DeepSeekV3:
+ router_logits = router_logits.to(torch.float32)
+
+ out = flashinfer.fused_moe.trtllm_fp8_per_tensor_scale_moe(
+ routing_logits=router_logits,
+ routing_bias=e_score_correction_bias,
+ hidden_states=hidden_states,
+ gemm1_weights=w1,
+ output1_scales_scalar=self._g1_scale_c,
+ output1_scales_gate_scalar=self._g1_alphas,
+ gemm2_weights=w2,
+ output2_scales_scalar=self._g2_alphas,
+ num_experts=global_num_experts,
+ top_k=self.topk,
+ n_group=num_expert_group or 0,
+ topk_group=topk_group or 0,
+ intermediate_size=self.intermediate_size_per_partition,
+ local_expert_offset=self.ep_rank * self.local_num_experts,
+ local_num_experts=self.local_num_experts,
+ routed_scaling_factor=routed_scaling_factor,
+ use_routing_scales_on_input=apply_router_weight_on_input,
+ routing_method_type=self.routing_method_type,
+ activation_type=activation_type,
+ )
+ return out
+
+ def apply(
+ self,
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ router_logits: torch.Tensor,
+ activation: MoEActivation,
+ global_num_experts: int,
+ expert_map: torch.Tensor | None,
+ a1q_scale: torch.Tensor | None,
+ apply_router_weight_on_input: bool,
+ # grouped topk + fused topk bias parameters
+ num_expert_group: int | None = None,
+ e_score_correction_bias: torch.Tensor | None = None,
+ routed_scaling_factor: float | None = None,
+ topk_group: int | None = None,
+ ) -> torch.Tensor:
+ if self.quant_config.block_shape is not None:
+ return self._apply_per_block(
+ hidden_states,
+ w1,
+ w2,
+ router_logits,
+ activation,
+ global_num_experts,
+ expert_map,
+ a1q_scale,
+ apply_router_weight_on_input,
+ num_expert_group=num_expert_group,
+ e_score_correction_bias=e_score_correction_bias,
+ routed_scaling_factor=routed_scaling_factor,
+ topk_group=topk_group,
+ )
+ elif self.quant_config.is_per_tensor:
+ return self._apply_per_tensor(
+ hidden_states,
+ w1,
+ w2,
+ router_logits,
+ activation,
+ global_num_experts,
+ expert_map,
+ a1q_scale,
+ apply_router_weight_on_input,
+ num_expert_group=num_expert_group,
+ e_score_correction_bias=e_score_correction_bias,
+ routed_scaling_factor=routed_scaling_factor,
+ )
+ else:
+ raise NotImplementedError(
+ "Only per-block and per-tensor quantization are supported in "
+ f"{self.__class__.__name__}."
+ )
diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py
new file mode 100644
index 0000000..5026717
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py
@@ -0,0 +1,326 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import flashinfer
+import torch
+
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from vllm.model_executor.layers.fused_moe.activation import MoEActivation
+from vllm.model_executor.layers.fused_moe.config import (
+ FusedMoEConfig,
+ FusedMoEParallelConfig,
+ FusedMoEQuantConfig,
+ RoutingMethodType,
+)
+from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
+ TopKWeightAndReduceNoOP,
+)
+from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
+ activation_to_flashinfer_int,
+)
+from vllm.model_executor.layers.quantization.utils.quant_utils import (
+ QuantKey,
+ kNvfp4Dynamic,
+ kNvfp4Static,
+)
+from vllm.platforms import current_platform
+
+
+class TrtLlmNvFp4ExpertsBase:
+ """
+ NvFp4 TRTLLM-Gen MoE kernels. Supports modular and monolithic interface.
+ """
+
+ def __init__(
+ self,
+ moe_config: FusedMoEConfig,
+ quant_config: FusedMoEQuantConfig,
+ ):
+ self.moe_config = moe_config
+ self.quant_config = quant_config
+
+ self.routing_method_type = self.moe_config.routing_method
+ self.topk = moe_config.experts_per_token
+ self.intermediate_size_per_partition = (
+ moe_config.intermediate_size_per_partition
+ )
+ self.hidden_dim = moe_config.hidden_dim
+ self.local_num_experts = moe_config.num_local_experts
+ self.ep_rank = moe_config.moe_parallel_config.ep_rank
+
+ assert self.quant_config.g1_alphas is not None
+ assert self.quant_config.a2_gscale is not None
+ if moe_config.is_act_and_mul:
+ # g1_alpha_s = a13_scale * w13_scale_2
+ # a2_gscale = (1 / a2_scale)
+ # g1_scale_c = a13_scale * w13_scale_2 / a2_scale
+ self.g1_scale_c = self.quant_config.g1_alphas * self.quant_config.a2_gscale
+ else:
+ self.g1_scale_c = (
+ torch.ones_like(self.quant_config.a1_gscale)
+ * self.quant_config.a2_gscale
+ )
+
+ @staticmethod
+ def _supports_current_device() -> bool:
+ """Supports only Blackwell-family GPUs."""
+ p = current_platform
+ return p.is_cuda() and p.is_device_capability_family(100)
+
+ @staticmethod
+ def _supports_no_act_and_mul() -> bool:
+ """Supports non-gated MoE (i.e. Nemotron-Nano)."""
+ return True
+
+ @staticmethod
+ def _supports_quant_scheme(
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ """Supports Nvfp4 quantization."""
+ SUPPORTED_W_A = [
+ (kNvfp4Static, kNvfp4Dynamic),
+ ]
+ return (weight_key, activation_key) in SUPPORTED_W_A
+
+ @staticmethod
+ def _supports_activation(activation: MoEActivation) -> bool:
+ """Supports only SiLU and RELU^2 non-gated activation."""
+ return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
+
+ @staticmethod
+ def _supports_shape(hidden_dim: int) -> bool:
+ """Requires hidden dim to be multiple of 512."""
+ return hidden_dim % 512 == 0
+
+ @staticmethod
+ def activation_format() -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
+
+ def supports_chunking(self) -> bool:
+ return False
+
+ def supports_expert_map(self) -> bool:
+ return False
+
+
+class TrtLlmNvFp4ExpertsModular(TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsModular):
+ """
+ Modular version of the implementation (just the experts).
+ """
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ """The modular implementation supports all parallel configs."""
+ return True
+
+ def workspace_shapes(
+ self,
+ M: int,
+ N: int,
+ K: int,
+ topk: int,
+ global_num_experts: int,
+ local_num_experts: int,
+ expert_tokens_meta: mk.ExpertTokensMetadata | None,
+ activation: MoEActivation,
+ ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
+ # The workspaces for this implementation are managed by flashinfer.
+ workspace1 = (0,)
+ workspace2 = (0,)
+
+ # Hidden states are Nvfp4, packed into int8 dtype, so we
+ # need to multiply K by 2 to get the output shape right.
+ assert self.hidden_dim == K * 2
+ output = (M, self.hidden_dim)
+
+ return (workspace1, workspace2, output)
+
+ def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
+ return TopKWeightAndReduceNoOP()
+
+ def apply(
+ self,
+ output: torch.Tensor,
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ activation: MoEActivation,
+ global_num_experts: int,
+ expert_map: torch.Tensor | None,
+ a1q_scale: torch.Tensor | None,
+ a2_scale: torch.Tensor | None,
+ workspace13: torch.Tensor,
+ workspace2: torch.Tensor,
+ expert_tokens_meta: mk.ExpertTokensMetadata | None,
+ apply_router_weight_on_input: bool,
+ ):
+ assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
+ assert a1q_scale is not None
+ assert self.quant_config.w1_scale is not None
+ assert self.quant_config.w2_scale is not None
+
+ # Pack topk ids and weights into format expected by the kernel.
+ packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
+ torch.bfloat16
+ ).view(torch.int16)
+
+ # trtllm_fp4_block_scale_routed_moe does not support autotuning
+ # so skip this kernel during dummy run for autotuning.
+ import vllm.utils.flashinfer as fi_utils
+
+ if fi_utils._is_fi_autotuning:
+ return hidden_states
+
+ # Invoke kernel.
+ flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
+ topk_ids=packed_tensor,
+ routing_bias=None,
+ hidden_states=hidden_states,
+ hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape(
+ *hidden_states.shape[:-1], -1
+ ),
+ gemm1_weights=w1,
+ gemm1_weights_scale=self.quant_config.w1_scale.view(torch.float8_e4m3fn),
+ gemm1_bias=None,
+ gemm1_alpha=None,
+ gemm1_beta=None,
+ gemm1_clamp_limit=None,
+ gemm2_weights=w2,
+ gemm2_weights_scale=self.quant_config.w2_scale.view(torch.float8_e4m3fn),
+ gemm2_bias=None,
+ output1_scale_scalar=self.g1_scale_c,
+ output1_scale_gate_scalar=self.quant_config.g1_alphas,
+ output2_scale_scalar=self.quant_config.g2_alphas,
+ num_experts=global_num_experts,
+ top_k=self.topk,
+ n_group=0,
+ topk_group=0,
+ intermediate_size=self.intermediate_size_per_partition,
+ local_expert_offset=self.ep_rank * self.local_num_experts,
+ local_num_experts=self.local_num_experts,
+ routed_scaling_factor=None,
+ routing_method_type=1,
+ do_finalize=True,
+ activation_type=activation_to_flashinfer_int(activation),
+ output=output,
+ )
+
+
+class TrtLlmNvFp4ExpertsMonolithic(
+ TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsMonolithic
+):
+ """
+ Monolithic version of the kernel (router + experts).
+ """
+
+ @staticmethod
+ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
+ """The modular implementation should be used for the Dp/Ep or EPLB case."""
+ return (
+ not moe_parallel_config.use_all2all_kernels
+ and not moe_parallel_config.enable_eplb
+ )
+
+ @staticmethod
+ def _supports_routing_method(
+ routing_method_type: RoutingMethodType,
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ # NOTE(rob): this is a conservative list.
+ return routing_method_type in [
+ RoutingMethodType.DeepSeekV3,
+ RoutingMethodType.Renormalize,
+ RoutingMethodType.RenormalizeNaive,
+ RoutingMethodType.Llama4,
+ ]
+
+ @staticmethod
+ def _supports_router_logits_dtype(
+ router_logits_dtype: torch.dtype | None,
+ routing_method: RoutingMethodType,
+ ) -> bool:
+ """
+ The FlashInfer TRTLLM NvFp4 kernel expects bfloat16 router_logits by default.
+ Only DeepSeekV3 routing supports float32 router_logits (which is converted
+ internally in the kernel).
+ """
+ if router_logits_dtype == torch.float32:
+ # Only DeepSeekV3 routing handles float32 logits
+ # https://github.com/flashinfer-ai/flashinfer/issues/2469
+ return routing_method == RoutingMethodType.DeepSeekV3
+ return True
+
+ def apply(
+ self,
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ router_logits: torch.Tensor,
+ activation: MoEActivation,
+ global_num_experts: int,
+ expert_map: torch.Tensor | None,
+ a1q_scale: torch.Tensor | None,
+ apply_router_weight_on_input: bool,
+ # grouped topk + fused topk bias parameters
+ num_expert_group: int | None = None,
+ e_score_correction_bias: torch.Tensor | None = None,
+ routed_scaling_factor: float | None = None,
+ topk_group: int | None = None,
+ ) -> torch.Tensor:
+ assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
+ assert a1q_scale is not None
+ assert self.quant_config.w1_scale is not None
+ assert self.quant_config.w2_scale is not None
+ assert (
+ apply_router_weight_on_input
+ and self.routing_method_type == RoutingMethodType.Llama4
+ ) or (
+ not apply_router_weight_on_input
+ and self.routing_method_type != RoutingMethodType.Llama4
+ )
+
+ # Prepare routing bias into kernel format.
+ routing_bias = e_score_correction_bias
+ if routing_bias is not None:
+ routing_bias = routing_bias.to(torch.bfloat16)
+ router_logits = (
+ router_logits.to(torch.float32)
+ if self.routing_method_type == RoutingMethodType.DeepSeekV3
+ else router_logits
+ )
+
+ # Invoke kernel.
+ return flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
+ routing_logits=router_logits,
+ routing_bias=routing_bias,
+ hidden_states=hidden_states,
+ hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape(
+ *hidden_states.shape[:-1], -1
+ ),
+ gemm1_weights=w1,
+ gemm1_weights_scale=self.quant_config.w1_scale.view(torch.float8_e4m3fn),
+ gemm1_bias=None,
+ gemm1_alpha=None,
+ gemm1_beta=None,
+ gemm1_clamp_limit=None,
+ gemm2_weights=w2,
+ gemm2_weights_scale=self.quant_config.w2_scale.view(torch.float8_e4m3fn),
+ gemm2_bias=None,
+ output1_scale_scalar=self.g1_scale_c,
+ output1_scale_gate_scalar=self.quant_config.g1_alphas,
+ output2_scale_scalar=self.quant_config.g2_alphas,
+ num_experts=global_num_experts,
+ top_k=self.topk,
+ n_group=(num_expert_group or 0),
+ topk_group=(topk_group or 0),
+ intermediate_size=self.intermediate_size_per_partition,
+ local_expert_offset=self.ep_rank * self.local_num_experts,
+ local_num_experts=self.local_num_experts,
+ routed_scaling_factor=routed_scaling_factor,
+ routing_method_type=self.routing_method_type,
+ do_finalize=True,
+ )[0]
diff --git a/vllm/model_executor/layers/fused_moe/fallback.py b/vllm/model_executor/layers/fused_moe/fallback.py
index 4b6458e..403a71e 100644
--- a/vllm/model_executor/layers/fused_moe/fallback.py
+++ b/vllm/model_executor/layers/fused_moe/fallback.py
@@ -11,13 +11,13 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
-class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
+class FallbackExperts(mk.FusedMoEExpertsModular, ABC):
"""Base class for runtime dispatching of expert implementations."""
def __init__(
self,
- experts: mk.FusedMoEPermuteExpertsUnpermute,
- fallback_experts: mk.FusedMoEPermuteExpertsUnpermute,
+ experts: mk.FusedMoEExpertsModular,
+ fallback_experts: mk.FusedMoEExpertsModular,
):
super().__init__(
moe_config=experts.moe_config, quant_config=experts.quant_config
@@ -27,8 +27,8 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
@staticmethod
def get_clses() -> tuple[
- type[mk.FusedMoEPermuteExpertsUnpermute],
- type[mk.FusedMoEPermuteExpertsUnpermute],
+ type[mk.FusedMoEExpertsModular],
+ type[mk.FusedMoEExpertsModular],
]:
"""
Get the cls for the experts and fallback experts.
@@ -149,7 +149,7 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC):
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
- ) -> mk.FusedMoEPermuteExpertsUnpermute:
+ ) -> mk.FusedMoEExpertsModular:
raise NotImplementedError
def apply(
diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py
index 39b3738..465d0ae 100644
--- a/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py
+++ b/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py
@@ -18,7 +18,7 @@ def get_local_sizes():
return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank()
-class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
+class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
"""Base class for FlashInfer MoE prepare and finalize operations."""
def __init__(
@@ -185,8 +185,8 @@ def flashinfer_alltoall_dispatch(
ep_size,
)
- # Swizzle after the A2A if nvfp4.
- if quant_config.quant_dtype == "nvfp4":
+ # Swizzle after the A2A if MoE kernel expects swizzled scales.
+ if quant_config.quant_dtype == "nvfp4" and quant_config.is_nvfp4_scale_swizzled:
if x_sf.element_size() == 1:
x_sf = x_sf.view(torch.uint8)
x_sf = nvfp4_block_scale_interleave(x_sf)
diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py
index d0cf753..730dc0c 100644
--- a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py
+++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py
@@ -30,7 +30,7 @@ from vllm.utils.flashinfer import (
logger = init_logger(__name__)
-class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute):
+class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular):
def __init__(
self,
moe_config: FusedMoEConfig,
diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
index b9566a3..02c31fd 100644
--- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
+++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py
@@ -60,7 +60,7 @@ def is_valid_flashinfer_cutlass_fused_moe(
return True
-class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
+class FlashInferExperts(mk.FusedMoEExpertsModular):
def __init__(
self,
moe_config: mk.FusedMoEConfig,
diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
index 732ab8e..6765e36 100644
--- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py
@@ -10,16 +10,6 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEParallelConfig,
RoutingMethodType,
)
-from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
-from vllm.model_executor.layers.quantization.utils.fp8_utils import (
- per_token_group_quant_fp8,
-)
-from vllm.model_executor.layers.quantization.utils.quant_utils import (
- QuantKey,
- kFp8Dynamic128Sym,
- kFp8Static128BlockSym,
- kFp8StaticTensorSym,
-)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
@@ -39,49 +29,10 @@ def _supports_no_act_and_mul() -> bool:
return True
-def _supports_quant_scheme(
- weight_key: QuantKey | None,
- activation_key: QuantKey | None,
-) -> bool:
- """Supports Fp8 per-tensor and Fp8 block."""
- SUPPORTED_W_A = [
- (kFp8Static128BlockSym, kFp8Dynamic128Sym),
- (kFp8StaticTensorSym, kFp8StaticTensorSym),
- ]
- return (weight_key, activation_key) in SUPPORTED_W_A
-
-
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
-def _supports_routing_method(
- weight_key: QuantKey | None,
- activation_key: QuantKey | None,
- routing_method: RoutingMethodType,
-) -> bool:
- """Monolithic kernels need to express router support."""
- # NOTE(dbari): TopK routing could also be enabled, but need to validate models
- # NOTE(dbari): Default is not implemented and should not be enabled until it is
- if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym):
- # NOTE(rob): potentially allow others here. This is a conservative list.
- return routing_method in [
- RoutingMethodType.DeepSeekV3,
- RoutingMethodType.Renormalize,
- RoutingMethodType.RenormalizeNaive,
- ]
- elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym):
- # NOTE(dbari): as above, potentially allow others here.
- return routing_method in [
- RoutingMethodType.DeepSeekV3,
- RoutingMethodType.Llama4,
- RoutingMethodType.Renormalize,
- RoutingMethodType.RenormalizeNaive,
- ]
- else:
- raise ValueError("Unsupported quantization scheme.")
-
-
def _supports_routing_method_bf16(
routing_method: RoutingMethodType,
) -> bool:
@@ -99,62 +50,6 @@ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bo
return not moe_parallel_config.enable_eplb
-def _supports_router_logits_dtype(
- router_logits_dtype: torch.dtype | None,
- routing_method: RoutingMethodType,
-) -> bool:
- """
- The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default.
- Only DeepSeekV3 routing supports float32 router_logits (which is converted
- internally in the kernel).
- """
- if router_logits_dtype == torch.float32:
- # Only DeepSeekV3 routing handles float32 logits
- # https://github.com/flashinfer-ai/flashinfer/issues/2469
- return routing_method == RoutingMethodType.DeepSeekV3
- return True
-
-
-def is_supported_config_trtllm_fp8(
- moe_config: FusedMoEConfig,
- weight_key: QuantKey | None,
- activation_key: QuantKey | None,
- activation_format: mk.FusedMoEActivationFormat,
-) -> tuple[bool, str | None]:
- """
- This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
- """
-
- def _make_reason(reason: str) -> str:
- return f"kernel does not support {reason}"
-
- if not _supports_current_device():
- return False, _make_reason(f"current device {current_platform.device_name}")
- elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
- return False, _make_reason("no act_and_mul MLP layer")
- elif not _supports_activation(moe_config.activation):
- return False, _make_reason(f"{moe_config.activation} activation")
- elif not _supports_quant_scheme(weight_key, activation_key):
- return False, _make_reason(f"quantization scheme {weight_key}x{activation_key}")
- elif not _supports_parallel_config(moe_config.moe_parallel_config):
- return False, _make_reason(f"parallel config {moe_config.moe_parallel_config}")
- elif not _supports_routing_method(
- weight_key, activation_key, moe_config.routing_method
- ):
- return False, _make_reason(f"routing method {moe_config.routing_method}")
- elif activation_format != mk.FusedMoEActivationFormat.Standard:
- return False, _make_reason(f"activation format {activation_format}")
- elif not _supports_router_logits_dtype(
- moe_config.router_logits_dtype, moe_config.routing_method
- ):
- return False, _make_reason(
- "float32 router_logits with non-DeepSeekV3 routing "
- f"{moe_config.router_logits_dtype}x{moe_config.routing_method}"
- )
-
- return True, None
-
-
def is_supported_config_trtllm_bf16(
moe_config: FusedMoEConfig,
activation_format: mk.FusedMoEActivationFormat,
@@ -183,199 +78,6 @@ def is_supported_config_trtllm_bf16(
return True, None
-def flashinfer_fused_moe_blockscale_fp8(
- routing_logits: torch.Tensor,
- routing_bias: torch.Tensor | None,
- x: torch.Tensor,
- w13_weight: torch.Tensor,
- w13_weight_scale_inv: torch.Tensor,
- w2_weight: torch.Tensor,
- w2_weight_scale_inv: torch.Tensor,
- global_num_experts: int,
- top_k: int,
- num_expert_group: int | None,
- topk_group: int | None,
- intermediate_size: int,
- expert_offset: int,
- local_num_experts: int,
- block_shape: list[int],
- routing_method_type: int,
- routed_scaling: float | None = 1.0,
-) -> torch.Tensor:
- from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
-
- num_expert_group = num_expert_group if num_expert_group is not None else 0
- topk_group = topk_group if topk_group is not None else 0
- assert top_k <= global_num_experts
- assert top_k <= 10
- assert global_num_experts % 4 == 0
- assert block_shape == [128, 128]
- # Routing kernel expects #experts <= #threads 512
- assert global_num_experts <= 512
-
- # The DeepSeekV3 routing method requires float32 router logits.
- if routing_method_type == RoutingMethodType.DeepSeekV3:
- routing_logits = routing_logits.to(torch.float32)
-
- if routing_bias is not None:
- routing_bias = routing_bias.to(x.dtype)
-
- a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
- # NOTE: scales of hidden states have to be transposed!
- a_sf_t = a_sf.t().contiguous()
- return flashinfer_trtllm_fp8_block_scale_moe(
- routing_logits=routing_logits,
- routing_bias=routing_bias,
- hidden_states=a_q,
- hidden_states_scale=a_sf_t,
- gemm1_weights=w13_weight,
- gemm1_weights_scale=w13_weight_scale_inv,
- gemm2_weights=w2_weight,
- gemm2_weights_scale=w2_weight_scale_inv,
- num_experts=global_num_experts,
- top_k=top_k,
- n_group=num_expert_group,
- topk_group=topk_group,
- intermediate_size=intermediate_size,
- local_expert_offset=expert_offset,
- local_num_experts=local_num_experts,
- routed_scaling_factor=routed_scaling,
- routing_method_type=routing_method_type,
- use_shuffled_weight=False,
- )
-
-
-def flashinfer_fused_moe_blockscale_fp8_fake(
- routing_logits: torch.Tensor,
- routing_bias: torch.Tensor | None,
- x: torch.Tensor,
- w13_weight: torch.Tensor,
- w13_weight_scale_inv: torch.Tensor,
- w2_weight: torch.Tensor,
- w2_weight_scale_inv: torch.Tensor,
- global_num_experts: int,
- top_k: int,
- num_expert_group: int,
- topk_group: int,
- intermediate_size: int,
- expert_offset: int,
- local_num_experts: int,
- block_shape: list[int],
- routing_method_type: int,
- routed_scaling: float = 1.0,
-) -> torch.Tensor:
- return torch.empty_like(x)
-
-
-# TODO(bnell): Does this really need to be a torch.op?
-direct_register_custom_op(
- op_name="flashinfer_fused_moe_blockscale_fp8",
- op_func=flashinfer_fused_moe_blockscale_fp8,
- fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
- tags=(torch.Tag.needs_fixed_stride_order,),
-)
-
-
-def fi_trtllm_fp8_per_tensor_moe(
- routing_logits: torch.Tensor,
- routing_bias: torch.Tensor | None,
- hidden_states: torch.Tensor,
- input_scale: torch.Tensor,
- gemm1_weights: torch.Tensor,
- gemm2_weights: torch.Tensor,
- output1_scales_scalar: torch.Tensor,
- output1_scales_gate_scalar: torch.Tensor,
- output2_scales_scalar: torch.Tensor,
- num_experts: int,
- top_k: int,
- num_expert_group: int | None,
- topk_group: int | None,
- intermediate_size: int,
- local_expert_offset: int,
- local_num_experts: int,
- use_routing_scales_on_input: bool,
- routing_method_type: int,
- activation_type: int,
- routed_scaling_factor: float = 1.0,
-) -> torch.Tensor:
- num_expert_group = num_expert_group if num_expert_group is not None else 0
- topk_group = topk_group if topk_group is not None else 0
-
- quant_hidden_states, _ = moe_kernel_quantize_input(
- hidden_states,
- input_scale,
- quant_dtype=torch.float8_e4m3fn,
- per_act_token_quant=False,
- )
-
- from flashinfer.fused_moe.core import ActivationType
-
- from vllm.utils.flashinfer import flashinfer_trtllm_fp8_per_tensor_scale_moe
-
- # The DeepSeekV3 routing method requires float32 router logits.
- if routing_method_type == RoutingMethodType.DeepSeekV3:
- routing_logits = routing_logits.to(torch.float32)
-
- return flashinfer_trtllm_fp8_per_tensor_scale_moe(
- routing_logits=routing_logits,
- routing_bias=routing_bias,
- hidden_states=quant_hidden_states,
- gemm1_weights=gemm1_weights,
- output1_scales_scalar=output1_scales_scalar,
- output1_scales_gate_scalar=output1_scales_gate_scalar,
- gemm2_weights=gemm2_weights,
- output2_scales_scalar=output2_scales_scalar,
- num_experts=num_experts,
- top_k=top_k,
- n_group=num_expert_group,
- topk_group=topk_group,
- intermediate_size=intermediate_size,
- local_expert_offset=local_expert_offset,
- local_num_experts=local_num_experts,
- routed_scaling_factor=routed_scaling_factor,
- use_routing_scales_on_input=use_routing_scales_on_input,
- routing_method_type=routing_method_type,
- # TODO: enum type Required for flashinfer==0.6.3, remove with update
- # https://github.com/flashinfer-ai/flashinfer/pull/2508
- activation_type=ActivationType(activation_type),
- )
-
-
-def fi_trtllm_fp8_per_tensor_moe_fake(
- routing_logits: torch.Tensor,
- routing_bias: torch.Tensor | None,
- hidden_states: torch.Tensor,
- input_scale: torch.Tensor,
- gemm1_weights: torch.Tensor,
- gemm2_weights: torch.Tensor,
- output1_scales_scalar: torch.Tensor,
- output1_scales_gate_scalar: torch.Tensor,
- output2_scales_scalar: torch.Tensor,
- num_experts: int,
- top_k: int,
- num_expert_group: int | None,
- topk_group: int | None,
- intermediate_size: int,
- local_expert_offset: int,
- local_num_experts: int,
- use_routing_scales_on_input: bool,
- routing_method_type: int,
- activation_type: int,
- routed_scaling_factor: float = 1.0,
-) -> torch.Tensor:
- return torch.empty_like(hidden_states)
-
-
-# TODO(bnell): Does this really need to be a torch.op?
-direct_register_custom_op(
- op_name="fi_trtllm_fp8_per_tensor_moe",
- op_func=fi_trtllm_fp8_per_tensor_moe,
- mutates_args=["hidden_states"],
- fake_impl=fi_trtllm_fp8_per_tensor_moe_fake,
- tags=(torch.Tag.needs_fixed_stride_order,),
-)
-
-
def flashinfer_fused_moe_bf16(
routing_logits: torch.Tensor,
routing_bias: torch.Tensor | None,
diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py
index fbd47f8..68393f7 100644
--- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py
@@ -489,11 +489,11 @@ def invoke_moe_batched_triton_kernel(
)
-class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
+class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
"""
A reference prepare/finalize class that reorganizes the tokens into
expert batched format, i.e. E x max_num_tokens x K. This is the format
- that the PPLX dispatch/combine kernels use.
+ that the batched dispatch/combine kernels use.
"""
def __init__(
@@ -645,10 +645,10 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
)
-class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
+class NaiveBatchedExperts(mk.FusedMoEExpertsModular):
"""
A reference MoE expert class that operates on expert batched format,
- i.e. E x max_num_tokens x K. This is the format that the pplx
+ i.e. E x max_num_tokens x K. This is the format that the batched
dispatch/combine kernels use.
"""
@@ -877,10 +877,10 @@ def batched_moe_kernel_quantize_input(
return A_q, A_q_scale
-class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
+class BatchedTritonExperts(mk.FusedMoEExpertsModular):
"""
A Triton based MoE expert class that operates on expert batched format,
- i.e. E x max_num_tokens x K. This is the format that the pplx
+ i.e. E x max_num_tokens x K. This is the format that the batched
dispatch/combine kernels use.
"""
diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
index 4a8f312..280d090 100644
--- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
@@ -526,7 +526,7 @@ def batched_fused_marlin_moe(
return output
-class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
+class MarlinExpertsBase(mk.FusedMoEExpertsModular):
def __init__(
self,
moe_config: FusedMoEConfig,
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py
index 64fa9c9..27dcd40 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe.py
@@ -53,7 +53,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op
-
+import vllm._custom_ops as ops
+import ixformer.inference.functions as ixfops
+from vllm.forward_context import ForwardContext, get_forward_context
+from vllm.distributed import get_ep_group
logger = init_logger(__name__)
@@ -575,56 +578,6 @@ def fused_moe_kernel(
tl.store(c_ptrs, accumulator, mask=c_mask)
-def invoke_fused_moe_kernel(
- A: torch.Tensor,
- B: torch.Tensor,
- C: torch.Tensor,
- A_scale: torch.Tensor | None,
- B_scale: torch.Tensor | None,
- B_zp: torch.Tensor | None,
- topk_weights: torch.Tensor | None,
- topk_ids: torch.Tensor,
- sorted_token_ids: torch.Tensor,
- expert_ids: torch.Tensor,
- num_tokens_post_padded: torch.Tensor,
- mul_routed_weight: bool,
- top_k: int,
- config: dict[str, Any],
- compute_type: tl.dtype,
- use_fp8_w8a8: bool,
- use_int8_w8a8: bool,
- use_int8_w8a16: bool,
- use_int4_w4a16: bool,
- per_channel_quant: bool,
- block_shape: list[int] | None = None,
- B_bias: torch.Tensor | None = None,
-) -> None:
- assert topk_weights is not None or not mul_routed_weight
- assert topk_weights is None or topk_weights.stride(1) == 1
- assert sorted_token_ids.stride(0) == 1
- ops.invoke_fused_moe_kernel(
- A,
- B,
- C,
- A_scale,
- B_scale,
- topk_weights,
- topk_ids,
- sorted_token_ids,
- expert_ids,
- num_tokens_post_padded,
- mul_routed_weight,
- top_k,
- config,
- compute_type,
- use_fp8_w8a8,
- use_int8_w8a16,
- block_shape,
- )
- # ops.invoke_fused_moe_kernel(A,B,C,A_scale,B_scale,topk_weights,topk_ids,sorted_token_ids,expert_ids,num_tokens_post_padded,mul_routed_weight,top_k,config,compute_type,use_fp8_w8a8,use_int8_w8a16,block_shape,B_bias)
- return
-
-
# NOTE(zyongye): we can remove all the wna16 kernel
# once we drop off sm75 support
def invoke_fused_moe_wna16_cuda_kernel(
@@ -782,6 +735,7 @@ def invoke_fused_moe_triton_kernel(
A_scale: torch.Tensor | None,
B_scale: torch.Tensor | None,
topk_weights: torch.Tensor | None,
+ topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor | None,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
@@ -799,7 +753,9 @@ def invoke_fused_moe_triton_kernel(
):
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
- assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1
+ assert sorted_token_ids.stride(0) == 1
+ ops.invoke_fused_moe_kernel(A,B,C,A_scale,B_scale,topk_weights,topk_ids,sorted_token_ids,expert_ids,num_tokens_post_padded,mul_routed_weight,top_k,config,compute_type,use_fp8_w8a8,use_int8_w8a16,block_shape,B_bias)
+ return
if use_fp8_w8a8 or use_int8_w8a8:
assert B_scale is not None
@@ -910,32 +866,6 @@ def dispatch_fused_moe_kernel(
block_shape: list[int] | None = None,
B_bias: torch.Tensor | None = None,
) -> None:
- invoke_fused_moe_kernel(
- A,
- B,
- C,
- A_scale,
- B_scale,
- B_zp,
- topk_weights,
- topk_ids,
- sorted_token_ids,
- expert_ids,
- num_tokens_post_padded,
- mul_routed_weight,
- top_k,
- config,
- compute_type,
- use_fp8_w8a8,
- use_int8_w8a8,
- use_int8_w8a16,
- use_int4_w4a16,
- per_channel_quant,
- block_shape,
- B_bias
- )
- return
-
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1
@@ -999,6 +929,7 @@ def dispatch_fused_moe_kernel(
A_scale,
B_scale,
topk_weights,
+ topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
@@ -1397,14 +1328,13 @@ def get_default_config(
"num_warps": num_warps,
"num_stages": num_stages,
}
- # TODO
numel = M * topk
if numel <= 64:
- config["BLOCK_SIZE_M"] = 32
+ config['BLOCK_SIZE_M'] = 32
elif numel <= 1024:
- config["BLOCK_SIZE_M"] = 64
+ config['BLOCK_SIZE_M'] = 64
else:
- config["BLOCK_SIZE_M"] = 256
+ config['BLOCK_SIZE_M'] = 256
return config
@@ -1424,14 +1354,12 @@ def try_get_optimal_moe_config(
else:
# First try to load optimal config from the file
E, _, N = w2_shape
- if dtype == "int4_w4a16":
- N = N * 2
- block_n = block_shape[0] if block_shape else 0
- block_k = block_shape[1] if block_shape else 0
- configs = get_moe_configs(E, N, dtype, block_n, block_k)
+ # block_n = block_shape[0] if block_shape else 0
+ # block_k = block_shape[1] if block_shape else 0
+ # configs = get_moe_configs(E, N, dtype, block_n, block_k)
configs = None
-
+
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
@@ -1560,13 +1488,12 @@ def outplace_fused_experts(
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
) -> torch.Tensor:
- return fused_experts_impl(
+ return fused_experts_impl_opt(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
- False,
activation,
apply_router_weight_on_input,
use_fp8_w8a8,
@@ -1626,14 +1553,12 @@ direct_register_custom_op(
def torch_vllm_inplace_fused_experts(**kwargs) -> torch.Tensor:
- # torch.ops.vllm.inplace_fused_experts(**kwargs)
inplace_fused_experts(**kwargs)
- hidden_states = kwargs["hidden_states"]
+ hidden_states = kwargs['hidden_states']
return hidden_states
def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
- # return torch.ops.vllm.outplace_fused_experts(**kwargs)
return outplace_fused_experts(**kwargs)
@@ -1661,7 +1586,6 @@ def fused_experts(
"""Run fused MoE expert computation using Triton kernels."""
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
-
assert not inplace or not disable_inplace()
return dispatch_fused_experts_func(inplace)(
@@ -1691,6 +1615,245 @@ def fused_experts(
w2_bias=quant_config.w2_bias,
)
+def fused_experts_impl_opt(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ activation: str = "silu",
+ apply_router_weight_on_input: bool = False,
+ use_fp8_w8a8: bool = False,
+ use_int8_w8a8: bool = False,
+ use_int8_w8a16: bool = False,
+ use_int4_w4a16: bool = False,
+ ocp_mx_scheme: str | None = None,
+ per_channel_quant: bool = False,
+ global_num_experts: int = -1,
+ expert_map: torch.Tensor | None = None,
+ w1_scale: torch.Tensor | None = None,
+ w2_scale: torch.Tensor | None = None,
+ w1_zp: torch.Tensor | None = None,
+ w2_zp: torch.Tensor | None = None,
+ a1_scale: torch.Tensor | None = None,
+ a2_scale: torch.Tensor | None = None,
+ block_shape: torch.Tensor | None = None,
+ w1_bias: torch.Tensor | None = None,
+ w2_bias: torch.Tensor | None = None,
+ output: torch.Tensor | None = None
+) -> torch.Tensor:
+ # check constraints
+ if use_fp8_w8a8 or use_int8_w8a8 or use_int8_w8a16 or use_int4_w4a16 or w1_scale or \
+ w2_scale or w1_zp or w2_zp or a1_scale or a2_scale:
+ raise ValueError("Quantized MoE is not supported")
+
+ attn_metadata = get_forward_context().attn_metadata
+ use_ep = expert_map is not None
+
+ # unsupported ep now
+ if attn_metadata:
+ only_decode = (use_ep == False and all(t.num_decodes > 0 and t.num_prefills ==0 for t in list(attn_metadata.values())))
+ else:
+ only_decode = False
+
+ assert topk_weights.size() == topk_ids.size(), "topk shape mismatch"
+ assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
+ assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
+ assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
+ assert hidden_states.dtype in [
+ torch.float32, torch.float16, torch.bfloat16
+ ]
+
+ num_tokens = hidden_states.size(0)
+ num_experts = w1.size(0)
+ top_k = topk_weights.size(1)
+
+ if use_ep:
+ local_num_experts = w1.size(0)
+ start_eid = get_ep_group().device_group.rank() * local_num_experts
+ end_eid = min((get_ep_group().device_group.rank() + 1) * local_num_experts, global_num_experts)
+ hidden_size = hidden_states.shape[1]
+ (
+ src_to_dst,
+ sorted_token_ids,
+ expert_sizes_gpu,
+ expert_sizes_cpu,
+ expand_tokens,
+ ) = ixfops.moe_compute_token_index_ep(
+ topk_ids=topk_ids,
+ num_experts=global_num_experts,
+ start_expert_id=start_eid,
+ end_expert_id=end_eid,
+ )
+ if expert_sizes_cpu.sum() == 0:
+ return torch.zeros(
+ (num_tokens, hidden_size),
+ device=hidden_states.device,
+ dtype=hidden_states.dtype,
+ )
+ else:
+ expand_tokens = num_tokens * top_k
+ (
+ src_to_dst,
+ sorted_token_ids,
+ expert_sizes_gpu,
+ expert_sizes_cpu,
+ ) = ixfops.moe_compute_token_index(
+ topk_ids=topk_ids,
+ num_experts=num_experts,
+ )
+
+ if only_decode:
+ # expand + reorder
+ hidden_states = ixfops.moe_expand_input(
+ hidden_states=hidden_states,
+ dst_to_src=sorted_token_ids,
+ dst_tokens=expand_tokens,
+ topk=top_k,
+ src_to_dst=src_to_dst,
+ )
+
+ # group gemm 1
+ pt_output_1 = ixfops.moe_w16a16_group_gemv(
+ input=hidden_states,
+ weight=w1,
+ output_dtype=hidden_states.dtype,
+ tokens_per_experts_gpu=expert_sizes_gpu,
+ dst_to_src=None,
+ bias=w1_bias,
+ format="TN",
+ )
+
+ # act
+ if activation == "silu":
+ pt_output_2 = ixfops.silu_and_mul(pt_output_1)
+ elif activation == "gelu":
+ pt_output_2 = ixfops.gelu_and_mul(pt_output_1)
+ elif activation == "swigluoai":
+ pt_output_2 = ixfops.swigluoai_and_mul(pt_output_1)
+ elif activation == "swiglustep":
+ from vllm.model_executor.layers.activation import swiglustep_and_mul_triton
+ output_dim = pt_output_1.shape[1]
+ pt_output_2 = torch.empty(
+ (num_tokens * top_k, output_dim//2),
+ device=pt_output_1.device,
+ dtype=pt_output_1.dtype,
+ )
+ swiglustep_and_mul_triton(pt_output_2, pt_output_1)
+ else:
+ raise ValueError(f"Unsupported activation: {activation}")
+
+ # group gemm 2 + reorder
+ pt_output_3 = ixfops.moe_w16a16_group_gemv(
+ input=pt_output_2,
+ weight=w2,
+ output_dtype=hidden_states.dtype,
+ tokens_per_experts_gpu=expert_sizes_gpu,
+ dst_to_src=sorted_token_ids,
+ bias=w2_bias,
+ format="TN",
+ )
+
+ # mul + reduce_sum
+ final_hidden_states = ixfops.moe_output_reduce_sum(
+ input=pt_output_3.view(num_tokens, top_k, -1),
+ topk_weight=topk_weights,
+ )
+
+ else:
+ expert_sizes_cpu = expert_sizes_gpu.cpu()
+ # expand + reorder
+ hidden_states = ixfops.moe_expand_input(
+ hidden_states=hidden_states,
+ dst_to_src=sorted_token_ids,
+ dst_tokens=expand_tokens,
+ topk=top_k,
+ src_to_dst=src_to_dst,
+ )
+ # group gemm 1
+ pt_output_1 = ixfops.moe_w16a16_group_gemm(
+ input=hidden_states,
+ weight=w1,
+ output_dtype=hidden_states.dtype,
+ tokens_per_experts=expert_sizes_cpu,
+ dst_to_src=None,
+ bias=w1_bias,
+ format="TN",
+ )
+
+ # act
+ if activation == "silu":
+ pt_output_2 = ixfops.silu_and_mul(pt_output_1)
+ elif activation == "gelu":
+ pt_output_2 = ixfops.gelu_and_mul(pt_output_1)
+ elif activation == "swigluoai":
+ pt_output_2 = ixfops.swigluoai_and_mul(pt_output_1)
+ elif activation == "swiglustep":
+ from vllm.model_executor.layers.activation import swiglustep_and_mul_triton
+ output_dim = pt_output_1.shape[1]
+ pt_output_2 = torch.empty(
+ (num_tokens * top_k, output_dim//2),
+ device=pt_output_1.device,
+ dtype=pt_output_1.dtype,
+ )
+ swiglustep_and_mul_triton(pt_output_2, pt_output_1)
+ else:
+ raise ValueError(f"Unsupported activation: {activation}")
+
+ if use_ep:
+ pt_output_3 = torch.empty(
+ (num_tokens * top_k, hidden_size),
+ device=hidden_states.device,
+ dtype=hidden_states.dtype,
+ )
+ # group gemm 2 + reorder
+ pt_output_3 = ixfops.moe_w16a16_group_gemm(
+ input=pt_output_2,
+ weight=w2,
+ output_dtype=hidden_states.dtype,
+ tokens_per_experts=expert_sizes_cpu,
+ dst_to_src=sorted_token_ids,
+ format="TN",
+ bias=w2_bias,
+ output=pt_output_3,
+ )
+
+ # mul + reduce_sum
+ reduce_mask = src_to_dst == -1
+ if output != None:
+ ixfops.moe_output_reduce_sum(
+ input=pt_output_3.view(num_tokens, top_k, -1),
+ topk_weight=topk_weights,
+ output=output,
+ mask=reduce_mask,
+ )
+ else:
+ final_hidden_states = ixfops.moe_output_reduce_sum(
+ input=pt_output_3.view(num_tokens, top_k, -1),
+ topk_weight=topk_weights,
+ mask=reduce_mask,
+ )
+ else:
+ # group gemm 2 + reorder
+ pt_output_3 = ixfops.moe_w16a16_group_gemm(
+ input=pt_output_2,
+ weight=w2,
+ output_dtype=hidden_states.dtype,
+ tokens_per_experts=expert_sizes_cpu,
+ dst_to_src=sorted_token_ids,
+ bias=w2_bias,
+ format="TN",
+ )
+
+ # mul + reduce_sum
+ final_hidden_states = ixfops.moe_output_reduce_sum(
+ input=pt_output_3.view(num_tokens, top_k, -1),
+ topk_weight=topk_weights,
+ )
+
+ if output == None:
+ return final_hidden_states
+
def _get_config_quant_dtype(
use_fp8_w8a8: bool,
@@ -1825,7 +1988,7 @@ def fused_experts_impl(
intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
# This needs separate memory since it's used concurrently with cache1
- activation_out_dim = mk.FusedMoEPermuteExpertsUnpermute.adjust_N_for_activation(
+ activation_out_dim = mk.FusedMoEExpertsModular.adjust_N_for_activation(
N, activation_enum
)
intermediate_cache2 = torch.empty(
@@ -1910,28 +2073,28 @@ def fused_experts_impl(
ocp_mx_scheme=ocp_mx_scheme,
)
- # SPARSITY_FACTOR is a heuristic margin ensuring tokens_in_chunk * top_k
- # activates only a small fraction of total experts
- SPARSITY_FACTOR = 4
- # block quantized code path is not implemented yet.
- naive_block_assignment = (
- expert_map is None
- and tokens_in_chunk * top_k_num * SPARSITY_FACTOR <= global_num_experts
- and not (
- (use_int8_w8a16 or use_int4_w4a16)
- and block_shape is not None
- and block_shape[1] > 0
- )
- )
+ # # SPARSITY_FACTOR is a heuristic margin ensuring tokens_in_chunk * top_k
+ # # activates only a small fraction of total experts
+ # SPARSITY_FACTOR = 4
+ # # block quantized code path is not implemented yet.
+ # naive_block_assignment = (
+ # expert_map is None
+ # and tokens_in_chunk * top_k_num * SPARSITY_FACTOR <= global_num_experts
+ # and not (
+ # (use_int8_w8a16 or use_int4_w4a16)
+ # and block_shape is not None
+ # and block_shape[1] > 0
+ # )
+ # )
# if not naive_block_assignment:
- # sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
- # curr_topk_ids,
- # config["BLOCK_SIZE_M"],
- # global_num_experts,
- # expert_map,
- # ignore_invalid_experts=True,
- # )
+ sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
+ curr_topk_ids,
+ config["BLOCK_SIZE_M"],
+ global_num_experts,
+ expert_map,
+ ignore_invalid_experts=True,
+ )
# else:
# max_num_tokens_padded = topk_ids.numel() * config["BLOCK_SIZE_M"]
# expert_ids = curr_topk_ids.view(-1)
@@ -1941,14 +2104,6 @@ def fused_experts_impl(
# num_tokens_post_padded.fill_(max_num_tokens_padded)
# sorted_token_ids = None
- sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
- curr_topk_ids,
- config["BLOCK_SIZE_M"],
- global_num_experts,
- expert_map,
- ignore_invalid_experts=True,
- )
-
dispatch_fused_moe_kernel(
qcurr_hidden_states,
w1,
@@ -2015,20 +2170,14 @@ def fused_experts_impl(
B_bias=w2_bias,
)
- # ops.moe_sum(
- # intermediate_cache3.view(*intermediate_cache3.size()),
- # out_hidden_states[begin_chunk_idx:end_chunk_idx],
- # )
- torch.sum(
- intermediate_cache3.view(*intermediate_cache3.shape),
- dim=1,
- out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
- )
+ torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
+ dim=1,
+ out=out_hidden_states[begin_chunk_idx:end_chunk_idx])
return out_hidden_states
-class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
+class TritonExperts(mk.FusedMoEExpertsModular):
"""Triton-based fused MoE expert implementation."""
def __init__(
@@ -2091,8 +2240,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
- # return not moe_parallel_config.use_fi_all2allv_kernels
- return True
+ return not moe_parallel_config.use_fi_all2allv_kernels
def supports_chunking(self) -> bool:
return True
@@ -2138,157 +2286,31 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
- # Check constraints.
- if self.quant_config.use_int4_w4a16:
- assert hidden_states.size(-1) // 2 == w1.size(2), "Hidden size mismatch"
- else:
- assert hidden_states.size(-1) == w1.size(2), (
- f"Hidden size mismatch {hidden_states.size(-1)} != {w1.size(2)}"
- )
-
- assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
- assert hidden_states.dim() == 2
- assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
- assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
- assert hidden_states.dtype in [
- torch.float32,
- torch.float16,
- torch.bfloat16,
- torch.float8_e4m3fn,
- torch.float8_e4m3fnuz,
- ]
-
- E, num_tokens, N, K, top_k_num = self.moe_problem_size(
- hidden_states, w1, w2, topk_ids
- )
-
- if global_num_experts == -1:
- global_num_experts = E
-
- config = try_get_optimal_moe_config(
- w1.size(),
- w2.size(),
- top_k_num,
- self.quant_config.config_name(hidden_states.dtype),
- num_tokens,
- block_shape=self.block_shape,
- )
-
- if hidden_states.dtype == torch.bfloat16:
- compute_type = tl.bfloat16
- elif hidden_states.dtype == torch.float16:
- compute_type = tl.float16
- elif hidden_states.dtype == torch.float32:
- compute_type = tl.float32
- elif (
- hidden_states.dtype == torch.float8_e4m3fn
- or hidden_states.dtype == torch.float8_e4m3fnuz
- ):
- compute_type = tl.bfloat16
- else:
- raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
-
- # Note that the output tensor might be in workspace1
- intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N))
- cache2_dim = self.adjust_N_for_activation(N, activation)
- intermediate_cache2 = _resize_cache(
- workspace13, (num_tokens * top_k_num, cache2_dim)
- )
- intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K))
-
- sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
- topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map
- )
-
- invoke_fused_moe_triton_kernel(
- hidden_states,
- w1,
- intermediate_cache1,
- a1q_scale,
- self.w1_scale,
- None, # topk_weights
- sorted_token_ids,
- expert_ids,
- num_tokens_post_padded,
- False, # mul_routed_weights
- top_k_num,
- config,
- compute_type=compute_type,
- use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
- use_int8_w8a8=self.quant_config.use_int8_w8a8,
- use_int8_w8a16=self.quant_config.use_int8_w8a16,
- use_int4_w4a16=self.quant_config.use_int4_w4a16,
- per_channel_quant=self.per_act_token_quant,
- block_shape=self.block_shape,
- B_bias=self.w1_bias,
- )
-
- self.activation(
- activation, intermediate_cache2, intermediate_cache1.view(-1, N)
- )
-
- a2q_scale: torch.Tensor | None = None
-
- qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
- intermediate_cache2,
- a2_scale,
- self.quant_dtype,
- self.per_act_token_quant,
- self.block_shape,
- )
-
- # invoke_fused_moe_triton_kernel(
- # qintermediate_cache2,
- # w2,
- # intermediate_cache3,
- # a2q_scale,
- # self.w2_scale,
- # topk_weights,
- # sorted_token_ids,
- # expert_ids,
- # num_tokens_post_padded,
- # not apply_router_weight_on_input,
- # 1,
- # config,
- # compute_type=compute_type,
- # use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
- # use_int8_w8a8=self.quant_config.use_int8_w8a8,
- # use_int8_w8a16=self.quant_config.use_int8_w8a16,
- # use_int4_w4a16=self.quant_config.use_int4_w4a16,
- # per_channel_quant=self.per_act_token_quant,
- # block_shape=self.block_shape,
- # B_bias=self.w2_bias,
- # )
-
- invoke_fused_moe_kernel(
- qintermediate_cache2,
- w2,
- intermediate_cache3,
- a2q_scale,
- self.w2_scale,
- self.w2_zp,
- topk_weights,
- sorted_token_ids,
- expert_ids,
- num_tokens_post_padded,
- not apply_router_weight_on_input,
- 1,
- config,
- compute_type=compute_type,
- use_fp8_w8a8=self.quant_config.use_fp8_w8a8,
- use_int8_w8a8=self.quant_config.use_int8_w8a8,
- use_int8_w8a16=self.quant_config.use_int8_w8a16,
- use_int4_w4a16=self.quant_config.use_int4_w4a16,
- per_channel_quant=self.per_act_token_quant,
- block_shape=self.block_shape,
- B_bias=self.w2_bias,
- )
-
- # separate function is required for MoE + LoRA
- self.moe_sum(intermediate_cache3, output)
-
- def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
- ops.moe_sum(input, output)
+ fused_experts_impl_opt(hidden_states,
+ w1,
+ w2,
+ topk_weights,
+ topk_ids,
+ activation,
+ apply_router_weight_on_input,
+ self.quant_config.use_fp8_w8a8,
+ self.quant_config.use_int8_w8a8,
+ self.quant_config.use_int8_w8a16,
+ self.quant_config.use_int4_w4a16,
+ self.quant_config.ocp_mx_scheme,
+ self.quant_config.per_act_token_quant,
+ global_num_experts,
+ expert_map,
+ self.quant_config.w1_scale,
+ self.quant_config.w2_scale,
+ self.quant_config.w1_zp,
+ self.quant_config.w2_zp,
+ self.quant_config.a1_scale,
+ self.quant_config.a2_scale,
+ self.quant_config.block_shape,
+ self.quant_config.w1_bias,
+ self.quant_config.w2_bias,
+ output)
class TritonWNA16Experts(TritonExperts):
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
index ac7c71e..88cd173 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py
@@ -12,8 +12,8 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
- FusedMoEPermuteExpertsUnpermute,
- FusedMoEPrepareAndFinalize,
+ FusedMoEExpertsModular,
+ FusedMoEPrepareAndFinalizeModular,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase,
@@ -27,19 +27,21 @@ class FusedMoEMethodBase(QuantizeMethodBase):
super().__init__()
self.moe: FusedMoEConfig = moe
self.moe_quant_config: FusedMoEQuantConfig | None = None
- self.moe_mk: mk.FusedMoEModularKernel | None = None
+ self.moe_kernel: mk.FusedMoEKernel | None = None
@property
def supports_internal_mk(self) -> bool:
# NOTE(rob): temporary attribute to indicate support for
# completed migration to the new internal MK interface.
- return self.moe_mk is not None
+ return self.moe_kernel is not None
@property
def mk_owns_shared_expert(self) -> bool:
# NOTE(rob): temporary attribute to indicate support for
# completed migration to the new internal MK interface.
- return self.moe_mk is not None and self.moe_mk.shared_experts is not None
+ return (
+ self.moe_kernel is not None and self.moe_kernel.shared_experts is not None
+ )
@abstractmethod
def create_weights(
@@ -66,35 +68,25 @@ class FusedMoEMethodBase(QuantizeMethodBase):
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
- ) -> FusedMoEPrepareAndFinalize | None:
+ ) -> FusedMoEPrepareAndFinalizeModular | None:
from .all2all_utils import maybe_make_prepare_finalize
- return maybe_make_prepare_finalize(
+ pf = maybe_make_prepare_finalize(
self.moe, self.moe_quant_config, routing_tables
)
+ assert pf is None or isinstance(pf, FusedMoEPrepareAndFinalizeModular)
+ return pf
def select_gemm_impl(
self,
- prepare_finalize: FusedMoEPrepareAndFinalize,
+ prepare_finalize: FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module,
- ) -> FusedMoEPermuteExpertsUnpermute:
+ ) -> FusedMoEExpertsModular:
# based on the all2all implementation, select the appropriate
# gemm implementation
- raise NotImplementedError(
- f"{self.__class__.__name__} must select appropriate gemm "
- "implementation based on the prepare_finalize"
- )
-
- def prepare_dp_allgather_tensor(
- self,
- layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
- hidden_states: torch.Tensor,
- router_logits: torch.Tensor,
- ) -> tuple[torch.Tensor, list[torch.Tensor]]:
- """Hook to prepare tensors and extra tensors for DP allgather + EP dispatch."""
- raise NotImplementedError(
- "Method 'prepare_dp_allgather_tensor' is not implemented in "
- f"{self.__class__.__name__}."
+ raise ValueError(
+ f"{self.__class__.__name__} uses the new modular kernel initialization "
+ "logic. This function should not be called."
)
@abstractmethod
@@ -105,8 +97,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
@property
def topk_indices_dtype(self) -> torch.dtype | None:
- if self.moe_mk is not None:
- return self.moe_mk.prepare_finalize.topk_indices_dtype()
+ if self.moe_kernel is not None:
+ return self.moe_kernel.prepare_finalize.topk_indices_dtype()
return None
@property
@@ -119,7 +111,12 @@ class FusedMoEMethodBase(QuantizeMethodBase):
@property
def is_monolithic(self) -> bool:
- return False
+ if self.moe_kernel is None:
+ if hasattr(self, "experts_cls"):
+ return self.experts_cls.is_monolithic()
+ else:
+ return False
+ return self.moe_kernel.is_monolithic
def apply(
self,
diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
index 187464c..0065c11 100644
--- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
+++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py
@@ -13,8 +13,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
- FusedMoEModularKernel,
- FusedMoEPrepareAndFinalize,
+ FusedMoEKernel,
+ FusedMoEPrepareAndFinalizeModular,
)
logger = init_logger(__name__)
@@ -26,15 +26,15 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
# --8<-- [end:modular_fused_moe]
def __init__(
- self, old_quant_method: FusedMoEMethodBase, experts: FusedMoEModularKernel
+ self, old_quant_method: FusedMoEMethodBase, moe_kernel: FusedMoEKernel
):
super().__init__(old_quant_method.moe)
self.moe_quant_config = old_quant_method.moe_quant_config
- self.moe_mk = experts
+ self.moe_kernel = moe_kernel
self.disable_expert_map = getattr(
old_quant_method,
"disable_expert_map",
- not self.moe_mk.supports_expert_map(),
+ not self.moe_kernel.supports_expert_map(),
)
self.old_quant_method = old_quant_method
logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__)
@@ -43,13 +43,13 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
def make(
moe_layer: torch.nn.Module,
old_quant_method: FusedMoEMethodBase,
- prepare_finalize: FusedMoEPrepareAndFinalize,
+ prepare_finalize: FusedMoEPrepareAndFinalizeModular,
shared_experts: torch.nn.Module | None,
inplace: bool = False,
) -> "FusedMoEModularMethod":
return FusedMoEModularMethod(
old_quant_method,
- FusedMoEModularKernel(
+ FusedMoEKernel(
prepare_finalize,
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
shared_experts,
@@ -90,8 +90,8 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- assert self.moe_mk is not None
- return self.moe_mk(
+ assert self.moe_kernel is not None
+ return self.moe_kernel.apply(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
index 5617156..8d6f716 100644
--- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
+++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
@@ -6,6 +6,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
+from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
@@ -178,7 +179,40 @@ def triton_kernel_moe_forward(
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
+ unpadded_N_w1=None,
+ unpadded_K_w1=None,
+ unpadded_N_w2=None,
+ unpadded_K_w2=None,
) -> torch.Tensor:
+ if (
+ quant_config is not None
+ and quant_config.use_mxfp4_w4a8
+ and rocm_aiter_ops.is_enabled()
+ ):
+ from aiter.ops.triton.moe_routing.routing import routing as aiter_routing
+
+ routing_data, gather_idx, scatter_idx = aiter_routing(
+ gating_output, topk, sm_first=not renormalize
+ )
+ return triton_kernel_fused_mxfp4_w4a8_experts(
+ None,
+ hidden_states,
+ w1,
+ w2,
+ routing_data,
+ gather_idx,
+ scatter_idx,
+ activation=activation.value,
+ quant_config=quant_config,
+ apply_router_weight_on_input=apply_router_weight_on_input,
+ global_num_experts=global_num_experts,
+ expert_map=expert_map,
+ unpadded_N_w1=unpadded_N_w1,
+ unpadded_K_w1=unpadded_K_w1,
+ unpadded_N_w2=unpadded_N_w2,
+ unpadded_K_w2=unpadded_K_w2,
+ )
+
if expert_map is not None:
# With expert parallelism, legacy_routing produces routing data
# using global expert IDs which don't correspond to local weight
@@ -210,6 +244,9 @@ def triton_kernel_moe_forward(
effective_global_num_experts = global_num_experts
output = torch.empty_like(hidden_states)
+ effective_quant_config = (
+ quant_config if quant_config is not None else FUSED_MOE_UNQUANTIZED_CONFIG
+ )
return triton_kernel_fused_experts(
output,
@@ -221,7 +258,7 @@ def triton_kernel_moe_forward(
scatter_idx,
topk=topk,
activation=activation,
- quant_config=quant_config,
+ quant_config=effective_quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=effective_global_num_experts,
expert_map=effective_expert_map,
@@ -252,8 +289,7 @@ def triton_kernel_fused_experts(
assert activation == MoEActivation.SWIGLUOAI, (
"Only SWIGLUOAI activation is supported"
)
- if quant_config is None:
- quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
+ assert quant_config is not None
# type check, uint8 means mxfp4
assert hidden_states.dtype == torch.bfloat16
@@ -330,6 +366,98 @@ def triton_kernel_fused_experts(
return output_tensor
+# This is a triton implementation of the fused_experts function
+def triton_kernel_fused_mxfp4_w4a8_experts(
+ output_tensor: torch.Tensor,
+ hidden_states: torch.Tensor,
+ w1, # Tensor or triton_kernels.Tensor
+ w2, # Tensor or triton_kernels.Tensor
+ routing_data, # RoutingData
+ gather_indx, # GatherIndx
+ scatter_indx, # ScatterIndx
+ activation: str = "silu",
+ quant_config: FusedMoEQuantConfig | None = None,
+ swiglu_alpha: float = 1.702,
+ swiglu_limit: float = 7.0,
+ apply_router_weight_on_input: bool = False,
+ global_num_experts: int = -1,
+ expert_map: torch.Tensor | None = None,
+ a1q_scale: torch.Tensor | None = None,
+ unpadded_N_w1=None,
+ unpadded_K_w1=None,
+ unpadded_N_w2=None,
+ unpadded_K_w2=None,
+) -> torch.Tensor:
+ assert quant_config is not None
+ # type check, uint8 means mxfp4
+ assert hidden_states.dtype == torch.bfloat16
+ assert quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32
+ assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32
+
+ # Shape check, only check non-mxfp4
+ assert hidden_states.shape[-1] == w1.shape[-2]
+ assert w2.shape[-1] == w1.shape[1]
+
+ E, _, N = w1.shape
+
+ if global_num_experts == -1:
+ global_num_experts = E
+
+ gammas = routing_data.gate_scal if routing_data else None
+
+ from aiter.ops.triton.moe_op_gemm_a8w4 import moe_gemm_a8w4
+ from aiter.ops.triton.quant_moe import downcast_to_static_fp8
+
+ assert quant_config.w1_precision is not None, (
+ "w1_precision in quant config can't be None"
+ )
+ assert quant_config.w2_precision is not None, (
+ "w2_precision in quant config can't be None"
+ )
+
+ hidden_states = downcast_to_static_fp8(
+ hidden_states, quant_config.w1_precision.flex_ctx.lhs_data.scale
+ )
+
+ intermediate_cache1 = moe_gemm_a8w4(
+ hidden_states,
+ w1.storage.data,
+ None,
+ quant_config.w1_precision.weight_scale.storage.data,
+ quant_config.w1_precision.flex_ctx.lhs_data.scale,
+ quant_config.w2_precision.flex_ctx.lhs_data.scale,
+ quant_config.w1_bias,
+ routing_data,
+ gather_indx=gather_indx,
+ gammas=gammas if apply_router_weight_on_input else None,
+ swizzle_mx_scale="CDNA4_SCALE",
+ out_dtype=torch.float8_e4m3fn,
+ apply_swiglu=True,
+ alpha=swiglu_alpha,
+ limit=swiglu_limit,
+ unpadded_N=unpadded_N_w1,
+ unpadded_K=unpadded_K_w1,
+ )
+
+ intermediate_cache3 = moe_gemm_a8w4(
+ intermediate_cache1,
+ w2.storage.data,
+ None,
+ quant_config.w2_precision.weight_scale.storage.data,
+ quant_config.w2_precision.flex_ctx.lhs_data.scale,
+ None,
+ quant_config.w2_bias,
+ routing_data,
+ scatter_indx=scatter_indx,
+ gammas=None if apply_router_weight_on_input else gammas,
+ swizzle_mx_scale="CDNA4_SCALE",
+ unpadded_N=unpadded_N_w2,
+ unpadded_K=unpadded_K_w2,
+ )
+
+ return intermediate_cache3
+
+
def make_routing_data(
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
@@ -383,7 +511,7 @@ def make_routing_data(
return routing_data, gather_indx, scatter_indx
-class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
+class BaseOAITritonExperts(mk.FusedMoEExpertsModular):
@staticmethod
def _supports_current_device() -> bool:
raise NotImplementedError(
@@ -520,6 +648,9 @@ class OAITritonExperts(BaseOAITritonExperts):
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
+ if self.quant_config is None:
+ self.quant_config: FusedMoEQuantConfig = FUSED_MOE_UNQUANTIZED_CONFIG
+
if expert_map is not None:
topk_ids = expert_map[topk_ids]
diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py
index 6ac1c19..69c9c17 100644
--- a/vllm/model_executor/layers/fused_moe/layer.py
+++ b/vllm/model_executor/layers/fused_moe/layer.py
@@ -5,8 +5,8 @@ from collections.abc import Callable, Iterable
from enum import Enum
from typing import Literal, cast, get_args, overload
+import ast, re
import torch
-import torch.nn.functional as F
from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs
@@ -54,10 +54,14 @@ from vllm.model_executor.layers.quantization.base_config import (
)
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
+from vllm.model_executor.layers.utils import (
+ parse_opt_exclude_layers,
+ weight_quant_l1,
+ weight_quant_l2,
+)
logger = init_logger(__name__)
-
class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor"
CHANNEL = "channel"
@@ -333,6 +337,7 @@ class FusedMoE(CustomOp):
gate: torch.nn.Module | None = None,
shared_experts: torch.nn.Module | None = None,
routed_input_transform: torch.nn.Module | None = None,
+ fused_shared_output: bool = False,
):
super().__init__()
@@ -483,6 +488,8 @@ class FusedMoE(CustomOp):
(expert_mask == 0) | (expert_mask == 1)
), "Aiter Fused MoE kernel only supports expert_map with 0 and 1s."
+ self.hidden_size = hidden_size
+ self.num_experts = num_experts
assert intermediate_size % self.tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
@@ -526,16 +533,18 @@ class FusedMoE(CustomOp):
# Round up hidden size before creating moe_config.
# This way moe_config is created with the correct hidden_size from the start.
+ unpadded_hidden_size = hidden_size
+ self.model_type = (
+ self.vllm_config.model_config.hf_config.model_type
+ if self.vllm_config.model_config is not None
+ else None
+ )
hidden_size = maybe_roundup_hidden_size(
hidden_size=hidden_size,
act_dtype=moe_in_dtype,
moe_parallel_config=self.moe_parallel_config,
is_lora_enabled=vllm_config.lora_config is not None,
- model_type=(
- self.vllm_config.model_config.hf_config.model_type
- if self.vllm_config.model_config is not None
- else None
- ),
+ model_type=self.model_type,
is_mxfp4_quant=(
quant_config is not None and quant_config.is_mxfp4_quant(prefix, self)
),
@@ -581,14 +590,27 @@ class FusedMoE(CustomOp):
"""
quant_method = None
if self.quant_config is not None:
+ self.opt_level = 0
quant_method = self.quant_config.get_quant_method(self, prefix)
if quant_method is None:
- quant_method = UnquantizedFusedMoEMethod(self.moe_config)
+ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
+ CompressedTensorsL1OptMoEMethod, CompressedTensorsL2OptMoEMethod)
+ if self.opt_level == 1:
+ quant_method = CompressedTensorsL1OptMoEMethod(self.moe_config)
+ elif self.opt_level == 2:
+ quant_method = CompressedTensorsL2OptMoEMethod(self.moe_config)
+ else:
+ quant_method = UnquantizedFusedMoEMethod(self.moe_config)
assert isinstance(quant_method, FusedMoEMethodBase)
return quant_method
# Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first.
+ self.opt_level = envs.VLLM_MOE_OPT_LEVEL
+ if parse_opt_exclude_layers(envs.VLLM_OPT_EXCLUDE_LAYERS, prefix):
+ self.opt_flag = False
+ logger.info(f"Excluding layer {prefix} from optimization")
+
self.quant_method: FusedMoEMethodBase = _get_quant_method()
if not self.moe_config.is_act_and_mul and not current_platform.is_cuda_alike():
@@ -611,6 +633,7 @@ class FusedMoE(CustomOp):
moe_quant_params = {
"num_experts": self.local_num_experts,
"hidden_size": hidden_size,
+ "unpadded_hidden_size": unpadded_hidden_size,
"intermediate_size_per_partition": self.intermediate_size_per_partition,
"params_dtype": params_dtype,
"weight_loader": self.weight_loader,
@@ -625,6 +648,7 @@ class FusedMoE(CustomOp):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params)
+ self.base_quant_method = self.quant_method
# Disable shared expert overlap if:
# - we are using eplb with non-default backend, because of correctness issues
@@ -638,7 +662,10 @@ class FusedMoE(CustomOp):
)
and self._shared_experts is not None
)
-
+ if fused_shared_output:
+ assert self.use_ep == False, "Fused shared output is only supported when EP is disabled."
+ assert shared_experts is not None, "Shared experts must be provided when fused_shared_output is True."
+ self.fused_shared_output = fused_shared_output
self.runner = self._init_runner()
def _init_runner(self):
@@ -655,6 +682,7 @@ class FusedMoE(CustomOp):
quant_method=self.quant_method,
reduce_results=self.reduce_results,
enable_dbo=self.vllm_config.parallel_config.enable_dbo,
+ fused_shared_output=self.fused_shared_output,
)
# TODO(bnell): This method is provided as a hook so vllm/lora/layers/fused_moe.py
@@ -681,7 +709,7 @@ class FusedMoE(CustomOp):
# routing_tables only needed for round-robin expert placement with
# DeepEP all2all backend.
routing_tables = self._maybe_init_expert_routing_tables()
- prepare_finalize = self.quant_method.maybe_make_prepare_finalize(
+ prepare_finalize = self.base_quant_method.maybe_make_prepare_finalize(
routing_tables=routing_tables
)
if prepare_finalize is not None:
@@ -691,7 +719,7 @@ class FusedMoE(CustomOp):
self._replace_quant_method(
FusedMoEModularMethod.make(
self,
- self.quant_method,
+ self.base_quant_method,
prepare_finalize,
self.shared_experts,
inplace=not self.moe_config.disable_inplace,
@@ -959,11 +987,7 @@ class FusedMoE(CustomOp):
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
- try:
- expert_data.copy_(loaded_weight)
- except Exception as e:
- print(expert_data.shape, expert_data.dtype, loaded_weight.shape, loaded_weight.dtype)
- raise e
+ expert_data.copy_(loaded_weight)
def _load_w2(
self,
@@ -976,7 +1000,7 @@ class FusedMoE(CustomOp):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
- shard_size = expert_data.shape[shard_dim]
+ shard_size = loaded_weight.shape[shard_dim] // self.tp_size
# Only narrow if the loaded_weight is not a scalar (0-dim tensor)
# and we're not loading the full weight
if not load_full and loaded_weight.ndim > 0:
@@ -984,7 +1008,55 @@ class FusedMoE(CustomOp):
shard_dim, shard_size * tp_rank, shard_size
)
# w2, down_proj: Load into only logical weight of w2.
- expert_data.copy_(loaded_weight)
+ expert_data.narrow(shard_dim, 0, shard_size).copy_(loaded_weight)
+
+ def _load_model_opt_weight_or_group_weight_scale(self,
+ shard_dim: int,
+ shard_dim_scale: int,
+ expert_data: torch.Tensor,
+ scale_data: torch.Tensor,
+ shard_id: str,
+ loaded_weight: torch.Tensor,
+ tp_rank: int,
+ opt_level: int,
+ load_full_w2: bool = False):
+ """
+ Load grouped weight scales for group quantization or model weights
+ :param shard_dim: dimension to shard
+ :param expert_data: parameter for a particular expert
+ :param shard_id: either w1, w2, or w3
+ :param loaded_weight: checkpoint weight to load into the param
+ :param tp_rank: tensor parallel rank
+ :param load_full_w2: whether or not the w2 loaded should be sharded.
+ """
+
+ assert opt_level in [1, 2]
+ if opt_level == 1:
+ weight, scale = weight_quant_l1(loaded_weight)
+ else:
+ weight, scale = weight_quant_l2(loaded_weight)
+ scale = scale.view(1, -1)
+
+ if shard_id == "w2":
+ # In the case where we have actorder/g_idx, we do not partition the
+ # w2 scales, as indicated by `load_full` argument, for all tp cases
+ self._load_w2(shard_dim=shard_dim,
+ loaded_weight=weight,
+ expert_data=expert_data,
+ tp_rank=tp_rank,
+ load_full=load_full_w2)
+ scale_data.copy_(scale)
+ elif shard_id in ("w1", "w3"):
+ self._load_w13(shard_id=shard_id,
+ shard_dim=shard_dim,
+ loaded_weight=weight,
+ expert_data=expert_data,
+ tp_rank=tp_rank)
+ self._load_w13(shard_id=shard_id,
+ shard_dim=shard_dim_scale,
+ loaded_weight=scale,
+ expert_data=scale_data,
+ tp_rank=tp_rank)
def _load_single_value(
self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int
@@ -1147,7 +1219,6 @@ class FusedMoE(CustomOp):
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if is_transposed:
shard_dim = int(not shard_dim)
-
shard_dim_force = getattr(param, "shard_dim", None)
shard_dim = shard_dim_force if shard_dim_force is not None else shard_dim
@@ -1309,13 +1380,28 @@ class FusedMoE(CustomOp):
# Case model weights
if "weight" in weight_name:
- self._load_model_weight_or_group_weight_scale(
- shard_id=shard_id,
- shard_dim=shard_dim,
- loaded_weight=loaded_weight,
- expert_data=expert_data,
- tp_rank=self.tp_rank,
- )
+ if self.opt_level != 0:
+ scale_name = weight_name.split('.')[-1] + "_scale"
+ params_dict = dict(self.named_parameters())
+ scale_param = params_dict[scale_name]
+ shard_dim_scale = getattr(scale_param, "shard_dim", None)
+ scale_expert_data = scale_param.data if full_load else scale_param.data[expert_id]
+ self._load_model_opt_weight_or_group_weight_scale(
+ shard_id=shard_id,
+ shard_dim=shard_dim,
+ shard_dim_scale=shard_dim_scale,
+ loaded_weight=loaded_weight,
+ expert_data=expert_data,
+ scale_data=scale_expert_data,
+ opt_level=self.opt_level,
+ tp_rank=self.tp_rank)
+ else:
+ self._load_model_weight_or_group_weight_scale(
+ shard_id=shard_id,
+ shard_dim=shard_dim,
+ loaded_weight=loaded_weight,
+ expert_data=expert_data,
+ tp_rank=self.tp_rank)
return True if return_success else None
return False if return_success else None
diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py
index 9a5f4a2..7b49282 100644
--- a/vllm/model_executor/layers/fused_moe/modular_kernel.py
+++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py
@@ -20,6 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
+ RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache,
@@ -56,25 +57,25 @@ logger = init_logger(__name__)
# MoE kernel implementations.
#
# The following main classes are defined:
-# * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE
+# * FusedMoEPrepareAndFinalizeModular - an abstract base class for preparation of MoE
# inputs (e.g. quantization, distribution) and finalization of Moe outputs.
# The prepare method must take care of any needed quantization and the
-# finalize method, informed by the FusedMoEPermuteExpertsUnpermute method,
+# finalize method, informed by the FusedMoEExpertsModular method,
# may apply weights and/or do the final reduction of the output.
-# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused
+# * FusedMoEExpertsModular - an abstract base class for the main fused
# MoE operation, i.e matmul + act_mul + optionally quant + matmul.
-# Some FusedMoEPermuteExpertsUnpermute implementations may choose to do
+# Some FusedMoEExpertsModular implementations may choose to do
# the weight application and/or reduction. The class communicates this
# to [Finalize] via a TopKWeightAndReduce object.
# * FusedMoEModularKernel - an interface class that combines a
-# FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to
+# FusedMoEPrepareAndFinalizeModular and a FusedMoEExpertsModular to
# provide the standard fused MoE kernel interface.
# * TopKWeightAndReduce - A TopKWeightAndReduce implementation chosen
-# by the FusedMoEPermuteExpertsUnpermute implementation that is passed
+# by the FusedMoEExpertsModular implementation that is passed
# on to [Finalize].
#
# [Quantize-Prepare] and [Finalize] functionality are bundled into a single
-# class `FusedMoEPrepareAndFinalize` since they could use collective
+# class `FusedMoEPrepareAndFinalizeModular` since they could use collective
# communication mechanisms that need to be consistent.
#
@@ -155,25 +156,96 @@ PrepareResultType = tuple[
torch.Tensor | None,
]
+#
+# PrepareResultType is a tuple of:
+# - quantized + dispatched a.
+# - quantized + dispatched a1_scales.
+# - dispatched router logits.
+#
+# See `prepare_monolithic` method below.
+#
+PrepareMonolithicResultType = tuple[
+ torch.Tensor,
+ torch.Tensor | None,
+ torch.Tensor,
+]
+
ReceiverType = Callable[[], PrepareResultType]
+################################################################################
+# Prepare/Finalize
+################################################################################
+
-# TODO: pass FusedMoEParallelConfig in as ctor parameter?
class FusedMoEPrepareAndFinalize(ABC):
"""
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
described above.
+
+ There are two variants of this class:
+ * FusedMoEPrepareAndFinalizeModular - this operates on topk ids and weights
+ * FusedMoEPrepareAndFinalizeMonolithic - the operates on router_logits
"""
- def post_init_setup(self, fused_experts: "FusedMoEPermuteExpertsUnpermute"):
+ def post_init_setup(self, fused_experts: "FusedMoEExperts"):
"""
- Initialize FusedMoEPrepareAndFinalize settings that depend on
- FusedMoEPermuteExpertsUnpermute experts object.
- The FusedMoEPrepareAndFinalize implementations that have such
+ Initialize FusedMoEPrepareAndFinalizeModular settings that depend on
+ FusedMoEExpertsModular experts object.
+ The FusedMoEPrepareAndFinalizeModular implementations that have such
dependencies may choose to override this function.
"""
return
+ @property
+ @abstractmethod
+ def activation_format(self) -> FusedMoEActivationFormat:
+ """
+ A property indicating the output format of the activations for the
+ 'prepare' method.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def topk_indices_dtype(self) -> torch.dtype | None:
+ """
+ The PrepareFinalize All2All implementations generally constrain the
+ dtype of the topk_ids they support. This function returns the
+ required topk indices dtype so it can be respected.
+ Return None if there are no such restrictions.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def max_num_tokens_per_rank(self) -> int | None:
+ """
+ Some PrepareFinalize All2All implementations are batched. Meaning,
+ they can process only as set of tokens at a time. This
+ function returns the batch size i.e the maximum number of tokens
+ the implementation can process at a time.
+ Return None if there are no such restrictions.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def num_dispatchers(self) -> int:
+ raise NotImplementedError
+
+ @abstractmethod
+ def output_is_reduced(self) -> bool:
+ """
+ Indicates whether or not the output of finalize is reduced across all
+ ranks.
+ """
+ raise NotImplementedError
+
+
+# TODO: pass FusedMoEParallelConfig in as ctor parameter?
+class FusedMoEPrepareAndFinalizeModular(FusedMoEPrepareAndFinalize):
+ """
+ An abstract base class for the [Quantize-Prepare] and [Finalize] steps
+ described above for the Modular case.
+ """
+
@abstractmethod
def prepare(
self,
@@ -198,7 +270,7 @@ class FusedMoEPrepareAndFinalize(ABC):
activations, before quantization + dispatching.
- quant_config: Quantization info provided by the fused experts.
- defer_input_quant: Runtime parameter indicating whether or not to
- defer input quantization to the FusedMoEPermuteExpertsUnpermute
+ defer input quantization to the FusedMoEExpertsModular
in cases where the compute kernel expects unquantized inputs
Returns a tuple of:
@@ -245,7 +317,7 @@ class FusedMoEPrepareAndFinalize(ABC):
- apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching.
- defer_input_quant: Runtime parameter indicating whether or not to
- defer input quantization to the FusedMoEPermuteExpertsUnpermute
+ defer input quantization to the FusedMoEExpertsModular
in cases where the compute kernel expects unquantized inputs
Returns a callback or a hook callback pair that when invoked waits for
@@ -338,56 +410,58 @@ class FusedMoEPrepareAndFinalize(ABC):
"""
raise NotImplementedError
- @property
+
+class FusedMoEPrepareAndFinalizeMonolithic(FusedMoEPrepareAndFinalize):
+ """
+ An abstract base class for the [Quantize-Prepare] and [Finalize] steps
+ described above for the monolithic case.
+ """
+
@abstractmethod
- def activation_format(self) -> FusedMoEActivationFormat:
+ def prepare(
+ self,
+ a1: torch.Tensor,
+ router_logits: torch.Tensor,
+ quant_config: FusedMoEQuantConfig,
+ defer_input_quant: bool = False,
+ ) -> PrepareMonolithicResultType:
"""
- A property indicating the output format of the activations for the
- 'prepare' method.
+ Optional method for subclasses compatible with monolithic
+ FusedMoEExpertsModular kernels.
+
+ Perform any quantization (and/or) dispatching needed for this kernel.
+ - a1: The (unquantized) input to the MoE layer.
+ - quant_config: Quantization info provided by the fused experts.
+ - defer_input_quant: Runtime parameter indicating whether or not to
+ defer input quantization to the FusedMoEExpertsModular
+
+ Returns a tuple of:
+ - quantized + dispatched a.
+ - Optional quantized + dispatched a1_scales.
"""
raise NotImplementedError
@abstractmethod
- def topk_indices_dtype(self) -> torch.dtype | None:
+ def finalize(self, fused_expert_output: torch.Tensor) -> torch.Tensor:
"""
- The PrepareFinalize All2All implementations generally constrain the
- dtype of the topk_ids they support. This function returns the
- required topk indices dtype so it can be respected.
- Return None if there are no such restrictions.
+ Optional method for subclasses compatible with monolithic
+ FusedMoEExpertsModular kernels.
+
+ Perform any combine plus apply weights and perform a reduction on the
+ fused experts output.
+ - fused_expert_output: The unweighted, unreduced output of the fused
+ experts, it will have (M, topk, K) shape.
"""
raise NotImplementedError
- @abstractmethod
- def max_num_tokens_per_rank(self) -> int | None:
- """
- Some PrepareFinalize All2All implementations are batched. Meaning,
- they can process only as set of tokens at a time. This
- function returns the batch size i.e the maximum number of tokens
- the implementation can process at a time.
- Return None if there are no such restrictions.
- """
- raise NotImplementedError
- @abstractmethod
- def num_dispatchers(self) -> int:
- raise NotImplementedError
-
- @abstractmethod
- def output_is_reduced(self) -> bool:
- """
- Indicates whether or not the output of finalize is reduced across all
- ranks.
- """
- raise NotImplementedError
+################################################################################
+# Experts
+################################################################################
# TODO: add supported activations method (return string)
-class FusedMoEPermuteExpertsUnpermute(ABC):
- """
- An abstract base class for the [Permute-Experts-Unpermute] step described
- above.
- """
-
+class FusedMoEExperts(ABC):
def __init__(
self,
moe_config: FusedMoEConfig,
@@ -419,6 +493,10 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
+ @staticmethod
+ def is_monolithic() -> bool:
+ raise NotImplementedError("Implemented by subclasses.")
+
@property
def expects_unquantized_inputs(self) -> bool:
"""
@@ -439,49 +517,6 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
"""
raise NotImplementedError
- def moe_problem_size(
- self,
- a1: torch.Tensor,
- w1: torch.Tensor,
- w2: torch.Tensor,
- topk_ids: torch.Tensor,
- ) -> tuple[int, int, int, int, int]:
- """
- Extract the MoE problem size from the given tensor arguments:
- - a: The hidden states, input to the MoE layer.
- - w1: The first set of expert weights.
- - w2: The second set of expert weights.
- - topk_ids: The topk ids.
-
- Note: extracting the problem shape from the weight and activation
- tensors is not obvious. It needs to be done this way specifically
- due to subtle issues with particular kernels, e.g. the int4 kernels
- divide the trailing dimension by two, so it's not "correct" to
- extract N or K from the trailing dimension of w1 or w2. Similarly,
- some kernels transpose the weights, so this needs to be kept in mind.
-
- Note: This implementation covers most cases. However, if experts
- require a specialized implementation, like MarlinExperts, they are free
- to override this function.
- """
- assert w1.dim() == 3 and w2.dim() == 3
- E, N, _ = w1.size()
- K = a1.size(-1)
-
- if a1.dim() == 2:
- # Make sure we are using the correct a1 (pre-permute).
- assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}"
- M = a1.size(0)
- else:
- assert a1.dim() == 3
- assert a1.size(0) == E, f"{a1.size(0)} == {E}"
- M = a1.size(1) # This is max_num_tokens
-
- assert topk_ids.dim() == 2
- topk = topk_ids.size(1)
-
- return E, M, N, K, topk
-
#
# Various helpers for registering support for various features.
# Used by the oracle to select a particular kernel for a deployment.
@@ -489,7 +524,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
@staticmethod
def is_supported_config(
- cls: type["FusedMoEPermuteExpertsUnpermute"],
+ cls: type["FusedMoEExperts"],
moe_config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
@@ -512,6 +547,21 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
return False, _make_reason(
f"parallel config {moe_config.moe_parallel_config}"
)
+ elif not cls._supports_routing_method(
+ moe_config.routing_method, weight_key, activation_key
+ ):
+ return False, _make_reason(f"routing method {moe_config.routing_method}")
+ elif not cls._supports_router_logits_dtype(
+ moe_config.router_logits_dtype,
+ moe_config.routing_method,
+ ):
+ return False, _make_reason(
+ f"router logits dtype {moe_config.router_logits_dtype}"
+ )
+ elif not cls._supports_shape(moe_config.hidden_dim):
+ return False, _make_reason(
+ f"{moe_config.hidden_dim} hidden dim is not supported"
+ )
elif activation_format != cls.activation_format():
return False, _make_reason(f"{activation_format.value} activation format")
return True, None
@@ -554,10 +604,48 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
@abstractmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
"""
- Whether the kernel supports deployment in expert parallel.
+ Whether the kernel supports deployment in particular parallel config.
+
+ Can be overriden if a kernel does not support EP, SP or some other
+ configuration.
"""
raise NotImplementedError
+ @staticmethod
+ def _supports_routing_method(
+ routing_method: RoutingMethodType,
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ """
+ Whether the kernel supports a routing method (e.g. GroupedTopK).
+
+ Can be overriden by monolithic kernels that execute the router
+ in addition to the experts if certain routers are not supported.
+ """
+ return True
+
+ @staticmethod
+ def _supports_router_logits_dtype(
+ router_logits_dtype: torch.dtype | None,
+ routing_method: RoutingMethodType,
+ ) -> bool:
+ """
+ Whether a kernel supports a particular dtype for router logits input.
+
+ Can be overriden by monolithic kernels that execute the router
+ in addition to the experts if certain dtypes are not supported.
+ """
+ return True
+
+ @staticmethod
+ def _supports_shape(hidden_dim: int) -> bool:
+ """
+ Whether a kernel supports a particular shape. Can be overridden if a kernel
+ has specific shape requirements.
+ """
+ return True
+
#
# Various helpers for accessing quantization parameters from the
# quant_config.
@@ -654,6 +742,65 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
"""
return False
+ def enable_chunking(self):
+ return (
+ envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and self.supports_chunking()
+ )
+
+
+class FusedMoEExpertsModular(FusedMoEExperts):
+ """
+ An abstract base class for the [Permute-Experts-Unpermute] step described
+ above.
+ """
+
+ @staticmethod
+ def is_monolithic() -> bool:
+ return False
+
+ def moe_problem_size(
+ self,
+ a1: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_ids: torch.Tensor,
+ ) -> tuple[int, int, int, int, int]:
+ """
+ Extract the MoE problem size from the given tensor arguments:
+ - a: The hidden states, input to the MoE layer.
+ - w1: The first set of expert weights.
+ - w2: The second set of expert weights.
+ - topk_ids: The topk ids.
+
+ Note: extracting the problem shape from the weight and activation
+ tensors is not obvious. It needs to be done this way specifically
+ due to subtle issues with particular kernels, e.g. the int4 kernels
+ divide the trailing dimension by two, so it's not "correct" to
+ extract N or K from the trailing dimension of w1 or w2. Similarly,
+ some kernels transpose the weights, so this needs to be kept in mind.
+
+ Note: This implementation covers most cases. However, if experts
+ require a specialized implementation, like MarlinExperts, they are free
+ to override this function.
+ """
+ assert w1.dim() == 3 and w2.dim() == 3
+ E, N, _ = w1.size()
+ K = a1.size(-1)
+
+ if a1.dim() == 2:
+ # Make sure we are using the correct a1 (pre-permute).
+ assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}"
+ M = a1.size(0)
+ else:
+ assert a1.dim() == 3
+ assert a1.size(0) == E, f"{a1.size(0)} == {E}"
+ M = a1.size(1) # This is max_num_tokens
+
+ assert topk_ids.dim() == 2
+ topk = topk_ids.size(1)
+
+ return E, M, N, K, topk
+
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
"""
Workspace type: The dtype to use for the workspace tensors.
@@ -726,11 +873,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
) -> None:
apply_moe_activation(activation, output, input)
- def enable_chunking(self):
- return (
- envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and self.supports_chunking()
- )
-
+ @abstractmethod
def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce:
raise NotImplementedError
@@ -791,6 +934,67 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
raise NotImplementedError
+class FusedMoEExpertsMonolithic(FusedMoEExperts):
+ """
+ An abstract base class for the [Permute-Experts-Unpermute] step described
+ above, but with the monolithic interface (accepts router logits
+ rather than topk ids and weights).
+ """
+
+ @staticmethod
+ def _supports_routing_method(
+ routing_method: RoutingMethodType,
+ weight_key: QuantKey | None,
+ activation_key: QuantKey | None,
+ ) -> bool:
+ """
+ Whether the kernel supports a routing method (e.g. GroupedTopK).
+
+ Monolithic kernels should explicitly opt-in to support.
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ def _supports_router_logits_dtype(
+ router_logits_dtype: torch.dtype | None,
+ routing_method: RoutingMethodType,
+ ) -> bool:
+ """
+ Whether the kernel supports a dtype for router logits.
+
+ Modular kernels should opt-in to support.
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ def is_monolithic() -> bool:
+ return True
+
+ def apply(
+ self,
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ router_logits: torch.Tensor,
+ activation: MoEActivation,
+ global_num_experts: int,
+ expert_map: torch.Tensor | None,
+ a1q_scale: torch.Tensor | None,
+ apply_router_weight_on_input: bool,
+ # grouped topk + fused topk bias parameters
+ num_expert_group: int | None = None,
+ e_score_correction_bias: torch.Tensor | None = None,
+ routed_scaling_factor: float | None = None,
+ topk_group: int | None = None,
+ ) -> torch.Tensor:
+ """
+ Same as apply(), except uses router_logits as opposed
+ to the topk_ids and topk_weights. This is useful for kernels
+ with fused router and fused_experts (e.g. FLASHINFER_TRTLLM).
+ """
+ raise NotImplementedError
+
+
def _slice_scales(
scales: torch.Tensor | None, start: int, end: int
) -> torch.Tensor | None:
@@ -802,75 +1006,32 @@ def _slice_scales(
return None
+################################################################################
+# Kernel
+################################################################################
+
+
@final
-class FusedMoEModularKernel(torch.nn.Module):
- """
- This class combines a FusedMoEPrepareAndFinalize instance and
- a FusedMoEPermuteExpertsUnpermute to provide an interface that
- is compatible with the `fused_experts` function in fused_moe.py.
-
- It takes care of managing any required scratch space.
-
- Note: Instances of this class should only be used for a single model
- layer due to any layer specific state that may be used by the component
- objects.
- """
-
+class FusedMoEKernelModularImpl:
def __init__(
self,
- prepare_finalize: FusedMoEPrepareAndFinalize,
- fused_experts: FusedMoEPermuteExpertsUnpermute,
+ prepare_finalize: FusedMoEPrepareAndFinalizeModular,
+ fused_experts: FusedMoEExpertsModular,
shared_experts: torch.nn.Module | None = None,
moe_parallel_config: FusedMoEParallelConfig | None = None,
inplace: bool = False,
):
- super().__init__()
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
self.shared_experts = shared_experts
+ self.moe_parallel_config = moe_parallel_config
self.inplace = inplace
-
- # prefer an explicit FusedMoEParallelConfig when available (from
- # FusedMoE layers / tests).
- # if not provided, assume this kernel is
- # running in a non-DP+EP context
- self.moe_parallel_config: FusedMoEParallelConfig | None = moe_parallel_config
self.is_dp_ep = (
moe_parallel_config is not None
and moe_parallel_config.dp_size > 1
and moe_parallel_config.use_ep
)
- self._post_init_setup()
- assert (
- prepare_finalize.activation_format == fused_experts.activation_format()
- ), (
- f"{prepare_finalize.__class__.__name__}."
- f"{prepare_finalize.activation_format} == "
- f"{fused_experts.__class__.__name__}."
- f"{fused_experts.activation_format()}"
- )
-
- def _post_init_setup(self):
- """
- Resolve any leftover setup dependencies between self.prepare_finalize
- and self.fused_experts here.
- """
- self.prepare_finalize.post_init_setup(self.fused_experts)
-
- def supports_expert_map(self) -> bool:
- """
- A flag indicating whether or not this class supports expert maps.
- """
- return self.fused_experts.supports_expert_map()
-
- def output_is_reduced(self) -> bool:
- """
- Indicates whether or not the output of fused MoE kernel
- is reduced across all ranks.
- """
- return self.prepare_finalize.output_is_reduced()
-
def _chunk_info(self, M: int) -> tuple[int, int]:
"""
Compute number of chunks and chunk size for given M.
@@ -919,7 +1080,7 @@ class FusedMoEModularKernel(torch.nn.Module):
workspace_dtype = self.fused_experts.workspace_dtype(out_dtype)
# Force worst-case allocation in profiling run for
- # "mk.FusedMoEModularKernel.Standard" formats where this is only bounded
+ # "mk.FusedMoEKernel.Standard" formats where this is only bounded
# by `VLLM_FUSED_MOE_CHUNK_SIZE` and may not be seen during profiling with
# DP+EP due to the random token routing.
is_profile_run = (
@@ -1172,9 +1333,9 @@ class FusedMoEModularKernel(torch.nn.Module):
# This happens when none of the tokens from the all2all reach this
# EP rank. Also, note that this is only relevant for CUDAGraph
# incompatible all2all kernels like the DeepEP high-throughput
- # kernels. CUDAGraph compatible all2all kernels like the pplx
- # kernels and the DeepEP low-latency kernels are always batched
- # and can never run into the tensor.numel() == 0 case.
+ # kernels. CUDAGraph compatible all2all kernels like the DeepEP
+ # low-latency kernels are always batched and can never run into
+ # the tensor.numel() == 0 case.
if M_full == 0:
assert num_chunks == 0
workspace13 = None
@@ -1313,19 +1474,18 @@ class FusedMoEModularKernel(torch.nn.Module):
assert shared_output is not None
return shared_output, output
- def forward(
+ def apply(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
- topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ topk_weights: torch.Tensor,
activation: MoEActivation = MoEActivation.SILU,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
shared_experts_input: torch.Tensor | None = None,
- **kwargs
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
@@ -1335,8 +1495,7 @@ class FusedMoEModularKernel(torch.nn.Module):
- hidden_states: (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- - topk_weights (torch.Tensor): The topk weights applied at the end of
- the layer.
+ - topk_weights (torch.Tensor): The topk weights applied at the end of the layer.
- topk_ids (torch.Tensor): A map of row to expert id.
- activation (MoEActivation): The activation function to apply after the first
MoE layer.
@@ -1355,23 +1514,6 @@ class FusedMoEModularKernel(torch.nn.Module):
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
- from .fused_moe import fused_experts as fused_experts_kernel
-
- result = fused_experts_kernel(
- hidden_states=hidden_states,
- w1=w1,
- w2=w2,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- inplace=True,
- activation=activation,
- quant_config=kwargs.get("quant_config", None),
- apply_router_weight_on_input=apply_router_weight_on_input,
- global_num_experts=global_num_experts,
- expert_map=expert_map,
- )
-
- return result
if self.inplace:
assert self.shared_experts is None
assert not disable_inplace()
@@ -1417,3 +1559,206 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input,
shared_experts_input=shared_experts_input,
)
+
+
+@final
+class FusedMoEKernelMonolithicImpl:
+ def __init__(
+ self,
+ prepare_finalize: FusedMoEPrepareAndFinalizeMonolithic,
+ fused_experts: FusedMoEExpertsMonolithic,
+ ):
+ self.prepare_finalize = prepare_finalize
+ self.fused_experts = fused_experts
+
+ def apply(
+ self,
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ router_logits: torch.Tensor,
+ activation: MoEActivation,
+ global_num_experts: int,
+ expert_map: torch.Tensor | None,
+ apply_router_weight_on_input: bool,
+ # grouped topk + fused topk bias parameters
+ num_expert_group: int | None = None,
+ e_score_correction_bias: torch.Tensor | None = None,
+ routed_scaling_factor: float | None = None,
+ topk_group: int | None = None,
+ ) -> torch.Tensor:
+ """
+ Same as forward(), except uses router_logits as opposed
+ to the topk_ids and topk_weights. This is used for kernels
+ that have fused router + experts (e.g. FLASHINFER_TRTLLM).
+ """
+
+ # TODO(rob): add inplace support.
+ a1q, a1q_scale, router_logits = self.prepare_finalize.prepare(
+ hidden_states,
+ router_logits=router_logits,
+ quant_config=self.fused_experts.quant_config,
+ defer_input_quant=self.fused_experts.expects_unquantized_inputs,
+ )
+
+ fused_out = self.fused_experts.apply(
+ hidden_states=a1q,
+ w1=w1,
+ w2=w2,
+ router_logits=router_logits,
+ activation=activation,
+ global_num_experts=global_num_experts,
+ expert_map=expert_map,
+ apply_router_weight_on_input=apply_router_weight_on_input,
+ a1q_scale=a1q_scale,
+ # grouped topk + fused topk bias parameters
+ num_expert_group=num_expert_group,
+ e_score_correction_bias=e_score_correction_bias,
+ routed_scaling_factor=routed_scaling_factor,
+ topk_group=topk_group,
+ )
+
+ output = self.prepare_finalize.finalize(fused_out)
+
+ return output
+
+
+@final
+class FusedMoEKernel:
+ def __init__(
+ self,
+ prepare_finalize: FusedMoEPrepareAndFinalize,
+ fused_experts: FusedMoEExperts,
+ shared_experts: torch.nn.Module | None = None,
+ moe_parallel_config: FusedMoEParallelConfig | None = None,
+ inplace: bool = False,
+ ):
+ super().__init__()
+ self.shared_experts = shared_experts # NOTE: check if we can remove
+
+ # Initialize the implementation (monolithic or modular).
+ self.impl: FusedMoEKernelModularImpl | FusedMoEKernelMonolithicImpl
+ if isinstance(
+ prepare_finalize, FusedMoEPrepareAndFinalizeModular
+ ) and isinstance(fused_experts, FusedMoEExpertsModular):
+ self.impl = FusedMoEKernelModularImpl(
+ prepare_finalize,
+ fused_experts,
+ shared_experts,
+ moe_parallel_config,
+ inplace,
+ )
+
+ elif isinstance(
+ prepare_finalize, FusedMoEPrepareAndFinalizeMonolithic
+ ) and isinstance(fused_experts, FusedMoEExpertsMonolithic):
+ assert shared_experts is None
+ assert not inplace
+ self.impl = FusedMoEKernelMonolithicImpl(
+ prepare_finalize,
+ fused_experts,
+ )
+
+ else:
+ raise ValueError(
+ "prepare_finalize and fused_experts must both be either monolithic "
+ f"or non-monolithic but got {prepare_finalize.__class__.__name__} "
+ f"and {fused_experts.__class__.__name__}"
+ )
+
+ self._post_init_setup()
+
+ @property
+ def is_monolithic(self) -> bool:
+ return isinstance(self.impl, FusedMoEKernelMonolithicImpl)
+
+ @property
+ def prepare_finalize(self) -> FusedMoEPrepareAndFinalize:
+ return self.impl.prepare_finalize
+
+ @property
+ def fused_experts(self) -> FusedMoEExperts:
+ return self.impl.fused_experts
+
+ def _post_init_setup(self):
+ """
+ Resolve any leftover setup dependencies between self.prepare_finalize
+ and self.fused_experts here.
+ """
+ self.prepare_finalize.post_init_setup(self.impl.fused_experts)
+ assert (
+ self.prepare_finalize.activation_format
+ == self.fused_experts.activation_format()
+ )
+
+ def supports_expert_map(self) -> bool:
+ """
+ A flag indicating whether or not this class supports expert maps.
+ """
+ return self.fused_experts.supports_expert_map()
+
+ def output_is_reduced(self) -> bool:
+ """
+ Indicates whether or not the output of fused MoE kernel
+ is reduced across all ranks.
+ """
+ return self.prepare_finalize.output_is_reduced()
+
+ def apply_monolithic(
+ self,
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ router_logits: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
+ activation: MoEActivation,
+ global_num_experts: int,
+ expert_map: torch.Tensor | None,
+ apply_router_weight_on_input: bool,
+ # grouped topk + fused topk bias parameters
+ num_expert_group: int | None = None,
+ e_score_correction_bias: torch.Tensor | None = None,
+ routed_scaling_factor: float | None = None,
+ topk_group: int | None = None,
+ ) -> torch.Tensor:
+ assert isinstance(self.impl, FusedMoEKernelMonolithicImpl)
+ return self.impl.apply(
+ hidden_states=hidden_states,
+ w1=w1,
+ w2=w2,
+ router_logits=router_logits,
+ activation=activation,
+ global_num_experts=global_num_experts,
+ expert_map=expert_map,
+ apply_router_weight_on_input=apply_router_weight_on_input,
+ num_expert_group=num_expert_group,
+ e_score_correction_bias=e_score_correction_bias,
+ routed_scaling_factor=routed_scaling_factor,
+ topk_group=topk_group,
+ )
+
+ def apply(
+ self,
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ activation: MoEActivation,
+ global_num_experts: int,
+ expert_map: torch.Tensor | None,
+ apply_router_weight_on_input: bool,
+ shared_experts_input: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ assert isinstance(self.impl, FusedMoEKernelModularImpl)
+ return self.impl.apply(
+ hidden_states=hidden_states,
+ w1=w1,
+ w2=w2,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ activation=activation,
+ global_num_experts=global_num_experts,
+ expert_map=expert_map,
+ apply_router_weight_on_input=apply_router_weight_on_input,
+ shared_experts_input=shared_experts_input,
+ )
diff --git a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py
index dc0f32d..164605d 100644
--- a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py
+++ b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py
@@ -12,7 +12,7 @@ from vllm.platforms import current_platform
logger = init_logger(__name__)
-class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
+class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
"""
Prepare/Finalize using MoRI kernels.
"""
diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py
index 6f961df..0ed159b 100644
--- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py
+++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py
@@ -18,13 +18,9 @@ from vllm.model_executor.layers.fused_moe.config import (
fp8_w8a8_moe_quant_config,
fp8_w8a16_moe_quant_config,
)
-from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
- is_supported_config_trtllm_fp8,
-)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
get_flashinfer_moe_backend,
- make_fp8_moe_alpha_scales_for_fi,
prepare_fp8_moe_layer_for_fi,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
@@ -103,9 +99,13 @@ def _get_priority_backends(
def backend_to_kernel_cls(
backend: Fp8MoeBackend,
-) -> type[mk.FusedMoEPermuteExpertsUnpermute]:
+) -> type[mk.FusedMoEExperts]:
if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
- raise NotImplementedError
+ from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import ( # noqa: E501
+ TrtLlmFp8Experts,
+ )
+
+ return TrtLlmFp8Experts
elif backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
@@ -205,13 +205,11 @@ def select_fp8_moe_backend(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
allow_vllm_cutlass: bool = False,
-) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]:
+) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts] | None]:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
- k_cls: type[mk.FusedMoEPermuteExpertsUnpermute] | None = None
-
if config.is_lora_enabled:
return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON)
@@ -252,7 +250,7 @@ def select_fp8_moe_backend(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
- ) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]:
+ ) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format
@@ -287,16 +285,6 @@ def select_fp8_moe_backend(
"vLLM CUTLASS FP8 MoE backend is disabled for this configuration."
)
- # Handle FLASHINFER_TRTLLM specially (no kernel class).
- if requested_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
- supported, reason = is_supported_config_trtllm_fp8(
- config, weight_key, activation_key, activation_format
- )
- if supported:
- logger.info_once(_make_log_backend(requested_backend))
- return requested_backend, None
- raise ValueError(_make_log_unsupported(requested_backend, reason))
-
return _return_or_raise(
requested_backend, config, weight_key, activation_key, activation_format
)
@@ -311,51 +299,32 @@ def select_fp8_moe_backend(
elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
# If user is explicit about backend, validate it.
fi_backend = get_flashinfer_moe_backend()
-
- if fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
- backend = Fp8MoeBackend.FLASHINFER_TRTLLM
- supported, reason = is_supported_config_trtllm_fp8(
- config, weight_key, activation_key, activation_format
- )
- if supported:
- logger.info_once(_make_log_backend(backend))
- return backend, None
- else:
- raise ValueError(_make_log_unsupported(backend, reason))
-
- elif fi_backend == FlashinferMoeBackend.CUTLASS:
+ if fi_backend == FlashinferMoeBackend.CUTLASS:
backend = Fp8MoeBackend.FLASHINFER_CUTLASS
- return _return_or_raise(
- backend, config, weight_key, activation_key, activation_format
- )
-
+ elif fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
+ backend = Fp8MoeBackend.FLASHINFER_TRTLLM
else:
- assert fi_backend == FlashinferMoeBackend.CUTEDSL
- raise ValueError("FlashInfer MaskedGEMM not supported for FP8")
-
+ raise ValueError(
+ f"FlashInfer MOE backend {fi_backend} does not support FP8 MoE."
+ )
+ k_cls = backend_to_kernel_cls(backend)
+ return _return_or_raise(
+ backend, config, weight_key, activation_key, activation_format
+ )
else:
# If the user is not explicit about the backend, try both.
for backend in [
Fp8MoeBackend.FLASHINFER_TRTLLM,
Fp8MoeBackend.FLASHINFER_CUTLASS,
]:
- if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
- k_cls = None
- supported, reason = is_supported_config_trtllm_fp8(
- config,
- weight_key,
- activation_key,
- activation_format,
- )
- else:
- k_cls = backend_to_kernel_cls(backend)
- supported, reason = k_cls.is_supported_config(
- k_cls,
- config,
- weight_key,
- activation_key,
- activation_format,
- )
+ k_cls = backend_to_kernel_cls(backend)
+ supported, reason = k_cls.is_supported_config(
+ k_cls,
+ config,
+ weight_key,
+ activation_key,
+ activation_format,
+ )
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
@@ -408,23 +377,14 @@ def select_fp8_moe_backend(
# Select kernels in order of backend.
for backend in AVAILABLE_BACKENDS:
- if backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
- k_cls = None
- supported, reason = is_supported_config_trtllm_fp8(
- config,
- weight_key,
- activation_key,
- activation_format,
- )
- else:
- k_cls = backend_to_kernel_cls(backend)
- supported, reason = k_cls.is_supported_config(
- k_cls,
- config,
- weight_key,
- activation_key,
- activation_format,
- )
+ k_cls = backend_to_kernel_cls(backend)
+ supported, reason = k_cls.is_supported_config(
+ k_cls,
+ config,
+ weight_key,
+ activation_key,
+ activation_format,
+ )
if supported:
logger.info_once(_make_log_backend(backend), scope="local")
@@ -510,7 +470,7 @@ def make_fp8_moe_quant_config(
block_shape: list[int] | None = None,
per_act_token_quant: bool = False,
per_out_ch_quant: bool = False,
-) -> FusedMoEQuantConfig | None:
+) -> FusedMoEQuantConfig:
"""
Create FusedMoEQuantConfig for the specified FP8 Backend.
The FusedMoEQuantConfig holds the scales that are used
@@ -523,9 +483,6 @@ def make_fp8_moe_quant_config(
In a future PR, we will have this function should be
a method of the modular kernel itself.
"""
- # TRTLLM does not use Modular Kernel abstraction yet.
- if fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
- return None
# MARLIN is mixed precision W8A16 config.
if fp8_backend == Fp8MoeBackend.MARLIN:
@@ -539,12 +496,6 @@ def make_fp8_moe_quant_config(
# (alpha = w_scale * a_scale) and inverse a2 scale.
if fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS and block_shape is None:
assert a1_scale is not None and a2_scale is not None
- g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
- w1_scale,
- a1_scale,
- w2_scale,
- a2_scale,
- )
return fp8_w8a8_moe_quant_config(
w1_scale=w1_scale,
w2_scale=w2_scale,
@@ -552,8 +503,8 @@ def make_fp8_moe_quant_config(
a2_scale=a2_scale,
a1_gscale=(1.0 / a1_scale),
a2_gscale=(1.0 / a2_scale),
- g1_alphas=g1_alphas,
- g2_alphas=g2_alphas,
+ g1_alphas=(w1_scale * a1_scale).squeeze(),
+ g2_alphas=(w2_scale * a2_scale).squeeze(),
)
# All other backends use normal config.
return fp8_w8a8_moe_quant_config(
@@ -570,17 +521,18 @@ def make_fp8_moe_quant_config(
def make_fp8_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
- experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
+ experts_cls: type[mk.FusedMoEExperts],
fp8_backend: Fp8MoeBackend,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: torch.nn.Module | None = None,
-) -> mk.FusedMoEModularKernel:
+) -> mk.FusedMoEKernel:
# Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=moe_quant_config,
routing_tables=routing_tables,
allow_new_interface=True,
+ use_monolithic=issubclass(experts_cls, mk.FusedMoEExpertsMonolithic),
)
assert prepare_finalize is not None
@@ -603,9 +555,9 @@ def make_fp8_moe_kernel(
)
# NOTE(rob): we only want the mk to control the shared_expert
- # if using all2all (for SBO). bnell is making this explict in
+ # if using all2all (for SBO). bnell is making this explicit in
# the new MoE runner class.
- kernel = mk.FusedMoEModularKernel(
+ kernel = mk.FusedMoEKernel(
prepare_finalize,
experts,
shared_experts=(
diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
index ee7db88..dd1a24d 100644
--- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
+++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py
@@ -19,7 +19,6 @@ from vllm.model_executor.layers.fused_moe.config import (
nvfp4_w4a16_moe_quant_config,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
- is_supported_config_trtllm,
prepare_nvfp4_moe_layer_for_fi_or_cutlass,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
@@ -67,39 +66,46 @@ def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool:
def backend_to_kernel_cls(
backend: NvFp4MoeBackend,
-) -> type[mk.FusedMoEPermuteExpertsUnpermute]:
+) -> list[type[mk.FusedMoEExperts]]:
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
- raise NotImplementedError(
- "FLASHINFER_TRTLLM doesn't support Modular Kernel Interface"
+ from vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe import (
+ TrtLlmNvFp4ExpertsModular,
+ TrtLlmNvFp4ExpertsMonolithic,
)
+ # NOTE: prefer Monolthic > Modular, so return Monolithic first.
+ return [
+ TrtLlmNvFp4ExpertsMonolithic,
+ TrtLlmNvFp4ExpertsModular,
+ ]
+
elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS:
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
- return FlashInferExperts
+ return [FlashInferExperts]
elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL:
from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import (
FlashInferCuteDSLExperts,
)
- return FlashInferCuteDSLExperts
+ return [FlashInferCuteDSLExperts]
elif backend == NvFp4MoeBackend.VLLM_CUTLASS:
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp4,
)
- return CutlassExpertsFp4
+ return [CutlassExpertsFp4]
elif backend == NvFp4MoeBackend.MARLIN:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
)
- return MarlinExperts
+ return [MarlinExperts]
else:
raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}")
@@ -125,7 +131,7 @@ def select_nvfp4_moe_backend(
config: FusedMoEConfig,
weight_key: QuantKey | None,
activation_key: QuantKey | None,
-) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]:
+) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]:
"""
Select the primary NvFP4 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
@@ -143,10 +149,7 @@ def select_nvfp4_moe_backend(
# NOTE(rob): this is kind of a hack. We need to peak into
# the prepare-finalize selection to determine if we are using
# the batched or standard expert format.
- use_batched = (
- config.moe_parallel_config.use_deepep_ll_kernels
- or config.moe_parallel_config.use_pplx_kernels
- )
+ use_batched = config.moe_parallel_config.use_deepep_ll_kernels
activation_format = (
mk.FusedMoEActivationFormat.BatchedExperts
if use_batched
@@ -178,29 +181,21 @@ def select_nvfp4_moe_backend(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
activation_format: mk.FusedMoEActivationFormat,
- ) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]:
- k_cls = backend_to_kernel_cls(backend)
- supported, reason = k_cls.is_supported_config(
- k_cls, config, weight_key, activation_key, activation_format
- )
- if supported:
- logger.info_once(_make_log_backend(backend))
- return backend, k_cls
+ ) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]:
+ for k_cls in backend_to_kernel_cls(backend):
+ supported, reason = k_cls.is_supported_config(
+ k_cls, config, weight_key, activation_key, activation_format
+ )
+ if supported:
+ logger.info_once(_make_log_backend(backend))
+ return backend, k_cls
+
raise ValueError(_make_log_unsupported(backend, reason))
# Handle explicit moe_backend from user.
runner_backend = config.moe_backend
if runner_backend != "auto":
requested_backend = map_nvfp4_backend(runner_backend)
- if requested_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
- supported, reason = is_supported_config_trtllm(
- config, weight_key, activation_key, activation_format
- )
- if supported:
- logger.info_once(_make_log_backend(requested_backend))
- return requested_backend, None
- raise ValueError(_make_log_unsupported(requested_backend, reason))
-
return _return_or_raise(
requested_backend, config, weight_key, activation_key, activation_format
)
@@ -213,36 +208,14 @@ def select_nvfp4_moe_backend(
elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
# If user is explicit about backend, validate it.
- fi_backend = get_flashinfer_moe_backend()
-
- if fi_backend == FlashinferMoeBackend.TENSORRT_LLM:
- backend = NvFp4MoeBackend.FLASHINFER_TRTLLM
- supported, reason = is_supported_config_trtllm(
- config, weight_key, activation_key, activation_format
- )
- if supported:
- logger.info_once(_make_log_backend(backend))
- return backend, None
- else:
- raise ValueError(_make_log_unsupported(backend, reason))
- else:
- backend = fi_2_vllm_backend_map[fi_backend]
- return _return_or_raise(
- backend, config, weight_key, activation_key, activation_format
- )
+ backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()]
+ return _return_or_raise(
+ backend, config, weight_key, activation_key, activation_format
+ )
else:
# If the user is not explicit about the backend, try each.
for backend in FLASHINFER_NVFP4_MOE_BACKENDS:
- if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
- k_cls = None
- supported, reason = is_supported_config_trtllm(
- config,
- weight_key,
- activation_key,
- activation_format,
- )
- else:
- k_cls = backend_to_kernel_cls(backend)
+ for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls,
config,
@@ -250,13 +223,13 @@ def select_nvfp4_moe_backend(
activation_key,
activation_format,
)
- if supported:
- logger.info_once(_make_log_backend(backend), scope="local")
- return backend, None
- else:
- logger.debug_once(
- _make_log_unsupported(backend, reason), scope="local"
- )
+ if supported:
+ logger.info_once(_make_log_backend(backend), scope="local")
+ return backend, k_cls
+ else:
+ logger.debug_once(
+ _make_log_unsupported(backend, reason), scope="local"
+ )
raise NotImplementedError(
"Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no "
@@ -271,16 +244,7 @@ def select_nvfp4_moe_backend(
# Select kernels in order of backend.
for backend in AVAILABLE_BACKENDS:
- if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
- k_cls = None # type: ignore[assignment]
- supported, reason = is_supported_config_trtllm(
- config,
- weight_key,
- activation_key,
- activation_format,
- )
- else:
- k_cls = backend_to_kernel_cls(backend)
+ for k_cls in backend_to_kernel_cls(backend):
supported, reason = k_cls.is_supported_config(
k_cls,
config,
@@ -289,11 +253,11 @@ def select_nvfp4_moe_backend(
activation_format,
)
- if supported:
- logger.info_once(_make_log_backend(backend), scope="local")
- return backend, k_cls
- else:
- logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
+ if supported:
+ logger.info_once(_make_log_backend(backend), scope="local")
+ return backend, k_cls
+ else:
+ logger.debug_once(_make_log_unsupported(backend, reason), scope="local")
raise NotImplementedError(
"No NvFp4 MoE backend supports the deployment configuration."
@@ -401,12 +365,8 @@ def make_nvfp4_moe_quant_config(
w2_scale_2: torch.Tensor,
a13_scale: torch.Tensor,
a2_scale: torch.Tensor,
-) -> FusedMoEQuantConfig | None:
- UNSUPPORTED = [NvFp4MoeBackend.FLASHINFER_TRTLLM]
- if backend in UNSUPPORTED:
- return None
-
- elif backend == NvFp4MoeBackend.MARLIN:
+) -> FusedMoEQuantConfig:
+ if backend == NvFp4MoeBackend.MARLIN:
return nvfp4_w4a16_moe_quant_config(
g1_alphas=w13_scale_2,
g2_alphas=w2_scale_2,
@@ -423,22 +383,27 @@ def make_nvfp4_moe_quant_config(
a2_gscale=(1.0 / a2_scale),
w1_scale=w13_scale,
w2_scale=w2_scale,
+ # NOTE(rob): this is a hack until the MoE kernels
+ # create their own quant configs. TRTLLM kernel
+ # does not accept swizzled input quant scales.
+ is_nvfp4_scale_swizzled=(backend != NvFp4MoeBackend.FLASHINFER_TRTLLM),
)
def make_nvfp4_moe_kernel(
moe_quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
- experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute],
+ experts_cls: type[mk.FusedMoEExperts],
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
shared_experts: torch.nn.Module | None = None,
-) -> mk.FusedMoEModularKernel:
+) -> mk.FusedMoEKernel:
# Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=moe_quant_config,
routing_tables=routing_tables,
allow_new_interface=True,
+ use_monolithic=issubclass(experts_cls, mk.FusedMoEExpertsMonolithic),
)
assert prepare_finalize is not None
@@ -461,9 +426,9 @@ def make_nvfp4_moe_kernel(
)
# NOTE(rob): we only want the mk to control the shared_expert
- # if using all2all (for SBO). bnell is making this explict in
+ # if using all2all (for SBO). bnell is making this explicit in
# the new MoE runner class.
- kernel = mk.FusedMoEModularKernel(
+ kernel = mk.FusedMoEKernel(
prepare_finalize,
experts,
shared_experts=(
diff --git a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py
index 1c582bc..3d8d56e 100644
--- a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py
+++ b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py
@@ -19,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import (
is_supported_config_trtllm_bf16,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
- MoEPrepareAndFinalizeNoEP,
+ MoEPrepareAndFinalizeNoDPEPModular,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
swap_w13_to_w31,
@@ -209,7 +209,7 @@ def make_unquantized_moe_kernel(
backend: UnquantizedMoeBackend,
quant_config: FusedMoEQuantConfig,
moe_config: FusedMoEConfig,
-) -> mk.FusedMoEModularKernel | None:
+) -> mk.FusedMoEKernel | None:
if backend in UNSUPPORTED_BACKEND:
return None
@@ -218,8 +218,8 @@ def make_unquantized_moe_kernel(
FlashInferExperts,
)
- kernel = mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(),
+ kernel = mk.FusedMoEKernel(
+ MoEPrepareAndFinalizeNoDPEPModular(),
FlashInferExperts(
moe_config=moe_config,
quant_config=quant_config,
@@ -232,8 +232,8 @@ def make_unquantized_moe_kernel(
AiterExperts,
)
- kernel = mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(),
+ kernel = mk.FusedMoEKernel(
+ MoEPrepareAndFinalizeNoDPEPModular(),
AiterExperts(
moe_config=moe_config,
quant_config=quant_config,
@@ -241,25 +241,6 @@ def make_unquantized_moe_kernel(
inplace=not moe_config.disable_inplace,
)
elif backend == UnquantizedMoeBackend.TRITON:
- from vllm.model_executor.layers.fused_moe import TritonExperts
-
- kernel = mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(),
- TritonExperts(
- moe_config=moe_config,
- quant_config=quant_config,
- ),
- inplace=not moe_config.disable_inplace,
- )
- elif backend == UnquantizedMoeBackend.XPU:
- from vllm.model_executor.layers.fused_moe import XPUExperts
-
- kernel = mk.FusedMoEModularKernel(
- MoEPrepareAndFinalizeNoEP(),
- XPUExperts(
- moe_config=moe_config,
- quant_config=quant_config,
- ),
- inplace=not moe_config.disable_inplace,
- )
+ from vllm.model_executor.layers.fused_moe import fused_experts
+ kernel = fused_experts
return kernel
diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
deleted file mode 100644
index 289ac0d..0000000
--- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
+++ /dev/null
@@ -1,373 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from collections.abc import Callable
-
-import pplx_kernels as pplx
-import torch
-
-import vllm.model_executor.layers.fused_moe.modular_kernel as mk
-from vllm.logger import init_logger
-from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
-from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
- TopKWeightAndReduceDelegate,
-)
-from vllm.model_executor.layers.fused_moe.utils import (
- _validate_scale_shape,
- moe_kernel_quantize_input,
-)
-from vllm.utils.math_utils import cdiv, round_up
-
-logger = init_logger(__name__)
-
-
-def pplx_hidden_dim_scale_bytes(
- max_num_tokens: int,
- hidden_dim: int,
- in_dtype: torch.dtype,
- quant_dtype: torch.dtype | str | None,
- per_act_token_quant: bool,
- block_shape: list[int] | None,
-):
- # All pplx byte sizes must be 16-byte aligned.
- align = 16
-
- # For blocked per token: set to
- # cdiv(hidden_dim, block_size) * sizeof(float32)
- # For per-token: set to 4 * sizeof(float32) (x4 for alignment)
- if quant_dtype is not None:
- assert isinstance(quant_dtype, torch.dtype)
- assert quant_dtype.itemsize == 1
- hidden_dim_bytes = hidden_dim * quant_dtype.itemsize
- elem_size = torch.float32.itemsize
-
- if per_act_token_quant:
- # per-token (M x 1)
- assert block_shape is None
- hidden_scale_bytes = elem_size
- elif block_shape is not None:
- # per-group (M x K_tiles)
- block_size = block_shape[1]
- num_blocks = cdiv(hidden_dim, block_size)
- hidden_scale_bytes = num_blocks * elem_size
- else:
- # per-tensor (1 x 1)
- hidden_scale_bytes = elem_size
- else:
- hidden_dim_bytes = hidden_dim * in_dtype.itemsize
- hidden_scale_bytes = 0
-
- return (
- round_up(hidden_dim_bytes, align),
- round_up(hidden_scale_bytes, align),
- )
-
-
-class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
- """PPLX-based prepare and finalize for expert parallelism."""
-
- def __init__(
- self,
- a2a: pplx.AllToAll,
- max_num_tokens: int,
- num_local_experts: int,
- num_dispatchers: int,
- ):
- super().__init__()
- assert max_num_tokens > 0
- assert num_local_experts > 0
- self.a2a = a2a
- self.max_num_tokens = max_num_tokens
- self.num_local_experts = num_local_experts
- self.num_dispatchers_ = num_dispatchers
-
- @property
- def activation_format(self) -> mk.FusedMoEActivationFormat:
- return mk.FusedMoEActivationFormat.BatchedExperts
-
- def max_num_tokens_per_rank(self) -> int | None:
- return self.max_num_tokens
-
- def topk_indices_dtype(self) -> torch.dtype | None:
- return torch.uint32
-
- def num_dispatchers(self) -> int:
- return self.num_dispatchers_
-
- def output_is_reduced(self) -> bool:
- return True
-
- def supports_async(self) -> bool:
- return True
-
- def prepare_async(
- self,
- a1: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- num_experts: int,
- expert_map: torch.Tensor | None,
- apply_router_weight_on_input: bool,
- quant_config: FusedMoEQuantConfig,
- defer_input_quant: bool = False,
- ) -> tuple[Callable, mk.ReceiverType]:
- if defer_input_quant:
- raise NotImplementedError(
- f"{self.__class__.__name__} does not support defer_input_quant=True. "
- "Please select an MoE kernel that accepts quantized inputs."
- )
-
- num_tokens = a1.size(0) # M
- hidden_dim = a1.size(-1) # K
-
- assert topk_ids.size(0) == num_tokens
- # expert_map should be None because with expert map, -1 id is used for
- # non-local token; this causes error when casting ids to the
- # topk_indices_dtype() int32
- #
- if expert_map is not None:
- logger.warning_once(
- "The PPLX backend does not support expert mapping. "
- "The provided `expert_map` will be ignored."
- )
- expert_map = None # noqa: F841
-
- # Is this always going to be a1.device?
- device = a1.device
-
- if apply_router_weight_on_input:
- topk = topk_ids.size(1)
- # TODO: this only works for topK=1, will need to update for topK>1
- assert topk == 1, (
- "apply_router_weight_on_input is only implemented for topk=1"
- )
- a1 = a1 * topk_weights.to(a1.dtype)
-
- repeat_cols = 4
- repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0)
- # TODO(bnell): always pass quant_config.a1_scale?
- a1q, a1q_scale = moe_kernel_quantize_input(
- a1,
- (None if quant_config.per_act_token_quant else quant_config.a1_scale),
- quant_dtype=quant_config.quant_dtype,
- per_act_token_quant=quant_config.per_act_token_quant,
- block_shape=quant_config.block_shape,
- )
-
- _validate_scale_shape(
- a1q, a1q_scale, quant_config.per_act_token_quant, quant_config.block_shape
- )
-
- orig_a_scale_block_shape: int | None = None
-
- if a1q_scale is not None:
- scalar_scales = a1q_scale.numel() == 1
-
- # pplx requires 2-d scales even for scalar scales
- if a1q_scale.dim() <= 1:
- assert scalar_scales
- a1q_scale = a1q_scale.view(1, 1)
-
- orig_a_scale_block_shape = a1q_scale.shape[-1]
-
- if not quant_config.is_block_quantized:
- # TODO (bnell): use group_broadcast instead?
- a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols)
-
- assert a1q_scale is None or a1q_scale.ndim == 2, (
- f"{0 if a1q_scale is None else (a1q_scale.ndim, a1q_scale.shape)}"
- )
-
- expert_num_tokens = torch.empty(
- self.num_local_experts,
- dtype=torch.int32,
- device=device,
- )
-
- expert_x = torch.empty(
- (
- self.num_local_experts,
- self.max_num_tokens * self.num_dispatchers(),
- hidden_dim,
- ),
- dtype=a1q.dtype,
- device=device,
- )
-
- expert_x_scale: torch.Tensor | None = None
- if a1q.dtype.itemsize == 1:
- if quant_config.is_per_act_token:
- # (M x 1) -> (E x M x K)
- final_dim = expert_x.size(2)
- elif quant_config.is_per_tensor:
- # (1 x 1) -> (E x 1 x 1)
- final_dim = 1
- else:
- # (M x K_tiles) -> (E x M x K_tiles)
- assert quant_config.block_shape is not None
- num_blocks = cdiv(expert_x.size(2), quant_config.block_shape[1])
- final_dim = num_blocks
-
- expert_x_scale_shape = (
- self.num_local_experts,
- expert_x.size(1),
- round_up(final_dim, 4), # round up for alignment
- )
-
- expert_x_scale = torch.empty(
- expert_x_scale_shape,
- dtype=torch.float32,
- device=expert_x.device,
- )
-
- # This argument is optional, defaults to indices.size(0)
- # There's not much point setting this unless it is != indices.size(0)
- bound_m: torch.Tensor | None = None
-
- self.a2a.dispatch(
- out_expert_num_tokens=expert_num_tokens,
- out_expert_x=expert_x,
- out_expert_x_scale=expert_x_scale,
- dp_x=a1q,
- dp_x_scale=a1q_scale,
- indices=topk_ids,
- bound_m=bound_m,
- do_send=True,
- do_recv=False,
- )
-
- hook = lambda: self.a2a.dispatch(
- out_expert_num_tokens=expert_num_tokens,
- out_expert_x=expert_x,
- out_expert_x_scale=expert_x_scale,
- dp_x=a1q,
- dp_x_scale=a1q_scale,
- indices=topk_ids,
- bound_m=bound_m,
- do_send=False,
- do_recv=True,
- )
-
- return (
- hook,
- lambda: self._receiver(
- expert_num_tokens,
- expert_x,
- expert_x_scale,
- orig_a_scale_block_shape,
- ),
- )
-
- def _receiver(
- self,
- expert_num_tokens: torch.Tensor,
- expert_x: torch.Tensor,
- expert_x_scale: torch.Tensor | None,
- orig_a_scale_block_shape: int | None,
- ) -> mk.PrepareResultType:
- if expert_x_scale is not None:
- expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
- assert expert_x_scale.ndim == 3
-
- expert_tokens_meta = mk.ExpertTokensMetadata(
- expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None
- )
-
- return expert_x, expert_x_scale, expert_tokens_meta, None, None
-
- def prepare(
- self,
- a1: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- num_experts: int,
- expert_map: torch.Tensor | None,
- apply_router_weight_on_input: bool,
- quant_config: FusedMoEQuantConfig,
- defer_input_quant: bool = False,
- ) -> mk.PrepareResultType:
- hook, receiver = self.prepare_async(
- a1,
- topk_weights,
- topk_ids,
- num_experts,
- expert_map,
- apply_router_weight_on_input,
- quant_config,
- defer_input_quant=defer_input_quant,
- )
- hook()
- return receiver()
-
- def finalize_async(
- self,
- output: torch.Tensor,
- fused_expert_output: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- apply_router_weight_on_input: bool,
- weight_and_reduce_impl: mk.TopKWeightAndReduce,
- ) -> Callable:
- assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate), (
- "Weight application and reduction happens in the combine kernel."
- )
-
- # This argument is optional
- # There's not much point setting this unless it is != topk_ids.size(0)
- bound_m: torch.Tensor | None = None
-
- # TODO (bnell): fails in test_pplx_moe.py, figure out what's going on
- # num_tokens = output.size(0) # M
- # assert topk_ids.size(0) == num_tokens, (
- # f"{topk_ids.size(0)} == {num_tokens}")
- assert topk_ids.size() == topk_weights.size(), (
- f"{topk_ids.size()} == {topk_weights.size()}"
- )
- assert output.size(0) <= self.max_num_tokens, (
- f"{output.size(0)} <= {self.max_num_tokens}"
- )
- assert output.size(1) == fused_expert_output.size(-1)
-
- # Set weights to 1 if we did them in dispatch. This is hacky.
- if apply_router_weight_on_input:
- topk_weights = torch.ones_like(topk_weights)
-
- topk_ids_u32 = topk_ids.view(dtype=torch.uint32)
-
- self.a2a.combine(
- out_tokens=output,
- indices=topk_ids_u32,
- weights=topk_weights,
- expert_y=fused_expert_output,
- bound_m=bound_m,
- do_send=True,
- do_recv=False,
- )
-
- return lambda: self.a2a.combine(
- out_tokens=output,
- indices=topk_ids_u32,
- weights=topk_weights,
- expert_y=fused_expert_output,
- bound_m=bound_m,
- do_send=False,
- do_recv=True,
- )
-
- def finalize(
- self,
- output: torch.Tensor,
- fused_expert_output: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- apply_router_weight_on_input: bool,
- weight_and_reduce_impl: mk.TopKWeightAndReduce,
- ) -> None:
- receiver = self.finalize_async(
- output,
- fused_expert_output,
- topk_weights,
- topk_ids,
- apply_router_weight_on_input,
- weight_and_reduce_impl,
- )
- receiver()
diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py
deleted file mode 100644
index 7b8dd3b..0000000
--- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py
+++ /dev/null
@@ -1,209 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import torch
-
-import vllm.model_executor.layers.fused_moe.modular_kernel as mk
-from vllm.distributed import get_ep_group
-from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
-from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
- TopKWeightAndReduceContiguous,
- TopKWeightAndReduceDelegate,
-)
-from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
-from vllm.utils.flashinfer import nvfp4_block_scale_interleave
-
-
-class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize):
- def __init__(
- self,
- is_sequence_parallel: bool = False,
- num_dispatchers: int = 1,
- ) -> None:
- super().__init__()
- self.is_sequence_parallel = is_sequence_parallel
- self._num_dispatchers = num_dispatchers
-
- @property
- def activation_format(self) -> mk.FusedMoEActivationFormat:
- return mk.FusedMoEActivationFormat.Standard
-
- def max_num_tokens_per_rank(self) -> int | None:
- return None
-
- def topk_indices_dtype(self) -> torch.dtype | None:
- return None
-
- def num_dispatchers(self) -> int:
- return self._num_dispatchers
-
- def output_is_reduced(self) -> bool:
- return False
-
- def prepare(
- self,
- a1: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- num_experts: int,
- expert_map: torch.Tensor | None,
- apply_router_weight_on_input: bool,
- quant_config: FusedMoEQuantConfig,
- defer_input_quant: bool = False,
- ) -> mk.PrepareResultType:
- if apply_router_weight_on_input:
- topk = topk_ids.size(1)
- assert topk == 1, (
- "apply_router_weight_on_input is only implemented for topk=1"
- )
- # Note: do not use inplace for shared experts overlap
- a1 = a1 * topk_weights.to(a1.dtype)
-
- # Defer input quantization to the MoE kernel.
- use_nvfp4 = quant_config.use_nvfp4_w4a4
- if defer_input_quant:
- a1q = a1
- a1q_scale = None
- else:
- a1q, a1q_scale = moe_kernel_quantize_input(
- a1,
- quant_config.a1_gscale if use_nvfp4 else quant_config.a1_scale,
- quant_config.quant_dtype,
- quant_config.per_act_token_quant,
- quant_config.block_shape,
- # NOTE: swizzling pads the scales to multiple of 128
- # which makes the scales tensor different shape than
- # the hidden states, breaking the A2A kernel. So, we
- # delay the swizzling until after the A2A.
- is_fp4_scale_swizzled=False,
- )
-
- # Skip gathering scales if we have static quantization
- # (the scale is a scalar, replicated on all ranks) or
- # if quantization is deferred.
- skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0
- scales = None if skip_gather_scales else [a1q_scale]
-
- res = get_ep_group().dispatch(
- a1q,
- topk_weights,
- topk_ids,
- is_sequence_parallel=self.is_sequence_parallel,
- extra_tensors=scales,
- )
- if skip_gather_scales:
- a1q, topk_weights, topk_ids = res
- else:
- a1q, topk_weights, topk_ids, scales = res
- assert scales is not None and len(scales) == 1
- a1q_scale = scales[0]
- if quant_config.quant_dtype == "nvfp4":
- assert a1q_scale is not None
- if a1q_scale.element_size() == 1:
- a1q_scale = a1q_scale.view(torch.uint8)
- a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
-
- return a1q, a1q_scale, None, topk_ids, topk_weights
-
- def finalize(
- self,
- output: torch.Tensor,
- fused_expert_output: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- apply_router_weight_on_input: bool,
- weight_and_reduce_impl: mk.TopKWeightAndReduce,
- ) -> None:
- if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
- weight_and_reduce_impl = TopKWeightAndReduceContiguous()
-
- out = weight_and_reduce_impl.apply(
- output=None,
- fused_expert_output=fused_expert_output,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- apply_router_weight_on_input=apply_router_weight_on_input,
- )
-
- output.copy_(
- get_ep_group().combine(out, is_sequence_parallel=self.is_sequence_parallel)
- )
-
-
-class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
- """MoE prepare and finalize without expert parallelism."""
-
- @property
- def activation_format(self) -> mk.FusedMoEActivationFormat:
- return mk.FusedMoEActivationFormat.Standard
-
- def max_num_tokens_per_rank(self) -> int | None:
- return None
-
- def topk_indices_dtype(self) -> torch.dtype | None:
- return None
-
- def num_dispatchers(self) -> int:
- return 1
-
- def output_is_reduced(self) -> bool:
- return False
-
- def prepare(
- self,
- a1: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- num_experts: int,
- expert_map: torch.Tensor | None,
- apply_router_weight_on_input: bool,
- quant_config: FusedMoEQuantConfig,
- defer_input_quant: bool = False,
- ) -> mk.PrepareResultType:
- if apply_router_weight_on_input:
- topk = topk_ids.size(1)
- # TODO: this only works for topK=1, will need to update for topK>1
- assert topk == 1, (
- "apply_router_weight_on_input is only implemented for topk=1"
- )
- # Note: do not use inplace for shared experts overlap
- a1 = a1 * topk_weights.to(a1.dtype)
-
- # Defer input quant to moe kernel for backends (e.g. AITER, FI)
- # which use a single kernel call for quant + experts.
- if defer_input_quant:
- return a1, None, None, None, None
-
- input_sf = (
- quant_config.a1_gscale
- if quant_config.use_nvfp4_w4a4
- else quant_config.a1_scale
- )
- a1q, a1q_scale = moe_kernel_quantize_input(
- a1,
- input_sf,
- quant_config.quant_dtype,
- quant_config.per_act_token_quant,
- quant_config.block_shape,
- )
-
- return a1q, a1q_scale, None, None, None
-
- def finalize(
- self,
- output: torch.Tensor,
- fused_expert_output: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- apply_router_weight_on_input: bool,
- weight_and_reduce_impl: mk.TopKWeightAndReduce,
- ) -> None:
- if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
- weight_and_reduce_impl = TopKWeightAndReduceContiguous()
- weight_and_reduce_impl.apply(
- output=output,
- fused_expert_output=fused_expert_output,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- apply_router_weight_on_input=apply_router_weight_on_input,
- )
diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/__init__.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/__init__.py
new file mode 100644
index 0000000..03fea7c
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/__init__.py
@@ -0,0 +1,22 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from vllm.model_executor.layers.fused_moe.prepare_finalize.naive_dp_ep import (
+ MoEPrepareAndFinalizeNaiveDPEPModular,
+ MoEPrepareAndFinalizeNaiveDPEPMonolithic,
+ make_moe_prepare_and_finalize_naive_dp_ep,
+)
+from vllm.model_executor.layers.fused_moe.prepare_finalize.no_dp_ep import (
+ MoEPrepareAndFinalizeNoDPEPModular,
+ MoEPrepareAndFinalizeNoDPEPMonolithic,
+ make_moe_prepare_and_finalize_no_dp_ep,
+)
+
+__all__ = [
+ "MoEPrepareAndFinalizeNaiveDPEPMonolithic",
+ "MoEPrepareAndFinalizeNaiveDPEPModular",
+ "make_moe_prepare_and_finalize_naive_dp_ep",
+ "MoEPrepareAndFinalizeNoDPEPMonolithic",
+ "MoEPrepareAndFinalizeNoDPEPModular",
+ "make_moe_prepare_and_finalize_no_dp_ep",
+]
diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py
new file mode 100644
index 0000000..6dc9f69
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py
@@ -0,0 +1,253 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import torch
+
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from vllm.distributed import get_ep_group
+from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
+from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
+ TopKWeightAndReduceContiguous,
+ TopKWeightAndReduceDelegate,
+)
+from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
+from vllm.utils.flashinfer import nvfp4_block_scale_interleave
+
+
+def _quantize_and_setup_dispatch(
+ a1: torch.Tensor,
+ quant_config: FusedMoEQuantConfig,
+ defer_input_quant: bool = False,
+) -> tuple[torch.Tensor, list[torch.Tensor] | None]:
+ # Defer input quantization to the MoE kernel.
+ if defer_input_quant:
+ a1q = a1
+ a1q_scale = None
+ else:
+ input_sf = (
+ quant_config.a1_gscale
+ if quant_config.use_nvfp4_w4a4
+ else quant_config.a1_scale
+ )
+
+ # NOTE: swizzling pads the scales to multiple of 128
+ # which makes the scales tensor different shape than
+ # the hidden states, breaking the A2A kernel. So, we
+ # delay the swizzling until after the A2A.
+ a1q, a1q_scale = a1q, a1q_scale = moe_kernel_quantize_input(
+ a1,
+ input_sf,
+ quant_dtype=quant_config.quant_dtype,
+ per_act_token_quant=quant_config.per_act_token_quant,
+ block_shape=quant_config.block_shape,
+ is_fp4_scale_swizzled=False,
+ )
+
+ # Skip gathering scales if we have static quantization
+ # (the scale is a scalar, replicated on all ranks) or
+ # if quantization is deferred.
+ skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0
+ scales = None if skip_gather_scales else [a1q_scale]
+
+ return a1q, scales
+
+
+def _unwrap_scale_and_prepare_for_moe(
+ scales: list[torch.Tensor] | None,
+ quant_config: FusedMoEQuantConfig,
+) -> torch.Tensor:
+ assert scales is not None and len(scales) == 1
+ a1q_scale = scales[0]
+ # Apply swizzling after a2a if the MoE kernel needs it.
+ if quant_config.quant_dtype == "nvfp4" and quant_config.is_nvfp4_scale_swizzled:
+ assert a1q_scale is not None
+ if a1q_scale.element_size() == 1:
+ a1q_scale = a1q_scale.view(torch.uint8)
+ a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
+
+ return a1q_scale
+
+
+class MoEPrepareAndFinalizeNaiveDPEPModular(mk.FusedMoEPrepareAndFinalizeModular):
+ """
+ Naive Prepare/Finalize for Dp/Ep case for Modular Kernels.
+
+ Uses Torch AR/RS or AR for dispatch/combine operations, applied
+ to the topk weights and ids.
+ """
+
+ def __init__(
+ self,
+ is_sequence_parallel: bool = False,
+ num_dispatchers: int = 1,
+ ) -> None:
+ super().__init__()
+ self.is_sequence_parallel = is_sequence_parallel
+ self._num_dispatchers = num_dispatchers
+
+ @property
+ def activation_format(self) -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
+
+ def max_num_tokens_per_rank(self) -> int | None:
+ return None
+
+ def topk_indices_dtype(self) -> torch.dtype | None:
+ return None
+
+ def num_dispatchers(self) -> int:
+ return self._num_dispatchers
+
+ def output_is_reduced(self) -> bool:
+ return False
+
+ def prepare(
+ self,
+ a1: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ num_experts: int,
+ expert_map: torch.Tensor | None,
+ apply_router_weight_on_input: bool,
+ quant_config: FusedMoEQuantConfig,
+ defer_input_quant: bool = False,
+ ) -> mk.PrepareResultType:
+ """Quantize and Dispatch Topk Weights and Topk Ids."""
+
+ if apply_router_weight_on_input:
+ topk = topk_ids.size(1)
+ assert topk == 1, (
+ "apply_router_weight_on_input is only implemented for topk=1"
+ )
+ # Note: do not use inplace for shared experts overlap
+ a1 = a1 * topk_weights.to(a1.dtype)
+
+ a1q, scales = _quantize_and_setup_dispatch(a1, quant_config, defer_input_quant)
+
+ res = get_ep_group().dispatch(
+ a1q,
+ topk_weights,
+ topk_ids,
+ is_sequence_parallel=self.is_sequence_parallel,
+ extra_tensors=scales,
+ )
+
+ if scales is None:
+ a1q, topk_weights, topk_ids = res
+ a1q_scale = None
+ else:
+ a1q, topk_weights, topk_ids, scales = res
+ a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config)
+
+ return a1q, a1q_scale, None, topk_ids, topk_weights
+
+ def finalize(
+ self,
+ output: torch.Tensor,
+ fused_expert_output: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ apply_router_weight_on_input: bool,
+ weight_and_reduce_impl: mk.TopKWeightAndReduce,
+ ) -> None:
+ if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
+ weight_and_reduce_impl = TopKWeightAndReduceContiguous()
+
+ out = weight_and_reduce_impl.apply(
+ output=None,
+ fused_expert_output=fused_expert_output,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ apply_router_weight_on_input=apply_router_weight_on_input,
+ )
+
+ output.copy_(
+ get_ep_group().combine(out, is_sequence_parallel=self.is_sequence_parallel)
+ )
+
+
+class MoEPrepareAndFinalizeNaiveDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMonolithic):
+ """
+ Naive Prepare/Finalize for Dp/Ep case for Modular Kernels.
+
+ Uses Torch AR/RS or AR for dispatch/combine operations, applied
+ to the router logits (the MoE kernel runs the router internally).
+ """
+
+ def __init__(
+ self,
+ is_sequence_parallel: bool = False,
+ num_dispatchers: int = 1,
+ ) -> None:
+ super().__init__()
+ self.is_sequence_parallel = is_sequence_parallel
+ self._num_dispatchers = num_dispatchers
+
+ @property
+ def activation_format(self) -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
+
+ def max_num_tokens_per_rank(self) -> int | None:
+ return None
+
+ def topk_indices_dtype(self) -> torch.dtype | None:
+ return None
+
+ def num_dispatchers(self) -> int:
+ return self._num_dispatchers
+
+ def output_is_reduced(self) -> bool:
+ return False
+
+ def prepare(
+ self,
+ a1: torch.Tensor,
+ router_logits: torch.Tensor,
+ quant_config: FusedMoEQuantConfig,
+ defer_input_quant: bool = False,
+ ) -> mk.PrepareMonolithicResultType:
+ """Quantize and Dispatch Router Logits."""
+
+ a1q, scales = _quantize_and_setup_dispatch(a1, quant_config, defer_input_quant)
+
+ res = get_ep_group().dispatch_router_logits(
+ a1q,
+ router_logits,
+ is_sequence_parallel=self.is_sequence_parallel,
+ extra_tensors=scales,
+ )
+
+ if scales is None:
+ a1q, router_logits = res
+ a1q_scale = None
+ else:
+ a1q, router_logits, scales = res
+ a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config)
+
+ return a1q, a1q_scale, router_logits
+
+ def finalize(
+ self,
+ fused_expert_output: torch.Tensor,
+ ) -> torch.Tensor:
+ out = get_ep_group().combine(
+ fused_expert_output, is_sequence_parallel=self.is_sequence_parallel
+ )
+ return out
+
+
+def make_moe_prepare_and_finalize_naive_dp_ep(
+ use_monolithic: bool,
+ is_sequence_parallel: bool = False,
+ num_dispatchers: int = 1,
+) -> MoEPrepareAndFinalizeNaiveDPEPModular | MoEPrepareAndFinalizeNaiveDPEPMonolithic:
+ return (
+ MoEPrepareAndFinalizeNaiveDPEPMonolithic(
+ is_sequence_parallel=is_sequence_parallel,
+ num_dispatchers=num_dispatchers,
+ )
+ if use_monolithic
+ else MoEPrepareAndFinalizeNaiveDPEPModular(
+ is_sequence_parallel=is_sequence_parallel,
+ num_dispatchers=num_dispatchers,
+ )
+ )
diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py
new file mode 100644
index 0000000..b9d57da
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py
@@ -0,0 +1,141 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import torch
+
+import vllm.model_executor.layers.fused_moe.modular_kernel as mk
+from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
+from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
+ TopKWeightAndReduceContiguous,
+ TopKWeightAndReduceDelegate,
+)
+from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
+
+
+def _quantize_input(
+ a1: torch.Tensor,
+ quant_config: FusedMoEQuantConfig,
+ defer_input_quant: bool = False,
+) -> tuple[torch.Tensor, torch.Tensor | None]:
+ # Defer input quant to moe kernel for backends (e.g. AITER, FI)
+ # which use a single kernel call for quant + experts.
+ if defer_input_quant:
+ return a1, None
+
+ input_sf = (
+ quant_config.a1_gscale if quant_config.use_nvfp4_w4a4 else quant_config.a1_scale
+ )
+ a1q, a1q_scale = moe_kernel_quantize_input(
+ a1,
+ input_sf,
+ quant_dtype=quant_config.quant_dtype,
+ per_act_token_quant=quant_config.per_act_token_quant,
+ block_shape=quant_config.block_shape,
+ is_fp4_scale_swizzled=quant_config.is_nvfp4_scale_swizzled,
+ )
+
+ return a1q, a1q_scale
+
+
+class MoEPrepareAndFinalizeNoDPEPModular(mk.FusedMoEPrepareAndFinalizeModular):
+ @property
+ def activation_format(self) -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
+
+ def max_num_tokens_per_rank(self) -> int | None:
+ return None
+
+ def topk_indices_dtype(self) -> torch.dtype | None:
+ return None
+
+ def num_dispatchers(self) -> int:
+ return 1
+
+ def output_is_reduced(self) -> bool:
+ return False
+
+ def prepare(
+ self,
+ a1: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ num_experts: int,
+ expert_map: torch.Tensor | None,
+ apply_router_weight_on_input: bool,
+ quant_config: FusedMoEQuantConfig,
+ defer_input_quant: bool = False,
+ ) -> mk.PrepareResultType:
+ if apply_router_weight_on_input:
+ topk = topk_ids.size(1)
+ # TODO: this only works for topK=1, will need to update for topK>1
+ assert topk == 1, (
+ "apply_router_weight_on_input is only implemented for topk=1"
+ )
+ # Note: do not use inplace for shared experts overlap
+ a1 = a1 * topk_weights.to(a1.dtype)
+
+ a1q, a1q_scale = _quantize_input(a1, quant_config, defer_input_quant)
+
+ return a1q, a1q_scale, None, None, None
+
+ def finalize(
+ self,
+ output: torch.Tensor,
+ fused_expert_output: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ apply_router_weight_on_input: bool,
+ weight_and_reduce_impl: mk.TopKWeightAndReduce,
+ ) -> None:
+ if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
+ weight_and_reduce_impl = TopKWeightAndReduceContiguous()
+ weight_and_reduce_impl.apply(
+ output=output,
+ fused_expert_output=fused_expert_output,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ apply_router_weight_on_input=apply_router_weight_on_input,
+ )
+
+
+class MoEPrepareAndFinalizeNoDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMonolithic):
+ @property
+ def activation_format(self) -> mk.FusedMoEActivationFormat:
+ return mk.FusedMoEActivationFormat.Standard
+
+ def max_num_tokens_per_rank(self) -> int | None:
+ return None
+
+ def topk_indices_dtype(self) -> torch.dtype | None:
+ return None
+
+ def num_dispatchers(self) -> int:
+ return 1
+
+ def output_is_reduced(self) -> bool:
+ return False
+
+ def prepare(
+ self,
+ a1: torch.Tensor,
+ router_logits: torch.Tensor,
+ quant_config: FusedMoEQuantConfig,
+ defer_input_quant: bool = False,
+ ) -> mk.PrepareMonolithicResultType:
+ a1q, a1q_scale = _quantize_input(a1, quant_config, defer_input_quant)
+ return a1q, a1q_scale, router_logits
+
+ def finalize(
+ self,
+ fused_expert_output: torch.Tensor,
+ ) -> torch.Tensor:
+ return fused_expert_output
+
+
+def make_moe_prepare_and_finalize_no_dp_ep(
+ use_monolithic: bool,
+) -> MoEPrepareAndFinalizeNoDPEPModular | MoEPrepareAndFinalizeNoDPEPMonolithic:
+ return (
+ MoEPrepareAndFinalizeNoDPEPMonolithic()
+ if use_monolithic
+ else MoEPrepareAndFinalizeNoDPEPModular()
+ )
diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
index 8c8439d..c550cad 100644
--- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
@@ -292,7 +292,7 @@ def rocm_aiter_fused_experts(
)
-class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute):
+class AiterExperts(mk.FusedMoEExpertsModular):
@property
def expects_unquantized_inputs(self) -> bool:
return True
diff --git a/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py b/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py
index 7608e06..b061b3d 100644
--- a/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py
+++ b/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py
@@ -20,6 +20,7 @@ import torch
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.forward_context import get_forward_context
+from vllm.platforms import current_platform
logger = logging.getLogger(__name__)
@@ -132,7 +133,7 @@ class RoutedExpertsCapturer:
self._device_buffer = torch.zeros(
(max_num_batched_tokens, num_layers, num_experts_per_tok),
dtype=torch.int32,
- device="cuda",
+ device=current_platform.device_type,
)
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
diff --git a/vllm/model_executor/layers/fused_moe/router/base_router.py b/vllm/model_executor/layers/fused_moe/router/base_router.py
index 52005d4..d5c83ce 100644
--- a/vllm/model_executor/layers/fused_moe/router/base_router.py
+++ b/vllm/model_executor/layers/fused_moe/router/base_router.py
@@ -64,7 +64,7 @@ if current_platform.is_cuda_alike():
# TODO(bowen): When using `FusedMoEModularKernel`, this
# can be done in a more unified way, since
- # `FusedMoEPrepareAndFinalize` will return the expert
+ # `FusedMoEPrepareAndFinalizeModular` will return the expert
# token count, in some cases directly from the kernel.
# However, now there are many code paths not using
# the modular kernel, e.g. calling `fused_experts`,
@@ -175,6 +175,7 @@ class BaseRouter(FusedMoERouter):
topk_ids = topk_ids.to(dtype=indices_type)
assert topk_ids.dtype == indices_type or indices_type is None
+ topk_ids = topk_ids.to(torch.int32)
return topk_ids
@abstractmethod
diff --git a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py
index 584e044..3ceb5a8 100644
--- a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py
+++ b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py
@@ -31,7 +31,7 @@ def vllm_topk_softmax(
token_expert_indices,
gating_output,
renormalize,
- e_score_correction_bias,
+ e_score_correction_bias
)
return topk_weights, topk_indices
@@ -85,13 +85,14 @@ def fused_topk_bias(
token_expert_indices = torch.empty(
M, topk, dtype=torch.int32, device=hidden_states.device
)
+ gating_output_float = gating_output.float() # TODO(woosuk): Optimize this.
if scoring_func == "softmax":
topk_weights, topk_ids = vllm_topk_softmax(
topk_weights,
topk_ids,
token_expert_indices,
- gating_output,
+ gating_output_float,
renormalize,
e_score_correction_bias,
)
@@ -186,7 +187,7 @@ class FusedTopKBiasRouter(BaseRouter):
indices_type=indices_type,
)
- if self.routed_scaling_factor != 1.0:
- topk_weights *= self.routed_scaling_factor
+ # if self.routed_scaling_factor != 1.0:
+ # topk_weights *= self.routed_scaling_factor
return topk_weights, topk_ids
diff --git a/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py b/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py
index 0ae867a..d0ad5c5 100644
--- a/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py
+++ b/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py
@@ -26,8 +26,9 @@ def vllm_topk_softmax(
topk_indices,
token_expert_indices,
gating_output,
- renormalize,
)
+ if renormalize:
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_indices
@@ -90,13 +91,14 @@ def fused_topk(
token_expert_indices = torch.empty(
M, topk, dtype=torch.int32, device=hidden_states.device
)
+ gating_output_float = gating_output.float()
if scoring_func == "softmax":
topk_func = dispatch_topk_softmax_func(
use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()
)
topk_weights, topk_ids = topk_func(
- topk_weights, topk_ids, token_expert_indices, gating_output.float(), renormalize
+ topk_weights, topk_ids, token_expert_indices, gating_output_float, renormalize
)
return topk_weights, topk_ids, token_expert_indices
@@ -105,7 +107,7 @@ def fused_topk(
use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()
)
topk_weights, topk_ids = topk_func(
- topk_weights, topk_ids, token_expert_indices, gating_output.float(), renormalize
+ topk_weights, topk_ids, token_expert_indices, gating_output_float, renormalize
)
return topk_weights, topk_ids, token_expert_indices
diff --git a/vllm/model_executor/layers/fused_moe/router/gate_linear.py b/vllm/model_executor/layers/fused_moe/router/gate_linear.py
new file mode 100644
index 0000000..26fedd5
--- /dev/null
+++ b/vllm/model_executor/layers/fused_moe/router/gate_linear.py
@@ -0,0 +1,115 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import torch
+from torch.nn.parameter import Parameter
+
+from vllm.model_executor.custom_op import PluggableLayer
+from vllm.model_executor.layers.linear import ReplicatedLinear
+from vllm.platforms import current_platform
+
+
+@PluggableLayer.register("gate_linear")
+class GateLinear(ReplicatedLinear):
+ """MoE gate linear layer with three-tier GEMM dispatch:
+
+ 1. DSV3 specialized kernel (SM90+, batch<=16, supported dims)
+ 2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype)
+ 3. F.linear via ReplicatedLinear (ultimate fallback)
+
+ The ``out_dtype`` attribute is mutable and can be set after init
+ (e.g. when the required dtype depends on the expert quantization
+ method which is only known later).
+ """
+
+ # Dimensions supported by the DSV3 specialized kernel
+ DSV3_SUPPORTED_NUM_EXPERTS = [256, 384]
+ DSV3_SUPPORTED_HIDDEN_SIZES = [7168]
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ bias: bool = False,
+ out_dtype: torch.dtype | None = None,
+ params_dtype: torch.dtype | None = None,
+ force_fp32_compute: bool = False,
+ prefix: str = "",
+ ):
+ is_hopper_or_blackwell = current_platform.is_device_capability(
+ (9, 0)
+ ) or current_platform.is_device_capability_family(100)
+ can_use_specialized_kernels = False
+
+ # If fp32 compute is required and no specialized kernel is available,
+ # store weights in fp32 so Tier 3 computes in fp32 natively.
+ if force_fp32_compute and not can_use_specialized_kernels:
+ params_dtype = torch.float32
+
+ super().__init__(
+ input_size,
+ output_size,
+ bias=bias,
+ params_dtype=params_dtype,
+ quant_config=None,
+ prefix=prefix,
+ )
+ self.out_dtype = out_dtype
+
+ # DSV3 specialized kernel eligibility (SM90+, exact dims)
+ self.allow_specialized_router_gemm = can_use_specialized_kernels
+ self.allow_dsv3_router_gemm = (
+ self.allow_specialized_router_gemm
+ and output_size in self.DSV3_SUPPORTED_NUM_EXPERTS
+ and input_size in self.DSV3_SUPPORTED_HIDDEN_SIZES
+ )
+
+ # cuBLAS bf16→fp32 eligibility
+ self.allow_cublas_router_gemm = (
+ self.allow_specialized_router_gemm
+ and self.weight.dtype == torch.bfloat16
+ and self.out_dtype == torch.float32
+ )
+
+ def set_out_dtype(self, out_dtype: torch.dtype) -> None:
+ """Set output dtype for the router logits after init.
+
+ Useful when the required dtype depends on the expert quantization
+ method which is only known after the gate is constructed.
+ """
+ if self.out_dtype is not None:
+ raise ValueError("out_dtype has already been set")
+ self.out_dtype = out_dtype
+
+ if (
+ not self.allow_cublas_router_gemm
+ and self.allow_specialized_router_gemm
+ and out_dtype == torch.float32
+ ):
+ self.allow_cublas_router_gemm = self.weight.dtype == torch.bfloat16
+
+ def forward(
+ self, x: torch.Tensor
+ ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
+ import vllm._custom_ops as ops
+
+ # Tier 1: DSV3 specialized kernel
+ if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
+ output = ops.dsv3_router_gemm(
+ hidden_states=x,
+ router_weight=self.weight,
+ output_dtype=self.out_dtype,
+ )
+ return output, None
+
+ # Tier 2: cuBLAS bf16→fp32
+ if self.allow_cublas_router_gemm and x.dtype == torch.bfloat16:
+ output = ops.router_gemm_bf16_fp32(x, self.weight)
+ return output, None
+
+ # Tier 3: F.linear (ReplicatedLinear)
+ if self.out_dtype is not None and x.dtype != self.weight.dtype:
+ x = x.to(self.weight.dtype)
+ output, output_bias = super().forward(x)
+ if self.out_dtype is not None and output.dtype != self.out_dtype:
+ output = output.to(self.out_dtype)
+ return output, output_bias
diff --git a/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py b/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py
index 486610f..99d8891 100644
--- a/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py
+++ b/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py
@@ -92,77 +92,9 @@ def grouped_topk(
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
- if (
- envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK
- and current_platform.is_cuda()
- and num_expert_group <= 32
- and topk <= 32
- and e_score_correction_bias is not None
- ):
- return fused_grouped_topk(
- hidden_states=hidden_states,
- gating_output=gating_output,
- topk=topk,
- renormalize=renormalize,
- e_score_correction_bias=e_score_correction_bias,
- num_expert_group=num_expert_group,
- topk_group=topk_group,
- scoring_func=scoring_func,
- routed_scaling_factor=routed_scaling_factor,
- )
-
- assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"
-
- if scoring_func == "softmax":
- scores = torch.softmax(gating_output, dim=-1)
- elif scoring_func == "sigmoid":
- scores = gating_output.sigmoid()
- else:
- raise ValueError(f"Unsupported scoring function: {scoring_func}")
-
- num_token = scores.size(0)
- if e_score_correction_bias is not None:
- # Store original scores before applying correction bias. We use biased
- # scores for expert selection but original scores for routing weights
- original_scores = scores
- scores = scores + e_score_correction_bias.unsqueeze(0)
- group_scores = (
- scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
- )
- else:
- group_scores = (
- scores.view(num_token, num_expert_group, -1).max(dim=-1).values
- ) # [n, n_group]
-
- # For batch invariance, use sorted=True to ensure deterministic expert selection
- use_sorted = vllm_is_batch_invariant()
- group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
- 1
- ] # [n, top_k_group]
- group_mask = torch.zeros_like(group_scores) # [n, n_group]
- group_mask.scatter_(1, group_idx, 1) # [n, n_group]
- score_mask = (
- group_mask.unsqueeze(-1)
- .expand(num_token, num_expert_group, scores.size(-1) // num_expert_group)
- .reshape(num_token, -1)
- ) # [n, e]
- tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e]
-
- if e_score_correction_bias is not None:
- topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1]
- # Use original unbiased scores for the routing weights
- topk_weights = original_scores.gather(1, topk_ids)
- else:
- topk_weights, topk_ids = torch.topk(
- tmp_scores, k=topk, dim=-1, sorted=use_sorted
- )
-
- if renormalize:
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
-
- if routed_scaling_factor != 1.0:
- topk_weights = topk_weights * routed_scaling_factor
- return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
+ from ixformer.inference.functions import moe_grouped_topk as grouped_topk
+ topk_weights, topk_ids = grouped_topk(gating_output, topk, num_expert_group, topk_group, scoring_func, e_score_correction_bias,renormalize = renormalize)
+ return topk_weights, topk_ids
# --8<-- [start:grouped_topk]
@@ -246,7 +178,6 @@ class GroupedTopk(CustomOp):
hidden_states, gating_output, e_score_correction_bias
)
-from ixformer.inference.functions import moe_grouped_topk as grouped_topk
class GroupedTopKRouter(BaseRouter):
"""Router using grouped top-k routing (e.g., DeepSeekV2/V3)."""
@@ -316,8 +247,8 @@ class GroupedTopKRouter(BaseRouter):
topk=self.top_k,
renormalize=self.renormalize,
)
- if self.routed_scaling_factor != 1.0:
- topk_weights *= self.routed_scaling_factor
+ # if self.routed_scaling_factor != 1.0:
+ # topk_weights *= self.routed_scaling_factor
else:
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states=hidden_states,
@@ -340,14 +271,14 @@ class GroupedTopKRouter(BaseRouter):
grouped_topk_impl = grouped_topk
topk_weights, topk_ids = grouped_topk_impl(
- # hidden_states=hidden_states,
+ hidden_states=hidden_states,
gating_output=router_logits,
topk=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
- # routed_scaling_factor=self.routed_scaling_factor,
+ routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
)
diff --git a/vllm/model_executor/layers/fused_moe/router/router_factory.py b/vllm/model_executor/layers/fused_moe/router/router_factory.py
index a0733ba..11027e8 100644
--- a/vllm/model_executor/layers/fused_moe/router/router_factory.py
+++ b/vllm/model_executor/layers/fused_moe/router/router_factory.py
@@ -44,7 +44,7 @@ def create_fused_moe_router(
# grouped topk + fused topk bias parameters
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
- # custom routing paramaters
+ # custom routing parameters
custom_routing_function: Callable | None = None,
# eplb parameters
enable_eplb: bool = False,
diff --git a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py
index 087b703..b69744d 100644
--- a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py
+++ b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import nullcontext
+from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F
@@ -30,6 +31,8 @@ from vllm.model_executor.layers.fused_moe.runner.moe_runner import MoERunner
from vllm.platforms import current_platform
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import (
+ HAS_OPAQUE_TYPE,
+ ModuleName,
aux_stream,
current_stream,
direct_register_custom_op,
@@ -56,13 +59,27 @@ def get_layer_from_name(layer_name: str) -> torch.nn.Module:
return forward_context.no_compile_layers[layer_name]
+# On torch >= 2.11, layer_name is a hoisted ModuleName opaque object;
+# on older versions it remains a plain str.
+if TYPE_CHECKING:
+ from typing import TypeAlias
+
+ _layer_name_type: TypeAlias = str | ModuleName
+else:
+ _layer_name_type = ModuleName if HAS_OPAQUE_TYPE else str
+
+
+def _resolve_layer_name(layer_name: str | ModuleName) -> str:
+ return layer_name.value if isinstance(layer_name, ModuleName) else layer_name
+
+
def _moe_forward(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
- layer_name: str,
+ layer_name: _layer_name_type,
) -> torch.Tensor:
- layer = get_layer_from_name(layer_name)
+ layer = get_layer_from_name(_resolve_layer_name(layer_name))
# TODO(bnell): this can be removed after MK migration is complete.
layer.ensure_moe_quant_config_init()
return layer.runner.forward_impl(
@@ -74,7 +91,7 @@ def _moe_forward_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
- layer_name: str,
+ layer_name: _layer_name_type,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
@@ -83,9 +100,9 @@ def _moe_forward_shared(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
- layer_name: str,
+ layer_name: _layer_name_type,
) -> tuple[torch.Tensor, torch.Tensor]:
- layer = get_layer_from_name(layer_name)
+ layer = get_layer_from_name(_resolve_layer_name(layer_name))
# TODO(bnell): this can be removed after MK migration is complete.
layer.ensure_moe_quant_config_init()
return layer.runner.forward_impl(
@@ -97,7 +114,7 @@ def _moe_forward_shared_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_experts_input: torch.Tensor | None,
- layer_name: str,
+ layer_name: _layer_name_type,
) -> tuple[torch.Tensor, torch.Tensor]:
# Output shapes:
# - fused_out: same as hidden_states (routed experts use transformed size)
@@ -105,12 +122,10 @@ def _moe_forward_shared_fake(
# hidden_states
# (For latent MoE: shared experts use original hidden_size, not latent size)
fused_out = torch.empty_like(hidden_states)
-
if shared_experts_input is not None:
shared_out = torch.empty_like(shared_experts_input)
else:
shared_out = torch.empty_like(hidden_states)
-
return shared_out, fused_out
@@ -165,6 +180,7 @@ class DefaultMoERunner(MoERunner):
quant_method: FusedMoEMethodBase,
reduce_results: bool,
enable_dbo: bool,
+ fused_shared_output: bool = False,
):
super().__init__()
self.moe_config = moe_config
@@ -175,6 +191,9 @@ class DefaultMoERunner(MoERunner):
self.quant_method = quant_method
self.reduce_results = reduce_results
self.enable_dbo = enable_dbo
+ self.fused_shared_output = fused_shared_output
+ if self.fused_shared_output:
+ assert self.shared_experts is not None, "Shared experts must be provided when fused_shared_output is True."
# Allow disabling of the separate shared experts stream for
# debug purposes.
@@ -195,19 +214,19 @@ class DefaultMoERunner(MoERunner):
# Needed for string -> FusedMoE layer lookup in custom ops.
self.layer_name = layer.layer_name
- if current_platform.is_tpu() or current_platform.is_cpu():
+ # if current_platform.is_tpu() or current_platform.is_cpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op.
# Note: CPU doesn't require wrapped forward_impl.
- if self.shared_experts is None:
- self.moe_forward = _moe_forward
- else:
- self.moe_forward = _moe_forward_shared
+ if self.shared_experts is None:
+ self.moe_forward = _moe_forward
else:
- if self.shared_experts is None:
- self.moe_forward = torch.ops.vllm.moe_forward
- else:
- self.moe_forward = torch.ops.vllm.moe_forward_shared
+ self.moe_forward = _moe_forward_shared
+ # else:
+ # if self.shared_experts is None:
+ # self.moe_forward = torch.ops.vllm.moe_forward
+ # else:
+ # self.moe_forward = torch.ops.vllm.moe_forward_shared
# Chunked all2all staging tensor
self.batched_hidden_states: torch.Tensor | None = None
@@ -216,8 +235,7 @@ class DefaultMoERunner(MoERunner):
@property
def use_dp_chunking(self) -> bool:
return (
- self.moe_config.moe_parallel_config.use_pplx_kernels
- or self.moe_config.moe_parallel_config.use_deepep_ll_kernels
+ self.moe_config.moe_parallel_config.use_deepep_ll_kernels
or self.moe_config.moe_parallel_config.use_mori_kernels
or self.moe_config.moe_parallel_config.use_fi_all2allv_kernels
) and envs.VLLM_ENABLE_MOE_DP_CHUNK
@@ -306,8 +324,8 @@ class DefaultMoERunner(MoERunner):
"""
assert self.quant_method is not None
return (
- self.quant_method.moe_mk is not None
- and self.quant_method.moe_mk.output_is_reduced()
+ self.quant_method.moe_kernel is not None
+ and self.quant_method.moe_kernel.output_is_reduced()
)
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
@@ -362,13 +380,15 @@ class DefaultMoERunner(MoERunner):
if isinstance(states, tuple):
return tuple(
- [func(s, trunc_size) for s, trunc_size in zip(states, trunc_sizes)]
+ [None if s is None else func(s, trunc_size) for s, trunc_size in zip(states, trunc_sizes)]
)
else:
assert len(trunc_sizes) == 1
return func(states, trunc_sizes[0])
- def _encode_layer_name(self) -> str:
+ def _encode_layer_name(self) -> str | ModuleName:
+ if HAS_OPAQUE_TYPE:
+ return ModuleName(self.layer_name)
# Can be unavailable or None in unittests
if (
is_forward_context_available()
@@ -624,53 +644,27 @@ class DefaultMoERunner(MoERunner):
)
with sp_ctx:
- extra_tensors = None
- if do_naive_dispatch_combine:
- post_quant_allgather = (
- self.quant_method is not None
- and self.moe_config.dp_size > 1
- and self.moe_config.use_ep
- and getattr(self.quant_method, "do_post_quant_allgather", False)
- )
- if post_quant_allgather:
- hidden_states_to_dispatch, extra_tensors = (
- self.quant_method.prepare_dp_allgather_tensor(
- layer, hidden_states, router_logits
- )
- )
- else:
- hidden_states_to_dispatch = hidden_states
-
- dispatch_res = get_ep_group().dispatch_router_logits(
- hidden_states_to_dispatch,
- router_logits,
- self.moe_config.is_sequence_parallel,
- extra_tensors=extra_tensors,
- )
- if extra_tensors is not None:
- (
- orig_hidden_states,
- router_logits,
- extra_tensors_combined,
- ) = dispatch_res
- hidden_states_combined = (
- orig_hidden_states,
- extra_tensors_combined[0],
- )
- else:
- hidden_states_combined, router_logits = dispatch_res
- orig_hidden_states = hidden_states_combined
- else:
- orig_hidden_states = hidden_states
-
# Run shared experts before matrix multiply.
# because matrix multiply maybe modify the hidden_states.
- if has_separate_shared_experts and not use_shared_experts_stream:
+ if has_separate_shared_experts: # and not use_shared_experts_stream:
assert self.shared_experts is not None
shared_input = (
shared_input if shared_input is not None else hidden_states
)
shared_output = self.shared_experts(shared_input)
+ else:
+ assert self.fused_shared_output == False, "fused_shared_output is only supported when has_separate_shared_experts is True"
+ shared_output = None
+ # For naive dispatch/combine Dp/Ep, dispatch the hidden states and
+ # router logits to all experts.
+ # NOTE: this will be removed once all kernels are migrated into the
+ # MoEKernel framework.
+ if do_naive_dispatch_combine:
+ hidden_states, router_logits = get_ep_group().dispatch_router_logits(
+ hidden_states,
+ router_logits,
+ self.moe_config.is_sequence_parallel,
+ )
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
@@ -685,42 +679,33 @@ class DefaultMoERunner(MoERunner):
dim=0,
)
- # TODO(bnell): deal with fp4 flashinfer tuple hidden states hack (#30014).
- # Figure out nicer way to do this.
- if do_naive_dispatch_combine:
- x = hidden_states_combined
- x_orig = orig_hidden_states
- else:
- x = hidden_states
- x_orig = hidden_states
-
# Matrix multiply.
if self.quant_method.is_monolithic:
final_hidden_states = self.quant_method.apply_monolithic(
layer=layer,
- x=x,
+ x=hidden_states,
router_logits=router_logits,
)
else:
topk_weights, topk_ids = self.router.select_experts(
- hidden_states=x_orig,
+ hidden_states=hidden_states,
router_logits=router_logits,
)
final_hidden_states = self.quant_method.apply(
layer=layer,
- x=x, # The type signture of this is wrong due to the hack.
+ x=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
- shared_experts_input=shared_input,
- router_logits=router_logits,
- top_k=topk_ids.shape[-1]
+ # Assign the value of shared_experts_output to variable shared_experts_input for fusion
+ shared_experts_input=shared_output if self.fused_shared_output else None,
)
if has_separate_shared_experts:
assert self.shared_experts is not None
if use_shared_experts_stream:
+ assert use_shared_experts_stream == False, "Running shared experts in parallel with the main MoE execution is currently not supported!"
# Run shared experts in parallel on a separate stream
# NOTE: We start the separate stream here and mark the
# sync end point immediately after it is done. This is
@@ -733,7 +718,7 @@ class DefaultMoERunner(MoERunner):
current_stream().wait_stream(self.shared_experts_stream)
final_hidden_states = (
- shared_output,
+ None if self.fused_shared_output else shared_output,
final_hidden_states,
)
diff --git a/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py b/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py
index 99d4038..4cebe60 100644
--- a/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py
+++ b/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py
@@ -10,14 +10,15 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce):
"""
- Useful in the case when some FusedMoEPermuteExpertsUnpermute
+ Useful in the case when some FusedMoEExpertsModular
implementation does not perform weight application and reduction
but cannot address the needs of all the compatible PrepareAndFinalize
implementations.
- For example, BatchedTritonExperts is compatible with both
- PplxPrepareAndFinalize and BatchedPrepareAndFinalize. PplxPrepareAndFinalize
- does the weight-application + reduction as part of the pplx combine kernel.
- But the BatchedPrepareAndFinalize needs an implementation. To facilitate
+ For example, BatchedTritonExperts is compatible with both batched
+ PrepareAndFinalize implementations like DeepEPLLPrepareAndFinalize and
+ BatchedPrepareAndFinalize. Some PrepareAndFinalize implementations do
+ the weight-application + reduction as part of the combine kernel, while
+ BatchedPrepareAndFinalize needs an explicit implementation. To facilitate
this case, the BatchedTritonExperts could use TopKWeightAndReduceDelegate
so the PrepareAndFinalize implementations could choose how to
weight + reduce.
@@ -61,7 +62,7 @@ class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce):
if output is None:
return fused_expert_output
- # MoEPrepareAndFinalizeNoEP needs the output to be in the `output`
+ # MoEPrepareAndFinalizeNoDPEPModular needs the output to be in the `output`
# tensor.
assert output.size() == fused_expert_output.size(), (
"output shape is expected to match the fused_expert_output shape. "
diff --git a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py
index 21a3d05..4aa396d 100644
--- a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py
+++ b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py
@@ -32,8 +32,8 @@ class TritonOrCutlassExperts(FallbackExperts):
@staticmethod
def get_clses() -> tuple[
- type[mk.FusedMoEPermuteExpertsUnpermute],
- type[mk.FusedMoEPermuteExpertsUnpermute],
+ type[mk.FusedMoEExpertsModular],
+ type[mk.FusedMoEExpertsModular],
]:
return (CutlassExpertsFp8, TritonExperts)
@@ -77,7 +77,7 @@ class TritonOrCutlassExperts(FallbackExperts):
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
- ) -> mk.FusedMoEPermuteExpertsUnpermute:
+ ) -> mk.FusedMoEExpertsModular:
# Small batch fallback for sm100.
if self.is_sm100 and hidden_states.shape[0] <= 8:
return self.fallback_experts
diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
index a3f2f59..b601806 100644
--- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py
@@ -32,8 +32,8 @@ class TritonOrDeepGemmExperts(FallbackExperts):
@staticmethod
def get_clses() -> tuple[
- type[mk.FusedMoEPermuteExpertsUnpermute],
- type[mk.FusedMoEPermuteExpertsUnpermute],
+ type[mk.FusedMoEExpertsModular],
+ type[mk.FusedMoEExpertsModular],
]:
return (DeepGemmExperts, TritonExperts)
@@ -79,7 +79,7 @@ class TritonOrDeepGemmExperts(FallbackExperts):
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
- ) -> mk.FusedMoEPermuteExpertsUnpermute:
+ ) -> mk.FusedMoEExpertsModular:
if is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2):
return self.experts
else:
diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py
index 2bd4cd7..5160840 100644
--- a/vllm/model_executor/layers/fused_moe/trtllm_moe.py
+++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py
@@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
-class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
+class TrtLlmGenExperts(mk.FusedMoEExpertsModular):
"""TensorRT-LLM-based fused MoE expert implementation."""
def __init__(
diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
index a293145..00d4124 100644
--- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
+++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py
@@ -24,8 +24,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat,
- FusedMoEPermuteExpertsUnpermute,
- FusedMoEPrepareAndFinalize,
+ FusedMoEExpertsModular,
+ FusedMoEPrepareAndFinalizeModular,
)
from vllm.model_executor.layers.fused_moe.oracle.unquantized import (
UnquantizedMoeBackend,
@@ -42,9 +42,9 @@ from vllm.platforms.interface import CpuArchEnum
if current_platform.is_cuda_alike() or current_platform.is_xpu():
from .fused_batched_moe import BatchedTritonExperts
- from .fused_moe import TritonExperts
else:
TritonExperts = None # type: ignore
+ fused_experts = None
logger = init_logger(__name__)
@@ -70,7 +70,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self.rocm_aiter_moe_enabled = (
rocm_aiter_ops.is_fused_moe_enabled() and moe.is_act_and_mul
)
- self.kernel: mk.FusedMoEModularKernel | None = None
+ self.kernel: mk.FusedMoEKernel | None = None
self._is_monolithic = (
current_platform.is_cpu()
or self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM
@@ -107,7 +107,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
- ) -> FusedMoEPrepareAndFinalize | None:
+ ) -> FusedMoEPrepareAndFinalizeModular | None:
if self.unquantized_backend == UnquantizedMoeBackend.AITER:
return None
else:
@@ -115,9 +115,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def select_gemm_impl(
self,
- prepare_finalize: FusedMoEPrepareAndFinalize,
+ prepare_finalize: FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module,
- ) -> FusedMoEPermuteExpertsUnpermute:
+ ) -> FusedMoEExpertsModular:
assert self.moe_quant_config is not None
if (
prepare_finalize.activation_format
@@ -296,16 +296,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ # Assign the value of shared_experts_output to variable shared_experts_input for fusion
shared_experts_input: torch.Tensor | None,
- **kwargs
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- return self.forward(
+ result = self.forward(
layer=layer,
x=x,
topk_weights=topk_weights,
topk_ids=topk_ids,
+ # not used
shared_experts_input=shared_experts_input,
- )
+ ) * layer.routed_scaling_factor
+ if shared_experts_input is not None:
+ result += shared_experts_input
+ return result
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
if self.moe.has_bias:
@@ -333,10 +337,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
+ quant_config=self.moe_quant_config,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
- expert_map=layer.expert_map,
- shared_experts_input=shared_experts_input,
+ expert_map=layer.expert_map
)
def forward_monolithic_cuda(
diff --git a/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py
index e6f8b8e..0693a25 100644
--- a/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py
@@ -23,7 +23,7 @@ if current_platform.is_xpu():
from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe
-class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute):
+class XPUExperts(mk.FusedMoEExpertsModular):
def __init__(
self,
moe_config: FusedMoEConfig,
diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py
index e55cd39..04e35d4 100644
--- a/vllm/model_executor/layers/layernorm.py
+++ b/vllm/model_executor/layers/layernorm.py
@@ -82,11 +82,12 @@ def fused_add_rms_norm(
return rms_norm_batch_invariant(
x + residual, weight, variance_epsilon
), x + residual
- ops.fused_add_rms_norm(
+ x, residual = ops.fused_add_rms_norm(
x,
residual,
weight,
variance_epsilon,
+ residual_alpha,
)
return x, residual
@@ -125,7 +126,7 @@ def dispatch_rocm_rmsnorm_func(
return fused_add_rms_norm
return rms_norm
-
+
def rms_norm_qk(
input_q: torch.Tensor,
input_k: torch.Tensor,
@@ -140,11 +141,7 @@ def rms_norm_qk(
output_q, output_k, input_q, input_k, weight_q, weight_k, epsilon
)
return output_q, output_k
-
-
-def dispatch_cuda_rmsnorm_qk_func() -> callable:
- return rms_norm_qk
-
+
@CustomOp.register("rms_norm_qk")
class RMSNormQK(CustomOp):
@@ -226,8 +223,7 @@ class RMSNormQK(CustomOp):
f"[RMSNormQK] Expected input_q and input_k to have same dtype, "
f"but got {input_q.dtype} vs {input_k.dtype}"
)
- norm_func = dispatch_cuda_rmsnorm_qk_func()
- return norm_func(
+ return rms_norm_qk(
input_q,
input_k,
weight_q,
@@ -264,7 +260,7 @@ class RMSNormQK(CustomOp):
f"eps={self.variance_epsilon}, "
)
-
+# --8<-- [start:rms_norm]
@CustomOp.register("rms_norm")
class RMSNorm(CustomOp):
"""Root mean square normalization.
@@ -375,7 +371,7 @@ class RMSNorm(CustomOp):
# otherwise Inductor eliminates the casts to and from f16,
# increasing memory usage (and complicating pattern matching)
x = x + residual
- residual = x.to(orig_dtype).contiguous()
+ residual = x.to(orig_dtype)
if x.shape[-1] != hidden_size:
raise ValueError(
@@ -425,6 +421,7 @@ class RMSNorm(CustomOp):
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
+ residual_alpha: float = 1.0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
@@ -499,7 +496,7 @@ class RMSNorm(CustomOp):
add_residual = residual is not None
if add_residual:
return fused_add_rms_norm(
- x, residual, self.weight.data, self.variance_epsilon
+ x, residual, self.weight.data, self.variance_epsilon,residual_alpha
)
else:
return rms_norm(x, self.weight.data, self.variance_epsilon)
@@ -649,6 +646,7 @@ class RMSNormGated(CustomOp):
norm_before_gate: bool = False,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
+ activation: str = "swish",
):
"""Initialize RMSNormGated.
@@ -663,10 +661,12 @@ class RMSNormGated(CustomOp):
If False and z is provided: out = norm(x * silu(z))
device: Device to create parameters on
dtype: Data type for parameters
+ activation: Activation function name for gating
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
+ self.activation = activation
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter("bias", None)
self.group_size = group_size
@@ -693,6 +693,11 @@ class RMSNormGated(CustomOp):
- norm_before_gate=True: out = norm(x) * silu(z)
- norm_before_gate=False: out = norm(x * silu(z))
"""
+ orig_dtype = x.dtype
+ x = x.float()
+ weight = self.weight.float()
+ z = z.float() if z is not None else None
+
# Apply gating before normalization if needed
if z is not None and not self.norm_before_gate:
x = x * F.silu(z)
@@ -702,7 +707,7 @@ class RMSNormGated(CustomOp):
# Standard RMS norm across the last dimension
variance = x.pow(2).mean(dim=-1, keepdim=True)
x_normed = x * torch.rsqrt(variance + self.eps)
- out = x_normed * self.weight
+ out = x_normed * weight
else:
# Group RMS norm
from einops import rearrange
@@ -710,13 +715,13 @@ class RMSNormGated(CustomOp):
x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size)
variance = x_group.pow(2).mean(dim=-1, keepdim=True)
x_normed = x_group * torch.rsqrt(variance + self.eps)
- out = rearrange(x_normed, "... g d -> ... (g d)") * self.weight
+ out = rearrange(x_normed, "... g d -> ... (g d)") * weight
# Apply gating after normalization if needed
if z is not None and self.norm_before_gate:
out = out * F.silu(z)
- return out
+ return out.to(orig_dtype)
def forward_cuda(
self, x: torch.Tensor, z: torch.Tensor | None = None
@@ -731,6 +736,7 @@ class RMSNormGated(CustomOp):
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate,
+ activation=self.activation,
)
diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py
index a3dcdf8..5d0bfe0 100644
--- a/vllm/model_executor/layers/linear.py
+++ b/vllm/model_executor/layers/linear.py
@@ -2,8 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
+import ast, re
from abc import abstractmethod
-from typing import Any
import torch
from torch.nn.parameter import Parameter, UninitializedParameter
@@ -16,6 +16,7 @@ from vllm.distributed import (
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
+import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.batch_invariant import (
@@ -28,7 +29,9 @@ from vllm.model_executor.layers.quantization.base_config import (
)
from vllm.model_executor.layers.utils import (
dispatch_unquantized_gemm,
- is_layer_moe_router_gate,
+ parse_opt_exclude_layers,
+ weight_quant_l1,
+ weight_quant_l2,
)
from vllm.model_executor.parameter import (
BasevLLMParameter,
@@ -41,12 +44,11 @@ from vllm.model_executor.parameter import (
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
-import vllm.envs as envs
+from compressed_tensors.quantization import QuantizationStrategy
logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [
- "UnquantizedLinearMethod",
"CompressedTensorsLinearMethod",
"CompressedTensorsLinearTransformMethod",
"AWQMarlinLinearMethod",
@@ -66,6 +68,14 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"PetitNvFp4LinearMethod",
]
+LINEAR_OPT_SUPPORTED = [
+ "ColumnParallelLinear",
+ "ReplicatedLinear",
+ "RowParallelLinear",
+ "QKVParallelLinear",
+ "MergedColumnParallelLinear",
+]
+
def adjust_marlin_shard(
param: Parameter,
@@ -135,44 +145,6 @@ def adjust_scalar_to_fused_array(
return param_data[shard_id], loaded_weight
-# TODO(Isotr0py): We might need a more flexible structure to handle
-# bitsandbytes shard offsets.
-def left_shift_bitsandbytes_4bit_shard(
- bnb_weight_attrs: dict[str, Any],
-) -> tuple[dict[str, Any], dict[str, Any]]:
- """
- Separate the BitsAndBytes 4-bit shard.
-
- For example, given bnb weight attributes as below:
- {
- 'bnb_shard_offsets': array([0, 4, 8, 16]),
- 'bnb_quant_state': {0: ..., 1: ..., 2: ...},
- }
-
- The function will return:
- {
- 'bnb_shard_offsets': array([0, 4]),
- 'bnb_quant_state': {0: ...},
- }
- and
- {
- 'bnb_shard_offsets': array([0, 4, 12]),
- 'bnb_quant_state': {0: ..., 1: ...},
- }
- """
- shard_offsets = bnb_weight_attrs["bnb_shard_offsets"]
- offset_l = shard_offsets[:2]
- offset_r = shard_offsets[1:] - shard_offsets[1]
- quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]}
- quant_state_r = {
- i - 1: bnb_weight_attrs["bnb_quant_state"][i]
- for i in range(1, len(shard_offsets) - 1)
- }
- left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l)
- right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r)
- return left, right
-
-
class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods."""
@@ -231,17 +203,11 @@ class UnquantizedLinearMethod(LinearMethodBase):
# The weights are not quantized, and they are not sharded.
# The amount of memory allocated for the weights is
# sum(output_partition_sizes) * input_size_per_partition.
- weight_loader = extra_weight_attrs.pop("weight_loader")
- weight = ModelWeightParameter(
- data=torch.empty(
- sum(output_partition_sizes),
- input_size_per_partition,
- dtype=params_dtype,
- ),
- input_dim=1,
- output_dim=0,
- weight_loader=weight_loader,
- )
+ weight = Parameter(torch.empty(sum(output_partition_sizes),
+ input_size_per_partition,
+ dtype=params_dtype),
+ requires_grad=False)
+ set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
@@ -258,11 +224,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
- if (
- vllm_is_batch_invariant()
- and current_platform.is_cuda_alike()
- and is_layer_moe_router_gate(getattr(layer, "prefix", ""))
- ):
+ if vllm_is_batch_invariant() and current_platform.is_cuda_alike():
return linear_batch_invariant(x, layer.weight, bias)
return dispatch_unquantized_gemm()(layer, x, layer.weight, bias)
@@ -305,15 +267,31 @@ class LinearBase(PluggableLayer):
self.quant_config = quant_config
self.prefix = prefix
self.allow_fp8_block_shape_mismatch = False
- if quant_config is None:
+ self.opt_level = envs.VLLM_LINEAR_OPT_LEVEL
+ if parse_opt_exclude_layers(envs.VLLM_LINEAR_SPECIFIED_LAYERS, self.prefix) or \
+ (envs.VLLM_LINEAR_SPECIFIED_KEYS != "" and envs.VLLM_LINEAR_SPECIFIED_KEYS in self.prefix):
+ self.opt_level = envs.VLLM_LINEAR_SPECIFIED_OPT_LEVEL
+ self.opt_flag = quant_config is None and self.opt_level != 0 and \
+ self.__class__.__name__ in LINEAR_OPT_SUPPORTED
+
+ if parse_opt_exclude_layers(envs.VLLM_OPT_EXCLUDE_LAYERS, self.prefix):
+ self.opt_flag = False
+ logger.info(f"Excluding layer {self.prefix} from optimization")
+
+ if self.opt_flag:
+ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsLinearMethod
+ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import CompressedTensorsW8A8Int8
+ self.quant_method: QuantizeMethodBase | None = CompressedTensorsLinearMethod(None)
+ self.scheme = CompressedTensorsW8A8Int8(QuantizationStrategy.CHANNEL, False, True, is_w4a8_linear=True if self.opt_level == 2 else False)
+ elif quant_config is None:
self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
self.return_bias = return_bias
+ self.output_padding_size = 0
self.disable_tp = disable_tp
self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0
self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1
- self.output_padding_size = 0
def update_param_tp_status(self):
for param in self.parameters():
@@ -402,7 +380,7 @@ class ReplicatedLinear(LinearBase):
# If the weight on disk does not have a shape, give it one
# (such scales for AutoFp8).
# Special case for GGUF
-
+
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
@@ -419,7 +397,17 @@ class ReplicatedLinear(LinearBase):
f"Tried to load weights of size {loaded_weight.size()}"
f"to a parameter of size {param.size()}"
)
+ if self.opt_flag:
+ if self.opt_level == 1:
+ loaded_weight, scale = weight_quant_l1(loaded_weight)
+ elif self.opt_level == 2:
+ loaded_weight, scale = weight_quant_l2(loaded_weight, format="NN")
+
param.data.copy_(loaded_weight)
+ if self.opt_flag:
+ params_dict = dict(self.named_parameters())
+ scale_param = params_dict["weight_scale"]
+ scale_param.data.copy_(scale)
def forward(
self,
@@ -609,7 +597,18 @@ class ColumnParallelLinear(LinearBase):
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
+
+ if self.opt_flag:
+ if self.opt_level == 1:
+ loaded_weight, scale = weight_quant_l1(loaded_weight)
+ elif self.opt_level == 2:
+ loaded_weight, scale = weight_quant_l2(loaded_weight, format="NN")
+
param.load_column_parallel_weight(loaded_weight=loaded_weight)
+ if self.opt_flag:
+ params_dict = dict(self.named_parameters())
+ scale_param = params_dict["weight_scale"]
+ scale_param.load_column_parallel_weight(loaded_weight=scale)
def forward(
self,
@@ -733,16 +732,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
loaded_shard_id: tuple[int, ...] | int | None = None,
):
self.validate_shard_id(loaded_shard_id)
- # FIXME(Isotr0py): Enable tuple shard_id for BNB quantization.
- if isinstance(loaded_shard_id, tuple):
- raise NotImplementedError(
- "Shard id with multiple indices is not supported in weight_loader, "
- "please use weight_loader_v2 instead."
- )
# Special case for GGUF
# initialize GGUF param after we know the quantize type
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
+ if isinstance(loaded_shard_id, tuple) and (
+ is_gguf_weight or is_gguf_weight_type
+ ):
+ raise NotImplementedError(
+ "Shard id with multiple indices is not supported for GGUF."
+ )
if is_gguf_weight_type:
if loaded_shard_id is not None:
param.data[loaded_shard_id].copy_(loaded_weight)
@@ -770,7 +769,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Special case for per-tensor scale to load scalar into fused array.
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
- if loaded_shard_id is None:
+ if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
# Loaded weight is already fused on disk (mlp).
# (e.g., Phi-3's gate_up_proj).
if output_dim is None:
@@ -782,10 +781,25 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
+
+ output_sizes = (
+ self.output_sizes[loaded_shard_id[0] : loaded_shard_id[-1] + 1]
+ if loaded_shard_id is not None
+ else self.output_sizes
+ )
current_shard_offset = 0
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
+ if (
+ use_bitsandbytes_4bit
+ and isinstance(loaded_shard_id, tuple)
+ and self.tp_size > 1
+ ):
+ raise NotImplementedError(
+ "Shard id with multiple indices is not supported "
+ "for BNB quantization with TP yet."
+ )
shard_offsets: list[tuple[int, int, int]] = []
- for i, output_size in enumerate(self.output_sizes):
+ for i, output_size in enumerate(output_sizes):
shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size
packed_dim = getattr(param, "packed_dim", None)
@@ -850,9 +864,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
if use_bitsandbytes_4bit:
- shard_size = loaded_weight.shape[output_dim]
- shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id
-
+ index = list(itertools.accumulate([0] + self.output_sizes))
+ orig_offsets = {
+ str(i): (index[i], size) for i, size in enumerate(self.output_sizes)
+ }
+ orig_offsets["total"] = (self.output_size, 0)
+ shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
+ param, orig_offsets, str(loaded_shard_id)
+ )
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
start_idx = self.tp_rank * shard_size
if not is_sharded_weight:
@@ -921,12 +940,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
loaded_weight: torch.Tensor,
loaded_shard_id: tuple[int, ...] | int | None = None,
):
+ if self.opt_flag:
+ if self.opt_level == 1:
+ loaded_weight, scale = weight_quant_l1(loaded_weight)
+ elif self.opt_level == 2:
+ loaded_weight, scale = weight_quant_l2(loaded_weight, format="NN")
self.validate_shard_id(loaded_shard_id)
- dtype = loaded_weight.dtype
- if envs.VLLM_W8A8_LINEAR_USE_W4A8 and not (param.shape[0] == 1 or param.shape[1] == 1) and dtype == torch.int8:
- load_sizes = [self.output_sizes[i] // 2 for i in range(len(self.output_sizes))]
- else:
- load_sizes = self.output_sizes
if loaded_shard_id is None or isinstance(loaded_shard_id, tuple):
if isinstance(param, PerTensorScaleParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
@@ -953,19 +972,21 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert loaded_shard_id < len(self.output_sizes)
- # shard_offset = sum(self.output_sizes[:loaded_shard_id])
- # shard_size = self.output_sizes[loaded_shard_id]
- shard_offset = sum(load_sizes[:loaded_shard_id])
- shard_size = load_sizes[loaded_shard_id]
+ shard_offset = sum(self.output_sizes[:loaded_shard_id])
+ shard_size = self.output_sizes[loaded_shard_id]
shard_offset //= self.tp_size
shard_size //= self.tp_size
+ scale_shard_offset = shard_offset
+ scale_shard_size = shard_size
+ if self.opt_flag and self.opt_level == 2:
+ shard_offset = shard_offset // 2
+ shard_size = shard_size // 2
if isinstance(param, BlockQuantScaleParameter):
weight_block_size = getattr(self, "weight_block_size", None)
shard_size, shard_offset = adjust_block_scale_shard(
weight_block_size, shard_size, shard_offset
)
-
param.load_merged_column_weight(
loaded_weight=loaded_weight,
shard_id=loaded_shard_id,
@@ -973,6 +994,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size=shard_size,
tp_rank=self.tp_rank,
)
+ if self.opt_flag:
+ params_dict = dict(self.named_parameters())
+ scale_param = params_dict["weight_scale"]
+ scale_param.load_merged_column_weight(
+ loaded_weight=scale,
+ shard_id=loaded_shard_id,
+ shard_offset=scale_shard_offset,
+ shard_size=scale_shard_size,
+ tp_rank=self.tp_rank,
+ )
class QKVParallelLinear(ColumnParallelLinear):
@@ -1128,12 +1159,24 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = param.adjust_shard_indexes_for_packing(
shard_size=shard_size, shard_offset=shard_offset
)
-
- loaded_weight_shard = loaded_weight.narrow(
- param.output_dim, shard_offset, shard_size
- )
+ if self.opt_level == 2:
+ loaded_weight_shard = loaded_weight.narrow(
+ 0, shard_offset, shard_size
+ )
+ else:
+ loaded_weight_shard = loaded_weight.narrow(
+ param.output_dim, shard_offset, shard_size
+ )
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
+ def quant(self, loaded_weight: torch.Tensor):
+ if self.opt_flag:
+ if self.opt_level == 1:
+ return weight_quant_l1(loaded_weight)
+ elif self.opt_level == 2:
+ return weight_quant_l2(loaded_weight, format="NN")
+ return loaded_weight, None
+
def weight_loader_v2(
self,
param: BasevLLMParameter,
@@ -1141,15 +1184,27 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_shard_id: str | None = None,
):
self.validate_shard_id(loaded_shard_id)
+ params_dict = dict(self.named_parameters())
if loaded_shard_id is None: # special case for certain models
if isinstance(param, PerTensorScaleParameter):
+ loaded_weight, scale = self.quant(loaded_weight)
param.load_qkv_weight(
loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank
)
+ if self.opt_flag:
+ scale_param = params_dict["weight_scale"]
+ scale_param.load_qkv_weight(
+ loaded_weight=scale, shard_id=0, tp_rank=self.tp_rank
+ )
return
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
+ loaded_weight, scale = self.quant(loaded_weight)
param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank)
+ if self.opt_flag:
+ scale_param = params_dict["weight_scale"]
+ scale_param.load_qkv_weight(loaded_weight=scale, tp_rank=self.tp_rank)
return
+
# TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight)
return
@@ -1158,11 +1213,15 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
shard_size = self._get_shard_size_mapping(loaded_shard_id)
- dtype = loaded_weight.dtype
- # w4a8 gemm需要除2,scale 不需要
- if envs.VLLM_W8A8_LINEAR_USE_W4A8 and not (param.shape[0] == 1 or param.shape[1] == 1) and dtype == torch.int8:
- shard_offset //= 2
- shard_size //= 2
+
+ scale_shard_offset = shard_offset
+ scale_shard_size = shard_size
+
+ loaded_weight, scale = self.quant(loaded_weight)
+
+ if self.opt_flag and self.opt_level == 2:
+ shard_offset = shard_offset // 2
+ shard_size = shard_size // 2
if isinstance(param, BlockQuantScaleParameter):
weight_block_size = getattr(self, "weight_block_size", None)
@@ -1179,6 +1238,15 @@ class QKVParallelLinear(ColumnParallelLinear):
tp_rank=self.tp_rank,
)
+ if self.opt_flag:
+ scale_param = params_dict["weight_scale"]
+ scale_param.load_qkv_weight(loaded_weight=scale,
+ num_heads=self.num_kv_head_replicas,
+ shard_id=loaded_shard_id,
+ shard_offset=scale_shard_offset,
+ shard_size=scale_shard_size,
+ tp_rank=self.tp_rank)
+
def weight_loader(
self,
param: Parameter,
@@ -1525,7 +1593,17 @@ class RowParallelLinear(LinearBase):
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
+ if self.opt_flag:
+ if self.opt_level == 1:
+ loaded_weight, scale = weight_quant_l1(loaded_weight)
+ elif self.opt_level == 2:
+ loaded_weight, scale = weight_quant_l2(loaded_weight, format="NN")
+
param.load_row_parallel_weight(loaded_weight=loaded_weight)
+ if self.opt_flag:
+ params_dict = dict(self.named_parameters())
+ scale_param = params_dict["weight_scale"]
+ scale_param.load_row_parallel_weight(loaded_weight=scale)
def forward(
self,
diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py
index dd2a61b..2e6fd19 100644
--- a/vllm/model_executor/layers/logits_processor.py
+++ b/vllm/model_executor/layers/logits_processor.py
@@ -61,7 +61,10 @@ class LogitsProcessor(CustomOp):
logits = hidden_states
else:
# Get the logits for the next tokens.
- logits = self._get_logits(hidden_states, lm_head, embedding_bias)
+ if hidden_states.shape[0] > 0:
+ logits = self._get_logits(hidden_states, lm_head, embedding_bias)
+ else:
+ logits = torch.empty([0, lm_head.weight.shape[0]], device=hidden_states.device, dtype=hidden_states.dtype)
if logits is not None:
if self.soft_cap is not None:
logits = logits / self.soft_cap
diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py
index 8b5f80f..8021418 100644
--- a/vllm/model_executor/layers/mamba/linear_attn.py
+++ b/vllm/model_executor/layers/mamba/linear_attn.py
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
+from collections.abc import Callable
import torch
import torch.nn.functional as F
@@ -43,7 +44,6 @@ class MiniMaxText01RMSNormTP(CustomOp):
self.weight.weight_loader = self.weight_loader
self.variance_epsilon = eps
- return
@staticmethod
def weight_loader(
@@ -56,7 +56,6 @@ class MiniMaxText01RMSNormTP(CustomOp):
shard_size = loaded_weight.shape[0] // tp_world
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
param.data.copy_(loaded_weight[shard])
- return
def _forward(
self,
@@ -102,6 +101,101 @@ class MiniMaxText01RMSNormTP(CustomOp):
return q, k
+def clear_linear_attention_cache_for_new_sequences(
+ kv_cache: torch.Tensor,
+ state_indices_tensor: torch.Tensor,
+ attn_metadata: LinearAttentionMetadata,
+) -> None:
+ num_prefills = getattr(attn_metadata, "num_prefills", 0)
+ if num_prefills <= 0:
+ return
+
+ num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", 0)
+ for prefill_idx in range(num_prefills):
+ q_start = attn_metadata.query_start_loc[num_decode_tokens + prefill_idx]
+ q_end = attn_metadata.query_start_loc[num_decode_tokens + prefill_idx + 1]
+ query_len = q_end - q_start
+ context_len = (
+ attn_metadata.seq_lens[num_decode_tokens + prefill_idx] - query_len
+ )
+ if context_len == 0:
+ block_to_clear = state_indices_tensor[num_decode_tokens + prefill_idx]
+ kv_cache[block_to_clear, ...] = 0
+
+
+def linear_attention_decode(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ kv_cache: torch.Tensor,
+ slope_rate: torch.Tensor,
+ state_indices_tensor: torch.Tensor,
+ q_start: int = 0,
+ q_end: int | None = None,
+ slot_start: int = 0,
+ slot_end: int | None = None,
+ block_size: int = 32,
+) -> torch.Tensor:
+ q = q[q_start:q_end].unsqueeze(2).contiguous()
+ k = k[q_start:q_end].unsqueeze(2).contiguous()
+ v = v[q_start:q_end].unsqueeze(2).contiguous()
+ slot_id = state_indices_tensor[slot_start:slot_end]
+ return linear_decode_forward_triton(
+ q, k, v, kv_cache, slope_rate, slot_id, block_size
+ )
+
+
+def linear_attention_prefill_and_mix(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ kv_cache: torch.Tensor,
+ state_indices_tensor: torch.Tensor,
+ attn_metadata: LinearAttentionMetadata,
+ slope_rate: torch.Tensor,
+ block_size: int,
+ decode_fn: Callable[..., torch.Tensor],
+ prefix_fn: Callable[..., torch.Tensor],
+ layer_idx: int | None = None,
+) -> torch.Tensor:
+ hidden = []
+ for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
+ if _prefill_idx >= len(attn_metadata.query_start_loc):
+ break
+ if _prefill_idx >= len(state_indices_tensor):
+ break
+ offset = attn_metadata.num_decode_tokens
+ _start = attn_metadata.query_start_loc[offset + _prefill_idx]
+ _end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
+ slot_id = state_indices_tensor[offset + _prefill_idx]
+ qs = q[_start:_end].transpose(0, 1).contiguous()
+ ks = k[_start:_end].transpose(0, 1).contiguous()
+ vs = v[_start:_end].transpose(0, 1).contiguous()
+ slice_layer_cache = kv_cache[slot_id, ...]
+ out_slice = prefix_fn(
+ qs,
+ ks,
+ vs,
+ slice_layer_cache,
+ slope_rate,
+ block_size,
+ layer_idx=layer_idx,
+ )
+ hidden.append(out_slice.contiguous())
+
+ if attn_metadata.num_decode_tokens > 0:
+ hidden_decode = decode_fn(
+ q, k, v, kv_cache, state_indices_tensor, attn_metadata
+ )
+ hidden.insert(0, hidden_decode)
+
+ if not hidden:
+ return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
+
+ hidden = torch.concat(hidden, dim=0).contiguous()
+ return hidden
+
+
class MiniMaxText01LinearKernel:
@staticmethod
def jit_linear_forward_prefix(
@@ -258,50 +352,33 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
def _prefill_and_mix_infer(
self, q, k, v, kv_cache, state_indices_tensor, attn_metadata
):
- hidden = []
- for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
- if _prefill_idx >= len(attn_metadata.query_start_loc):
- break
- if _prefill_idx >= len(state_indices_tensor):
- break
- offset = attn_metadata.num_decode_tokens
- _start = attn_metadata.query_start_loc[offset + _prefill_idx]
- _end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
- slot_id = state_indices_tensor[offset + _prefill_idx]
- qs = q[_start:_end].transpose(0, 1).contiguous()
- ks = k[_start:_end].transpose(0, 1).contiguous()
- vs = v[_start:_end].transpose(0, 1).contiguous()
- slice_layer_cache = kv_cache[slot_id, ...]
-
- out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
- qs,
- ks,
- vs,
- slice_layer_cache,
- self.tp_slope,
- self.BLOCK,
- layer_idx=self.layer_idx,
- )
- hidden.append(out_slice.contiguous())
- if attn_metadata.num_decode_tokens > 0:
- hidden_decode = self._decode_infer(
- q, k, v, kv_cache, state_indices_tensor, attn_metadata
- )
- hidden.insert(0, hidden_decode)
-
- if not hidden:
- return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
-
- hidden = torch.concat(hidden, dim=0).contiguous()
- return hidden
+ return linear_attention_prefill_and_mix(
+ q=q,
+ k=k,
+ v=v,
+ kv_cache=kv_cache,
+ state_indices_tensor=state_indices_tensor,
+ attn_metadata=attn_metadata,
+ slope_rate=self.tp_slope,
+ block_size=self.BLOCK,
+ decode_fn=self._decode_infer,
+ prefix_fn=MiniMaxText01LinearKernel.jit_linear_forward_prefix,
+ layer_idx=self.layer_idx,
+ )
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata):
- q = q[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
- k = k[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
- v = v[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
- slot_id = state_indices_tensor[: attn_metadata.num_decodes]
- hidden = linear_decode_forward_triton(
- q, k, v, kv_cache, self.tp_slope, slot_id, 32
+ hidden = linear_attention_decode(
+ q,
+ k,
+ v,
+ kv_cache,
+ self.tp_slope,
+ state_indices_tensor,
+ q_start=0,
+ q_end=attn_metadata.num_decode_tokens,
+ slot_start=0,
+ slot_end=attn_metadata.num_decodes,
+ block_size=32,
)
return hidden
@@ -338,27 +415,9 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
if attn_metadata is not None:
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
state_indices_tensor = attn_metadata.state_indices_tensor
-
- num_prefills = getattr(attn_metadata, "num_prefills", 0)
- if num_prefills > 0:
- num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", 0)
- for prefill_idx in range(num_prefills):
- q_start = attn_metadata.query_start_loc[
- num_decode_tokens + prefill_idx
- ]
- q_end = attn_metadata.query_start_loc[
- num_decode_tokens + prefill_idx + 1
- ]
- query_len = q_end - q_start
- context_len = (
- attn_metadata.seq_lens[num_decode_tokens + prefill_idx]
- - query_len
- )
- if context_len == 0:
- block_to_clear = state_indices_tensor[
- num_decode_tokens + prefill_idx
- ]
- kv_cache[block_to_clear, ...] = 0
+ clear_linear_attention_cache_for_new_sequences(
+ kv_cache, state_indices_tensor, attn_metadata
+ )
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
if attn_metadata is None:
diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py
index 24e189a..6a33fc7 100644
--- a/vllm/model_executor/layers/mamba/mamba_mixer.py
+++ b/vllm/model_executor/layers/mamba/mamba_mixer.py
@@ -271,6 +271,8 @@ class MambaMixer(MambaBase, PluggableLayer):
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p
+ cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
+ last_chunk_indices_p = attn_metadata.last_chunk_indices_p
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
@@ -376,6 +378,8 @@ class MambaMixer(MambaBase, PluggableLayer):
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
initial_state_idx=block_idx_last_computed_token_p,
+ cu_chunk_seqlen=cu_chunk_seqlen_p,
+ last_chunk_indices=last_chunk_indices_p,
)
ssm_outputs.append(scan_out_p)
diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py
index fc8912f..1f6751f 100644
--- a/vllm/model_executor/layers/mamba/mamba_utils.py
+++ b/vllm/model_executor/layers/mamba/mamba_utils.py
@@ -289,9 +289,6 @@ def get_temporal_copy_spec(
)
-get_full_copy_spec = get_temporal_copy_spec
-
-
class MambaStateCopyFuncCalculator:
@classmethod
def linear_attention_state_copy_func(cls):
diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
index 507a0a9..b0c1ffb 100644
--- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
+++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
@@ -1159,7 +1159,7 @@ def causal_conv1d_update(
f"ERROR: conv_state_indices should have shape ({batch},*) but got {conv_state_indices.shape}"
)
- # assert num_cache_lines >= batch
+ assert num_cache_lines >= batch
assert weight.stride(1) == 1 # Need this
# adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o'
diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py
index a0df65f..44e73dd 100644
--- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py
+++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py
@@ -497,6 +497,8 @@ def selective_scan_fn(
block_idx_first_scheduled_token=None,
block_idx_last_scheduled_token=None,
initial_state_idx=None,
+ cu_chunk_seqlen=None,
+ last_chunk_indices=None,
) -> torch.Tensor:
"""
u: (dim, total_length) for varlen or (batch, dim, seqlen)
@@ -588,6 +590,8 @@ def selective_scan_fn(
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
initial_state_idx,
+ cu_chunk_seqlen,
+ last_chunk_indices,
)
if z is None:
diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py
index e289727..ca09b23 100644
--- a/vllm/model_executor/layers/mla.py
+++ b/vllm/model_executor/layers/mla.py
@@ -9,7 +9,6 @@ from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.attention import MLAAttention
from vllm.model_executor.layers.quantization import QuantizationConfig
-
@dataclass
class MLAModules:
"""Modules used in MLA."""
@@ -18,7 +17,7 @@ class MLAModules:
kv_b_proj: torch.nn.Module
rotary_emb: torch.nn.Module
o_proj: torch.nn.Module
- fused_qkv_a_proj: torch.nn.Module | None
+ q_a_proj: torch.nn.Module | None
kv_a_proj_with_mqa: torch.nn.Module | None
q_a_layernorm: torch.nn.Module | None
q_b_proj: torch.nn.Module | None
@@ -74,7 +73,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
- self.fused_qkv_a_proj = mla_modules.fused_qkv_a_proj
+ self.q_a_proj = mla_modules.q_a_proj
self.kv_a_proj_with_mqa = mla_modules.kv_a_proj_with_mqa
self.q_a_layernorm = mla_modules.q_a_layernorm
self.q_b_proj = mla_modules.q_b_proj
@@ -106,7 +105,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
kv_b_proj=self.kv_b_proj,
use_sparse=self.is_sparse,
indexer=self.indexer,
- rotary_emb=self.rotary_emb
+ rotary_emb=self.rotary_emb,
)
self.prefix = prefix
@@ -119,60 +118,47 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
) -> torch.Tensor:
q_c = None
kv_lora = None
-
if self.q_lora_rank is not None:
- assert self.fused_qkv_a_proj is not None, (
- "fused_qkv_a_proj is required when q_lora_rank is not None"
- )
- assert self.q_a_layernorm is not None, (
- "q_a_layernorm is required when q_lora_rank is not None"
- )
- assert self.q_b_proj is not None, (
- "q_b_proj is required when q_lora_rank is not None"
- )
- qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
- q_c, kv_lora = qkv_lora.split(
- [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
- dim=-1,
- )
- q_c = self.q_a_layernorm(q_c)
- q = self.q_b_proj(q_c)[0]
+ q = self.q_a_proj(hidden_states)[0]
+ kv_a, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1)
+ q = self.q_a_layernorm(q)
+ q = self.q_b_proj(q)[0].view(-1, self.num_heads, self.qk_head_dim)
+ kv_a = self.kv_a_layernorm(kv_a)
else:
- assert self.kv_a_proj_with_mqa is not None, (
- "kv_a_proj_with_mqa is required when q_lora_rank is None"
- )
- assert self.q_proj is not None, (
- "q_proj is required when q_lora_rank is None"
- )
- kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0]
- q = self.q_proj(hidden_states)[0]
-
- kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
- kv_c_normed = self.kv_a_layernorm(kv_c)
-
- q = q.view(-1, self.num_heads, self.qk_head_dim)
- # Add head dim of 1 to k_pe
- # k_pe = k_pe.unsqueeze(1)
-
- # if self.rotary_emb is not None:
- # q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb(
- # positions, q[..., self.qk_nope_head_dim :], k_pe
- # )
-
- if self.indexer and self.is_sparse:
- _topk_indices = self.indexer(
- hidden_states, q_c, positions, self.indexer_rope_emb
- )
-
+ q = self.q_proj(hidden_states)[0].view(-1, self.num_heads, self.qk_head_dim)
+ latent_kpe = self.kv_a_proj_with_mqa(hidden_states)[0]
+ kv_a, k_pe = latent_kpe.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1)
+ kv_a = self.kv_a_layernorm(kv_a)
+
+ # NOTE attention data do not have position, pass it here
if llama_4_scaling is not None:
q *= llama_4_scaling
-
- self.mla_attn.impl.forward_prepare(positions)
- attn_out = self.mla_attn(
- q,
- kv_c_normed,
- k_pe,
- output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim),
- )
-
+ attn_out = self.mla_attn(q, kv_a, k_pe, positions)
return self.o_proj(attn_out)[0]
+
+ def forward_opt(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ llama_4_scaling: torch.Tensor | None = None):
+ if self.q_lora_rank is not None:
+ q_latent_kpe = self.q_a_proj(hidden_states)[0]
+ q, kv_a, k_pe, _ = q_latent_kpe.split([self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim, self.q_a_proj.output_padding_size], dim=1)
+ q_c = self.q_a_layernorm(q)
+ q = self.q_b_proj(q_c)[0].view(-1, self.num_heads, self.qk_head_dim)
+ kv_a = self.kv_a_layernorm(kv_a)
+ else:
+ q = self.q_proj(hidden_states)[0].view(-1, self.num_heads, self.qk_head_dim)
+ latent_kpe = self.kv_a_proj_with_mqa(hidden_states)[0]
+ kv_a, k_pe = latent_kpe.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1)
+ kv_a = self.kv_a_layernorm(kv_a)
+ if self.indexer and self.is_sparse:
+ _topk_indices = self.indexer(hidden_states, q_c, positions,
+ self.rotary_emb)
+
+ # NOTE attention data do not have position, pass it here
+ if llama_4_scaling is not None:
+ q *= llama_4_scaling
+ attn_out = self.mla_attn(q, kv_a, k_pe, positions)
+ return self.o_proj(attn_out)[0]
+
diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py
index 09e67f5..8794647 100644
--- a/vllm/model_executor/layers/quantization/__init__.py
+++ b/vllm/model_executor/layers/quantization/__init__.py
@@ -18,6 +18,7 @@ QuantizationMethods = Literal[
"modelopt",
"modelopt_fp4",
"modelopt_mxfp8",
+ "modelopt_mixed",
"gguf",
"gptq_marlin",
"awq_marlin",
@@ -32,6 +33,7 @@ QuantizationMethods = Literal[
"mxfp4",
"petit_nvfp4",
"cpu_awq",
+ "w8a16"
]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
@@ -120,12 +122,18 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .gptq import GPTQConfig
from .gptq_marlin import GPTQMarlinConfig
from .inc import INCConfig
- from .modelopt import ModelOptFp8Config, ModelOptMxFp8Config, ModelOptNvFp4Config
+ from .modelopt import (
+ ModelOptFp8Config,
+ ModelOptMixedPrecisionConfig,
+ ModelOptMxFp8Config,
+ ModelOptNvFp4Config,
+ )
from .moe_wna16 import MoeWNA16Config
from .mxfp4 import Mxfp4Config
from .petit import PetitNvFp4Config
from .ptpc_fp8 import PTPCFp8Config
from .torchao import TorchAOConfig
+ from .w8a16 import W8a16Config
method_to_config: dict[str, type[QuantizationConfig]] = {
"awq": AWQConfig,
@@ -135,6 +143,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"modelopt": ModelOptFp8Config,
"modelopt_fp4": ModelOptNvFp4Config,
"modelopt_mxfp8": ModelOptMxFp8Config,
+ "modelopt_mixed": ModelOptMixedPrecisionConfig,
"gguf": GGUFConfig,
"gptq_marlin": GPTQMarlinConfig,
"awq_marlin": AWQMarlinConfig,
@@ -151,6 +160,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"mxfp4": Mxfp4Config,
"petit_nvfp4": PetitNvFp4Config,
"cpu_awq": CPUAWQConfig,
+ "w8a16": W8a16Config,
}
# Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py
index d80a9ba..791ad0c 100644
--- a/vllm/model_executor/layers/quantization/awq_marlin.py
+++ b/vllm/model_executor/layers/quantization/awq_marlin.py
@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from typing import TYPE_CHECKING, Any, Callable
+from typing import TYPE_CHECKING, Any
import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
@@ -9,6 +9,7 @@ from torch.nn import Parameter
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
+from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
@@ -60,6 +61,7 @@ from vllm.transformers_utils.config import get_safetensors_params_metadata
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.models.utils import WeightsMapper
+import ixformer.inference.functions as ixfops
logger = init_logger(__name__)
@@ -197,7 +199,7 @@ class AWQMarlinConfig(QuantizationConfig):
quant_method.input_dtype = get_marlin_input_dtype(prefix)
return quant_method
elif isinstance(layer, FusedMoE):
- from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
+ # from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
# if is_layer_skipped(
# prefix,
@@ -213,9 +215,10 @@ class AWQMarlinConfig(QuantizationConfig):
# return MoeWNA16Config.from_config(self.full_config).get_quant_method(
# layer, prefix
# )
- moe_quant_method = AWQMarlinMoEMethod(self, layer.moe_config)
- moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
- return moe_quant_method
+ # moe_quant_method = AWQMarlinMoEMethod(self, layer.moe_config)
+ # moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
+ # return moe_quant_method
+ return AWQMarlinMoEMethod(self, layer.moe_config)
return None
@classmethod
@@ -389,13 +392,13 @@ class AWQMarlinLinearMethod(LinearMethodBase):
replace_parameter(layer, "qweight", pad_qweight)
replace_parameter(layer, "qzeros", pad_qzeros)
replace_parameter(layer, "scales", pad_scales)
- return
+
# TODO(gyf) Marlin format is not support for now..
device = layer.qweight.device
layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
-
+ return
# Allocate marlin workspace
layer.workspace = marlin_make_workspace_new(device)
@@ -811,49 +814,33 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
self,
layer: FusedMoE,
x: torch.Tensor,
- router_logits: torch.Tensor,
- top_k: int,
- renormalize: bool,
- use_grouped_topk: bool = False,
- topk_group: int | None = None,
- num_expert_group: int | None = None,
- global_num_experts: int = -1,
- expert_map: torch.Tensor | None = None,
- custom_routing_function: Callable | None = None,
- scoring_func: str = "softmax",
- routed_scaling_factor: float = 1.0,
- e_score_correction_bias: torch.Tensor | None = None,
- apply_router_weight_on_input: bool = False,
- activation: str = "silu",
- enable_eplb: bool = False,
- expert_load_view: torch.Tensor | None = None,
- logical_to_physical_map: torch.Tensor | None = None,
- logical_replica_count: torch.Tensor | None = None,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ # Assign the value of shared_experts_output to variable shared_experts_input for fusion
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- # return fused_marlin_moe(
- # x,
- # layer.w13_qweight,
- # layer.w2_qweight,
- # getattr(layer, "w13_bias", None),
- # getattr(layer, "w2_bias", None),
- # layer.w13_scales,
- # layer.w2_scales,
- # topk_weights,
- # topk_ids,
- # input_global_scale1=getattr(layer, "w13_input_global_scale", None),
- # input_global_scale2=getattr(layer, "w2_input_global_scale", None),
- # quant_type_id=self.quant_type.id,
- # apply_router_weight_on_input=layer.apply_router_weight_on_input,
- # global_num_experts=layer.global_num_experts,
- # expert_map=layer.expert_map,
- # w1_zeros=layer.w13_qzeros,
- # w2_zeros=layer.w2_qzeros,
- # workspace=layer.workspace,
- # input_dtype=self.input_dtype,
- # inplace=not self.moe.disable_inplace,
- # )
- num_tokens, num_experts = router_logits.shape
+ assert layer.activation.value == "silu", "Only SiLU activation is supported."
+ use_ep = layer.expert_map is not None
+ attn_metadata = get_forward_context().attn_metadata
+ if attn_metadata:
+ if isinstance(attn_metadata, dict):
+ only_decode = (use_ep == False and all(t.num_decodes > 0 and t.num_prefills ==0 for t in list(attn_metadata.values())))
+ else:
+ only_decode = use_ep == False and attn_metadata.num_decodes > 0 and attn_metadata.num_prefills == 0
+ else:
+ only_decode = False
+
+ if use_ep:
+ start_eid = layer.ep_rank * layer.local_num_experts
+ end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, layer.global_num_experts)
+ if layer.apply_router_weight_on_input:
+ raise NotImplementedError(
+ "Apply router weight on input is not supported for"
+ "fused Marlin MoE method.")
+
+ num_tokens = topk_ids.shape[0]
+ num_experts = layer.global_num_experts
if use_ep:
hidden_size = x.shape[1]
(
@@ -875,7 +862,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
dtype=x.dtype,
)
else:
- expand_tokens = num_tokens * top_k
+ expand_tokens = num_tokens * layer.top_k
(
src_to_dst,
sorted_token_ids,
@@ -885,7 +872,6 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
topk_ids=topk_ids,
num_experts=num_experts,
)
- expert_sizes_cpu = expert_sizes_gpu.cpu()
# expand + reorder
# TODO use kernel
@@ -893,76 +879,130 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
hidden_states=x,
dst_to_src=sorted_token_ids,
dst_tokens=expand_tokens,
- topk=top_k,
+ topk=layer.top_k,
src_to_dst=src_to_dst,
)
# w4a16 group gemm 1
# pt_output_1: (expand_tokens, 2n) dtype
- pt_output_1 = ixfops.moe_w4a16_group_gemm(
- input=expand_hidden_states,
- weight=layer.w13_qweight,
- w_scales=layer.w13_scales,
- quant_type="awq",
- tokens_per_experts=expert_sizes_cpu,
- w_zeros=layer.w13_qzeros,
- group_size=self.quant_config.group_size,
- dst_to_src=None,
- format="NN",
- tokens_per_experts_gpu=expert_sizes_gpu,
- )
-
- # act
- pt_output_2 = ixfops.silu_and_mul(pt_output_1)
-
- # w4a16 group gemm 2 + reorder
- # pt_output_3: (expand_tokens, k) dtype
- if use_ep:
- pt_output_3 = torch.empty(
- (num_tokens * top_k, hidden_size),
- device=x.device,
- dtype=x.dtype,
- )
-
- ixfops.moe_w4a16_group_gemm(
- input=pt_output_2,
- weight=layer.w2_qweight,
- w_scales=layer.w2_scales,
+ if only_decode:
+ pt_output_1 = ixfops.moe_w4a16_group_gemv(
+ input=expand_hidden_states,
+ weight=layer.w13_qweight,
+ w_scales=layer.w13_scales,
quant_type="awq",
- tokens_per_experts=expert_sizes_cpu,
- w_zeros=layer.w2_qzeros,
+ w_zeros=layer.w13_qzeros,
group_size=self.quant_config.group_size,
- dst_to_src=sorted_token_ids,
- format="NN",
- output=pt_output_3,
- )
-
- reduce_mask = src_to_dst == -1
- final_hidden_states = ixfops.moe_output_reduce_sum(
- input=pt_output_3.view(num_tokens, top_k, -1),
- topk_weight=topk_weights,
- scaling_factor=routed_scaling_factor,
- mask=reduce_mask,
- )
- else:
- pt_output_3 = ixfops.moe_w4a16_group_gemm(
- input=pt_output_2,
- weight=layer.w2_qweight,
- w_scales=layer.w2_scales,
- quant_type="awq",
- tokens_per_experts=expert_sizes_cpu,
- w_zeros=layer.w2_qzeros,
- group_size=self.quant_config.group_size,
- dst_to_src=sorted_token_ids,
+ dst_to_src=None,
format="NN",
tokens_per_experts_gpu=expert_sizes_gpu,
)
- # mul + reduce_sum
- # final_hidden_states: (num_tokens, k)
+ # act
+ pt_output_2 = ixfops.silu_and_mul(pt_output_1)
+
+ pt_output_3 = ixfops.moe_w4a16_group_gemv(
+ input=pt_output_2,
+ weight=layer.w2_qweight,
+ w_scales=layer.w2_scales,
+ quant_type="awq",
+ w_zeros=layer.w2_qzeros,
+ group_size=self.quant_config.group_size,
+ dst_to_src=sorted_token_ids,
+ format="NN",
+ tokens_per_experts_gpu=expert_sizes_gpu,
+ )
+
+ # mul + reduce_sum
+ # final_hidden_states: (num_tokens, k)
final_hidden_states = ixfops.moe_output_reduce_sum(
- input=pt_output_3.view(num_tokens, top_k, -1),
+ input=pt_output_3.view(num_tokens, layer.top_k, -1),
topk_weight=topk_weights,
- scaling_factor=routed_scaling_factor
+ scaling_factor=layer.routed_scaling_factor,
+ extra_residual=shared_experts_input,
)
+
+ else:
+ expert_sizes_cpu = expert_sizes_gpu.cpu()
+ pt_output_1 = ixfops.moe_w4a16_group_gemm(
+ input=expand_hidden_states,
+ weight=layer.w13_qweight,
+ w_scales=layer.w13_scales,
+ quant_type="awq",
+ tokens_per_experts=expert_sizes_cpu,
+ w_zeros=layer.w13_qzeros,
+ group_size=self.quant_config.group_size,
+ dst_to_src=None,
+ format="NN",
+ tokens_per_experts_gpu=expert_sizes_gpu,
+ )
+
+ # act
+ pt_output_2 = ixfops.silu_and_mul(pt_output_1)
+
+ # w4a16 group gemm 2 + reorder
+ # pt_output_3: (expand_tokens, k) dtype
+ if use_ep:
+ pt_output_3 = torch.empty(
+ (num_tokens * layer.top_k, hidden_size),
+ device=x.device,
+ dtype=x.dtype,
+ )
+
+ ixfops.moe_w4a16_group_gemm(
+ input=pt_output_2,
+ weight=layer.w2_qweight,
+ w_scales=layer.w2_scales,
+ quant_type="awq",
+ tokens_per_experts=expert_sizes_cpu,
+ w_zeros=layer.w2_qzeros,
+ group_size=self.quant_config.group_size,
+ dst_to_src=sorted_token_ids,
+ format="NN",
+ output=pt_output_3,
+ tokens_per_experts_gpu=expert_sizes_gpu,
+ )
+
+ reduce_mask = src_to_dst == -1
+ final_hidden_states = ixfops.moe_output_reduce_sum(
+ input=pt_output_3.view(num_tokens, layer.top_k, -1),
+ topk_weight=topk_weights,
+ scaling_factor=layer.routed_scaling_factor,
+ mask=reduce_mask,
+ )
+ else:
+ pt_output_3 = ixfops.moe_w4a16_group_gemm(
+ input=pt_output_2,
+ weight=layer.w2_qweight,
+ w_scales=layer.w2_scales,
+ quant_type="awq",
+ tokens_per_experts=expert_sizes_cpu,
+ w_zeros=layer.w2_qzeros,
+ group_size=self.quant_config.group_size,
+ dst_to_src=sorted_token_ids,
+ format="NN",
+ tokens_per_experts_gpu=expert_sizes_gpu,
+ )
+
+ # mul + reduce_sum
+ # final_hidden_states: (num_tokens, k)
+ final_hidden_states = ixfops.moe_output_reduce_sum(
+ input=pt_output_3.view(num_tokens, layer.top_k, -1),
+ topk_weight=topk_weights,
+ scaling_factor=layer.routed_scaling_factor,
+ extra_residual=shared_experts_input,
+ )
return final_hidden_states
+ # return torch.ops.vllm.fused_marlin_moe(
+ # x,
+ # layer.w13_qweight,
+ # layer.w2_qweight,
+ # layer.w13_scales,
+ # layer.w2_scales,
+ # router_logits,
+ # topk_weights,
+ # topk_ids,
+ # w1_zeros=layer.w13_qzeros,
+ # w2_zeros=layer.w2_qzeros,
+ # num_bits=self.quant_config.weight_bits,
+ # )
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
index 0654e0c..4a02c91 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
@@ -18,7 +18,6 @@ from compressed_tensors.quantization import (
)
from compressed_tensors.transform import TransformConfig
-import vllm.envs as envs
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@@ -52,7 +51,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16,
- CompressedTensorsW4A8Int8
)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
CompressedTensorsLinearTransformMethod,
@@ -401,8 +399,8 @@ class CompressedTensorsConfig(QuantizationConfig):
) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = (
- weight_quant.strategy == QuantizationStrategy.TENSOR.value
- or weight_quant.strategy == QuantizationStrategy.CHANNEL.value
+ weight_quant.strategy == QuantizationStrategy.CHANNEL.value
+ or weight_quant.strategy == QuantizationStrategy.GROUP.value
)
is_tensor = (
weight_strategy
@@ -420,8 +418,8 @@ class CompressedTensorsConfig(QuantizationConfig):
) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = (
- weight_quant.strategy == QuantizationStrategy.TENSOR.value
- or weight_quant.strategy == QuantizationStrategy.CHANNEL.value
+ weight_quant.strategy == QuantizationStrategy.CHANNEL.value
+ or weight_quant.strategy == QuantizationStrategy.GROUP.value
)
is_token = (
weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value
@@ -663,12 +661,6 @@ class CompressedTensorsConfig(QuantizationConfig):
)
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
- if envs.VLLM_W8A8_LINEAR_USE_W4A8:
- return CompressedTensorsW4A8Int8(
- strategy=weight_quant.strategy,
- is_static_input_scheme=False,
- input_symmetric=input_quant.symmetric,
- )
return CompressedTensorsW8A8Int8(
strategy=weight_quant.strategy,
is_static_input_scheme=False,
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
index d83b5c8..4c8b865 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
@@ -3,6 +3,7 @@
import enum
from enum import Enum
+from typing import List, Optional, Tuple
import torch
from compressed_tensors import CompressionFormat
@@ -12,16 +13,16 @@ from compressed_tensors.quantization import (
QuantizationStrategy,
)
-import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
+from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEActivationFormat,
+ FusedMoEExpertsModular,
FusedMoEMethodBase,
- FusedMoEPermuteExpertsUnpermute,
FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod,
)
@@ -41,7 +42,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe,
)
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
- Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
make_fp8_moe_quant_config,
@@ -60,18 +60,11 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compress
WNA16_SUPPORTED_BITS,
WNA16_SUPPORTED_TYPES_MAP,
)
-from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
- flashinfer_trtllm_fp4_moe,
- flashinfer_trtllm_fp4_routed_moe,
-)
from vllm.model_executor.layers.quantization.utils.flashinfer_mxint4_moe import (
flashinfer_trtllm_mxint4_moe,
is_flashinfer_mxint4_moe_available,
prepare_static_weights_for_trtllm_mxint4_moe,
)
-from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
- apply_fi_trtllm_fp8_per_tensor_moe,
-)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
process_fp8_input_tensor_strategy_moe,
process_fp8_weight_tensor_strategy_moe,
@@ -106,8 +99,6 @@ from vllm.platforms import CpuArchEnum, current_platform
logger = init_logger(__name__)
import ixformer.inference.functions as ixfops
import vllm.envs as envs
-from vllm.forward_context import ForwardContext, get_forward_context
-
class GPTQMarlinState(Enum):
@@ -126,6 +117,42 @@ __all__ = [
]
+def moe_w4a8_group_gemm_pad_k(
+ input: torch.Tensor,
+ weight: torch.Tensor,
+ i_scales: torch.Tensor,
+ w_scales: torch.Tensor,
+ output_dtype: torch.dtype,
+ tokens_per_experts: torch.Tensor,
+ w_i8scales: torch.Tensor = None,
+ w_i8zeros: torch.Tensor = None,
+ dst_to_src: torch.Tensor = None,
+ bias: Optional[torch.Tensor] = None,
+ format: int = 0,
+ version: int = 2,
+ group_size: int = -1,
+ persistent: int = 0,
+ gdr_buffer_ptr: int = None,
+ output: torch.Tensor = None,
+):
+ # Pad k dimension if needed for kernel constraint (format 0/1: k%64==0 && k>=256; format 2: k%128==0 && k>=256)
+ if format in (0, 1):
+ k = weight.size(1)
+ k_padded = max(256, (k + 63) & ~63) # bitwise align to 64, avoids div/mul
+ else:
+ k = weight.size(2) << 1 # * 2
+ k_padded = max(256, (k + 127) & ~127) # bitwise align to 128
+ pad_k = k_padded - k
+ if pad_k > 0:
+ if format in (0, 1):
+ input = torch.nn.functional.pad(input, (0, pad_k), mode='constant', value=0)
+ weight = torch.nn.functional.pad(weight, (0, 0, 0, pad_k), mode='constant', value=0)
+ else:
+ input = torch.nn.functional.pad(input, (0, pad_k), mode='constant', value=0)
+ weight = torch.nn.functional.pad(weight, (0, pad_k >> 1, 0, 0), mode='constant', value=0)
+
+ return ixfops.moe_w4a8_group_gemm(input, weight, i_scales, w_scales, output_dtype, tokens_per_experts, w_i8scales, w_i8zeros, dst_to_src, bias, format, version, group_size, persistent, gdr_buffer_ptr, output)
+
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
@staticmethod
def get_moe_method(
@@ -182,9 +209,13 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
f"but got format: {CompressionFormat.pack_quantized.value} "
f" and bits: {weight_quant.num_bits}",
)
+ if envs.VLLM_WNA16_MOE_USE_W4A8:
+ assert weight_quant.num_bits == 4 and group_size == -1
+ logger.info_once("Using CompressedTensorsW4A8MoEMethod")
+ return CompressedTensorsW4A16MoEMethod(weight_quant, input_quant, layer.moe_config)
# Prefer to use the MarlinMoE kernel when it is supported.
- if (
+ elif (
not check_moe_marlin_supports_layer(layer, group_size)
or current_platform.is_rocm()
):
@@ -225,10 +256,11 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
return CompressedTensorsW8A8Fp8MoEMethod(
weight_quant, input_quant, layer.moe_config
)
- elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
+ elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant) or quant_config._is_static_tensor_w8a8(weight_quant, input_quant):
if envs.VLLM_W8A8_MOE_USE_W4A8:
- return CompressedTensorsW4A8MoEMethod(quant_config, layer.moe_config)
- return CompressedTensorsW8A8Int8MoEMethod(
+ return CompressedTensorsW4A8MoEMethod(weight_quant, input_quant, layer.moe_config)
+ else:
+ return CompressedTensorsW8A8Int8MoEMethod(
weight_quant, input_quant, layer.moe_config
)
elif quant_config._is_fp8_w4a8_sm90(weight_quant, input_quant):
@@ -343,7 +375,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config is not None:
- self.moe_mk = make_nvfp4_moe_kernel(
+ self.moe_kernel = make_nvfp4_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
experts_cls=self.experts_cls,
@@ -358,10 +390,9 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
- **kwargs,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- assert self.moe_mk is not None
- return self.moe_mk(
+ assert self.moe_kernel is not None
+ return self.moe_kernel.apply(
x,
layer.w13_weight,
layer.w2_weight,
@@ -570,43 +601,27 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
layer.w13_input_scale = a13_scale
layer.w2_input_scale = a2_scale
- # Setup modular kernel for TP case and naive DP/EP case.
- # In non-naive DP/EP case, we will create a ModularKernelMethod.
- # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
- # in both cases.
+ # Setup modular kernel.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
- if self.moe_quant_config:
- assert self.experts_cls is not None
- self.moe_mk = make_nvfp4_moe_kernel(
- moe_quant_config=self.moe_quant_config,
- moe_config=self.moe,
- experts_cls=self.experts_cls,
- shared_experts=layer.shared_experts,
- routing_tables=layer._maybe_init_expert_routing_tables(),
- )
+ assert self.experts_cls is not None
+ self.moe_kernel = make_nvfp4_moe_kernel(
+ moe_quant_config=self.moe_quant_config,
+ moe_config=self.moe,
+ experts_cls=self.experts_cls,
+ shared_experts=layer.shared_experts,
+ routing_tables=layer._maybe_init_expert_routing_tables(),
+ )
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
- ) -> mk.FusedMoEPrepareAndFinalize | None:
+ ) -> mk.FusedMoEPrepareAndFinalizeModular | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
- def select_gemm_impl(
- self,
- prepare_finalize: mk.FusedMoEPrepareAndFinalize,
- layer: torch.nn.Module,
- ) -> mk.FusedMoEPermuteExpertsUnpermute:
- raise ValueError(
- f"{self.__class__.__name__} uses the new modular kernel initialization "
- "logic. This function should not be called."
- )
-
- def get_fused_moe_quant_config(
- self, layer: torch.nn.Module
- ) -> FusedMoEQuantConfig | None:
+ def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
return make_nvfp4_moe_quant_config(
backend=self.nvfp4_backend,
w13_scale=layer.w13_weight_scale,
@@ -617,13 +632,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
a2_scale=layer.w2_input_scale,
)
- @property
- def is_monolithic(self) -> bool:
- return (
- self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
- and not self.moe.moe_parallel_config.enable_eplb
- )
-
def apply_monolithic(
self,
layer: FusedMoE,
@@ -631,24 +639,20 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
- assert layer.activation == MoEActivation.SILU, (
- f"Only SiLU activation is supported, not {layer.activation}."
- )
- assert (
- self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
- and not layer.enable_eplb
- )
- return flashinfer_trtllm_fp4_moe(
- layer=layer,
- x=x,
- router_logits=router_logits,
- top_k=layer.top_k,
+ assert self.moe_kernel is not None
+ return self.moe_kernel.apply_monolithic(
+ x,
+ layer.w13_weight,
+ layer.w2_weight,
+ router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
+ expert_map=layer.expert_map,
+ apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
- custom_routing_function=layer.custom_routing_function,
e_score_correction_bias=layer.e_score_correction_bias,
+ routed_scaling_factor=layer.routed_scaling_factor,
)
def apply(
@@ -658,36 +662,20 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
- **kwargs,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- assert not self.is_monolithic
-
- # EPLB path
- if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
- assert layer.enable_eplb
- return flashinfer_trtllm_fp4_routed_moe(
- layer=layer,
- x=x,
- topk_ids=topk_ids,
- topk_weights=topk_weights,
- top_k=layer.top_k,
- activation=layer.activation,
- global_num_experts=layer.global_num_experts,
- )
- else:
- assert self.moe_mk is not None
- return self.moe_mk(
- x,
- layer.w13_weight,
- layer.w2_weight,
- topk_weights,
- topk_ids,
- activation=layer.activation,
- global_num_experts=layer.global_num_experts,
- expert_map=layer.expert_map,
- apply_router_weight_on_input=layer.apply_router_weight_on_input,
- shared_experts_input=shared_experts_input,
- )
+ assert self.moe_kernel is not None
+ return self.moe_kernel.apply(
+ x,
+ layer.w13_weight,
+ layer.w2_weight,
+ topk_weights,
+ topk_ids,
+ activation=layer.activation,
+ global_num_experts=layer.global_num_experts,
+ expert_map=layer.expert_map,
+ apply_router_weight_on_input=layer.apply_router_weight_on_input,
+ shared_experts_input=shared_experts_input,
+ )
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
@@ -946,7 +934,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
w13,
w13_scale,
shard_size=layer.intermediate_size_per_partition,
- num_experts=layer.num_local_experts,
+ num_experts=layer.local_num_experts,
is_act_and_mul=self.moe.is_act_and_mul,
)
@@ -975,7 +963,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
- self.moe_mk = make_fp8_moe_kernel(
+ self.moe_kernel = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
@@ -987,94 +975,47 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
- ) -> mk.FusedMoEPrepareAndFinalize | None:
+ ) -> mk.FusedMoEPrepareAndFinalizeModular | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
- def select_gemm_impl(
- self,
- prepare_finalize: mk.FusedMoEPrepareAndFinalize,
- layer: torch.nn.Module,
- ) -> mk.FusedMoEPermuteExpertsUnpermute:
- raise ValueError(
- f"{self.__class__.__name__} uses the new modular kernel initialization "
- "logic. This function should not be called."
- )
-
- def get_fused_moe_quant_config(
- self, layer: torch.nn.Module
- ) -> FusedMoEQuantConfig | None:
- w1_scale = layer.w13_weight_scale
- w2_scale = layer.w2_weight_scale
- a1_scale = layer.w13_input_scale
- a2_scale = layer.w2_input_scale
-
+ def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
+ is_per_token = self.input_quant.strategy == QuantizationStrategy.TOKEN
return make_fp8_moe_quant_config(
fp8_backend=self.fp8_backend,
- w1_scale=w1_scale,
- w2_scale=w2_scale,
- a1_scale=a1_scale,
- a2_scale=a2_scale,
- per_act_token_quant=(
- self.input_quant.strategy == QuantizationStrategy.TOKEN
- ),
- per_out_ch_quant=(self.input_quant.strategy == QuantizationStrategy.TOKEN),
+ w1_scale=layer.w13_weight_scale,
+ w2_scale=layer.w2_weight_scale,
+ a1_scale=layer.w13_input_scale,
+ a2_scale=layer.w2_input_scale,
+ per_act_token_quant=is_per_token,
+ per_out_ch_quant=is_per_token,
block_shape=self.weight_block_size,
)
- @property
- def is_monolithic(self) -> bool:
- return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
-
def apply_monolithic(
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- assert self.is_monolithic
- assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
- assert layer.activation == MoEActivation.SILU, (
- f"Only SiLU activation is supported, not {layer.activation}."
+ assert self.moe_kernel is not None
+ return self.moe_kernel.apply_monolithic(
+ x,
+ layer.w13_weight,
+ layer.w2_weight,
+ router_logits,
+ activation=layer.activation,
+ global_num_experts=layer.global_num_experts,
+ expert_map=layer.expert_map,
+ apply_router_weight_on_input=layer.apply_router_weight_on_input,
+ num_expert_group=layer.num_expert_group,
+ topk_group=layer.topk_group,
+ e_score_correction_bias=layer.e_score_correction_bias,
+ routed_scaling_factor=layer.routed_scaling_factor,
)
- if self.block_quant:
- import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
-
- return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
- routing_logits=router_logits,
- routing_bias=layer.e_score_correction_bias,
- x=x,
- w13_weight=layer.w13_weight,
- w13_weight_scale_inv=layer.w13_weight_scale,
- w2_weight=layer.w2_weight,
- w2_weight_scale_inv=layer.w2_weight_scale,
- global_num_experts=layer.global_num_experts,
- top_k=layer.top_k,
- num_expert_group=layer.num_expert_group,
- topk_group=layer.topk_group,
- intermediate_size=layer.intermediate_size_per_partition,
- expert_offset=layer.ep_rank * layer.local_num_experts,
- local_num_experts=layer.local_num_experts,
- block_shape=self.weight_block_size,
- routing_method_type=layer.routing_method_type,
- routed_scaling=layer.routed_scaling_factor,
- )
- else:
- return apply_fi_trtllm_fp8_per_tensor_moe(
- layer=layer,
- hidden_states=x,
- router_logits=router_logits,
- routing_bias=layer.e_score_correction_bias,
- global_num_experts=layer.global_num_experts,
- top_k=layer.top_k,
- num_expert_group=layer.num_expert_group,
- topk_group=layer.topk_group,
- apply_router_weight_on_input=layer.apply_router_weight_on_input,
- )
-
def apply(
self,
layer: FusedMoE,
@@ -1082,11 +1023,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
- **kwargs
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
- assert self.moe_mk is not None
- return self.moe_mk(
+ assert self.moe_kernel is not None
+ return self.moe_kernel.apply(
x,
layer.w13_weight,
layer.w2_weight,
@@ -1115,8 +1055,13 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
layer_name: str | None = None,
):
super().__init__(moe)
+ self.has_bias = self.moe.has_bias
self.weight_quant = weight_quant
self.input_quant = input_quant
+ self.group_size = -1 if self.weight_quant.group_size is None else self.weight_quant.group_size
+ self.gemm_format = envs.VLLM_W8A8_FORMAT
+ assert self.gemm_format in ["TN", "NN"], f"W8A8 INT8 MoE only supports TN or NN format, but got {self.gemm_format}"
+ self.padding_dim = -1 if self.gemm_format == "TN" else -2
per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
@@ -1130,11 +1075,11 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
)
self.static_input_scales = not self.input_quant.dynamic
- if self.static_input_scales:
- raise ValueError(
- "For INT8 Fused MoE layers, we require channelwise, "
- "dynamic per token quantization. Found static input scales."
- )
+ # if self.static_input_scales:
+ # raise ValueError(
+ # "For INT8 Fused MoE layers, we require channelwise, "
+ # "dynamic per token quantization. Found static input scales.")
+
def create_weights(
self,
@@ -1146,50 +1091,88 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
**extra_weight_attrs,
):
params_dtype = torch.int8
+ w13_num_shards = 2 if self.moe.is_act_and_mul else 1
w13_remainder = hidden_size % 64
w2_remainder = intermediate_size_per_partition % 64
+ if w13_remainder != 0:
+ hidden_size_padded = hidden_size + (64 - w13_remainder)
+ else:
+ hidden_size_padded = hidden_size
+ if w2_remainder != 0:
+ intermediate_size_per_partition_padded = intermediate_size_per_partition + (64 - w2_remainder)
+ else:
+ intermediate_size_per_partition_padded = intermediate_size_per_partition
- hidden_size_padded = hidden_size if w13_remainder == 0 else hidden_size + (64 - w13_remainder)
-
- intermediate_size_per_partition_padded = intermediate_size_per_partition if w2_remainder == 0 else intermediate_size_per_partition + (64 - w2_remainder)
-
+ if self.gemm_format == "NN":
+ w13_shape = (num_experts, hidden_size_padded, w13_num_shards * intermediate_size_per_partition_padded)
+ w2_shape = (num_experts, intermediate_size_per_partition_padded, hidden_size_padded)
+ else:
+ w13_shape = (num_experts, w13_num_shards * intermediate_size_per_partition, hidden_size_padded)
+ w2_shape = (num_experts, hidden_size, intermediate_size_per_partition_padded)
+
# WEIGHTS
- w13_weight = torch.nn.Parameter(
- torch.empty(
- num_experts,
- 2 * intermediate_size_per_partition,
- hidden_size_padded,
- dtype=params_dtype,
- ),
- requires_grad=False,
- )
+ w13_weight = torch.nn.Parameter(torch.zeros(
+ w13_shape,
+ dtype=params_dtype), requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
- w2_weight = torch.nn.Parameter(
- torch.empty(
- num_experts,
- hidden_size,
- intermediate_size_per_partition_padded,
- dtype=params_dtype,
- ),
- requires_grad=False,
- )
+ w2_weight = torch.nn.Parameter(torch.zeros(
+ w2_shape,
+ dtype=params_dtype), requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
+ if self.gemm_format == "TN":
+ w13_bias_shape = (num_experts, 2 * intermediate_size_per_partition)
+ w2_bias_shape = (num_experts, hidden_size)
+ elif self.gemm_format == "NN":
+ w13_bias_shape = (num_experts, 2 * intermediate_size_per_partition_padded)
+ w2_bias_shape = (num_experts, hidden_size_padded)
+
+ if self.has_bias:
+ w13_bias = torch.nn.Parameter(
+ torch.zeros(
+ w13_bias_shape,
+ dtype=torch.bfloat16,
+ ),
+ requires_grad=False,
+ )
+ layer.register_parameter("w13_bias", w13_bias)
+ set_weight_attrs(w13_bias, extra_weight_attrs)
+ else:
+ setattr(layer, "w13_bias", None)
+
+ if self.has_bias:
+ w2_bias = torch.nn.Parameter(
+ torch.zeros(
+ w2_bias_shape,
+ dtype=torch.bfloat16,
+ ),
+ requires_grad=False,
+ )
+ layer.register_parameter("w2_bias", w2_bias)
+ set_weight_attrs(w2_bias, extra_weight_attrs)
+ else:
+ setattr(layer, "w2_bias", None)
+
# WEIGHT_SCALES
assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL
+ if self.gemm_format == "NN":
+ w13_weight_scale_shape = (num_experts, w13_num_shards * intermediate_size_per_partition_padded, 1)
+ w2_weight_scale_shape = (num_experts, hidden_size_padded, 1)
+ else:
+ w13_weight_scale_shape = (num_experts, w13_num_shards * intermediate_size_per_partition, 1)
+ w2_weight_scale_shape = (num_experts, hidden_size, 1)
+
w13_weight_scale = torch.nn.Parameter(
- torch.ones(
- num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
- ),
+ torch.ones(w13_weight_scale_shape, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(
- torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
+ torch.ones(w2_weight_scale_shape, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
@@ -1201,9 +1184,23 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
- assert not self.static_input_scales
- layer.w13_input_scale = None
- layer.w2_input_scale = None
+ if self.static_input_scales:
+ extra_weight_attrs.update(
+ {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
+ w13_input_scale = torch.nn.Parameter(torch.ones(
+ num_experts, hidden_size, dtype=torch.float32),
+ requires_grad=False)
+ layer.register_parameter("w13_input_scale", w13_input_scale)
+ set_weight_attrs(w13_input_scale, extra_weight_attrs)
+
+ w2_input_scale = torch.nn.Parameter(torch.ones(
+ num_experts, intermediate_size_per_partition, dtype=torch.float32),
+ requires_grad=False)
+ layer.register_parameter("w2_input_scale", w2_input_scale)
+ set_weight_attrs(w2_input_scale, extra_weight_attrs)
+ else:
+ layer.w13_input_scale = None
+ layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
@@ -1225,24 +1222,40 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
- router_logits: torch.Tensor,
- expert_map: torch.Tensor | None = None,
- enable_eplb: bool = False,
- extra_residual: torch.Tensor = None,
- routed_scaling_factor: float = 1.0,
- *args,
- **kwargs
+ # Assign the value of shared_experts_output to variable shared_experts_input for fusion
+ shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- from vllm.model_executor.layers.fused_moe import fused_experts
- use_ep = expert_map is not None
+
+ attn_metadata = get_forward_context().attn_metadata
+ use_ep = layer.expert_map is not None
+ # unsupported ep now
+ if attn_metadata:
+ deepseek_instance = None
+ for value in attn_metadata.values():
+ if hasattr(value, 'num_prefills') and hasattr(value, 'num_decodes'):
+ deepseek_instance = value
+ break
+ value_types = {type(value).__name__ for value in attn_metadata.values()}
+ is_same_class = len(value_types) == 1
+ if is_same_class:
+ assert deepseek_instance
+ only_decode = (use_ep == False and all(t.num_decodes > 0 and t.num_prefills ==0 for t in list(attn_metadata.values())))
+ else:
+ if deepseek_instance:
+ only_decode = (use_ep == False and deepseek_instance.num_decodes > 0 and deepseek_instance.num_prefills ==0)
+ else:
+ only_decode = False
+ else:
+ only_decode = False
+
if use_ep:
start_eid = layer.ep_rank * layer.local_num_experts
- end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, global_num_experts)
-
-
- top_k = topk_ids.shape[-1]
+ end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, layer.global_num_experts)
+
dtype = x.dtype
- num_tokens, num_experts = router_logits.shape
+ num_tokens = topk_ids.shape[0]
+ num_experts = layer.global_num_experts
+
if use_ep:
hidden_size = x.shape[1]
(
@@ -1264,7 +1277,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
dtype=x.dtype,
)
else:
- expand_tokens = num_tokens * top_k
+ expand_tokens = num_tokens * layer.top_k
(
src_to_dst,
sorted_token_ids,
@@ -1281,108 +1294,1170 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
hidden_states=x,
dst_to_src=sorted_token_ids,
dst_tokens=expand_tokens,
- topk=top_k,
+ topk=layer.top_k,
src_to_dst=src_to_dst,
topk_ids=None,
smooth_scales=layer.w13_input_scale,
)
- i8_hidden_states_align = i8_hidden_states
-
if i8_hidden_states.shape[-1] != layer.w13_weight.shape[-1]:
padding = layer.w13_weight.shape[-1] - i8_hidden_states.shape[-1]
i8_hidden_states_align = torch.nn.functional.pad(i8_hidden_states, (0, padding), mode='constant', value=0)
+ else:
+ i8_hidden_states_align = i8_hidden_states
- # w8a8 group gemm 1
- pt_output_1 = ixfops.moe_w8a8_group_gemm(
- input=i8_hidden_states_align,
- weight=layer.w13_weight,
- i_scales=a_scale,
- w_scales=layer.w13_weight_scale,
- output_dtype=dtype,
- tokens_per_experts=expert_sizes_cpu,
- dst_to_src=None,
- format="TN",
- )
+ if only_decode and self.gemm_format == "NN":
- # act + quant
- pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8(
- input=pt_output_1,
- bias=None,
- smooth_scales=layer.w2_input_scale,
- dst_to_src=sorted_token_ids,
- topk_ids=None,
- act_type="swiglu",
- )
-
- pt_output_2_align = pt_output_2
-
- if pt_output_2.shape[-1] != layer.w2_weight.shape[-1]:
- padding = layer.w2_weight.shape[-1] - pt_output_2.shape[-1]
- pt_output_2_align = torch.nn.functional.pad(pt_output_2, (0, padding), mode='constant', value=0)
-
- # w8a8 group gemm 2 + reorder
- if use_ep:
- pt_output_3 = torch.empty(
- (num_tokens * top_k, hidden_size),
- device=x.device,
- dtype=x.dtype,
+ # expand + reorder + quant
+ i8_hidden_states, a_scale = ixfops.moe_expand_input_dynamic_scaled_int8(
+ hidden_states=x,
+ dst_to_src=sorted_token_ids,
+ dst_tokens=expand_tokens,
+ topk=layer.top_k,
+ src_to_dst=src_to_dst,
+ topk_ids=None,
+ smooth_scales=layer.w13_input_scale,
)
- ixfops.moe_w8a8_group_gemm(
+ if i8_hidden_states.shape[-1] != layer.w13_weight.shape[-2]:
+ padding = layer.w13_weight.shape[-2] - i8_hidden_states.shape[-1]
+ i8_hidden_states_align = torch.nn.functional.pad(i8_hidden_states, (0, padding), mode='constant', value=0)
+ else:
+ i8_hidden_states_align = i8_hidden_states
+
+ # w8a8 group gemm 1
+ pt_output_1 = ixfops.moe_w8a8_group_gemv(
+ input=i8_hidden_states_align,
+ weight=layer.w13_weight,
+ i_scales=a_scale,
+ w_scales=layer.w13_weight_scale,
+ output_dtype=dtype,
+ tokens_per_experts=expert_sizes_gpu,
+ dst_to_src=None,
+ bias=layer.w13_bias,
+ format=0,
+ group_size=self.group_size,
+ )
+
+ # act + quant
+ if layer.activation.value == "swigluoai":
+ pt_output_2, a2_scale = ixfops.activation_swigluoai_dynamic_scaled_int8(
+ input=pt_output_1,
+ bias=None,
+ smooth_scales=layer.w2_input_scale,
+ dst_to_src=sorted_token_ids,
+ topk_ids=None,
+ )
+ else:
+ pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8(
+ input=pt_output_1,
+ bias=None,
+ smooth_scales=layer.w2_input_scale,
+ dst_to_src=sorted_token_ids,
+ topk_ids=None,
+ act_type="swiglu",
+ )
+
+
+ if pt_output_2.shape[-1] != layer.w2_weight.shape[-2]:
+ padding = layer.w2_weight.shape[-2] - pt_output_2.shape[-1]
+ pt_output_2_align = torch.nn.functional.pad(pt_output_2, (0, padding), mode='constant', value=0)
+ else:
+ pt_output_2_align = pt_output_2
+
+ # w8a8 group gemm 2 + reorder
+ pt_output_3 = ixfops.moe_w8a8_group_gemv(
input=pt_output_2_align,
weight=layer.w2_weight,
i_scales=a2_scale,
w_scales=layer.w2_weight_scale,
output_dtype=dtype,
- tokens_per_experts=expert_sizes_cpu,
+ tokens_per_experts=expert_sizes_gpu,
dst_to_src=sorted_token_ids,
- format="TN",
- output=pt_output_3,
+ bias=layer.w2_bias,
+ format=0,
+ group_size=self.group_size,
)
-
- reduce_mask = src_to_dst == -1
+
final_hidden_states = ixfops.moe_output_reduce_sum(
- input=pt_output_3.view(num_tokens, top_k, -1),
+ input=pt_output_3.view(num_tokens, layer.top_k, -1),
topk_weight=topk_weights,
- extra_residual=extra_residual,
- scaling_factor=routed_scaling_factor,
- mask=reduce_mask,
+ scaling_factor=layer.routed_scaling_factor,
+ extra_residual=shared_experts_input,
)
else:
- pt_output_3 = ixfops.moe_w8a8_group_gemm(
+ expert_sizes_cpu = expert_sizes_gpu.cpu()
+ # expand + reorder + quant
+ i8_hidden_states, a_scale = ixfops.moe_expand_input_dynamic_scaled_int8(
+ hidden_states=x,
+ dst_to_src=sorted_token_ids,
+ dst_tokens=expand_tokens,
+ topk=layer.top_k,
+ src_to_dst=src_to_dst,
+ topk_ids=None,
+ smooth_scales=layer.w13_input_scale,
+ )
+
+ if i8_hidden_states.shape[-1] != layer.w13_weight.shape[self.padding_dim]:
+ padding = layer.w13_weight.shape[self.padding_dim] - i8_hidden_states.shape[-1]
+ i8_hidden_states_align = torch.nn.functional.pad(i8_hidden_states, (0, padding), mode='constant', value=0)
+ else:
+ i8_hidden_states_align = i8_hidden_states
+
+ # w8a8 group gemm 1
+ pt_output_1 = ixfops.moe_w8a8_group_gemm(
+ input=i8_hidden_states_align,
+ weight=layer.w13_weight,
+ i_scales=a_scale,
+ w_scales=layer.w13_weight_scale,
+ output_dtype=dtype,
+ tokens_per_experts=expert_sizes_cpu,
+ dst_to_src=None,
+ bias=layer.w13_bias,
+ format=self.gemm_format,
+ )
+
+ # act + quant
+ if layer.activation.value == "swigluoai":
+ pt_output_2, a2_scale = ixfops.activation_swigluoai_dynamic_scaled_int8(
+ input=pt_output_1,
+ bias=None,
+ smooth_scales=layer.w2_input_scale,
+ dst_to_src=sorted_token_ids,
+ topk_ids=None,
+ )
+ else:
+ pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8(
+ input=pt_output_1,
+ bias=None,
+ smooth_scales=layer.w2_input_scale,
+ dst_to_src=sorted_token_ids,
+ topk_ids=None,
+ act_type="swiglu",
+ )
+
+ if pt_output_2.shape[-1] != layer.w2_weight.shape[self.padding_dim]:
+ padding = layer.w2_weight.shape[self.padding_dim] - pt_output_2.shape[-1]
+ pt_output_2_align = torch.nn.functional.pad(pt_output_2, (0, padding), mode='constant', value=0)
+ else:
+ pt_output_2_align = pt_output_2
+
+
+ # w8a8 group gemm 2 + reorder
+ if use_ep:
+ pt_output_3 = torch.empty(
+ (num_tokens * layer.top_k, hidden_size),
+ device=x.device,
+ dtype=x.dtype,
+ )
+ ixfops.moe_w8a8_group_gemm(
+ input=pt_output_2_align,
+ weight=layer.w2_weight,
+ i_scales=a2_scale,
+ w_scales=layer.w2_weight_scale,
+ output_dtype=dtype,
+ tokens_per_experts=expert_sizes_cpu,
+ dst_to_src=sorted_token_ids,
+ bias=layer.w2_bias,
+ format=self.gemm_format,
+ output=pt_output_3,
+ )
+
+ reduce_mask = src_to_dst == -1
+ final_hidden_states = ixfops.moe_output_reduce_sum(
+ input=pt_output_3.view(num_tokens, layer.top_k, -1),
+ topk_weight=topk_weights,
+ scaling_factor=layer.routed_scaling_factor,
+ mask=reduce_mask,
+ )
+ else:
+ pt_output_3 = ixfops.moe_w8a8_group_gemm(
+ input=pt_output_2_align,
+ weight=layer.w2_weight,
+ i_scales=a2_scale,
+ w_scales=layer.w2_weight_scale,
+ output_dtype=dtype,
+ tokens_per_experts=expert_sizes_cpu,
+ dst_to_src=sorted_token_ids,
+ bias=layer.w2_bias,
+ format=self.gemm_format,
+ )
+
+ # mul + reduce_sum
+ final_hidden_states = ixfops.moe_output_reduce_sum(
+ input=pt_output_3.view(num_tokens, layer.top_k, -1),
+ topk_weight=topk_weights,
+ scaling_factor=layer.routed_scaling_factor,
+ extra_residual=shared_experts_input,
+ )
+
+ return final_hidden_states
+
+
+class CompressedTensorsW4A16MoEMethod(CompressedTensorsMoEMethod):
+
+ def __init__(
+ self,
+ weight_quant: QuantizationArgs,
+ input_quant: QuantizationArgs,
+ moe: FusedMoEConfig,
+ layer_name: str | None = None,
+ ):
+ super().__init__(moe)
+ self.has_bias = self.moe.has_bias
+ self.weight_quant = weight_quant
+ self.input_quant = input_quant
+
+ self.pack_factor = 2
+ self.group_size = -1 if self.weight_quant.group_size is None else self.weight_quant.group_size
+ self.weight_symmetric = self.weight_quant.symmetric
+ self.gemm_format = envs.VLLM_W4A8_FORMAT
+ self.format_mapping = {"NN":0,"NT":1,"TN":2}
+ self.version = envs.VLLM_W4A8_VERSION
+ assert self.gemm_format in ["TN","NN"]
+ self.static_input_scales = False
+
+ def create_weights(self, layer: torch.nn.Module, num_experts: int,
+ hidden_size: int, intermediate_size_per_partition: int,
+ params_dtype: torch.dtype, **extra_weight_attrs):
+
+ params_dtype = torch.int8
+ w13_remainder = (hidden_size // self.pack_factor) % 64
+ w2_remainder = (intermediate_size_per_partition // self.pack_factor) % 64
+ if self.gemm_format == "TN":
+ if w13_remainder != 0:
+ w13_shape = (num_experts, 2 * intermediate_size_per_partition, (hidden_size // self.pack_factor) + 64 - w13_remainder)
+ else:
+ w13_shape = (num_experts, 2 * intermediate_size_per_partition, hidden_size // self.pack_factor)
+
+ if w2_remainder != 0:
+ w2_shape = (num_experts, hidden_size, (intermediate_size_per_partition // self.pack_factor) + 64 - w2_remainder)
+ else:
+ w2_shape = (num_experts, hidden_size, intermediate_size_per_partition // self.pack_factor)
+ else:
+ w13_shape = (num_experts, hidden_size, 2 * intermediate_size_per_partition // self.pack_factor)
+ w2_shape = (num_experts, intermediate_size_per_partition, hidden_size // self.pack_factor)
+
+ # WEIGHTS
+ # use process_weights_after_loading to get get right layout if gemm_format is NN
+ w13_weight = torch.nn.Parameter(torch.empty(w13_shape,
+ dtype=params_dtype),
+ requires_grad=False)
+ layer.register_parameter("w13_weight", w13_weight)
+ set_weight_attrs(w13_weight, extra_weight_attrs)
+ if self.gemm_format == "NN":
+ setattr(w13_weight, "shard_dim", 1)
+
+ if self.has_bias:
+ w13_bias = torch.nn.Parameter(
+ torch.zeros(
+ num_experts,
+ 2 * intermediate_size_per_partition,
+ dtype=torch.bfloat16,
+ ),
+ requires_grad=False,
+ )
+ layer.register_parameter("w13_bias", w13_bias)
+ set_weight_attrs(w13_bias, extra_weight_attrs)
+ else:
+ setattr(layer, "w13_bias", None)
+
+ if w2_remainder != 0:
+ w2_weight = torch.nn.Parameter(torch.zeros(w2_shape,
+ dtype=params_dtype),
+ requires_grad=False)
+ else:
+ w2_weight = torch.nn.Parameter(torch.empty(w2_shape,
+ dtype=params_dtype),
+ requires_grad=False)
+ layer.register_parameter("w2_weight", w2_weight)
+ set_weight_attrs(w2_weight, extra_weight_attrs)
+ if self.gemm_format == "NN":
+ setattr(w2_weight, "shard_dim", 0)
+
+ if self.has_bias:
+ w2_bias = torch.nn.Parameter(
+ torch.zeros(
+ num_experts,
+ hidden_size,
+ dtype=torch.bfloat16,
+ ),
+ requires_grad=False,
+ )
+ layer.register_parameter("w2_bias", w2_bias)
+ set_weight_attrs(w2_bias, extra_weight_attrs)
+ else:
+ setattr(layer, "w2_bias", None)
+
+ # WEIGHT_SCALES
+ # Allocate 2 scales for w1 and w3 respectively.
+ # They will be combined to a single scale after weight loading.
+ # The following scale or zero will use permute(0,2,1) to get right layout, init here to avoid rewrite data_loader
+ w13_weight_scale = torch.nn.Parameter(torch.empty(num_experts,
+ 1 if self.version == 2 else 1 if self.group_size == -1 else hidden_size // self.group_size,
+ 2 * intermediate_size_per_partition,
+ dtype=torch.float32),
+ requires_grad=False)
+ layer.register_parameter("w13_weight_scale", w13_weight_scale)
+ setattr(w13_weight_scale, "shard_dim", 1)
+
+ w2_weight_scale = torch.nn.Parameter(torch.empty(num_experts,
+ 1 if self.version == 2 else 1 if self.group_size == -1 else intermediate_size_per_partition // self.group_size,
+ hidden_size,
+ dtype=torch.float32),
+ requires_grad=False)
+ layer.register_parameter("w2_weight_scale", w2_weight_scale)
+ setattr(w2_weight_scale, "shard_dim", 0)
+ # setattr(w2_weight_scale, "load_full_w2", True)
+
+ extra_weight_attrs.update(
+ {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value if self.version == 2 or self.group_size == -1 else FusedMoeWeightScaleSupported.GROUP.value})
+ set_weight_attrs(w13_weight_scale, extra_weight_attrs)
+ set_weight_attrs(w2_weight_scale, extra_weight_attrs)
+
+ if self.version == 2:
+ # INT8 -> INT4 weight scales/zeros
+ if self.group_size != -1:
+ w13_i8_weight_scale = torch.nn.Parameter(torch.empty(num_experts,
+ hidden_size // self.group_size,
+ 2 * intermediate_size_per_partition,
+ dtype=torch.int32),
+ requires_grad=False)
+ layer.register_parameter("w13_i8_weight_scale", w13_i8_weight_scale)
+ setattr(w13_i8_weight_scale, "shard_dim", 1)
+ if not self.weight_symmetric:
+ w13_i8_weight_zero = torch.nn.Parameter(torch.empty(num_experts,
+ 1 if self.group_size == -1 else hidden_size // self.group_size,
+ 2 * intermediate_size_per_partition,
+ dtype=torch.int32),
+ requires_grad=False)
+ layer.register_parameter("w13_i8_weight_zero", w13_i8_weight_zero)
+ setattr(w13_i8_weight_zero, "shard_dim", 1)
+
+ if self.group_size != -1:
+ w2_i8_weight_scale = torch.nn.Parameter(torch.empty(num_experts,
+ intermediate_size_per_partition // self.group_size,
+ hidden_size,
+ dtype=torch.int32),
+ requires_grad=False)
+ layer.register_parameter("w2_i8_weight_scale", w2_i8_weight_scale)
+ setattr(w2_i8_weight_scale, "shard_dim", 0)
+ if not self.weight_symmetric:
+ w2_i8_weight_zero = torch.nn.Parameter(torch.empty(num_experts,
+ 1 if self.group_size == -1 else intermediate_size_per_partition // self.group_size,
+ hidden_size,
+ dtype=torch.int32),
+ requires_grad=False)
+ layer.register_parameter("w2_i8_weight_zero", w2_i8_weight_zero)
+ setattr(w2_i8_weight_zero, "shard_dim", 0)
+
+ # Add the quantization method used (per tensor/grouped/channel)
+ # to ensure the weight scales are loaded in properly
+ extra_weight_attrs.update(
+ {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value if self.group_size == -1 else FusedMoeWeightScaleSupported.GROUP.value})
+
+ if self.version == 2 and self.group_size != -1:
+ set_weight_attrs(w13_i8_weight_scale, extra_weight_attrs)
+ set_weight_attrs(w2_i8_weight_scale, extra_weight_attrs)
+ else:
+ setattr(layer, "w13_i8_weight_scale", None)
+ setattr(layer, "w2_i8_weight_scale", None)
+ if self.version == 2 and not self.weight_symmetric:
+ set_weight_attrs(w13_i8_weight_zero, extra_weight_attrs)
+ set_weight_attrs(w2_i8_weight_zero, extra_weight_attrs)
+ else:
+ setattr(layer, "w13_i8_weight_zero", None)
+ setattr(layer, "w2_i8_weight_zero", None)
+
+ # DO NOT SUPPORT INPUT_SCALES
+ if self.static_input_scales:
+ extra_weight_attrs.update(
+ {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
+ w13_input_scale = torch.nn.Parameter(torch.ones(
+ num_experts, hidden_size, dtype=torch.float32),
+ requires_grad=False)
+ layer.register_parameter("w13_input_scale", w13_input_scale)
+ set_weight_attrs(w13_input_scale, extra_weight_attrs)
+
+ w2_input_scale = torch.nn.Parameter(torch.ones(
+ num_experts, intermediate_size_per_partition, dtype=torch.float32),
+ requires_grad=False)
+ layer.register_parameter("w2_input_scale", w2_input_scale)
+ set_weight_attrs(w2_input_scale, extra_weight_attrs)
+ else:
+ layer.w13_input_scale = None
+ layer.w2_input_scale = None
+
+ self.gemm_format = self.format_mapping[self.gemm_format]
+
+ def get_fused_moe_quant_config(
+ self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
+ return None
+
+ def apply(
+ self,
+ layer: FusedMoE,
+ x: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ # Assign the value of shared_experts_output to variable shared_experts_input for fusion
+ shared_experts_input: torch.Tensor | None,
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
+ attn_metadata = get_forward_context().attn_metadata
+ use_ep = layer.expert_map is not None
+ # unsupported ep now
+ if attn_metadata:
+ deepseek_instance = None
+ for value in attn_metadata.values():
+ if hasattr(value, 'num_prefills') and hasattr(value, 'num_decodes'):
+ deepseek_instance = value
+ break
+ value_types = {type(value).__name__ for value in attn_metadata.values()}
+ is_same_class = len(value_types) == 1
+ if is_same_class:
+ assert deepseek_instance
+ only_decode = (use_ep == False and all(t.num_decodes > 0 and t.num_prefills ==0 for t in list(attn_metadata.values())))
+ else:
+ if deepseek_instance:
+ only_decode = (use_ep == False and deepseek_instance.num_decodes > 0 and deepseek_instance.num_prefills ==0)
+ else:
+ only_decode = False
+ else:
+ only_decode = False
+
+ if use_ep:
+ start_eid = layer.ep_rank * layer.local_num_experts
+ end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, layer.global_num_experts)
+
+ dtype = x.dtype
+ num_tokens = topk_ids.shape[0]
+ num_experts = layer.global_num_experts
+ if use_ep:
+ hidden_size = x.shape[1]
+ (
+ src_to_dst,
+ sorted_token_ids,
+ expert_sizes_gpu,
+ expert_sizes_cpu,
+ expand_tokens,
+ ) = ixfops.moe_compute_token_index_ep(
+ topk_ids=topk_ids,
+ num_experts=num_experts,
+ start_expert_id=start_eid,
+ end_expert_id=end_eid,
+ )
+ if expert_sizes_cpu.sum() == 0:
+ return torch.zeros(
+ (num_tokens, hidden_size),
+ device=x.device,
+ dtype=x.dtype,
+ )
+ else:
+ expand_tokens = num_tokens * layer.top_k
+ (
+ src_to_dst,
+ sorted_token_ids,
+ expert_sizes_gpu,
+ expert_sizes_cpu,
+ ) = ixfops.moe_compute_token_index(
+ topk_ids=topk_ids,
+ num_experts=num_experts,
+ )
+
+ if only_decode and self.gemm_format == 2:
+ i8_hidden_states, a_scale = ixfops.moe_expand_input_dynamic_scaled_int8(
+ hidden_states=x,
+ dst_to_src=sorted_token_ids,
+ dst_tokens=expand_tokens,
+ topk=layer.top_k,
+ src_to_dst=src_to_dst,
+ topk_ids=None,
+ smooth_scales=layer.w13_input_scale,
+ output_format = 1,
+ )
+
+ if i8_hidden_states.shape[-1] != layer.w13_weight.shape[-1] * 2:
+ padding = layer.w13_weight.shape[-1] * 2 - i8_hidden_states.shape[-1]
+ i8_hidden_states_align = torch.nn.functional.pad(i8_hidden_states, (0, padding), mode='constant', value=0)
+ else:
+ i8_hidden_states_align = i8_hidden_states
+
+ pt_output_1 = ixfops.moe_w4a8_group_gemv(
+ input=i8_hidden_states_align,
+ weight=layer.w13_weight,
+ i_scales=a_scale,
+ w_scales=layer.w13_weight_scale,
+ output_dtype=dtype,
+ tokens_per_experts=expert_sizes_gpu,
+ w_i8scales=layer.w13_i8_weight_scale,
+ w_i8zeros=layer.w13_i8_weight_zero,
+ dst_to_src=None,
+ bias=layer.w13_bias,
+ format=self.gemm_format,
+ group_size=self.group_size,
+ )
+
+ if layer.activation.value == "swigluoai":
+ pt_output_2, a2_scale = ixfops.activation_swigluoai_dynamic_scaled_int8(
+ input=pt_output_1,
+ bias=None,
+ smooth_scales=layer.w2_input_scale,
+ dst_to_src=sorted_token_ids,
+ topk_ids=None,
+ output_format = 1,
+ )
+ else:
+ pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8(
+ input=pt_output_1,
+ bias=None,
+ smooth_scales=layer.w2_input_scale,
+ dst_to_src=sorted_token_ids,
+ topk_ids=None,
+ act_type="swiglu",
+ output_format = 1,
+ )
+ if pt_output_2.shape[-1] != layer.w2_weight.shape[-1] * 2:
+ padding = layer.w2_weight.shape[-1] * 2 - pt_output_2.shape[-1]
+ pt_output_2_align = torch.nn.functional.pad(pt_output_2, (0, padding), mode='constant', value=0)
+ else:
+ pt_output_2_align = pt_output_2
+
+ pt_output_3 = ixfops.moe_w4a8_group_gemv(
input=pt_output_2_align,
weight=layer.w2_weight,
i_scales=a2_scale,
w_scales=layer.w2_weight_scale,
output_dtype=dtype,
- tokens_per_experts=expert_sizes_cpu,
+ tokens_per_experts=expert_sizes_gpu,
+ w_i8scales=layer.w2_i8_weight_scale,
+ w_i8zeros=layer.w2_i8_weight_zero,
dst_to_src=sorted_token_ids,
- format="TN",
+ bias=layer.w2_bias,
+ format=self.gemm_format,
+ group_size=self.group_size,
)
-
- # mul + reduce_sum
+ # mul + reduce_sum
final_hidden_states = ixfops.moe_output_reduce_sum(
- input=pt_output_3.view(num_tokens, top_k, -1),
+ input=pt_output_3.view(num_tokens, layer.top_k, -1),
topk_weight=topk_weights,
- extra_residual=extra_residual,
- scaling_factor=routed_scaling_factor
+ scaling_factor=layer.routed_scaling_factor,
+ extra_residual=shared_experts_input,
+ )
+ else:
+ expert_sizes_cpu = expert_sizes_gpu.cpu()
+ # expand + reorder + quant
+ i8_hidden_states, a_scale = ixfops.moe_expand_input_dynamic_scaled_int8(
+ hidden_states=x,
+ dst_to_src=sorted_token_ids,
+ dst_tokens=expand_tokens,
+ topk=layer.top_k,
+ src_to_dst=src_to_dst,
+ topk_ids=None,
+ smooth_scales=layer.w13_input_scale,
)
- return final_hidden_states
- # return fused_experts(
- # hidden_states=x,
- # w1=layer.w13_weight,
- # w2=layer.w2_weight,
- # topk_weights=topk_weights,
- # topk_ids=topk_ids,
- # inplace=not self.moe.disable_inplace,
- # activation=layer.activation,
- # apply_router_weight_on_input=layer.apply_router_weight_on_input,
- # global_num_experts=layer.global_num_experts,
- # expert_map=layer.expert_map,
- # quant_config=self.moe_quant_config,
- # )
+ if i8_hidden_states.shape[-1] != layer.w13_weight.shape[-1] * 2:
+ padding = layer.w13_weight.shape[-1] * 2 - i8_hidden_states.shape[-1]
+ i8_hidden_states_align = torch.nn.functional.pad(i8_hidden_states, (0, padding), mode='constant', value=0)
+ else:
+ i8_hidden_states_align = i8_hidden_states
+
+ # w4a8 group gemm 1
+ pt_output_1 = ixfops.moe_w4a8_group_gemm(
+ input=i8_hidden_states_align,
+ weight=layer.w13_weight,
+ i_scales=a_scale,
+ w_scales=layer.w13_weight_scale,
+ output_dtype=dtype,
+ tokens_per_experts=expert_sizes_cpu,
+ w_i8scales=layer.w13_i8_weight_scale,
+ w_i8zeros=layer.w13_i8_weight_zero,
+ dst_to_src=None,
+ bias=layer.w13_bias,
+ format=self.gemm_format,
+ group_size=self.group_size,
+ version=self.version
+ )
+
+ # act + quant
+ if layer.activation.value == "swigluoai":
+ pt_output_2, a2_scale = ixfops.activation_swigluoai_dynamic_scaled_int8(
+ input=pt_output_1,
+ bias=None,
+ smooth_scales=layer.w2_input_scale,
+ dst_to_src=sorted_token_ids,
+ topk_ids=None,
+ )
+ else:
+ pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8(
+ input=pt_output_1,
+ bias=None,
+ smooth_scales=layer.w2_input_scale,
+ dst_to_src=sorted_token_ids,
+ topk_ids=None,
+ act_type="swiglu",
+ )
+
+ if pt_output_2.shape[-1] != layer.w2_weight.shape[-1] * 2 and self.gemm_format == 2:
+ padding = layer.w2_weight.shape[-1] * 2 - pt_output_2.shape[-1]
+ pt_output_2_align = torch.nn.functional.pad(pt_output_2, (0, padding), mode='constant', value=0)
+ else:
+ pt_output_2_align = pt_output_2
+
+ # w4a8 group gemm 2 + reorder
+ if use_ep:
+ pt_output_3 = torch.empty(
+ (num_tokens * layer.top_k, hidden_size),
+ device=x.device,
+ dtype=x.dtype,
+ )
+
+ ixfops.moe_w4a8_group_gemm(
+ input=pt_output_2_align,
+ weight=layer.w2_weight,
+ i_scales=a2_scale,
+ w_scales=layer.w2_weight_scale,
+ output_dtype=dtype,
+ tokens_per_experts=expert_sizes_cpu,
+ w_i8scales=layer.w2_i8_weight_scale,
+ w_i8zeros=layer.w2_i8_weight_zero,
+ dst_to_src=sorted_token_ids,
+ bias=layer.w2_bias,
+ format=self.gemm_format,
+ group_size=self.group_size,
+ version=self.version,
+ output=pt_output_3,
+ )
+
+ reduce_mask = src_to_dst == -1
+ final_hidden_states = ixfops.moe_output_reduce_sum(
+ input=pt_output_3.view(num_tokens, layer.top_k, -1),
+ topk_weight=topk_weights,
+ scaling_factor=layer.routed_scaling_factor,
+ mask=reduce_mask,
+ )
+ else:
+ pt_output_3 = ixfops.moe_w4a8_group_gemm(
+ input=pt_output_2_align,
+ weight=layer.w2_weight,
+ i_scales=a2_scale,
+ w_scales=layer.w2_weight_scale,
+ output_dtype=dtype,
+ tokens_per_experts=expert_sizes_cpu,
+ w_i8scales=layer.w2_i8_weight_scale,
+ w_i8zeros=layer.w2_i8_weight_zero,
+ dst_to_src=sorted_token_ids,
+ bias=layer.w2_bias,
+ format=self.gemm_format,
+ group_size=self.group_size,
+ version=self.version
+ )
+
+ # mul + reduce_sum
+ final_hidden_states = ixfops.moe_output_reduce_sum(
+ input=pt_output_3.view(num_tokens, layer.top_k, -1),
+ topk_weight=topk_weights,
+ scaling_factor=layer.routed_scaling_factor,
+ extra_residual=shared_experts_input,
+ )
+ return final_hidden_states
+
+class CompressedTensorsW4A8MoEMethod(CompressedTensorsMoEMethod):
+
+ def __init__(
+ self,
+ weight_quant: QuantizationArgs,
+ input_quant: QuantizationArgs,
+ moe: FusedMoEConfig,
+ layer_name: str | None = None,
+ ):
+ super().__init__(moe)
+ self.has_bias = self.moe.has_bias
+ self.weight_quant = weight_quant
+ self.input_quant = input_quant
+
+ self.pack_factor = 2
+ self.group_size = -1 if self.weight_quant.group_size is None else self.weight_quant.group_size
+ self.weight_symmetric = self.weight_quant.symmetric
+ self.gemm_format = envs.VLLM_W4A8_FORMAT
+ self.format_mapping = {"NN":0,"NT":1,"TN":2}
+ self.version = envs.VLLM_W4A8_VERSION
+ assert self.gemm_format in ["TN","NN"]
+
+ if not ((self.weight_quant.strategy == QuantizationStrategy.CHANNEL
+ or self.weight_quant.strategy == QuantizationStrategy.GROUP)
+ and self.input_quant.strategy == QuantizationStrategy.TOKEN):
+ raise ValueError(
+ "For INT4 pack2 Fused MoE layers, only per-channel or group scales"
+ "for weights and per-token scales for activations are supported. Found "
+ f"{self.weight_quant}, {self.input_quant}")
+
+ self.static_input_scales = not self.input_quant.dynamic
+
+
+ def create_weights(self, layer: torch.nn.Module, num_experts: int,
+ hidden_size: int, intermediate_size_per_partition: int,
+ params_dtype: torch.dtype, **extra_weight_attrs):
+
+ params_dtype = torch.int8
+ w13_remainder = (hidden_size // self.pack_factor) % 64
+ w2_remainder = (intermediate_size_per_partition // self.pack_factor) % 64
+ if self.gemm_format == "TN":
+ if w13_remainder != 0:
+ w13_shape = (num_experts, 2 * intermediate_size_per_partition, (hidden_size // self.pack_factor) + 64 - w13_remainder)
+ else:
+ w13_shape = (num_experts, 2 * intermediate_size_per_partition, hidden_size // self.pack_factor)
+
+ if w2_remainder != 0:
+ w2_shape = (num_experts, hidden_size, (intermediate_size_per_partition // self.pack_factor) + 64 - w2_remainder)
+ else:
+ w2_shape = (num_experts, hidden_size, intermediate_size_per_partition // self.pack_factor)
+ else:
+ w13_shape = (num_experts, hidden_size, 2 * intermediate_size_per_partition // self.pack_factor)
+ w2_shape = (num_experts, intermediate_size_per_partition, hidden_size // self.pack_factor)
+
+ # WEIGHTS
+ # use process_weights_after_loading to get get right layout if gemm_format is NN
+ w13_weight = torch.nn.Parameter(torch.empty(w13_shape,
+ dtype=params_dtype),
+ requires_grad=False)
+ layer.register_parameter("w13_weight", w13_weight)
+ set_weight_attrs(w13_weight, extra_weight_attrs)
+ if self.gemm_format == "NN":
+ setattr(w13_weight, "shard_dim", 1)
+
+ if self.has_bias:
+ w13_bias = torch.nn.Parameter(
+ torch.zeros(
+ num_experts,
+ 2 * intermediate_size_per_partition,
+ dtype=torch.bfloat16,
+ ),
+ requires_grad=False,
+ )
+ layer.register_parameter("w13_bias", w13_bias)
+ set_weight_attrs(w13_bias, extra_weight_attrs)
+ else:
+ setattr(layer, "w13_bias", None)
+
+ if w2_remainder != 0:
+ w2_weight = torch.nn.Parameter(torch.zeros(w2_shape,
+ dtype=params_dtype),
+ requires_grad=False)
+ else:
+ w2_weight = torch.nn.Parameter(torch.empty(w2_shape,
+ dtype=params_dtype),
+ requires_grad=False)
+ layer.register_parameter("w2_weight", w2_weight)
+ set_weight_attrs(w2_weight, extra_weight_attrs)
+ if self.gemm_format == "NN":
+ setattr(w2_weight, "shard_dim", 0)
+
+ if self.has_bias:
+ w2_bias = torch.nn.Parameter(
+ torch.zeros(
+ num_experts,
+ hidden_size,
+ dtype=torch.bfloat16,
+ ),
+ requires_grad=False,
+ )
+ layer.register_parameter("w2_bias", w2_bias)
+ set_weight_attrs(w2_bias, extra_weight_attrs)
+ else:
+ setattr(layer, "w2_bias", None)
+
+ # WEIGHT_SCALES
+ # Allocate 2 scales for w1 and w3 respectively.
+ # They will be combined to a single scale after weight loading.
+ # The following scale or zero will use permute(0,2,1) to get right layout, init here to avoid rewrite data_loader
+ w13_weight_scale = torch.nn.Parameter(torch.empty(num_experts,
+ 1 if self.version == 2 else 1 if self.group_size == -1 else hidden_size // self.group_size,
+ 2 * intermediate_size_per_partition,
+ dtype=torch.float32),
+ requires_grad=False)
+ layer.register_parameter("w13_weight_scale", w13_weight_scale)
+ setattr(w13_weight_scale, "shard_dim", 1)
+
+ w2_weight_scale = torch.nn.Parameter(torch.empty(num_experts,
+ 1 if self.version == 2 else 1 if self.group_size == -1 else intermediate_size_per_partition // self.group_size,
+ hidden_size,
+ dtype=torch.float32),
+ requires_grad=False)
+ layer.register_parameter("w2_weight_scale", w2_weight_scale)
+ setattr(w2_weight_scale, "shard_dim", 0)
+ # setattr(w2_weight_scale, "load_full_w2", True)
+
+ extra_weight_attrs.update(
+ {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value if self.version == 2 or self.group_size == -1 else FusedMoeWeightScaleSupported.GROUP.value})
+ set_weight_attrs(w13_weight_scale, extra_weight_attrs)
+ set_weight_attrs(w2_weight_scale, extra_weight_attrs)
+
+ if self.version == 2:
+ # INT8 -> INT4 weight scales/zeros
+ if self.group_size != -1:
+ w13_i8_weight_scale = torch.nn.Parameter(torch.empty(num_experts,
+ hidden_size // self.group_size,
+ 2 * intermediate_size_per_partition,
+ dtype=torch.int32),
+ requires_grad=False)
+ layer.register_parameter("w13_i8_weight_scale", w13_i8_weight_scale)
+ setattr(w13_i8_weight_scale, "shard_dim", 1)
+ if not self.weight_symmetric:
+ w13_i8_weight_zero = torch.nn.Parameter(torch.empty(num_experts,
+ 1 if self.group_size == -1 else hidden_size // self.group_size,
+ 2 * intermediate_size_per_partition,
+ dtype=torch.int32),
+ requires_grad=False)
+ layer.register_parameter("w13_i8_weight_zero", w13_i8_weight_zero)
+ setattr(w13_i8_weight_zero, "shard_dim", 1)
+
+ if self.group_size != -1:
+ w2_i8_weight_scale = torch.nn.Parameter(torch.empty(num_experts,
+ intermediate_size_per_partition // self.group_size,
+ hidden_size,
+ dtype=torch.int32),
+ requires_grad=False)
+ layer.register_parameter("w2_i8_weight_scale", w2_i8_weight_scale)
+ setattr(w2_i8_weight_scale, "shard_dim", 0)
+ if not self.weight_symmetric:
+ w2_i8_weight_zero = torch.nn.Parameter(torch.empty(num_experts,
+ 1 if self.group_size == -1 else intermediate_size_per_partition // self.group_size,
+ hidden_size,
+ dtype=torch.int32),
+ requires_grad=False)
+ layer.register_parameter("w2_i8_weight_zero", w2_i8_weight_zero)
+ setattr(w2_i8_weight_zero, "shard_dim", 0)
+
+ # Add the quantization method used (per tensor/grouped/channel)
+ # to ensure the weight scales are loaded in properly
+ extra_weight_attrs.update(
+ {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value if self.group_size == -1 else FusedMoeWeightScaleSupported.GROUP.value})
+
+ if self.version == 2 and self.group_size != -1:
+ set_weight_attrs(w13_i8_weight_scale, extra_weight_attrs)
+ set_weight_attrs(w2_i8_weight_scale, extra_weight_attrs)
+ else:
+ setattr(layer, "w13_i8_weight_scale", None)
+ setattr(layer, "w2_i8_weight_scale", None)
+ if self.version == 2 and not self.weight_symmetric:
+ set_weight_attrs(w13_i8_weight_zero, extra_weight_attrs)
+ set_weight_attrs(w2_i8_weight_zero, extra_weight_attrs)
+ else:
+ setattr(layer, "w13_i8_weight_zero", None)
+ setattr(layer, "w2_i8_weight_zero", None)
+
+ # DO NOT SUPPORT INPUT_SCALES
+ if self.static_input_scales:
+ extra_weight_attrs.update(
+ {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
+ w13_input_scale = torch.nn.Parameter(torch.ones(
+ num_experts, hidden_size, dtype=torch.float32),
+ requires_grad=False)
+ layer.register_parameter("w13_input_scale", w13_input_scale)
+ set_weight_attrs(w13_input_scale, extra_weight_attrs)
+
+ w2_input_scale = torch.nn.Parameter(torch.ones(
+ num_experts, intermediate_size_per_partition, dtype=torch.float32),
+ requires_grad=False)
+ layer.register_parameter("w2_input_scale", w2_input_scale)
+ set_weight_attrs(w2_input_scale, extra_weight_attrs)
+ else:
+ layer.w13_input_scale = None
+ layer.w2_input_scale = None
+
+ self.gemm_format = self.format_mapping[self.gemm_format]
+
+ def get_fused_moe_quant_config(
+ self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
+ return None
+
+ def apply(
+ self,
+ layer: FusedMoE,
+ x: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ # Assign the value of shared_experts_output to variable shared_experts_input for fusion
+ shared_experts_input: torch.Tensor | None,
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
+ attn_metadata = get_forward_context().attn_metadata
+ use_ep = layer.expert_map is not None
+ # unsupported ep now
+ if attn_metadata:
+ deepseek_instance = None
+ for value in attn_metadata.values():
+ if hasattr(value, 'num_prefills') and hasattr(value, 'num_decodes'):
+ deepseek_instance = value
+ break
+ value_types = {type(value).__name__ for value in attn_metadata.values()}
+ is_same_class = len(value_types) == 1
+ if is_same_class:
+ assert deepseek_instance
+ only_decode = (use_ep == False and all(t.num_decodes > 0 and t.num_prefills ==0 for t in list(attn_metadata.values())))
+ else:
+ if deepseek_instance:
+ only_decode = (use_ep == False and deepseek_instance.num_decodes > 0 and deepseek_instance.num_prefills ==0)
+ else:
+ only_decode = False
+ else:
+ only_decode = False
+
+ if use_ep:
+ start_eid = layer.ep_rank * layer.local_num_experts
+ end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, layer.global_num_experts)
+
+ dtype = x.dtype
+ num_tokens = topk_ids.shape[0]
+ num_experts = layer.global_num_experts
+ if use_ep:
+ hidden_size = x.shape[1]
+ (
+ src_to_dst,
+ sorted_token_ids,
+ expert_sizes_gpu,
+ expert_sizes_cpu,
+ expand_tokens,
+ ) = ixfops.moe_compute_token_index_ep(
+ topk_ids=topk_ids,
+ num_experts=num_experts,
+ start_expert_id=start_eid,
+ end_expert_id=end_eid,
+ )
+ if expert_sizes_cpu.sum() == 0:
+ return torch.zeros(
+ (num_tokens, hidden_size),
+ device=x.device,
+ dtype=x.dtype,
+ )
+ else:
+ expand_tokens = num_tokens * layer.top_k
+ (
+ src_to_dst,
+ sorted_token_ids,
+ expert_sizes_gpu,
+ expert_sizes_cpu,
+ ) = ixfops.moe_compute_token_index(
+ topk_ids=topk_ids,
+ num_experts=num_experts,
+ )
+
+ if only_decode and self.gemm_format == 2:
+ i8_hidden_states, a_scale = ixfops.moe_expand_input_dynamic_scaled_int8(
+ hidden_states=x,
+ dst_to_src=sorted_token_ids,
+ dst_tokens=expand_tokens,
+ topk=layer.top_k,
+ src_to_dst=src_to_dst,
+ topk_ids=None,
+ smooth_scales=layer.w13_input_scale,
+ output_format = 1,
+ )
+
+ if i8_hidden_states.shape[-1] != layer.w13_weight.shape[-1] * 2:
+ padding = layer.w13_weight.shape[-1] * 2 - i8_hidden_states.shape[-1]
+ i8_hidden_states_align = torch.nn.functional.pad(i8_hidden_states, (0, padding), mode='constant', value=0)
+ else:
+ i8_hidden_states_align = i8_hidden_states
+
+ pt_output_1 = ixfops.moe_w4a8_group_gemv(
+ input=i8_hidden_states_align,
+ weight=layer.w13_weight,
+ i_scales=a_scale,
+ w_scales=layer.w13_weight_scale,
+ output_dtype=dtype,
+ tokens_per_experts=expert_sizes_gpu,
+ w_i8scales=layer.w13_i8_weight_scale,
+ w_i8zeros=layer.w13_i8_weight_zero,
+ dst_to_src=None,
+ bias=layer.w13_bias,
+ format=self.gemm_format,
+ group_size=self.group_size,
+ )
+
+ if layer.activation.value == "swigluoai":
+ pt_output_2, a2_scale = ixfops.activation_swigluoai_dynamic_scaled_int8(
+ input=pt_output_1,
+ bias=None,
+ smooth_scales=layer.w2_input_scale,
+ dst_to_src=sorted_token_ids,
+ topk_ids=None,
+ output_format = 1,
+ )
+ elif layer.activation.value == "swiglustep":
+ pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8(
+ input=pt_output_1,
+ bias=None,
+ smooth_scales=layer.w2_input_scale,
+ dst_to_src=sorted_token_ids,
+ topk_ids=None,
+ act_type="swiglustep",
+ output_format = 1,
+ )
+ else:
+ pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8(
+ input=pt_output_1,
+ bias=None,
+ smooth_scales=layer.w2_input_scale,
+ dst_to_src=sorted_token_ids,
+ topk_ids=None,
+ act_type="swiglu",
+ output_format = 1,
+ )
+ if pt_output_2.shape[-1] != layer.w2_weight.shape[-1] * 2:
+ padding = layer.w2_weight.shape[-1] * 2 - pt_output_2.shape[-1]
+ pt_output_2_align = torch.nn.functional.pad(pt_output_2, (0, padding), mode='constant', value=0)
+ else:
+ pt_output_2_align = pt_output_2
+
+ pt_output_3 = ixfops.moe_w4a8_group_gemv(
+ input=pt_output_2_align,
+ weight=layer.w2_weight,
+ i_scales=a2_scale,
+ w_scales=layer.w2_weight_scale,
+ output_dtype=dtype,
+ tokens_per_experts=expert_sizes_gpu,
+ w_i8scales=layer.w2_i8_weight_scale,
+ w_i8zeros=layer.w2_i8_weight_zero,
+ dst_to_src=sorted_token_ids,
+ bias=layer.w2_bias,
+ format=self.gemm_format,
+ group_size=self.group_size,
+ )
+ # mul + reduce_sum
+ final_hidden_states = ixfops.moe_output_reduce_sum(
+ input=pt_output_3.view(num_tokens, layer.top_k, -1),
+ topk_weight=topk_weights,
+ scaling_factor=layer.routed_scaling_factor,
+ extra_residual=shared_experts_input,
+ )
+ else:
+ expert_sizes_cpu = expert_sizes_gpu.cpu()
+ # expand + reorder + quant
+ i8_hidden_states, a_scale = ixfops.moe_expand_input_dynamic_scaled_int8(
+ hidden_states=x,
+ dst_to_src=sorted_token_ids,
+ dst_tokens=expand_tokens,
+ topk=layer.top_k,
+ src_to_dst=src_to_dst,
+ topk_ids=None,
+ smooth_scales=layer.w13_input_scale,
+ )
+
+ if i8_hidden_states.shape[-1] != layer.w13_weight.shape[-1] * 2:
+ padding = layer.w13_weight.shape[-1] * 2 - i8_hidden_states.shape[-1]
+ i8_hidden_states_align = torch.nn.functional.pad(i8_hidden_states, (0, padding), mode='constant', value=0)
+ else:
+ i8_hidden_states_align = i8_hidden_states
+
+ # w4a8 group gemm 1
+ pt_output_1 = moe_w4a8_group_gemm_pad_k(
+ input=i8_hidden_states_align,
+ weight=layer.w13_weight,
+ i_scales=a_scale,
+ w_scales=layer.w13_weight_scale,
+ output_dtype=dtype,
+ tokens_per_experts=expert_sizes_cpu,
+ w_i8scales=layer.w13_i8_weight_scale,
+ w_i8zeros=layer.w13_i8_weight_zero,
+ dst_to_src=None,
+ bias=layer.w13_bias,
+ format=self.gemm_format,
+ group_size=self.group_size,
+ version=self.version
+ )
+
+ # act + quant
+ if layer.activation.value == "swigluoai":
+ pt_output_2, a2_scale = ixfops.activation_swigluoai_dynamic_scaled_int8(
+ input=pt_output_1,
+ bias=None,
+ smooth_scales=layer.w2_input_scale,
+ dst_to_src=sorted_token_ids,
+ topk_ids=None,
+ )
+ elif layer.activation.value == "swiglustep":
+ pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8(
+ input=pt_output_1,
+ bias=None,
+ smooth_scales=layer.w2_input_scale,
+ dst_to_src=sorted_token_ids,
+ topk_ids=None,
+ act_type="swiglustep",
+ )
+ else:
+ pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8(
+ input=pt_output_1,
+ bias=None,
+ smooth_scales=layer.w2_input_scale,
+ dst_to_src=sorted_token_ids,
+ topk_ids=None,
+ act_type="swiglu",
+ )
+
+ if pt_output_2.shape[-1] != layer.w2_weight.shape[-1] * 2 and self.gemm_format == 2:
+ padding = layer.w2_weight.shape[-1] * 2 - pt_output_2.shape[-1]
+ pt_output_2_align = torch.nn.functional.pad(pt_output_2, (0, padding), mode='constant', value=0)
+ else:
+ pt_output_2_align = pt_output_2
+
+ # w4a8 group gemm 2 + reorder
+ if use_ep:
+ pt_output_3 = torch.empty(
+ (num_tokens * layer.top_k, hidden_size),
+ device=x.device,
+ dtype=x.dtype,
+ )
+
+ moe_w4a8_group_gemm_pad_k(
+ input=pt_output_2_align,
+ weight=layer.w2_weight,
+ i_scales=a2_scale,
+ w_scales=layer.w2_weight_scale,
+ output_dtype=dtype,
+ tokens_per_experts=expert_sizes_cpu,
+ w_i8scales=layer.w2_i8_weight_scale,
+ w_i8zeros=layer.w2_i8_weight_zero,
+ dst_to_src=sorted_token_ids,
+ bias=layer.w2_bias,
+ format=self.gemm_format,
+ group_size=self.group_size,
+ version=self.version,
+ output=pt_output_3,
+ )
+
+ reduce_mask = src_to_dst == -1
+ final_hidden_states = ixfops.moe_output_reduce_sum(
+ input=pt_output_3.view(num_tokens, layer.top_k, -1),
+ topk_weight=topk_weights,
+ scaling_factor=layer.routed_scaling_factor,
+ mask=reduce_mask,
+ )
+ else:
+ pt_output_3 = moe_w4a8_group_gemm_pad_k(
+ input=pt_output_2_align,
+ weight=layer.w2_weight,
+ i_scales=a2_scale,
+ w_scales=layer.w2_weight_scale,
+ output_dtype=dtype,
+ tokens_per_experts=expert_sizes_cpu,
+ w_i8scales=layer.w2_i8_weight_scale,
+ w_i8zeros=layer.w2_i8_weight_zero,
+ dst_to_src=sorted_token_ids,
+ bias=layer.w2_bias,
+ format=self.gemm_format,
+ group_size=self.group_size,
+ version=self.version
+ )
+
+ # mul + reduce_sum
+ final_hidden_states = ixfops.moe_output_reduce_sum(
+ input=pt_output_3.view(num_tokens, layer.top_k, -1),
+ topk_weight=topk_weights,
+ scaling_factor=layer.routed_scaling_factor,
+ extra_residual=shared_experts_input,
+ )
+ return final_hidden_states
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
@@ -1806,9 +2881,9 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def select_gemm_impl(
self,
- prepare_finalize: mk.FusedMoEPrepareAndFinalize,
+ prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module,
- ) -> mk.FusedMoEPermuteExpertsUnpermute:
+ ) -> mk.FusedMoEExpertsModular:
assert self.num_bits == 4, "only supporting w4"
layer.w13_weight = layer.w13_weight_packed
layer.w2_weight = layer.w2_weight_packed
@@ -1878,7 +2953,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
- **kwargs
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.kernel_backend == "Marlin"
return fused_marlin_moe(
@@ -2098,9 +3172,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def select_gemm_impl(
self,
- prepare_finalize: mk.FusedMoEPrepareAndFinalize,
+ prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module,
- ) -> mk.FusedMoEPermuteExpertsUnpermute:
+ ) -> mk.FusedMoEExpertsModular:
if self.moe.is_lora_enabled:
assert self.moe_quant_config is not None
from vllm.triton_utils import HAS_TRITON
@@ -2128,7 +3202,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
- **kwargs
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
@@ -2683,7 +3756,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
- ) -> mk.FusedMoEPrepareAndFinalize | None:
+ ) -> mk.FusedMoEPrepareAndFinalizeModular | None:
return super().maybe_make_prepare_finalize(routing_tables)
def get_fused_moe_quant_config(
@@ -2704,9 +3777,9 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def select_gemm_impl(
self,
- prepare_finalize: mk.FusedMoEPrepareAndFinalize,
+ prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module,
- ) -> mk.FusedMoEPermuteExpertsUnpermute:
+ ) -> mk.FusedMoEExpertsModular:
assert self.moe_quant_config is not None
assert (
prepare_finalize.activation_format == FusedMoEActivationFormat.Standard
@@ -2714,7 +3787,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
from vllm.model_executor.layers.fused_moe import CutlassExpertsW4A8Fp8
- experts: FusedMoEPermuteExpertsUnpermute
+ experts: FusedMoEExpertsModular
logger.debug("CutlassExpertsW4A8Fp8(%s)", self.__class__.__name__)
experts = CutlassExpertsW4A8Fp8(
@@ -2746,7 +3819,6 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
- **kwargs
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if layer.enable_eplb:
raise NotImplementedError(
@@ -2785,183 +3857,118 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def supports_eplb(self) -> bool:
return False
-
-# for corex
-class CompressedTensorsW4A8MoEMethod(CompressedTensorsMoEMethod):
+class CompressedTensorsL1OptMoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
- quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig,
):
super().__init__(moe)
- self.quant_config = quant_config
- self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
- "weights")
- self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
- "input_activations")
- self.pack_factor = 2
- self.group_size = -1 if self.weight_quant.group_size is None else self.weight_quant.group_size
- self.weight_symmetric = self.weight_quant.symmetric
- self.gemm_format = envs.VLLM_W4A8_FORMAT
- self.format_mapping = {"NN":0,"NT":1,"TN":2}
- self.version = envs.VLLM_W4A8_VERSION
- assert self.gemm_format in ["TN","NN"]
+ self.has_bias = self.moe.has_bias
- if not ((self.weight_quant.strategy == QuantizationStrategy.CHANNEL
- or self.weight_quant.strategy == QuantizationStrategy.GROUP)
- and self.input_quant.strategy == QuantizationStrategy.TOKEN):
- raise ValueError(
- "For INT4 pack2 Fused MoE layers, only per-channel or group scales"
- "for weights and per-token scales for activations are supported. Found "
- f"{self.weight_quant}, {self.input_quant}")
-
- self.static_input_scales = not self.input_quant.dynamic
-
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
params_dtype = torch.int8
- remainder = (intermediate_size_per_partition // self.pack_factor) % 64
- if self.gemm_format == "TN":
- w13_shape = (num_experts, 2 * intermediate_size_per_partition, hidden_size // self.pack_factor)
- if remainder != 0:
- w2_shape = (num_experts, hidden_size, (intermediate_size_per_partition // self.pack_factor) + 64 - remainder)
- else:
- w2_shape = (num_experts, hidden_size, intermediate_size_per_partition // self.pack_factor)
+
+ w13_remainder = hidden_size % 64
+ w2_remainder = intermediate_size_per_partition % 64
+ if w13_remainder != 0:
+ hidden_size_padded = hidden_size + (64 - w13_remainder)
else:
- w13_shape = (num_experts, hidden_size, 2 * intermediate_size_per_partition // self.pack_factor)
- w2_shape = (num_experts, intermediate_size_per_partition, hidden_size // self.pack_factor)
-
+ hidden_size_padded = hidden_size
+ if w2_remainder != 0:
+ intermediate_size_per_partition_padded = intermediate_size_per_partition + (64 - w2_remainder)
+ else:
+ intermediate_size_per_partition_padded = intermediate_size_per_partition
+
# WEIGHTS
- # use process_weights_after_loading to get get right layout if gemm_format is NN
- w13_weight = torch.nn.Parameter(torch.empty(w13_shape,
- dtype=params_dtype),
+ w13_weight = torch.nn.Parameter(torch.empty(
+ num_experts,
+ 2 * intermediate_size_per_partition,
+ hidden_size_padded,
+ dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
- if self.gemm_format == "NN":
- setattr(w13_weight, "shard_dim", 1)
- if remainder != 0:
- w2_weight = torch.nn.Parameter(torch.zeros(w2_shape,
- dtype=params_dtype),
- requires_grad=False)
+ if self.has_bias:
+ w13_bias = torch.nn.Parameter(
+ torch.zeros(
+ num_experts,
+ 2 * intermediate_size_per_partition,
+ dtype=torch.bfloat16,
+ ),
+ requires_grad=False,
+ )
+ layer.register_parameter("w13_bias", w13_bias)
+ set_weight_attrs(w13_bias, extra_weight_attrs)
else:
- w2_weight = torch.nn.Parameter(torch.empty(w2_shape,
- dtype=params_dtype),
- requires_grad=False)
+ setattr(layer, "w13_bias", None)
+
+ w2_weight = torch.nn.Parameter(torch.empty(
+ num_experts,
+ hidden_size,
+ intermediate_size_per_partition_padded,
+ dtype=params_dtype),
+ requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
- if self.gemm_format == "NN":
- setattr(w2_weight, "shard_dim", 0)
+
+ if self.has_bias:
+ w2_bias = torch.nn.Parameter(
+ torch.zeros(
+ num_experts,
+ hidden_size,
+ dtype=torch.bfloat16,
+ ),
+ requires_grad=False,
+ )
+ layer.register_parameter("w2_bias", w2_bias)
+ set_weight_attrs(w2_bias, extra_weight_attrs)
+ else:
+ setattr(layer, "w2_bias", None)
# WEIGHT_SCALES
- # Allocate 2 scales for w1 and w3 respectively.
- # They will be combined to a single scale after weight loading.
- # The following scale or zero will use permute(0,2,1) to get right layout, init here to avoid rewrite data_loader
- w13_weight_scale = torch.nn.Parameter(torch.empty(num_experts,
- 1 if self.version == 2 else 1 if self.group_size == -1 else hidden_size // self.group_size,
- 2 * intermediate_size_per_partition,
- dtype=torch.float32),
+ w13_weight_scale = torch.nn.Parameter(torch.ones(
+ num_experts,
+ 2 * intermediate_size_per_partition,
+ 1,
+ dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
- setattr(w13_weight_scale, "shard_dim", 1)
-
- w2_weight_scale = torch.nn.Parameter(torch.empty(num_experts,
- 1 if self.version == 2 else 1 if self.group_size == -1 else intermediate_size_per_partition // self.group_size,
- hidden_size,
- dtype=torch.float32),
+ setattr(w13_weight_scale, "shard_dim", 0)
+ w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
+ hidden_size,
+ 1,
+ dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
- setattr(w2_weight_scale, "shard_dim", 0)
- # setattr(w2_weight_scale, "load_full_w2", True)
-
+ setattr(w2_weight_scale, "shard_dim", 1)
+ # Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
- {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value if self.version == 2 or self.group_size == -1 else FusedMoeWeightScaleSupported.GROUP.value})
+ {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
-
- if self.version == 2:
- # INT8 -> INT4 weight scales/zeros
- if self.group_size != -1:
- w13_i8_weight_scale = torch.nn.Parameter(torch.empty(num_experts,
- hidden_size // self.group_size,
- 2 * intermediate_size_per_partition,
- dtype=torch.int32),
- requires_grad=False)
- layer.register_parameter("w13_i8_weight_scale", w13_i8_weight_scale)
- setattr(w13_i8_weight_scale, "shard_dim", 1)
- if not self.weight_symmetric:
- w13_i8_weight_zero = torch.nn.Parameter(torch.empty(num_experts,
- 1 if self.group_size == -1 else hidden_size // self.group_size,
- 2 * intermediate_size_per_partition,
- dtype=torch.int32),
- requires_grad=False)
- layer.register_parameter("w13_i8_weight_zero", w13_i8_weight_zero)
- setattr(w13_i8_weight_zero, "shard_dim", 1)
-
- if self.group_size != -1:
- w2_i8_weight_scale = torch.nn.Parameter(torch.empty(num_experts,
- intermediate_size_per_partition // self.group_size,
- hidden_size,
- dtype=torch.int32),
- requires_grad=False)
- layer.register_parameter("w2_i8_weight_scale", w2_i8_weight_scale)
- setattr(w2_i8_weight_scale, "shard_dim", 0)
- if not self.weight_symmetric:
- w2_i8_weight_zero = torch.nn.Parameter(torch.empty(num_experts,
- 1 if self.group_size == -1 else intermediate_size_per_partition // self.group_size,
- hidden_size,
- dtype=torch.int32),
- requires_grad=False)
- layer.register_parameter("w2_i8_weight_zero", w2_i8_weight_zero)
- setattr(w2_i8_weight_zero, "shard_dim", 0)
-
- # Add the quantization method used (per tensor/grouped/channel)
- # to ensure the weight scales are loaded in properly
- extra_weight_attrs.update(
- {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value if self.group_size == -1 else FusedMoeWeightScaleSupported.GROUP.value})
- if self.version == 2 and self.group_size != -1:
- set_weight_attrs(w13_i8_weight_scale, extra_weight_attrs)
- set_weight_attrs(w2_i8_weight_scale, extra_weight_attrs)
- else:
- setattr(layer, "w13_i8_weight_scale", None)
- setattr(layer, "w2_i8_weight_scale", None)
- if self.version == 2 and not self.weight_symmetric:
- set_weight_attrs(w13_i8_weight_zero, extra_weight_attrs)
- set_weight_attrs(w2_i8_weight_zero, extra_weight_attrs)
- else:
- setattr(layer, "w13_i8_weight_zero", None)
- setattr(layer, "w2_i8_weight_zero", None)
-
- # DO NOT SUPPORT INPUT_SCALES
- if self.static_input_scales:
- extra_weight_attrs.update(
- {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
- w13_input_scale = torch.nn.Parameter(torch.ones(
- num_experts, hidden_size, dtype=torch.float32),
- requires_grad=False)
- layer.register_parameter("w13_input_scale", w13_input_scale)
- set_weight_attrs(w13_input_scale, extra_weight_attrs)
+ layer.w13_input_scale = None
+ layer.w2_input_scale = None
+
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
+ pass
- w2_input_scale = torch.nn.Parameter(torch.ones(
- num_experts, intermediate_size_per_partition, dtype=torch.float32),
- requires_grad=False)
- layer.register_parameter("w2_input_scale", w2_input_scale)
- set_weight_attrs(w2_input_scale, extra_weight_attrs)
- else:
- layer.w13_input_scale = None
- layer.w2_input_scale = None
-
- self.gemm_format = self.format_mapping[self.gemm_format]
-
def get_fused_moe_quant_config(
- self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
- return None
+ self, layer: torch.nn.Module
+ ) -> FusedMoEQuantConfig | None:
+ return int8_w8a8_moe_quant_config(
+ w1_scale=layer.w13_weight_scale,
+ w2_scale=layer.w2_weight_scale,
+ a1_scale=layer.w13_input_scale,
+ a2_scale=layer.w2_input_scale,
+ per_act_token_quant=True,
+ )
def apply(
self,
@@ -2969,45 +3976,23 @@ class CompressedTensorsW4A8MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
- router_logits: torch.Tensor,
- top_k: int,
- expert_map: torch.Tensor | None = None,
- enable_eplb: bool = False,
- extra_residual: torch.Tensor = None,
- routed_scaling_factor: float = 1.0,
- **kwargs
- ) -> torch.Tensor:
- attn_metadata = get_forward_context().attn_metadata
- use_ep = expert_map is not None
- # unsupported ep now
- if attn_metadata:
- if isinstance(attn_metadata, dict):
- deepseek_instance = None
- for value in attn_metadata.values():
- if hasattr(value, 'num_prefills') and hasattr(value, 'num_decodes'):
- deepseek_instance = value
- break
- value_types = {type(value).__name__ for value in attn_metadata.values()}
- is_same_class = len(value_types) == 1
- if is_same_class:
- assert deepseek_instance
- only_decode = (use_ep == False and all(t.num_decodes > 0 and t.num_prefills ==0 for t in list(attn_metadata.values())))
- else:
- if deepseek_instance:
- only_decode = (use_ep == False and deepseek_instance.num_decodes > 0 and deepseek_instance.num_prefills ==0)
- else:
- only_decode = False
- else:
- only_decode = False
+ # Assign the value of shared_experts_output to variable shared_experts_input for fusion
+ shared_experts_input: torch.Tensor | None,
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
+ if layer.enable_eplb:
+ raise NotImplementedError(
+ "EPLB not supported for `CompressedTensorsW8A8Int8MoEMethod` yet."
+ )
- if enable_eplb:
- raise NotImplementedError("EPLB not supported for "
- "`CompressedTensorsW4A8MoEMethod` yet.")
+ use_ep = layer.expert_map is not None
if use_ep:
start_eid = layer.ep_rank * layer.local_num_experts
- end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, global_num_experts)
+ end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, layer.global_num_experts)
+
dtype = x.dtype
- num_tokens, num_experts = router_logits.shape
+ num_tokens = topk_ids.shape[0]
+ num_experts = layer.global_num_experts
+
if use_ep:
hidden_size = x.shape[1]
(
@@ -3029,13 +4014,316 @@ class CompressedTensorsW4A8MoEMethod(CompressedTensorsMoEMethod):
dtype=x.dtype,
)
else:
- expand_tokens = num_tokens * top_k
+ expand_tokens = num_tokens * layer.top_k
(
src_to_dst,
sorted_token_ids,
expert_sizes_gpu,
expert_sizes_cpu,
- ) = torch.ops.ixf_ops.moe_compute_token_index(
+ ) = ixfops.moe_compute_token_index(
+ topk_ids=topk_ids,
+ num_experts=num_experts,
+ )
+ expert_sizes_cpu = expert_sizes_gpu.cpu()
+
+ # expand + reorder + quant
+ i8_hidden_states, a_scale = ixfops.moe_expand_input_dynamic_scaled_int8(
+ hidden_states=x,
+ dst_to_src=sorted_token_ids,
+ dst_tokens=expand_tokens,
+ topk=layer.top_k,
+ src_to_dst=src_to_dst,
+ topk_ids=None,
+ smooth_scales=layer.w13_input_scale,
+ )
+
+ if i8_hidden_states.shape[-1] != layer.w13_weight.shape[-1]:
+ padding = layer.w13_weight.shape[-1] - i8_hidden_states.shape[-1]
+ i8_hidden_states_align = torch.nn.functional.pad(i8_hidden_states, (0, padding), mode='constant', value=0)
+ else:
+ i8_hidden_states_align = i8_hidden_states
+
+ # w8a8 group gemm 1
+ pt_output_1 = ixfops.moe_w8a8_group_gemm(
+ input=i8_hidden_states_align,
+ weight=layer.w13_weight,
+ i_scales=a_scale,
+ w_scales=layer.w13_weight_scale,
+ output_dtype=dtype,
+ tokens_per_experts=expert_sizes_cpu,
+ dst_to_src=None,
+ bias=layer.w13_bias,
+ format="TN",
+ )
+
+ # act + quant
+ if layer.activation.value == "swigluoai":
+ pt_output_2, a2_scale = ixfops.activation_swigluoai_dynamic_scaled_int8(
+ input=pt_output_1,
+ bias=None,
+ smooth_scales=layer.w2_input_scale,
+ dst_to_src=sorted_token_ids,
+ topk_ids=None,
+ )
+ else:
+ pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8(
+ input=pt_output_1,
+ bias=None,
+ smooth_scales=layer.w2_input_scale,
+ dst_to_src=sorted_token_ids,
+ topk_ids=None,
+ act_type="swiglu",
+ )
+
+ if pt_output_2.shape[-1] != layer.w2_weight.shape[-1]:
+ padding = layer.w2_weight.shape[-1] - pt_output_2.shape[-1]
+ pt_output_2_align = torch.nn.functional.pad(pt_output_2, (0, padding), mode='constant', value=0)
+ else:
+ pt_output_2_align = pt_output_2
+
+ # w8a8 group gemm 2 + reorder
+ if use_ep:
+ pt_output_3 = torch.empty(
+ (num_tokens * layer.top_k, hidden_size),
+ device=x.device,
+ dtype=x.dtype,
+ )
+
+ ixfops.moe_w8a8_group_gemm(
+ input=pt_output_2_align,
+ weight=layer.w2_weight,
+ i_scales=a2_scale,
+ w_scales=layer.w2_weight_scale,
+ output_dtype=dtype,
+ tokens_per_experts=expert_sizes_cpu,
+ dst_to_src=sorted_token_ids,
+ bias=layer.w2_bias,
+ format="TN",
+ output=pt_output_3,
+ )
+
+ reduce_mask = src_to_dst == -1
+ final_hidden_states = ixfops.moe_output_reduce_sum(
+ input=pt_output_3.view(num_tokens, layer.top_k, -1),
+ topk_weight=topk_weights,
+ scaling_factor=layer.routed_scaling_factor,
+ mask=reduce_mask,
+ )
+ else:
+ pt_output_3 = ixfops.moe_w8a8_group_gemm(
+ input=pt_output_2_align,
+ weight=layer.w2_weight,
+ i_scales=a2_scale,
+ w_scales=layer.w2_weight_scale,
+ output_dtype=dtype,
+ tokens_per_experts=expert_sizes_cpu,
+ dst_to_src=sorted_token_ids,
+ bias=layer.w2_bias,
+ format="TN",
+ )
+
+ # mul + reduce_sum
+ final_hidden_states = ixfops.moe_output_reduce_sum(
+ input=pt_output_3.view(num_tokens, layer.top_k, -1),
+ topk_weight=topk_weights,
+ scaling_factor=layer.routed_scaling_factor,
+ extra_residual=shared_experts_input,
+ )
+ return final_hidden_states
+
+class CompressedTensorsL2OptMoEMethod(CompressedTensorsMoEMethod):
+
+ def __init__(
+ self,
+ moe: FusedMoEConfig,
+ ):
+ super().__init__(moe)
+ self.has_bias = self.moe.has_bias
+ self.pack_factor = 2
+ self.group_size = -1
+ self.version = 2
+ self.gemm_format = envs.VLLM_W4A8_FORMAT
+ self.format_mapping = {"NN":0,"NT":1,"TN":2}
+ assert self.gemm_format in ["TN","NN"]
+
+
+ def create_weights(self, layer: torch.nn.Module, num_experts: int,
+ hidden_size: int, intermediate_size_per_partition: int,
+ params_dtype: torch.dtype, **extra_weight_attrs):
+
+ params_dtype = torch.int8
+ w13_remainder = (hidden_size // self.pack_factor) % 64
+ w2_remainder = (intermediate_size_per_partition // self.pack_factor) % 64
+ if self.gemm_format == "TN":
+ if w13_remainder != 0:
+ w13_shape = (num_experts, 2 * intermediate_size_per_partition, (hidden_size // self.pack_factor) + 64 - w13_remainder)
+ else:
+ w13_shape = (num_experts, 2 * intermediate_size_per_partition, hidden_size // self.pack_factor)
+
+ if w2_remainder != 0:
+ w2_shape = (num_experts, hidden_size, (intermediate_size_per_partition // self.pack_factor) + 64 - w2_remainder)
+ else:
+ w2_shape = (num_experts, hidden_size, intermediate_size_per_partition // self.pack_factor)
+ else:
+ w13_shape = (num_experts, hidden_size, 2 * intermediate_size_per_partition // self.pack_factor)
+ w2_shape = (num_experts, intermediate_size_per_partition, hidden_size // self.pack_factor)
+
+ # WEIGHTS
+ # use process_weights_after_loading to get get right layout if gemm_format is NN
+ w13_weight = torch.nn.Parameter(torch.empty(w13_shape,
+ dtype=params_dtype),
+ requires_grad=False)
+ layer.register_parameter("w13_weight", w13_weight)
+ set_weight_attrs(w13_weight, extra_weight_attrs)
+ if self.gemm_format == "NN":
+ setattr(w13_weight, "shard_dim", 1)
+
+ if self.has_bias:
+ w13_bias = torch.nn.Parameter(
+ torch.zeros(
+ num_experts,
+ 2 * intermediate_size_per_partition,
+ dtype=torch.bfloat16,
+ ),
+ requires_grad=False,
+ )
+ layer.register_parameter("w13_bias", w13_bias)
+ set_weight_attrs(w13_bias, extra_weight_attrs)
+ else:
+ setattr(layer, "w13_bias", None)
+
+ if w2_remainder != 0:
+ w2_weight = torch.nn.Parameter(torch.zeros(w2_shape,
+ dtype=params_dtype),
+ requires_grad=False)
+ else:
+ w2_weight = torch.nn.Parameter(torch.empty(w2_shape,
+ dtype=params_dtype),
+ requires_grad=False)
+ layer.register_parameter("w2_weight", w2_weight)
+ set_weight_attrs(w2_weight, extra_weight_attrs)
+ if self.gemm_format == "NN":
+ setattr(w2_weight, "shard_dim", 0)
+
+ if self.has_bias:
+ w2_bias = torch.nn.Parameter(
+ torch.zeros(
+ num_experts,
+ hidden_size,
+ dtype=torch.bfloat16,
+ ),
+ requires_grad=False,
+ )
+ layer.register_parameter("w2_bias", w2_bias)
+ set_weight_attrs(w2_bias, extra_weight_attrs)
+ else:
+ setattr(layer, "w2_bias", None)
+
+ # WEIGHT_SCALES
+ # Allocate 2 scales for w1 and w3 respectively.
+ # They will be combined to a single scale after weight loading.
+ # The following scale or zero will use permute(0,2,1) to get right layout, init here to avoid rewrite data_loader
+ w13_weight_scale = torch.nn.Parameter(torch.empty(num_experts,
+ 1,
+ 2 * intermediate_size_per_partition,
+ dtype=torch.float32),
+ requires_grad=False)
+ layer.register_parameter("w13_weight_scale", w13_weight_scale)
+ setattr(w13_weight_scale, "shard_dim", 1)
+
+ w2_weight_scale = torch.nn.Parameter(torch.empty(num_experts,
+ 1,
+ hidden_size,
+ dtype=torch.float32),
+ requires_grad=False)
+ layer.register_parameter("w2_weight_scale", w2_weight_scale)
+ setattr(w2_weight_scale, "shard_dim", 0)
+
+ extra_weight_attrs.update({"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
+ set_weight_attrs(w13_weight_scale, extra_weight_attrs)
+ set_weight_attrs(w2_weight_scale, extra_weight_attrs)
+
+ setattr(layer, "w13_i8_weight_scale", None)
+ setattr(layer, "w2_i8_weight_scale", None)
+ setattr(layer, "w13_i8_weight_zero", None)
+ setattr(layer, "w2_i8_weight_zero", None)
+
+ layer.w13_input_scale = None
+ layer.w2_input_scale = None
+
+ self.gemm_format = self.format_mapping[self.gemm_format]
+
+ def get_fused_moe_quant_config(
+ self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
+ return None
+
+ def apply(
+ self,
+ layer: FusedMoE,
+ x: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ # Assign the value of shared_experts_output to variable shared_experts_input for fusion
+ shared_experts_input: torch.Tensor | None,
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
+ attn_metadata = get_forward_context().attn_metadata
+ use_ep = layer.expert_map is not None
+ # unsupported ep now
+ if attn_metadata:
+ deepseek_instance = None
+ for value in attn_metadata.values():
+ if hasattr(value, 'num_prefills') and hasattr(value, 'num_decodes'):
+ deepseek_instance = value
+ break
+ value_types = {type(value).__name__ for value in attn_metadata.values()}
+ is_same_class = len(value_types) == 1
+ if is_same_class:
+ assert deepseek_instance
+ only_decode = (use_ep == False and all(t.num_decodes > 0 and t.num_prefills ==0 for t in list(attn_metadata.values())))
+ else:
+ if deepseek_instance:
+ only_decode = (use_ep == False and deepseek_instance.num_decodes > 0 and deepseek_instance.num_prefills ==0)
+ else:
+ only_decode = False
+ else:
+ only_decode = False
+ if layer.enable_eplb:
+ raise NotImplementedError("EPLB not supported for `CompressedTensorsW4A8MoEMethod` yet.")
+ if use_ep:
+ start_eid = layer.ep_rank * layer.local_num_experts
+ end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, layer.global_num_experts)
+
+ dtype = x.dtype
+ num_tokens = topk_ids.shape[0]
+ num_experts = layer.global_num_experts
+ if use_ep:
+ hidden_size = x.shape[1]
+ (
+ src_to_dst,
+ sorted_token_ids,
+ expert_sizes_gpu,
+ expert_sizes_cpu,
+ expand_tokens,
+ ) = ixfops.moe_compute_token_index_ep(
+ topk_ids=topk_ids,
+ num_experts=num_experts,
+ start_expert_id=start_eid,
+ end_expert_id=end_eid,
+ )
+ if expert_sizes_cpu.sum() == 0:
+ return torch.zeros(
+ (num_tokens, hidden_size),
+ device=x.device,
+ dtype=x.dtype,
+ )
+ else:
+ expand_tokens = num_tokens * layer.top_k
+ (
+ src_to_dst,
+ sorted_token_ids,
+ expert_sizes_gpu,
+ expert_sizes_cpu,
+ ) = ixfops.moe_compute_token_index(
topk_ids=topk_ids,
num_experts=num_experts,
)
@@ -3045,15 +4333,21 @@ class CompressedTensorsW4A8MoEMethod(CompressedTensorsMoEMethod):
hidden_states=x,
dst_to_src=sorted_token_ids,
dst_tokens=expand_tokens,
- topk=top_k,
+ topk=layer.top_k,
src_to_dst=src_to_dst,
topk_ids=None,
smooth_scales=layer.w13_input_scale,
output_format = 1,
)
+ if i8_hidden_states.shape[-1] != layer.w13_weight.shape[-1] * 2:
+ padding = layer.w13_weight.shape[-1] * 2 - i8_hidden_states.shape[-1]
+ i8_hidden_states_align = torch.nn.functional.pad(i8_hidden_states, (0, padding), mode='constant', value=0)
+ else:
+ i8_hidden_states_align = i8_hidden_states
+
pt_output_1 = ixfops.moe_w4a8_group_gemv(
- input=i8_hidden_states,
+ input=i8_hidden_states_align,
weight=layer.w13_weight,
i_scales=a_scale,
w_scales=layer.w13_weight_scale,
@@ -3062,19 +4356,30 @@ class CompressedTensorsW4A8MoEMethod(CompressedTensorsMoEMethod):
w_i8scales=layer.w13_i8_weight_scale,
w_i8zeros=layer.w13_i8_weight_zero,
dst_to_src=None,
+ bias=layer.w13_bias,
format=self.gemm_format,
group_size=self.group_size,
)
- pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8(
- input=pt_output_1,
- bias=None,
- smooth_scales=layer.w2_input_scale,
- dst_to_src=sorted_token_ids,
- topk_ids=None,
- act_type="swiglu",
- output_format = 1,
- )
+ if layer.activation.value == "swigluoai":
+ pt_output_2, a2_scale = ixfops.activation_swigluoai_dynamic_scaled_int8(
+ input=pt_output_1,
+ bias=None,
+ smooth_scales=layer.w2_input_scale,
+ dst_to_src=sorted_token_ids,
+ topk_ids=None,
+ output_format = 1,
+ )
+ else:
+ pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8(
+ input=pt_output_1,
+ bias=None,
+ smooth_scales=layer.w2_input_scale,
+ dst_to_src=sorted_token_ids,
+ topk_ids=None,
+ act_type="swiglu",
+ output_format = 1,
+ )
if pt_output_2.shape[-1] != layer.w2_weight.shape[-1] * 2:
padding = layer.w2_weight.shape[-1] * 2 - pt_output_2.shape[-1]
@@ -3092,32 +4397,39 @@ class CompressedTensorsW4A8MoEMethod(CompressedTensorsMoEMethod):
w_i8scales=layer.w2_i8_weight_scale,
w_i8zeros=layer.w2_i8_weight_zero,
dst_to_src=sorted_token_ids,
+ bias=layer.w2_bias,
format=self.gemm_format,
group_size=self.group_size,
)
# mul + reduce_sum
final_hidden_states = ixfops.moe_output_reduce_sum(
- input=pt_output_3.view(num_tokens, top_k, -1),
+ input=pt_output_3.view(num_tokens, layer.top_k, -1),
topk_weight=topk_weights,
- extra_residual=extra_residual,
- scaling_factor=routed_scaling_factor
- )
+ scaling_factor=layer.routed_scaling_factor,
+ extra_residual=shared_experts_input,
+ )
else:
expert_sizes_cpu = expert_sizes_gpu.cpu()
# expand + reorder + quant
- i8_hidden_states, a_scale = torch.ops.ixf_ops.moe_expand_input_dynamic_scaled_int8(
+ i8_hidden_states, a_scale = ixfops.moe_expand_input_dynamic_scaled_int8(
hidden_states=x,
dst_to_src=sorted_token_ids,
dst_tokens=expand_tokens,
- topk=top_k,
+ topk=layer.top_k,
src_to_dst=src_to_dst,
topk_ids=None,
smooth_scales=layer.w13_input_scale,
)
- # w4a8 group gemm 1
- pt_output_1 = torch.ops.ixf_ops.moe_w4a8_group_gemm(
- input=i8_hidden_states,
+ if i8_hidden_states.shape[-1] != layer.w13_weight.shape[-1] * 2:
+ padding = layer.w13_weight.shape[-1] * 2 - i8_hidden_states.shape[-1]
+ i8_hidden_states_align = torch.nn.functional.pad(i8_hidden_states, (0, padding), mode='constant', value=0)
+ else:
+ i8_hidden_states_align = i8_hidden_states
+
+
+ pt_output_1 = moe_w4a8_group_gemm_pad_k(
+ input=i8_hidden_states_align,
weight=layer.w13_weight,
i_scales=a_scale,
w_scales=layer.w13_weight_scale,
@@ -3126,22 +4438,32 @@ class CompressedTensorsW4A8MoEMethod(CompressedTensorsMoEMethod):
w_i8scales=layer.w13_i8_weight_scale,
w_i8zeros=layer.w13_i8_weight_zero,
dst_to_src=None,
+ bias=layer.w13_bias,
format=self.gemm_format,
group_size=self.group_size,
version=self.version
)
# act + quant
- pt_output_2, a2_scale = torch.ops.ixf_ops.activation_dynamic_scaled_int8(
- input=pt_output_1,
- bias=None,
- smooth_scales=layer.w2_input_scale,
- dst_to_src=sorted_token_ids,
- topk_ids=None,
- act_type="swiglu",
- )
+ if layer.activation.value == "swigluoai":
+ pt_output_2, a2_scale = ixfops.activation_swigluoai_dynamic_scaled_int8(
+ input=pt_output_1,
+ bias=None,
+ smooth_scales=layer.w2_input_scale,
+ dst_to_src=sorted_token_ids,
+ topk_ids=None,
+ )
+ else:
+ pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8(
+ input=pt_output_1,
+ bias=None,
+ smooth_scales=layer.w2_input_scale,
+ dst_to_src=sorted_token_ids,
+ topk_ids=None,
+ act_type="swiglu",
+ )
- if pt_output_2.shape[-1] != layer.w2_weight.shape[-1] * 2:
+ if pt_output_2.shape[-1] != layer.w2_weight.shape[-1] * 2 and self.gemm_format == 2:
padding = layer.w2_weight.shape[-1] * 2 - pt_output_2.shape[-1]
pt_output_2_align = torch.nn.functional.pad(pt_output_2, (0, padding), mode='constant', value=0)
else:
@@ -3150,12 +4472,12 @@ class CompressedTensorsW4A8MoEMethod(CompressedTensorsMoEMethod):
# w4a8 group gemm 2 + reorder
if use_ep:
pt_output_3 = torch.empty(
- (num_tokens * top_k, hidden_size),
+ (num_tokens * layer.top_k, hidden_size),
device=x.device,
dtype=x.dtype,
)
- torch.ops.ixf_ops.moe_w4a8_group_gemm(
+ moe_w4a8_group_gemm_pad_k(
input=pt_output_2_align,
weight=layer.w2_weight,
i_scales=a2_scale,
@@ -3165,6 +4487,7 @@ class CompressedTensorsW4A8MoEMethod(CompressedTensorsMoEMethod):
w_i8scales=layer.w2_i8_weight_scale,
w_i8zeros=layer.w2_i8_weight_zero,
dst_to_src=sorted_token_ids,
+ bias=layer.w2_bias,
format=self.gemm_format,
group_size=self.group_size,
version=self.version,
@@ -3172,15 +4495,14 @@ class CompressedTensorsW4A8MoEMethod(CompressedTensorsMoEMethod):
)
reduce_mask = src_to_dst == -1
- final_hidden_states = torch.ops.ixf_ops.moe_output_reduce_sum(
- input=pt_output_3.view(num_tokens, top_k, -1),
- topk_weight=topk_weight,
- extra_residual=extra_residual,
- scaling_factor=routed_scaling_factor,
+ final_hidden_states = ixfops.moe_output_reduce_sum(
+ input=pt_output_3.view(num_tokens, layer.top_k, -1),
+ topk_weight=topk_weights,
+ scaling_factor=layer.routed_scaling_factor,
mask=reduce_mask,
)
else:
- pt_output_3 = torch.ops.ixf_ops.moe_w4a8_group_gemm(
+ pt_output_3 = moe_w4a8_group_gemm_pad_k(
input=pt_output_2_align,
weight=layer.w2_weight,
i_scales=a2_scale,
@@ -3190,16 +4512,17 @@ class CompressedTensorsW4A8MoEMethod(CompressedTensorsMoEMethod):
w_i8scales=layer.w2_i8_weight_scale,
w_i8zeros=layer.w2_i8_weight_zero,
dst_to_src=sorted_token_ids,
+ bias=layer.w2_bias,
format=self.gemm_format,
group_size=self.group_size,
version=self.version
)
# mul + reduce_sum
- final_hidden_states = torch.ops.ixf_ops.moe_output_reduce_sum(
- input=pt_output_3.view(num_tokens, top_k, -1),
+ final_hidden_states = ixfops.moe_output_reduce_sum(
+ input=pt_output_3.view(num_tokens, layer.top_k, -1),
topk_weight=topk_weights,
- extra_residual=extra_residual,
- scaling_factor=routed_scaling_factor
- )
- return final_hidden_states
\ No newline at end of file
+ scaling_factor=layer.routed_scaling_factor,
+ extra_residual=shared_experts_input,
+ )
+ return final_hidden_states
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
index 8157182..c9dd98d 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
@@ -8,7 +8,7 @@ from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int
from .compressed_tensors_w4a16_mxfp4 import CompressedTensorsW4A16Mxfp4
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
-from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8, CompressedTensorsW4A8Int8
+from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS, CompressedTensorsWNA16
@@ -28,5 +28,4 @@ __all__ = [
"CompressedTensorsW4A4Fp4",
"CompressedTensorsW4A8Int",
"CompressedTensorsW4A8Fp8",
- "CompressedTensorsW4A8Int8"
]
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
index e660fc4..ed919a7 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
@@ -25,11 +25,18 @@ logger = init_logger(__name__)
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
def __init__(
- self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
+ self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool, is_w4a8_linear: bool = False
):
- self.strategy = strategy
+ import vllm.envs as env
+ if env.VLLM_MIX_QUANTIZATION_TYPE == "TENSOR":
+ self.strategy = QuantizationStrategy.TENSOR
+ elif env.VLLM_MIX_QUANTIZATION_TYPE == "CHANNEL":
+ self.strategy = QuantizationStrategy.CHANNEL
+ else:
+ self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.input_symmetric = input_symmetric
+ self.is_w4a8_linear = is_w4a8_linear
@classmethod
def get_min_capability(cls) -> int:
@@ -53,16 +60,32 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
input_symmetric=self.input_symmetric,
module_name=self.__class__.__name__,
)
+ remainder = input_size_per_partition % 64
+ if remainder != 0:
+ input_size_per_partition_padded = input_size_per_partition + (64 - remainder)
+ else:
+ input_size_per_partition_padded = input_size_per_partition
# WEIGHT
- weight = ModelWeightParameter(
- data=torch.empty(
- sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
- ),
- input_dim=1,
- output_dim=0,
- weight_loader=weight_loader,
- )
+ if self.is_w4a8_linear:
+ # only "NN" is supported
+ weight = ModelWeightParameter(data=torch.empty(
+ input_size_per_partition_padded,
+ sum(output_partition_sizes) // 2,
+ dtype=torch.int8),
+ input_dim=0,
+ output_dim=1,
+ weight_loader=weight_loader,
+ )
+ else:
+ weight = ModelWeightParameter(data=torch.empty(
+ sum(output_partition_sizes),
+ input_size_per_partition_padded,
+ dtype=torch.int8),
+ input_dim=1,
+ output_dim=0,
+ weight_loader=weight_loader,
+ )
layer.register_parameter("weight", weight)
@@ -109,104 +132,4 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
) -> torch.Tensor:
- return self.kernel.apply_weights(layer, x, bias)
-
-
-
-class CompressedTensorsW4A8Int8(CompressedTensorsScheme):
- def __init__(
- self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
- ):
- self.strategy = strategy
- self.is_static_input_scheme = is_static_input_scheme
- self.input_symmetric = input_symmetric
-
- @classmethod
- def get_min_capability(cls) -> int:
- # turing and up
- return 75
-
- def create_weights(
- self,
- layer: torch.nn.Module,
- output_partition_sizes: list[int],
- input_size_per_partition: int,
- params_dtype: torch.dtype,
- weight_loader: Callable,
- **kwargs,
- ):
- layer.logical_widths = output_partition_sizes
-
- self.kernel = init_int8_linear_kernel(
- is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
- is_static_input_scheme=self.is_static_input_scheme,
- input_symmetric=self.input_symmetric,
- module_name=self.__class__.__name__,
- )
-
- # WEIGHT
- # weight = ModelWeightParameter(
- # data=torch.empty(
- # sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
- # ),
- # input_dim=1,
- # output_dim=0,
- # weight_loader=weight_loader,
- # )
- weight = ModelWeightParameter(
- data=torch.empty(
- input_size_per_partition,
- sum(output_partition_sizes) // 2,
- dtype=torch.int8
- ),
- input_dim=0,
- output_dim=1,
- weight_loader=weight_loader,
- )
-
- layer.register_parameter("weight", weight)
-
- # WEIGHT SCALE
- if self.strategy == QuantizationStrategy.CHANNEL:
- weight_scale = ChannelQuantScaleParameter(
- data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
- output_dim=0,
- weight_loader=weight_loader,
- )
- else:
- assert self.strategy == QuantizationStrategy.TENSOR
- weight_scale = PerTensorScaleParameter(
- data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
- weight_loader=weight_loader,
- )
- layer.register_parameter("weight_scale", weight_scale)
-
- # INPUT SCALE
- input_zero_point = None
- input_scale = None
- if self.is_static_input_scheme:
- input_scale = BasevLLMParameter(
- data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
- )
- if not self.input_symmetric:
- # Note: compressed-tensors stores the zp using the same dtype
- # as the weights
- # AZP loaded as int8 but used as int32
- input_zero_point = BasevLLMParameter(
- data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
- )
-
- layer.register_parameter("input_zero_point", input_zero_point)
- layer.register_parameter("input_scale", input_scale)
- if not hasattr(layer, "azp_adj"):
- layer.register_parameter("azp_adj", None)
-
- # Checkpoints are serialized in compressed-tensors format, which is
- # different from the format the kernel may want. Handle repacking here.
- def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
- self.kernel.process_weights_after_loading(layer)
-
- def apply_weights(
- self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
- ) -> torch.Tensor:
- return self.kernel.apply_weights(layer, x, bias)
\ No newline at end of file
+ return self.kernel.apply_weights(layer, x, bias, self.is_w4a8_linear)
diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py
index e3174ba..5101347 100644
--- a/vllm/model_executor/layers/quantization/fp8.py
+++ b/vllm/model_executor/layers/quantization/fp8.py
@@ -23,17 +23,13 @@ from vllm.model_executor.layers.batch_invariant import (
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEMethodBase,
- FusedMoEPermuteExpertsUnpermute,
- FusedMoEPrepareAndFinalize,
FusedMoeWeightScaleSupported,
- MoEActivation,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
- Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
make_fp8_moe_quant_config,
@@ -50,9 +46,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
-from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
- apply_fi_trtllm_fp8_per_tensor_moe,
-)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
create_fp8_input_scale,
@@ -860,14 +853,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale)
replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale)
- # Setup modular kernel for TP case and naive DP/EP case.
- # In non-naive DP/EP case, we will create a ModularKernelMethod.
- # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
- # in both cases.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
if self.moe_quant_config:
assert self.experts_cls is not None
- self.moe_mk = make_fp8_moe_kernel(
+ self.moe_kernel = make_fp8_moe_kernel(
moe_quant_config=self.moe_quant_config,
moe_config=self.moe,
fp8_backend=self.fp8_backend,
@@ -930,29 +919,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
- ) -> mk.FusedMoEPrepareAndFinalize | None:
+ ) -> mk.FusedMoEPrepareAndFinalizeModular | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
)
- def select_gemm_impl(
- self,
- prepare_finalize: FusedMoEPrepareAndFinalize,
- layer: torch.nn.Module,
- ) -> FusedMoEPermuteExpertsUnpermute:
- raise ValueError(
- f"{self.__class__.__name__} uses the new modular kernel initialization "
- "logic. This function should not be called."
- )
-
- def get_fused_moe_quant_config(
- self, layer: torch.nn.Module
- ) -> FusedMoEQuantConfig | None:
- # TRTLLM does not use Modular Kernel.
- if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
- return None
-
+ def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
w1_scale = getattr(layer, f"w13_{self.weight_scale_name}")
w2_scale = getattr(layer, f"w2_{self.weight_scale_name}")
a1_scale = layer.w13_input_scale
@@ -983,10 +956,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def supports_eplb(self) -> bool:
return True
- @property
- def is_monolithic(self) -> bool:
- return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
-
def apply_monolithic(
self,
layer: FusedMoE,
@@ -994,50 +963,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
- assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
-
- # TODO(rob): convert this to MK.
- if layer.enable_eplb:
- raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.")
- assert layer.activation == MoEActivation.SILU, (
- f"Expected 'silu' activation but got {layer.activation}"
+ assert self.moe_kernel is not None
+ return self.moe_kernel.apply_monolithic(
+ x,
+ layer.w13_weight,
+ layer.w2_weight,
+ router_logits,
+ activation=layer.activation,
+ global_num_experts=layer.global_num_experts,
+ expert_map=layer.expert_map,
+ apply_router_weight_on_input=layer.apply_router_weight_on_input,
+ num_expert_group=layer.num_expert_group,
+ topk_group=layer.topk_group,
+ e_score_correction_bias=layer.e_score_correction_bias,
+ routed_scaling_factor=layer.routed_scaling_factor,
)
- if self.block_quant:
- import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
-
- return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
- routing_logits=router_logits,
- routing_bias=layer.e_score_correction_bias,
- x=x,
- w13_weight=layer.w13_weight,
- w13_weight_scale_inv=layer.w13_weight_scale_inv,
- w2_weight=layer.w2_weight,
- w2_weight_scale_inv=layer.w2_weight_scale_inv,
- global_num_experts=layer.global_num_experts,
- top_k=layer.top_k,
- num_expert_group=layer.num_expert_group,
- topk_group=layer.topk_group,
- intermediate_size=layer.intermediate_size_per_partition,
- expert_offset=layer.ep_rank * layer.local_num_experts,
- local_num_experts=layer.local_num_experts,
- block_shape=self.weight_block_size,
- routing_method_type=layer.routing_method_type,
- routed_scaling=layer.routed_scaling_factor,
- )
- else:
- return apply_fi_trtllm_fp8_per_tensor_moe(
- layer=layer,
- hidden_states=x,
- router_logits=router_logits,
- routing_bias=layer.e_score_correction_bias,
- global_num_experts=layer.global_num_experts,
- top_k=layer.top_k,
- num_expert_group=layer.num_expert_group,
- topk_group=layer.topk_group,
- apply_router_weight_on_input=layer.apply_router_weight_on_input,
- )
-
def apply(
self,
layer: FusedMoE,
@@ -1046,9 +987,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_ids: torch.Tensor,
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- assert self.moe_mk is not None
assert not self.is_monolithic
- return self.moe_mk(
+ assert self.moe_kernel is not None
+ return self.moe_kernel.apply(
x,
layer.w13_weight,
layer.w2_weight,
diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py
index 8802334..c7dd67d 100644
--- a/vllm/model_executor/layers/quantization/gguf.py
+++ b/vllm/model_executor/layers/quantization/gguf.py
@@ -7,6 +7,7 @@ from typing import Any
import gguf
import torch
+import torch.nn.functional as F
from gguf import GGMLQuantizationType as WeightType
from torch.nn.parameter import Parameter, UninitializedParameter
@@ -234,7 +235,7 @@ try:
op_func=_fused_mul_mat_gguf,
fake_impl=_fused_mul_mat_gguf_fake,
)
- fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf
+ fused_mul_mat_gguf = _fused_mul_mat_gguf
except AttributeError as error:
raise error
@@ -365,7 +366,7 @@ try:
op_func=_fused_moe_gguf,
fake_impl=_fused_moe_gguf_fake,
)
- fused_moe_gguf = torch.ops.vllm._fused_moe_gguf
+ fused_moe_gguf = _fused_moe_gguf
except AttributeError as error:
raise error
@@ -410,7 +411,7 @@ try:
op_func=_apply_gguf_embedding,
fake_impl=_apply_gguf_embedding_fake,
)
- apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding
+ apply_gguf_embedding = _apply_gguf_embedding
except AttributeError as error:
raise error
@@ -451,6 +452,9 @@ class GGUFLinearMethod(LinearMethodBase):
"data_container": [],
"shard_id": [],
"shard_id_map": {},
+ "params_dtype": params_dtype,
+ "input_size_per_partition" :input_size_per_partition, # restore shape for qkv and merge
+ "output_partition_sizes" :output_partition_sizes,
},
)
set_weight_attrs(qweight, extra_weight_attrs)
@@ -664,6 +668,10 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
"""
def embedding(self, layer: torch.nn.Module, x: torch.Tensor) -> torch.Tensor:
+ weight = layer.weight
+ return F.embedding(x, weight)
+
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
qweight = layer.qweight
qweight_type = layer.qweight_type.weight_type
hidden_size = qweight.tensor_shape[1]
diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py
index 2cc86a9..dc91f17 100644
--- a/vllm/model_executor/layers/quantization/gptq.py
+++ b/vllm/model_executor/layers/quantization/gptq.py
@@ -128,7 +128,7 @@ class GPTQConfig(QuantizationConfig):
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
- return [torch.half, torch.bfloat16]
+ return [torch.bfloat16, torch.half]
@classmethod
# Need to figure it out
diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py
index d7b2a36..5dd352c 100644
--- a/vllm/model_executor/layers/quantization/gptq_marlin.py
+++ b/vllm/model_executor/layers/quantization/gptq_marlin.py
@@ -59,9 +59,164 @@ from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.transformers_utils.config import get_safetensors_params_metadata
from vllm.utils.collection_utils import is_list_of
+import ixformer.inference.functions as ixfops
logger = init_logger(__name__)
+#[B,K//8,N] ->[B,K,N]
+# less memmory
+def unpack_k_batch_opt(packed_w: torch.Tensor, num_bits: int = 4, chunk_size: int = 2) -> torch.Tensor:
+ """
+ Memory-efficient unpacking for 3D tensor.
+ Converts [B, K // pack_factor, N] int32 tensor → [B, K, N] int8 tensor,
+ without broadcasting huge intermediate tensors (avoids OOM).
+
+ Args:
+ packed_w: torch.int32 tensor of shape [B, K // pack_factor, N].
+ num_bits: Number of bits per packed element (e.g., 4 or 2).
+ chunk_size: How many bit groups to unpack at once (tradeoff between speed and memory).
+
+ Returns:
+ unpacked: torch.int8 tensor of shape [B, K, N].
+ """
+ B, k_packed, N = packed_w.shape
+ pack_factor = 32 // num_bits
+ K = k_packed * pack_factor
+ mask = (1 << num_bits) - 1
+
+ # Allocate output tensor once
+ unpacked = torch.empty((B, K, N), dtype=torch.int8, device=packed_w.device)
+
+ # Process bit chunks iteratively to save memory
+ for i in range(0, pack_factor, chunk_size):
+ # Precompute shifts for this chunk
+ shift_vals = num_bits * torch.arange(i, min(i + chunk_size, pack_factor), device=packed_w.device)
+ # [chunk_size, 1, 1, 1]
+ shifts = shift_vals.view(-1, 1, 1, 1)
+ # Compute small chunk only
+ chunk = ((packed_w.unsqueeze(0) >> shifts) & mask).to(torch.int8)
+
+ # chunk: [chunk_size, B, k_packed, N]
+ # write into output
+ for j in range(chunk.shape[0]):
+ unpacked[:, (i + j)::pack_factor, :] = chunk[j]
+
+ del chunk # release memory early
+
+ return unpacked
+
+# more memmory
+def unpack_k_batch(packed_w: torch.Tensor, num_bits: int = 4) -> torch.Tensor:
+ """
+ Efficient vectorized unpacking for 3D tensor (batch version).
+ Converts [B, K // pack_factor, N] int32 tensor → [B, K, N] int8 tensor.
+
+ Args:
+ packed_w: torch.int32 tensor of shape [B, K // pack_factor, N].
+ num_bits: Number of bits per packed element (e.g., 4).
+
+ Returns:
+ unpacked: torch.int8 tensor of shape [B, K, N].
+ """
+ B, k_packed, n = packed_w.shape
+ pack_factor = 32 // num_bits
+ k = k_packed * pack_factor
+
+ mask = (1 << num_bits) - 1
+
+ # [pack_factor, 1, 1, 1]
+ shifts = (num_bits * torch.arange(pack_factor, device=packed_w.device)).view(-1, 1, 1, 1)
+
+ # [1, B, k_packed, N]
+ packed_expanded = packed_w.unsqueeze(0)
+
+ # Extract each group of num_bits using bitwise ops
+ unpacked_groups = ((packed_expanded >> shifts) & mask).to(torch.int8)
+
+ # [pack_factor, B, k_packed, N] → [B, K, N]
+ unpacked = unpacked_groups.permute(1, 2, 0, 3).reshape(B, k, n)
+
+ return unpacked
+
+
+#[B,K,N] ->[B,K,N//8]
+# less memmory
+def pack_n_batch_opt(x: torch.Tensor, pack_num: int = 8, order_map=None, chunk_size: int = 2) -> torch.Tensor:
+ """
+ Memory-efficient batch packing with correct bit order.
+ [B, K, N] int4 -> [B, K, N//pack_num] int32.
+ """
+ B, K, N = x.shape
+ assert N % pack_num == 0, "N must be divisible by pack_num"
+ cols = N // pack_num
+ unit = 32 // pack_num
+
+ if order_map is None:
+ order_map = list(range(pack_num))
+ order_map = torch.tensor(order_map, device=x.device)
+
+ shifts = unit * torch.arange(pack_num, device=x.device) # always 0..unit*(pack_num-1)
+ packed = torch.zeros((B, K, cols), dtype=torch.int32, device=x.device)
+ x_reshape = x.view(B, K, cols, pack_num) & 0xF
+
+ # process in chunks for memory efficiency
+ for start in range(0, pack_num, chunk_size):
+ end = min(start + chunk_size, pack_num)
+ idx_chunk = order_map[start:end]
+ shift_chunk = shifts[start:end]
+
+ vals = torch.gather(x_reshape, 3, idx_chunk.view(1,1,1,-1).expand(B,K,cols,-1)).to(torch.int32)
+ for j in range(vals.shape[-1]):
+ packed.add_(vals[..., j] << shift_chunk[j])
+
+ return packed
+
+## more memmory
+def pack_n_batch(x: torch.Tensor, pack_num: int = 8, order_map=None) -> torch.Tensor:
+ """
+ Efficient vectorized batch packing: [B, K, N] int4 -> [B, K, N//pack_num] int32.
+
+ Args:
+ x: torch.int32 tensor of shape [B, K, N], each element 0-15 (int4).
+ pack_num: Number of 4-bit elements per packed int32 (default=8).
+ order_map: Optional order of elements within each packed int32.
+
+ Returns:
+ torch.int32 tensor of shape [B, K, N//pack_num].
+ """
+
+ B, K, N = x.shape
+ assert N % pack_num == 0, "N must be divisible by pack_num"
+ cols = N // pack_num
+
+ if order_map is None:
+ order_map = list(range(pack_num))
+ order_map = torch.tensor(order_map, device=x.device)
+
+ unit = 32 // pack_num # number of bits per element
+
+ # reshape to [B, K, cols, pack_num]
+ pack_num_int = int(pack_num)
+
+ x_reshape = x.view(B, K, cols, pack_num_int)
+
+ # reorder according to order_map
+ x_reorder = torch.gather(
+ x_reshape, 3, order_map.view(1, 1, 1, -1).expand(B, K, cols, -1)
+ )
+
+ # mask low 4 bits
+ x_reorder = x_reorder & 0xF
+
+ # bit shifts [pack_num] -> [1,1,1,pack_num] broadcastable
+ shifts = (unit * torch.arange(pack_num_int, device=x.device)).view(1, 1, 1, -1)
+
+ # shift and sum along last dimension to combine bits
+ packed = (x_reorder << shifts).sum(dim=-1).to(torch.int32)
+
+ return packed
+
+
def get_moe_quant_method(
config: "GPTQMarlinConfig",
@@ -495,8 +650,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
self.quant_config = quant_config
if self.quant_config.quant_type.size_bits == 4:
self.quant_type = scalar_types.uint4b8
- elif self.quant_config.quant_type.size_bits == 8:
- self.quant_type = scalar_types.uint8b128
+ # elif self.quant_config.quant_type.size_bits == 8:
+ # self.quant_type = scalar_types.uint8b128
else:
raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.")
self.input_dtype = None
@@ -594,7 +749,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
num_experts,
scales_size13,
2 * intermediate_size_per_partition // self.quant_config.pack_factor,
- dtype=params_dtype,
+ dtype=torch.int32,
),
requires_grad=False,
)
@@ -606,7 +761,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
num_experts,
scales_size2,
hidden_size // self.quant_config.pack_factor,
- dtype=params_dtype,
+ dtype=torch.int32,
),
requires_grad=False,
)
@@ -656,7 +811,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
device = layer.w13_qweight.device
- layer.workspace = marlin_make_workspace_new(device, 4)
+ # layer.workspace = marlin_make_workspace_new(device, 4)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
@@ -673,119 +828,111 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.w2_scales.data = layer.w2_scales.data * 512
# Process act_order
- if self.quant_config.desc_act:
+ # if self.quant_config.desc_act:
# Get sorting based on g_idx
- num_experts = layer.w13_g_idx.shape[0]
- w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
- w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
- w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
- w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
- for e in range(num_experts):
- w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to(
- torch.int32
- )
- w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
- torch.int32
- )
- w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]]
- w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]]
- replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
- replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
- replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
- replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
- else:
- # Reset g_idx related tensors
- num_experts = layer.w13_g_idx.shape[0]
- device = layer.w13_g_idx.device
- layer.w13_g_idx = torch.nn.Parameter(
- torch.empty((num_experts, 0), dtype=torch.int32, device=device),
- requires_grad=False,
- )
- layer.w2_g_idx = torch.nn.Parameter(
- torch.empty((num_experts, 0), dtype=torch.int32, device=device),
- requires_grad=False,
- )
- layer.w13_g_idx_sort_indices = torch.nn.Parameter(
- torch.empty((num_experts, 0), dtype=torch.int32, device=device),
- requires_grad=False,
- )
- layer.w2_g_idx_sort_indices = torch.nn.Parameter(
- torch.empty((num_experts, 0), dtype=torch.int32, device=device),
- requires_grad=False,
- )
- # Repack weights
- marlin_w13_qweight = ops.gptq_marlin_moe_repack(
- layer.w13_qweight,
- layer.w13_g_idx_sort_indices,
- layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
- layer.w13_qweight.shape[2],
- self.quant_config.quant_type.size_bits,
- is_a_8bit=is_a_8bit,
- )
- replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
- marlin_w2_qweight = ops.gptq_marlin_moe_repack(
- layer.w2_qweight,
- layer.w2_g_idx_sort_indices,
- layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
- layer.w2_qweight.shape[2],
- self.quant_config.quant_type.size_bits,
- is_a_8bit=is_a_8bit,
- )
- replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
+ # num_experts = layer.w13_g_idx.shape[0]
+ # w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx)
+ # w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx)
+ # w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
+ # w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
+ # for e in range(num_experts):
+ # w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to(
+ # torch.int32
+ # )
+ # w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(
+ # torch.int32
+ # )
+ # w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]]
+ # w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]]
+ # replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
+ # replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
+ # replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
+ # replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
+ # else:
+ # # Reset g_idx related tensors
+ # num_experts = layer.w13_g_idx.shape[0]
+ # device = layer.w13_g_idx.device
+ # layer.w13_g_idx = torch.nn.Parameter(
+ # torch.empty((num_experts, 0), dtype=torch.int32, device=device),
+ # requires_grad=False,
+ # )
+ # layer.w2_g_idx = torch.nn.Parameter(
+ # torch.empty((num_experts, 0), dtype=torch.int32, device=device),
+ # requires_grad=False,
+ # )
+ # layer.w13_g_idx_sort_indices = torch.nn.Parameter(
+ # torch.empty((num_experts, 0), dtype=torch.int32, device=device),
+ # requires_grad=False,
+ # )
+ # layer.w2_g_idx_sort_indices = torch.nn.Parameter(
+ # torch.empty((num_experts, 0), dtype=torch.int32, device=device),
+ # requires_grad=False,
+ # )
+ # # Repack weights
+ # marlin_w13_qweight = ops.gptq_marlin_moe_repack(
+ # layer.w13_qweight,
+ # layer.w13_g_idx_sort_indices,
+ # layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
+ # layer.w13_qweight.shape[2],
+ # self.quant_config.quant_type.size_bits,
+ # )
+ # replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
+ # marlin_w2_qweight = ops.gptq_marlin_moe_repack(
+ # layer.w2_qweight,
+ # layer.w2_g_idx_sort_indices,
+ # layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
+ # layer.w2_qweight.shape[2],
+ # self.quant_config.quant_type.size_bits,
+ # )
+ # replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
+ # # Repack scales
+ # marlin_w13_scales = marlin_moe_permute_scales(
+ # s=layer.w13_scales,
+ # size_k=layer.intermediate_size_per_partition,
+ # size_n=layer.w13_scales.shape[2],
+ # group_size=self.quant_config.group_size,
+ # )
+ # replace_parameter(layer, "w13_scales", marlin_w13_scales)
+ # marlin_w2_scales = marlin_moe_permute_scales(
+ # s=layer.w2_scales,
+ # size_k=layer.w2_scales.shape[1]
+ # * (
+ # self.quant_config.group_size
+ # if self.quant_config.group_size != -1
+ # else self.quant_config.pack_factor
+ # ),
+ # size_n=layer.w2_scales.shape[2],
+ # group_size=self.quant_config.group_size,
+ # )
+ # replace_parameter(layer, "w2_scales", marlin_w2_scales)
- # The modular kernel expects w13_weight and w2_weight,
- # but GPTQ uses w13_qweight and w2_qweight
- # Alias for modular kernel
- layer.w13_weight = layer.w13_qweight
- # Alias for modular kernel
- layer.w2_weight = layer.w2_qweight
+ # if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
+ # layer.w13_bias.data = marlin_permute_bias(layer.w13_bias)
- # Repack scales
- marlin_w13_scales = marlin_moe_permute_scales(
- s=layer.w13_scales,
- size_k=layer.intermediate_size_per_partition,
- size_n=layer.w13_scales.shape[2],
- group_size=self.quant_config.group_size,
- is_a_8bit=is_a_8bit,
- )
- if self.input_dtype == torch.int8 and layer.num_groups_w13 > 1:
- marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
- marlin_w13_scales
- )
- layer.register_parameter(
- "w13_input_global_scale",
- torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
- )
-
- replace_parameter(layer, "w13_scales", marlin_w13_scales)
- marlin_w2_scales = marlin_moe_permute_scales(
- s=layer.w2_scales,
- size_k=layer.w2_scales.shape[1]
- * (
- self.quant_config.group_size
- if self.quant_config.group_size != -1
- else self.quant_config.pack_factor
- ),
- size_n=layer.w2_scales.shape[2],
- group_size=self.quant_config.group_size,
- is_a_8bit=is_a_8bit,
- )
- if self.input_dtype == torch.int8 and layer.num_groups_w2 > 1:
- marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
- marlin_w2_scales
- )
- layer.register_parameter(
- "w2_input_global_scale",
- torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
- )
-
- replace_parameter(layer, "w2_scales", marlin_w2_scales)
-
- if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
- layer.w13_bias.data = marlin_permute_bias(layer.w13_bias)
-
- if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
- layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
+ # if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
+ # layer.w2_bias.data = marlin_permute_bias(layer.w2_bias)
+ if self.quant_config.desc_act:
+ raise NotImplementedError(
+ "GPTQMarlinMoEMethod now not support desc_act. please fix it")
+ w13_qweight_unpacked = unpack_k_batch(layer.w13_qweight)
+ w13_qweight_repacked = pack_n_batch(w13_qweight_unpacked,self.quant_config.pack_factor,order_map=[0, 2, 4, 6, 1, 3, 5, 7])
+ replace_parameter(layer, "w13_qweight", w13_qweight_repacked)
+
+ # quant vllm/model_executor/layers/quantization/utils/quant_utils.py#quantize_weights
+ # if quant_type.has_bias():
+ # w_q += quant_type.bias
+ # use quant_type.bias as zp,(ixformer support)
+ w13_zp = torch.full_like(layer.w13_scales, self.quant_type.bias, dtype=torch.int32)
+ w13_zp_pack = pack_n_batch(w13_zp, self.quant_config.pack_factor, order_map=[0, 2, 4, 6, 1, 3, 5, 7]).contiguous()
+ replace_parameter(layer, "w13_qzeros", w13_zp_pack)
+
+ w2_qweight_unpacked = unpack_k_batch(layer.w2_qweight)
+ w2_qweight_repacked = pack_n_batch(w2_qweight_unpacked,self.quant_config.pack_factor,order_map=[0, 2, 4, 6, 1, 3, 5, 7])
+ replace_parameter(layer, "w2_qweight", w2_qweight_repacked)
+
+ w2_zp = torch.full_like(layer.w2_scales, self.quant_type.bias, dtype=torch.int32)
+ w2_zp_pack = pack_n_batch(w2_zp, self.quant_config.pack_factor, order_map=[0, 2, 4, 6, 1, 3, 5, 7]).contiguous()
+ replace_parameter(layer, "w2_qzeros", w2_zp_pack)
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
@@ -900,30 +1047,165 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
+ # Assign the value of shared_experts_output to variable shared_experts_input for fusion
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
- return fused_marlin_moe(
- x,
- layer.w13_qweight,
- layer.w2_qweight,
- getattr(layer, "w13_bias", None),
- getattr(layer, "w2_bias", None),
- layer.w13_scales,
- layer.w2_scales,
- topk_weights,
- topk_ids,
- input_global_scale1=getattr(layer, "w13_input_global_scale", None),
- input_global_scale2=getattr(layer, "w2_input_global_scale", None),
- quant_type_id=self.quant_type.id,
- apply_router_weight_on_input=layer.apply_router_weight_on_input,
- global_num_experts=layer.global_num_experts,
- expert_map=layer.expert_map,
- g_idx1=layer.w13_g_idx,
- g_idx2=layer.w2_g_idx,
- sort_indices1=layer.w13_g_idx_sort_indices,
- sort_indices2=layer.w2_g_idx_sort_indices,
- workspace=layer.workspace,
- is_k_full=self.is_k_full,
- input_dtype=self.input_dtype,
- inplace=not self.moe.disable_inplace,
+ assert layer.activation.value == "silu", "Only SiLU activation is supported."
+ use_ep = layer.expert_map is not None
+
+ if use_ep:
+ start_eid = layer.ep_rank * layer.local_num_experts
+ end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, layer.global_num_experts)
+
+ if layer.apply_router_weight_on_input:
+ raise NotImplementedError(
+ "GPTQMarlinMoEMethod Apply router weight on input is not supported for"
+ "fused Marlin MoE method.")
+
+ if (hasattr(layer, "w13_bias") and layer.w13_bias is not None) or (hasattr(layer, "w2_bias") and layer.w2_bias is not None):
+ raise NotImplementedError(
+ "GPTQMarlinMoEMethod moe_w4a16_group_gemm not supported bias, please fix this")
+
+ num_tokens = topk_ids.shape[0]
+ num_experts = layer.global_num_experts
+
+ if use_ep:
+ hidden_size = x.shape[1]
+ (
+ src_to_dst,
+ sorted_token_ids,
+ expert_sizes_gpu,
+ expert_sizes_cpu,
+ expand_tokens,
+ ) = ixfops.moe_compute_token_index_ep(
+ topk_ids=topk_ids,
+ num_experts=num_experts,
+ start_expert_id=start_eid,
+ end_expert_id=end_eid,
+ )
+ if expert_sizes_cpu.sum() == 0:
+ return torch.zeros(
+ (num_tokens, hidden_size),
+ device=x.device,
+ dtype=x.dtype,
+ )
+ else:
+ expand_tokens = num_tokens * layer.top_k
+ (
+ src_to_dst,
+ sorted_token_ids,
+ expert_sizes_gpu,
+ expert_sizes_cpu,
+ ) = ixfops.moe_compute_token_index(
+ topk_ids=topk_ids,
+ num_experts=num_experts,
+ )
+ expert_sizes_cpu = expert_sizes_gpu.cpu()
+
+ # expand + reorder
+ # TODO use kernel
+ expand_hidden_states = ixfops.moe_expand_input(
+ hidden_states=x,
+ dst_to_src=sorted_token_ids,
+ dst_tokens=expand_tokens,
+ topk=layer.top_k,
+ src_to_dst=src_to_dst,
)
+
+ # w4a16 group gemm 1
+ # pt_output_1: (expand_tokens, 2n) dtype
+ pt_output_1 = ixfops.moe_w4a16_group_gemm(
+ input=expand_hidden_states,
+ weight=layer.w13_qweight,
+ w_scales=layer.w13_scales,
+ quant_type="awq",
+ tokens_per_experts=expert_sizes_cpu,
+ w_zeros=layer.w13_qzeros,
+ group_size=self.quant_config.group_size,
+ dst_to_src=None,
+ format="NN",
+ tokens_per_experts_gpu=expert_sizes_gpu,
+ )
+
+ # act
+ pt_output_2 = ixfops.silu_and_mul(pt_output_1)
+
+ # w4a16 group gemm 2 + reorder
+ # pt_output_3: (expand_tokens, k) dtype
+ if use_ep:
+ pt_output_3 = torch.empty(
+ (num_tokens * layer.top_k, hidden_size),
+ device=x.device,
+ dtype=x.dtype,
+ )
+
+ ixfops.moe_w4a16_group_gemm(
+ input=pt_output_2,
+ weight=layer.w2_qweight,
+ w_scales=layer.w2_scales,
+ quant_type="awq",
+ tokens_per_experts=expert_sizes_cpu,
+ w_zeros=layer.w2_qzeros,
+ group_size=self.quant_config.group_size,
+ dst_to_src=sorted_token_ids,
+ format="NN",
+ output=pt_output_3,
+ tokens_per_experts_gpu=expert_sizes_gpu,
+ )
+
+ reduce_mask = src_to_dst == -1
+ final_hidden_states = ixfops.moe_output_reduce_sum(
+ input=pt_output_3.view(num_tokens, layer.top_k, -1),
+ topk_weight=topk_weights,
+ scaling_factor=layer.routed_scaling_factor,
+ mask=reduce_mask,
+ )
+ else:
+ pt_output_3 = ixfops.moe_w4a16_group_gemm(
+ input=pt_output_2,
+ weight=layer.w2_qweight,
+ w_scales=layer.w2_scales,
+ quant_type="awq",
+ tokens_per_experts=expert_sizes_cpu,
+ w_zeros=layer.w2_qzeros,
+ group_size=self.quant_config.group_size,
+ dst_to_src=sorted_token_ids,
+ format="NN",
+ tokens_per_experts_gpu=expert_sizes_gpu,
+ )
+
+ # mul + reduce_sum
+ # final_hidden_states: (num_tokens, k)
+ final_hidden_states = ixfops.moe_output_reduce_sum(
+ input=pt_output_3.view(num_tokens, layer.top_k, -1),
+ topk_weight=topk_weights,
+ scaling_factor=layer.routed_scaling_factor,
+ extra_residual=shared_experts_input,
+ )
+ return final_hidden_states
+
+
+
+
+
+ # return torch.ops.vllm.fused_marlin_moe(
+ # x,
+ # layer.w13_qweight,
+ # layer.w2_qweight,
+ # getattr(layer, "w13_bias", None),
+ # getattr(layer, "w2_bias", None),
+ # layer.w13_scales,
+ # layer.w2_scales,
+ # router_logits,
+ # topk_weights,
+ # topk_ids,
+ # quant_type_id=self.quant_type.id,
+ # apply_router_weight_on_input=apply_router_weight_on_input,
+ # global_num_experts=global_num_experts,
+ # expert_map=expert_map,
+ # g_idx1=layer.w13_g_idx,
+ # g_idx2=layer.w2_g_idx,
+ # sort_indices1=layer.w13_g_idx_sort_indices,
+ # sort_indices2=layer.w2_g_idx_sort_indices,
+ # workspace=layer.workspace,
+ # is_k_full=self.is_k_full)
diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py
index 4c059da..f167e21 100644
--- a/vllm/model_executor/layers/quantization/modelopt.py
+++ b/vllm/model_executor/layers/quantization/modelopt.py
@@ -12,8 +12,7 @@ from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel,
)
-from vllm.model_executor.layers.attention import Attention
-from vllm.model_executor.layers.fused_moe.activation import MoEActivation
+from vllm.model_executor.layers.attention import Attention, MLAAttention
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
@@ -24,14 +23,12 @@ from vllm.model_executor.layers.fused_moe.layer import (
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.oracle.fp8 import (
- Fp8MoeBackend,
convert_to_fp8_moe_kernel_format,
make_fp8_moe_kernel,
make_fp8_moe_quant_config,
select_fp8_moe_backend,
)
from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import (
- NvFp4MoeBackend,
convert_to_nvfp4_moe_kernel_format,
is_global_sf_supported_for_nvfp4_backend,
make_nvfp4_moe_kernel,
@@ -49,13 +46,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
-from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
- flashinfer_trtllm_fp4_moe,
- flashinfer_trtllm_fp4_routed_moe,
-)
-from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
- apply_fi_trtllm_fp8_per_tensor_moe,
-)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
process_fp8_input_tensor_strategy_moe,
@@ -114,6 +104,8 @@ QUANT_ALGOS = [
"NVFP4",
# MXFP8
"MXFP8",
+ # MIXED_PRECISION,
+ "MIXED_PRECISION",
]
KV_CACHE_QUANT_ALGOS = ["FP8"]
@@ -181,7 +173,7 @@ class ModelOptQuantConfigBase(QuantizationConfig):
self, layer: torch.nn.Module, prefix: str
) -> "QuantizeMethodBase | None":
# handle kv-cache first so we can focus only on weight quantization thereafter
- if isinstance(layer, Attention):
+ if isinstance(layer, (Attention, MLAAttention)):
return self.KVCacheMethodCls(self)
# handle exclusion
@@ -235,6 +227,26 @@ class ModelOptQuantConfigBase(QuantizationConfig):
self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules)
+ @staticmethod
+ def _extract_modelopt_quant_algo(
+ hf_quant_cfg: dict[str, Any] | None,
+ ) -> str | None:
+ """Extract upper-cased quant_algo from a modelopt config.
+
+ Returns the quant_algo string (upper-cased), or None if the config
+ is not a modelopt config.
+ """
+ if hf_quant_cfg is None:
+ return None
+ if hf_quant_cfg.get("quant_method", "").lower() != "modelopt":
+ return None
+ if "quantization" in hf_quant_cfg:
+ quant_config = hf_quant_cfg["quantization"]
+ if isinstance(quant_config, dict):
+ return str(quant_config.get("quant_algo", "")).upper()
+ return None
+ return str(hf_quant_cfg.get("quant_algo", "")).upper()
+
@staticmethod
def get_config_filenames() -> list[str]:
return ["hf_quant_config.json"]
@@ -272,10 +284,20 @@ class ModelOptQuantConfigBase(QuantizationConfig):
# "exclude_modules" is the key in the legacy hf_quant_config.json
exclude_modules = quant_config.get("exclude_modules", [])
else:
- # Compressed-tensors style format:
+ # Compressed-tensors style format (config.json quantization_config):
# {"quant_algo": "...", "quant_method": "modelopt"}
quant_method = config.get("quant_algo")
- kv_cache_quant_method = config.get("kv_cache_quant_algo")
+
+ # "kv_cache_scheme" (a dict) instead of "kv_cache_quant_algo" (a string).
+ kv_cache_scheme = config.get("kv_cache_scheme")
+ if isinstance(kv_cache_scheme, dict) and (
+ kv_cache_scheme.get("type") == "float"
+ and kv_cache_scheme.get("num_bits") == 8
+ ):
+ kv_cache_quant_method = "FP8"
+ else:
+ kv_cache_quant_method = None
+
# "ignore" is the key in config.json
exclude_modules = config.get("ignore", [])
group_size_raw = config.get("group_size")
@@ -379,32 +401,9 @@ class ModelOptFp8Config(ModelOptQuantConfigBase):
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> QuantizationMethods | None:
- """Detect if this ModelOpt config should be used based on
- quantization config."""
-
- if hf_quant_cfg is None:
- return None
-
- # Use the community standard 'quant_method'
- quant_method = hf_quant_cfg.get("quant_method", "").lower()
-
- # Only proceed if the method is explicitly "modelopt"
- if quant_method != "modelopt":
- return None
-
- # Look for ModelOpt-specific config structure
- if "quantization" in hf_quant_cfg:
- quant_config = hf_quant_cfg["quantization"]
- if isinstance(quant_config, dict):
- quant_algo = str(quant_config.get("quant_algo", ""))
- if quant_algo.upper() == "FP8":
- return "modelopt"
- else:
- # Check for compressed-tensors style config with specific quant_algo
- quant_algo = str(hf_quant_cfg.get("quant_algo", ""))
- if quant_algo.upper() == "FP8":
- return "modelopt"
-
+ algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
+ if algo is not None and algo == "FP8":
+ return "modelopt"
return None
@classmethod
@@ -737,7 +736,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
- ) -> mk.FusedMoEPrepareAndFinalize | None:
+ ) -> mk.FusedMoEPrepareAndFinalizeModular | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
@@ -745,9 +744,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def select_gemm_impl(
self,
- prepare_finalize: mk.FusedMoEPrepareAndFinalize,
+ prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module,
- ) -> mk.FusedMoEPermuteExpertsUnpermute:
+ ) -> mk.FusedMoEExpertsModular:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
@@ -862,16 +861,15 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
# Setup modular kernel.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
- if self.moe_quant_config:
- assert self.experts_cls is not None
- self.moe_mk = make_fp8_moe_kernel(
- moe_quant_config=self.moe_quant_config,
- moe_config=self.moe,
- fp8_backend=self.fp8_backend,
- experts_cls=self.experts_cls,
- routing_tables=layer._maybe_init_expert_routing_tables(),
- shared_experts=layer.shared_experts,
- )
+ assert self.experts_cls is not None
+ self.moe_kernel = make_fp8_moe_kernel(
+ moe_quant_config=self.moe_quant_config,
+ moe_config=self.moe,
+ fp8_backend=self.fp8_backend,
+ experts_cls=self.experts_cls,
+ routing_tables=layer._maybe_init_expert_routing_tables(),
+ shared_experts=layer.shared_experts,
+ )
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w13 = layer.w13_weight
@@ -904,9 +902,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
)
- def get_fused_moe_quant_config(
- self, layer: torch.nn.Module
- ) -> FusedMoEQuantConfig | None:
+ def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
w1_scale = layer.w13_weight_scale
w2_scale = layer.w2_weight_scale
a1_scale = layer.w13_input_scale
@@ -920,10 +916,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
a2_scale=a2_scale,
)
- @property
- def is_monolithic(self) -> bool:
- return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
-
def apply_monolithic(
self,
layer: FusedMoE,
@@ -931,28 +923,20 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
- assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM
- if layer.enable_eplb:
- raise NotImplementedError(
- "EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend."
- )
- # TODO(rob): this validation should happen at kernel selection
- # time in the oracle rather than here.
- SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
- assert layer.activation in SUPPORTED_ACTIVATIONS, (
- f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
- f"TRTLLM FP4 MoE, {layer.activation} found instead."
- )
- return apply_fi_trtllm_fp8_per_tensor_moe(
- layer=layer,
- hidden_states=x,
- router_logits=router_logits,
- routing_bias=layer.e_score_correction_bias,
+ assert self.moe_kernel is not None
+ return self.moe_kernel.apply_monolithic(
+ x,
+ layer.w13_weight,
+ layer.w2_weight,
+ router_logits,
+ activation=layer.activation,
global_num_experts=layer.global_num_experts,
- top_k=layer.top_k,
+ expert_map=layer.expert_map,
+ apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
- apply_router_weight_on_input=layer.apply_router_weight_on_input,
+ e_score_correction_bias=layer.e_score_correction_bias,
+ routed_scaling_factor=layer.routed_scaling_factor,
)
def apply(
@@ -964,25 +948,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
-
- # TODO(rob): this validation should happen at kernel selection
- # time in the oracle rather than here.
- if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
- assert layer.activation in (
- MoEActivation.SILU,
- MoEActivation.RELU2_NO_MUL,
- ), (
- "Expected activation to be in ('silu', 'relu2_no_mul'),"
- f"but got {layer.activation}"
- )
-
- assert self.moe_mk is not None
- return self.moe_mk(
- hidden_states=x,
- w1=layer.w13_weight,
- w2=layer.w2_weight,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
+ assert self.moe_kernel is not None
+ return self.moe_kernel.apply(
+ x,
+ layer.w13_weight,
+ layer.w2_weight,
+ topk_weights,
+ topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
@@ -1031,32 +1003,9 @@ class ModelOptNvFp4Config(ModelOptQuantConfigBase):
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> QuantizationMethods | None:
- """Detect if this ModelOpt FP4 config should be used based on
- quantization config."""
- if hf_quant_cfg is None:
- return None
-
- # Use the community standard 'quant_method'
- quant_method = hf_quant_cfg.get("quant_method", "").lower()
-
- # Only proceed if the method is explicitly "modelopt"
- if quant_method != "modelopt":
- return None
-
- # Look for ModelOpt-specific config structure
- if "quantization" in hf_quant_cfg:
- quant_config = hf_quant_cfg["quantization"]
- if isinstance(quant_config, dict):
- quant_algo = quant_config.get("quant_algo", "")
- if "NVFP4" in quant_algo:
- return "modelopt_fp4"
- else:
- # Check for compressed-tensors style config with specific
- # quant_algo field
- quant_algo = hf_quant_cfg.get("quant_algo", "")
- if isinstance(quant_algo, str) and "FP4" in quant_algo.upper():
- return "modelopt_fp4"
-
+ algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
+ if algo is not None and ("NVFP4" in algo or "FP4" in algo):
+ return "modelopt_fp4"
return None
@classmethod
@@ -1249,17 +1198,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
- ) -> mk.FusedMoEPrepareAndFinalize | None:
- raise ValueError(
- f"{self.__class__.__name__} uses the new modular kernel initialization "
- "logic. This function should not be called."
- )
-
- def select_gemm_impl(
- self,
- prepare_finalize: mk.FusedMoEPrepareAndFinalize,
- layer: torch.nn.Module,
- ) -> mk.FusedMoEPermuteExpertsUnpermute:
+ ) -> mk.FusedMoEPrepareAndFinalizeModular | None:
raise ValueError(
f"{self.__class__.__name__} uses the new modular kernel initialization "
"logic. This function should not be called."
@@ -1434,51 +1373,18 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
replace_parameter(layer, "w2_weight_scale_2", w2_scale_2)
replace_parameter(layer, "w2_input_scale", a2_scale)
- # Setup modular kernel for TP case and naive DP/EP case.
- # In non-naive DP/EP case, we will create a ModularKernelMethod.
- # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel
- # in both cases.
+ # Setup modular kernel.
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
- if self.moe_quant_config:
- assert self.experts_cls is not None
- self.moe_mk = make_nvfp4_moe_kernel(
- moe_quant_config=self.moe_quant_config,
- moe_config=self.moe,
- experts_cls=self.experts_cls,
- shared_experts=layer.shared_experts,
- routing_tables=layer._maybe_init_expert_routing_tables(),
- )
-
- @property
- def do_post_quant_allgather(self):
- return self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
-
- def prepare_dp_allgather_tensor(
- self,
- layer: FusedMoE,
- hidden_states: torch.Tensor,
- router_logits: torch.Tensor,
- ) -> tuple[torch.Tensor, list[torch.Tensor]]:
- """Optionally prepare extra tensors to carry through DP allgather/EP."""
- if self.nvfp4_backend != NvFp4MoeBackend.FLASHINFER_TRTLLM:
- raise RuntimeError(
- "prepare_dp_allgather_tensor is only supported for "
- "FlashInfer TRTLLM NVFP4 MoE backend."
- )
-
- import flashinfer
-
- hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
- hidden_states,
- layer.a1_gscale,
- is_sf_swizzled_layout=False,
+ assert self.experts_cls is not None
+ self.moe_kernel = make_nvfp4_moe_kernel(
+ moe_quant_config=self.moe_quant_config,
+ moe_config=self.moe,
+ experts_cls=self.experts_cls,
+ shared_experts=layer.shared_experts,
+ routing_tables=layer._maybe_init_expert_routing_tables(),
)
- extra_tensors: list[torch.Tensor] = [hidden_states_sf]
- return hidden_states_fp4, extra_tensors
- def get_fused_moe_quant_config(
- self, layer: torch.nn.Module
- ) -> FusedMoEQuantConfig | None:
+ def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
return make_nvfp4_moe_quant_config(
backend=self.nvfp4_backend,
w13_scale=layer.w13_weight_scale,
@@ -1493,13 +1399,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
def supports_eplb(self) -> bool:
return True
- @property
- def is_monolithic(self) -> bool:
- return (
- self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
- and not self.moe.moe_parallel_config.enable_eplb
- )
-
def apply_monolithic(
self,
layer: FusedMoE,
@@ -1507,22 +1406,20 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert self.is_monolithic
- assert (
- self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM
- and not layer.enable_eplb
- )
-
- return flashinfer_trtllm_fp4_moe(
- layer=layer,
- x=x,
- router_logits=router_logits,
- top_k=layer.top_k,
+ assert self.moe_kernel is not None
+ return self.moe_kernel.apply_monolithic(
+ x,
+ layer.w13_weight,
+ layer.w2_weight,
+ router_logits,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
+ expert_map=layer.expert_map,
+ apply_router_weight_on_input=layer.apply_router_weight_on_input,
num_expert_group=layer.num_expert_group,
topk_group=layer.topk_group,
- custom_routing_function=layer.custom_routing_function,
e_score_correction_bias=layer.e_score_correction_bias,
+ routed_scaling_factor=layer.routed_scaling_factor,
)
def apply(
@@ -1534,33 +1431,19 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not self.is_monolithic
-
- # EPLB path
- if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
- assert layer.enable_eplb
- return flashinfer_trtllm_fp4_routed_moe(
- layer=layer,
- x=x,
- topk_ids=topk_ids,
- topk_weights=topk_weights,
- top_k=layer.top_k,
- activation=layer.activation,
- global_num_experts=layer.global_num_experts,
- )
- else:
- assert self.moe_mk is not None
- return self.moe_mk(
- hidden_states=x,
- w1=layer.w13_weight,
- w2=layer.w2_weight,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- activation=layer.activation,
- global_num_experts=layer.global_num_experts,
- expert_map=layer.expert_map,
- apply_router_weight_on_input=layer.apply_router_weight_on_input,
- shared_experts_input=shared_experts_input,
- )
+ assert self.moe_kernel is not None
+ return self.moe_kernel.apply(
+ x,
+ layer.w13_weight,
+ layer.w2_weight,
+ topk_weights,
+ topk_ids,
+ activation=layer.activation,
+ global_num_experts=layer.global_num_experts,
+ expert_map=layer.expert_map,
+ apply_router_weight_on_input=layer.apply_router_weight_on_input,
+ shared_experts_input=shared_experts_input,
+ )
ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod
@@ -1619,31 +1502,9 @@ class ModelOptMxFp8Config(ModelOptQuantConfigBase):
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> QuantizationMethods | None:
- """Detect if this ModelOpt MXFP8 config should be used based on
- quantization config."""
- if hf_quant_cfg is None:
- return None
-
- # Use the community standard 'quant_method'
- quant_method = hf_quant_cfg.get("quant_method", "").lower()
-
- # Only proceed if the method is explicitly "modelopt"
- if quant_method != "modelopt":
- return None
-
- # Look for ModelOpt-specific config structure
- if "quantization" in hf_quant_cfg:
- quant_config = hf_quant_cfg["quantization"]
- if isinstance(quant_config, dict):
- quant_algo = str(quant_config.get("quant_algo", "")).upper()
- if "MXFP8" in quant_algo:
- return "modelopt_mxfp8"
- else:
- # Check for compressed-tensors style config with specific quant_algo
- quant_algo = str(hf_quant_cfg.get("quant_algo", "")).upper()
- if "MXFP8" in quant_algo:
- return "modelopt_mxfp8"
-
+ algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
+ if algo is not None and "MXFP8" in algo:
+ return "modelopt_mxfp8"
return None
@classmethod
@@ -1841,3 +1702,188 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
# Register the method classes for ModelOptMxFp8Config
ModelOptMxFp8Config.LinearMethodCls = ModelOptMxFp8LinearMethod
ModelOptMxFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod
+
+
+class ModelOptMixedPrecisionConfig(ModelOptQuantConfigBase):
+ """Config class for ModelOpt MIXED_PRECISION.
+
+ Supports checkpoints where different layers use different quantization
+ algorithms (e.g., FP8 for dense layers and NVFP4 for MoE experts).
+ The per-layer algorithm is specified in the ``quantized_layers`` dict
+ inside ``config.json``'s ``quantization_config`` (preferred) or the
+ legacy ``hf_quant_config.json``.
+ """
+
+ def __init__(
+ self,
+ kv_cache_quant_method: str | None,
+ exclude_modules: list[str],
+ quantized_layers: dict[str, dict[str, Any]],
+ fp8_config: ModelOptFp8Config,
+ nvfp4_config: ModelOptNvFp4Config,
+ ) -> None:
+ super().__init__(exclude_modules)
+ self.kv_cache_quant_method = kv_cache_quant_method
+ self.quantized_layers = quantized_layers
+ self.fp8_config = fp8_config
+ self.nvfp4_config = nvfp4_config
+
+ def get_name(self) -> QuantizationMethods:
+ return "modelopt_mixed"
+
+ def get_supported_act_dtypes(self) -> list[torch.dtype]:
+ return [torch.bfloat16, torch.half]
+
+ @classmethod
+ def get_min_capability(cls) -> int:
+ return 89
+
+ @classmethod
+ def override_quantization_method(
+ cls, hf_quant_cfg, user_quant
+ ) -> QuantizationMethods | None:
+ algo = cls._extract_modelopt_quant_algo(hf_quant_cfg)
+ if algo is not None and algo == "MIXED_PRECISION":
+ return "modelopt_mixed"
+ return None
+
+ @classmethod
+ def _from_config(
+ cls,
+ *,
+ quant_method: str,
+ kv_cache_quant_method: str | None,
+ exclude_modules: list[str],
+ original_config: dict[str, Any],
+ group_size: int | None,
+ **kwargs: Any,
+ ) -> "ModelOptMixedPrecisionConfig":
+ if "quantization" in original_config:
+ quantized_layers = original_config["quantization"].get(
+ "quantized_layers", {}
+ )
+ else:
+ quantized_layers = original_config.get("quantized_layers", {})
+
+ if not quantized_layers:
+ raise ValueError(
+ "MIXED_PRECISION quant_algo requires a non-empty "
+ "'quantized_layers' mapping in the quantization config."
+ )
+
+ # Determine group_size from the first NVFP4 entry if not provided.
+ if group_size is None:
+ for layer_info in quantized_layers.values():
+ if layer_info.get("quant_algo", "").upper() == "NVFP4":
+ group_size = layer_info.get("group_size", 16)
+ break
+ if group_size is None:
+ group_size = 16
+
+ fp8_config = ModelOptFp8Config(
+ quant_method="FP8",
+ is_checkpoint_fp8_serialized=True,
+ kv_cache_quant_method=kv_cache_quant_method,
+ exclude_modules=[],
+ )
+ nvfp4_config = ModelOptNvFp4Config(
+ is_checkpoint_nvfp4_serialized=True,
+ kv_cache_quant_algo=kv_cache_quant_method,
+ exclude_modules=[],
+ group_size=group_size,
+ )
+
+ return cls(
+ kv_cache_quant_method=kv_cache_quant_method,
+ exclude_modules=exclude_modules,
+ quantized_layers=quantized_layers,
+ fp8_config=fp8_config,
+ nvfp4_config=nvfp4_config,
+ )
+
+ def _resolve_quant_algo(self, prefix: str) -> str | None:
+ """Look up the quant_algo for a vLLM-side layer prefix.
+
+ Tries three strategies in order:
+ 1. Direct lookup in ``quantized_layers``.
+ 2. Packed/fused-layer lookup (unfuse via ``packed_modules_mapping``).
+ 3. Prefix-based lookup for FusedMoE (any child key starts with
+ ``prefix + "."``).
+
+ Returns the upper-cased quant_algo string, or *None* if the prefix
+ is not found.
+ """
+ # 1. Direct lookup
+ if prefix in self.quantized_layers:
+ return self.quantized_layers[prefix]["quant_algo"].upper()
+
+ # 2. Packed / fused layer lookup
+ proj_name = prefix.rsplit(".", 1)[-1]
+ if self.packed_modules_mapping and proj_name in self.packed_modules_mapping:
+ algos: set[str] = set()
+ base = prefix.rsplit(".", 1)[0]
+ for shard_name in self.packed_modules_mapping[proj_name]:
+ shard_prefix = f"{base}.{shard_name}"
+ if shard_prefix in self.quantized_layers:
+ algos.add(self.quantized_layers[shard_prefix]["quant_algo"].upper())
+ if len(algos) == 1:
+ return algos.pop()
+ if len(algos) > 1:
+ raise ValueError(
+ f"Mixed quant_algo within fused layer {prefix}: "
+ f"{algos}. All shards must use the same quantization."
+ )
+
+ # 3. Prefix-based lookup (for FusedMoE / parent modules)
+ prefix_dot = prefix + "."
+ for key, info in self.quantized_layers.items():
+ if key.startswith(prefix_dot):
+ return info["quant_algo"].upper()
+
+ return None
+
+ def get_quant_method(
+ self, layer: torch.nn.Module, prefix: str
+ ) -> "QuantizeMethodBase | None":
+ """Return quantize-method based on layer."""
+ # KV-cache quantization
+ if isinstance(layer, Attention):
+ if self.kv_cache_quant_method:
+ return ModelOptFp8KVCacheMethod(self)
+ return None
+
+ # Excluded layers
+ if self.is_layer_excluded(prefix):
+ if isinstance(layer, LinearBase):
+ return UnquantizedLinearMethod()
+ return None
+
+ quant_algo = self._resolve_quant_algo(prefix)
+
+ if isinstance(layer, LinearBase):
+ if quant_algo == "FP8":
+ return ModelOptFp8LinearMethod(self.fp8_config)
+ if quant_algo == "NVFP4":
+ return ModelOptNvFp4LinearMethod(self.nvfp4_config)
+ # Layer not in quantized_layers — leave unquantized
+ return UnquantizedLinearMethod()
+
+ if isinstance(layer, FusedMoE):
+ if quant_algo == "FP8":
+ return ModelOptFp8MoEMethod(
+ quant_config=self.fp8_config,
+ moe_config=layer.moe_config,
+ )
+ if quant_algo == "NVFP4":
+ return ModelOptNvFp4FusedMoE(
+ quant_config=self.nvfp4_config,
+ moe_config=layer.moe_config,
+ )
+ return None
+
+ return None
+
+ def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
+ super().apply_vllm_mapper(hf_to_vllm_mapper)
+ if self.quantized_layers:
+ self.quantized_layers = hf_to_vllm_mapper.apply_dict(self.quantized_layers)
diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py
index d81f0f8..97d6017 100644
--- a/vllm/model_executor/layers/quantization/mxfp4.py
+++ b/vllm/model_executor/layers/quantization/mxfp4.py
@@ -6,6 +6,7 @@ import torch
from torch.nn.parameter import Parameter
from vllm import envs
+from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
@@ -77,6 +78,8 @@ class Mxfp4Backend(Enum):
# Triton Backend
TRITON = 6
+ CK = 7
+
def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
"""
@@ -167,9 +170,15 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
elif current_platform.is_xpu():
logger.info_once("Using xpu backend on XPU")
return Mxfp4Backend.MARLIN
- elif current_platform.is_rocm() and has_triton_kernels():
- logger.info_once("Using Triton backend")
- return Mxfp4Backend.TRITON
+ elif current_platform.is_rocm():
+ from vllm.platforms.rocm import on_gfx950
+
+ if rocm_aiter_ops.is_enabled() and on_gfx950():
+ logger.info_once("Using CK MXFP4 MoE backend (Aiter ROCm)")
+ return Mxfp4Backend.CK
+ elif has_triton_kernels():
+ logger.info_once("Using Triton backend")
+ return Mxfp4Backend.TRITON
return Mxfp4Backend.NONE
@@ -257,7 +266,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
# Initialized in process_weights_after_loading for CUTLASS/SM90 backends
- self.moe_mk: mk.FusedMoEModularKernel | None = None
+ self.moe_kernel: mk.FusedMoEKernel | None = None
def create_weights(
self,
@@ -338,6 +347,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.intermediate_size = intermediate_size_per_partition_after_pad
self.hidden_size = hidden_size
+ self.hidden_pad = extra_weight_attrs.get("hidden_pad", 0)
+ self.intermediate_pad = (
+ intermediate_size_per_partition_after_pad - intermediate_size_per_partition
+ )
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(
torch.zeros(
@@ -427,7 +440,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
assert prepare_finalize is not None
- self.moe_mk = mk.FusedMoEModularKernel(
+ self.moe_kernel = mk.FusedMoEKernel(
prepare_finalize,
MarlinExperts(
self.moe,
@@ -776,7 +789,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
assert prepare_finalize is not None
- self.moe_mk = mk.FusedMoEModularKernel(
+ self.moe_kernel = mk.FusedMoEKernel(
prepare_finalize,
FlashInferExperts(
moe_config=self.moe,
@@ -784,6 +797,66 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
),
shared_experts=None,
)
+ elif self.mxfp4_backend == Mxfp4Backend.CK:
+ if layer.w13_bias is not None:
+ layer.w13_bias.data = layer.w13_bias.data.to(torch.float32)
+ if layer.w2_bias.data is not None:
+ layer.w2_bias.data = layer.w2_bias.data.to(torch.float32)
+
+ e, n, k = layer.w13_weight.shape
+ layer.w13_weight.view(torch.uint8).copy_(
+ layer.w13_weight.data.view(torch.uint8)
+ .view(e, n // 2, 2, k)
+ .permute(0, 2, 1, 3)
+ .contiguous()
+ .view(e, n, k)
+ )
+ layer.w13_weight_scale.data = (
+ layer.w13_weight_scale.data.view(e, n // 2, 2, -1)
+ .permute(0, 2, 1, 3)
+ .contiguous()
+ .view(e, n, -1)
+ )
+ layer.w13_weight.data = layer.w13_weight.data.view(torch.float4_e2m1fn_x2)
+ layer.w2_weight.data = layer.w2_weight.data.view(torch.float4_e2m1fn_x2)
+
+ layer.w13_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(
+ layer.w13_weight, 16, True
+ )
+ shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4(
+ layer.w13_weight_scale.view(-1, layer.w13_weight_scale.shape[-1]),
+ self.num_experts,
+ True,
+ )
+
+ layer.w2_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(
+ layer.w2_weight, 16, False
+ )
+ shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4(
+ layer.w2_weight_scale.view(-1, layer.w2_weight_scale.shape[-1]),
+ self.num_experts,
+ False,
+ )
+
+ layer.w13_bias.data = (
+ layer.w13_bias.data.view(-1, n // 2, 2)
+ .permute(0, 2, 1)
+ .contiguous()
+ .view(-1, n)
+ )
+
+ layer.w13_weight_scale = torch.nn.Parameter(
+ shuffled_w13_scale, requires_grad=False
+ )
+ layer.w2_weight_scale = torch.nn.Parameter(
+ shuffled_w2_scale, requires_grad=False
+ )
+ # replace_parameter(layer, "w13_bias", w13_bias)
+ # replace_parameter(layer, "w13_weight_scale", w13_weight_scale)
+ # replace_parameter(layer, "w2_weight_scale", w2_weight_scale)
+ # replace_parameter(layer, "w13_weight", w13_weight)
+ # replace_parameter(layer, "w2_weight", w2_weight)
+
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
@@ -792,18 +865,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.w13_bias = Parameter(w13_bias, requires_grad=False)
layer.w2_bias = Parameter(w2_bias, requires_grad=False)
-
# Ideally we'd use FusedMoEModularKernel.prepare_finalize object
# (stored in self.fused_experts) to determine if the MoE has a
# batched activation format. As self.fused_experts is not
# initialized at this point, we resort to checking the MoE config
# directly.
- is_batched_moe = self.moe.use_pplx_kernels or self.moe.use_deepep_ll_kernels
+ is_batched_moe = self.moe.use_deepep_ll_kernels
if is_batched_moe:
num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
else:
num_warps = 8
-
w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
layer.w13_weight, layer.w13_weight_scale, num_warps
)
@@ -817,13 +888,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.w2_precision_config = PrecisionConfig(
weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex)
)
-
self.w13_weight = w13_weight
self.w2_weight = w2_weight
del layer.w13_weight
del layer.w2_weight
layer.w13_weight = w13_weight
layer.w2_weight = w2_weight
+
else:
raise ValueError(
f"Unsupported mxfp4_backend: {self.mxfp4_backend}: "
@@ -862,6 +933,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
elif self.mxfp4_backend in [
Mxfp4Backend.SM100_FI_MXFP4_BF16,
Mxfp4Backend.SM90_FI_MXFP4_BF16,
+ Mxfp4Backend.CK,
]:
return mxfp4_w4a16_moe_quant_config(
w1_bias=layer.w13_bias,
@@ -882,9 +954,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
def select_gemm_impl(
self,
- prepare_finalize: mk.FusedMoEPrepareAndFinalize,
+ prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular,
layer: torch.nn.Module,
- ) -> mk.FusedMoEPermuteExpertsUnpermute:
+ ) -> mk.FusedMoEExpertsModular:
if (
prepare_finalize.activation_format
== mk.FusedMoEActivationFormat.BatchedExperts
@@ -929,10 +1001,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
@property
def is_monolithic(self) -> bool:
+ if self.moe.is_lora_enabled:
+ return False
return (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM
or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16
or self.mxfp4_backend == Mxfp4Backend.TRITON
+ or self.mxfp4_backend == Mxfp4Backend.CK
)
def apply(
@@ -968,8 +1043,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
or self.mxfp4_backend == Mxfp4Backend.MARLIN
)
- assert self.moe_mk is not None
- return self.moe_mk(
+ assert self.moe_kernel is not None
+ return self.moe_kernel.apply(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@@ -1054,6 +1129,27 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
tune_max_num_tokens=max(self.max_capture_size, 1),
)[0]
return trtllm_gen_output
+ elif self.mxfp4_backend == Mxfp4Backend.CK:
+ topk_weights, topk_ids = rocm_aiter_ops.fused_topk(
+ x, router_logits, layer.top_k, True
+ )
+ output = rocm_aiter_ops.fused_moe(
+ x,
+ layer.w13_weight,
+ layer.w2_weight,
+ topk_weights,
+ topk_ids,
+ activation_method=rocm_aiter_ops.get_aiter_activation_type("swiglu"),
+ quant_method=rocm_aiter_ops.get_aiter_quant_type("per_1x32"),
+ w1_scale=layer.w13_weight_scale,
+ w2_scale=layer.w2_weight_scale,
+ doweight_stage1=False,
+ hidden_pad=self.hidden_pad // 128 * 128,
+ intermediate_pad=self.intermediate_pad // 64 * 64 * 2,
+ bias1=layer.w13_bias,
+ bias2=layer.w2_bias,
+ )
+ return output
elif self.mxfp4_backend == Mxfp4Backend.TRITON:
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
triton_kernel_moe_forward,
@@ -1162,7 +1258,7 @@ class XpuMxfp4MoEMethod(Mxfp4MoEMethod):
topk_weights=routing_weights,
topk_ids=selected_experts,
n_experts_per_token=layer.top_k,
- activation=layer.activation,
+ activation=layer.activation.value,
num_experts=layer.local_num_experts,
is_mxfp4=True,
)
diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py
index 76410f2..5d7b7b5 100644
--- a/vllm/model_executor/layers/quantization/ptpc_fp8.py
+++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py
@@ -7,7 +7,6 @@ import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
-from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel,
)
@@ -26,10 +25,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
)
from vllm.platforms import current_platform
-ACTIVATION_SCHEMES = ["static", "dynamic"]
-
-logger = init_logger(__name__)
-
class PTPCFp8Config(Fp8Config):
"""Config class for Per-Token-Per-Channel Dynamic Quantization Fp8."""
diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py
index 36f20c8..dedc7db 100644
--- a/vllm/model_executor/layers/quantization/quark/quark.py
+++ b/vllm/model_executor/layers/quantization/quark/quark.py
@@ -35,6 +35,7 @@ from vllm.model_executor.layers.quantization.quark.utils import (
)
from vllm.model_executor.models.utils import WeightsMapper
from vllm.platforms import current_platform
+from vllm.transformers_utils.config import get_config
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
@@ -59,6 +60,22 @@ class QuarkConfig(QuantizationConfig):
self.kv_cache_group = kv_cache_group
self.kv_cache_config = kv_cache_config
self.pack_method = pack_method
+ self.dynamic_mxfp4_quant = False
+
+ def maybe_update_config(self, model_name: str, revision: str | None = None):
+ self.hf_config = get_config(
+ model=model_name,
+ trust_remote_code=False, # or get from model_config if available
+ revision=revision,
+ config_format="auto",
+ )
+
+ quant_config = getattr(self.hf_config, "quantization_config", None)
+ if quant_config is not None:
+ quant_dtype = quant_config["global_quant_config"]["weight"]["dtype"]
+ model_type = self.hf_config.model_type
+ if quant_dtype == "fp4" and model_type == "deepseek_v3":
+ self.dynamic_mxfp4_quant = True
def get_linear_method(self) -> "QuarkLinearMethod":
return QuarkLinearMethod(self)
@@ -108,7 +125,20 @@ class QuarkConfig(QuantizationConfig):
if should_ignore_layer(
prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping
):
- return UnquantizedLinearMethod()
+ if (
+ "self_attn" not in prefix # only quantize attention projections
+ or not getattr(self, "dynamic_mxfp4_quant", False)
+ or not isinstance(layer, LinearBase) # Ignore other methods
+ ):
+ return UnquantizedLinearMethod()
+
+ scheme = self.get_scheme(
+ layer=layer,
+ layer_name=prefix,
+ dynamic_mxfp4_quant=True,
+ )
+ layer.scheme = scheme
+ return QuarkLinearMethod(self)
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
layer.scheme = scheme
@@ -450,7 +480,9 @@ class QuarkConfig(QuantizationConfig):
)
return global_quant_config
- def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
+ def _get_scheme_from_config(
+ self, config: dict[str, Any], dynamic_mxfp4_quant: bool = False
+ ) -> "QuarkScheme":
if config.get("output_tensors") or config.get("bias"):
raise NotImplementedError(
"Currently, Quark models with output_tensors "
@@ -473,7 +505,9 @@ class QuarkConfig(QuantizationConfig):
input_symmetric=input_config.get("symmetric"),
)
elif self._is_w_ocp_mx_a_x(weight_config, input_config):
- return QuarkOCP_MX(weight_config, input_config)
+ return QuarkOCP_MX(
+ weight_config, input_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant
+ )
raise NotImplementedError(
"No quark compatible scheme was found. "
@@ -481,11 +515,15 @@ class QuarkConfig(QuantizationConfig):
f"Input config: {input_config}"
)
- def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme":
+ def get_scheme(
+ self, layer: torch.nn.Module, layer_name: str, dynamic_mxfp4_quant: bool = False
+ ) -> "QuarkScheme":
layer_quant_config = self._find_matched_config(layer_name, layer)
# Find the quant_scheme
- scheme = self._get_scheme_from_config(layer_quant_config)
+ scheme = self._get_scheme_from_config(
+ layer_quant_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant
+ )
# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability())
diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py
index 8394857..b2abbce 100644
--- a/vllm/model_executor/layers/quantization/quark/quark_moe.py
+++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py
@@ -5,8 +5,8 @@ from typing import Any
import torch
-import vllm.envs as envs
from vllm import _custom_ops as ops
+from vllm import envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.quantization.mxfp4 import (
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_fp8_moe_layer_for_marlin,
)
+from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_BLOCK_SIZE,
OCP_MX_Scheme,
@@ -49,7 +50,11 @@ from vllm.utils.math_utils import round_up
logger = init_logger(__name__)
-__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkOCP_MX_MoEMethod"]
+__all__ = [
+ "QuarkMoEMethod",
+ "QuarkOCP_MX_MoEMethod",
+ "QuarkOCP_MX_MoEMethod_OSS",
+]
class QuarkMoEMethod(FusedMoEMethodBase):
@@ -71,14 +76,30 @@ class QuarkMoEMethod(FusedMoEMethodBase):
"output_tensors and bias "
"quantized are not supported"
)
+
weight_config = layer_quant_config.get("weight")
input_config = layer_quant_config.get("input_tensors")
+
if quant_config._is_fp8_w4a8(weight_config, input_config):
return QuarkW4A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
elif quant_config._is_fp8_w8a8(weight_config, input_config):
return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config)
elif quant_config._is_w_ocp_mx_a_x(weight_config, input_config):
- return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config)
+ emulate = not current_platform.supports_mx() or not (
+ rocm_aiter_ops.is_fused_moe_enabled()
+ )
+ if (
+ input_config.get("dtype") == "fp8_e4m3"
+ and not input_config.get("is_dynamic")
+ and not emulate
+ ):
+ return QuarkOCP_MX_MoEMethod_OSS(
+ weight_config, input_config, module.moe_config
+ )
+ else:
+ return QuarkOCP_MX_MoEMethod(
+ weight_config, input_config, module.moe_config
+ )
else:
raise RuntimeError("Unsupported FusedMoe scheme")
@@ -706,13 +727,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
get_current_vllm_config().model_config.hf_config, "model_type", None
)
- self._emulate = (
+ self.emulate = (
not current_platform.supports_mx()
or not self.ocp_mx_scheme.startswith("w_mxfp4")
) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe)
- self.emulate = True if self.model_type == "gpt_oss" else self._emulate
-
if self.emulate:
logger.warning_once(
f"The current mode (supports_mx={current_platform.supports_mx()}, "
@@ -753,6 +772,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
)
params_dtype = torch.uint8
+ self.intermediate_size_per_partition = intermediate_size_per_partition
if self.model_type == "gpt_oss":
if current_platform.is_rocm():
intermediate_size_per_partition_after_pad = round_up(
@@ -765,6 +785,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
else:
intermediate_size_per_partition_after_pad = intermediate_size_per_partition
+ self.unpadded_hidden_size = extra_weight_attrs.get(
+ "unpadded_hidden_size", hidden_size
+ )
+
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
@@ -991,30 +1015,20 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
shared_experts_input: torch.Tensor | None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
if not self.emulate:
- if (
- self.model_type == "gpt_oss"
- and self.mxfp4_backend == Mxfp4Backend.TRITON
- ):
- raise NotImplementedError(
- "Triton kernel implemented fused MoE for GPT_OSS model "
- "in Quark(MoE) format is not integrated or provided yet."
- )
+ from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
+ rocm_aiter_fused_experts,
+ )
- else:
- from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
- rocm_aiter_fused_experts,
- )
-
- return rocm_aiter_fused_experts(
- x,
- layer.w13_weight,
- layer.w2_weight,
- topk_weights=topk_weights,
- topk_ids=topk_ids,
- activation=layer.activation,
- quant_config=self.moe_quant_config,
- expert_map=layer.expert_map,
- )
+ return rocm_aiter_fused_experts(
+ x,
+ layer.w13_weight,
+ layer.w2_weight,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ activation=layer.activation,
+ quant_config=self.moe_quant_config,
+ expert_map=layer.expert_map,
+ )
else:
from vllm.model_executor.layers.fused_moe import fused_experts
@@ -1031,3 +1045,133 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
expert_map=layer.expert_map,
quant_config=self.moe_quant_config,
)
+
+
+class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod):
+ def __init__(
+ self,
+ weight_config: dict[str, Any],
+ input_config: dict[str, Any],
+ moe: FusedMoEConfig,
+ ):
+ super().__init__(weight_config, input_config, moe)
+
+ def process_weights_after_loading(self, layer):
+ from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
+
+ w13_bias = layer.w13_bias.to(torch.float32)
+ w2_bias = layer.w2_bias.to(torch.float32)
+
+ layer.w13_bias = torch.nn.Parameter(w13_bias, requires_grad=False)
+ layer.w2_bias = torch.nn.Parameter(w2_bias, requires_grad=False)
+
+ # FIXME warp need to be adjusted based on batch size
+ # only apply to batched mode
+ if self.moe.use_ep:
+ num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8
+ else:
+ num_warps = 8
+
+ w13_weight, w13_flex, w13_scale = _swizzle_mxfp4(
+ layer.w13_weight, layer.w13_weight_scale, num_warps
+ )
+ w2_weight, w2_flex, w2_scale = _swizzle_mxfp4(
+ layer.w2_weight, layer.w2_weight_scale, num_warps
+ )
+
+ self.w13_weight_triton_tensor = w13_weight
+ self.w2_weight_triton_tensor = w2_weight
+
+ # need to delete the original weights to save memory on single GPU
+ del layer.w13_weight
+ del layer.w2_weight
+ layer.w13_weight = None
+ layer.w2_weight = None
+ torch.cuda.empty_cache()
+
+ if self.static_input_scales:
+ if layer.w13_input_scale is None or layer.w2_input_scale is None:
+ raise ValueError(
+ "QuantConfig has static quantization, but found "
+ "activation scales are None."
+ )
+ if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
+ layer.w2_input_scale
+ ):
+ logger.warning_once(
+ "Found input_scales that are not equal for "
+ "fp8 MoE layer. Using the maximum across experts "
+ "for each layer."
+ )
+
+ layer.w13_input_scale = torch.nn.Parameter(
+ layer.w13_input_scale.max().to(torch.float32), requires_grad=False
+ )
+ layer.w2_input_scale = torch.nn.Parameter(
+ layer.w2_input_scale.max().to(torch.float32), requires_grad=False
+ )
+
+ from triton_kernels.numerics import InFlexData
+
+ lhs_data13 = InFlexData(scale=layer.w13_input_scale)
+ lhs_data2 = InFlexData(scale=layer.w2_input_scale)
+
+ self.w13_precision_config = PrecisionConfig(
+ weight_scale=w13_scale,
+ flex_ctx=FlexCtx(rhs_data=w13_flex, lhs_data=lhs_data13),
+ )
+
+ self.w2_precision_config = PrecisionConfig(
+ weight_scale=w2_scale,
+ flex_ctx=FlexCtx(rhs_data=w2_flex, lhs_data=lhs_data2),
+ )
+
+ def get_fused_moe_quant_config(
+ self, layer: torch.nn.Module
+ ) -> FusedMoEQuantConfig | None:
+ return mxfp4_w4a8_moe_quant_config(
+ w1_scale=self.w13_precision_config,
+ w2_scale=self.w2_precision_config,
+ a1_scale=layer.w13_input_scale,
+ a2_scale=layer.w2_input_scale,
+ w1_bias=layer.w13_bias,
+ w2_bias=layer.w2_bias,
+ block_shape=None,
+ )
+
+ @property
+ def is_monolithic(self) -> bool:
+ return True
+
+ def apply_monolithic(
+ self,
+ layer: torch.nn.Module,
+ x: torch.Tensor,
+ router_logits: torch.Tensor,
+ expert_map: torch.Tensor | None = None,
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
+ if layer.enable_eplb:
+ raise NotImplementedError(
+ "EPLB not supported for `QuarkW4MXFp4MoEMethod_OSS` yet."
+ )
+
+ from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501
+ triton_kernel_moe_forward,
+ )
+
+ return triton_kernel_moe_forward(
+ hidden_states=x,
+ w1=self.w13_weight_triton_tensor,
+ w2=self.w2_weight_triton_tensor,
+ gating_output=router_logits,
+ topk=layer.top_k,
+ renormalize=layer.renormalize,
+ global_num_experts=layer.global_num_experts,
+ expert_map=expert_map,
+ quant_config=self.moe_quant_config,
+ apply_router_weight_on_input=layer.apply_router_weight_on_input,
+ unpadded_N_w1=self.intermediate_size_per_partition * 2,
+ unpadded_K_w1=self.unpadded_hidden_size,
+ unpadded_N_w2=self.unpadded_hidden_size,
+ unpadded_K_w2=self.intermediate_size_per_partition,
+ )
diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py
index c5f5012..6917bb6 100644
--- a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py
+++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py
@@ -24,7 +24,12 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_BLOCK_SIZE,
OCP_MX_Scheme,
)
-from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter
+from vllm.model_executor.parameter import (
+ GroupQuantScaleParameter,
+ ModelWeightParameter,
+ PackedvLLMParameter,
+)
+from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from .quark_scheme import QuarkScheme
@@ -169,13 +174,16 @@ except (ImportError, AttributeError, RuntimeError):
class QuarkOCP_MX(QuarkScheme):
def __init__(
- self, weight_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any]
+ self,
+ weight_quant_spec: dict[str, Any],
+ input_quant_spec: dict[str, Any],
+ dynamic_mxfp4_quant: bool = False,
):
self.out_dtype = torch.get_default_dtype()
self.qscheme = "per_group"
self.weight_quant_spec = weight_quant_spec
self.input_quant_spec = input_quant_spec
-
+ self.dynamic_mxfp4_quant = dynamic_mxfp4_quant
self.weight_dtype = weight_quant_spec["dtype"].replace("fp", "mxfp")
self.input_dtype = input_quant_spec["dtype"].replace("fp", "mxfp")
@@ -269,7 +277,13 @@ class QuarkOCP_MX(QuarkScheme):
layer.weight_scale.data, requires_grad=False
)
else:
- if self.rocm_use_aiter_fp4_asm_gemm:
+ if self.dynamic_mxfp4_quant:
+ w_q, w_s = dynamic_mxfp4_quant(layer.weight)
+ layer.weight_scale = torch.nn.Parameter(
+ w_s.T.contiguous(), requires_grad=False
+ )
+ layer.weight = torch.nn.Parameter(w_q, requires_grad=False)
+ elif self.rocm_use_aiter_fp4_asm_gemm:
# shuffle weight scale
weight_scale_shuffle = layer.weight_scale.data
sm, sn = weight_scale_shuffle.shape
@@ -302,36 +316,51 @@ class QuarkOCP_MX(QuarkScheme):
weight_loader: Callable,
**kwargs,
):
- output_size_per_partition = sum(output_partition_sizes)
- layer.logical_widths = output_partition_sizes
+ if self.dynamic_mxfp4_quant:
+ weight = ModelWeightParameter(
+ data=torch.empty(
+ sum(output_partition_sizes),
+ input_size_per_partition,
+ dtype=params_dtype,
+ ),
+ input_dim=1,
+ output_dim=0,
+ weight_loader=weight_loader,
+ )
- # WEIGHT
- weight = PackedvLLMParameter(
- data=torch.empty(
- output_size_per_partition,
- self.get_packed_dim(input_size_per_partition, self.weight_dtype),
- dtype=torch.uint8,
- ),
- input_dim=1,
- output_dim=0,
- packed_dim=1,
- packed_factor=self.packed_factor,
- weight_loader=weight_loader,
- )
- layer.register_parameter("weight", weight)
+ layer.register_parameter("weight", weight)
+ set_weight_attrs(weight, kwargs)
+ else:
+ output_size_per_partition = sum(output_partition_sizes)
+ layer.logical_widths = output_partition_sizes
- # WEIGHT SCALE
- weight_scale = GroupQuantScaleParameter(
- data=torch.empty(
- output_size_per_partition,
- input_size_per_partition // OCP_MX_BLOCK_SIZE,
- dtype=torch.uint8,
- ),
- input_dim=1,
- output_dim=0,
- weight_loader=weight_loader,
- )
- layer.register_parameter("weight_scale", weight_scale)
+ # WEIGHT
+ weight = PackedvLLMParameter(
+ data=torch.empty(
+ output_size_per_partition,
+ self.get_packed_dim(input_size_per_partition, self.weight_dtype),
+ dtype=torch.uint8,
+ ),
+ input_dim=1,
+ output_dim=0,
+ packed_dim=1,
+ packed_factor=self.packed_factor,
+ weight_loader=weight_loader,
+ )
+ layer.register_parameter("weight", weight)
+
+ # WEIGHT SCALE
+ weight_scale = GroupQuantScaleParameter(
+ data=torch.empty(
+ output_size_per_partition,
+ input_size_per_partition // OCP_MX_BLOCK_SIZE,
+ dtype=torch.uint8,
+ ),
+ input_dim=1,
+ output_dim=0,
+ weight_loader=weight_loader,
+ )
+ layer.register_parameter("weight_scale", weight_scale)
def apply_weights(
self,
diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
index fadf56b..42677a5 100644
--- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
+++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py
@@ -6,28 +6,18 @@ from typing import TYPE_CHECKING
import torch
-import vllm.model_executor.layers.fused_moe.modular_kernel as mk
-from vllm import _custom_ops as ops
+import vllm.envs as envs
from vllm.logger import init_logger
-from vllm.model_executor.layers.fused_moe.activation import MoEActivation
-from vllm.model_executor.layers.fused_moe.config import (
- FusedMoEConfig,
- FusedMoEParallelConfig,
- RoutingMethodType,
-)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
- activation_to_flashinfer_int,
align_fp4_moe_weights_for_fi,
)
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
swizzle_blockscale,
)
-from vllm.model_executor.layers.quantization.utils.quant_utils import (
- QuantKey,
- kNvfp4Dynamic,
- kNvfp4Static,
-)
from vllm.platforms import current_platform
+from vllm.utils.flashinfer import (
+ has_flashinfer_cutlass_fused_moe,
+)
if TYPE_CHECKING:
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
@@ -42,92 +32,15 @@ __all__ = [
"reorder_w1w3_to_w3w1",
]
-#
-# Methods used by the oracle for kernel selection.
-#
-
-def _supports_current_device() -> bool:
- """Supports only Blackwell-family GPUs."""
- p = current_platform
- return p.is_cuda() and p.is_device_capability_family(100)
-
-
-def _supports_no_act_and_mul() -> bool:
- """Supports non-gated MoE."""
- return True
-
-
-def _supports_quant_scheme(
- weight_key: QuantKey | None,
- activation_key: QuantKey | None,
-) -> bool:
- """Supports Nvfp4 quantization."""
- SUPPORTED_W_A = [
- (kNvfp4Static, kNvfp4Dynamic),
- ]
- return (weight_key, activation_key) in SUPPORTED_W_A
-
-
-def _supports_activation(activation: MoEActivation) -> bool:
- return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
-
-
-def _supports_routing_method(
- routing_method: RoutingMethodType,
-) -> bool:
- """Monolithic kernels need to express router support."""
- # NOTE(rob): potentially allow others here. This is a conservative list.
- return routing_method in [
- RoutingMethodType.DeepSeekV3,
- RoutingMethodType.Renormalize,
- RoutingMethodType.RenormalizeNaive,
- RoutingMethodType.Llama4,
- ]
-
-
-def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
- """
- TRTLLM is a monolithic kernel that requires dispatch_router_logits() for
- the naive dispatch/combine path. DeepEP HT only implements dispatch() for
- the modular kernel path, so TRTLLM is incompatible with DeepEP HT.
- """
- return not moe_parallel_config.use_deepep_ht_kernels
-
-
-def is_supported_config_trtllm(
- moe_config: FusedMoEConfig,
- weight_key: QuantKey | None,
- activation_key: QuantKey | None,
- activation_format: mk.FusedMoEActivationFormat,
-) -> tuple[bool, str | None]:
- """
- This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config
- """
-
- def _make_reason(reason: str) -> str:
- return f"kernel does not support {reason}"
-
- if not _supports_current_device():
- return False, _make_reason(f"current device {current_platform.device_name}")
- elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()):
- return False, _make_reason("no act_and_mul MLP layer")
- elif not _supports_activation(moe_config.activation):
- return False, _make_reason(f"{moe_config.activation} activation")
- elif not _supports_quant_scheme(weight_key, activation_key):
- return False, _make_reason(f"quantization scheme {weight_key}x{activation_key}")
- elif not _supports_parallel_config(moe_config.moe_parallel_config):
- return False, _make_reason(f"parallel config {moe_config.moe_parallel_config}")
- elif not _supports_routing_method(moe_config.routing_method):
- return False, _make_reason(f"routing method {moe_config.routing_method}")
- elif activation_format != mk.FusedMoEActivationFormat.Standard:
- return False, _make_reason(f"activation format {activation_format}")
- elif moe_config.hidden_dim % 512 != 0:
- return False, _make_reason(
- f"hidden_dim must be divisible by 512, found {moe_config.hidden_dim}"
- )
-
- return True, None
+def is_flashinfer_fp4_cutlass_moe_available() -> bool:
+ """Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
+ return (
+ envs.VLLM_USE_FLASHINFER_MOE_FP4
+ and has_flashinfer_cutlass_fused_moe()
+ and current_platform.is_cuda()
+ and current_platform.has_device_capability(100)
+ )
def reorder_w1w3_to_w3w1(
@@ -276,190 +189,6 @@ def prepare_static_weights_for_trtllm_fp4_moe(
)
-def flashinfer_trtllm_fp4_moe(
- layer: torch.nn.Module,
- x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
- router_logits: torch.Tensor,
- top_k: int,
- activation: MoEActivation,
- global_num_experts: int,
- num_expert_group: int | None,
- topk_group: int | None,
- custom_routing_function: object | None,
- e_score_correction_bias: torch.Tensor | None,
-) -> torch.Tensor:
- """
- Apply FlashInfer TensorRT-LLM FP4 MoE kernel.
-
- Args:
- layer: The MoE layer with weights and scales
- x: Input tensor
- router_logits: Router logits for expert selection
- top_k: Number of experts to select per token
- activation: Activation function to use
- global_num_experts: Total number of experts across all ranks
- num_expert_group: Number of expert groups (for grouped routing)
- topk_group: Top-k within each group
- custom_routing_function: Custom routing function (e.g., Llama4)
- e_score_correction_bias: Optional routing bias correction
-
- Returns:
- Output tensor from the MoE layer
- """
- import flashinfer
-
- from vllm.model_executor.models.llama4 import Llama4MoE
-
- SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL]
- assert activation in SUPPORTED_ACTIVATIONS, (
- f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer "
- f"TRTLLM FP4 MoE, {activation} found instead."
- )
-
- # Quantize input to FP4
- if isinstance(x, tuple):
- hidden_states_fp4, hidden_states_scale_linear_fp4 = x
- else:
- # hidden_states is the already quantized
- (hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
- x, layer.a1_gscale, is_sf_swizzled_layout=False
- )
-
- # Determine routing method type
- use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function
- routing_method_type = layer.routing_method_type
- if use_llama4_routing:
- routing_method_type = flashinfer.RoutingMethodType.Llama4
-
- # Cast to Fp32 (required by kernel).
- router_logits = (
- router_logits.to(torch.float32)
- if routing_method_type == RoutingMethodType.DeepSeekV3
- else router_logits
- )
-
- # Determine activation type
- activation_type = activation_to_flashinfer_int(layer.activation)
-
- # Call TRT-LLM FP4 block-scale MoE kernel
- out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe(
- routing_logits=router_logits,
- routing_bias=e_score_correction_bias,
- hidden_states=hidden_states_fp4,
- hidden_states_scale=hidden_states_scale_linear_fp4.view(
- torch.float8_e4m3fn
- ).reshape(*hidden_states_fp4.shape[:-1], -1),
- gemm1_weights=layer.w13_weight.data,
- gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
- gemm1_bias=None,
- gemm1_alpha=None,
- gemm1_beta=None,
- gemm1_clamp_limit=None,
- gemm2_weights=layer.w2_weight.data,
- gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
- gemm2_bias=None,
- output1_scale_scalar=layer.g1_scale_c.data,
- output1_scale_gate_scalar=layer.g1_alphas.data,
- output2_scale_scalar=layer.g2_alphas.data,
- num_experts=global_num_experts,
- top_k=top_k,
- n_group=num_expert_group if num_expert_group is not None else 0,
- topk_group=topk_group if topk_group is not None else 0,
- intermediate_size=layer.intermediate_size_per_partition,
- local_expert_offset=layer.ep_rank * layer.local_num_experts,
- local_num_experts=layer.local_num_experts,
- routed_scaling_factor=None,
- routing_method_type=routing_method_type,
- do_finalize=True,
- activation_type=activation_type,
- )[0]
-
- return out
-
-
-def flashinfer_trtllm_fp4_routed_moe(
- layer: torch.nn.Module,
- x: torch.Tensor,
- topk_ids: torch.Tensor,
- topk_weights: torch.Tensor,
- top_k: int,
- activation: MoEActivation,
- global_num_experts: int,
-) -> torch.Tensor:
- """
- Apply FlashInfer TensorRT-LLM FP4 MoE kernel. Uses packed
- input top k expert indices and scores rather than computing
- top k expert indices from scores.
-
- Args:
- layer: The MoE layer with weights and scales
- x: Input tensor
- topk_ids: Ids of selected experts
- top_k: Number of experts to select per token
- activation: Activation function to use
- global_num_experts: Total number of experts across all ranks
-
- Returns:
- Output tensor from the MoE layer
- """
- import flashinfer
-
- # https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2535
- assert activation == MoEActivation.SILU, (
- "Only SiLU activation is supported for FlashInfer TRTLLM FP4 Routed MoE. "
- f"{activation} found instead."
- )
-
- # Pack top k ids and expert weights into a single int32 tensor, as
- # required by TRT-LLM
- packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
- torch.bfloat16
- ).view(torch.int16)
-
- if isinstance(x, tuple):
- # Hidden_states is the already quantized
- hidden_states_fp4, hidden_states_scale_linear_fp4 = x
- else:
- # Quantize input to FP4
- (hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
- x, layer.a1_gscale, is_sf_swizzled_layout=False
- )
-
- # Call TRT-LLM FP4 block-scale MoE kernel
- out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
- topk_ids=packed_tensor,
- routing_bias=None,
- hidden_states=hidden_states_fp4,
- hidden_states_scale=hidden_states_scale_linear_fp4.view(
- torch.float8_e4m3fn
- ).reshape(*hidden_states_fp4.shape[:-1], -1),
- gemm1_weights=layer.w13_weight.data,
- gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn),
- gemm1_bias=None,
- gemm1_alpha=None,
- gemm1_beta=None,
- gemm1_clamp_limit=None,
- gemm2_weights=layer.w2_weight.data,
- gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn),
- gemm2_bias=None,
- output1_scale_scalar=layer.g1_scale_c.data,
- output1_scale_gate_scalar=layer.g1_alphas.data,
- output2_scale_scalar=layer.g2_alphas.data,
- num_experts=global_num_experts,
- top_k=top_k,
- n_group=0,
- topk_group=0,
- intermediate_size=layer.intermediate_size_per_partition,
- local_expert_offset=layer.ep_rank * layer.local_num_experts,
- local_num_experts=layer.local_num_experts,
- routed_scaling_factor=None,
- routing_method_type=1,
- do_finalize=True,
- )[0]
-
- return out
-
-
def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
backend: "NvFp4MoeBackend",
layer: "FusedMoE",
@@ -526,6 +255,7 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass(
)
)
layer.intermediate_size_per_partition = padded_intermediate
+ layer.moe_config.intermediate_size_per_partition = padded_intermediate
w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe(
w13,
diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
index 3d7d8e6..a8be1d6 100644
--- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
+from typing import TYPE_CHECKING
import torch
@@ -10,6 +11,9 @@ from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.platforms import current_platform
from vllm.utils.math_utils import round_up
+if TYPE_CHECKING:
+ from flashinfer.fused_moe.core import ActivationType
+
logger = init_logger(__name__)
@@ -20,6 +24,10 @@ class FlashinferMoeBackend(Enum):
def activation_to_flashinfer_int(activation: MoEActivation) -> int:
+ return activation_to_flashinfer_type(activation).value
+
+
+def activation_to_flashinfer_type(activation: MoEActivation) -> "ActivationType":
from flashinfer.fused_moe.core import ActivationType
# silu and gelu are mapped to their gated versions SwiGLU and GeGLU respectively
@@ -30,7 +38,7 @@ def activation_to_flashinfer_int(activation: MoEActivation) -> int:
MoEActivation.GELU: ActivationType.Geglu,
MoEActivation.RELU2_NO_MUL: ActivationType.Relu2,
}
- return ACTIVATION_TO_FI_ACTIVATION[activation].value
+ return ACTIVATION_TO_FI_ACTIVATION[activation]
def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
@@ -87,104 +95,6 @@ def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(
)
-def register_scales_for_trtllm_fp8_per_tensor_moe(
- layer: torch.nn.Module,
- w13_scale: torch.Tensor,
- w13_input_scale: torch.Tensor,
- w2_scale: torch.Tensor,
- w2_input_scale: torch.Tensor,
-) -> None:
- """Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel"""
- g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi(
- w13_scale=w13_scale,
- w13_input_scale=w13_input_scale,
- w2_scale=w2_scale,
- w2_input_scale=w2_input_scale,
- )
- layer.w2_input_scale_inv = 1.0 / w2_input_scale
- layer.output1_scales_gate_scalar = g1_alphas
-
- if layer.activation.is_gated:
- layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv
- else:
- layer.output1_scales_scalar = (
- torch.ones_like(g1_alphas) * layer.w2_input_scale_inv
- )
- layer.output2_scales_scalar = g2_alphas
-
-
-def apply_fi_trtllm_fp8_per_tensor_moe(
- layer: torch.nn.Module,
- hidden_states: torch.Tensor,
- router_logits: torch.Tensor,
- routing_bias: torch.Tensor | None,
- top_k: int,
- num_expert_group: int | None,
- topk_group: int | None,
- global_num_experts: int,
- apply_router_weight_on_input: bool,
-) -> torch.Tensor:
- from flashinfer.fused_moe import RoutingMethodType
-
- import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401
- from vllm.model_executor.models.llama4 import Llama4MoE
-
- # Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe
- assert (
- hasattr(layer, "output1_scales_scalar")
- and hasattr(layer, "output1_scales_gate_scalar")
- and hasattr(layer, "output2_scales_scalar")
- )
-
- if layer.routing_method_type == RoutingMethodType.Llama4:
- assert (
- not layer.renormalize
- and layer.custom_routing_function == Llama4MoE.custom_routing_function
- ), (
- "FusedMoE flashinfer kernels with Llama4 routing method are only "
- "supported for Llama4"
- )
- else:
- assert layer.custom_routing_function is None, (
- "Custom routing function is only supported for Llama4"
- )
- activation_type = activation_to_flashinfer_int(layer.activation)
-
- return torch.ops.vllm.fi_trtllm_fp8_per_tensor_moe(
- routing_logits=router_logits,
- routing_bias=routing_bias,
- hidden_states=hidden_states,
- input_scale=layer.w13_input_scale,
- gemm1_weights=layer.w13_weight,
- gemm2_weights=layer.w2_weight,
- output1_scales_scalar=layer.output1_scales_scalar,
- output1_scales_gate_scalar=layer.output1_scales_gate_scalar,
- output2_scales_scalar=layer.output2_scales_scalar,
- num_experts=global_num_experts,
- top_k=top_k,
- num_expert_group=num_expert_group,
- topk_group=topk_group,
- intermediate_size=layer.intermediate_size_per_partition,
- local_expert_offset=layer.ep_rank * layer.local_num_experts,
- local_num_experts=layer.local_num_experts,
- use_routing_scales_on_input=apply_router_weight_on_input,
- routing_method_type=layer.routing_method_type,
- activation_type=activation_type,
- )
-
-
-def make_fp8_moe_alpha_scales_for_fi(
- w13_scale: torch.Tensor,
- w13_input_scale: torch.Tensor,
- w2_scale: torch.Tensor,
- w2_input_scale: torch.Tensor,
-) -> tuple[torch.Tensor, torch.Tensor]:
- g1_alphas = (w13_scale * w13_input_scale).squeeze()
- g2_alphas = (w2_scale * w2_input_scale).squeeze()
-
- return g1_alphas, g2_alphas
-
-
def get_flashinfer_moe_backend() -> FlashinferMoeBackend:
backend_map = {
"throughput": FlashinferMoeBackend.CUTLASS,
@@ -432,6 +342,7 @@ def prepare_fp8_moe_layer_for_fi(
min_alignment,
)
layer.intermediate_size_per_partition = new_intermediate
+ layer.moe_config.intermediate_size_per_partition = new_intermediate
# FI kernels require W31 layout rather than W13.
if layer.moe_config.is_act_and_mul:
@@ -440,20 +351,12 @@ def prepare_fp8_moe_layer_for_fi(
w13_scale = swap_w13_to_w31(w13_scale)
# FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle
- # and registration of alpha scales. Note that we do not register
- # as nn.Parameters since they are not needed for weight-reloading.
+ # and registration of alpha scales.
if is_trtllm and not block_quant:
assert w13_input_scale is not None
assert w2_input_scale is not None
rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2, is_gated)
- register_scales_for_trtllm_fp8_per_tensor_moe(
- layer,
- w13_scale=w13_scale,
- w13_input_scale=w13_input_scale,
- w2_scale=w2_scale,
- w2_input_scale=w2_input_scale,
- )
# Clamp block scales to avoid NaN from the FlashInfer CUTLASS kernel.
# Some FP8 models have near-zero block scales (~1e-23) for dead/unused
diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py
index ee3f2ce..ccd5cce 100644
--- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py
@@ -53,7 +53,10 @@ logger = init_logger(__name__)
def is_fp8(x: torch.dtype | torch.Tensor) -> bool:
if isinstance(x, torch.Tensor):
x = x.dtype
- return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
+ try:
+ return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
+ except:
+ return False
# We need to pass in the is_hopper flag as argument because the function
diff --git a/vllm/model_executor/layers/quantization/utils/gguf_utils.py b/vllm/model_executor/layers/quantization/utils/gguf_utils.py
new file mode 100644
index 0000000..79b34e2
--- /dev/null
+++ b/vllm/model_executor/layers/quantization/utils/gguf_utils.py
@@ -0,0 +1,373 @@
+import torch
+import numpy as np
+from gguf.constants import GGMLQuantizationType
+
+def get_awq_format(w, group_size=128, w_bit=4):
+ org_w_shape = w.shape
+ ori_w_dtype = torch.get_default_dtype()
+ assert w_bit == 4
+ assert w.shape[1] % group_size == 0
+
+ in_features = org_w_shape[1]
+ w = w.reshape(-1, group_size)
+ assert torch.isnan(w).sum() == 0
+
+ max_val = w.amax(dim=1, keepdim=True)
+ min_val = w.amin(dim=1, keepdim=True)
+ max_int = 2**w_bit - 1
+ min_int = 0
+ scales = (max_val - min_val).clamp(min=1e-5) / max_int
+ zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
+ w = (
+ torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
+ ) * scales
+ zeros = zeros.view(org_w_shape[0], -1)
+ scales = scales.view(org_w_shape[0], -1)
+ w = w.reshape(org_w_shape)
+ assert torch.isnan(scales).sum() == 0
+ assert torch.isnan(w).sum() == 0
+
+ scales = scales.t().contiguous() # input // group, o
+ zeros = zeros.t().contiguous() # input // group, o
+
+ # from auto awq
+ scale_zeros = zeros * scales
+ scales = scales.clone().to(ori_w_dtype)
+
+ pack_num = 32 // w_bit
+ intweight = []
+ for idx in range(in_features):
+ intweight.append(
+ torch.round(
+ (w[:, idx] + scale_zeros[idx // group_size])
+ / scales[idx // group_size]
+ ).to(torch.int)[:, None]
+ )
+ intweight = torch.cat(intweight, dim=1)
+ intweight = intweight.t().contiguous()
+ intweight = intweight.to(dtype=torch.int32)
+
+ qweight = torch.zeros(
+ (intweight.shape[0], intweight.shape[1] // 32 * w_bit),
+ dtype=torch.int32,
+ device=intweight.device,
+ )
+
+ for col in range(intweight.shape[1] // pack_num):
+ order_map = [0, 2, w_bit, 6, 1, 3, 5, 7]
+ for i in range(pack_num):
+ qweight_col = intweight[:, col * pack_num + order_map[i]]
+ qweight[:, col] |= qweight_col << (i * w_bit)
+
+ zeros = zeros.to(dtype=torch.int32, device=qweight.device)
+
+ qzeros = torch.zeros(
+ (zeros.shape[0], zeros.shape[1] // 32 * w_bit),
+ dtype=torch.int32,
+ device=zeros.device,
+ )
+
+ for col in range(zeros.shape[1] // pack_num):
+ order_map = [0, 2, w_bit, 6, 1, 3, 5, 7]
+ for i in range(pack_num):
+ qzero_col = zeros[:, col * pack_num + order_map[i]]
+ qzeros[:, col] |= qzero_col << (i * w_bit)
+
+ return qweight, qzeros, scales
+
+GGML_BLOCK_SIZES = {
+ "F32": 4,
+ "F16": 2,
+ "Q4_0": 2 + 16,
+ "Q5_0": 2 + 4 + 16,
+ "Q8_0": 2 + 32,
+ "Q2_K": 256 // 16 + 256 // 4 + 2 + 2,
+ "Q3_K": 256 // 8 + 256 // 4 + 12 + 2,
+ "Q4_K": 2 + 2 + 12 + 256 // 2,
+ "Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2,
+ "Q6_K": 256 // 2 + 256 // 4 + 256 // 16 + 2,
+ "IQ4_XS": 2 + 2 + 256 // 2 + 256 // 64,
+}
+
+def dequantize_f32(data):
+ return np.frombuffer(data, dtype=np.float32)
+
+def dequantize_f16(data):
+ return np.frombuffer(data, dtype=np.float16)
+
+def dequantize_q4_0(data):
+ num_blocks = len(data) // GGML_BLOCK_SIZES["Q4_0"]
+
+ scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 8)[:, :1].astype(np.float32)
+ qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 16)[:, 2:]
+
+ return np.concatenate([
+ scales * ((qs & 0xf).astype(np.int8) - 8),
+ scales * ((qs >> 4).astype(np.int8) - 8),
+ ], axis=1)
+
+def dequantize_q5_0(data):
+ num_blocks = len(data) // GGML_BLOCK_SIZES["Q5_0"]
+
+ scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 2 + 8)[:, :1].astype(np.float32)
+ qh = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2:2 + 4]
+ qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2 + 4:]
+
+ bits = np.unpackbits(qh, axis=-1, bitorder="little")
+
+ x0 = ((qs & 0xf).astype(np.int8) | (bits[:, :16] << 4)) - 16
+ x1 = ((qs >> 4).astype(np.int8) | (bits[:, 16:] << 4)) - 16
+
+ return np.concatenate([
+ scales * x0,
+ scales * x1,
+ ], axis=1)
+
+def dequantize_q8_0(data):
+ num_blocks = len(data) // GGML_BLOCK_SIZES["Q8_0"]
+
+ scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 16)[:, :1].astype(np.float32)
+ qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:]
+ return scales * qs
+
+def dequantize_q2_k(data):
+ block_size = GGML_BLOCK_SIZES["Q2_K"]
+ num_blocks = len(data) // block_size
+
+ data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
+ data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
+
+ dmin = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)
+ d = data_f16[:, -2].reshape(num_blocks, 1, 1).astype(np.float32)
+ scales = data_u8[:, :16].reshape(num_blocks, 16, 1)
+ qs = data_u8[:, 16:80].reshape(num_blocks, 64)
+
+ tmp = np.stack([
+ qs[:, 00:16] >> 0,
+ qs[:, 16:32] >> 0,
+ qs[:, 00:16] >> 2,
+ qs[:, 16:32] >> 2,
+ qs[:, 00:16] >> 4,
+ qs[:, 16:32] >> 4,
+ qs[:, 00:16] >> 6,
+ qs[:, 16:32] >> 6,
+ qs[:, 32:48] >> 0,
+ qs[:, 48:64] >> 0,
+ qs[:, 32:48] >> 2,
+ qs[:, 48:64] >> 2,
+ qs[:, 32:48] >> 4,
+ qs[:, 48:64] >> 4,
+ qs[:, 32:48] >> 6,
+ qs[:, 48:64] >> 6,
+ ], axis=1)
+
+ return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)
+
+
+def dequantize_q3_k(data):
+ block_size = GGML_BLOCK_SIZES["Q3_K"]
+ num_blocks = len(data) // block_size
+
+ data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
+ data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
+
+ d = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)
+ bits = np.unpackbits(data_u8[:, :32].reshape(num_blocks, 32, 1), axis=-1, bitorder="little")
+ bits = 4 ^ (bits << 2)
+ qs = data_u8[:, 32:32 + 64].astype(np.int16)
+ a, b, c = data_u8[:, 96: 96 + 12].reshape(num_blocks, 3, 4).transpose(1, 0, 2)
+ scales = np.zeros((num_blocks, 4, 4), dtype=np.uint8)
+ scales[:, 0] = (a & 15) | ((c & 3) << 4)
+ scales[:, 1] = (b & 15) | (((c >> 2) & 3) << 4)
+ scales[:, 2] = (a >> 4) | (((c >> 4) & 3) << 4)
+ scales[:, 3] = (b >> 4) | ((c >> 6) << 4)
+ scales = scales.reshape(num_blocks, 16, 1).astype(np.int16)
+
+ return d * (scales - 32) * np.stack([
+ (((qs[:, 00:16] >> 0) & 3) - bits[:, :16, 0]),
+ (((qs[:, 16:32] >> 0) & 3) - bits[:, 16:, 0]),
+ (((qs[:, 00:16] >> 2) & 3) - bits[:, :16, 1]),
+ (((qs[:, 16:32] >> 2) & 3) - bits[:, 16:, 1]),
+ (((qs[:, 00:16] >> 4) & 3) - bits[:, :16, 2]),
+ (((qs[:, 16:32] >> 4) & 3) - bits[:, 16:, 2]),
+ (((qs[:, 00:16] >> 6) & 3) - bits[:, :16, 3]),
+ (((qs[:, 16:32] >> 6) & 3) - bits[:, 16:, 3]),
+ (((qs[:, 32:48] >> 0) & 3) - bits[:, :16, 4]),
+ (((qs[:, 48:64] >> 0) & 3) - bits[:, 16:, 4]),
+ (((qs[:, 32:48] >> 2) & 3) - bits[:, :16, 5]),
+ (((qs[:, 48:64] >> 2) & 3) - bits[:, 16:, 5]),
+ (((qs[:, 32:48] >> 4) & 3) - bits[:, :16, 6]),
+ (((qs[:, 48:64] >> 4) & 3) - bits[:, 16:, 6]),
+ (((qs[:, 32:48] >> 6) & 3) - bits[:, :16, 7]),
+ (((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7])
+ ], axis=1)
+
+def dequantize_q4_k(data, device=None):
+ block_size = GGML_BLOCK_SIZES["Q4_K"]
+ num_blocks = len(data) // block_size
+ data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
+ data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
+ # Casting to float32 because float16 is very slow on CPU
+ scale_factors = data_f16[:, 0].reshape(num_blocks, 1, 1).astype(np.float32)
+ scale_offsets = data_f16[:, 1].reshape(num_blocks, 1, 1).astype(np.float32)
+ qs1 = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
+ qs2 = data_u8[:, 16:].reshape(num_blocks, 4, 32)
+ # Dequantize scales and offsets (6 bits and 4 + 2 bits)
+ factors = scale_factors * np.concatenate([qs1[:, 0:4] & 0b111111, (qs1[:, 8:] & 15) | ((qs1[:, 0:4] >> 6) << 4)], axis=1)
+ offsets = scale_offsets * np.concatenate([qs1[:, 4:8] & 0b111111, (qs1[:, 8:] >> 4) | ((qs1[:, 4:8] >> 6) << 4)], axis=1)
+ # Interleave low and high quantized bits
+ qs2 = np.stack([qs2 & 0xf, qs2 >> 4], axis=2).reshape(num_blocks, 8, 32)
+ # Dequantize final weights using scales and offsets
+ weight = factors * qs2 - offsets
+ if device is None:
+ return weight
+ return torch.from_numpy(weight).to(device=device)
+
+def dequantize_q5_k(data):
+ block_size = GGML_BLOCK_SIZES["Q5_K"]
+ num_blocks = len(data) // block_size
+
+ data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
+ data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
+
+ d = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32)
+ dmin = data_f16[:, 1].reshape(num_blocks, 1).astype(np.float32)
+ scales = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
+ qh = data_u8[:, 16: 16 + 32].reshape(num_blocks, 32, 1)
+ qs = data_u8[:, 48: 48 + 128].reshape(num_blocks, 4, 32)
+
+ bits = np.unpackbits(qh, axis=-1, bitorder="little")
+
+ qs_hi_4 = qs >> 4
+ qs_lo_4 = qs & 15
+
+ scales_lo_6 = scales[:, :8] & 63
+ scales_hi_6 = scales[:, :8] >> 6
+ scales_lo_4 = scales[:, 8:] & 15
+ scales_hi_4 = scales[:, 8:] >> 4
+
+ m1 = dmin * scales_lo_6[:, 4]
+ m2 = dmin * scales_lo_6[:, 5]
+ m3 = dmin * scales_lo_6[:, 6]
+ m4 = dmin * scales_lo_6[:, 7]
+ m5 = dmin * (scales_hi_4[:, 0] | (scales_hi_6[:, 4] << 4))
+ m6 = dmin * (scales_hi_4[:, 1] | (scales_hi_6[:, 5] << 4))
+ m7 = dmin * (scales_hi_4[:, 2] | (scales_hi_6[:, 6] << 4))
+ m8 = dmin * (scales_hi_4[:, 3] | (scales_hi_6[:, 7] << 4))
+
+ d1 = d * scales_lo_6[:, 0]
+ d2 = d * scales_lo_6[:, 1]
+ d3 = d * scales_lo_6[:, 2]
+ d4 = d * scales_lo_6[:, 3]
+ d5 = d * (scales_lo_4[:, 0] | (scales_hi_6[:, 0] << 4))
+ d6 = d * (scales_lo_4[:, 1] | (scales_hi_6[:, 1] << 4))
+ d7 = d * (scales_lo_4[:, 2] | (scales_hi_6[:, 2] << 4))
+ d8 = d * (scales_lo_4[:, 3] | (scales_hi_6[:, 3] << 4))
+
+ return np.concatenate([
+ d1 * (qs_lo_4[:, 0] + (bits[:, :, 0] << 4)) - m1,
+ d2 * (qs_hi_4[:, 0] + (bits[:, :, 1] << 4)) - m2,
+ d3 * (qs_lo_4[:, 1] + (bits[:, :, 2] << 4)) - m3,
+ d4 * (qs_hi_4[:, 1] + (bits[:, :, 3] << 4)) - m4,
+ d5 * (qs_lo_4[:, 2] + (bits[:, :, 4] << 4)) - m5,
+ d6 * (qs_hi_4[:, 2] + (bits[:, :, 5] << 4)) - m6,
+ d7 * (qs_lo_4[:, 3] + (bits[:, :, 6] << 4)) - m7,
+ d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8,
+ ], axis=1)
+
+def dequantize_q6_k(data, device = None):
+ block_size = GGML_BLOCK_SIZES["Q6_K"]
+ num_blocks = len(data) // block_size
+
+ data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
+ data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
+ data_i8 = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, block_size)
+
+ scales = data_f16[:, -1].reshape(num_blocks, 1).astype(np.float32)
+ # TODO use uint8 and cast later?
+ ql = data_u8[:, :128].astype(np.int16)
+ qh = data_u8[:, 128:192].astype(np.int16)
+ sc = data_i8[:, 192:208, np.newaxis].astype(np.float32)
+
+ # Unpack bits, subtraction requires signed data type
+ q1 = (ql[:, :32 ] & 0xF) | (((qh[:, :32] >> 0) & 3) << 4) - 32
+ q2 = (ql[:, 32:64 ] & 0xF) | (((qh[:, :32] >> 2) & 3) << 4) - 32
+ q3 = (ql[:, :32 ] >> 4) | (((qh[:, :32] >> 4) & 3) << 4) - 32
+ q4 = (ql[:, 32:64 ] >> 4) | (((qh[:, :32] >> 6) & 3) << 4) - 32
+ q5 = (ql[:, 64:96 ] & 0xF) | (((qh[:, 32:] >> 0) & 3) << 4) - 32
+ q6 = (ql[:, 96:128] & 0xF) | (((qh[:, 32:] >> 2) & 3) << 4) - 32
+ q7 = (ql[:, 64:96 ] >> 4) | (((qh[:, 32:] >> 4) & 3) << 4) - 32
+ q8 = (ql[:, 96:128] >> 4) | (((qh[:, 32:] >> 6) & 3) << 4) - 32
+
+ # Dequantize
+ weight = scales * np.concatenate([
+ sc[:, 0] * q1[:, :16],
+ sc[:, 1] * q1[:, 16:],
+ sc[:, 2] * q2[:, :16],
+ sc[:, 3] * q2[:, 16:],
+ sc[:, 4] * q3[:, :16],
+ sc[:, 5] * q3[:, 16:],
+ sc[:, 6] * q4[:, :16],
+ sc[:, 7] * q4[:, 16:],
+ sc[:, 8] * q5[:, :16],
+ sc[:, 9] * q5[:, 16:],
+ sc[:, 10] * q6[:, :16],
+ sc[:, 11] * q6[:, 16:],
+ sc[:, 12] * q7[:, :16],
+ sc[:, 13] * q7[:, 16:],
+ sc[:, 14] * q8[:, :16],
+ sc[:, 15] * q8[:, 16:],
+ ], axis=1)
+
+ if device is None:
+ return weight
+ return torch.from_numpy(weight).to(device=device)
+
+QK_K = 256
+kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8)
+
+def dequantize_iq4_xs(data):
+ block_size = GGML_BLOCK_SIZES["IQ4_XS"]
+ num_blocks = len(data) // block_size
+
+ d = np.frombuffer(data, dtype=np.float16)[0::block_size//2].astype(np.float32).reshape(num_blocks, 1)
+ scales_h = np.frombuffer(data, dtype=np.uint16)[1::block_size//2].reshape(num_blocks, 1)
+ data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)[:, 4:]
+ scales_l = data_u8[:, :4].reshape(num_blocks, 4)
+ qs = data_u8[:, 4:].reshape(num_blocks, block_size - 8)
+
+ ls = np.zeros((num_blocks, QK_K // 32), dtype=np.int8)
+ for ib in range(QK_K // 32):
+ ls[:, ib] = ((scales_l[:, ib // 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h[:, 0] >> 2 * ib) & 3) << 4)
+
+ dl = (d * (ls - 32)).reshape(num_blocks, -1, 1)
+
+ qs_lo_4 = qs[:, :QK_K // 2].reshape(num_blocks, -1, 16) & 0xf
+ qs_hi_4 = qs[:, :QK_K // 2].reshape(num_blocks, -1, 16) >> 4
+
+ y = np.zeros((num_blocks, QK_K), dtype=np.float32)
+ for ib in range(QK_K // 32):
+ y[:, ib*32:(ib*32)+16] = dl[:, ib] * kvalues_iq4nl[qs_lo_4[:, ib]]
+ y[:, (ib*32)+16:(ib*32)+32] = dl[:, ib] * kvalues_iq4nl[qs_hi_4[:, ib]]
+
+ return y.flatten()
+
+GGML_DEQUANTIZE = {
+ int(GGMLQuantizationType.F32): dequantize_f32,
+ int(GGMLQuantizationType.F16): dequantize_f16,
+ int(GGMLQuantizationType.Q4_0): dequantize_q4_0,
+ int(GGMLQuantizationType.Q5_0): dequantize_q5_0,
+ int(GGMLQuantizationType.Q8_0): dequantize_q8_0,
+ int(GGMLQuantizationType.Q2_K): dequantize_q2_k,
+ int(GGMLQuantizationType.Q3_K): dequantize_q3_k,
+ int(GGMLQuantizationType.Q4_K): dequantize_q4_k,
+ int(GGMLQuantizationType.Q5_K): dequantize_q5_k,
+ int(GGMLQuantizationType.Q6_K): dequantize_q6_k,
+ int(GGMLQuantizationType.IQ4_XS): dequantize_iq4_xs,
+}
+
+
+def dequant_gguf(data, type, shape):
+ values = GGML_DEQUANTIZE[type](data)
+ values = torch.from_numpy(values).view(shape)
+ return values
\ No newline at end of file
diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py
index c114772..23ccfc5 100644
--- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py
@@ -255,18 +255,6 @@ def marlin_moe_intermediate_size(w1_packed: torch.Tensor, w2_packed: torch.Tenso
return w2_packed.size(1) * marlin_tile_size
-def marlin_make_workspace(
- output_size_per_partition: int, device: torch.device
-) -> torch.Tensor:
- max_workspace_size = (
- output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N
- ) * GPTQ_MARLIN_MAX_PARALLEL
-
- return torch.zeros(
- max_workspace_size, dtype=torch.int, device=device, requires_grad=False
- )
-
-
def marlin_make_workspace_new(
device: torch.device, max_blocks_per_sm: int = 1
) -> torch.Tensor:
@@ -297,12 +285,6 @@ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
)
-def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
- return torch.nn.Parameter(
- torch.empty(0, dtype=torch.int, device=device), requires_grad=False
- )
-
-
def marlin_sort_g_idx(g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
return g_idx[g_idx_sort_indices], g_idx_sort_indices
diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
index 9dbfc6e..1bd2d75 100644
--- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
@@ -175,7 +175,7 @@ try:
op_func=_dequant_mxfp4,
fake_impl=_dequant_mxfp4_fake,
)
- dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4
+ dequant_mxfp4 = None
except AttributeError as error:
raise error
@@ -185,6 +185,6 @@ try:
op_func=_quant_dequant_mxfp4,
fake_impl=_quant_dequant_mxfp4_fake,
)
- quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4
+ quant_dequant_mxfp4 = None
except AttributeError as error:
raise error
diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py
index 12a1799..5709bfb 100644
--- a/vllm/model_executor/layers/quantization/utils/quant_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py
@@ -271,12 +271,12 @@ def scaled_quantize(
If None, uses input dtype. Use torch.float32 for higher precision.
"""
group_shape = _normalize_quant_group_shape(x, group_shape)
- assert quant_dtype.is_floating_point, (
- "currently `scaled_quantize` only supports floating point dtypes "
- "but could be extended to support other dtypes"
- )
+ # assert quant_dtype.is_floating_point, (
+ # "currently `scaled_quantize` only supports floating point dtypes "
+ # "but could be extended to support other dtypes"
+ # )
- finfo = torch.finfo(quant_dtype)
+ finfo = torch.finfo(quant_dtype) if quant_dtype.is_floating_point else torch.iinfo(quant_dtype)
# Convert to compute dtype if specified
x_compute = x if compute_dtype is None else x.to(compute_dtype)
diff --git a/vllm/model_executor/layers/quantization/w8a16.py b/vllm/model_executor/layers/quantization/w8a16.py
new file mode 100644
index 0000000..6c42ce7
--- /dev/null
+++ b/vllm/model_executor/layers/quantization/w8a16.py
@@ -0,0 +1,114 @@
+from typing import Any, Dict, List, Optional
+
+import torch
+from torch.nn.parameter import Parameter
+
+from vllm import _custom_ops as ops
+from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
+from vllm.model_executor.layers.quantization.base_config import (
+ QuantizationConfig)
+from vllm.model_executor.parameter import (GroupQuantScaleParameter,
+ PackedvLLMParameter)
+from vllm.model_executor.utils import set_weight_attrs
+
+
+class W8a16Config(QuantizationConfig):
+ """Config class for W8a16.
+
+ """
+
+ def __init__(
+ self,
+ ) -> None:
+ pass
+
+ def __repr__(self) -> str:
+ return ("W8a16Config")
+
+ def get_name(self) -> str:
+ return "w8a16"
+
+ def get_supported_act_dtypes(self) -> List[torch.dtype]:
+ return [torch.half, torch.bfloat16]
+
+ def get_min_capability(self) -> int:
+ return 75
+
+ @staticmethod
+ def get_config_filenames():
+ return []
+
+ @classmethod
+ def from_config(cls, config: Dict[str, Any]) -> "W8a16Config":
+ return cls()
+
+ def get_quant_method(self, layer: torch.nn.Module,
+ prefix: str) -> Optional["W8a16LinearMethod"]:
+ if isinstance(layer, LinearBase):
+ return W8a16LinearMethod(self)
+ return None
+
+
+ def get_scaled_act_names(self) -> List[str]:
+ return []
+
+
+class W8a16LinearMethod(LinearMethodBase):
+ """Linear method for w8a16.
+
+ """
+
+ def __init__(self, quant_config: W8a16Config):
+ self.quant_config = quant_config
+
+ def create_weights(self, layer: torch.nn.Module,
+ input_size_per_partition: int,
+ output_partition_sizes: List[int], input_size: int,
+ output_size: int, params_dtype: torch.dtype,
+ **extra_weight_attrs):
+ output_size_per_partition = sum(output_partition_sizes)
+ weight = Parameter(
+ torch.empty(
+ output_size_per_partition,
+ input_size_per_partition,
+ dtype=torch.int8,
+ ),
+ requires_grad=False,
+ )
+ set_weight_attrs(
+ weight, {
+ "input_dim": 1,
+ "output_dim": 0,
+ })
+
+ scales = Parameter(
+ torch.empty(
+ 1,
+ output_size_per_partition,
+ dtype=params_dtype,
+ ),
+ requires_grad=False,
+ )
+ set_weight_attrs(scales, {
+ "input_dim": None,
+ "output_dim": 1,
+ })
+
+ layer.register_parameter("weight", weight)
+ set_weight_attrs(weight, extra_weight_attrs)
+ layer.register_parameter("scales", scales)
+ set_weight_attrs(scales, extra_weight_attrs)
+
+
+ def apply(self,
+ layer: torch.nn.Module,
+ x: torch.Tensor,
+ bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+ qweight = layer.weight
+ scales = layer.scales
+ out_shape = (x.shape[:-1] + (qweight.shape[-2],))
+ reshaped_x = x.reshape(-1, x.shape[-1])
+ out = ops.linear_w8a16(reshaped_x, qweight, scales, format="TN")
+ if bias is not None:
+ out = out + bias
+ return out.reshape(out_shape)
\ No newline at end of file
diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py
index 1374334..0a8fd19 100644
--- a/vllm/model_executor/layers/rotary_embedding/base.py
+++ b/vllm/model_executor/layers/rotary_embedding/base.py
@@ -227,6 +227,7 @@ class RotaryEmbedding(RotaryEmbeddingBase):
self.head_size,
cos_sin_cache,
self.is_neox_style,
+ self.rotary_dim,
)
return query, key
diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py
index 54c6684..2668a70 100644
--- a/vllm/model_executor/layers/rotary_embedding/common.py
+++ b/vllm/model_executor/layers/rotary_embedding/common.py
@@ -229,22 +229,7 @@ class ApplyRotaryEmb(CustomOp):
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
- # from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
return self.forward_native(x, cos, sin)
- x, cos, sin, origin_shape, origin_dtype = self._pre_process(x, cos, sin)
-
- """
- Arguments of apply_rotary_emb() in vllm_flash_attn:
- x: [batch_size, seq_len, nheads, headdim]
- cos, sin: [seqlen_rotary, rotary_dim / 2]
- interleaved: defalut as False (Neox-style).
- ...
- """
- interleaved = not self.is_neox_style
- output = apply_rotary_emb(x, cos, sin, interleaved)
-
- output = self._post_process(output, origin_shape, origin_dtype)
- return output
def forward_hip(
self,
diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py
index c3abdc1..c263843 100644
--- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py
+++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py
@@ -8,7 +8,7 @@ import torch
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
-from .base import RotaryEmbeddingBase
+from .base import RotaryEmbedding
from .common import (
rotate_gptj,
rotate_neox,
@@ -23,7 +23,7 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
return 0.1 * mscale * math.log(scale) + 1.0
-class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
+class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with YaRN method.
Credits to Peng et al. github.com/jquesnelle/yarn
@@ -110,73 +110,3 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
sin = freqs.sin() * self.mscale
cache = torch.cat((cos, sin), dim=-1)
return cache
-
- def forward_native(
- self,
- positions: torch.Tensor,
- query: torch.Tensor,
- key: torch.Tensor | None = None,
- offsets: torch.Tensor | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
- """PyTorch-native implementation equivalent to forward()."""
- assert key is not None
- cos_sin_cache = self._match_cos_sin_cache_dtype(query)
- query_rot = query[..., : self.rotary_dim]
- key_rot = key[..., : self.rotary_dim]
- if self.rotary_dim < self.head_size:
- query_pass = query[..., self.rotary_dim :]
- key_pass = key[..., self.rotary_dim :]
-
- cos_sin = cos_sin_cache[
- torch.add(positions, offsets) if offsets is not None else positions
- ]
- cos, sin = cos_sin.chunk(2, dim=-1)
- if self.is_neox_style:
- # NOTE(woosuk): Here we assume that the positions tensor has the
- # shape [batch_size, seq_len].
- cos = cos.repeat(1, 1, 2).unsqueeze(-2)
- sin = sin.repeat(1, 1, 2).unsqueeze(-2)
- else:
- cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
- sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
-
- rotate_fn = rotate_neox if self.is_neox_style else rotate_gptj
- query_rot = query_rot * cos + rotate_fn(query_rot) * sin
- key_rot = key_rot * cos + rotate_fn(key_rot) * sin
-
- if self.rotary_dim < self.head_size:
- query = torch.cat((query_rot, query_pass), dim=-1)
- key = torch.cat((key_rot, key_pass), dim=-1)
- else:
- query = query_rot
- key = key_rot
- return query, key
-
- def forward_hip(
- self,
- positions: torch.Tensor,
- query: torch.Tensor,
- key: torch.Tensor | None = None,
- offsets: torch.Tensor | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
- return self.forward_native(positions, query, key, offsets)
-
- def forward_cuda(
- self,
- positions: torch.Tensor,
- query: torch.Tensor,
- key: torch.Tensor | None = None,
- offsets: torch.Tensor | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
- if self.use_flashinfer:
- torch.ops.vllm.flashinfer_rotary_embedding(
- torch.add(positions, offsets) if offsets is not None else positions,
- query,
- key,
- self.head_size,
- self.cos_sin_cache,
- self.is_neox_style,
- )
- return query, key
- else:
- return self.forward_native(positions, query, key, offsets)
diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py
index 3c946dd..27635dd 100644
--- a/vllm/model_executor/layers/rotary_embedding/mrope.py
+++ b/vllm/model_executor/layers/rotary_embedding/mrope.py
@@ -330,48 +330,46 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
) -> tuple[torch.Tensor, torch.Tensor | None]:
assert positions.ndim == 1 or positions.ndim == 2
assert key is not None
+ from vllm import _custom_ops as ops
- cos_sin_cache = self._match_cos_sin_cache_dtype(query)
- num_tokens = positions.shape[-1]
- cos_sin = cos_sin_cache[positions]
- cos, sin = cos_sin.chunk(2, dim=-1)
- query_shape = query.shape
- key_shape = key.shape
- if positions.ndim == 2:
- assert self.mrope_section
+ self._match_cos_sin_cache_dtype(query)
+
+ if self.mrope_interleaved:
+ num_tokens = positions.shape[-1]
+ cos_sin = self.cos_sin_cache[positions]
+ cos, sin = cos_sin.chunk(2, dim=-1)
+ query_shape = query.shape
+ key_shape = key.shape
+ if positions.ndim == 2:
+ assert self.mrope_section
+ q, k = triton_mrope(
+ query,
+ key,
+ cos,
+ sin,
+ self.mrope_section,
+ self.head_size,
+ self.rotary_dim,
+ self.mrope_interleaved,
+ )
- q, k = triton_mrope(
- query,
- key,
- cos,
- sin,
- self.mrope_section,
- self.head_size,
- self.rotary_dim,
- self.mrope_interleaved,
- )
-
- return q.reshape(query_shape), k.reshape(key_shape)
-
- query = query.view(num_tokens, -1, self.head_size)
- query_rot = query[..., : self.rotary_dim]
- query_pass = query[..., self.rotary_dim :]
- query_rot = self.apply_rotary_emb(
- query_rot,
- cos,
- sin,
- )
- query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
-
- key = key.view(num_tokens, -1, self.head_size)
- key_rot = key[..., : self.rotary_dim]
- key_pass = key[..., self.rotary_dim :]
- key_rot = self.apply_rotary_emb(
- key_rot,
- cos,
- sin,
- )
- key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
+ return q.reshape(query_shape), k.reshape(key_shape)
+
+ if positions.ndim == 1:
+ ops.rotary_embedding(positions, query, key, self.head_size,
+ self.cos_sin_cache, self.is_neox_style)
+ else:
+ if self.is_neox_style:
+ ops.m_rotary_embedding(positions.contiguous(), query, key, self.head_size,
+ self.cos_sin_cache,
+ torch.tensor(self.mrope_section, dtype=torch.int),
+ self.is_neox_style)
+ else:
+ query, key = self.forward_native(
+ positions, query, key
+ )
+
+
return query, key
def forward_cpu(
diff --git a/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py b/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py
index e58c978..5e519cf 100644
--- a/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py
+++ b/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py
@@ -12,6 +12,7 @@ from .common import rotate_neox
logger = init_logger(__name__)
+import ixformer.inference.functions as ixops
class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
"""Phi3 family of models scaled rotary embedding.
@@ -133,27 +134,18 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)
- if self.use_long_rope:
- k = self.original_max_position_embeddings
- long_prompt_offset = torch.full_like(positions, k).long()
- idx = torch.add(positions, long_prompt_offset)
- else:
- idx = positions
- idx = torch.add(idx, offsets) if offsets is not None else idx
- cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
+ k = self.original_max_position_embeddings
+ long_prompt_offset = torch.any(positions > k)
+
+ ixops.vllm_rotary_embedding_phi(
+ positions,
+ query,
+ key,
+ self.head_size,
+ self.long_short_cos_sin_cache,
+ long_prompt_offset,
+ k,
+ offsets
+ )
- cos, sin = cos_sin.chunk(2, dim=-1)
- cos = cos.repeat(1, 2).unsqueeze(-2)
- sin = sin.repeat(1, 2).unsqueeze(-2)
-
- query_rot = query[..., : self.rotary_dim]
- query_pass = query[..., self.rotary_dim :]
- query_rot = query_rot * cos + rotate_neox(query_rot) * sin
- query = torch.cat((query_rot, query_pass), dim=-1)
-
- key_rot = key[..., : self.rotary_dim]
- key_pass = key[..., self.rotary_dim :]
- key_rot = key_rot * cos + rotate_neox(key_rot) * sin
- key = torch.cat((key_rot, key_pass), dim=-1)
-
- return query.flatten(-2), key.flatten(-2)
+ return query, key
diff --git a/vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py
new file mode 100644
index 0000000..e1e3d18
--- /dev/null
+++ b/vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py
@@ -0,0 +1,56 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from typing import Optional
+
+import torch
+
+from vllm.distributed import tensor_model_parallel_all_reduce
+from vllm.model_executor.layers.fused_moe.layer import FusedMoE
+
+
+# TODO(bnell): Add shared + fused combo function? e.g. +
+class SharedFusedMoE(FusedMoE):
+ """
+ A FusedMoE operation that also computes the results of shared experts.
+ If an all2all communicator is being used the shared expert computation
+ can be interleaved with the fused all2all dispatch communication step.
+ """
+
+ def __init__(
+ self,
+ shared_experts: torch.nn.Module,
+ use_overlapped: bool = True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self._shared_experts = shared_experts
+ self.use_overlapped = use_overlapped
+
+ @property
+ def shared_experts(self) -> Optional[torch.nn.Module]:
+ return self._shared_experts if self.use_overlapped else None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ router_logits: torch.Tensor,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ if not self.use_overlapped:
+ shared_out = self._shared_experts(hidden_states)
+
+ # Reduce outputs if necessary, since the MLP should
+ # have been created with reduce_results=False.
+ if (self.reduce_results and self.tp_size > 1
+ and self.must_reduce_shared_expert_outputs()):
+ shared_out = tensor_model_parallel_all_reduce(shared_out)
+
+ fused_out = super().forward(
+ hidden_states=hidden_states,
+ router_logits=router_logits,
+ )
+ else:
+ shared_out, fused_out = super().forward(
+ hidden_states=hidden_states,
+ router_logits=router_logits,
+ )
+ return shared_out, fused_out
diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py
index 826caa5..26e79f9 100644
--- a/vllm/model_executor/layers/sparse_attn_indexer.py
+++ b/vllm/model_executor/layers/sparse_attn_indexer.py
@@ -9,58 +9,108 @@ from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
-from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
-from vllm.utils.import_utils import has_deep_gemm
+from vllm.utils.deep_gemm import (
+ fp8_mqa_logits,
+ fp8_mqa_logits_torch,
+ fp8_paged_mqa_logits,
+ fp8_paged_mqa_logits_torch,
+ is_deep_gemm_supported,
+)
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerMetadata,
)
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
from vllm.v1.worker.workspace import current_workspace_manager
-
+from vllm.utils.math_utils import cdiv
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
from vllm._xpu_ops import xpu_ops as ops
+import ixformer.inference.functions as ixfops
+
logger = init_logger(__name__)
+@torch.inference_mode()
+def cp_gather_indexer_k_quant_cache(
+ kv_cache, # [num_blocks, block_size, head_dim]
+ dst_value, # [cu_seq_lens[-1], head_dim]
+ block_table, # [batch_size, num_blocks]
+ cu_seq_lens, # [batch_size + 1, ]
+ batch_size,
+):
+ num_blocks, block_size, _ = kv_cache.shape
+ head_dim = dst_value.shape[-1]
+ kv_cache = kv_cache.view(num_blocks, -1)
+
+ expected_value = []
+ # expected_scale = []
+ for b in range(batch_size):
+ s = cu_seq_lens[b + 1] - cu_seq_lens[b]
+ if s == 0:
+ continue
+ tot = cdiv(s, block_size)
+ blocks = block_table[b, :tot]
+
+ value = []
+ scale = []
+ full_block = torch.arange(tot - 1,
+ device=kv_cache.device,
+ dtype=torch.int32)
+ non_remaining_value = kv_cache[blocks[full_block], :block_size *
+ head_dim].view(-1, head_dim)
+ # non_remaining_scale = kv_cache[blocks[full_block],
+ # block_size * head_dim:].view(-1, 4)
+
+ remaining = s - (tot - 1) * block_size
+
+ value = torch.cat([
+ non_remaining_value,
+ kv_cache[blocks[-1], :remaining * head_dim].view(-1, head_dim)
+ ],
+ dim=0)
+ # scale = torch.cat([
+ # non_remaining_scale,
+ # kv_cache[blocks[-1], block_size * head_dim:block_size * head_dim +
+ # remaining * 4].view(-1, 4)
+ # ],
+ # dim=0)
+
+ expected_value.append(value)
+ # expected_scale.append(scale)
+
+ gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim)
+ # gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4)
+ gather_value = gather_value.view(torch.bfloat16)
+ # gather_scale = gather_scale.view(torch.float32)
+ dst_value.copy_(gather_value)
+ # dst_scale.copy_(gather_scale)
def sparse_attn_indexer(
hidden_states: torch.Tensor,
k_cache_prefix: str,
kv_cache: torch.Tensor,
- q_fp8: torch.Tensor,
+ q: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
- quant_block_size: int,
- scale_fmt: str | None,
topk_tokens: int,
head_dim: int,
max_model_len: int,
total_seq_lens: int,
- topk_indices_buffer: torch.Tensor,
+ topk_indices_buffer: torch.Tensor | None,
) -> torch.Tensor:
# careful! this will be None in dummy run
attn_metadata = get_forward_context().attn_metadata
- fp8_dtype = current_platform.fp8_dtype()
-
# assert isinstance(attn_metadata, dict)
if not isinstance(attn_metadata, dict):
- # Reserve workspace for indexer during profiling run
- current_workspace_manager().get_simultaneous(
- ((total_seq_lens, head_dim), torch.float8_e4m3fn),
- ((total_seq_lens, 4), torch.uint8),
- )
return sparse_attn_indexer_fake(
hidden_states,
k_cache_prefix,
kv_cache,
- q_fp8,
+ q,
k,
weights,
- quant_block_size,
- scale_fmt,
topk_tokens,
head_dim,
max_model_len,
@@ -74,12 +124,118 @@ def sparse_attn_indexer(
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
- ops.indexer_k_quant_and_cache(
+ ops.indexer_k_cache(
+ k,
+ kv_cache,
+ slot_mapping
+ )
+
+ # topk_indices_buffer[: hidden_states.shape[0]] = -1
+ if has_prefill:
+ prefill_metadata = attn_metadata.prefill
+ for chunk in prefill_metadata.chunks:
+ logits = ixfops.dsa_indexer_mqa_logits_with_blocks(
+ q[chunk.token_start:chunk.token_end],
+ chunk.cu_seqlens_q,
+ chunk.cu_seq_lens,
+ kv_cache,
+ chunk.block_table,
+ weights[chunk.token_start : chunk.token_end],
+ max_q_len=chunk.max_q_len,
+ max_kv_len=chunk.max_kv_len,
+ max_context_len=chunk.max_context_len
+ )
+ ixfops.dsa_update_topk_indices(
+ logits, chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, topk_tokens,
+ topk_indices_buffer[chunk.token_start:chunk.token_end]
+ )
+
+ if has_decode:
+ decode_metadata = attn_metadata.decode
+ # TODO: support speculative decode
+ if decode_metadata.requires_padding:
+ raise NotImplementedError(
+ "Sparse attention indexer does not support requires_padding"
+ )
+
+ # Use dsa_indexer_mqa_logits_with_blocks similar to prefill
+ logits = ixfops.dsa_indexer_mqa_logits_with_blocks(
+ q[:num_decode_tokens],
+ decode_metadata.cu_seqlens_q,
+ decode_metadata.cu_seqlens_kv,
+ kv_cache,
+ decode_metadata.block_table,
+ weights[:num_decode_tokens],
+ max_q_len=decode_metadata.max_q_len,
+ max_kv_len=decode_metadata.max_kv_len,
+ max_context_len=decode_metadata.max_context_len,
+ )
+
+ ixfops.dsa_update_topk_indices(
+ logits,
+ decode_metadata.cu_seqlen_ks,
+ decode_metadata.cu_seqlen_ke,
+ topk_tokens,
+ topk_indices_buffer[:num_decode_tokens],
+ )
+
+ return topk_indices_buffer
+
+
+def sparse_attn_indexer_original(
+ hidden_states: torch.Tensor,
+ k_cache_prefix: str,
+ kv_cache: torch.Tensor,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ weights: torch.Tensor,
+ topk_tokens: int,
+ head_dim: int,
+ max_model_len: int,
+ total_seq_lens: int,
+ topk_indices_buffer: torch.Tensor,
+) -> torch.Tensor:
+ # careful! this will be None in dummy run
+ attn_metadata = get_forward_context().attn_metadata
+ # fp8_dtype = current_platform.fp8_dtype()
+
+ # assert isinstance(attn_metadata, dict)
+ if not isinstance(attn_metadata, dict):
+ # Reserve workspace for indexer during profiling run
+ current_workspace_manager().get_simultaneous(
+ ((total_seq_lens, head_dim), torch.float8_e4m3fn),
+ ((total_seq_lens, 4), torch.uint8),
+ )
+ return sparse_attn_indexer_fake(
+ hidden_states,
+ k_cache_prefix,
+ kv_cache,
+ q,
+ k,
+ weights,
+ topk_tokens,
+ head_dim,
+ max_model_len,
+ total_seq_lens,
+ topk_indices_buffer,
+ )
+ attn_metadata = attn_metadata[k_cache_prefix]
+ assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
+ slot_mapping = attn_metadata.slot_mapping
+ has_decode = attn_metadata.num_decodes > 0
+ has_prefill = attn_metadata.num_prefills > 0
+ num_decode_tokens = attn_metadata.num_decode_tokens
+
+ # During speculative decoding, k may be padded to the CUDA graph batch
+ # size while slot_mapping only covers actual tokens. Truncate k to avoid
+ # out-of-bounds reads in the kernel.
+ num_tokens = slot_mapping.shape[0]
+ k = k[:num_tokens]
+
+ ops.indexer_k_cache(
k,
kv_cache,
slot_mapping,
- quant_block_size,
- scale_fmt,
)
topk_indices_buffer[: hidden_states.shape[0]] = -1
@@ -88,44 +244,42 @@ def sparse_attn_indexer(
# Get the full shared workspace buffers once (will allocate on first use)
workspace_manager = current_workspace_manager()
- k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
- ((total_seq_lens, head_dim), fp8_dtype),
- ((total_seq_lens, 4), torch.uint8),
- )
+ k_full = workspace_manager.get_simultaneous(
+ ((total_seq_lens, head_dim), torch.bfloat16),
+ )[0]
for chunk in prefill_metadata.chunks:
- k_fp8 = k_fp8_full[: chunk.total_seq_lens]
- k_scale = k_scale_full[: chunk.total_seq_lens]
- ops.cp_gather_indexer_k_quant_cache(
+ k = k_full[: chunk.total_seq_lens]
+ # k_scale = k_scale_full[: chunk.total_seq_lens]
+ cp_gather_indexer_k_quant_cache(
kv_cache,
- k_fp8,
- k_scale,
+ k,
chunk.block_table,
chunk.cu_seq_lens,
+ chunk.num_reqs,
)
- logits = fp8_mqa_logits(
- q_fp8[chunk.token_start : chunk.token_end],
- (k_fp8, k_scale.view(torch.float32).flatten()),
+ logits = ops.ref_mqa_logits(
+ q[chunk.token_start:chunk.token_end],
+ k,
weights[chunk.token_start : chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
- clean_logits=False,
- )
- num_rows = logits.shape[0]
-
- topk_indices = topk_indices_buffer[
- chunk.token_start : chunk.token_end, :topk_tokens
- ]
- torch.ops._C.top_k_per_row_prefill(
- logits,
- chunk.cu_seqlen_ks,
- chunk.cu_seqlen_ke,
- topk_indices,
- num_rows,
- logits.stride(0),
- logits.stride(1),
- topk_tokens,
)
+ topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]),
+ dim=-1)[1]
+ topk_indices -= chunk.cu_seqlen_ks[:, None]
+ mask_lo = topk_indices >= 0
+ mask_hi = topk_indices - (chunk.cu_seqlen_ke -
+ chunk.cu_seqlen_ks)[:, None] < 0
+ mask = torch.full_like(topk_indices,
+ False,
+ dtype=torch.bool,
+ device=topk_indices.device)
+ mask = mask_lo & mask_hi
+ topk_indices = topk_indices.masked_fill(~mask, -1)
+ topk_indices_buffer[
+ chunk.token_start:chunk.token_end, :topk_indices.
+ shape[-1]] = topk_indices.to(dtype=torch.int32)
# Compute lengths from row spans
# lengths = (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks).to(torch.int32)
@@ -147,63 +301,50 @@ def sparse_attn_indexer(
# decode_threshold since we unstrictly split
# prefill and decode by decode_threshold
# (currently set to 1 + speculative tokens)
- padded_q_fp8_decode_tokens = pack_seq_triton(
- q_fp8[:num_decode_tokens], decode_lens
- )
+ padded_q_decode_tokens = pack_seq_triton(
+ q[:num_decode_tokens], decode_lens)
else:
- padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
- decode_lens.shape[0], -1, *q_fp8.shape[1:]
- )
+ padded_q_decode_tokens = q[:num_decode_tokens].reshape(
+ decode_lens.shape[0], -1, *q.shape[1:])
# TODO: move and optimize below logic with triton kernels
- batch_size = padded_q_fp8_decode_tokens.shape[0]
- next_n = padded_q_fp8_decode_tokens.shape[1]
+ batch_size = padded_q_decode_tokens.shape[0]
+ next_n = padded_q_decode_tokens.shape[1]
assert batch_size == decode_metadata.seq_lens.shape[0]
num_padded_tokens = batch_size * next_n
- logits = fp8_paged_mqa_logits(
- padded_q_fp8_decode_tokens,
+ logits = ops.ref_paged_mqa_logits(
+ padded_q_decode_tokens,
kv_cache,
weights[:num_padded_tokens],
decode_metadata.seq_lens,
decode_metadata.block_table,
- decode_metadata.schedule_metadata,
max_model_len=max_model_len,
clean_logits=False,
)
- num_rows = logits.shape[0]
- topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
-
- if decode_metadata.use_large_context_topk:
- if next_n == 1:
- lengths = decode_metadata.seq_lens
- else:
- # (bs,) -> (bs, 1) + (next_n,) -> (bs, next_n) -> (bs * next_n,)
- lengths = (
- decode_metadata.seq_lens.unsqueeze(1)
- - next_n
- + 1
- + decode_metadata.offsets
- ).flatten()
-
- torch.ops._C.large_context_topk(
- logits,
- topk_indices,
- lengths,
- None,
- )
- else:
- torch.ops._C.top_k_per_row_decode(
- logits,
- next_n,
- decode_metadata.seq_lens,
- topk_indices,
- num_rows,
- logits.stride(0),
- logits.stride(1),
- topk_tokens,
- )
-
+ # padded query len
+ current_device = padded_q_decode_tokens.device
+ padded_num_tokens = batch_size * next_n
+ positions = torch.arange(max_model_len,
+ device=current_device).unsqueeze(0).expand(
+ batch_size * next_n, -1)
+ row_indices = torch.arange(padded_num_tokens,
+ device=current_device) // next_n
+ next_n_offset = torch.arange(
+ padded_num_tokens,
+ device=padded_q_decode_tokens.device) % next_n
+ index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n +
+ next_n_offset).unsqueeze(1)
+ # index_end_pos: [B * N, 1]
+ mask = positions <= index_end_pos
+ # mask: [B * N, L]
+ logits = logits.masked_fill(~mask, float('-inf'))
+ topk_indices = logits.topk(topk_tokens,
+ dim=-1)[1].to(torch.int32) # [B * N, K]
+ # ensure we don't set indices for the top k
+ # that is out of range(masked already)
+ # this will happen if context length is shorter than K
+ topk_indices[topk_indices > index_end_pos] = -1
if decode_metadata.requires_padding:
# if padded, we need to unpack
# the topk indices removing padded tokens
@@ -211,9 +352,8 @@ def sparse_attn_indexer(
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
decode_lens,
)
- topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
- topk_indices
- )
+ topk_indices_buffer[:num_decode_tokens, :topk_indices.
+ shape[-1]] = topk_indices.to(dtype=torch.int32)
return topk_indices_buffer
@@ -222,11 +362,9 @@ def sparse_attn_indexer_fake(
hidden_states: torch.Tensor,
k_cache_prefix: str,
kv_cache: torch.Tensor,
- q_fp8: torch.Tensor,
+ q: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
- quant_block_size: int,
- scale_fmt: str | None,
topk_tokens: int,
head_dim: int,
max_model_len: int,
@@ -278,9 +416,12 @@ class SparseAttnIndexer(CustomOp):
self.max_model_len = max_model_len
self.max_total_seq_len = max_total_seq_len
self.topk_indices_buffer = topk_indices_buffer
- if current_platform.is_cuda() and not has_deep_gemm():
- raise RuntimeError(
- "Sparse Attention Indexer CUDA op requires DeepGEMM to be installed."
+ if current_platform.is_cuda() and not is_deep_gemm_supported():
+ logger.warning_once(
+ "DeepGEMM is not supported or available. SparseAttnIndexer will use a "
+ "less efficient PyTorch implementation. "
+ "Please make sure you have the required hardware and software setup "
+ "for DeepGEMM to achieve optimal performance."
)
def forward_native(
@@ -303,7 +444,7 @@ class SparseAttnIndexer(CustomOp):
def forward_cuda(
self,
hidden_states: torch.Tensor,
- q_fp8: torch.Tensor,
+ q: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
):
@@ -311,11 +452,9 @@ class SparseAttnIndexer(CustomOp):
hidden_states,
self.k_cache.prefix,
self.k_cache.kv_cache[0],
- q_fp8,
+ q,
k,
weights,
- self.quant_block_size,
- self.scale_fmt,
self.topk_tokens,
self.head_dim,
self.max_model_len,
diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py
index bc51b0e..11a4ff0 100644
--- a/vllm/model_executor/layers/utils.py
+++ b/vllm/model_executor/layers/utils.py
@@ -3,6 +3,8 @@
"""Utility methods for model layers."""
from collections.abc import Callable
+import ast
+import re
import torch
@@ -13,6 +15,7 @@ from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.utils.platform_utils import num_compute_units
from vllm.utils.torch_utils import direct_register_custom_op
+import ixformer.inference.functions as IXF
logger = init_logger(__name__)
@@ -31,27 +34,6 @@ def is_layer_moe_router_gate(prefix: str) -> bool:
return prefix.rsplit(".", 1)[-1] in MOE_LAYER_ROUTER_GATE_SUFFIXES
-def shuffle_weight(w: torch.Tensor) -> torch.Tensor:
- # Shuffle weight along the last dimension so that
- # we folded the weights to adjance location
- # Example:
- # input:
- # [[1, 2, 3, 4, 5, 6],
- # [7, 8, 9, 10, 11, 12]]
- # output:
- # [[1, 4, 2, 5, 3, 6],
- # [7, 10, 8, 11, 9, 12]]
- # This will be used together with triton swiglu kernel
- shape = w.shape
- N = shape[-1]
- first = w[..., : N // 2]
- second = w[..., N // 2 :]
-
- stacked = torch.stack((first, second), dim=-1)
- w_shuffled = stacked.reshape(shape)
- return w_shuffled
-
-
def get_token_bin_counts_and_mask(
tokens: torch.Tensor,
vocab_size: int,
@@ -116,7 +98,11 @@ def default_unquantized_gemm(
weight: torch.Tensor,
bias: torch.Tensor | None = None,
):
- return torch.nn.functional.linear(x, weight, bias)
+ if bias is None and x.dtype in [torch.half, torch.bfloat16] and weight.dtype == torch.float32:
+ return IXF.mixed_type_linear(input=x, weight=layer.weight)
+ if x.dtype == torch.float32:
+ return torch.nn.functional.linear(x, weight, bias)
+ return IXF.linear(x, weight, bias)
def use_aiter_triton_gemm(n, m, k, dtype):
@@ -191,7 +177,6 @@ def rocm_unquantized_gemm_impl(
and on_gfx9()
and x.dtype in [torch.float16, torch.bfloat16]
and k % 8 == 0
- and x.is_contiguous()
)
if use_skinny is not True:
@@ -302,3 +287,72 @@ def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]:
return cpu_unquantized_gemm
else:
return default_unquantized_gemm
+
+def weight_quant_l1(input: torch.Tensor):
+ qmax = 127.0
+ input = input.to(device="cuda")
+ abs_max = torch.abs(input).max(dim=1, keepdim=True)[0]
+ scale = abs_max / qmax
+ assert scale.shape == (input.shape[0], 1)
+ quantized = torch.round(input / scale)
+ quantized = torch.clamp(quantized, -qmax, qmax)
+ return quantized.to(torch.int8), scale.to(torch.float32)
+
+def weight_quant_l2(input: torch.Tensor, format: str = "TN"):
+ qmax = 127.0
+ input = input.to(device="cuda")
+ abs_max = torch.abs(input).max(dim=1, keepdim=True)[0] # [rows, 1]
+ scale = abs_max / qmax # [rows, 1]
+ assert scale.shape == (input.shape[0], 1)
+ quantized = torch.round(input / scale)
+ quantized = torch.clamp(quantized, -qmax, qmax)
+
+ i4_weights, i8scales, i8zeros = IXF.quant_repack_int4(quantized.to(torch.int8).unsqueeze_(0), -1, 2, format, False)
+ return i4_weights.squeeze(0), scale.to(torch.float32)
+
+
+def parse_opt_exclude_layers(
+ opt_exclude_layers_str: str,
+ prefix: str,
+) -> bool:
+ """
+ Parses the VLLM_OPT_EXCLUDE_LAYERS environment variable to determine if
+ the current layer should be excluded from optimization.
+
+ Args:
+ opt_exclude_layers_str: The string value from the
+ VLLM_OPT_EXCLUDE_LAYERS environment variable.
+ prefix: The prefix of the current layer (e.g.,
+ "model.layers.12.qkv_proj").
+
+ Returns:
+ A boolean indicating whether the layer should be excluded.
+ """
+ if not opt_exclude_layers_str:
+ return False
+
+ try:
+ # Safely evaluate the string to a Python object
+ excluded_layers = ast.literal_eval(opt_exclude_layers_str)
+
+ # If a single integer is provided, convert it to a set
+ if isinstance(excluded_layers, int):
+ excluded_layers = {excluded_layers}
+ elif not isinstance(excluded_layers, (set, tuple, list)):
+ raise TypeError
+
+ excluded_layers: set[int] = set(excluded_layers)
+
+ # Extract layer number from the prefix string
+ layer_match = re.search(r"\.(\d+)", prefix)
+ if layer_match and int(layer_match.group(1)) in excluded_layers:
+ return True # Exclude this layer
+ except (ValueError, SyntaxError, TypeError):
+ logger.warning(
+ "Failed to parse VLLM_OPT_EXCLUDE_LAYERS: %s. "
+ "Expected a string representation of an integer or a "
+ "tuple/list/set of integers.",
+ opt_exclude_layers_str,
+ )
+
+ return False # Do not exclude this layer
\ No newline at end of file
diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py
index ff95d5b..30ac389 100644
--- a/vllm/model_executor/model_loader/__init__.py
+++ b/vllm/model_executor/model_loader/__init__.py
@@ -23,6 +23,10 @@ from vllm.model_executor.model_loader.utils import (
get_model_architecture,
get_model_cls,
)
+from vllm.model_executor.model_loader.weight_utils import (
+ padding_weight_loader
+)
+
logger = init_logger(__name__)
diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py
index ebd3a15..f75d252 100644
--- a/vllm/model_executor/model_loader/default_loader.py
+++ b/vllm/model_executor/model_loader/default_loader.py
@@ -13,6 +13,7 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import ModelConfig
from vllm.config.load import LoadConfig
+from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
@@ -32,6 +33,8 @@ from vllm.model_executor.model_loader.weight_utils import (
)
from vllm.tracing import instrument
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
+from vllm import envs
+
logger = init_logger(__name__)
@@ -287,8 +290,7 @@ class DefaultModelLoader(BaseModelLoader):
self.load_config.safetensors_load_strategy = "torchao"
weights_to_load = {name for name, _ in model.named_parameters()}
- all_weights = self.get_all_weights(model_config, model)
- loaded_weights = model.load_weights(all_weights)
+ loaded_weights = model.load_weights(self.get_all_weights(model_config, model))
self.counter_after_loading_weights = time.perf_counter()
logger.info_once(
@@ -298,7 +300,8 @@ class DefaultModelLoader(BaseModelLoader):
)
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
- if model_config.quantization is None and loaded_weights is not None:
+ opt_flag = envs.VLLM_MOE_OPT_LEVEL != 0 or envs.VLLM_LINEAR_OPT_LEVEL != 0
+ if model_config.quantization is None and loaded_weights is not None and not opt_flag:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError(
diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py
index f3d0b03..0bcf1cf 100644
--- a/vllm/model_executor/model_loader/weight_utils.py
+++ b/vllm/model_executor/model_loader/weight_utils.py
@@ -39,8 +39,6 @@ from vllm.platforms import current_platform
from vllm.tracing import instrument
from vllm.utils.import_utils import PlaceholderModule
-import ixformer.inference.functions as ixfop
-
try:
from runai_model_streamer import SafetensorsStreamer
except ImportError:
@@ -289,7 +287,17 @@ def get_quant_config(
)
if hf_quant_config is not None:
- return quant_cls.from_config(hf_quant_config)
+ # For modelopt_mixed, config.json's quantization_config may or may
+ # not contain the per-layer quantized_layers map. Newer checkpoints
+ # embed it directly; older ones keep it only in hf_quant_config.json.
+ # If it is missing, fall through to the file-based loading path.
+ if (
+ model_config.quantization == "modelopt_mixed"
+ and "quantized_layers" not in hf_quant_config
+ ):
+ pass # fall through to file-based loading below
+ else:
+ return quant_cls.from_config(hf_quant_config)
# if hf_quant_config is None, we will try to get config from
# hf_overrides
@@ -367,8 +375,8 @@ def get_quant_config(
if model_config.quantization == "bitsandbytes":
config["adapter_name_or_path"] = model_config.model
- elif model_config.quantization == "modelopt":
- if config["producer"]["name"] == "modelopt":
+ elif model_config.quantization in ("modelopt", "modelopt_mixed"):
+ if config.get("producer", {}).get("name") == "modelopt":
return quant_cls.from_config(config)
else:
raise ValueError(
@@ -697,13 +705,6 @@ def np_cache_weights_iterator(
yield name, torch.from_numpy(param)
-FP_TYPES = {
- "torch.bfloat16",
- "torch.float16",
- "torch.float32",
- "torch.half",
-}
-
def safetensors_weights_iterator(
hf_weights_files: list[str],
use_tqdm_on_load: bool,
@@ -714,12 +715,6 @@ def safetensors_weights_iterator(
if safetensors_load_strategy == "eager":
loading_desc += " (eager)"
- CUSTOM_QUANT_CONFIG = os.environ.get("CUSTOM_QUANT_CONFIG", None)
- try:
- with open(f"{CUSTOM_QUANT_CONFIG}/quant_map.json", "r") as f:
- quant_map = json.load(f)
- except Exception as e:
- quant_map = None
leftover_state_dict: dict[str, torch.Tensor] = {}
for st_file in tqdm(
sorted(hf_weights_files, key=_natural_sort_key),
@@ -763,160 +758,9 @@ def safetensors_weights_iterator(
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
- if not quant_map:
- yield name, param
- continue
- quant_type = quant_map.get(name)
-
- if quant_type is None or quant_type in FP_TYPES:
- yield name, param
- continue
-
- qtype, qformat = quant_type.split("-")
-
- qname = name
- qscale_name = f"{name}_scale"
- is_expert = ("expert" in name and "shared" not in name)
-
- # INT8
- if qtype == "int8":
- if param.ndim == 2:
- param.unsqueeze_(0)
-
- qweight, qscale = weight_quant_bf16_to_int8(param)
-
- A = qweight.shape[0]
-
- if is_expert:
- qscale = qscale.view(A, 1, -1).transpose(-2, -1).contiguous()
-
- if A == 1:
- qweight = qweight.squeeze(0)
- qscale = qscale.squeeze(0)
-
- yield qname, qweight
- yield qscale_name, qscale
- continue
-
- # INT4
- if qtype == "int4":
- i8scales, i8zeros = None, None
-
- if param.ndim == 2:
- param.unsqueeze_(0)
-
- qweight, qscale, i8scales, i8zeros = weight_quant_bf16_to_int4pack8(
- param,
- format=qformat,
- symmetric=True,
- )
-
- A = qweight.shape[0]
-
- if is_expert:
- qscale = qscale.view(A, 1, -1).contiguous()
-
- if A == 1:
- qweight = qweight.squeeze(0)
- qscale = qscale.squeeze(0)
-
- yield qname, qweight
- yield qscale_name, qscale
-
- if i8scales is not None:
- yield f"{name}_i8_weight_scale", i8scales.squeeze_(0)
-
- if i8zeros is not None:
- yield f"{name}_i8_weight_zero", i8zeros.squeeze_(0)
-
- continue
-
yield name, param
-def weight_quant_bf16_to_int8(inputs: torch.Tensor):
- device = current_platform.current_device()
-
- assert inputs.dim() == 3, f"inputs shape is [batch, output_dim, input_dim], but got {inputs.dim()}"
-
- ori_device = inputs.device
- if inputs.device != device:
- inputs = inputs.to(device)
-
- qmax = 127.0
- abs_max = torch.abs(inputs).max(dim=2, keepdim=True)[0]
- scale = abs_max / qmax
-
- assert scale.shape == (*inputs.shape[:2], 1)
-
- quantized = torch.round(inputs / scale)
- quantized = torch.clamp(quantized, -qmax, qmax)
- return quantized.to(torch.int8).to(ori_device), scale.to(torch.float32).to(ori_device)
-
-
-def weight_quant_bf16_to_int4pack8(
- v: torch.Tensor, # [B, R, C]
- block_size: int = 128,
- group_size: int = -1,
- format: str = "TN",
- symmetric: bool = True,
- version: int = 2,
-):
- """
- Batch 版本 INT4 量化 + 打包。
-
- Args:
- v: [batch, rows, cols], float Tensor
-
- Returns:
- i4_weights: [batch, rows, packed_cols]
- scale: [batch, rows, 1]
- i8scales: 来自 ixfop.quant_repack_int4
- i8zeros: 来自 ixfop.quant_repack_int4
- """
- device = current_platform.current_device()
- ori_device = inputs.device
- if inputs.device != device:
- inputs = inputs.to(device)
- assert v.dim() == 3, f"expected [batch, rows, cols], got {v.shape}"
-
- B, R, C = v.shape
-
- qmax = 127.0
-
- # abs_max: [B, R, 1]
- abs_max = torch.abs(v).amax(dim=2, keepdim=True)
- scale = abs_max / qmax # [B, R, 1]
-
- # quantized: [B, R, C]
- quantized = torch.round(v / scale)
- quantized = torch.clamp(quantized, -qmax, qmax).to(torch.int8)
-
- # ixfop.quant_repack_int4 需要 [batch, rows, cols]
- # 它本来就是 batch-first,可直接送进去
- # 返回形状一般是:
- # i4_weights: [B, R, packed_C]
- # i8scales: [B, R, groups]
- # i8zeros: [B, R, groups]
- i4_weights, i8scales, i8zeros = ixfop.quant_repack_int4(
- quantized, # 不需要 unsqueeze,因为本来就是 [B, R, C]
- group_size,
- version,
- format,
- symmetric,
- )
- if i8scales is not None:
- i8scales = i8scales.to(ori_device)
- i8zeros = i8zeros.to(ori_device)
- return (
- i4_weights.to(ori_device), # [B, R, packed_C]
- scale.to(torch.float32).to(ori_device), # [B, R, 1]
- i8scales, # 来自 repack
- i8zeros
- )
-
-
-
def multi_thread_safetensors_weights_iterator(
hf_weights_files: list[str],
use_tqdm_on_load: bool,
@@ -1426,3 +1270,50 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
# If there were no matches, return the untouched param name
return name
+
+
+def padding_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
+ """Weight loader that allows padding on the last (and optionally middle) dims.
+
+ If shapes match: copy directly.
+ If shapes differ: copy the overlapping slice (min along each dimension).
+ Special-cases MoE weights that have expert dim in front (2D/3D).
+ """
+ # Fast path: exact match
+ if param.shape == loaded_weight.shape:
+ param.data.copy_(loaded_weight)
+ return
+
+ # Basic sanity checks
+ if param.ndim != loaded_weight.ndim:
+ raise ValueError(
+ f"Cannot load weight with different ndim: param.ndim={param.ndim}, "
+ f"loaded.ndim={loaded_weight.ndim}, param.shape={tuple(param.shape)}, "
+ f"loaded.shape={tuple(loaded_weight.shape)}"
+ )
+
+ dims = param.ndim
+ if dims not in (2, 3):
+ raise ValueError(
+ f"padding_weight_loader only supports 2D/3D tensors, got {dims}D. "
+ f"param.shape={tuple(param.shape)}, loaded.shape={tuple(loaded_weight.shape)}"
+ )
+
+ # For MoE tensors, dim0 is num_experts and must match.
+ if param.shape[0] != loaded_weight.shape[0]:
+ raise AssertionError(
+ f"Mismatch in number of experts: param={param.shape[0]}, loaded={loaded_weight.shape[0]}"
+ )
+
+ # Copy the overlapping region: [:, :min(dim1), :min(dim2)]
+ if dims == 2:
+ copy_d1 = min(param.shape[1], loaded_weight.shape[1])
+ param.data[:, :copy_d1].copy_(loaded_weight[:, :copy_d1])
+ return
+
+ # dims == 3
+ copy_d1 = min(param.shape[1], loaded_weight.shape[1])
+ copy_d2 = min(param.shape[2], loaded_weight.shape[2])
+ param.data[:, :copy_d1, :copy_d2].copy_(
+ loaded_weight[:, :copy_d1, :copy_d2]
+ )
diff --git a/vllm/model_executor/models/AXK1.py b/vllm/model_executor/models/AXK1.py
new file mode 100644
index 0000000..f5ed440
--- /dev/null
+++ b/vllm/model_executor/models/AXK1.py
@@ -0,0 +1,1168 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+# Adapted from
+# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
+# Copyright 2023 The vLLM team.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only A.X K1 model."""
+
+import typing
+from collections.abc import Callable, Iterable
+from itertools import islice
+
+import torch
+from torch import nn
+
+from vllm._aiter_ops import rocm_aiter_ops
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import CacheConfig, ParallelConfig, VllmConfig
+from vllm.distributed import (
+ get_ep_group,
+ get_pp_group,
+ get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+ tensor_model_parallel_all_gather,
+)
+from vllm.logger import init_logger
+from vllm.model_executor.layers.attention import Attention
+from vllm.model_executor.layers.fused_moe import SharedFusedMoE
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.linear import (
+ ColumnParallelLinear,
+ MergedColumnParallelLinear,
+ ReplicatedLinear,
+ RowParallelLinear,
+)
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.rotary_embedding import get_rope
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ ParallelLMHead,
+ VocabParallelEmbedding,
+)
+from vllm.model_executor.model_loader.weight_utils import (
+ default_weight_loader,
+ maybe_remap_kv_scale_name,
+)
+from vllm.model_executor.models.deepseek_v2 import (
+ DeepseekAttention,
+ DeepseekV2MLP,
+ yarn_get_mscale,
+)
+from vllm.model_executor.models.utils import sequence_parallel_chunk
+from vllm.platforms import current_platform
+from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.configs.AXK1 import AXK1Config
+
+from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP
+from .utils import (
+ PPMissingLayer,
+ is_pp_missing_parameter,
+ make_empty_intermediate_tensors_factory,
+ make_layers,
+ maybe_prefix,
+)
+
+logger = init_logger(__name__)
+
+
+class AXK1MLP(DeepseekV2MLP):
+ pass
+
+
+class AXK1MoE(nn.Module):
+ def __init__(
+ self,
+ config: AXK1Config,
+ parallel_config: ParallelConfig,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ self.tp_size = get_tensor_model_parallel_world_size()
+ self.tp_rank = get_tensor_model_parallel_rank()
+
+ self.routed_scaling_factor = config.routed_scaling_factor
+
+ self.ep_group = get_ep_group().device_group
+ self.ep_rank = get_ep_group().rank_in_group
+ self.ep_size = self.ep_group.size()
+ self.n_routed_experts: int = config.n_routed_experts
+ self.n_shared_experts: int = config.n_shared_experts
+
+ self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
+
+ if config.hidden_act != "silu":
+ raise ValueError(
+ f"Unsupported activation: {config.hidden_act}. "
+ "Only silu is supported for now."
+ )
+
+ self.gate = ReplicatedLinear(
+ config.hidden_size,
+ config.n_routed_experts,
+ bias=False,
+ quant_config=None,
+ prefix=f"{prefix}.gate",
+ )
+ if config.topk_method == "noaux_tc":
+ self.gate.e_score_correction_bias = nn.Parameter(
+ torch.empty(config.n_routed_experts, dtype=torch.float32)
+ )
+ else:
+ self.gate.e_score_correction_bias = None
+
+ # Load balancing settings.
+ eplb_config = parallel_config.eplb_config
+ self.enable_eplb = parallel_config.enable_eplb
+
+ self.n_redundant_experts = eplb_config.num_redundant_experts
+ self.n_logical_experts = self.n_routed_experts
+ self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
+ self.n_local_physical_experts = self.n_physical_experts // self.ep_size
+
+ self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
+ self.physical_expert_end = (
+ self.physical_expert_start + self.n_local_physical_experts
+ )
+
+ self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
+ self.is_fusion_moe_shared_experts_enabled = (
+ rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
+ )
+ if config.n_shared_experts is None or self.is_fusion_moe_shared_experts_enabled:
+ self.shared_experts = None
+ else:
+ intermediate_size = config.moe_intermediate_size * config.n_shared_experts
+
+ self.shared_experts = AXK1MLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=intermediate_size,
+ hidden_act=config.hidden_act,
+ quant_config=quant_config,
+ is_sequence_parallel=self.is_sequence_parallel,
+ reduce_results=False,
+ prefix=f"{prefix}.shared_experts",
+ )
+
+ self.experts = SharedFusedMoE(
+ shared_experts=self.shared_experts,
+ gate=self.gate,
+ num_experts=config.n_routed_experts,
+ top_k=config.num_experts_per_tok,
+ hidden_size=config.hidden_size,
+ intermediate_size=config.moe_intermediate_size,
+ reduce_results=False,
+ renormalize=config.norm_topk_prob,
+ quant_config=quant_config,
+ use_grouped_topk=True,
+ num_expert_group=config.n_group,
+ topk_group=config.topk_group,
+ prefix=f"{prefix}.experts",
+ scoring_func=config.scoring_func,
+ # we do scaling outside, set factor to 1.0 to avoid double mul
+ # aiter applies routed_scaling_factor internally
+ routed_scaling_factor=1.0
+ if not self.is_rocm_aiter_moe_enabled
+ else self.routed_scaling_factor,
+ e_score_correction_bias=self.gate.e_score_correction_bias,
+ enable_eplb=self.enable_eplb,
+ num_redundant_experts=self.n_redundant_experts,
+ is_sequence_parallel=self.is_sequence_parallel,
+ n_shared_experts=config.n_shared_experts
+ if self.is_fusion_moe_shared_experts_enabled
+ else None,
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ num_tokens, hidden_dim = hidden_states.shape
+ hidden_states = hidden_states.view(-1, hidden_dim)
+
+ # Chunk the hidden states so they aren't replicated across TP ranks.
+ # This avoids duplicate computation in self.experts.
+ # TODO: We can replace the all_reduce at the end of attn with a
+ # reduce_scatter instead of chunking here.
+ if self.is_sequence_parallel:
+ hidden_states = sequence_parallel_chunk(hidden_states)
+
+ if self.experts.is_internal_router:
+ # In this case, the gate/router runs inside the FusedMoE class
+ fused_moe_out = self.experts(
+ hidden_states=hidden_states, router_logits=hidden_states
+ )
+ else:
+ # router_logits: (num_tokens, n_experts)
+ router_logits, _ = self.gate(hidden_states)
+ fused_moe_out = self.experts(
+ hidden_states=hidden_states, router_logits=router_logits
+ )
+
+ shared_output, final_hidden_states = fused_moe_out
+ if self.shared_experts is None:
+ assert shared_output is None
+
+ # Fix FP16 overflow
+ # See AXK1DecoderLayer for more details.
+ if hidden_states.dtype != torch.float16:
+ if not self.is_rocm_aiter_moe_enabled:
+ final_hidden_states *= self.routed_scaling_factor
+ elif self.shared_experts is not None:
+ assert shared_output is not None
+ shared_output *= 1.0 / self.routed_scaling_factor
+
+ if self.shared_experts is not None:
+ assert shared_output is not None
+ final_hidden_states += shared_output
+
+ if self.is_sequence_parallel:
+ final_hidden_states = tensor_model_parallel_all_gather(
+ final_hidden_states, 0
+ )
+ final_hidden_states = final_hidden_states[:num_tokens]
+ elif self.tp_size > 1:
+ final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
+ final_hidden_states
+ )
+
+ return final_hidden_states.view(num_tokens, hidden_dim)
+
+
+def _get_llama_4_scaling(
+ original_max_position_embeddings: int, scaling_beta: float, positions: torch.Tensor
+) -> torch.Tensor:
+ scaling = 1 + scaling_beta * torch.log(
+ 1 + torch.floor(positions / original_max_position_embeddings)
+ )
+ # Broadcast over num_heads and head_dim
+ return scaling[..., None, None]
+
+
+class AXK1Attention(nn.Module):
+ def __init__(
+ self,
+ vllm_config: VllmConfig,
+ config: AXK1Config,
+ hidden_size: int,
+ num_heads: int,
+ qk_nope_head_dim: int,
+ qk_rope_head_dim: int,
+ v_head_dim: int,
+ q_lora_rank: int,
+ kv_lora_rank: int,
+ max_position_embeddings: int = 8192,
+ cache_config: CacheConfig | None = None,
+ quant_config: QuantizationConfig | None = None,
+ topk_indices_buffer: torch.Tensor | None = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.qk_nope_head_dim = qk_nope_head_dim
+ self.qk_rope_head_dim = qk_rope_head_dim
+ self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
+ self.v_head_dim = v_head_dim
+ self.q_lora_rank = q_lora_rank
+ self.kv_lora_rank = kv_lora_rank
+ self.num_heads = num_heads
+ tp_size = get_tensor_model_parallel_world_size()
+ assert num_heads % tp_size == 0
+ self.num_local_heads = num_heads // tp_size
+ self.scaling = self.qk_head_dim**-0.5
+ self.max_position_embeddings = max_position_embeddings
+ assert topk_indices_buffer is None, (
+ "topk_indices_buffer is not \
+ supported for AXK1Attention"
+ )
+
+ if self.q_lora_rank is not None:
+ self.q_a_proj = ReplicatedLinear(
+ self.hidden_size,
+ self.q_lora_rank,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.q_a_proj",
+ )
+ self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
+ self.q_b_proj = ColumnParallelLinear(
+ q_lora_rank,
+ self.num_heads * self.qk_head_dim,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.q_b_proj",
+ )
+ else:
+ self.q_proj = ColumnParallelLinear(
+ self.hidden_size,
+ self.num_heads * self.qk_head_dim,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.q_proj",
+ )
+
+ self.kv_a_proj_with_mqa = ReplicatedLinear(
+ self.hidden_size,
+ self.kv_lora_rank + self.qk_rope_head_dim,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.kv_a_proj_with_mqa",
+ )
+ self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
+ self.kv_b_proj = ColumnParallelLinear(
+ self.kv_lora_rank,
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.kv_b_proj",
+ )
+ # O projection.
+ self.o_proj = RowParallelLinear(
+ self.num_heads * self.v_head_dim,
+ self.hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.o_proj",
+ )
+ if config.rope_parameters["rope_type"] != "default":
+ config.rope_parameters["rope_type"] = (
+ "deepseek_yarn"
+ if config.rope_parameters.get("apply_yarn_scaling", True)
+ else "deepseek_llama_scaling"
+ )
+
+ self.rotary_emb = get_rope(
+ qk_rope_head_dim,
+ max_position=max_position_embeddings,
+ rope_parameters=config.rope_parameters,
+ is_neox_style=False,
+ )
+
+ if config.rope_parameters["rope_type"] == "deepseek_yarn":
+ mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
+ scaling_factor = config.rope_parameters["factor"]
+ mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
+ self.scaling = self.scaling * mscale * mscale
+
+ self.attn = Attention(
+ self.num_local_heads,
+ self.qk_head_dim,
+ self.scaling,
+ num_kv_heads=self.num_local_heads,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.attn",
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ llama_4_scaling: torch.Tensor | None,
+ ) -> torch.Tensor:
+ if self.q_lora_rank is not None:
+ q = self.q_a_proj(hidden_states)[0]
+ q = self.q_a_layernorm(q)
+ q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
+ else:
+ q = self.q_proj(hidden_states)[0].view(
+ -1, self.num_local_heads, self.qk_head_dim
+ )
+ q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
+ latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
+ kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
+ latent_cache = latent_cache.unsqueeze(1)
+ kv_a = self.kv_a_layernorm(kv_a)
+ kv = self.kv_b_proj(kv_a)[0]
+ kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
+ k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
+ k_pe = latent_cache[:, :, self.kv_lora_rank :]
+
+ q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
+
+ q[..., self.qk_nope_head_dim :] = q_pe
+ k = torch.empty_like(q)
+ k[..., : self.qk_nope_head_dim] = k_nope
+ k[..., self.qk_nope_head_dim :] = k_pe
+
+ # Apply llama 4 scaling if provided
+ if llama_4_scaling is not None:
+ q *= llama_4_scaling
+
+ # padding value to qk_head_dim for alignment
+ v = torch.nn.functional.pad(
+ v, [0, self.qk_head_dim - self.v_head_dim], value=0
+ ).view(-1, self.num_local_heads * self.qk_head_dim)
+ attn_output = self.attn(q, k, v)
+ attn_output = attn_output.view(-1, self.num_local_heads, self.qk_head_dim)[
+ ..., : self.v_head_dim
+ ].reshape(-1, self.num_local_heads * self.v_head_dim)
+ output, _ = self.o_proj(attn_output)
+ return output
+
+
+class AXK1MLAAttention(nn.Module):
+ """
+ Main reference: DeepseekV2 paper, and FlashInfer Implementation
+ (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
+
+ For more info see MLACommonImpl in:
+ vllm/v1/attention/backends/mla/utils.py
+ """
+
+ def __init__(
+ self,
+ vllm_config: VllmConfig,
+ config: AXK1Config,
+ hidden_size: int,
+ num_heads: int,
+ qk_nope_head_dim: int,
+ qk_rope_head_dim: int,
+ v_head_dim: int,
+ q_lora_rank: int | None,
+ kv_lora_rank: int,
+ max_position_embeddings: int = 8192,
+ cache_config: CacheConfig | None = None,
+ quant_config: QuantizationConfig | None = None,
+ prefix: str = "",
+ topk_indices_buffer: torch.Tensor | None = None,
+ ) -> None:
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.qk_nope_head_dim = qk_nope_head_dim
+ self.qk_rope_head_dim = qk_rope_head_dim
+ self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
+ self.v_head_dim = v_head_dim
+
+ self.q_lora_rank = q_lora_rank
+ self.kv_lora_rank = kv_lora_rank
+
+ self.num_heads = num_heads
+ tp_size = get_tensor_model_parallel_world_size()
+ assert num_heads % tp_size == 0
+ self.num_local_heads = num_heads // tp_size
+
+ self.scaling = self.qk_head_dim**-0.5
+ self.max_position_embeddings = max_position_embeddings
+
+ if self.q_lora_rank is not None:
+ self.fused_qkv_a_proj = MergedColumnParallelLinear(
+ self.hidden_size,
+ [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.fused_qkv_a_proj",
+ disable_tp=True,
+ )
+ else:
+ self.kv_a_proj_with_mqa = ReplicatedLinear(
+ self.hidden_size,
+ self.kv_lora_rank + self.qk_rope_head_dim,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.kv_a_proj_with_mqa",
+ )
+
+ if self.q_lora_rank is not None:
+ self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
+ self.q_b_proj = ColumnParallelLinear(
+ self.q_lora_rank,
+ self.num_heads * self.qk_head_dim,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.q_b_proj",
+ )
+ else:
+ self.q_proj = ColumnParallelLinear(
+ self.hidden_size,
+ self.num_heads * self.qk_head_dim,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.q_proj",
+ )
+ self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
+ self.kv_b_proj = ColumnParallelLinear(
+ self.kv_lora_rank,
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.kv_b_proj",
+ )
+ self.o_proj = RowParallelLinear(
+ self.num_heads * self.v_head_dim,
+ self.hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.o_proj",
+ )
+
+ if config.rope_parameters["rope_type"] != "default":
+ config.rope_parameters["rope_type"] = (
+ "deepseek_yarn"
+ if config.rope_parameters.get("apply_yarn_scaling", True)
+ else "deepseek_llama_scaling"
+ )
+
+ self.rotary_emb = get_rope(
+ qk_rope_head_dim,
+ max_position=max_position_embeddings,
+ rope_parameters=config.rope_parameters,
+ is_neox_style=False,
+ )
+
+ if config.rope_parameters["rope_type"] == "deepseek_yarn":
+ mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
+ scaling_factor = config.rope_parameters["factor"]
+ mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
+ self.scaling = self.scaling * mscale * mscale
+
+ mla_modules = MLAModules(
+ kv_a_layernorm=self.kv_a_layernorm,
+ kv_b_proj=self.kv_b_proj,
+ rotary_emb=self.rotary_emb,
+ o_proj=self.o_proj,
+ fused_qkv_a_proj=self.fused_qkv_a_proj
+ if self.q_lora_rank is not None
+ else None,
+ kv_a_proj_with_mqa=self.kv_a_proj_with_mqa
+ if self.q_lora_rank is None
+ else None,
+ q_a_layernorm=self.q_a_layernorm if self.q_lora_rank is not None else None,
+ q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
+ q_proj=self.q_proj if self.q_lora_rank is None else None,
+ indexer=None,
+ indexer_rotary_emb=None,
+ is_sparse=False,
+ topk_indices_buffer=topk_indices_buffer,
+ )
+
+ self.mla_attn = MultiHeadLatentAttentionWrapper(
+ self.hidden_size,
+ self.num_local_heads,
+ self.scaling,
+ self.qk_nope_head_dim,
+ self.qk_rope_head_dim,
+ self.v_head_dim,
+ self.q_lora_rank,
+ self.kv_lora_rank,
+ mla_modules,
+ cache_config,
+ quant_config,
+ prefix,
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ llama_4_scaling: torch.Tensor | None,
+ ) -> torch.Tensor:
+ return self.mla_attn(positions, hidden_states, llama_4_scaling)
+
+
+class AXK1DecoderLayer(nn.Module):
+ def __init__(
+ self,
+ vllm_config: VllmConfig,
+ prefix: str,
+ config: AXK1Config | None = None,
+ ) -> None:
+ super().__init__()
+
+ if config is None:
+ config = vllm_config.model_config.hf_config
+ model_config = vllm_config.model_config
+ cache_config = vllm_config.cache_config
+ quant_config = vllm_config.quant_config
+ parallel_config = vllm_config.parallel_config
+ self.config = config
+
+ self.hidden_size = config.hidden_size
+ max_position_embeddings = config.max_position_embeddings
+ # DecoderLayers are created with `make_layers` which passes the prefix
+ # with the layer's index.
+ layer_idx = int(prefix.split(sep=".")[-1])
+ self.layer_idx = layer_idx
+
+ # verify MLA attention specific fields
+ qk_nope_head_dim = config.qk_nope_head_dim
+ qk_rope_head_dim = config.qk_rope_head_dim
+ v_head_dim = config.v_head_dim
+ kv_lora_rank = config.kv_lora_rank
+ use_mha = all(dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim))
+ self.use_mha = use_mha
+
+ if use_mha:
+ attn_cls = DeepseekAttention
+ elif model_config.use_mla:
+ attn_cls = AXK1MLAAttention
+ else:
+ attn_cls = AXK1Attention
+ self.self_attn = attn_cls(
+ vllm_config=vllm_config,
+ config=config,
+ hidden_size=self.hidden_size,
+ num_heads=config.num_attention_heads,
+ qk_nope_head_dim=qk_nope_head_dim,
+ qk_rope_head_dim=qk_rope_head_dim,
+ v_head_dim=v_head_dim,
+ q_lora_rank=config.q_lora_rank,
+ kv_lora_rank=kv_lora_rank,
+ max_position_embeddings=max_position_embeddings,
+ cache_config=cache_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.self_attn",
+ topk_indices_buffer=None,
+ )
+
+ self.is_layer_sparse = self._is_layer_sparse()
+ if self.is_layer_sparse:
+ self.mlp = AXK1MoE(
+ config=config,
+ parallel_config=parallel_config,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
+ )
+ else:
+ self.mlp = AXK1MLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp",
+ )
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(
+ config.hidden_size, eps=config.rms_norm_eps
+ )
+ self.post_mlp_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.routed_scaling_factor = config.routed_scaling_factor
+
+ def _is_layer_sparse(self) -> bool:
+ return (
+ self.config.n_routed_experts is not None
+ and self.layer_idx >= self.config.first_k_dense_replace
+ and self.layer_idx % self.config.moe_layer_freq == 0
+ )
+
+ def forward(
+ self,
+ positions: torch.Tensor,
+ hidden_states: torch.Tensor,
+ residual: torch.Tensor | None,
+ llama_4_scaling: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ # Self Attention
+ if residual is None:
+ residual = hidden_states.clone()
+ hidden_states = self.input_layernorm(hidden_states)
+ else:
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
+
+ attn_kwargs = {
+ "positions": positions,
+ "hidden_states": hidden_states,
+ }
+ if not self.use_mha:
+ attn_kwargs["llama_4_scaling"] = llama_4_scaling
+ hidden_states = self.self_attn(**attn_kwargs)
+
+ if (
+ not isinstance(self.self_attn, DeepseekAttention)
+ and hidden_states.dtype == torch.float16
+ ):
+ # Fix FP16 overflow
+ # We scale both hidden_states and residual before
+ # rmsnorm, and rmsnorm result would not affect by scale.
+ hidden_states *= 1.0 / self.routed_scaling_factor
+ if self.layer_idx == 0:
+ # The residual is shared by all layers, we only scale it on
+ # first layer.
+ residual *= 1.0 / self.routed_scaling_factor
+
+ # Fully Connected
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
+ hidden_states = self.mlp(hidden_states)
+
+ if self.is_layer_sparse:
+ hidden_states = self.post_mlp_layernorm(hidden_states)
+
+ if isinstance(self.mlp, AXK1MLP) and hidden_states.dtype == torch.float16:
+ # Fix FP16 overflow
+ # Scaling the AXK1MLP output, it is the input of
+ # input_layernorm of next decoder layer.
+ # The scaling of AXK1MOE output would be done in the forward
+ # of AXK1MOE
+ hidden_states *= 1.0 / self.routed_scaling_factor
+
+ return hidden_states, residual
+
+
+@support_torch_compile
+class AXK1Model(nn.Module):
+ fall_back_to_pt_during_load = False
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+
+ config: AXK1Config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ self.config = config
+ self.device = current_platform.device_type
+ self.vocab_size = config.vocab_size
+
+ if get_pp_group().is_first_rank:
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ prefix=f"{prefix}.embed_tokens",
+ )
+ else:
+ self.embed_tokens = PPMissingLayer()
+ self.start_layer, self.end_layer, self.layers = make_layers(
+ config.num_hidden_layers,
+ lambda prefix: AXK1DecoderLayer(vllm_config, prefix),
+ prefix=f"{prefix}.layers",
+ )
+
+ if get_pp_group().is_last_rank:
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ else:
+ self.norm = PPMissingLayer()
+ self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
+ ["hidden_states", "residual"], config.hidden_size
+ )
+
+ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.embed_tokens(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor | None,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None,
+ inputs_embeds: torch.Tensor | None = None,
+ ) -> torch.Tensor | IntermediateTensors:
+ if get_pp_group().is_first_rank:
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+ else:
+ hidden_states = self.embed_input_ids(input_ids)
+ residual = None
+ else:
+ assert intermediate_tensors is not None
+ hidden_states = intermediate_tensors["hidden_states"]
+ residual = intermediate_tensors["residual"]
+
+ # Compute llama 4 scaling once per forward pass if enabled
+ llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None)
+ llama_4_scaling: torch.Tensor | None
+ if llama_4_scaling_config is not None:
+ llama_4_scaling = _get_llama_4_scaling(
+ original_max_position_embeddings=llama_4_scaling_config[
+ "original_max_position_embeddings"
+ ],
+ scaling_beta=llama_4_scaling_config["beta"],
+ positions=positions,
+ )
+ else:
+ llama_4_scaling = None
+
+ for layer in islice(self.layers, self.start_layer, self.end_layer):
+ hidden_states, residual = layer(
+ positions, hidden_states, residual, llama_4_scaling
+ )
+
+ if not get_pp_group().is_last_rank:
+ return IntermediateTensors(
+ {"hidden_states": hidden_states, "residual": residual}
+ )
+
+ hidden_states, _ = self.norm(hidden_states, residual)
+ return hidden_states
+
+
+class AXK1MixtureOfExperts(MixtureOfExperts):
+ moe_mlp_layers: list[AXK1MoE]
+ """
+ List of MoE MLP layers in the model.
+ """
+
+ def extract_moe_parameters(self, example_moe: AXK1MoE | None):
+ if example_moe is None:
+ self.num_moe_layers = 0
+ self.num_expert_groups = 0
+ self.num_logical_experts = 0
+ self.num_physical_experts = 0
+ self.num_local_physical_experts = 0
+ self.num_routed_experts = 0
+ self.num_shared_experts = 0
+ self.num_redundant_experts = 0
+ logger.warning("AXK1: No AXK1MoE layer found in model.layers.")
+ else:
+ self.num_logical_experts = example_moe.n_logical_experts
+ self.num_physical_experts = example_moe.n_physical_experts
+ self.num_local_physical_experts = example_moe.n_local_physical_experts
+ self.num_routed_experts = example_moe.n_routed_experts
+ self.num_shared_experts = example_moe.n_shared_experts
+ self.num_redundant_experts = example_moe.n_redundant_experts
+
+ def update_physical_experts_metadata(
+ self,
+ num_physical_experts: int,
+ num_local_physical_experts: int,
+ ) -> None:
+ assert self.num_local_physical_experts == num_local_physical_experts
+ self.num_physical_experts = num_physical_experts
+ self.num_local_physical_experts = num_local_physical_experts
+ self.num_redundant_experts = num_physical_experts - self.num_logical_experts
+ for moe in self.moe_mlp_layers:
+ moe.n_local_physical_experts = num_local_physical_experts
+ moe.n_physical_experts = num_physical_experts
+ moe.n_redundant_experts = self.num_redundant_experts
+ moe.experts.update_expert_map()
+
+
+class AXK1ForCausalLM(
+ nn.Module, SupportsPP, AXK1MixtureOfExperts, SupportsLoRA, SupportsEagle
+):
+ packed_modules_mapping = {
+ "gate_up_proj": ["gate_proj", "up_proj"],
+ }
+ model_cls = AXK1Model
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config: AXK1Config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ self.config = config
+ self.quant_config = quant_config
+
+ qk_nope_head_dim = config.qk_nope_head_dim
+ qk_rope_head_dim = config.qk_rope_head_dim
+ self.use_mha = all(dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim))
+
+ if self.use_mha:
+ self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]
+
+ # `packed_modules_mapping` needs to be modified before
+ # initializing AXK1Model, as it is passed inplace to
+ # quantization config init and may be used to select the
+ # quant_method for relevant layers during initialization.
+ self.fuse_qkv_a_proj = config.q_lora_rank is not None
+ if self.fuse_qkv_a_proj:
+ self.packed_modules_mapping["fused_qkv_a_proj"] = [
+ "q_a_proj",
+ "kv_a_proj_with_mqa",
+ ]
+
+ self.model = self.model_cls(
+ vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
+ )
+ if get_pp_group().is_last_rank:
+ self.lm_head = ParallelLMHead(
+ config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ prefix=maybe_prefix(prefix, "lm_head"),
+ )
+ else:
+ self.lm_head = PPMissingLayer()
+ self.logits_processor = LogitsProcessor(config.vocab_size)
+ self.make_empty_intermediate_tensors = (
+ self.model.make_empty_intermediate_tensors
+ )
+ # Set MoE hyperparameters
+ self.num_moe_layers = (
+ self.config.num_hidden_layers - self.config.first_k_dense_replace
+ )
+ self.set_moe_parameters()
+
+ def set_moe_parameters(self):
+ self.expert_weights = []
+
+ self.num_expert_groups = getattr(self.config, "n_group", 1)
+
+ self.moe_layers = []
+ self.moe_mlp_layers = []
+ example_moe = None
+ for layer in self.model.layers:
+ if isinstance(layer, PPMissingLayer):
+ continue
+
+ assert isinstance(layer, AXK1DecoderLayer)
+ if isinstance(layer.mlp, AXK1MoE):
+ # Pick last one layer since the first ones may be dense layers.
+ example_moe = layer.mlp
+ self.moe_mlp_layers.append(layer.mlp)
+ self.moe_layers.append(layer.mlp.experts)
+
+ self.extract_moe_parameters(example_moe)
+
+ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.model.embed_input_ids(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor | None,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ ) -> torch.Tensor | IntermediateTensors:
+ hidden_states = self.model(
+ input_ids, positions, intermediate_tensors, inputs_embeds
+ )
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor | None:
+ logits = self.logits_processor(self.lm_head, hidden_states)
+ return logits
+
+ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
+ # Params for weights, fp8 weight scales, fp8 activation scales
+ # (param_name, weight_name, expert_id, shard_id)
+ return SharedFusedMoE.make_expert_params_mapping(
+ self,
+ ckpt_gate_proj_name="gate_proj",
+ ckpt_down_proj_name="down_proj",
+ ckpt_up_proj_name="up_proj",
+ num_experts=self.config.n_routed_experts,
+ num_redundant_experts=0,
+ )
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ rocm_aiter_moe_shared_expert_enabled = (
+ rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
+ )
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("gate_up_proj", "gate_proj", 0),
+ ("gate_up_proj", "up_proj", 1),
+ ]
+ mla_params_mapping = [
+ ("fused_qkv_a_proj", "q_a_proj", 0),
+ ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
+ ]
+ mha_params_mapping = [
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ]
+ if self.use_mha:
+ stacked_params_mapping.extend(mha_params_mapping)
+ else:
+ stacked_params_mapping.extend(mla_params_mapping)
+
+ # Params for weights, fp8 weight scales, fp8 activation scales
+ # (param_name, weight_name, expert_id, shard_id)
+ expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
+ self,
+ ckpt_gate_proj_name="gate_proj",
+ ckpt_down_proj_name="down_proj",
+ ckpt_up_proj_name="up_proj",
+ num_experts=self.config.n_routed_experts
+ + (
+ self.config.n_shared_experts
+ if rocm_aiter_moe_shared_expert_enabled
+ else 0
+ ),
+ num_redundant_experts=self.num_redundant_experts,
+ )
+
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+ for name, loaded_weight in weights:
+ if "rotary_emb.inv_freq" in name:
+ continue
+
+ spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
+ if spec_layer is not None:
+ continue # skip spec decode layers for main model
+
+ is_fusion_moe_shared_experts_layer = (
+ rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
+ )
+
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ # Skip non-stacked layers and experts (experts handled below).
+ if weight_name not in name:
+ continue
+ # We have mlp.experts[0].gate_proj in the checkpoint.
+ # Since we handle the experts below in expert_params_mapping,
+ # we need to skip here BEFORE we update the name, otherwise
+ # name will be updated to mlp.experts[0].gate_up_proj, which
+ # will then be updated below in expert_params_mapping
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
+ if ("mlp.experts." in name) and name not in params_dict:
+ continue
+ if is_fusion_moe_shared_experts_layer:
+ continue
+ name_mapped = name.replace(weight_name, param_name)
+
+ # QKV fusion is optional, fall back to normal
+ # weight loading if it's not enabled
+ # if go with fusion option, then update name
+ if (
+ param_name == "fused_qkv_a_proj"
+ ) and name_mapped not in params_dict:
+ continue
+ else:
+ name = name_mapped
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+
+ if is_pp_missing_parameter(name, self):
+ continue
+
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ is_expert_weight = False
+
+ # Special handling: when AITER fusion_shared_experts is enabled,
+ # checkpoints may provide a single widened shared_experts tensor
+ # without explicit expert indices
+ # (e.g. ...mlp.shared_experts.gate_proj.weight).
+ # For models with multiple shared experts, split that tensor
+ # evenly into per-shared-expert slices and load them into
+ # appended expert slots mlp.experts.{n_routed_experts + j}.*
+ # accordingly.
+ num_chunks = 1
+ if is_fusion_moe_shared_experts_layer:
+ num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
+ # Determine split axis based on op type
+ # gate/up: ColumnParallel → split along dim 0
+ # down: RowParallel → split along dim 1
+ split_dim = (
+ 1
+ if ("down_proj.weight" in name and loaded_weight.ndim > 1)
+ else 0
+ )
+ total = loaded_weight.shape[split_dim]
+ assert total % num_chunks == 0, (
+ f"Shared expert weight dim {total} "
+ f"not divisible by num_chunks {num_chunks}"
+ )
+ chunk_size = total // num_chunks
+
+ for j in range(num_chunks):
+ chunk_name = name
+ weight_to_load = loaded_weight
+
+ if is_fusion_moe_shared_experts_layer:
+ chunk_slice = slice(j * chunk_size, (j + 1) * chunk_size)
+ if loaded_weight.ndim == 1:
+ weight_to_load = loaded_weight[chunk_slice]
+ elif split_dim == 0:
+ weight_to_load = loaded_weight[chunk_slice, :]
+ else:
+ weight_to_load = loaded_weight[:, chunk_slice]
+ # Synthesize an expert-style name so expert mapping
+ # can route it
+ chunk_name = name.replace(
+ "mlp.shared_experts",
+ f"mlp.experts.{self.config.n_routed_experts + j}",
+ )
+
+ # Use expert_params_mapping to locate the destination
+ # param and delegate to its expert-aware weight_loader
+ # with expert_id.
+ for mapping in expert_params_mapping:
+ param_name, weight_name, expert_id, shard_id = mapping
+ if weight_name not in chunk_name:
+ continue
+
+ # Anyway, this is an expert weight and should not be
+ # attempted to load as other weights later
+ is_expert_weight = True
+
+ # Do not modify `name` since the loop may continue here
+ # Instead, create a new variable
+ name_mapped = chunk_name.replace(weight_name, param_name)
+
+ if is_pp_missing_parameter(name_mapped, self):
+ continue
+
+ param = params_dict[name_mapped]
+ # We should ask the weight loader to return success or
+ # not here since otherwise we may skip experts with
+ # other available replicas.
+ weight_loader = typing.cast(
+ Callable[..., bool], param.weight_loader
+ )
+ success = weight_loader(
+ param,
+ weight_to_load,
+ name_mapped,
+ shard_id=shard_id,
+ expert_id=expert_id,
+ return_success=True,
+ )
+ if success:
+ if not is_fusion_moe_shared_experts_layer:
+ name = name_mapped
+ else:
+ loaded_params.add(name_mapped)
+ break
+ else:
+ if is_expert_weight:
+ # We've checked that this is an expert weight
+ # However it's not mapped locally to this rank
+ # So we simply skip it
+ continue
+
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+
+ # Remapping the name of FP8 kv-scale.
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
+
+ if is_pp_missing_parameter(name, self):
+ continue
+
+ param = params_dict[name]
+ weight_loader = getattr(
+ param, "weight_loader", default_weight_loader
+ )
+ weight_loader(param, loaded_weight)
+ if not is_fusion_moe_shared_experts_layer:
+ loaded_params.add(name)
+
+ return loaded_params
+
+
+def get_spec_layer_idx_from_weight_name(
+ config: AXK1Config, weight_name: str
+) -> int | None:
+ if config.num_nextn_predict_layers and config.num_nextn_predict_layers > 0:
+ layer_idx = config.num_hidden_layers
+ for i in range(config.num_nextn_predict_layers):
+ if weight_name.startswith(f"model.layers.{layer_idx + i}."):
+ return layer_idx + i
+ return None
diff --git a/vllm/model_executor/models/bailing_moe_linear.py b/vllm/model_executor/models/bailing_moe_linear.py
new file mode 100644
index 0000000..9b54ec6
--- /dev/null
+++ b/vllm/model_executor/models/bailing_moe_linear.py
@@ -0,0 +1,1246 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import copy
+from collections.abc import Iterable
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers.configuration_utils import PretrainedConfig
+
+from vllm.compilation.decorators import support_torch_compile
+from vllm.config import CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config
+from vllm.distributed import (
+ get_pp_group,
+ get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+)
+from vllm.forward_context import get_forward_context
+from vllm.logger import init_logger
+from vllm.model_executor.layers.fla.ops.layernorm_guard import (
+ RMSNormGated,
+ layernorm_fn,
+)
+from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE
+from vllm.model_executor.layers.layernorm import RMSNorm
+from vllm.model_executor.layers.linear import (
+ ColumnParallelLinear,
+ MergedColumnParallelLinear,
+ QKVParallelLinear,
+ ReplicatedLinear,
+ RowParallelLinear,
+)
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.mamba.abstract import MambaBase
+from vllm.model_executor.layers.mamba.linear_attn import (
+ MiniMaxText01LinearAttention,
+ MiniMaxText01LinearKernel,
+ MiniMaxText01RMSNormTP,
+ clear_linear_attention_cache_for_new_sequences,
+ linear_attention_decode,
+ linear_attention_prefill_and_mix,
+)
+from vllm.model_executor.layers.mamba.mamba_utils import (
+ MambaStateCopyFuncCalculator,
+ MambaStateDtypeCalculator,
+ MambaStateShapeCalculator,
+)
+from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper
+from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
+from vllm.model_executor.layers.rotary_embedding import get_rope
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+ ParallelLMHead,
+ VocabParallelEmbedding,
+)
+from vllm.model_executor.model_loader.weight_utils import (
+ default_weight_loader,
+ maybe_remap_kv_scale_name,
+)
+from vllm.model_executor.models.bailing_moe import BailingMLP
+from vllm.sequence import IntermediateTensors
+from vllm.v1.attention.backend import AttentionMetadata
+from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
+
+from .interfaces import HasInnerState, IsHybrid, SupportsPP
+from .utils import (
+ AutoWeightsLoader,
+ PPMissingLayer,
+ is_pp_missing_parameter,
+ make_layers,
+ maybe_prefix,
+)
+
+logger = init_logger(__name__)
+
+
+def is_linear_layer(layer_idx, layer_group_size):
+ if layer_idx is None:
+ return False
+ if layer_group_size > 0:
+ return (layer_idx + 1) % layer_group_size != 0
+ else:
+ return False
+
+
+def _build_rope_parameters(config: PretrainedConfig) -> dict | None:
+ rope_parameters = copy.deepcopy(getattr(config, "rope_parameters", None)) or {}
+ if "rope_theta" not in rope_parameters and hasattr(config, "rope_theta"):
+ rope_parameters["rope_theta"] = config.rope_theta
+ if "partial_rotary_factor" not in rope_parameters and hasattr(
+ config, "partial_rotary_factor"
+ ):
+ rope_parameters["partial_rotary_factor"] = config.partial_rotary_factor
+
+ rope_scaling = getattr(config, "rope_scaling", None)
+ if isinstance(rope_scaling, dict):
+ rope_scaling = copy.deepcopy(rope_scaling)
+ if "type" in rope_scaling and "rope_type" not in rope_scaling:
+ rope_scaling["rope_type"] = rope_scaling.pop("type")
+ rope_parameters.update(rope_scaling)
+
+ return rope_parameters or None
+
+
+class BailingMoeV25MLAAttention(nn.Module):
+ """
+ MLA Attention for BailingMoeV2.5 full attention layers.
+ """
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: QuantizationConfig | None = None,
+ layer_id: int = 0,
+ prefix: str = "attention",
+ cache_config: CacheConfig | None = None,
+ ) -> None:
+ super().__init__()
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.layer_id = layer_id
+ self.prefix = prefix
+
+ # MLA dimensions
+ self.qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 128)
+ self.qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 64)
+ self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim
+ self.v_head_dim = getattr(config, "v_head_dim", 128)
+
+ # LoRA ranks
+ self.q_lora_rank = getattr(config, "q_lora_rank", None)
+ self.kv_lora_rank = getattr(config, "kv_lora_rank", 512)
+
+ tp_size = get_tensor_model_parallel_world_size()
+ assert self.num_heads % tp_size == 0
+ self.num_local_heads = self.num_heads // tp_size
+
+ self.scaling = self.qk_head_dim**-0.5
+
+ # KV projections
+ self.kv_a_layernorm = RMSNorm(
+ self.kv_lora_rank,
+ eps=config.rms_norm_eps,
+ )
+ self.kv_b_proj = ColumnParallelLinear(
+ self.kv_lora_rank,
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.kv_b_proj",
+ )
+
+ # Output projection
+ self.o_proj = RowParallelLinear(
+ self.num_heads * self.v_head_dim,
+ self.hidden_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.o_proj",
+ )
+
+ if self.q_lora_rank is not None:
+ # Use fused_qkv_a_proj when q_lora_rank is set
+ self.fused_qkv_a_proj = MergedColumnParallelLinear(
+ self.hidden_size,
+ [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.fused_qkv_a_proj",
+ disable_tp=True,
+ )
+ self.q_a_layernorm = RMSNorm(
+ self.q_lora_rank,
+ eps=config.rms_norm_eps,
+ )
+ self.q_b_proj = ColumnParallelLinear(
+ self.q_lora_rank,
+ self.num_heads * self.qk_head_dim,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.q_b_proj",
+ )
+ self.q_proj = None
+ self.kv_a_proj_with_mqa = None
+ else:
+ # Direct projections when no q_lora_rank
+ self.q_proj = ColumnParallelLinear(
+ self.hidden_size,
+ self.num_heads * self.qk_head_dim,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.q_proj",
+ )
+ self.kv_a_proj_with_mqa = ReplicatedLinear(
+ self.hidden_size,
+ self.kv_lora_rank + self.qk_rope_head_dim,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.kv_a_proj_with_mqa",
+ )
+ self.fused_qkv_a_proj = None
+ self.q_a_layernorm = None
+ self.q_b_proj = None
+
+ rope_parameters = _build_rope_parameters(config)
+ max_position = getattr(config, "max_position_embeddings", 8192)
+ self.rotary_emb = get_rope(
+ head_size=self.qk_rope_head_dim,
+ max_position=max_position,
+ is_neox_style=False,
+ rope_parameters=rope_parameters or None,
+ dtype=torch.float32,
+ )
+
+ # Build MLAModules for MultiHeadLatentAttentionWrapper
+ mla_modules = MLAModules(
+ kv_a_layernorm=self.kv_a_layernorm,
+ kv_b_proj=self.kv_b_proj,
+ rotary_emb=self.rotary_emb,
+ o_proj=self.o_proj,
+ fused_qkv_a_proj=self.fused_qkv_a_proj,
+ kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
+ q_a_layernorm=self.q_a_layernorm,
+ q_b_proj=self.q_b_proj,
+ q_proj=self.q_proj,
+ indexer=None,
+ is_sparse=False,
+ topk_indices_buffer=None,
+ )
+
+ self.mla_attn = MultiHeadLatentAttentionWrapper(
+ self.hidden_size,
+ self.num_local_heads,
+ self.scaling,
+ self.qk_nope_head_dim,
+ self.qk_rope_head_dim,
+ self.v_head_dim,
+ self.q_lora_rank,
+ self.kv_lora_rank,
+ mla_modules,
+ cache_config,
+ quant_config,
+ prefix,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ positions: torch.Tensor,
+ ) -> torch.Tensor:
+ """Forward pass for MLA attention."""
+ return self.mla_attn(positions, hidden_states)
+
+
+class BailingMoEGate(nn.Module):
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ params_dtype: torch.dtype | None = None,
+ prefix: str = "",
+ ):
+ super().__init__()
+ if params_dtype is None:
+ params_dtype = torch.get_default_dtype()
+ self.params_dtype = params_dtype
+ self.weight = nn.Parameter(
+ torch.empty(
+ (config.num_experts, config.hidden_size),
+ dtype=self.params_dtype,
+ ),
+ )
+ if getattr(config, "moe_router_enable_expert_bias", False):
+ self.expert_bias = nn.Parameter(
+ torch.empty((config.num_experts,), dtype=torch.float32),
+ )
+ else:
+ self.expert_bias = None
+
+ def forward(self, hidden_states):
+ logits = F.linear(hidden_states.to(self.weight.dtype), self.weight, None).to(
+ hidden_states.dtype
+ )
+ return logits
+
+
+class BailingMoeV25(nn.Module):
+ """Bailing MoE v2.5 - standalone implementation for linear attention model."""
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: QuantizationConfig | None = None,
+ layer_id: int = 0,
+ prefix: str = "",
+ ):
+ super().__init__()
+
+ self.layer_id = layer_id
+ self.tp_size = get_tensor_model_parallel_world_size()
+ self.tp_rank = get_tensor_model_parallel_rank()
+ self.num_experts = config.num_experts
+ self.top_k = config.num_experts_per_tok
+ norm_topk_prob = getattr(config, "norm_topk_prob", None)
+ # Ring-2.5 reference implementations normalize routing weights by default.
+ self.norm_expert_prob = True if norm_topk_prob is None else bool(norm_topk_prob)
+ self.hidden_size = config.hidden_size
+ self.quant_config = quant_config
+ self.num_shared_experts = config.num_shared_experts
+ self.score_function = getattr(config, "score_function", None)
+ self.n_group = getattr(config, "n_group", None)
+ self.topk_group = getattr(config, "topk_group", None)
+ self.use_grouped_topk = self.n_group is not None and self.topk_group is not None
+ self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
+
+ router_dtype = getattr(config, "router_dtype", None)
+ if router_dtype is None or router_dtype == "fp32":
+ self.router_dtype = torch.float32
+ else:
+ self.router_dtype = torch.bfloat16
+
+ # Gate for routing
+ self.gate = BailingMoEGate(
+ config=config,
+ params_dtype=self.router_dtype,
+ prefix=f"{prefix}.gate",
+ )
+ correction_bias = (
+ self.gate.expert_bias if self.gate.expert_bias is not None else None
+ )
+ if self.score_function is not None:
+ assert (self.score_function == "softmax" and correction_bias is None) or (
+ self.score_function == "sigmoid" and correction_bias is not None
+ ), (
+ "score_function and correction_bias should be "
+ "(softmax, None) or (sigmoid, not None)"
+ )
+
+ # Shared experts (using BailingMLP)
+ if self.num_shared_experts > 0:
+ if hasattr(config, "moe_shared_expert_intermediate_size"):
+ intermediate_size = config.moe_shared_expert_intermediate_size
+ else:
+ intermediate_size = config.moe_intermediate_size
+ intermediate_size *= config.num_shared_experts
+ self.shared_experts = BailingMLP(
+ intermediate_size=intermediate_size,
+ config=config,
+ quant_config=quant_config,
+ reduce_results=False,
+ prefix=f"{prefix}.shared_experts",
+ )
+ else:
+ self.shared_experts = None
+
+ # Routed experts using SharedFusedMoE
+ self.experts = SharedFusedMoE(
+ shared_experts=self.shared_experts,
+ num_experts=self.num_experts,
+ top_k=self.top_k,
+ hidden_size=self.hidden_size,
+ intermediate_size=config.moe_intermediate_size,
+ reduce_results=False,
+ renormalize=self.norm_expert_prob,
+ quant_config=quant_config,
+ prefix=f"{prefix}.experts",
+ scoring_func=self.score_function,
+ e_score_correction_bias=correction_bias,
+ num_expert_group=self.n_group,
+ topk_group=self.topk_group,
+ use_grouped_topk=self.use_grouped_topk,
+ router_logits_dtype=self.router_dtype,
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ num_tokens, hidden_size = hidden_states.shape
+ # Ensure contiguous token-major layout before router/projections.
+ hidden_states = hidden_states.contiguous().view(-1, hidden_size)
+
+ # router_logits: (num_tokens, n_experts)
+ router_logits = self.gate(hidden_states.to(self.router_dtype))
+ router_logits = router_logits.to(hidden_states.dtype)
+
+ final_hidden_states = self.experts(
+ hidden_states=hidden_states, router_logits=router_logits
+ )
+
+ # Handle tuple return from SharedFusedMoE
+ if self.shared_experts is not None:
+ shared_output, final_hidden_states = final_hidden_states
+ else:
+ shared_output = None
+
+ final_hidden_states *= self.routed_scaling_factor
+
+ if shared_output is not None:
+ final_hidden_states = final_hidden_states + shared_output
+
+ if self.tp_size > 1:
+ final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
+ final_hidden_states
+ )
+
+ return final_hidden_states.view(num_tokens, hidden_size)
+
+
+BailingRMSNormTP = MiniMaxText01RMSNormTP
+
+
+class BailingGroupRMSNormGate(RMSNormGated):
+ def __init__(
+ self,
+ hidden_size,
+ eps=1e-5,
+ group_size=None,
+ norm_before_gate=True,
+ device=None,
+ dtype=None,
+ ):
+ super().__init__(
+ hidden_size,
+ eps=eps,
+ group_size=group_size,
+ norm_before_gate=norm_before_gate,
+ device=device,
+ dtype=dtype,
+ activation="sigmoid",
+ )
+ # Add custom weight loader for TP sharding
+ self.weight.weight_loader = self._weight_loader
+
+ @staticmethod
+ def _weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor) -> None:
+ """Load weight with TP sharding."""
+ tp_size = get_tensor_model_parallel_world_size()
+ tp_rank = get_tensor_model_parallel_rank()
+ shard_size = loaded_weight.shape[0] // tp_size
+ shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
+ param.data.copy_(loaded_weight[shard].contiguous())
+
+
+class BailingMoELinearAttention(nn.Module, MambaBase):
+ """
+ Bailing MoE Linear Attention implementation using minimax backend.
+
+ This implements the linear attention mechanism from sglang, adapted for vLLM's
+ v1 engine with MambaBase interface support.
+ """
+
+ @property
+ def mamba_type(self) -> str:
+ return "linear_attention"
+
+ def get_state_shape(self) -> tuple[tuple[int, ...], ...]:
+ """Return state shape for linear attention cache.
+
+ Must match the calculation in get_mamba_state_shape_from_config.
+ """
+ return MambaStateShapeCalculator.linear_attention_state_shape(
+ num_heads=self.total_num_heads,
+ tp_size=self.tp_size,
+ head_dim=self.head_dim,
+ )
+
+ def get_state_dtype(self) -> tuple[torch.dtype, ...]:
+ """Return state dtype for linear attention cache.
+
+ Must match the calculation in get_mamba_state_dtype_from_config.
+ """
+ return MambaStateDtypeCalculator.linear_attention_state_dtype(
+ self.model_config.dtype,
+ self.cache_config.mamba_cache_dtype,
+ )
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: QuantizationConfig | None = None,
+ layer_id: int = 0,
+ prefix: str = "linear_attn",
+ model_config: ModelConfig | None = None,
+ cache_config: CacheConfig | None = None,
+ ):
+ super().__init__()
+
+ self.layer_id = layer_id
+ self.hidden_size = config.hidden_size
+ self.total_num_heads = config.num_attention_heads
+ self.total_kv_heads = config.num_attention_heads # MHA
+ self.tp_size = get_tensor_model_parallel_world_size()
+ self.tp_rank = get_tensor_model_parallel_rank()
+ self.model_config = model_config
+ self.cache_config = cache_config
+ self.prefix = prefix
+
+ self.head_dim = (
+ config.head_dim
+ if hasattr(config, "head_dim")
+ else config.hidden_size // self.total_num_heads
+ )
+
+ self.hidden_inner_size = self.head_dim * self.total_num_heads
+ self.scaling = self.head_dim**-0.5
+
+ assert self.total_num_heads % self.tp_size == 0
+ self.tp_heads = self.total_num_heads // self.tp_size
+
+ self.max_position_embeddings = config.max_position_embeddings
+ self.rope_theta = getattr(config, "rope_theta", 600000)
+
+ self.tp_kv_heads = self.total_kv_heads // self.tp_size
+ self.q_size_per_rank = self.head_dim * self.tp_heads
+ self.kv_size_per_rank = self.head_dim * self.tp_kv_heads
+
+ self.use_qk_norm = getattr(config, "use_qk_norm", False)
+ self.linear_backend = "minimax"
+ self.linear_scale = self.linear_backend == "minimax"
+ self.linear_rope = getattr(config, "linear_rope", True)
+ if hasattr(config, "use_linear_silu"):
+ self.linear_silu = config.use_linear_silu
+ elif hasattr(config, "linear_silu"):
+ self.linear_silu = config.linear_silu
+ else:
+ self.linear_silu = False
+
+ # Block size for lightning attention
+ self.BLOCK = getattr(config, "block", 256)
+
+ self.query_key_value = QKVParallelLinear(
+ self.hidden_size,
+ self.head_dim,
+ self.total_num_heads,
+ self.total_num_heads, # MHA: kv_heads = num_heads
+ bias=(config.use_bias or config.use_qkv_bias),
+ quant_config=quant_config,
+ prefix=f"{prefix}.query_key_value",
+ )
+
+ if self.use_qk_norm:
+ self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+ self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
+
+ self.g_proj = ColumnParallelLinear(
+ self.hidden_size,
+ self.hidden_inner_size,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.g_proj",
+ )
+ self.dense = RowParallelLinear(
+ self.hidden_inner_size,
+ self.hidden_size,
+ bias=config.use_bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.dense",
+ reduce_results=True,
+ )
+
+ self.group_norm_size = getattr(config, "group_norm_size", 1)
+ self.rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-5))
+ assert self.tp_size <= self.group_norm_size, (
+ "tp_size must be <= group_norm_size for local rms norm"
+ )
+ assert self.group_norm_size % self.tp_size == 0, (
+ "group_norm_size must be divisible by tp_size"
+ )
+
+ # When group_norm_size == 1, group_size equals hidden_size // tp_size
+ self.g_norm = BailingGroupRMSNormGate(
+ hidden_size=self.hidden_inner_size // self.tp_size,
+ eps=self.rms_norm_eps,
+ group_size=(
+ self.hidden_inner_size // self.group_norm_size
+ if self.group_norm_size > 1
+ else self.hidden_inner_size // self.tp_size
+ ),
+ )
+
+ # use fp32 rotary embedding
+ rope_parameters = _build_rope_parameters(config)
+
+ self.rotary_emb = get_rope(
+ self.head_dim,
+ max_position=self.max_position_embeddings,
+ is_neox_style=True,
+ dtype=torch.float32,
+ rope_parameters=rope_parameters or None,
+ )
+
+ # Build slope tensor for linear attention decay
+ num_hidden_layers = config.num_hidden_layers
+ slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(
+ self.total_num_heads
+ )
+ if num_hidden_layers <= 1:
+ self.slope_rate = slope_rate * (1 + 1e-5)
+ else:
+ self.slope_rate = slope_rate * (
+ 1 - layer_id / (num_hidden_layers - 1) + 1e-5
+ )
+ self.tp_slope = self.slope_rate[
+ self.tp_rank * self.tp_heads : (self.tp_rank + 1) * self.tp_heads
+ ].contiguous()
+
+ # Register for compilation
+ compilation_config = get_current_vllm_config().compilation_config
+ if prefix in compilation_config.static_forward_context:
+ raise ValueError(f"Duplicate layer name: {prefix}")
+ compilation_config.static_forward_context[prefix] = self
+
+ @staticmethod
+ def weight_direct_load(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
+ """Load weight for linear attention layers.
+
+ For FP8 quantized parameters, we need to use the weight_loader if available,
+ as it handles special cases like tensor parallelism sharding.
+ """
+ # Check if param has a weight_loader (for vLLM ModelWeightParameter)
+ weight_loader = getattr(param, "weight_loader", None)
+ if weight_loader is not None:
+ # Use the weight_loader which handles TP sharding and quantization
+ weight_loader(param, loaded_weight)
+ else:
+ # Fall back to direct copy for standard tensors
+ assert param.size() == loaded_weight.size(), (
+ f"Shape mismatch: {param.shape} vs {loaded_weight.shape}"
+ )
+ param.data.copy_(loaded_weight)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ output: torch.Tensor,
+ positions: torch.Tensor,
+ ) -> None:
+ """Forward method called by torch.ops.vllm.linear_attention"""
+ torch.ops.vllm.linear_attention(
+ hidden_states,
+ output,
+ positions,
+ self.prefix,
+ )
+
+ def _forward(
+ self,
+ hidden_states: torch.Tensor,
+ output: torch.Tensor,
+ positions: torch.Tensor,
+ ) -> None:
+ """Actual forward implementation."""
+ forward_context = get_forward_context()
+ attn_metadata: AttentionMetadata = forward_context.attn_metadata
+ if attn_metadata is not None:
+ assert isinstance(attn_metadata, dict)
+ attn_metadata = attn_metadata[self.prefix]
+ assert isinstance(attn_metadata, LinearAttentionMetadata)
+ num_actual_tokens = (
+ attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
+ )
+ else:
+ num_actual_tokens = hidden_states.shape[0]
+
+ # QKV projection
+ qkv, _ = self.query_key_value(hidden_states[:num_actual_tokens])
+
+ # use rotary_emb support fp32
+ qkv = qkv.to(torch.float32)
+ if self.linear_silu:
+ qkv = F.silu(qkv)
+
+ # Split q, k, v
+ q, k, v = torch.split(
+ qkv,
+ [self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank],
+ dim=-1,
+ )
+
+ # Apply QK norm if needed
+ if self.use_qk_norm:
+ q = q.reshape(-1, self.tp_heads, self.head_dim)
+ k = k.reshape(-1, self.tp_kv_heads, self.head_dim)
+ q = layernorm_fn(
+ q,
+ self.query_layernorm.weight.data,
+ bias=None,
+ eps=self.rms_norm_eps,
+ is_rms_norm=True,
+ )
+ k = layernorm_fn(
+ k,
+ self.key_layernorm.weight.data,
+ bias=None,
+ eps=self.rms_norm_eps,
+ is_rms_norm=True,
+ )
+ q = q.reshape(-1, self.q_size_per_rank)
+ k = k.reshape(-1, self.kv_size_per_rank)
+
+ # Apply rotary embeddings
+ if self.linear_rope:
+ q, k = self.rotary_emb(positions[:num_actual_tokens], q, k)
+
+ # Reshape to [batch, heads, seq_len, head_dim]
+ q = q.view((qkv.shape[0], self.tp_heads, self.head_dim))
+ k = k.view((qkv.shape[0], self.tp_kv_heads, self.head_dim))
+ v = v.view((qkv.shape[0], self.tp_kv_heads, self.head_dim))
+
+ # Apply scaling if using minimax backend
+ if self.linear_scale:
+ q = q * self.scaling
+
+ # Get KV cache and state indices
+ if attn_metadata is not None:
+ kv_cache = self.kv_cache[forward_context.virtual_engine][0]
+ state_indices_tensor = attn_metadata.state_indices_tensor
+ clear_linear_attention_cache_for_new_sequences(
+ kv_cache, state_indices_tensor, attn_metadata
+ )
+
+ # Compute attention
+ decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
+ if attn_metadata is None:
+ hidden = torch.empty(
+ (q.shape[0], q.shape[1] * q.shape[2]), device=q.device, dtype=q.dtype
+ )
+ else:
+ if not decode_only:
+ hidden = self._prefill_and_mix_infer(
+ q, k, v, kv_cache, state_indices_tensor, attn_metadata
+ )
+ else:
+ hidden = self._decode_infer(
+ q, k, v, kv_cache, state_indices_tensor, attn_metadata
+ )
+
+ # Apply group norm and gate (matching SGLang behavior)
+ gate, _ = self.g_proj(hidden_states[:num_actual_tokens])
+
+ if self.group_norm_size > 1:
+ hidden = self.g_norm(hidden, gate)
+ else:
+ hidden = self.g_norm(hidden)
+ hidden = F.sigmoid(gate) * hidden
+
+ hidden = hidden.to(hidden_states.dtype)
+
+ # Output projection
+ dense_out, _ = self.dense(hidden)
+ output[:num_actual_tokens] = dense_out
+
+ def _prefill_and_mix_infer(
+ self, q, k, v, kv_cache, state_indices_tensor, attn_metadata
+ ):
+ """Handle prefill (mixed with decode if any)."""
+ return linear_attention_prefill_and_mix(
+ q=q,
+ k=k,
+ v=v,
+ kv_cache=kv_cache,
+ state_indices_tensor=state_indices_tensor,
+ attn_metadata=attn_metadata,
+ slope_rate=self.tp_slope,
+ block_size=self.BLOCK,
+ decode_fn=self._decode_infer,
+ prefix_fn=MiniMaxText01LinearKernel.jit_linear_forward_prefix,
+ layer_idx=self.layer_id,
+ )
+
+ def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata):
+ """Handle decode (single token per sequence)."""
+ num_prefill_tokens = attn_metadata.num_prefill_tokens
+ num_prefills = attn_metadata.num_prefills
+ hidden = linear_attention_decode(
+ q,
+ k,
+ v,
+ kv_cache,
+ self.tp_slope,
+ state_indices_tensor,
+ q_start=num_prefill_tokens,
+ q_end=None,
+ slot_start=num_prefills,
+ slot_end=None,
+ block_size=32,
+ )
+ return hidden
+
+
+class BailingMoeV25DecoderLayer(nn.Module):
+ """Decoder layer supporting both linear and full attention."""
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ quant_config: QuantizationConfig | None = None,
+ layer_id: int = 0,
+ prefix: str = "layer",
+ model_config: ModelConfig | None = None,
+ cache_config: CacheConfig | None = None,
+ ) -> None:
+ super().__init__()
+ self.layer_id = layer_id
+ self.hidden_size = config.hidden_size
+
+ # Determine attention type (0 = linear, 1 = full)
+ self.attention_type = getattr(config, "attention_type", 1)
+
+ if self.attention_type == 0: # Linear attention
+ self.self_attn = BailingMoELinearAttention(
+ config,
+ quant_config=quant_config,
+ layer_id=layer_id,
+ prefix=f"{prefix}.self_attn",
+ model_config=model_config,
+ cache_config=cache_config,
+ )
+ else: # Full attention
+ self.self_attn = BailingMoeV25MLAAttention(
+ config,
+ quant_config=quant_config,
+ layer_id=layer_id,
+ prefix=f"{prefix}.self_attn",
+ cache_config=cache_config,
+ )
+
+ # MLP/MoE
+ is_moe_layer = config.num_experts > 1 and layer_id >= getattr(
+ config, "first_k_dense_replace", 0
+ )
+
+ if is_moe_layer:
+ self.mlp = BailingMoeV25(
+ config,
+ quant_config=quant_config,
+ layer_id=layer_id,
+ prefix=f"{prefix}.mlp",
+ )
+ else:
+ self.mlp = BailingMLP(
+ intermediate_size=config.intermediate_size,
+ config=config,
+ quant_config=quant_config,
+ reduce_results=True,
+ prefix=f"{prefix}.mlp",
+ )
+
+ # Layer norms
+ rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-5))
+ self.input_layernorm = RMSNorm(self.hidden_size, eps=rms_norm_eps)
+ self.post_attention_layernorm = RMSNorm(self.hidden_size, eps=rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ positions: torch.Tensor,
+ attn_metadata: AttentionMetadata | None = None,
+ residual: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
+ # Input layernorm
+ if residual is None:
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ else:
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
+
+ # Self attention
+ if self.attention_type == 0:
+ # Linear attention uses output tensor
+ self_attention_output = torch.zeros_like(hidden_states)
+ self.self_attn(
+ hidden_states=hidden_states,
+ output=self_attention_output,
+ positions=positions,
+ )
+ else:
+ # Full attention
+ self_attention_output = self.self_attn(hidden_states, positions)
+
+ hidden_states, residual = self.post_attention_layernorm(
+ self_attention_output, residual
+ )
+ hidden_states = self.mlp(hidden_states)
+ return hidden_states, residual
+
+
+@support_torch_compile(
+ dynamic_arg_dims={
+ "input_ids": 0,
+ "positions": -1,
+ "intermediate_tensors": 0,
+ "inputs_embeds": 0,
+ }
+)
+class BailingMoeV25Model(nn.Module):
+ """Bailing MoE v2.5 Model with hybrid attention support."""
+
+ def __init__(
+ self,
+ *,
+ vllm_config: VllmConfig,
+ prefix: str = "",
+ ):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ model_config = vllm_config.model_config
+ quant_config = vllm_config.quant_config
+ cache_config = vllm_config.cache_config
+
+ self.config = config
+ self.vocab_size = config.vocab_size
+ self.embed_dim = config.hidden_size
+
+ # Determine layer types based on layer_group_size
+ self.layer_group_size = getattr(config, "layer_group_size", 1)
+ self.num_layers = config.num_hidden_layers
+
+ # decoder_attention_types: 0 = linear, 1 = full
+ self.decoder_attention_types = [
+ 0 if is_linear_layer(i, self.layer_group_size) else 1
+ for i in range(self.num_layers)
+ ]
+
+ # Embeddings
+ if get_pp_group().is_first_rank:
+ self.word_embeddings = VocabParallelEmbedding(
+ self.vocab_size,
+ self.embed_dim,
+ org_num_embeddings=self.vocab_size,
+ )
+ else:
+ from vllm.model_executor.models.utils import PPMissingLayer
+
+ self.word_embeddings = PPMissingLayer()
+
+ # Layers
+ def layer_fn(prefix):
+ layer_idx = int(prefix.split(".")[-1])
+ layer_config = copy.deepcopy(config)
+ layer_config.attention_type = self.decoder_attention_types[layer_idx]
+
+ return BailingMoeV25DecoderLayer(
+ config=layer_config,
+ quant_config=quant_config,
+ layer_id=layer_idx,
+ prefix=prefix,
+ model_config=model_config,
+ cache_config=cache_config,
+ )
+
+ self.start_layer, self.end_layer, self.layers = make_layers(
+ self.num_layers, layer_fn, prefix=f"{prefix}.layers"
+ )
+
+ # Final norm
+ norm_kwargs = {}
+ if hasattr(config, "rms_norm_eps"):
+ norm_kwargs["eps"] = config.rms_norm_eps
+ if get_pp_group().is_last_rank:
+ self.norm = RMSNorm(config.hidden_size, **norm_kwargs)
+ else:
+ from vllm.model_executor.models.utils import PPMissingLayer
+
+ self.norm = PPMissingLayer()
+
+ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.word_embeddings(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor | None,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ forward_context = get_forward_context()
+ attn_metadata = forward_context.attn_metadata
+
+ if get_pp_group().is_first_rank:
+ if inputs_embeds is None:
+ hidden_states = self.word_embeddings(input_ids)
+ else:
+ hidden_states = inputs_embeds
+ residual = None
+ else:
+ assert intermediate_tensors is not None
+ hidden_states = intermediate_tensors["hidden_states"]
+ residual = intermediate_tensors["residual"]
+
+ for layer in self.layers[self.start_layer : self.end_layer]:
+ hidden_states, residual = layer(
+ hidden_states=hidden_states,
+ positions=positions,
+ attn_metadata=attn_metadata,
+ residual=residual,
+ )
+
+ if not get_pp_group().is_last_rank:
+ return IntermediateTensors(
+ {"hidden_states": hidden_states, "residual": residual}
+ )
+ else:
+ if residual is not None:
+ hidden_states, _ = self.norm(hidden_states, residual)
+ else:
+ hidden_states = self.norm(hidden_states)
+ return hidden_states
+
+ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
+ """Get expert parameter mapping for MoE layers."""
+ return FusedMoE.make_expert_params_mapping(
+ self,
+ ckpt_gate_proj_name="gate_proj",
+ ckpt_down_proj_name="down_proj",
+ ckpt_up_proj_name="up_proj",
+ num_experts=self.config.num_experts,
+ num_redundant_experts=0,
+ )
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ """Load checkpoint weights with simplified mapping."""
+
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
+ loaded_params: set[str] = set()
+
+ # Stacked parameter mappings (fused projections)
+ stacked_mappings = [
+ (".fused_qkv_a_proj", ".q_a_proj", 0),
+ (".fused_qkv_a_proj", ".kv_a_proj_with_mqa", 1),
+ (".gate_up_proj", ".gate_proj", 0),
+ (".gate_up_proj", ".up_proj", 1),
+ ]
+
+ # Expert parameter mappings from FusedMoE
+ expert_mappings = list(self.get_expert_mapping())
+
+ def load_param(name: str, tensor: torch.Tensor, shard_id=None) -> bool:
+ """Load a single parameter."""
+ if name not in params_dict or is_pp_missing_parameter(name, self):
+ return False
+ if name.endswith(".bias") and name not in params_dict:
+ return False
+
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
+
+ if shard_id is None:
+ weight_loader(param, tensor)
+ elif isinstance(shard_id, int):
+ weight_loader(param, tensor, shard_id)
+ else:
+ # Expert param: (expert_id, shard_id)
+ weight_loader(
+ param, tensor, name, expert_id=shard_id[0], shard_id=shard_id[1]
+ )
+
+ loaded_params.add(name)
+ return True
+
+ def normalize_name(name: str) -> str | None:
+ """Normalize checkpoint name to model parameter name."""
+ # Skip special weights
+ if name.startswith("model.mtp"):
+ return None
+ # Remove 'model.' prefix if present
+ # (e.g., 'model.layers.0...' -> 'layers.0...')
+ name = name.removeprefix("model.")
+ # Map attention.dense based on layer type
+ if "attention.dense" in name:
+ layer_idx = (
+ int(name.split("layers.")[1].split(".")[0])
+ if "layers." in name
+ else 0
+ )
+ attn_name = (
+ "self_attn.dense"
+ if is_linear_layer(layer_idx, self.config.layer_group_size)
+ else "self_attn.o_proj"
+ )
+ name = name.replace("attention.dense", attn_name)
+
+ # Standard mappings
+ name = name.replace("attention.", "self_attn.")
+ name = name.replace(
+ "mlp.gate.e_score_correction_bias", "mlp.gate.expert_bias"
+ )
+
+ return maybe_remap_kv_scale_name(name, params_dict)
+
+ for orig_name, weight in weights:
+ norm_name = normalize_name(orig_name)
+ if norm_name is None:
+ continue
+
+ # Try stacked mappings
+ loaded = False
+ for param_suf, weight_suf, shard_id in stacked_mappings:
+ if weight_suf not in norm_name:
+ continue
+ mapped = norm_name.replace(weight_suf, param_suf).replace(
+ "attention.", "self_attn."
+ )
+ if load_param(mapped, weight, shard_id):
+ loaded = True
+ break
+ if loaded:
+ continue
+
+ # Handle expert weights
+ if "mlp.experts" in norm_name:
+ # Expert bias
+ if (
+ "mlp.experts.e_score_correction_bias" in norm_name
+ or "mlp.experts.expert_bias" in norm_name
+ ):
+ alt = norm_name.replace(
+ "mlp.experts.e_score_correction_bias", "mlp.gate.expert_bias"
+ ).replace("mlp.experts.expert_bias", "mlp.gate.expert_bias")
+ if load_param(alt, weight) or load_param(norm_name, weight):
+ continue
+
+ # Routed experts
+ for param_name, weight_name, expert_id, shard_id in expert_mappings:
+ if weight_name not in norm_name:
+ continue
+ mapped = norm_name.replace(weight_name, param_name)
+ if load_param(mapped, weight, (expert_id, shard_id)):
+ break
+ continue
+
+ # General parameters
+ load_param(norm_name, weight)
+
+ return loaded_params
+
+
+class BailingMoeV25ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsPP):
+ """Bailing MoE v2.5 For CausalLM."""
+
+ packed_modules_mapping = {
+ "gate_up_proj": ["gate_proj", "up_proj"],
+ }
+
+ def __init__(
+ self,
+ *,
+ vllm_config: VllmConfig,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+
+ self.config = config
+ self.quant_config = quant_config
+
+ self.model = BailingMoeV25Model(
+ vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "model"),
+ )
+
+ if get_pp_group().is_last_rank:
+ self.lm_head = ParallelLMHead(
+ config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ )
+ self.logits_processor = LogitsProcessor(config.vocab_size)
+ else:
+ self.lm_head = PPMissingLayer()
+
+ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.model.embed_input_ids(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor | None,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ hidden_states = self.model(
+ input_ids=input_ids,
+ positions=positions,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ )
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor:
+ return self.logits_processor(self.lm_head, hidden_states)
+
+ def make_empty_intermediate_tensors(
+ self, batch_size: int, dtype: torch.dtype, device: torch.device
+ ) -> IntermediateTensors:
+ return IntermediateTensors(
+ {
+ "hidden_states": torch.zeros(
+ (batch_size, self.config.hidden_size), dtype=dtype, device=device
+ ),
+ "residual": torch.zeros(
+ (batch_size, self.config.hidden_size), dtype=dtype, device=device
+ ),
+ }
+ )
+
+ @classmethod
+ def get_mamba_state_shape_from_config(
+ cls,
+ vllm_config: VllmConfig,
+ ) -> tuple[tuple[int, ...], ...]:
+ """Calculate shape for linear attention cache."""
+ config = vllm_config.model_config.hf_config
+ tp_size = vllm_config.parallel_config.tensor_parallel_size
+
+ head_dim = getattr(
+ config, "head_dim", config.hidden_size // config.num_attention_heads
+ )
+
+ # Return base state shape from linear attention (no padding)
+ return MambaStateShapeCalculator.linear_attention_state_shape(
+ num_heads=config.num_attention_heads,
+ tp_size=tp_size,
+ head_dim=head_dim,
+ )
+
+ @classmethod
+ def get_mamba_state_dtype_from_config(
+ cls,
+ vllm_config: VllmConfig,
+ ) -> tuple[torch.dtype, ...]:
+ return MambaStateDtypeCalculator.linear_attention_state_dtype(
+ vllm_config.model_config.dtype,
+ vllm_config.cache_config.mamba_cache_dtype,
+ )
+
+ @classmethod
+ def get_mamba_state_copy_func(cls) -> tuple:
+ return MambaStateCopyFuncCalculator.linear_attention_state_copy_func()
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(self)
+ return loader.load_weights(weights)
+
+ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
+ return self.model.get_expert_mapping()
diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py
index 27cf3a7..1f326ce 100644
--- a/vllm/model_executor/models/config.py
+++ b/vllm/model_executor/models/config.py
@@ -112,6 +112,42 @@ class LlamaBidirectionalConfig(VerifyAndUpdateConfig):
model_config.pooler_config.seq_pooling_type = pooling_type
+class LlamaNemotronVLConfig(VerifyAndUpdateConfig):
+ """Config handler for LlamaNemotronVL embedding models."""
+
+ @staticmethod
+ def verify_and_update_model_config(model_config: "ModelConfig") -> None:
+ from vllm.config.pooler import SequencePoolingType
+
+ hf_config = model_config.hf_config
+
+ # Set bidirectional attention on the language model config
+ hf_config.is_causal = False
+ if hasattr(hf_config, "llm_config"):
+ hf_config.llm_config.is_causal = False
+
+ if hasattr(hf_config, "vision_config"):
+ hf_config.patch_size = hf_config.vision_config.patch_size
+
+ # Set up pooling type
+ pooling_type_map: dict[str, SequencePoolingType] = {
+ "avg": "MEAN",
+ "cls": "CLS",
+ "last": "LAST",
+ }
+
+ # Get pooling type from config (check both top-level and llm_config)
+ pooling = getattr(hf_config, "pooling", None)
+ if pooling is None and hasattr(hf_config, "llm_config"):
+ pooling = getattr(hf_config.llm_config, "pooling", "avg")
+
+ pooling_type = pooling_type_map.get(pooling)
+ if pooling_type is None:
+ raise ValueError(f"pool_type {pooling!r} not supported")
+
+ model_config.pooler_config.seq_pooling_type = pooling_type
+
+
class NomicBertModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
@@ -177,7 +213,7 @@ class NomicBertModelConfig(VerifyAndUpdateConfig):
"Nomic context extension is disabled. "
"Changing max_model_len from %s to %s. "
"To enable context extension, see: "
- "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
+ "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.py",
max_model_len_before,
model_config.max_model_len,
)
@@ -293,6 +329,14 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
}
+class Ernie4_5_VLMoeForConditionalGenerationConfig(VerifyAndUpdateConfig):
+ @staticmethod
+ def verify_and_update_config(vllm_config: "VllmConfig") -> None:
+ # Ernie4.5-VL conditionally executes text/vision MoE branches, so
+ # fast_moe_cold_start can silently produce incorrect execution order.
+ vllm_config.compilation_config.fast_moe_cold_start = False
+
+
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
@@ -553,7 +597,7 @@ class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
if cache_config.cache_dtype.startswith("fp8"):
cache_config.cache_dtype = "fp8_ds_mla"
logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
- if cache_config.cache_dtype == "bfloat16":
+ if cache_config.cache_dtype == "auto" or cache_config.cache_dtype == "bfloat16":
cache_config.cache_dtype = "auto"
logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
@@ -619,11 +663,14 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"Gemma3TextModel": Gemma3TextModelConfig,
"LlamaBidirectionalForSequenceClassification": LlamaBidirectionalConfig,
"LlamaBidirectionalModel": LlamaBidirectionalConfig,
+ "LlamaNemotronVLModel": LlamaNemotronVLConfig,
+ "LlamaNemotronVLForSequenceClassification": LlamaNemotronVLConfig,
"NomicBertModel": NomicBertModelConfig,
"Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig,
"Qwen2ForRewardModel": Qwen2ForRewardModelConfig,
"Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
"Qwen3VLForSequenceClassification": Qwen3VLForSequenceClassificationConfig,
+ "Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLMoeForConditionalGenerationConfig, # noqa: E501
"XLMRobertaModel": JinaRobertaModelConfig,
"ColBERTJinaRobertaModel": JinaRobertaModelConfig,
"JinaVLForRanking": JinaVLForSequenceClassificationConfig,
diff --git a/vllm/model_executor/models/deepencoder.py b/vllm/model_executor/models/deepencoder.py
index f7ae426..68c1014 100644
--- a/vllm/model_executor/models/deepencoder.py
+++ b/vllm/model_executor/models/deepencoder.py
@@ -18,6 +18,7 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPVisionConfig
+from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.attention import MMEncoderAttention
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -263,9 +264,13 @@ class Block(nn.Module):
return x
-class RelPosAttention(nn.Module):
+# --8<-- [start:rel_pos_attention]
+@PluggableLayer.register("rel_pos_attention")
+class RelPosAttention(PluggableLayer):
"""Multi-head Attention block with relative position embeddings."""
+ # --8<-- [end:rel_pos_attention]
+
def __init__(
self,
dim: int,
diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py
index 182828c..8dfb679 100644
--- a/vllm/model_executor/models/deepseek_mtp.py
+++ b/vllm/model_executor/models/deepseek_mtp.py
@@ -32,6 +32,7 @@ from .deepseek_v2 import (
DeepseekV2MoE,
get_spec_layer_idx_from_weight_name,
)
+from .interfaces import SupportsPP
from .utils import maybe_prefix
logger = init_logger(__name__)
@@ -180,7 +181,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
@support_torch_compile
-class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
+class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
@@ -415,8 +416,141 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
weight_loader(param, loaded_weight)
if not is_fusion_moe_shared_experts_layer:
loaded_params.add(name)
+
+ # Validate that weights were loaded for each expected MTP layer.
+ loaded_layers: set[int] = set()
+ for param_name in loaded_params:
+ spec_layer = get_spec_layer_idx_from_weight_name(self.config, param_name)
+ if spec_layer is not None:
+ loaded_layers.add(spec_layer)
+ for layer_idx in range(
+ self.model.mtp_start_layer_idx,
+ self.model.mtp_start_layer_idx + self.model.num_mtp_layers,
+ ):
+ if layer_idx not in loaded_layers:
+ raise ValueError(
+ f"MTP speculative decoding layer {layer_idx} weights "
+ f"missing from checkpoint. The checkpoint may have "
+ f"been quantized without including the MTP layers. "
+ f"Use a checkpoint that includes MTP layer weights, "
+ f"or disable speculative decoding."
+ )
+
+ # Post-load optimization: fuse q_a_proj and kv_a_proj_with_mqa
+ # into a single GEMM, then monkey-patch forward to forward_opt.
+ # Same logic as DeepseekV2ForCausalLM.load_weights.
+ opt_support_quant_method = [
+ "GGUFLinearMethod", "UnquantizedLinearMethod",
+ "CompressedTensorsW8A8Int8", "AWQMarlinLinearMethod",
+ ]
+
+ def inject_layer(layer, quant_method, is_mla):
+ logger.info(
+ "DeepSeekMTP optimization: fused q_a_proj and kv_a_proj_with_mqa for layer '%s' (quant_method=%s, is_mla=%s). Forward replaced with forward_opt.",
+ layer.__class__.__name__, quant_method, is_mla)
+ q_lora_rank = getattr(layer, "q_lora_rank", None)
+ if quant_method in ["UnquantizedLinearMethod",
+ "CompressedTensorsW8A8Int8"]:
+ if q_lora_rank is not None:
+ layer.q_a_proj.weight.data = torch.cat(
+ [layer.q_a_proj.weight,
+ layer.kv_a_proj_with_mqa.weight], dim=0)
+ if hasattr(layer.q_a_proj, "weight_scale"):
+ layer.q_a_proj.weight_scale.data = torch.cat(
+ [layer.q_a_proj.weight_scale,
+ layer.kv_a_proj_with_mqa.weight_scale], dim=0)
+ del layer.kv_a_proj_with_mqa.weight_scale
+ elif not is_mla:
+ layer.q_proj.weight.data = torch.cat(
+ [layer.q_proj.weight,
+ layer.kv_a_proj_with_mqa.weight], dim=0)
+ if hasattr(layer.q_proj, "weight_scale"):
+ layer.q_proj.weight_scale.data = torch.cat(
+ [layer.q_proj.weight_scale,
+ layer.kv_a_proj_with_mqa.weight_scale], dim=0)
+ del layer.kv_a_proj_with_mqa.weight_scale
+ else:
+ return
+ del layer.kv_a_proj_with_mqa.weight
+ del layer.kv_a_proj_with_mqa
+ if is_mla:
+ layer.mla_attn.forward = layer.mla_attn.forward_opt
+ else:
+ layer.forward = layer.forward_opt
+ elif quant_method == "GGUFLinearMethod":
+ pass
+ elif quant_method == "AWQMarlinLinearMethod":
+ dtype = layer.kv_a_proj_with_mqa.qweight.dtype
+ assert dtype == torch.int32
+ if q_lora_rank is not None:
+ layer.q_a_proj.qweight.data = torch.cat(
+ [layer.q_a_proj.qweight,
+ layer.kv_a_proj_with_mqa.qweight], dim=1)
+ layer.q_a_proj.scales.data = torch.cat(
+ [layer.q_a_proj.scales,
+ layer.kv_a_proj_with_mqa.scales], dim=1)
+ del layer.kv_a_proj_with_mqa.scales
+ layer.q_a_proj.qzeros.data = torch.cat(
+ [layer.q_a_proj.qzeros,
+ layer.kv_a_proj_with_mqa.qzeros], dim=1)
+ del layer.kv_a_proj_with_mqa.qzeros
+ elif not is_mla:
+ layer.q_proj.weight.data = torch.cat(
+ [layer.q_proj.weight,
+ layer.kv_a_proj_with_mqa.weight], dim=1)
+ layer.q_proj.scales.data = torch.cat(
+ [layer.q_proj.scales,
+ layer.kv_a_proj_with_mqa.scales], dim=1)
+ del layer.kv_a_proj_with_mqa.scales
+ layer.q_proj.qzeros.data = torch.cat(
+ [layer.q_proj.qzeros,
+ layer.kv_a_proj_with_mqa.qzeros], dim=1)
+ del layer.kv_a_proj_with_mqa.qzeros
+ else:
+ return
+ del layer.kv_a_proj_with_mqa.qweight
+ del layer.kv_a_proj_with_mqa
+ if is_mla:
+ layer.mla_attn.forward = layer.mla_attn.forward_opt
+ else:
+ layer.forward = layer.forward_opt
+
+ for _, layer in self.model.named_modules():
+ if layer.__class__.__name__ in [
+ "DeepseekV2Attention", "DeepseekV2MLAAttention"
+ ]:
+ if hasattr(layer.kv_a_proj_with_mqa, "scheme"):
+ quant_method = (
+ layer.kv_a_proj_with_mqa.scheme.__class__.__name__)
+ else:
+ quant_method = (
+ layer.kv_a_proj_with_mqa
+ .quant_method.__class__.__name__)
+ if quant_method not in opt_support_quant_method:
+ break
+ inject_layer(
+ layer, quant_method,
+ is_mla=(layer.__class__.__name__
+ == "DeepseekV2MLAAttention"))
+
+ # Check if all parameters have been loaded
+ all_params = set(params_dict.keys())
+ not_loaded = all_params - loaded_params
+ if not_loaded:
+ logger.warning(
+ "DeepSeekMTP weight loading: %d parameters were NOT loaded.\n%s",
+ len(not_loaded),
+ "\n".join(sorted(not_loaded)),
+ )
+ else:
+ logger.info(
+ "DeepSeekMTP weight loading: All %d parameters loaded successfully.",
+ len(all_params),
+ )
+
return loaded_params
+
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
"""
Rewrite the weight name to match the format of the original model.
diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py
index cb2cf19..5a6547f 100644
--- a/vllm/model_executor/models/deepseek_v2.py
+++ b/vllm/model_executor/models/deepseek_v2.py
@@ -47,7 +47,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
-from vllm.model_executor.layers.fused_moe import SharedFusedMoE
+from vllm.model_executor.layers.fused_moe import GateLinear, SharedFusedMoE
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
@@ -75,7 +75,9 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
+from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionBackend
+from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerBackend,
)
@@ -89,6 +91,7 @@ from .utils import (
make_layers,
maybe_prefix,
)
+import ixformer.inference.functions as ixfops
logger = init_logger(__name__)
@@ -221,73 +224,6 @@ class DeepseekV2MLP(nn.Module):
return x
-class DeepSeekV2Gate(ReplicatedLinear):
- def __init__(
- self,
- hidden_size: int,
- n_experts: int,
- quant_config: QuantizationConfig | None = None,
- prefix: str = "",
- ):
- assert quant_config is None
- super().__init__(
- hidden_size,
- n_experts,
- bias=False,
- quant_config=quant_config,
- prefix=f"{prefix}.gate",
- )
-
- # Unquantized only, will be called "weight".
- assert hasattr(self, "weight")
- is_hopper_or_blackwell = current_platform.is_device_capability(
- (9, 0)
- ) or current_platform.is_device_capability_family(100)
- SUPPORTED_NUM_EXPERTS = [256, 384]
- SUPPORTED_HIDDEN_SIZES = [7168]
-
- self.allow_dsv3_router_gemm = (
- current_platform.is_cuda()
- and is_hopper_or_blackwell
- and n_experts in SUPPORTED_NUM_EXPERTS
- and hidden_size in SUPPORTED_HIDDEN_SIZES
- )
-
- self._out_dtype: torch.dtype | None = None
-
- def set_out_dtype(self, out_dtype: torch.dtype) -> None:
- """
- Set out dtype for the router logits. This is needed after
- __init__, b/c we need to check if the trtllm kernel is
- selected before we decide between bf16 and fp32.
- """
-
- if self._out_dtype is not None:
- raise ValueError("out_dtype has already been set")
- else:
- self._out_dtype = out_dtype
-
- @property
- def out_dtype(self) -> torch.dtype:
- if self._out_dtype is None:
- raise ValueError("out_dtype has not been set yet")
- return self._out_dtype
-
- def forward(
- self,
- x: torch.Tensor,
- ) -> tuple[torch.Tensor, None]:
- """
- Use specialized GEMM for low batch size for DSV3 and KIMI.
- """
- if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
- return ops.dsv3_router_gemm(
- hidden_states=x, router_weight=self.weight, output_dtype=self.out_dtype
- ), None
- else:
- return super().forward(x)
-
-
class DeepseekV2MoE(nn.Module):
def __init__(
self,
@@ -316,23 +252,12 @@ class DeepseekV2MoE(nn.Module):
"Only silu is supported for now."
)
- # self.gate = DeepSeekV2Gate(
- # config.hidden_size,
- # config.n_routed_experts,
- # quant_config=None,
- # prefix=f"{prefix}.gate",
- # )
- self.gate = ReplicatedLinear(
+ self.gate = GateLinear(
config.hidden_size,
config.n_routed_experts,
- bias=False,
- quant_config=None,
prefix=f"{prefix}.gate",
)
if getattr(config, "topk_method", None) == "noaux_tc":
- # self.gate.e_score_correction_bias = nn.Parameter(
- # torch.empty(config.n_routed_experts, dtype=torch.float32)
- # )
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.n_routed_experts)
)
@@ -401,12 +326,12 @@ class DeepseekV2MoE(nn.Module):
else None,
)
- # # NOTE(rob): this is a hack until we finish off the PR for
- # # merging TRTLLM kernels into the MK framework. Then we can
- # # query the MonolithicMK for the expected router logits.
- # self.gate.set_out_dtype(
- # torch.float32 if self.experts.quant_method.is_monolithic else torch.bfloat16
- # )
+ # NOTE(rob): this is a hack until we finish off the PR for
+ # merging TRTLLM kernels into the MK framework. Then we can
+ # query the MonolithicMK for the expected router logits.
+ self.gate.set_out_dtype(
+ torch.float32 if self.experts.quant_method.is_monolithic else torch.bfloat16
+ )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
@@ -443,11 +368,12 @@ class DeepseekV2MoE(nn.Module):
elif self.shared_experts is not None:
assert shared_output is not None
shared_output *= 1.0 / self.routed_scaling_factor
-
+
if self.shared_experts is not None:
assert shared_output is not None
final_hidden_states += shared_output
+
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0
@@ -596,7 +522,7 @@ class DeepseekV2Attention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
-
+
def forward(
self,
positions: torch.Tensor,
@@ -605,23 +531,20 @@ class DeepseekV2Attention(nn.Module):
) -> torch.Tensor:
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
+ kv_a, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1)
q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
- q = self.q_proj(hidden_states)[0].view(
- -1, self.num_local_heads, self.qk_head_dim
- )
- q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
- latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
- kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
- latent_cache = latent_cache.unsqueeze(1)
+ q = self.q_proj(hidden_states)[0]
+ kv_a, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1)
+ q = q.view(-1, self.num_local_heads, self.qk_head_dim)
+
+ _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
+
kv_a = self.kv_a_layernorm(kv_a)
kv = self.kv_b_proj(kv_a)[0]
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
- k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
- k_pe = latent_cache[:, :, self.kv_lora_rank :]
-
- q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
+ k_nope, v_nope = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
q[..., self.qk_nope_head_dim :] = q_pe
k = torch.empty_like(q)
@@ -671,7 +594,7 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
def get_attn_backend(self) -> AttentionBackend:
return DeepseekV32IndexerBackend
-
+
class Indexer(nn.Module):
def __init__(
@@ -727,8 +650,8 @@ class Indexer(nn.Module):
# where we store value in fp8 and scale in fp32
# per self.quant_block_size element
self.k_cache = DeepseekV32IndexerCache(
- head_dim=self.head_dim + self.head_dim // self.quant_block_size * 4,
- dtype=torch.uint8,
+ head_dim=self.head_dim,
+ dtype=torch.bfloat16,
prefix=f"{prefix}.k_cache",
cache_config=cache_config,
)
@@ -776,23 +699,61 @@ class Indexer(nn.Module):
k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1)
# we only quant q here since k quant is fused with cache insertion
- q = q.view(-1, self.head_dim)
- q_fp8, q_scale = per_token_group_quant_fp8(
- q,
- self.quant_block_size,
- column_major_scales=False,
- use_ue8m0=self.scale_fmt is not None,
- )
- q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
- q_scale = q_scale.view(-1, self.n_head, 1)
+ # q = q.view(-1, self.head_dim)
+ # q_fp8, q_scale = per_token_group_quant_fp8(
+ # q,
+ # self.quant_block_size,
+ # column_major_scales=False,
+ # use_ue8m0=self.scale_fmt is not None,
+ # )
+ # q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
+ # q_scale = q_scale.view(-1, self.n_head, 1)
weights, _ = self.weights_proj(hidden_states)
weights = (
- weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
+ weights.unsqueeze(-1) * self.softmax_scale * self.n_head**-0.5
)
weights = weights.squeeze(-1)
- return self.indexer_op(hidden_states, q_fp8, k, weights)
+ return self.indexer_op(hidden_states, q, k, weights)
+
+
+def _min_latency_fused_qkv_a_proj_impl(
+ input_: torch.Tensor,
+ weight: torch.Tensor,
+) -> torch.Tensor:
+ """
+ Dynamically run min-latency gemm if num_tokens <= 16.
+ This must be wrapped in a custom op because our torch.compile integration
+ does not support runtime dispatching on num_tokens.
+ """
+ num_tokens = input_.shape[0]
+ if 0 < num_tokens <= 16:
+ output = torch.empty(
+ num_tokens,
+ weight.shape[0],
+ dtype=torch.bfloat16,
+ device=input_.device,
+ )
+ ops.dsv3_fused_a_gemm(output, input_, weight.T)
+ return output
+ else:
+ return torch.nn.functional.linear(input_, weight)
+
+
+def _min_latency_fused_qkv_a_proj_fake(
+ input_: torch.Tensor,
+ weight: torch.Tensor,
+) -> torch.Tensor:
+ return input_.new_empty(input_.shape[0], weight.shape[0])
+
+
+direct_register_custom_op(
+ op_name="min_latency_fused_qkv_a_proj",
+ op_func=_min_latency_fused_qkv_a_proj_impl,
+ mutates_args=[],
+ fake_impl=_min_latency_fused_qkv_a_proj_fake,
+)
class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear):
@@ -830,19 +791,8 @@ class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear):
self,
input_,
) -> torch.Tensor | tuple[torch.Tensor, torch.nn.Parameter | None]:
- num_tokens = input_.shape[0]
- if self._use_min_latency_gemm and (0 < num_tokens <= 16):
- output = torch.empty(
- num_tokens,
- 2112,
- dtype=torch.bfloat16,
- device=input_.device,
- )
- ops.dsv3_fused_a_gemm(
- output,
- input_,
- self.weight.T,
- )
+ if self._use_min_latency_gemm:
+ output = torch.ops.vllm.min_latency_fused_qkv_a_proj(input_, self.weight)
if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
@@ -898,47 +848,35 @@ class DeepseekV2MLAAttention(nn.Module):
self.max_position_embeddings = max_position_embeddings
if self.q_lora_rank is not None:
- # self.fused_qkv_a_proj = DeepSeekV2FusedQkvAProj(
- # self.hidden_size,
- # [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
- # quant_config=quant_config,
- # prefix=f"{prefix}.fused_qkv_a_proj",
- # )
- self.fused_qkv_a_proj = MergedColumnParallelLinear(
- self.hidden_size,
- [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
- bias=False,
- quant_config=quant_config,
- prefix=f"{prefix}.fused_qkv_a_proj",
- disable_tp=True,
- )
+ self.q_a_proj = ReplicatedLinear(self.hidden_size,
+ self.q_lora_rank,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.q_a_proj")
+ self.q_a_layernorm = RMSNorm(self.q_lora_rank,
+ eps=config.rms_norm_eps)
+ self.q_b_proj = ColumnParallelLinear(self.q_lora_rank,
+ self.num_heads *
+ self.qk_head_dim,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.q_b_proj")
else:
- self.kv_a_proj_with_mqa = ReplicatedLinear(
- self.hidden_size,
- self.kv_lora_rank + self.qk_rope_head_dim,
- bias=False,
- quant_config=quant_config,
- prefix=f"{prefix}.kv_a_proj_with_mqa",
- )
+ self.q_proj = ColumnParallelLinear(self.hidden_size,
+ self.num_heads *
+ self.qk_head_dim,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.q_proj")
- if self.q_lora_rank is not None:
- self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
- self.q_b_proj = ColumnParallelLinear(
- self.q_lora_rank,
- self.num_heads * self.qk_head_dim,
- bias=False,
- quant_config=quant_config,
- prefix=f"{prefix}.q_b_proj",
- )
- else:
- self.q_proj = ColumnParallelLinear(
- self.hidden_size,
- self.num_heads * self.qk_head_dim,
- bias=False,
- quant_config=quant_config,
- prefix=f"{prefix}.q_proj",
- )
- self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
+ self.kv_a_proj_with_mqa = ReplicatedLinear(
+ self.hidden_size,
+ self.kv_lora_rank + self.qk_rope_head_dim,
+ bias=False,
+ quant_config=quant_config,
+ prefix=f"{prefix}.kv_a_proj_with_mqa")
+ self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
+ eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
@@ -1005,9 +943,7 @@ class DeepseekV2MLAAttention(nn.Module):
kv_b_proj=self.kv_b_proj,
rotary_emb=self.rotary_emb,
o_proj=self.o_proj,
- fused_qkv_a_proj=self.fused_qkv_a_proj
- if self.q_lora_rank is not None
- else None,
+ q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa
if self.q_lora_rank is None
else None,
@@ -1346,14 +1282,14 @@ class DeepseekV2ForCausalLM(
# initializing DeepseekV2Model, as it is passed inplace to
# quantization config init and may be used to select the
# quant_method for relevant layers during initialization.
- self.fuse_qkv_a_proj = (
- hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
- )
- if self.fuse_qkv_a_proj:
- self.packed_modules_mapping["fused_qkv_a_proj"] = [
- "q_a_proj",
- "kv_a_proj_with_mqa",
- ]
+ # self.fuse_qkv_a_proj = (
+ # hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
+ # )
+ # if self.fuse_qkv_a_proj:
+ # self.packed_modules_mapping["fused_qkv_a_proj"] = [
+ # "q_a_proj",
+ # "kv_a_proj_with_mqa",
+ # ]
self.model = self.model_cls(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
@@ -1385,19 +1321,19 @@ class DeepseekV2ForCausalLM(
self.moe_layers = []
self.moe_mlp_layers = []
example_moe = None
+
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
-
assert isinstance(layer, DeepseekV2DecoderLayer)
if isinstance(layer.mlp, DeepseekV2MoE):
# Pick last one layer since the first ones may be dense layers.
example_moe = layer.mlp
self.moe_mlp_layers.append(layer.mlp)
self.moe_layers.append(layer.mlp.experts)
-
self.extract_moe_parameters(example_moe)
+
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
@@ -1441,10 +1377,10 @@ class DeepseekV2ForCausalLM(
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
- mla_params_mapping = [
- ("fused_qkv_a_proj", "q_a_proj", 0),
- ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
- ]
+ # mla_params_mapping = [
+ # ("fused_qkv_a_proj", "q_a_proj", 0),
+ # ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
+ # ]
mha_params_mapping = [
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
@@ -1452,8 +1388,8 @@ class DeepseekV2ForCausalLM(
]
if self.use_mha:
stacked_params_mapping.extend(mha_params_mapping)
- else:
- stacked_params_mapping.extend(mla_params_mapping)
+ # else:
+ # stacked_params_mapping.extend(mla_params_mapping)
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
@@ -1474,168 +1410,232 @@ class DeepseekV2ForCausalLM(
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
- if "rotary_emb.inv_freq" in name:
- continue
-
- spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
- if spec_layer is not None:
- continue # skip spec decode layers for main model
-
- is_fusion_moe_shared_experts_layer = (
- rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
- )
-
- for param_name, weight_name, shard_id in stacked_params_mapping:
- # Skip non-stacked layers and experts (experts handled below).
- if weight_name not in name:
- continue
- # We have mlp.experts[0].gate_proj in the checkpoint.
- # Since we handle the experts below in expert_params_mapping,
- # we need to skip here BEFORE we update the name, otherwise
- # name will be updated to mlp.experts[0].gate_up_proj, which
- # will then be updated below in expert_params_mapping
- # for mlp.experts[0].gate_gate_up_proj, which breaks load.
- if ("mlp.experts." in name) and name not in params_dict:
- continue
- if is_fusion_moe_shared_experts_layer:
- continue
- name_mapped = name.replace(weight_name, param_name)
-
- # QKV fusion is optional, fall back to normal
- # weight loading if it's not enabled
- # if go with fusion option, then update name
- if (
- param_name == "fused_qkv_a_proj"
- ) and name_mapped not in params_dict:
- continue
- else:
- name = name_mapped
- # Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
+ try:
+ if "rotary_emb.inv_freq" in name:
continue
- if is_pp_missing_parameter(name, self):
- continue
+ spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
+ if spec_layer is not None:
+ continue # skip spec decode layers for main model
- param = params_dict[name]
- weight_loader = param.weight_loader
- weight_loader(param, loaded_weight, shard_id)
- break
- else:
- is_expert_weight = False
-
- # Special handling: when AITER fusion_shared_experts is enabled,
- # checkpoints may provide a single widened shared_experts tensor
- # without explicit expert indices
- # (e.g. ...mlp.shared_experts.gate_proj.weight).
- # For models with multiple shared experts, split that tensor
- # evenly into per-shared-expert slices and load them into
- # appended expert slots mlp.experts.{n_routed_experts + j}.*
- # accordingly.
- num_chunks = 1
- if is_fusion_moe_shared_experts_layer:
- num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
- # Determine split axis based on op type
- # gate/up: ColumnParallel → split along dim 0
- # down: RowParallel → split along dim 1
- split_dim = (
- 1
- if ("down_proj.weight" in name and loaded_weight.ndim > 1)
- else 0
- )
- total = loaded_weight.shape[split_dim]
- assert total % num_chunks == 0, (
- f"Shared expert weight dim {total} "
- f"not divisible by num_chunks {num_chunks}"
- )
- chunk_size = total // num_chunks
-
- for j in range(num_chunks):
- chunk_name = name
- weight_to_load = loaded_weight
+ is_fusion_moe_shared_experts_layer = (
+ rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
+ )
+ for param_name, weight_name, shard_id in stacked_params_mapping:
+ # Skip non-stacked layers and experts (experts handled below).
+ if weight_name not in name:
+ continue
+ # We have mlp.experts[0].gate_proj in the checkpoint.
+ # Since we handle the experts below in expert_params_mapping,
+ # we need to skip here BEFORE we update the name, otherwise
+ # name will be updated to mlp.experts[0].gate_up_proj, which
+ # will then be updated below in expert_params_mapping
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
+ if ("mlp.experts." in name) and name not in params_dict:
+ continue
if is_fusion_moe_shared_experts_layer:
- chunk_slice = slice(j * chunk_size, (j + 1) * chunk_size)
- if loaded_weight.ndim == 1:
- weight_to_load = loaded_weight[chunk_slice]
- elif split_dim == 0:
- weight_to_load = loaded_weight[chunk_slice, :]
- else:
- weight_to_load = loaded_weight[:, chunk_slice]
- # Synthesize an expert-style name so expert mapping
- # can route it
- chunk_name = name.replace(
- "mlp.shared_experts",
- f"mlp.experts.{self.config.n_routed_experts + j}",
- )
+ continue
+ name_mapped = name.replace(weight_name, param_name)
- # Use expert_params_mapping to locate the destination
- # param and delegate to its expert-aware weight_loader
- # with expert_id.
- for mapping in expert_params_mapping:
- param_name, weight_name, expert_id, shard_id = mapping
- if weight_name not in chunk_name:
- continue
-
- # Anyway, this is an expert weight and should not be
- # attempted to load as other weights later
- is_expert_weight = True
-
- # Do not modify `name` since the loop may continue here
- # Instead, create a new variable
- name_mapped = chunk_name.replace(weight_name, param_name)
-
- if is_pp_missing_parameter(name_mapped, self):
- continue
-
- param = params_dict[name_mapped]
- # We should ask the weight loader to return success or
- # not here since otherwise we may skip experts with
- # other available replicas.
- weight_loader = typing.cast(
- Callable[..., bool], param.weight_loader
- )
- success = weight_loader(
- param,
- weight_to_load,
- name_mapped,
- shard_id=shard_id,
- expert_id=expert_id,
- return_success=True,
- )
- if success:
- if not is_fusion_moe_shared_experts_layer:
- name = name_mapped
- else:
- loaded_params.add(name_mapped)
- break
+ # QKV fusion is optional, fall back to normal
+ # weight loading if it's not enabled
+ # if go with fusion option, then update name
+ if (
+ param_name == "fused_qkv_a_proj"
+ ) and name_mapped not in params_dict:
+ continue
else:
- if is_expert_weight:
- # We've checked that this is an expert weight
- # However it's not mapped locally to this rank
- # So we simply skip it
- continue
+ name = name_mapped
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
- # Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
- continue
+ if is_pp_missing_parameter(name, self):
+ continue
- # Remapping the name of FP8 kv-scale.
- name = maybe_remap_kv_scale_name(name, params_dict)
- if name is None:
- continue
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ is_expert_weight = False
- if is_pp_missing_parameter(name, self):
- continue
-
- param = params_dict[name]
- weight_loader = getattr(
- param, "weight_loader", default_weight_loader
+ # Special handling: when AITER fusion_shared_experts is enabled,
+ # checkpoints may provide a single widened shared_experts tensor
+ # without explicit expert indices
+ # (e.g. ...mlp.shared_experts.gate_proj.weight).
+ # For models with multiple shared experts, split that tensor
+ # evenly into per-shared-expert slices and load them into
+ # appended expert slots mlp.experts.{n_routed_experts + j}.*
+ # accordingly.
+ num_chunks = 1
+ if is_fusion_moe_shared_experts_layer:
+ num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
+ # Determine split axis based on op type
+ # gate/up: ColumnParallel → split along dim 0
+ # down: RowParallel → split along dim 1
+ split_dim = (
+ 1
+ if ("down_proj.weight" in name and loaded_weight.ndim > 1)
+ else 0
)
- weight_loader(param, loaded_weight)
- if name is not None and not is_fusion_moe_shared_experts_layer:
- loaded_params.add(name)
+ total = loaded_weight.shape[split_dim]
+ assert total % num_chunks == 0, (
+ f"Shared expert weight dim {total} "
+ f"not divisible by num_chunks {num_chunks}"
+ )
+ chunk_size = total // num_chunks
+ for j in range(num_chunks):
+ chunk_name = name
+ weight_to_load = loaded_weight
+
+ if is_fusion_moe_shared_experts_layer:
+ chunk_slice = slice(j * chunk_size, (j + 1) * chunk_size)
+ if loaded_weight.ndim == 1:
+ weight_to_load = loaded_weight[chunk_slice]
+ elif split_dim == 0:
+ weight_to_load = loaded_weight[chunk_slice, :]
+ else:
+ weight_to_load = loaded_weight[:, chunk_slice]
+ # Synthesize an expert-style name so expert mapping
+ # can route it
+ chunk_name = name.replace(
+ "mlp.shared_experts",
+ f"mlp.experts.{self.config.n_routed_experts + j}",
+ )
+
+ # Use expert_params_mapping to locate the destination
+ # param and delegate to its expert-aware weight_loader
+ # with expert_id.
+ for mapping in expert_params_mapping:
+ param_name, weight_name, expert_id, shard_id = mapping
+ if weight_name not in chunk_name:
+ continue
+ # Anyway, this is an expert weight and should not be
+ # attempted to load as other weights later
+ is_expert_weight = True
+
+ # Do not modify `name` since the loop may continue here
+ # Instead, create a new variable
+ name_mapped = chunk_name.replace(weight_name, param_name)
+
+ if is_pp_missing_parameter(name_mapped, self):
+ continue
+
+ param = params_dict[name_mapped]
+ # We should ask the weight loader to return success or
+ # not here since otherwise we may skip experts with
+ # other available replicas.
+ weight_loader = typing.cast(
+ Callable[..., bool], param.weight_loader
+ )
+ success = weight_loader(
+ param,
+ weight_to_load,
+ name_mapped,
+ shard_id=shard_id,
+ expert_id=expert_id,
+ return_success=True,
+ )
+ if success:
+ if not is_fusion_moe_shared_experts_layer:
+ name = name_mapped
+ else:
+ loaded_params.add(name_mapped)
+ break
+ else:
+ if is_expert_weight:
+ # We've checked that this is an expert weight
+ # However it's not mapped locally to this rank
+ # So we simply skip it
+ continue
+ # Skip loading extra bias for GPTQ models.
+ if name.endswith(".bias") and name not in params_dict:
+ continue
+
+ # Remapping the name of FP8 kv-scale.
+ name = maybe_remap_kv_scale_name(name, params_dict)
+ if name is None:
+ continue
+
+ if is_pp_missing_parameter(name, self):
+ continue
+
+ param = params_dict[name]
+ weight_loader = getattr(
+ param, "weight_loader", default_weight_loader
+ )
+ weight_loader(param, loaded_weight)
+ if name is not None and not is_fusion_moe_shared_experts_layer:
+ loaded_params.add(name)
+ except:
+ pass
+
+ opt_support_quant_method = ["GGUFLinearMethod", "UnquantizedLinearMethod", "CompressedTensorsW8A8Int8", "AWQMarlinLinearMethod"]
+ # add your opt here..
+ def inject_layer(layer, quant_method, is_mla):
+ q_lora_rank = getattr(layer, "q_lora_rank", None)
+ if quant_method in ["UnquantizedLinearMethod", "CompressedTensorsW8A8Int8"]:
+ if q_lora_rank is not None:
+ layer.q_a_proj.weight.data = torch.cat([layer.q_a_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=0)
+ if hasattr(layer.q_a_proj, "weight_scale"):
+ layer.q_a_proj.weight_scale.data = torch.cat([layer.q_a_proj.weight_scale, layer.kv_a_proj_with_mqa.weight_scale], dim=0)
+ del layer.kv_a_proj_with_mqa.weight_scale
+ elif not is_mla:
+ layer.q_proj.weight.data = torch.cat([layer.q_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=0)
+ if hasattr(layer.q_proj, "weight_scale"):
+ layer.q_proj.weight_scale.data = torch.cat([layer.q_proj.weight_scale, layer.kv_a_proj_with_mqa.weight_scale], dim=0)
+ del layer.kv_a_proj_with_mqa.weight_scale
+ else:
+ return
+ del layer.kv_a_proj_with_mqa.weight
+ del layer.kv_a_proj_with_mqa
+ if is_mla:
+ layer.mla_attn.forward = layer.mla_attn.forward_opt
+ else:
+ layer.forward = layer.forward_opt
+ elif quant_method == "GGUFLinearMethod":
+ pass
+ elif quant_method == "AWQMarlinLinearMethod":
+ dtype = layer.kv_a_proj_with_mqa.qweight.dtype
+ assert dtype == torch.int32
+ if layer.q_lora_rank is not None:
+ layer.q_a_proj.qweight.data = torch.cat([layer.q_a_proj.qweight, layer.kv_a_proj_with_mqa.qweight], dim=1)
+ layer.q_a_proj.scales.data = torch.cat([layer.q_a_proj.scales, layer.kv_a_proj_with_mqa.scales], dim=1)
+ del layer.kv_a_proj_with_mqa.scales
+ layer.q_a_proj.qzeros.data = torch.cat([layer.q_a_proj.qzeros, layer.kv_a_proj_with_mqa.qzeros], dim=1)
+ del layer.kv_a_proj_with_mqa.qzeros
+ elif not is_mla:
+ layer.q_proj.weight.data = torch.cat([layer.q_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=1)
+ layer.q_proj.scales.data = torch.cat([layer.q_proj.scales, layer.kv_a_proj_with_mqa.scales], dim=1)
+ del layer.kv_a_proj_with_mqa.scales
+ layer.q_proj.qzeros.data = torch.cat([layer.q_proj.qzeros, layer.kv_a_proj_with_mqa.qzeros], dim=1)
+ del layer.kv_a_proj_with_mqa.qzeros
+ else:
+ return
+
+ del layer.kv_a_proj_with_mqa.qweight
+ del layer.kv_a_proj_with_mqa
+ if is_mla:
+ layer.mla_attn.forward = layer.mla_attn.forward_opt
+ else:
+ layer.forward = layer.forward_opt
+ else:
+ pass
+
+ for _, layer in self.model.named_modules():
+ if layer.__class__.__name__ in ["DeepseekV2Attention","DeepseekV2MLAAttention"]:
+ if hasattr(layer.kv_a_proj_with_mqa, "scheme"):
+ quant_method = layer.kv_a_proj_with_mqa.scheme.__class__.__name__
+ else:
+ quant_method = layer.kv_a_proj_with_mqa.quant_method.__class__.__name__
+ if quant_method not in opt_support_quant_method:
+ break
+
+ inject_layer(layer, quant_method, is_mla = layer.__class__.__name__ == "DeepseekV2MLAAttention")
+
return loaded_params
diff --git a/vllm/model_executor/models/ernie45_moe.py b/vllm/model_executor/models/ernie45_moe.py
index f038cfb..3cfdb4d 100644
--- a/vllm/model_executor/models/ernie45_moe.py
+++ b/vllm/model_executor/models/ernie45_moe.py
@@ -164,7 +164,7 @@ class Ernie4_5_MoeMoE(nn.Module):
config.hidden_size,
config.moe_num_experts,
bias=False,
- params_dtype=torch.float32,
+ # params_dtype=torch.float32,
quant_config=None,
prefix=f"{prefix}.gate",
)
@@ -209,7 +209,7 @@ class Ernie4_5_MoeMoE(nn.Module):
hidden_dim = hidden_states.shape[-1]
hidden_states = hidden_states.view(-1, hidden_dim)
- router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32))
+ router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
@@ -429,7 +429,8 @@ class Ernie4_5_MoeModel(nn.Module):
self.num_redundant_experts = eplb_config.num_redundant_experts
- if get_pp_group().is_first_rank:
+ if get_pp_group().is_first_rank or (config.tie_word_embeddings
+ and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
@@ -653,11 +654,11 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExpe
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
+ if self.config.tie_word_embeddings:
+ self.lm_head.weight = self.model.embed_tokens.weight
else:
self.lm_head = PPMissingLayer()
- if self.config.tie_word_embeddings:
- self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py
index 1df4adf..edf4c2c 100644
--- a/vllm/model_executor/models/ernie45_vl.py
+++ b/vllm/model_executor/models/ernie45_vl.py
@@ -829,16 +829,31 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo):
spatial_conv_size = hf_config.spatial_conv_size
temporal_conv_size = hf_config.temporal_conv_size
+ if self.ctx.model_config.trust_remote_code:
+ # Defined in HF Hub repo
+ min_pixels_key = "min_pixels"
+ max_pixels_key = "max_pixels"
+ else:
+ # Defined in Transformers library (requires v5.0 or above)
+ min_pixels_key = "shortest_edge"
+ max_pixels_key = "longest_edge"
+
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
- size = mm_kwargs.get("size", image_processor.size)
+ size = image_processor.size
+ if override_size := mm_kwargs.get("size"):
+ size = size | override_size
+ if (override_min_pixels := mm_kwargs.get("min_pixels")) is not None:
+ size = size | {min_pixels_key: override_min_pixels}
+ if (override_max_pixels := mm_kwargs.get("max_pixels")) is not None:
+ size = size | {max_pixels_key: override_max_pixels}
if do_resize:
resized_height, resized_width = smart_resize(
height=image_height,
width=image_width,
factor=patch_size * spatial_conv_size,
- min_pixels=size["min_pixels"],
- max_pixels=size["max_pixels"],
+ min_pixels=size[min_pixels_key],
+ max_pixels=size[max_pixels_key],
)
preprocessed_size = ImageSize(width=resized_width, height=resized_height)
else:
diff --git a/vllm/model_executor/models/extract_hidden_states.py b/vllm/model_executor/models/extract_hidden_states.py
new file mode 100644
index 0000000..ae9bdb5
--- /dev/null
+++ b/vllm/model_executor/models/extract_hidden_states.py
@@ -0,0 +1,394 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+"""Hidden States Extractor Model.
+
+This model extracts and caches hidden states from the target model
+without performing actual token generation. It's used with the
+extract_hidden_states speculative decoding method.
+"""
+
+from collections.abc import Iterable
+from typing import ClassVar
+
+import torch
+import torch.nn as nn
+
+from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
+from vllm.config.cache import CacheDType
+from vllm.forward_context import get_forward_context
+from vllm.model_executor.layers.attention.attention import set_default_quant_scales
+from vllm.model_executor.layers.attention.kv_transfer_utils import (
+ maybe_transfer_kv_layer,
+)
+from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
+from vllm.model_executor.models.utils import maybe_prefix
+from vllm.utils.torch_utils import kv_cache_dtype_str_to_dtype
+from vllm.v1.attention.backend import (
+ AttentionBackend,
+ AttentionImpl,
+ AttentionMetadataBuilder,
+ AttentionType,
+ CommonAttentionMetadata,
+ is_quantized_kv_cache,
+)
+from vllm.v1.kv_cache_interface import (
+ AttentionSpec,
+ KVCacheSpec,
+ MLAAttentionSpec,
+)
+
+########## Custom Ops ########
+
+
+def unified_kv_cache_update(
+ to_cache: torch.Tensor,
+ layer_name: str,
+) -> torch.Tensor:
+ """
+ Returns a dummy that is passed to unified_attention to signal a side effect and
+ the data dependency between them to ensure torch.compile preserves ordering.
+ """
+ forward_context = get_forward_context()
+ attn_layer = forward_context.no_compile_layers[layer_name]
+ kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
+
+ slot_mapping = forward_context.slot_mapping
+ assert isinstance(slot_mapping, dict), (
+ f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
+ )
+ layer_slot_mapping = slot_mapping.get(layer_name)
+ if layer_slot_mapping is not None:
+ assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
+ f"{attn_layer.impl.__class__.__name__} does not support kv cache update"
+ )
+ attn_layer.impl.do_kv_cache_update(
+ attn_layer,
+ to_cache,
+ kv_cache,
+ layer_slot_mapping,
+ )
+
+ return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
+
+
+@maybe_transfer_kv_layer
+def dummy_attention(layer_name, _placeholder):
+ # Note: layer_name arg required by @maybe_transfer_kv_layer
+ return _placeholder
+
+
+def basic_cache(
+ to_cache: torch.Tensor, # shape: [num_blocks, block_size, num_heads, head_size]
+ kv_cache: torch.Tensor, # shape: [seq_len, num_heads, head_size]
+ slot_mapping: torch.Tensor, # shape: [seq_len]
+):
+ num_blocks, block_size, num_heads, head_size = kv_cache.shape
+ token_kv_cache = kv_cache.view(num_blocks * block_size, num_heads, head_size)
+ token_kv_cache[slot_mapping] = to_cache
+
+
+######### CacheOnlyAttentionBackend ########
+
+
+class CacheOnlyAttentionBackend(AttentionBackend):
+ """Attention backend that only caches KV without computing attention."""
+
+ accept_output_buffer: bool = False
+ supported_dtypes: ClassVar[list[torch.dtype]] = [
+ torch.float16,
+ torch.bfloat16,
+ torch.float32,
+ ]
+ supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
+ "auto",
+ "bfloat16",
+ ]
+ forward_includes_kv_cache_update: bool = False
+
+ @staticmethod
+ def get_name() -> str:
+ return "CACHE_ONLY_ATTN"
+
+ @classmethod
+ def supports_attn_type(cls, attn_type: str) -> bool:
+ return attn_type == AttentionType.DECODER
+
+ @classmethod
+ def supports_mm_prefix(cls) -> bool:
+ return True
+
+ @staticmethod
+ def get_impl_cls() -> type["CacheOnlyAttentionImpl"]:
+ return CacheOnlyAttentionImpl
+
+ @staticmethod
+ def get_kv_cache_shape(
+ num_blocks: int,
+ block_size: int,
+ num_kv_heads: int,
+ head_size: int,
+ cache_dtype_str: str = "auto",
+ ) -> tuple[int, ...]:
+ # We set `num_kv_heads = num_hidden_layers` and `head_size = hidden_size`
+ # We also don't use a k/v (2) dim
+ return (num_blocks, block_size, num_kv_heads, head_size)
+
+ @staticmethod
+ def get_builder_cls() -> type["CacheOnlyAttentionMetadataBuilder"]:
+ return CacheOnlyAttentionMetadataBuilder
+
+ @staticmethod
+ def use_cascade_attention(*args, **kwargs) -> bool:
+ return False
+
+ @classmethod
+ def get_supported_head_sizes(cls) -> list[int]:
+ return []
+
+
+class CacheOnlyAttentionMetadata:
+ def __init__(self, slot_mapping: torch.Tensor):
+ self.slot_mapping = slot_mapping
+
+
+class CacheOnlyAttentionMetadataBuilder(
+ AttentionMetadataBuilder[CacheOnlyAttentionMetadata]
+):
+ def __init__(
+ self,
+ kv_cache_spec: AttentionSpec,
+ layer_names: list[str],
+ vllm_config: VllmConfig,
+ device: torch.device,
+ ):
+ super().__init__(kv_cache_spec, layer_names, vllm_config, device)
+
+ def build(
+ self,
+ common_prefix_len: int,
+ common_attn_metadata: CommonAttentionMetadata,
+ fast_build: bool = False,
+ ) -> CacheOnlyAttentionMetadata:
+ use_cascade = common_prefix_len > 0
+ if use_cascade:
+ raise NotImplementedError(
+ "Cascade attention not supported by CacheOnlyAttention"
+ )
+ causal = common_attn_metadata.causal
+ if not causal:
+ raise NotImplementedError(
+ "Non-causal attention not supported by CacheOnlyAttention"
+ )
+
+ return CacheOnlyAttentionMetadata(
+ slot_mapping=common_attn_metadata.slot_mapping,
+ )
+
+
+class CacheOnlyAttentionImpl(AttentionImpl):
+ """Attention implementation that only caches KV states."""
+
+ def __init__(
+ self,
+ num_heads: int,
+ head_size: int,
+ kv_cache_dtype: str,
+ kv_cache_torch_dtype: torch.dtype,
+ attn_type: AttentionType = AttentionType.DECODER,
+ ) -> None:
+ self.num_heads = num_heads
+ self.head_size = head_size
+ self.kv_cache_dtype = kv_cache_dtype
+ self.kv_cache_torch_dtype = kv_cache_torch_dtype
+
+ if attn_type != AttentionType.DECODER:
+ raise NotImplementedError(f"Unsupported attention type: {attn_type}")
+ if is_quantized_kv_cache(kv_cache_dtype):
+ raise NotImplementedError("Quantized KV cache not supported")
+
+ self.num_queries_per_kv = 1
+
+ def do_kv_cache_update(
+ self,
+ layer,
+ to_cache,
+ kv_cache,
+ slot_mapping,
+ ):
+ assert to_cache.dtype == self.kv_cache_torch_dtype, (
+ f"Data to cache must be {self.kv_cache_torch_dtype}, got {to_cache.dtype}"
+ )
+ assert kv_cache.dtype == self.kv_cache_torch_dtype, (
+ f"KV cache must be {self.kv_cache_torch_dtype}, got {kv_cache.dtype}"
+ )
+
+ basic_cache(to_cache, kv_cache, slot_mapping)
+
+ def forward(self, *args, **kwargs):
+ # Empty implementation of abstract method
+ pass
+
+
+############## CacheOnlyAttentionLayer (replaces Attention) ############
+
+
+class CacheOnlyAttentionLayer(nn.Module, AttentionLayerBase):
+ """Attention layer that only caches key/value states without computing attention."""
+
+ def __init__(
+ self,
+ num_heads: int,
+ head_size: int,
+ cache_config: CacheConfig | None = None,
+ prefix: str = "",
+ attn_type: str = AttentionType.DECODER,
+ ):
+ super().__init__()
+
+ self.num_heads = num_heads
+ self.head_size = head_size
+ self.layer_name = prefix
+
+ vllm_config = get_current_vllm_config()
+
+ # KV cache configuration
+ cache_config = cache_config or vllm_config.cache_config
+ if cache_config is not None:
+ kv_cache_dtype = cache_config.cache_dtype
+ self.block_size = cache_config.block_size
+ else:
+ kv_cache_dtype = "auto"
+ self.block_size = 16
+
+ assert kv_cache_dtype in ["auto", "bfloat16", "float16"], (
+ "CacheOnlyAttentionLayer doesn't currently support quantized kv cache but"
+ f"kv cache dtype was set to {kv_cache_dtype}"
+ )
+ self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
+ kv_cache_dtype, vllm_config.model_config
+ )
+
+ # Initialize KV cache quantization attributes
+ set_default_quant_scales(self, register_buffer=True)
+
+ # Attention backend
+ self.attn_backend = CacheOnlyAttentionBackend
+ impl_cls = self.attn_backend.get_impl_cls()
+ self.impl = impl_cls(
+ num_heads,
+ head_size,
+ kv_cache_dtype,
+ self.kv_cache_torch_dtype,
+ attn_type,
+ )
+
+ assert not self.attn_backend.forward_includes_kv_cache_update, (
+ "KV cache update should be independent of forward"
+ )
+
+ # Placeholder KV cache (replaced by bind_kv_cache)
+ self.kv_cache = [
+ torch.tensor([])
+ for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
+ ]
+
+ # Register in compilation context
+ compilation_config = vllm_config.compilation_config
+ if prefix in compilation_config.static_forward_context:
+ raise ValueError(f"Duplicate layer name: {prefix}")
+ compilation_config.static_forward_context[prefix] = self
+
+ def forward(self, to_cache: torch.Tensor) -> torch.Tensor:
+ """Cache hidden states as KV pairs without computing attention.
+
+ Args:
+ to_cache: The tensor to insert into the kv cache.
+ shape [num_tokens, num_heads, head_size]
+
+ Returns:
+ Dummy output tensor (not used)
+ """
+ # Note: we set num_heads to num_hidden_layers and
+ # head_size to hidden_size for hidden states storage
+ output = torch.empty(0, device=to_cache.device, dtype=to_cache.dtype)
+
+ # Note: dummy_out is used to force torch.compile to preserve ordering between
+ # cache update and attention op (which triggers kv_connector transfer)
+ dummy_out = unified_kv_cache_update(to_cache, self.layer_name)
+
+ # Triggers kv_connector transfer via decorator
+ _ = dummy_attention(self.layer_name, dummy_out)
+
+ return output
+
+ def get_attn_backend(self) -> type[AttentionBackend]:
+ return self.attn_backend
+
+ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
+ # Note: we use MLAAttentionSpec here to because it will
+ # produce page sizes of (block_size * num_kv_heads * head_size * dtype_size)
+ # whereas FullAttentionSpec will add an additional factor of 2
+ return MLAAttentionSpec(
+ block_size=self.block_size,
+ num_kv_heads=self.num_heads,
+ head_size=self.head_size,
+ dtype=self.kv_cache_torch_dtype,
+ )
+
+
+############ ExtractHiddenStatesModel definition ##########
+
+
+class ExtractHiddenStatesModel(nn.Module):
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+
+ self.vllm_config = vllm_config
+ self.hf_config = vllm_config.speculative_config.draft_model_config.hf_config
+ self.hidden_size = vllm_config.model_config.get_hidden_size()
+ self.target_num_hidden_layers = (
+ vllm_config.model_config.get_total_num_hidden_layers()
+ )
+ self.num_hidden_states = len(
+ getattr(self.hf_config, "eagle_aux_hidden_state_layer_ids", [])
+ )
+
+ cache_config = vllm_config.cache_config
+
+ # Create a single cache-only attention layer
+ # Note: We set num_heads <- self.num_hidden_states
+ # and head_size <- hidden_size so that we can insert
+ # the hidden states directly into the cache without
+ # reshaping
+ self.cache_only_layers = nn.ModuleDict(
+ {
+ str(self.target_num_hidden_layers): CacheOnlyAttentionLayer(
+ num_heads=self.num_hidden_states,
+ head_size=self.hidden_size,
+ cache_config=cache_config,
+ prefix=maybe_prefix(
+ prefix, f"cache_only_layers.{self.target_num_hidden_layers}"
+ ),
+ )
+ }
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> None:
+ """Process and cache hidden states.
+
+ Args:
+ hidden_states: Hidden states from target model
+ shape: [num_tokens, num_hidden_states, hidden_size]
+
+ Returns:
+ Tuple of (dummy_output, dummy_output) - both unused
+ """
+
+ # Call dummy attention layer to cache hidden states
+ # Output is ignored - we only care about the KV cache side effects
+ _ = self.cache_only_layers[str(self.target_num_hidden_layers)](hidden_states)
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ """No weights to load for this dummy model."""
+ return set()
diff --git a/vllm/model_executor/models/fireredasr2.py b/vllm/model_executor/models/fireredasr2.py
new file mode 100644
index 0000000..f0d3e12
--- /dev/null
+++ b/vllm/model_executor/models/fireredasr2.py
@@ -0,0 +1,829 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import math
+from collections.abc import Iterable, Mapping, Sequence
+from typing import Annotated, Literal, cast
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+from transformers import (
+ BatchFeature,
+ Qwen2Config,
+)
+
+from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
+from vllm.config.multimodal import BaseDummyOptions
+from vllm.inputs.data import PromptType
+from vllm.logger import init_logger
+from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
+from vllm.model_executor.layers.linear import (
+ ReplicatedLinear,
+)
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.models.whisper_utils import (
+ ISO639_1_SUPPORTED_LANGS,
+)
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import (
+ MultiModalDataDict,
+ MultiModalFieldConfig,
+ MultiModalKwargsItems,
+)
+from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
+from vllm.multimodal.processing import (
+ BaseDummyInputsBuilder,
+ BaseMultiModalProcessor,
+ BaseProcessingInfo,
+ PromptReplacement,
+ PromptUpdate,
+ PromptUpdateDetails,
+)
+from vllm.transformers_utils.processor import cached_processor_from_config
+from vllm.transformers_utils.processors.fireredasr2_processor import (
+ FireRedASR2FeatureExtractor,
+)
+from vllm.utils.tensor_schema import TensorSchema, TensorShape
+
+from .interfaces import (
+ MultiModalEmbeddings,
+ SupportsMultiModal,
+ SupportsTranscription,
+ _require_is_multimodal,
+)
+from .qwen2 import Qwen2ForCausalLM
+from .utils import (
+ AutoWeightsLoader,
+ WeightsMapper,
+ _merge_multimodal_embeddings,
+ maybe_prefix,
+)
+
+logger = init_logger(__name__)
+
+
+class FireRedASR2AudioInputs(TensorSchema):
+ """
+ Dimensions:
+ - b: Batch size
+ - nmb: Number of mel bins
+ - t: Time frames (M)
+ """
+
+ input_features: Annotated[
+ list[torch.Tensor] | None,
+ TensorShape("b", "nmb", "t"),
+ ]
+ speech_lengths: Annotated[
+ list[torch.Tensor] | None,
+ TensorShape("b"),
+ ]
+ fake_token_lengths: Annotated[
+ list[torch.Tensor] | None,
+ TensorShape("b"),
+ ]
+
+
+class Swish(nn.Module):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return x * torch.sigmoid(x)
+
+
+class Conv2dSubsampling(nn.Module):
+ def __init__(self, idim: int, d_model: int, out_channels: int = 32):
+ super().__init__()
+ self.conv = nn.Sequential(
+ nn.Conv2d(1, out_channels, 3, 2),
+ nn.ReLU(),
+ nn.Conv2d(out_channels, out_channels, 3, 2),
+ nn.ReLU(),
+ )
+ subsample_idim = ((idim - 1) // 2 - 1) // 2
+ self.out = ReplicatedLinear(
+ input_size=out_channels * subsample_idim,
+ output_size=d_model,
+ bias=True,
+ )
+
+ self.subsampling = 4
+ left_context = right_context = 3 # both exclude currect frame
+ self.context = left_context + 1 + right_context # 7
+
+ def forward(
+ self, x: torch.Tensor, x_mask: torch.Tensor
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ x = x.unsqueeze(1)
+ x = self.conv(x)
+ N, C, T, D = x.size()
+ x, _ = self.out(x.transpose(1, 2).contiguous().view(N, T, C * D))
+ mask = x_mask[:, :, :-2:2][:, :, :-2:2]
+ input_lengths = mask[:, -1, :].sum(dim=-1)
+ return x, input_lengths, mask
+
+
+class RelPositionalEncoding(nn.Module):
+ def __init__(self, d_model: int, max_len: int = 5000):
+ super().__init__()
+ pe_positive = torch.zeros(max_len, d_model, requires_grad=False)
+ pe_negative = torch.zeros(max_len, d_model, requires_grad=False)
+ position = torch.arange(0, max_len).unsqueeze(1).float()
+ div_term = torch.exp(
+ torch.arange(0, d_model, 2).float()
+ * -(torch.log(torch.tensor(10000.0)).item() / d_model)
+ )
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
+
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
+ pe_negative = pe_negative[1:].unsqueeze(0)
+ self.pe = torch.cat([pe_positive, pe_negative], dim=1)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # Tmax = 2 * max_len - 1
+ Tmax, T = self.pe.size(1), x.size(1)
+ pos_emb = self.pe[:, Tmax // 2 - T + 1 : Tmax // 2 + T].clone().detach()
+ return pos_emb
+
+
+class ConformerFeedForward(nn.Module):
+ def __init__(self, d_model: int):
+ super().__init__()
+ self.pre_layer_norm = nn.LayerNorm(d_model)
+ self.linear_expand = ReplicatedLinear(
+ input_size=d_model,
+ output_size=d_model * 4,
+ bias=True,
+ )
+ self.nonlinear = Swish()
+ self.linear_project = ReplicatedLinear(
+ input_size=d_model * 4,
+ output_size=d_model,
+ bias=True,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ residual = x
+ x = self.pre_layer_norm(x)
+ x, _ = self.linear_expand(x)
+ x = self.nonlinear(x)
+ x, _ = self.linear_project(x)
+ output = x + residual
+ return output
+
+
+class EncoderMultiHeadAttention(nn.Module):
+ def __init__(self, n_head: int, d_model: int):
+ super().__init__()
+ assert d_model % n_head == 0
+ self.n_head = n_head
+ self.d_k = d_model // n_head
+ self.d_v = self.d_k
+
+ self.w_qs = ReplicatedLinear(
+ input_size=d_model, output_size=n_head * self.d_k, bias=False
+ )
+ self.w_ks = ReplicatedLinear(
+ input_size=d_model, output_size=n_head * self.d_k, bias=False
+ )
+ self.w_vs = ReplicatedLinear(
+ input_size=d_model, output_size=n_head * self.d_v, bias=False
+ )
+
+ self.layer_norm_q = nn.LayerNorm(d_model)
+ self.layer_norm_k = nn.LayerNorm(d_model)
+ self.layer_norm_v = nn.LayerNorm(d_model)
+
+ self.fc = ReplicatedLinear(
+ input_size=n_head * self.d_v, output_size=d_model, bias=False
+ )
+
+ def forward_qkv(
+ self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
+ sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
+
+ q = self.layer_norm_q(q)
+ k = self.layer_norm_k(k)
+ v = self.layer_norm_v(v)
+
+ q = self.w_qs(q)[0].view(sz_b, len_q, n_head, d_k)
+ k = self.w_ks(k)[0].view(sz_b, len_k, n_head, d_k)
+ v = self.w_vs(v)[0].view(sz_b, len_v, n_head, d_v)
+ q = q.transpose(1, 2)
+ k = k.transpose(1, 2)
+ v = v.transpose(1, 2)
+ return q, k, v
+
+ def forward_output(
+ self, output: torch.Tensor, residual: torch.Tensor, sz_b: int, len_q: int
+ ) -> torch.Tensor:
+ output = output.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
+ fc_out, _ = self.fc(output)
+ output = fc_out
+ output = output + residual
+ return output
+
+ def forward_attention(
+ self, attn: torch.Tensor, v: torch.Tensor, mask: torch.Tensor | None = None
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ if mask is not None:
+ mask = mask.unsqueeze(1)
+ mask = mask.eq(0)
+ attn = attn.masked_fill(mask, -float("inf"))
+ attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0)
+ else:
+ attn = torch.softmax(attn, dim=-1)
+
+ d_attn = attn
+ output = torch.matmul(d_attn, v)
+
+ return output, attn
+
+
+class RelPosMultiHeadAttention(EncoderMultiHeadAttention):
+ def __init__(self, n_head: int, d_model: int):
+ super().__init__(n_head, d_model)
+ d_k = d_model // n_head
+ self.scale = 1.0 / (d_k**0.5)
+ self.linear_pos = ReplicatedLinear(
+ input_size=d_model, output_size=n_head * d_k, bias=False
+ )
+ self.pos_bias_u = nn.Parameter(torch.empty([n_head, d_k]))
+ self.pos_bias_v = nn.Parameter(torch.empty([n_head, d_k]))
+
+ def _rel_shift(self, x):
+ N, H, T1, T2 = x.size()
+ zero_pad = torch.zeros((N, H, T1, 1), device=x.device, dtype=x.dtype)
+ x_padded = torch.cat([zero_pad, x], dim=-1)
+
+ x_padded = x_padded.view(N, H, T2 + 1, T1)
+ x = x_padded[:, :, 1:].view_as(x)
+ x = x[:, :, :, : x.size(-1) // 2 + 1]
+ return x
+
+ def forward(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ pos_emb: torch.Tensor,
+ mask: torch.Tensor | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ sz_b, len_q = q.size(0), q.size(1)
+
+ residual = q
+ q, k, v = self.forward_qkv(q, k, v)
+
+ q = q.transpose(1, 2)
+ n_batch_pos = pos_emb.size(0)
+ p = self.linear_pos(pos_emb)[0].view(n_batch_pos, -1, self.n_head, self.d_k)
+ p = p.transpose(1, 2)
+
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+ matrix_bd = self._rel_shift(matrix_bd)
+
+ attn_scores = matrix_ac + matrix_bd
+ attn_scores.mul_(self.scale)
+
+ output, attn = self.forward_attention(attn_scores, v, mask=mask)
+
+ output = self.forward_output(output, residual, sz_b, len_q)
+ return output, attn
+
+
+class ConformerConvolution(nn.Module):
+ def __init__(self, d_model: int, kernel_size: int = 33):
+ super().__init__()
+ assert kernel_size % 2 == 1
+ self.pre_layer_norm = nn.LayerNorm(d_model)
+ self.pointwise_conv1 = nn.Conv1d(
+ d_model, d_model * 4, kernel_size=1, bias=False
+ )
+ self.padding = (kernel_size - 1) // 2
+ self.depthwise_conv = nn.Conv1d(
+ d_model * 2,
+ d_model * 2,
+ kernel_size,
+ stride=1,
+ padding=self.padding,
+ groups=d_model * 2,
+ bias=False,
+ )
+ self.batch_norm = nn.LayerNorm(d_model * 2)
+ self.swish = Swish()
+ self.pointwise_conv2 = nn.Conv1d(
+ d_model * 2, d_model, kernel_size=1, bias=False
+ )
+
+ def forward(
+ self, x: torch.Tensor, mask: torch.Tensor | None = None
+ ) -> torch.Tensor:
+ residual = x
+ out = self.pre_layer_norm(x)
+ out = out.transpose(1, 2)
+ if mask is not None:
+ out.masked_fill_(mask.ne(1), 0.0)
+ out = self.pointwise_conv1(out)
+ out = F.glu(out, dim=1)
+ out = self.depthwise_conv(out)
+
+ out = out.transpose(1, 2)
+ out = self.swish(self.batch_norm(out))
+ out = out.transpose(1, 2)
+
+ out = self.pointwise_conv2(out)
+ if mask is not None:
+ out.masked_fill_(mask.ne(1), 0.0)
+ out = out.transpose(1, 2)
+ return out + residual
+
+
+class RelPosEmbConformerBlock(nn.Module):
+ def __init__(self, d_model, n_head, kernel_size=33):
+ super().__init__()
+ self.ffn1 = ConformerFeedForward(d_model)
+ self.mhsa = RelPosMultiHeadAttention(n_head, d_model)
+ self.conv = ConformerConvolution(d_model, kernel_size)
+ self.ffn2 = ConformerFeedForward(d_model)
+ self.layer_norm = nn.LayerNorm(d_model)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ pos_emb: torch.Tensor,
+ slf_attn_mask: torch.Tensor | None = None,
+ pad_mask: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ out = 0.5 * x + 0.5 * self.ffn1(x)
+ out = self.mhsa(out, out, out, pos_emb, mask=slf_attn_mask)[0]
+ out = self.conv(out, pad_mask)
+ out = 0.5 * out + 0.5 * self.ffn2(out)
+ out = self.layer_norm(out)
+ return out
+
+
+class ConformerEncoder(nn.Module):
+ def __init__(
+ self,
+ idim: int,
+ n_layers_enc: int,
+ n_head: int,
+ d_model: int,
+ kernel_size: int = 33,
+ pe_maxlen: int = 5000,
+ ):
+ super().__init__()
+ self.odim = d_model
+
+ self.input_preprocessor = Conv2dSubsampling(idim, d_model)
+ self.positional_encoding = RelPositionalEncoding(d_model)
+
+ self.layer_stack = nn.ModuleList()
+ for _ in range(n_layers_enc):
+ block = RelPosEmbConformerBlock(d_model, n_head, kernel_size)
+ self.layer_stack.append(block)
+
+ def forward(
+ self, padded_input: torch.Tensor, input_lengths: torch.Tensor, pad: bool = True
+ ):
+ if pad:
+ padded_input = F.pad(
+ padded_input,
+ (0, 0, 0, self.input_preprocessor.context - 1),
+ "constant",
+ 0.0,
+ )
+ src_mask = self.padding_position_is_0(padded_input, input_lengths)
+
+ embed_output, input_lengths, src_mask = self.input_preprocessor(
+ padded_input, src_mask
+ )
+ enc_output = embed_output
+
+ pos_emb = self.positional_encoding(embed_output)
+
+ enc_outputs = []
+ for enc_layer in self.layer_stack:
+ enc_output = enc_layer(
+ enc_output, pos_emb, slf_attn_mask=src_mask, pad_mask=src_mask
+ )
+ enc_outputs.append(enc_output)
+
+ return enc_output, input_lengths, src_mask
+
+ def padding_position_is_0(
+ self, padded_input: torch.Tensor, input_lengths: torch.Tensor
+ ) -> torch.Tensor:
+ N, T = padded_input.size()[:2]
+ mask = torch.ones((N, T)).to(padded_input.device)
+ for i in range(N):
+ mask[i, input_lengths[i] :] = 0
+ mask = mask.unsqueeze(dim=1)
+ return mask.to(torch.uint8)
+
+
+class FireRedASR2Adapter(nn.Module):
+ def __init__(self, encoder_dim: int, llm_dim: int, downsample_rate: int = 2):
+ super().__init__()
+ self.ds = downsample_rate
+ self.linear1 = ReplicatedLinear(
+ input_size=encoder_dim * downsample_rate,
+ output_size=llm_dim,
+ bias=True,
+ )
+ self.relu = _ACTIVATION_REGISTRY["relu"]
+ self.linear2 = ReplicatedLinear(
+ input_size=llm_dim,
+ output_size=llm_dim,
+ bias=True,
+ )
+
+ def forward(self, x, x_lens):
+ batch_size, seq_len, feat_dim = x.size()
+ num_frames_to_discard = seq_len % self.ds
+ if num_frames_to_discard > 0:
+ x = x[:, :-num_frames_to_discard, :]
+ seq_len = x.size(1)
+
+ x = x.contiguous()
+ x = x.view(batch_size, seq_len // self.ds, feat_dim * self.ds)
+
+ x, _ = self.linear1(x)
+ x = self.relu(x)
+ x, _ = self.linear2(x)
+
+ new_x_lens = torch.clamp(x_lens, max=seq_len) // self.ds
+ return x, new_x_lens
+
+
+class FireRedASR2Encoder(nn.Module):
+ def __init__(
+ self,
+ *,
+ vllm_config: VllmConfig,
+ ):
+ super().__init__()
+ self.audio_encoder = ConformerEncoder(
+ **vllm_config.model_config.hf_config.audio_encoder_conf
+ )
+
+
+class FireRedASR2Model(nn.Module):
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ self.encoder = FireRedASR2Encoder(
+ vllm_config=vllm_config,
+ )
+ encoder_dim = self.encoder.audio_encoder.odim
+ llm_dim = vllm_config.model_config.hf_config.hidden_size
+ self.encoder_projector = FireRedASR2Adapter(
+ encoder_dim,
+ llm_dim,
+ vllm_config.model_config.hf_config.encoder_downsample_rate,
+ )
+
+ self.decoder = Qwen2ForCausalLM(
+ vllm_config=vllm_config, prefix=maybe_prefix(prefix, "decoder")
+ )
+
+ def forward(
+ self,
+ input_ids: torch.Tensor | None,
+ positions: torch.Tensor,
+ inputs_embeds: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ decoder_outputs = self.decoder(
+ input_ids=input_ids,
+ positions=positions,
+ inputs_embeds=inputs_embeds,
+ )
+ return decoder_outputs
+
+ def get_encoder_outputs(
+ self,
+ speech: torch.Tensor | list[torch.Tensor] | None,
+ speech_lengths: torch.Tensor | list[torch.Tensor] | None,
+ ) -> torch.Tensor | None:
+ encoder_outs, enc_lengths, enc_mask = self.encoder.audio_encoder(
+ speech, speech_lengths
+ )
+ speech_features, speech_lens = self.encoder_projector(encoder_outs, enc_lengths)
+ return speech_features
+
+
+class FireRedASR2ProcessingInfo(BaseProcessingInfo):
+ def get_hf_config(self) -> Qwen2Config:
+ return self.ctx.get_hf_config(Qwen2Config)
+
+ def get_supported_mm_limits(self) -> Mapping[str, int | None]:
+ return {"audio": 1}
+
+ def get_feature_extractor(self, **kwargs: object) -> FireRedASR2FeatureExtractor:
+ hf_processor = self.get_hf_processor(**kwargs)
+ feature_extractor = hf_processor.feature_extractor # type: ignore
+ assert isinstance(feature_extractor, FireRedASR2FeatureExtractor)
+ return feature_extractor
+
+ def get_data_parser(self) -> MultiModalDataParser:
+ feature_extractor = self.get_feature_extractor()
+ return MultiModalDataParser(
+ target_sr=feature_extractor.sampling_rate,
+ target_channels=self.get_target_channels(),
+ )
+
+ def get_target_channels(self) -> int:
+ return 1
+
+
+class FireRedASR2DummyInputsBuilder(BaseDummyInputsBuilder[FireRedASR2ProcessingInfo]):
+ def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
+ num_audios = mm_counts.get("audio", 0)
+
+ return "<|AUDIO|>" * num_audios
+
+ def get_dummy_mm_data(
+ self,
+ seq_len: int,
+ mm_counts: Mapping[str, int],
+ mm_options: Mapping[str, BaseDummyOptions],
+ ) -> MultiModalDataDict:
+ feature_extractor = self.info.get_feature_extractor()
+
+ sampling_rate = feature_extractor.sampling_rate
+ audio_len = feature_extractor.chunk_length * sampling_rate
+ num_audios = mm_counts.get("audio", 0)
+
+ audio_overrides = mm_options.get("audio")
+
+ ret = {
+ "audio": self._get_dummy_audios(
+ length=audio_len, num_audios=num_audios, overrides=audio_overrides
+ )
+ }
+ return ret
+
+
+class FireRedASR2MultiModalProcessor(
+ BaseMultiModalProcessor[FireRedASR2ProcessingInfo]
+):
+ def _call_hf_processor(
+ self,
+ prompt: str,
+ mm_data: Mapping[str, object],
+ mm_kwargs: Mapping[str, object],
+ tok_kwargs: Mapping[str, object],
+ ) -> BatchFeature:
+ if mm_data:
+ feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
+ mm_data = dict(audio=mm_data.pop("audios"))
+ mm_kwargs = dict(
+ **mm_kwargs,
+ sampling_rate=feature_extractor.sampling_rate,
+ )
+ processed_outputs = super()._call_hf_processor(
+ prompt=prompt,
+ mm_data=mm_data,
+ mm_kwargs=mm_kwargs,
+ tok_kwargs=tok_kwargs,
+ )
+ if "labels" in processed_outputs:
+ processed_outputs["input_ids"] = processed_outputs.pop("labels")
+ return processed_outputs
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ return dict(
+ input_features=MultiModalFieldConfig.batched("audio"),
+ speech_lengths=MultiModalFieldConfig.batched("audio"),
+ fake_token_lengths=MultiModalFieldConfig.batched("audio"),
+ )
+
+ def _get_prompt_updates(
+ self,
+ mm_items: MultiModalDataItems,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ out_mm_kwargs: MultiModalKwargsItems,
+ ) -> Sequence[PromptUpdate]:
+ processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
+ tokenizer = self.info.get_tokenizer()
+ vocab = tokenizer.get_vocab()
+
+ audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
+
+ audio_token_id = vocab[audio_token]
+
+ out_mm_data = out_mm_kwargs.get_data()
+
+ fake_token_lengths = out_mm_data.get("fake_token_lengths")
+
+ if fake_token_lengths is None:
+ audio_output_lengths = []
+ else:
+ assert isinstance(fake_token_lengths, torch.Tensor)
+
+ audio_output_lengths = fake_token_lengths.tolist()
+
+ def get_replacement_fireredasr2_audio(item_idx: int):
+ num_features = audio_output_lengths[item_idx]
+
+ audio_tokens = [audio_token_id] * int(num_features)
+
+ return PromptUpdateDetails.select_token_id(
+ audio_tokens,
+ embed_token_id=audio_token_id,
+ )
+
+ return [
+ PromptReplacement(
+ modality="audio",
+ target=[audio_token_id],
+ replacement=get_replacement_fireredasr2_audio,
+ )
+ ]
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ FireRedASR2MultiModalProcessor,
+ info=FireRedASR2ProcessingInfo,
+ dummy_inputs=FireRedASR2DummyInputsBuilder,
+)
+class FireRedASR2ForConditionalGeneration(
+ nn.Module, SupportsTranscription, SupportsMultiModal
+):
+ packed_modules_mapping = {
+ "self_attn.qkv_proj": [
+ "self_attn.q_proj",
+ "self_attn.k_proj",
+ "self_attn.v_proj",
+ ],
+ "encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
+ }
+
+ hf_to_vllm_mapper = WeightsMapper(
+ orig_to_new_substr={
+ "llm.": "model.decoder.",
+ "encoder.": "model.encoder.audio_encoder.",
+ "encoder_projector.": "model.encoder_projector.",
+ "net.0": "pre_layer_norm",
+ "net.1": "linear_expand",
+ "net.4": "linear_project",
+ }
+ )
+
+ supports_transcription_only = True
+ supports_segment_timestamp = True
+ supported_languages = ISO639_1_SUPPORTED_LANGS
+
+ @classmethod
+ def validate_language(cls, language: str | None) -> str | None:
+ if language is None:
+ # TODO language should be optional and can be guessed.
+ # For now we default to en. See
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
+ logger.warning(
+ "Defaulting to language='en'. If you wish to transcribe "
+ "audio in a different language, pass the `language` field "
+ "in the TranscriptionRequest."
+ )
+ language = "en"
+ return super().validate_language(language)
+
+ @classmethod
+ def get_generation_prompt(
+ cls,
+ audio: np.ndarray,
+ model_config: ModelConfig, # not needed here
+ stt_config: SpeechToTextConfig,
+ language: str | None,
+ task_type: Literal["transcribe", "translate"],
+ request_prompt: str,
+ to_language: str | None,
+ ) -> PromptType:
+ if language is None:
+ raise ValueError(
+ "Language must be specified when creating the fireredasr2 prompt"
+ )
+
+ prompt_str = "<|im_start|>user\n<|AUDIO|>请转写音频为文字<|im_end|>\n<|im_start|>assistant\n" # noqa: E501
+ prompt = {
+ "prompt": prompt_str,
+ "multi_modal_data": {
+ "audio": (audio, stt_config.sample_rate),
+ },
+ }
+ return cast(PromptType, prompt)
+
+ @classmethod
+ def get_speech_to_text_config(
+ cls, model_config: ModelConfig, task_type: str
+ ) -> SpeechToTextConfig:
+ processor = cached_processor_from_config(model_config)
+
+ return SpeechToTextConfig(
+ max_audio_clip_s=processor.feature_extractor.chunk_length,
+ sample_rate=processor.feature_extractor.sampling_rate,
+ )
+
+ @classmethod
+ def get_num_audio_tokens(
+ cls,
+ audio_duration_s: float,
+ stt_config: SpeechToTextConfig,
+ model_config: ModelConfig,
+ ) -> int | None:
+ processor = cached_processor_from_config(model_config)
+ hop_length = processor.feature_extractor.hop_length
+ assert hop_length is not None
+ return math.ceil(audio_duration_s * stt_config.sample_rate / hop_length)
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ self.config = config
+ self.dtype = vllm_config.model_config.dtype
+
+ self.model = FireRedASR2Model(
+ vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "model"),
+ )
+ logit_scale = getattr(config, "logit_scale", 1.0)
+
+ self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ inputs_embeds: torch.Tensor | None = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ decoder_outputs = self.model(
+ input_ids=input_ids,
+ positions=positions,
+ inputs_embeds=inputs_embeds,
+ )
+ return decoder_outputs
+
+ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
+ audio_input = self._parse_and_validate_audio_input(**kwargs)
+
+ speech = audio_input["input_features"]
+ speech_lengths = audio_input["speech_lengths"].to(torch.int32)
+ enc_output = self.model.get_encoder_outputs(
+ speech=speech, speech_lengths=speech_lengths
+ )
+
+ return enc_output
+
+ def embed_input_ids(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: MultiModalEmbeddings | None = None,
+ *,
+ is_multimodal: torch.Tensor | None = None,
+ handle_oov_mm_token: bool = False,
+ ) -> torch.Tensor:
+ inputs_embeds = self.model.decoder.embed_input_ids(input_ids)
+
+ ret = _merge_multimodal_embeddings(
+ inputs_embeds=inputs_embeds,
+ multimodal_embeddings=multimodal_embeddings,
+ is_multimodal=_require_is_multimodal(is_multimodal),
+ )
+ return ret
+
+ def _parse_and_validate_audio_input(
+ self, **kwargs: object
+ ) -> FireRedASR2AudioInputs:
+ input_features = kwargs.pop("input_features", None)
+ speech_lengths = kwargs.pop("speech_lengths", None)
+ fake_token_lengths = kwargs.pop("fake_token_lengths", None)
+
+ return FireRedASR2AudioInputs(
+ input_features=input_features,
+ speech_lengths=speech_lengths,
+ fake_token_lengths=fake_token_lengths,
+ )
+
+ def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ logits = self.logits_processor(self.model.decoder.lm_head, hidden_states)
+ return logits
+
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
+ loader = AutoWeightsLoader(
+ self, skip_prefixes=["model.encoder.audio_encoder.positional_encoding.pe"]
+ )
+
+ return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
diff --git a/vllm/model_executor/models/funaudiochat.py b/vllm/model_executor/models/funaudiochat.py
index 5bcb49e..2265d04 100644
--- a/vllm/model_executor/models/funaudiochat.py
+++ b/vllm/model_executor/models/funaudiochat.py
@@ -13,7 +13,6 @@ positions via `inputs_embeds`, while `position_ids` (RoPE) remains standard 1D.
from __future__ import annotations
-import os
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import Any
@@ -924,53 +923,6 @@ class FunAudioChatForConditionalGeneration(nn.Module, SupportsMultiModal, Suppor
f"sequence of Tensors (got {type(speech_attention_mask)})"
)
- debug = os.getenv("VLLM_FUN_AUDIOCHAT_DEBUG", "") == "1"
- if debug:
- print(
- f"[FunAudioChat] embed_multimodal speech_ids={tuple(speech_ids.shape)} "
- f"speech_attention_mask={tuple(speech_attention_mask.shape)}",
- flush=True,
- )
- attn_impl = getattr(
- self.continuous_audio_tower.config, "_attn_implementation", None
- )
- print(
- f"[FunAudioChat] audio_attn_impl={attn_impl}",
- flush=True,
- )
- if hasattr(self.continuous_audio_tower, "conv1"):
- conv1_w = self.continuous_audio_tower.conv1.weight
- print(
- f"[FunAudioChat] conv1_w_norm={float(conv1_w.norm().item()):.6g}",
- flush=True,
- )
- try:
- attn0 = self.continuous_audio_tower.layers[0].self_attn
- q_norm = float(attn0.q_proj.weight.norm().item())
- k_norm = float(attn0.k_proj.weight.norm().item())
- v_norm = float(attn0.v_proj.weight.norm().item())
- o_norm = float(attn0.out_proj.weight.norm().item())
- print(
- f"[FunAudioChat] attn0_q_norm={q_norm:.6g} "
- f"k_norm={k_norm:.6g} "
- f"v_norm={v_norm:.6g} "
- f"o_norm={o_norm:.6g}",
- flush=True,
- )
- except Exception:
- pass
- if isinstance(input_features, torch.Tensor):
- print(
- f"[FunAudioChat] input_features={tuple(input_features.shape)}",
- flush=True,
- )
- if isinstance(feature_attention_mask, torch.Tensor):
- print(
- "[FunAudioChat] feature_attention_mask="
- f"{tuple(feature_attention_mask.shape)}",
- flush=True,
- )
-
group_size = int(self.audio_tower.group_size)
speech_maxlen = int(speech_ids.shape[-1])
@@ -1019,38 +971,6 @@ class FunAudioChatForConditionalGeneration(nn.Module, SupportsMultiModal, Suppor
embeds = tuple(
audio_features[i, : int(length)] for i, length in enumerate(lengths)
)
- if debug:
- embed_lens = [int(t.shape[0]) for t in embeds]
- print(f"[FunAudioChat] embed_multimodal out_lens={embed_lens}", flush=True)
- if embeds:
- t0 = embeds[0]
- print(
- f"[FunAudioChat] embed0 dtype={t0.dtype} device={t0.device} "
- f"nan={bool(torch.isnan(t0).any())} "
- f"norm={float(t0.norm().item()):.6g}",
- flush=True,
- )
- dump_path = os.getenv("VLLM_FUN_AUDIOCHAT_DUMP_PATH", "")
- if (
- dump_path
- and speech_ids.shape[0] == 1
- and len(embeds) == 1
- and embed_lens[0] > 10
- ):
- if not os.path.exists(dump_path):
- np.save(dump_path, embeds[0].detach().float().cpu().numpy())
- print(f"[FunAudioChat] dumped embeds to {dump_path}", flush=True)
- cont_path = dump_path.replace(".npy", "_cont.npy")
- if continuous_audio_features is not None and not os.path.exists(
- cont_path
- ):
- np.save(
- cont_path,
- continuous_audio_features.detach().float().cpu().numpy(),
- )
- print(
- f"[FunAudioChat] dumped continuous to {cont_path}", flush=True
- )
return embeds
def forward(
diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py
index 770424b..5c84187 100644
--- a/vllm/model_executor/models/gemma3n.py
+++ b/vllm/model_executor/models/gemma3n.py
@@ -409,7 +409,7 @@ class Gemma3nAttention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
-
+ q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
q = q.unflatten(-1, (self.num_heads, self.head_dim))
q = self.q_norm(q)
q = q.flatten(-2, -1)
diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py
index 9705bdf..c8fdd51 100644
--- a/vllm/model_executor/models/glm4_moe.py
+++ b/vllm/model_executor/models/glm4_moe.py
@@ -110,7 +110,12 @@ class Glm4MoeMLP(nn.Module):
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
- x, _ = self.down_proj(x)
+ if self.down_proj.quant_method.__class__.__name__ != "UnquantizedLinearMethod" and x.shape[-1] != self.down_proj.weight.shape[0]:
+ padding = self.down_proj.weight.shape[0] - x.shape[-1]
+ x_align = torch.nn.functional.pad(x, (0, padding), mode='constant', value=0)
+ else:
+ x_align = x
+ x, _ = self.down_proj(x_align)
return x
@@ -144,11 +149,10 @@ class Glm4MoE(nn.Module):
config.hidden_size,
config.n_routed_experts,
bias=False,
- # dtype=torch.float32,
+ dtype=torch.bfloat16,
)
self.gate.e_score_correction_bias = nn.Parameter(
- torch.empty(config.n_routed_experts)
- )
+ torch.empty(config.n_routed_experts, dtype=torch.bfloat16))
# Load balancing settings.
vllm_config = get_current_vllm_config()
@@ -205,8 +209,7 @@ class Glm4MoE(nn.Module):
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts)
- # router_logits = self.gate(hidden_states.to(dtype=torch.float32))
- router_logits = self.gate(hidden_states)
+ router_logits = self.gate(hidden_states.to(dtype=torch.bfloat16))
fused_moe_out = self.experts(
hidden_states=hidden_states, router_logits=router_logits
@@ -312,6 +315,9 @@ class Glm4MoeAttention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+ q = q.contiguous()
+ k = k.contiguous()
+ v = v.contiguous()
if self.use_qk_norm:
q = self.q_norm(q.reshape(-1, self.num_heads, self.head_dim)).reshape(
q.shape
diff --git a/vllm/model_executor/models/glm4_moe_lite.py b/vllm/model_executor/models/glm4_moe_lite.py
index 9139457..6529986 100644
--- a/vllm/model_executor/models/glm4_moe_lite.py
+++ b/vllm/model_executor/models/glm4_moe_lite.py
@@ -127,12 +127,10 @@ class Glm4MoeLiteDecoderLayer(nn.Module):
v_head_dim = getattr(config, "v_head_dim", 0)
kv_lora_rank = getattr(config, "kv_lora_rank", 0)
- # if model_config.use_mla:
- # attn_cls = Glm4MoeLiteMLAAttention
- # else:
- # attn_cls = Glm4MoeLiteAttention
-
- attn_cls = Glm4MoeLiteAttention
+ if model_config.use_mla:
+ attn_cls = Glm4MoeLiteMLAAttention
+ else:
+ attn_cls = Glm4MoeLiteAttention
self.self_attn = attn_cls(
vllm_config=vllm_config,
@@ -306,7 +304,7 @@ class Glm4MoeLiteModel(nn.Module):
),
}
)
-
+
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
@@ -318,6 +316,120 @@ class Glm4MoeLiteModel(nn.Module):
num_experts=self.config.n_routed_experts,
)
+
+class Glm4MoeLiteForCausalLM(
+ nn.Module, SupportsPP, SupportsLoRA, Glm4LiteMixtureOfExperts
+):
+ packed_modules_mapping = {
+ "gate_up_proj": ["gate_proj", "up_proj"],
+ }
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ self.config = config
+ self.quant_config = quant_config
+
+ qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
+ qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
+ self.use_mha = config.model_type == "deepseek" or all(
+ dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
+ )
+
+ if self.use_mha:
+ self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]
+
+ # `packed_modules_mapping` needs to be modified before
+ # initializing DeepseekV2Model, as it is passed inplace to
+ # quantization config init and may be used to select the
+ # quant_method for relevant layers during initialization.
+ self.fuse_qkv_a_proj = (
+ hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
+ )
+ if self.fuse_qkv_a_proj:
+ self.packed_modules_mapping["fused_qkv_a_proj"] = [
+ "q_a_proj",
+ "kv_a_proj_with_mqa",
+ ]
+
+ self.model = Glm4MoeLiteModel(
+ vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
+ )
+ if get_pp_group().is_last_rank:
+ self.lm_head = ParallelLMHead(
+ config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ prefix=maybe_prefix(prefix, "lm_head"),
+ )
+ else:
+ self.lm_head = PPMissingLayer()
+ self.logits_processor = LogitsProcessor(config.vocab_size)
+ self.make_empty_intermediate_tensors = (
+ self.model.make_empty_intermediate_tensors
+ )
+ # Set MoE hyperparameters
+ self.num_moe_layers = (
+ self.config.num_hidden_layers - self.config.first_k_dense_replace
+ )
+ self.set_moe_parameters()
+
+ def set_moe_parameters(self):
+ self.expert_weights = []
+
+ self.num_expert_groups = getattr(self.config, "n_group", 1)
+
+ self.moe_layers = []
+ self.moe_mlp_layers = []
+ example_moe = None
+ for layer in self.model.layers:
+ if isinstance(layer, PPMissingLayer):
+ continue
+
+ assert isinstance(layer, Glm4MoeLiteDecoderLayer)
+ if isinstance(layer.mlp, Glm4MoeLite):
+ # Pick last one layer since the first ones may be dense layers.
+ example_moe = layer.mlp
+ self.moe_mlp_layers.append(layer.mlp)
+ self.moe_layers.append(layer.mlp.experts)
+
+ self.extract_moe_parameters(example_moe)
+
+ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
+ return self.model.embed_input_ids(input_ids)
+
+ def forward(
+ self,
+ input_ids: torch.Tensor | None,
+ positions: torch.Tensor,
+ intermediate_tensors: IntermediateTensors | None = None,
+ inputs_embeds: torch.Tensor | None = None,
+ ) -> torch.Tensor | IntermediateTensors:
+ hidden_states = self.model(
+ input_ids, positions, intermediate_tensors, inputs_embeds
+ )
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ ) -> torch.Tensor | None:
+ logits = self.logits_processor(self.lm_head, hidden_states)
+ return logits
+
+ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
+ # Params for weights, fp8 weight scales, fp8 activation scales
+ # (param_name, weight_name, expert_id, shard_id)
+ return SharedFusedMoE.make_expert_params_mapping(
+ self,
+ ckpt_gate_proj_name="gate_proj",
+ ckpt_down_proj_name="down_proj",
+ ckpt_up_proj_name="up_proj",
+ num_experts=self.config.n_routed_experts,
+ num_redundant_experts=0,
+ )
+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
rocm_aiter_moe_shared_expert_enabled = (
rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
@@ -327,12 +439,13 @@ class Glm4MoeLiteModel(nn.Module):
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
- mla_params_mapping = [
- ("fused_qkv_a_proj", "q_a_proj", 0),
- ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
+ mha_params_mapping = [
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
]
-
- stacked_params_mapping.extend(mla_params_mapping)
+ if self.use_mha:
+ stacked_params_mapping.extend(mha_params_mapping)
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
@@ -510,128 +623,71 @@ class Glm4MoeLiteModel(nn.Module):
weight_loader(param, loaded_weight)
if not is_fusion_moe_shared_experts_layer:
loaded_params.add(name)
+ opt_support_quant_method = ["GGUFLinearMethod", "UnquantizedLinearMethod", "CompressedTensorsW8A8Int8", "AWQMarlinLinearMethod"]
+ # add your opt here..
+ def inject_layer(layer, quant_method, is_mla):
+ q_lora_rank = getattr(layer, "q_lora_rank", None)
+ if quant_method in ["UnquantizedLinearMethod", "CompressedTensorsW8A8Int8"]:
+ if q_lora_rank is not None:
+ layer.q_a_proj.weight.data = torch.cat([layer.q_a_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=0)
+ if hasattr(layer.q_a_proj, "weight_scale"):
+ layer.q_a_proj.weight_scale.data = torch.cat([layer.q_a_proj.weight_scale, layer.kv_a_proj_with_mqa.weight_scale], dim=0)
+ del layer.kv_a_proj_with_mqa.weight_scale
+ elif not is_mla:
+ layer.q_proj.weight.data = torch.cat([layer.q_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=0)
+ if hasattr(layer.q_proj, "weight_scale"):
+ layer.q_proj.weight_scale.data = torch.cat([layer.q_proj.weight_scale, layer.kv_a_proj_with_mqa.weight_scale], dim=0)
+ del layer.kv_a_proj_with_mqa.weight_scale
+ else:
+ return
+ del layer.kv_a_proj_with_mqa.weight
+ del layer.kv_a_proj_with_mqa
+ if is_mla:
+ layer.mla_attn.forward = layer.mla_attn.forward_opt
+ else:
+ layer.forward = layer.forward_opt
+ elif quant_method == "GGUFLinearMethod":
+ pass
+ elif quant_method == "AWQMarlinLinearMethod":
+ dtype = layer.kv_a_proj_with_mqa.qweight.dtype
+ assert dtype == torch.int32
+ if layer.q_lora_rank is not None:
+ layer.q_a_proj.qweight.data = torch.cat([layer.q_a_proj.qweight, layer.kv_a_proj_with_mqa.qweight], dim=1)
+ layer.q_a_proj.scales.data = torch.cat([layer.q_a_proj.scales, layer.kv_a_proj_with_mqa.scales], dim=1)
+ del layer.kv_a_proj_with_mqa.scales
+ layer.q_a_proj.qzeros.data = torch.cat([layer.q_a_proj.qzeros, layer.kv_a_proj_with_mqa.qzeros], dim=1)
+ del layer.kv_a_proj_with_mqa.qzeros
+ elif not is_mla:
+ layer.q_proj.weight.data = torch.cat([layer.q_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=1)
+ layer.q_proj.scales.data = torch.cat([layer.q_proj.scales, layer.kv_a_proj_with_mqa.scales], dim=1)
+ del layer.kv_a_proj_with_mqa.scales
+ layer.q_proj.qzeros.data = torch.cat([layer.q_proj.qzeros, layer.kv_a_proj_with_mqa.qzeros], dim=1)
+ del layer.kv_a_proj_with_mqa.qzeros
+ else:
+ return
+ del layer.kv_a_proj_with_mqa.qweight
+ del layer.kv_a_proj_with_mqa
+ if is_mla:
+ layer.mla_attn.forward = layer.mla_attn.forward_opt
+ else:
+ layer.forward = layer.forward_opt
+ else:
+ pass
+
+ for _, layer in self.model.named_modules():
+ if layer.__class__.__name__ in ["Glm4MoeLiteAttention","Glm4MoeLiteMLAAttention"]:
+ if hasattr(layer.kv_a_proj_with_mqa, "scheme"):
+ quant_method = layer.kv_a_proj_with_mqa.scheme.__class__.__name__
+ else:
+ quant_method = layer.kv_a_proj_with_mqa.quant_method.__class__.__name__
+ if quant_method not in opt_support_quant_method:
+ break
+
+ inject_layer(layer, quant_method, is_mla = layer.__class__.__name__ == "Glm4MoeLiteMLAAttention")
return loaded_params
-class Glm4MoeLiteForCausalLM(
- nn.Module, SupportsPP, SupportsLoRA, Glm4LiteMixtureOfExperts
-):
- packed_modules_mapping = {
- "gate_up_proj": ["gate_proj", "up_proj"],
- }
-
- def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
- super().__init__()
- config = vllm_config.model_config.hf_config
- quant_config = vllm_config.quant_config
- self.config = config
- self.quant_config = quant_config
-
- qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
- qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
- self.use_mha = config.model_type == "deepseek" or all(
- dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
- )
-
- if self.use_mha:
- self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]
-
- # `packed_modules_mapping` needs to be modified before
- # initializing DeepseekV2Model, as it is passed inplace to
- # quantization config init and may be used to select the
- # quant_method for relevant layers during initialization.
- self.fuse_qkv_a_proj = (
- hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
- )
- if self.fuse_qkv_a_proj:
- self.packed_modules_mapping["fused_qkv_a_proj"] = [
- "q_a_proj",
- "kv_a_proj_with_mqa",
- ]
-
- self.model = Glm4MoeLiteModel(
- vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
- )
- if get_pp_group().is_last_rank:
- self.lm_head = ParallelLMHead(
- config.vocab_size,
- config.hidden_size,
- quant_config=quant_config,
- prefix=maybe_prefix(prefix, "lm_head"),
- )
- else:
- self.lm_head = PPMissingLayer()
- self.logits_processor = LogitsProcessor(config.vocab_size)
- self.make_empty_intermediate_tensors = (
- self.model.make_empty_intermediate_tensors
- )
- # Set MoE hyperparameters
- self.num_moe_layers = (
- self.config.num_hidden_layers - self.config.first_k_dense_replace
- )
- self.set_moe_parameters()
-
- def set_moe_parameters(self):
- self.expert_weights = []
-
- self.num_expert_groups = getattr(self.config, "n_group", 1)
-
- self.moe_layers = []
- self.moe_mlp_layers = []
- example_moe = None
- for layer in self.model.layers:
- if isinstance(layer, PPMissingLayer):
- continue
-
- assert isinstance(layer, Glm4MoeLiteDecoderLayer)
- if isinstance(layer.mlp, Glm4MoeLite):
- # Pick last one layer since the first ones may be dense layers.
- example_moe = layer.mlp
- self.moe_mlp_layers.append(layer.mlp)
- self.moe_layers.append(layer.mlp.experts)
-
- self.extract_moe_parameters(example_moe)
-
- def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
- return self.model.embed_input_ids(input_ids)
-
- def forward(
- self,
- input_ids: torch.Tensor | None,
- positions: torch.Tensor,
- intermediate_tensors: IntermediateTensors | None = None,
- inputs_embeds: torch.Tensor | None = None,
- ) -> torch.Tensor | IntermediateTensors:
- hidden_states = self.model(
- input_ids, positions, intermediate_tensors, inputs_embeds
- )
- return hidden_states
-
- def compute_logits(
- self,
- hidden_states: torch.Tensor,
- ) -> torch.Tensor | None:
- logits = self.logits_processor(self.lm_head, hidden_states)
- return logits
-
- def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
- # Params for weights, fp8 weight scales, fp8 activation scales
- # (param_name, weight_name, expert_id, shard_id)
- return SharedFusedMoE.make_expert_params_mapping(
- self,
- ckpt_gate_proj_name="gate_proj",
- ckpt_down_proj_name="down_proj",
- ckpt_up_proj_name="up_proj",
- num_experts=self.config.n_routed_experts,
- num_redundant_experts=0,
- )
-
- def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
- loader = AutoWeightsLoader(self)
- return loader.load_weights(weights)
-
-
def get_spec_layer_idx_from_weight_name(
config: "Glm4MoeLiteConfig", weight_name: str
) -> int | None:
diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py
index fd70508..61f41c9 100644
--- a/vllm/model_executor/models/gpt_oss.py
+++ b/vllm/model_executor/models/gpt_oss.py
@@ -7,7 +7,7 @@ import torch
import torch.distributed as dist
from torch import nn
from transformers import GptOssConfig
-
+import vllm.envs as envs
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
@@ -23,7 +23,11 @@ from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
from vllm.model_executor.layers.layernorm import RMSNorm
-from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
+from vllm.model_executor.layers.linear import (
+ QKVParallelLinear,
+ ReplicatedLinear,
+ RowParallelLinear,
+)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE
@@ -42,6 +46,7 @@ from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import AttentionType
+from vllm.model_executor.model_loader import padding_weight_loader
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import (
@@ -107,7 +112,6 @@ class OAIAttention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
-
self.o_proj = RowParallelLinear(
input_size=self.num_attention_heads * self.head_dim,
output_size=self.hidden_size,
@@ -165,7 +169,14 @@ class MLPBlock(torch.nn.Module):
self.hidden_size = config.hidden_size
self.experts_per_token = config.num_experts_per_tok
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
- self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts)
+ self.router = ReplicatedLinear(
+ config.hidden_size,
+ config.num_local_experts,
+ bias=True,
+ quant_config=None,
+ prefix=f"{prefix}.router",
+ return_bias=False,
+ )
assert config.intermediate_size % self.world_size == 0
self.experts = FusedMoE(
num_experts=config.num_local_experts,
@@ -969,8 +980,18 @@ class GptOssModel(nn.Module):
weights: Iterable[tuple[str, torch.Tensor]],
stacked_params_mapping: list[tuple[str, ...]],
) -> set[str]:
- params_dict = dict(self.named_parameters())
- loaded_params: set[str] = set()
+
+ def handle_weight(name, weight, param_name, permute_dims=None, slice_dims=None, contiguous=True):
+ """Helper function to handle weight loading with optional slicing and permutation."""
+ param = params_dict[param_name]
+ if slice_dims:
+ weight = weight[slice_dims]
+ if permute_dims:
+ weight = weight.permute(*permute_dims)
+ if contiguous:
+ weight = weight.contiguous()
+ padding_weight_loader(param, weight)
+ loaded_params.add(param_name)
use_ep = self.parallel_config.enable_expert_parallel
@@ -986,91 +1007,71 @@ class GptOssModel(nn.Module):
intermediate_size = self.config.intermediate_size
per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
- # Calculate common slicing bounds for current rank
tp_rank_start = tp_rank * per_rank_intermediate_size
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size)
+ params_dict = dict(self.named_parameters())
+ loaded_params: set[str] = set()
+
+ pack_factor = 2 if envs.VLLM_W8A8_MOE_USE_W4A8 else 1
+ w4a8_flag = envs.VLLM_W8A8_MOE_USE_W4A8
+ gemm_format = envs.VLLM_W8A8_FORMAT
+
for name, weight in weights:
- # Skip layers on other devices.
+ # Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
-
- if ".w13_weight" in name:
- # Handle MLP gate and up projection weights
- # Extract gate and up projection parts
- if use_ep:
- narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
- else:
- narrow_weight = weight[:, :, 2 * tp_rank_start : 2 * tp_rank_end]
-
- narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
+ if ".experts.w13_weight" in name and "scale" not in name and "bias" not in name:
+ slice_dims = (slice(ep_rank_start, ep_rank_end), ...) if use_ep else (slice(None), slice(None), slice(2 * tp_rank_start, 2 * tp_rank_end))
+ permute_dims = None if gemm_format == "NN" else (0, 2, 1)
+ handle_weight(name, weight, name, permute_dims=permute_dims, slice_dims=slice_dims)
+ elif ".experts.w2_weight" in name and "scale" not in name and "bias" not in name:
+ slice_dims = (slice(ep_rank_start, ep_rank_end), ...) if use_ep else (slice(None), slice(tp_rank_start // pack_factor, tp_rank_end // pack_factor), slice(None))
+ permute_dims = None if gemm_format == "NN" else (0, 2, 1)
+ handle_weight(name, weight, name, permute_dims=permute_dims, slice_dims=slice_dims)
+ elif ".experts.gate_up_proj_scale" in name:
+ new_name = name.replace("gate_up_proj_scale", "w13_weight_scale")
+ slice_dims = (slice(ep_rank_start, ep_rank_end), ...) if use_ep else (slice(None), slice(None), slice(2 * tp_rank_start, 2 * tp_rank_end))
+ permute_dims = None if w4a8_flag else (0, 2, 1)
+ handle_weight(name, weight, new_name, permute_dims=permute_dims, slice_dims=slice_dims, contiguous=w4a8_flag)
+ elif ".experts.down_proj_scale" in name:
+ new_name = name.replace("down_proj_scale", "w2_weight_scale")
+ slice_dims = (slice(ep_rank_start, ep_rank_end), ...) if use_ep else None
+ permute_dims = None if w4a8_flag else (0, 2, 1)
+ handle_weight(name, weight, new_name, permute_dims=permute_dims, slice_dims=slice_dims, contiguous=w4a8_flag)
+ elif ".experts.w13_bias" in name:
+ slice_dims = (slice(ep_rank_start, ep_rank_end), ...) if use_ep else (slice(None), slice(2 * tp_rank_start, 2 * tp_rank_end))
+ handle_weight(name, weight, name, slice_dims=slice_dims, contiguous=False)
+ elif ".experts.w2_bias" in name:
param = params_dict[name]
-
- param.copy_(narrow_weight)
- loaded_params.add(name)
- continue
- elif ".w2_weight" in name:
- # Handle MLP down projection weights
- if use_ep:
- narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
- else:
- narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
- narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
- param = params_dict[name]
-
- param.copy_(narrow_weight)
- loaded_params.add(name)
- continue
- elif ".w13_bias" in name:
- # Handle MLP gate and up projection biases
- # Extract gate and up projection bias parts
- if use_ep:
- narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
- else:
- narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end]
-
- param = params_dict[name]
- param.copy_(narrow_weight)
- loaded_params.add(name)
- continue
- elif ".w2_bias" in name:
- # Handle MLP down projection bias
if use_ep:
weight = weight[ep_rank_start:ep_rank_end, ...]
- else:
- # (only load on rank 0 to avoid duplication)
- if tp_rank != 0:
- weight.zero_()
- param = params_dict[name]
- param.copy_(weight)
+ elif tp_rank != 0:
+ weight.zero_()
+ param.data.copy_(weight)
loaded_params.add(name)
- continue
elif "sinks" in name:
- # Handle attention sinks (distributed across ranks)
+ name = name.replace("self_attn", "attn")
param = params_dict[name]
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
param.data.copy_(narrow_weight)
loaded_params.add(name)
- continue
- for param_name, weight_name, shard_id in stacked_params_mapping:
- if weight_name not in name:
- continue
- name = name.replace(weight_name, param_name)
- param = params_dict[name]
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
- if weight_loader == default_weight_loader:
- weight_loader(param, weight)
- else:
- weight_loader(param, weight, shard_id)
- break
+ elif ("q_proj" in name or "k_proj" in name or "v_proj" in name):
+ shard_id = ("q" if "q_proj" in name else "k" if "k_proj" in name else "v")
+ name = name.replace("self_attn", "attn")
+ param_name = name.replace(f"{shard_id}_proj", "qkv_proj")
+ param = params_dict[param_name]
+ weight_loader = param.weight_loader
+ weight_loader(param, weight, loaded_shard_id=shard_id)
+ loaded_params.add(param_name)
else:
- # Handle all other weights with potential renaming
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, weight)
- loaded_params.add(name)
+ loaded_params.add(name)
+
return loaded_params
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
diff --git a/vllm/model_executor/models/hunyuan_vision.py b/vllm/model_executor/models/hunyuan_vision.py
index 3f2d0e7..b6fda25 100644
--- a/vllm/model_executor/models/hunyuan_vision.py
+++ b/vllm/model_executor/models/hunyuan_vision.py
@@ -636,7 +636,13 @@ class HunYuanVLProcessingInfo(BaseProcessingInfo):
spatial_merge_size = vision_config.spatial_merge_size
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
- size = mm_kwargs.get("size", image_processor.size)
+ size = image_processor.size
+ if override_size := mm_kwargs.get("size"):
+ size = size | override_size
+ if (override_min_pixels := mm_kwargs.get("min_pixels")) is not None:
+ size = size | {"shortest_edge": override_min_pixels}
+ if (override_max_pixels := mm_kwargs.get("max_pixels")) is not None:
+ size = size | {"longest_edge": override_max_pixels}
if do_resize:
resized_height, resized_width = smart_resize(
diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py
index 1fb0d5e..5b0dfe4 100644
--- a/vllm/model_executor/models/hyperclovax_vision.py
+++ b/vllm/model_executor/models/hyperclovax_vision.py
@@ -49,7 +49,6 @@ from .utils import (
)
from .vision import get_vision_encoder_info
-EOT = "<|endofturn|>"
IMAGE_TOKEN: str = "<|dummy3|>"
VIDEO_TOKEN: str = "<|_unuse_missing_100270|>"
diff --git a/vllm/model_executor/models/isaac.py b/vllm/model_executor/models/isaac.py
index f4f7ce4..6d8b45a 100644
--- a/vllm/model_executor/models/isaac.py
+++ b/vllm/model_executor/models/isaac.py
@@ -551,7 +551,7 @@ def process_vision_for_patches(
`(num_images, height, width, channels)` for a batch. Channels are
expected to be RGB.
patch_size (`int`):
- Edge length of square patches; implictly controls resize grid granularity.
+ Edge length of square patches; implicitly controls resize grid granularity.
max_num_patches (`int`):
Maximum number of patches allowed after resizing.
min_num_patches (`int`, *optional*):
diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py
index 2cb7dc4..4c43e41 100644
--- a/vllm/model_executor/models/keye.py
+++ b/vllm/model_executor/models/keye.py
@@ -1021,7 +1021,13 @@ class KeyeProcessingInfo(BaseProcessingInfo):
temporal_patch_size = 1
mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs)
- size = mm_kwargs.get("size", image_processor.size)
+ size = image_processor.size
+ if override_size := mm_kwargs.get("size"):
+ size = size | override_size
+ if (override_min_pixels := mm_kwargs.get("min_pixels")) is not None:
+ size = size | {"min_pixels": override_min_pixels}
+ if (override_max_pixels := mm_kwargs.get("max_pixels")) is not None:
+ size = size | {"max_pixels": override_max_pixels}
if do_resize:
resized_height, resized_width = smart_resize(
diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py
index 4492b57..529233c 100644
--- a/vllm/model_executor/models/minicpm.py
+++ b/vllm/model_executor/models/minicpm.py
@@ -654,4 +654,4 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
- return loader.load_weights(weights)
+ return loader.load_weights(weights)
\ No newline at end of file
diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py
index e61e9d0..522f4a2 100644
--- a/vllm/model_executor/models/minicpm3.py
+++ b/vllm/model_executor/models/minicpm3.py
@@ -230,4 +230,4 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
}
def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""):
- return MiniCPM3Model(vllm_config=vllm_config, prefix=prefix)
+ return MiniCPM3Model(vllm_config=vllm_config, prefix=prefix)
\ No newline at end of file
diff --git a/vllm/model_executor/models/minimax_m2.py b/vllm/model_executor/models/minimax_m2.py
index c4591ac..0761393 100644
--- a/vllm/model_executor/models/minimax_m2.py
+++ b/vllm/model_executor/models/minimax_m2.py
@@ -88,7 +88,7 @@ class MiniMaxM2MoE(nn.Module):
self.use_routing_bias = getattr(config, "use_routing_bias", False)
if self.use_routing_bias:
self.e_score_correction_bias = nn.Parameter(
- torch.empty(config.num_local_experts, dtype=torch.float32)
+ torch.empty(config.num_local_experts, dtype=torch.get_default_dtype())
)
self.e_score_correction_bias.weight_loader = (
MiniMaxM2MoE.ebias_weight_loader
@@ -107,13 +107,14 @@ class MiniMaxM2MoE(nn.Module):
renormalize=True,
quant_config=quant_config,
prefix=f"{prefix}.experts",
+ router_logits_dtype=torch.float32,
)
self.gate = ReplicatedLinear(
config.hidden_size,
config.num_local_experts,
bias=False,
- # params_dtype=torch.float32,
+ params_dtype=torch.float32,
quant_config=None,
prefix=f"{prefix}.gate",
)
@@ -121,7 +122,6 @@ class MiniMaxM2MoE(nn.Module):
@staticmethod
def ebias_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor) -> None:
assert param.size() == loaded_weight.size()
- # param.data.copy_(loaded_weight.to(torch.float32))
param.data.copy_(loaded_weight)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -129,10 +129,9 @@ class MiniMaxM2MoE(nn.Module):
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts)
- # router_logits, _ = self.gate(hidden_states.to(torch.float32))
- router_logits, _ = self.gate(hidden_states)
+ router_logits, _ = self.gate(hidden_states.to(torch.float32))
final_hidden_states = self.experts(
- hidden_states=hidden_states, router_logits=router_logits
+ hidden_states=hidden_states, router_logits=router_logits.to(hidden_states.dtype)
)
final_hidden_states = final_hidden_states
if self.tp_size > 1:
diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py
index 46cf7fe..82422e8 100644
--- a/vllm/model_executor/models/nano_nemotron_vl.py
+++ b/vllm/model_executor/models/nano_nemotron_vl.py
@@ -44,6 +44,7 @@ from vllm.model_executor.models.internvl import (
)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
+from vllm.model_executor.models.parakeet import ParakeetExtractor, ProjectedParakeet
from vllm.model_executor.models.radio import RadioModel, calc_seq_lens
from vllm.model_executor.models.utils import (
init_vllm_registered_model,
@@ -55,12 +56,14 @@ from vllm.multimodal.evs import (
compute_retention_mask,
)
from vllm.multimodal.inputs import (
+ AudioItem,
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
VideoItem,
)
from vllm.multimodal.parse import (
+ AudioProcessorItems,
ImageEmbeddingItems,
ImageProcessorItems,
ImageSize,
@@ -91,9 +94,29 @@ Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely
# Alternative: Set a specific higher limit
# Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels
+
+class NanoNemotronVLAudioFeatureInputs(TensorSchema):
+ """
+ Dimensions:
+ - b: Number of audio clips
+ - t: Audio feature length
+ - f: Feature size (mel bins)
+ """
+
+ type: Literal["audio_features"] = "audio_features"
+ input_audio_features: Annotated[torch.Tensor, TensorShape("b", "t", "f")]
+ feature_attention_mask: Annotated[torch.Tensor, TensorShape("b", "t")]
+ audio_feature_lengths: Annotated[torch.Tensor, TensorShape("b")]
+
+
+MAX_AUDIO_LEN_S = 10 * 60 # 10 minutes
+
IMG_START = "
"
IMG_END = ""
IMG_CONTEXT = ""
+AUDIO_START = ""
+AUDIO_END = ""
+AUDIO_CONTEXT = ""
# Profiling
# MAX_FRAMES = 16
@@ -820,6 +843,11 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
self.video_token = video_token
self.video_pruning_rate = video_pruning_rate
+ self.audio_extractor: ParakeetExtractor | None = None
+ raw_sound_config = getattr(config, "sound_config", None)
+ if raw_sound_config is not None:
+ self.audio_extractor = ParakeetExtractor(raw_sound_config)
+
# Pre-tokenize special tokens for video processing
# to avoid repeated tokenization
self._img_start_token_ids = tokenizer.encode(
@@ -952,11 +980,53 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor):
text = [t.replace("