From 938d0854a58f17d17a758daa698e4ed463541535 Mon Sep 17 00:00:00 2001 From: LZiBee <2736864745@qq.com> Date: Wed, 29 Apr 2026 19:38:22 +0800 Subject: [PATCH] Upgrade to vllm 0.17.0 corex v4.1 overlay --- Dockerfile | 11 +- README.md | 69 +- .../direct_url.json | 1 - .../INSTALLER | 0 .../METADATA | 15 +- .../RECORD | 838 +++--- .../REQUESTED | 0 .../WHEEL | 2 +- .../direct_url.json | 1 + .../entry_points.txt | 0 .../licenses/LICENSE | 0 .../top_level.txt | 0 vllm/.gitignore | 244 -- vllm/__init__.py | 6 - vllm/_aiter_ops.py | 160 ++ vllm/_bc_linter.py | 54 - vllm/_custom_ops.py | 1741 ++++++------ vllm/benchmarks/datasets.py | 31 +- vllm/benchmarks/lib/__init__.py | 3 + vllm/benchmarks/lib/endpoint_request_func.py | 802 ++++++ vllm/benchmarks/lib/ready_checker.py | 79 + vllm/benchmarks/lib/utils.py | 131 + vllm/benchmarks/plot.py | 316 +++ vllm/benchmarks/serve.py | 167 +- vllm/benchmarks/sweep/cli.py | 6 +- vllm/benchmarks/sweep/plot.py | 17 +- vllm/benchmarks/sweep/plot_pareto.py | 5 +- vllm/benchmarks/sweep/serve.py | 129 +- vllm/benchmarks/sweep/serve_sla.py | 305 --- vllm/benchmarks/sweep/serve_workload.py | 328 +++ vllm/benchmarks/sweep/startup.py | 103 +- vllm/compilation/backends.py | 70 +- vllm/compilation/caching.py | 21 +- vllm/compilation/compiler_interface.py | 48 +- vllm/compilation/counter.py | 2 + vllm/compilation/cuda_graph.py | 7 +- .../passes/fusion/collective_fusion.py | 6 +- .../passes/fusion/rocm_aiter_fusion.py | 22 +- .../passes/fusion/sequence_parallelism.py | 4 - .../passes/utility/fix_functionalization.py | 10 +- vllm/compilation/piecewise_backend.py | 24 +- vllm/compilation/wrapper.py | 49 + vllm/config/attention.py | 4 +- vllm/config/compilation.py | 18 +- vllm/config/model.py | 25 +- vllm/config/parallel.py | 142 +- vllm/config/speculative.py | 230 +- vllm/config/vllm.py | 131 +- vllm/config/weight_transfer.py | 2 +- vllm/device_allocator/cumem.py | 17 +- .../device_communicators/all2all.py | 128 +- .../device_communicators/all_reduce_utils.py | 49 +- .../base_device_communicator.py | 68 +- .../device_communicators/cpu_communicator.py | 37 +- .../device_communicators/cuda_communicator.py | 115 +- .../device_communicators/pynccl.py | 36 +- vllm/distributed/elastic_ep/__init__.py | 0 .../distributed/elastic_ep/elastic_execute.py | 529 ++++ vllm/distributed/elastic_ep/elastic_state.py | 563 ++++ vllm/distributed/elastic_ep/standby_state.py | 117 + vllm/distributed/eplb/async_worker.py | 4 - vllm/distributed/eplb/eplb_state.py | 314 +-- vllm/distributed/eplb/rebalance_execute.py | 80 +- vllm/distributed/kv_events.py | 4 + .../kv_transfer/kv_connector/factory.py | 6 + .../kv_transfer/kv_connector/utils.py | 45 +- .../kv_transfer/kv_connector/v1/base.py | 22 + .../kv_connector/v1/example_connector.py | 19 +- .../v1/example_hidden_states_connector.py | 354 +++ .../kv_connector/v1/lmcache_connector.py | 10 + .../v1/mooncake/mooncake_connector.py | 31 +- .../kv_connector/v1/multi_connector.py | 15 + .../kv_connector/v1/nixl_connector.py | 1231 +++++---- vllm/distributed/parallel_state.py | 301 ++- vllm/distributed/stateless_coordinator.py | 322 +++ vllm/distributed/utils.py | 93 +- vllm/distributed/weight_transfer/base.py | 29 +- vllm/distributed/weight_transfer/factory.py | 6 + .../distributed/weight_transfer/ipc_engine.py | 291 ++ .../weight_transfer/nccl_engine.py | 85 +- vllm/engine/arg_utils.py | 30 +- vllm/entrypoints/anthropic/api_router.py | 66 +- vllm/entrypoints/anthropic/protocol.py | 44 +- vllm/entrypoints/anthropic/serving.py | 684 +++-- vllm/entrypoints/chat_utils.py | 8 + vllm/entrypoints/cli/__init__.py | 25 +- vllm/entrypoints/cli/benchmark/main.py | 28 - vllm/entrypoints/cli/serve.py | 12 +- vllm/entrypoints/grpc_server.py | 15 +- vllm/entrypoints/llm.py | 158 +- vllm/entrypoints/logger.py | 14 + .../openai/chat_completion/protocol.py | 41 +- .../openai/chat_completion/serving.py | 22 +- vllm/entrypoints/openai/cli_args.py | 7 +- .../entrypoints/openai/completion/protocol.py | 43 +- vllm/entrypoints/openai/engine/serving.py | 21 +- vllm/entrypoints/openai/responses/protocol.py | 19 +- vllm/entrypoints/openai/responses/serving.py | 79 +- .../openai/speech_to_text/speech_to_text.py | 21 +- .../openai/translations/__init__.py | 12 - .../openai/translations/api_router.py | 14 - .../openai/translations/protocol.py | 14 - .../openai/translations/serving.py | 14 - .../openai/translations/speech_to_text.py | 15 - vllm/entrypoints/pooling/__init__.py | 1 + vllm/entrypoints/pooling/base/io_processor.py | 189 ++ vllm/entrypoints/pooling/base/protocol.py | 4 - vllm/entrypoints/pooling/base/serving.py | 378 +++ .../pooling/classify/api_router.py | 31 +- .../pooling/classify/io_processor.py | 50 + vllm/entrypoints/pooling/classify/protocol.py | 2 - vllm/entrypoints/pooling/classify/serving.py | 132 +- vllm/entrypoints/pooling/embed/protocol.py | 19 - .../pooling/io_processor_factories.py | 31 + vllm/entrypoints/pooling/pooling/protocol.py | 19 - vllm/entrypoints/pooling/score/protocol.py | 2 - vllm/entrypoints/pooling/score/serving.py | 19 +- vllm/entrypoints/pooling/score/utils.py | 88 +- vllm/entrypoints/pooling/typing.py | 51 + vllm/entrypoints/sagemaker/api_router.py | 3 +- vllm/entrypoints/utils.py | 71 +- vllm/env_override.py | 41 + vllm/envs.py | 203 +- vllm/forward_context.py | 3 +- vllm/kernels/helion/register.py | 127 +- vllm/lora/layers/fused_moe.py | 25 +- vllm/lora/layers/logits_processor.py | 6 +- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 252 +- vllm/lora/ops/triton_ops/lora_expand_op.py | 6 +- .../ops/triton_ops/lora_kernel_metadata.py | 42 +- vllm/lora/ops/triton_ops/lora_shrink_op.py | 9 +- vllm/lora/ops/triton_ops/utils.py | 6 + vllm/lora/punica_wrapper/punica_gpu.py | 13 + .../model_executor/kernels/linear/__init__.py | 2 +- .../kernels/linear/mixed_precision/machete.py | 39 +- .../kernels/linear/mixed_precision/marlin.py | 314 ++- .../kernels/linear/scaled_mm/cutlass.py | 56 +- vllm/model_executor/layers/activation.py | 28 +- .../layers/attention/attention.py | 123 +- .../layers/attention/extra_cache.py | 131 + .../layers/attention/mla_attention.py | 2039 +++++++------- .../layers/attention/mm_encoder_attention.py | 209 +- .../model_executor/layers/fla/ops/__init__.py | 8 +- vllm/model_executor/layers/fla/ops/chunk.py | 10 +- .../layers/fla/ops/chunk_delta_h.py | 2 +- vllm/model_executor/layers/fla/ops/chunk_o.py | 4 +- .../layers/fla/ops/chunk_scaled_dot_kkt.py | 4 +- .../layers/fla/ops/fused_recurrent.py | 248 +- .../layers/fla/ops/fused_sigmoid_gating.py | 279 ++ vllm/model_executor/layers/fla/ops/index.py | 10 +- vllm/model_executor/layers/fla/ops/kda.py | 16 +- .../layers/fla/ops/layernorm_guard.py | 35 +- vllm/model_executor/layers/fla/ops/wy_fast.py | 2 +- .../layers/fused_moe/__init__.py | 12 +- .../layers/fused_moe/activation.py | 12 +- .../layers/fused_moe/all2all_utils.py | 68 +- .../layers/fused_moe/batched_deep_gemm_moe.py | 2 +- .../model_executor/layers/fused_moe/config.py | 16 +- ...evice_name=NVIDIA_H200,dtype=fp8_w8a8.json | 147 + .../E=128,N=96,device_name=NVIDIA_H200.json | 147 + .../layers/fused_moe/cutlass_moe.py | 17 +- .../layers/fused_moe/deep_gemm_moe.py | 2 +- .../fused_moe/deepep_ht_prepare_finalize.py | 9 +- .../fused_moe/deepep_ll_prepare_finalize.py | 12 +- .../layers/fused_moe/experts/__init__.py | 0 .../fused_moe/experts/trtllm_fp8_moe.py | 335 +++ .../fused_moe/experts/trtllm_nvfp4_moe.py | 326 +++ .../layers/fused_moe/fallback.py | 12 +- .../flashinfer_a2a_prepare_finalize.py | 6 +- .../fused_moe/flashinfer_cutedsl_moe.py | 2 +- .../fused_moe/flashinfer_cutlass_moe.py | 2 +- .../layers/fused_moe/flashinfer_trtllm_moe.py | 298 --- .../layers/fused_moe/fused_batched_moe.py | 12 +- .../layers/fused_moe/fused_marlin_moe.py | 2 +- .../layers/fused_moe/fused_moe.py | 594 +++-- .../layers/fused_moe/fused_moe_method_base.py | 51 +- .../fused_moe/fused_moe_modular_method.py | 18 +- .../fused_moe/gpt_oss_triton_kernels_moe.py | 139 +- vllm/model_executor/layers/fused_moe/layer.py | 138 +- .../layers/fused_moe/modular_kernel.py | 701 +++-- .../layers/fused_moe/mori_prepare_finalize.py | 2 +- .../layers/fused_moe/oracle/fp8.py | 132 +- .../layers/fused_moe/oracle/nvfp4.py | 143 +- .../layers/fused_moe/oracle/unquantized.py | 35 +- .../layers/fused_moe/pplx_prepare_finalize.py | 373 --- .../layers/fused_moe/prepare_finalize.py | 209 -- .../fused_moe/prepare_finalize/__init__.py | 22 + .../fused_moe/prepare_finalize/naive_dp_ep.py | 253 ++ .../fused_moe/prepare_finalize/no_dp_ep.py | 141 + .../layers/fused_moe/rocm_aiter_fused_moe.py | 2 +- .../fused_moe/routed_experts_capturer.py | 3 +- .../layers/fused_moe/router/base_router.py | 3 +- .../router/fused_topk_bias_router.py | 9 +- .../fused_moe/router/fused_topk_router.py | 8 +- .../layers/fused_moe/router/gate_linear.py | 115 + .../fused_moe/router/grouped_topk_router.py | 83 +- .../layers/fused_moe/router/router_factory.py | 2 +- .../fused_moe/runner/default_moe_runner.py | 143 +- .../fused_moe/topk_weight_and_reduce.py | 13 +- .../layers/fused_moe/triton_cutlass_moe.py | 6 +- .../layers/fused_moe/triton_deep_gemm_moe.py | 6 +- .../layers/fused_moe/trtllm_moe.py | 2 +- .../fused_moe/unquantized_fused_moe_method.py | 28 +- .../layers/fused_moe/xpu_fused_moe.py | 2 +- vllm/model_executor/layers/layernorm.py | 36 +- vllm/model_executor/layers/linear.py | 260 +- .../model_executor/layers/logits_processor.py | 5 +- .../layers/mamba/linear_attn.py | 189 +- .../layers/mamba/mamba_mixer.py | 4 + .../layers/mamba/mamba_utils.py | 3 - .../layers/mamba/ops/causal_conv1d.py | 2 +- .../layers/mamba/ops/mamba_ssm.py | 4 + vllm/model_executor/layers/mla.py | 98 +- .../layers/quantization/__init__.py | 12 +- .../layers/quantization/awq_marlin.py | 256 +- .../compressed_tensors/compressed_tensors.py | 16 +- .../compressed_tensors_moe.py | 2367 +++++++++++++---- .../compressed_tensors/schemes/__init__.py | 3 +- .../schemes/compressed_tensors_w8a8_int8.py | 145 +- .../model_executor/layers/quantization/fp8.py | 97 +- .../layers/quantization/gguf.py | 14 +- .../layers/quantization/gptq.py | 2 +- .../layers/quantization/gptq_marlin.py | 560 +++- .../layers/quantization/modelopt.py | 540 ++-- .../layers/quantization/mxfp4.py | 126 +- .../layers/quantization/ptpc_fp8.py | 5 - .../layers/quantization/quark/quark.py | 48 +- .../layers/quantization/quark/quark_moe.py | 202 +- .../quark/schemes/quark_ocp_mx.py | 93 +- .../quantization/utils/flashinfer_fp4_moe.py | 296 +-- .../quantization/utils/flashinfer_utils.py | 119 +- .../layers/quantization/utils/fp8_utils.py | 5 +- .../layers/quantization/utils/gguf_utils.py | 373 +++ .../layers/quantization/utils/marlin_utils.py | 18 - .../layers/quantization/utils/mxfp4_utils.py | 4 +- .../layers/quantization/utils/quant_utils.py | 10 +- .../layers/quantization/w8a16.py | 114 + .../layers/rotary_embedding/base.py | 1 + .../layers/rotary_embedding/common.py | 15 - .../rotary_embedding/deepseek_scaling_rope.py | 74 +- .../layers/rotary_embedding/mrope.py | 78 +- .../phi3_long_rope_scaled_rope.py | 38 +- .../shared_fused_moe/shared_fused_moe.py | 56 + .../layers/sparse_attn_indexer.py | 349 ++- vllm/model_executor/layers/utils.py | 100 +- vllm/model_executor/model_loader/__init__.py | 4 + .../model_loader/default_loader.py | 9 +- .../model_loader/weight_utils.py | 229 +- vllm/model_executor/models/AXK1.py | 1168 ++++++++ .../models/bailing_moe_linear.py | 1246 +++++++++ vllm/model_executor/models/config.py | 51 +- vllm/model_executor/models/deepencoder.py | 7 +- vllm/model_executor/models/deepseek_mtp.py | 136 +- vllm/model_executor/models/deepseek_v2.py | 668 ++--- vllm/model_executor/models/ernie45_moe.py | 11 +- vllm/model_executor/models/ernie45_vl.py | 21 +- .../models/extract_hidden_states.py | 394 +++ vllm/model_executor/models/fireredasr2.py | 829 ++++++ vllm/model_executor/models/funaudiochat.py | 80 - vllm/model_executor/models/gemma3n.py | 2 +- vllm/model_executor/models/glm4_moe.py | 18 +- vllm/model_executor/models/glm4_moe_lite.py | 316 ++- vllm/model_executor/models/gpt_oss.py | 141 +- vllm/model_executor/models/hunyuan_vision.py | 8 +- .../models/hyperclovax_vision.py | 1 - vllm/model_executor/models/isaac.py | 2 +- vllm/model_executor/models/keye.py | 8 +- vllm/model_executor/models/minicpm.py | 2 +- vllm/model_executor/models/minicpm3.py | 2 +- vllm/model_executor/models/minimax_m2.py | 11 +- .../model_executor/models/nano_nemotron_vl.py | 372 ++- vllm/model_executor/models/nemotron_h.py | 20 +- vllm/model_executor/models/nemotron_vl.py | 359 ++- vllm/model_executor/models/ovis2_5.py | 3 - vllm/model_executor/models/paddleocr_vl.py | 21 +- vllm/model_executor/models/parakeet.py | 145 + vllm/model_executor/models/phimoe.py | 3 +- vllm/model_executor/models/pixtral.py | 27 +- vllm/model_executor/models/qwen.py | 2 +- .../models/qwen2_5_omni_thinker.py | 92 +- vllm/model_executor/models/qwen2_5_vl.py | 13 +- vllm/model_executor/models/qwen2_audio.py | 20 + vllm/model_executor/models/qwen2_vl.py | 18 +- vllm/model_executor/models/qwen3.py | 30 +- vllm/model_executor/models/qwen3_5.py | 30 +- vllm/model_executor/models/qwen3_5_mtp.py | 2 +- vllm/model_executor/models/qwen3_moe.py | 44 +- vllm/model_executor/models/qwen3_next.py | 281 +- .../models/qwen3_omni_moe_thinker.py | 62 +- vllm/model_executor/models/qwen3_vl.py | 883 ++++-- vllm/model_executor/models/qwen3_vl_moe.py | 7 +- vllm/model_executor/models/qwen_vl.py | 68 +- vllm/model_executor/models/registry.py | 15 + vllm/model_executor/models/step3_text.py | 2 +- vllm/model_executor/models/step3p5.py | 67 +- vllm/model_executor/models/step3p5_mtp.py | 63 +- vllm/model_executor/models/swin.py | 500 ---- vllm/model_executor/models/teleflm.py | 3 +- .../models/transformers/base.py | 14 +- .../models/transformers/multimodal.py | 26 +- vllm/model_executor/models/utils.py | 16 +- vllm/model_executor/parameter.py | 29 +- .../model_executor/warmup/deep_gemm_warmup.py | 2 +- vllm/model_executor/warmup/kernel_warmup.py | 11 +- vllm/multimodal/evs.py | 78 +- vllm/multimodal/processing/processor.py | 12 +- vllm/multimodal/utils.py | 18 - vllm/platforms/__init__.py | 40 +- vllm/platforms/cpu.py | 24 + vllm/platforms/cuda.py | 37 +- vllm/platforms/interface.py | 13 +- vllm/platforms/rocm.py | 150 +- vllm/platforms/xpu.py | 6 + vllm/plugins/io_processors/__init__.py | 18 +- vllm/plugins/io_processors/interface.py | 3 +- vllm/pooling_params.py | 8 +- vllm/reasoning/__init__.py | 4 +- vllm/reasoning/kimi_k2_reasoning_parser.py | 228 ++ vllm/reasoning/minimax_m2_reasoning_parser.py | 7 +- vllm/reasoning/qwen3_reasoning_parser.py | 25 +- vllm/renderers/qwen_vl.py | 29 + vllm/renderers/registry.py | 1 + vllm/sampling_params.py | 87 +- vllm/third_party/pynvml.py | 2 +- vllm/tokenizers/deepseek_v32.py | 2 +- vllm/tokenizers/qwen_vl.py | 67 + vllm/tokenizers/registry.py | 5 + vllm/tool_parsers/qwen3coder_tool_parser.py | 399 ++- vllm/tool_parsers/utils.py | 15 - vllm/transformers_utils/configs/AXK1.py | 215 ++ vllm/transformers_utils/configs/__init__.py | 2 + .../configs/extract_hidden_states.py | 53 + vllm/transformers_utils/configs/parakeet.py | 49 + .../model_arch_config_convertor.py | 10 +- vllm/transformers_utils/processor.py | 94 +- .../transformers_utils/processors/__init__.py | 4 + .../processors/fireredasr2_processor.py | 341 +++ vllm/transformers_utils/repo_utils.py | 2 +- vllm/triton_utils/allocation.py | 13 + vllm/utils/deep_gemm.py | 121 + vllm/utils/flashinfer.py | 9 +- vllm/utils/import_utils.py | 5 - vllm/utils/math_utils.py | 5 + vllm/utils/system_utils.py | 16 + vllm/utils/torch_utils.py | 39 +- vllm/v1/attention/backend.py | 71 +- vllm/v1/attention/backends/fa_utils.py | 139 +- vllm/v1/attention/backends/flash_attn.py | 616 +++-- vllm/v1/attention/backends/flashinfer.py | 295 +- vllm/v1/attention/backends/mamba1_attn.py | 33 +- vllm/v1/attention/backends/mamba2_attn.py | 112 +- vllm/v1/attention/backends/mamba_attn.py | 121 + .../attention/backends/mla/flashattn_mla.py | 6 +- vllm/v1/attention/backends/mla/flashmla.py | 6 +- .../attention/backends/mla/flashmla_sparse.py | 112 +- vllm/v1/attention/backends/mla/indexer.py | 247 +- .../attention/backends/mla/rocm_aiter_mla.py | 6 +- vllm/v1/attention/backends/mla/triton_mla.py | 71 +- .../backends/rocm_aiter_unified_attn.py | 31 + vllm/v1/attention/backends/rocm_attn.py | 80 +- vllm/v1/attention/ops/flashmla.py | 28 + vllm/v1/attention/ops/vit_attn_wrappers.py | 92 +- vllm/v1/core/kv_cache_manager.py | 38 + vllm/v1/core/kv_cache_utils.py | 143 +- vllm/v1/core/sched/output.py | 13 +- vllm/v1/core/sched/scheduler.py | 583 +++- vllm/v1/core/sched/utils.py | 66 + vllm/v1/core/single_type_kv_cache_manager.py | 11 + vllm/v1/cudagraph_dispatcher.py | 60 +- vllm/v1/engine/__init__.py | 20 +- vllm/v1/engine/async_llm.py | 42 +- vllm/v1/engine/coordinator.py | 21 +- vllm/v1/engine/core.py | 413 ++- vllm/v1/engine/core_client.py | 336 ++- vllm/v1/engine/input_processor.py | 11 - vllm/v1/engine/llm_engine.py | 1 + vllm/v1/engine/utils.py | 68 +- vllm/v1/executor/abstract.py | 10 +- vllm/v1/executor/multiproc_executor.py | 42 +- vllm/v1/executor/ray_executor.py | 6 +- vllm/v1/executor/ray_utils.py | 52 +- vllm/v1/executor/uniproc_executor.py | 17 +- vllm/v1/kv_cache_interface.py | 112 +- vllm/v1/kv_offload/worker/cpu_gpu.py | 22 +- vllm/v1/outputs.py | 70 +- vllm/v1/request.py | 2 + vllm/v1/sample/logits_processor/__init__.py | 7 +- vllm/v1/sample/logits_processor/builtin.py | 54 + vllm/v1/sample/logits_processor/state.py | 4 +- vllm/v1/sample/rejection_sampler.py | 15 +- vllm/v1/spec_decode/eagle.py | 596 +++-- vllm/v1/spec_decode/extract_hidden_states.py | 395 +++ vllm/v1/spec_decode/metadata.py | 42 + vllm/v1/spec_decode/multi_layer_eagle.py | 504 ++++ vllm/v1/spec_decode/utils.py | 13 +- vllm/v1/worker/cpu_model_runner.py | 7 +- vllm/v1/worker/cpu_worker.py | 5 +- vllm/v1/worker/dp_utils.py | 48 +- vllm/v1/worker/gpu/async_utils.py | 36 + vllm/v1/worker/gpu/block_table.py | 11 + vllm/v1/worker/gpu/cudagraph_utils.py | 132 +- vllm/v1/worker/gpu/input_batch.py | 54 +- vllm/v1/worker/gpu/mm/encoder_cache.py | 40 + vllm/v1/worker/gpu/mm/encoder_runner.py | 53 +- vllm/v1/worker/gpu/model_runner.py | 524 ++-- vllm/v1/worker/gpu/model_states/__init__.py | 18 + vllm/v1/worker/gpu/model_states/default.py | 161 ++ vllm/v1/worker/gpu/model_states/interface.py | 67 + vllm/v1/worker/gpu/pool/__init__.py | 0 vllm/v1/worker/gpu/pool/pooling_runner.py | 45 + vllm/v1/worker/gpu/sample/bad_words.py | 18 +- vllm/v1/worker/gpu/sample/gumbel.py | 48 +- vllm/v1/worker/gpu/sample/logit_bias.py | 32 +- vllm/v1/worker/gpu/sample/min_p.py | 24 +- vllm/v1/worker/gpu/sample/penalties.py | 30 +- vllm/v1/worker/gpu/sample/sampler.py | 24 +- vllm/v1/worker/gpu/sample/states.py | 14 +- .../worker/gpu/spec_decode/eagle/cudagraph.py | 27 +- .../gpu/spec_decode/eagle/speculator.py | 114 +- vllm/v1/worker/gpu/warmup.py | 105 + vllm/v1/worker/gpu_input_batch.py | 89 + vllm/v1/worker/gpu_model_runner.py | 963 ++++--- vllm/v1/worker/gpu_worker.py | 326 +-- vllm/v1/worker/mamba_utils.py | 94 +- vllm/v1/worker/utils.py | 366 ++- vllm/v1/worker/worker_base.py | 17 +- vllm/v1/worker/workspace.py | 28 + vllm/version.py | 4 +- vllm/vllm_flash_attn/__init__.py | 24 + vllm/vllm_flash_attn/flash_attn_interface.py | 567 ++++ 430 files changed, 35969 insertions(+), 14511 deletions(-) delete mode 100644 vllm-0.16.1rc0+corex.4.4.0.dist-info/direct_url.json rename {vllm-0.16.1rc0+corex.4.4.0.dist-info => vllm-0.17.0+corex.20260420090923.dist-info}/INSTALLER (100%) rename {vllm-0.16.1rc0+corex.4.4.0.dist-info => vllm-0.17.0+corex.20260420090923.dist-info}/METADATA (95%) rename {vllm-0.16.1rc0+corex.4.4.0.dist-info => vllm-0.17.0+corex.20260420090923.dist-info}/RECORD (87%) rename {vllm-0.16.1rc0+corex.4.4.0.dist-info => vllm-0.17.0+corex.20260420090923.dist-info}/REQUESTED (100%) rename {vllm-0.16.1rc0+corex.4.4.0.dist-info => vllm-0.17.0+corex.20260420090923.dist-info}/WHEEL (65%) create mode 100644 vllm-0.17.0+corex.20260420090923.dist-info/direct_url.json rename {vllm-0.16.1rc0+corex.4.4.0.dist-info => vllm-0.17.0+corex.20260420090923.dist-info}/entry_points.txt (100%) rename {vllm-0.16.1rc0+corex.4.4.0.dist-info => vllm-0.17.0+corex.20260420090923.dist-info}/licenses/LICENSE (100%) rename {vllm-0.16.1rc0+corex.4.4.0.dist-info => vllm-0.17.0+corex.20260420090923.dist-info}/top_level.txt (100%) delete mode 100644 vllm/.gitignore delete mode 100644 vllm/_bc_linter.py create mode 100644 vllm/benchmarks/lib/__init__.py create mode 100644 vllm/benchmarks/lib/endpoint_request_func.py create mode 100644 vllm/benchmarks/lib/ready_checker.py create mode 100644 vllm/benchmarks/lib/utils.py create mode 100644 vllm/benchmarks/plot.py delete mode 100644 vllm/benchmarks/sweep/serve_sla.py create mode 100644 vllm/benchmarks/sweep/serve_workload.py create mode 100644 vllm/distributed/elastic_ep/__init__.py create mode 100644 vllm/distributed/elastic_ep/elastic_execute.py create mode 100644 vllm/distributed/elastic_ep/elastic_state.py create mode 100644 vllm/distributed/elastic_ep/standby_state.py create mode 100644 vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py create mode 100644 vllm/distributed/stateless_coordinator.py create mode 100644 vllm/distributed/weight_transfer/ipc_engine.py delete mode 100644 vllm/entrypoints/openai/translations/__init__.py delete mode 100644 vllm/entrypoints/openai/translations/api_router.py delete mode 100644 vllm/entrypoints/openai/translations/protocol.py delete mode 100644 vllm/entrypoints/openai/translations/serving.py delete mode 100644 vllm/entrypoints/openai/translations/speech_to_text.py create mode 100644 vllm/entrypoints/pooling/base/io_processor.py create mode 100644 vllm/entrypoints/pooling/base/serving.py create mode 100644 vllm/entrypoints/pooling/classify/io_processor.py create mode 100644 vllm/entrypoints/pooling/io_processor_factories.py create mode 100644 vllm/entrypoints/pooling/typing.py create mode 100644 vllm/model_executor/layers/attention/extra_cache.py create mode 100644 vllm/model_executor/layers/fla/ops/fused_sigmoid_gating.py create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H200,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H200.json create mode 100644 vllm/model_executor/layers/fused_moe/experts/__init__.py create mode 100644 vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py create mode 100644 vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py delete mode 100644 vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py delete mode 100644 vllm/model_executor/layers/fused_moe/prepare_finalize.py create mode 100644 vllm/model_executor/layers/fused_moe/prepare_finalize/__init__.py create mode 100644 vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py create mode 100644 vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py create mode 100644 vllm/model_executor/layers/fused_moe/router/gate_linear.py create mode 100644 vllm/model_executor/layers/quantization/utils/gguf_utils.py create mode 100644 vllm/model_executor/layers/quantization/w8a16.py create mode 100644 vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py create mode 100644 vllm/model_executor/models/AXK1.py create mode 100644 vllm/model_executor/models/bailing_moe_linear.py create mode 100644 vllm/model_executor/models/extract_hidden_states.py create mode 100644 vllm/model_executor/models/fireredasr2.py create mode 100644 vllm/model_executor/models/parakeet.py delete mode 100644 vllm/model_executor/models/swin.py create mode 100644 vllm/reasoning/kimi_k2_reasoning_parser.py create mode 100644 vllm/renderers/qwen_vl.py create mode 100644 vllm/tokenizers/qwen_vl.py create mode 100644 vllm/transformers_utils/configs/AXK1.py create mode 100644 vllm/transformers_utils/configs/extract_hidden_states.py create mode 100644 vllm/transformers_utils/configs/parakeet.py create mode 100644 vllm/transformers_utils/processors/fireredasr2_processor.py create mode 100644 vllm/triton_utils/allocation.py create mode 100644 vllm/v1/spec_decode/extract_hidden_states.py create mode 100644 vllm/v1/spec_decode/multi_layer_eagle.py create mode 100644 vllm/v1/worker/gpu/mm/encoder_cache.py create mode 100644 vllm/v1/worker/gpu/model_states/__init__.py create mode 100644 vllm/v1/worker/gpu/model_states/default.py create mode 100644 vllm/v1/worker/gpu/model_states/interface.py create mode 100644 vllm/v1/worker/gpu/pool/__init__.py create mode 100644 vllm/v1/worker/gpu/pool/pooling_runner.py create mode 100644 vllm/v1/worker/gpu/warmup.py create mode 100644 vllm/vllm_flash_attn/__init__.py create mode 100644 vllm/vllm_flash_attn/flash_attn_interface.py diff --git a/Dockerfile b/Dockerfile index e83ace5..99280a6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,11 +1,12 @@ -FROM registry.iluvatar.com.cn:10443/customer/sz/vllm0.11.2-4.4.0-x86:v8 +ARG BASE_IMAGE=registry.iluvatar.com.cn:10443/customer/sz/vllm0.17.0-4.4.0-x86:v4.1 +FROM ${BASE_IMAGE} -# Keep the runtime stack from the known-good v8 image, but replace the -# installed Python package with the repository's patched 0.16.1rc0 sources. -WORKDIR /tmp +WORKDIR /home RUN rm -rf /usr/local/lib/python3.12/dist-packages/vllm \ /usr/local/lib/python3.12/dist-packages/vllm-*.dist-info COPY vllm /usr/local/lib/python3.12/dist-packages/vllm -COPY vllm-0.16.1rc0+corex.4.4.0.dist-info /usr/local/lib/python3.12/dist-packages/vllm-0.16.1rc0+corex.4.4.0.dist-info +COPY vllm-0.17.0+corex.20260420090923.dist-info /usr/local/lib/python3.12/dist-packages/vllm-0.17.0+corex.20260420090923.dist-info + +ENTRYPOINT ["/bin/bash"] diff --git a/README.md b/README.md index 8689007..81019bd 100644 --- a/README.md +++ b/README.md @@ -1,62 +1,37 @@ # bi_150-vllm -基于 `registry.iluvatar.com.cn:10443/customer/sz/vllm0.11.2-4.4.0-x86:v8` 的 -`vLLM 0.16.1rc0` 构建仓库,用于在 BI-V150 虚拟机环境中生成可直接运行的镜像。 +This repository contains the extracted `vLLM 0.17.0+corex.20260420090923` +Python package used to overlay the vendor image +`registry.iluvatar.com.cn:10443/customer/sz/vllm0.17.0-4.4.0-x86:v4.1`. -## 改动说明 - -本仓库只保留构建镜像所需的最小内容: +## Included files - `vllm/` - 当前运行代码 -- `vllm-0.16.1rc0+corex.4.4.0.dist-info/` - 对应的包元数据 + The Python package code copied from the image payload. +- `vllm-0.17.0+corex.20260420090923.dist-info/` + The package metadata extracted from the image. - `Dockerfile` - 构建最终镜像 + Builds a new image by replacing the installed `vllm` package in the vendor base image. -与基础镜像相比,本仓库保留的关键代码改动如下: +## Build -- 在 `vllm/platforms/__init__.py` 中修复 CUDA 平台识别逻辑 -- 当 NVML 不可用且出现 `NVML Shared Library Not Found` 一类错误时 - 不再直接判定为非 CUDA 平台 -- 改为回退到 `torch.cuda.is_available()` 和 - `torch.cuda.device_count()` 继续判断 CUDA 是否可用 -- 调整 CLI 初始化逻辑,避免 benchmark 可选依赖缺失时阻塞 - `vllm serve ...` 启动 - -这个修复用于解决如下启动失败: - -```text -RuntimeError: Failed to infer device type -``` - -## 构建镜像 - -在仓库根目录执行: +Run the following command from the repository root: ```bash -docker build -t bi_150_vllm:0.16.1 . +docker build --pull=false \ + --build-arg BASE_IMAGE=registry.iluvatar.com.cn:10443/customer/sz/vllm0.17.0-4.4.0-x86:v4.1 \ + -t bi_150_vllm:0.17.0 \ + . ``` -## 启动镜像 +## Verify ```bash -docker run -dit \ - --name iluvatar_test \ - -p 38047:8000 \ - --privileged \ - -v /lib/modules:/lib/modules \ - -v /dev:/dev \ - -v /usr/src:/usr/src \ - -v /mnt/gpfs/leaderboard/modelHubXC/Amu/t1-1.5B:/model \ - -e CUDA_VISIBLE_DEVICES=0 \ - --entrypoint vllm \ - bi_150_vllm:0.16.1 \ - serve /model \ - --port 8000 \ - --served-model-name llm \ - --max-model-len 2048 \ - --enforce-eager \ - --trust-remote-code \ - -tp 1 +docker run --rm -it bi_150_vllm:0.17.0 \ + python3 -c "import vllm; print(vllm.__file__); print(vllm.__version__)" ``` + +## Notes + +- This is an overlay-style repository, not the original upstream git source tree. +- The Docker image keeps the vendor runtime stack and only replaces the Python package files. diff --git a/vllm-0.16.1rc0+corex.4.4.0.dist-info/direct_url.json b/vllm-0.16.1rc0+corex.4.4.0.dist-info/direct_url.json deleted file mode 100644 index 3a7409e..0000000 --- a/vllm-0.16.1rc0+corex.4.4.0.dist-info/direct_url.json +++ /dev/null @@ -1 +0,0 @@ -{"archive_info": {"hash": "sha256=f19c1c4880bc4199fc95de8d77590fcdb4912b1fad9bf939883005812f2338c1", "hashes": {"sha256": "f19c1c4880bc4199fc95de8d77590fcdb4912b1fad9bf939883005812f2338c1"}}, "url": "file:///workspace/vllm-0.16.1rc0%2Bcorex.4.4.0-py3-none-any.whl"} \ No newline at end of file diff --git a/vllm-0.16.1rc0+corex.4.4.0.dist-info/INSTALLER b/vllm-0.17.0+corex.20260420090923.dist-info/INSTALLER similarity index 100% rename from vllm-0.16.1rc0+corex.4.4.0.dist-info/INSTALLER rename to vllm-0.17.0+corex.20260420090923.dist-info/INSTALLER diff --git a/vllm-0.16.1rc0+corex.4.4.0.dist-info/METADATA b/vllm-0.17.0+corex.20260420090923.dist-info/METADATA similarity index 95% rename from vllm-0.16.1rc0+corex.4.4.0.dist-info/METADATA rename to vllm-0.17.0+corex.20260420090923.dist-info/METADATA index 9eab02e..15c009c 100644 --- a/vllm-0.16.1rc0+corex.4.4.0.dist-info/METADATA +++ b/vllm-0.17.0+corex.20260420090923.dist-info/METADATA @@ -1,9 +1,9 @@ Metadata-Version: 2.4 Name: vllm -Version: 0.16.1rc0+corex.4.4.0 +Version: 0.17.0+corex.20260420090923 Summary: A high-throughput and memory-efficient inference and serving engine for LLMs Author: vLLM Team -License-Expression: Apache-2.0 +License: Apache-2.0 Project-URL: Homepage, https://github.com/vllm-project/vllm Project-URL: Documentation, https://docs.vllm.ai/en/latest/ Project-URL: Slack, https://slack.vllm.ai/ @@ -23,7 +23,7 @@ Requires-Dist: regex Requires-Dist: cachetools Requires-Dist: psutil Requires-Dist: sentencepiece -Requires-Dist: numpy==1.26.4 +Requires-Dist: numpy Requires-Dist: requests>=2.26.0 Requires-Dist: tqdm Requires-Dist: blake3 @@ -33,7 +33,7 @@ Requires-Dist: tokenizers>=0.21.1 Requires-Dist: protobuf!=6.30.*,!=6.31.*,!=6.32.*,!=6.33.0.*,!=6.33.1.*,!=6.33.2.*,!=6.33.3.*,!=6.33.4.*,>=5.29.6 Requires-Dist: fastapi[standard]>=0.115.0 Requires-Dist: aiohttp>=3.13.3 -Requires-Dist: openai>=1.99.1 +Requires-Dist: openai<2.25.0,>=1.99.1 Requires-Dist: pydantic>=2.12.0 Requires-Dist: prometheus_client>=0.18.0 Requires-Dist: pillow @@ -52,6 +52,7 @@ Requires-Dist: pyzmq>=25.0.0 Requires-Dist: msgspec Requires-Dist: gguf>=0.17.0 Requires-Dist: mistral_common[image]>=1.9.1 +Requires-Dist: opencv-python-headless>=4.13.0 Requires-Dist: pyyaml Requires-Dist: six>=1.16.0; python_version > "3.11" Requires-Dist: setuptools<81.0.0,>=77.0.3; python_version > "3.11" @@ -76,6 +77,7 @@ Requires-Dist: opentelemetry-sdk>=1.27.0 Requires-Dist: opentelemetry-api>=1.27.0 Requires-Dist: opentelemetry-exporter-otlp>=1.27.0 Requires-Dist: opentelemetry-semantic-conventions-ai>=0.4.1 +Requires-Dist: kaldi-native-fbank>=1.18.7 Requires-Dist: numba==0.61.2 Requires-Dist: ray[cgraph]>=2.48.0 Provides-Extra: bench @@ -84,6 +86,7 @@ Requires-Dist: matplotlib; extra == "bench" Requires-Dist: seaborn; extra == "bench" Requires-Dist: datasets; extra == "bench" Requires-Dist: scipy; extra == "bench" +Requires-Dist: plotly; extra == "bench" Provides-Extra: tensorizer Requires-Dist: tensorizer==2.10.1; extra == "tensorizer" Provides-Extra: fastsafetensors @@ -97,6 +100,10 @@ Requires-Dist: soundfile; extra == "audio" Requires-Dist: mistral_common[audio]; extra == "audio" Provides-Extra: video Provides-Extra: flashinfer +Provides-Extra: petit-kernel +Requires-Dist: petit-kernel; extra == "petit-kernel" +Provides-Extra: helion +Requires-Dist: helion; extra == "helion" Provides-Extra: otel Requires-Dist: opentelemetry-sdk>=1.26.0; extra == "otel" Requires-Dist: opentelemetry-api>=1.26.0; extra == "otel" diff --git a/vllm-0.16.1rc0+corex.4.4.0.dist-info/RECORD b/vllm-0.17.0+corex.20260420090923.dist-info/RECORD similarity index 87% rename from vllm-0.16.1rc0+corex.4.4.0.dist-info/RECORD rename to vllm-0.17.0+corex.20260420090923.dist-info/RECORD index d7a8649..4b5fab1 100644 --- a/vllm-0.16.1rc0+corex.4.4.0.dist-info/RECORD +++ b/vllm-0.17.0+corex.20260420090923.dist-info/RECORD @@ -1,17 +1,16 @@ ../../../bin/vllm,sha256=X5a5Mk4d820sk-Ykhco4CdCc8Bp3-tHJuyboP7UXkgw,172 -vllm-0.16.1rc0+corex.4.4.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 -vllm-0.16.1rc0+corex.4.4.0.dist-info/METADATA,sha256=Moq2RINlNwYeWq4WfWK1brgpmC_yY_sdEEEWz8is5uQ,9243 -vllm-0.16.1rc0+corex.4.4.0.dist-info/RECORD,, -vllm-0.16.1rc0+corex.4.4.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 -vllm-0.16.1rc0+corex.4.4.0.dist-info/WHEEL,sha256=YCfwYGOYMi5Jhw2fU4yNgwErybb2IX5PEwBKV4ZbdBo,91 -vllm-0.16.1rc0+corex.4.4.0.dist-info/direct_url.json,sha256=-_DIKKgt63x-HOusl6PX9Hvuz52Kdn6rtr73dXXPyUQ,265 -vllm-0.16.1rc0+corex.4.4.0.dist-info/entry_points.txt,sha256=N3YeH_2RPW4YWbdFCRGvTmJ1qSzLsxllfCjpUf5e4UA,276 -vllm-0.16.1rc0+corex.4.4.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357 -vllm-0.16.1rc0+corex.4.4.0.dist-info/top_level.txt,sha256=fAgb8Pt4zQoKTUA3ZnKEIgcjh0L97_dwEjYDTL5MEEo,5 -vllm/__init__.py,sha256=fS_Nsr5gWhdic4rOVRUVgy1PvdSFt-iIxjh_wb1-KHY,3881 +vllm-0.17.0+corex.20260420090923.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +vllm-0.17.0+corex.20260420090923.dist-info/METADATA,sha256=QtgKJZW0hL48Lkk1ylapaZSJL0FeD6duFwiioo_yu40,9512 +vllm-0.17.0+corex.20260420090923.dist-info/RECORD,, +vllm-0.17.0+corex.20260420090923.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +vllm-0.17.0+corex.20260420090923.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92 +vllm-0.17.0+corex.20260420090923.dist-info/direct_url.json,sha256=dfda9jaQTAyViphUkyn_78MzNasJkr2bA6plpqmvM-k,309 +vllm-0.17.0+corex.20260420090923.dist-info/entry_points.txt,sha256=N3YeH_2RPW4YWbdFCRGvTmJ1qSzLsxllfCjpUf5e4UA,276 +vllm-0.17.0+corex.20260420090923.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357 +vllm-0.17.0+corex.20260420090923.dist-info/top_level.txt,sha256=fAgb8Pt4zQoKTUA3ZnKEIgcjh0L97_dwEjYDTL5MEEo,5 +vllm/__init__.py,sha256=_TcI5aE6vpjFZq_XmmsZh873DAVIuxnBIY_2p_nUM0Y,3661 vllm/__pycache__/__init__.cpython-312.pyc,, vllm/__pycache__/_aiter_ops.cpython-312.pyc,, -vllm/__pycache__/_bc_linter.cpython-312.pyc,, vllm/__pycache__/_custom_ops.cpython-312.pyc,, vllm/__pycache__/_oink_ops.cpython-312.pyc,, vllm/__pycache__/_xpu_ops.cpython-312.pyc,, @@ -34,9 +33,8 @@ vllm/__pycache__/scripts.cpython-312.pyc,, vllm/__pycache__/sequence.cpython-312.pyc,, vllm/__pycache__/tasks.cpython-312.pyc,, vllm/__pycache__/version.cpython-312.pyc,, -vllm/_aiter_ops.py,sha256=giwrq3xY5REN8-dp_FuSq9yYPZ0uHz7kUo_mSHpVlzY,56984 -vllm/_bc_linter.py,sha256=2F7ZE_cba80XNDfbrZKVBunazE7ZDxVMtpOUdEFdH3w,1105 -vllm/_custom_ops.py,sha256=B3_ssd1AKErsUByn6nzMfwZttRGooGspgHKK76cbuNE,127805 +vllm/_aiter_ops.py,sha256=3m_hqOTaJaCJKHGCvWmmroU7GyVlzogqwxVdz2rlKFQ,61895 +vllm/_custom_ops.py,sha256=B1njqipvzncK3HpOspkwYyQ9-CL1tubXDbePbDR8YUs,128364 vllm/_oink_ops.py,sha256=AcvOGv-Yr-4MU9AkUU3G9B2kTs2c_TAHZGKTloSCfgM,3057 vllm/_xpu_ops.py,sha256=Mdpuen6oVkGmy0At4rrzdBHBEicEfquh0tqrIkGR8kw,5287 vllm/assets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 @@ -55,10 +53,11 @@ vllm/benchmarks/__pycache__/__init__.cpython-312.pyc,, vllm/benchmarks/__pycache__/datasets.cpython-312.pyc,, vllm/benchmarks/__pycache__/latency.cpython-312.pyc,, vllm/benchmarks/__pycache__/mm_processor.cpython-312.pyc,, +vllm/benchmarks/__pycache__/plot.cpython-312.pyc,, vllm/benchmarks/__pycache__/serve.cpython-312.pyc,, vllm/benchmarks/__pycache__/startup.cpython-312.pyc,, vllm/benchmarks/__pycache__/throughput.cpython-312.pyc,, -vllm/benchmarks/datasets.py,sha256=2AwzVH9a_SW3F6_bYtQFDU04TzVZFE3NqDZH9uY_2x0,128835 +vllm/benchmarks/datasets.py,sha256=hbTjt0B-81rTZMzUVO_c-poGdZFEWC1aOjdh4_haQeI,129511 vllm/benchmarks/latency.py,sha256=S-LAXVvNOdfRJRVj4KtX0WVRIdyD6-l7BOnzCfkB_1I,5837 vllm/benchmarks/lib/__init__.py,sha256=BlbNrv7ga0LXYeHLKKiey2woiGh5MkW36g7QG6L6BN0,142 vllm/benchmarks/lib/__pycache__/__init__.cpython-312.pyc,, @@ -69,7 +68,8 @@ vllm/benchmarks/lib/endpoint_request_func.py,sha256=v3Hcx5fBVsjgVDDxzg1T3A7xQezc vllm/benchmarks/lib/ready_checker.py,sha256=Hth9XB4BZsleV_vs0GiKBxb7bc-2gk0BJ4dR14rHtMA,2607 vllm/benchmarks/lib/utils.py,sha256=QZ45aYXPSbqa_V-WY2tOjd1b3KsEo8X7aOM6qSXNAPk,4248 vllm/benchmarks/mm_processor.py,sha256=gqIa5hIoYcBQ7Yg50gda0c24TIhGySKhKEw5Y2AbQH8,17883 -vllm/benchmarks/serve.py,sha256=jthwgqL4yjHY4AKU6BBLfGlpy61OfyvppgiPbWoey_Y,67322 +vllm/benchmarks/plot.py,sha256=UZF66p7jp-N2H53nwKQGeoTzO8085LyjvMAZEXWRmOY,9975 +vllm/benchmarks/serve.py,sha256=8qPEljDYvDJ6IStp0sUWrNhCEkr6eNv3BgJ6MHB5HtI,72360 vllm/benchmarks/startup.py,sha256=L7yRKATwX4CUL1DqlugNCiysFEhLA3c7BVaa4Bp8ldU,11743 vllm/benchmarks/sweep/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/benchmarks/sweep/__pycache__/__init__.cpython-312.pyc,, @@ -78,18 +78,18 @@ vllm/benchmarks/sweep/__pycache__/param_sweep.cpython-312.pyc,, vllm/benchmarks/sweep/__pycache__/plot.cpython-312.pyc,, vllm/benchmarks/sweep/__pycache__/plot_pareto.cpython-312.pyc,, vllm/benchmarks/sweep/__pycache__/serve.cpython-312.pyc,, -vllm/benchmarks/sweep/__pycache__/serve_sla.cpython-312.pyc,, +vllm/benchmarks/sweep/__pycache__/serve_workload.cpython-312.pyc,, vllm/benchmarks/sweep/__pycache__/server.cpython-312.pyc,, vllm/benchmarks/sweep/__pycache__/startup.cpython-312.pyc,, vllm/benchmarks/sweep/__pycache__/utils.cpython-312.pyc,, -vllm/benchmarks/sweep/cli.py,sha256=_dDV5dIvxKflDgaXURwvYj25pvMvjxwstgMJag3J4fQ,1455 +vllm/benchmarks/sweep/cli.py,sha256=55RWr5Nrln867HPRz5LMv7gCdOqi_mInZ_jIHlkOFbw,1485 vllm/benchmarks/sweep/param_sweep.py,sha256=7fOlYnrpnZc6kMxfAW203AFHADSugT-wlR3ajaRuWls,5594 -vllm/benchmarks/sweep/plot.py,sha256=ghtvWMH8OuHENd5cc82kzAIkSn17_BhiN2MosBLYMS0,19800 -vllm/benchmarks/sweep/plot_pareto.py,sha256=pMFyla6LX7dafUWSlHvuxa9fikilD91MG4Qx-uD07hE,11015 -vllm/benchmarks/sweep/serve.py,sha256=FQXjWRoK7mM5xjznHQJK4AeH_atheOx9hGq9oR6NhuE,14407 -vllm/benchmarks/sweep/serve_sla.py,sha256=IJIpj7W6aDaoN28xM4AZ5WD15nwMvecTfgFLMzHIVK8,9390 +vllm/benchmarks/sweep/plot.py,sha256=Whg0sFICzqu-j2YxGv9eylawYcsilCPkW0SbNFdu9IA,19833 +vllm/benchmarks/sweep/plot_pareto.py,sha256=NbDyL63B5XwF93injxzGNFm9tX6n0IPETSHy5IAO7VQ,10992 +vllm/benchmarks/sweep/serve.py,sha256=WLMVR3F6KM3Y7Nx2vy65BbuQm41bI2iCLwgnIQyIsPw,15634 +vllm/benchmarks/sweep/serve_workload.py,sha256=rw339djPijtbdLI5lTq4Hl_VYOrlfQiuyjRmCGFjZj0,10218 vllm/benchmarks/sweep/server.py,sha256=RBUnoU85t_a4yMTFbZhdz77IhJLhQbg5KkmQsBsiTSE,4619 -vllm/benchmarks/sweep/startup.py,sha256=rranpLrrfBMvwSj_r0jBLhQxFsmydyiHTrotezRuQFk,12302 +vllm/benchmarks/sweep/startup.py,sha256=gCHF4JNCz80wApQJvNvjKSdVPSKNKmgh_eJJpfYZeBo,13371 vllm/benchmarks/sweep/utils.py,sha256=YzhiqE4jmEN7TzUnzQXnlm8TLERkNjgoQ_SF6jJ8hB4,232 vllm/benchmarks/throughput.py,sha256=Kp14mTYV8cf1vcabXar15tWlCfRG5KdQ3HM3vJocwPE,34962 vllm/collect_env.py,sha256=lih4_OSz6OlOE6eWoAYhMCR5aTUs80ZziBLGEWlpQgo,27835 @@ -106,12 +106,12 @@ vllm/compilation/__pycache__/monitor.cpython-312.pyc,, vllm/compilation/__pycache__/partition_rules.cpython-312.pyc,, vllm/compilation/__pycache__/piecewise_backend.cpython-312.pyc,, vllm/compilation/__pycache__/wrapper.cpython-312.pyc,, -vllm/compilation/backends.py,sha256=sVkt0Ydf8ruosaV-KWAJXKoUUTkSifDXwIv4DimJTiI,43776 +vllm/compilation/backends.py,sha256=970-FYlOkVesieL04LjD-03xBMRTJzXuFMwn3aWLfwQ,41386 vllm/compilation/base_static_graph.py,sha256=XP6OBtuEqkLNQc3ZPr6TrwCzokLIARBCBXUw4pIrUjo,2007 -vllm/compilation/caching.py,sha256=Is0Ej-9Wu3tWXuLixlsgaXcw9FJIl1f238AWPiU9ggM,19889 -vllm/compilation/compiler_interface.py,sha256=El4c_sUAvKaqcK1pIMXxyhctjemvwU0sZnNT2OFhmoE,25450 -vllm/compilation/counter.py,sha256=WiKJRLbgqBchplITJ7FrpZ54f6A87829K8VJa1un6hA,1735 -vllm/compilation/cuda_graph.py,sha256=yhpcOz7IoRffPfBgVleLF7cKSbImH6McViTxcWGPhnk,13462 +vllm/compilation/caching.py,sha256=r4YuAboiTyvirfsKKREmVIQnmlaKNqsnZ_fS1LBvx4Q,20301 +vllm/compilation/compiler_interface.py,sha256=5nJ-WxRMyK_6BMWlRxrCrAJXJfTbi86C56yrtxM3Xo4,27254 +vllm/compilation/counter.py,sha256=fx9MGc70eex-sqBlL-IKdX-pmYBRmxgJKNOYi4zURTI,1854 +vllm/compilation/cuda_graph.py,sha256=srsYfNoiv6X_2gUiLzI9Wq6q3MeZqyJ_GLjW3gOBNS4,13345 vllm/compilation/decorators.py,sha256=xOb8eZ5JhusNy8Jl0lL0L3oHHtv4arwQFDkQ0OGHM7Y,26849 vllm/compilation/monitor.py,sha256=VkCQdUi9cLAG2DZnUj2bNfkOkRmSC-ooFB3JM9KrFmM,2202 vllm/compilation/partition_rules.py,sha256=hX0vsG3IceodNxfNAfvU5SBZhBANmLe3Kikz1v8c-k0,2296 @@ -136,13 +136,13 @@ vllm/compilation/passes/fusion/__pycache__/sequence_parallelism.cpython-312.pyc, vllm/compilation/passes/fusion/act_quant_fusion.py,sha256=HcSahkJO6nOhDJG0JYuKJ64qmEqC1jO4hnNgG95yl6o,6929 vllm/compilation/passes/fusion/allreduce_rms_fusion.py,sha256=EQIhoUTSenqef9YBWi1sp1-4iEOqeKVgSdIKWAGeWeQ,32060 vllm/compilation/passes/fusion/attn_quant_fusion.py,sha256=LXR4NrhQ7H9IFO_gEFNFaai9rGvAEenWpj5983uBwys,13228 -vllm/compilation/passes/fusion/collective_fusion.py,sha256=0dpP9QNj5DIx44jboZUMH96C7mH2fQBaE3VXWk1lNnI,15000 +vllm/compilation/passes/fusion/collective_fusion.py,sha256=IUOqTzy7gaH8_GML13sOjns-NjbVxP0s768dUjyr-a8,15000 vllm/compilation/passes/fusion/matcher_utils.py,sha256=9sT6rbkiFb9otMrqEhy1BHmbs-FOPOE2TfpGW5p05eQ,15834 vllm/compilation/passes/fusion/qk_norm_rope_fusion.py,sha256=R6hOWdYz5nTAgeAPokPpVcDrcLNrFykHNUKYK72NSz8,8906 vllm/compilation/passes/fusion/rms_quant_fusion.py,sha256=unEcPpWRu3tsjDIxKjTFkPsd5LkGToFIHA_HQ9LtXds,22500 -vllm/compilation/passes/fusion/rocm_aiter_fusion.py,sha256=v1JTk9hCukpRlwF0j8DiQeP1-FGejRHe2w_HdTsxXmg,17267 +vllm/compilation/passes/fusion/rocm_aiter_fusion.py,sha256=SvOQog2ufjG7yrQ6IcKnnYkusWrFQ_dYtOT34rjZvpI,17091 vllm/compilation/passes/fusion/rope_kvcache_fusion.py,sha256=ZOxtqsNg-H8390X0G4Ouy-vV1F-G7jgGe7xPfYqYt3Q,8267 -vllm/compilation/passes/fusion/sequence_parallelism.py,sha256=4qGn7LENXs-9PxmRbI5fgmJFu2gDE1YRzGrZk7Etkgs,17753 +vllm/compilation/passes/fusion/sequence_parallelism.py,sha256=fy58E1FEgTq0LzjwkKeRdtdMPMoRrNn44jK-hyg7CIY,17666 vllm/compilation/passes/fx_utils.py,sha256=BvALST35BtfoeBhH0_H39Wrn7RaagLv7Dizm_kB3ujM,2681 vllm/compilation/passes/inductor_pass.py,sha256=IdgVtsX6yzMQiJ7ji00Fltu7qWkPaHQNuMyoUIhBahs,4003 vllm/compilation/passes/pass_manager.py,sha256=Xf51XSGWM806CT6fnTBUVuv2DCRHv1xBPXhguYsXfBg,7003 @@ -153,14 +153,14 @@ vllm/compilation/passes/utility/__pycache__/noop_elimination.cpython-312.pyc,, vllm/compilation/passes/utility/__pycache__/post_cleanup.cpython-312.pyc,, vllm/compilation/passes/utility/__pycache__/scatter_split_replace.cpython-312.pyc,, vllm/compilation/passes/utility/__pycache__/split_coalescing.cpython-312.pyc,, -vllm/compilation/passes/utility/fix_functionalization.py,sha256=7lRG0kQhM_XY07dztxOA_oH7HpJrcfWPcl6EidIYX-A,12785 +vllm/compilation/passes/utility/fix_functionalization.py,sha256=gzVk0IyuYqhLixrRyyhIAmHOAi_Pl-mhOW2M-6K6iHI,13021 vllm/compilation/passes/utility/noop_elimination.py,sha256=5IfkneFTFC_vK37tgGSnrAwvWh4DAJ9rgMeqgDH8A-k,5287 vllm/compilation/passes/utility/post_cleanup.py,sha256=HDixT8IJFly8-Mp2OhesTlKLl_WbVz5dUt8szn11jeM,743 vllm/compilation/passes/utility/scatter_split_replace.py,sha256=E4QCzDCXdooy6eW40aBUthduSUvdyGhOI8otxf3D_do,6168 vllm/compilation/passes/utility/split_coalescing.py,sha256=CwKs2eGtQ9ez4Si5FOlpuVgHc_KBgHiEbfvxY6i9CHE,2274 vllm/compilation/passes/vllm_inductor_pass.py,sha256=JcsztOiPeFIIdSeiKesbcsQBfbqnOt5mNqf9BKCXDpM,6783 -vllm/compilation/piecewise_backend.py,sha256=coB0urtV6sdcTvySBnZeWLPYU29qI08PFfRR6KLqJaI,13986 -vllm/compilation/wrapper.py,sha256=Ntbqm4hGTiIouXGUjfQPxHVQd3gxvaIue55r3yRoJ98,13050 +vllm/compilation/piecewise_backend.py,sha256=n49Pm2d0pmsmX8acy7IjqNJmTouY7vzS1_7qrJedjN8,13956 +vllm/compilation/wrapper.py,sha256=rTZkzKRt-k5xE77Cs-y54EqllTPeYuC0Lmo5OUT1jsk,15182 vllm/config/__init__.py,sha256=Johf-4RN5M25Yag9OQrJytrGprPizqln8CyoK0eEyjg,3810 vllm/config/__pycache__/__init__.cpython-312.pyc,, vllm/config/__pycache__/attention.cpython-312.pyc,, @@ -188,9 +188,9 @@ vllm/config/__pycache__/structured_outputs.cpython-312.pyc,, vllm/config/__pycache__/utils.cpython-312.pyc,, vllm/config/__pycache__/vllm.cpython-312.pyc,, vllm/config/__pycache__/weight_transfer.cpython-312.pyc,, -vllm/config/attention.py,sha256=9tS0q7u4dpz51FUceowhQnBj5Tolxe788U3m-mxGV60,2518 +vllm/config/attention.py,sha256=F93F9gXDE9GuB-lcqQYE4QgLASkgJqLr9vowGjkATNQ,2525 vllm/config/cache.py,sha256=eWUIia30-vZZxYWDI_hJQvxX7lDL0JzFHLnz3Pja48c,11993 -vllm/config/compilation.py,sha256=KePNi2O6PVBGiAoGMfVPWpuu0LWlHvr5SdLg1lIhnRo,50769 +vllm/config/compilation.py,sha256=1Mc_PGxVpIkwj84JECS_oZtqZToDKRAJwyqyqc5ZJ80,51104 vllm/config/device.py,sha256=NlIM9KMfuETmlvp_JHhxMrvKDsmoO3_-wkrMwV08M2M,2719 vllm/config/ec_transfer.py,sha256=6Cc3nmfPrfcFSkDHOmO_1QHoV7FhzUi8QtZna-G4-pw,3841 vllm/config/kernel.py,sha256=bCrlswXTw1wRJd9AbtSPyqiQnKzwBPHOKJ5H1yM9zA8,2694 @@ -198,31 +198,32 @@ vllm/config/kv_events.py,sha256=ZB7tW09-WdmXAC0uXnSSgepOzQLn3yqVQam2_10MRG4,1591 vllm/config/kv_transfer.py,sha256=-CnBqhNYnEEIhLTtRRBTNKpPe-gHcga7qEzcfAd3h0k,4254 vllm/config/load.py,sha256=qlVIok6eKmoRjEe63KXJheU_W59xn8Dh7tI7_qyR8Ck,5697 vllm/config/lora.py,sha256=igUl4RE5B38Q46JONF0HZNaWpiCsFKdKM_Oj7kron6k,4473 -vllm/config/model.py,sha256=AMS9fGA722XhhOg84AX8AKujBs07U5mEiScV2eh7I7U,83285 +vllm/config/model.py,sha256=uoGS6kmKizZJcC8Vx_6f18SJHy-jQOliwLVLzySI9w0,83647 vllm/config/model_arch.py,sha256=Z10j3I3r851inz9mivcErhNMtINjAMMGpaMpb318PKQ,1658 vllm/config/multimodal.py,sha256=2zBq1kXtUUoZ6kQ_c6TxtUHwnxxxJBQ7HzmoeFcaDpw,10917 vllm/config/observability.py,sha256=TN02f-Clehu5Av-FfSrGXPVHIn2MIobJfaTc5vo9Fps,6402 vllm/config/offload.py,sha256=by2z2aJzBYhHPVFVgRaQxiQbFnoh3LYvu9nrtjht5Cw,6522 -vllm/config/parallel.py,sha256=QH_NWb2FrERrfN5tbI5oE8KVFka_nlPpDYWkAdyBuAU,29859 +vllm/config/parallel.py,sha256=b0rHxXQF9lwzbURqjJjkO-NtcUofHslOA08jhxaOZFk,35453 vllm/config/pooler.py,sha256=v3OFT5ep3tjWwsXQIJDngXybuqFDaMSm2LrMCl3sdrI,5382 vllm/config/profiler.py,sha256=3tFoS0peSYRgpkZUIfDUrEMU0yJjvothMcFRyLKEDxU,4799 vllm/config/scheduler.py,sha256=dEUbBOBjvjVN8XlrYQu17zx4PBYxEF9QyFZpr8z86ws,12671 -vllm/config/speculative.py,sha256=luo8LX7UCcCzs1MHBAcJ_aw28rqzS8_Fz5GSJIddU6I,35030 +vllm/config/speculative.py,sha256=j8qwVoiZ4E9L1-vxcrnFMr7-tnHL_ei_kt4FAA7QQNk,42136 vllm/config/speech_to_text.py,sha256=mXSpXPUVxjxGrEz2xCLo1gllx3SYybIiblxphCXGcSE,1560 vllm/config/structured_outputs.py,sha256=wCVC9wuK7vaYo0ZbQuv45s-y6djocq75Qof_9Po5p9I,3296 vllm/config/utils.py,sha256=0swsZ-Ddp_mKhemVKjG3Kh9OcaaosnyJv2NBtodk0Go,15108 -vllm/config/vllm.py,sha256=uFE5EZ91DXFNq_eYhnR0Ea3Za1BwQkDeSrnEE8XSFsE,76962 -vllm/config/weight_transfer.py,sha256=BRBCpStnmCHbCw_5v8c6qYO3Kr-W61Cf9ElIa1Q_wOI,363 +vllm/config/vllm.py,sha256=IetpfDsD5qBAyzWS_GguYi0K_D7EmCLpMT65RsMwqlw,79354 +vllm/config/weight_transfer.py,sha256=iZGfB2dMA5vdL4OhWoMeWyuBoGJxyUu5JTxrIxB9ny8,370 vllm/connections.py,sha256=66GGcz2MO5RtwBp4NinqL355v5lJq8EjB9cMAB7eZtI,5352 vllm/device_allocator/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/device_allocator/__pycache__/__init__.cpython-312.pyc,, vllm/device_allocator/__pycache__/cumem.cpython-312.pyc,, -vllm/device_allocator/cumem.py,sha256=2itL8dIqAIY8tkeNvFh6jHRqN6aZuYKXhJ4d1THDvxU,11525 +vllm/device_allocator/cumem.py,sha256=4zAbmNjZxkOFhK__yKLqLUuCM4lMGRqxDDwleZ1hhRI,11632 vllm/distributed/__init__.py,sha256=l_KMMLMmYq-MqxY5OeTuQjA38RYwdS4TtVy8JafQq3E,191 vllm/distributed/__pycache__/__init__.cpython-312.pyc,, vllm/distributed/__pycache__/communication_op.cpython-312.pyc,, vllm/distributed/__pycache__/kv_events.cpython-312.pyc,, vllm/distributed/__pycache__/parallel_state.cpython-312.pyc,, +vllm/distributed/__pycache__/stateless_coordinator.cpython-312.pyc,, vllm/distributed/__pycache__/utils.cpython-312.pyc,, vllm/distributed/communication_op.py,sha256=PD-Kxg2zjs6JHTncHi9vGUetQ4aFiT0uV3fQcjIL0bg,1323 vllm/distributed/device_communicators/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 @@ -245,16 +246,16 @@ vllm/distributed/device_communicators/__pycache__/shm_broadcast.cpython-312.pyc, vllm/distributed/device_communicators/__pycache__/shm_object_storage.cpython-312.pyc,, vllm/distributed/device_communicators/__pycache__/symm_mem.cpython-312.pyc,, vllm/distributed/device_communicators/__pycache__/xpu_communicator.cpython-312.pyc,, -vllm/distributed/device_communicators/all2all.py,sha256=CXZN9HaxCO5MI9njTUKD0Xe2s0emqR5g6-cp9yUrbNQ,23610 -vllm/distributed/device_communicators/all_reduce_utils.py,sha256=FXN49WKCqJM3WKxMlQntN1xVwrlss5IjrGRXjfIaoYY,12324 -vllm/distributed/device_communicators/base_device_communicator.py,sha256=2R29s4AZIQuagNYinQtWTOTkkCiLO3P4PbpjY55UQA4,13373 -vllm/distributed/device_communicators/cpu_communicator.py,sha256=ULLbY3DLSpqsLJg3znYfRTM_ooI-uRcD2WQOQRMRAuM,10115 -vllm/distributed/device_communicators/cuda_communicator.py,sha256=BXe2fjCkTenM1tAs30Bzl0G2dB6SESIMxW9vkFVweFI,17214 +vllm/distributed/device_communicators/all2all.py,sha256=b-lzmlCCIZYn8jhL_yr6yRr1vkArUq4bI9qgRqIBAAE,21093 +vllm/distributed/device_communicators/all_reduce_utils.py,sha256=tKM6UaaviabxhpA-t4EGclm23ZEYxLHfxob6X38nktA,13965 +vllm/distributed/device_communicators/base_device_communicator.py,sha256=tU-YjSUFT2FkgoYR7TzMLhU6hxP-1hE-EsTTIPSL-m8,14582 +vllm/distributed/device_communicators/cpu_communicator.py,sha256=p6YxyWRfgKGJdA38G5UIK0tI79bCVvtPYoOBQW7q4qA,11137 +vllm/distributed/device_communicators/cuda_communicator.py,sha256=Iy2FjDDm4QXj8QErGZVDU2Ew4iXGwlu4HwutmCSnjOg,16797 vllm/distributed/device_communicators/cuda_wrapper.py,sha256=N1znzgx_UgCm5XLXVnKdtmtMcguhmFE0B6lxCPoVcGY,7483 vllm/distributed/device_communicators/custom_all_reduce.py,sha256=qwIMODFiNFmXGFWjr9wyzXZ6dl6WASbya1wEqAEmXP0,12857 vllm/distributed/device_communicators/flashinfer_all_reduce.py,sha256=x1gkMrLjnnNgzk50L6r7JaTBjWdpRIwzOMpxKw_B5PM,7846 vllm/distributed/device_communicators/mnnvl_compat.py,sha256=zgi7ly54YOyLax2TgmlHn5R_iQ6JsdQwxqixY3Hfcqw,1207 -vllm/distributed/device_communicators/pynccl.py,sha256=IZRiSwMlWK-yS5cwGBPr7kn4Ov4XSVhA61aVP5ggpxs,13902 +vllm/distributed/device_communicators/pynccl.py,sha256=6rk6TKGVUMqAXOK9Bj_70ug9s9iple6arbDlKe2WUus,14954 vllm/distributed/device_communicators/pynccl_allocator.py,sha256=D4IyD8MX76loodZhzbsHSmtT0LLoyhcZbIMFN_BZVMU,6260 vllm/distributed/device_communicators/pynccl_wrapper.py,sha256=Md9XRm8EPyhzErzSp8xyN8r86psxZLG7sa-w3Uvw5HM,19098 vllm/distributed/device_communicators/quick_all_reduce.py,sha256=BU5G_-FaMmper_qOYtsA8ViItWNE0u8XUlwoORUX1cw,10903 @@ -275,14 +276,22 @@ vllm/distributed/ec_transfer/ec_connector/base.py,sha256=Nx_zH02nSLHgSicRDx0E4x9 vllm/distributed/ec_transfer/ec_connector/example_connector.py,sha256=tCi_MwxdnEgJuzJ21sJzJOTqlrcxTYLDxixiopzx0hU,7215 vllm/distributed/ec_transfer/ec_connector/factory.py,sha256=LRo-mSc6_mm5a9WjcHTxzSO7Lds7HrLHYLdkCpWfmKs,3070 vllm/distributed/ec_transfer/ec_transfer_state.py,sha256=60Iu-N4qVnjSa2kXaysANPa77HX-cYomPTu7GRk5Evw,1152 +vllm/distributed/elastic_ep/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +vllm/distributed/elastic_ep/__pycache__/__init__.cpython-312.pyc,, +vllm/distributed/elastic_ep/__pycache__/elastic_execute.cpython-312.pyc,, +vllm/distributed/elastic_ep/__pycache__/elastic_state.cpython-312.pyc,, +vllm/distributed/elastic_ep/__pycache__/standby_state.cpython-312.pyc,, +vllm/distributed/elastic_ep/elastic_execute.py,sha256=IfwBwevEgROygmxlOsHh0lyY1SKMRQdk4-jHPlVVp70,22059 +vllm/distributed/elastic_ep/elastic_state.py,sha256=xDT71-jpKGvC9-iKBt0HZ9osSvZIhY450yq-Fv2Km6M,21767 +vllm/distributed/elastic_ep/standby_state.py,sha256=t7A9rZOGCtUXxwFXebXwd9v1nGkcuxR0Hb1cb4vqPK0,3495 vllm/distributed/eplb/__init__.py,sha256=G5wu4iq7WIjDVMWqAfCUTyU-XiThNdAZd4OYu-SVpig,154 vllm/distributed/eplb/__pycache__/__init__.cpython-312.pyc,, vllm/distributed/eplb/__pycache__/async_worker.cpython-312.pyc,, vllm/distributed/eplb/__pycache__/eplb_state.cpython-312.pyc,, vllm/distributed/eplb/__pycache__/eplb_utils.cpython-312.pyc,, vllm/distributed/eplb/__pycache__/rebalance_execute.cpython-312.pyc,, -vllm/distributed/eplb/async_worker.py,sha256=GZ1i90SmMYFqxaPmuKK2QH-MTVGbzUC3ICWrARxIOLo,7816 -vllm/distributed/eplb/eplb_state.py,sha256=tmOe5JfvpHvgjOf_EBXP2cqp7ufhi-bzSpdaMSoRZVc,48615 +vllm/distributed/eplb/async_worker.py,sha256=Wz5JRddZ_tI-l-y4m7IxU5LBBy-ADJigTVwmGXFSAYQ,7618 +vllm/distributed/eplb/eplb_state.py,sha256=FXlKFof025ZjKIatiM7aPp6YVDOosydOxeYwGcz4lFA,43591 vllm/distributed/eplb/eplb_utils.py,sha256=h8WJ1iqulRVQbxtoxb9xg7MZ5Ax0vPQkvjOCoVoV3K8,2281 vllm/distributed/eplb/policy/__init__.py,sha256=2usmbIhEKDKxq4yMTMBkiRs4uem6AK6jJPkmyH3G034,542 vllm/distributed/eplb/policy/__pycache__/__init__.cpython-312.pyc,, @@ -290,8 +299,8 @@ vllm/distributed/eplb/policy/__pycache__/abstract.cpython-312.pyc,, vllm/distributed/eplb/policy/__pycache__/default.cpython-312.pyc,, vllm/distributed/eplb/policy/abstract.py,sha256=xkphHsTmcIZSkrw5_CbMhxDiZsDMSEYqVGL9QW4jHcE,1624 vllm/distributed/eplb/policy/default.py,sha256=TOe_1Vy5dxCau9rVn3uH9RQP7UXJJ2KjA8Ya0m9tDIs,16382 -vllm/distributed/eplb/rebalance_execute.py,sha256=ufQLtQp-RLGGORVVB8scSg2bZpYbk7SUNorE6_S3xQc,28133 -vllm/distributed/kv_events.py,sha256=_sUdT-GYOZPmzgcc7Qhkqm0NaDdJv4ZLxj0t_hOLqRA,16513 +vllm/distributed/eplb/rebalance_execute.py,sha256=zLj2KcKTjod--3GRQPaU79tkCgKX1oFCRfiRSMTIC0o,29492 +vllm/distributed/kv_events.py,sha256=f22G2ao2-PyXuA5PeHbV_mxgCTJO6rey6jqXnplHYjw,16658 vllm/distributed/kv_transfer/README.md,sha256=cKIw6vXYSxBlf0wWwO7haP82CX2oB2QzW0-RxZE5mT4,2007 vllm/distributed/kv_transfer/__init__.py,sha256=wb5OWOxpI7rmB6NVrSKJHHQPBWpp_3NYseEOo_0mgOo,552 vllm/distributed/kv_transfer/__pycache__/__init__.cpython-312.pyc,, @@ -303,23 +312,25 @@ vllm/distributed/kv_transfer/kv_connector/__pycache__/base.cpython-312.pyc,, vllm/distributed/kv_transfer/kv_connector/__pycache__/factory.cpython-312.pyc,, vllm/distributed/kv_transfer/kv_connector/__pycache__/utils.cpython-312.pyc,, vllm/distributed/kv_transfer/kv_connector/base.py,sha256=KuKixI9XMfNTMWanVED-kedkMyMFtdcT34QO26lweJ0,370 -vllm/distributed/kv_transfer/kv_connector/factory.py,sha256=C2YTjVvEC7r_P7x3wGIGuXR1in0nLWkQYg7noMDKJFk,7245 -vllm/distributed/kv_transfer/kv_connector/utils.py,sha256=FnRC_x34z_AMTN_ZzmvTXpG0xu3uOQBdRKZR3VVLioU,19179 +vllm/distributed/kv_transfer/kv_connector/factory.py,sha256=DhEKiTtBrvkL8MYZwUfJ1lfPHK7exTdDOltajTzis94,7443 +vllm/distributed/kv_transfer/kv_connector/utils.py,sha256=4gLVJsbKl8pIxMCk1bV3l4Socd9ex2H9vRUwB93k2_8,20481 vllm/distributed/kv_transfer/kv_connector/v1/__init__.py,sha256=W-tsytEv3P4VgZgtpvGDIcrKUOc6azJZ6_PI3Is__Mo,508 vllm/distributed/kv_transfer/kv_connector/v1/__pycache__/__init__.cpython-312.pyc,, vllm/distributed/kv_transfer/kv_connector/v1/__pycache__/base.cpython-312.pyc,, vllm/distributed/kv_transfer/kv_connector/v1/__pycache__/decode_bench_connector.cpython-312.pyc,, vllm/distributed/kv_transfer/kv_connector/v1/__pycache__/example_connector.cpython-312.pyc,, +vllm/distributed/kv_transfer/kv_connector/v1/__pycache__/example_hidden_states_connector.cpython-312.pyc,, vllm/distributed/kv_transfer/kv_connector/v1/__pycache__/lmcache_connector.cpython-312.pyc,, vllm/distributed/kv_transfer/kv_connector/v1/__pycache__/lmcache_mp_connector.cpython-312.pyc,, vllm/distributed/kv_transfer/kv_connector/v1/__pycache__/metrics.cpython-312.pyc,, vllm/distributed/kv_transfer/kv_connector/v1/__pycache__/multi_connector.cpython-312.pyc,, vllm/distributed/kv_transfer/kv_connector/v1/__pycache__/nixl_connector.cpython-312.pyc,, vllm/distributed/kv_transfer/kv_connector/v1/__pycache__/offloading_connector.cpython-312.pyc,, -vllm/distributed/kv_transfer/kv_connector/v1/base.py,sha256=bfDdRRgMWeKL8HsE6_Gi8WWmdS5XmCyWWsarXG9C4Xk,21225 +vllm/distributed/kv_transfer/kv_connector/v1/base.py,sha256=QuWuf-D4UysNmPAcigABKrBzlXcsrNxiY5Xi8HJnQh4,22102 vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py,sha256=bzlQ8_pgzmMviZaOnHsdNhlWH6TUALRsajNoBHsbV4A,15531 -vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py,sha256=8O6lyNeinE1BHJwAx8kjBpBLk106_CGcJ9jHLv2bp_I,16765 -vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py,sha256=zDBxiFNCjg5kX1XKYXmMfRze0iLidEJaIXIEEFBE1iw,12005 +vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py,sha256=5sFs135kFbU-WfJT796s3rg63XXqGL7mAHtZ5kUAnXc,17539 +vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py,sha256=UoXHU0LKUg06MbiZfDtqo0rYhYH0WxDlfHRlZoPfbW4,12484 +vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py,sha256=rp-i3iAaXfkyniz4kA6_PkvuH0uo1mWLaceFuuhrM50,12429 vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/__init__.py,sha256=k3tOaUj8P0IjtWcZ0CGxO3iE8IuALduzLpW5ejYJAFE,426 vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/__pycache__/__init__.cpython-312.pyc,, vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/__pycache__/multi_process_adapter.cpython-312.pyc,, @@ -334,7 +345,7 @@ vllm/distributed/kv_transfer/kv_connector/v1/mooncake/__init__.py,sha256=47DEQpj vllm/distributed/kv_transfer/kv_connector/v1/mooncake/__pycache__/__init__.cpython-312.pyc,, vllm/distributed/kv_transfer/kv_connector/v1/mooncake/__pycache__/mooncake_connector.cpython-312.pyc,, vllm/distributed/kv_transfer/kv_connector/v1/mooncake/__pycache__/mooncake_utils.cpython-312.pyc,, -vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py,sha256=3wBhN0pKAV-6ae59q7xoi3vwKSJU-SZ0fRFkP8ALPeA,51087 +vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py,sha256=wTGfVM2k2tCu8uqirfPLxBXryUx7bom6CesZJz7FL8o,52127 vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_utils.py,sha256=wSg8lscXSMpUESTGpDr8xFqiLgAkVsPJTXtkVIzurTA,4170 vllm/distributed/kv_transfer/kv_connector/v1/moriio/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/distributed/kv_transfer/kv_connector/v1/moriio/__pycache__/__init__.cpython-312.pyc,, @@ -344,8 +355,8 @@ vllm/distributed/kv_transfer/kv_connector/v1/moriio/__pycache__/moriio_engine.cp vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py,sha256=hAA-iozOOJOf-IdwuynQDB3imwtwbKGI_DLUGuI5l6Q,8898 vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_connector.py,sha256=uyL5lnmtIK1LQKTu4y0Br-IFLx1PDrSBBbQiGRuXfVI,59062 vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_engine.py,sha256=5vX3LF6AOwyJFCV3jwEBNIDi5Ai2cEC964ScgBYrgwM,21306 -vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py,sha256=AyIOBx0NEzJfiitEeMsn7e9y4d25VeBvnTTC02mZPpo,20165 -vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py,sha256=x2mCraQgUZgWM7scX9dLQZ9rZ4KQTvEjfyHefi2Qq2o,116288 +vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py,sha256=EnyCeJ3Yy_EIuVjfIidjpQfo6tnFf__WjmU0n-_Qekw,20856 +vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py,sha256=9SuIjfh79psb536pCgOHEkQH87ePtFiDB8QRs4iVd5Y,121294 vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py,sha256=GbNAYMx4Fw6nqLqy9PJTcaKPfNyVZgEy_kcGosIThV0,31212 vllm/distributed/kv_transfer/kv_connector/v1/p2p/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/distributed/kv_transfer/kv_connector/v1/p2p/__pycache__/__init__.cpython-312.pyc,, @@ -356,17 +367,20 @@ vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py,sha256=nY vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py,sha256=gdHYw55Qm1z8cz5cWoq76siI8tL6tewzYldcxNCG2s4,23390 vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py,sha256=A2mMTFKJ1zBYEo0gPhe166-ms1S_gAaL4Zs9tULMbTc,9225 vllm/distributed/kv_transfer/kv_transfer_state.py,sha256=js4h2XWbnmB8IoshDJBaGK9yc7bX10uB2D7PY3A3TEs,2277 -vllm/distributed/parallel_state.py,sha256=Xv9s57_NweT5g3Vg6X3XL4wlwTSNbTvLAf1vT1uMed4,71800 -vllm/distributed/utils.py,sha256=BLdN-Vn16TYG-wY4MTEMXcGIc6MIKbMSk_60a4pQFSM,21419 +vllm/distributed/parallel_state.py,sha256=CUE7LArKpC31tWDjrPVsWQn06JCX_59wluqrdWe3Wc8,77664 +vllm/distributed/stateless_coordinator.py,sha256=opc5FRC4IU4TZ1CsFni87nO8uvq_kdv2aXQ2fE51eC0,11625 +vllm/distributed/utils.py,sha256=0aLYrVvnGLdokC9mCRzL8kahYKtHlJfpciPUgol-_dk,23717 vllm/distributed/weight_transfer/__init__.py,sha256=BNylvNnKXwhGl-mBSwRKuAdi0sGXpZzaQRk4DHn1XZI,333 vllm/distributed/weight_transfer/__pycache__/__init__.cpython-312.pyc,, vllm/distributed/weight_transfer/__pycache__/base.cpython-312.pyc,, vllm/distributed/weight_transfer/__pycache__/factory.cpython-312.pyc,, +vllm/distributed/weight_transfer/__pycache__/ipc_engine.cpython-312.pyc,, vllm/distributed/weight_transfer/__pycache__/nccl_engine.cpython-312.pyc,, vllm/distributed/weight_transfer/__pycache__/packed_tensor.cpython-312.pyc,, -vllm/distributed/weight_transfer/base.py,sha256=WsY9kVujjSIvFiDtyNwx8D_LF8WCHA-t38NQKp1iieM,5084 -vllm/distributed/weight_transfer/factory.py,sha256=a6-Jk4YmVWSACbxqY9KmSmpx8yXzPg1_afVzhy15FUI,3972 -vllm/distributed/weight_transfer/nccl_engine.py,sha256=1z12EBYiXFamrRkQi_Q1f-gShrww20nN_8vOssEtRqQ,12459 +vllm/distributed/weight_transfer/base.py,sha256=9DmAtuLTxWJY2aRdHbZWvY-59P_1Iu1Ek6kG_WOIj-E,6245 +vllm/distributed/weight_transfer/factory.py,sha256=OCyp9reYTzMMJHLeqHisf9Z_6L98BejEJq79zhddGqo,4113 +vllm/distributed/weight_transfer/ipc_engine.py,sha256=0oOLGMNy_CgNiZGMNGRcyFDH68Pya-My4nixHHpnhNU,11019 +vllm/distributed/weight_transfer/nccl_engine.py,sha256=AJQIVm5-IjJ7YC8Amv0rFsLsyQN41Cml8qCdhtt6nts,13089 vllm/distributed/weight_transfer/packed_tensor.py,sha256=Sl28ATzJUAjQx30o_QEZFyqyQb7R_U2-jpOGCBkkX_I,8989 vllm/engine/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/engine/__pycache__/__init__.cpython-312.pyc,, @@ -374,7 +388,7 @@ vllm/engine/__pycache__/arg_utils.cpython-312.pyc,, vllm/engine/__pycache__/async_llm_engine.cpython-312.pyc,, vllm/engine/__pycache__/llm_engine.cpython-312.pyc,, vllm/engine/__pycache__/protocol.cpython-312.pyc,, -vllm/engine/arg_utils.py,sha256=5KYdKSnnoZOUZEOPv6nX9NxqfM0d8SCQVORYc-Qpkhc,94731 +vllm/engine/arg_utils.py,sha256=hG_lreAPllHnzuSHCSxstp6wLxmxWQaoxoo2g2eKF58,94861 vllm/engine/async_llm_engine.py,sha256=P3UcYRNf9T5GO7j-Nj00LVs2OYKbzl0ivavGILQS21s,284 vllm/engine/llm_engine.py,sha256=Ormr0eFokwGJd2DBWJzJGraI47sc3TPbBBprWEvyU0I,296 vllm/engine/protocol.py,sha256=R8vhuOZRFsOgyH2bXz-TYGr-kLCDYcg25PeSo51-194,6963 @@ -394,11 +408,11 @@ vllm/entrypoints/anthropic/__pycache__/__init__.cpython-312.pyc,, vllm/entrypoints/anthropic/__pycache__/api_router.cpython-312.pyc,, vllm/entrypoints/anthropic/__pycache__/protocol.cpython-312.pyc,, vllm/entrypoints/anthropic/__pycache__/serving.cpython-312.pyc,, -vllm/entrypoints/anthropic/api_router.py,sha256=VvFf7mLGh4EXm0mK537cPkxax6fUDh3vXeK_lWtZ-64,3079 -vllm/entrypoints/anthropic/protocol.py,sha256=nrxqThJiw762ZKd0m-rHe0hWnmdcm-Hs_D8hhBZaU14,4582 -vllm/entrypoints/anthropic/serving.py,sha256=MC1AyTerhhuMsoolVF_TRf-fyyTsru6KB4xxWVdAYc4,21312 +vllm/entrypoints/anthropic/api_router.py,sha256=-_NhIEW9OnukH5lE6copkgOS44Ob7NPAxbAM3mqjsnE,4591 +vllm/entrypoints/anthropic/protocol.py,sha256=IQOP0NQuIi2sdzbg_8hCXjjbVd5r--ByJTXJ-rIaCgw,5620 +vllm/entrypoints/anthropic/serving.py,sha256=HOzHRtIrVk5WoEtvRjxunaAVT_rmTAnZRqLCEq4XthI,33710 vllm/entrypoints/api_server.py,sha256=JUM2xq5yDLIVY6lDfhEtY002kQBo_YomSfmrrRzu7cE,5902 -vllm/entrypoints/chat_utils.py,sha256=XpNCh1EjCSTTd5TdFBFsazf4CqEMHKaPrzkKJTSuT38,56673 +vllm/entrypoints/chat_utils.py,sha256=f0X_hOTfyJb6Iav1LDI8RihXk7X3eO8ySBBRm8FVsGg,56904 vllm/entrypoints/cli/__init__.py,sha256=-dU_jpAUU9ksRsxbw8Leuhy-wPcRTp7YQp_q5-tQM1Y,828 vllm/entrypoints/cli/__pycache__/__init__.cpython-312.pyc,, vllm/entrypoints/cli/__pycache__/collect_env.cpython-312.pyc,, @@ -429,13 +443,13 @@ vllm/entrypoints/cli/collect_env.py,sha256=H9qpqasc2FRmAQUvmRV5WuCxJqdj-ouYAQ2CM vllm/entrypoints/cli/main.py,sha256=vHRDDzGjKK2ZE9EWyGn--jpfBWewDtiYFDab9srDc78,2468 vllm/entrypoints/cli/openai.py,sha256=CWzTEwp_nLWLXEPpvKlw7rZjiM-CSEnMaxic9s8IpIc,8134 vllm/entrypoints/cli/run_batch.py,sha256=8ChYNc3UbbCVEbtp_tTbgR32_jVq3WRJgO1RzGhP8yI,2242 -vllm/entrypoints/cli/serve.py,sha256=bNd6Mq7gyXw2hP89wKZSgk-KII_k6BmLJPW_oICtf4g,11424 +vllm/entrypoints/cli/serve.py,sha256=ZLRCfZidbjWzU9HXGAdaHGVg4cZQJt6FkqV8FWfnsOw,11829 vllm/entrypoints/cli/types.py,sha256=On70PbJbmYXWTQZaL3EbI7I-fd1eIE21FAhKuHnTcfo,812 vllm/entrypoints/constants.py,sha256=7rJGI5ZxBEJr2_7YjIJHzVaRX_9PChU7aOSYp95v6tE,356 -vllm/entrypoints/grpc_server.py,sha256=NrqiGJQr3C_3o41sYZ4AkkZ0Y3y136kQUTQ3G7PUTJs,17665 +vllm/entrypoints/grpc_server.py,sha256=HmPLZdHZuuXzY7Mg9J20l8D55ExVd74jL52-jMSuZE0,17994 vllm/entrypoints/launcher.py,sha256=iLd368aVcRHtb2i35cn99MZZVxPYVDpmbSKTvVRuNHM,6446 -vllm/entrypoints/llm.py,sha256=jJS8T-bhdeSTX0YS6E0qf3bXPMvxajaDRMUcoUeDJdI,85063 -vllm/entrypoints/logger.py,sha256=GeSbm6Y6FndcA5HU_PYmlFC-6xFe2oGmegmNUgcHMKA,2659 +vllm/entrypoints/llm.py,sha256=mmxbulxBDgC-Won2sSt7zbmHSIcgn2i_aZkMpbn8Kbg,85440 +vllm/entrypoints/logger.py,sha256=GZqaMQEyqkOPRCVpSdVe2NPD6HVV2SmKo0OLaSGy9-A,3297 vllm/entrypoints/mcp/__init__.py,sha256=48Vcuw16i9R99vM3iDVkRkX107yJtFeTtdWgLQE-XFg,107 vllm/entrypoints/mcp/__pycache__/__init__.cpython-312.pyc,, vllm/entrypoints/mcp/__pycache__/tool.cpython-312.pyc,, @@ -458,24 +472,24 @@ vllm/entrypoints/openai/chat_completion/__pycache__/protocol.cpython-312.pyc,, vllm/entrypoints/openai/chat_completion/__pycache__/serving.cpython-312.pyc,, vllm/entrypoints/openai/chat_completion/__pycache__/stream_harmony.cpython-312.pyc,, vllm/entrypoints/openai/chat_completion/api_router.py,sha256=XZsKmAyfurBNM5QzjCJk5Pv107cTH--Z8vi0NY3z4ds,3691 -vllm/entrypoints/openai/chat_completion/protocol.py,sha256=vt7RuTozB0X9j7jl1lG7VMSM0w63Q-Cple5Kgu99FF8,28983 -vllm/entrypoints/openai/chat_completion/serving.py,sha256=A-1WyPsqH2HGiPWysrlVKfvRc9H3qmIVlKP4InpJzLc,88403 +vllm/entrypoints/openai/chat_completion/protocol.py,sha256=CQw5DFLdgsBXDSp8MJSrvu4WjgOOi4tFADvb-67WTzQ,30517 +vllm/entrypoints/openai/chat_completion/serving.py,sha256=vot1wNgmHoKB7x_uRXI6-qX2Po9WZDqZEcvSNlm9xkE,89084 vllm/entrypoints/openai/chat_completion/stream_harmony.py,sha256=de1cKmdVEfJfkYVKPLpH_R2VEBC-Zefh4x8zfHkviWI,6158 -vllm/entrypoints/openai/cli_args.py,sha256=1BidcF3nh8AAdo0MoMDxCaiyzfQRFwkDbdNk0X83_iU,15969 +vllm/entrypoints/openai/cli_args.py,sha256=CMlFwCLd9kRWd7q4Llz5KLQocGeCwJNHSw5jWbchUvU,16302 vllm/entrypoints/openai/completion/__init__.py,sha256=48Vcuw16i9R99vM3iDVkRkX107yJtFeTtdWgLQE-XFg,107 vllm/entrypoints/openai/completion/__pycache__/__init__.cpython-312.pyc,, vllm/entrypoints/openai/completion/__pycache__/api_router.cpython-312.pyc,, vllm/entrypoints/openai/completion/__pycache__/protocol.cpython-312.pyc,, vllm/entrypoints/openai/completion/__pycache__/serving.cpython-312.pyc,, vllm/entrypoints/openai/completion/api_router.py,sha256=QDNpQA5gC1YjLRsTFZZEg7eT9EDyEqD9Cou4CcP2pKo,3641 -vllm/entrypoints/openai/completion/protocol.py,sha256=oRPTwHWqWzZ4OabxCGH-C7sKRFhLepRTHGGUttlDgCE,17864 +vllm/entrypoints/openai/completion/protocol.py,sha256=zMu9piJ9s0UxBJCtQRIXQ5NZBofDhBAPwF5wPMNChsI,19367 vllm/entrypoints/openai/completion/serving.py,sha256=i4uMF351n5j7wsNYfTsBwm8zXFLTTJ1XyaHQgR6NjBM,27370 vllm/entrypoints/openai/engine/__init__.py,sha256=48Vcuw16i9R99vM3iDVkRkX107yJtFeTtdWgLQE-XFg,107 vllm/entrypoints/openai/engine/__pycache__/__init__.cpython-312.pyc,, vllm/entrypoints/openai/engine/__pycache__/protocol.cpython-312.pyc,, vllm/entrypoints/openai/engine/__pycache__/serving.cpython-312.pyc,, vllm/entrypoints/openai/engine/protocol.py,sha256=WpE0x_RUno9Gvs4kjiGgQB30bmiGV29Z0zKdgG9uNx8,9841 -vllm/entrypoints/openai/engine/serving.py,sha256=RTIOkuMZSC0p_EhvX6Eu5K3GqDfHA6SdqbNK36J_3xI,47254 +vllm/entrypoints/openai/engine/serving.py,sha256=GCcYWRN1ta_FTPw61Sxa4tQwQYDNXJAcr1ebtYnrI0c,46702 vllm/entrypoints/openai/generate/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/entrypoints/openai/generate/__pycache__/__init__.cpython-312.pyc,, vllm/entrypoints/openai/generate/__pycache__/api_router.cpython-312.pyc,, @@ -517,8 +531,8 @@ vllm/entrypoints/openai/responses/__pycache__/utils.cpython-312.pyc,, vllm/entrypoints/openai/responses/api_router.py,sha256=Y8p4WT9POGG7nQstsYE1-T6x3U13T9O971eN7S71f3U,4748 vllm/entrypoints/openai/responses/context.py,sha256=l7JXOgngdBqd84LuJWUTlUT0qusm34rwbHnHJBDp_GM,35185 vllm/entrypoints/openai/responses/harmony.py,sha256=iqXkCyKNeIwxYs0l745c3tsmOnA0uj5Qs3RQCtzfA20,20068 -vllm/entrypoints/openai/responses/protocol.py,sha256=yzPLhSjZyY-8vBixnltO7FsJv1w4V_K-77GknPvt2EI,23685 -vllm/entrypoints/openai/responses/serving.py,sha256=VtogA39uqpHG9OQvFAHjysHPKLQSl31wH3iVbIA2SYU,70206 +vllm/entrypoints/openai/responses/protocol.py,sha256=AuXLRUo_OpWW9wF8Cr59C1__Wiy4SqPhXvf8_LO3mO4,23945 +vllm/entrypoints/openai/responses/serving.py,sha256=y7odWu6baXWx25MNLLQvaYij8sQZYIviDZgGhanX7F4,72058 vllm/entrypoints/openai/responses/streaming_events.py,sha256=ynVNwoBd4x38f0n9SNTCQyymzuFn_V3tnP6GhnJCvuQ,27220 vllm/entrypoints/openai/responses/utils.py,sha256=NSkKEUcUMmCrgVYMKzvshGoQPEd-7QbNXpt_8XcN_Ow,9128 vllm/entrypoints/openai/run_batch.py,sha256=VaHRg-2Ja_3u7wzGtq_EvvqEFDiJDGTlKKR0fp95gMM,29303 @@ -532,48 +546,47 @@ vllm/entrypoints/openai/speech_to_text/__pycache__/speech_to_text.cpython-312.py vllm/entrypoints/openai/speech_to_text/api_router.py,sha256=-CFrCyG9z4FVf6fv4aQxJAmNPebdELGTAQ0I3pLShlg,5104 vllm/entrypoints/openai/speech_to_text/protocol.py,sha256=ZAnM3Jk4vCGQyTydyUXsaF0IbHOnU9fXm_HH9nzLKY0,17205 vllm/entrypoints/openai/speech_to_text/serving.py,sha256=mZl6FEbpzS-JyvKsov09odZeIXEY9c8Br82T0dTmoK0,6150 -vllm/entrypoints/openai/speech_to_text/speech_to_text.py,sha256=ByVpW4SKmYbwjgq-vqSjbcMrxUqAr6Es3AYGiuo8P5U,31459 -vllm/entrypoints/openai/translations/__init__.py,sha256=3L2Iwb-SbRw9JjhRlP-GUF9aaGJ0Z_52o6smXxbve0c,410 -vllm/entrypoints/openai/translations/__pycache__/__init__.cpython-312.pyc,, -vllm/entrypoints/openai/translations/__pycache__/api_router.cpython-312.pyc,, -vllm/entrypoints/openai/translations/__pycache__/protocol.cpython-312.pyc,, -vllm/entrypoints/openai/translations/__pycache__/serving.cpython-312.pyc,, -vllm/entrypoints/openai/translations/__pycache__/speech_to_text.cpython-312.pyc,, -vllm/entrypoints/openai/translations/api_router.py,sha256=bz-XDlh_lQFLoPvrsa74dg61EHo5I6yG39T19IotaSc,508 -vllm/entrypoints/openai/translations/protocol.py,sha256=PUVFoOn9f1kpagYj9ogWtF-mRA5kNtWD2sqwE35z9Us,502 -vllm/entrypoints/openai/translations/serving.py,sha256=OMTUCsZQOY95vQmnaZ_CVykK1R5VzFFKmLZTimzhV7E,499 -vllm/entrypoints/openai/translations/speech_to_text.py,sha256=KP2KldbK8EafCwgtaQ5pHknF4-p-Oo00haK0FdpuHE0,527 +vllm/entrypoints/openai/speech_to_text/speech_to_text.py,sha256=dmonxmIRKckLuVTShHU1ja-G9YPxZIzJJLiIBpbCA7I,32279 vllm/entrypoints/openai/utils.py,sha256=4ws1wnAPy-ozf3Dw1-D2h1S07xWkjRC1AFpsUECWf0o,1624 -vllm/entrypoints/pooling/__init__.py,sha256=I4DWP-s2rOTaXSbjdfC3H_d9I0ZJVFo5jyz38vIPTX4,4325 +vllm/entrypoints/pooling/__init__.py,sha256=0aDd_nORG54TwBZQcHV2LniatQ1XuxI5AHzl1BGUSN4,4414 vllm/entrypoints/pooling/__pycache__/__init__.cpython-312.pyc,, +vllm/entrypoints/pooling/__pycache__/io_processor_factories.cpython-312.pyc,, +vllm/entrypoints/pooling/__pycache__/typing.cpython-312.pyc,, vllm/entrypoints/pooling/__pycache__/utils.cpython-312.pyc,, vllm/entrypoints/pooling/base/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/entrypoints/pooling/base/__pycache__/__init__.cpython-312.pyc,, +vllm/entrypoints/pooling/base/__pycache__/io_processor.cpython-312.pyc,, vllm/entrypoints/pooling/base/__pycache__/protocol.cpython-312.pyc,, -vllm/entrypoints/pooling/base/protocol.py,sha256=peoLP5PbN28KHiUHe8ybS4ZsHor5XqphHxsF89pUDKg,7490 +vllm/entrypoints/pooling/base/__pycache__/serving.cpython-312.pyc,, +vllm/entrypoints/pooling/base/io_processor.py,sha256=_VtgNro9ycFsbDCZIzs0ot9YA3qPpWlfJPp4PWb87x8,6374 +vllm/entrypoints/pooling/base/protocol.py,sha256=QA7obcw7N_XsNcxOfdWjBaKNMCOMVcD_VGLX1ifqtsk,7354 +vllm/entrypoints/pooling/base/serving.py,sha256=mTIKxd-d-yPWMs9ZYxsXvB09x1_xnzv4HEndQfEO7f0,13018 vllm/entrypoints/pooling/classify/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/entrypoints/pooling/classify/__pycache__/__init__.cpython-312.pyc,, vllm/entrypoints/pooling/classify/__pycache__/api_router.cpython-312.pyc,, +vllm/entrypoints/pooling/classify/__pycache__/io_processor.cpython-312.pyc,, vllm/entrypoints/pooling/classify/__pycache__/protocol.cpython-312.pyc,, vllm/entrypoints/pooling/classify/__pycache__/serving.cpython-312.pyc,, -vllm/entrypoints/pooling/classify/api_router.py,sha256=oVlc9VP1R1cNy_WIw1mQofOmrHdjKyCHm9-8ybGu85w,1704 -vllm/entrypoints/pooling/classify/protocol.py,sha256=u_JCHyPp7DN97DRCUgPrt8GTNUGAtzLfhEoLqfvHnGY,2822 -vllm/entrypoints/pooling/classify/serving.py,sha256=R-B9-uiTDUeJYgM3zDAdHxHnt661A4kIaJwRUUmK4kI,5578 +vllm/entrypoints/pooling/classify/api_router.py,sha256=eId06qUMKV80fSWvmpYVUo1tf3IZi4t_s-wnkeAVC1w,1272 +vllm/entrypoints/pooling/classify/io_processor.py,sha256=Wck30F81yUmDAtM1mcL8O_8eS2Ucjk5uhTTnA846tN8,1997 +vllm/entrypoints/pooling/classify/protocol.py,sha256=o6oI_xXXTEsbuhChHQaWrn_lGdeh-2vtxK4vna54d40,2694 +vllm/entrypoints/pooling/classify/serving.py,sha256=nJ3o3swwS00bWhB04ddfUFRSvf-aCRpsqMJPEJVWh7k,2552 vllm/entrypoints/pooling/embed/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/entrypoints/pooling/embed/__pycache__/__init__.cpython-312.pyc,, vllm/entrypoints/pooling/embed/__pycache__/api_router.cpython-312.pyc,, vllm/entrypoints/pooling/embed/__pycache__/protocol.cpython-312.pyc,, vllm/entrypoints/pooling/embed/__pycache__/serving.cpython-312.pyc,, vllm/entrypoints/pooling/embed/api_router.py,sha256=eO-thkJZ0_ACfBE6I26dSeblrxu-U4Bmbsw446dV3-A,2628 -vllm/entrypoints/pooling/embed/protocol.py,sha256=dX35bWvbj_ls2kK6hcB26y9B1pAK47yDB6WkixmBjRY,4454 +vllm/entrypoints/pooling/embed/protocol.py,sha256=uz5gJ7a05ssZLPwulaWhrxFKr5vDZhT79UqOWSMEqP0,3724 vllm/entrypoints/pooling/embed/serving.py,sha256=-byoFBEG4zpM5Qz6I9n0wq2l8twDI_vbVqjdQbeanL8,25045 +vllm/entrypoints/pooling/io_processor_factories.py,sha256=A0jAz430HHdgvDPtjCZ3jXrZAvZI8bOSPudP6kZzIEQ,1026 vllm/entrypoints/pooling/pooling/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/entrypoints/pooling/pooling/__pycache__/__init__.cpython-312.pyc,, vllm/entrypoints/pooling/pooling/__pycache__/api_router.cpython-312.pyc,, vllm/entrypoints/pooling/pooling/__pycache__/protocol.cpython-312.pyc,, vllm/entrypoints/pooling/pooling/__pycache__/serving.cpython-312.pyc,, vllm/entrypoints/pooling/pooling/api_router.py,sha256=BT3Wf9C3Wf-0U-3Fr9WBIxE8ij2qGrLXNMSLI4hqR8s,2163 -vllm/entrypoints/pooling/pooling/protocol.py,sha256=Tov1Nx3PpL6xOliqP7sXendfrEh6mMwbqLZjHCSfr1g,4893 +vllm/entrypoints/pooling/pooling/protocol.py,sha256=lkF3Mu3lzzFUeURmO2wrsmqY5M1tIfCdeJ9oSKgyIjI,4163 vllm/entrypoints/pooling/pooling/serving.py,sha256=8Q7wD_UEyU_6Uf_mFEdKoGfCAUldL0izZvuB8X7a78E,12756 vllm/entrypoints/pooling/score/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/entrypoints/pooling/score/__pycache__/__init__.cpython-312.pyc,, @@ -582,14 +595,15 @@ vllm/entrypoints/pooling/score/__pycache__/protocol.cpython-312.pyc,, vllm/entrypoints/pooling/score/__pycache__/serving.cpython-312.pyc,, vllm/entrypoints/pooling/score/__pycache__/utils.cpython-312.pyc,, vllm/entrypoints/pooling/score/api_router.py,sha256=EgEEBPugF80xLXgMZwpLZcPGXpQ3-ZmYM3uYsQDzvPg,4690 -vllm/entrypoints/pooling/score/protocol.py,sha256=Vr4YLFtegqo6tFT6Dk_czMh6B_jmE2-ftiXoxpMrG7M,4025 -vllm/entrypoints/pooling/score/serving.py,sha256=jq2pbuajHZGrTll_by52AtCL9A0xrvIZxBROHvSOlF4,22113 -vllm/entrypoints/pooling/score/utils.py,sha256=L2y1JuMywqwnUpKs3PCuVZP_YmxIvKYIbWmHvxdsf3o,13844 +vllm/entrypoints/pooling/score/protocol.py,sha256=2nVYKBKp00HHfKm-g4CKMvq2cnK32wiEllHt_JxmE88,3897 +vllm/entrypoints/pooling/score/serving.py,sha256=IZHHmlA4LeaI757tidNif1WZkxsVQvewWGQyESjkdnE,22253 +vllm/entrypoints/pooling/score/utils.py,sha256=0F8ijuXas6wBey4ZXFfH8L8POaO8zB2jKQgUPyPSpqc,17063 +vllm/entrypoints/pooling/typing.py,sha256=p184UjkQARdEqT_iiJ2AfJVu6fH4QmN4_AHNVmupzwA,1302 vllm/entrypoints/pooling/utils.py,sha256=hutLstZWBpVpLPGur21YBatWDw1qKm8x-ye4clPjwrw,3150 vllm/entrypoints/sagemaker/__init__.py,sha256=Fh71s9Oigag26xVp42xT7JCAr_dvDRGnjoRhEV18cwk,155 vllm/entrypoints/sagemaker/__pycache__/__init__.cpython-312.pyc,, vllm/entrypoints/sagemaker/__pycache__/api_router.cpython-312.pyc,, -vllm/entrypoints/sagemaker/api_router.py,sha256=-A-nAq0SwH0yUX50EeeT06EjltvBLbqVprJV6JZ_TD8,6225 +vllm/entrypoints/sagemaker/api_router.py,sha256=Z3jd_n66sTm-9foJS0k5NfjBwN7IbXNE0QqmYXs4O7E,6307 vllm/entrypoints/serve/__init__.py,sha256=3NTr65tVBGmL_xe-ypSqffMKBMrCaHlVItF2QCiSmiU,1439 vllm/entrypoints/serve/__pycache__/__init__.cpython-312.pyc,, vllm/entrypoints/serve/cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 @@ -655,11 +669,11 @@ vllm/entrypoints/serve/tokenize/api_router.py,sha256=HH4-2P4hruiUbusT6Rq0F8j2zIE vllm/entrypoints/serve/tokenize/protocol.py,sha256=FQ_wXfO7rupRl94JQkqNznQXhQoNIIPneQhft10ZrNA,6071 vllm/entrypoints/serve/tokenize/serving.py,sha256=61jRHCiABN8U0G-p9IIxeO5RVHdwXMWFQK5zypvFDnU,7052 vllm/entrypoints/ssl.py,sha256=P6ApVGAQFLHERnJeoLNpZk-HQI1nNqWdlF8UAQpnC48,2712 -vllm/entrypoints/utils.py,sha256=IiE53NLG1Hn7byyvqx5WH8toPrCaeCP1yELV-f5DIK8,11253 -vllm/env_override.py,sha256=NkjKlikC3W5yjLfT9-wSl79RuFvtRDskUBahkmgK-Bk,18087 -vllm/envs.py,sha256=vNPBXbjpmKZY1EWKrU64VnhFnnp6HO6-lq0_mPNAAbA,83716 +vllm/entrypoints/utils.py,sha256=grCOSCM66SZzcom-_lSZIdjco22dVzRsbOG1j7WA77s,13270 +vllm/env_override.py,sha256=4HQLIXC8V6H6-pzHEVk8Jw0Y1BdWQmyAqQS6qzzxH94,19912 +vllm/envs.py,sha256=mmV1CyoMXLPSHEQ2DQqElndXnM-M5Wqvk5UhItcAEAI,87511 vllm/exceptions.py,sha256=AtRlqcG-WMiNYd_W1uJsbHdRoJusuBTWHAfpUru1NkQ,1067 -vllm/forward_context.py,sha256=poboqbv4FfR4uN5K84MI_JZDY92FoPjw5Hw-qUUAFic,16286 +vllm/forward_context.py,sha256=fyE3Z3oH1FBse9k145wQkUPFSGdewxk9rHkvCXUxkm8,16248 vllm/grpc/__init__.py,sha256=leqIiw5daoeDRF-SfkvbfjNXiJVmEp4p0vZcsNyYOME,504 vllm/grpc/__pycache__/__init__.cpython-312.pyc,, vllm/grpc/__pycache__/compile_protos.cpython-312.pyc,, @@ -686,7 +700,7 @@ vllm/kernels/helion/ops/__init__.py,sha256=vdhTpQvMYs9RDhS4APjiSh0S8dmeUMLBv1aV1 vllm/kernels/helion/ops/__pycache__/__init__.cpython-312.pyc,, vllm/kernels/helion/ops/__pycache__/silu_mul_fp8.cpython-312.pyc,, vllm/kernels/helion/ops/silu_mul_fp8.py,sha256=UMGebTNqxQ-HtMTxi0kN1ytwrbw4Gxw7Ai83d-DW1Mw,4836 -vllm/kernels/helion/register.py,sha256=5pZvkAeAVfaijr1vATcoXDMHiVe4wXRvUPQ5fsIkiw0,16371 +vllm/kernels/helion/register.py,sha256=whg7hgdYp9XNDnBEPCziZyVz38kKOG1Q7trfVR_Yxuw,19890 vllm/kernels/helion/utils.py,sha256=clw61SkYlr-fNwdm9DwhUowPgooDhEoycqFAassb8I8,3101 vllm/logger.py,sha256=stQuWFxaJ_CKcNOvN9EQIaOA__ylmSIOAUIVMLAp7CY,11133 vllm/logging_utils/__init__.py,sha256=91YL2ScSZ4kK19HWo6DCGXwhFH5K6jqDrweZWk3Ixg8,538 @@ -727,8 +741,8 @@ vllm/lora/layers/__pycache__/vocal_parallel_embedding.cpython-312.pyc,, vllm/lora/layers/base.py,sha256=djqVQ30rFpuG71V8nIj4QARsbEDzjfeosBk9jTt98Ts,1825 vllm/lora/layers/base_linear.py,sha256=ixQzJ4ogWgyITKjvUFCb7vZU1Rul-9496vhVeLrbKkE,5909 vllm/lora/layers/column_parallel_linear.py,sha256=Rp3sSTG4AlG_N2FA7QwmEmva06qxRzg0mLFg8dz_7CM,22947 -vllm/lora/layers/fused_moe.py,sha256=PcbT2LH_IeASAYwkbIricUeBJXW9rBDHsbQavWRwJWM,29585 -vllm/lora/layers/logits_processor.py,sha256=W_ZzEyMsydDqrynWqabzIRMmjfZQVnrNuNnmc5DHkyo,6625 +vllm/lora/layers/fused_moe.py,sha256=g2N2Wi2zhRUPoyHJIyLT2US7wJbKnw0gEY7sjEIL5e8,29746 +vllm/lora/layers/logits_processor.py,sha256=6uUYo43HEL7-sfqpD7UddbjnGWI1NoCDRZp8BuXsCUg,6544 vllm/lora/layers/replicated_linear.py,sha256=oswJbpelrd1Smiz1AV9zlLRcZb62XW700uporDliAJE,2205 vllm/lora/layers/row_parallel_linear.py,sha256=y6MpIbqGU3eSv50nlzuaSDRZdyUfwUJp66E6Fwct8GM,6187 vllm/lora/layers/utils.py,sha256=PGhIDjvImTX6hte5t-PtuwBV9EGPB1yZeTKCEq_uRC4,3329 @@ -753,12 +767,12 @@ vllm/lora/ops/triton_ops/__pycache__/lora_kernel_metadata.cpython-312.pyc,, vllm/lora/ops/triton_ops/__pycache__/lora_shrink_op.cpython-312.pyc,, vllm/lora/ops/triton_ops/__pycache__/utils.cpython-312.pyc,, vllm/lora/ops/triton_ops/fused_moe_lora_fp8_op.py,sha256=ZPbTYFG2YvplXu5Y5BT_0Q0CmVnk31irzeAgaRnL2vM,32151 -vllm/lora/ops/triton_ops/fused_moe_lora_op.py,sha256=jFWC86-yuhiUUAsYZom_8vpQLBQHodh2qgtHMxfMY38,24369 +vllm/lora/ops/triton_ops/fused_moe_lora_op.py,sha256=x4ONU29wqBfxD8tTdLLoVmn5_ifG0f1ZJG1FjpwCxEM,29829 vllm/lora/ops/triton_ops/kernel_utils.py,sha256=YiE1F7sojPFFF06QsVSeaDN19Fpo4X1MhFSoX4F8EXA,10375 -vllm/lora/ops/triton_ops/lora_expand_op.py,sha256=lqhtVPpv8kGz5bm2FY_ueiM3IBdcA1t0TeLUOXZGGd4,9477 -vllm/lora/ops/triton_ops/lora_kernel_metadata.py,sha256=oTXKTYFBLEYZR60kkro7g1AxiNX2Dre2FxxAnQzN0u0,6889 -vllm/lora/ops/triton_ops/lora_shrink_op.py,sha256=AEXWzv_099JVS2f4hJEnorwKQNx76pqtDYyx1nwXu-o,8932 -vllm/lora/ops/triton_ops/utils.py,sha256=kUfcxG1ALtR0_9M4LulLS7ah-PAmK0jBmg5OjTmR53Y,10511 +vllm/lora/ops/triton_ops/lora_expand_op.py,sha256=V2_Jnj-Lq5NMzm7opVTTuipWQ86hgTpFpzSA9WroUa4,9530 +vllm/lora/ops/triton_ops/lora_kernel_metadata.py,sha256=dmnu60keid9-9P8FCnnTk0znpjsUb-kyOfCB0XbGiIE,7695 +vllm/lora/ops/triton_ops/lora_shrink_op.py,sha256=AVrVt7uAprP94HpZEfi4hZn8-jG2Z6-fHViaINilh74,9204 +vllm/lora/ops/triton_ops/utils.py,sha256=OqcUyicFRFj2XrTWpOZAuZOD33XP6qiobVgUGe6M54c,10723 vllm/lora/ops/xpu_ops/__init__.py,sha256=7h0eSNAE9MWqAPvjKCCZOaNQSsXkgCxP-rBcJrEfqrI,258 vllm/lora/ops/xpu_ops/__pycache__/__init__.cpython-312.pyc,, vllm/lora/ops/xpu_ops/__pycache__/lora_ops.cpython-312.pyc,, @@ -774,7 +788,7 @@ vllm/lora/punica_wrapper/__pycache__/punica_xpu.cpython-312.pyc,, vllm/lora/punica_wrapper/__pycache__/utils.cpython-312.pyc,, vllm/lora/punica_wrapper/punica_base.py,sha256=gh0ZvV_6bdhnBhk5IbOoP74HB-RbQ94sShDDc2tBbO4,15864 vllm/lora/punica_wrapper/punica_cpu.py,sha256=qJLetoy4aPyQazVqbEq0pOOr848ox_hXvO3AOIH4XR4,10817 -vllm/lora/punica_wrapper/punica_gpu.py,sha256=uklNQbXbs6XPUgQthvSLjufEOvx-w02UE9fjhakUgXk,14605 +vllm/lora/punica_wrapper/punica_gpu.py,sha256=M9RfCIZ2hkB7VZqS3-6ZRXrkka4UrOrhXFiJTZZmLws,15232 vllm/lora/punica_wrapper/punica_selector.py,sha256=-IyHqLvTPs5onKUXacOqS9J03otlfaT8J7yj4iD_YRs,818 vllm/lora/punica_wrapper/punica_xpu.py,sha256=YjsIUsRV_ciLquTZO7GpC-Gwjr7EvoaebJ0QnwF67ys,13633 vllm/lora/punica_wrapper/utils.py,sha256=omm684H29vAOXrC4ny19kmKFRk0_-pIAMTR87CmRMXw,5517 @@ -790,7 +804,7 @@ vllm/model_executor/__pycache__/utils.cpython-312.pyc,, vllm/model_executor/custom_op.py,sha256=gNfYzyjwHnT0GbpaWmeuD-P6PrbeBTfBUG_O3OoijCM,13728 vllm/model_executor/kernels/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/model_executor/kernels/__pycache__/__init__.cpython-312.pyc,, -vllm/model_executor/kernels/linear/__init__.py,sha256=_TTL-2c1LaCjdwzX8aZj6sf3Zg5bKj_IlAenw-NhCM4,12957 +vllm/model_executor/kernels/linear/__init__.py,sha256=y3NUXd2h96j2EfJ1aDuqxwvWFui5RrFi8hKrcS8rQf0,12957 vllm/model_executor/kernels/linear/__pycache__/__init__.cpython-312.pyc,, vllm/model_executor/kernels/linear/mixed_precision/MPLinearKernel.py,sha256=dvSBv9nYGrOiSEN3udSv2ElQYpf7vQorIjA61b32D2g,2825 vllm/model_executor/kernels/linear/mixed_precision/__init__.py,sha256=sQyWnhpVrXm5f2SrQwmtap4F1dvUdKCAWAETyhwLgdw,1453 @@ -811,8 +825,8 @@ vllm/model_executor/kernels/linear/mixed_precision/cpu.py,sha256=OdM7Q6LUYwfT4ny vllm/model_executor/kernels/linear/mixed_precision/cutlass.py,sha256=LY3XjsyGlY1y1cGflBYkC40MqsJ9-uuvF6DIi2-aaI8,4631 vllm/model_executor/kernels/linear/mixed_precision/dynamic_4bit.py,sha256=1NT07Ioc9wQN5CoQZ4CKE0yoFOKq8aAGngLIrcjg-dw,6115 vllm/model_executor/kernels/linear/mixed_precision/exllama.py,sha256=9vy2P7jHm_bDNCDpiNkCH8wEg20_pwN4HLAhpHg-Org,6443 -vllm/model_executor/kernels/linear/mixed_precision/machete.py,sha256=gdBdme4JUaZ_YcrKTzJ1JfdoeX_oG9e0SSQa204AvTI,5726 -vllm/model_executor/kernels/linear/mixed_precision/marlin.py,sha256=K9HDpxkAEuopvXCsExA8Gd7o3bwumkte9TzVTRkNUnc,7456 +vllm/model_executor/kernels/linear/mixed_precision/machete.py,sha256=YSXF_YG2KNEaz58sGvTKMCFyqRJUY1r0pko1H97rGes,6019 +vllm/model_executor/kernels/linear/mixed_precision/marlin.py,sha256=8CnwYzgqZF9ilcHbxWHO_f-hpcp6We3_v-qvbNU1TOU,13793 vllm/model_executor/kernels/linear/mixed_precision/xpu.py,sha256=sudCl-MAKmFx_Tx_B3YxjhxPT78KABj65lj_7YSc1zo,3012 vllm/model_executor/kernels/linear/scaled_mm/ScaledMMLinearKernel.py,sha256=hhjrsj8bfhSkrSU0pMGcMq31KzC5JbhFOhgIWhdNRHE,5577 vllm/model_executor/kernels/linear/scaled_mm/__init__.py,sha256=XAedDbqkte2vfh0Z0yw4_20yczzSt3Xud8lnFRpk6Tw,1839 @@ -828,7 +842,7 @@ vllm/model_executor/kernels/linear/scaled_mm/__pycache__/triton.cpython-312.pyc, vllm/model_executor/kernels/linear/scaled_mm/__pycache__/xpu.cpython-312.pyc,, vllm/model_executor/kernels/linear/scaled_mm/aiter.py,sha256=l2XlFMcOWwv8MWzC_O-U8bAgYRz-PnIJEhtpURfnQYU,4253 vllm/model_executor/kernels/linear/scaled_mm/cpu.py,sha256=TK1c_-zADd-McY4tmtlKRAZniNt7RXpfM5RH-b965nw,7980 -vllm/model_executor/kernels/linear/scaled_mm/cutlass.py,sha256=EbTxE3ti7Kysh3LWBu95O_0cHW1clPbOaV1zJ5bE3K0,6589 +vllm/model_executor/kernels/linear/scaled_mm/cutlass.py,sha256=ddc33PnnGAfOJ8Vb8qpd-WwPpVnF5vdwap2SN0FejKw,8027 vllm/model_executor/kernels/linear/scaled_mm/flashinfer.py,sha256=dcs5UdeM-sWoDqkCuDue9O63kQ3RkKeROoXA_-UVGNQ,1783 vllm/model_executor/kernels/linear/scaled_mm/pytorch.py,sha256=4H6pqQN66UbgTMnkwa6ifZzxUb2Lfe5h4ukps5HxYBk,7730 vllm/model_executor/kernels/linear/scaled_mm/rocm.py,sha256=erNElT7iQxk1_4gMLvJqZzZoCkGSGcPrVtdx-3z74AA,3286 @@ -850,31 +864,33 @@ vllm/model_executor/layers/__pycache__/resampler.cpython-312.pyc,, vllm/model_executor/layers/__pycache__/sparse_attn_indexer.cpython-312.pyc,, vllm/model_executor/layers/__pycache__/utils.cpython-312.pyc,, vllm/model_executor/layers/__pycache__/vocab_parallel_embedding.cpython-312.pyc,, -vllm/model_executor/layers/activation.py,sha256=zNAl1vvd-aYcZOULsK0g6e5oD6iCfqSN_pBmgE_WTLU,24690 +vllm/model_executor/layers/activation.py,sha256=5KIB6Quq92koiYif6svQXN99fPPMRg4CvF9hX_KoaEM,24813 vllm/model_executor/layers/attention/__init__.py,sha256=WVCsMgceTb4i77vkNi4z42zaM5DhAZLttnO-gAcmZnY,912 vllm/model_executor/layers/attention/__pycache__/__init__.cpython-312.pyc,, vllm/model_executor/layers/attention/__pycache__/attention.cpython-312.pyc,, vllm/model_executor/layers/attention/__pycache__/chunked_local_attention.cpython-312.pyc,, vllm/model_executor/layers/attention/__pycache__/cross_attention.cpython-312.pyc,, vllm/model_executor/layers/attention/__pycache__/encoder_only_attention.cpython-312.pyc,, +vllm/model_executor/layers/attention/__pycache__/extra_cache.cpython-312.pyc,, vllm/model_executor/layers/attention/__pycache__/kv_transfer_utils.cpython-312.pyc,, vllm/model_executor/layers/attention/__pycache__/mla_attention.cpython-312.pyc,, vllm/model_executor/layers/attention/__pycache__/mm_encoder_attention.cpython-312.pyc,, vllm/model_executor/layers/attention/__pycache__/static_sink_attention.cpython-312.pyc,, -vllm/model_executor/layers/attention/attention.py,sha256=DnmYpSuIH1UjammjzOGXQhWFseHvowkCA8nDd4b5_Ww,28445 +vllm/model_executor/layers/attention/attention.py,sha256=s3wmcNTMpfmbw6bqZw7YcTPYGSKdOQIetY96P4DbCzI,30292 vllm/model_executor/layers/attention/chunked_local_attention.py,sha256=Lra8sc8XRfHewPca_suNdn9wYKmivWNnIRdG_pseXe4,4619 vllm/model_executor/layers/attention/cross_attention.py,sha256=jEq6GJk4AgKm0Va52HTI82Oy1uX2s1gY68J7XaBqXd0,8117 vllm/model_executor/layers/attention/encoder_only_attention.py,sha256=7Pa-sQszZrggyFQawK-cfDRNiqd8_valjo3XYV-cRUI,3036 +vllm/model_executor/layers/attention/extra_cache.py,sha256=47q-6XEOyj29m_V0YLuBO7ocJosX3WtHD4qsQyiXLPY,4525 vllm/model_executor/layers/attention/kv_transfer_utils.py,sha256=zzMz0HAky07ebU5kZhnl8IdbBxhp-8lDiiBF65B04_4,2030 -vllm/model_executor/layers/attention/mla_attention.py,sha256=wjAHqWj-EQ8nTjajRZemuTcqC75mcSbEzYv8biDPd7A,120514 -vllm/model_executor/layers/attention/mm_encoder_attention.py,sha256=cdKkO1cgu-SUY0wv4qng6AMiZ47ztlPsW9KI805gyaI,8964 +vllm/model_executor/layers/attention/mla_attention.py,sha256=otfm87etz4vBu94_vHHfLNX4RPNn--34rI9coY2cmmk,122436 +vllm/model_executor/layers/attention/mm_encoder_attention.py,sha256=P0I_pUQgjp-zvCZ18uT2HsWHsq_UsVq3SvKaM8y9bFc,15052 vllm/model_executor/layers/attention/static_sink_attention.py,sha256=MITS6MM1-fied2C0YANgXFOdwLMzNmJY9W3RU52yca4,8379 vllm/model_executor/layers/attention_layer_base.py,sha256=T8X8sxru-eTivn94eWOVqjEosV6LNaJbpTrR8sOe1QQ,1015 vllm/model_executor/layers/batch_invariant.py,sha256=lJAO7RJc79q5Nl_r26c7M_T6SJzCAI6MZAsXZbzrx_w,33841 vllm/model_executor/layers/conv.py,sha256=UVHGlHFBo56rLzFK7bOE22WkDEB4W3xpMJp4pYtxOjc,8152 vllm/model_executor/layers/fla/__init__.py,sha256=Xbv6P8KG7bLpcOF17zPO5PYYHImweE9aiYYgGyMS9U4,391 vllm/model_executor/layers/fla/__pycache__/__init__.cpython-312.pyc,, -vllm/model_executor/layers/fla/ops/__init__.py,sha256=Tx9ajmbsCCptuJofX-QpCC_Ut2SwZBTT56SUjGXGNs0,642 +vllm/model_executor/layers/fla/ops/__init__.py,sha256=6lsg78KTrKLkZ0Re9n2GmIJmVzqyLzzyWxS4ixI0Rzc,876 vllm/model_executor/layers/fla/ops/__pycache__/__init__.cpython-312.pyc,, vllm/model_executor/layers/fla/ops/__pycache__/chunk.cpython-312.pyc,, vllm/model_executor/layers/fla/ops/__pycache__/chunk_delta_h.cpython-312.pyc,, @@ -882,6 +898,7 @@ vllm/model_executor/layers/fla/ops/__pycache__/chunk_o.cpython-312.pyc,, vllm/model_executor/layers/fla/ops/__pycache__/chunk_scaled_dot_kkt.cpython-312.pyc,, vllm/model_executor/layers/fla/ops/__pycache__/cumsum.cpython-312.pyc,, vllm/model_executor/layers/fla/ops/__pycache__/fused_recurrent.cpython-312.pyc,, +vllm/model_executor/layers/fla/ops/__pycache__/fused_sigmoid_gating.cpython-312.pyc,, vllm/model_executor/layers/fla/ops/__pycache__/index.cpython-312.pyc,, vllm/model_executor/layers/fla/ops/__pycache__/kda.cpython-312.pyc,, vllm/model_executor/layers/fla/ops/__pycache__/l2norm.cpython-312.pyc,, @@ -890,21 +907,22 @@ vllm/model_executor/layers/fla/ops/__pycache__/op.cpython-312.pyc,, vllm/model_executor/layers/fla/ops/__pycache__/solve_tril.cpython-312.pyc,, vllm/model_executor/layers/fla/ops/__pycache__/utils.cpython-312.pyc,, vllm/model_executor/layers/fla/ops/__pycache__/wy_fast.cpython-312.pyc,, -vllm/model_executor/layers/fla/ops/chunk.py,sha256=YWK9csLUNXzExLLjGi2QMc-OocBUCzoLIZzdxucgfww,7885 -vllm/model_executor/layers/fla/ops/chunk_delta_h.py,sha256=SGaH1JZJB4Kqa_G2gXxwQSGD0eX2vBOKY6hDgUxw9T0,11944 -vllm/model_executor/layers/fla/ops/chunk_o.py,sha256=zshbCdGrTizDEbFVp9ahHmCEyiPOmxZ9QF9MPm1HOMM,5109 -vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py,sha256=_i7MMGO66MjD1RYQFDqNpa9yGknPyHBOTEDwChc9hFw,4653 +vllm/model_executor/layers/fla/ops/chunk.py,sha256=vM41AcghOkRolUZ-FhgKAlbXsonZMXvedUmVngACn9I,7870 +vllm/model_executor/layers/fla/ops/chunk_delta_h.py,sha256=gaRUkiXe67fo36q8z52wPB16tpNWp49cDSZR-X5_0cU,11940 +vllm/model_executor/layers/fla/ops/chunk_o.py,sha256=80kaFqamRCEq5Sl4raQ-J68e1Y11WP_WdBGjfSt6ZSI,5101 +vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py,sha256=bdoiRdFOvWHZPNT6F-CzHSkcR_JDw63jajArkF2OVIg,4645 vllm/model_executor/layers/fla/ops/cumsum.py,sha256=qhS_Lny-roiECMEDvtWBE2RpiFYDKjqWxwQONRdL2gA,8224 -vllm/model_executor/layers/fla/ops/fused_recurrent.py,sha256=W9qGwriB3oaeB1BLmMSmfs_Oy3NLD_BJg9p2wE668YA,13510 -vllm/model_executor/layers/fla/ops/index.py,sha256=cg2IV5Tn7cvDzPw8fTdI96sIYFy87tk960b0nPJFXSE,1221 -vllm/model_executor/layers/fla/ops/kda.py,sha256=TWHF34L4V1h-Ff0TpASmp7EmnvrpuIFarENkgtduLR4,38297 +vllm/model_executor/layers/fla/ops/fused_recurrent.py,sha256=Oio8UkWsuhJ68ybws_KHs1J3ao6x4fbkDCSnZ_rjwu8,21309 +vllm/model_executor/layers/fla/ops/fused_sigmoid_gating.py,sha256=AAq4mWr5eI_biEOmo7kYM-ehTIrMDh6gc6U2Mw9ky28,9166 +vllm/model_executor/layers/fla/ops/index.py,sha256=H6Nrd4M8Vs3B7Brz6jgiXYrp4frsWyGM9Ir1SC3fFaA,1185 +vllm/model_executor/layers/fla/ops/kda.py,sha256=mFUW0Pg9S-efqJQe5hwAMx8bbBxjduKuExMm2TdN1eY,38265 vllm/model_executor/layers/fla/ops/l2norm.py,sha256=HfUKWKl1EvrUKcowbDu4E5QPQCgZbPFumK43eli0Q6Q,3978 -vllm/model_executor/layers/fla/ops/layernorm_guard.py,sha256=ToZ3wAZ66DJ9RZN5FC0i4kXOZgEOkc3nr1gh-mZEhz8,12272 +vllm/model_executor/layers/fla/ops/layernorm_guard.py,sha256=pdEFfJWMfrjKALIqqCejzybUCUaYYLlNZwywdf98d_E,13019 vllm/model_executor/layers/fla/ops/op.py,sha256=atkZKsyTmBUTa68GnRk0VfEVfKAbmRoFLYtQlf8_JBM,1712 vllm/model_executor/layers/fla/ops/solve_tril.py,sha256=ARfbUgKOHld2Zbq3VlgjP07bjczFSCCQc0MRLmABcV8,19385 vllm/model_executor/layers/fla/ops/utils.py,sha256=8RRhqjpzSTv0DUsZpUor_qzrBrX6FkFuOOpha_HpmGc,6374 -vllm/model_executor/layers/fla/ops/wy_fast.py,sha256=TqyZCDUrAUo0ZFdtHPRY4ntUxavxsYsAGGj0m23oyiI,4372 -vllm/model_executor/layers/fused_moe/__init__.py,sha256=CLy7BL0W8uOuP1IuHlWwim3olUZfQmcX2VdF-eRn2Pc,4110 +vllm/model_executor/layers/fla/ops/wy_fast.py,sha256=9xUY538LvcvXnqKm2jt16wOUNBszvc41owl563RfCDQ,4368 +vllm/model_executor/layers/fused_moe/__init__.py,sha256=7RbCAApAQ4x32dvPzaIc91FP7NlqbUbs9x7cnb51_N0,4202 vllm/model_executor/layers/fused_moe/__pycache__/__init__.cpython-312.pyc,, vllm/model_executor/layers/fused_moe/__pycache__/activation.cpython-312.pyc,, vllm/model_executor/layers/fused_moe/__pycache__/all2all_utils.cpython-312.pyc,, @@ -932,8 +950,6 @@ vllm/model_executor/layers/fused_moe/__pycache__/modular_kernel.cpython-312.pyc, vllm/model_executor/layers/fused_moe/__pycache__/moe_align_block_size.cpython-312.pyc,, vllm/model_executor/layers/fused_moe/__pycache__/moe_permute_unpermute.cpython-312.pyc,, vllm/model_executor/layers/fused_moe/__pycache__/mori_prepare_finalize.cpython-312.pyc,, -vllm/model_executor/layers/fused_moe/__pycache__/pplx_prepare_finalize.cpython-312.pyc,, -vllm/model_executor/layers/fused_moe/__pycache__/prepare_finalize.cpython-312.pyc,, vllm/model_executor/layers/fused_moe/__pycache__/rocm_aiter_fused_moe.cpython-312.pyc,, vllm/model_executor/layers/fused_moe/__pycache__/routed_experts_capturer.cpython-312.pyc,, vllm/model_executor/layers/fused_moe/__pycache__/shared_fused_moe.cpython-312.pyc,, @@ -945,10 +961,10 @@ vllm/model_executor/layers/fused_moe/__pycache__/unquantized_fused_moe_method.cp vllm/model_executor/layers/fused_moe/__pycache__/utils.cpython-312.pyc,, vllm/model_executor/layers/fused_moe/__pycache__/xpu_fused_moe.cpython-312.pyc,, vllm/model_executor/layers/fused_moe/__pycache__/zero_expert_fused_moe.cpython-312.pyc,, -vllm/model_executor/layers/fused_moe/activation.py,sha256=OiQlBhQsJbpZefi-ZOFAg2fF_qdQv5fZ_7EyX9ce-SI,4919 -vllm/model_executor/layers/fused_moe/all2all_utils.py,sha256=nXmupKl6wIitp3bwdAJd7w9H_URRVdncBJs_7R_2KYA,9574 -vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py,sha256=V4XdBBljjms2ToISaHvJY0-m33nl_hUTvNLmSWG3QaY,16068 -vllm/model_executor/layers/fused_moe/config.py,sha256=GTQD51HOaDCzWNzQADHsnKQzd5d1cT3tRBqhimZ7cmw,41671 +vllm/model_executor/layers/fused_moe/activation.py,sha256=U9QeFMk9XFPMdkhm91ncYvmobf9QsGOOk-vYFhI5HSQ,4733 +vllm/model_executor/layers/fused_moe/all2all_utils.py,sha256=Rb5piZhuzyXHGzGCrHjxePI2t0NU3o3qohqGVRD7HEc,8105 +vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py,sha256=gU7jzYR2hfgoRmjlTugnlNw5cE6jURjKDRAipWxSUWQ,16059 +vllm/model_executor/layers/fused_moe/config.py,sha256=Guwe4HSIbLxcwwY0XA5YtHWLDFdWbDSUB6FtNZgE1D4,41748 "vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json",sha256=iNGsE2ZeVnQEnN4A8UJ9Jv0d3hbRF2MJ9oBgjup5Szk,2737 "vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json",sha256=hH5rRN9Wtyv35azxMzyUMHWtiKgOHev5tNjIG8j6dsE,2751 "vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json",sha256=qPumkNxaHMvVBnEjPe_Xiuz9ICb6Hqc-9I1DAR8s3gA,4130 @@ -1016,6 +1032,8 @@ vllm/model_executor/layers/fused_moe/config.py,sha256=GTQD51HOaDCzWNzQADHsnKQzd5 "vllm/model_executor/layers/fused_moe/configs/E=128,N=928,device_name=NVIDIA_H100_80GB_HBM3.json",sha256=iplHjPj366nNBu44ZHlCRe7aMaV2T4QthXINmkWi218,3269 "vllm/model_executor/layers/fused_moe/configs/E=128,N=928,device_name=NVIDIA_L40S.json",sha256=l59WyUnnymmJQykoJzEvvkzwlhIlmyEeVtvCFRzXGRM,3279 "vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H20.json",sha256=JtcHRlPz8xQEAqJ9EWI63oYvdmjQFG6VTHqtt85VOSA,3221 +"vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H200,dtype=fp8_w8a8.json",sha256=LJgf-vNZtRHawIm17vq6eN6cuzxw256i_MJ-VZGHIt0,3276 +"vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H200.json",sha256=KoFfThKn3gvLmzE4NMKRe9-g7ZA1zDMJJy30ABJSfwE,3269 "vllm/model_executor/layers/fused_moe/configs/E=129,N=704,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Workstation_Edition,dtype=fp8_w8a8.json",sha256=vaiXepgaaixIxU3_vVGcYsA18XAiC8V13phIN5niRrc,3250 "vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json",sha256=f3iM3xm8hGUirJ4ilAIPO6Pe9bs4sm3qaRKMswN9SKE,4731 "vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_B200,dtype=fp8_w8a8.json",sha256=Pux4G7sjfPL22uOJU6t35TXe-VU3OaoPA97TDj9_wZQ,3251 @@ -1250,39 +1268,49 @@ vllm/model_executor/layers/fused_moe/config.py,sha256=GTQD51HOaDCzWNzQADHsnKQzd5 "vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json",sha256=WQLKugnKzlQ0avf1N-41lRHtG6wJ56DfVPv_nip6NBc,3273 vllm/model_executor/layers/fused_moe/configs/README,sha256=W2yIZkP9O8GGlg97We9BJfTtWUtPbuz5ZH3esrrjBX0,572 vllm/model_executor/layers/fused_moe/cpu_fused_moe.py,sha256=GDuJQLv1AzB6X6dFZBOXe24Vo2xN5x3nJgmyKqhpa8w,16076 -vllm/model_executor/layers/fused_moe/cutlass_moe.py,sha256=1NNO0WoatYpynhVerZdFGr67BnyRGN0fXu66YAL4mU4,40797 -vllm/model_executor/layers/fused_moe/deep_gemm_moe.py,sha256=TR0dsJm3n_XE5DmZVGjDhUrOYvnttHY7FFU24SWJnXI,10621 +vllm/model_executor/layers/fused_moe/cutlass_moe.py,sha256=I5-_euo_sy2FcOzyICtjdQrUwvtMz7mzbspdKA0jn0Y,40767 +vllm/model_executor/layers/fused_moe/deep_gemm_moe.py,sha256=3WPsba24gKi7QKL6Vq6r37Ij1aYYJZL1F2O3sVoypuI,10612 vllm/model_executor/layers/fused_moe/deep_gemm_utils.py,sha256=OZ5vnMME0tDx4xSaWTroxkaO3akmLyznT_VWs1IZ42w,13264 -vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py,sha256=ajZ_R-CdU9AsKTDRkK6fGY0yldMCCNsYYXPcvt8uSF4,15657 -vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py,sha256=ewj4N7Bl4gdr2M82CtlaeD7ADHtr1cffVl6DRFIEXI4,15598 -vllm/model_executor/layers/fused_moe/fallback.py,sha256=pU1jxeelW47TOVKlrM6R1t6CTn-MMfXmZ1gY8-btqoQ,6174 -vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py,sha256=AdVfb0Q-ExuEjYgiouFL_rF54yOGLm_4tRjCV6uX5zc,7067 -vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py,sha256=l7knaZMEdAcDR8-RGcravrNeELR0eeogshwi5FxCKx8,11981 -vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py,sha256=jkBpj1WT8KzUurnhSnhz7w8DzJRbXq1pcKbo8ZcI8rc,15070 -vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py,sha256=0Q-wv4UzfxkmS48M5IiSJTa1s0sy51z6UGwfQJv_0RE,15335 -vllm/model_executor/layers/fused_moe/fused_batched_moe.py,sha256=SCqKf05Ih-bRvfj-3HnxSyQNZg7B1mTd3BVGYMU1Jgo,35954 -vllm/model_executor/layers/fused_moe/fused_marlin_moe.py,sha256=q_3FZBo530Isl3Lfw-qDICQZ08QdqLGEBcBxZug4SVk,29612 -vllm/model_executor/layers/fused_moe/fused_moe.py,sha256=lzrRaouy1432Ksd_h9iuDJXLiVEhMhh-ctjyaMu4vrk,83276 -vllm/model_executor/layers/fused_moe/fused_moe_method_base.py,sha256=fBF60yY-oztJDLqNxGgeUEN7cWv3xvqUYM_vQTnbxTo,4598 -vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py,sha256=bF6Wplb4ep1c0kmlbYwWcNhKpWOpXqCyxy8ILiENJ9E,3409 -vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py,sha256=tPuxX4nzZUQPrHlremrDyz_U2aNFIjZdNAjGUEYWsRg,23335 -vllm/model_executor/layers/fused_moe/layer.py,sha256=-ipIIEZZ0Np4M_m8IC7VHHf0voqxJUgADYWA5EZ4QKY,62185 -vllm/model_executor/layers/fused_moe/modular_kernel.py,sha256=THV7B6GCwJIB7wv2N4RARUbthdb_PABX02P8UT-KdWI,51871 +vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py,sha256=wsLE3TDA9d_mqm-b3586sqAE3XCr3y4ZK7ZKHPHOcSI,15764 +vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py,sha256=d6A-1FC3IiDbQmJUmuEsFfjbtjIyCf-Rp6uKWhz-MZU,15607 +vllm/model_executor/layers/fused_moe/experts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +vllm/model_executor/layers/fused_moe/experts/__pycache__/__init__.cpython-312.pyc,, +vllm/model_executor/layers/fused_moe/experts/__pycache__/trtllm_fp8_moe.cpython-312.pyc,, +vllm/model_executor/layers/fused_moe/experts/__pycache__/trtllm_nvfp4_moe.cpython-312.pyc,, +vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py,sha256=tjZVW89ghNWv_6JfFbwqyMaNSnTNYPyvNLGyrlx3fEY,12948 +vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py,sha256=O-DrUoyGWgQ65EB98yUpOGntAxLhlbgaKjOX5wEb6CI,11862 +vllm/model_executor/layers/fused_moe/fallback.py,sha256=b7uGJJs_NqZkjZLGH1AXRzv4KnbbPtAFB0GQJG-IYGs,6120 +vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py,sha256=fswvle7q9qP8Op8jVbPsjthuu0lGqs6dFG_48WlZ7qA,7144 +vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py,sha256=O-XyogSPLBaqZFZxdANa_81EYVC8CGT1_MRTxPqAqfs,11972 +vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py,sha256=hUqPWz2wxXBTvkvnZb6E4FgtdTMo2EXRg06I0w3Cpbg,15061 +vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py,sha256=yplWuWVHhtMClU7DZaUk3IN1Ldr_TaQ9WwSkTdF5PWo,4574 +vllm/model_executor/layers/fused_moe/fused_batched_moe.py,sha256=FqLxGcRCoFL4THzP67P-IXiAlxN05D_EpuvPdjzPIL8,35952 +vllm/model_executor/layers/fused_moe/fused_marlin_moe.py,sha256=By6-b6VX0FC4i6O1ttAIHvJoM4WTS1DshhDkC5BbqlU,29603 +vllm/model_executor/layers/fused_moe/fused_moe.py,sha256=8_cvvSE5Hk6tyksBPFnlnJQboIgMF0Rqq-iY_jY9ang,85591 +vllm/model_executor/layers/fused_moe/fused_moe_method_base.py,sha256=3dPVwK0OrOrPf6ZzSId4X_suscOH_rCZ3s20mEOq63Q,4453 +vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py,sha256=b395WXOpgOJoL554tkS13h5GKADyT9Gaw2Puvpei7MM,3430 +vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py,sha256=e6x5_qaZdwc7A_MyYwfzVF0jvOyVcJtGE94Ir4P6pV4,27690 +vllm/model_executor/layers/fused_moe/layer.py,sha256=99JSTy05tAKbP7umntGGpOwjBOnKTtYi-MgLLOPWvno,66842 +vllm/model_executor/layers/fused_moe/modular_kernel.py,sha256=bYIDbRbeMIrBuPbe9G_ZA9dVghbtGGgVdaFKvgOAglc,63407 vllm/model_executor/layers/fused_moe/moe_align_block_size.py,sha256=w_f8IIeDbwFgoyq5nOsdj26Hx5NnnaNN1nLiJct_oxs,8339 vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py,sha256=ztL9u2t25hNBKFgD0Qj1buDbVmWlcIBgF9XjtV02Xok,4992 -vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py,sha256=Rm85qW912Y_e7vFHCeFE4lrnigIYyp7XDfVRmhJxZf0,4121 +vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py,sha256=6B-0hMQjOGSGZIhsbfTcQMGv08o54jdhgHZbsT2YGPY,4128 vllm/model_executor/layers/fused_moe/oracle/__init__.py,sha256=48Vcuw16i9R99vM3iDVkRkX107yJtFeTtdWgLQE-XFg,107 vllm/model_executor/layers/fused_moe/oracle/__pycache__/__init__.cpython-312.pyc,, vllm/model_executor/layers/fused_moe/oracle/__pycache__/fp8.cpython-312.pyc,, vllm/model_executor/layers/fused_moe/oracle/__pycache__/nvfp4.cpython-312.pyc,, vllm/model_executor/layers/fused_moe/oracle/__pycache__/unquantized.cpython-312.pyc,, -vllm/model_executor/layers/fused_moe/oracle/fp8.py,sha256=QyPaQzy82JVVRTsL6X2uFtqF3FUnRLO1RrvdZGy-oLY,22226 -vllm/model_executor/layers/fused_moe/oracle/nvfp4.py,sha256=YM-prlv66hDjT3rGfSnjahPt2PcWOSaxjr_R3fs92sA,15875 -vllm/model_executor/layers/fused_moe/oracle/unquantized.py,sha256=jie64ySw2lz7Sb2ZGLqfKqWuq0G1TO5NFiFIRnsL00s,9533 -vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py,sha256=PND1lFLLUKz4KygyrX_1HefXzocVVeUAd07cobUXCME,12218 -vllm/model_executor/layers/fused_moe/prepare_finalize.py,sha256=eOvUSMWpdaiJuRTqPv98rHKNxI1Wa4_QIXDpi_E3FBk,7337 -vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py,sha256=EuP6YL2gFZEKLylrVHtBatmrzLHe_Tpzz1RplYpolv0,14004 -vllm/model_executor/layers/fused_moe/routed_experts_capturer.py,sha256=z7ufSRlsWlFAl3ULl1R2wLD650ouQ0MP4LinCbfego4,11568 +vllm/model_executor/layers/fused_moe/oracle/fp8.py,sha256=5PMXJ5zNFAqmGgPteG2hT95EBT86PdCzXM2Yi-6kL3M,20299 +vllm/model_executor/layers/fused_moe/oracle/nvfp4.py,sha256=dI6Okdrp6V6rDhMEYOudxU3sOmcXMdwtcFOsjuprNTM,14537 +vllm/model_executor/layers/fused_moe/oracle/unquantized.py,sha256=DS14JoXWTLDnQrkorcQZ8vhXY6EI5os0GP_osFFVT7Q,8916 +vllm/model_executor/layers/fused_moe/prepare_finalize/__init__.py,sha256=iwuOYZXkbmX0D3Udt6LpGTFlyo_mlK1jEyov1cxvqDk,822 +vllm/model_executor/layers/fused_moe/prepare_finalize/__pycache__/__init__.cpython-312.pyc,, +vllm/model_executor/layers/fused_moe/prepare_finalize/__pycache__/naive_dp_ep.cpython-312.pyc,, +vllm/model_executor/layers/fused_moe/prepare_finalize/__pycache__/no_dp_ep.cpython-312.pyc,, +vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py,sha256=pqx-xzTaWh_v4wpg7oQKOZB0OknksFQi3yuju7ILBiI,8276 +vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py,sha256=4QjzExh0Wtb15JmAxOchreKDo1pdRMaS1XqgswXyKTA,4597 +vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py,sha256=9GSW36FKmdfkBFKjIGb14Sj8SiiJg0vvV56tQcILy3o,13995 +vllm/model_executor/layers/fused_moe/routed_experts_capturer.py,sha256=3b-7pwPTN_rvMase5plWxYxQnV5I6w0TXh3v0yUkHM0,11634 vllm/model_executor/layers/fused_moe/router/__init__.py,sha256=48Vcuw16i9R99vM3iDVkRkX107yJtFeTtdWgLQE-XFg,107 vllm/model_executor/layers/fused_moe/router/__pycache__/__init__.cpython-312.pyc,, vllm/model_executor/layers/fused_moe/router/__pycache__/base_router.cpython-312.pyc,, @@ -1290,37 +1318,39 @@ vllm/model_executor/layers/fused_moe/router/__pycache__/custom_routing_router.cp vllm/model_executor/layers/fused_moe/router/__pycache__/fused_moe_router.cpython-312.pyc,, vllm/model_executor/layers/fused_moe/router/__pycache__/fused_topk_bias_router.cpython-312.pyc,, vllm/model_executor/layers/fused_moe/router/__pycache__/fused_topk_router.cpython-312.pyc,, +vllm/model_executor/layers/fused_moe/router/__pycache__/gate_linear.cpython-312.pyc,, vllm/model_executor/layers/fused_moe/router/__pycache__/grouped_topk_router.cpython-312.pyc,, vllm/model_executor/layers/fused_moe/router/__pycache__/router_factory.cpython-312.pyc,, vllm/model_executor/layers/fused_moe/router/__pycache__/routing_simulator_router.cpython-312.pyc,, -vllm/model_executor/layers/fused_moe/router/base_router.py,sha256=cLvlX4sFRnt9nC44GLThNoLIePmTcKUhFm3xS7XbE30,9601 +vllm/model_executor/layers/fused_moe/router/base_router.py,sha256=BMsX5xlhNYqY0xTu2cS4S1WZvX6TfrWe2bBbhx1n1Wo,9652 vllm/model_executor/layers/fused_moe/router/custom_routing_router.py,sha256=vDfX3qXVCb5-AXK4pKo-EymWoBxWODStPivZVPDpWl0,2122 vllm/model_executor/layers/fused_moe/router/fused_moe_router.py,sha256=DlbIBHHQs-kqisruGPe7vKzN_UeTu4ogcC-K0t9WNro,1261 -vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py,sha256=anUL0jqvRd24aTjo_aRwJDNVcdNjGDTZOFrG4ZwypC0,6144 -vllm/model_executor/layers/fused_moe/router/fused_topk_router.py,sha256=F60qRZ3sr1J-OCRopf6PTWK5QqGWbaPzPZGuy6zqi-c,4941 -vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py,sha256=hKXp091Heu6Ey0sjLXqSzd29gCwrcWjNgPBWRkGznWw,12916 -vllm/model_executor/layers/fused_moe/router/router_factory.py,sha256=Vegx6kR16u9ADbfCaN8tbL1rRpFs-zPo-WpjxrIsm6A,6337 +vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py,sha256=2VU8IAvMBEblib-bX6KgJyhzY-3cdR-hJyp9z63jAcU,6237 +vllm/model_executor/layers/fused_moe/router/fused_topk_router.py,sha256=uBDhmVUzKbgz3xKNu1s0deV0cxsx0ArLLLwIK_rKSaY,5061 +vllm/model_executor/layers/fused_moe/router/gate_linear.py,sha256=qVUylcrqPZIPO92jJy6kj7rynnCIhIadRk5iVNFjsYE,4255 +vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py,sha256=SvG7t755eUNPYlE4sJymxNLgvrh_Q5gL6Fw46OlOw0E,10288 +vllm/model_executor/layers/fused_moe/router/router_factory.py,sha256=UZPMhGivUlIj95GAMvnC8Rc6EPW4pw6FqlMk50HozjM,6337 vllm/model_executor/layers/fused_moe/router/routing_simulator_router.py,sha256=tD040EnxUfoa8wKjjJj2KqcOhgIFYxXC31r6Hfu6qhk,12060 vllm/model_executor/layers/fused_moe/runner/__init__.py,sha256=48Vcuw16i9R99vM3iDVkRkX107yJtFeTtdWgLQE-XFg,107 vllm/model_executor/layers/fused_moe/runner/__pycache__/__init__.cpython-312.pyc,, vllm/model_executor/layers/fused_moe/runner/__pycache__/default_moe_runner.cpython-312.pyc,, vllm/model_executor/layers/fused_moe/runner/__pycache__/moe_runner.cpython-312.pyc,, -vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py,sha256=9ve4iCY0qAJUChLCGIqWyMuYsRsioL7MarCLJCWFyIs,30270 +vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py,sha256=qOK4FIcc-hvmrjeDQAgWxPAZ6J8-3j8Ew39TdPMAt1c,30096 vllm/model_executor/layers/fused_moe/runner/moe_runner.py,sha256=nQnuvW9D9DsMvA9Y_wWVVMvxozQD6UQg_WdsjFVOaZY,992 vllm/model_executor/layers/fused_moe/shared_fused_moe.py,sha256=xrkplE36sFIWFkhEiCrb06UmYJTuyNtx2HWjqZPjsfA,2161 -vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py,sha256=sLKn82zPW7Yij8Bcrt4sXW78SyuwIwDELfiqsXaTxsw,5801 -vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py,sha256=O5wADe6kTwJkHXoMKypPcxETSPpFPJfDankwEZgClQk,2684 -vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py,sha256=bADax16fYxG-wKPYuW9idzwIlpuVO9MTk7OeTczSMPI,2857 -vllm/model_executor/layers/fused_moe/trtllm_moe.py,sha256=i1J7aU7NMzQ-rxkgatIT-7sMGtwhqQ20D1Wr7uQQ3F4,6338 -vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py,sha256=yU1xzCy2LeXuAe3nMXEAtXnKLO0D1fgt66_yhdxjl4M,14792 +vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py,sha256=XC84o7emdFIyVMP8YfndvtavqfVPkDzauLGjheEkG4I,5874 +vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py,sha256=qLDd3tjuEfcdSZHU3apLMGQ5ioFwpa8Zb1oQ0JC4Iw0,2657 +vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py,sha256=4RKgs0D0xjTY2E2JFDfTLyRLdXZ17rAI5Usn_B4Nkqg,2830 +vllm/model_executor/layers/fused_moe/trtllm_moe.py,sha256=QcyXsi8Uufh-Se4R4etoVFjc1ZPOi07pvOq1B79Ji8c,6329 +vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py,sha256=gsbKewhv312LLpk03lp6-MKpPuv1RH04zN7rbIbO3-M,15008 vllm/model_executor/layers/fused_moe/utils.py,sha256=HmnOO-olejNBBTMf3XVtYR5mT3tc1gshzLeHjVMZ2Jk,11712 -vllm/model_executor/layers/fused_moe/xpu_fused_moe.py,sha256=fvl7rUCDeuLqtANnPkaMQCysc7CWmZKI_D4D7_RmaY8,4676 +vllm/model_executor/layers/fused_moe/xpu_fused_moe.py,sha256=6CD_APqC4SIK4qXBCTKEiWC1gxNTOYsjiQB5PsL0EIA,4667 vllm/model_executor/layers/fused_moe/zero_expert_fused_moe.py,sha256=4LiP1U9X2ESu_SOVdZ1SHrHZOe7XzdqBt2Ee0RDQPgM,7676 vllm/model_executor/layers/kda.py,sha256=9gozc7-nXg7ubhTPk_CzE0c1oWVNGZcxlP6SaWepRyo,15497 -vllm/model_executor/layers/layernorm.py,sha256=bsZuPCvDsDZo1-KnSawYwKoIGidSdMUB6j2yPu7XpU8,25348 +vllm/model_executor/layers/layernorm.py,sha256=k3Q_kdNGEO-m6hxm83Oog2SvPCju8BVogJLUFH9FEJs,25645 vllm/model_executor/layers/lightning_attn.py,sha256=yknJVad9ilzRGFdFhvjkoe_xEiffNp_KopUsd32Qzf8,20864 -vllm/model_executor/layers/linear.py,sha256=NzlAcv23OLV5PnN5IjF4ey7CtJkvhFt3IaFDcUUcSeY,62213 -vllm/model_executor/layers/logits_processor.py,sha256=UyFpVbQLr9Y7JxNxYtBOb6Tv9mIIB6RGhJB93Wf1dSE,6208 +vllm/model_executor/layers/linear.py,sha256=kr2f0QFQEtHYOIeB8KCJhCkVKomdQiAk3wCamy4Z4Wk,66525 +vllm/model_executor/layers/logits_processor.py,sha256=1H2XdBVL__PwzlQXlZ4lNPOobQNng222HMBZioXNeOE,6396 vllm/model_executor/layers/mamba/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/model_executor/layers/mamba/__pycache__/__init__.cpython-312.pyc,, vllm/model_executor/layers/mamba/__pycache__/abstract.cpython-312.pyc,, @@ -1330,10 +1360,10 @@ vllm/model_executor/layers/mamba/__pycache__/mamba_mixer2.cpython-312.pyc,, vllm/model_executor/layers/mamba/__pycache__/mamba_utils.cpython-312.pyc,, vllm/model_executor/layers/mamba/__pycache__/short_conv.cpython-312.pyc,, vllm/model_executor/layers/mamba/abstract.py,sha256=8HIHxbE1th8PUbL2rIfhM81n9WhanyUwvPezabHuxsc,2213 -vllm/model_executor/layers/mamba/linear_attn.py,sha256=J00weDmLflZV1jshtA02XasPMHTrIYl5WYas5txSS4U,15254 -vllm/model_executor/layers/mamba/mamba_mixer.py,sha256=VNWCtLeI_E2VdgW2TuqEkhpKPYUu0cOvYY4bVMP2UwQ,19660 +vllm/model_executor/layers/mamba/linear_attn.py,sha256=Qm9bDDzN2o5JrC_SlEsZaHkty58oDSea9fiv3MJVXXM,16562 +vllm/model_executor/layers/mamba/mamba_mixer.py,sha256=pKZFo82YqqMD0pR0k65CC-EcRcJI0DxnV78rvtXMh4A,19902 vllm/model_executor/layers/mamba/mamba_mixer2.py,sha256=Fdr4hKModFKAP1WDTSnu-exWcqqTUuiZxjqj2YorAhA,38444 -vllm/model_executor/layers/mamba/mamba_utils.py,sha256=3w7V0LrNiZHBkFrZG-OBMYbQd7ED56K7jv3-rzPB6fQ,10372 +vllm/model_executor/layers/mamba/mamba_utils.py,sha256=NQ7XbYBIv2_VxFyZ_trUxaI3zYXSr-KszHkwn7szGzw,10326 vllm/model_executor/layers/mamba/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/model_executor/layers/mamba/ops/__pycache__/__init__.cpython-312.pyc,, vllm/model_executor/layers/mamba/ops/__pycache__/causal_conv1d.cpython-312.pyc,, @@ -1344,16 +1374,16 @@ vllm/model_executor/layers/mamba/ops/__pycache__/ssd_chunk_scan.cpython-312.pyc, vllm/model_executor/layers/mamba/ops/__pycache__/ssd_chunk_state.cpython-312.pyc,, vllm/model_executor/layers/mamba/ops/__pycache__/ssd_combined.cpython-312.pyc,, vllm/model_executor/layers/mamba/ops/__pycache__/ssd_state_passing.cpython-312.pyc,, -vllm/model_executor/layers/mamba/ops/causal_conv1d.py,sha256=g4ahWb0tzWIbXky0huQtfCJYA5Hn6SJ4NazsQ6-HZhE,49672 +vllm/model_executor/layers/mamba/ops/causal_conv1d.py,sha256=bi_Saq1xSfzL5cIzq7ABpxNyFCM7JEw7nWCm1a92pc8,49670 vllm/model_executor/layers/mamba/ops/layernorm_gated.py,sha256=KA5JRI8jEStDEkN6U6dqaOBQSCWugxUelAhZISZyf7A,5228 -vllm/model_executor/layers/mamba/ops/mamba_ssm.py,sha256=KRP4YhI9MQCFi_HpDxjkkR-fTxB3dqyjHFVVlhERVv0,20010 +vllm/model_executor/layers/mamba/ops/mamba_ssm.py,sha256=s5mUlU7lK5hoFwzNeViBpra_GtQrMWklAqb7EuaU16s,20118 vllm/model_executor/layers/mamba/ops/ssd_bmm.py,sha256=mlQ258Jsexnb5vHGh09YW0l-rF3SZM4H1R29o3hb2-s,6842 vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py,sha256=Yow2h10PSWWKvrwT3TRz7UZTCq--99j7MZToXN59aCc,15550 vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py,sha256=6-0Dl2-frs5av7ra-H0xHWPupVaSo0Vex4d5vZVq41Q,24226 vllm/model_executor/layers/mamba/ops/ssd_combined.py,sha256=oKPphPh0wjU3XT3k5U2t9cHb4jDLuYzIQuWK9w0sOcs,7353 vllm/model_executor/layers/mamba/ops/ssd_state_passing.py,sha256=OWIEJ6WjdYHbX_7HRzhPt50oejUjvzhBTST6bQPhjas,5275 vllm/model_executor/layers/mamba/short_conv.py,sha256=peKNJ66UhX78ZWjfV2kLedNlcAOPJn58vQ1oUiCbJ1Y,8257 -vllm/model_executor/layers/mla.py,sha256=SNpV2Bi1BnP9iOcm0jnN1PL9V2i8LTcJ4kepi6Zu6yI,6440 +vllm/model_executor/layers/mla.py,sha256=3g8P4Zn2t_lWHbmBVws_3NvIsMNWu_agE6lIQ5s4HV8,6516 vllm/model_executor/layers/pooler/__init__.py,sha256=AvMJtY1sQcP_n7SWB9Zp0F7dBfgNxKOpoUfjuZfP1_k,176 vllm/model_executor/layers/pooler/__pycache__/__init__.cpython-312.pyc,, vllm/model_executor/layers/pooler/__pycache__/abstract.cpython-312.pyc,, @@ -1380,7 +1410,7 @@ vllm/model_executor/layers/pooler/tokwise/__pycache__/poolers.cpython-312.pyc,, vllm/model_executor/layers/pooler/tokwise/heads.py,sha256=CCwLTzPIZ_s344JSCPoZWNgX6iO-vWl03uB7AbBFTLo,4149 vllm/model_executor/layers/pooler/tokwise/methods.py,sha256=ewPD-pazZ_mxULxR5ZW1m-3tfHoHz8SKlu_OKlCEO7c,4240 vllm/model_executor/layers/pooler/tokwise/poolers.py,sha256=BU6GwjUveV7Mx6aSuVdd9HBDwFY8MEMHJqreCC9QhpE,4096 -vllm/model_executor/layers/quantization/__init__.py,sha256=p-p8jPHEkQyu7nkhbkESZEo8D45OhooF8OLMUYjzdLE,5324 +vllm/model_executor/layers/quantization/__init__.py,sha256=2Vn15_EIGAHFE4BLagygwzfSzluLbpcTGZCEM4NJgB8,5550 vllm/model_executor/layers/quantization/__pycache__/__init__.cpython-312.pyc,, vllm/model_executor/layers/quantization/__pycache__/awq.cpython-312.pyc,, vllm/model_executor/layers/quantization/__pycache__/awq_marlin.cpython-312.pyc,, @@ -1406,8 +1436,9 @@ vllm/model_executor/layers/quantization/__pycache__/ptpc_fp8.cpython-312.pyc,, vllm/model_executor/layers/quantization/__pycache__/qutlass_utils.cpython-312.pyc,, vllm/model_executor/layers/quantization/__pycache__/schema.cpython-312.pyc,, vllm/model_executor/layers/quantization/__pycache__/torchao.cpython-312.pyc,, +vllm/model_executor/layers/quantization/__pycache__/w8a16.cpython-312.pyc,, vllm/model_executor/layers/quantization/awq.py,sha256=EqFKOHtFQepFitFjmAdMMnrjeeRCwlQvjcL-k2rZJ1o,10208 -vllm/model_executor/layers/quantization/awq_marlin.py,sha256=FnOD0LI11ztLDgbVSf64BLEPOb8y4UU1la3UWeiDd5k,36870 +vllm/model_executor/layers/quantization/awq_marlin.py,sha256=uHW2Q0Pa6MV7S0j-I08opJjlq2napcjJeiVj-c-a8G0,38811 vllm/model_executor/layers/quantization/awq_triton.py,sha256=fRH9rX5jDSqQtdmcrXwxtPtQEiKxTUdPkAJmtvW_zng,11620 vllm/model_executor/layers/quantization/base_config.py,sha256=ZY_lnCDF33boaqF9u3DPV3D9-KWJ37MRr8amKC9aVRo,6563 vllm/model_executor/layers/quantization/bitsandbytes.py,sha256=JiHCrscq3pKS_FFaETQ9rpzEmripvXSxK_AbIKi04IE,20250 @@ -1417,9 +1448,9 @@ vllm/model_executor/layers/quantization/compressed_tensors/__pycache__/compresse vllm/model_executor/layers/quantization/compressed_tensors/__pycache__/compressed_tensors_moe.cpython-312.pyc,, vllm/model_executor/layers/quantization/compressed_tensors/__pycache__/triton_scaled_mm.cpython-312.pyc,, vllm/model_executor/layers/quantization/compressed_tensors/__pycache__/utils.cpython-312.pyc,, -vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py,sha256=sK8R-FeQtwOQ6TIcaBz87nMrC8D6vafmR4uW-wBahQo,43488 -vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py,sha256=id0owM9G5wvMfL7WIelnELsPY7lvVP6WZcnFgECzfxA,125810 -vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py,sha256=9ulnB6i8TA_ET4N84o56K8j-vxYHpcjL_csVUZCo0kI,1357 +vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py,sha256=hlDlYVo7JQcuu4C2lCdtP03QrocshL1zJfPYttNgiGY,43131 +vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py,sha256=JjZpIZ0RcoJLz73U2gLodCg2yl-R4NSxayR-kM0EOi0,185236 +vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py,sha256=9MgIcmvg_tbj7cvK1V2b2WNdsou4Etb7xStFdvnZ9X8,1298 vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/__init__.cpython-312.pyc,, vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/compressed_tensors_24.cpython-312.pyc,, vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/compressed_tensors_scheme.cpython-312.pyc,, @@ -1441,7 +1472,7 @@ vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_te vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_int.py,sha256=cwo918k-M8UUqf0o7sN5IttmgnA--mWudwLplZbwH8M,5124 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py,sha256=IX-MRlV3KEnSaReMS5oP4cnzYD2qMOjziPh2A1v3plg,5365 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py,sha256=3wzNEzEaS75TxKnymt3OkXuD-X4G7RengWzAFlPGs5g,7650 -vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py,sha256=5_7c-lYG0XRXr-6dmB4w7w4Vt8QJlx2O98XpDt08cBY,7628 +vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py,sha256=hEh9SeMz0AXnF_9SnVB51SJwOgLZgT2FJI4SNIoSk5g,5022 vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py,sha256=pUnlwIWELDxSfKyOUnIaF1LKydCNhpW7cdEzTk4YNGU,7875 vllm/model_executor/layers/quantization/compressed_tensors/transform/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/model_executor/layers/quantization/compressed_tensors/transform/__pycache__/__init__.cpython-312.pyc,, @@ -1460,33 +1491,33 @@ vllm/model_executor/layers/quantization/compressed_tensors/utils.py,sha256=MGBWA vllm/model_executor/layers/quantization/cpu_wna16.py,sha256=CCh4moMa9rDrpMxDViuD9bMh-AObWLugjiLX8Wagovc,9713 vllm/model_executor/layers/quantization/experts_int8.py,sha256=uEOc_6HbAxRHk4kGVrVXKXY-NDmsZhKPzOz8f9O7E-8,6891 vllm/model_executor/layers/quantization/fbgemm_fp8.py,sha256=4WYLD1DECYNdK1vlV9JNm5xQz0tA-QW5PZLuwiqqT1M,6537 -vllm/model_executor/layers/quantization/fp8.py,sha256=ftGOMgzAt8hGmBBHHPvL1iZ0dV36zB7N_36EISWkkT4,51874 +vllm/model_executor/layers/quantization/fp8.py,sha256=BbDGeeAJcdzPT3TcEAeBBYnG368Lz5vwXFG_qJQL87g,49438 vllm/model_executor/layers/quantization/fp_quant.py,sha256=-_Y67uqhWYnF3-VgelEieYYyJdHGFNHQU7N0EZRkS30,12808 -vllm/model_executor/layers/quantization/gguf.py,sha256=vNzIk01jPHcfybG1DpW84SL2ZSCnGRwOjCPLFZ0lSJg,23210 -vllm/model_executor/layers/quantization/gptq.py,sha256=9N215iLPUsDYI4bIaaEh690XvH3Zei7VgXAMysOxEcI,14814 -vllm/model_executor/layers/quantization/gptq_marlin.py,sha256=jFL9Z4WCqD0PpR8R9k9wFBe--hSmWyvUJvOPIve6W5w,34985 +vllm/model_executor/layers/quantization/gguf.py,sha256=UloLOKvRfV26yLl2v5R8eUqaqMXZQ4a9JUP2yuvgC3g,23559 +vllm/model_executor/layers/quantization/gptq.py,sha256=VMFcCvRxWW8Y0yQ8g_0F4lZav5VoHmox3ck7IoeblyQ,14814 +vllm/model_executor/layers/quantization/gptq_marlin.py,sha256=cfXVJepdbqDkiKLKe7F4zZmaiehy4lssLQk4tZs87To,45705 vllm/model_executor/layers/quantization/inc.py,sha256=e1WXuAsEHbeCMhCDL_D-JiLpKuo6xz0o0udCjjbZaqo,16368 vllm/model_executor/layers/quantization/input_quant_fp8.py,sha256=dPgo7XUhapj0dcbi57i-iRxGtXrDfWGbae8Ck2A8pS4,9421 vllm/model_executor/layers/quantization/kv_cache.py,sha256=MEa3GgF2uX3PMmj27T5XaITpXpRKvanEhahDHHMvICE,6616 -vllm/model_executor/layers/quantization/modelopt.py,sha256=gTNqO6NhocH0472tJe9RjacPCP8wpHyWQB-ebVpdcyE,67413 +vllm/model_executor/layers/quantization/modelopt.py,sha256=9DyDh7ZPjCLRDS3FKP2eCEzWHb9ZnrEotzKi4BpZNPQ,68978 vllm/model_executor/layers/quantization/moe_wna16.py,sha256=kDQ--mJ4XsCp95yR_txEEIZGZBP0Pahj_RpTEKQC0Vc,20437 -vllm/model_executor/layers/quantization/mxfp4.py,sha256=Mbaol2gJw0nQRgyEwf2pYF0Ecjp2bCLCkggjmPghUOg,46827 +vllm/model_executor/layers/quantization/mxfp4.py,sha256=SbmLqaO6aMQ8vXINFQq-DvcieorwDywAzENwwf6yRKo,50841 vllm/model_executor/layers/quantization/petit.py,sha256=OxEhXQ83yAPyoz_wQanCo9ZIeIAr77ERIThMW7imJb0,11563 -vllm/model_executor/layers/quantization/ptpc_fp8.py,sha256=U2UlS5PeDUGVqNfI_QIk7DQr0-OhOJi2Psabz7dtyFY,5178 +vllm/model_executor/layers/quantization/ptpc_fp8.py,sha256=dGjDPG2Jry6tream6pXiFza1fc8vZcxJ3mJe29BaChM,5066 vllm/model_executor/layers/quantization/quark/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/model_executor/layers/quantization/quark/__pycache__/__init__.cpython-312.pyc,, vllm/model_executor/layers/quantization/quark/__pycache__/quark.cpython-312.pyc,, vllm/model_executor/layers/quantization/quark/__pycache__/quark_moe.cpython-312.pyc,, vllm/model_executor/layers/quantization/quark/__pycache__/utils.cpython-312.pyc,, -vllm/model_executor/layers/quantization/quark/quark.py,sha256=zhgVltnFyOecCtZ7uQP6ONV_kulSwe925Erxips6Zpk,22873 -vllm/model_executor/layers/quantization/quark/quark_moe.py,sha256=3j8Ah6ve7RPMGUJyQjTwOFeN5WTdroOqX6BBxJmU2cA,41962 +vllm/model_executor/layers/quantization/quark/quark.py,sha256=2tl5jg2Kz-VJ3YBhiSBwp8-CoR6k6eQZIN2Z_YXir9g,24362 +vllm/model_executor/layers/quantization/quark/quark_moe.py,sha256=s1n8LsqPzkfLpFAZAPzxUhDnr1S27YBTm6mbJxljgvE,47045 vllm/model_executor/layers/quantization/quark/schemes/__init__.py,sha256=A699VuIbNQcvuANSmPjZSu9SmSHuqKfSQhvn9Y7GlpE,343 vllm/model_executor/layers/quantization/quark/schemes/__pycache__/__init__.cpython-312.pyc,, vllm/model_executor/layers/quantization/quark/schemes/__pycache__/quark_ocp_mx.cpython-312.pyc,, vllm/model_executor/layers/quantization/quark/schemes/__pycache__/quark_scheme.cpython-312.pyc,, vllm/model_executor/layers/quantization/quark/schemes/__pycache__/quark_w8a8_fp8.cpython-312.pyc,, vllm/model_executor/layers/quantization/quark/schemes/__pycache__/quark_w8a8_int8.cpython-312.pyc,, -vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py,sha256=1_TY1bFMBtkP2N2_CblVGPcHmYyuIorMQAUoH6qLN5w,12912 +vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py,sha256=SD95VzSRowSDcRZvek20PXCERY3xXEydRMVVTSsTh1c,14048 vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py,sha256=VMPNUJAQ6PqGlGmVD0wUx3a2NstTp2aC5Dx6w5ZTHn0,1516 vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py,sha256=VKI1F1WfR7bXnsMpAvipR_PgLoBjGZsnNBUrJzQbrmc,7051 vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py,sha256=58h2xdOgp60F5I8juKR5kfs0Q_tCJ2HnVKsfD4LFxRs,4489 @@ -1501,6 +1532,7 @@ vllm/model_executor/layers/quantization/utils/__pycache__/flashinfer_fp4_moe.cpy vllm/model_executor/layers/quantization/utils/__pycache__/flashinfer_mxint4_moe.cpython-312.pyc,, vllm/model_executor/layers/quantization/utils/__pycache__/flashinfer_utils.cpython-312.pyc,, vllm/model_executor/layers/quantization/utils/__pycache__/fp8_utils.cpython-312.pyc,, +vllm/model_executor/layers/quantization/utils/__pycache__/gguf_utils.cpython-312.pyc,, vllm/model_executor/layers/quantization/utils/__pycache__/gptq_utils.cpython-312.pyc,, vllm/model_executor/layers/quantization/utils/__pycache__/int8_utils.cpython-312.pyc,, vllm/model_executor/layers/quantization/utils/__pycache__/layer_utils.cpython-312.pyc,, @@ -1738,27 +1770,29 @@ vllm/model_executor/layers/quantization/utils/allspark_utils.py,sha256=-VRplxaZT "vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json",sha256=HIoWSUgAOcNaK2kj2YwDjDa23PzQVTT2C2ePW985Ovw,3805 "vllm/model_executor/layers/quantization/utils/configs/N=9216,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json",sha256=gZCqNrWU0B8v6xL1Vgizi1K7fO5gChFMlxp2UbNkOZ4,3254 vllm/model_executor/layers/quantization/utils/configs/README.md,sha256=kfjjurECwd-xH4EDjuueS0Xezi86c_pYu2yELgiw8Ew,102 -vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py,sha256=b_P-NpzJT-f9M_etWKOGj97hQU1ZFx65_hnAccUswAw,20186 +vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py,sha256=RiXhdFi85Oy3rr0_KYzFMhFFJzpTRhU6B-59ObhU5SY,10291 vllm/model_executor/layers/quantization/utils/flashinfer_mxint4_moe.py,sha256=Y1hwfirmxoXradoUqkJ3CeLzHY0KH5jELIDHwjeEN_I,10157 -vllm/model_executor/layers/quantization/utils/flashinfer_utils.py,sha256=sAAMKIgyDSoBIltyMdpIomZ6XafCzcKfi7AGD5wUswg,16759 -vllm/model_executor/layers/quantization/utils/fp8_utils.py,sha256=8QKemKKq3WrbqTQ1SEjGg7bH18pYSeAppw-0Utp5dyo,53968 +vllm/model_executor/layers/quantization/utils/flashinfer_utils.py,sha256=VuifuvCV8lgj6uR7F4pLtl9KQZmFZIo7veKvYY8Dtbw,13212 +vllm/model_executor/layers/quantization/utils/fp8_utils.py,sha256=5HKm6LG5zyyVpCcQQ7AQyLzaBBtbyhsXNV3SSFOzp8Y,54014 +vllm/model_executor/layers/quantization/utils/gguf_utils.py,sha256=UsgKgIbb1sKz2CB8J-q6k7BuVNtwN2zXXgXCCQNgP2c,14654 vllm/model_executor/layers/quantization/utils/gptq_utils.py,sha256=-0GnKH51q6vGtt42RhlbGRkvDI4Yhta4QBG1FlWWV5w,5886 vllm/model_executor/layers/quantization/utils/int8_utils.py,sha256=JI0KrskdcFxZ073TZ-1iLMOgPWN6Gc-BZIA415T8q4s,14424 vllm/model_executor/layers/quantization/utils/layer_utils.py,sha256=CvEUZEj2ZFt7oxcDaRklOJo0IFc-llCiA2bVjybRKFc,1574 vllm/model_executor/layers/quantization/utils/machete_utils.py,sha256=Wv5uFC2UuuoBC_X1YGFOORBf23Zl_Yp5Ck_hFQIbRUA,1699 -vllm/model_executor/layers/quantization/utils/marlin_utils.py,sha256=kIsJzQpmpvhFueA17kgAIUsC82WiJeyo6bPfMXLyXaM,21561 +vllm/model_executor/layers/quantization/utils/marlin_utils.py,sha256=sPn-SwJww6GT4XSgCBdU0KtH9JsFLighLOgqjwdkUPQ,21044 vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py,sha256=L2lfxNwj86u3R3ZClLIECuxpnG1VkI-gw7NY3lCeda0,19454 vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py,sha256=U7Ne8oQRqyV4POZKh0XwINlwUcdV7V4otNFPao0etDc,13977 vllm/model_executor/layers/quantization/utils/marlin_utils_test.py,sha256=D8c0-obb1g9qQVAMptBTTF_2SzHzXR5sIV9ant1_-1M,7011 -vllm/model_executor/layers/quantization/utils/mxfp4_utils.py,sha256=mleIf8cK_qaxk_nehd0Y3d-pl3DplvdLjEa-ikaLQIg,6477 +vllm/model_executor/layers/quantization/utils/mxfp4_utils.py,sha256=ow7rlvA-XCOStp35_UPEB092xXfVeImobEcdTjSifhA,6423 vllm/model_executor/layers/quantization/utils/mxfp6_utils.py,sha256=3ccQ0AmjiDMVtIyZ-G5WZ778QDKbpbRqjv8vfNE9u5I,4533 vllm/model_executor/layers/quantization/utils/mxfp8_utils.py,sha256=aod_8kn4USLCdVojSAgPIv_pEvI3gQeopHA36hH5j_U,7496 vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py,sha256=rFpYWNmA1DD4I536b0U0FgjB4zBVDSpN6cpeldjssvw,4633 vllm/model_executor/layers/quantization/utils/nvfp4_utils.py,sha256=xQ9PdFxtD4_LpV7Uv_zym-y0Zt7O00M8jjjxpixaCmw,13078 vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py,sha256=wkKQfLry_jtVO5sDNBDJDajAOUP2YjRixRrYY_8kqpY,2610 vllm/model_executor/layers/quantization/utils/petit_utils.py,sha256=jKuC_wfRf4D55z9IgDsSJQC8ycsBI2SjBOePR8qxDy4,3841 -vllm/model_executor/layers/quantization/utils/quant_utils.py,sha256=ZqlU3Q4PPBFJqAB0HwglvzETTfhStlWUMiwk-8Q8pUM,27903 +vllm/model_executor/layers/quantization/utils/quant_utils.py,sha256=9nx9Tq_EZNcP3FPMDnumFQxVUOVwj57BgxM8GjkeBtg,27974 vllm/model_executor/layers/quantization/utils/w8a8_utils.py,sha256=ltM6EmBuI9dpUz6bp3xnEoab58z5AfbFYig-3y5z2YU,4856 +vllm/model_executor/layers/quantization/w8a16.py,sha256=Et0jZF5U-2tyRAy6cEuM0IYaZx-2Jyj5Ed8l0wjbk3w,3350 vllm/model_executor/layers/resampler.py,sha256=kK_gbASYVbvMoPPN3petUTfNGK8DkdNdihYnFx-JLn4,9877 vllm/model_executor/layers/rotary_embedding/__init__.py,sha256=PllRvEzr6kmY_ZJeGA9p3u_ubdEKhnbtjoHzD0CmgO4,11094 vllm/model_executor/layers/rotary_embedding/__pycache__/__init__.cpython-312.pyc,, @@ -1779,9 +1813,9 @@ vllm/model_executor/layers/rotary_embedding/__pycache__/ntk_scaling_rope.cpython vllm/model_executor/layers/rotary_embedding/__pycache__/phi3_long_rope_scaled_rope.cpython-312.pyc,, vllm/model_executor/layers/rotary_embedding/__pycache__/xdrope.cpython-312.pyc,, vllm/model_executor/layers/rotary_embedding/__pycache__/yarn_scaling_rope.cpython-312.pyc,, -vllm/model_executor/layers/rotary_embedding/base.py,sha256=MuIUUjEBL5jIxhPMN9_mIZlqNQivd-5-AmW6795cETk,10129 -vllm/model_executor/layers/rotary_embedding/common.py,sha256=2s9Oiytr2WeKr1ebloMvap-MKho8Yh94YqZw2qL0zTg,8505 -vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py,sha256=nWQRNBkTYYj1yrMIqUCKzlXnTi9TvnEsgp7RNCtv4v0,6143 +vllm/model_executor/layers/rotary_embedding/base.py,sha256=mFFI_QYNResCyYVOBFVNZ-oK_83POWHSsk8ljUmsVG8,10158 +vllm/model_executor/layers/rotary_embedding/common.py,sha256=sHf-fXTKfjulR7st_Fj5bvR8GwpUkWl9gmPuvdv3We8,7885 +vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py,sha256=13K3gZiUXQ1o15jCqXq7XPeZ2YRsKA8BkLRlsMG8hjo,3474 vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py,sha256=vCIgcU7FgX8eymrhLHyZ2MXl9i5-kevyyE-J3ex9Mtc,8461 vllm/model_executor/layers/rotary_embedding/dynamic_ntk_alpha_rope.py,sha256=gkD_OmIqduxIUvBgqEqFu22dVDm4zzOmtJ5SgRQJOAc,1289 vllm/model_executor/layers/rotary_embedding/dynamic_ntk_scaling_rope.py,sha256=6lZQAj1v-nvRbZkLIR9_4a9nzAczr09YByOF7IXDr_U,2681 @@ -1790,16 +1824,18 @@ vllm/model_executor/layers/rotary_embedding/fope.py,sha256=RqRQj8sMSazROYGZjF9f3 vllm/model_executor/layers/rotary_embedding/linear_scaling_rope.py,sha256=5a1OVeRvwUeU35x7xjUbhpc_UNtWd2MRWlZ1QftnxkI,4643 vllm/model_executor/layers/rotary_embedding/llama3_rope.py,sha256=EMeHz5tB4hKDXr6ce4GFfFpt8x03HKYlmedwXUgDcHw,1785 vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py,sha256=ls-dUts73Wyp99QPkk4_ED1fewJVH05BighSzRpCmm4,3313 -vllm/model_executor/layers/rotary_embedding/mrope.py,sha256=1fxcpkC_R_38J0R7zGtnsilMmshAu8c2_nXMYJjipRw,14342 +vllm/model_executor/layers/rotary_embedding/mrope.py,sha256=p6HWG-CXP1TsteD3HrnJPJNwgsLI13DogWo4wZEfo-k,14501 vllm/model_executor/layers/rotary_embedding/mrope_interleaved.py,sha256=U83KHEt0F8tM4jjgDgpybXYW47A3csafhFOT1qXbdgE,6291 vllm/model_executor/layers/rotary_embedding/ntk_scaling_rope.py,sha256=KzRfhq8fROqVm5Eu6ETmQZyvc8LZjIKzde0ZGcr8Pkc,1462 -vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py,sha256=AT89hmBGCmGUtX1hoy4agKZfYdHd98eJnE1fGlzkQVA,5699 +vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py,sha256=XSIA8_Ab5B2QqrF60Tjhy6xyd6eppPBzxcHy6HeUiuo,5108 vllm/model_executor/layers/rotary_embedding/xdrope.py,sha256=MDHhx7D3Lir4BWx_bhmOYIruQfBJCzsD-OV7g2sUfGA,5108 vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py,sha256=rZENzx8gm_Ysbp_OzOn0qQB9v-D1faYWz5ZILHJBbrs,2863 -vllm/model_executor/layers/sparse_attn_indexer.py,sha256=0KOwB0Pb7Ut5FrnD8I6qRlSot7K5lnKNQ_F3lmQdrKc,11877 -vllm/model_executor/layers/utils.py,sha256=nkb9R36XoorqHg0mm5bYvTr_gbEiY61Es7BVVdfSVpE,9720 +vllm/model_executor/layers/shared_fused_moe/__pycache__/shared_fused_moe.cpython-312.pyc,, +vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py,sha256=0KI0seYhUNqYNH15p_siLdW_CHYc2hec6P6OYI8VmH0,1936 +vllm/model_executor/layers/sparse_attn_indexer.py,sha256=H8JHgeV4igtlXBG1Y_ICGcmd8hetsc-brpTYN7Cikms,17630 +vllm/model_executor/layers/utils.py,sha256=5mO4gzXtSA8o7nq1W3kpQ26IrpWe8mrf_SXFEgP6I58,11933 vllm/model_executor/layers/vocab_parallel_embedding.py,sha256=nf9sq3wXeYmRzEjoVPJ1prWWzGN-VpTxZsIM3GjTGHk,22192 -vllm/model_executor/model_loader/__init__.py,sha256=snspEqkFK3If_Ev_pvYP5fcIOAq49tnfO-HQ1ocsbVs,5019 +vllm/model_executor/model_loader/__init__.py,sha256=LUgqYUBeNGqh0zXN2Gh_VQWBvGgiIFc2ADWNpKBgfKY,5108 vllm/model_executor/model_loader/__pycache__/__init__.cpython-312.pyc,, vllm/model_executor/model_loader/__pycache__/base_loader.cpython-312.pyc,, vllm/model_executor/model_loader/__pycache__/bitsandbytes_loader.cpython-312.pyc,, @@ -1814,7 +1850,7 @@ vllm/model_executor/model_loader/__pycache__/utils.cpython-312.pyc,, vllm/model_executor/model_loader/__pycache__/weight_utils.cpython-312.pyc,, vllm/model_executor/model_loader/base_loader.py,sha256=VzYyzhjqk_IgUJuJ9pJm6WYHggNyunYCILg8F2usd94,3113 vllm/model_executor/model_loader/bitsandbytes_loader.py,sha256=QjnWFa8OGprcFE0qI2njvjEEmXgdumgBfyypwRfch3U,34623 -vllm/model_executor/model_loader/default_loader.py,sha256=ms-p8MRpMW0vR1-0u5bCPDUZEDwLFNyljMC_aRxdiLc,11817 +vllm/model_executor/model_loader/default_loader.py,sha256=GbtJE43yTXJmMleW6E4-1XndD6zcZ8f0_Met0eeSYuM,11928 vllm/model_executor/model_loader/dummy_loader.py,sha256=R7xjrtLVejnd9j8zejScUTuY9zeAN6JlpBcNCfeUXJ0,1129 vllm/model_executor/model_loader/gguf_loader.py,sha256=cDZIzZ18naZghuK1Si6d6gxe-GyJwvHdQpA5sjY-qBI,15557 vllm/model_executor/model_loader/reload/__init__.py,sha256=h3WL_a1isSc2RKQBuEzD-0HaYslFIVWPsR7BMPCZVdc,1435 @@ -1836,8 +1872,10 @@ vllm/model_executor/model_loader/sharded_state_loader.py,sha256=KxhCy60owpwAzg8O vllm/model_executor/model_loader/tensorizer.py,sha256=kqhUvxmpKEb2WKDQdE-8kHqsogZMvFtUpTmPXjH49jA,29902 vllm/model_executor/model_loader/tensorizer_loader.py,sha256=Gq7pYVtrOr8engSSWbW9x1qtoos5HuaS-ZrzWQHvR40,5900 vllm/model_executor/model_loader/utils.py,sha256=LQ6dgzLXaDDpPghN1VaWPtIjCBSap2DyAYWRGgt9Fbk,11020 -vllm/model_executor/model_loader/weight_utils.py,sha256=_AROcisvdtIP-Nopk_oRwecL7Qs0qd3CJmbxtjamv3w,53018 +vllm/model_executor/model_loader/weight_utils.py,sha256=wSjU-TUkeMLTBv-sRrAOulQWtQTnB7Nb28m4LTsw4_w,49996 +vllm/model_executor/models/AXK1.py,sha256=gliagXFSEpAOjNjYxy0RU5tzjGpDsceiS8E2bOCHNtU,45919 vllm/model_executor/models/__init__.py,sha256=JC8KFZavvBlB4woyRgOBd_vcuoiA31nmBOR5dGcJSjo,997 +vllm/model_executor/models/__pycache__/AXK1.cpython-312.pyc,, vllm/model_executor/models/__pycache__/__init__.cpython-312.pyc,, vllm/model_executor/models/__pycache__/adapters.cpython-312.pyc,, vllm/model_executor/models/__pycache__/afmoe.cpython-312.pyc,, @@ -1851,6 +1889,7 @@ vllm/model_executor/models/__pycache__/aya_vision.cpython-312.pyc,, vllm/model_executor/models/__pycache__/bagel.cpython-312.pyc,, vllm/model_executor/models/__pycache__/baichuan.cpython-312.pyc,, vllm/model_executor/models/__pycache__/bailing_moe.cpython-312.pyc,, +vllm/model_executor/models/__pycache__/bailing_moe_linear.cpython-312.pyc,, vllm/model_executor/models/__pycache__/bamba.cpython-312.pyc,, vllm/model_executor/models/__pycache__/bee.cpython-312.pyc,, vllm/model_executor/models/__pycache__/bert.cpython-312.pyc,, @@ -1888,9 +1927,11 @@ vllm/model_executor/models/__pycache__/exaone.cpython-312.pyc,, vllm/model_executor/models/__pycache__/exaone4.cpython-312.pyc,, vllm/model_executor/models/__pycache__/exaone_moe.cpython-312.pyc,, vllm/model_executor/models/__pycache__/exaone_moe_mtp.cpython-312.pyc,, +vllm/model_executor/models/__pycache__/extract_hidden_states.cpython-312.pyc,, vllm/model_executor/models/__pycache__/fairseq2_llama.cpython-312.pyc,, vllm/model_executor/models/__pycache__/falcon.cpython-312.pyc,, vllm/model_executor/models/__pycache__/falcon_h1.cpython-312.pyc,, +vllm/model_executor/models/__pycache__/fireredasr2.cpython-312.pyc,, vllm/model_executor/models/__pycache__/flex_olmo.cpython-312.pyc,, vllm/model_executor/models/__pycache__/funasr.cpython-312.pyc,, vllm/model_executor/models/__pycache__/funaudiochat.cpython-312.pyc,, @@ -2021,6 +2062,7 @@ vllm/model_executor/models/__pycache__/ovis.cpython-312.pyc,, vllm/model_executor/models/__pycache__/ovis2_5.cpython-312.pyc,, vllm/model_executor/models/__pycache__/paddleocr_vl.cpython-312.pyc,, vllm/model_executor/models/__pycache__/paligemma.cpython-312.pyc,, +vllm/model_executor/models/__pycache__/parakeet.cpython-312.pyc,, vllm/model_executor/models/__pycache__/persimmon.cpython-312.pyc,, vllm/model_executor/models/__pycache__/phi.cpython-312.pyc,, vllm/model_executor/models/__pycache__/phi3.cpython-312.pyc,, @@ -2070,7 +2112,6 @@ vllm/model_executor/models/__pycache__/step3_vl.cpython-312.pyc,, vllm/model_executor/models/__pycache__/step3p5.cpython-312.pyc,, vllm/model_executor/models/__pycache__/step3p5_mtp.cpython-312.pyc,, vllm/model_executor/models/__pycache__/step_vl.cpython-312.pyc,, -vllm/model_executor/models/__pycache__/swin.cpython-312.pyc,, vllm/model_executor/models/__pycache__/tarsier.cpython-312.pyc,, vllm/model_executor/models/__pycache__/telechat2.cpython-312.pyc,, vllm/model_executor/models/__pycache__/teleflm.cpython-312.pyc,, @@ -2097,6 +2138,7 @@ vllm/model_executor/models/aya_vision.py,sha256=5XLbDFjKP3TMhhMEZKiPvtuvqR13m-B_ vllm/model_executor/models/bagel.py,sha256=RFmNsYD-cHvkJYu2t_lqcwsUzYf2Ge-xBQWPOrad1iI,20840 vllm/model_executor/models/baichuan.py,sha256=UDYfE2aKMTOLJVLcCa5tPOlGnpgulbYgSp42EYnUs98,17841 vllm/model_executor/models/bailing_moe.py,sha256=_mYgI8Wx_MoTTWaso0jZcC7LzH2J-OmbabLzxUJYiR8,23151 +vllm/model_executor/models/bailing_moe_linear.py,sha256=2cFzHeSqLXII49rzD8hBLiQ4MNGxJn_WIwFVdw9PXGA,44543 vllm/model_executor/models/bamba.py,sha256=F5BlJxQgw4jtqoW_D62rmAkVV_pdWMA_IFtMhCyM9nU,17905 vllm/model_executor/models/bee.py,sha256=VyADm_f5piyv7b0-_D_mEZpSCq-tRld4dmJ9CJT5be4,5412 vllm/model_executor/models/bert.py,sha256=2BWsUrhAPe7XDu-2EvlPnMzR5sAU4DnpC-dGoT0zxcc,30889 @@ -2112,47 +2154,49 @@ vllm/model_executor/models/colbert.py,sha256=9khjV3Pvo30Rn5ZRiRTZOBt8JaltTSy4Ct8 vllm/model_executor/models/colmodernvbert.py,sha256=cb4sak9EiKmNdWgL4xJYjM2PpTGnMCbwRCCZWXOW2LU,15895 vllm/model_executor/models/colqwen3.py,sha256=uaJdiDF7kCNFE8tst-MWmKfUQtI9KKUqCyjLV8FmlIw,12519 vllm/model_executor/models/commandr.py,sha256=nt2FhYi9BMJy-dNWdt_n5QrrKIYLi4yGcq1TuMaGR7U,17426 -vllm/model_executor/models/config.py,sha256=nJ--ydGzN2cVqpdlq5ycKoxHfVoU0Isl2RPdSCEpT7w,27285 +vllm/model_executor/models/config.py,sha256=8e8wpp_tSIaKCmTtgAzXEhQbM7DSH7rV9h7gyRQGTK8,29241 vllm/model_executor/models/dbrx.py,sha256=lotryU3Rj2LkGDvDh56OmpUpett9FAG1OulvJFndfi0,17459 -vllm/model_executor/models/deepencoder.py,sha256=dGi2ga5F1pIvKzDkVXTx4QYfw2Jg17eJgD5MheqjIIs,23840 +vllm/model_executor/models/deepencoder.py,sha256=fQpcg2l92QZxT_zZxwO6l26YCrKEDUVHZxI8uRL9jxU,24021 vllm/model_executor/models/deepencoder2.py,sha256=wX5MmEGmVcOzXCYdBvUcPECUpTaY9c4uQkfl2Yu3bcw,9295 vllm/model_executor/models/deepseek_eagle.py,sha256=godN3QkVdWqAcpDrX7lpHe_8tVFmvvThd_BAzcexk8k,9347 -vllm/model_executor/models/deepseek_mtp.py,sha256=71PGNDRN0Kdg12MfGDdwGJDmcHubVkxIN4jLAxBAmNQ,18208 +vllm/model_executor/models/deepseek_mtp.py,sha256=0oOlEU_fAzzw8SpFZpHzcQ5DFcbmfQpZfmssmUCZYVs,24697 vllm/model_executor/models/deepseek_ocr.py,sha256=n2e35GbvjPPc5ROT8pKEBTyo9iorCfdayb4d6MOLyuo,21479 vllm/model_executor/models/deepseek_ocr2.py,sha256=XTfawQbWR_pj3s9MCwzc49DpnQCZRPyJcH7EIo2sOF4,15521 -vllm/model_executor/models/deepseek_v2.py,sha256=gxB24k4gMtblkj1FmwyHu8BpORBH4bnsJkvV9YiuuPo,63828 +vllm/model_executor/models/deepseek_v2.py,sha256=zimuPBOFDeebwyhjKQNXmiHOWl4ti_ZljNphvKcm7oY,66562 vllm/model_executor/models/deepseek_vl2.py,sha256=Xp9LsW6G7d2uF-D94luBaFW3om-A8bD4twa4ufKctGo,22435 vllm/model_executor/models/dots1.py,sha256=BqFjkVW2NlrLoTDhY4IWSWpUK1L_C-nd9FFiFmaoNUo,20865 vllm/model_executor/models/dots_ocr.py,sha256=Lt5qSAqHgZ8jiV8hNBDzv3usUMFJMDMMUhZkmbRra4I,26938 vllm/model_executor/models/eagle2_5_vl.py,sha256=4qvPgo27uo2UzrMTZWN94Ffilktau1ZGT5wyGX6rhRU,16273 vllm/model_executor/models/ernie45.py,sha256=gxkle1tmFbYsDkQx8AjoYZ-2_e93ALUtHMMdVnPA_fY,2244 -vllm/model_executor/models/ernie45_moe.py,sha256=V1DB-RLpKkN-DxrkUw5uZYzbhFOb4NK8_dbP6CGOyE0,28140 -vllm/model_executor/models/ernie45_vl.py,sha256=8H_gqsqmjADNqEjetMKt8bkb2LTG2wIrBoagRrstb2I,60912 +vllm/model_executor/models/ernie45_moe.py,sha256=_-UdUvUiUCxcqh5p8HYLv79WjFdu0Prs3fSlUiXOkPg,28234 +vllm/model_executor/models/ernie45_vl.py,sha256=brF7polIt5aOwjkAdIN82smjI1HmQnn5NtBkKswt_mA,61614 vllm/model_executor/models/ernie45_vl_moe.py,sha256=B_xBD4bvbQsGbV3Vk1tPPpt9TAvtnirxuyfSo2M-vBY,30498 vllm/model_executor/models/ernie_mtp.py,sha256=Fxb09Xis5scfTMtOngTHbftAFvYLMCx4Ja4R5qW1Ov0,10528 vllm/model_executor/models/exaone.py,sha256=ZSt7DOqMrA58yJ0pUWF-_lpmN72cFTUsRKrDkq31EsI,18832 vllm/model_executor/models/exaone4.py,sha256=UVO950qM-kG6qJg8CQ7CTUz7mobl2I-Ki1Fd-9MiTtc,18753 vllm/model_executor/models/exaone_moe.py,sha256=L2jDqHY1KPMUzXTIDPfTbkrJQVOcDA21KpDnqHPVsrw,21895 vllm/model_executor/models/exaone_moe_mtp.py,sha256=84xCAROtS2_BNWcySKRFm-v7bUS1sjUsIjhoE6clH_k,9294 +vllm/model_executor/models/extract_hidden_states.py,sha256=EcCjVVgjSTj7XqMTpXG6CLHJk6gbmaGOsJ4pKj4iYDI,13218 vllm/model_executor/models/fairseq2_llama.py,sha256=dIadCgG1u8yASjN4qmOTvkli4xxu-Ygl3t0iL2Czsik,6375 vllm/model_executor/models/falcon.py,sha256=LqHvEA3eLJG5qmZyby1y94rGXhuWQxvWJNMqogPoWXg,20490 vllm/model_executor/models/falcon_h1.py,sha256=fNGP9Yl1rc8pa4rkBHaP0VnQZrZmRWm6Td6pPXPHhhc,24732 +vllm/model_executor/models/fireredasr2.py,sha256=HT2ZxZgm20a-xdRoOhWR0qZLI1aZ0w5n0L3a-ACOySc,27555 vllm/model_executor/models/flex_olmo.py,sha256=8xGT9dxqgsyGeebsCzlV_4Z9WsJvUv868bOyWEd21xs,5610 vllm/model_executor/models/funasr.py,sha256=13zyQqR7uiqiw4r9wuARByxKE-rGPDZSksva5MDEm40,34185 -vllm/model_executor/models/funaudiochat.py,sha256=TneAjOvWfTleNiTKnTsL_LQ5oQtZiid2hoWwc_QYKvE,42133 +vllm/model_executor/models/funaudiochat.py,sha256=_XJvJA7-qEeraJaKwzjRcw_bhPbPd2wd_tA98W3fJfw,38660 vllm/model_executor/models/fuyu.py,sha256=60pJLMtw1ahoISbFi-rocZrTD2UWE7lXAh2Imo11VCU,12790 vllm/model_executor/models/gemma.py,sha256=3Jge3UOI-LmSOKHOaUFH8riFZ4SNxDxsRHx4Zb4M4hI,15564 vllm/model_executor/models/gemma2.py,sha256=OFktjmniv2G9EvQVF9CjMCQbeGbcty7a1sHdNawj2QM,16385 vllm/model_executor/models/gemma3.py,sha256=T7ulpHlAUoIBwpQv_1EmU51GclNoglosMX7hbC3rDak,19521 vllm/model_executor/models/gemma3_mm.py,sha256=CPBhSYTD5bQnYRcBTncCGUnUuWnDE8QFIsf8zYgQVQU,23327 -vllm/model_executor/models/gemma3n.py,sha256=yc7muBjfrsDIOZFPUiK4ZbGasGZ8wVngTVlU6WavluM,44213 +vllm/model_executor/models/gemma3n.py,sha256=FxdCWLWN-nETIWtn3dg8jTlw4NPkJApkPIPwv1MWGDA,44277 vllm/model_executor/models/gemma3n_audio_utils.py,sha256=8QRbJkqTWWcTtQZ8q8w-jBC6WlslT1ytYX49Zqn-oxQ,2414 vllm/model_executor/models/gemma3n_mm.py,sha256=yFLFL2y1u4iJxwBfmaRNlbSKgwBmLZqrzbrnzhZFwyE,30115 vllm/model_executor/models/glm.py,sha256=Xnak9Z2N9i2PFVY9m9-ocRcocTEQCPow5mGFena0URI,1108 vllm/model_executor/models/glm4.py,sha256=6aVSxPSrNVikr9hXe4r5NyplJE7YKni26JwEAOsa2qk,14460 vllm/model_executor/models/glm4_1v.py,sha256=lPbnzKQrxGjHvBKjW9cBUTd_zPopnLy18uw5EbLIBHo,63465 -vllm/model_executor/models/glm4_moe.py,sha256=1mXD0vtG1JyzCAVsK1XepWZUyq7ulGSE4WkkHeDX1nk,27686 -vllm/model_executor/models/glm4_moe_lite.py,sha256=RNW4hradY48Mbh7aLF6MU-RRHkg4vTVY_pRMVT7wNj8,24704 +vllm/model_executor/models/glm4_moe.py,sha256=xW2YcmruADEtMUmvTJk4fmeBJMx1rCmeiDFMuxe-tj4,28067 +vllm/model_executor/models/glm4_moe_lite.py,sha256=DVw9ss8nwtAG4tIO_8TAc_5buu7xOF_CWoZmKTDbsw8,28356 vllm/model_executor/models/glm4_moe_lite_mtp.py,sha256=KBVnPqBX9_ebWB-B6p81wYgl-gEUkGqSu4bhpbDJTyk,18956 vllm/model_executor/models/glm4_moe_mtp.py,sha256=xBuVd-8HQf0dPcXdCSuVRkIizLbOpKBJ63r2gxEY31U,14543 vllm/model_executor/models/glm4v.py,sha256=hLOKgSaMmkppivK7PKlJoUlQsT_y-SjGysoBOitKCds,22760 @@ -2164,7 +2208,7 @@ vllm/model_executor/models/gpt2.py,sha256=hn1Ke4iG7EFl2Z08cDh_h4eT_H6-A_cDkBWzvV vllm/model_executor/models/gpt_bigcode.py,sha256=Z6CXbWDnof42fumqyNk1gkHP6EM3HaaSS52OnFlYnVQ,12165 vllm/model_executor/models/gpt_j.py,sha256=rCYV9bJeMnNPfOSmprPN7SgIBsDn83wzGinGflYh98s,12788 vllm/model_executor/models/gpt_neox.py,sha256=WmjxCXe50QcSJ4BYnmy3vLFwx62cUxwlfp4xKp3s6QA,12733 -vllm/model_executor/models/gpt_oss.py,sha256=zAwPWoFm2Cj1HVw3ypn3dA1vj6KSHNtaF1T5IwL3BJE,49178 +vllm/model_executor/models/gpt_oss.py,sha256=ez6LZkZUljft_qX8delpij9KY9iTQCnKZGjQnXH6-kQ,50254 vllm/model_executor/models/granite.py,sha256=fhn0J0lU3X2JDaOL0XAXwLPJWkHYgafLNjoEgubw4uQ,17392 vllm/model_executor/models/granite_speech.py,sha256=C0PblOLooxXdRDC39aFHv55T6E3_HAmgS351Ok_NVq4,34913 vllm/model_executor/models/granitemoe.py,sha256=X8iv2scVPqg0RmGZbb1sHzj8ajIK9DKTB_l1ansVbc8,21148 @@ -2174,8 +2218,8 @@ vllm/model_executor/models/gritlm.py,sha256=TI4KV2jQeocNtwY-VF1gGq_66EvX1kPo7NNG vllm/model_executor/models/grok1.py,sha256=tNciZn3JfgQ1xiKrtT_mmNcHgpPx0zYXexaalcvXwjg,29214 vllm/model_executor/models/h2ovl.py,sha256=-_UGnfuUCDJgKP0eLyg_z79fOg8aCsMHY2ewRYcv4nI,17187 vllm/model_executor/models/hunyuan_v1.py,sha256=FetuRvv6Ry8ksbkbZ1l9257j9YnllULpxo85aURtm4A,40081 -vllm/model_executor/models/hunyuan_vision.py,sha256=RUhknygqSk0dlEnFjwXfwxlWD1JLMEHH1GDv8gFUQcg,34868 -vllm/model_executor/models/hyperclovax_vision.py,sha256=UejOXTZcqP8ko8CkMQa1fHK3xyT9hX726eNrvvzB_ns,40079 +vllm/model_executor/models/hunyuan_vision.py,sha256=acNd-RjmKUNg3eg7z0rNgA9TOKRlFNjSjIE_KavwsQs,35219 +vllm/model_executor/models/hyperclovax_vision.py,sha256=yfAs11Bh5EBnukembCU7k6iaQgIp1vtlRRFidmAGTb8,40057 vllm/model_executor/models/idefics2_vision_model.py,sha256=sq8wgfKf8s84Vfvhf0QvYgti7EuDHY0BUjXkfo-gWbk,15354 vllm/model_executor/models/idefics3.py,sha256=LVJLfwIXkQd1CLUX9rgAgjAM8E0OEYI0pshqAVjR45o,24142 vllm/model_executor/models/interfaces.py,sha256=qfCq1h0UP7A5MKnpV_DSg2ToG7wh-9ZuN3mz_6SRkQM,45722 @@ -2188,13 +2232,13 @@ vllm/model_executor/models/interns1_pro.py,sha256=DMT6NCMIeKUpg2_Meulj-bNmJdtJVn vllm/model_executor/models/interns1_vit.py,sha256=_xYnzI342n_VEGVD6Bh_paQRAkMkhMl0wAh9P0VFe8U,15097 vllm/model_executor/models/internvl.py,sha256=gUczPpJolPFpESSGfxQnbUC3axD69Zrbnh_YobrJyTA,49486 vllm/model_executor/models/iquest_loopcoder.py,sha256=eoGgtrYyWMN4GbBm_Y26Blj5UCghLOZ-I-_OjQ6FwXk,22512 -vllm/model_executor/models/isaac.py,sha256=tAROikBtYtDAnVLqIV16p21c545eVIogfgEdYrMZQ_0,53516 +vllm/model_executor/models/isaac.py,sha256=QH42oeAUyYTEHiBsTncCeE1gA57rvvKL2MRG7ihnnKQ,53517 vllm/model_executor/models/jais.py,sha256=6Gh19Z0QfJ_Sy7fyLpYB7ev9OmtEoyFBRvNbEijb8u0,14182 vllm/model_executor/models/jais2.py,sha256=Pxu3I7vI3yy3jTCeIg3NbPdMYFiy7sqM6fccKKXejE0,18586 vllm/model_executor/models/jamba.py,sha256=MA-WJEfGyTWwgzaBDBvWLcBA7Nnq6OuejRc-X61FTsc,21509 vllm/model_executor/models/jina_vl.py,sha256=zp_BVbAQ5toYF6_pwlzTk8qsh9ezG0dW4IgBOaeoVtk,5084 vllm/model_executor/models/kanana_v.py,sha256=IORzJfTJBOSQfnkmpS7BCtnN5CsRyNIOP5QxGOoGRZg,27312 -vllm/model_executor/models/keye.py,sha256=JyFkuxrek5__J5p85I_dwkxmpNZm5Y3m93tL1QQE2c4,59197 +vllm/model_executor/models/keye.py,sha256=uEINJwt2PbcHWZghTfLxXkyp7qNv9wWP_bP-b4v3NB0,59543 vllm/model_executor/models/keye_vl1_5.py,sha256=a4ZzxhDIqDSU3J3GtzPm-I22lw34s8UzLgpFSoqnpJg,24967 vllm/model_executor/models/kimi_k25.py,sha256=tRvAZV2UxwxeaSvPatl0qncAFCRn0z8wx3RmpQB4MtQ,17313 vllm/model_executor/models/kimi_k25_vit.py,sha256=LU_0gXaYzaWo48mEXGxePpf-6KeMfxtKWwX7jd2MWjc,23163 @@ -2223,12 +2267,12 @@ vllm/model_executor/models/midashenglm.py,sha256=qKLjwkuRF4rX3lgsXeaAc5vCuqW3uGz vllm/model_executor/models/mimo.py,sha256=oMHhy2sKkX4jrQUE4wxo9Ogk7ftrJ7PHJAxDzKaCfAs,7339 vllm/model_executor/models/mimo_mtp.py,sha256=c763S3lq5jzNswJ4TeyWlZ4YEjNEkLqt9qKKAP4uUxM,11068 vllm/model_executor/models/mimo_v2_flash.py,sha256=sYAPzCmLc_W5gKPGKjSwNnN1-4VDc0_oCGxjf_2NRJM,25791 -vllm/model_executor/models/minicpm.py,sha256=onw84qqdkTxWInx271tkZ1Ed6EAyk38px8ftGmTNAYE,23790 -vllm/model_executor/models/minicpm3.py,sha256=lA7ah336BFyl-3PKNHGEDqvdNdydk3CpEp2yJHRKHRc,8585 +vllm/model_executor/models/minicpm.py,sha256=UZGtCej1sXoS3QS5Gd8-jQvmnHxFjDkKKmlt4e5Yts4,23789 +vllm/model_executor/models/minicpm3.py,sha256=5PhiB4l7f6mlkGumDtiko6GdjijI2AqVjSOPcIgR88c,8584 vllm/model_executor/models/minicpm_eagle.py,sha256=-Y-N1X76_Crvq1OMd_r-2FqPfZ-d-5-OYbOhgZYHI1o,14435 vllm/model_executor/models/minicpmo.py,sha256=mlRcavEVtZJZuUXuXILevcdsJpcYcPsPXhEhgeRHkJk,30605 vllm/model_executor/models/minicpmv.py,sha256=igRUuvDYijPdX4OE-sSA37NZ251azZbb6qqT6ZyVYbQ,59008 -vllm/model_executor/models/minimax_m2.py,sha256=4LEPOv737S2cLwTTy27Ye4mNmaw4BBkcZLbGl4283t0,20952 +vllm/model_executor/models/minimax_m2.py,sha256=I57RFtBryk3ryVYRmQPUPewNTHUGmXodl34GwVg5JhY,20919 vllm/model_executor/models/minimax_text_01.py,sha256=KXXkFsjYQketFw_dz3uTVDtNJkq_9aULi5f1ypQ6XVk,38245 vllm/model_executor/models/minimax_vl_01.py,sha256=pzFFB2OvyRsx4CcRhQLUlkpwEXIHqFql-Vpqg7loKq8,14147 vllm/model_executor/models/mistral.py,sha256=YYklS51bg-BRZv2dNTODfpjXOF2jxtphb_xO4EZIhv8,11706 @@ -2245,13 +2289,13 @@ vllm/model_executor/models/molmo2.py,sha256=q26NYVOKkLzjZIwAkhCecE7WLONbl-gyJFsT vllm/model_executor/models/moonvit.py,sha256=0kIg-6OIbPmRWWnNCuWcMSLgD5GN623A8c5f_EtOeQU,21491 vllm/model_executor/models/mpt.py,sha256=hT2CtvYFxHRbmGmkDp2fmZ2n5C7OWKuVgTFsbkml0t8,12149 vllm/model_executor/models/musicflamingo.py,sha256=W9K3xxiB-hbl07faLAZ5Tl_XaP2t87tZPhor4oPq8os,2383 -vllm/model_executor/models/nano_nemotron_vl.py,sha256=PdswnArosgwi4KipZEoGNdbAkIkAYCFmHPVPI2FFvwk,81121 +vllm/model_executor/models/nano_nemotron_vl.py,sha256=ncAcHPfWsiUIsaOUOmQO3qSrMWz29Q2_B9w3UNlcSHU,86755 vllm/model_executor/models/nemotron.py,sha256=NFaGuOxfWOJGbR9s1EjqWjHAf4V40P5LYvOAzLgoiTw,17952 -vllm/model_executor/models/nemotron_h.py,sha256=j_pTGXNsUVKRyufxV9q0Cc8jRCcVdwNfHn7YRxBNLgk,35030 +vllm/model_executor/models/nemotron_h.py,sha256=NZkUNMkfqasFpPIjFH1xl5ZC6uMuPacAV8nzv3XAwGI,35147 vllm/model_executor/models/nemotron_h_mtp.py,sha256=HKJFEN2ue2_vq5aym9HIiQ023j5v0ONq6WH-N7g4nww,17912 vllm/model_executor/models/nemotron_nas.py,sha256=Er37_CCWpEreyAQKoDujVmDTGDBZPupfUBnJNluK63I,17101 vllm/model_executor/models/nemotron_parse.py,sha256=K3Qr4c3Tc7jCtkRD0Z-uqduYsukIyPjzNmJ3UHCUde0,32172 -vllm/model_executor/models/nemotron_vl.py,sha256=ryzL1J8mKLlAovTs8cgP7Xr6oCCWVNShGbZv3hWSZhs,22405 +vllm/model_executor/models/nemotron_vl.py,sha256=nXwin_4Spok5Krdf0JhUB-4TU8i_qn04O3LPwnpj-nA,33620 vllm/model_executor/models/nvlm_d.py,sha256=piXDDA8xuNAWfruEPTHbYj5X5l3YZ_TJ1EaCvNXWNgU,7598 vllm/model_executor/models/olmo.py,sha256=5TXW9R5WZYCIr9ZXm340uIkrXJolsmorMAmTNt83S4A,14357 vllm/model_executor/models/olmo2.py,sha256=cd4UEDpwHV90pln9e8m-HmvPROAczb54qTP9q5Cy3lg,16618 @@ -2264,9 +2308,10 @@ vllm/model_executor/models/opt.py,sha256=VovPe87bR_iDzkUMnTiRiO8sq6pmzRbARTPDXy3 vllm/model_executor/models/orion.py,sha256=3-7b_q8dW9C8mbHdF59rp8bnJxUhLpp95nmIbs566VQ,13242 vllm/model_executor/models/ouro.py,sha256=FUcWm2suD7ROuA-RAcjHRe6ae-NRk_UPMTML_YceuBA,18806 vllm/model_executor/models/ovis.py,sha256=f-F1pVovgx_b7JegMEhHgXgFjZhbY2UyJXFinrWoRBY,20489 -vllm/model_executor/models/ovis2_5.py,sha256=L-XiG74IOJvbsVmj8TaVkxh-jNMZ2tLaa8v4CAb_SYc,24277 -vllm/model_executor/models/paddleocr_vl.py,sha256=0vaw2w5w8faNdmVgHuqHTaBjX8BrscxOsfCvBZOuq6A,43231 +vllm/model_executor/models/ovis2_5.py,sha256=FTR4GVuwDrlEMkncXpoBQZoearimbqHa-84Wgky1yv0,24189 +vllm/model_executor/models/paddleocr_vl.py,sha256=dTa73tvoy6Kicpi5_Anw8HtxcgYrGn68mDpcP83ht9s,43933 vllm/model_executor/models/paligemma.py,sha256=xZedgaDERYrwBP5OyTx0D7veWPo97Oj07J2NjSyFNoo,13873 +vllm/model_executor/models/parakeet.py,sha256=vwq0mgj_Y0EzMSoUmUb7BDlP9C-HPSmcFInb54Jw2Ts,5774 vllm/model_executor/models/persimmon.py,sha256=ZrT26c5oE6YXyuO9Lu-qFfFqwh6QQKHYxGtz-Ii12FQ,13620 vllm/model_executor/models/phi.py,sha256=CS86mMtB-SS1d-fxzKLu26EgvcE6CwyCcMYgx_jgdFw,12987 vllm/model_executor/models/phi3.py,sha256=N380ApMP7_3VL-rlWeVD27SUjC-kLXbldzYNWgG8ApM,456 @@ -2274,32 +2319,32 @@ vllm/model_executor/models/phi3v.py,sha256=16lU2Ry4D9HQWDXTSHgf_TYOYm-H9QrGOR3vk vllm/model_executor/models/phi4mm.py,sha256=fY5utJ3CmhehRhawB2ITPdtsCZtZAfPgaXc_bEA6K7o,45373 vllm/model_executor/models/phi4mm_audio.py,sha256=4la2vPrfGNSC0Y_LAD5qeoJDOk514IXfVWij6_z_e70,50426 vllm/model_executor/models/phi4mm_utils.py,sha256=GxjDFelDuVA-jdf9HBYUlaXNfgs62ebCAVK80CFzqTo,66152 -vllm/model_executor/models/phimoe.py,sha256=BuUlupn3L7i49USU0tTuCEjOKNMg0EhD7TYLdF0Ki-8,23530 -vllm/model_executor/models/pixtral.py,sha256=ugP9TulFlcm_ynoJnld7bC1eRJ4YEL5j-DE2ZBuMiEc,49151 +vllm/model_executor/models/phimoe.py,sha256=4AZ8TQfyOrTYdhRgFs3YA_g8VH5Dq4ZqqoeAUWqo1KY,23631 +vllm/model_executor/models/pixtral.py,sha256=pbs1OE7xfwwNq78xozpcoRuTPcqwXAGQ0jUYiW7aUOc,49305 vllm/model_executor/models/plamo2.py,sha256=UH6XNDM0ll6GznSqOHSaLMRemk2NkHIoofflHYeD9fs,37903 vllm/model_executor/models/plamo3.py,sha256=CbqsDUY59dgrr5j37aT4_74mMZKYe3EEntzhHsExjpo,15850 -vllm/model_executor/models/qwen.py,sha256=hK3zUTqM8oRU2x5hl1CMpVulQpcXyVwFSu26QJdR-XU,13274 +vllm/model_executor/models/qwen.py,sha256=n0801rEic97cw6YOetUbBi1iYitoOBbiE1FW_Y7btAI,13282 vllm/model_executor/models/qwen2.py,sha256=-UoJnvCi1tQUODCDbPooJJesEDtOuiGnPat0eQYUvP4,22235 -vllm/model_executor/models/qwen2_5_omni_thinker.py,sha256=bWgLbLPXox9U0yVY0vwe8SqunBJn-II4SW4HLkm1wBY,55350 -vllm/model_executor/models/qwen2_5_vl.py,sha256=PDLj70FNFMHVw0L4O3iiFrsP-6_iRT4NW4xO6bK6oO0,53962 -vllm/model_executor/models/qwen2_audio.py,sha256=cY4Jy1r7EsVsWxBfFQcUbRTXPhi-O5hJ2bs8fA5okC8,16642 +vllm/model_executor/models/qwen2_5_omni_thinker.py,sha256=b7Htse1Mb2UgqO2EwMB4bn1Wdets0l132nOraCLoLeE,57875 +vllm/model_executor/models/qwen2_5_vl.py,sha256=LBG2zprvqKNkfDyKnB8yWmmGbnxerp9-Qvjmi5nS1Fs,54596 +vllm/model_executor/models/qwen2_audio.py,sha256=g0veswhwz4mWEZIDFHkgru5AdWiz_LwMOMSrFR_fJB4,17394 vllm/model_executor/models/qwen2_moe.py,sha256=WIma4Bczak6gDjd4h4iWnhemF4wROWFIasyp5lafSo8,23264 vllm/model_executor/models/qwen2_rm.py,sha256=qDtLznXFur9l3y6sQSeG_AMNIdMken95UHdKBbgVG6E,4120 -vllm/model_executor/models/qwen2_vl.py,sha256=kStWdKZsqnimKdxSBGqj1wkW2qFFBz4QicKd7LWJC7g,53333 -vllm/model_executor/models/qwen3.py,sha256=r7q8DSESqiHQjcI51rjIYHvFt1d9fB4MUzjRE11b9rY,13294 -vllm/model_executor/models/qwen3_5.py,sha256=RqkiCHK8zkTiPh7n8B80o-eebWWiPGCrE628UMo8_dQ,32317 -vllm/model_executor/models/qwen3_5_mtp.py,sha256=wlktoP4-wYSR_eHLqq-O2KSgzzD2_UqugeZ4TkPSzTk,16574 +vllm/model_executor/models/qwen2_vl.py,sha256=vaOZPPwEvbNCV-msL8XmcYiDR0TLFA_OutIbB4IkoCk,54141 +vllm/model_executor/models/qwen3.py,sha256=_WYUj0_aBKUDmJ7i4-1H-nyLT-AcoM_lpp7v_zJl4mc,13397 +vllm/model_executor/models/qwen3_5.py,sha256=jHhBKnHzl-rPUovQkg0Y6hIPqZ-VHu_JI62P8VxH6_k,32218 +vllm/model_executor/models/qwen3_5_mtp.py,sha256=6yjQt59R0VhrlCCe9toyScBmEM9cpcZ4Sq4lI2z83vA,16574 vllm/model_executor/models/qwen3_asr.py,sha256=H2s7E2svPp9Gpy1pKEIKTOzl3Y2mbQxn4cW19mNPd8Q,21181 vllm/model_executor/models/qwen3_asr_realtime.py,sha256=6k2QkZZ9E6wb5JEk9zb7MkJ7mhmquGzysEW1Nrqlcps,8879 -vllm/model_executor/models/qwen3_moe.py,sha256=iMmeP73CF02m4l_Mhk8OU3Mu6JVeoyv0yGbWoZagQqg,31302 -vllm/model_executor/models/qwen3_next.py,sha256=WEqxvBk95T5lRBaJ6BWjb_YgH3yXnutAKZzBSGDPMqA,56101 +vllm/model_executor/models/qwen3_moe.py,sha256=ue8JOXWEGoD4vUo-AJOtbv_EW0_k36njr-9tTacClxk,32314 +vllm/model_executor/models/qwen3_next.py,sha256=sDhPnDxtmyA3jISo-Nr7VyaayVYO7TMQ0GAhlAsakGg,62735 vllm/model_executor/models/qwen3_next_mtp.py,sha256=lMY5ULOAj2ykUsLFsQTnSx2-wfn_Z5Akwdw2g41Jtzw,10668 -vllm/model_executor/models/qwen3_omni_moe_thinker.py,sha256=8J-LV7hkxkx_gplEwALhMDwtbIFdksmFnNG2O3ETD7M,88291 -vllm/model_executor/models/qwen3_vl.py,sha256=k-9DiyzpFJECpyaYjqgRNZDTMIbHoYJ9Z1JqK5yA90w,81041 -vllm/model_executor/models/qwen3_vl_moe.py,sha256=_RQDu20sMXoPnKi3FslBORFgFl8V8LKjNyRupcXfr-0,19664 -vllm/model_executor/models/qwen_vl.py,sha256=dpe4J6aIOynwdnP3oMr43g8Okk9Qwuq0Lrwkda0jMsQ,25918 +vllm/model_executor/models/qwen3_omni_moe_thinker.py,sha256=0K3PbENWnKowRxqY59fIIxZ6ZYC0LY-LYlDpG0mdAHY,90452 +vllm/model_executor/models/qwen3_vl.py,sha256=BGwW4P8fzP6sKpZX6bUZP-jKgUNSdcizPX3VSmbts1U,95398 +vllm/model_executor/models/qwen3_vl_moe.py,sha256=P2W7NLHLT953mQDhrNFP1gm_hJ4_ILSIWrFRnl5jLOk,19890 +vllm/model_executor/models/qwen_vl.py,sha256=F0Ubyo0LslTfWNXTta-gbKs2gII7FjAJAIkt6tBfrCY,23987 vllm/model_executor/models/radio.py,sha256=46UFQzPd4bkoyZwt03rV8iM9Ld6YN4rXjwlt1FZNePk,25512 -vllm/model_executor/models/registry.py,sha256=V-IHKusPALbPvNiQsVrnpNk0ca2urn4D4ZE6rfrhqN0,51521 +vllm/model_executor/models/registry.py,sha256=RohMN3KxCV2iFc9iDy-jz4uMA-ehAjfys5xf4QdhZHE,52098 vllm/model_executor/models/roberta.py,sha256=kY8L-R4SqEbL7cGc_En3LMTyNVN-NkfHH8BITXjq4Gg,12952 vllm/model_executor/models/rvl.py,sha256=x0LpjLj9DPzT2ErbrxM2ZDZgxM84QMD-5m1ROp8Ino8,3529 vllm/model_executor/models/seed_oss.py,sha256=3au3bmOI5W0hYE_Qq1ExX8NJrktd-17Kn2pt9c1w_e0,17901 @@ -2311,15 +2356,14 @@ vllm/model_executor/models/solar.py,sha256=cp5LoWIxnQZ2_WFjeb6YdkTOpaMevW5lxw2XJ vllm/model_executor/models/stablelm.py,sha256=qxsgs-JAfeR0UXOv58PAaBuEO8oKmU_PJe0TU46ivl4,13595 vllm/model_executor/models/starcoder2.py,sha256=3qw3Rg7K1qyNzrzXkuP1aDGqj4b7D1WxLhzOoi_Yvhk,13401 vllm/model_executor/models/step1.py,sha256=1hH9WJwssqYIP4wF9Nq4HqGjnKVwhllCXQcO6a5GYSc,15190 -vllm/model_executor/models/step3_text.py,sha256=Jho9rywZWBQn4cFeRbLbCegRxhMEOrllUsei3y4MZZM,19894 +vllm/model_executor/models/step3_text.py,sha256=I9J573qe95GSW72sstQZNrZOBr74NMCLAGZeEL8AlZg,19905 vllm/model_executor/models/step3_vl.py,sha256=Pakr_mwr9FnsSzALSDgIHJBq9oXKIN-6S13yKsYfW4s,39484 -vllm/model_executor/models/step3p5.py,sha256=6r-PcRDlGXE_jdJHsIQh8AQBgesnVdUd1BqegOwHBDA,37688 -vllm/model_executor/models/step3p5_mtp.py,sha256=lWs3tSxr_sYT4SDisJaBSq09Du2-Er9P9xRkGHK0KfI,12542 +vllm/model_executor/models/step3p5.py,sha256=oMQMKsqifSj4tMXPgPEo3JsZDZodmCVQR9Rriq-GWho,39071 +vllm/model_executor/models/step3p5_mtp.py,sha256=prsiPktOD0zQ4TemNe4ybXdlppaUMAsbB8sRg4Y3OI0,14279 vllm/model_executor/models/step_vl.py,sha256=464Mjc6zeUApfPiNC04yPBJzNBw_xS5VMIxjOFW35Y0,18859 -vllm/model_executor/models/swin.py,sha256=gKgBbW4e1gkSkUw_2OlqvLyijWGPkPu6YBYvfkKZ-ro,16428 vllm/model_executor/models/tarsier.py,sha256=7KdUQBig2GNlnP4i1KqZZuiH0CAsGyh-mLIRuZNyRvQ,22354 vllm/model_executor/models/telechat2.py,sha256=wOEquSezPeLg2qqUNZ7ncsIzpfxHxRCrFoproADyARE,6245 -vllm/model_executor/models/teleflm.py,sha256=vlTObvAPJAJhkhy-8Snyx_7R1jGqAbYU8OYbWAPgvh4,3030 +vllm/model_executor/models/teleflm.py,sha256=AQvwDlusYU4VyDuOO5_uj2FSyoD4pB7cUx9GLf4ZG50,3104 vllm/model_executor/models/terratorch.py,sha256=ZyRLssdTyADURq2uHROxn-ejp3mgnQwzDg4_aaa_IzA,11086 vllm/model_executor/models/transformers/__init__.py,sha256=KR0hAoCBU4oqFfOpZ1kT9u80zQ_s-Z6HWpTfkvra-iM,4264 vllm/model_executor/models/transformers/__pycache__/__init__.cpython-312.pyc,, @@ -2330,15 +2374,15 @@ vllm/model_executor/models/transformers/__pycache__/moe.cpython-312.pyc,, vllm/model_executor/models/transformers/__pycache__/multimodal.cpython-312.pyc,, vllm/model_executor/models/transformers/__pycache__/pooling.cpython-312.pyc,, vllm/model_executor/models/transformers/__pycache__/utils.cpython-312.pyc,, -vllm/model_executor/models/transformers/base.py,sha256=CB8v6-tcJHPjPda8X5b55b0H-nWaLKR5_Vl6VrDY5ao,21105 +vllm/model_executor/models/transformers/base.py,sha256=Ans3HGdUNAg4R_6AkXvqwxeZDk8e2NuyDkImpgGdQvk,21881 vllm/model_executor/models/transformers/causal.py,sha256=lO_fPotGi08ZcpwRMKeEP2DKctafVjhW9XZAvNA-S7A,2661 vllm/model_executor/models/transformers/legacy.py,sha256=NY96kEaStY66o4Sa3tOFCqLhhwgiIf3eO2dnViXowIQ,3229 vllm/model_executor/models/transformers/moe.py,sha256=ey3ADmdqPZkwTFdQ5avKdC5gXVp_z8dagNg28uO9GIc,14328 -vllm/model_executor/models/transformers/multimodal.py,sha256=U70KCN4N5PmsX1md_6BYw_d7RNrxN-fZKb_X4uneJ-A,19289 +vllm/model_executor/models/transformers/multimodal.py,sha256=wFVkVDR9zBwfE-ZLBA_sBOe_TWE32tGNJddL7_WMq9E,20313 vllm/model_executor/models/transformers/pooling.py,sha256=SS_1ciigE75k4WB5IthheWqYLZPeIGZWiT3tvKCmaqE,4077 vllm/model_executor/models/transformers/utils.py,sha256=HVyyfp_2Y9XPwg15SacdGFlpbWSK8abHE1ZQRFb4-I8,8825 vllm/model_executor/models/ultravox.py,sha256=mRny8WOuDkOToEvAi_hx_NdY2gkLPc9Bom3pGL5BNbM,29549 -vllm/model_executor/models/utils.py,sha256=beOOPhvdL-75EhwiAbcUlg8kE2ApiHOp4-YKQ5TwqKQ,29424 +vllm/model_executor/models/utils.py,sha256=ThsVFW09be9dVIoReq4WFNaMJa9fhbKRAcS2P7cWHQY,29149 vllm/model_executor/models/vision.py,sha256=0jQuAx2aSJKCS7ZuVCB6gfFWbDqWmfL522yjmWnX66E,21391 vllm/model_executor/models/voxtral.py,sha256=QTP6ebmVb3uwzABO7rYzEmCQCjty0N811hBloLP2NJE,34824 vllm/model_executor/models/voxtral_realtime.py,sha256=lbMt7I0oDjUk8KjMdvFB6q451phLJo9vDW9S6l6py-E,18997 @@ -2357,14 +2401,14 @@ vllm/model_executor/offloader/base.py,sha256=Eg89-UqJ22EGGYYPQhSoZpXoWyEq8dfTtYZ vllm/model_executor/offloader/prefetch.py,sha256=GnH8axqlOL2AJjaUFo9YSBJAoc8RXmcWDXPi35a2fd4,28295 vllm/model_executor/offloader/prefetch_ops.py,sha256=DB1mKrQ91LBO4pEIRwI6_pEOdamCX3hn2vHrrR3VvxI,2545 vllm/model_executor/offloader/uva.py,sha256=NaWXzOKNus3zNgqpL_cmNB55SnT01VddiKQGsykfE34,5220 -vllm/model_executor/parameter.py,sha256=E9TVgals5xAY4jx0mcgbi90hJygCJJL4qjb7gU5sjvE,21945 +vllm/model_executor/parameter.py,sha256=Fs6RjMq461a8eAt7lJP-497YU8Zcu9UD6aP09NbYC_E,22429 vllm/model_executor/utils.py,sha256=WI1ei6LPUsDfoZaw7fgwdSlVEfRkKegN6qK5rsKZHac,4542 vllm/model_executor/warmup/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/model_executor/warmup/__pycache__/__init__.cpython-312.pyc,, vllm/model_executor/warmup/__pycache__/deep_gemm_warmup.cpython-312.pyc,, vllm/model_executor/warmup/__pycache__/kernel_warmup.cpython-312.pyc,, -vllm/model_executor/warmup/deep_gemm_warmup.py,sha256=OC6sL7gNXDU9lgmktF1HGNEkjKE0kVYH34NqpnYe-Eo,13025 -vllm/model_executor/warmup/kernel_warmup.py,sha256=_m1mE9tiMgcqja1kvjxPWg1byGJndrRm87sqOa2o8ws,3755 +vllm/model_executor/warmup/deep_gemm_warmup.py,sha256=rhnLDXmHTxV7_iymb_H-JHX3n2FFYUuf3_BJLL5Y5uA,13029 +vllm/model_executor/warmup/kernel_warmup.py,sha256=VSL0EIGcY59NQjgefYM4u5UbgqS3xzmr1lZ1SYwo2fo,4036 vllm/model_inspection.py,sha256=99U1s9FBinjggNPt5NvHvabMcIh3IK_dnXLgbFYC_jc,4743 vllm/multimodal/__init__.py,sha256=tnAxNnu1kVSEBizAbOTSXQ2zpRHvc_ZnyxORjLjmGBc,985 vllm/multimodal/__pycache__/__init__.cpython-312.pyc,, @@ -2382,7 +2426,7 @@ vllm/multimodal/__pycache__/video.cpython-312.pyc,, vllm/multimodal/audio.py,sha256=sS5AeYnH01hdM2TRzgAHHLgdPrTL2b3I9H5FsfTxMCc,11139 vllm/multimodal/cache.py,sha256=uZushqgTUTR4F-m__CFuVFRH2_BC_sF-j2VX0oFu1Bk,22324 vllm/multimodal/encoder_budget.py,sha256=dqryW7h1X3Iyxfic0VFzorsI2M2qqhB5cZmklEiltJs,7111 -vllm/multimodal/evs.py,sha256=fRGDSsF8VNjS6c665IH4SgKZd__RKIll-sEeDZP3Kyg,10886 +vllm/multimodal/evs.py,sha256=cs1ZvdSx7r3QpQDZjDDQkMvLqLdv_LmDB4mzEve0Adk,14494 vllm/multimodal/hasher.py,sha256=AET4pGTelW9oVmHIKGGlSo_aYMv9I5cbQ9fjQcxmZos,5340 vllm/multimodal/image.py,sha256=V7Ev9PMZt4MqN6bXefVCOTe1_RdZPlyldtkstzN-XH0,1192 vllm/multimodal/inputs.py,sha256=zcJhUiqaiLV3KANARgLRK2u2UEpPj6LN0qb8IXR-D1k,35665 @@ -2408,9 +2452,9 @@ vllm/multimodal/processing/__pycache__/processor.cpython-312.pyc,, vllm/multimodal/processing/context.py,sha256=0xnFAD3iAVIYu5gbnpS93lWenQ_EFounyIPNkWpN9Lg,16378 vllm/multimodal/processing/dummy_inputs.py,sha256=C7wgoX9gmrtWsur4EPFB-FuzlC_9z_Zr6h6pHQmXdQU,6281 vllm/multimodal/processing/inputs.py,sha256=JCMdwmCD_w4XNvupdPSSFCkxDxuC-_AOU--IOU9ZO24,2798 -vllm/multimodal/processing/processor.py,sha256=PwZ4_2wjCw7nqdjoDHvigF7I1J52xPmiWRBYTVF2tmA,58929 +vllm/multimodal/processing/processor.py,sha256=V1lIReMzRuHG_2DsjHBHQNWpEv6aOGG9sEg_wiRXqes,58563 vllm/multimodal/registry.py,sha256=IX6LWAFYnZuKo0bF-SVpa2hk7nkONnnWOjm-IhyWLWA,12284 -vllm/multimodal/utils.py,sha256=vW2Md85sh_vygLaDjLZOY89IFsdYUh-KcFeR-gwpUYA,9832 +vllm/multimodal/utils.py,sha256=jLnnfxOa5ZMhlnanpQkTKNeKRJ9Jjr63XcoifWvg8LU,9281 vllm/multimodal/video.py,sha256=CiILeXYUrjyW7WSkd34AEtpYvtVpS8jeCA_SmsdslCA,30131 vllm/outputs.py,sha256=mRwJYHcKYZBmzL9GnMCM4elSPn_hnjoRbEzRPB4XUKA,12628 vllm/parser/__init__.py,sha256=tISZkyTnsqET8CWO-_w9DKGfR016aJRuy_7e00GoVow,920 @@ -2421,7 +2465,7 @@ vllm/parser/__pycache__/parser_manager.cpython-312.pyc,, vllm/parser/abstract_parser.py,sha256=1vzOQp-KTgkhFQxlHkyqMsdE7VhwZFC6jNDr3lhSfUk,18946 vllm/parser/minimax_m2_parser.py,sha256=7ITK9L57iW60y18H2LLxJuCem2FL3fvflhogeo-gszU,1908 vllm/parser/parser_manager.py,sha256=AMWOnEKFSwaKx6hhxQDnZuZBpVxMEitD85jGvXp__o8,10514 -vllm/platforms/__init__.py,sha256=yVDxm_6yxNOgxHqxWbytL1k0gbvzSCLXrjTBzdDMqfc,9727 +vllm/platforms/__init__.py,sha256=-RUbNJzU1WSuvMPAhNh5g9NSKOl6W6OAA6u164WqHG0,9681 vllm/platforms/__pycache__/__init__.cpython-312.pyc,, vllm/platforms/__pycache__/cpu.cpython-312.pyc,, vllm/platforms/__pycache__/cuda.cpython-312.pyc,, @@ -2429,25 +2473,25 @@ vllm/platforms/__pycache__/interface.cpython-312.pyc,, vllm/platforms/__pycache__/rocm.cpython-312.pyc,, vllm/platforms/__pycache__/tpu.cpython-312.pyc,, vllm/platforms/__pycache__/xpu.cpython-312.pyc,, -vllm/platforms/cpu.py,sha256=eK2HRaEKigFmFwChBOjfTdeo4SOsR3e0PdF_lTF3gLg,18338 -vllm/platforms/cuda.py,sha256=eqGY7NXVVP3CxWNAgtyUfSYqC_Vn8qW8vZ-UNGzOUpc,25508 -vllm/platforms/interface.py,sha256=BTxw-MJJnYQ0vC7w-pAZoaR7mEFX3ZMeJ-1GTEk1HN8,23504 -vllm/platforms/rocm.py,sha256=yIh1P_PSnf0-yNJES_Wm2vHq7HrPjq27MVhHq4_K6F0,26010 +vllm/platforms/cpu.py,sha256=qRBZPDrcJxQrqOGuTWo6OpxyHwQocHafifH_HVGzvXw,19426 +vllm/platforms/cuda.py,sha256=dH5TE7jcUkWIEkDnL5LIxKHjE5y0DJMP2l7uPggWzdw,26656 +vllm/platforms/interface.py,sha256=GkftAURIiA3AUXQT516vNeTiFLv5XJgvrg6T4AJpGic,23529 +vllm/platforms/rocm.py,sha256=oWfl-cNmhvLiQ4iyRtpLg-SPESYvFi8nG2qGj14eXS0,30601 vllm/platforms/tpu.py,sha256=jP2fyAJCwm7b0phOFcRlT7BxXHubFOiQMFLxA3oE31Q,497 -vllm/platforms/xpu.py,sha256=Q9asZAboAyNz5bB2hi67iajp19fVOeIkOxCKKRs57WI,11389 +vllm/platforms/xpu.py,sha256=1q6G6_jsxBCmTiDSJ0ukjIcyJtPGby9SB5jjVsmp66A,11716 vllm/plugins/__init__.py,sha256=S-ZhkM6u6dBGX2Kt6AGo6UqQfXq5_bCmf6FM6HRIrp8,2902 vllm/plugins/__pycache__/__init__.cpython-312.pyc,, -vllm/plugins/io_processors/__init__.py,sha256=UhJEF--e0B816MCdrtyBxon4RQgsuM6vxLTBu6V6Y4A,2517 +vllm/plugins/io_processors/__init__.py,sha256=cJWVuQGdBZT0wyg1_omUBpx-PcIERlN_KUAG7Lqbbvo,3067 vllm/plugins/io_processors/__pycache__/__init__.cpython-312.pyc,, vllm/plugins/io_processors/__pycache__/interface.cpython-312.pyc,, -vllm/plugins/io_processors/interface.py,sha256=cl9MdLYrYUq-EmbqBu-59d3TW34luNT7XXf0NxHWIlc,4260 +vllm/plugins/io_processors/interface.py,sha256=JM4M0WsF7MKwKRrnRXnPSUAcrzHTrr5SD4ZT6GKAoho,4324 vllm/plugins/lora_resolvers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/plugins/lora_resolvers/__pycache__/__init__.cpython-312.pyc,, vllm/plugins/lora_resolvers/__pycache__/filesystem_resolver.cpython-312.pyc,, vllm/plugins/lora_resolvers/__pycache__/hf_hub_resolver.cpython-312.pyc,, vllm/plugins/lora_resolvers/filesystem_resolver.py,sha256=ys4JylLFIcxzO8maZID3OJKM_ZbP_A7JRh3wmqNs2_M,2345 vllm/plugins/lora_resolvers/hf_hub_resolver.py,sha256=dsXM5Vlb2YVNhYTeUKaF-grzx-nnN4I41tUR6M7FlRQ,5600 -vllm/pooling_params.py,sha256=m8gF0DVPyP-blqTcYKL6PxbPwPHvkL7wHjbMSLNc8tM,7925 +vllm/pooling_params.py,sha256=e1Hvm4fzEunEb-5MT86oWyC_FJwf97GzI0Rj7w2OD7U,7522 vllm/profiler/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/profiler/__pycache__/__init__.cpython-312.pyc,, vllm/profiler/__pycache__/layerwise_profile.cpython-312.pyc,, @@ -2463,7 +2507,7 @@ vllm/ray/__pycache__/lazy_utils.cpython-312.pyc,, vllm/ray/__pycache__/ray_env.cpython-312.pyc,, vllm/ray/lazy_utils.py,sha256=0QXF-uTIUdzl992ugkYgkvHbWUG-85yjuIDNSLVu2gQ,651 vllm/ray/ray_env.py,sha256=x04IQZOk22XIrW7rUJo7EUJFJChzm8utQNv72-1jRh4,4169 -vllm/reasoning/__init__.py,sha256=WY68O9kcZXb9Lmj6wUdX_y6Gnic-bnXX8n8Am3RQFkE,2576 +vllm/reasoning/__init__.py,sha256=5wv00rMoOjbqRTLDmUCXdtmkr6N0arx06NRotQHcWgM,2556 vllm/reasoning/__pycache__/__init__.cpython-312.pyc,, vllm/reasoning/__pycache__/abs_reasoning_parsers.cpython-312.pyc,, vllm/reasoning/__pycache__/basic_parsers.cpython-312.pyc,, @@ -2474,6 +2518,7 @@ vllm/reasoning/__pycache__/gptoss_reasoning_parser.cpython-312.pyc,, vllm/reasoning/__pycache__/granite_reasoning_parser.cpython-312.pyc,, vllm/reasoning/__pycache__/hunyuan_a13b_reasoning_parser.cpython-312.pyc,, vllm/reasoning/__pycache__/identity_reasoning_parser.cpython-312.pyc,, +vllm/reasoning/__pycache__/kimi_k2_reasoning_parser.cpython-312.pyc,, vllm/reasoning/__pycache__/minimax_m2_reasoning_parser.cpython-312.pyc,, vllm/reasoning/__pycache__/mistral_reasoning_parser.cpython-312.pyc,, vllm/reasoning/__pycache__/olmo3_reasoning_parser.cpython-312.pyc,, @@ -2490,10 +2535,11 @@ vllm/reasoning/gptoss_reasoning_parser.py,sha256=HWJ2eXOICySjOm3jFuamMJMcALsDPb2 vllm/reasoning/granite_reasoning_parser.py,sha256=MeDBbB2Iyv1SrKJisfsqYHR1Cn_JQfgE3QD3TeFXH5U,15272 vllm/reasoning/hunyuan_a13b_reasoning_parser.py,sha256=ThfBaOuTVFqY7OEU8vq_OWsiRBZXQ7Uu-xqcv8Y3bgI,9651 vllm/reasoning/identity_reasoning_parser.py,sha256=0wNlliwNu_wea0IdvdJqdZq1xYoIwtwwef04oV8slQo,2210 -vllm/reasoning/minimax_m2_reasoning_parser.py,sha256=ag2xV0h86r1yIdsDneIUs5m9W3b2tDmBj0ptUx6w__8,3941 +vllm/reasoning/kimi_k2_reasoning_parser.py,sha256=j-6uGi3WcXgL-lgNb3LpkmIkfrYuj-fGAQL5VeXTJDo,8309 +vllm/reasoning/minimax_m2_reasoning_parser.py,sha256=m0PlsWAa_PLA6mHLM8Ox6UQnwkKjrRGOv81Slu5SiAA,4134 vllm/reasoning/mistral_reasoning_parser.py,sha256=1OVWfroTTHBwlOI-L2COq7FPQvVDDZnD8vMB3VTqQnE,6319 vllm/reasoning/olmo3_reasoning_parser.py,sha256=ocBot-3mlnfuEkWg941OY-bbpcdB5A8PHU1jlI89brE,11123 -vllm/reasoning/qwen3_reasoning_parser.py,sha256=_gMOu9SXAnxUHVajPYfY4TVJRROY7fmWnNOOIpznzsc,5339 +vllm/reasoning/qwen3_reasoning_parser.py,sha256=OLWU8THTcNcLHWmlMLBjIDP3ky3NXrynsDvPVuGMO5o,6099 vllm/reasoning/seedoss_reasoning_parser.py,sha256=Oz4zSgr6seYtssSyDKsFi7J-b9kTy-IcuC01HNJkv_c,832 vllm/reasoning/step3_reasoning_parser.py,sha256=PD92XgxuGaH7yJf3pQLt6DTXeZkekorTIv9bY4SDPlQ,4403 vllm/reasoning/step3p5_reasoning_parser.py,sha256=RA4AmtxGLicQL_V2i3duPW4_fMRUCz0rzF8F6rKnaJE,7227 @@ -2506,6 +2552,7 @@ vllm/renderers/__pycache__/grok2.cpython-312.pyc,, vllm/renderers/__pycache__/hf.cpython-312.pyc,, vllm/renderers/__pycache__/mistral.cpython-312.pyc,, vllm/renderers/__pycache__/params.cpython-312.pyc,, +vllm/renderers/__pycache__/qwen_vl.cpython-312.pyc,, vllm/renderers/__pycache__/registry.cpython-312.pyc,, vllm/renderers/__pycache__/terratorch.cpython-312.pyc,, vllm/renderers/base.py,sha256=fP7HzLgm-oMDOr5vbPDtxDpvIvOZZJ2NNx30VZmkUCw,26149 @@ -2521,9 +2568,10 @@ vllm/renderers/inputs/preprocess.py,sha256=OgO3sBDbaKcFk95rkQJZ1Kuq0cfoGmNkgH5HE vllm/renderers/inputs/tokenize.py,sha256=p1SGoV6Z_BdoW1ZclSNhzHm8aH4gFzOLv0m799XqKe0,1393 vllm/renderers/mistral.py,sha256=5xJEpFTCaJgxrvYti3eB2yKtyXGEcgsdANx-bILpbGs,4337 vllm/renderers/params.py,sha256=KZYXBx9JRuxhD1EVxizfgnRzYvhZnZgWsgg2qlKY7o8,14128 -vllm/renderers/registry.py,sha256=eOVvDfFdt7mQwlwvjqruBTeyzFVDoodf8Mu1P02XbAg,2850 +vllm/renderers/qwen_vl.py,sha256=nDMh0Siz9z85FCJl2qDZfWhEcxTruNuv0cMzC6VMLc8,871 +vllm/renderers/registry.py,sha256=wzx4GDoucNRHKGFRqfJnuM6Li6FP98Gz61J7wEtftEI,2896 vllm/renderers/terratorch.py,sha256=QNmRVC-dddsBXE3NV6v_rBK-u8aVamBRh54bkBLMNc4,2278 -vllm/sampling_params.py,sha256=_qcGtMojZzw74PGt3ZuhItnus2e2yOG0rzAIuzC8EfI,35852 +vllm/sampling_params.py,sha256=QUuPyJnCrWmXvKP6CdQe1KLgkCr5n_NuoDJnDpB5y2o,37465 vllm/scalar_type.py,sha256=s8CZpZru6-4vEexIofKsJzhw5OBjVuvueG6b6yb6A94,12573 vllm/scripts.py,sha256=d2RU3e6Qi_xY7QUD44uqsdz8DdBR-_zQnFf4rYUW_iw,504 vllm/sequence.py,sha256=q10MDnLiFsnhKuuKBeIrQ6Pyj9LpGzVcHKnm-1uK0YI,2173 @@ -2533,7 +2581,7 @@ vllm/third_party/__pycache__/__init__.cpython-312.pyc,, vllm/third_party/__pycache__/pynvml.cpython-312.pyc,, vllm/third_party/flashmla/__init__.py,sha256=eMVQe_qZ6bEFHduSNgQ7ZAP3AvV3oMbEmyy_A-s-RSc,31 vllm/third_party/flashmla/__pycache__/__init__.cpython-312.pyc,, -vllm/third_party/pynvml.py,sha256=HRQEbE5ZB-AgMIySG4j24hjlEgSGLrNvHJUq7UawCfA,234653 +vllm/third_party/pynvml.py,sha256=7ylxuC4d6eXttiQdX9CjOAVWe2sqEEtIztGQKLCf_RE,234646 vllm/tokenizers/__init__.py,sha256=F3xPUuvHJZfUCeLImLWYyDBZa-e_1aMlj37SfUZHr_A,418 vllm/tokenizers/__pycache__/__init__.cpython-312.pyc,, vllm/tokenizers/__pycache__/deepseek_v32.cpython-312.pyc,, @@ -2543,15 +2591,17 @@ vllm/tokenizers/__pycache__/grok2.cpython-312.pyc,, vllm/tokenizers/__pycache__/hf.cpython-312.pyc,, vllm/tokenizers/__pycache__/mistral.cpython-312.pyc,, vllm/tokenizers/__pycache__/protocol.cpython-312.pyc,, +vllm/tokenizers/__pycache__/qwen_vl.cpython-312.pyc,, vllm/tokenizers/__pycache__/registry.cpython-312.pyc,, -vllm/tokenizers/deepseek_v32.py,sha256=b0nnVfswcvF-IVct3zOnVNuOIf0m18bBQI2zDaIvzHM,3213 +vllm/tokenizers/deepseek_v32.py,sha256=nNR_ToCswY76ZhuYX2UUxrUHqO2uXICMs6YM94cl1Uo,3221 vllm/tokenizers/deepseek_v32_encoding.py,sha256=3jX1CQW-tmoh-YB8w2uxm_jWdo3eUHcPtjnnQyvjVLU,15812 vllm/tokenizers/detokenizer_utils.py,sha256=nyIW3eO0H7ovSUyyHt4mGbmmKOr3Ip-jAgK-3BN5KXQ,7821 vllm/tokenizers/grok2.py,sha256=x5j44_g31paTCXVp4WvccBRGA5XnqecwrXjDyk4wKpk,14522 vllm/tokenizers/hf.py,sha256=9GIQPaNBIgUkAqmWwSdRLR4cXQqiX8Ewz5fUCCqS9vg,4493 vllm/tokenizers/mistral.py,sha256=kntyYVaIafQ3N5fEn8PAmxIbSIQVV-Z2d_8bHgDqO0Q,21738 vllm/tokenizers/protocol.py,sha256=_LfaT2d9M6AEu-nPJW82pdw4b8Byk-T6Ps2viGAh1EI,3291 -vllm/tokenizers/registry.py,sha256=5KEg0piy6hX9wi0OGxvmTNPIQq0b-ykgOdes6gajDT4,7991 +vllm/tokenizers/qwen_vl.py,sha256=cDjvLkjkfNgfkBjM2SbFQmnK-WT7Z2t5yxtH1MMu2uc,2165 +vllm/tokenizers/registry.py,sha256=83OHhZhAmMKu7HIQjAY-rI2jZLc42J-z62Hqcv5O-_g,8177 vllm/tool_parsers/__init__.py,sha256=TvPqzR3Wlr2qzbX50fmzbt22zxN9_bGOwHSORuh7Ebc,3666 vllm/tool_parsers/__pycache__/__init__.cpython-312.pyc,, vllm/tool_parsers/__pycache__/abstract_tool_parser.cpython-312.pyc,, @@ -2613,12 +2663,12 @@ vllm/tool_parsers/olmo3_tool_parser.py,sha256=-pGGNRErh_F8AJTCUo4uOTitEDE1hpoDai vllm/tool_parsers/openai_tool_parser.py,sha256=NnO1_6IwnAgmbsPiHZ0-Et3e1EW-WyKrBoRC_Or5K_M,4279 vllm/tool_parsers/phi4mini_tool_parser.py,sha256=5QFCzfRrwpUGuwGaXrOakOhwpsFqvQdyXQF-q_U99Bo,4084 vllm/tool_parsers/pythonic_tool_parser.py,sha256=9C8fOprJgamB_8g0rHsHfgBhrU_KW7BQRv6jmsiY4tQ,12269 -vllm/tool_parsers/qwen3coder_tool_parser.py,sha256=MBPZKQoCpQB4EMtn_JmXevEZ4-bxWIrfs3dqSwXKEYQ,32804 +vllm/tool_parsers/qwen3coder_tool_parser.py,sha256=zuu4Nml-xRndFBcfZSVcVCFWt2VklPZ9XyKMvCYq32Q,29185 vllm/tool_parsers/qwen3xml_tool_parser.py,sha256=0cRnsspFfki7fPwrGFD8RD1wAfTyl0fWtuK-JVMYD8A,53857 vllm/tool_parsers/seed_oss_tool_parser.py,sha256=gEQbkx3uY54FBcXK513d9kWPurkuGIFPUMnpOcTHO68,31404 vllm/tool_parsers/step3_tool_parser.py,sha256=wakZ3uuPcW8IOFe0tG3SONYtfEF55LTKhy3DA_y6098,12159 vllm/tool_parsers/step3p5_tool_parser.py,sha256=zmSzhP9207A1cEEodvfviFihTBePIWDpPsjQwKdvbn4,61867 -vllm/tool_parsers/utils.py,sha256=PaTL3WKpzrZwU7wJHqI-XlO9McJR4LGL4ldFk1ukMC0,7460 +vllm/tool_parsers/utils.py,sha256=TDnPFr7xYcLQuEoAfEMIYkZs5oSu_lkm-gCkezspxZI,7087 vllm/tool_parsers/xlam_tool_parser.py,sha256=1xOuZ1IRvsvbhTwLx2WXejI2ylfiFrDYQLy7N-SAaHg,24689 vllm/tracing/__init__.py,sha256=LYFnkXkGpQBWKbAY95Am9yKeOmoi-iRmM4WQhoju8mE,4308 vllm/tracing/__pycache__/__init__.cpython-312.pyc,, @@ -2652,7 +2702,9 @@ vllm/transformers_utils/chat_templates/template_fuyu.jinja,sha256=hzdsPgeUMaZnd5 vllm/transformers_utils/chat_templates/template_minicpmv45.jinja,sha256=tNvJgf7S0je1GqFxPUQ-q5nSGwXoINdqW3-VOFNekh8,4231 vllm/transformers_utils/config.py,sha256=PP_sZlNDfxx_yHTIJYITVv_Nzg4v5SjYzfmiix8MWhw,42310 vllm/transformers_utils/config_parser_base.py,sha256=qSyL1AYBvvnQTII-lS9_N602OJ_Huf5BkUZYBhQEFtU,523 -vllm/transformers_utils/configs/__init__.py,sha256=MCljcKMzfFhdTPLIaOTMl3IEMKt0mlEq75stsaRD1J0,5482 +vllm/transformers_utils/configs/AXK1.py,sha256=7KL6K89TunBNrye7TKtNH76tbeT67m4lvP97wQttqF0,10995 +vllm/transformers_utils/configs/__init__.py,sha256=EbMCeMBZf-g46jMnb1C-QyY1gVLGk8NgaOttlI5bnQ8,5558 +vllm/transformers_utils/configs/__pycache__/AXK1.cpython-312.pyc,, vllm/transformers_utils/configs/__pycache__/__init__.cpython-312.pyc,, vllm/transformers_utils/configs/__pycache__/afmoe.cpython-312.pyc,, vllm/transformers_utils/configs/__pycache__/arctic.cpython-312.pyc,, @@ -2663,6 +2715,7 @@ vllm/transformers_utils/configs/__pycache__/colqwen3.cpython-312.pyc,, vllm/transformers_utils/configs/__pycache__/deepseek_vl2.cpython-312.pyc,, vllm/transformers_utils/configs/__pycache__/dotsocr.cpython-312.pyc,, vllm/transformers_utils/configs/__pycache__/eagle.cpython-312.pyc,, +vllm/transformers_utils/configs/__pycache__/extract_hidden_states.cpython-312.pyc,, vllm/transformers_utils/configs/__pycache__/falcon.cpython-312.pyc,, vllm/transformers_utils/configs/__pycache__/flex_olmo.cpython-312.pyc,, vllm/transformers_utils/configs/__pycache__/funaudiochat.cpython-312.pyc,, @@ -2682,6 +2735,7 @@ vllm/transformers_utils/configs/__pycache__/nemotron.cpython-312.pyc,, vllm/transformers_utils/configs/__pycache__/nemotron_h.cpython-312.pyc,, vllm/transformers_utils/configs/__pycache__/olmo3.cpython-312.pyc,, vllm/transformers_utils/configs/__pycache__/ovis.cpython-312.pyc,, +vllm/transformers_utils/configs/__pycache__/parakeet.cpython-312.pyc,, vllm/transformers_utils/configs/__pycache__/qwen3_5.cpython-312.pyc,, vllm/transformers_utils/configs/__pycache__/qwen3_5_moe.cpython-312.pyc,, vllm/transformers_utils/configs/__pycache__/qwen3_asr.cpython-312.pyc,, @@ -2700,6 +2754,7 @@ vllm/transformers_utils/configs/colqwen3.py,sha256=IUckSjIN7qxmKVDEttyXK91jP-6iq vllm/transformers_utils/configs/deepseek_vl2.py,sha256=I4Ge43Aqaw1cS1M7gps9utrM7XwmXj_ktbxRcHAeoKI,4075 vllm/transformers_utils/configs/dotsocr.py,sha256=_mxuA6xDpO6N2-FD4R5XVqLDP79D8i0OAZcPvqUCSdU,2445 vllm/transformers_utils/configs/eagle.py,sha256=R8LT924n_tTE8LkfYcENCv_UVvLmsIWKl6th6okyOWg,3012 +vllm/transformers_utils/configs/extract_hidden_states.py,sha256=84LXbdhcsDSJQ1E3bKFgVgOPmKuDiz-7RnUPgyh3FQE,1749 vllm/transformers_utils/configs/falcon.py,sha256=XNrg2x08UDZngZqtv5T0CiEFZq7qhoJPWdEYnI64e_Q,2937 vllm/transformers_utils/configs/flex_olmo.py,sha256=NInzKT7pO5w6CKxxaaEAXW_YnN6f1Sv3hDWQYhPL7ko,3157 vllm/transformers_utils/configs/funaudiochat.py,sha256=F3R7akXupVr2KQTh9vycB6S8Z90oEFWDiihyLx2qr6o,4695 @@ -2719,6 +2774,7 @@ vllm/transformers_utils/configs/nemotron.py,sha256=EAYX0byXwgziKfac4nL2cFtaatPyD vllm/transformers_utils/configs/nemotron_h.py,sha256=fL8DudDBuHNAdBAdkXjP1XhbGpGRFIOW98qXH3zHCdc,13414 vllm/transformers_utils/configs/olmo3.py,sha256=pnjjin2JhHw7lA0JzpsVaRiF4l3-jDwj5_DrroEEVTU,3026 vllm/transformers_utils/configs/ovis.py,sha256=zjfzz8-btPauEhKwQm9dRqmsZ8jakSaVCqqZ2rJZv2g,7535 +vllm/transformers_utils/configs/parakeet.py,sha256=B2bzUzRWpUqojQhx4_ZK9bBrymoXawnTbfue9iPyhSg,1658 vllm/transformers_utils/configs/qwen3_5.py,sha256=d4NwouDihpFiCkOGWGPxq_5rPjWhT4C_LeUZ9SnJiW4,6996 vllm/transformers_utils/configs/qwen3_5_moe.py,sha256=t85XK6BlTnhevVDiUr55xlr3oIh0X0_0AVl78EwcL9Q,7661 vllm/transformers_utils/configs/qwen3_asr.py,sha256=Vv0fHqk53_R1M7BSli0pQ3XoVILtMQ5IzK1ALPVpmBg,20563 @@ -2736,13 +2792,14 @@ vllm/transformers_utils/configs/tarsier2.py,sha256=ydiT9TPSYbSL58QHX8y2KwtboQ0gt vllm/transformers_utils/configs/ultravox.py,sha256=4DwreIyyKxSwyie6av42UJof5wejQHeDLWtcA5jBPO0,4944 vllm/transformers_utils/dynamic_module.py,sha256=Gh0_B_IHLEedgKAHPNUZcQo2F_dAxicKo6luvfTo7zw,2038 vllm/transformers_utils/gguf_utils.py,sha256=DHxZAHenR0iZAwxRaK2KuzhUSZXCgJ-wgAw7SF9OgtI,10279 -vllm/transformers_utils/model_arch_config_convertor.py,sha256=MKieLiPi4Se4DjICUz4jq6jhg5kbBIRhbSdFol_Yiyw,17803 -vllm/transformers_utils/processor.py,sha256=QEIBQBoYMr-zgNnj2aGHyhmtC0I93ZbD1NlEnE0a8kw,17747 -vllm/transformers_utils/processors/__init__.py,sha256=MHOjadEoO3x803thzujcfC-raEWzofV5uCC7EQIQaXg,1067 +vllm/transformers_utils/model_arch_config_convertor.py,sha256=6yGeq38_M65ZyWvUJvVlkPJg8QSdlJ2OkWmxqBTFqek,17980 +vllm/transformers_utils/processor.py,sha256=UF5zLB6CwEYpEFT5tGGwqIvRguVQeNLBOBgoB25QUR4,17576 +vllm/transformers_utils/processors/__init__.py,sha256=TDMetNekay1ikxBwruohrq1aMXi2y5l_OnzgakelPJY,1194 vllm/transformers_utils/processors/__pycache__/__init__.cpython-312.pyc,, vllm/transformers_utils/processors/__pycache__/bagel.cpython-312.pyc,, vllm/transformers_utils/processors/__pycache__/deepseek_ocr.cpython-312.pyc,, vllm/transformers_utils/processors/__pycache__/deepseek_vl2.cpython-312.pyc,, +vllm/transformers_utils/processors/__pycache__/fireredasr2_processor.cpython-312.pyc,, vllm/transformers_utils/processors/__pycache__/funasr_processor.cpython-312.pyc,, vllm/transformers_utils/processors/__pycache__/hunyuan_vl.cpython-312.pyc,, vllm/transformers_utils/processors/__pycache__/hunyuan_vl_image.cpython-312.pyc,, @@ -2752,20 +2809,23 @@ vllm/transformers_utils/processors/__pycache__/qwen3_asr.cpython-312.pyc,, vllm/transformers_utils/processors/bagel.py,sha256=zi9b5HjfDpnd7a7GS-ANquWCft-3KcMrMuVfcVZSJuA,2745 vllm/transformers_utils/processors/deepseek_ocr.py,sha256=IG7NDyVUoZa5lNsMebYz5FBPzN-wy4o0LsKgYwR2r_8,15899 vllm/transformers_utils/processors/deepseek_vl2.py,sha256=YWmA16RtLAOE_eMHTJwHJSFoRFlBU3wD7cR3ZXP9JyQ,15121 +vllm/transformers_utils/processors/fireredasr2_processor.py,sha256=NvYz59sEnOEt93IGTo7LETY2WvyvKNohlDjecjBRMos,13047 vllm/transformers_utils/processors/funasr_processor.py,sha256=cUe7FyL5vwQHIIYT2Y0U16kUk0bXJWuX62CZVMb8Y_Q,18547 vllm/transformers_utils/processors/hunyuan_vl.py,sha256=4Bo9Ge6OXn4_KnzVArwFr37gdMNdPI_EpZhufv9Nnnk,9435 vllm/transformers_utils/processors/hunyuan_vl_image.py,sha256=KeauATCER1RdedquOh8bvXPKk9GT6eQAG8avSRqSdz8,22292 vllm/transformers_utils/processors/ovis.py,sha256=X7F5oc4jep7cV_G2Zj9r7uJE3QOccZCbeB9NsK2mZpg,19608 vllm/transformers_utils/processors/ovis2_5.py,sha256=LbudomwDaXccixhZcYT7V-J5eWYNbEnANn4Q_69RzbE,20083 vllm/transformers_utils/processors/qwen3_asr.py,sha256=w82qQ7wMbWe97bGoqrGjBZYpkXMr5BQ5gVAki9_1_Rk,9095 -vllm/transformers_utils/repo_utils.py,sha256=Ktp76jBwmBN8YUDxlc8adZ6IYQOHVgD0IyTt1gE5ClE,9005 +vllm/transformers_utils/repo_utils.py,sha256=vTY2zuYnYiQGsjPsRu4sl4ORMyy6XjV6o047Af7mi7M,9015 vllm/transformers_utils/runai_utils.py,sha256=SUDsljkSzlbmLhGzqpgykG4kgNVaVOxe2g4-kmXm4z0,3154 vllm/transformers_utils/s3_utils.py,sha256=3Vk1LybSbByVKEZZzOly7d3ggdhMdPRwS2wvqxUOpcE,2770 vllm/transformers_utils/tokenizer.py,sha256=v4LtF_1D5_aqj_1cpv9d53FNsVUAWqtX4-osRvkU-rs,665 vllm/transformers_utils/utils.py,sha256=H5CxGzLlIpCA-Zd8lkKlghMcFgeHopgYBA_9p7Rrle4,3078 vllm/triton_utils/__init__.py,sha256=FY20qv1flNw_TDjM9j544voqnflClkYGSOEF5lPwv0M,567 vllm/triton_utils/__pycache__/__init__.cpython-312.pyc,, +vllm/triton_utils/__pycache__/allocation.cpython-312.pyc,, vllm/triton_utils/__pycache__/importing.cpython-312.pyc,, +vllm/triton_utils/allocation.py,sha256=oGBpkpw7qD5MH7ZNr5jb7y0IJXUPYQw6v_7Ya7bgDhA,376 vllm/triton_utils/importing.py,sha256=CcaJ7kpP7bKP3yfuR_1w_vNRBxiZk8qNdcatjsvoAYI,3589 vllm/usage/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/usage/__pycache__/__init__.cpython-312.pyc,, @@ -2806,14 +2866,14 @@ vllm/utils/async_utils.py,sha256=8QiEviVKKi3NQ_OHRZVQqrPk4L-K_rAqFV77gZDQCHU,113 vllm/utils/cache.py,sha256=p6Cvrul9e1fOuw4h9-n74qmY_lHMQyS0BqhHXQqN-_Y,6206 vllm/utils/collection_utils.py,sha256=UHEf5oYJtCzYI8bEW6K5J-zLhGTF8gnmIzBnFhwiFVs,3592 vllm/utils/counter.py,sha256=6Xa0WjHey2LwyPJkGfYa37mgOsQfcHJF9kRypIcaHDQ,1165 -vllm/utils/deep_gemm.py,sha256=Kl9l_pcLZIYmuECIpA-XVG6D_38L3nXoBACWIXDRerU,14925 -vllm/utils/flashinfer.py,sha256=Y820S37DtbPacSFvrXfxxa8F4ty1L8l7OBqL0AQq1jE,24792 +vllm/utils/deep_gemm.py,sha256=gzpu5Yzh7zN1qLP7NFar4d024HF3xUfhxkEwbjNKVbQ,19593 +vllm/utils/flashinfer.py,sha256=apQmb9MkQtVB75qpZjKkeOEwPLUKKfMkLvNorw3BBRc,24831 vllm/utils/func_utils.py,sha256=pLXKoub-DlGrxZRwliCGmuMy2eiEHrglZPEQSbvPjKg,7706 vllm/utils/gc_utils.py,sha256=zMuLZmEcjECblk30kHjKGwLdOHANND5vopi17x_IRjI,4970 vllm/utils/hashing.py,sha256=ibS9K9JGpb_c8A12bqOKgYI3auz9a2t9XEfZaxVZJD4,3618 -vllm/utils/import_utils.py,sha256=iv2ycHmNYKHljx3OpNxtoIwO6DE10HaJ388a-H6isTg,14154 +vllm/utils/import_utils.py,sha256=cbPEomdgp8bfCu5ueAoNb19DwXCHc8lnOImpIiXpXIE,14021 vllm/utils/jsontree.py,sha256=lktXNEOq0w1qqAGTPgpJ5eDogjetHDC15Hq___2OLmg,3755 -vllm/utils/math_utils.py,sha256=j0iLxkURCK07Y3Kc9zr5NGWn_oKK57ijNKKB2wvyD7E,890 +vllm/utils/math_utils.py,sha256=3HHFjy16D9iPqbPd_4xdBgNQrgXLZ-O3TcZHqZEFaL4,1042 vllm/utils/mem_constants.py,sha256=DerNf9-iEKE2kx8c6ufmLY23mn7vV3K1vskiAw3SWJU,390 vllm/utils/mem_utils.py,sha256=fUzur9cxpAP1E5Jf0mGQ1Q9Xtv5TIvCJDjIuu1dI43M,10285 vllm/utils/mistral.py,sha256=aEBEQhLIXcgkv8RaDEi6mSruMq9_5-sgflgpTP4dQ04,1029 @@ -2825,9 +2885,9 @@ vllm/utils/print_utils.py,sha256=M_6PwgPusaGHGQLPo9e3TovYSAdpW5HoNsyM4ZZxqYY,302 vllm/utils/profiling.py,sha256=UWku3rWFiMP4UclIokbOxF1fh5xVuwiOc4vjsq2Fdo4,1469 vllm/utils/registry.py,sha256=gWRZ2XRwU-YapctL5A05O1bG2Vgm-QBcPrxjUPRugFU,1555 vllm/utils/serial_utils.py,sha256=W_2bYeQugj6AlLZLGHgsm0TXsfEz1kDxe0ecvIwfj0Q,2805 -vllm/utils/system_utils.py,sha256=3T02U9mDJzjxQ70tbRxnZLUtcbYmUpktzps-fDYXj54,9241 +vllm/utils/system_utils.py,sha256=sKIPCcmjTXHjaMhCGpalqPBzLJ6097C4WE4-aDl93hs,9767 vllm/utils/tensor_schema.py,sha256=Q2GzQ-CzlvQjA9qvZBcP8f3layD454s7c5nRt_FeZWU,9140 -vllm/utils/torch_utils.py,sha256=TyjxKoyoWZh4dS2-XmKCjNwjZZK5LXH_zZOSMIjGFzg,26997 +vllm/utils/torch_utils.py,sha256=ymLP-Yv-tk3wvoUfG19z0Y0HGvQR5_Ei_NtnKOxIXeY,27945 vllm/utils/tqdm_utils.py,sha256=JW5ymX2EnmIix7RLymsiDPFasapB6udhbSurNvXTlXM,831 vllm/v1/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/v1/__pycache__/__init__.cpython-312.pyc,, @@ -2841,7 +2901,7 @@ vllm/v1/attention/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU vllm/v1/attention/__pycache__/__init__.cpython-312.pyc,, vllm/v1/attention/__pycache__/backend.cpython-312.pyc,, vllm/v1/attention/__pycache__/selector.cpython-312.pyc,, -vllm/v1/attention/backend.py,sha256=muIkHtECY0qPILG3MO4DSzIwKx2_lGgkl_XzGDKDTB4,30672 +vllm/v1/attention/backend.py,sha256=P6tojKvUNwYyZRzt7n_S-9aRD3S-x8BsRYfaTy7ZbO4,32661 vllm/v1/attention/backends/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/v1/attention/backends/__pycache__/__init__.cpython-312.pyc,, vllm/v1/attention/backends/__pycache__/cpu_attn.cpython-312.pyc,, @@ -2864,16 +2924,16 @@ vllm/v1/attention/backends/__pycache__/tree_attn.cpython-312.pyc,, vllm/v1/attention/backends/__pycache__/triton_attn.cpython-312.pyc,, vllm/v1/attention/backends/__pycache__/utils.cpython-312.pyc,, vllm/v1/attention/backends/cpu_attn.py,sha256=iYyIvrNtpSBcKrIplpde1u49-KcmqkNTTHJGCK76oA8,18338 -vllm/v1/attention/backends/fa_utils.py,sha256=JUYQ594bQYVZGUU6Pj_jOeWt15t_yhQF7r67abH0rQ0,6110 -vllm/v1/attention/backends/flash_attn.py,sha256=v6wtmZFVPTl__VjipyEBqI4SKwNWKEOwLmrlXSAEN1Y,51112 +vllm/v1/attention/backends/fa_utils.py,sha256=qeo-t4CkCimRDBlo__qumUeAsKzsv3DT4LDk4fQJY8w,7318 +vllm/v1/attention/backends/flash_attn.py,sha256=ry_b2avMWWD9SbNnV-uyQFCzCWxyCRwg2gEKxGZNbyA,60988 vllm/v1/attention/backends/flash_attn_diffkv.py,sha256=8EU5rcDQZDtc8znzCvrL8mc5AFWFbQbPUDsPIap9DgE,11068 -vllm/v1/attention/backends/flashinfer.py,sha256=NAP38iPLg3gU6kOAJpxDZrbUR4jgZ5XwkumwuA8Qo5I,70745 +vllm/v1/attention/backends/flashinfer.py,sha256=49NpSoqLCwY_765DXN4Gx50gH2_58vHRT4h-tYbTVEE,69003 vllm/v1/attention/backends/flex_attention.py,sha256=0Hk1rJaq04PnApS8y23K0n_meHwd5SctGOxgIQ0ewfc,39786 vllm/v1/attention/backends/gdn_attn.py,sha256=93921EKaiUJ_ThvRVayxt00tRY6hIHCDyQgnjRql1CM,17327 vllm/v1/attention/backends/linear_attn.py,sha256=5CzIcHkvZMu-2Nkf2kaTWl3ZRuAWIOJGIqwxtHpvzM0,2722 -vllm/v1/attention/backends/mamba1_attn.py,sha256=b35cfbixMEN6XEETvY3sMpXJbVBykJJ7o9AIyTIc7Do,811 -vllm/v1/attention/backends/mamba2_attn.py,sha256=xCB_VlCKDHP-twOhbuUdZuTK1cP7-dLKVkWQJBMwtdc,9885 -vllm/v1/attention/backends/mamba_attn.py,sha256=QBoMLjkGgK_T3wWmSRh-TDf_ly0jnGdJyNbuS_kngFc,18401 +vllm/v1/attention/backends/mamba1_attn.py,sha256=rP-FJ3xGx1QHO6tcjHn57vnZHqc5Yk3Q4TsjDxIMdCo,1743 +vllm/v1/attention/backends/mamba2_attn.py,sha256=aY0ksoNS_hCdFr2vwLfjtjStt6BCjm77KXwstUJfZj4,5692 +vllm/v1/attention/backends/mamba_attn.py,sha256=vpu7bKD0xbsmFKpZhS211qL5FyWVJx68ge7OFGcrnPk,23072 vllm/v1/attention/backends/mla/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/v1/attention/backends/mla/__pycache__/__init__.cpython-312.pyc,, vllm/v1/attention/backends/mla/__pycache__/aiter_triton_mla.cpython-312.pyc,, @@ -2890,20 +2950,20 @@ vllm/v1/attention/backends/mla/__pycache__/sparse_utils.cpython-312.pyc,, vllm/v1/attention/backends/mla/__pycache__/triton_mla.cpython-312.pyc,, vllm/v1/attention/backends/mla/aiter_triton_mla.py,sha256=7Q4oXNrgagGRUMQcyVizwLXOsOTymdUxjRyQ4sI5U5A,2000 vllm/v1/attention/backends/mla/cutlass_mla.py,sha256=U5piAUug73lO6ioNkDe48gD0gQNG1q5p7ikcXFgMmo0,8799 -vllm/v1/attention/backends/mla/flashattn_mla.py,sha256=RUh_zQxwPTSeYLQyWfc9O-Lnkvk98_OljSsOlJtVHZI,12805 +vllm/v1/attention/backends/mla/flashattn_mla.py,sha256=AJEMZ4_dcy-MN_naxGwVcoBcgg5kIupb9vpNS_nzFG0,12974 vllm/v1/attention/backends/mla/flashinfer_mla.py,sha256=RnAiv3mmxEZ4spmkqPNcARcQN-5EUqR5pe-1_sqpKRc,6591 vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py,sha256=L9kd6cixFsX-bwzUT_fhS4WxJMfPqVilK_0KTFvahPo,11872 -vllm/v1/attention/backends/mla/flashmla.py,sha256=ZgpQjV1kltnILiyOGPwPG5jTXEpotue2H4KPpx3c6G4,11002 -vllm/v1/attention/backends/mla/flashmla_sparse.py,sha256=UkyZDdbijKAcJozt_BZ5MeL2XwDWAW7rBR-RywuuX2Q,33493 -vllm/v1/attention/backends/mla/indexer.py,sha256=QnxZDvmRnTyolQYA_ZPP5Zpo0fSmC3DE_BdNoSOWY8I,14592 -vllm/v1/attention/backends/mla/rocm_aiter_mla.py,sha256=lsCZA_CKMXQft__BMA6dVbK2dTWuwtTAqMXRQOvickg,9871 +vllm/v1/attention/backends/mla/flashmla.py,sha256=bNDNcVPaWPFOA7yVxCtDboo0fAVkuDKYdopgxb7_i5E,11171 +vllm/v1/attention/backends/mla/flashmla_sparse.py,sha256=g2iNW_VqDaA1_vQGjhSRYyIpox9dM_z6dy8ElSgZsyA,36349 +vllm/v1/attention/backends/mla/indexer.py,sha256=YRALm24dJuR3z7plkjuun9aoddITD1mYL16lo9o6duQ,21150 +vllm/v1/attention/backends/mla/rocm_aiter_mla.py,sha256=ppnGFEhOslTchdZN_81-7gW6Q1QJFStAayeR1_JL8-w,10040 vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py,sha256=A-psJkVpEhf6rF_g_EHso0hR6jWMO415e52Z96lTWTw,12857 vllm/v1/attention/backends/mla/sparse_utils.py,sha256=Niep6oTN_8_CyrqTvO_JPByN-7cRoeY7SvnYekSYumU,7124 -vllm/v1/attention/backends/mla/triton_mla.py,sha256=4ChkFVZNMuyhGe87gNSBB7em6xMffJ9WS2Cj3Vlscdc,7038 +vllm/v1/attention/backends/mla/triton_mla.py,sha256=62lvVPNyWlDBYb-lXHonDcNYC_ZmlTVVZfUYR8TucgA,7616 vllm/v1/attention/backends/registry.py,sha256=3nIonTWV4rJx-rI_LIDJUa0fjVaev7N5n3HzSJ1BdQw,9695 vllm/v1/attention/backends/rocm_aiter_fa.py,sha256=VD-JfyOSt2jFfP3_lscGG7gPmRlDIqMpZJ_NfixZJvc,53117 -vllm/v1/attention/backends/rocm_aiter_unified_attn.py,sha256=UU14VRIMXKiNEupXHveZuBIA_MgtpSgbT1VOgGQJuBE,8056 -vllm/v1/attention/backends/rocm_attn.py,sha256=R6i2CkKwPi9MMiTLvjoShfTzDDPFpI1HjKGJ8VrJTwI,16626 +vllm/v1/attention/backends/rocm_aiter_unified_attn.py,sha256=ObOdNzBZ-j5RxsUMuwbVX1Re4_fs2-jN-nKBweYi1IU,9329 +vllm/v1/attention/backends/rocm_attn.py,sha256=vllPGdybcBMF13lR1tgXTZ-wrtkE8N-ZUiP1Jjp_DT0,19124 vllm/v1/attention/backends/short_conv_attn.py,sha256=MEf-oDqbQl62cWs8vmKeSo5_ZcOuNUjW4PC724iGIns,835 vllm/v1/attention/backends/tree_attn.py,sha256=ctIymZevnFTX2VuZCrpa6Qce2RqpmK_iWE8hsf8ecSg,15716 vllm/v1/attention/backends/triton_attn.py,sha256=rymjUf9jVPJhYd2PRLH2dX_mpWBfZt61P8MIpMpOi5o,23275 @@ -2925,7 +2985,7 @@ vllm/v1/attention/ops/__pycache__/triton_unified_attention.cpython-312.pyc,, vllm/v1/attention/ops/__pycache__/vit_attn_wrappers.cpython-312.pyc,, vllm/v1/attention/ops/chunked_prefill_paged_decode.py,sha256=QJOLdArCSXYIhTX6s-HoVbYa3Yol6oR3zLOmNNDEBrA,15750 vllm/v1/attention/ops/common.py,sha256=J_g0iiIkHoqwaxgVe8IGTRNluldoeriDRRg6cmLYA9s,13602 -vllm/v1/attention/ops/flashmla.py,sha256=bl_rqebAj-ip7_5M-U_8bH2EpE8DlkeD4t-72mQZc_c,5119 +vllm/v1/attention/ops/flashmla.py,sha256=N6aXtzwa_u3TBTXClwzoiPGiMx9GKTQ4XIRDq8riTYI,6025 vllm/v1/attention/ops/merge_attn_states.py,sha256=gwJoIzgUPGqEr20vWSF8ByEUPxkBDTSeUd5whJPvAXM,1658 vllm/v1/attention/ops/paged_attn.py,sha256=AWqV-JLsdRyGs3vh4SFxutnwtAdrvyIZ_zNvTpxQNlM,1437 vllm/v1/attention/ops/prefix_prefill.py,sha256=_p33Jl6zj3PqWkq7klbBo6pS_RJvnImBnaA9WUzAaxM,27286 @@ -2935,7 +2995,7 @@ vllm/v1/attention/ops/triton_merge_attn_states.py,sha256=4fOT2PY6AocmhvwAYt5Zu1P vllm/v1/attention/ops/triton_prefill_attention.py,sha256=jO4m6SCLKLk9GYc5zA2ki8qxWoYG34RR-KmYJMwmPbk,7585 vllm/v1/attention/ops/triton_reshape_and_cache_flash.py,sha256=l23FcUrpKLfL1CAjtRKRE7jDotvOm7bGhh4FTBy84cU,12771 vllm/v1/attention/ops/triton_unified_attention.py,sha256=jo05PE1VFUfeOXhZRirXw3UCMIQUWOZY1zOB2_P1kAU,39926 -vllm/v1/attention/ops/vit_attn_wrappers.py,sha256=H7TzVtBZeFed44ec_NJOF2eGfJOVGw7VIF8VJWxcvyw,7404 +vllm/v1/attention/ops/vit_attn_wrappers.py,sha256=mTc6aReCGePTbFltpxQXba1E9YIOY46-8fgQeX-Rxnc,10054 vllm/v1/attention/selector.py,sha256=6YRzCyWndCuO773kxQnjgFRjkpU6OJAFfL1LDuY-jA4,4765 vllm/v1/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/v1/core/__pycache__/__init__.cpython-312.pyc,, @@ -2949,9 +3009,9 @@ vllm/v1/core/__pycache__/single_type_kv_cache_manager.cpython-312.pyc,, vllm/v1/core/block_pool.py,sha256=LqkyYorGoGYH3Uje16kI9wW1XD62gBIyJk7i2xQsbNk,20030 vllm/v1/core/encoder_cache_manager.py,sha256=yG0-WuDAGL0ljoYtxZfbW1ulyaY93k0Z68y8m2HP5DI,15537 vllm/v1/core/kv_cache_coordinator.py,sha256=jwTjrG1gjcMC0Ha5fEm9E_geGwUGNShYCyyUfCvMkeg,22399 -vllm/v1/core/kv_cache_manager.py,sha256=5owMv2aer3Sj_p6iNLc-MN2pRiVZESmdOo429zBpin4,20597 +vllm/v1/core/kv_cache_manager.py,sha256=9byU-exEalqxVO5ZvjoAmy2YlXuZ9p3bx2BFm87nQlI,22107 vllm/v1/core/kv_cache_metrics.py,sha256=7-3sBP7_XwSm1RxsRkKQtSq4wknWUSGpq58TcNa2f0c,3138 -vllm/v1/core/kv_cache_utils.py,sha256=ZkLJ_QXRbjpJIuAo2NtW2oaR0zc739pc0rm4HE3Bcqo,67368 +vllm/v1/core/kv_cache_utils.py,sha256=SOMvwuYKJcbrmBwTPNlBKFTQwXTPlwq1GFbyc78ady8,68571 vllm/v1/core/sched/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/v1/core/sched/__pycache__/__init__.cpython-312.pyc,, vllm/v1/core/sched/__pycache__/async_scheduler.cpython-312.pyc,, @@ -2962,13 +3022,13 @@ vllm/v1/core/sched/__pycache__/scheduler.cpython-312.pyc,, vllm/v1/core/sched/__pycache__/utils.cpython-312.pyc,, vllm/v1/core/sched/async_scheduler.py,sha256=TfN-iEL7BtXM47xBTpQjoJgI6TNqlqJVMN0icPB40JM,2659 vllm/v1/core/sched/interface.py,sha256=W-rtDuHPA6D3YndYfY97-e7njS0fAkhxNEqMwH60520,9077 -vllm/v1/core/sched/output.py,sha256=_ynXfniAPeJ3dfOj97D1mUAqjhmfvEs8kDIYRrW4_aA,9828 +vllm/v1/core/sched/output.py,sha256=X6_GOOUpT_AmpbAZJ_s0JPughfBIC3WKgdvZo_LBXfo,10129 vllm/v1/core/sched/request_queue.py,sha256=S42Tjl-4FS_mEDC4qpkfmD_KyaQFw37YkIBF4F6Crl4,6989 -vllm/v1/core/sched/scheduler.py,sha256=3Tp8IkUUBcdMj996i70WTEFbk2KC6ZjLdolCaB6u-WU,101673 -vllm/v1/core/sched/utils.py,sha256=3EmgmNpjOubCbvW5q0BFekIRQoGkGjSHdX1EskOOP9w,2148 -vllm/v1/core/single_type_kv_cache_manager.py,sha256=SfO1IxrkoY64LyN17CbJbSg_USU6TQmuVr_nRMC0Dmo,48201 -vllm/v1/cudagraph_dispatcher.py,sha256=D5UbrFbjT2lake_WN8shyvVLFAGCpwJmNk4rcqFbjNg,14433 -vllm/v1/engine/__init__.py,sha256=-eyCzUh0c7eXKl0R2LAGqTROR5eDAsVS1olvFlTH-vM,7650 +vllm/v1/core/sched/scheduler.py,sha256=yJ-_VWDF8E2SVLV6uHdZO1YR-x77IgU37XUFpE3XM00,126753 +vllm/v1/core/sched/utils.py,sha256=hegurlVaA0l60qwVQO1WKmw2_CYYWqYjNyXJFIFqobM,4115 +vllm/v1/core/single_type_kv_cache_manager.py,sha256=wTqh7tVrf9jqJQrv57Xftd2LYi5jWm9ijjBaUNgz2ZY,48726 +vllm/v1/cudagraph_dispatcher.py,sha256=tucITt0mub8XERm8UhoveEkUutts2TTFjC8EkFGXLmA,15325 +vllm/v1/engine/__init__.py,sha256=gIHKUrwQ5kP1I985MbkbsFU6CA0oZCPWBtuN_SrV7NM,8342 vllm/v1/engine/__pycache__/__init__.cpython-312.pyc,, vllm/v1/engine/__pycache__/async_llm.cpython-312.pyc,, vllm/v1/engine/__pycache__/coordinator.cpython-312.pyc,, @@ -2982,18 +3042,18 @@ vllm/v1/engine/__pycache__/logprobs.cpython-312.pyc,, vllm/v1/engine/__pycache__/output_processor.cpython-312.pyc,, vllm/v1/engine/__pycache__/parallel_sampling.cpython-312.pyc,, vllm/v1/engine/__pycache__/utils.cpython-312.pyc,, -vllm/v1/engine/async_llm.py,sha256=0xjQJ_tdSIhgY3uuHqY09Trz8ISVvteCuuMeWjjBhkI,41673 -vllm/v1/engine/coordinator.py,sha256=wGbIqHFaHGThFBjgDDsdANVMI33mcp6G1uahTybRpto,17400 -vllm/v1/engine/core.py,sha256=caGTN0fD35dWGXP5iHnI7VzKt73aIxoe3qplR1braSA,74937 -vllm/v1/engine/core_client.py,sha256=_WBCgwuez4dD-qF0iGR1pmQb4V2eZpD-yK_cWuHrWqM,55909 +vllm/v1/engine/async_llm.py,sha256=r4BTWzNMdCA5-FU0hZHoeVPMXY_LFtNvjIPN3dVcn70,42353 +vllm/v1/engine/coordinator.py,sha256=kF3GdPZlPdZcVU4pudADKsZxRzfrpDjpu5UzIheSG2c,18416 +vllm/v1/engine/core.py,sha256=XEzaITtIP0gSG22sMgjkRVIOLmiunWXWHfVrF1EMaMY,87333 +vllm/v1/engine/core_client.py,sha256=P6XQMGBLHJfA6D9Z6G32syo1IdiC6P_dPPYXy-cTu-8,67207 vllm/v1/engine/detokenizer.py,sha256=2bi9yMmZWhtJxxPf_7DGXmQNwqAFewSg35BqdtCQwWc,12779 vllm/v1/engine/exceptions.py,sha256=Uqu6bUKEvMYpeVJYxlqYeveRcd9WnsFUENKt0JLMins,732 -vllm/v1/engine/input_processor.py,sha256=vFNkZ4gEZMwFkUiOTSTA4dDv5Ktm2QfZsZptgVB4Clo,19620 -vllm/v1/engine/llm_engine.py,sha256=-zRacelQUffg0eRLI70fu4874bsZ8-ykhD3mKx7zHi4,16761 +vllm/v1/engine/input_processor.py,sha256=tzo6FjUp3e7MbTMaqj8A19KLahWKtSxjGBXb_id7vbk,19184 +vllm/v1/engine/llm_engine.py,sha256=3l-FyV1wKuDWgmosrxJPoD7yOa0c2lrtMTl2v18rqVI,16788 vllm/v1/engine/logprobs.py,sha256=3KDfKD9nIkUyjzc9Jf5SgfBn8tPN3cnJkO8Vh4Sat3U,8682 vllm/v1/engine/output_processor.py,sha256=VcabTEItWEKOsuBqsnqDuzMFg8VCU1JPsyun_wQRkqI,31131 vllm/v1/engine/parallel_sampling.py,sha256=vz8KPGZAqvcG0bb37DqOCalw892efpb6apAXeeqE9h0,5481 -vllm/v1/engine/utils.py,sha256=QnqRiiHXdLk6T-0k3jZjRotKmSSCNXo8W2T98tkrjUI,42322 +vllm/v1/engine/utils.py,sha256=uc8uVI0inY4mxNHogC6yCkmgHX3X2_eeRdEXyqTAFGI,43327 vllm/v1/executor/__init__.py,sha256=VwU0pRwKjvcCcXII_7N2FIYT2oNj1fquLjvUS-qqE2U,227 vllm/v1/executor/__pycache__/__init__.cpython-312.pyc,, vllm/v1/executor/__pycache__/abstract.cpython-312.pyc,, @@ -3002,13 +3062,13 @@ vllm/v1/executor/__pycache__/ray_distributed_executor.cpython-312.pyc,, vllm/v1/executor/__pycache__/ray_executor.cpython-312.pyc,, vllm/v1/executor/__pycache__/ray_utils.cpython-312.pyc,, vllm/v1/executor/__pycache__/uniproc_executor.cpython-312.pyc,, -vllm/v1/executor/abstract.py,sha256=wugXhVnGaxPbqqNeBXnoLrDvenn9O28tsw55ZAQR-xU,13543 -vllm/v1/executor/multiproc_executor.py,sha256=jnhZ3EmZqTP48gi8gFgIGp6TCOeQF5JksK-pjtHPywY,36296 +vllm/v1/executor/abstract.py,sha256=wVly5AWYsLBqaVmW5kBgUKTr8d8pxAXsyPvk9fIwVlA,13984 +vllm/v1/executor/multiproc_executor.py,sha256=_mctcYpZNrK9D847MN-u8yS402BjBMJyCgi_cUHFHtg,37701 vllm/v1/executor/ray_distributed_executor.py,sha256=tl8FrTK1jqQk2SKF-Rt6RtXacGD6ND-hXIjTrquAbkc,289 -vllm/v1/executor/ray_executor.py,sha256=ZTQBTnJWWC2v8-WLNmmX2jInSo7vnpd3oq9Zvzr5dSs,25988 -vllm/v1/executor/ray_utils.py,sha256=BMK5PvSLYPVGGGj4BUWN9bTj70uGe6ZoXDBz7w19uc0,19384 -vllm/v1/executor/uniproc_executor.py,sha256=ZOJR_Xq5V581VWJjkvPQFwJMHuMqNDL7k_O7tzqNvQo,7244 -vllm/v1/kv_cache_interface.py,sha256=SAnPe2VhcWNmJLgrY_EABO9Atn4uhsqbhKftTyuN6hc,17920 +vllm/v1/executor/ray_executor.py,sha256=FfbQAD4_jA6yCyfhSihOcJyssdIJhR3oIL_JjWEBrhs,26095 +vllm/v1/executor/ray_utils.py,sha256=uQ0LTGxcW3xMTx3DvgI7xCjO6230GsMXCJfi65IN0T8,21725 +vllm/v1/executor/uniproc_executor.py,sha256=J29IY8nFtbJIpnpvgBKofNx-cTiY8H6MUvgPwURqvfQ,6930 +vllm/v1/kv_cache_interface.py,sha256=qrp2bwc4nZGnJAepkfF4Ab5LYFxFkBUeiYWst61q6pE,21143 vllm/v1/kv_offload/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/v1/kv_offload/__pycache__/__init__.cpython-312.pyc,, vllm/v1/kv_offload/__pycache__/abstract.cpython-312.pyc,, @@ -3035,7 +3095,7 @@ vllm/v1/kv_offload/worker/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJW vllm/v1/kv_offload/worker/__pycache__/__init__.cpython-312.pyc,, vllm/v1/kv_offload/worker/__pycache__/cpu_gpu.cpython-312.pyc,, vllm/v1/kv_offload/worker/__pycache__/worker.cpython-312.pyc,, -vllm/v1/kv_offload/worker/cpu_gpu.py,sha256=MjoWZ6i1prPRYxcFLvasoZgXLEw_NmPaZ79ylgedudU,12260 +vllm/v1/kv_offload/worker/cpu_gpu.py,sha256=INfbItgRU9Jr0Hven2rcLbjKBZQ7GDp_T6ejCVdsdu4,12515 vllm/v1/kv_offload/worker/worker.py,sha256=fbTsSg9KkNsFLDyq8NfO64M4pa9A0iGbtpsBH9m-B4M,5275 vllm/v1/metrics/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/v1/metrics/__pycache__/__init__.cpython-312.pyc,, @@ -3051,25 +3111,25 @@ vllm/v1/metrics/prometheus.py,sha256=dhI7DQbT7X3gFlcQOXmoa7PY5JPyw3fvFm0azVEP-ok vllm/v1/metrics/ray_wrappers.py,sha256=dOK01HU7uWdjDIjJ14Y98tnFXp81YfFA9-4COAig-SQ,6614 vllm/v1/metrics/reader.py,sha256=drwtEKJn4gOnaCFvlaYEjibmQx7cQHIsDiQii4B9Uk4,8694 vllm/v1/metrics/stats.py,sha256=8KdlA4K8y2-sfWQMDWfsGZIIfjccXhEiVToYJj047Gc,17808 -vllm/v1/outputs.py,sha256=CGVW4PTP6aUnyEwD-EbFrDsDTc-XhOZrIk1h-Dnw-tg,8705 +vllm/v1/outputs.py,sha256=6B8Mapc9GXmNgQEFmuv1QiXU06tNj1Vp1ldsE-Cwyfo,10547 vllm/v1/pool/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/v1/pool/__pycache__/__init__.cpython-312.pyc,, vllm/v1/pool/__pycache__/metadata.cpython-312.pyc,, vllm/v1/pool/metadata.py,sha256=_ztjF86T0kNUDHkBtRyoXJPeM02qt0cCmQqv76B9lUE,4106 -vllm/v1/request.py,sha256=hqCuC8XgYG9b8JqPs7um07-Mv_ZXY3upv61f2UAR184,12414 +vllm/v1/request.py,sha256=9fdnsu3d_QLkn5xUj_rfUeW6b41k8NbaFnChdqZyYz8,12516 vllm/v1/sample/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/v1/sample/__pycache__/__init__.cpython-312.pyc,, vllm/v1/sample/__pycache__/metadata.cpython-312.pyc,, vllm/v1/sample/__pycache__/rejection_sampler.cpython-312.pyc,, vllm/v1/sample/__pycache__/sampler.cpython-312.pyc,, -vllm/v1/sample/logits_processor/__init__.py,sha256=Hs6S01p-p_Xo-LcG5ihCwwGjqypt_8d_TEFnNLNfbIU,12047 +vllm/v1/sample/logits_processor/__init__.py,sha256=zxtm6oiHsp642GXAylVFjVjRSf4xJJrl8sAHoltyPT0,12085 vllm/v1/sample/logits_processor/__pycache__/__init__.cpython-312.pyc,, vllm/v1/sample/logits_processor/__pycache__/builtin.cpython-312.pyc,, vllm/v1/sample/logits_processor/__pycache__/interface.cpython-312.pyc,, vllm/v1/sample/logits_processor/__pycache__/state.cpython-312.pyc,, -vllm/v1/sample/logits_processor/builtin.py,sha256=ejN7QaSXMarG07G5rVPwCZ-MbYua3ji6JbnjErnz1x8,10323 +vllm/v1/sample/logits_processor/builtin.py,sha256=N-Bq-Nk_Nzc2xgrx7LdPLJpovdmfliMmyc0j5_YfGPM,12440 vllm/v1/sample/logits_processor/interface.py,sha256=t2hnDF_0VXH9D0ZzZ1v3KAteJslPs6fZyBe6XxA6gIY,3205 -vllm/v1/sample/logits_processor/state.py,sha256=Fk_xGySqz7y2MiAi8hL0Q95DNXWelyu5E4XB-pD0d4E,5666 +vllm/v1/sample/logits_processor/state.py,sha256=qWQR8UNixF-DnhgdyzrAHVuICBK8yzVS_BeGOyGfRKw,5676 vllm/v1/sample/metadata.py,sha256=jLeOuEJuWBJhgcy9TBL2gfjFj2DaJLWk-zPXARUEbsI,1142 vllm/v1/sample/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/v1/sample/ops/__pycache__/__init__.cpython-312.pyc,, @@ -3083,27 +3143,31 @@ vllm/v1/sample/ops/logprobs.py,sha256=krdqPu-zZQChtxbssdvIQtm49l4EckUTljJNVSxZfj vllm/v1/sample/ops/penalties.py,sha256=OOrkMB8o3sFEbYjnCq8fvTbT0EEdL7Zn2U1Euf44CjM,1929 vllm/v1/sample/ops/topk_topp_sampler.py,sha256=dfxdo8oMZWuCxd_RYyz3bXAVN21LAIoSO6RsTW42-PU,15403 vllm/v1/sample/ops/topk_topp_triton.py,sha256=Fp-qC5xPzkFodAWqKevHKQL2XxGhothV0okhin6CRAo,49608 -vllm/v1/sample/rejection_sampler.py,sha256=EFjwscvMpQKlUjKYl4oB2tYjvfPkYKMczi5rK4edzSQ,31448 +vllm/v1/sample/rejection_sampler.py,sha256=Mbe4cf4PxqXHzDVSvyexW81_tzMJ1LI5S5UMDeWmqq0,31742 vllm/v1/sample/sampler.py,sha256=uTYAiQ2r3lQxCjmzTPKpYXP7tmIRudSnQ4bBq57krl4,12583 vllm/v1/serial_utils.py,sha256=kDvyyWcvgeomZ5i8K4Vvap8wCbn8ZjiQyJFFv-HoR5k,19753 vllm/v1/spec_decode/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/v1/spec_decode/__pycache__/__init__.cpython-312.pyc,, vllm/v1/spec_decode/__pycache__/draft_model.cpython-312.pyc,, vllm/v1/spec_decode/__pycache__/eagle.cpython-312.pyc,, +vllm/v1/spec_decode/__pycache__/extract_hidden_states.cpython-312.pyc,, vllm/v1/spec_decode/__pycache__/medusa.cpython-312.pyc,, vllm/v1/spec_decode/__pycache__/metadata.cpython-312.pyc,, vllm/v1/spec_decode/__pycache__/metrics.cpython-312.pyc,, +vllm/v1/spec_decode/__pycache__/multi_layer_eagle.cpython-312.pyc,, vllm/v1/spec_decode/__pycache__/ngram_proposer.cpython-312.pyc,, vllm/v1/spec_decode/__pycache__/suffix_decoding.cpython-312.pyc,, vllm/v1/spec_decode/__pycache__/utils.cpython-312.pyc,, vllm/v1/spec_decode/draft_model.py,sha256=mcv4VYfi9qcg_IUIwpYV82NXPUoJqHg_kDq9byFUSo8,2904 -vllm/v1/spec_decode/eagle.py,sha256=BnMaN9L6Ablyl38UCGnLMgNpTw-VX_6aRwmHnEcnK2Q,78574 +vllm/v1/spec_decode/eagle.py,sha256=K1yxTolHg1FqE1JxRl_KTKzFSw_Mw0L0JFmh_Ecuaqk,84237 +vllm/v1/spec_decode/extract_hidden_states.py,sha256=U2OVdz1_cTgAKm-cMRzpMuGm2tcLeUAcJU8l4NDc-ag,16545 vllm/v1/spec_decode/medusa.py,sha256=8MlI6DXXZLx6dZbGeu-PkKUuWmRUW0eg9TnCH6tk_V4,2694 -vllm/v1/spec_decode/metadata.py,sha256=pmlkh7PIgjwRn2Ad6BUENBve_v6Vuv5nPB9XCrICjag,2353 +vllm/v1/spec_decode/metadata.py,sha256=ekCirq6LvkOCh39r4pAfJ-3MTp3FOjYI8fcjypnKJoI,3771 vllm/v1/spec_decode/metrics.py,sha256=KZ4k_OZ5hyJFc9EciPxjEXJCTDWG6BuT9369H0VANxI,7867 +vllm/v1/spec_decode/multi_layer_eagle.py,sha256=0Tn6vHKEzHfibuiMjqheF8_KXIZeXGNRkFNa8h_conY,16506 vllm/v1/spec_decode/ngram_proposer.py,sha256=TTS_3Z_72a1U5pgTTJYorYgbk8cFGOrC1I8OUB-cZ2M,10935 vllm/v1/spec_decode/suffix_decoding.py,sha256=cbt6zu0gqI8Fe7mjNqea2vKy2dQWLStQs3yIvj5i2Pw,4442 -vllm/v1/spec_decode/utils.py,sha256=Bn6fU9Jzg6P2mNP-SjRCAfQs-Iw0lSqlReAgUW_puqA,15121 +vllm/v1/spec_decode/utils.py,sha256=iTaZanw9Azukt0U2mZYTxfcL8xLAH3ljY7JRrdcepBA,15624 vllm/v1/structured_output/__init__.py,sha256=RnhQEalmOswNbXZGTVF-nvyzl8U31ppO6mGFin8cU4Q,14970 vllm/v1/structured_output/__pycache__/__init__.cpython-312.pyc,, vllm/v1/structured_output/__pycache__/backend_guidance.cpython-312.pyc,, @@ -3146,9 +3210,9 @@ vllm/v1/worker/__pycache__/xpu_model_runner.cpython-312.pyc,, vllm/v1/worker/__pycache__/xpu_worker.cpython-312.pyc,, vllm/v1/worker/block_table.py,sha256=G2urbvJ3fJQXzTimkvkxokCG5DVVL2KDQ76w4TuQ5uU,13264 vllm/v1/worker/cp_utils.py,sha256=BOVYrUTA2bv1YpaoRRTQPiy0oAERoWldGwFFkitUCcg,2389 -vllm/v1/worker/cpu_model_runner.py,sha256=ALLO8HBcqsntz68X8w92AvjSUmE49fJfpamsrjq75Qc,4316 -vllm/v1/worker/cpu_worker.py,sha256=nasaeTVYKSwDOa4cds-kcnHfmK4FP2nhmCA0eANJsrc,9015 -vllm/v1/worker/dp_utils.py,sha256=HF6MZ5bGffPbVV7rsyfasF18J3yqkj6PSp27NBMo2Eo,8851 +vllm/v1/worker/cpu_model_runner.py,sha256=ymDHFo4eiV06SIrUVvWpWfbgb1S9YxugtyukIPxIvr0,4531 +vllm/v1/worker/cpu_worker.py,sha256=dhxpYHOor_HFRckX6K_MaogHBHypOwOnrX7IKjZGq1k,9066 +vllm/v1/worker/dp_utils.py,sha256=wUPKqk9kn4OV14Pf3rhr9yLv0VTNoTzz5wRBxW7qPyU,8688 vllm/v1/worker/ec_connector_model_runner_mixin.py,sha256=j9YWGXs4tuxFrDJ-SLOcdPbTJ6623cI5ptbDeQCpzTs,2988 vllm/v1/worker/gpu/README.md,sha256=1H_fnHpMus5c95vNXWX5-Ryx3UkQiRwY_b5WDd0D8Ww,181 vllm/v1/worker/gpu/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 @@ -3167,14 +3231,15 @@ vllm/v1/worker/gpu/__pycache__/model_runner.cpython-312.pyc,, vllm/v1/worker/gpu/__pycache__/pp_utils.cpython-312.pyc,, vllm/v1/worker/gpu/__pycache__/states.cpython-312.pyc,, vllm/v1/worker/gpu/__pycache__/structured_outputs.cpython-312.pyc,, -vllm/v1/worker/gpu/async_utils.py,sha256=EJlTaJg44ySvT2WFxfIgQk71K1ikYoVLOXT3MnJAGas,3533 +vllm/v1/worker/gpu/__pycache__/warmup.cpython-312.pyc,, +vllm/v1/worker/gpu/async_utils.py,sha256=ulLo02hHgh5QxMd_nCGN9nXBfRiAcnyJpGjlpL86mGg,4923 vllm/v1/worker/gpu/attn_utils.py,sha256=AZ_u3enlzaIsQT0y5r2A4-aEr7QeQnEy1ZscR28zmy0,8624 -vllm/v1/worker/gpu/block_table.py,sha256=9r3ddR4ikKo6BaXYRHWmh0-QYWYtZM5NE2pFYWzkW6Q,9709 +vllm/v1/worker/gpu/block_table.py,sha256=1o5qeyq8neXRt_yJf8WIWlvmoDRP-hK8tEWaehCTabQ,10454 vllm/v1/worker/gpu/buffer_utils.py,sha256=rECYHbR75Ls1zLIRa4hJfRokKfYaR5tuDgFpy9UHZYc,7209 vllm/v1/worker/gpu/cp_utils.py,sha256=ffa-o_Ujw2M5D4ceuTBqSBeFuLD0lbDHDqhASxvuJ7E,1700 -vllm/v1/worker/gpu/cudagraph_utils.py,sha256=0Q5OHMcWm5FIwVTDp7EjoBl5QlGirZknUcmzlv0BGJM,17605 +vllm/v1/worker/gpu/cudagraph_utils.py,sha256=_bzWlD267-JvuMOkFKuHgVYiBkyVHU28z-wbMenj-8c,16053 vllm/v1/worker/gpu/dp_utils.py,sha256=ZgZkMHZgAZGBfpJE-ZDQEOgEENRh583xt3Xq_LQrOdE,2936 -vllm/v1/worker/gpu/input_batch.py,sha256=4NaX9YxOlGjG2WjwuzB6XG4xDQjS2QKpqVGthxsDrzk,17462 +vllm/v1/worker/gpu/input_batch.py,sha256=Hj6nmlXwoD8zKT66wHeuAs6sygVlu7xt7BR5YI6DJE0,17907 vllm/v1/worker/gpu/kv_connector.py,sha256=rEIQ0BaHqWdq0KNafnsBwwmwNo0l0ljA2MHKLq3eTb0,4709 vllm/v1/worker/gpu/lora_utils.py,sha256=XF82PxJNFA4AR6BUk2kyqmgKHdfF9o6tkLibEELbMM0,1563 vllm/v1/worker/gpu/metrics/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 @@ -3183,11 +3248,23 @@ vllm/v1/worker/gpu/metrics/__pycache__/logits.cpython-312.pyc,, vllm/v1/worker/gpu/metrics/logits.py,sha256=6gWNdRCgd6Vfisfkw9WWUINGJE6QjLNyhl1xWfYu3Ik,1205 vllm/v1/worker/gpu/mm/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/v1/worker/gpu/mm/__pycache__/__init__.cpython-312.pyc,, +vllm/v1/worker/gpu/mm/__pycache__/encoder_cache.cpython-312.pyc,, vllm/v1/worker/gpu/mm/__pycache__/encoder_runner.cpython-312.pyc,, vllm/v1/worker/gpu/mm/__pycache__/mrope_utils.cpython-312.pyc,, -vllm/v1/worker/gpu/mm/encoder_runner.py,sha256=mBrs8T1_A2niX2Qo6ktDYJoTPT7Karv8bDV8TBPBwLA,7244 +vllm/v1/worker/gpu/mm/encoder_cache.py,sha256=B5nfNNjgosWQPKvWbJXcacBs4Yxg8ZGL8fdKlorhCIk,1326 +vllm/v1/worker/gpu/mm/encoder_runner.py,sha256=_J3xYnU6qh4S7QbB_KMgLVCQCz_Ubxwhl6mhnD7hk4g,6159 vllm/v1/worker/gpu/mm/mrope_utils.py,sha256=NO9QS1rPfqAAMOI8ppNp4aFJrsmhagge8ZAZpTYS-RM,4879 -vllm/v1/worker/gpu/model_runner.py,sha256=5ggLg-eUroMW5c1fz4GOnx8cfe3_b6blXjfZTmoKB_g,47058 +vllm/v1/worker/gpu/model_runner.py,sha256=WreH6o1QR0TjJLBslGTxpn_H8huNpCf3M6EiNsRSrcs,48326 +vllm/v1/worker/gpu/model_states/__init__.py,sha256=WkIIo16IxVTj5oOB_opmueNvTS_XHhf0WJ1g6gOSu04,530 +vllm/v1/worker/gpu/model_states/__pycache__/__init__.cpython-312.pyc,, +vllm/v1/worker/gpu/model_states/__pycache__/default.cpython-312.pyc,, +vllm/v1/worker/gpu/model_states/__pycache__/interface.cpython-312.pyc,, +vllm/v1/worker/gpu/model_states/default.py,sha256=rKG1UEuAEylulebXXZvCp8aVwm57bmrZOVHaFZ-NVCo,6345 +vllm/v1/worker/gpu/model_states/interface.py,sha256=xZ5GQBbbWr6DnpHhz8X-a6xm583BSbf8nF94XtdZBrg,1971 +vllm/v1/worker/gpu/pool/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +vllm/v1/worker/gpu/pool/__pycache__/__init__.cpython-312.pyc,, +vllm/v1/worker/gpu/pool/__pycache__/pooling_runner.cpython-312.pyc,, +vllm/v1/worker/gpu/pool/pooling_runner.py,sha256=_RNihQqlvQnR1SfH46vZunkIAxFfAB-co40frFR5nQ0,1682 vllm/v1/worker/gpu/pp_utils.py,sha256=4iDJrkigTxM3IcyoEAnZzHlT50wBVGKQwKqHu7vhBP8,1371 vllm/v1/worker/gpu/sample/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 vllm/v1/worker/gpu/sample/__pycache__/__init__.cpython-312.pyc,, @@ -3201,16 +3278,16 @@ vllm/v1/worker/gpu/sample/__pycache__/penalties.cpython-312.pyc,, vllm/v1/worker/gpu/sample/__pycache__/prompt_logprob.cpython-312.pyc,, vllm/v1/worker/gpu/sample/__pycache__/sampler.cpython-312.pyc,, vllm/v1/worker/gpu/sample/__pycache__/states.cpython-312.pyc,, -vllm/v1/worker/gpu/sample/bad_words.py,sha256=rU2u_9OnLDLdTsJX2pOqy5xaAihPew7YUlG5AqfW-gE,6483 -vllm/v1/worker/gpu/sample/gumbel.py,sha256=_sCRSx0VS0H743wu2ZYMvAMsL4XsJj0iDjvGzFFQWbo,4372 -vllm/v1/worker/gpu/sample/logit_bias.py,sha256=c2meBhCqk2KplYd5r1_HlafQjqTEodM04Rfhx06qb04,9438 +vllm/v1/worker/gpu/sample/bad_words.py,sha256=6ZWcrjNSk_DZ5iosRDjrBSRjC3w_VAcis6r3SZrJ55Q,6489 +vllm/v1/worker/gpu/sample/gumbel.py,sha256=saZxdQQBUBuAff1GIE_aozusW_hxUb1T22jJLCt8xXI,4458 +vllm/v1/worker/gpu/sample/logit_bias.py,sha256=BbkqDRpL4pGiYtcH-hRIK564p7QVTX5QPRh1A8Lt4EY,9496 vllm/v1/worker/gpu/sample/logprob.py,sha256=mhQxQN8aKznHd4hOTtBcwNY8dIeasQkpksdggp1pX-8,3983 -vllm/v1/worker/gpu/sample/min_p.py,sha256=lQPdt01iJjWeIvkC5Tan-3oQM1kBOMaC-4IzhpjOjwQ,1661 +vllm/v1/worker/gpu/sample/min_p.py,sha256=iEcKMo0tc9Z7UP2i_J-tf59tmDPyo5qTzbpPp9Y2Ij0,1761 vllm/v1/worker/gpu/sample/output.py,sha256=Uzj6Lujb5LEswYwYUsTmJ9HkDwJ-QoV1s1uyH4VZ0CE,349 -vllm/v1/worker/gpu/sample/penalties.py,sha256=w_WNkWAH1wBLZ5omhaE44uex-C8mL3AnO-hJ5r4pOso,10636 +vllm/v1/worker/gpu/sample/penalties.py,sha256=7HUBXmO-CHSLVpIpn-OUG-D64RIqW8XURZF_grCjoJs,10757 vllm/v1/worker/gpu/sample/prompt_logprob.py,sha256=bAdWFBumN1rPxQbUi_tcYZCKVN-3l6dKMtdmwzZNwms,8131 -vllm/v1/worker/gpu/sample/sampler.py,sha256=wwaw91gaFN1glvTNEo6Pg56iWGZvW7AnZIRwcg-qiuY,5838 -vllm/v1/worker/gpu/sample/states.py,sha256=hkXodXx-Vqe6Vj9m7XHJSurskz9TuR_7OzORx0ZMJ_Q,3886 +vllm/v1/worker/gpu/sample/sampler.py,sha256=EBx21NwKubvWQu_X-5qov_Jmn2vgUeoUPSNMA0cuAKM,5972 +vllm/v1/worker/gpu/sample/states.py,sha256=OW-3nHrPXB3Ksa_H_UMA2bJAe-_IdNHoiuxnwtTQ5qk,3949 vllm/v1/worker/gpu/spec_decode/__init__.py,sha256=MwFTIriVJZZqVhVWbas6czZW_eXW5TwSoCg1Ro-ETqQ,584 vllm/v1/worker/gpu/spec_decode/__pycache__/__init__.cpython-312.pyc,, vllm/v1/worker/gpu/spec_decode/__pycache__/rejection_sample.cpython-312.pyc,, @@ -3221,28 +3298,33 @@ vllm/v1/worker/gpu/spec_decode/eagle/__pycache__/cudagraph.cpython-312.pyc,, vllm/v1/worker/gpu/spec_decode/eagle/__pycache__/eagle3_utils.cpython-312.pyc,, vllm/v1/worker/gpu/spec_decode/eagle/__pycache__/speculator.cpython-312.pyc,, vllm/v1/worker/gpu/spec_decode/eagle/__pycache__/utils.cpython-312.pyc,, -vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py,sha256=79E6LvNpC2GK13mCX-uKGhbvDokyR1DcTFyxnsns8AA,6807 +vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py,sha256=19wpXyZ5sT6MwdBNnDJYgumwpYuuzEacnO70gWsIHXM,7681 vllm/v1/worker/gpu/spec_decode/eagle/eagle3_utils.py,sha256=h2v6qPgEyrmyA_VYBrEXXyAMiPWe7qmNsHMaB2i69gc,1726 -vllm/v1/worker/gpu/spec_decode/eagle/speculator.py,sha256=LgYzo-ddGBlwgMrG_Lrv5iN0AJ1rG8wUiRrKHbmu83k,21159 +vllm/v1/worker/gpu/spec_decode/eagle/speculator.py,sha256=XJJn3xQiq-NVBY4LNFVinMWPrdNsEZy-mOuFCIoe1OM,22298 vllm/v1/worker/gpu/spec_decode/eagle/utils.py,sha256=1gmAeKdUTXhgnEf1NS33GEaNyIuBx_KQTEKZbPA8P5U,2170 vllm/v1/worker/gpu/spec_decode/rejection_sample.py,sha256=316CS-i86IEDzWt6y6bEhwQkb4VviFQfqaXDTAUuCPg,2218 vllm/v1/worker/gpu/spec_decode/utils.py,sha256=w82Z2kbCmr05EjZHIPMaDWhwMGbSN9qLFkw0vntVpOA,1915 vllm/v1/worker/gpu/states.py,sha256=ZaKsOUHd9kd7j_oC9RHkcQnRXJCLJ3dPyAdND0DQe3E,4806 vllm/v1/worker/gpu/structured_outputs.py,sha256=qykZ-Y9PFFk2cY8JZgz1l9Rhg_zXT5XHKsGyxw5cIcg,4108 -vllm/v1/worker/gpu_input_batch.py,sha256=Eb9Biwl5MM2UKzTolIaNY3DLMv1U4OpIAN6FDkJaqYM,43321 -vllm/v1/worker/gpu_model_runner.py,sha256=2wfvqi9aOVmYPeRk0p582CxdNNBNlYwEwqViq-_0LVA,277994 +vllm/v1/worker/gpu/warmup.py,sha256=mK46LlfKN0f9dVt8jPUfjCXLuZ0nd4p_1Uyu0uy-5AE,4091 +vllm/v1/worker/gpu_input_batch.py,sha256=hpOGLUJEaFj4FKhD2lMka9N5aobDZ1muy3PPT8fucFE,46692 +vllm/v1/worker/gpu_model_runner.py,sha256=xgmFWuLHIpF1pLkF8nrz5e8l7iM7AjXUK-dBYdtO_1Q,295633 vllm/v1/worker/gpu_ubatch_wrapper.py,sha256=M_sKqzRgCPDEGTP8Fpjahm_9k-gSerCHRnPUgsXXfaA,19455 -vllm/v1/worker/gpu_worker.py,sha256=Jo_zldnw2w1dlJiBcZUQa5IVgo9WbjJsl33sigOvIJ4,47469 +vllm/v1/worker/gpu_worker.py,sha256=O4Hdp4J2rS8Rw_wI1eMyTP9Um3HNPL-vlpr-yx1xIXs,40459 vllm/v1/worker/kv_connector_model_runner_mixin.py,sha256=JZrNatrj2mu_1H5BMHmV2PC3_npX3MKd68RUDFWp_Co,10927 vllm/v1/worker/lora_model_runner_mixin.py,sha256=SbtsW2hCNf1T6yF2al3rwXYXZ4gBwWVW8JdiUUqSFMY,11014 -vllm/v1/worker/mamba_utils.py,sha256=4tM43Oh8BJnHdx7Zuo_yF3kcqFP55gaUt8YdJXGXhbE,10073 +vllm/v1/worker/mamba_utils.py,sha256=BA1n7csi212an-tECS1Ui8MNkKV1CVqTK8prqQF6uNo,10489 vllm/v1/worker/tpu_input_batch.py,sha256=Cc9z4nxam3SR1QSIPEkPYsAJ1e2ins8oLL5FcppSAyU,24177 vllm/v1/worker/ubatch_utils.py,sha256=LKzttglKXQ-8k-aecrgp3wt1gYuQsCwOyiw0rTshUeQ,8550 vllm/v1/worker/ubatching.py,sha256=QDkSQcVk_rXxbHeJiuauFS7W5xpGguKkBjh3hdjeAtc,8394 -vllm/v1/worker/utils.py,sha256=oYzLhX7_-dSG1qq7EInQC8PkYlD7SpCFB5HZzOSxg1g,9446 -vllm/v1/worker/worker_base.py,sha256=sHdv9d9kNFNcZnzgPIwGzt8YJD3VKJ9JZJA2WDM_8Zw,14129 -vllm/v1/worker/workspace.py,sha256=YgWweZ2HOEV75SnSlTpIlUQonFSktC_EjGzOvXmTbH4,9102 +vllm/v1/worker/utils.py,sha256=HAUB-4Ba9FE8eJBmmimrikc8DIu9W9lBT3ipUch2MNU,23346 +vllm/v1/worker/worker_base.py,sha256=qF5yIRnHw1lm5jCb4OKQ3gE0h2TWxCUaPwVIOP-6qvo,13986 +vllm/v1/worker/workspace.py,sha256=ZDxKVepVtm9gdJK_dzHyxpkg2iz1tDA5mVkqf4zqqsE,10070 vllm/v1/worker/xpu_model_runner.py,sha256=Vuw43jij32tueGHg-2Hxiw59ZwxKsWdn2y2WipBudK4,1534 vllm/v1/worker/xpu_worker.py,sha256=XLL1AXvgP8k_PTXpZRMFYn2el156VZ068Ktx9i4ML4Y,4105 -vllm/version.py,sha256=nYdjue0uc9o6X6OSVlXN3DROsCofHt0xAlw6EDm5vOA,57 +vllm/version.py,sha256=qlCNmDctpuxf10HWjBqD2XAPyLvtWXBrLFVw274Zk7w,54 vllm/vllm_flash_attn/.gitkeep,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +vllm/vllm_flash_attn/__init__.py,sha256=qVPVedLMtxi42E_M9z7_Z8nNtQPhunOQ6RIaKzTcjq4,689 +vllm/vllm_flash_attn/__pycache__/__init__.cpython-312.pyc,, +vllm/vllm_flash_attn/__pycache__/flash_attn_interface.cpython-312.pyc,, +vllm/vllm_flash_attn/flash_attn_interface.py,sha256=VclHCNwVRywtvnHOVklDhkBcm1fdnIGLvQlz8h4FWCs,20201 diff --git a/vllm-0.16.1rc0+corex.4.4.0.dist-info/REQUESTED b/vllm-0.17.0+corex.20260420090923.dist-info/REQUESTED similarity index 100% rename from vllm-0.16.1rc0+corex.4.4.0.dist-info/REQUESTED rename to vllm-0.17.0+corex.20260420090923.dist-info/REQUESTED diff --git a/vllm-0.16.1rc0+corex.4.4.0.dist-info/WHEEL b/vllm-0.17.0+corex.20260420090923.dist-info/WHEEL similarity index 65% rename from vllm-0.16.1rc0+corex.4.4.0.dist-info/WHEEL rename to vllm-0.17.0+corex.20260420090923.dist-info/WHEEL index 1ef5583..0885d05 100644 --- a/vllm-0.16.1rc0+corex.4.4.0.dist-info/WHEEL +++ b/vllm-0.17.0+corex.20260420090923.dist-info/WHEEL @@ -1,5 +1,5 @@ Wheel-Version: 1.0 -Generator: setuptools (82.0.0) +Generator: setuptools (80.10.2) Root-Is-Purelib: true Tag: py3-none-any diff --git a/vllm-0.17.0+corex.20260420090923.dist-info/direct_url.json b/vllm-0.17.0+corex.20260420090923.dist-info/direct_url.json new file mode 100644 index 0000000..2795c48 --- /dev/null +++ b/vllm-0.17.0+corex.20260420090923.dist-info/direct_url.json @@ -0,0 +1 @@ +{"archive_info": {"hash": "sha256=844cb01bfec51cf2ec37322ff74c77b31cced2cd9253312cff724ebd06b7f740", "hashes": {"sha256": "844cb01bfec51cf2ec37322ff74c77b31cced2cd9253312cff724ebd06b7f740"}}, "url": "file:///home/poweruser/zrl/code/vllm_hub/vllm/build_pip/vllm-0.17.0%2Bcorex.20260420090923-py3-none-any.whl"} \ No newline at end of file diff --git a/vllm-0.16.1rc0+corex.4.4.0.dist-info/entry_points.txt b/vllm-0.17.0+corex.20260420090923.dist-info/entry_points.txt similarity index 100% rename from vllm-0.16.1rc0+corex.4.4.0.dist-info/entry_points.txt rename to vllm-0.17.0+corex.20260420090923.dist-info/entry_points.txt diff --git a/vllm-0.16.1rc0+corex.4.4.0.dist-info/licenses/LICENSE b/vllm-0.17.0+corex.20260420090923.dist-info/licenses/LICENSE similarity index 100% rename from vllm-0.16.1rc0+corex.4.4.0.dist-info/licenses/LICENSE rename to vllm-0.17.0+corex.20260420090923.dist-info/licenses/LICENSE diff --git a/vllm-0.16.1rc0+corex.4.4.0.dist-info/top_level.txt b/vllm-0.17.0+corex.20260420090923.dist-info/top_level.txt similarity index 100% rename from vllm-0.16.1rc0+corex.4.4.0.dist-info/top_level.txt rename to vllm-0.17.0+corex.20260420090923.dist-info/top_level.txt diff --git a/vllm/.gitignore b/vllm/.gitignore deleted file mode 100644 index 3492fbe..0000000 --- a/vllm/.gitignore +++ /dev/null @@ -1,244 +0,0 @@ -# version file generated by setuptools-scm -/vllm/_version.py - -# vllm-flash-attn built from source -vllm/vllm_flash_attn/* -!vllm/vllm_flash_attn/__init__.py -!vllm/vllm_flash_attn/flash_attn_interface.py - -# OpenAI triton kernels copied from source -vllm/third_party/triton_kernels/* - -# FlashMLA interface copied from source -vllm/third_party/flashmla/flash_mla_interface.py - -# triton jit -.triton - -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -cmake-build-*/ -CMakeUserPresets.json -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST -/.deps/ - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ -cover/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# generated files -**/generated/** - -# uv -uv.lock - -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# poetry -# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock - -# pdm -# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. -#pdm.lock -# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it -# in version control. -# https://pdm.fming.dev/#use-with-ide -.pdm.toml - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site -docs/argparse -docs/examples/* -!docs/examples/README.md - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -.idea/ - -# VSCode -.vscode/ - -# Claude -.claude/ - -# Codex -.codex/ - -# Cursor -.cursor/ - -# DS Store -.DS_Store - -# Results -*.csv - -# Python pickle files -*.pkl - -# Sphinx documentation -_build/ - -# vim swap files -*.swo -*.swp - -# hip files generated by PyTorch -*.hip -*_hip* -hip_compat.h - -# Benchmark dataset -benchmarks/**/*.json - -# Linting -actionlint -shellcheck*/ - -# Ignore moe/marlin_moe gen code -csrc/moe/marlin_moe_wna16/kernel_* - -# Ignore ep_kernels_workspace folder -ep_kernels_workspace/ - -# Allow tracked library source folders under submodules (e.g., benchmarks/lib) -!vllm/benchmarks/lib/ - -# Generated gRPC protobuf files (compiled at build time from vllm_engine.proto) -vllm/grpc/vllm_engine_pb2.py -vllm/grpc/vllm_engine_pb2_grpc.py -vllm/grpc/vllm_engine_pb2.pyi - -# Ignore generated cpu headers -csrc/cpu/cpu_attn_dispatch_generated.h - diff --git a/vllm/__init__.py b/vllm/__init__.py index 19b2cdc..968d1a1 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -14,8 +14,6 @@ import typing import vllm.env_override # noqa: F401 MODULE_ATTRS = { - "bc_linter_skip": "._bc_linter:bc_linter_skip", - "bc_linter_include": "._bc_linter:bc_linter_include", "AsyncEngineArgs": ".engine.arg_utils:AsyncEngineArgs", "EngineArgs": ".engine.arg_utils:EngineArgs", "AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine", @@ -62,8 +60,6 @@ if typing.TYPE_CHECKING: from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.v1.executor.ray_utils import initialize_ray_cluster - - from ._bc_linter import bc_linter_include, bc_linter_skip else: def __getattr__(name: str) -> typing.Any: @@ -79,8 +75,6 @@ else: __all__ = [ "__version__", - "bc_linter_skip", - "bc_linter_include", "__version_tuple__", "LLM", "ModelRegistry", diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 8ef34bf..c8366ec 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -87,6 +87,10 @@ def _rocm_aiter_fused_moe_impl( a2_scale: torch.Tensor | None = None, num_local_tokens: torch.Tensor | None = None, output_dtype: torch.dtype | None = None, + hidden_pad: int = 0, + intermediate_pad: int = 0, + bias1: torch.Tensor | None = None, + bias2: torch.Tensor | None = None, ) -> torch.Tensor: from aiter import ActivationType, QuantType from aiter.fused_moe import fused_moe @@ -110,6 +114,10 @@ def _rocm_aiter_fused_moe_impl( a2_scale, num_local_tokens=num_local_tokens, dtype=output_dtype, + hidden_pad=hidden_pad, + intermediate_pad=intermediate_pad, + bias1=bias1, + bias2=bias2, ) @@ -307,6 +315,28 @@ def _rocm_aiter_grouped_topk_fake( pass +def _rocm_aiter_fused_topk_impl( + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + gate_up: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + from aiter.fused_moe import fused_topk + + # fused_topk returns (topk_weights, topk_indices) + return fused_topk(x, router_logits, top_k, gate_up) + + +def _rocm_aiter_fused_topk_fake( + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + gate_up: bool, +) -> None: + # tuple[torch.Tensor, torch.Tensor]: + pass + + # Cache whether aiter supports FP8 MLA parameters _AITER_MLA_SUPPORTS_FP8: bool | None = None @@ -994,6 +1024,70 @@ class rocm_aiter_ops: cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS cls._TRITON_UNQUANT_GEMM = envs.VLLM_ROCM_USE_AITER_TRITON_GEMM + @staticmethod + def get_aiter_activation_type(activation_str: str): + """ + Given an activation type as a string, returns the corresponding aiter ActivationType enum. + Supported activation types: "no", "none", "silu", "gelu", "swiglu". + Returns None if the mapping fails. + + Args: + activation_str (str): Activation type as string. + + Returns: + Aiter ActivationType enum value, or None if not found. + """ + # Import only locally, since aiter may not always be available. + try: + from aiter import ActivationType + except ImportError: + return None + + if not isinstance(activation_str, str): + return None + + name = activation_str.strip().lower() + mapping = { + "none": ActivationType.No, + "no": ActivationType.No, + "silu": ActivationType.Silu, + "gelu": ActivationType.Gelu, + "swiglu": ActivationType.Swiglu, + } + return mapping.get(name) + + @staticmethod + def get_aiter_quant_type(quant_type_str: str): + """ + Given a quantization type as a string, returns the corresponding aiter QuantType enum. + Supported quantization types: "no", "per_tensor", "per_token", "per_1x32", "per_1x128", "per_128x128". + Returns None if the mapping fails. + + Args: + quant_type_str (str): Quantization type as string. + + Returns: + Aiter QuantType enum value, or None if not found. + """ + try: + from aiter import QuantType + except ImportError: + return None + + if not isinstance(quant_type_str, str): + return None + + name = quant_type_str.strip().lower() + mapping = { + "no": QuantType.No, + "per_tensor": QuantType.per_Tensor, + "per_token": QuantType.per_Token, + "per_1x32": QuantType.per_1x32, + "per_1x128": QuantType.per_1x128, + "per_128x128": QuantType.per_128x128, + } + return mapping.get(name) + @classmethod @if_aiter_supported def is_enabled(cls) -> bool: @@ -1127,6 +1221,14 @@ class rocm_aiter_ops: dispatch_key=current_platform.dispatch_key, ) + direct_register_custom_op( + op_name="rocm_aiter_fused_topk", + op_func=_rocm_aiter_fused_topk_impl, + mutates_args=[], + fake_impl=_rocm_aiter_fused_topk_fake, + dispatch_key=current_platform.dispatch_key, + ) + direct_register_custom_op( op_name="rocm_aiter_mla_decode_fwd", op_func=_rocm_aiter_mla_decode_fwd_impl, @@ -1360,6 +1462,10 @@ class rocm_aiter_ops: a2_scale: torch.Tensor | None = None, num_local_tokens: torch.Tensor | None = None, output_dtype: torch.dtype | None = None, + hidden_pad: int = 0, + intermediate_pad: int = 0, + bias1: torch.Tensor | None = None, + bias2: torch.Tensor | None = None, ) -> torch.Tensor: return torch.ops.vllm.rocm_aiter_fused_moe( hidden_states, @@ -1377,6 +1483,10 @@ class rocm_aiter_ops: a2_scale, num_local_tokens, output_dtype, + hidden_pad, + intermediate_pad, + bias1, + bias2, ) @staticmethod @@ -1481,6 +1591,15 @@ class rocm_aiter_ops: routed_scaling_factor, ) + @staticmethod + def fused_topk( + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + gate_up: bool, + ) -> tuple[torch.Tensor, torch.Tensor]: + return torch.ops.vllm.rocm_aiter_fused_topk(x, router_logits, top_k, gate_up) + @staticmethod def mla_decode_fwd( q: torch.Tensor, @@ -1701,6 +1820,47 @@ class rocm_aiter_ops: return shuffle_weight(tensor, layout=layout) + @staticmethod + def shuffle_weight_a16w4( + tensor: "torch.Tensor", + nLane: int, + gate_up: bool, + ) -> "torch.Tensor": + """ + Shuffles the weight tensor into (A16W4) layout for AITER kernels. + + Args: + tensor: The input weight tensor to be shuffled. + layout: The block layout to use, defaults to (16, 4). + + Returns: + torch.Tensor: The shuffled tensor. + """ + from aiter.ops.shuffle import shuffle_weight_a16w4 + + return shuffle_weight_a16w4(tensor, nLane, gate_up) + + @staticmethod + def shuffle_scale_a16w4( + tensor: "torch.Tensor", + num_experts: int, + gate_up: bool, + ) -> "torch.Tensor": + """ + Shuffles the scale tensor into (A16W4) layout for AITER kernels. + + Args: + tensor: The input scale tensor to be shuffled. + num_experts: Number of experts, needed for reshaping logic. + gate_up: Whether the scale is for w13 (True) or w2 (False). + + Returns: + torch.Tensor: The shuffled scale tensor. + """ + from aiter.ops.shuffle import shuffle_scale_a16w4 + + return shuffle_scale_a16w4(tensor, num_experts, gate_up) + @staticmethod def shuffle_weights( *tensors: torch.Tensor, layout: tuple[int, int] = (16, 16) diff --git a/vllm/_bc_linter.py b/vllm/_bc_linter.py deleted file mode 100644 index 2929a8b..0000000 --- a/vllm/_bc_linter.py +++ /dev/null @@ -1,54 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# vllm/_bc_linter.py -from collections.abc import Callable -from typing import Any, TypeVar, overload - -T = TypeVar("T") - - -@overload -def bc_linter_skip(obj: T) -> T: ... - - -@overload -def bc_linter_skip(*, reason: str | None = ...) -> Callable[[T], T]: ... - - -def bc_linter_skip(obj: Any = None, *, reason: str | None = None): - """ - No-op decorator to mark symbols/files for BC-linter suppression. - - Usage: - @bc_linter_skip - def legacy_api(...): ... - """ - - def _wrap(x: T) -> T: - return x - - return _wrap if obj is None else obj - - -@overload -def bc_linter_include(obj: T) -> T: ... - - -@overload -def bc_linter_include(*, reason: str | None = ...) -> Callable[[T], T]: ... - - -def bc_linter_include(obj: Any = None, *, reason: str | None = None): - """ - Usage: - @bc_linter_include - def public_api(...): ... - """ - - def _wrap(x: T) -> T: - return x - - return _wrap if obj is None else obj - - -__all__ = ["bc_linter_skip", "bc_linter_include"] diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 5f60a3d..1cbebfe 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,11 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, Optional, List, Dict, Any import torch - -import vllm.envs as envs +import torch.nn.functional as F from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.scalar_type import ScalarType @@ -14,13 +13,12 @@ from vllm.utils.flashinfer import ( ) from vllm.utils.math_utils import cdiv -import torch.nn.functional as F import ixformer.inference.functions as ops - -from ixformer.core import config from ixformer.distributed import _distributed as cdist - -logger = init_logger(__name__) +import vllm.envs as envs +from ixformer.core import config +import math +_USE_TORCH_OPS = config.IXFORMER_USE_TORCH_OPS current_platform.import_kernels() @@ -34,25 +32,12 @@ else: except ImportError: from torch.library import impl_abstract as register_fake - -def swiglustep_and_mul_torch(output, input, limit=7.0): - b, n = input.shape - d = n // 2 - - gate = input[:, :d] - up = input[:, d:] - - # 直接写入 output - torch.mul( - torch.clamp(F.silu(gate), max=limit), - torch.clamp(up, -limit, limit), - out=output - ) - - +# activation ops def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: ops.silu_and_mul(x, out) +def silu_and_mul_quant(input: torch.Tensor, out_dim: int, i8_output: torch.Tensor = None, output_scales: torch.Tensor = None, is_dynamic=True) -> None: + return ops.silu_and_mul_quant(input, out_dim, i8_output, output_scales, is_dynamic) def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: ops.gelu_and_mul(x, out) @@ -60,95 +45,31 @@ def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: ops.gelu_tanh_and_mul(x, out) + +def swigluoai_and_mul(out: torch.Tensor, x: torch.Tensor, + alpha: float = 1.702, limit: float = 7.0) -> None: + ops.swigluoai_and_mul(x, out, alpha, limit) - -def fatrelu_and_mul(out: torch.Tensor, x: torch.Tensor, threshold: float = 0.0) -> None: - raise NotImplementedError("FIX soon") - - +#https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: - out.copy_(F.gelu(x, approximate="tanh")) + x = 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x))) + out.copy_(x) return out def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: - out.copy_(F.gelu(x, approximate="tanh")) + x = 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + out.copy_(x) return out def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: - out.copy_(F.gelu(x, approximate="tanh")) + #inplace + out.copy_(x) + out.mul_(torch.sigmoid(x * 1.702)) return out -def swigluoai_and_mul( - out: torch.Tensor, x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0 -) -> None: - ops.swigluoai_and_mul(x, out, alpha, limit) - # return - - -def swigluoai_and_mul_torch( - out: torch.Tensor, x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0 -): - gate, up = x[..., ::2], x[..., 1::2] - gate = gate.clamp(min=None, max=limit) - up = up.clamp(min=-limit, max=limit) - glu = gate * torch.sigmoid(gate * alpha) - gated_output = (up + 1) * glu - out.copy_(gated_output) - - -def rms_norm_qk( - output_q: torch.Tensor, - output_k: torch.Tensor, - input_q: torch.Tensor, - input_k: torch.Tensor, - weight_q: torch.Tensor, - weight_k: torch.Tensor, - epsilon: float, -) -> None: - torch.ops.ixf_ops.rms_norm_qk( - output_q, output_k, input_q, input_k, weight_q, weight_k, epsilon - ) - - -def advance_step_flashattn( - num_seqs: int, - num_queries: int, - block_size: int, - input_tokens: torch.Tensor, - sampled_token_ids: torch.Tensor, - input_positions: torch.Tensor, - seq_lens: torch.Tensor, - slot_mapping: torch.Tensor, - block_tables: torch.Tensor, -) -> None: - """Advance a step on GPU for existing inputs for a multi-step runner""" - return ops.advance_step_flashattn( - num_seqs, - num_queries, - block_size, - input_tokens, - sampled_token_ids, - input_positions, - seq_lens, - slot_mapping, - block_tables, - ) - - -def quant_kv(kv): - amax_, _ = torch.max(torch.abs(kv), dim=-1, keepdim=True) - f_scale = amax_.float() / 127.0 - scales = f_scale.view(kv.shape[:-1]) - - # 量化 - kv = kv / f_scale - kv = torch.clamp(torch.round(kv), -127, 127).to(torch.int8) - return kv, scales - - # page attention ops def paged_attention_v1( out: torch.Tensor, @@ -193,7 +114,6 @@ def paged_attention_v1( blocksparse_head_sliding_step, ) - def paged_attention_v2( out: torch.Tensor, exp_sum: torch.Tensor, @@ -298,9 +218,7 @@ def mla_decode_kvcache_cpu( block_tables: torch.Tensor, seq_lens: torch.Tensor, ) -> None: - torch.ops._C_cpu.mla_decode_kvcache( - out, query, kv_cache, scale, block_tables, seq_lens - ) + torch.ops._C.mla_decode_kvcache(out, query, kv_cache, scale, block_tables, seq_lens) # merge attn states ops @@ -442,65 +360,54 @@ def rotary_embedding( head_size: int, cos_sin_cache: torch.Tensor, is_neox: bool, + rotary_dim: int=-1, ) -> None: - # torch.ops._C.rotary_embedding( - # positions, query, key, head_size, cos_sin_cache, is_neox - # ) - ops.vllm_rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox) + ops.vllm_rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox, rotary_dim=rotary_dim) - -def batched_rotary_embedding( +def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, + key: Optional[torch.Tensor], head_size: int, + cos_sin_cache: torch.Tensor, is_neox: bool, + rot_dim: int, + cos_sin_cache_offsets: torch.Tensor) -> None: + ops.vllm_batched_rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox, rot_dim, + cos_sin_cache_offsets) +def m_rotary_embedding( positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor | None, + key: Optional[torch.Tensor], head_size: int, cos_sin_cache: torch.Tensor, + smrope_section: torch.Tensor, is_neox: bool, - rot_dim: int, - cos_sin_cache_offsets: torch.Tensor, ) -> None: - ops.vllm_batched_rotary_embedding( - positions, - query, - key, - head_size, - cos_sin_cache, - is_neox, - rot_dim, - cos_sin_cache_offsets, - ) - + ops.vllm_m_rotary_embedding(positions, query, key, head_size, + cos_sin_cache, smrope_section, is_neox) # layer norm ops -def rms_norm( - out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float -) -> None: - # torch.ops._C.rms_norm(out, input, weight, epsilon) - input_contiguous = input.contiguous() - ops.rms_norm(input_contiguous, weight, epsilon, out) +def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + epsilon: float) -> None: + ops.rms_norm(input, weight, epsilon, out) -# def fused_add_rms_norm( -# input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, epsilon: float -# ) -> None: -# torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) -def fused_add_rms_norm( - input: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, +def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, epsilon: float, + residual_alpha: Optional[float] = 1) -> None: + output, residual_output = ops.residual_rms_norm(input, weight, epsilon, residual_alpha, residual) + return output, residual_output + +def rms_norm_qk( + output_q: torch.Tensor, + output_k: torch.Tensor, + input_q: torch.Tensor, + input_k: torch.Tensor, + weight_q: torch.Tensor, + weight_k: torch.Tensor, epsilon: float, - residual_alpha: float | None = 1.0, ) -> None: - # torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) - output, residual_output = ops.residual_rms_norm( - input=input, - weight=weight, - residual=residual, - eps=epsilon, - residual_alpha=residual_alpha, - ) - input[:] = output - residual[:] = residual_output + ops.rms_norm_qk( + input_q, input_k, weight_q, weight_k, epsilon, output_q, output_k) def fused_qk_norm_rope( @@ -553,10 +460,7 @@ def apply_repetition_penalties_cuda( output_mask: torch.Tensor, repetition_penalties: torch.Tensor, ) -> None: - # torch.ops._C.apply_repetition_penalties_( - # logits, prompt_mask, output_mask, repetition_penalties - # ) - apply_repetition_penalties_torch( + torch.ops._C.apply_repetition_penalties_( logits, prompt_mask, output_mask, repetition_penalties ) @@ -575,15 +479,14 @@ def apply_repetition_penalties( output_mask: A boolean tensor indicating which tokens appear in the output. repetition_penalties: The repetition penalties of shape (num_seqs, ). """ - if logits.is_cuda and logits.is_contiguous(): - apply_repetition_penalties_cuda( - logits, prompt_mask, output_mask, repetition_penalties - ) - else: - apply_repetition_penalties_torch( - logits, prompt_mask, output_mask, repetition_penalties - ) - + # if logits.is_cuda and logits.is_contiguous(): + # apply_repetition_penalties_cuda( + # logits, prompt_mask, output_mask, repetition_penalties + # ) + # else: + apply_repetition_penalties_torch( + logits, prompt_mask, output_mask, repetition_penalties + ) # fused quant layer norm ops def rms_norm_dynamic_per_token_quant( @@ -700,73 +603,33 @@ if hasattr(torch.ops._C, "awq_dequantize"): return torch.empty((in_c, out_c), dtype=scales.dtype, device=scales.device) -# def awq_gemm( -# input: torch.Tensor, -# qweight: torch.Tensor, -# scales: torch.Tensor, -# qzeros: torch.Tensor, -# split_k_iters: int, -# ) -> torch.Tensor: -# if envs.VLLM_USE_TRITON_AWQ: -# from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton - -# return awq_gemm_triton(input, qweight, scales, qzeros, split_k_iters) -# return torch.ops._C.awq_gemm(input, qweight, scales, qzeros, split_k_iters) - - -def awq_gemm( - input: torch.Tensor, - qweight: torch.Tensor, - scales: torch.Tensor, - qzeros: torch.Tensor, - pack_factor, - group_size: int = 128, -) -> torch.Tensor: +def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, + pack_factor, group_size: int = 128) -> torch.Tensor: + if _USE_TORCH_OPS: + return torch.ops.ixf_ops.wui4a16( + input, qweight, scales, qzeros, None, group_size, "NN" + ) return ops.wui4a16(input, qweight, scales, qzeros, None, group_size, "NN") -if hasattr(torch.ops._C, "awq_gemm"): - @register_fake("_C::awq_gemm") - def _awq_gemm_fake( - input: torch.Tensor, - qweight: torch.Tensor, - scales: torch.Tensor, - qzeros: torch.Tensor, - split_k_iters: torch.SymInt, - ) -> torch.Tensor: - num_in_feats = input.size(0) - return torch.empty( - (split_k_iters, num_in_feats, qweight.size(1) * 8), - dtype=input.dtype, - device=input.device, - ).sum(0) +def custom_gptq_marlin_gemm(input: torch.Tensor, qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, + pack_factor, group_size: int = 128, bias = None) -> torch.Tensor: + if _USE_TORCH_OPS: + return torch.ops.ixf_ops.wui4a16(input, qweight, scales, qzeros, bias, group_size, "NN") + else: + return ops.wui4a16(input, qweight, scales, qzeros, bias, group_size, "NN") # gptq -def gptq_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_gptq_qzeros: torch.Tensor, - b_gptq_scales: torch.Tensor, - b_g_idx: torch.Tensor, - use_exllama: bool, - use_v2_format: bool, - bit: int, -) -> torch.Tensor: - # return torch.ops._C.gptq_gemm( - # a, - # b_q_weight, - # b_gptq_qzeros, - # b_gptq_scales, - # b_g_idx, - # use_exllama, - # use_v2_format, - # bit, - # ) - return ops.gptq_gemm( - a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_exllama, bit - ) +def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, + b_g_idx: torch.Tensor, use_exllama: bool, use_v2_format: bool, + bit: int) -> torch.Tensor: + if use_v2_format: + raise NotImplementedError("gptq_gemm not support use_v2_format") + return ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros ,b_gptq_scales, + b_g_idx, use_exllama, bit) if hasattr(torch.ops._C, "gptq_gemm"): @@ -787,8 +650,8 @@ if hasattr(torch.ops._C, "gptq_gemm"): ) -def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None: - # torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) +def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, + bit: int) -> None: ops.vllm_gptq_shuffle(q_weight, q_perm, bit) @@ -896,12 +759,10 @@ def cutlass_scaled_fp4_mm( def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: - # return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability) return False def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool: - # return torch.ops._C.cutlass_scaled_mm_supports_block_fp8(cuda_device_capability) return False @@ -912,7 +773,8 @@ def cutlass_scaled_mm( scale_b: torch.Tensor, out_dtype: torch.dtype, bias: torch.Tensor | None = None, - format: str | None = "TN", + format: str = "TN", + is_w4a8_linear: bool = False ) -> torch.Tensor: """ `cutlass_scaled_mm` implements a fused version of @@ -936,38 +798,41 @@ def cutlass_scaled_mm( scale_a.shape * [1, 128] == a.shape scale_b.shape * [128, 128] == b.shape """ - target_shape = (*a.shape[:-1], b.shape[1]) assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 - assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype - # a is x, b is weight - m = a.shape[:-1] - n = b.shape[1] * 2 if envs.VLLM_W8A8_LINEAR_USE_W4A8 else b.shape[1] + assert a.ndim <= 3 and b.ndim == 2 + assert bias is None or (bias.dtype == out_dtype and bias.numel() == b.shape[1] \ + if not is_w4a8_linear else bias.numel() == b.shape[1] * 2) + + # NN format: GEMM is a @ b, requires a.shape[1] == b.shape[0]. + # When weight K-dim was padded for 64-alignment (e.g. K=160 → pad_K=192), + # input must be padded to match. + if format == "NN" and a.ndim == 2 and b.ndim == 2 and a.shape[-1] < b.shape[0]: + pad_size = b.shape[0] - a.shape[-1] + a = F.pad(a, (0, pad_size), mode="constant", value=0) + + m = a.shape[0] + n = b.shape[1] if not is_w4a8_linear else b.shape[1] * 2 if format == "TN": b = b.t() - out = torch.empty(m + (n,), dtype=out_dtype, device=a.device) - if envs.VLLM_W8A8_LINEAR_USE_W4A8: - assert format == "NN" - ops.w4a8( - a, - b, - scale_a, - scale_b, - bias=bias, - format=0, - output=out.view(-1, n), - output_dtype=out_dtype, - ) + if a.shape[-1] != b.shape[1]: + padding = b.shape[1] - a.shape[-1] + a = torch.nn.functional.pad(a, (0, padding), mode='constant', value=0) else: - ops.w8a8( - a, - b, - scale_a, - scale_b, - bias, - format=format, - output=out.view(-1, n), - out_dtype=out_dtype, - ) + if a.shape[-1] != b.shape[0]: + padding = b.shape[0] - a.shape[-1] + a = torch.nn.functional.pad(a, (0, padding), mode='constant', value=0) + + if a.ndim == 2: + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + else: + bs = a.shape[1] + out = torch.empty((m, bs, n), dtype=out_dtype, device=a.device) + + if is_w4a8_linear: + assert format=="NN", "W4A8 linear only supports NN format" + ops.w4a8(a, b, scale_a, scale_b, bias=bias, format=0, output=out.view(-1, n), output_dtype=out_dtype) + else: + ops.w8a8(a, b, scale_a, scale_b, bias, format=format, output=out, out_dtype=out_dtype) return out @@ -1000,10 +865,8 @@ def cutlass_scaled_mm_azp( torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) return out.view(*target_shape) - def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool: - # return torch.ops._C.cutlass_sparse_scaled_mm_supported(cuda_device_capability) - return False + return torch.ops._C.cutlass_sparse_scaled_mm_supported(cuda_device_capability) def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool: @@ -1187,7 +1050,7 @@ def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor): return output_tensor -def get_cutlass_pplx_moe_mm_data( +def get_cutlass_batched_moe_mm_data( expert_offsets: torch.Tensor, problem_sizes1: torch.Tensor, problem_sizes2: torch.Tensor, @@ -1210,7 +1073,7 @@ def get_cutlass_pplx_moe_mm_data( multiplication in two grouped MMs used in the fused MoE operation. """ - return torch.ops._C.get_cutlass_pplx_moe_mm_data( + return torch.ops._C.get_cutlass_batched_moe_mm_data( expert_offsets, problem_sizes1, problem_sizes2, @@ -1303,6 +1166,76 @@ def cutlass_fp4_moe_mm( ) +def mxfp8_experts_quant( + input_tensor: torch.Tensor, + problem_sizes: torch.Tensor, + expert_offsets: torch.Tensor, + blockscale_offsets: torch.Tensor, + quant_output: torch.Tensor, + scale_factor: torch.Tensor, +) -> None: + torch.ops._C.mxfp8_experts_quant( + input_tensor, + problem_sizes, + expert_offsets, + blockscale_offsets, + quant_output, + scale_factor, + ) + + +def cutlass_mxfp8_grouped_mm( + a_tensors: torch.Tensor, + b_tensors: torch.Tensor, + a_scales: torch.Tensor, + b_scales: torch.Tensor, + out_tensors: torch.Tensor, + problem_sizes: torch.Tensor, + expert_offsets: torch.Tensor, + blockscale_offsets: torch.Tensor, +) -> None: + torch.ops._C.cutlass_mxfp8_grouped_mm( + a_tensors, + b_tensors, + a_scales, + b_scales, + out_tensors, + problem_sizes, + expert_offsets, + blockscale_offsets, + ) + + +if hasattr(torch.ops._C, "mxfp8_experts_quant"): + + @register_fake("_C::mxfp8_experts_quant") + def _mxfp8_experts_quant_fake( + input_tensor: torch.Tensor, + problem_sizes: torch.Tensor, + expert_offsets: torch.Tensor, + blockscale_offsets: torch.Tensor, + quant_output: torch.Tensor, + scale_factor: torch.Tensor, + ) -> None: + return None + + +if hasattr(torch.ops._C, "cutlass_mxfp8_grouped_mm"): + + @register_fake("_C::cutlass_mxfp8_grouped_mm") + def _cutlass_mxfp8_grouped_mm_fake( + a_tensors: torch.Tensor, + b_tensors: torch.Tensor, + a_scales: torch.Tensor, + b_scales: torch.Tensor, + out_tensors: torch.Tensor, + problem_sizes: torch.Tensor, + expert_offsets: torch.Tensor, + blockscale_offsets: torch.Tensor, + ) -> None: + return None + + # gptq_marlin def gptq_marlin_repack( b_q_weight: torch.Tensor, @@ -2123,7 +2056,7 @@ def scaled_int8_quant( assert symmetric == (azp is None), ( "azp must only be provided for asymmetric quantization." ) - torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) + ops.static_scaled_int8_quant(output, input, scale) return output, scale, azp # dynamic-per-token quantization. @@ -2131,9 +2064,6 @@ def scaled_int8_quant( (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 ) input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) - # torch.ops._C.dynamic_scaled_int8_quant( - # output, input.contiguous(), input_scales, input_azp - # ) ops.dynamic_scaled_int8_quant(output, input, input_scales) return output, input_scales, input_azp @@ -2153,7 +2083,6 @@ def ggml_mul_mat_vec_a8( ) -> torch.Tensor: return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row) - def ggml_mul_mat_a8( W: torch.Tensor, X: torch.Tensor, @@ -2202,7 +2131,6 @@ def ggml_moe_a8_vec( def ggml_moe_get_block_size(quant_type: int) -> int: return torch.ops._C.ggml_moe_get_block_size(quant_type) - # mamba def selective_scan_fwd( u: torch.Tensor, @@ -2223,6 +2151,8 @@ def selective_scan_fwd( block_idx_first_scheduled_token: torch.Tensor | None = None, block_idx_last_scheduled_token: torch.Tensor | None = None, initial_state_idx: torch.Tensor | None = None, + cu_chunk_seqlen: torch.Tensor | None = None, + last_chunk_indices: torch.Tensor | None = None, ): torch.ops._C.selective_scan_fwd( u, @@ -2243,6 +2173,8 @@ def selective_scan_fwd( block_idx_first_scheduled_token, block_idx_last_scheduled_token, initial_state_idx, + cu_chunk_seqlen, + last_chunk_indices, ) @@ -2291,23 +2223,9 @@ def moe_align_block_size( num_tokens_post_pad: torch.Tensor, expert_map: torch.Tensor | None = None, ) -> None: - # torch.ops._moe_C.moe_align_block_size( - # topk_ids, - # num_experts, - # block_size, - # sorted_token_ids, - # experts_ids, - # num_tokens_post_pad, - # expert_map, - # ) - ops.vllm_moe_align_block_size( - topk_ids, - num_experts, - block_size, - sorted_token_ids, - experts_ids, - num_tokens_post_pad, - ) + ops.vllm_moe_align_block_size(topk_ids, num_experts, block_size, + sorted_token_ids, experts_ids, + num_tokens_post_pad) def batched_moe_align_block_size( @@ -2398,6 +2316,23 @@ def moe_wna16_gemm( ) +def router_gemm_bf16_fp32(input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """bf16 x bf16 -> fp32 GEMM via cuBLAS. weight shape: (N, K).""" + return torch.ops._moe_C.router_gemm_bf16_fp32(input, weight) + + +if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "router_gemm_bf16_fp32"): + + @register_fake("_moe_C::router_gemm_bf16_fp32") + def router_gemm_bf16_fp32_fake( + input: torch.Tensor, + weight: torch.Tensor, + ) -> torch.Tensor: + return torch.empty( + input.shape[0], weight.shape[0], dtype=torch.float32, device=input.device + ) + + def dsv3_router_gemm( hidden_states: torch.Tensor, router_weight: torch.Tensor, @@ -2421,53 +2356,15 @@ def topk_softmax( renormalize: bool = False, e_score_correction_bias: torch.Tensor | None = None, ) -> None: - # torch.ops._moe_C.topk_softmax( - # topk_weights, - # topk_ids, - # token_expert_indices, - # gating_output, - # renormalize, - # e_score_correction_bias, - # ) - ops.vllm_moe_topk_softmax( - topk_weights, topk_ids, token_expert_indices, gating_output - ) - - if renormalize: - topk_weights[:] = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - - - -def topk_sigmoid_torch( - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor, - renormalize: bool = False, - e_score_correction_bias: torch.Tensor | None = None, -): - batch, num_experts = gating_output.shape - k = topk_weights.shape[1] - - # Sigmoid + bias - probs = torch.sigmoid(gating_output) - if e_score_correction_bias is not None: - probs = probs + e_score_correction_bias.unsqueeze(0) - - # Top-K - topk_vals, topk_idx = torch.topk(probs, k=k, dim=1) - - # 写入结果 - topk_weights[:] = topk_vals - topk_ids[:] = topk_idx - token_expert_indices[:] = (torch.arange(k, device=gating_output.device).unsqueeze(0) * batch - + torch.arange(batch, device=gating_output.device).unsqueeze(1)) - - # renormalize - if renormalize: - denom = topk_weights.sum(dim=1, keepdim=True) - denom = torch.where(denom > 0, denom, torch.ones_like(denom)) - topk_weights[:] = topk_weights / denom + ops.moe_fused_topk_bias( + gating_output=gating_output, + topk=topk_weights.shape[-1], + scoring_func="softmax", + routed_scaling_factor=1, + renormalize=renormalize, + e_score_correction_bias=e_score_correction_bias, + topk_weight=topk_weights, + topk_ids=topk_ids) def topk_sigmoid( @@ -2478,15 +2375,15 @@ def topk_sigmoid( renormalize: bool = False, e_score_correction_bias: torch.Tensor | None = None, ) -> None: - # torch.ops._moe_C.topk_sigmoid( - # topk_weights, - # topk_ids, - # token_expert_indices, - # gating_output, - # renormalize, - # e_score_correction_bias, - # ) - topk_sigmoid_torch(topk_weights, topk_ids, token_expert_indices, gating_output, renormalize, e_score_correction_bias) + ops.moe_fused_topk_bias( + gating_output=gating_output, + topk=topk_weights.shape[-1], + scoring_func="sigmoid", + routed_scaling_factor=1, + renormalize=renormalize, + e_score_correction_bias=e_score_correction_bias, + topk_weight=topk_weights, + topk_ids=topk_ids) def grouped_topk( @@ -2685,26 +2582,24 @@ def reshape_and_cache_flash( k_scale: torch.Tensor, v_scale: torch.Tensor, ) -> None: - # torch.ops._C_cache_ops.reshape_and_cache_flash( - # key, - # value, - # key_cache, - # value_cache, - # slot_mapping, - # kv_cache_dtype, - # k_scale, - # v_scale, - # ) - ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - slot_mapping, - kv_cache_dtype, - 1.0, - 1.0, - ) + ops.reshape_and_cache_flash(key, value, key_cache, + value_cache, slot_mapping, + kv_cache_dtype, 1.0, 1.0) + + +def reshape_and_cache_flash_mix( + key: torch.Tensor, + value: torch.Tensor, + key_scale: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + key_scale_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, +): + ops.reshape_and_cache_flash_mix(key, value, key_scale, + key_cache, value_cache, key_scale_cache, + slot_mapping, kv_cache_dtype) def concat_and_cache_mla( @@ -2715,37 +2610,11 @@ def concat_and_cache_mla( kv_cache_dtype: str, scale: torch.Tensor, ) -> None: - # torch.ops._C_cache_ops.concat_and_cache_mla( - # kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale - # ) - ops.vllm_concat_and_cache_mla( - kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale - ) + ops.vllm_concat_and_cache_mla(kv_c, k_pe, kv_cache, + slot_mapping, kv_cache_dtype, + scale) -def concat_and_cache_mla_int8( - kv_c_int8: torch.Tensor, - kv_c_scale: torch.Tensor, - k_pe_int8: torch.Tensor, - k_pe_scale: torch.Tensor, - kv_cache: torch.Tensor, - kv_cache_scale: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - scale: torch.Tensor, -) -> None: - ops.vllm_concat_and_cache_mla_int8( - kv_c_int8, - kv_c_scale, - k_pe_int8, - k_pe_scale, - kv_cache, - kv_cache_scale, - slot_mapping, - kv_cache_dtype, - scale, - ) - def concat_and_cache_mla_rope_fused( positions: torch.Tensor, @@ -2799,7 +2668,6 @@ def swap_blocks( but not both on cpu. the block mapping tensor must on cpu. """ - # torch.ops._C_cache_ops.swap_blocks(src, dst, block_size_in_bytes, block_mapping) ops.vllm_swap_blocks(src, dst, block_mapping) @@ -2841,9 +2709,6 @@ def cp_gather_cache( batch_size: int, seq_starts: torch.Tensor | None = None, ) -> None: - # torch.ops._C_cache_ops.cp_gather_cache( - # src_cache, dst, block_table, cu_seq_lens, batch_size, seq_starts - # ) ops.vllm_cp_gather_cache( src_cache, dst, block_table, cu_seq_lens, batch_size, seq_starts ) @@ -2882,19 +2747,309 @@ def indexer_k_quant_and_cache( torch.ops._C_cache_ops.indexer_k_quant_and_cache( k, kv_cache, slot_mapping, quant_block_size, kv_cache_dtype ) + +def indexer_k_cache(k: torch.Tensor, kv_cache: torch.Tensor,slot_mapping: torch.Tensor)-> None: + ops.indexer_k_cache(k, kv_cache, slot_mapping) +def dsa_convert_req_index_to_global_index( + req_id: torch.Tensor, # int32 [num_tokens] + block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req] + token_indices: torch.Tensor, # int32 [num_tokens, num_topk_tokens] + block_size: int, + output: torch.Tensor = None, +) -> torch.Tensor: + """ + Convert request-local token indices to global KV cache indices. -def cp_gather_indexer_k_quant_cache( - kv_cache: torch.Tensor, - dst_k: torch.Tensor, - dst_scale: torch.Tensor, - block_table: torch.Tensor, - cu_seq_lens: torch.Tensor, -) -> None: - torch.ops._C_cache_ops.cp_gather_indexer_k_quant_cache( - kv_cache, dst_k, dst_scale, block_table, cu_seq_lens + For each (token_id, indice_id): + tok = token_indices[token_id, indice_id] + if tok < 0 or tok // block_size >= max_num_blocks_per_req: + output[token_id, indice_id] = -1 + else: + req = req_id[token_id] + block_idx = block_table[req, tok // block_size] + output[token_id, indice_id] = block_idx * block_size + tok % block_size + """ + return ops.dsa_convert_req_index_to_global_index( + req_id, block_table, token_indices, block_size, output ) + +def ref_mqa_logits( + q: torch.Tensor, # [num_tokens, n_head, head_dim] - 可能已量化 + k: torch.Tensor, # [num_blocks, block_size, head_dim] 或展开形式 - 可能已量化 + weights: torch.Tensor, # [num_tokens, n_head, 1] - 权重 + cu_seqlen_ks: torch.Tensor, # 序列起始位置 + cu_seqlen_ke: torch.Tensor, # 序列结束位置 +) -> torch.Tensor: + """ + 多查询注意力logits计算的PyTorch等价实现 + """ + M, H, D = q.shape + N = k.shape[0] + device = q.device + # 初始化输出logits [M, N] + logits = torch.full((M, N), -float('inf'), device=device, dtype=torch.float32) + for i in range(M): + seq_start = cu_seqlen_ks[i] + seq_end = cu_seqlen_ke[i] + + if seq_start >= seq_end: + continue + + #当前查询的Q [H, D] + q_i = q[i] # [H, D] + + seq_k = k[seq_start:seq_end] # [seq_len, head_dim] + + # 计算注意力分数 [H, seq_len] + attention_scores = torch.matmul(q_i, seq_k.T) # BF16计算 + attention_scores = F.relu(attention_scores) + + # 应用权重 [H, seq_len] + attention_scores_f32 = attention_scores.float() + weights_i = weights[i].unsqueeze(1) # [H, 1] + weighted_scores = attention_scores_f32 * weights_i # [H, seq_len] + + # 汇总所有头的logits [seq_len] + logits_i = torch.sum(weighted_scores, dim=0) # [seq_len] + + # 将结果填充到输出logits的对应位置 + logits[i, seq_start:seq_end] = logits_i + + return logits + +def ref_paged_mqa_logits( + q: torch.Tensor, + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, + clean_logits: bool = True +) -> torch.Tensor: + """使用分页KV缓存计算FP8多查询注意力logits的PyTorch实现 + + Args: + q: 查询张量 [B, next_n, H, D] + kv_cache: 分页KV缓存 [num_blocks, block_size, 1, D] + weights: 权重张量 [B * next_n, H], dtype=torch.float32 + context_lens: 上下文长度 [B], dtype=int32 + block_tables: 块映射表 [B, max_blocks], dtype=int32 + schedule_metadata: 调度元数据 + max_model_len: 最大序列长度,用于确定输出logits大小 + + Returns: + Logits张量 [B * next_n, max_model_len], dtype=torch.float32 + """ + def reassemble_k_from_paged_cache( + kv_cache: torch.Tensor, + block_table: torch.Tensor, + context_len: int, + head_dim: int, + block_size: int + ) -> torch.Tensor: + """从分页缓存中重组K值""" + num_blocks_needed = (context_len + block_size - 1) // block_size + valid_blocks = block_table[:num_blocks_needed] + device = kv_cache.device + # 初始化输出K序列 [context_len, head_dim] + k_sequence = torch.zeros(context_len, head_dim, device=device, dtype=kv_cache.dtype) + token_offset = 0 + for block_idx in valid_blocks: + if block_idx < 0: + break + # 当前块中的token数量 + tokens_in_block = min(block_size, context_len - token_offset) + if tokens_in_block <= 0: + break + # 从缓存块中提取K值 + block_data = kv_cache[block_idx] # [block_size, 1, D] + + # 提取K值 + k_sequence[token_offset:token_offset + tokens_in_block] = block_data[:tokens_in_block, 0, :head_dim] # [tokens_in_block, D] + token_offset += tokens_in_block + + return k_sequence + + def compute_mqa_logits( + q: torch.Tensor, # [next_n, H, D] + k: torch.Tensor, # [context_len, D] + weights: torch.Tensor, # [next_n, H] + context_len: int, + max_model_len: int + ) -> torch.Tensor: + """计算多查询注意力logits""" + next_n, H, D = q.shape + device = q.device + + # 初始化批次logits [next_n, max_model_len] + batch_logits = torch.full((next_n, max_model_len), -float('inf'), + device=device, dtype=torch.float32) + + # 扩展K以匹配多头 [context_len, H, D] + k_expanded = k.unsqueeze(1).expand(-1, H, -1) # [context_len, H, D] + + # 转置以便矩阵乘法 + q_transposed = q.transpose(0, 1) # [H, next_n, D] + k_transposed = k_expanded.transpose(0, 1) # [H, context_len, D] + # 批量计算注意力分数 [H, next_n, context_len] + attention_scores = torch.bmm(q_transposed, k_transposed.transpose(1, 2)) # [H, next_n, context_len] + attention_scores = F.relu(attention_scores) + # 应用权重并汇总所有头 [next_n, context_len] + weights_expanded = weights.transpose(0, 1).unsqueeze(2) # [H, next_n, 1] + weighted_scores = attention_scores * weights_expanded # [H, next_n, context_len] + logits_per_token = weighted_scores.sum(dim=0) # [next_n, context_len] + + # 填充到输出logits中 + batch_logits[:, :context_len] = logits_per_token + + return batch_logits + def clean_logits_tensor( + logits: torch.Tensor, + context_lens: torch.Tensor, + next_n: int, + max_model_len: int + ) -> torch.Tensor: + """清理logits张量,将超出上下文长度的位置设为负无穷""" + B = len(context_lens) + + for batch_idx in range(B): + context_len = context_lens[batch_idx].item() + if context_len >= max_model_len: + continue + + # 当前批次在logits中的位置 + batch_start = batch_idx * next_n + batch_end = (batch_idx + 1) * next_n + + # 将超出上下文长度的位置设为负无穷 + logits[batch_start:batch_end, context_len:] = -float('inf') + + return logits + + B, next_n, H, D = q.shape + num_blocks, block_size, _, cache_stride = kv_cache.shape + device = q.device + + # 初始化输出logits [B * next_n, max_model_len] + logits = torch.full((B * next_n, max_model_len), -float('inf'), + device=device, dtype=torch.float32) + + # 处理每个批次 + for batch_idx in range(B): + context_len = context_lens[batch_idx].item() + if context_len == 0: + continue + + # 当前批次的查询 [next_n, H, D] + batch_q = q[batch_idx] # [next_n, H, D] + + # 当前批次的权重 [next_n, H] + batch_weights_start = batch_idx * next_n + batch_weights_end = (batch_idx + 1) * next_n + batch_weights = weights[batch_weights_start:batch_weights_end] # [next_n, H] + + # 从分页缓存中重组K值 + batch_k = reassemble_k_from_paged_cache( + kv_cache, block_tables[batch_idx], context_len, D, block_size + ) # [context_len, D] + # 计算多查询注意力logits + batch_logits = compute_mqa_logits( + batch_q, batch_k, batch_weights, context_len, max_model_len + ) # [next_n, max_model_len] + + # 填充到输出logits中 + logits[batch_weights_start:batch_weights_end] = batch_logits + + if clean_logits: + # 清理logits:将超出上下文长度的位置设为负无穷 + logits = clean_logits_tensor(logits, context_lens, next_n, max_model_len) + + return logits + +def sparse_prefill_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, +): + return ops.sparse_flash_attn(q, kv, indices, sm_scale, d_v) + +def ref_sparse_prefill_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, +): + """ + 稀疏注意力预填充内核的PyTorch实现 + + Args: + - q: [s_q, h_q, d_qk], bfloat16 + - kv: [s_kv, h_kv, d_qk], bfloat16 + - indices: [s_q, h_kv, topk], int32. 无效索引设为-1或>=s_kv + - sm_scale: float + - d_v: 值向量的维度,只能为512 + + Returns: + - (output, max_logits, lse) + - output: [s_q, h_q, d_v], bfloat16 + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, 以2为底的对数求和指数 + """ + def ref_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sm_scale: float, + ) -> torch.Tensor: + query = query * sm_scale + dtype = query.dtype + device = query.device + attn = torch.einsum("qhd,khd->hqk", query, key) + attn = attn.to(torch.float) + attn = torch.softmax(attn, dim=-1) + value = value.to(torch.float) + out = torch.einsum("hqk,khd->qhd", attn, value) + out = out.to(device).to(dtype) + return out + s_q, h_q, d_qk = q.shape + s_kv, h_kv, _ = kv.shape + _, _, topk = indices.shape + + device = q.device + dtype = q.dtype + + # 分离K和V + k = kv # [s_kv, h_kv, d_qk] + v = kv[:, :, :d_v] # [s_kv, h_kv, d_v] + + # 初始化输出 + output = torch.zeros(s_q, h_q, d_v, device=device, dtype=dtype) + # 处理每个查询位置 + for i in range(s_q): + # 当前查询 [h_q, d_qk] + q_i = q[i].unsqueeze(0) # [1, h_q, d_qk] + # 获取当前查询位置的稀疏索引 [topk] + sparse_indices = indices[i, 0] # [topk] + # 过滤有效索引 (>=0 且 < s_kv) + valid_mask = (sparse_indices >= 0) & (sparse_indices < s_kv) + valid_indices = sparse_indices[valid_mask] + # 获取有效的K和V + valid_k = k[valid_indices].repeat(1, h_q, 1) # [valid_len, h_q, d_qk] + valid_v = v[valid_indices].repeat(1, h_q, 1) # [valid_len, h_q, d_v] + out = ref_masked_attention( + q_i, + valid_k, + valid_v, + sm_scale + ) + out = out.view(h_q, d_v) + output[i].copy_(out, non_blocking=True) + return output def get_device_attribute(attribute: int, device: int) -> int: return torch.ops._C_cuda_utils.get_device_attribute(attribute, device) @@ -2902,10 +3057,9 @@ def get_device_attribute(attribute: int, device: int) -> int: def get_max_shared_memory_per_block_device_attribute(device: int) -> int: # ruff: noqa: E501 - # return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute( - # device - # ) - return 32 * 1024 + return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute( + device + ) # custom ar @@ -2951,7 +3105,6 @@ def register_graph_buffers( ) -> None: torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) - def allocate_shared_buffer_and_handle(size: int) -> tuple[int, torch.Tensor]: return torch.ops._C_custom_ar.allocate_shared_buffer_and_handle(size) @@ -3014,7 +3167,6 @@ def get_flash_mla_metadata( cache_seqlens, num_heads_per_head_k, num_heads_k ) - def flash_mla_with_kvcache( q: torch.Tensor, k_cache: torch.Tensor, @@ -3164,211 +3316,6 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"): return torch.empty((M, N), dtype=out_dtype) -# Add our new features here.. - - -# moe -def invoke_fused_moe_kernel( - A: torch.Tensor, - B: torch.Tensor, - C: torch.Tensor, - A_scale: torch.Tensor | None, - B_scale: torch.Tensor | None, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, - mul_routed_weight: bool, - top_k: int, - config: dict[str, "any"], - compute_type, - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - block_shape: list[int] | None = None, - bias: torch.Tensor | None = None, -) -> None: - ops.vllm_invoke_fused_moe_kernel( - A, - B, - C, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - mul_routed_weight, - top_k, - config["BLOCK_SIZE_M"], - bias=bias - ) - -# broadcast -class Async_helper: - # For now, the comm and the other kernels are in the same stream, so we can remove the stream wait.. - def wait( - self, - ): - return True - - -def broadcast(tensor, src=0, group=None, async_op=False): - cdist.broadcast(tensor, src, group, async_op=True) - if async_op: - return Async_helper() - else: - pass - - -# w8a16 -def linear_w8a16( - x: torch.Tensor, - qweight: torch.Tensor, - scales: torch.Tensor, - group_size: int = -1, - format: str = "TN", -) -> torch.Tensor: - return ops.w8a16(x, qweight, scales, format="TN", group_size=group_size) - - -## lora sgmv / bgmv -def sbgmv_expand( - x: torch.Tensor, - w_t_all: torch.Tensor, - y: torch.Tensor, - b_seq_start_loc: torch.Tensor = None, - seq_len_tensor: torch.Tensor = None, - lora_indices_tensor: torch.Tensor = None, - batches: int = -1, - max_seq_length: int = -1, - token_nums: int = -1, - add_input=True, -): - """ - x: inputs - w_t_all: lora weight - y: output - - y += x@wt_t_all - """ - assert x.dtype in [torch.float16, torch.bfloat16, torch.float32] - assert w_t_all.dtype in [ - torch.float16, - torch.bfloat16, - ] - - assert x.is_contiguous() - # assert y.is_contiguous() - if x.dtype == torch.float: - x = x.to(w_t_all.dtype) - - if w_t_all.ndim == 4: # shape:(lora_num,1,size,rank) - assert w_t_all.size(1) == 1 - w_t_all = w_t_all.squeeze(dim=1) - else: - assert w_t_all.ndim == 3 # shape:(lora_num,size,rank) - assert w_t_all.is_contiguous() - - assert add_input == True - - lora_indices = lora_indices_tensor.cpu().tolist() - lora_num = w_t_all.shape[0] - - ## 单一lora model, 且所有request均使用lora - if lora_num == 1 and all(x == lora_indices[0] for x in lora_indices): - if lora_indices[0] != -1: - w_t = w_t_all[0] - y += torch.matmul(x, w_t.t()) - ## 多个lora model - else: - ## prefill - if batches != -1: - for i, lora_id, start, seq_len in zip( - range(batches), lora_indices, b_seq_start_loc, seq_len_tensor - ): - if lora_id != -1: - xi = x[start : start + seq_len] - w_t = w_t_all[lora_id] - y[start : start + seq_len] += xi @ w_t.t() - ## decode - else: - batches = x.shape[0] - for i, lora_id in zip(range(batches), lora_indices): - if lora_id != -1: - xi = x[i].unsqueeze(0) - w_t = w_t_all[lora_id] - y[i] += (xi @ w_t.t()).squeeze(0) - - return y - - -def sbgmv_shrink( - x: torch.Tensor, - w_t_all: torch.Tensor, - y: torch.Tensor, - b_seq_start_loc: torch.Tensor = None, - seq_len_tensor: torch.Tensor = None, - lora_indices_tensor: torch.Tensor = None, - batches: int = -1, - max_seq_length: int = -1, - token_nums: int = -1, - scale: float = 1.0, -): - """ - xx: inputs - w_t_all: lora weight - y: output - scale: float - - y = x@w_t_all * scale - """ - assert x.dtype == w_t_all.dtype - assert x.dtype in [torch.float16, torch.bfloat16] - assert x.is_contiguous() - assert y.is_contiguous() - - if w_t_all.ndim == 4: # shape:(lora_num,1,size,rank) - assert w_t_all.size(1) == 1 - w_t_all = w_t_all.squeeze(dim=1) - else: - assert w_t_all.ndim == 3 # shape:(lora_num,size,rank) - assert w_t_all.is_contiguous() - - lora_num = w_t_all.shape[0] - lora_indices = lora_indices_tensor.cpu().tolist() - - ## 单一lora model, 且所有request均使用lora - if lora_num == 1 and all(x == lora_indices[0] for x in lora_indices): - if lora_indices[0] != -1: - w_t = w_t_all[0] - y = torch.matmul(x, w_t.t()) * scale - ## 多个lora model - else: - ## prefill - if batches != -1: - for i, lora_id, start, seq_len in zip( - range(batches), lora_indices, b_seq_start_loc, seq_len_tensor - ): - if lora_id != -1: - xi = x[start : start + seq_len] - w_t = w_t_all[lora_id] - y[start : start + seq_len] = (xi @ w_t.t()) * scale - ## decode - else: - batches = x.shape[0] - for i, lora_id in zip(range(batches), lora_indices): - if lora_id != -1: - xi = x[i].unsqueeze(0) - w_t = w_t_all[lora_id] - y[i] = (xi @ w_t.t()).squeeze(0) * scale - - return y - - -def dynamic_scaled_quant_dynamic_int8(x, input_scales=None, int8_out=None, scales=None): - return ops.dynamic_scaled_quant_smoothquant(x, input_scales, int8_out, scales) - - class CPUDNNLGEMMHandler: def __init__(self) -> None: self.handler_tensor: torch.Tensor | None = None @@ -3799,50 +3746,246 @@ if hasattr(torch.ops._C, "hadacore_transform"): @register_fake("_C::hadacore_transform") def _hadacore_transform_fake(x: torch.Tensor, inplace: bool) -> torch.Tensor: return torch.empty_like(x) if not inplace else x - - +# Add our new features here.. def gather_cache( - src_cache: torch.Tensor, # [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] - dst: torch.Tensor, # [TOT_TOKENS, ENTRIES...] - block_table: torch.Tensor, # [BATCH, BLOCK_INDICES] - cu_seq_lens: torch.Tensor, # [BATCH+1] + src_cache: torch.Tensor, # [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + dst: torch.Tensor, # [TOT_TOKENS, ENTRIES...] + block_table: torch.Tensor, # [BATCH, BLOCK_INDICES] + cu_seq_lens: torch.Tensor, # [BATCH+1] batch_size: int, - seq_starts: torch.Tensor = None, + seq_starts: torch.Tensor = None ): - ops.vllm_gather_cache( - src_cache, dst, block_table, cu_seq_lens, batch_size, seq_starts - ) - - + ops.vllm_gather_cache(src_cache, dst, block_table, cu_seq_lens, batch_size, seq_starts) + def gather_cache_int8( - src_cache: torch.Tensor, # [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] - src_cache_scale: torch.Tensor, # [NUM_BLOCKS, BLOCK_SIZE, 2] + src_cache: torch.Tensor, # [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...] + src_cache_scale: torch.Tensor,# [NUM_BLOCKS, BLOCK_SIZE, 2] kv_lora_rank: int, - dst: torch.Tensor, # [TOT_TOKENS, ENTRIES...] - block_table: torch.Tensor, # [BATCH, BLOCK_INDICES] - cu_seq_lens: torch.Tensor, # [BATCH+1] + dst: torch.Tensor, # [TOT_TOKENS, ENTRIES...] + block_table: torch.Tensor, # [BATCH, BLOCK_INDICES] + cu_seq_lens: torch.Tensor, # [BATCH+1] batch_size: int, - seq_starts: torch.Tensor = None, + seq_starts: torch.Tensor = None ): - ops.vllm_gather_cache_int8( - src_cache, - src_cache_scale, - kv_lora_rank, - dst, - block_table, - cu_seq_lens, - batch_size, - seq_starts, + ops.vllm_gather_cache_int8(src_cache,src_cache_scale, kv_lora_rank, dst, block_table, cu_seq_lens, batch_size, seq_starts) + +def quant_kv(kv): + amax_, _ = torch.max(torch.abs(kv), dim=-1, keepdim=True) + f_scale = amax_.float() / 127.0 + scales = f_scale.view(kv.shape[:-1]) + + # 量化 + kv = kv / f_scale + kv = torch.clamp(torch.round(kv), -127, 127).to(torch.int8) + return kv, scales + + +def concat_and_cache_mla_int8( + kv_c_int8: torch.Tensor, + kv_c_scale: torch.Tensor, + k_pe_int8: torch.Tensor, + k_pe_scale: torch.Tensor, + kv_cache: torch.Tensor, + kv_cache_scale: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + scale: torch.Tensor, +) -> None: + ops.vllm_concat_and_cache_mla_int8(kv_c_int8,kv_c_scale, k_pe_int8, k_pe_scale, kv_cache, kv_cache_scale, slot_mapping, kv_cache_dtype, scale) +def invoke_fused_moe_kernel( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + config: Dict[str, Any], + compute_type, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + block_shape: Optional[List[int]] = None, + bias: Optional[torch.Tensor] = None, +) -> None: + ops.vllm_invoke_fused_moe_kernel( + A, + B, + C, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + mul_routed_weight, + top_k, + config['BLOCK_SIZE_M'], + bias=bias ) +# broadcast +class Async_helper(): + # For now, the comm and the other kernels are in the same stream, so we can remove the stream wait.. + def wait(self,): + return True + + +def broadcast(tensor, src=0, group=None, async_op=False): + cdist.broadcast(tensor,src,group,async_op=True) + if async_op: + return Async_helper() + else: + pass + + +# w8a16 +def linear_w8a16(x: torch.Tensor, qweight: torch.Tensor, scales:torch.Tensor, + group_size: int = -1, format: str = "TN")-> torch.Tensor: + return ops.w8a16(x, qweight, scales, format="TN", group_size=group_size) + + +## lora sgmv / bgmv +def sbgmv_expand(x: torch.Tensor, + w_t_all: torch.Tensor, + y: torch.Tensor, + b_seq_start_loc: torch.Tensor = None, + seq_len_tensor: torch.Tensor = None, + lora_indices_tensor: torch.Tensor = None, + batches: int = -1, + max_seq_length: int = -1, + token_nums: int = -1, + add_input=True, + ): + ''' + x: inputs + w_t_all: lora weight + y: output + + y += x@wt_t_all + ''' + assert x.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert w_t_all.dtype in [ + torch.float16, + torch.bfloat16, + ] + + assert x.is_contiguous() + # assert y.is_contiguous() + if x.dtype == torch.float: + x = x.to(w_t_all.dtype) + + if w_t_all.ndim == 4: # shape:(lora_num,1,size,rank) + assert w_t_all.size(1) == 1 + w_t_all = w_t_all.squeeze(dim=1) + else: + assert w_t_all.ndim == 3 # shape:(lora_num,size,rank) + assert w_t_all.is_contiguous() + + assert add_input == True + + lora_indices = lora_indices_tensor.cpu().tolist() + lora_num = w_t_all.shape[0] + + ## 单一lora model, 且所有request均使用lora + if lora_num == 1 and all(x == lora_indices[0] for x in lora_indices): + if lora_indices[0] != -1: + w_t = w_t_all[0] + y += torch.matmul(x, w_t.t()) + ## 多个lora model + else: + ## prefill + if batches != -1: + for i, lora_id, start, seq_len in zip(range(batches), lora_indices, b_seq_start_loc, seq_len_tensor): + if lora_id != -1: + xi = x[start: start+seq_len] + w_t = w_t_all[lora_id] + y[start:start+seq_len] += (xi @ w_t.t()) + ## decode + else: + batches = x.shape[0] + for i, lora_id in zip(range(batches), lora_indices): + if lora_id != -1: + xi = x[i].unsqueeze(0) + w_t = w_t_all[lora_id] + y[i] += (xi @ w_t.t()).squeeze(0) + + return y + + +def sbgmv_shrink(x: torch.Tensor, + w_t_all: torch.Tensor, + y: torch.Tensor, + b_seq_start_loc: torch.Tensor = None, + seq_len_tensor: torch.Tensor = None, + lora_indices_tensor: torch.Tensor = None, + batches: int = -1, + max_seq_length: int = -1, + token_nums: int = -1, + scale: float = 1.0,): + """ + xx: inputs + w_t_all: lora weight + y: output + scale: float + + y = x@w_t_all * scale + """ + assert x.dtype == w_t_all.dtype + assert x.dtype in [torch.float16, torch.bfloat16] + assert x.is_contiguous() + assert y.is_contiguous() + + if w_t_all.ndim == 4: # shape:(lora_num,1,size,rank) + assert w_t_all.size(1) == 1 + w_t_all = w_t_all.squeeze(dim=1) + else: + assert w_t_all.ndim == 3 # shape:(lora_num,size,rank) + assert w_t_all.is_contiguous() + + lora_num = w_t_all.shape[0] + lora_indices = lora_indices_tensor.cpu().tolist() + + ## 单一lora model, 且所有request均使用lora + if lora_num == 1 and all(x == lora_indices[0] for x in lora_indices): + if lora_indices[0] != -1: + w_t = w_t_all[0] + y = torch.matmul(x, w_t.t()) * scale + ## 多个lora model + else: + ## prefill + if batches != -1: + for i, lora_id, start, seq_len in zip(range(batches), lora_indices, b_seq_start_loc, seq_len_tensor): + if lora_id != -1: + xi = x[start: start+seq_len] + w_t = w_t_all[lora_id] + y[start:start+seq_len] = (xi @ w_t.t())* scale + ## decode + else: + batches = x.shape[0] + for i, lora_id in zip(range(batches), lora_indices): + if lora_id != -1: + xi = x[i].unsqueeze(0) + w_t = w_t_all[lora_id] + y[i] = (xi @ w_t.t()).squeeze(0) * scale + + return y + +def dynamic_scaled_quant_dynamic_int8(x, input_scales=None, int8_out=None, scales=None): + return ops.dynamic_scaled_quant_smoothquant(x, input_scales, int8_out, scales) + + def rejection_greedy_sample_torch( - output_token_ids: torch.Tensor, # [batch_size, max_spec_len + 1] - cu_num_draft_tokens: torch.Tensor, # [batch_size] (前缀和形式) - draft_token_ids: torch.Tensor, # [num_tokens] - target_argmax: torch.Tensor, # [num_tokens] - bonus_token_ids: torch.Tensor, # [batch_size] - is_greedy: torch.Tensor = None, # [batch_size] 或 None + output_token_ids: torch.Tensor, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens: torch.Tensor, # [batch_size] (前缀和形式) + draft_token_ids: torch.Tensor, # [num_tokens] + target_argmax: torch.Tensor, # [num_tokens] + bonus_token_ids: torch.Tensor, # [batch_size] + is_greedy: torch.Tensor = None, # [batch_size] 或 None ): """ 完全等价于 rejection_greedy_sample_kernel 的 PyTorch 实现 @@ -3886,18 +4029,17 @@ def rejection_greedy_sample_torch( return output_token_ids # 原位修改 - def rejection_random_sample_torch( - output_token_ids: torch.Tensor, # [batch_size, max_spec_len + 1] - cu_num_draft_tokens: torch.Tensor, # [batch_size] (前缀和形式) - draft_token_ids: torch.Tensor, # [num_tokens] - draft_probs: torch.Tensor | None, # [num_tokens, vocab_size] 或 None - target_probs: torch.Tensor, # [num_tokens, vocab_size] - bonus_token_ids: torch.Tensor, # [batch_size] - recovered_token_ids: torch.Tensor, # [num_tokens] - uniform_probs: torch.Tensor, # [num_tokens] (0~1均匀分布) - is_greedy: torch.Tensor | None, # [batch_size] 或 None - NO_DRAFT_PROBS: bool = False, # 是否忽略draft_probs + output_token_ids: torch.Tensor, # [batch_size, max_spec_len + 1] + cu_num_draft_tokens: torch.Tensor, # [batch_size] (前缀和形式) + draft_token_ids: torch.Tensor, # [num_tokens] + draft_probs: torch.Tensor | None, # [num_tokens, vocab_size] 或 None + target_probs: torch.Tensor, # [num_tokens, vocab_size] + bonus_token_ids: torch.Tensor, # [batch_size] + recovered_token_ids: torch.Tensor, # [num_tokens] + uniform_probs: torch.Tensor, # [num_tokens] (0~1均匀分布) + is_greedy: torch.Tensor | None, # [batch_size] 或 None + NO_DRAFT_PROBS: bool = False, # 是否忽略draft_probs ): batch_size = output_token_ids.size(0) max_spec_len_plus_1 = output_token_ids.size(1) @@ -3928,9 +4070,7 @@ def rejection_random_sample_torch( if NO_DRAFT_PROBS: draft_prob = 1.0 else: - assert ( - draft_probs is not None - ), "draft_probs不能为None当NO_DRAFT_PROBS=False" + assert draft_probs is not None, "draft_probs不能为None当NO_DRAFT_PROBS=False" draft_prob = draft_probs[global_pos, draft_token_id] # 获取target概率和均匀随机数 @@ -3952,287 +4092,4 @@ def rejection_random_sample_torch( return output_token_ids - weak_ref_tensor = ops.weak_ref_tensor - - -def indexer_k_cache(k: torch.Tensor, kv_cache: torch.Tensor,slot_mapping: torch.Tensor)-> None: - num_tokens, head_dim = k.shape - _, block_size, cache_stride = kv_cache.shape - assert head_dim == cache_stride - for i in range(num_tokens): - block_idx = torch.div(slot_mapping[i], block_size, rounding_mode="floor") - block_offset = slot_mapping[i] % block_size - kv_cache[block_idx, block_offset, :] = k[i] - - -def ref_mqa_logits( - q: torch.Tensor, # [num_tokens, n_head, head_dim] - 可能已量化 - k: torch.Tensor, # [num_blocks, block_size, head_dim] 或展开形式 - 可能已量化 - weights: torch.Tensor, # [num_tokens, n_head, 1] - 权重 - cu_seqlen_ks: torch.Tensor, # 序列起始位置 - cu_seqlen_ke: torch.Tensor, # 序列结束位置 -) -> torch.Tensor: - """ - 多查询注意力logits计算的PyTorch等价实现 - """ - - M, H, D = q.shape - N = k.shape[0] - device = q.device - # 初始化输出logits [M, N] - logits = torch.full((M, N), -float('inf'), device=device, dtype=torch.float32) - for i in range(M): - seq_start = cu_seqlen_ks[i] - seq_end = cu_seqlen_ke[i] - - if seq_start >= seq_end: - continue - - #当前查询的Q [H, D] - q_i = q[i] # [H, D] - - seq_k = k[seq_start:seq_end] # [seq_len, head_dim] - - # 计算注意力分数 [H, seq_len] - attention_scores = torch.matmul(q_i, seq_k.T) # BF16计算 - attention_scores = F.relu(attention_scores) - - # 应用权重 [H, seq_len] - attention_scores_f32 = attention_scores.float() - weights_i = weights[i].unsqueeze(1) # [H, 1] - weighted_scores = attention_scores_f32 * weights_i # [H, seq_len] - - # 汇总所有头的logits [seq_len] - logits_i = torch.sum(weighted_scores, dim=0) # [seq_len] - - # 将结果填充到输出logits的对应位置 - logits[i, seq_start:seq_end] = logits_i - - return logits - - -def ref_paged_mqa_logits( - q: torch.Tensor, - kv_cache: torch.Tensor, - weights: torch.Tensor, - context_lens: torch.Tensor, - block_tables: torch.Tensor, - max_model_len: int, - clean_logits: bool = True -) -> torch.Tensor: - """使用分页KV缓存计算FP8多查询注意力logits的PyTorch实现 - - Args: - q: 查询张量 [B, next_n, H, D] - kv_cache: 分页KV缓存 [num_blocks, block_size, 1, D] - weights: 权重张量 [B * next_n, H], dtype=torch.float32 - context_lens: 上下文长度 [B], dtype=int32 - block_tables: 块映射表 [B, max_blocks], dtype=int32 - schedule_metadata: 调度元数据 - max_model_len: 最大序列长度,用于确定输出logits大小 - - Returns: - Logits张量 [B * next_n, max_model_len], dtype=torch.float32 - """ - def reassemble_k_from_paged_cache( - kv_cache: torch.Tensor, - block_table: torch.Tensor, - context_len: int, - head_dim: int, - block_size: int - ) -> torch.Tensor: - """从分页缓存中重组K值""" - num_blocks_needed = (context_len + block_size - 1) // block_size - valid_blocks = block_table[:num_blocks_needed] - device = kv_cache.device - # 初始化输出K序列 [context_len, head_dim] - k_sequence = torch.zeros(context_len, head_dim, device=device, dtype=kv_cache.dtype) - token_offset = 0 - for block_idx in valid_blocks: - if block_idx < 0: - break - # 当前块中的token数量 - tokens_in_block = min(block_size, context_len - token_offset) - if tokens_in_block <= 0: - break - # 从缓存块中提取K值 - block_data = kv_cache[block_idx] # [block_size, 1, D] - - # 提取K值 - k_sequence[token_offset:token_offset + tokens_in_block] = block_data[:tokens_in_block, 0, :head_dim] # [tokens_in_block, D] - token_offset += tokens_in_block - - return k_sequence - - def compute_mqa_logits( - q: torch.Tensor, # [next_n, H, D] - k: torch.Tensor, # [context_len, D] - weights: torch.Tensor, # [next_n, H] - context_len: int, - max_model_len: int - ) -> torch.Tensor: - """计算多查询注意力logits""" - next_n, H, D = q.shape - device = q.device - - # 初始化批次logits [next_n, max_model_len] - batch_logits = torch.full((next_n, max_model_len), -float('inf'), - device=device, dtype=torch.float32) - - # 扩展K以匹配多头 [context_len, H, D] - k_expanded = k.unsqueeze(1).expand(-1, H, -1) # [context_len, H, D] - - # 转置以便矩阵乘法 - q_transposed = q.transpose(0, 1) # [H, next_n, D] - k_transposed = k_expanded.transpose(0, 1) # [H, context_len, D] - # 批量计算注意力分数 [H, next_n, context_len] - attention_scores = torch.bmm(q_transposed, k_transposed.transpose(1, 2)) # [H, next_n, context_len] - attention_scores = F.relu(attention_scores) - # 应用权重并汇总所有头 [next_n, context_len] - weights_expanded = weights.transpose(0, 1).unsqueeze(2) # [H, next_n, 1] - weighted_scores = attention_scores * weights_expanded # [H, next_n, context_len] - logits_per_token = weighted_scores.sum(dim=0) # [next_n, context_len] - - # 填充到输出logits中 - batch_logits[:, :context_len] = logits_per_token - - return batch_logits - - def clean_logits_tensor( - logits: torch.Tensor, - context_lens: torch.Tensor, - next_n: int, - max_model_len: int - ) -> torch.Tensor: - """清理logits张量,将超出上下文长度的位置设为负无穷""" - B = len(context_lens) - - for batch_idx in range(B): - context_len = context_lens[batch_idx].item() - if context_len >= max_model_len: - continue - - # 当前批次在logits中的位置 - batch_start = batch_idx * next_n - batch_end = (batch_idx + 1) * next_n - - # 将超出上下文长度的位置设为负无穷 - logits[batch_start:batch_end, context_len:] = -float('inf') - - return logits - - B, next_n, H, D = q.shape - num_blocks, block_size, _, cache_stride = kv_cache.shape - device = q.device - - # 初始化输出logits [B * next_n, max_model_len] - logits = torch.full((B * next_n, max_model_len), -float('inf'), - device=device, dtype=torch.float32) - - # 处理每个批次 - for batch_idx in range(B): - context_len = context_lens[batch_idx].item() - if context_len == 0: - continue - - # 当前批次的查询 [next_n, H, D] - batch_q = q[batch_idx] # [next_n, H, D] - - # 当前批次的权重 [next_n, H] - batch_weights_start = batch_idx * next_n - batch_weights_end = (batch_idx + 1) * next_n - batch_weights = weights[batch_weights_start:batch_weights_end] # [next_n, H] - - # 从分页缓存中重组K值 - batch_k = reassemble_k_from_paged_cache( - kv_cache, block_tables[batch_idx], context_len, D, block_size - ) # [context_len, D] - # 计算多查询注意力logits - batch_logits = compute_mqa_logits( - batch_q, batch_k, batch_weights, context_len, max_model_len - ) # [next_n, max_model_len] - - # 填充到输出logits中 - logits[batch_weights_start:batch_weights_end] = batch_logits - - if clean_logits: - # 清理logits:将超出上下文长度的位置设为负无穷 - logits = clean_logits_tensor(logits, context_lens, next_n, max_model_len) - - return logits - - -def sparse_prefill_fwd( - q: torch.Tensor, - kv: torch.Tensor, - indices: torch.Tensor, - sm_scale: float, - d_v: int = 512, -): - """ - 稀疏注意力预填充内核的PyTorch实现 - - Args: - - q: [s_q, h_q, d_qk], bfloat16 - - kv: [s_kv, h_kv, d_qk], bfloat16 - - indices: [s_q, h_kv, topk], int32. 无效索引设为-1或>=s_kv - - sm_scale: float - - d_v: 值向量的维度,只能为512 - - Returns: - - (output, max_logits, lse) - - output: [s_q, h_q, d_v], bfloat16 - - max_logits: [s_q, h_q], float - - lse: [s_q, h_q], float, 以2为底的对数求和指数 - """ - def ref_masked_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - sm_scale: float, - ) -> torch.Tensor: - query = query * sm_scale - dtype = query.dtype - device = query.device - attn = torch.einsum("qhd,khd->hqk", query, key) - attn = attn.to(torch.float) - attn = torch.softmax(attn, dim=-1) - value = value.to(torch.float) - out = torch.einsum("hqk,khd->qhd", attn, value) - out = out.to(device).to(dtype) - return out - s_q, h_q, d_qk = q.shape - s_kv, h_kv, _ = kv.shape - _, _, topk = indices.shape - - device = q.device - dtype = q.dtype - - # 分离K和V - k = kv # [s_kv, h_kv, d_qk] - v = kv[:, :, :d_v] # [s_kv, h_kv, d_v] - - # 初始化输出 - output = torch.zeros(s_q, h_q, d_v, device=device, dtype=dtype) - # 处理每个查询位置 - for i in range(s_q): - # 当前查询 [h_q, d_qk] - q_i = q[i].unsqueeze(0) # [1, h_q, d_qk] - # 获取当前查询位置的稀疏索引 [topk] - sparse_indices = indices[i, 0] # [topk] - # 过滤有效索引 (>=0 且 < s_kv) - valid_mask = (sparse_indices >= 0) & (sparse_indices < s_kv) - valid_indices = sparse_indices[valid_mask] - # 获取有效的K和V - valid_k = k[valid_indices].repeat(1, h_q, 1) # [valid_len, h_q, d_qk] - valid_v = v[valid_indices].repeat(1, h_q, 1) # [valid_len, h_q, d_v] - out = ref_masked_attention( - q_i, - valid_k, - valid_v, - sm_scale - ) - out = out.view(h_q, d_v) - output[i].copy_(out, non_blocking=True) - return output \ No newline at end of file diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index a8b6b21..21ebeb9 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -31,6 +31,7 @@ from tempfile import NamedTemporaryFile from typing import Any, cast import numpy as np +from huggingface_hub import snapshot_download from PIL import Image from typing_extensions import deprecated @@ -60,6 +61,8 @@ except ImportError: logger = logging.getLogger(__name__) +DEFAULT_NUM_PROMPTS = 1000 + # ----------------------------------------------------------------------------- # Data Classes # ----------------------------------------------------------------------------- @@ -303,9 +306,11 @@ def process_image(image: Any) -> Mapping[str, Any]: a JPEG in memory. - Encodes the JPEG data as a base64 string. - Returns a dictionary with the image as a base64 data URL. - 3. String input: - Treats the string as a URL or local file path. - - Prepends "file://" if the string doesn't start with "http://" or - "file://". - Returns a dictionary with the image URL. + 3. String input: - Treats the string as a URL, local file path, or base64 + encoded data. - If string starts with "data:image/", treats as base64. + - If string starts with "http://", "https://", or "file://", treats as URL. + - Otherwise treats as local file path and prepends "file://". + - Returns a dictionary with the image URL or base64 data. Raises: ValueError: If the input is not a supported type. @@ -325,14 +330,14 @@ def process_image(image: Any) -> Mapping[str, Any]: if isinstance(image, str): image_url = ( image - if image.startswith(("http://", "https://", "file://")) + if image.startswith(("http://", "https://", "file://", "data:image/")) else f"file://{image}" ) return {"type": "image_url", "image_url": {"url": image_url}} raise ValueError( - f"Invalid image input {image}. Must be a PIL.Image.Image" - " or str or dictionary with raw image bytes." + f"Invalid image input {image}. Must be a PIL.Image.Image, " + "str (URL, file path, or base64 data URL), or dictionary with raw image bytes." ) @@ -1338,7 +1343,7 @@ def add_dataset_parser(parser: FlexibleArgumentParser): parser.add_argument( "--num-prompts", type=int, - default=1000, + default=DEFAULT_NUM_PROMPTS, help="Number of prompts to process.", ) parser.add_argument( @@ -2676,6 +2681,14 @@ class MMVUDataset(HuggingFaceDataset): + (" ".join(f"{k}.{v}" for k, v in x["choices"].items())), } + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + + self._remote_path_root = ( + f"https://huggingface.co/datasets/{self.hf_name}/resolve/main" + ) + self._local_path_root = snapshot_download(self.hf_name, repo_type="dataset") + def sample( self, tokenizer: TokenizerLike, @@ -2698,7 +2711,9 @@ class MMVUDataset(HuggingFaceDataset): break prompt = parser_fn(item) - mm_content = process_video(item["video"]) + mm_content = process_video( + item["video"].replace(self._remote_path_root, self._local_path_root) + ) prompt_len = len(tokenizer.encode(prompt)) if enable_multimodal_chat: # Note: when chat is enabled the request prompt_len is no longer diff --git a/vllm/benchmarks/lib/__init__.py b/vllm/benchmarks/lib/__init__.py new file mode 100644 index 0000000..005e87a --- /dev/null +++ b/vllm/benchmarks/lib/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Benchmark library utilities.""" diff --git a/vllm/benchmarks/lib/endpoint_request_func.py b/vllm/benchmarks/lib/endpoint_request_func.py new file mode 100644 index 0000000..e231ccf --- /dev/null +++ b/vllm/benchmarks/lib/endpoint_request_func.py @@ -0,0 +1,802 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""The request function for API endpoints.""" + +import io +import json +import os +import sys +import time +import traceback +from collections.abc import Awaitable +from dataclasses import dataclass, field +from typing import Any, Literal, Protocol + +import aiohttp +import regex as re +from tqdm.asyncio import tqdm + +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) + + +class StreamedResponseHandler: + """Handles streaming HTTP responses by accumulating chunks until complete + messages are available.""" + + def __init__(self): + self.buffer = "" + + def add_chunk(self, chunk_bytes: bytes) -> list[str]: + """Add a chunk of bytes to the buffer and return any complete + messages.""" + chunk_str = chunk_bytes.decode("utf-8") + self.buffer += chunk_str + + messages = [] + + # Split by double newlines (SSE message separator) + while "\n\n" in self.buffer: + message, self.buffer = self.buffer.split("\n\n", 1) + message = message.strip() + if message: + messages.append(message) + + # if self.buffer is not empty, check if it is a complete message + # by removing data: prefix and check if it is a valid JSON + if self.buffer.startswith("data: "): + message_content = self.buffer.removeprefix("data: ").strip() + if message_content == "[DONE]": + messages.append(self.buffer.strip()) + self.buffer = "" + elif message_content: + try: + json.loads(message_content) + messages.append(self.buffer.strip()) + self.buffer = "" + except json.JSONDecodeError: + # Incomplete JSON, wait for more chunks. + pass + + return messages + + +@dataclass +class RequestFuncInput: + """The input for the request function.""" + + prompt: str | list[str] + api_url: str + prompt_len: int + output_len: int + model: str + model_name: str | None = None + logprobs: int | None = None + extra_headers: dict | None = None + extra_body: dict | None = None + multi_modal_content: dict | list[dict] | None = None + ignore_eos: bool = False + language: str | None = None + request_id: str | None = None + + +@dataclass +class RequestFuncOutput: + """The output of the request function including metrics.""" + + generated_text: str = "" + success: bool = False + latency: float = 0.0 + output_tokens: int = 0 + ttft: float = 0.0 # Time to first token + itl: list[float] = field(default_factory=list) # list of inter-token latencies + tpot: float = 0.0 # avg next-token latencies + prompt_len: int = 0 + error: str = "" + start_time: float = 0.0 + input_audio_duration: float = 0.0 # in seconds + + +class RequestFunc(Protocol): + def __call__( + self, + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: tqdm | None = None, + ) -> Awaitable[RequestFuncOutput]: ... + + +def _validate_api_url( + api_url: str, + api_name: str, + expected_suffixes: str | set[str], +) -> None: + if isinstance(expected_suffixes, str): + expected_suffixes = {expected_suffixes} + + expected_suffixes = {*expected_suffixes, "profile"} + + if not api_url.endswith(tuple(expected_suffixes)): + raise ValueError(f"{api_name} URL must end with one of: {expected_suffixes}.") + + +def _update_payload_common( + payload: dict[str, Any], + request_func_input: RequestFuncInput, +) -> None: + if request_func_input.ignore_eos: + payload["ignore_eos"] = request_func_input.ignore_eos + if request_func_input.extra_body: + payload.update(request_func_input.extra_body) + + +def _update_headers_common( + headers: dict[str, Any], + request_func_input: RequestFuncInput, +) -> None: + if request_func_input.extra_headers: + headers |= request_func_input.extra_headers + if request_func_input.request_id: + headers["x-request-id"] = request_func_input.request_id + + +def _get_headers(content_type: str | None = None) -> dict[str, str]: + headers = {} + if content_type: + headers["Content-Type"] = content_type + api_key = os.environ.get("OPENAI_API_KEY") + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + return headers + + +async def async_request_openai_completions( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: tqdm | None = None, +) -> RequestFuncOutput: + """The async request function for the OpenAI Completions API. + + Args: + request_func_input: The input for the request function. + pbar: The progress bar to display the progress. + + Returns: + The output of the request function. + """ + api_url = request_func_input.api_url + _validate_api_url(api_url, "OpenAI Completions API", "completions") + + payload = { + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, + "prompt": request_func_input.prompt, + "repetition_penalty": 1.0, + "max_tokens": request_func_input.output_len, + "logprobs": request_func_input.logprobs, + "stream": True, + "stream_options": { + "include_usage": True, + }, + } + _update_payload_common(payload, request_func_input) + + headers = _get_headers() + _update_headers_common(headers, request_func_input) + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + st = time.perf_counter() + output.start_time = st + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload, headers=headers) as response: + if response.status == 200: + first_chunk_received = False + handler = StreamedResponseHandler() + + async for chunk_bytes in response.content.iter_any(): + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + messages = handler.add_chunk(chunk_bytes) + for message in messages: + # NOTE: SSE comments (often used as pings) start with + # a colon. These are not JSON data payload and should + # be skipped. + if message.startswith(":"): + continue + + chunk = message.removeprefix("data: ") + + if chunk != "[DONE]": + data = json.loads(chunk) + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + if choices := data.get("choices"): + # Note that text could be empty here + # e.g. for special tokens + text = choices[0].get("text") + timestamp = time.perf_counter() + # First token + if not first_chunk_received: + first_chunk_received = True + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text += text or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get("completion_tokens") + if first_chunk_received: + output.success = True + else: + output.success = False + output.error = ( + "Never received a valid chunk to calculate TTFT." + "This response will be marked as failed!" + ) + output.generated_text = generated_text + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +def _get_chat_content( + request_func_input: RequestFuncInput, + mm_position: Literal["first", "last"] = "last", +) -> list[dict[str, Any]]: + text_contents = [{"type": "text", "text": request_func_input.prompt}] + + mm_contents = [] + if request_func_input.multi_modal_content: + mm_content = request_func_input.multi_modal_content + if isinstance(mm_content, list): + mm_contents.extend(request_func_input.multi_modal_content) + elif isinstance(mm_content, dict): + mm_contents.append(request_func_input.multi_modal_content) + else: + raise TypeError( + "multi_modal_content must be a dict or list[dict] for openai-chat" + ) + + if mm_position == "first": + return mm_contents + text_contents + + return text_contents + mm_contents + + +async def async_request_openai_chat_completions( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: tqdm | None = None, + mm_position: Literal["first", "last"] = "last", +) -> RequestFuncOutput: + api_url = request_func_input.api_url + _validate_api_url(api_url, "OpenAI Chat Completions API", "chat/completions") + + content = _get_chat_content(request_func_input, mm_position=mm_position) + + payload = { + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, + "messages": [ + {"role": "user", "content": content}, + ], + "max_completion_tokens": request_func_input.output_len, + "stream": True, + "stream_options": { + "include_usage": True, + }, + } + _update_payload_common(payload, request_func_input) + + headers = _get_headers("application/json") + _update_headers_common(headers, request_func_input) + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + output.start_time = st + most_recent_timestamp = st + try: + async with session.post(url=api_url, json=payload, headers=headers) as response: + if response.status == 200: + handler = StreamedResponseHandler() + async for chunk_bytes in response.content.iter_any(): + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + messages = handler.add_chunk(chunk_bytes) + for message in messages: + # NOTE: SSE comments (often used as pings) start with + # a colon. These are not JSON data payload and should + # be skipped. + if message.startswith(":"): + continue + + chunk = message.removeprefix("data: ") + + if chunk != "[DONE]": + timestamp = time.perf_counter() + data = json.loads(chunk) + + if choices := data.get("choices"): + content = choices[0]["delta"].get("content") + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get("completion_tokens") + + most_recent_timestamp = timestamp + + output.generated_text = generated_text + output.success = True + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +async def async_request_openai_audio( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: tqdm | None = None, +) -> RequestFuncOutput: + # Lazy import without PlaceholderModule to avoid vllm dep. + import soundfile + + api_url = request_func_input.api_url + _validate_api_url(api_url, "OpenAI Audio API", {"transcriptions", "translations"}) + + content = [{"type": "text", "text": request_func_input.prompt}] + payload = { + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, + "max_completion_tokens": request_func_input.output_len, + "stream": True, + "language": "en", + # Flattened due to multipart/form-data + "stream_include_usage": True, + "stream_continuous_usage_stats": True, + } + _update_payload_common(payload, request_func_input) + + headers = _get_headers() + _update_headers_common(headers, request_func_input) + + # Send audio file + def to_bytes(y, sr): + buffer = io.BytesIO() + soundfile.write(buffer, y, sr, format="WAV") + buffer.seek(0) + return buffer + + mm_audio = request_func_input.multi_modal_content + if not isinstance(mm_audio, dict) or "audio" not in mm_audio: + raise TypeError("multi_modal_content must be a dict containing 'audio'") + with to_bytes(*mm_audio["audio"]) as f: + form = aiohttp.FormData() + form.add_field("file", f, content_type="audio/wav") + for key, value in payload.items(): + form.add_field(key, str(value)) + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + output.input_audio_duration = soundfile.info(f).duration + f.seek(0) + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + output.start_time = st + most_recent_timestamp = st + try: + async with session.post( + url=api_url, data=form, headers=headers + ) as response: + if response.status == 200: + handler = StreamedResponseHandler() + + async for chunk_bytes in response.content.iter_any(): + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + messages = handler.add_chunk(chunk_bytes) + for message in messages: + if type(message) is bytes: + message = message.decode("utf-8") + chunk = message.removeprefix("data: ") + if chunk != "[DONE]": + timestamp = time.perf_counter() + data = json.loads(chunk) + + if choices := data.get("choices"): + content = choices[0]["delta"].get("content") + # First token + if ttft == 0.0: + ttft = timestamp - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append( + timestamp - most_recent_timestamp + ) + + generated_text += content or "" + elif usage := data.get("usage"): + output.output_tokens = usage.get( + "completion_tokens" + ) + + most_recent_timestamp = timestamp + + output.generated_text = generated_text + output.success = True + output.latency = most_recent_timestamp - st + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + +async def _run_pooling_request( + session: aiohttp.ClientSession, + api_url: str, + payload: dict[str, Any], + headers: dict[str, Any], + pbar: tqdm | None = None, +) -> RequestFuncOutput: + output = RequestFuncOutput() + st = time.perf_counter() + output.start_time = st + try: + async with session.post(url=api_url, headers=headers, json=payload) as response: + if response.status == 200: + output.ttft = output.latency = time.perf_counter() - st + + if payload.get("encoding_format", "float") == "bytes": + metadata = json.loads(response.headers["metadata"]) + usage = metadata.get("usage", {}) + else: + data = await response.json() + usage = data.get("usage", {}) + + output.success = True + output.generated_text = "" + output.prompt_len = usage.get("prompt_tokens", 0) + else: + output.success = False + output.error = response.reason or "" + except Exception as e: + output.success = False + output.error = str(e) + + if pbar: + pbar.update(1) + return output + + +async def async_request_openai_embeddings( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: tqdm | None = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + _validate_api_url(api_url, "OpenAI Embeddings API", "embeddings") + + payload = { + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, + "input": request_func_input.prompt, + # Many embedding models have short context length, + # this is to avoid dropping some of the requests. + "truncate_prompt_tokens": -1, + } + _update_payload_common(payload, request_func_input) + + headers = _get_headers("application/json") + _update_headers_common(headers, request_func_input) + + return await _run_pooling_request( + session, + api_url, + payload=payload, + headers=headers, + pbar=pbar, + ) + + +async def async_request_vllm_rerank( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: tqdm | None = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + _validate_api_url(api_url, "vLLM score API", "rerank") + + assert ( + isinstance(request_func_input.prompt, list) + and len(request_func_input.prompt) > 1 + ) + + payload = { + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, + "query": request_func_input.prompt[0], + "documents": request_func_input.prompt[1:], + # Many reranker models have short context length, + # this is to avoid dropping some of the requests. + "truncate_prompt_tokens": -1, + } + + headers = _get_headers("application/json") + _update_headers_common(headers, request_func_input) + + return await _run_pooling_request( + session, + api_url, + payload=payload, + headers=headers, + pbar=pbar, + ) + + +async def async_request_openai_embeddings_chat( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: tqdm | None = None, + mm_position: Literal["first", "last"] = "last", +) -> RequestFuncOutput: + api_url = request_func_input.api_url + _validate_api_url(api_url, "OpenAI Embeddings API", "embeddings") + + content = _get_chat_content(request_func_input, mm_position=mm_position) + + payload = { + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, + "messages": [ + {"role": "user", "content": content}, + ], + # Many embedding models have short context length, + # this is to avoid dropping some of the requests. + "truncate_prompt_tokens": -1, + } + _update_payload_common(payload, request_func_input) + + headers = _get_headers("application/json") + _update_headers_common(headers, request_func_input) + + return await _run_pooling_request( + session, + api_url, + payload=payload, + headers=headers, + pbar=pbar, + ) + + +def _try_extract_request_idx(request_func_input: RequestFuncInput): + if request_func_input.request_id: + match = re.search(r"(\d+)$", request_func_input.request_id) + if match: + try: + return int(match.group(1)) + except ValueError: + pass + + return None + + +def _preprocess_clip(request_func_input: RequestFuncInput): + if request_func_input.multi_modal_content: + # Image input + request_func_input.prompt = "" + + +def _preprocess_vlm2vec(request_func_input: RequestFuncInput): + if request_func_input.multi_modal_content: + request_idx = _try_extract_request_idx(request_func_input) + + # Adjust the ratio manually if needed. + use_image_only_prompt = request_idx is None or request_idx % 2 == 0 + + if use_image_only_prompt: + # Image input + request_func_input.prompt = "Represent the given image." + else: + # Text+Image input + request_func_input.prompt = ( + f"Represent the given image with the following question: " + f"{request_func_input.prompt}" + ) + + +async def async_request_openai_embeddings_clip( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: tqdm | None = None, +) -> RequestFuncOutput: + _preprocess_clip(request_func_input) + + return await async_request_openai_embeddings_chat( + request_func_input, + session, + pbar=pbar, + ) + + +async def async_request_openai_embeddings_vlm2vec( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: tqdm | None = None, +) -> RequestFuncOutput: + _preprocess_vlm2vec(request_func_input) + + return await async_request_openai_embeddings_chat( + request_func_input, + session, + pbar=pbar, + mm_position="first", + ) + + +async def async_request_infinity_embeddings( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: tqdm | None = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + _validate_api_url(api_url, "Infinity Embeddings API", "embeddings") + + payload = { + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, + } + + if request_func_input.prompt: + payload["input"] = request_func_input.prompt + else: + mm_content = request_func_input.multi_modal_content + assert isinstance(mm_content, dict) + + mm_type = mm_content["type"] + payload["input"] = mm_content[mm_type]["url"] + payload["modality"] = mm_type.split("_", 1)[0] + + _update_payload_common(payload, request_func_input) + + headers = _get_headers("application/json") + _update_headers_common(headers, request_func_input) + + return await _run_pooling_request( + session, + api_url, + payload=payload, + headers=headers, + pbar=pbar, + ) + + +async def async_request_infinity_embeddings_clip( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: tqdm | None = None, +) -> RequestFuncOutput: + _preprocess_clip(request_func_input) + + return await async_request_infinity_embeddings( + request_func_input, + session, + pbar=pbar, + ) + + +async def async_request_vllm_pooling( + request_func_input: RequestFuncInput, + session: aiohttp.ClientSession, + pbar: tqdm | None = None, +) -> RequestFuncOutput: + api_url = request_func_input.api_url + _validate_api_url(api_url, "vLLM Pooling API", "pooling") + + payload = { + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, + "truncate_prompt_tokens": -1, + } + + payload = payload | request_func_input.prompt + + _update_payload_common(payload, request_func_input) + + headers = _get_headers("application/json") + _update_headers_common(headers, request_func_input) + + return await _run_pooling_request( + session, + api_url, + payload=payload, + headers=headers, + pbar=pbar, + ) + + +# TODO: Add more request functions for different API protocols. +ASYNC_REQUEST_FUNCS: dict[str, RequestFunc] = { + "vllm": async_request_openai_completions, + "openai": async_request_openai_completions, + "openai-chat": async_request_openai_chat_completions, + "openai-audio": async_request_openai_audio, + "openai-embeddings": async_request_openai_embeddings, + "openai-embeddings-chat": async_request_openai_embeddings_chat, + "openai-embeddings-clip": async_request_openai_embeddings_clip, + "openai-embeddings-vlm2vec": async_request_openai_embeddings_vlm2vec, + # Infinity embedding server: https://github.com/michaelfeil/infinity + "infinity-embeddings": async_request_infinity_embeddings, + "infinity-embeddings-clip": async_request_infinity_embeddings_clip, + # (Infinity embedding server does not support vlm2vec) + "vllm-pooling": async_request_vllm_pooling, + "vllm-rerank": async_request_vllm_rerank, +} + +OPENAI_COMPATIBLE_BACKENDS = [ + k + for k, v in ASYNC_REQUEST_FUNCS.items() + if v in (async_request_openai_completions, async_request_openai_chat_completions) +] diff --git a/vllm/benchmarks/lib/ready_checker.py b/vllm/benchmarks/lib/ready_checker.py new file mode 100644 index 0000000..eec4a42 --- /dev/null +++ b/vllm/benchmarks/lib/ready_checker.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Utilities for checking endpoint readiness.""" + +import asyncio +import time + +import aiohttp +from tqdm.asyncio import tqdm + +from vllm.logger import init_logger + +from .endpoint_request_func import RequestFunc, RequestFuncInput, RequestFuncOutput + +logger = init_logger(__name__) + + +async def wait_for_endpoint( + request_func: RequestFunc, + test_input: RequestFuncInput, + session: aiohttp.ClientSession, + timeout_seconds: int = 600, + retry_interval: int = 5, +) -> RequestFuncOutput: + """ + Wait for an endpoint to become available before starting benchmarks. + + Args: + request_func: The async request function to call + test_input: The RequestFuncInput to test with + timeout_seconds: Maximum time to wait in seconds (default: 10 minutes) + retry_interval: Time between retries in seconds (default: 5 seconds) + + Returns: + RequestFuncOutput: The successful response + + Raises: + ValueError: If the endpoint doesn't become available within the timeout + """ + deadline = time.perf_counter() + timeout_seconds + output = RequestFuncOutput(success=False) + print(f"Waiting for endpoint to become up in {timeout_seconds} seconds") + + with tqdm( + total=timeout_seconds, + bar_format="{desc} |{bar}| {elapsed} elapsed, {remaining} remaining", + unit="s", + ) as pbar: + while True: + # update progress bar + remaining = deadline - time.perf_counter() + elapsed = timeout_seconds - remaining + update_amount = min(elapsed - pbar.n, timeout_seconds - pbar.n) + pbar.update(update_amount) + pbar.refresh() + if remaining <= 0: + pbar.close() + break + + # ping the endpoint using request_func + try: + output = await request_func( + request_func_input=test_input, session=session + ) + if output.success: + pbar.close() + return output + else: + err_last_line = str(output.error).rstrip().rsplit("\n", 1)[-1] + logger.warning("Endpoint is not ready. Error='%s'", err_last_line) + except aiohttp.ClientConnectorError: + pass + + # retry after a delay + sleep_duration = min(retry_interval, remaining) + if sleep_duration > 0: + await asyncio.sleep(sleep_duration) + + return output diff --git a/vllm/benchmarks/lib/utils.py b/vllm/benchmarks/lib/utils.py new file mode 100644 index 0000000..99a3bf9 --- /dev/null +++ b/vllm/benchmarks/lib/utils.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import json +import math +import os +from contextlib import contextmanager +from typing import Any + + +def extract_field( + args: argparse.Namespace, extra_info: dict[str, Any], field_name: str +) -> str: + if field_name in extra_info: + return extra_info[field_name] + + v = args + # For example, args.compilation_config.mode + for nested_field in field_name.split("."): + if not hasattr(v, nested_field): + return "" + v = getattr(v, nested_field) + return v + + +def use_compile(args: argparse.Namespace, extra_info: dict[str, Any]) -> bool: + """ + Check if the benchmark is run with torch.compile + """ + return not ( + extract_field(args, extra_info, "compilation_config.mode") == "0" + or "eager" in getattr(args, "output_json", "") + or "eager" in getattr(args, "result_filename", "") + ) + + +def convert_to_pytorch_benchmark_format( + args: argparse.Namespace, metrics: dict[str, list], extra_info: dict[str, Any] +) -> list: + """ + Save the benchmark results in the format used by PyTorch OSS benchmark with + on metric per record + https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database + """ + records = [] + if not os.environ.get("SAVE_TO_PYTORCH_BENCHMARK_FORMAT", False): + return records + + for name, benchmark_values in metrics.items(): + if not isinstance(benchmark_values, list): + raise TypeError( + f"benchmark_values for metric '{name}' must be a list, " + f"but got {type(benchmark_values).__name__}" + ) + + record = { + "benchmark": { + "name": "vLLM benchmark", + "extra_info": { + "args": vars(args), + "compilation_config.mode": extract_field( + args, extra_info, "compilation_config.mode" + ), + "optimization_level": extract_field( + args, extra_info, "optimization_level" + ), + # A boolean field used by vLLM benchmark HUD dashboard + "use_compile": use_compile(args, extra_info), + }, + }, + "model": { + "name": args.model, + }, + "metric": { + "name": name, + "benchmark_values": benchmark_values, + "extra_info": extra_info, + }, + } + + tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size") + # Save tensor_parallel_size parameter if it's part of the metadata + if not tp and "tensor_parallel_size" in extra_info: + record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = ( + extra_info["tensor_parallel_size"] + ) + + records.append(record) + + return records + + +class InfEncoder(json.JSONEncoder): + def clear_inf(self, o: Any): + if isinstance(o, dict): + return { + str(k) + if not isinstance(k, (str, int, float, bool, type(None))) + else k: self.clear_inf(v) + for k, v in o.items() + } + elif isinstance(o, list): + return [self.clear_inf(v) for v in o] + elif isinstance(o, float) and math.isinf(o): + return "inf" + return o + + def iterencode(self, o: Any, *args, **kwargs) -> Any: + return super().iterencode(self.clear_inf(o), *args, **kwargs) + + +def write_to_json(filename: str, records: list) -> None: + with open(filename, "w") as f: + json.dump( + records, + f, + cls=InfEncoder, + default=lambda o: f"<{type(o).__name__} is not JSON serializable>", + ) + + +@contextmanager +def default_vllm_config(): + """Set a default VllmConfig for cases that directly test CustomOps or pathways + that use get_current_vllm_config() outside of a full engine context. + """ + from vllm.config import VllmConfig, set_current_vllm_config + + with set_current_vllm_config(VllmConfig()): + yield diff --git a/vllm/benchmarks/plot.py b/vllm/benchmarks/plot.py new file mode 100644 index 0000000..3f36ede --- /dev/null +++ b/vllm/benchmarks/plot.py @@ -0,0 +1,316 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Generate plots for benchmark results.""" + +from pathlib import Path +from typing import Any + +from vllm.utils.import_utils import PlaceholderModule + +try: + import plotly.express as px + import plotly.io as pio +except ImportError: + _plotly = PlaceholderModule("plotly") + px = _plotly.placeholder_attr("express") + pio = _plotly.placeholder_attr("io") + +try: + import matplotlib.pyplot as plt +except ImportError: + _matplotlib = PlaceholderModule("matplotlib") + plt = _matplotlib.placeholder_attr("pyplot") + + +def generate_timeline_plot( + results: list[dict[str, Any]], + output_path: Path, + colors: list[str] | None = None, + itl_thresholds: list[float] | None = None, + labels: list[str] | None = None, +) -> None: + """ + Generate an HTML timeline plot from benchmark results. + + Args: + results: List of per-request result dictionaries containing: + - start_time: Request start time (seconds) + - ttft: Time to first token (seconds) + - itl: List of inter-token latencies (seconds) + - latency: Total request latency (seconds) + - prompt_len: Number of prompt tokens + - output_tokens: Number of output tokens + output_path: Path where the HTML file will be saved + colors: List of colors for ITL categories (default: green, orange, red, black) + itl_thresholds: ITL thresholds in seconds (default: [1.0, 4.0, 6.0]) + labels: Labels for ITL categories (default based on thresholds) + """ + + # Set defaults + if colors is None: + colors = ["#109618", "#FF7F0E", "#D62728"] + if itl_thresholds is None: + itl_thresholds = [0.025, 0.050] + if labels is None: + labels = [ + f"ITL < {itl_thresholds[0] * 1000:.0f}ms", + f"{itl_thresholds[0] * 1000:.0f}ms ≤ ITL < {itl_thresholds[1] * 1000:.0f}ms", # noqa + f"ITL ≥ {itl_thresholds[1] * 1000:.0f}ms", + ] + + labels_colors = {"TTFT": "#636EFA", **dict(zip(labels, colors))} + labels_order = ["TTFT"] + labels + + timeline_data = construct_timeline_data(results, itl_thresholds, labels) + + if not timeline_data: + print("No timeline data to plot") + return + + # Create the plot + fig = px.timeline( + timeline_data, + x_start="start", + x_end="end", + y="request_id", + color="type", + color_discrete_map=labels_colors, + category_orders={"type": labels_order}, + hover_data=[ + "prompt_tokens", + "output_tokens", + "req_start_time", + "req_finish_time", + "segment_start", + "segment_end", + "duration", + ], + ) + + # Customize hover template to show only time without date + fig.update_traces( + hovertemplate="%{y}
" + "Type: %{fullData.name}
" + "Start: %{customdata[4]}
" + "End: %{customdata[5]}
" + "Duration: %{customdata[6]}
" + "Prompt Tokens: %{customdata[0]}
" + "Output Tokens: %{customdata[1]}
" + "Request Start Time: %{customdata[2]}
" + "Request End Time: %{customdata[3]}
" + "" + ) + + fig.update_yaxes(autorange="reversed") + fig.update_layout( + xaxis_title="Time", + yaxis_title="Request ID", + showlegend=True, + ) + + # Save to HTML + pio.write_html(fig, str(output_path)) + print(f"Timeline plot saved to: {output_path}") + + +def construct_timeline_data( + requests_data: list[dict[str, Any]], + itl_thresholds: list[float], + labels: list[str], +) -> list[dict[str, Any]]: + """ + Construct timeline data from request results. + + Args: + requests_data: List of per-request result dictionaries + itl_thresholds: ITL thresholds in seconds + labels: Labels for ITL categories + + Returns: + List of timeline segments for plotting + """ + + def tostr(sec_time: float) -> str: + """Convert seconds to HH:MM:SS.mmm format.""" + h = int(sec_time // 3600) + assert h < 100, "time seems to last more than 100 hours" + m = int((sec_time % 3600) // 60) + s = sec_time % 60 + return f"{h:02d}:{m:02d}:{s:06.3f}" + + def itl_type(itl: float) -> str: + """Categorize ITL based on thresholds.""" + if itl < itl_thresholds[0]: + return labels[0] + elif itl < itl_thresholds[1]: + return labels[1] + else: + return labels[2] + + # Find the earliest start time to use as t0 + t0 = None + for request in requests_data: + start_time = request.get("start_time") + if start_time is not None and (t0 is None or start_time < t0): + t0 = start_time + + if t0 is None: + return [] + + timeline_data = [] + + for i, request in enumerate(requests_data): + start_time = request.get("start_time") + ttft = request.get("ttft") + itl = request.get("itl", []) + latency = request.get("latency") + prompt_len = request.get("prompt_len", 0) + output_tokens = request.get("output_tokens", 0) + + # Skip requests without required data + if start_time is None or ttft is None or latency is None: + continue + + # Normalize start time + start_time = start_time - t0 + start_time_str = tostr(start_time) + + # TTFT segment + ttft_end = start_time + ttft + ttft_end_str = tostr(ttft_end) + + timeline_data.append( + { + "request_id": f"Req {i}", + "start": start_time_str, + "end": ttft_end_str, + "type": "TTFT", + "prompt_tokens": prompt_len, + "output_tokens": output_tokens, + "req_start_time": tostr(start_time), + "req_finish_time": tostr(start_time + latency), + "segment_start": start_time_str, + "segment_end": ttft_end_str, + "duration": f"{ttft:.3f}s", + } + ) + + # ITL segments + prev_time = ttft_end + prev_time_str = ttft_end_str + + for itl_value in itl: + itl_end = prev_time + itl_value + itl_end_str = tostr(itl_end) + + timeline_data.append( + { + "request_id": f"Req {i}", + "start": prev_time_str, + "end": itl_end_str, + "type": itl_type(itl_value), + "prompt_tokens": prompt_len, + "output_tokens": output_tokens, + "req_start_time": tostr(start_time), + "req_finish_time": tostr(start_time + latency), + "segment_start": prev_time_str, + "segment_end": itl_end_str, + "duration": f"{itl_value:.3f}s", + } + ) + + prev_time = itl_end + prev_time_str = itl_end_str + + return timeline_data + + +def generate_dataset_stats_plot( + results: list[dict[str, Any]], + output_path: Path, +) -> None: + """ + Generate a matplotlib figure with dataset statistics. + + Creates a figure with 4 subplots: + - Top-left: Prompt tokens distribution (histogram) + - Top-right: Output tokens distribution (histogram) + - Bottom-left: Prompt+output tokens distribution (histogram) + - Bottom-right: Stacked bar chart (request_id vs tokens) + + Args: + results: List of per-request result dictionaries containing: + - prompt_len: Number of prompt tokens + - output_tokens: Number of output tokens + output_path: Path where the figure will be saved + """ + # Extract data + prompt_tokens = [] + output_tokens = [] + total_tokens = [] + + for request in results: + prompt_len = request.get("prompt_len", 0) + output_len = request.get("output_tokens", 0) + + prompt_tokens.append(prompt_len) + output_tokens.append(output_len) + total_tokens.append(prompt_len + output_len) + + if not prompt_tokens: + print("No data available for dataset statistics plot") + return + + # Create figure with 4 subplots + fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 10)) + + # Top-left: Prompt tokens distribution + ax1.hist(prompt_tokens, bins=30, color="steelblue", edgecolor="black", alpha=0.7) + ax1.set_xlabel("Prompt Tokens") + ax1.set_ylabel("Frequency") + ax1.set_title("Prompt Tokens Distribution") + ax1.grid(True, alpha=0.3) + + # Top-right: Output tokens distribution + ax2.hist(output_tokens, bins=30, color="coral", edgecolor="black", alpha=0.7) + ax2.set_xlabel("Output Tokens") + ax2.set_ylabel("Frequency") + ax2.set_title("Output Tokens Distribution") + ax2.grid(True, alpha=0.3) + + # Bottom-left: Prompt+output tokens distribution + ax3.hist( + total_tokens, bins=30, color="mediumseagreen", edgecolor="black", alpha=0.7 + ) + ax3.set_xlabel("Total Tokens (Prompt + Output)") + ax3.set_ylabel("Frequency") + ax3.set_title("Total Tokens Distribution") + ax3.grid(True, alpha=0.3) + + # Bottom-right: Stacked bar chart + request_ids = list(range(len(prompt_tokens))) + ax4.bar( + request_ids, prompt_tokens, label="Prompt Tokens", color="steelblue", alpha=0.7 + ) + ax4.bar( + request_ids, + output_tokens, + bottom=prompt_tokens, + label="Output Tokens", + color="coral", + alpha=0.7, + ) + ax4.set_xlabel("Request ID") + ax4.set_ylabel("Tokens") + ax4.set_title("Tokens per Request (Stacked)") + ax4.legend() + ax4.grid(True, alpha=0.3, axis="y") + + # Adjust layout to prevent overlap + plt.tight_layout() + + # Save figure + plt.savefig(str(output_path), dpi=150, bbox_inches="tight") + plt.close(fig) + + print(f"Dataset statistics plot saved to: {output_path}") diff --git a/vllm/benchmarks/serve.py b/vllm/benchmarks/serve.py index 06e67f9..7c9a95e 100644 --- a/vllm/benchmarks/serve.py +++ b/vllm/benchmarks/serve.py @@ -34,6 +34,7 @@ from collections.abc import AsyncGenerator, Iterable from dataclasses import dataclass from datetime import datetime from enum import Enum +from pathlib import Path from typing import Any, Literal import aiohttp @@ -1183,6 +1184,49 @@ def save_to_pytorch_benchmark_format( write_to_json(pt_file, pt_records) +def compute_result_filename( + args: argparse.Namespace, + model_id: str, + label: str, + current_dt: str, +) -> str | None: + """Compute the result filename based on benchmark configuration. + + Args: + args: Command line arguments containing result configuration + model_id: The model identifier + label: The benchmark label + current_dt: Current datetime string + + Returns: + The computed filename path or None if no result saving is requested + """ + if not (args.plot_timeline or args.save_result or args.append_result): + return None + + base_model_id = model_id.split("/")[-1] + max_concurrency_str = ( + f"-concurrency{args.max_concurrency}" + if args.max_concurrency is not None + else "" + ) + label = label or args.backend + + if args.ramp_up_strategy is not None: + file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa + else: + file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa + + if args.result_filename: + file_name = args.result_filename + + if args.result_dir: + os.makedirs(args.result_dir, exist_ok=True) + file_name = os.path.join(args.result_dir, file_name) + + return file_name + + def add_cli_args(parser: argparse.ArgumentParser): add_dataset_parser(parser) parser.add_argument( @@ -1277,6 +1321,7 @@ def add_cli_args(parser: argparse.ArgumentParser): - "slow" will always use the slow tokenizer.\n - "mistral" will always use the tokenizer from `mistral_common`.\n - "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n + - "qwen_vl" will always use the tokenizer from `qwen_vl`.\n - Other custom values can be supported via plugins.""", ) parser.add_argument("--use-beam-search", action="store_true") @@ -1535,6 +1580,30 @@ def add_cli_args(parser: argparse.ArgumentParser): "connecting to servers with self-signed certificates.", ) + parser.add_argument( + "--plot-timeline", + action="store_true", + help="Generate an HTML timeline plot showing request execution. " + "The plot will be saved alongside the results JSON file.", + ) + parser.add_argument( + "--timeline-itl-thresholds", + type=float, + nargs=2, + default=[25.0, 50.0], + metavar=("THRESHOLD1", "THRESHOLD2"), + help="ITL thresholds in milliseconds for timeline plot coloring. " + "Specify two values to categorize inter-token latencies into three groups: " + "below first threshold (green), between thresholds (orange), " + "and above second threshold (red). Default: 25 50 (milliseconds).", + ) + parser.add_argument( + "--plot-dataset-stats", + action="store_true", + help="Generate a matplotlib figure with dataset statistics showing " + "prompt tokens, output tokens, and combined token distributions.", + ) + def main(args: argparse.Namespace) -> dict[str, Any]: return asyncio.run(main_async(args)) @@ -1770,6 +1839,86 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: # Merge with benchmark result result_json = {**result_json, **benchmark_result} + # Compute file_name once before using it for plots or saving results + file_name = compute_result_filename(args, model_id, label, current_dt) + + # Generate timeline plot if requested + if args.plot_timeline: + try: + from vllm.benchmarks.plot import generate_timeline_plot + + # Prepare per-request data for timeline + per_request_data = [] + start_times = benchmark_result.get("start_times", []) + ttfts = benchmark_result.get("ttfts", []) + itls = benchmark_result.get("itls", []) + input_lens = benchmark_result.get("input_lens", []) + output_lens = benchmark_result.get("output_lens", []) + + if start_times and ttfts and itls: + for i in range(len(start_times)): + # Calculate latency as ttft + sum of all itls + latency = ttfts[i] + sum(itls[i]) if itls[i] else ttfts[i] + + per_request_data.append( + { + "start_time": start_times[i], + "ttft": ttfts[i], + "itl": itls[i], + "latency": latency, + "prompt_len": input_lens[i], + "output_tokens": output_lens[i], + } + ) + + timeline_path = Path(file_name).with_suffix(".timeline.html") + # Convert thresholds from milliseconds to seconds + itl_thresholds_sec = [t / 1000.0 for t in args.timeline_itl_thresholds] + generate_timeline_plot( + per_request_data, timeline_path, itl_thresholds=itl_thresholds_sec + ) + else: + warnings.warn( + "Timeline plot requires detailed metrics. " + "Ensure the benchmark completed successfully.", + stacklevel=2, + ) + except Exception as e: + warnings.warn(f"Failed to generate timeline plot: {e}", stacklevel=2) + + # Generate dataset statistics plot if requested + if args.plot_dataset_stats: + try: + from vllm.benchmarks.plot import generate_dataset_stats_plot + + # Prepare per-request data for dataset stats + per_request_data = [] + input_lens = benchmark_result.get("input_lens", []) + output_lens = benchmark_result.get("output_lens", []) + + if input_lens and output_lens: + for req_input_len, req_output_len in zip(input_lens, output_lens): + per_request_data.append( + { + "prompt_len": req_input_len, + "output_tokens": req_output_len, + } + ) + + stats_path = Path(file_name).with_suffix(".dataset_stats.png") + generate_dataset_stats_plot(per_request_data, stats_path) + else: + warnings.warn( + "Dataset statistics plot requires input and " + "output length data. Ensure the benchmark completed " + "successfully.", + stacklevel=2, + ) + except Exception as e: + warnings.warn( + f"Failed to generate dataset statistics plot: {e}", stacklevel=2 + ) + if not args.save_detailed: # Remove fields with too many data points for field in [ @@ -1786,24 +1935,8 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]: if field in benchmark_result: del benchmark_result[field] - # Save to file + # Save to file if args.save_result or args.append_result: - base_model_id = model_id.split("/")[-1] - max_concurrency_str = ( - f"-concurrency{args.max_concurrency}" - if args.max_concurrency is not None - else "" - ) - label = label or args.backend - if args.ramp_up_strategy is not None: - file_name = f"{label}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa - else: - file_name = f"{label}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa - if args.result_filename: - file_name = args.result_filename - if args.result_dir: - os.makedirs(args.result_dir, exist_ok=True) - file_name = os.path.join(args.result_dir, file_name) with open( file_name, mode="a+" if args.append_result else "w", encoding="utf-8" ) as outfile: diff --git a/vllm/benchmarks/sweep/cli.py b/vllm/benchmarks/sweep/cli.py index a752000..7554910 100644 --- a/vllm/benchmarks/sweep/cli.py +++ b/vllm/benchmarks/sweep/cli.py @@ -10,14 +10,14 @@ from .plot_pareto import SweepPlotParetoArgs from .plot_pareto import main as plot_pareto_main from .serve import SweepServeArgs from .serve import main as serve_main -from .serve_sla import SweepServeSLAArgs -from .serve_sla import main as serve_sla_main +from .serve_workload import SweepServeWorkloadArgs +from .serve_workload import main as serve_workload_main from .startup import SweepStartupArgs from .startup import main as startup_main SUBCOMMANDS = ( (SweepServeArgs, serve_main), - (SweepServeSLAArgs, serve_sla_main), + (SweepServeWorkloadArgs, serve_workload_main), (SweepStartupArgs, startup_main), (SweepPlotArgs, plot_main), (SweepPlotParetoArgs, plot_pareto_main), diff --git a/vllm/benchmarks/sweep/plot.py b/vllm/benchmarks/sweep/plot.py index 53c7db3..156e18f 100644 --- a/vllm/benchmarks/sweep/plot.py +++ b/vllm/benchmarks/sweep/plot.py @@ -324,6 +324,11 @@ def _plot_fig( df = filter_by.apply(df) df = bin_by.apply(df) + if len(df) == 0: + print(f"No data to plot. Filters: {filter_by}") + print("[END FIGURE]") + return + # Sort by curve_by columns alphabetically for consistent legend ordering if curve_by: df = df.sort_values(by=curve_by) @@ -494,7 +499,7 @@ class SweepPlotArgs: @classmethod def from_cli_args(cls, args: argparse.Namespace): - output_dir = Path(args.OUTPUT_DIR) + output_dir = Path(args.EXPERIMENT_DIR) if not output_dir.exists(): raise ValueError(f"No parameter sweep results under {output_dir}") @@ -526,11 +531,9 @@ class SweepPlotArgs: @classmethod def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument( - "OUTPUT_DIR", + "EXPERIMENT_DIR", type=str, - default="results", - help="The directory containing the results to plot, " - "i.e., the `--output-dir` argument to the parameter sweep script.", + help="The directory containing the sweep results to plot.", ) parser.add_argument( "--fig-dir", @@ -570,13 +573,13 @@ class SweepPlotArgs: parser.add_argument( "--var-x", type=str, - default="request_throughput", + default="total_token_throughput", help="The variable for the x-axis.", ) parser.add_argument( "--var-y", type=str, - default="p99_ttft_ms", + default="median_ttft_ms", help="The variable for the y-axis", ) parser.add_argument( diff --git a/vllm/benchmarks/sweep/plot_pareto.py b/vllm/benchmarks/sweep/plot_pareto.py index 3d17e47..365e87f 100644 --- a/vllm/benchmarks/sweep/plot_pareto.py +++ b/vllm/benchmarks/sweep/plot_pareto.py @@ -325,7 +325,7 @@ class SweepPlotParetoArgs: @classmethod def from_cli_args(cls, args: argparse.Namespace): - output_dir = Path(args.OUTPUT_DIR) + output_dir = Path(args.EXPERIMENT_DIR) if not output_dir.exists(): raise ValueError(f"No parameter sweep results under {output_dir}") @@ -342,9 +342,8 @@ class SweepPlotParetoArgs: @classmethod def add_cli_args(cls, parser: argparse.ArgumentParser): parser.add_argument( - "OUTPUT_DIR", + "EXPERIMENT_DIR", type=str, - default="results", help="The directory containing the sweep results to plot.", ) parser.add_argument( diff --git a/vllm/benchmarks/sweep/serve.py b/vllm/benchmarks/sweep/serve.py index 7420f25..f64006e 100644 --- a/vllm/benchmarks/sweep/serve.py +++ b/vllm/benchmarks/sweep/serve.py @@ -4,6 +4,7 @@ import argparse import contextlib import json import shlex +from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime from pathlib import Path @@ -135,17 +136,21 @@ def run_benchmark( def _get_comb_base_path( - output_dir: Path, + experiment_dir: Path, serve_comb: ParameterSweepItem, bench_comb: ParameterSweepItem, + *, + extra_parts: tuple[str, ...] = (), ): parts = list[str]() if serve_comb: parts.extend(("SERVE-", serve_comb.name)) if bench_comb: parts.extend(("BENCH-", bench_comb.name)) + if extra_parts: + parts.extend(extra_parts) - return output_dir / sanitize_filename("-".join(parts)) + return experiment_dir / sanitize_filename("-".join(parts)) def _get_comb_run_path(base_path: Path, run_number: int | None): @@ -158,10 +163,10 @@ def _get_comb_run_path(base_path: Path, run_number: int | None): def _comb_needs_server( serve_comb: ParameterSweepItem, bench_combs: ParameterSweep, - output_dir: Path, + experiment_dir: Path, ): for bench_comb in bench_combs: - base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb) + base_path = _get_comb_base_path(experiment_dir, serve_comb, bench_comb) if not _get_comb_run_path(base_path, run_number=None).exists(): return True @@ -175,11 +180,11 @@ def server_ctx( show_stdout: bool, serve_comb: ParameterSweepItem, bench_params: ParameterSweep, - output_dir: Path, + experiment_dir: Path, dry_run: bool, server_ready_timeout: int = 300, ): - if not _comb_needs_server(serve_comb, bench_params, output_dir): + if not _comb_needs_server(serve_comb, bench_params, experiment_dir): return contextlib.nullcontext() return run_server( @@ -211,10 +216,10 @@ def run_comb( *, serve_comb: ParameterSweepItem, bench_comb: ParameterSweepItem, + link_vars: list[tuple[str, str]], base_path: Path, num_runs: int, dry_run: bool, - link_vars: list[tuple[str, str]], ): if not _comb_is_valid(serve_comb, bench_comb, link_vars): return None @@ -253,10 +258,10 @@ def run_combs( server_ready_timeout: int, serve_params: ParameterSweep, bench_params: ParameterSweep, - output_dir: Path, + link_vars: list[tuple[str, str]], + experiment_dir: Path, num_runs: int, dry_run: bool, - link_vars: list[tuple[str, str]], ): all_data = list[dict[str, object]]() for serve_comb in serve_params: @@ -266,22 +271,22 @@ def run_combs( show_stdout=show_stdout, serve_comb=serve_comb, bench_params=bench_params, - output_dir=output_dir, + experiment_dir=experiment_dir, dry_run=dry_run, server_ready_timeout=server_ready_timeout, ) as server: for bench_comb in bench_params: - base_path = _get_comb_base_path(output_dir, serve_comb, bench_comb) + base_path = _get_comb_base_path(experiment_dir, serve_comb, bench_comb) comb_data = run_comb( server, bench_cmd, serve_comb=serve_comb, bench_comb=bench_comb, + link_vars=link_vars, base_path=base_path, num_runs=num_runs, dry_run=dry_run, - link_vars=link_vars, ) if comb_data is not None: @@ -291,7 +296,7 @@ def run_combs( return None combined_df = pd.DataFrame.from_records(all_data) - combined_df.to_csv(output_dir / "summary.csv") + combined_df.to_csv(experiment_dir / "summary.csv") return combined_df @@ -305,11 +310,12 @@ class SweepServeArgs: server_ready_timeout: int serve_params: ParameterSweep bench_params: ParameterSweep + link_vars: list[tuple[str, str]] output_dir: Path + experiment_name: str num_runs: int dry_run: bool - resume: str | None - link_vars: list[tuple[str, str]] + resume: bool parser_name: ClassVar[str] = "serve" parser_help: ClassVar[str] = "Run vLLM server benchmark under multiple settings." @@ -336,6 +342,11 @@ class SweepServeArgs: link_vars = cls.parse_link_vars(args.link_vars) + if args.experiment_name: + experiment_name = args.experiment_name + else: + experiment_name = datetime.now().strftime("%Y%m%d_%H%M%S") + num_runs = args.num_runs if num_runs < 1: raise ValueError("`num_runs` should be at least 1.") @@ -347,11 +358,12 @@ class SweepServeArgs: show_stdout=args.show_stdout, serve_params=serve_params, bench_params=bench_params, + link_vars=link_vars, output_dir=Path(args.output_dir), + experiment_name=experiment_name, num_runs=num_runs, dry_run=args.dry_run, resume=args.resume, - link_vars=link_vars, server_ready_timeout=args.server_ready_timeout, ) @@ -388,6 +400,7 @@ class SweepServeArgs: default=300, help="Timeout in seconds to wait for the server to become ready.", ) + parser.add_argument( "--serve-params", type=str, @@ -398,6 +411,16 @@ class SweepServeArgs: "If both `serve_params` and `bench_params` are given, " "this script will iterate over their Cartesian product.", ) + parser.add_argument( + "--link-vars", + type=str, + default="", + help=( + "Comma-separated list of linked variables between serve and bench, " + "e.g. max_num_seqs=max_concurrency,max_model_len=random_input_len" + ), + ) + parser.add_argument( "--bench-params", type=str, @@ -413,7 +436,15 @@ class SweepServeArgs: "--output-dir", type=str, default="results", - help="The directory to which results are written.", + help="The main directory to which results are written.", + ) + parser.add_argument( + "-e", + "--experiment-name", + type=str, + default=None, + help="The name of this experiment (defaults to current timestamp). " + "Results will be stored under `output_dir/experiment_name`.", ) parser.add_argument( "--num-runs", @@ -429,21 +460,10 @@ class SweepServeArgs: ) parser.add_argument( "--resume", - type=str, - default=None, - help="Set this to the name of a directory under `output_dir` (which is a " - "timestamp) to resume a previous execution of this script, i.e., only run " - "parameter combinations for which there are still no output files.", - ) - - parser.add_argument( - "--link-vars", - type=str, - default="", - help=( - "Comma-separated list of linked variables between serve and bench, " - "e.g. max_num_seqs=max_concurrency,max_model_len=random_input_len" - ), + action="store_true", + help="Resume a previous execution of this script, i.e., only run " + "parameter combinations for which there are still no output files " + "under `output_dir/experiment_name`.", ) return parser @@ -458,33 +478,52 @@ class SweepServeArgs: pairs.append((a.strip(), b.strip())) return pairs + def resolve_experiment_dir(self) -> Path: + experiment_dir = self.output_dir / self.experiment_name + + if self.resume: + if not experiment_dir.exists(): + raise ValueError(f"Cannot resume from non-existent {experiment_dir=}") + else: + if experiment_dir.exists(): + raise ValueError(f"Cannot overwrite existing {experiment_dir=}") + + return experiment_dir + + @contextmanager + def run_ctx(self, experiment_dir: Path): + if self.dry_run: + yield + print(f"Experiment will be saved at: {experiment_dir}") + return + + try: + yield + print(f"Experiment has been saved at: {experiment_dir}") + except BaseException as exc: + raise RuntimeError( + "The script was terminated early. Use `--resume` " + "to continue the script from its last checkpoint." + ) from exc + def run_main(args: SweepServeArgs): - timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S") - output_dir = args.output_dir / timestamp + experiment_dir = args.resolve_experiment_dir() - if args.resume and not output_dir.exists(): - raise ValueError(f"Cannot resume from non-existent directory ({output_dir})") - - try: + with args.run_ctx(experiment_dir): return run_combs( serve_cmd=args.serve_cmd, bench_cmd=args.bench_cmd, + link_vars=args.link_vars, after_bench_cmd=args.after_bench_cmd, show_stdout=args.show_stdout, server_ready_timeout=args.server_ready_timeout, serve_params=args.serve_params, bench_params=args.bench_params, - output_dir=output_dir, + experiment_dir=experiment_dir, num_runs=args.num_runs, dry_run=args.dry_run, - link_vars=args.link_vars, ) - except BaseException as exc: - raise RuntimeError( - f"The script was terminated early. Use `--resume {timestamp}` " - f"to continue the script from its last checkpoint." - ) from exc def main(args: argparse.Namespace): diff --git a/vllm/benchmarks/sweep/serve_sla.py b/vllm/benchmarks/sweep/serve_sla.py deleted file mode 100644 index 89169ec..0000000 --- a/vllm/benchmarks/sweep/serve_sla.py +++ /dev/null @@ -1,305 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import argparse -import math -from dataclasses import asdict, dataclass -from datetime import datetime -from pathlib import Path -from typing import ClassVar, Literal, get_args - -import numpy as np -from typing_extensions import assert_never - -from vllm.utils.import_utils import PlaceholderModule - -from .param_sweep import ParameterSweep, ParameterSweepItem -from .serve import ( - SweepServeArgs, - _get_comb_base_path, - run_comb, - server_ctx, -) -from .server import ServerProcess - -try: - import pandas as pd -except ImportError: - pd = PlaceholderModule("pandas") - - -SLAVariable = Literal["request_rate", "max_concurrency"] - - -def _estimate_sla_value(run_data: dict[str, object], sla_variable: SLAVariable): - request_throughput = float(run_data["request_throughput"]) # type: ignore - if sla_variable == "request_rate": - return request_throughput - if sla_variable == "max_concurrency": - mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore - return request_throughput * mean_latency_ms / 1000 - - assert_never(sla_variable) - - -def _estimate_sla_avg(runs: list[dict[str, object]], sla_variable: SLAVariable): - return sum(_estimate_sla_value(run, sla_variable) for run in runs) / len(runs) - - -def run_comb_sla( - server: ServerProcess | None, - bench_cmd: list[str], - *, - serve_comb: ParameterSweepItem, - bench_comb: ParameterSweepItem, - output_dir: Path, - num_runs: int, - dry_run: bool, - link_vars: list[tuple[str, str]], - sla_variable: SLAVariable, - sla_value: int, -) -> list[dict[str, object]] | None: - bench_comb_sla = bench_comb | {sla_variable: sla_value} - - return run_comb( - server, - bench_cmd, - serve_comb=serve_comb, - bench_comb=bench_comb_sla, - base_path=_get_comb_base_path(output_dir, serve_comb, bench_comb_sla), - num_runs=num_runs, - dry_run=dry_run, - link_vars=link_vars, - ) - - -def explore_sla( - server: ServerProcess | None, - bench_cmd: list[str], - *, - serve_comb: ParameterSweepItem, - bench_comb: ParameterSweepItem, - sla_variable: SLAVariable, - sla_iters: int, - output_dir: Path, - num_runs: int, - dry_run: bool, - link_vars: list[tuple[str, str]], -): - print("[SLA START]") - print(f"Serve parameters: {serve_comb.as_text() or '(None)'}") - print(f"Bench parameters: {bench_comb.as_text() or '(None)'}") - print(f"Number of SLA iterations: {sla_iters}") - - if sla_iters < 2: - raise ValueError("`sla_iters` should be at least 2") - - serial_comb_data = run_comb_sla( - server, - bench_cmd, - serve_comb=serve_comb, - bench_comb=bench_comb, - output_dir=output_dir, - num_runs=num_runs, - dry_run=dry_run, - link_vars=link_vars, - sla_variable=sla_variable, - sla_value=1, - ) - batch_comb_data = run_comb_sla( - server, - bench_cmd, - serve_comb=serve_comb, - bench_comb=bench_comb, - output_dir=output_dir, - num_runs=num_runs, - dry_run=dry_run, - link_vars=link_vars, - sla_variable=sla_variable, - sla_value=int(bench_comb.get("num_prompts", 1000)), # type: ignore - ) - - if serial_comb_data is None or batch_comb_data is None: - if dry_run: - print("Omitting intermediate SLA iterations.") - print("[SLA END]") - - return - - serial_sla_value = math.ceil(_estimate_sla_avg(serial_comb_data, sla_variable)) - print(f"Serial inference: {sla_variable}={serial_sla_value}") - - batch_sla_value = math.floor(_estimate_sla_avg(batch_comb_data, sla_variable)) - print(f"Batch inference: {sla_variable}={batch_sla_value}") - - # Avoid duplicated runs for intermediate values if the range between - # `serial_sla_value` and `batch_sla_value` is small - inter_sla_values = np.linspace(serial_sla_value, batch_sla_value, sla_iters)[1:-1] - inter_sla_values = sorted(set(map(round, inter_sla_values))) - - inter_combs_data: list[dict[str, object]] = [] - for inter_sla_value in inter_sla_values: - print(f"Exploring: {sla_variable}={inter_sla_value}") - inter_comb_data = run_comb_sla( - server, - bench_cmd, - serve_comb=serve_comb, - bench_comb=bench_comb, - output_dir=output_dir, - num_runs=num_runs, - dry_run=dry_run, - link_vars=link_vars, - sla_variable=sla_variable, - sla_value=inter_sla_value, - ) - if inter_comb_data is not None: - inter_combs_data.extend(inter_comb_data) - - print("[SLA END]") - - return serial_comb_data + inter_combs_data + batch_comb_data - - -def run_slas( - serve_cmd: list[str], - bench_cmd: list[str], - after_bench_cmd: list[str], - *, - show_stdout: bool, - server_ready_timeout: int, - serve_params: ParameterSweep, - bench_params: ParameterSweep, - sla_variable: SLAVariable, - sla_iters: int, - output_dir: Path, - num_runs: int, - dry_run: bool, - link_vars: list[tuple[str, str]], -): - if any(bench_comb.has_param(sla_variable) for bench_comb in bench_params): - raise ValueError( - f"You should not override `{sla_variable}` in `bench_params` in SLA mode, " - "since it is supposed to be determined automatically." - ) - - all_data = list[dict[str, object]]() - for serve_comb in serve_params: - with server_ctx( - serve_cmd, - after_bench_cmd, - show_stdout=show_stdout, - server_ready_timeout=server_ready_timeout, - serve_comb=serve_comb, - bench_params=bench_params, - output_dir=output_dir, - dry_run=dry_run, - ) as server: - for bench_comb in bench_params: - comb_data = explore_sla( - server, - bench_cmd, - serve_comb=serve_comb, - bench_comb=bench_comb, - sla_variable=sla_variable, - sla_iters=sla_iters, - output_dir=output_dir, - num_runs=num_runs, - dry_run=dry_run, - link_vars=link_vars, - ) - - if comb_data is not None: - all_data.extend(comb_data) - - if dry_run: - return None - - combined_df = pd.DataFrame.from_records(all_data) - combined_df.to_csv(output_dir / "summary.csv") - - return combined_df - - -@dataclass -class SweepServeSLAArgs(SweepServeArgs): - sla_variable: SLAVariable - sla_iters: int - - parser_name: ClassVar[str] = "serve_sla" - parser_help: ClassVar[str] = ( - "Explore the latency-throughput space for determining SLAs." - ) - - @classmethod - def from_cli_args(cls, args: argparse.Namespace): - # NOTE: Don't use super() as `from_cli_args` calls `cls()` - base_args = SweepServeArgs.from_cli_args(args) - - return cls( - **asdict(base_args), - sla_variable=args.sla_variable, - sla_iters=args.sla_iters, - ) - - @classmethod - def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: - parser = super().add_cli_args(parser) - - sla_group = parser.add_argument_group("sla options") - sla_group.add_argument( - "--sla-variable", - type=str, - choices=get_args(SLAVariable), - default="request_rate", - help="The variable to adjust in each iteration.", - ) - sla_group.add_argument( - "--sla-iters", - type=int, - default=10, - help="Number of iterations used to explore the latency-throughput space. " - "This includes the first two iterations used to interpolate the value of " - "`sla_variable` for remaining iterations.", - ) - - return parser - - -def run_main(args: SweepServeSLAArgs): - timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S") - output_dir = args.output_dir / timestamp - - if args.resume and not output_dir.exists(): - raise ValueError(f"Cannot resume from non-existent directory ({output_dir})") - - try: - return run_slas( - serve_cmd=args.serve_cmd, - bench_cmd=args.bench_cmd, - after_bench_cmd=args.after_bench_cmd, - show_stdout=args.show_stdout, - server_ready_timeout=args.server_ready_timeout, - serve_params=args.serve_params, - bench_params=args.bench_params, - sla_variable=args.sla_variable, - sla_iters=args.sla_iters, - output_dir=output_dir, - num_runs=args.num_runs, - dry_run=args.dry_run, - link_vars=args.link_vars, - ) - except BaseException as exc: - raise RuntimeError( - f"The script was terminated early. Use `--resume {timestamp}` " - f"to continue the script from its last checkpoint." - ) from exc - - -def main(args: argparse.Namespace): - run_main(SweepServeSLAArgs.from_cli_args(args)) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description=SweepServeSLAArgs.parser_help) - SweepServeSLAArgs.add_cli_args(parser) - - main(parser.parse_args()) diff --git a/vllm/benchmarks/sweep/serve_workload.py b/vllm/benchmarks/sweep/serve_workload.py new file mode 100644 index 0000000..ca7ba09 --- /dev/null +++ b/vllm/benchmarks/sweep/serve_workload.py @@ -0,0 +1,328 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import math +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import ClassVar, Literal, get_args + +import numpy as np +from typing_extensions import assert_never + +from vllm.benchmarks.datasets import DEFAULT_NUM_PROMPTS +from vllm.utils.import_utils import PlaceholderModule + +from .param_sweep import ParameterSweep, ParameterSweepItem +from .serve import ( + SweepServeArgs, + _get_comb_base_path, + run_comb, + server_ctx, +) +from .server import ServerProcess + +try: + import pandas as pd +except ImportError: + pd = PlaceholderModule("pandas") + + +WorkloadVariable = Literal["request_rate", "max_concurrency"] + + +def _estimate_workload_value( + run_data: dict[str, object], + workload_var: WorkloadVariable, +): + request_throughput = float(run_data["request_throughput"]) # type: ignore + if workload_var == "request_rate": + return request_throughput + if workload_var == "max_concurrency": + mean_latency_ms = float(run_data["mean_e2el_ms"]) # type: ignore + return request_throughput * mean_latency_ms / 1000 + + assert_never(workload_var) + + +def _estimate_workload_avg( + runs: list[dict[str, object]], + workload_var: WorkloadVariable, +): + total = sum(_estimate_workload_value(run, workload_var) for run in runs) + return total / len(runs) + + +def run_comb_workload( + server: ServerProcess | None, + bench_cmd: list[str], + *, + serve_comb: ParameterSweepItem, + bench_comb: ParameterSweepItem, + link_vars: list[tuple[str, str]], + experiment_dir: Path, + num_runs: int, + dry_run: bool, + workload_var: WorkloadVariable, + workload_value: int, +) -> list[dict[str, object]] | None: + bench_comb_workload = bench_comb | {workload_var: workload_value} + + return run_comb( + server, + bench_cmd, + serve_comb=serve_comb, + bench_comb=bench_comb_workload, + link_vars=link_vars, + base_path=_get_comb_base_path( + experiment_dir, + serve_comb, + bench_comb, + extra_parts=("WL-", f"{workload_var}={workload_value}"), + ), + num_runs=num_runs, + dry_run=dry_run, + ) + + +def explore_comb_workloads( + server: ServerProcess | None, + bench_cmd: list[str], + *, + serve_comb: ParameterSweepItem, + bench_comb: ParameterSweepItem, + link_vars: list[tuple[str, str]], + workload_var: WorkloadVariable, + workload_iters: int, + experiment_dir: Path, + num_runs: int, + dry_run: bool, +): + print("[WL START]") + print(f"Serve parameters: {serve_comb.as_text() or '(None)'}") + print(f"Bench parameters: {bench_comb.as_text() or '(None)'}") + print(f"Number of workload iterations: {workload_iters}") + + if workload_iters < 2: + raise ValueError("`workload_iters` should be at least 2") + + dataset_size = DEFAULT_NUM_PROMPTS + if "num_prompts" in bench_comb: + dataset_size = int(bench_comb["num_prompts"]) # type: ignore + else: + for i, arg in enumerate(bench_cmd): + if arg == "--num-prompts" and i + 1 < len(bench_cmd): + dataset_size = int(bench_cmd[i + 1]) + break + elif arg.startswith("--num-prompts="): + dataset_size = int(arg.split("=", 1)[1]) + break + + print(f"Dataset size: {dataset_size}") + + serial_workload_data = run_comb_workload( + server, + bench_cmd, + serve_comb=serve_comb, + bench_comb=bench_comb | {"max_concurrency": 1}, + link_vars=link_vars, + experiment_dir=experiment_dir, + num_runs=num_runs, + dry_run=dry_run, + workload_var=workload_var, + workload_value=1, + ) + batch_workload_data = run_comb_workload( + server, + bench_cmd, + serve_comb=serve_comb, + bench_comb=bench_comb | {"max_concurrency": dataset_size}, + link_vars=link_vars, + experiment_dir=experiment_dir, + num_runs=num_runs, + dry_run=dry_run, + workload_var=workload_var, + workload_value=dataset_size, + ) + + if serial_workload_data is None or batch_workload_data is None: + if dry_run: + print("Omitting intermediate Workload iterations.") + print("[WL END]") + + return + + serial_workload_value = math.ceil( + _estimate_workload_avg(serial_workload_data, workload_var) + ) + print(f"Serial inference: {workload_var}={serial_workload_value}") + + batch_workload_value = math.floor( + _estimate_workload_avg(batch_workload_data, workload_var) + ) + print(f"Batch inference: {workload_var}={batch_workload_value}") + + # Avoid duplicated runs for intermediate values if the range between + # `serial_workload_value` and `batch_workload_value` is small + inter_workload_values = np.linspace( + serial_workload_value, batch_workload_value, workload_iters + )[1:-1] + inter_workload_values = sorted(set(map(round, inter_workload_values))) + + inter_workloads_data: list[dict[str, object]] = [] + for inter_workload_value in inter_workload_values: + print(f"Exploring: {workload_var}={inter_workload_value}") + inter_workload_data = run_comb_workload( + server, + bench_cmd, + serve_comb=serve_comb, + bench_comb=bench_comb, + link_vars=link_vars, + experiment_dir=experiment_dir, + num_runs=num_runs, + dry_run=dry_run, + workload_var=workload_var, + workload_value=inter_workload_value, + ) + if inter_workload_data is not None: + inter_workloads_data.extend(inter_workload_data) + + print("[WL END]") + + return serial_workload_data + inter_workloads_data + batch_workload_data + + +def explore_combs_workloads( + serve_cmd: list[str], + bench_cmd: list[str], + after_bench_cmd: list[str], + *, + show_stdout: bool, + server_ready_timeout: int, + serve_params: ParameterSweep, + bench_params: ParameterSweep, + link_vars: list[tuple[str, str]], + workload_var: WorkloadVariable, + workload_iters: int, + experiment_dir: Path, + num_runs: int, + dry_run: bool, +): + if any(bench_comb.has_param(workload_var) for bench_comb in bench_params): + raise ValueError( + f"You should not override `{workload_var}` in `bench_params` " + "since it is supposed to be explored automatically." + ) + + all_data = list[dict[str, object]]() + for serve_comb in serve_params: + with server_ctx( + serve_cmd, + after_bench_cmd, + show_stdout=show_stdout, + server_ready_timeout=server_ready_timeout, + serve_comb=serve_comb, + bench_params=bench_params, + experiment_dir=experiment_dir, + dry_run=dry_run, + ) as server: + for bench_comb in bench_params: + comb_data = explore_comb_workloads( + server, + bench_cmd, + serve_comb=serve_comb, + bench_comb=bench_comb, + link_vars=link_vars, + workload_var=workload_var, + workload_iters=workload_iters, + experiment_dir=experiment_dir, + num_runs=num_runs, + dry_run=dry_run, + ) + + if comb_data is not None: + all_data.extend(comb_data) + + if dry_run: + return None + + combined_df = pd.DataFrame.from_records(all_data) + combined_df.to_csv(experiment_dir / "summary.csv") + + return combined_df + + +@dataclass +class SweepServeWorkloadArgs(SweepServeArgs): + workload_var: WorkloadVariable + workload_iters: int + + parser_name: ClassVar[str] = "serve_workload" + parser_help: ClassVar[str] = ( + "Explore the latency-throughput tradeoff for different workload levels." + ) + + @classmethod + def from_cli_args(cls, args: argparse.Namespace): + # NOTE: Don't use super() as `from_cli_args` calls `cls()` + base_args = SweepServeArgs.from_cli_args(args) + + return cls( + **asdict(base_args), + workload_var=args.workload_var, + workload_iters=args.workload_iters, + ) + + @classmethod + def add_cli_args(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + parser = super().add_cli_args(parser) + + workload_group = parser.add_argument_group("workload options") + workload_group.add_argument( + "--workload-var", + type=str, + choices=get_args(WorkloadVariable), + default="request_rate", + help="The variable to adjust in each iteration.", + ) + workload_group.add_argument( + "--workload-iters", + type=int, + default=10, + help="Number of workload levels to explore. " + "This includes the first two iterations used to interpolate the value of " + "`workload_var` for remaining iterations.", + ) + + return parser + + +def run_main(args: SweepServeWorkloadArgs): + experiment_dir = args.resolve_experiment_dir() + + with args.run_ctx(experiment_dir): + return explore_combs_workloads( + serve_cmd=args.serve_cmd, + bench_cmd=args.bench_cmd, + after_bench_cmd=args.after_bench_cmd, + show_stdout=args.show_stdout, + server_ready_timeout=args.server_ready_timeout, + serve_params=args.serve_params, + bench_params=args.bench_params, + link_vars=args.link_vars, + workload_var=args.workload_var, + workload_iters=args.workload_iters, + experiment_dir=experiment_dir, + num_runs=args.num_runs, + dry_run=args.dry_run, + ) + + +def main(args: argparse.Namespace): + run_main(SweepServeWorkloadArgs.from_cli_args(args)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=SweepServeWorkloadArgs.parser_help) + SweepServeWorkloadArgs.add_cli_args(parser) + + main(parser.parse_args()) diff --git a/vllm/benchmarks/sweep/startup.py b/vllm/benchmarks/sweep/startup.py index b4d979b..6f5217e 100644 --- a/vllm/benchmarks/sweep/startup.py +++ b/vllm/benchmarks/sweep/startup.py @@ -4,6 +4,7 @@ import argparse import json import shlex import subprocess +from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime from functools import lru_cache @@ -111,7 +112,7 @@ def _apply_output_json(cmd: list[str], output_path: Path) -> list[str]: def _get_comb_base_path( - output_dir: Path, + experiment_dir: Path, serve_comb: ParameterSweepItem, startup_comb: ParameterSweepItem, ) -> Path: @@ -120,7 +121,8 @@ def _get_comb_base_path( parts.extend(("SERVE-", serve_comb.name)) if startup_comb: parts.extend(("STARTUP-", startup_comb.name)) - return output_dir / sanitize_filename("-".join(parts)) + + return experiment_dir / sanitize_filename("-".join(parts)) def _get_comb_run_path(base_path: Path, run_number: int | None) -> Path: @@ -225,7 +227,7 @@ def run_combs( *, serve_params: ParameterSweep, startup_params: ParameterSweep, - output_dir: Path, + experiment_dir: Path, num_runs: int, show_stdout: bool, dry_run: bool, @@ -233,7 +235,7 @@ def run_combs( all_data = list[dict[str, object]]() for serve_comb in serve_params: for startup_comb in startup_params: - base_path = _get_comb_base_path(output_dir, serve_comb, startup_comb) + base_path = _get_comb_base_path(experiment_dir, serve_comb, startup_comb) comb_data = run_comb( startup_cmd, serve_comb=serve_comb, @@ -250,7 +252,7 @@ def run_combs( return None combined_df = pd.DataFrame.from_records(all_data) - combined_df.to_csv(output_dir / "summary.csv") + combined_df.to_csv(experiment_dir / "summary.csv") return combined_df @@ -260,11 +262,11 @@ class SweepStartupArgs: serve_params: ParameterSweep startup_params: ParameterSweep output_dir: Path + experiment_name: str num_runs: int show_stdout: bool dry_run: bool - resume: str | None - strict_params: bool + resume: bool parser_name: ClassVar[str] = "startup" parser_help: ClassVar[str] = ( @@ -286,13 +288,19 @@ class SweepStartupArgs: startup_params = ParameterSweep.from_records([{}]) supported = _get_supported_startup_keys() + strict_params = args.strict_params serve_params = _filter_params( - serve_params, supported=supported, strict=args.strict_params + serve_params, supported=supported, strict=strict_params ) startup_params = _filter_params( - startup_params, supported=supported, strict=args.strict_params + startup_params, supported=supported, strict=strict_params ) + if args.experiment_name: + experiment_name = args.experiment_name + else: + experiment_name = datetime.now().strftime("%Y%m%d_%H%M%S") + if args.num_runs < 1: raise ValueError("`num_runs` should be at least 1.") @@ -301,11 +309,11 @@ class SweepStartupArgs: serve_params=serve_params, startup_params=startup_params, output_dir=Path(args.output_dir), + experiment_name=experiment_name, num_runs=args.num_runs, show_stdout=args.show_stdout, dry_run=args.dry_run, resume=args.resume, - strict_params=args.strict_params, ) @classmethod @@ -316,6 +324,7 @@ class SweepStartupArgs: default="vllm bench startup", help="The command used to run the startup benchmark.", ) + parser.add_argument( "--serve-params", type=str, @@ -331,12 +340,27 @@ class SweepStartupArgs: help="Path to JSON file containing parameter combinations " "for the `vllm bench startup` command.", ) + parser.add_argument( + "--strict-params", + action="store_true", + help="If set, unknown parameters in sweep files raise an error " + "instead of being ignored.", + ) + parser.add_argument( "-o", "--output-dir", type=str, default="results", - help="The directory to which results are written.", + help="The main directory to which results are written.", + ) + parser.add_argument( + "-e", + "--experiment-name", + type=str, + default=None, + help="The name of this experiment (defaults to current timestamp). " + "Results will be stored under `output_dir/experiment_name`.", ) parser.add_argument( "--num-runs", @@ -357,43 +381,56 @@ class SweepStartupArgs: ) parser.add_argument( "--resume", - type=str, - default=None, - help="Set this to the name of a directory under `output_dir` (which is a " - "timestamp) to resume a previous execution of this script, i.e., only run " - "parameter combinations for which there are still no output files.", - ) - parser.add_argument( - "--strict-params", action="store_true", - help="If set, unknown parameters in sweep files raise an error " - "instead of being ignored.", + help="Resume a previous execution of this script, i.e., only run " + "parameter combinations for which there are still no output files " + "under `output_dir/experiment_name`.", ) + return parser + def resolve_experiment_dir(self) -> Path: + experiment_dir = self.output_dir / self.experiment_name + + if self.resume: + if not experiment_dir.exists(): + raise ValueError(f"Cannot resume from non-existent {experiment_dir=}") + else: + if experiment_dir.exists(): + raise ValueError(f"Cannot overwrite existing {experiment_dir=}") + + return experiment_dir + + @contextmanager + def run_ctx(self, experiment_dir: Path): + if self.dry_run: + yield + print(f"Experiment will be saved at: {experiment_dir}") + return + + try: + yield + print(f"Experiment has been saved at: {experiment_dir}") + except BaseException as exc: + raise RuntimeError( + "The script was terminated early. Use `--resume` " + "to continue the script from its last checkpoint." + ) from exc + def run_main(args: SweepStartupArgs): - timestamp = args.resume or datetime.now().strftime("%Y%m%d_%H%M%S") - output_dir = args.output_dir / timestamp + experiment_dir = args.resolve_experiment_dir() - if args.resume and not output_dir.exists(): - raise ValueError(f"Cannot resume from non-existent directory ({output_dir})") - - try: + with args.run_ctx(experiment_dir): return run_combs( startup_cmd=args.startup_cmd, serve_params=args.serve_params, startup_params=args.startup_params, - output_dir=output_dir, + experiment_dir=experiment_dir, num_runs=args.num_runs, show_stdout=args.show_stdout, dry_run=args.dry_run, ) - except BaseException as exc: - raise RuntimeError( - f"The script was terminated early. Use `--resume {timestamp}` " - f"to continue the script from its last checkpoint." - ) from exc def main(args: argparse.Namespace): diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index b974ca0..843c042 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -282,61 +282,13 @@ class CompilerManager: maybe_key += f"{compile_range.start}_{compile_range.end}" maybe_key += f"_subgraph_{graph_index}" with self.compile_context(compile_range): - # There is a compilation time optimization here. - # - # If the (input metadata, graph, compiler config) are the same, then - # we want to avoid compiling the same artifact again. If we didn't - # do this optimization, the backend compilation (InductorAdaptor or - # InductorStandaloneAdaptor) - # is able to cache hit and produce an artifact faster if it was - # already created, but it is still a duplicate artifact that - # requires unnecessary things e.g. disk IO. - # - # The optimization is: If the backend compilation cache hits, - # then do an early return from the backend compilation and look up - # which of the previous in-memory artifacts we created to reuse. - # - # We implemented this by monkey-patching torch (torch does not - # easily expose the cache_key function), but in the future torch - # should expose the cache_key function that we can just call - # directly before invoking backend compilation. - cache_key = None - orig = torch._functorch._aot_autograd.autograd_cache.autograd_cache_key - - def autograd_cache_key(*args, **kwargs): - result = orig(*args, **kwargs) - if result is None: - return None - nonlocal cache_key - cache_key = result[0] - if cache_key in self.loaded_artifacts: - raise StopCompiling() - return result - - from unittest.mock import patch - - with ( - # Graphs that are isometric (different node names but same - # structure) should be treated as the same. - torch._functorch.config.patch(autograd_cache_normalize_inputs=True), - patch( - "torch._functorch._aot_autograd.autograd_cache.autograd_cache_key", - autograd_cache_key, - ), - ): - try: - compiled_graph, handle = self.compiler.compile( - graph, - example_inputs, - additional_inductor_config, - compile_range, - maybe_key, - ) - except StopCompiling: - assert cache_key is not None - return self.loaded_artifacts[cache_key] - if cache_key is not None and compiled_graph is not None: - self.loaded_artifacts[cache_key] = compiled_graph + compiled_graph, handle = self.compiler.compile( + graph, + example_inputs, + additional_inductor_config, + compile_range, + maybe_key, + ) assert compiled_graph is not None, "Failed to compile the graph" @@ -497,7 +449,7 @@ def wrap_with_cudagraph_if_needed( # it from the FULL cudagraph runtime mode, no matter it # is wrapped on a full or piecewise fx graph. return static_graph_wrapper_class( - runnable=piecewise_backend, + runnable=piecewise_backend.graph.forward, vllm_config=vllm_config, runtime_mode=CUDAGraphMode.PIECEWISE, cudagraph_options=CUDAGraphOptions( @@ -780,7 +732,7 @@ class VllmBackend: return standalone_compile_artifacts, sym_shape_indices_map, returns_tuple_map def configure_post_pass(self) -> None: - # self.pass_manager.configure(self.vllm_config) + self.pass_manager.configure(self.vllm_config) # Post-grad custom passes are run using the post_grad_custom_post_pass # hook. If a pass for that hook exists, add it to the pass manager. @@ -846,7 +798,7 @@ class VllmBackend: ), ) - def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any: + def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any], **kwargs) -> Any: from .caching import ( VllmSerializableFunction, ) @@ -988,7 +940,7 @@ class VllmBackend: assert not self._called, "VllmBackend can only be called once" self.graph = graph - self.configure_post_pass() + # self.configure_post_pass() if self.compilation_config.use_inductor_graph_partition: # Let Inductor decide partitioning; avoid FX-level pre-splitting. diff --git a/vllm/compilation/caching.py b/vllm/compilation/caching.py index 07f9db4..3917a4f 100644 --- a/vllm/compilation/caching.py +++ b/vllm/compilation/caching.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib import hashlib import inspect import os @@ -144,6 +145,18 @@ class StandaloneCompiledArtifacts: self.loaded_submodule_store = {} +@contextlib.contextmanager +def patch_pytree_map_over_slice(): + pytree._private_register_pytree_node( + slice, lambda x: ([x.start, x.stop, x.step], None), lambda x, c: slice(*x) + ) + + try: + yield + finally: + pytree._deregister_pytree_node(slice) + + class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] """ A wrapper around a compiled function by vllm. It will forward the tensor @@ -235,7 +248,10 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] lambda inp: torch.empty_like(inp, device="meta"), state["example_inputs"], ) - with patch.object(GraphPickler, "reducer_override", _graph_reducer_override): + with ( + patch.object(GraphPickler, "reducer_override", _graph_reducer_override), + patch_pytree_map_over_slice(), + ): state["graph_module"] = GraphPickler.dumps( state["graph_module"], Options(ops_filter=None) ) @@ -261,7 +277,8 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] state = pickle.loads(data) fake_mode = FakeTensorMode(shape_env=ShapeEnv()) - state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode) + with patch_pytree_map_over_slice(): + state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode) state["graph_module"].recompile() state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index c00486a..b6fa5d1 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -184,6 +184,47 @@ def is_compile_cache_enabled( ) +def _patch_standalone_compile_atomic_save() -> None: + """Backport of pytorch/pytorch#162432 for torch < 2.10.0. + + Patches CompiledArtifact.save() to use write_atomic for binary format, + preventing corrupt cache files when multiple processes compile + concurrently. + """ + from torch._inductor.codecache import write_atomic + from torch._inductor.standalone_compile import CompiledArtifact as cls + + if getattr(cls.save, "_vllm_patched", False): + return + + original_save = cls.save + + def _save( + self: Any, *, path: str, format: Literal["binary", "unpacked"] = "binary" + ) -> None: + if format != "binary": + return original_save(self, path=path, format=format) + from torch._dynamo.utils import dynamo_timed + from torch._inductor.codecache import torch_key + from torch.utils._appending_byte_serializer import BytesWriter + + with dynamo_timed("CompiledArtifact.save"): + assert self._artifacts is not None + artifact_bytes, cache_info = self._artifacts + assert len(cache_info.aot_autograd_artifacts) == 1, cache_info + key = cache_info.aot_autograd_artifacts[0] + assert not os.path.isdir(path) + writer = BytesWriter() + writer.write_bytes(torch_key()) + writer.write_str(key) + writer.write_bytes(artifact_bytes) + write_atomic(path, writer.to_bytes()) + + _save._vllm_patched = True # type: ignore[attr-defined] + cls.save = _save # type: ignore[assignment] + logger.debug("Patched %s.save for atomic writes (torch < 2.10)", cls.__name__) + + class InductorStandaloneAdaptor(CompilerInterface): """ The adaptor for the Inductor compiler. @@ -197,6 +238,8 @@ class InductorStandaloneAdaptor(CompilerInterface): name = "inductor_standalone" def __init__(self, save_format: Literal["binary", "unpacked"]) -> None: + if not is_torch_equal_or_newer("2.10.0"): + _patch_standalone_compile_atomic_save() self.save_format = save_format def compute_hash(self, vllm_config: VllmConfig) -> str: @@ -224,7 +267,7 @@ class InductorStandaloneAdaptor(CompilerInterface): if compiler_config is not None: current_config.update(compiler_config) set_inductor_config(current_config, compile_range) - set_functorch_config() + # set_functorch_config() if compile_range.is_single_size(): dynamic_shapes = "from_example_inputs" @@ -325,6 +368,7 @@ class InductorStandaloneAdaptor(CompilerInterface): inductor_compiled_graph = torch._inductor.CompiledArtifact.load( path=path, format=self.save_format ) + compilation_counter.num_compiled_artifacts_loaded += 1 from torch._inductor.compile_fx import graph_returns_tuple returns_tuple = graph_returns_tuple(graph) @@ -395,7 +439,7 @@ class InductorAdaptor(CompilerInterface): current_config["fx_graph_remote_cache"] = False set_inductor_config(current_config, compile_range) - set_functorch_config() + # set_functorch_config() # inductor can inplace modify the graph, so we need to copy it # see https://github.com/pytorch/pytorch/issues/138980 diff --git a/vllm/compilation/counter.py b/vllm/compilation/counter.py index 29d3045..2ed49b9 100644 --- a/vllm/compilation/counter.py +++ b/vllm/compilation/counter.py @@ -29,6 +29,8 @@ class CompilationCounter: num_cache_entries_updated: int = 0 # The number of standalone_compile compiled artifacts saved num_compiled_artifacts_saved: int = 0 + # The number of standalone_compile compiled artifacts loaded from cache + num_compiled_artifacts_loaded: int = 0 # Number of times a model was loaded with CompilationMode.STOCK_TORCH_COMPILE stock_torch_compile_count: int = 0 diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index 2014808..dfc9dc8 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -21,6 +21,7 @@ from vllm.model_executor.offloader.base import get_offloader from vllm.platforms import current_platform from vllm.utils.torch_utils import current_stream, weak_ref_tensors from vllm.sequence import IntermediateTensors + logger = init_logger(__name__) @@ -204,14 +205,14 @@ class CUDAGraphWrapper: def unwrap(self) -> Callable[..., Any]: # in case we need to access the original runnable. return self.runnable - + def weak_ref_tensors_with_intermediate(self, output): if isinstance(output, IntermediateTensors): intermediate_states = IntermediateTensors( tensors={key: weak_ref_tensors(value) for key, value in output.tensors.items()}) return intermediate_states return weak_ref_tensors(output) - + def __call__(self, *args: Any, **kwargs: Any) -> Any | None: forward_context = get_forward_context() batch_descriptor = forward_context.batch_descriptor @@ -298,12 +299,10 @@ class CUDAGraphWrapper: # the last graph in piecewise cuadgraph mode, because # the output of the last graph will not be used by # any other cuda graph. - # output = weak_ref_tensors(output) output = self.weak_ref_tensors_with_intermediate(output) # here we always use weak ref for the output # to save memory - # entry.output = weak_ref_tensors(output) entry.output = self.weak_ref_tensors_with_intermediate(output) entry.cudagraph = cudagraph diff --git a/vllm/compilation/passes/fusion/collective_fusion.py b/vllm/compilation/passes/fusion/collective_fusion.py index 55a5a2e..a9b64ad 100644 --- a/vllm/compilation/passes/fusion/collective_fusion.py +++ b/vllm/compilation/passes/fusion/collective_fusion.py @@ -53,7 +53,7 @@ class GEMMReduceScatterPattern(BasePattern): gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter( mul, mm_weight, - "avg", + "sum", scatter_dim=0, group_name=self.tp.device_group.group_name, ) @@ -150,7 +150,7 @@ class ScaledMMReduceScatterPattern(BasePattern): mat2, scale_a, scale_b, - "avg", + "sum", scatter_dim, # orig_scatter_dim scatter_dim, # scatter_dim_after_maybe_reshape self.tp.device_group.group_name, @@ -285,7 +285,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern): mat2, scale_a, scale_b, - "avg", + "sum", scatter_dim, # orig_scatter_dim scatter_dim, # scatter_dim_after_maybe_reshape self.tp.device_group.group_name, diff --git a/vllm/compilation/passes/fusion/rocm_aiter_fusion.py b/vllm/compilation/passes/fusion/rocm_aiter_fusion.py index d8131ce..59c94db 100644 --- a/vllm/compilation/passes/fusion/rocm_aiter_fusion.py +++ b/vllm/compilation/passes/fusion/rocm_aiter_fusion.py @@ -5,7 +5,6 @@ import torch import torch._inductor.pattern_matcher as pm from torch import fx from torch._inductor.pattern_matcher import PatternMatcherPass -from torch._ops import OpOverload import vllm.model_executor.layers.quantization.utils.fp8_utils # noqa: F401 from vllm._aiter_ops import rocm_aiter_ops @@ -15,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, QuantKey, ScaleDesc, + kFp8Dynamic128Sym, ) from vllm.platforms import current_platform @@ -312,7 +312,9 @@ class RocmAiterRMSNormQuantFusionPass(VllmPatternMatcherPass): @VllmInductorPass.time_and_log def __call__(self, graph: fx.Graph) -> None: self.matched_count = self.patterns.apply(graph) - logger.debug("Replaced %s patterns", self.matched_count) + logger.debug( + "%s Replaced %s patterns", self.__class__.__name__, self.matched_count + ) def uuid(self) -> str: fusion_patterns = [ @@ -332,9 +334,11 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern): FUSED_SILU_MUL_QUANT_OP = rocm_aiter_ops.get_act_mul_fused_fp8_group_quant_op() - def __init__(self, quant_op: OpOverload) -> None: + def __init__(self) -> None: self.silu_and_mul_matcher = MatcherSiluAndMul() - self.quant_op = quant_op + self.quant_matcher = MatcherQuantFP8( + quant_key=kFp8Dynamic128Sym, match_rocm_aiter=True + ) def get_inputs(self) -> list[torch.Tensor]: return [ @@ -346,7 +350,7 @@ class AiterSiluMulFp8GroupQuantPattern(ActivationQuantPattern): input: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: at1 = self.silu_and_mul_matcher(input) - at2 = self.quant_op(at1, 128) + at2 = self.quant_matcher(at1) return at2[0], at2[1] def replacement( @@ -370,11 +374,6 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass): https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980 """ - AITER_GROUP_FP8_QUANT_OP = rocm_aiter_ops.get_group_quant_op() - TRITON_GROUP_FP8_QUANT_OP = torch.ops.vllm.triton_per_token_group_quant_fp8.default - - QUANT_OPS = [AITER_GROUP_FP8_QUANT_OP, TRITON_GROUP_FP8_QUANT_OP] - @enable_fake_mode def __init__(self, config: VllmConfig) -> None: super().__init__(config) @@ -383,8 +382,7 @@ class RocmAiterSiluMulFp8GroupQuantFusionPass(VllmPatternMatcherPass): pass_name="rocm_aiter_silu_mul_fp8_group_quant_fusion_pass" ) - for quant_op in self.QUANT_OPS: - AiterSiluMulFp8GroupQuantPattern(quant_op).register(self.patterns) + AiterSiluMulFp8GroupQuantPattern().register(self.patterns) self.dump_patterns(config, self.patterns) diff --git a/vllm/compilation/passes/fusion/sequence_parallelism.py b/vllm/compilation/passes/fusion/sequence_parallelism.py index 63de859..b7ae3dc 100644 --- a/vllm/compilation/passes/fusion/sequence_parallelism.py +++ b/vllm/compilation/passes/fusion/sequence_parallelism.py @@ -18,7 +18,6 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8StaticTensorSym, ) -from vllm.platforms import current_platform from ..inductor_pass import enable_fake_mode from ..utility.noop_elimination import NoOpEliminationPass @@ -215,9 +214,6 @@ class MiddleAllReduceRMSNormPattern(_SequenceParallelPatternHelper): ) -FP8_DTYPE = current_platform.fp8_dtype() - - class FirstAllReduceRMSNormStaticFP8Pattern(_SequenceParallelPatternHelper): def __init__( self, diff --git a/vllm/compilation/passes/utility/fix_functionalization.py b/vllm/compilation/passes/utility/fix_functionalization.py index c7df5f9..1b656d0 100644 --- a/vllm/compilation/passes/utility/fix_functionalization.py +++ b/vllm/compilation/passes/utility/fix_functionalization.py @@ -37,6 +37,14 @@ class FixFunctionalizationPass(VllmInductorPass): self.nodes_to_remove: list[torch.fx.Node] = [] count = 0 + + rope_targets = [torch.ops._C.rotary_embedding.default] + + if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"): + rope_targets.append( + torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default + ) + for node in graph.nodes: if not is_func(node, auto_functionalized): continue # Avoid deep if-elif nesting @@ -44,7 +52,7 @@ class FixFunctionalizationPass(VllmInductorPass): kwargs = node.kwargs at_target = node.args[0] - if at_target == torch.ops._C.rotary_embedding.default: + if at_target in rope_targets: query = kwargs["query"] key = kwargs["key"] getitem_nodes = self.getitem_users(node) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index f9eb245..91d5dc6 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -298,18 +298,18 @@ class PiecewiseBackend: else list(args) ) - with ( - torch._functorch.config.patch("bundled_autograd_cache", True), - ): - range_entry.runnable = self.vllm_backend.compiler_manager.compile( - self.graph, - args_list, - self.vllm_backend.inductor_config, - self.compilation_config, - compile_range=range_entry.compile_range, - graph_index=self.piecewise_compile_index, - num_graphs=self.total_piecewise_compiles, - ) + # with ( + # torch._functorch.config.patch("bundled_autograd_cache", True), + # ): + range_entry.runnable = self.vllm_backend.compiler_manager.compile( + self.graph, + args_list, + self.vllm_backend.inductor_config, + self.compilation_config, + compile_range=range_entry.compile_range, + graph_index=self.piecewise_compile_index, + num_graphs=self.total_piecewise_compiles, + ) range_entry.compiled = True self.to_be_compiled_ranges.remove(range_entry.compile_range) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 850ddae..5dff296 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -319,3 +319,52 @@ class TorchCompileWithNoGuardsWrapper: yield finally: self.__class__.forward.__code__ = original + + +def reset_compile_wrapper(model: torch.nn.Module) -> None: + """ + Clean up compiled model and captured CUDA graphs for elastic EP. + """ + if not isinstance(model, TorchCompileWithNoGuardsWrapper) and hasattr( + model, "model" + ): + model = model.model + if not isinstance(model, TorchCompileWithNoGuardsWrapper): + return + # model.do_not_compile is set by the @support_torch_compile decorator + if hasattr(model, "do_not_compile") and model.do_not_compile: + return + from vllm.compilation.counter import compilation_counter + + # reset the compilation counter + compilation_counter.num_models_seen = 0 + compilation_counter.num_graphs_seen = 0 + compilation_counter.num_piecewise_graphs_seen = 0 + compilation_counter.num_piecewise_capturable_graphs_seen = 0 + compilation_counter.num_backend_compilations = 0 + compilation_counter.num_gpu_runner_capture_triggers = 0 + compilation_counter.num_cudagraph_captured = 0 + compilation_counter.num_inductor_compiles = 0 + compilation_counter.num_eager_compiles = 0 + compilation_counter.num_cache_entries_updated = 0 + compilation_counter.num_compiled_artifacts_saved = 0 + compilation_counter.stock_torch_compile_count = 0 + + # Clear the AOT compiled function so the model is forced to + # recompile on the next call. Without this, decorators.py + # __call__ uses the stale aot_compiled_fn whose torchinductor + # kernels have old parameters (expert_map size for example) + # baked in as compile-time constants. + if hasattr(model, "aot_compiled_fn"): + model.aot_compiled_fn = None + if hasattr(model, "was_aot_compile_fn_loaded_from_disk"): + model.was_aot_compile_fn_loaded_from_disk = False + + # Reset the cache_dir so VllmBackend recomputes the hash + # (data_parallel_size changed, so the config hash differs). + compilation_config = model.vllm_config.compilation_config + compilation_config.cache_dir = "" + compilation_config.local_cache_dir = "" + + model.__class__.forward.__code__ = model.original_code_object() + TorchCompileWithNoGuardsWrapper.__init__(model) diff --git a/vllm/config/attention.py b/vllm/config/attention.py index 97a139c..74bb3d6 100644 --- a/vllm/config/attention.py +++ b/vllm/config/attention.py @@ -16,8 +16,8 @@ class AttentionConfig: backend: AttentionBackendEnum | None = None """Attention backend to use. If None, will be selected automatically.""" - flash_attn_version: Literal[2, 3] | None = None - """Force vllm to use a specific flash-attention version (2 or 3). + flash_attn_version: Literal[2, 3, 4] | None = None + """Force vllm to use a specific flash-attention version (2, 3, or 4). Only valid when using the flash-attention backend.""" use_prefill_decode_attention: bool = False diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index d22e9a9..10fc0a0 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -86,9 +86,16 @@ class CUDAGraphMode(enum.Enum): def separate_routine(self) -> bool: return isinstance(self.value, tuple) + + def decode_use_graph(self) -> bool: + return self.decode_mode() == CUDAGraphMode.FULL - def valid_runtime_modes(self) -> bool: - return self in [CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL] + @classmethod + def valid_runtime_modes(cls) -> frozenset["CUDAGraphMode"]: + return frozenset({cls.NONE, cls.PIECEWISE, cls.FULL}) + + def is_valid_runtime_mode(self) -> bool: + return self in CUDAGraphMode.valid_runtime_modes() def __str__(self) -> str: return self.name @@ -385,7 +392,7 @@ class CompilationConfig: Please use mode. Currently all levels are mapped to mode. """ # Top-level Compilation control - mode: CompilationMode = Field(default=None) + mode: CompilationMode = Field(default=CompilationMode.NONE) """The compilation approach used for torch.compile-based compilation of the model. @@ -503,7 +510,7 @@ class CompilationConfig: constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`.""" # CudaGraph compilation - cudagraph_mode: CUDAGraphMode = Field(default=None) + cudagraph_mode: CUDAGraphMode = Field(default=CUDAGraphMode.FULL_DECODE_ONLY) """ The mode of the cudagraph: @@ -1003,6 +1010,7 @@ class CompilationConfig: # https://github.com/vllm-project/vllm/issues/33267 if not self.use_inductor_graph_partition: self.splitting_ops.append("vllm::unified_kv_cache_update") + self.splitting_ops.append("vllm::unified_mla_kv_cache_update") elif len(self.splitting_ops) == 0: if ( @@ -1045,7 +1053,7 @@ class CompilationConfig: "are optimized for prefill and are incompatible with CUDA Graphs. " "In order to use CUDA Graphs for decode-optimized workloads, " "use --all2all-backend with another option, such as " - "deepep_low_latency, pplx, or allgather_reducescatter." + "deepep_low_latency or allgather_reducescatter." ) self.cudagraph_mode = CUDAGraphMode.NONE diff --git a/vllm/config/model.py b/vllm/config/model.py index 377a84a..3686537 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -50,8 +50,6 @@ from vllm.transformers_utils.utils import maybe_model_redirect from vllm.utils.import_utils import LazyLoader from vllm.v1.attention.backends.registry import AttentionBackendEnum -import os - if TYPE_CHECKING: from transformers import PretrainedConfig @@ -128,6 +126,7 @@ class ModelConfig: - "slow" will always use the slow tokenizer.\n - "mistral" will always use the tokenizer from `mistral_common`.\n - "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n + - "qwen_vl" will always use the tokenizer from `qwen_vl`.\n - Other custom values can be supported via plugins.""" trust_remote_code: bool = False """Trust remote code (e.g., from HuggingFace) when downloading the model @@ -463,8 +462,6 @@ class ModelConfig: self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer) - from vllm.platforms import current_platform - if self.override_attention_dtype is not None and not current_platform.is_rocm(): warnings.warn( "override-attention-dtype is set but not using ROCm platform", @@ -473,10 +470,9 @@ class ModelConfig: if self.enable_sleep_mode and not current_platform.is_sleep_mode_available(): raise ValueError("Sleep mode is not supported on current platform.") - - temp_hf_config_path = os.environ.get("CUSTOM_QUANT_CONFIG", None) + hf_config = get_config( - temp_hf_config_path or self.hf_config_path or self.model, + self.hf_config_path or self.model, self.trust_remote_code, self.revision, self.code_revision, @@ -622,6 +618,16 @@ class ModelConfig: self._try_verify_and_update_model_config() self._verify_quantization() self._verify_cuda_graph() + import os + enforce_cuda_graph = os.environ.get("VLLM_ENFORCE_CUDA_GRAPH",None) + if enforce_cuda_graph is not None and enforce_cuda_graph in ["1", "y", "Y"]: + self.enforce_eager = False + else: + self.enforce_eager = True + logger.warning_once( + "Please export VLLM_ENFORCE_CUDA_GRAPH=1 to enable cuda graph. " + "For now, cuda graph is not used and --enforce-eager is disabled ," + "we are trying to use cuda graph as the default mode") self._verify_bnb_config() def get_model_arch_config( @@ -886,6 +892,7 @@ class ModelConfig: "modelopt", "modelopt_fp4", "modelopt_mxfp8", + "modelopt_mixed", "petit_nvfp4", # Ensure heavy backends are probed last to avoid unnecessary # imports during override detection (e.g., MXFP4 imports Triton) @@ -942,8 +949,6 @@ class ModelConfig: f"Unknown quantization method: {self.quantization}. Must " f"be one of {supported_quantization}." ) - from vllm.platforms import current_platform - current_platform.verify_quantization(self.quantization) if self.quantization in me_quant.DEPRECATED_QUANTIZATION_METHODS: @@ -1813,8 +1818,6 @@ def _resolve_auto_dtype( *, is_pooling_model: bool, ): - from vllm.platforms import current_platform - supported_dtypes = [ dtype for dtype in current_platform.supported_dtypes diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index cc2cfa9..6e84cf1 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -152,7 +152,6 @@ class ParallelConfig: - "naive": Naive all2all implementation using broadcasts\n - "allgather_reducescatter": All2all based on allgather and reducescatter\n - - "pplx": Use pplx kernels\n - "deepep_high_throughput": Use deepep high-throughput kernels\n - "deepep_low_latency": Use deepep low-latency kernels\n - "mori": Use mori kernels\n @@ -166,6 +165,9 @@ class ParallelConfig: disable_custom_all_reduce: bool = False """Disable the custom all-reduce kernel and fall back to NCCL.""" + enable_elastic_ep: bool = False + """Enable elastic expert parallelism with stateless NCCL groups for DP/EP.""" + enable_dbo: bool = False """Enable dual batch overlap for the model executor.""" ubatch_size: int = 0 @@ -245,6 +247,34 @@ class ParallelConfig: Set to be private as it's not intended to be configured by users. """ + _stateless_dp_group_port_list: list[list[int]] = Field(default_factory=list) + """List of open ports for stateless DP groups when enable_elastic_ep is True. + Set to be private as it's not intended to be configured by users. + It is a list of list[int], with each inner list contains a set of 3 ports + to be used for setting up the stateless CPU/device/TCPStore groups + in StatelessGroupCoordinator. The number of inner lists is equal to + the number of DP groups, + i.e., len(self._stateless_dp_group_port_list) == world_size_across_dp // dp_size, + and len(self._stateless_dp_group_port_list[i]) == 3 for all i. + """ + + _stateless_ep_group_port_list: list[list[int]] = Field(default_factory=list) + """List of open ports for stateless EP groups when enable_elastic_ep is True. + Set to be private as it's not intended to be configured by users. + len(self._stateless_ep_group_port_list) == world_size_across_dp // ep_size, + """ + + _stateless_eplb_group_port_list: list[list[int]] = Field(default_factory=list) + """List of open ports for stateless EPLB groups when enable_elastic_ep is True. + Same topology as EP but separate NCCL communicator to avoid deadlocks. + """ + + _stateless_world_group_port_list: list[list[int]] = Field(default_factory=list) + """List of open ports for stateless world group when enable_elastic_ep is True. + Set to be private as it's not intended to be configured by users. + len(self._stateless_world_group_port_list) == 1, + """ + decode_context_parallel_size: int = 1 """Number of decode context parallel groups, because the world size does not change by dcp, it simply reuse the GPUs of TP group, and tp_size @@ -310,6 +340,13 @@ class ParallelConfig: f"but found: {self._api_process_rank}" ) + if self.all2all_backend == "pplx": + logger.warning( + "The 'pplx' all2all backend has been removed. " + "Falling back to 'allgather_reducescatter'." + ) + self.all2all_backend = "allgather_reducescatter" + if self.data_parallel_size_local > self.data_parallel_size: raise ValueError( f"data_parallel_size_local ({self.data_parallel_size_local}) " @@ -396,7 +433,67 @@ class ParallelConfig: return answer - def stateless_init_dp_group(self) -> ProcessGroup: + def allocate_elastic_ep_ports(self) -> None: + """Allocate all ports for elastic EP (stateless groups + DP master). + + Must be called AFTER ray.init() so that ports claimed by Ray's + idle worker pool are already in use and won't be returned by + get_open_ports_list(). + """ + if not self.enable_elastic_ep: + return + if self._stateless_world_group_port_list: + return + + num_world_groups = 1 + dp_size = self.data_parallel_size + ep_size = self.data_parallel_size * self.world_size_across_dp + num_dp_groups = max(1, self.world_size_across_dp // dp_size) + num_ep_groups = max(1, self.world_size_across_dp // ep_size) + num_eplb_groups = num_ep_groups + total_stateless_ports = ( + num_world_groups + num_dp_groups + num_ep_groups + num_eplb_groups + ) * 3 + num_dp_master_ports = 5 + + all_ports = get_open_ports_list(total_stateless_ports + num_dp_master_ports) + + self._data_parallel_master_port_list = all_ports[-num_dp_master_ports:] + self.data_parallel_master_port = self._data_parallel_master_port_list.pop() + all_ports = all_ports[:-num_dp_master_ports] + + self._stateless_world_group_port_list = [ + all_ports[i : i + 3] for i in range(0, num_world_groups * 3, 3) + ] + start_idx = num_world_groups * 3 + self._stateless_dp_group_port_list = [ + all_ports[i : i + 3] + for i in range(start_idx, start_idx + num_dp_groups * 3, 3) + ] + start_idx += num_dp_groups * 3 + self._stateless_ep_group_port_list = [ + all_ports[i : i + 3] + for i in range(start_idx, start_idx + num_ep_groups * 3, 3) + ] + start_idx += num_ep_groups * 3 + self._stateless_eplb_group_port_list = [ + all_ports[i : i + 3] + for i in range(start_idx, start_idx + num_eplb_groups * 3, 3) + ] + + def get_next_stateless_world_group_port(self) -> list[int]: + return self._stateless_world_group_port_list.pop() + + def get_next_stateless_dp_group_port(self) -> list[int]: + return self._stateless_dp_group_port_list.pop() + + def get_next_stateless_ep_group_port(self) -> list[int]: + return self._stateless_ep_group_port_list.pop() + + def get_next_stateless_eplb_group_port(self) -> list[int]: + return self._stateless_eplb_group_port_list.pop() + + def stateless_init_dp_group(self, return_store: bool = False) -> ProcessGroup: # NOTE: In high-concurrency scenarios multiple processes # can pick the same (currently free) port through a race # condition when calling `get_open_port()`. When the first @@ -420,7 +517,8 @@ class ParallelConfig: self.get_next_dp_init_port(), self.data_parallel_rank, self.data_parallel_size, - backend=current_platform.dist_backend, + backend="gloo", + return_store=return_store, ) except DistNetworkError as e: # We only want to retry when the root cause is EADDRINUSE. @@ -442,7 +540,6 @@ class ParallelConfig: # In this case, ensure the input to the experts is sequence parallel # to avoid the excess work. # - # Not needed for pplx-kernels as it can handle duplicate input tokens. @property def use_sequence_parallel_moe(self) -> bool: return ( @@ -556,6 +653,21 @@ class ParallelConfig: logger.info("Using external launcher for distributed inference.") self.world_size *= self.data_parallel_size + if self.enable_elastic_ep: + if not self.enable_eplb: + raise ValueError("Elastic EP is only supported with enable_eplb=True.") + if self.pipeline_parallel_size > 1: + raise ValueError( + "Elastic EP is not supported with pipeline parallelism " + f"(pipeline_parallel_size={self.pipeline_parallel_size})." + ) + if self.data_parallel_external_lb or self.data_parallel_hybrid_lb: + raise NotImplementedError( + "Elastic EP is not compatible with data_parallel_external_lb " + "or data_parallel_hybrid_lb. Elastic EP relies on a single API " + "server and core client to coordinate scale up/down." + ) + if self.data_parallel_size > 1 or self.data_parallel_size_local == 0: # Data parallel was specified in the engine args. if self.distributed_executor_backend == "external_launcher": @@ -568,9 +680,12 @@ class ParallelConfig: "Set data_parallel_rank to %d automatically.", self.data_parallel_rank, ) - if not self._data_parallel_master_port_list: - self._data_parallel_master_port_list = get_open_ports_list(5) - self.data_parallel_master_port = self._data_parallel_master_port_list.pop() + if not self.enable_elastic_ep: + if not self._data_parallel_master_port_list: + self._data_parallel_master_port_list = get_open_ports_list(5) + self.data_parallel_master_port = ( + self._data_parallel_master_port_list.pop() + ) if not (0 <= self.data_parallel_rank < self.data_parallel_size): raise ValueError( @@ -597,7 +712,7 @@ class ParallelConfig: os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" logger.info("Disabling V1 multiprocessing for external launcher.") - if self.distributed_executor_backend is None and self.world_size > 1: + if self.distributed_executor_backend is None and self.world_size_across_dp > 1: # We use multiprocessing by default if world_size fits on the # current node and we aren't in a ray placement group. @@ -659,6 +774,17 @@ class ParallelConfig: "backend is mp, uni or external_launcher." ) + if ( + self.all2all_backend in ("allgather_reducescatter", "naive") + and self.eplb_config.use_async + ): + logger.warning( + "Async EPLB causes hangs with the '%s' all2all backend. " + "Forcing synchronous EPLB.", + self.all2all_backend, + ) + self.eplb_config.use_async = False + @property def use_ray(self) -> bool: return self.distributed_executor_backend == "ray" or ( diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index c2bced7..308356a 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import ast +import copy from typing import TYPE_CHECKING, Any, Literal, get_args from pydantic import Field, SkipValidation, model_validator @@ -45,7 +46,7 @@ MTPModelTypes = Literal[ "pangu_ultra_moe_mtp", "step3p5_mtp", ] -EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes] +EagleModelTypes = Literal["eagle", "eagle3", "extract_hidden_states", MTPModelTypes] SpeculativeMethod = Literal[ "ngram", "medusa", @@ -77,12 +78,24 @@ class SpeculativeConfig: If using `ngram` method, the related configuration `prompt_lookup_max` and `prompt_lookup_min` should be considered.""" + enable_multi_layers_mtp: bool = False + """If set to True, the MTP method will run multiple layers of MTP + speculator. If set to False, it will run only one layer of MTP speculator. + This is only effective when the method is set to `mtp`.""" draft_tensor_parallel_size: int | None = Field(default=None, ge=1) """The degree of the tensor parallelism for the draft model. Can only be 1 or the same as the target model's tensor parallel size.""" + draft_pipeline_parallel_size: int | None = Field(default=None, ge=1) + """The degree of pipeline parallelism for the draft model. + + Defaults to the target model's pipeline parallel size. Set this to 1 to + run the drafter locally on the last target PP stage.""" tensor_parallel_size: int | None = None """Users should pass "draft_tensor_parallel_size". This parameter's purpose is to warn users when they mistakenly provide the wrong argument.""" + pipeline_parallel_size: int | None = None + """Users should pass "draft_pipeline_parallel_size". This parameter's + purpose is to warn users when they mistakenly provide the wrong argument.""" # Draft model configuration quantization: me_quant.QuantizationMethods | None = None @@ -181,9 +194,22 @@ class SpeculativeConfig: the final hidden states. """ factors: list[Any] = [] - # Eagle3 affects the computation graph because it returns intermediate - # hidden states in addition to the final hidden state. - factors.append(self.method == "eagle3") + # Eagle3 and extract_hidden_states affect the computation graph because + # they return intermediate hidden states in addition to the final hidden state. + uses_aux_hidden_states = self.method in ("eagle3", "extract_hidden_states") + factors.append(uses_aux_hidden_states) + + # The specific layers used also affect the computation graph + if uses_aux_hidden_states and self.draft_model_config is not None: + layer_ids = getattr( + self.draft_model_config.hf_config, + "eagle_aux_hidden_state_layer_ids", + None, + ) + if layer_ids is not None: + # Convert to tuple to make it hashable + factors.append(tuple(layer_ids)) + hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str @@ -352,6 +378,8 @@ class SpeculativeConfig: self.model = "ngram" elif self.method == "suffix": self.model = "suffix" + elif self.method == "extract_hidden_states": + self.model = "extract_hidden_states" else: raise ValueError( "num_speculative_tokens was provided but without speculative model." @@ -394,6 +422,34 @@ class SpeculativeConfig: self.draft_parallel_config = self.target_parallel_config elif self.method == "suffix": self._validate_suffix_decoding() + elif self.method == "extract_hidden_states": + from vllm.transformers_utils.configs.extract_hidden_states import ( + ExtractHiddenStatesConfig, + ) + + # ExtractHiddenStatesModel is instantiated manually in load_model() + # We just need to store the target model config for KV cache shape info + self.model = "extract_hidden_states" + self.prompt_lookup_max = 0 + self.prompt_lookup_min = 0 + + if hasattr(self.draft_model_config, "hf_config"): + hf_config = self.draft_model_config.hf_config.to_dict() + elif ( + isinstance(self.draft_model_config, dict) + and "hf_config" in self.draft_model_config + ): + hf_config = self.draft_model_config["hf_config"] + else: + hf_config = {} + + self.draft_model_config = copy.copy(self.target_model_config) + self.draft_model_config.hf_config = ExtractHiddenStatesConfig( + self.draft_model_config.hf_config, **hf_config + ) + self.update_arch_() + self.draft_parallel_config = self.target_parallel_config + else: self.prompt_lookup_max = 0 self.prompt_lookup_min = 0 @@ -439,7 +495,10 @@ class SpeculativeConfig: MTPModelTypes ): self.method = "mtp" - if self.num_speculative_tokens > 1: + if ( + self.enable_multi_layers_mtp is False + and self.num_speculative_tokens > 1 + ): logger.warning( "Enabling num_speculative_tokens > 1 will run " "multiple times of forward on same MTP layer" @@ -478,23 +537,8 @@ class SpeculativeConfig: method=self.method, model_type="eagle", ) - # EAGLEConfig primarily updates architectures, so update - # all architectures-related fields in draft_model_config self.draft_model_config.hf_config = eagle_config - self.draft_model_config.hf_text_config = get_hf_text_config( - self.draft_model_config.hf_config - ) - self.draft_model_config.model_arch_config = ( - self.draft_model_config.get_model_arch_config() - ) - model_info, arch = ( - self.draft_model_config.registry.inspect_model_cls( - self.draft_model_config.architectures, - self.draft_model_config, - ) - ) - self.draft_model_config._model_info = model_info - self.draft_model_config._architecture = arch + self.update_arch_() if self.num_speculative_tokens is not None and hasattr( self.draft_model_config.hf_config, "num_lookahead_tokens" @@ -510,6 +554,17 @@ class SpeculativeConfig: if self.num_speculative_tokens is None: # Default to max value defined in draft model config. self.num_speculative_tokens = n_predict + elif ( + self.method == "mtp" + and self.enable_multi_layers_mtp + and self.num_speculative_tokens > n_predict + ): + logger.warning_once( + "For multi_layer_eagle, num_speculative_tokens " + "is greater than the layer_num, adjusting to " + "layer_num" + ) + self.num_speculative_tokens = n_predict elif ( self.num_speculative_tokens > n_predict and self.num_speculative_tokens % n_predict != 0 @@ -555,9 +610,17 @@ class SpeculativeConfig: ) ) + self.draft_pipeline_parallel_size = ( + SpeculativeConfig._verify_and_get_draft_pp( + self.target_parallel_config, + self.draft_pipeline_parallel_size, + ) + ) self.draft_parallel_config = ( SpeculativeConfig.create_draft_parallel_config( - self.target_parallel_config, self.draft_tensor_parallel_size + self.target_parallel_config, + self.draft_tensor_parallel_size, + self.draft_pipeline_parallel_size, ) ) return self @@ -671,17 +734,61 @@ class SpeculativeConfig: ) return speculative_draft_tensor_parallel_size + @staticmethod + def _verify_and_get_draft_pp( + target_parallel_config: ParallelConfig, + speculative_draft_pipeline_parallel_size: int | None, + ) -> int: + """ + Verifies and adjusts the pipeline parallel size for a draft model + specified using speculative_draft_pipeline_parallel_size. + """ + if speculative_draft_pipeline_parallel_size is None: + return target_parallel_config.pipeline_parallel_size + + if speculative_draft_pipeline_parallel_size not in ( + 1, + target_parallel_config.pipeline_parallel_size, + ): + raise ValueError( + f"{speculative_draft_pipeline_parallel_size=} cannot be " + "other value than 1 or target model " + f"pipeline_parallel_size=" + f"{target_parallel_config.pipeline_parallel_size}" + ) + return speculative_draft_pipeline_parallel_size + + def update_arch_(self): + """ + EagleConfig and ExtractHiddenStatesConfig update architectures, so update all + architectures-related fields in self.draft_model_config + """ + self.draft_model_config.hf_text_config = get_hf_text_config( + self.draft_model_config.hf_config + ) + self.draft_model_config.model_arch_config = ( + self.draft_model_config.get_model_arch_config() + ) + model_info, arch = self.draft_model_config.registry.inspect_model_cls( + self.draft_model_config.architectures, + self.draft_model_config, + ) + self.draft_model_config._model_info = model_info + self.draft_model_config._architecture = arch + @staticmethod def create_draft_parallel_config( target_parallel_config: ParallelConfig, speculative_draft_tensor_parallel_size: int, + speculative_draft_pipeline_parallel_size: int, ) -> ParallelConfig: """Create a parallel config for use by the draft worker. - This is mostly a copy of the target parallel config, except the tp_size. + This is mostly a copy of the target parallel config, except the tp/pp + sizes used by the draft model. """ draft_parallel_config = ParallelConfig( - pipeline_parallel_size=target_parallel_config.pipeline_parallel_size, + pipeline_parallel_size=speculative_draft_pipeline_parallel_size, tensor_parallel_size=speculative_draft_tensor_parallel_size, distributed_executor_backend=target_parallel_config.distributed_executor_backend, max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers, @@ -699,6 +806,12 @@ class SpeculativeConfig: "'tensor_parallel_size' is not a valid argument in the " "speculative_config. Please pass 'draft_tensor_parallel_size' instead." ) + if self.pipeline_parallel_size is not None: + raise ValueError( + "'pipeline_parallel_size' is not a valid argument in the " + "speculative_config. Please pass " + "'draft_pipeline_parallel_size' instead." + ) if self.num_speculative_tokens is None: raise ValueError( @@ -718,7 +831,7 @@ class SpeculativeConfig: self.draft_parallel_config ) - eagle3_target_supported = [ + aux_hidden_states_supported = [ "llama", "qwen", "minicpm", @@ -729,16 +842,16 @@ class SpeculativeConfig: "nemotron_h", ] if ( - self.method == "eagle3" + self.method in ("eagle3", "extract_hidden_states") and self.target_model_config and not any( supported_model in self.target_model_config.hf_text_config.model_type - for supported_model in eagle3_target_supported + for supported_model in aux_hidden_states_supported ) ): raise ValueError( - f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501 - f"Got {self.target_model_config.hf_text_config.model_type=}" + f"{self.method} is only supported for {aux_hidden_states_supported}" + f" models. Got {self.target_model_config.hf_text_config.model_type=}" ) self.verify_equal_vocab_size_if_draft_model() return self @@ -782,8 +895,65 @@ class SpeculativeConfig: def uses_draft_model(self) -> bool: return self.method == "draft_model" + def uses_extract_hidden_states(self) -> bool: + return self.method == "extract_hidden_states" + + def needs_partial_pp_draft_remap( + self, target_parallel_config: ParallelConfig + ) -> bool: + """Whether draft PP is smaller than target PP and needs rank remap.""" + if self.draft_parallel_config is None: + return False + return ( + target_parallel_config.pipeline_parallel_size + > self.draft_parallel_config.pipeline_parallel_size + ) + + def resolve_partial_pp_draft_rank( + self, target_parallel_config: ParallelConfig + ) -> int: + """Map a target rank to the local draft rank for partial-PP drafting. + + Currently this only supports running the draft model with `draft_pp=1` + on the last target PP stage. + """ + if not self.needs_partial_pp_draft_remap(target_parallel_config): + return target_parallel_config.rank + + assert self.draft_parallel_config is not None + draft_pp = self.draft_parallel_config.pipeline_parallel_size + if draft_pp != 1: + raise ValueError( + "Partial pp drafter rank remapping only supports " + "draft_pipeline_parallel_size=1 when target PP is larger." + ) + + target_tp = target_parallel_config.tensor_parallel_size + draft_tp = self.draft_parallel_config.tensor_parallel_size + if draft_tp != target_tp: + raise ValueError( + "Partial pp drafter rank remapping requires " + "draft_tensor_parallel_size to equal target tensor_parallel_size. " + f"Got draft_tp={draft_tp}, target_tp={target_tp}." + ) + + target_pp = target_parallel_config.pipeline_parallel_size + target_rank = target_parallel_config.rank + target_pp_rank = target_rank // target_tp + target_tp_rank = target_rank % target_tp + if target_pp_rank != target_pp - 1: + raise ValueError( + "Partial pp drafter should only run on the last " + f"pipeline stage, but got pp rank {target_pp_rank} / {target_pp}" + ) + return target_tp_rank + def __repr__(self) -> str: method = self.method - model = None if method in ("ngram", "suffix") else self.draft_model_config.model + model = ( + None + if method in ("ngram", "suffix", "extract_hidden_states") + else self.draft_model_config.model + ) num_spec_tokens = self.num_speculative_tokens return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})" diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 6a90a4c..c236837 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -126,6 +126,9 @@ def enable_allreduce_rms_fusion(cfg: "VllmConfig") -> bool: # tp-dp combination broken: # https://github.com/vllm-project/vllm/issues/34458 and cfg.parallel_config.data_parallel_size == 1 + # tp-pp combination broken: + # https://github.com/vllm-project/vllm/issues/35426 + and cfg.parallel_config.pipeline_parallel_size == 1 ) @@ -857,7 +860,7 @@ class VllmConfig: self.compilation_config.pass_config.fuse_gemm_comms = False else: # Compute SP threshold early; disable if None (model too - # small) before +rms_norm gets forced into custom_ops. + # small for SP to be beneficial). pass_config = self.compilation_config.pass_config if pass_config.sp_min_token_num is None: from vllm.compilation.passes.fusion.sequence_parallelism import ( @@ -880,15 +883,13 @@ class VllmConfig: self.compilation_config.pass_config.enable_sp = False self.compilation_config.pass_config.fuse_gemm_comms = False - if self.compilation_config.pass_config.enable_sp: - if "-rms_norm" in self.compilation_config.custom_ops: - logger.warning( - "RMS norm force disabled, sequence parallelism might break" - ) - else: - self.compilation_config.custom_ops.append("+rms_norm") + from vllm.utils.torch_utils import HAS_OPAQUE_TYPE - if self.compilation_config.fast_moe_cold_start is None: + if HAS_OPAQUE_TYPE: + # On torch >= 2.11 the hoisted OpaqueObject approach supersedes + # fast_moe_cold_start, so force it off. + self.compilation_config.fast_moe_cold_start = False + elif self.compilation_config.fast_moe_cold_start is None: # resolve default behavior: try to be as safe as possible # this config is unsafe if any spec decoding draft model has a MOE. # We'll conservatively turn it off if we see spec decoding. @@ -907,9 +908,9 @@ class VllmConfig: ): logger.warning_once( "Pooling models do not support full cudagraphs. " - "Overriding cudagraph_mode to PIECEWISE." + "Overriding cudagraph_mode to NONE." ) - self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE elif ( model_config.is_encoder_decoder and self.compilation_config.cudagraph_mode @@ -924,6 +925,33 @@ class VllmConfig: CUDAGraphMode.FULL_DECODE_ONLY ) + # Check if KV connector requires PIECEWISE mode for CUDA graphs + if ( + self.kv_transfer_config is not None + and self.kv_transfer_config.is_kv_transfer_instance + and self.compilation_config.cudagraph_mode.has_full_cudagraphs() + ): + # Lazy import to avoid circular dependencies + from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory, + ) + + connector_cls = KVConnectorFactory.get_connector_class( + self.kv_transfer_config + ) + if connector_cls.requires_piecewise_for_cudagraph( + self.kv_transfer_config.kv_connector_extra_config + ): + logger.warning_once( + "KV connector %s requires PIECEWISE CUDA graph mode " + "due to layerwise async operations that cannot be " + "captured in CUDA graphs. " + "Overriding cudagraph_mode from %s to PIECEWISE.", + connector_cls.__name__, + self.compilation_config.cudagraph_mode.name, + ) + self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE + # disable cudagraph when enforce eager execution if self.model_config is not None and self.model_config.enforce_eager: logger.info("Cudagraph is disabled under eager mode") @@ -1113,6 +1141,20 @@ class VllmConfig: if not self.instance_id: self.instance_id = random_uuid()[:5] + + def is_ixserver_connector(kv_transfer_config) -> bool: + if kv_transfer_config is not None and hasattr( + kv_transfer_config, "kv_connector" + ): + connector = kv_transfer_config.kv_connector + if isinstance(connector, str): + connector_name = connector + else: + connector_name = getattr( + type(connector), "__name__", str(connector) + ) + return "IxServer" in connector_name + return False # Hybrid KV cache manager (HMA) runtime rules: # - Explicit enable (--no-disable-kv-cache-manager): error if runtime @@ -1154,21 +1196,29 @@ class VllmConfig: if self.scheduler_config.disable_hybrid_kv_cache_manager is None: # Default to disable HMA, but only if the user didn't express a preference. if self.kv_transfer_config is not None: + if is_ixserver_connector(self.kv_transfer_config): + pass # NOTE(Kuntai): turn HMA off for connector unless specifically enabled. - need_disable_hybrid_kv_cache_manager = True - logger.warning( - "Turning off hybrid kv cache manager because " - "`--kv-transfer-config` is set. This will reduce the " - "performance of vLLM on LLMs with sliding window attention " - "or Mamba attention. If you are a developer of kv connector" - ", please consider supporting hybrid kv cache manager for " - "your connector by making sure your connector is a subclass" - " of `SupportsHMA` defined in kv_connector/v1/base.py and" - " use --no-disable-hybrid-kv-cache-manager to start vLLM." + else: + need_disable_hybrid_kv_cache_manager = True + logger.warning( + "Turning off hybrid kv cache manager because " + "`--kv-transfer-config` is set. This will reduce the " + "performance of vLLM on LLMs with sliding window attention " + "or Mamba attention. If you are a developer of kv connector" + ", please consider supporting hybrid kv cache manager for " + "your connector by making sure your connector is a subclass" + " of `SupportsHMA` defined in kv_connector/v1/base.py and" + " use --no-disable-hybrid-kv-cache-manager to start vLLM." + ) + self.scheduler_config.disable_hybrid_kv_cache_manager = ( + need_disable_hybrid_kv_cache_manager + ) + + else: + self.scheduler_config.disable_hybrid_kv_cache_manager = ( + need_disable_hybrid_kv_cache_manager ) - self.scheduler_config.disable_hybrid_kv_cache_manager = ( - need_disable_hybrid_kv_cache_manager - ) elif ( self.scheduler_config.disable_hybrid_kv_cache_manager is False and need_disable_hybrid_kv_cache_manager @@ -1466,22 +1516,22 @@ class VllmConfig: if compile_range_end is not None: computed_compile_ranges_split_points.append(compile_range_end) - # # Add the compile ranges for flashinfer - # if compilation_config.pass_config.fuse_allreduce_rms: - # tp_size = self.parallel_config.tensor_parallel_size - # max_size = compilation_config.pass_config.flashinfer_max_size(tp_size) - # if max_size is not None: - # max_token_num = max_size // ( - # self.model_config.get_hidden_size() - # * self.model_config.dtype.itemsize - # ) - # if compile_range_end is not None and max_token_num < compile_range_end: - # computed_compile_ranges_split_points.append(max_token_num) - # else: - # logger.debug( - # "Max num batched tokens below allreduce-rms fusion threshold, " - # "allreduce-rms fusion will be enabled for all num_tokens." - # ) + # Add the compile ranges for flashinfer + if compilation_config.pass_config.fuse_allreduce_rms: + tp_size = self.parallel_config.tensor_parallel_size + max_size = compilation_config.pass_config.flashinfer_max_size(tp_size) + if max_size is not None: + max_token_num = max_size // ( + self.model_config.get_hidden_size() + * self.model_config.dtype.itemsize + ) + if compile_range_end is not None and max_token_num < compile_range_end: + computed_compile_ranges_split_points.append(max_token_num) + else: + logger.debug( + "Max num batched tokens below allreduce-rms fusion threshold, " + "allreduce-rms fusion will be enabled for all num_tokens." + ) # Add the compile ranges for sequence parallelism if compilation_config.pass_config.enable_sp: @@ -1618,6 +1668,7 @@ class VllmConfig: f"pipeline_parallel_size={self.parallel_config.pipeline_parallel_size}, " # noqa f"data_parallel_size={self.parallel_config.data_parallel_size}, " # noqa f"disable_custom_all_reduce={self.parallel_config.disable_custom_all_reduce}, " # noqa + f"quantization={self.model_config.quantization}, " f"enforce_eager={self.model_config.enforce_eager}, " f"enable_return_routed_experts={self.model_config.enable_return_routed_experts}, " # noqa f"kv_cache_dtype={self.cache_config.cache_dtype}, " diff --git a/vllm/config/weight_transfer.py b/vllm/config/weight_transfer.py index 855b0d9..1da1f96 100644 --- a/vllm/config/weight_transfer.py +++ b/vllm/config/weight_transfer.py @@ -9,5 +9,5 @@ from vllm.config.utils import config class WeightTransferConfig: """Configuration for weight transfer during RL training.""" - backend: Literal["nccl"] = "nccl" + backend: Literal["nccl", "ipc"] = "nccl" """The backend to use for weight transfer.""" diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 2f97288..554a34b 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -11,7 +11,7 @@ import dataclasses import gc import os -from collections.abc import Callable +from collections.abc import Callable, Iterator from contextlib import contextmanager from typing import Any @@ -25,6 +25,7 @@ logger = init_logger(__name__) cumem_available = False +libcudart: Any = None try: from vllm.cumem_allocator import ( init_module, @@ -41,9 +42,7 @@ except ModuleNotFoundError: init_module = None python_create_and_map = None python_unmap_and_release = None - CudaRTLibrary = None lib_name = None - libcudart = None # py_device, py_alignedSize, py_d_mem, py_p_memHandle HandleType = tuple[int, int, int, int] @@ -65,7 +64,8 @@ def unmap_and_release(allocation_handle: HandleType) -> None: def get_pluggable_allocator( - python_malloc_fn: Callable[[int], int], python_free_func: Callable[[int, int], None] + python_malloc_fn: Callable[[HandleType], None], + python_free_func: Callable[[int], HandleType], ) -> torch.cuda.memory.CUDAPluggableAllocator: init_module(python_malloc_fn, python_free_func) new_alloc = torch.cuda.memory.CUDAPluggableAllocator( @@ -76,8 +76,11 @@ def get_pluggable_allocator( @contextmanager def use_memory_pool_with_allocator( - python_malloc_fn: Callable[[int], int], python_free_func: Callable[[int, int], None] -) -> None: + python_malloc_fn: Callable[[HandleType], None], + python_free_func: Callable[[int], HandleType], +) -> Iterator[ + tuple[torch.cuda.memory.MemPool, torch.cuda.memory.CUDAPluggableAllocator] +]: new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func) mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator) with torch.cuda.memory.use_mem_pool(mem_pool): @@ -109,7 +112,7 @@ class CuMemAllocator: not work as expected. """ - instance: "CuMemAllocator" = None + instance: "CuMemAllocator | None" = None default_tag: str = "default" @staticmethod diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 678cd45..7e1c9e6 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -3,14 +3,13 @@ from typing import Any import torch -import torch.distributed as dist import vllm.envs as envs from vllm.distributed import get_dp_group, get_ep_group from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.utils.flashinfer import has_flashinfer_all2all -from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx +from vllm.utils.import_utils import has_deep_ep, has_mori from .base_device_communicator import All2AllManagerBase, Cache @@ -32,8 +31,8 @@ class NaiveAll2AllManager(All2AllManagerBase): debugging. """ - def __init__(self, cpu_group): - super().__init__(cpu_group) + def __init__(self, cpu_group, tcp_store_group=None): + super().__init__(cpu_group, tcp_store_group) def naive_multicast( self, @@ -139,8 +138,8 @@ class AgRsAll2AllManager(All2AllManagerBase): all-gather (dispatch) and reduce-scatter (combine). """ - def __init__(self, cpu_group): - super().__init__(cpu_group) + def __init__(self, cpu_group, tcp_store_group=None): + super().__init__(cpu_group, tcp_store_group) def dispatch_router_logits( self, @@ -235,107 +234,17 @@ class AgRsAll2AllManager(All2AllManagerBase): pass -class PPLXAll2AllManager(All2AllManagerBase): - """ - All2All communication based on PPLX kernels. - """ - - def __init__(self, cpu_group): - assert has_pplx(), ( - "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md" - " to install pplx_kernels." - ) - super().__init__(cpu_group) - - if self.internode: - # inter-node communication needs nvshmem, - # intra-node communication uses p2p mapping directly - from pplx_kernels.nvshmem import ( # type: ignore[import-not-found] - nvshmem_alloc_empty_unique_id, - nvshmem_get_unique_id, - nvshmem_init, - ) - - logger.debug( - "Initialize NVSHMEM for pplx_kernels: rank=%d, world size=%d", - self.rank, - self.world_size, - ) - uid = ( - nvshmem_get_unique_id() - if self.rank == 0 - else nvshmem_alloc_empty_unique_id() - ) - dist.broadcast( - uid, - src=dist.get_process_group_ranks(self.cpu_group)[0], - group=self.cpu_group, - ) - logger.debug("PPLX NVSHMEM UID = %s", uid) - nvshmem_init(uid, self.rank, self.world_size) - - self.handle_cache = Cache() - - def get_handle(self, kwargs): - import pplx_kernels as pplx # type: ignore[import-not-found] - - return self.handle_cache.get_or_create( - kwargs, - pplx.AllToAll.internode if self.internode else pplx.AllToAll.intranode, - ) - - def dispatch_router_logits( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - is_sequence_parallel: bool = False, - extra_tensors: list[torch.Tensor] | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - raise NotImplementedError - - def dispatch( - self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - is_sequence_parallel: bool = False, - extra_tensors: list[torch.Tensor] | None = None, - ) -> ( - tuple[torch.Tensor, torch.Tensor, torch.Tensor] - | tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]] - ): - raise NotImplementedError - - def combine( - self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False - ) -> torch.Tensor: - raise NotImplementedError - - def destroy(self): - with self.handle_cache._lock: - for _, handle in self.handle_cache._cache.items(): - handle.destroy() - - if self.internode: - from pplx_kernels.nvshmem import ( - nvshmem_finalize, # type: ignore[import-not-found] - ) - - logger.debug("PPLX NVSHMEM finalize") - nvshmem_finalize() - - class DeepEPAll2AllManagerBase(All2AllManagerBase): """ All2All communication based on DeepEP High-Throughput kernels. """ - def __init__(self, cpu_group): + def __init__(self, cpu_group, tcp_store_group=None): assert has_deep_ep(), ( "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md" " to install DeepEP kernels." ) # noqa - super().__init__(cpu_group) + super().__init__(cpu_group, tcp_store_group) self.handle_cache = Cache() # This is the DeepEP default. Stick to it till we can establish @@ -373,7 +282,10 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): raise NotImplementedError def destroy(self): - pass + with self.handle_cache._lock: + for _, handle in self.handle_cache._cache.items(): + handle.destroy() + self.handle_cache._cache.clear() class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): @@ -381,8 +293,8 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): All2All communication based on DeepEP High-Throughput kernels. """ - def __init__(self, cpu_group): - super().__init__(cpu_group) + def __init__(self, cpu_group, tcp_store_group=None): + super().__init__(cpu_group, tcp_store_group) def _make_all2all_kwargs(self) -> dict[Any, Any]: # Defaults for internode and intranode are taken from DeepEP tests. @@ -405,6 +317,7 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): num_rdma_bytes=num_rdma_bytes, low_latency_mode=False, num_qps_per_rank=num_qps_per_rank, + explicitly_destroy=True, ) def get_handle(self, kwargs): @@ -438,8 +351,8 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): All2All communication based on DeepEP Low-Latency kernels. """ - def __init__(self, cpu_group): - super().__init__(cpu_group) + def __init__(self, cpu_group, tcp_store_group=None): + super().__init__(cpu_group, tcp_store_group) def _make_all2all_kwargs( self, @@ -476,8 +389,9 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): num_rdma_bytes=num_rdma_bytes, low_latency_mode=True, num_qps_per_rank=num_qps_per_rank, - allow_nvlink_for_low_latency_mode=True, - allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL, + # allow_nvlink_for_low_latency_mode=True, + # allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL, + explicitly_destroy=True, ) def get_handle(self, kwargs): @@ -509,11 +423,11 @@ class FlashInferAllToAllManager(All2AllManagerBase): rank: int world_size: int - def __init__(self, cpu_group): + def __init__(self, cpu_group, tcp_store_group=None): assert has_flashinfer_all2all(), ( "flashinfer all2all module not found. Please install/check flashinfer" ) # noqa - super().__init__(cpu_group) + super().__init__(cpu_group, tcp_store_group) logger.debug( "Initialize for flashinfer All2All rank=%d, world size=%d", self.rank, diff --git a/vllm/distributed/device_communicators/all_reduce_utils.py b/vllm/distributed/device_communicators/all_reduce_utils.py index ff2d743..3c347ef 100644 --- a/vllm/distributed/device_communicators/all_reduce_utils.py +++ b/vllm/distributed/device_communicators/all_reduce_utils.py @@ -27,6 +27,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless logger = init_logger(__name__) +KiB = 1024 MiB = 1024 * 1024 # Max size for each world size in case symmetric memory is available # For different SM architectures @@ -60,17 +61,44 @@ SYMM_MEM_ALL_REDUCE_MAX_SIZES = { }, } +# NCCL symmetric memory allreduce configuration based on H100 and GB200 benchmarks. +# PyNCCL-symm outperforms custom_AR for small and large tensor sizes, +# while custom_AR wins for mid-range sizes. +# +# Benchmark results (8 GPUs): +# 2K - 16K: PyNCCL-symm wins (1.35x - 1.48x faster) +# 32K - 64K: custom_AR wins +# 128K - 1G: PyNCCL-symm wins (1.12x - 6.14x faster) +# +# Benchmark results (4 GPUs): +# 2K - 16K: PyNCCL-symm wins (1.21x - 1.30x faster) +# 32K - 256K: custom_AR wins (1.07x - 1.35x faster) +# 512K - 1G: PyNCCL-symm wins (1.10x - 2.32x faster) +# +# The config defines ranges where custom_AR is preferred (symm_mem disabled). NCCL_SYMM_MEM_ALL_REDUCE_CONFIG: dict[str, Any] = { "min_world_size": 4, - "thresholds": { - 4: 2 * MiB, # 2 MB - 8: 1 * MiB, # 1 MB + # Ranges where custom_AR outperforms NCCL symm_mem: (lower_bound, upper_bound) + # NCCL symm_mem will NOT be used for sizes in range: lower < size < upper + "custom_ar_preferred_ranges": { + 4: (16 * KiB, 512 * KiB), # custom_AR wins for 32K-256K + 8: (16 * KiB, 128 * KiB), # custom_AR wins for 32K-64K }, "always_use_above_world_size": 8, # Always use symm mem for world_size > 8 } def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor) -> bool: + """ + Determine if NCCL symmetric memory allreduce should be used. + + Based on H100 and GB200 benchmarks, NCCL symm_mem is preferred for: + - Small tensors (≤16K): Lower latency than custom_AR + - Large tensors (≥128K for 8 GPUs, ≥512K for 4 GPUs): Better bandwidth + + Custom_AR is preferred for mid-range sizes where its P2P approach + has lower overhead than the symm_mem copy-in/copy-out pattern. + """ from vllm.distributed.device_communicators.pynccl_allocator import ( is_symmetric_memory_enabled, ) @@ -80,11 +108,20 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor) if not is_symmetric_memory_enabled(): return False + if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]: return False - threshold = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["thresholds"].get(world_size) - if threshold is not None and input_tensor.nbytes >= threshold: - return True + + tensor_size = input_tensor.nbytes + custom_ar_range = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["custom_ar_preferred_ranges"].get( + world_size + ) + + if custom_ar_range is not None: + lower_bound, upper_bound = custom_ar_range + # Use symm_mem for small sizes (≤ lower_bound) and large sizes (≥ upper_bound) + # Use custom_AR (not symm_mem) for mid-range sizes + return tensor_size <= lower_bound or tensor_size >= upper_bound return world_size > NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"] diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index fb055ba..52c6954 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -30,8 +30,9 @@ class All2AllManagerBase: rank: int world_size: int - def __init__(self, cpu_group): + def __init__(self, cpu_group, tcp_store_group=None): self.cpu_group = cpu_group + self.tcp_store_group = tcp_store_group # compute some common properties from vllm.distributed.parallel_state import ( @@ -48,12 +49,17 @@ class All2AllManagerBase: # when we create this object self.dp_rank = self.dp_group.rank_in_group self.dp_world_size = self.dp_group.world_size - self.rank = dist.get_rank(cpu_group) - self.world_size = dist.get_world_size(cpu_group) + self.rank = cpu_group.rank() + self.world_size = cpu_group.size() # all2all communication often has separate implementations for # intra-node and inter-node communication - self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0)) + if tcp_store_group is None: + self.internode = not all(in_the_same_node_as(cpu_group, source_rank=0)) + else: + self.internode = not all( + in_the_same_node_as(tcp_store_group, source_rank=0) + ) def get_handle(self, kwargs): # get a handle for the all2all communication, @@ -122,17 +128,36 @@ class DeviceCommunicatorBase: device: torch.device | None = None, device_group: ProcessGroup | None = None, unique_name: str = "", + global_ranks: list[int] | None = None, + global_world_size: int | None = None, ): self.device = device or torch.device("cpu") self.cpu_group = cpu_group self.device_group = device_group self.unique_name = unique_name - self.rank = dist.get_rank(cpu_group) - self.world_size = dist.get_world_size(cpu_group) - self.ranks = dist.get_process_group_ranks(cpu_group) - self.global_rank = dist.get_rank() - self.global_world_size = dist.get_world_size() - self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank) + + # Check if this is a stateless process group + from torch.distributed.distributed_c10d import _world + + is_stateless = _world.pg_map.get(cpu_group, None) is None + + if is_stateless: + # For stateless groups, we can't use torch.distributed methods + self.rank = cpu_group.rank() + self.world_size = cpu_group.size() + assert global_ranks is not None + assert global_world_size is not None + self.ranks = global_ranks + self.global_rank = self.ranks[self.rank] + self.global_world_size = global_world_size + self.rank_in_group = self.rank + else: + self.rank = dist.get_rank(cpu_group) + self.world_size = dist.get_world_size(cpu_group) + self.ranks = dist.get_process_group_ranks(cpu_group) + self.global_rank = dist.get_rank() + self.global_world_size = dist.get_world_size() + self.rank_in_group = dist.get_group_rank(self.cpu_group, self.global_rank) use_ep = False all2all_backend = None @@ -146,7 +171,7 @@ class DeviceCommunicatorBase: use_ep = config.parallel_config.data_parallel_size > 1 all2all_backend = config.parallel_config.all2all_backend - self.is_ep_communicator = "ep" in unique_name + self.is_ep_communicator = unique_name.split(":")[0] == "ep" self.use_all2all = self.is_ep_communicator and use_ep self.all2all_backend = all2all_backend self.all2all_manager: All2AllManagerBase | None = None @@ -175,9 +200,7 @@ class DeviceCommunicatorBase: group=self.device_group, async_op=True) else: - torch.distributed.all_gather_into_tensor(output_tensor, - input_, - group=self.device_group) + dist.all_gather_into_tensor(output_tensor, input_, group=self.device_group) # Reshape output_tensor = output_tensor.reshape((self.world_size,) + input_size) output_tensor = output_tensor.movedim(0, dim) @@ -263,10 +286,9 @@ class DeviceCommunicatorBase: group=self.device_group, async_op=True) else: - torch.distributed.gather(input_, - gather_list, - dst=self.ranks[dst], - group=self.device_group) + torch.distributed.gather( + input_, gather_list, dst=self.ranks[dst], group=self.device_group + ) if self.rank_in_group == dst: output_tensor = torch.cat(gather_list, dim=dim) else: @@ -292,6 +314,13 @@ class DeviceCommunicatorBase: torch.distributed.recv(tensor, self.ranks[src], self.device_group) return tensor + def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor: + """Broadcast a tensor from source rank to all ranks.""" + if self.world_size == 1: + return tensor + torch.distributed.broadcast(tensor, self.ranks[src], self.device_group) + return tensor + def destroy(self): pass @@ -360,3 +389,6 @@ class DeviceCommunicatorBase: This is a no-op in the base class. """ return hidden_states + + def batch_isend_irecv(self, p2p_ops: list): + raise NotImplementedError diff --git a/vllm/distributed/device_communicators/cpu_communicator.py b/vllm/distributed/device_communicators/cpu_communicator.py index 23be8fc..2bce5fa 100644 --- a/vllm/distributed/device_communicators/cpu_communicator.py +++ b/vllm/distributed/device_communicators/cpu_communicator.py @@ -35,8 +35,15 @@ class CpuCommunicator(DeviceCommunicatorBase): ) and hasattr(torch.ops._C, "init_shm_manager") and (unique_name.startswith("tp") or unique_name.startswith("pp")) + and self._all_group_ranks_share_shm_group_name() ): self.dist_module = _CPUSHMDistributed(self) + elif unique_name.startswith("tp") or unique_name.startswith("pp"): + logger.info( + "CPU SHM communicator disabled for group %s: ranks do not share " + "the same SHM group name, falling back to torch.distributed.", + unique_name, + ) if self.use_all2all: if self.all2all_backend != "naive": # type: ignore[has-type] @@ -52,6 +59,20 @@ class CpuCommunicator(DeviceCommunicatorBase): self.all2all_manager = NaiveAll2AllManager(self.cpu_group) logger.info("Using naive all2all manager.") + def _all_group_ranks_share_shm_group_name(self) -> bool: + """ + CPUSHM requires all ranks in this group to agree on one SHM group name. + This is a lightweight consistency check for VLLM_DIST_IDENT/name inputs. + """ + local_name = _CPUSHMDistributed.make_group_name(self) + names: list[str] = [""] * self.world_size + torch.distributed.all_gather_object( + names, + local_name, + group=self.device_group, + ) + return len(set(names)) == 1 + def all_reduce(self, input_): self.dist_module.all_reduce(input_, group=self.device_group) return input_ @@ -193,16 +214,20 @@ class CpuCommunicator(DeviceCommunicatorBase): class _CPUSHMDistributed: def __init__(self, communicator: CpuCommunicator): + self.communicator = communicator + + self.group_name = self.make_group_name(communicator) + + self.handle = self._init_cpu_shm() + + @staticmethod + def make_group_name(communicator: CpuCommunicator) -> str: instance_identifier = os.environ["VLLM_DIST_IDENT"] unique_name = communicator.unique_name instance_identifier = f"{instance_identifier}-{unique_name}" - self.communicator = communicator - - group_ranks = [str(rank) for rank in self.communicator.ranks] + group_ranks = [str(rank) for rank in communicator.ranks] shm_group_identifier = f"[{'-'.join(group_ranks)}]" - self.group_name = f"{instance_identifier}-{shm_group_identifier}-cpushm" - - self.handle = self._init_cpu_shm() + return f"{instance_identifier}-{shm_group_identifier}-cpushm" def _init_cpu_shm(self) -> int: thread_num_tensor = torch.tensor( diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 6302e50..519a37f 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -16,6 +16,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import ( from vllm.logger import init_logger from vllm.platforms import current_platform +from ..utils import StatelessProcessGroup from .base_device_communicator import DeviceCommunicatorBase import ixformer.distributed as ixfd import os @@ -29,8 +30,18 @@ class CudaCommunicator(DeviceCommunicatorBase): device: torch.device | None = None, device_group: ProcessGroup | None = None, unique_name: str = "", + global_ranks: list[int] | None = None, + global_world_size: int | None = None, + tcp_store_group: StatelessProcessGroup | None = None, ): - super().__init__(cpu_group, device, device_group, unique_name) + super().__init__( + cpu_group, + device, + device_group, + unique_name, + global_ranks, + global_world_size, + ) if "tp" not in unique_name: # custom allreduce or torch symm mem can be used only by tp use_custom_allreduce = False @@ -46,8 +57,9 @@ class CudaCommunicator(DeviceCommunicatorBase): self.use_custom_allreduce = use_custom_allreduce self.use_torch_symm_mem = use_torch_symm_mem self.use_flashinfer_allreduce = use_flashinfer_allreduce - + self.use_vllm_comm = os.environ.get("VLLM_FORCE_NCCL_COMM",None) not in ["1", "Y", "y"] + # lazy import to avoid documentation build error from vllm.distributed.device_communicators.custom_all_reduce import ( CustomAllreduce, @@ -64,7 +76,7 @@ class CudaCommunicator(DeviceCommunicatorBase): self.pynccl_comm: PyNcclCommunicator | None = None if self.world_size > 1: self.pynccl_comm = PyNcclCommunicator( - group=self.cpu_group, + group=self.cpu_group if tcp_store_group is None else tcp_store_group, device=self.device, ) if is_symmetric_memory_enabled(): @@ -109,23 +121,27 @@ class CudaCommunicator(DeviceCommunicatorBase): if self.all2all_backend == "naive": from .all2all import NaiveAll2AllManager - self.all2all_manager = NaiveAll2AllManager(self.cpu_group) + self.all2all_manager = NaiveAll2AllManager( + self.cpu_group, tcp_store_group + ) elif self.all2all_backend == "allgather_reducescatter": from .all2all import AgRsAll2AllManager - self.all2all_manager = AgRsAll2AllManager(self.cpu_group) - elif self.all2all_backend == "pplx": - from .all2all import PPLXAll2AllManager - - self.all2all_manager = PPLXAll2AllManager(self.cpu_group) + self.all2all_manager = AgRsAll2AllManager( + self.cpu_group, tcp_store_group + ) elif self.all2all_backend == "deepep_high_throughput": from .all2all import DeepEPHTAll2AllManager - self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group) + self.all2all_manager = DeepEPHTAll2AllManager( + self.cpu_group, tcp_store_group + ) elif self.all2all_backend == "deepep_low_latency": from .all2all import DeepEPLLAll2AllManager - self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group) + self.all2all_manager = DeepEPLLAll2AllManager( + self.cpu_group, tcp_store_group + ) elif self.all2all_backend == "mori": from .all2all import MoriAll2AllManager @@ -133,7 +149,9 @@ class CudaCommunicator(DeviceCommunicatorBase): elif self.all2all_backend == "flashinfer_all2allv": from .all2all import FlashInferAllToAllManager - self.all2all_manager = FlashInferAllToAllManager(self.cpu_group) + self.all2all_manager = FlashInferAllToAllManager( + self.cpu_group, tcp_store_group + ) else: raise ValueError(f"Unknown all2all backend: {self.all2all_backend}") @@ -188,27 +206,12 @@ class CudaCommunicator(DeviceCommunicatorBase): return out if self.world_size == 1: return input_ + if self.use_vllm_comm: - # torch.ops.ixf_ops.vllm_all_reduce(input_, async_op=True) ixfd.all_reduce(input_, group=self.device_group, async_op=True) else: torch.distributed.all_reduce(input_, group=self.device_group) return input_ - pynccl_comm = self.pynccl_comm - if pynccl_comm is None or pynccl_comm.disabled: - out = input_.clone() - torch.distributed.all_reduce(out, group=self.device_group) - return out - assert pynccl_comm is not None - out = pynccl_comm.all_reduce(input_) - if out is None: - # fall back to the default all-reduce using PyTorch. - # this usually happens during testing. - # when we run the model, allreduce only happens for the TP - # group, where we always have either custom allreduce or pynccl. - out = input_.clone() - torch.distributed.all_reduce(out, group=self.device_group) - return out def reduce_scatter(self, input_: torch.Tensor, dim: int = -1): world_size = self.world_size @@ -230,10 +233,8 @@ class CudaCommunicator(DeviceCommunicatorBase): output_shape, dtype=input_tensor.dtype, device=input_tensor.device ) - # pynccl_comm.reduce_scatter(output, input_tensor) - torch.distributed.reduce_scatter_tensor(output, - input_tensor, - group=self.device_group) + # Perform reduce-scatter operation + ixfd.reduce_scatter_tensor(output,input_tensor,group=self.device_group, async_op=True) # Reshape before returning return output.movedim(0, dim).contiguous() @@ -278,12 +279,6 @@ class CudaCommunicator(DeviceCommunicatorBase): """NOTE: `dst` is the local rank of the destination rank.""" if dst is None: dst = (self.rank_in_group + 1) % self.world_size - - pynccl_comm = self.pynccl_comm - # if pynccl_comm is not None and not pynccl_comm.disabled: - # pynccl_comm.send(tensor, dst) - # else: - # torch.distributed.send(tensor, self.ranks[dst], self.device_group) if self.use_vllm_comm: ixfd.send(tensor, self.ranks[dst], self.device_group) else: @@ -298,17 +293,24 @@ class CudaCommunicator(DeviceCommunicatorBase): src = (self.rank_in_group - 1) % self.world_size tensor = torch.empty(size, dtype=dtype, device=self.device) - # pynccl_comm = self.pynccl_comm - # if pynccl_comm is not None and not pynccl_comm.disabled: - # pynccl_comm.recv(tensor, src) - # else: - # torch.distributed.recv(tensor, self.ranks[src], self.device_group) if self.use_vllm_comm: ixfd.recv(tensor, self.ranks[src], self.device_group) else: torch.distributed.recv(tensor, self.ranks[src], self.device_group) return tensor + def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor: + """Broadcast a tensor from source rank to all ranks.""" + if self.world_size == 1: + return tensor + + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.broadcast(tensor, src) + return tensor + else: + raise ValueError("No PyNCCL communicator found") + def destroy(self): if self.pynccl_comm is not None: self.pynccl_comm = None @@ -319,7 +321,7 @@ class CudaCommunicator(DeviceCommunicatorBase): self.fi_ar_comm = None if self.all2all_manager is not None: self.all2all_manager.destroy() - self.all2all_manager = None + self.all2all_manager = None # type: ignore[assignment] def all_gatherv( self, @@ -372,7 +374,6 @@ class CudaCommunicator(DeviceCommunicatorBase): def dispatch_router_logits( self, hidden_states: torch.Tensor, - extra_residual:torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, extra_tensors: list[torch.Tensor] | None = None, @@ -409,16 +410,13 @@ class CudaCommunicator(DeviceCommunicatorBase): This is a no-op in the base class. """ assert self.all2all_manager is not None - # return self.all2all_manager.dispatch( - # hidden_states, - # topk_weights, - # topk_ids, - # is_sequence_parallel, - # extra_tensors=extra_tensors, - # ) - hidden_states, extra_residual, router_logits = self.all2all_manager.dispatch( - hidden_states, extra_residual, router_logits) - return hidden_states, extra_residual, router_logits + return self.all2all_manager.dispatch( + hidden_states, + topk_weights, + topk_ids, + is_sequence_parallel, + extra_tensors=extra_tensors, + ) def combine( self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False @@ -432,3 +430,10 @@ class CudaCommunicator(DeviceCommunicatorBase): hidden_states, is_sequence_parallel, ) + + def batch_isend_irecv(self, p2p_ops: list): + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.batch_isend_irecv(p2p_ops) + else: + raise ValueError("No PyNCCL communicator found") diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 2fc35e8..44dc113 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -312,10 +312,19 @@ class PyNcclCommunicator: ) if stream is None: stream = current_stream() + if tensor.dtype in [ + torch.float8_e5m2, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, + ]: + nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8) + else: + nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype) self.nccl.ncclSend( buffer_type(tensor.data_ptr()), tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), + nccl_dtype, dst, self.comm, cudaStream_t(stream.cuda_stream), @@ -330,10 +339,19 @@ class PyNcclCommunicator: ) if stream is None: stream = current_stream() + if tensor.dtype in [ + torch.float8_e5m2, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, + ]: + nccl_dtype = ncclDataTypeEnum.from_torch(torch.uint8) + else: + nccl_dtype = ncclDataTypeEnum.from_torch(tensor.dtype) self.nccl.ncclRecv( buffer_type(tensor.data_ptr()), tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), + nccl_dtype, src, self.comm, cudaStream_t(stream.cuda_stream), @@ -384,3 +402,17 @@ class PyNcclCommunicator: def deregister_comm_window(self, window): return self.nccl.ncclCommWindowDeregister(self.comm, window) + + def batch_isend_irecv(self, p2p_ops: list, stream=None): + if self.disabled: + return + if stream is None: + stream = current_stream() + self.group_start() + for op in p2p_ops: + if op.op is torch.distributed.isend: + self.send(op.tensor, op.group_peer, stream) + elif op.op is torch.distributed.irecv: + self.recv(op.tensor, op.group_peer, stream) + + self.group_end() diff --git a/vllm/distributed/elastic_ep/__init__.py b/vllm/distributed/elastic_ep/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/distributed/elastic_ep/elastic_execute.py b/vllm/distributed/elastic_ep/elastic_execute.py new file mode 100644 index 0000000..22d5706 --- /dev/null +++ b/vllm/distributed/elastic_ep/elastic_execute.py @@ -0,0 +1,529 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy +import gc +import weakref +from collections.abc import Iterable, Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed import P2POp + +from vllm.compilation.counter import compilation_counter +from vllm.compilation.cuda_graph import CUDAGraphWrapper +from vllm.compilation.wrapper import reset_compile_wrapper +from vllm.config import ( + CompilationMode, + set_current_vllm_config, +) +from vllm.distributed import ( + get_dp_group, + get_ep_group, + get_pcp_group, + get_tp_group, +) +from vllm.distributed.elastic_ep.standby_state import ( + create_standby_groups, + get_standby_dp_group, + get_standby_ep_group, + pop_standby_groups, +) +from vllm.distributed.parallel_state import ( + _replace_active_groups, + prepare_communication_buffer_for_model, +) +from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.layer import FusedMoEParallelConfig +from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType +from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper +from vllm.v1.worker.workspace import lock_workspace, unlock_workspace + +logger = init_logger(__name__) + + +def batch_transfer_weights( + model: nn.Module, + is_sender: bool, + peer_rank: int, + dp_group: StatelessGroupCoordinator, + expert_weights: Sequence[Iterable[torch.Tensor]], +) -> None: + device_comm = dp_group.device_communicator + if device_comm is None: + raise ValueError("No device communicator found") + + expert_weights_set = set() + for weight_group in expert_weights: + for weight in weight_group: + expert_weights_set.add(weight.data_ptr()) + + state_dict = model.state_dict() + all_params = [] + + for name, param in state_dict.items(): + if name.endswith("expert_map"): + continue + if param.data_ptr() not in expert_weights_set: + all_params.append(param.data) + + assert len(all_params) > 0 + p2p_ops = [] + for param in all_params: + op = object.__new__(P2POp) + if is_sender: + op.op = torch.distributed.isend + op.tensor = param + else: + op.op = torch.distributed.irecv + op.tensor = param + op.group_peer = peer_rank + p2p_ops.append(op) + device_comm.batch_isend_irecv(p2p_ops) + + +def broadcast_expert_mapping( + physical_to_logical: torch.Tensor | None, + num_local_physical_experts: int | None, + num_logical_experts: int | None, + dp_group: StatelessGroupCoordinator, + device: torch.device, + src_rank: int = 0, +) -> tuple[torch.Tensor, int, int]: + if dp_group.rank_in_group == src_rank: + assert physical_to_logical is not None + assert num_local_physical_experts is not None + assert num_logical_experts is not None + assert physical_to_logical.dtype == torch.int64 + shape_tensor = torch.tensor( + list(physical_to_logical.shape), dtype=torch.int64, device="cpu" + ) + metadata_tensor = torch.tensor( + [num_local_physical_experts, num_logical_experts], + dtype=torch.int64, + device="cpu", + ) + else: + shape_tensor = torch.empty(2, dtype=torch.int64, device="cpu") + metadata_tensor = torch.empty(2, dtype=torch.int64, device="cpu") + + shape_tensor = dp_group.tcp_store_group.broadcast(shape_tensor, src_rank) + metadata_tensor = dp_group.tcp_store_group.broadcast(metadata_tensor, src_rank) + + if dp_group.rank_in_group != src_rank: + assert device is not None + physical_to_logical = torch.empty( + tuple(shape_tensor.tolist()), + dtype=torch.int64, + device=device, + ) + + assert physical_to_logical is not None + physical_to_logical = dp_group.broadcast(physical_to_logical, src_rank) + num_local_physical_experts = int(metadata_tensor[0].item()) + num_logical_experts = int(metadata_tensor[1].item()) + + return physical_to_logical, num_local_physical_experts, num_logical_experts + + +class ElasticEPScalingExecutor: + def __init__(self, worker): + self.worker_ref = weakref.ref(worker) + self.reconfig_request = None + + @property + def worker(self): + worker = self.worker_ref() + if worker is None: + raise RuntimeError("Worker has been garbage collected") + return worker + + def execute(self, execute_method: str, *args, **kwargs): + method = getattr(self, execute_method, None) + if method is None: + raise ValueError(f"Unknown execute method: {execute_method}") + return method(*args, **kwargs) + + def create_standby_groups( + self, reconfig_request: ReconfigureDistributedRequest + ) -> None: + self.reconfig_request = reconfig_request + new_dp_size = reconfig_request.new_data_parallel_size + world_size = self.worker.vllm_config.parallel_config.world_size + new_world_size_across_dp = world_size * new_dp_size + updated_config = copy.copy(self.worker.vllm_config) + updated_config.parallel_config = copy.deepcopy( + self.worker.vllm_config.parallel_config + ) + updated_config.parallel_config.data_parallel_size = new_dp_size + with set_current_vllm_config(updated_config): + create_standby_groups( + new_dp_size=new_dp_size, + new_world_size_across_dp=new_world_size_across_dp, + master_ip=reconfig_request.new_data_parallel_master_ip, + world_group_ports=reconfig_request.new_stateless_world_group_port_list, + dp_group_ports=reconfig_request.new_stateless_dp_group_port_list, + ep_group_ports=reconfig_request.new_stateless_ep_group_port_list, + eplb_group_ports=reconfig_request.new_stateless_eplb_group_port_list, + ) + self.worker.model_runner.eep_eplb_suppressed = True + standby_ep_group = get_standby_ep_group() + assert standby_ep_group is not None + if standby_ep_group.rank == 0: + logger.info("[Elastic EP] EPLB disabled during elastic scaling transition") + + def transfer_weights(self, old_dp_size: int, new_dp_size: int) -> None: + standby_dp_group = get_standby_dp_group() + assert standby_dp_group is not None + # Broadcast old_dp_size to all workers in standby group + if standby_dp_group.rank_in_group < old_dp_size: + old_dp_size_tensor = torch.tensor( + [old_dp_size], dtype=torch.int64, device="cpu" + ) + else: + old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu") + old_dp_size_tensor = standby_dp_group.tcp_store_group.broadcast( + old_dp_size_tensor, 0 + ) + + num_new_workers = new_dp_size - old_dp_size + dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank + + # Sender-receiver pairing: the first new_workers % old_dp_size + # senders get (k+1) contiguous receivers, the rest get k + # receivers. + num_dst_per_sender = num_new_workers // old_dp_size + remainder = num_new_workers % old_dp_size + + if dp_rank < remainder: + recv_begin = dp_rank * (num_dst_per_sender + 1) + recv_end = recv_begin + num_dst_per_sender + 1 + else: + recv_begin = ( + remainder * (num_dst_per_sender + 1) + + (dp_rank - remainder) * num_dst_per_sender + ) + recv_end = recv_begin + num_dst_per_sender + + ranks_to_send = list(range(old_dp_size + recv_begin, old_dp_size + recv_end)) + + model = self.worker.model_runner.get_model() + for new_worker_rank in sorted(ranks_to_send): + batch_transfer_weights( + model=model, + is_sender=True, + peer_rank=new_worker_rank, + dp_group=standby_dp_group, + expert_weights=model.expert_weights, + ) + torch.cuda.synchronize() + + def broadcast_expert_mapping(self) -> None: + standby_dp_group = get_standby_dp_group() + assert standby_dp_group is not None + model_config = self.worker.model_runner.model_config + eplb_state = self.worker.model_runner.eplb_state + assert eplb_state is not None + eplb_model_state = eplb_state.model_states[model_config.compute_hash()] + physical_to_logical = eplb_model_state.physical_to_logical_map + num_physical_experts = physical_to_logical.shape[1] + num_local_physical_experts = num_physical_experts // get_ep_group().world_size + num_logical_experts = eplb_model_state.logical_replica_count.shape[1] + broadcast_expert_mapping( + physical_to_logical=physical_to_logical, + num_local_physical_experts=num_local_physical_experts, + num_logical_experts=num_logical_experts, + dp_group=standby_dp_group, + src_rank=0, + device=self.worker.device, + ) + + def switch_and_remove(self) -> None: + _replace_active_groups(world=None, dp=None, ep=None, eplb=None, node_count=None) + + def switch_and_prepare(self) -> None: + old_dp_size = get_dp_group().world_size + old_ep_size = get_ep_group().world_size + + _replace_active_groups(**pop_standby_groups()) + + parallel_config = self.worker.vllm_config.parallel_config + reconfig_request = self.reconfig_request + assert reconfig_request is not None + new_dp_size = reconfig_request.new_data_parallel_size + new_ep_size = get_ep_group().world_size + + parallel_config.data_parallel_size = new_dp_size + if ( + reconfig_request.new_data_parallel_rank + != ReconfigureRankType.KEEP_CURRENT_RANK + ): + parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank + if ( + reconfig_request.new_data_parallel_rank_local + != ReconfigureRankType.KEEP_CURRENT_RANK + ): + parallel_config.data_parallel_rank_local = ( + reconfig_request.new_data_parallel_rank_local + ) + parallel_config.data_parallel_master_ip = ( + reconfig_request.new_data_parallel_master_ip + ) + parallel_config.data_parallel_master_port = ( + reconfig_request.new_data_parallel_master_port + ) + + # Reconfigure MoE modules with new EP size + moe_modules = [ + module + for module in self.worker.model_runner.model.modules() + if ( + module.__class__.__name__ == "FusedMoE" + or module.__class__.__name__ == "SharedFusedMoE" + ) + ] + num_local_experts = moe_modules[0].moe_config.num_local_experts + assert all( + module.moe_config.num_local_experts == num_local_experts + for module in moe_modules + ), "All MoE modules must have the same number of experts" + for module in moe_modules: + module.moe_config.num_experts = num_local_experts * new_ep_size + module.global_num_experts = module.moe_config.num_experts + tp_size = get_tp_group().world_size + is_sequence_parallel = parallel_config.use_sequence_parallel_moe + sp_size = tp_size if is_sequence_parallel else 1 + module.moe_parallel_config = FusedMoEParallelConfig.make( + tp_size_=tp_size, + pcp_size_=get_pcp_group().world_size, + dp_size_=get_dp_group().world_size, + sp_size_=sp_size, + vllm_parallel_config=parallel_config, + ) + module.moe_config.moe_parallel_config = module.moe_parallel_config + + # Update EPLB state + eplb_state = self.worker.model_runner.eplb_state + assert eplb_state is not None + model_config = self.worker.model_runner.model_config + eplb_model_state = eplb_state.model_states[model_config.compute_hash()] + + num_physical_experts = num_local_experts * new_ep_size + num_logical_experts = eplb_model_state.logical_replica_count.shape[1] + parallel_config.eplb_config.num_redundant_experts = ( + num_physical_experts - num_logical_experts + ) + old_physical_to_logical = eplb_model_state.physical_to_logical_map + num_moe_layers = old_physical_to_logical.shape[0] + num_local_experts = eplb_model_state.expert_load_pass.shape[1] // old_ep_size + if new_dp_size > old_dp_size: + expanded_physical_to_logical = torch.full( + (num_moe_layers, num_local_experts * new_ep_size), + -1, + dtype=old_physical_to_logical.dtype, + device=old_physical_to_logical.device, + ) + expanded_physical_to_logical[:, : num_local_experts * old_ep_size] = ( + old_physical_to_logical + ) + eplb_model_state.physical_to_logical_map = expanded_physical_to_logical + + old_num_physical_experts = eplb_model_state.expert_load_pass.shape[1] + pad_size = num_physical_experts - old_num_physical_experts + if new_dp_size > old_dp_size: + assert pad_size > 0 + expanded_expert_load_pass = F.pad( + eplb_model_state.expert_load_pass, (0, pad_size), value=0 + ) + expanded_expert_load_window = F.pad( + eplb_model_state.expert_load_window, (0, pad_size), value=0 + ) + eplb_model_state.expert_load_pass = expanded_expert_load_pass + eplb_model_state.expert_load_window = expanded_expert_load_window + eplb_state.num_valid_physical_experts = old_num_physical_experts + else: + assert pad_size < 0 + eplb_model_state.expert_load_pass = eplb_model_state.expert_load_pass[ + :, :num_physical_experts + ] + eplb_model_state.expert_load_window = eplb_model_state.expert_load_window[ + :, :, :num_physical_experts + ] + eplb_state.num_valid_physical_experts = num_physical_experts + + model = self.worker.model_runner.get_model() + model.expert_weights = [] + with set_current_vllm_config(self.worker.vllm_config): + model.set_eplb_state( + eplb_model_state.expert_load_pass, + eplb_model_state.logical_to_physical_map, + eplb_model_state.logical_replica_count, + ) + model.update_physical_experts_metadata( + num_physical_experts=num_physical_experts, + num_local_physical_experts=num_local_experts, + ) + # Force re-creation of the modular kernel (and all2all manager) + # for the new EP size by resetting quant_method to base + for module in moe_modules: + if hasattr(module.quant_method, "old_quant_method"): + module.quant_method = module.quant_method.old_quant_method + module.runner = module._init_runner() + prepare_communication_buffer_for_model(self.worker.model_runner.model) + if ( + self.worker.vllm_config.compilation_config.mode + == CompilationMode.STOCK_TORCH_COMPILE + ): + # NOTE(yongji): when using stock torch.compile, + # torch.compile is triggered during GPUModelRunner's load_model() + # TODO(yongji):check do we need to re-trigger torch.compile here? + # any changes to the tensor shapes in execution should already + # be handled internally by torch.compile. + backend = self.worker.vllm_config.compilation_config.init_backend( + self.worker.vllm_config + ) + compilation_counter.stock_torch_compile_count += 1 + self.worker.model_runner.model.compile(fullgraph=True, backend=backend) + + # release all previously captured CUDA graphs + if isinstance(self.worker.model_runner.model, CUDAGraphWrapper): + wrapper = self.worker.model_runner.model + wrapper.concrete_cudagraph_entries = {} + elif isinstance(self.worker.model_runner.model, UBatchWrapper): + raise RuntimeError("DBO is not yet supported in elastic EP") + + multi_block_table = self.worker.model_runner.input_batch.block_table + saved_block_tables: list[tuple[torch.Tensor, torch.Tensor]] = [] + for bt in multi_block_table.block_tables: + saved_block_tables.append( + (bt.block_table.gpu.clone(), bt.block_table.cpu.clone()) + ) + multi_block_table.clear() + + # reset the compile wrapper + torch.compiler.reset() + with set_current_vllm_config(self.worker.vllm_config): + reset_compile_wrapper(self.worker.model_runner.get_model()) + + gc.collect() + torch.cuda.synchronize() + torch.cuda.empty_cache() + unlock_workspace() + self.worker.compile_or_warm_up_model() + lock_workspace() + + for bt, (saved_gpu, saved_cpu) in zip( + multi_block_table.block_tables, saved_block_tables + ): + bt.block_table.gpu.copy_(saved_gpu) + bt.block_table.cpu.copy_(saved_cpu) + + def perform_eplb_reshuffle(self, new_dp_size: int | None = None) -> None: + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Starting expert resharding...") + + eplb_state = self.worker.model_runner.eplb_state + assert eplb_state is not None + + model_config = self.worker.model_runner.model_config + eplb_model_state = eplb_state.model_states[model_config.compute_hash()] + is_async_enabled = eplb_state.is_async + eplb_state.is_async = False + if new_dp_size is None: + eplb_state.rearrange() + else: + # scale down + parallel_config = self.worker.vllm_config.parallel_config + tp_size = parallel_config.tensor_parallel_size + old_ep_size = parallel_config.data_parallel_size * tp_size + new_ep_size = new_dp_size * tp_size + + rank_mapping = { + old_ep_rank: old_ep_rank if old_ep_rank < new_ep_size else -1 + for old_ep_rank in range(old_ep_size) + } + + eplb_state.rearrange(rank_mapping=rank_mapping) + # NOTE(yongji): check whether we need to synchronize here + torch.cuda.synchronize() + # reset expert_rearrangement_step to ensure all ranks are synchronized + eplb_state.expert_rearrangement_step = 0 + eplb_state.num_valid_physical_experts = ( + eplb_model_state.physical_to_logical_map.shape[1] + ) + eplb_state.is_async = is_async_enabled + self.worker.model_runner.eep_eplb_suppressed = False + if get_ep_group().rank == 0: + logger.info("[Elastic EP] Expert resharding completed") + + def receive_weights(self) -> None: + dp_group = get_dp_group() + assert isinstance(dp_group, StatelessGroupCoordinator) + new_dp_size = dp_group.world_size + dp_rank = self.worker.vllm_config.parallel_config.data_parallel_rank + + # Receive old_dp_size broadcasted during transfer_weights + old_dp_size_tensor = torch.empty(1, dtype=torch.int64, device="cpu") + old_dp_size_tensor = dp_group.tcp_store_group.broadcast(old_dp_size_tensor, 0) + old_dp_size = int(old_dp_size_tensor[0].item()) + + # Calculate which existing worker will send to this new worker + num_new_workers = new_dp_size - old_dp_size + new_worker_idx = dp_rank - old_dp_size + num_dst_per_sender = num_new_workers // old_dp_size + remainder = num_new_workers % old_dp_size + + if new_worker_idx < remainder * (num_dst_per_sender + 1): + sender_rank = new_worker_idx // (num_dst_per_sender + 1) + else: + sender_rank = ( + remainder + + (new_worker_idx - remainder * (num_dst_per_sender + 1)) + // num_dst_per_sender + ) + + model = self.worker.model_runner.get_model() + batch_transfer_weights( + model=model, + is_sender=False, + peer_rank=sender_rank, + dp_group=dp_group, + expert_weights=model.expert_weights, + ) + torch.cuda.synchronize() + + def receive_expert_mapping(self) -> tuple[torch.Tensor, int, int]: + dp_group = get_dp_group() + assert isinstance(dp_group, StatelessGroupCoordinator) + physical_to_logical, num_local_physical_experts, num_logical_experts = ( + broadcast_expert_mapping( + physical_to_logical=None, + num_local_physical_experts=None, + num_logical_experts=None, + dp_group=dp_group, + src_rank=0, + device=self.worker.device, + ) + ) + num_moe_layers = physical_to_logical.shape[0] + new_dp_size = get_dp_group().world_size + tp_size = self.worker.vllm_config.parallel_config.tensor_parallel_size + new_ep_size = new_dp_size * tp_size + expanded_physical_to_logical = torch.full( + (num_moe_layers, num_local_physical_experts * new_ep_size), + -1, + dtype=physical_to_logical.dtype, + device=physical_to_logical.device, + ) + old_num_physical_experts = physical_to_logical.shape[1] + expanded_physical_to_logical[:, :old_num_physical_experts] = physical_to_logical + return ( + expanded_physical_to_logical, + num_logical_experts, + old_num_physical_experts, + ) + + def prepare_new_worker(self) -> None: + with set_current_vllm_config(self.worker.vllm_config): + prepare_communication_buffer_for_model(self.worker.model_runner.get_model()) diff --git a/vllm/distributed/elastic_ep/elastic_state.py b/vllm/distributed/elastic_ep/elastic_state.py new file mode 100644 index 0000000..4845a16 --- /dev/null +++ b/vllm/distributed/elastic_ep/elastic_state.py @@ -0,0 +1,563 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import enum +import time +import weakref +from datetime import timedelta +from typing import TYPE_CHECKING, Literal + +import torch.distributed + +from vllm.config import ParallelConfig +from vllm.distributed import ( + sched_yield, + stateless_destroy_torch_distributed_process_group, +) +from vllm.logger import init_logger +from vllm.v1.engine import ( + EEPNotificationType, + ReconfigureDistributedRequest, + ReconfigureRankType, +) +from vllm.v1.engine.core import DPEngineCoreProc + +if TYPE_CHECKING: + from vllm.config import VllmConfig + from vllm.v1.executor.abstract import Executor + +logger = init_logger(__name__) + +WorkerType = Literal["existing", "new", "removing"] + + +class ScaleUpExistingEngineState(enum.IntEnum): + WAIT_NEW_CORE_ENGINES_INIT = 0 + CREATE_STANDBY_GROUPS = 1 + TRANSFER_EXPERT_MAPPING = 2 + WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT = 3 + TRANSFER_WEIGHTS = 4 + SYNC_KV_CACHE_MEMORY_SIZE = 5 + SWITCH_AND_PREPARE = 6 + EPLB_RESHUFFLE = 7 + COMPLETE = 8 + + +class ScaleUpNewEngineState(enum.IntEnum): + PREPARE = 0 + EPLB_RESHUFFLE = 1 + COMPLETE = 2 + + +class ScaleDownRemainingEngineState(enum.IntEnum): + PREPARE = 0 + EPLB_RESHUFFLE = 1 + SWITCH_AND_PREPARE = 2 + COMPLETE = 3 + + +class ScaleDownRemovingEngineState(enum.IntEnum): + PREPARE = 0 + EPLB_RESHUFFLE = 1 + COMPLETE = 2 + + +class _BarrierTimeoutError(RuntimeError): + """ + Exception raised for timeout + in the first stage of our two-staged + TCPStore based barrier to synchronize the + execution of all engines in the DP group. + """ + + +class ElasticEPScalingState: + def __init__( + self, + model_executor: "Executor", + engine_core: "DPEngineCoreProc", + vllm_config: "VllmConfig", + new_parallel_config: ParallelConfig, + worker_type: WorkerType, + scale_type: Literal["scale_up", "scale_down"], + reconfig_request: ReconfigureDistributedRequest | None = None, + ): + self.model_executor_ref = weakref.ref(model_executor) + self.engine_core_ref = weakref.ref(engine_core) + self.vllm_config = vllm_config + self.old_dp_group = self.engine_core.dp_group if worker_type != "new" else None + self.old_dp_store = self.engine_core.dp_store if worker_type != "new" else None + self.new_parallel_config: ParallelConfig = new_parallel_config + self.new_dp_group: torch.distributed.ProcessGroup | None = ( + self.engine_core.dp_group if worker_type == "new" else None + ) + self.new_dp_store = self.engine_core.dp_store if worker_type == "new" else None + self.worker_type = worker_type + self.scale_type = scale_type + self.reconfig_request = reconfig_request + + if scale_type == "scale_up": + self.state = ( + ScaleUpNewEngineState.PREPARE + if worker_type == "new" + else ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT + ) + else: + self.state = ( + ScaleDownRemovingEngineState.PREPARE + if worker_type == "removing" + else ScaleDownRemainingEngineState.PREPARE + ) + + @property + def model_executor(self) -> "Executor": + model_executor = self.model_executor_ref() + if model_executor is None: + raise RuntimeError("Model executor has been garbage collected") + return model_executor + + @property + def engine_core(self) -> "DPEngineCoreProc": + engine_core = self.engine_core_ref() + if engine_core is None: + raise RuntimeError("Engine core has been garbage collected") + return engine_core + + def progress(self) -> bool: + if self.scale_type == "scale_up": + return ( + self._progress_new_engine() + if self.worker_type == "new" + else self._progress_existing_engine() + ) + return ( + self._progress_removing_engine() + if self.worker_type == "removing" + else self._progress_remaining_engine() + ) + + def _execute_tcp_store_barrier( + self, dp_store, group_rank, group_size, barrier_id, timeout=None + ): + arrival_key = f"arrival_{barrier_id}_{group_rank}" + dp_store.set(arrival_key, b"1") + + start_time = time.time() + processes_arrived: set[int] = set() + + while len(processes_arrived) < group_size: + if ( + timeout is not None + and time.time() - start_time > timeout.total_seconds() + ): + raise _BarrierTimeoutError( + f"Barrier timed out after {timeout.total_seconds()} seconds" + ) + + for i in range(group_size): + if i in processes_arrived: + continue + + key = f"arrival_{barrier_id}_{i}" + present = dp_store.check([key]) + if present: + processes_arrived.add(i) + + if len(processes_arrived) < group_size: + sched_yield() + + def _staged_barrier(self, use_new_group: bool, barrier_name: str) -> bool: + """ + Execute a two-staged barrier to synchronize all engines in the DP group. + + Some DP EngineCores may receive the reconfiguration notifications + later than others, and already proceed to engine step (model forward) + in the busy loop. + In this case, EngineCores that already proceed to reconfiguration + should skip reconfiguration and execute model forward for one more + step, so in the next step, all EngineCores will be synchronized. + We use a two-staged barrier to achieve this. The first time each + EngineCore executes the barrier, if a timeout is reached before the + barrier completes, that means some EngineCores have already entered + engine step. The EngineCores that timed out will then proceed to + engine step, and will synchronize with the other EngineCores in the + next step with a barrier without timeout. + """ + dp_store = self.new_dp_store if use_new_group else self.old_dp_store + dp_group = self.new_dp_group if use_new_group else self.old_dp_group + assert dp_group is not None + + group_rank = dp_group.rank() + group_size = dp_group.size() + barrier_id = f"eep_barrier_{barrier_name}" + sync_key = f"{barrier_id}_sync" + + # TODO(yongji): figure out appropriate timeout for the barrier + timeout = None if dp_store.check([sync_key]) else timedelta(seconds=5) + + try: + self._execute_tcp_store_barrier( + dp_store, group_rank, group_size, barrier_id, timeout=timeout + ) + torch.distributed.barrier(dp_group) + if group_rank == 0: + dp_store.delete_key(sync_key) + for i in range(group_size): + dp_store.delete_key(f"arrival_{barrier_id}_{i}") + return True + except _BarrierTimeoutError as e: + if timeout is None: + raise RuntimeError("Unexpected timeout encountered") from e + dp_store.compare_set(sync_key, "", b"1") + return False + + def _progress_existing_engine(self) -> bool: + state = self.state + + if state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT: + return False + + elif state == ScaleUpExistingEngineState.CREATE_STANDBY_GROUPS: + # NOTE(yongji): wait for all existing workers to receive the request + if ( + int(self.old_dp_store.get("eep_barrier_engine_count")) + < self.old_dp_group.size() + ): + return False + if not self._staged_barrier( + use_new_group=False, barrier_name="create_standby_groups" + ): + return False + if self.old_dp_group.rank() == 0: + self.old_dp_store.delete_key("eep_barrier_engine_count") + self._create_standby_groups() + self.state = ScaleUpExistingEngineState.TRANSFER_EXPERT_MAPPING + return True + + elif state == ScaleUpExistingEngineState.TRANSFER_EXPERT_MAPPING: + self._transfer_expert_mapping() + self.state = ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT + return True + + elif state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT: + return False + + elif state == ScaleUpExistingEngineState.TRANSFER_WEIGHTS: + if ( + int(self.old_dp_store.get("eep_barrier_engine_count")) + < self.old_dp_group.size() + ): + return False + if not self._staged_barrier( + use_new_group=False, barrier_name="transfer_weights" + ): + return False + if self.old_dp_group.rank() == 0: + self.old_dp_store.delete_key("eep_barrier_engine_count") + self._transfer_weights() + self.state = ScaleUpExistingEngineState.SYNC_KV_CACHE_MEMORY_SIZE + return True + + elif state == ScaleUpExistingEngineState.SYNC_KV_CACHE_MEMORY_SIZE: + self._sync_kv_cache_memory_size() + self.state = ScaleUpExistingEngineState.SWITCH_AND_PREPARE + return True + + elif state == ScaleUpExistingEngineState.SWITCH_AND_PREPARE: + self._switch_and_prepare() + self.state = ScaleUpExistingEngineState.EPLB_RESHUFFLE + self.new_dp_store.add("eep_barrier_engine_count", 1) + return True + + elif state == ScaleUpExistingEngineState.EPLB_RESHUFFLE: + assert self.new_dp_group is not None + if ( + int(self.new_dp_store.get("eep_barrier_engine_count")) + < self.new_dp_group.size() + ): + return False + if not self._staged_barrier( + use_new_group=True, barrier_name="eplb_reshuffle" + ): + return False + if self.new_dp_group.rank() == 0: + self.new_dp_store.delete_key("eep_barrier_engine_count") + self._eplb_reshuffle() + self.state = ScaleUpExistingEngineState.COMPLETE + self._update_parallel_config() + return True + + else: + assert self.state == ScaleUpExistingEngineState.COMPLETE + return True + + def _progress_new_engine(self) -> bool: + state = self.state + assert self.new_dp_group is not None + + if state == ScaleUpNewEngineState.PREPARE: + tensor = torch.tensor([0, 0, 0], dtype=torch.int32, device="cpu") + torch.distributed.all_reduce( + tensor, + op=torch.distributed.ReduceOp.MAX, + group=self.new_dp_group, + ) + data = tensor.tolist() + self.engine_core.engines_running = bool(data[0]) + self.engine_core.current_wave = int(data[1]) + self.engine_core.step_counter = int(data[2]) + self.state = ScaleUpNewEngineState.EPLB_RESHUFFLE + self.new_dp_store.add("eep_barrier_engine_count", 1) + return True + + elif state == ScaleUpNewEngineState.EPLB_RESHUFFLE: + if ( + int(self.new_dp_store.get("eep_barrier_engine_count")) + < self.new_dp_group.size() + ): + return False + if not self._staged_barrier( + use_new_group=True, barrier_name="eplb_reshuffle" + ): + return False + assert self.new_dp_group.rank() > 0 + self._eplb_reshuffle() + self.state = ScaleUpNewEngineState.COMPLETE + return True + + else: + assert self.state == ScaleUpNewEngineState.COMPLETE + return True + + def _progress_remaining_engine(self) -> bool: + state = self.state + + if state == ScaleDownRemainingEngineState.PREPARE: + self.state = ScaleDownRemainingEngineState.EPLB_RESHUFFLE + self.old_dp_store.add("eep_barrier_engine_count", 1) + return True + + elif state == ScaleDownRemainingEngineState.EPLB_RESHUFFLE: + if ( + int(self.old_dp_store.get("eep_barrier_engine_count")) + < self.old_dp_group.size() + ): + return False + if not self._staged_barrier( + use_new_group=False, barrier_name="eplb_reshuffle" + ): + return False + if self.old_dp_group.rank() == 0: + self.old_dp_store.delete_key("eep_barrier_engine_count") + self._eplb_reshuffle_before_scale_down() + self.state = ScaleDownRemainingEngineState.SWITCH_AND_PREPARE + # NOTE(yongji): currently, after EPLB reshuffle + # that redistributes experts to remaining workers, workers + # to be removed will immediately initiate shutdown; + # existing workers can no longer execute forward steps using + # the old setup. In the future, we may keep + # the removing workers alive a bit longer, + # e.g., to drain in-batch requests. + self._create_standby_groups() + self._switch_and_prepare() + self._update_parallel_config() + self.state = ScaleDownRemainingEngineState.COMPLETE + return True + + else: + assert self.state == ScaleDownRemainingEngineState.COMPLETE + return True + + def _progress_removing_engine(self) -> bool: + state = self.state + + if state == ScaleDownRemovingEngineState.PREPARE: + self.state = ScaleDownRemovingEngineState.EPLB_RESHUFFLE + self.old_dp_store.add("eep_barrier_engine_count", 1) + return True + + if state == ScaleDownRemovingEngineState.EPLB_RESHUFFLE: + if ( + int(self.old_dp_store.get("eep_barrier_engine_count")) + < self.old_dp_group.size() + ): + return False + if not self._staged_barrier( + use_new_group=False, barrier_name="eplb_reshuffle" + ): + return False + assert self.old_dp_group.rank() > 0 + self._eplb_reshuffle_before_scale_down() + self._switch_and_remove() + self.state = ScaleDownRemovingEngineState.COMPLETE + self.engine_core._eep_send_engine_core_notification( + EEPNotificationType.SHUTDOWN_COMPLETE + ) + self.engine_core.shutdown() + return True + + else: + assert self.state == ScaleDownRemovingEngineState.COMPLETE + return True + + def handle_notification(self, notification_type: EEPNotificationType): + assert self.worker_type != "new" + if ( + notification_type == EEPNotificationType.NEW_CORE_ENGINES_INIT_READY + and self.state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT + ): + self.old_dp_store.add("eep_barrier_engine_count", 1) + self.state = ScaleUpExistingEngineState.CREATE_STANDBY_GROUPS + elif ( + notification_type == EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY + and self.state + == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT + ): + self.old_dp_store.add("eep_barrier_engine_count", 1) + self.state = ScaleUpExistingEngineState.TRANSFER_WEIGHTS + + def is_complete(self) -> bool: + if self.scale_type == "scale_up": + return ( + self.state == ScaleUpNewEngineState.COMPLETE + if self.worker_type == "new" + else self.state == ScaleUpExistingEngineState.COMPLETE + ) + return ( + self.state == ScaleDownRemovingEngineState.COMPLETE + if self.worker_type == "removing" + else self.state == ScaleDownRemainingEngineState.COMPLETE + ) + + def _create_standby_groups(self): + self.new_dp_group, self.new_dp_store = ( + self.new_parallel_config.stateless_init_dp_group(return_store=True) + ) + self.model_executor.collective_rpc( + "elastic_ep_execute", args=("create_standby_groups", self.reconfig_request) + ) + if self.old_dp_group.rank() == 0: + logger.info("[Elastic EP] Created standby communication groups") + + def _transfer_weights(self): + assert self.reconfig_request is not None + old_dp_size = self.old_dp_group.size() + new_dp_size = self.reconfig_request.new_data_parallel_size + + self.model_executor.collective_rpc( + "elastic_ep_execute", args=("transfer_weights", old_dp_size, new_dp_size) + ) + if self.old_dp_group.rank() == 0: + logger.info("[Elastic EP] Transferred weights to new workers") + + def _transfer_expert_mapping(self): + self.model_executor.collective_rpc( + "elastic_ep_execute", args=("broadcast_expert_mapping",) + ) + if self.old_dp_group.rank() == 0: + logger.info("[Elastic EP] Broadcasted expert mapping to new workers") + + def _sync_kv_cache_memory_size(self): + assert self.engine_core.available_gpu_memory_for_kv_cache > 0 + assert self.new_dp_group is not None + ParallelConfig.sync_kv_cache_memory_size( + self.new_dp_group, + self.engine_core.available_gpu_memory_for_kv_cache, + ) + if self.old_dp_group.rank() == 0: + logger.info("[Elastic EP] Synced KV cache memory size to new workers") + + def _switch_and_prepare(self): + self.model_executor.collective_rpc( + "elastic_ep_execute", args=("switch_and_prepare",) + ) + old_dp_group = self.old_dp_group + stateless_destroy_torch_distributed_process_group(old_dp_group) + assert self.new_dp_group is not None + new_dp_group = self.new_dp_group + self.engine_core.dp_group = new_dp_group + self.engine_core.dp_rank = new_dp_group.rank() + self.engine_core.dp_store = self.new_dp_store + engines_running = int(self.engine_core.engines_running) + current_wave = self.engine_core.current_wave + step_counter = self.engine_core.step_counter + tensor = torch.tensor( + [engines_running, current_wave, step_counter], + dtype=torch.int32, + device="cpu", + ) + torch.distributed.all_reduce( + tensor, op=torch.distributed.ReduceOp.MAX, group=new_dp_group + ) + data = tensor.tolist() + self.engine_core.engines_running = bool(data[0]) + self.engine_core.current_wave = int(data[1]) + self.engine_core.step_counter = int(data[2]) + if new_dp_group.rank() == 0: + self.engine_core._eep_send_engine_core_notification( + EEPNotificationType.RECONFIGURE_FINISHED + ) + logger.info("[Elastic EP] Switched to new setup") + + def _eplb_reshuffle(self): + self.model_executor.collective_rpc( + "elastic_ep_execute", args=("perform_eplb_reshuffle",) + ) + assert self.new_dp_group is not None + if self.new_dp_group.rank() == 0: + logger.info("[Elastic EP] EPLB reshuffle completed") + + def _eplb_reshuffle_before_scale_down(self): + assert self.reconfig_request is not None + self.model_executor.collective_rpc( + "elastic_ep_execute", + args=( + "perform_eplb_reshuffle", + self.reconfig_request.new_data_parallel_size, + ), + ) + if self.old_dp_group.rank() == 0: + logger.info("[Elastic EP] EPLB reshuffle completed") + + def _switch_and_remove(self): + self.model_executor.collective_rpc( + "elastic_ep_execute", args=("switch_and_remove",) + ) + + def _update_parallel_config(self): + assert self.reconfig_request is not None + reconfig_request = self.reconfig_request + parallel_config = self.vllm_config.parallel_config + parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size + if ( + reconfig_request.new_data_parallel_rank + != ReconfigureRankType.KEEP_CURRENT_RANK + ): + parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank + if ( + reconfig_request.new_data_parallel_rank_local + != ReconfigureRankType.KEEP_CURRENT_RANK + ): + parallel_config.data_parallel_rank_local = ( + reconfig_request.new_data_parallel_rank_local + ) + parallel_config.data_parallel_master_ip = ( + reconfig_request.new_data_parallel_master_ip + ) + parallel_config.data_parallel_master_port = ( + reconfig_request.new_data_parallel_master_port + ) + parallel_config._data_parallel_master_port_list = ( + reconfig_request.new_data_parallel_master_port_list + ) + parallel_config._stateless_world_group_port_list = ( + reconfig_request.new_stateless_world_group_port_list + ) + parallel_config._stateless_dp_group_port_list = ( + reconfig_request.new_stateless_dp_group_port_list + ) + parallel_config._stateless_ep_group_port_list = ( + reconfig_request.new_stateless_ep_group_port_list + ) + parallel_config._stateless_eplb_group_port_list = ( + reconfig_request.new_stateless_eplb_group_port_list + ) diff --git a/vllm/distributed/elastic_ep/standby_state.py b/vllm/distributed/elastic_ep/standby_state.py new file mode 100644 index 0000000..d11e0b5 --- /dev/null +++ b/vllm/distributed/elastic_ep/standby_state.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.distributed.parallel_state import ( + _init_stateless_group, + _node_count, + get_pp_group, + get_tp_group, + get_world_group, +) +from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator + +_STANDBY_WORLD: StatelessGroupCoordinator | None = None +_STANDBY_WORLD_NODE_COUNT: int | None = None +_STANDBY_DP: StatelessGroupCoordinator | None = None +_STANDBY_EP: StatelessGroupCoordinator | None = None +_STANDBY_EPLB: StatelessGroupCoordinator | None = None + + +def get_standby_dp_group() -> StatelessGroupCoordinator | None: + return _STANDBY_DP + + +def get_standby_ep_group() -> StatelessGroupCoordinator | None: + return _STANDBY_EP + + +def get_standby_eplb_group() -> StatelessGroupCoordinator | None: + return _STANDBY_EPLB + + +def get_standby_world_group() -> StatelessGroupCoordinator | None: + return _STANDBY_WORLD + + +def create_standby_groups( + new_dp_size: int, + new_world_size_across_dp: int, + master_ip: str, + world_group_ports: list[list[int]], + dp_group_ports: list[list[int]], + ep_group_ports: list[list[int]], + eplb_group_ports: list[list[int]] | None = None, + backend: str | None = None, +) -> None: + global \ + _STANDBY_WORLD, \ + _STANDBY_WORLD_NODE_COUNT, \ + _STANDBY_DP, \ + _STANDBY_EP, \ + _STANDBY_EPLB + + assert new_world_size_across_dp == torch.distributed.get_world_size() * new_dp_size + world_group = get_world_group() + assert isinstance(world_group, StatelessGroupCoordinator) + backend = backend or world_group.backend + + standby_world_ranks = [list(range(new_world_size_across_dp))] + _STANDBY_WORLD = _init_stateless_group( + standby_world_ranks, + "world", + world_group_ports, + master_ip, + backend, + use_device_communicator=False, + ) + _STANDBY_WORLD_NODE_COUNT = _node_count(_STANDBY_WORLD.tcp_store_group) + + tp_size = get_tp_group().world_size + pp_size = get_pp_group().world_size + + all_ranks = torch.arange(new_world_size_across_dp).reshape( + -1, new_dp_size, pp_size, tp_size + ) + standby_dp_ranks = all_ranks.transpose(1, 3).reshape(-1, new_dp_size).unbind(0) + standby_dp_ranks = [x.tolist() for x in standby_dp_ranks] + _STANDBY_DP = _init_stateless_group( + standby_dp_ranks, "dp", dp_group_ports, master_ip, backend + ) + + standby_ep_ranks = ( + all_ranks.transpose(1, 2).reshape(-1, new_dp_size * tp_size).unbind(0) + ) + standby_ep_ranks = [x.tolist() for x in standby_ep_ranks] + _STANDBY_EP = _init_stateless_group( + standby_ep_ranks, "ep", ep_group_ports, master_ip, backend + ) + + if eplb_group_ports is not None: + _STANDBY_EPLB = _init_stateless_group( + standby_ep_ranks, "eplb", eplb_group_ports, master_ip, backend + ) + + +def pop_standby_groups() -> dict: + """Return all standby groups and clear the standby state.""" + global \ + _STANDBY_WORLD, \ + _STANDBY_WORLD_NODE_COUNT, \ + _STANDBY_DP, \ + _STANDBY_EP, \ + _STANDBY_EPLB + + result = dict( + world=_STANDBY_WORLD, + dp=_STANDBY_DP, + ep=_STANDBY_EP, + eplb=_STANDBY_EPLB, + node_count=_STANDBY_WORLD_NODE_COUNT, + ) + _STANDBY_WORLD = None + _STANDBY_WORLD_NODE_COUNT = None + _STANDBY_DP = None + _STANDBY_EP = None + _STANDBY_EPLB = None + return result diff --git a/vllm/distributed/eplb/async_worker.py b/vllm/distributed/eplb/async_worker.py index b81c7fa..5dd862f 100644 --- a/vllm/distributed/eplb/async_worker.py +++ b/vllm/distributed/eplb/async_worker.py @@ -24,7 +24,6 @@ logger = init_logger(__name__) def start_async_worker( state: "EplbState", - rank_mapping: dict[int, int] | None = None, is_profile: bool = False, ) -> threading.Thread: eplb_group = get_eplb_group().device_group @@ -45,7 +44,6 @@ def start_async_worker( eplb_group=eplb_group, cuda_stream=cuda_stream, is_profile=is_profile, - rank_mapping=rank_mapping, ) ) except Exception as exc: # pragma: no cover - diagnostic path @@ -107,7 +105,6 @@ async def transfer_run_periodically( eplb_group: ProcessGroup, cuda_stream: torch.cuda.Stream, is_profile: bool = False, - rank_mapping: dict[int, int] | None = None, ) -> None: while True: await asyncio.to_thread(state.rearrange_event.wait) @@ -176,7 +173,6 @@ async def transfer_run_periodically( ep_group=eplb_group, is_profile=is_profile, cuda_stream=cuda_stream, - rank_mapping=rank_mapping, ) event = torch.cuda.Event(blocking=False) cuda_stream.record_event(event) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 7c3701b..b417c2b 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -40,6 +40,7 @@ from vllm.distributed.parallel_state import ( get_node_count, in_the_same_node_as, ) +from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger from vllm.model_executor.models.interfaces import MixtureOfExperts @@ -159,7 +160,7 @@ class EplbModelState: NOTE: The expert_load_view now records load for all physical experts rather than just local experts. This ensures consistent load statistics - across different dispatch methods (naive all-to-all, DeepEP, pplx-kernels). + across different dispatch methods (naive all-to-all, DeepEP). The recorded load will be multiplied by dp_size when using naive all-to-all due to each DP rank contributing the same token set to the calculation. See: @@ -302,6 +303,14 @@ class EplbState: """ CUDA device index for the async EPLB worker thread. """ + self.num_valid_physical_experts: int = 0 + """ + Number of valid physical experts. + This is the number of physical experts that are + actually mapped to logical experts. In elastic EP, + newly started EP ranks may not have physical experts + mapped yet. + """ if self.device.type == "cuda": self.cuda_device_index = self.device.index if self.cuda_device_index is None and torch.cuda.is_available(): @@ -367,9 +376,6 @@ class EplbState: self, model: MixtureOfExperts, model_config: ModelConfig, - global_expert_load: torch.Tensor | None = None, - old_global_expert_indices: torch.Tensor | None = None, - rank_mapping: dict[int, int] | None = None, ): """ Build the initial EPLB state. @@ -462,75 +468,15 @@ class EplbState: ) self.expert_rearrangement_step_interval = eplb_step_interval - # Set the policy based on the selected eplb algorithm type. policy_type = self.parallel_config.eplb_config.policy self.policy = EPLB_POLICIES[policy_type] logger.debug("Selected EPLB policy: %s", policy_type) - if global_expert_load is not None: - ep_group = get_ep_group().device_group - assert global_expert_load.shape == ( - model.num_moe_layers, - model.num_logical_experts, - ) - assert global_expert_load.dtype == torch.int64 - num_replicas = model.num_physical_experts - num_groups = model.num_expert_groups - num_nodes = get_node_count() - num_gpus = ep_group.size() - - if num_gpus % num_nodes != 0: - num_nodes = 1 - logger.warning_once( - f"num_gpus % num_nodes != 0, " - "not using hierarchical rearrangement algorithm.\n" - f"{num_gpus=}, {num_nodes=}" - ) - - # Get new expert mappings - ( - new_physical_to_logical_map, - new_logical_to_physical_map, - new_logical_replica_count, - ) = self.policy.rebalance_experts( - global_expert_load, - num_replicas, - num_groups, - num_nodes, - num_gpus, - ) - - max_physical_slots = new_logical_to_physical_map.shape[-1] - assert max_physical_slots <= logical_to_physical_map.shape[-1] - new_logical_to_physical_map = torch.nn.functional.pad( - new_logical_to_physical_map, - (0, logical_to_physical_map.shape[-1] - max_physical_slots), - value=-1, - ) - physical_to_logical_map = new_physical_to_logical_map.to(self.device) - logical_to_physical_map.copy_(new_logical_to_physical_map) - logical_replica_count.copy_(new_logical_replica_count) - else: - new_physical_to_logical_map = None - - new_logical_to_physical_map = None - - new_logical_replica_count = None model.set_eplb_state( expert_load_pass, logical_to_physical_map, logical_replica_count, ) - if global_expert_load is not None: - rearrange_expert_weights_inplace( - old_global_expert_indices, - new_physical_to_logical_map, - model.expert_weights, - ep_group, - False, - rank_mapping, - ) - self.expert_rearrangement_step = 0 expert_buffer = [torch.empty_like(w) for w in model.expert_weights[0]] @@ -561,11 +507,12 @@ class EplbState: recv_dst_rows=np.array([]), ), cuda_device_index=self.cuda_device_index, - new_physical_to_logical_map=new_physical_to_logical_map, - new_logical_to_physical_map=new_logical_to_physical_map, - new_logical_replica_count=new_logical_replica_count, + new_physical_to_logical_map=None, + new_logical_to_physical_map=None, + new_logical_replica_count=None, ) self.model_states[model_config.compute_hash()] = model_state + self.num_valid_physical_experts = model.num_physical_experts def step( self, @@ -696,8 +643,6 @@ class EplbState: def rearrange( self, is_profile: bool = False, - execute_shuffle: bool = True, - global_expert_loads: list[torch.Tensor] | None = None, rank_mapping: dict[int, int] | None = None, ) -> torch.Tensor | None: """ @@ -707,12 +652,6 @@ class EplbState: is_profile (bool): If `True`, perform a dummy rearrangement. This is used in `profile_run` to reserve enough memory, no memory movement will be performed. Default is False. - execute_shuffle (bool): If `True`, execute the shuffle - in elastic expert parallel (EEP). Default is True. - global_expert_loads (list[torch.Tensor] | None): The global expert - loads when scaling is done in EEP. - List of expert loads for the main and drafter - (when spec decode is used) models. rank_mapping (dict[int, int] | None): The rank mapping when scaling is done in EEP. """ @@ -734,67 +673,34 @@ class EplbState: "(profile)" if is_profile else "", ) - if global_expert_loads is None: - # Map the physical expert load to global logical experts - global_expert_load_windows = [] - if not execute_shuffle: - num_models = torch.tensor( - [len(self.model_states)], dtype=torch.int32, device="cpu" - ) - torch.distributed.broadcast( - num_models, group=get_ep_group().cpu_group, group_src=0 - ) - - for eplb_model_state in self.model_states.values(): - logical_expert_load_window = torch.zeros( - self.expert_load_window_size, - eplb_model_state.model.num_moe_layers, - eplb_model_state.model.num_logical_experts, - dtype=eplb_model_state.expert_load_window.dtype, - device=eplb_model_state.expert_load_window.device, - ) - logical_expert_load_window.scatter_add_( - dim=-1, - index=eplb_model_state.physical_to_logical_map.unsqueeze(0) - .expand_as(eplb_model_state.expert_load_window) - .long(), - src=eplb_model_state.expert_load_window, - ) - - if not execute_shuffle: - metadata = torch.tensor( - [ - eplb_model_state.model.num_moe_layers, - eplb_model_state.model.num_logical_experts, - eplb_model_state.physical_to_logical_map.shape[1], - ], - dtype=torch.int32, - device="cpu", - ) - torch.distributed.broadcast( - metadata, group=get_ep_group().cpu_group, group_src=0 - ) - - global_expert_load_window = logical_expert_load_window.sum(dim=0) - global_expert_load_windows.append(global_expert_load_window) - # Perform all-reduce to get the expert load across all ranks for each model - global_expert_load_windows = self._allreduce_list( - global_expert_load_windows + # Map the physical expert load to global logical experts + global_expert_load_windows = [] + for eplb_model_state in self.model_states.values(): + expert_load_window = eplb_model_state.expert_load_window[ + :, :, : self.num_valid_physical_experts + ] + logical_expert_load_window = torch.zeros( + self.expert_load_window_size, + eplb_model_state.model.num_moe_layers, + eplb_model_state.model.num_logical_experts, + dtype=eplb_model_state.expert_load_window.dtype, + device=eplb_model_state.expert_load_window.device, ) - if not execute_shuffle: - for eplb_model_state, global_expert_load_window in zip( - self.model_states.values(), global_expert_load_windows - ): - # (num_moe_layers, old_num_physical_experts) - old_global_expert_indices = eplb_model_state.physical_to_logical_map - torch.distributed.broadcast( - old_global_expert_indices, group=ep_group, group_src=0 - ) - if not execute_shuffle: - return global_expert_load_windows - else: - assert execute_shuffle - global_expert_load_windows = global_expert_loads + logical_expert_load_window.scatter_add_( + dim=-1, + index=eplb_model_state.physical_to_logical_map[ + :, : self.num_valid_physical_experts + ] + .unsqueeze(0) + .expand_as(expert_load_window) + .long(), + src=expert_load_window, + ) + + global_expert_load_window = logical_expert_load_window.sum(dim=0) + global_expert_load_windows.append(global_expert_load_window) + # Perform all-reduce to get the expert load across all ranks for each model + global_expert_load_windows = self._allreduce_list(global_expert_load_windows) # TODO(bowen): Treat differently for prefill and decode nodes eplb_model_state = next(iter(self.model_states.values())) @@ -806,8 +712,10 @@ class EplbState: # NOTE(yongji): scale down, we need to rebalance the experts on # remaining GPUs, transfer the experts while we haven't shutdown # the GPUs to be released. - cpu_group = get_ep_group().cpu_group - num_nodes = _node_count_with_rank_mapping(cpu_group, rank_mapping) + coordinator = get_ep_group() + assert isinstance(coordinator, StatelessGroupCoordinator) + tcp_store_group = coordinator.tcp_store_group + num_nodes = _node_count_with_rank_mapping(tcp_store_group, rank_mapping) num_gpus = sum(new_rank != -1 for new_rank in rank_mapping.values()) num_replicas = ( num_replicas // ep_group.size() * num_gpus @@ -933,7 +841,6 @@ class EplbState: if self.async_worker is None: self.async_worker = start_async_worker( self, - rank_mapping=rank_mapping, is_profile=is_profile, ) @@ -1089,83 +996,6 @@ class EplbState: model_state.new_logical_to_physical_map = None model_state.new_logical_replica_count = None - @staticmethod - def recv_state() -> tuple[list[torch.Tensor], list[torch.Tensor]]: - """ - Receive the expert load and old placement from the master rank. - """ - ep_group = get_ep_group() - num_models = torch.empty(1, dtype=torch.int32, device="cpu") - torch.distributed.broadcast(num_models, group=ep_group.cpu_group, group_src=0) - num_models = num_models.item() - global_expert_loads = [] - old_global_expert_indices_per_model = [] - for _ in range(num_models): - metadata = torch.empty(3, dtype=torch.int32, device="cpu") - torch.distributed.broadcast(metadata, group=ep_group.cpu_group, group_src=0) - num_moe_layers, num_logical_experts, num_old_physical_experts = ( - metadata.tolist() - ) - global_expert_load = torch.zeros( - (num_moe_layers, num_logical_experts), - dtype=torch.int64, - device=ep_group.device, - ) - all_reduce(global_expert_load, group=ep_group.device_group) - old_global_expert_indices = torch.empty( - (num_moe_layers, num_old_physical_experts), - dtype=torch.int64, - device=ep_group.device, - ) - torch.distributed.broadcast( - old_global_expert_indices, - group=ep_group.device_group, - group_src=0, - ) - global_expert_loads.append(global_expert_load) - old_global_expert_indices_per_model.append(old_global_expert_indices) - return global_expert_loads, old_global_expert_indices_per_model - - @classmethod - def get_eep_state( - cls, parallel_config: ParallelConfig - ) -> tuple[ - list[torch.Tensor] | None, - list[torch.Tensor] | None, - dict[int, int] | None, - ]: - num_local_physical_experts = torch.empty(1, dtype=torch.int32, device="cpu") - torch.distributed.broadcast( - num_local_physical_experts, - group=get_ep_group().cpu_group, - group_src=0, - ) - num_local_physical_experts = int(num_local_physical_experts.item()) - new_ep_size = get_ep_group().world_size - global_expert_loads, old_global_expert_indices_per_model = ( - EplbState.recv_state() - ) - - # EP configuration for all models has to be the same so as eplb config - num_logical_experts = global_expert_loads[0].shape[1] - parallel_config.eplb_config.num_redundant_experts = ( - num_local_physical_experts * new_ep_size - num_logical_experts - ) - assert ( - old_global_expert_indices_per_model[0].shape[1] % num_local_physical_experts - == 0 - ) - old_ep_size = ( - old_global_expert_indices_per_model[0].shape[1] - // num_local_physical_experts - ) - rank_mapping = {old_ep_rank: old_ep_rank for old_ep_rank in range(old_ep_size)} - return ( - global_expert_loads, - old_global_expert_indices_per_model, - rank_mapping, - ) - def _allreduce_list(self, tensor_list: list[torch.Tensor]) -> list[torch.Tensor]: """ All-reduce a list of tensors. @@ -1203,6 +1033,60 @@ class EplbState: load_pass_list.append(eplb_model_state.expert_load_pass.clone()) return self._allreduce_list(load_pass_list) + @classmethod + def from_mapping( + cls, + model: MixtureOfExperts, + model_config: ModelConfig, + device: torch.device, + parallel_config: ParallelConfig, + expanded_physical_to_logical: torch.Tensor, + num_valid_physical_experts: int, + ) -> "EplbState": + eplb_state = cls( + parallel_config=parallel_config, + device=device, + ) + eplb_state.add_model( + model=model, + model_config=model_config, + ) + eplb_state.num_valid_physical_experts = num_valid_physical_experts + num_moe_layers = expanded_physical_to_logical.shape[0] + num_physical_experts = expanded_physical_to_logical.shape[1] + eplb_model_state = eplb_state.model_states[model_config.compute_hash()] + eplb_model_state.physical_to_logical_map.copy_(expanded_physical_to_logical) + + logical_to_physical_map = torch.full( + ( + num_moe_layers, + model.num_logical_experts, + eplb_model_state.logical_to_physical_map.shape[2], + ), + -1, + dtype=torch.int64, + ) + logical_replica_count = torch.zeros( + (num_moe_layers, model.num_logical_experts), + dtype=torch.int64, + ) + expanded_physical_to_logical_numpy = expanded_physical_to_logical.cpu().numpy() + for layer_idx in range(num_moe_layers): + for phys_idx in range(num_physical_experts): + logical_idx = expanded_physical_to_logical_numpy[layer_idx, phys_idx] + if logical_idx >= 0: + replica_idx = logical_replica_count[layer_idx, logical_idx] + logical_to_physical_map[layer_idx, logical_idx, replica_idx] = ( + phys_idx + ) + logical_replica_count[layer_idx, logical_idx] += 1 + + logical_to_physical_map = logical_to_physical_map.to(device) + logical_replica_count = logical_replica_count.to(device) + eplb_model_state.logical_to_physical_map.copy_(logical_to_physical_map) + eplb_model_state.logical_replica_count.copy_(logical_replica_count) + return eplb_state + @dataclass class EplbLayerState: diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 1be1e24..777f9c5 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -19,6 +19,8 @@ from torch.distributed import ( get_global_rank, ) +from vllm.distributed.parallel_state import get_ep_group +from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator from vllm.logger import init_logger logger = init_logger(__name__) @@ -249,10 +251,18 @@ def move_to_buffer( b[dst].copy_(w[src_local], non_blocking=True) p2p_ops: list[P2POp] = [] + if isinstance(get_ep_group(), StatelessGroupCoordinator): + ep_group = get_ep_group() + is_stateless = True + else: + is_stateless = False - # Pre-compute global ranks mapping + # Pre-compute global ranks mapping (only needed for non-stateless groups) ep_size = ep_group.size() - rank_to_global = {rank: get_global_rank(ep_group, rank) for rank in range(ep_size)} + if not is_stateless: + rank_to_global = { + rank: get_global_rank(ep_group, rank) for rank in range(ep_size) + } # 2. Post sends if send_count > 0: @@ -284,15 +294,23 @@ def move_to_buffer( if recver_pos < len(ranks_to_recv): recv_ranks.append(ranks_to_recv[recver_pos]) for dst in recv_ranks: - dst_global = rank_to_global[dst] - p2p_ops += [ - P2POp( - torch.distributed.isend, - w[src], - dst_global, - ) - for w in expert_weights - ] + if is_stateless: + for w in expert_weights: + op = object.__new__(P2POp) + op.op = torch.distributed.isend + op.tensor = w[src] + op.group_peer = dst + p2p_ops.append(op) + else: + dst_global = rank_to_global[dst] + p2p_ops += [ + P2POp( + torch.distributed.isend, + w[src], + dst_global, + ) + for w in expert_weights + ] # 3. Post recvs if recv_count > 0: @@ -321,26 +339,40 @@ def move_to_buffer( src = ranks_to_send[recver_pos // num_dst_per_sender] else: src = ranks_to_send[recver_pos - remainder_start] - src_global = rank_to_global[src] - p2p_ops += [ - P2POp( - torch.distributed.irecv, - b[dst], - src_global, - ) - for b in expert_weights_buffers - ] + if is_stateless: + for b in expert_weights_buffers: + op = object.__new__(P2POp) + op.op = torch.distributed.irecv + op.tensor = b[dst] + op.group_peer = src + p2p_ops.append(op) + else: + src_global = rank_to_global[src] + p2p_ops += [ + P2POp( + torch.distributed.irecv, + b[dst], + src_global, + ) + for b in expert_weights_buffers + ] # 4. Execute the P2P operations. The real communication happens here. if p2p_ops and cuda_stream is not None: with torch.cuda.stream(cuda_stream): + if is_stateless: + ep_group.device_communicator.batch_isend_irecv(p2p_ops) + else: + reqs = batch_isend_irecv(p2p_ops) + for req in reqs: + req.wait() + elif p2p_ops: + if is_stateless: + ep_group.device_communicator.batch_isend_irecv(p2p_ops) + else: reqs = batch_isend_irecv(p2p_ops) for req in reqs: req.wait() - elif p2p_ops: - reqs = batch_isend_irecv(p2p_ops) - for req in reqs: - req.wait() # wait for the communication to finish return ( is_unchanged, diff --git a/vllm/distributed/kv_events.py b/vllm/distributed/kv_events.py index 096ed44..21ec7a3 100644 --- a/vllm/distributed/kv_events.py +++ b/vllm/distributed/kv_events.py @@ -209,6 +209,10 @@ class KVConnectorKVEvents(ABC): def clear_events(self) -> None: raise NotImplementedError + def merge(self, other: "KVConnectorKVEvents") -> "KVConnectorKVEvents": + self.add_events(other.get_all_events()) + return self + class EventPublisher(ABC): """Lightweight publisher for EventBatch batches with data parallelism diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 1ceac39..d5a40fc 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -149,6 +149,12 @@ KVConnectorFactory.register_connector( "ExampleConnector", ) +KVConnectorFactory.register_connector( + "ExampleHiddenStatesConnector", + "vllm.distributed.kv_transfer.kv_connector.v1.example_hidden_states_connector", + "ExampleHiddenStatesConnector", +) + KVConnectorFactory.register_connector( "P2pNcclConnector", "vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_connector", diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index f9367da..1b012d4 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -413,7 +413,20 @@ class TpKVTopology: f"by local tensor parallel size {self.tp_size}." ) # P TP > D TP case, return the ratio as negative - return -remote_tp_size // self.tp_size + return remote_tp_size // self.tp_size + + def pp_ratio( + self, + remote_pp_size: int, + ) -> int: + """ + Calculate the pipeline parallel ratio between local and remote PP. + """ + assert self.pp_size % remote_pp_size == 0 or remote_pp_size % self.pp_size == 0, ( + f"Local pipline parallel size {self.tp_size} is not divisible " + f"by remote pipline parallel size {remote_pp_size} or vice versa." + ) + return self.pp_size // remote_pp_size if self.pp_size % remote_pp_size == 0 else remote_pp_size // self.pp_size def block_size_ratio( self, @@ -457,6 +470,7 @@ class TpKVTopology: def get_target_remote_ranks( self, remote_tp_size: int, + remote_pp_size: int ) -> list[int]: """ Get the remote TP rank (on P) that the current local TP rank @@ -464,19 +478,36 @@ class TpKVTopology: read from multiple remote ranks. """ tp_ratio = self.tp_ratio(remote_tp_size) - if tp_ratio > 0: - return [self.tp_rank // tp_ratio] + pp_ratio = self.pp_ratio(remote_pp_size) + target_pp_rank_list = [] + target_tp_rank_list = [] + if self.pp_size < remote_pp_size: + for i in range(pp_ratio): + target_pp_rank_list.append(self.pp_rank * pp_ratio + i) + else: + target_pp_rank_list.append(self.pp_rank // pp_ratio) - # P TP > D TP case, D reads from |tp_ratio| remote workers. - tp_ratio = -tp_ratio - return [self.tp_rank * tp_ratio + i for i in range(tp_ratio)] + if self.tp_size < remote_tp_size: + for i in range(tp_ratio): + target_tp_rank_list.append(self.tp_rank * tp_ratio + i) + else: + target_tp_rank_list.append(self.tp_rank // tp_ratio) + + target_rank_list = [] + for pp_rank in target_pp_rank_list: + for tp_rank in target_tp_rank_list: + target_rank = pp_rank * remote_tp_size + tp_rank + target_rank_list.append((target_rank, pp_rank, tp_rank)) + + return target_rank_list def get_target_remote_ranks_from_engine_id( self, remote_engine_id: EngineId, ) -> list[int]: remote_tp_size = self.remote_tp_size[remote_engine_id] - return self.get_target_remote_ranks(remote_tp_size) + remote_pp_size = self.remote_pp_size[remote_engine_id] + return self.get_target_remote_ranks(remote_tp_size, remote_pp_size) def get_current_attn_backend(vllm_config: VllmConfig): diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index a0e03b0..c096827 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -543,6 +543,28 @@ class KVConnectorBase_V1(ABC): ) return None + @classmethod + def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool: + """ + Check if this connector requires PIECEWISE CUDA graph mode. + + Connectors that use asynchronous layer-by-layer operations + (wait_for_layer_load/save_kv_layer) should override this method + to return True when those operations are enabled. These operations + cannot be captured in CUDA graphs and will be skipped during replay, + causing data races. PIECEWISE mode allows Python code to execute + between graph pieces, ensuring proper synchronization. + + Args: + extra_config: The kv_connector_extra_config dict from + KVTransferConfig. + + Returns: + True if this connector requires PIECEWISE CUDA graph mode, + False otherwise. + """ + return False + def get_finished_count(self) -> int | None: """ Get the count of requests expected to complete send/receive operations diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py index d4a99cf..14feafc 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/example_connector.py @@ -17,6 +17,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadata from vllm.utils.hashing import safe_hash from vllm.v1.attention.backend import AttentionMetadata +from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput if TYPE_CHECKING: @@ -118,12 +119,12 @@ class ExampleConnector(KVConnectorBase_V1): The number of elements in kv_caches and layer_names should be the same. """ - attn_metadata = forward_context.attn_metadata def inject_kv_into_layer( dst_kv_cache_layer: torch.Tensor, src_kv_cache: torch.Tensor, slot_mapping: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> None: """Inject the KV cache into the layer. @@ -145,6 +146,10 @@ class ExampleConnector(KVConnectorBase_V1): num_pages * page_size, -1 ) dst_kv_cache_layer[slot_mapping, ...] = src_kv_cache + elif isinstance(attn_metadata, TritonAttentionMetadata): + block_idxs = slot_mapping // self._block_size + offsets = slot_mapping % self._block_size + dst_kv_cache_layer[block_idxs, :, offsets] = src_kv_cache else: num_pages = dst_kv_cache_layer_shape[1] page_size = dst_kv_cache_layer_shape[2] @@ -186,7 +191,13 @@ class ExampleConnector(KVConnectorBase_V1): layer_name, request.token_ids, request.mm_hashes ) kv_cache = safetensors.torch.load_file(filename)["kv_cache"].cuda() - inject_kv_into_layer(kv_cache_layer, kv_cache, request.slot_mapping) + if isinstance(attn_metadata, dict): + inject_kv_into_layer( + kv_cache_layer, + kv_cache, + request.slot_mapping, + attn_metadata[layer_name], + ) def wait_for_layer_load(self, layer_name: str) -> None: """Blocking until the KV for a specific layer is loaded into vLLM's @@ -229,6 +240,10 @@ class ExampleConnector(KVConnectorBase_V1): if isinstance(attn_metadata, MLACommonMetadata): num_pages, page_size = layer.shape[0], layer.shape[1] return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...] + elif isinstance(attn_metadata, TritonAttentionMetadata): + block_idxs = slot_mapping // self._block_size + offsets = slot_mapping % self._block_size + return layer[block_idxs, :, offsets] num_pages, page_size = layer.shape[1], layer.shape[2] return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping, ...] diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py new file mode 100644 index 0000000..945f8d9 --- /dev/null +++ b/vllm/distributed/kv_transfer/kv_connector/v1/example_hidden_states_connector.py @@ -0,0 +1,354 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Optional + +import safetensors +import torch + +from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, + KVConnectorMetadata, + KVConnectorRole, +) +from vllm.logger import init_logger +from vllm.v1.attention.backend import AttentionMetadata +from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput + +if TYPE_CHECKING: + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig + from vllm.v1.request import Request + +logger = init_logger(__name__) + + +def extract_from_kv_cache( + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + num_tokens: int, +) -> torch.Tensor: + """Extract data from KV cache + Assume the shape of the kv_cache is (num_pages, page_size, num_heads, head_size) + """ + + padded_kv = kv_cache.flatten(0, 1)[slot_mapping] + # shape: [len(slot_mapping), num_heads, head_size] + return padded_kv[:num_tokens] # shape: [num_tokens, num_heads, head_size] + + +@dataclass +class ReqMeta: + # Request ID + req_id: str + # Request filename + filename: str + # Request tokens + token_ids: torch.Tensor + # Slot mappings, should have the same length as token_ids + slot_mapping: torch.Tensor + # Whether this request is a new request or partially computed already + new_req: bool + + @staticmethod + def make_meta( + req_id: str, + filename: str, + token_ids: list[int], + block_ids: list[int], + block_size: int, + new_req: bool, + ) -> "ReqMeta": + token_ids_tensor = torch.tensor(token_ids) + block_ids_tensor = torch.tensor(block_ids) + num_blocks = block_ids_tensor.shape[0] + block_offsets = torch.arange(0, block_size) + slot_mapping = ( + block_offsets.reshape((1, block_size)) + + block_ids_tensor.reshape((num_blocks, 1)) * block_size + ) + slot_mapping = slot_mapping.flatten() + return ReqMeta( + req_id=req_id, + filename=filename, + token_ids=token_ids_tensor, + slot_mapping=slot_mapping, + new_req=new_req, + ) + + +@dataclass +class ExampleHiddenStatesConnectorMetadata(KVConnectorMetadata): + requests: list[ReqMeta] = field(default_factory=list) + + def add_request( + self, + req_id: str, + filename: str, + token_ids: list[int], + block_ids: list[int], + block_size: int, + new_req: bool = True, + ) -> None: + self.requests.append( + ReqMeta.make_meta( + req_id, filename, token_ids, block_ids, block_size, new_req + ) + ) + + +class ExampleHiddenStatesConnector(KVConnectorBase_V1): + """ + Simple debug implementation of a HiddenStatesConnector. + + Simply extracts the hidden states from the kv cache and stores them to disk. + Must be used in conjunction with the `extract_hidden_states` spec decoding method. + """ + + @property + def prefer_cross_layer_blocks(self) -> bool: + """ + Indicates whether this connector prefers KV blocks that hold KV data for all + layers, which can speed up KV data transfers. Defaults to False. + """ + # Must be False so that drafter kv cache isn't merged with verifier's + return False + + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super().__init__( + vllm_config=vllm_config, + role=role, + kv_cache_config=kv_cache_config, + ) + self._block_size = vllm_config.cache_config.block_size + self._storage_path = self._kv_transfer_config.get_from_extra_config( + "shared_storage_path", "/tmp" + ) + self.cache_layers: list[str] = [] # set by self.register_kv_caches + logger.info(self._kv_transfer_config) + logger.info("Shared storage path is %s", self._storage_path) + + assert self._vllm_config.speculative_config is not None, ( + "ExampleHiddenStatesConnector only works when using " + "'extract_hidden_states' speculative method" + ) + spec_config = self._vllm_config.speculative_config.draft_model_config.hf_config + self.num_hidden_states = len( + getattr(spec_config, "eagle_aux_hidden_state_layer_ids", []) + ) + + self._request_filenames: dict[str, str] = {} + self._active_requests: dict[str, NewRequestData] = {} + self._req_blocks: dict[str, list[int]] = {} + + # ============================== + # Worker-side methods + # ============================== + def start_load_kv(self, *args, **kwargs: Any) -> None: + pass # Empty implementation of abstract method + + def wait_for_layer_load(self, layer_name: str) -> None: + pass # Empty implementation of abstract method + + def wait_for_save(self): + pass # Empty implementation of abstract method + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + from vllm.model_executor.models.extract_hidden_states import ( + CacheOnlyAttentionLayer, + ) + + # Filter layers to only include CacheOnlyAttentionLayers + layers = get_layers_from_vllm_config( + self._vllm_config, CacheOnlyAttentionLayer, list(kv_caches.keys()) + ) + self.cache_layers = list(layers.keys()) + assert len(self.cache_layers) == 1, ( + f"Expected 1 CacheOnlyAttentionLayer, got {len(self.cache_layers)}" + ) + + def save_kv_layer( + self, + layer_name: str, + kv_layer: torch.Tensor, + attn_metadata: AttentionMetadata, + **kwargs: Any, + ) -> None: + """Start saving the KV cache of the layer from vLLM's paged buffer + to the connector. + + Args: + layer_name (str): the name of the layer. + kv_layer (torch.Tensor): the paged KV buffer of the current + layer in vLLM. + attn_metadata (AttentionMetadata): the attention metadata. + **kwargs: additional arguments for the save operation. + """ + if layer_name not in self.cache_layers: + return + + from vllm.model_executor.models.extract_hidden_states import ( + CacheOnlyAttentionMetadata, + ) + + assert isinstance(attn_metadata, CacheOnlyAttentionMetadata), ( + "ExampleHiddenStatesConnector only supports CacheOnlyAttentionBackend" + ) + + connector_metadata = self._get_connector_metadata() + assert isinstance(connector_metadata, ExampleHiddenStatesConnectorMetadata) + + os.makedirs(self._storage_path, exist_ok=True) + for request in connector_metadata.requests: + hidden_states = extract_from_kv_cache( + kv_layer, request.slot_mapping, request.token_ids.shape[0] + ) + tensors = { + "hidden_states": hidden_states.detach().cpu(), + "token_ids": request.token_ids.detach().cpu(), + } + safetensors.torch.save_file(tensors, request.filename) + + # ============================== + # Scheduler-side methods + # ============================== + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int | None, bool]: + """ + Get number of new tokens that can be loaded from the + external KV cache beyond the num_computed_tokens. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + # This connector is store-only, so we don't need to load any tokens + return 0, False + + def update_state_after_alloc( + self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int + ): + # Usually used to handle allocation of new blocks for requests that are loading + # tokens from connector's external kv cache. We never load from external cache + # so this is a no-op. + assert num_external_tokens == 0, "This connector is store-only" + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + """Build the connector metadata for this step. + + This function should NOT modify any fields in the scheduler_output. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + meta = ExampleHiddenStatesConnectorMetadata() + for new_req in scheduler_output.scheduled_new_reqs: + token_ids = new_req.prompt_token_ids or [] + filename = os.path.join(self._storage_path, f"{new_req.req_id}.safetensors") + meta.add_request( + new_req.req_id, + filename=filename, + token_ids=token_ids, + block_ids=new_req.block_ids[0], + block_size=self._block_size, + ) + self._request_filenames[new_req.req_id] = filename + self._active_requests[new_req.req_id] = new_req + self._req_blocks[new_req.req_id] = list(new_req.block_ids[0]) + + cached_reqs = scheduler_output.scheduled_cached_reqs + for i, req_id in enumerate(cached_reqs.req_ids): + if req_id not in self._active_requests: + continue + + new_block_ids = cached_reqs.new_block_ids[i] + + cached_req = self._active_requests[req_id] + req_block_ids = self._req_blocks[req_id] + + assert new_block_ids is not None + block_ids = new_block_ids[0] + + req_block_ids.extend(block_ids) + filename = os.path.join(self._storage_path, f"{req_id}.safetensors") + + meta.add_request( + req_id=req_id, + filename=filename, + token_ids=cached_req.prompt_token_ids or [], + block_ids=req_block_ids, + block_size=self._block_size, + new_req=False, + ) + + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + """ + Called exactly once when a request has finished, before its blocks are + freed. + + The connector may assumes responsibility for freeing the blocks + asynchronously by returning True. + + Returns: + True if the request is being saved/sent asynchronously and blocks + should not be freed until the request_id is returned from + get_finished(). + Optional KVTransferParams to be included in the request outputs + returned by the engine. + """ + req_id = request.request_id + req_filename = self._request_filenames.pop(req_id, None) + _ = self._active_requests.pop(req_id, None) + _ = self._req_blocks.pop(req_id, None) + + return False, {"hidden_states_path": req_filename} + + @classmethod + def get_required_kvcache_layout(cls, vllm_config: "VllmConfig") -> str | None: + """ + Get the required KV cache layout for this connector. + Args: + vllm_config (VllmConfig): the vllm config. + + Returns: + str: the required KV cache layout. e.g. HND, or NHD. + None if the connector does not require a specific layout. + """ + + if cls is KVConnectorBase_V1: + raise TypeError( + "get_required_kvcache_layout should not be called " + "on the abstract base class" + ) + # NHD means we have (num_tokens, num_heads) + # HND means we have (num_heads, num_tokens) + # For now, we only support NHD layout since this keeps the + # hidden states for each token together in memory. + # HND is primarily used when sharding heads across devices. + return "NHD" diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index 376215e..64aee2b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -70,6 +70,16 @@ class LMCacheKVEvents(KVConnectorKVEvents): class LMCacheConnectorV1(KVConnectorBase_V1): + @classmethod + def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool: + """ + LMCache requires PIECEWISE CUDA graph mode when layerwise + operations are enabled. The wait_for_layer_load and save_kv_layer + methods perform actual async synchronization that cannot be + captured in CUDA graphs. + """ + return extra_config.get("use_layerwise", False) + def __init__( self, vllm_config: "VllmConfig", diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py index f105d34..744d763 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake/mooncake_connector.py @@ -173,6 +173,29 @@ class MooncakeConnector(KVConnectorBase_V1): self.connector_scheduler = None self.connector_worker = MooncakeConnectorWorker(vllm_config, self.engine_id) + + ############################################################ + # Class Methods + ############################################################ + @classmethod + def get_required_kvcache_layout(cls, vllm_config: VllmConfig): + if vllm_config.model_config is None: + logger.warning_once( + "Unable to detect current VLLM config. " + "Fallback to default kv cache layout." + ) + return None + use_mla = vllm_config.model_config.use_mla + if use_mla: + # return None when we have mla + # as the layout should not matter in that case, + # which fallback to the default behavior. + return None + logger.info_once( + "MooncakeConnector setting KV cache layout to HND for better xfer performance." + ) + return "HND" + ############################################################ # Scheduler Side Methods ############################################################ @@ -941,7 +964,13 @@ class MooncakeConnectorWorker: assert tensor_size_bytes == curr_tensor_size_bytes, ( "All kv cache tensors must have the same size" ) - kernel_block_size = cache.shape[-2 if self.use_mla else -3] + + cache_layout = get_kv_cache_layout() + if cache_layout == "HND": + kernel_block_size = cache.shape[-2] + else: + kernel_block_size = cache.shape[-3] + assert self.block_size == kernel_block_size kv_data_ptrs.append(base_addr) kv_data_lens.append(tensor_size_bytes) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 3f0c983..7052886 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -112,6 +112,21 @@ class MultiConnector(KVConnectorBase_V1): - Save to all connectors. """ + @classmethod + def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool: + """ + MultiConnector requires PIECEWISE CUDA graph mode if any of its + child connectors require it. + """ + connectors_config = extra_config.get("connectors", []) + for conn_config in connectors_config: + temp_ktc = KVTransferConfig(**conn_config) + connector_cls = KVConnectorFactory.get_connector_class(temp_ktc) + child_extra_config = conn_config.get("kv_connector_extra_config", {}) + if connector_cls.requires_piecewise_for_cudagraph(child_extra_config): + return True + return False + def __init__( self, vllm_config: "VllmConfig", diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index b3f2ae7..4cafe12 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -30,7 +30,6 @@ from vllm.distributed.kv_transfer.kv_connector.utils import ( kv_postprocess_blksize_and_layout_on_receive, kv_postprocess_blksize_on_receive, kv_postprocess_layout_on_receive, - yield_req_data, ) from vllm.distributed.kv_transfer.kv_connector.v1.base import ( CopyBlocksOp, @@ -49,6 +48,7 @@ from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group, + get_pp_group, ) from vllm.forward_context import ForwardContext from vllm.logger import init_logger @@ -135,7 +135,10 @@ _NIXL_SUPPORTED_DEVICE = { "cpu", ), "tpu": ("cpu",), - "xpu": ("cpu",), + "xpu": ( + "cpu", + "xpu", + ), "cpu": ("cpu",), } # support for oot platform by providing mapping in current_platform @@ -150,8 +153,12 @@ class NixlAgentMetadata: device_id: int num_blocks: int block_lens: list[int] + attn_backend_name: str kv_cache_layout: str block_size: int + pp_split: str + pp_rank: int + tp_rank: int @dataclass @@ -245,10 +252,26 @@ class RemoteMeta: @dataclass class ReqMeta: local_block_ids: list[int] - # To be used when logical block size does not match the kernel block size local_physical_block_ids: list[int] + remote_block_ids: list[int] + remote_host: str + remote_port: int + remote_engine_id: str tp_size: int - remote: RemoteMeta | None = None + remote_pp_size: int = 1 + remote_tp_size: int = 1 + remote_pp_rank: int = 0 + remote_tp_rank: int = 0 + + @property + def remote(self) -> RemoteMeta: + return RemoteMeta( + block_ids=self.remote_block_ids, + host=self.remote_host, + port=self.remote_port, + engine_id=self.remote_engine_id, + request_id="", + ) class NixlConnectorMetadata(KVConnectorMetadata): @@ -259,43 +282,32 @@ class NixlConnectorMetadata(KVConnectorMetadata): self.reqs_in_batch: set[ReqId] = set() self.reqs_not_processed: set[ReqId] = set() - def _add_new_req( + def add_new_req( self, + request_id: ReqId, local_block_ids: list[int], kv_transfer_params: dict[str, Any], - ) -> ReqMeta: - return ReqMeta( + load_remote_cache: bool = True, + save_to_host: bool = False, + ): + assert load_remote_cache ^ save_to_host + _req = ReqMeta( local_block_ids=local_block_ids, local_physical_block_ids=local_block_ids, - # P workers don't need to receive tp_size from proxy here. + remote_block_ids=kv_transfer_params["remote_block_ids"], + remote_engine_id=kv_transfer_params["remote_engine_id"], + remote_host=kv_transfer_params["remote_host"], + remote_port=kv_transfer_params["remote_port"], tp_size=kv_transfer_params.get("tp_size", 1), + remote_tp_size=kv_transfer_params.get("tp_size", 1), + remote_pp_size=kv_transfer_params.get("pp_size", 1), + remote_tp_rank=0, + remote_pp_rank=0, ) - - def add_new_req_to_save( - self, - request_id: ReqId, - local_block_ids: list[int], - kv_transfer_params: dict[str, Any], - ): - self.reqs_to_save[request_id] = self._add_new_req( - local_block_ids, kv_transfer_params - ) - - def add_new_req_to_recv( - self, - request_id: ReqId, - local_block_ids: list[int], - kv_transfer_params: dict[str, Any], - ): - req = self._add_new_req(local_block_ids, kv_transfer_params) - req.remote = RemoteMeta( - block_ids=kv_transfer_params["remote_block_ids"], - engine_id=kv_transfer_params["remote_engine_id"], - request_id=kv_transfer_params["remote_request_id"], - host=kv_transfer_params["remote_host"], - port=kv_transfer_params["remote_port"], - ) - self.reqs_to_recv[request_id] = req + if save_to_host: + self.reqs_to_save[request_id] = _req + if load_remote_cache: + self.reqs_to_recv[request_id] = _req class NixlConnector(KVConnectorBase_V1): @@ -543,7 +555,7 @@ class NixlConnectorScheduler: # New requests are added by update_state_after_alloc in # the scheduler. Used to make metadata passed to Worker. self._reqs_need_recv: dict[ReqId, tuple[Request, list[int]]] = {} - self._reqs_need_save: dict[ReqId, Request] = {} + self._reqs_need_save: dict[ReqId, tuple[Request, list[int]]] = {} # Reqs to send and their expiration time self._reqs_need_send: dict[ReqId, float] = {} self._reqs_in_batch: set[ReqId] = set() @@ -568,17 +580,17 @@ class NixlConnectorScheduler: """ encoded_data: dict[int, bytes] = {} encoder = msgspec.msgpack.Encoder() - for tp_rank, rank_metadata in metadata.items(): + for global_rank, rank_metadata in metadata.items(): if not isinstance(rank_metadata, NixlHandshakePayload): raise ValueError( "NixlConnectorScheduler expects NixlHandshakePayload for " "handshake metadata." ) - encoded_data[tp_rank] = encoder.encode(rank_metadata) + encoded_data[global_rank] = encoder.encode(rank_metadata) logger.debug( - "Tp rank %d: encoded NixlHandshakePayload size: %s bytes", - tp_rank, - str(len(encoded_data[tp_rank])), + "Global rank %d: encoded NixlHandshakePayload size: %s bytes", + global_rank, + str(len(encoded_data[global_rank])), ) self._encoded_xfer_handshake_metadata = encoded_data @@ -625,14 +637,14 @@ class NixlConnectorScheduler: break continue # Decode the message which contains (GET_META_MSG, rank) - msg, target_tp_rank = msgspec.msgpack.decode(msg) + msg, target_rank = msgspec.msgpack.decode(msg) logger.debug( - "Received message for tp rank %s", - target_tp_rank, + "Received message for rank %s", + target_rank, ) if msg != GET_META_MSG: logger.warning("Connection listener got unexpected message %s", msg) - sock.send_multipart((identity, b"", encoded_data[target_tp_rank])) + sock.send_multipart((identity, b"", encoded_data[target_rank])) def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int @@ -689,7 +701,9 @@ class NixlConnectorScheduler: if self.use_host_buffer and params.get("do_remote_decode"): # NOTE: when accelerator is not directly supported by Nixl, # prefilled blocks need to be saved to host memory before transfer. - self._reqs_need_save[request.request_id] = request + block_ids = blocks.get_block_ids()[0] + if block_ids: + self._reqs_need_save[request.request_id] = (request, block_ids) elif params.get("do_remote_prefill"): if params.get("remote_block_ids"): if all( @@ -735,38 +749,23 @@ class NixlConnectorScheduler: # Loop through scheduled reqs and convert to ReqMeta. for req_id, (req, block_ids) in self._reqs_need_recv.items(): assert req.kv_transfer_params is not None - meta.add_new_req_to_recv( + meta.add_new_req( request_id=req_id, local_block_ids=block_ids, kv_transfer_params=req.kv_transfer_params, + load_remote_cache=True, + save_to_host=False, ) - # NOTE: For the prefill side, there might be a chance that an early added - # request is a chunked prefill, so we need to check if new blocks are added - for req_id, new_block_id_groups, _ in yield_req_data(scheduler_output): - req_to_save = self._reqs_need_save.get(req_id) - if req_to_save is None or new_block_id_groups is None: - continue - req = req_to_save - + for req_id, (req, block_ids) in self._reqs_need_save.items(): assert req.kv_transfer_params is not None - meta.add_new_req_to_save( + meta.add_new_req( request_id=req_id, - local_block_ids=new_block_id_groups[0], + local_block_ids=block_ids, kv_transfer_params=req.kv_transfer_params, + load_remote_cache=False, + save_to_host=True, ) - assert scheduler_output.num_scheduled_tokens is not None - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] - is_partial = ( - req.num_computed_tokens + num_scheduled_tokens - ) < req.num_prompt_tokens - if not is_partial: - # For non-partial prefills, once new req_meta is scheduled, it - # can be removed from _reqs_need_save. - # For partial prefill case, we will retain the request in - # _reqs_need_save until all blocks are scheduled with req_meta. - # Therefore, only pop if `not is_partial`. - self._reqs_need_save.pop(req_id) meta.reqs_to_send = self._reqs_need_send meta.reqs_in_batch = self._reqs_in_batch @@ -774,6 +773,7 @@ class NixlConnectorScheduler: # Clear the list once workers start the transfers self._reqs_need_recv.clear() + self._reqs_need_save.clear() self._reqs_in_batch = set() self._reqs_not_processed = set() self._reqs_need_send = {} @@ -819,8 +819,6 @@ class NixlConnectorScheduler: # Also include the case of a P/D Prefill request with immediate # block free (eg abort). Stop tracking this request. self._reqs_not_processed.add(request.request_id) - # Clear _reqs_need_save if a request is aborted as partial prefill. - self._reqs_need_save.pop(request.request_id, None) return False, None # TODO: check whether block_ids actually ever be 0. If not we could @@ -848,6 +846,7 @@ class NixlConnectorScheduler: remote_host=self.side_channel_host, remote_port=self.side_channel_port, tp_size=self.vllm_config.parallel_config.tensor_parallel_size, + pp_size=self.vllm_config.parallel_config.pipeline_parallel_size, ) @@ -901,6 +900,12 @@ class NixlConnectorWorker: # Metadata. self.engine_id: EngineId = engine_id self.tp_rank = get_tensor_model_parallel_rank() + self.pp_rank = get_pp_group().rank_in_group + self.tp_world_size = get_tp_group().world_size + self.pp_world_size = get_pp_group().world_size + self.dp_rank = vllm_config.parallel_config.data_parallel_rank + self.dp_world_size = vllm_config.parallel_config.data_parallel_size + self.global_rank = self.pp_rank * self.tp_world_size + self.tp_rank self.world_size = get_tensor_model_parallel_world_size() self.tp_group = get_tp_group() self.num_blocks = 0 @@ -945,7 +950,7 @@ class NixlConnectorWorker: # type based on kv_buffer_device nixl_memory_type = current_platform.get_nixl_memory_type() if nixl_memory_type is None: - if self.kv_buffer_device == "cuda": + if self.kv_buffer_device in ["cuda", "xpu"]: nixl_memory_type = "VRAM" elif self.kv_buffer_device == "cpu": nixl_memory_type = "DRAM" @@ -960,10 +965,9 @@ class NixlConnectorWorker: self.copy_blocks: CopyBlocksOp | None = None # Map of engine_id -> kv_caches_base_addr. For TP case, each local + # rank will still only pull from a single remote TP worker. + self.kv_caches_base_addr: dict[EngineId, list[int]] = {} self.device_id: int = 0 - # Current rank may pull from multiple remote TP workers. - # EngineId, dict[int, list[int]] -> engine_id, tp_rank, base_addr_for_layer - self.kv_caches_base_addr = defaultdict[EngineId, dict[int, list[int]]](dict) # Number of NIXL regions. Currently one region per cache # (so 1 per layer for MLA, otherwise 2 per layer) @@ -971,12 +975,12 @@ class NixlConnectorWorker: self.num_layers = 0 # nixl_prepped_dlist_handle. - self.src_xfer_handles_by_block_size: dict[int, int] = {} - # Populated dynamically during handshake based on remote configuration. - # Keep track of regions at different tp_ratio values. tp_ratio->handles - self.src_xfer_handles_by_tp_ratio: dict[int, list[int]] = {} - # Map of engine_id -> {tp_rank: nixl_prepped_dlist_handle (int)}. - self.dst_xfer_side_handles = defaultdict[EngineId, dict[int, int]](dict) + self.src_xfer_side_handle: int = 0 + # Map of engine_id -> nixl_prepped_dlist_handle (int)]. + self.dst_xfer_side_handles: dict[EngineId, dict[int, int]] = {} + self.src_xfer_side_handles: dict[EngineId, dict[int, int]] = {} + # Map of (offset, len) -> nixl_prepped_dlist_handle (int). + self.src_registered_xfer_side_headles: dict[tuple[int], int] = {} # Map of engine_id -> num_blocks. All ranks in the same deployment will # have the same number of blocks. @@ -986,7 +990,7 @@ class NixlConnectorWorker: # In progress transfers. # [req_id -> list[handle]] self._recving_metadata: dict[ReqId, ReqMeta] = {} - self._recving_transfers = defaultdict[ReqId, list[TransferHandle]](list) + self._recving_transfers = defaultdict[ReqId, list[tuple[int, float]]](list) # Track the expiration time of requests that are waiting to be sent. self._reqs_to_send: dict[ReqId, float] = {} # Set of requests that have been part of a batch, regardless of status. @@ -1034,7 +1038,8 @@ class NixlConnectorWorker: self.compat_hash: str | None = None self.kv_topo: TpKVTopology | None = None - self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size} + self._tp_size: dict[EngineId, int] = {self.engine_id: self.tp_world_size} + self._pp_size: dict[EngineId, int] = {self.engine_id: self.pp_world_size} self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size} # With heterogeneous TP, P must wait for all assigned D TP workers to # finish reading before safely freeing the blocks. @@ -1047,35 +1052,50 @@ class NixlConnectorWorker: "enforce_handshake_compat", True ) + pp_split = envs.VLLM_PP_LAYER_PARTITION + if pp_split is None and self.pp_world_size > 1: + raise RuntimeError( + "VLLM_PP_LAYER_PARTITION must be set when using nixl PD with pp > 1." + ) + self.pp_split = pp_split or "0" + if self.pp_world_size > 1: + assert len(self.pp_split.split(",")) == self.pp_world_size, ( + "VLLM_PP_LAYER_PARTITION split count must equal pp_size" + ) + self.local_start_index: dict[EngineId, dict[int, tuple]] = {} + self.local_end_index: dict[EngineId, dict[int, tuple]] = {} + self.remote_nums: dict[EngineId, dict[int, int]] = {} + self.is_kv = 1 if self.use_mla else 2 + self.kv_element_size = None + self.ds_v32 = hasattr(self.model_config.hf_config, "index_topk") + def _nixl_handshake( self, host: str, port: int, remote_tp_size: int, + remote_pp_size: int, expected_engine_id: str, - ) -> dict[int, str]: + ) -> dict[int, tuple[str, bool]]: """Do a NIXL handshake with a remote instance.""" - # When target instance TP > local TP, we need to perform multiple - # handshakes. Do it in a single background job for simplicity. - # Regardless, only handshake with the remote TP rank(s) that current - # local rank will read from. Note that With homogeneous TP, - # this happens to be the same single rank_i. assert self.kv_topo is not None - p_remote_ranks = self.kv_topo.get_target_remote_ranks(remote_tp_size) - remote_rank_to_agent_name = {} + p_remote_rank_list = self.kv_topo.get_target_remote_ranks( + remote_tp_size, remote_pp_size + ) + remote_agent_dict: dict[int, tuple[str, bool]] = {} path = make_zmq_path("tcp", host, port) with zmq_ctx(zmq.REQ, path) as sock: - for remote_rank in p_remote_ranks: + for p_remote_rank, remote_pp_rank, remote_tp_rank in p_remote_rank_list: logger.debug( - "Querying metadata on path: %s at remote tp rank %s", - path, - remote_rank, + "Querying metadata on path: %s at remote rank %s " + "(pp_rank=%s, tp_rank=%s)", + path, p_remote_rank, remote_pp_rank, remote_tp_rank, ) start_time = time.perf_counter() # Send query for the request. - msg = msgspec.msgpack.encode((GET_META_MSG, remote_rank)) + msg = msgspec.msgpack.encode((GET_META_MSG, p_remote_rank)) # Set receive timeout to 5 seconds to avoid hanging on dead server sock.setsockopt(zmq.RCVTIMEO, 5000) # milliseconds sock.send(msg) @@ -1108,10 +1128,7 @@ class NixlConnectorWorker: f"Local: {self.compat_hash}, " f"Remote: {handshake_payload.compatibility_hash}. " f"Prefill and decode instances have incompatible " - f"configurations. This may be due to: different vLLM versions," - f" models, dtypes, KV cache layouts, attention backends, etc. " - f"Both instances must use identical configurations." - f"Disable this check using " + f"configurations. Disable this check using " f'--kv-transfer-config \'{{"kv_connector_extra_config": ' f'{{"enforce_handshake_compat": false}}}}\'' ) @@ -1140,18 +1157,22 @@ class NixlConnectorWorker: f"Expected {expected_engine_id}," f"received {metadata.engine_id}." ) - setup_agent_time = time.perf_counter() - # Register Remote agent. - remote_agent_name = self.add_remote_agent( - metadata, remote_rank, remote_tp_size + assert metadata.block_size <= self.block_size, ( + "nP > nD is not supported yet." ) + remote_agent_name, read = self.add_remote_agent( + metadata, p_remote_rank, remote_tp_rank, + remote_tp_size, remote_pp_rank, remote_pp_size, + ) + + setup_agent_time = time.perf_counter() logger.debug( "NIXL handshake: add agent took: %s", setup_agent_time - got_metadata_time, ) - remote_rank_to_agent_name[remote_rank] = remote_agent_name - return remote_rank_to_agent_name + remote_agent_dict[p_remote_rank] = (remote_agent_name, read) + return remote_agent_dict def initialize_host_xfer_buffer(self, kv_caches: dict[str, torch.Tensor]) -> None: """ @@ -1228,15 +1249,14 @@ class NixlConnectorWorker: # Try to get metadata from in progress transfers when not provided meta = self._recving_metadata.get(req_id) - if meta and meta.remote: + if meta: context.update( { - "remote_engine_id": meta.remote.engine_id, - "remote_request_id": meta.remote.request_id, - "remote_host": meta.remote.host, - "remote_port": meta.remote.port, + "remote_engine_id": meta.remote_engine_id, + "remote_host": meta.remote_host, + "remote_port": meta.remote_port, "num_local_blocks": len(meta.local_block_ids), - "num_remote_blocks": len(meta.remote.block_ids), + "num_remote_blocks": len(meta.remote_block_ids), "local_block_ids_sample": meta.local_block_ids[:10], } ) @@ -1259,12 +1279,12 @@ class NixlConnectorWorker: # Do NIXL handshake in background and add to _ready_requests when done. fut = self._handshake_futures.get(remote_engine_id) if fut is None: - assert meta.remote is not None fut = self._handshake_initiation_executor.submit( self._nixl_handshake, - meta.remote.host, - meta.remote.port, - meta.tp_size, + meta.remote_host, + meta.remote_port, + meta.remote_tp_size, + meta.remote_pp_size, remote_engine_id, ) self._handshake_futures[remote_engine_id] = fut @@ -1317,6 +1337,48 @@ class NixlConnectorWorker: attn_backend=self.attn_backend, tensor_shape=next(iter(kv_caches.values())).shape, ) + original_get_target = self.kv_topo.get_target_remote_ranks + + def _pp_aware_get_target_remote_ranks( + remote_tp_size, remote_pp_size=1 + ): + if remote_pp_size == 1 and self.pp_world_size == 1: + tp_ranks = original_get_target(remote_tp_size) + return [(r, 0, r) for r in tp_ranks] + + tp_ratio = self.kv_topo.tp_ratio(remote_tp_size) + pp_ratio = ( + self.pp_world_size // remote_pp_size + if self.pp_world_size >= remote_pp_size + else remote_pp_size // self.pp_world_size + ) + target_pp_rank_list = [] + target_tp_rank_list = [] + if self.pp_world_size < remote_pp_size: + for i in range(pp_ratio): + target_pp_rank_list.append( + self.pp_rank * pp_ratio + i) + else: + target_pp_rank_list.append( + self.pp_rank // pp_ratio) + + if self.tp_world_size < remote_tp_size: + for i in range(tp_ratio): + target_tp_rank_list.append( + self.tp_rank * tp_ratio + i) + else: + target_tp_rank_list.append( + self.tp_rank // tp_ratio) + + result = [] + for pp_rank in target_pp_rank_list: + for tp_rank in target_tp_rank_list: + global_rank = pp_rank * remote_tp_size + tp_rank + result.append((global_rank, pp_rank, tp_rank)) + return result + + self.kv_topo.get_target_remote_ranks = ( + _pp_aware_get_target_remote_ranks) self.compat_hash = compute_nixl_compatibility_hash( self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks ) @@ -1422,7 +1484,7 @@ class NixlConnectorWorker: assert len(self.block_len_per_layer) == len(seen_base_addresses) assert self.num_blocks != 0 - self.kv_caches_base_addr[self.engine_id][self.tp_rank] = seen_base_addresses + self.kv_caches_base_addr[self.engine_id] = seen_base_addresses self.num_regions = len(caches_data) self.num_layers = len(xfer_buffers.keys()) @@ -1450,8 +1512,9 @@ class NixlConnectorWorker: # Register local/src descr for NIXL xfer. self.seen_base_addresses = seen_base_addresses - self.src_xfer_handles_by_block_size[self.block_size], self.src_blocks_data = ( - self.register_local_xfer_handler(self.block_size) + first_kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=0) + self.src_registered_xfer_side_headles[(0, first_kv_block_len)] = ( + self.register_local_xfer_handler(self.block_size)[0] ) # TODO(mgoin): Hybrid memory allocator is currently disabled for @@ -1481,13 +1544,17 @@ class NixlConnectorWorker: engine_id=self.engine_id, agent_metadata=self.nixl_wrapper.get_agent_metadata(), device_id=self.device_id, - kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id][self.tp_rank], + kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id], num_blocks=self.num_blocks, block_lens=self.block_len_per_layer, + attn_backend_name=self.backend_name, kv_cache_layout=self.kv_cache_layout if not self.use_host_buffer else self.host_buffer_kv_cache_layout, block_size=self.block_size, + pp_split=self.pp_split, + pp_rank=self.pp_rank, + tp_rank=self.tp_rank, ) # Wrap metadata in payload with hash for defensive decoding assert self.compat_hash is not None @@ -1555,63 +1622,33 @@ class NixlConnectorWorker: def add_remote_agent( self, nixl_agent_meta: NixlAgentMetadata, + p_remote_rank: int = 0, remote_tp_rank: int = 0, remote_tp_size: int = 1, - ) -> str: + remote_pp_rank: int = 0, + remote_pp_size: int = 1, + ) -> tuple[str, bool]: """ Add the remote NIXL agent and prepare the descriptors for reading cache - blocks from remote. - - In particular, handle both homogeneous and heterogeneous TP. The former - requires local rank_i to read from remote rank_i. - The latter, in the case of D.world_size < P.world_size, requires that a - local (D) TP worker reads from multiple remote (P) TP workers. - Conversely, assuming D.world_size > P.world_size, two or more local TP - workers will read from a single remote TP worker. - - Here's an example for the last case described above (non-MLA): - - rank_offset p_remote_tp_rank - (kv split no) - -------------------------------- - 0 0 Worker0 ---- 1st half of KV ----> Worker0 [ KV Cache ] - / - 1 0 Worker1 ---- 2nd half of KV -----/ - - 0 1 Worker2 ---- 1st half of KV ----> Worker1 [ KV Cache ] - / - 1 1 Worker3 ---- 2nd half of KV -----/ - - - Decoder TP workers Prefix TP workers - (world_size=4) (world_size=2) - tp_ratio = 4 // 2 = 2 - - Considering the KV Caches, if P-Worker_i has cache size [2, num_blocksP, kv_heads, block_size, head_dim] - then D-Worker_j has [2, num_blocksD, kv_heads//tp_ratio, block_size, head_dim]. Mind the "HND" layout format. - Assuming num_blocksD >= num_blocksP, D-Worker0 reads from P-Worker0 by preparing the kv_heads//tp_ratio - first heads from all the slots of all the blocks. D-Worker1 will do the same, but reading the second split - along the kv_heads dimension, and so forth until "tp_ratio" D TP workers have pulled from P-Worker0. - - Note that the above will also hold true for the homogeneous TP case, where tp_ratio evaluates to 1. - - Regarding MLA case, the cache is replicated across TP workers so the rank_offset will just always be 0 - so that the whole cache is shared by "tp_ratio" D TP workers. - """ # noqa: E501 + blocks from remote, with PP support. + """ engine_id = nixl_agent_meta.engine_id - # TODO re-evaluate refreshing for scaling/recovery - if remote_tp_rank in self._remote_agents.get(engine_id, {}): + if p_remote_rank in self._remote_agents.get(engine_id, {}): logger.debug( - "Remote agent with engine_id %s and rank" + "Remote agent with engine_id %s and rank " "%s already exchanged metadata, skip handshake.", - engine_id, - remote_tp_rank, + engine_id, p_remote_rank, ) - return self._remote_agents[engine_id][remote_tp_rank] + return self._remote_agents[engine_id][p_remote_rank] + + if (self.tp_world_size == remote_tp_size + and self.pp_world_size == remote_pp_size): + assert self.tp_rank == remote_tp_rank + assert self.pp_rank == remote_pp_rank - ### Register remote agent metadata if engine_id not in self._tp_size: self._tp_size[engine_id] = remote_tp_size + self._pp_size[engine_id] = remote_pp_size if engine_id not in self._block_size: self._block_size[engine_id] = nixl_agent_meta.block_size @@ -1619,139 +1656,275 @@ class NixlConnectorWorker: nixl_agent_meta.agent_metadata ) - # Create dst descs and xfer side handles. TP workers have same #blocks - # so we only register once per engine_id. - # Example: - # block_size_ratio > 1: - # remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12| - # local origin:| 0| 1| 8| 12| - # local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15| + remote_replicates_kv_cache = self.kv_topo.replicates_kv_cache(engine_id) + local_replicates_kv_cache = self.kv_topo.replicates_kv_cache(self.engine_id) + assert self.kv_topo is not None block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(engine_id) + assert block_size_ratio == 1, "block_size_ratio != 1 not supported with PP" if engine_id not in self.dst_num_blocks: self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks - # Keep track of remote agent kv caches base addresses. - self.kv_caches_base_addr[engine_id][remote_tp_rank] = ( - nixl_agent_meta.kv_caches_base_addr + self._validate_remote_agent_handshake( + nixl_agent_meta, remote_tp_size, remote_pp_size ) - self._validate_remote_agent_handshake(nixl_agent_meta, remote_tp_size) - # This is 1 when P and D `--tensor-parallel-size` match. Otherwise, - # this is the ratio between the two sizes. tp_ratio = self.kv_topo.tp_ratio_from_engine_id(engine_id) + total_num_kv_heads = self.model_config.get_total_num_kv_heads() - # Handle tp_size>num_kv_heads: replicate KV cache. - indexes_into_remote = ( - not self.kv_topo.replicates_kv_cache(engine_id) and tp_ratio > 0 - ) - - logger.debug( - "Registering remote agent (%s, rank %s) memory regions with tp_ratio %s", - engine_id, - remote_tp_rank, - tp_ratio, - ) - - ### (Optional) Register local agent memory regions. MLA is not split. - if ( - tp_ratio < 0 - and not self.use_mla - and tp_ratio not in self.src_xfer_handles_by_tp_ratio - ): - # Remote tp_size > local tp_size: read from multiple remote ranks. - # Logically "split" own regions into |tp_ratio| chunks. Mind that - # we only do this once per remote tp_size (replica-friendly). - self.src_xfer_handles_by_tp_ratio[tp_ratio] = [] - for i in range(-tp_ratio): - blocks_data = [] - for memory_region in self.src_blocks_data: - addr, local_block_len, own_tp_rank = memory_region - # Computing block len layer by layer allows for different - # block sizes to be used. - remote_block_len = local_block_len // (-tp_ratio) - addr = addr + i * remote_block_len - blocks_data.append((addr, remote_block_len, own_tp_rank)) - descs = self.nixl_wrapper.get_xfer_descs( - blocks_data, self.nixl_memory_type - ) - handle = self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs) - self.src_xfer_handles_by_tp_ratio[tp_ratio].append(handle) - - ### Register remote agent memory regions blocks_data = [] - # With homogeneous TP, D pulls the whole kv cache from corresponding - # rank. With heterogeneous TP, prepare the descriptors by splitting the - # P KV cache along kv_head dim, of D worker's kv_head size (D>P). - # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..]. + if engine_id not in self.local_start_index: + self.local_start_index[engine_id] = {} + self.local_end_index[engine_id] = {} + self.remote_nums[engine_id] = {} + self.local_start_index[engine_id][p_remote_rank] = 0 + self.local_end_index[engine_id][p_remote_rank] = self.num_regions + self.remote_nums[engine_id][p_remote_rank] = self.num_regions + remote_start = 0 + remote_end = self.num_regions - # Register all remote blocks, but only the corresponding kv heads. - for i, base_addr in enumerate(nixl_agent_meta.kv_caches_base_addr): - # Read our whole local region size from remote. - local_block_len = self.get_backend_aware_kv_block_len(layer_idx=i) - remote_kv_block_len = local_block_len // block_size_ratio - if block_size_ratio > 1: - # using remote kv_block_len as transfer unit - local_block_len = remote_kv_block_len - - if tp_ratio < 0 and not self.use_mla: - # Remote tp is bigger: read a chunk of local region from remote - local_block_len = local_block_len // (-tp_ratio) - rank_offset = ( - self.tp_rank % tp_ratio * remote_kv_block_len - if indexes_into_remote - else 0 + if self.pp_world_size != remote_pp_size: + assert self.pp_split != "" or nixl_agent_meta.pp_split != "", ( + "pp_split must be set when pp_size > 1" ) + assert (self.pp_world_size % remote_pp_size == 0 + or remote_pp_size % self.pp_world_size == 0) + + local_pp_layer_split = [int(x) for x in self.pp_split.split(",")] + local_pp_layer_index_cum = [0] + for x in local_pp_layer_split: + local_pp_layer_index_cum.append( + local_pp_layer_index_cum[-1] + x) + + remote_pp_layer_split = [ + int(x) for x in nixl_agent_meta.pp_split.split(",") + ] + remote_pp_layer_index_cum = [0] + for x in remote_pp_layer_split: + remote_pp_layer_index_cum.append( + remote_pp_layer_index_cum[-1] + x) + + if self.pp_world_size < remote_pp_size: + ratio = remote_pp_size // self.pp_world_size + assert remote_pp_rank == int(nixl_agent_meta.pp_rank) + assert (local_pp_layer_index_cum[self.pp_rank + 1] + >= remote_pp_layer_index_cum[remote_pp_rank + 1]) + + if (remote_pp_rank + 1) % ratio == 0: + assert (local_pp_layer_index_cum[self.pp_rank + 1] + == remote_pp_layer_index_cum[remote_pp_rank + 1]) + + lst = (0 if remote_pp_rank == 0 + else (remote_pp_layer_index_cum[remote_pp_rank] + - local_pp_layer_index_cum[ + remote_pp_rank // ratio])) + lst = lst * self.is_kv + led = (remote_pp_layer_split[remote_pp_rank] + * self.is_kv + lst) + self.local_start_index[engine_id][p_remote_rank] = lst + self.local_end_index[engine_id][p_remote_rank] = led + if self.ds_v32: + local_pp_layers = local_pp_layer_split[self.pp_rank] + self.local_start_index[engine_id][p_remote_rank] = ( + lst, lst + local_pp_layers) + self.local_end_index[engine_id][p_remote_rank] = ( + led, led + local_pp_layers) + remote_start = 0 + remote_end = (remote_pp_layer_split[remote_pp_rank] + * self.is_kv) + self.remote_nums[engine_id][p_remote_rank] = ( + remote_end if not self.ds_v32 else remote_end * 2) + else: + ratio = self.pp_world_size // remote_pp_size + assert remote_pp_rank == int(nixl_agent_meta.pp_rank) + assert self.pp_rank // ratio == remote_pp_rank + assert (local_pp_layer_index_cum[self.pp_rank + 1] + <= remote_pp_layer_index_cum[remote_pp_rank + 1]) + + if (self.pp_rank + 1) % ratio == 0: + assert (local_pp_layer_index_cum[self.pp_rank + 1] + == remote_pp_layer_index_cum[ + remote_pp_rank + 1]) + + remote_start = ( + 0 if self.pp_rank == 0 + else (local_pp_layer_index_cum[self.pp_rank] + - remote_pp_layer_index_cum[remote_pp_rank])) + remote_start = remote_start * self.is_kv + remote_end = (local_pp_layer_split[self.pp_rank] + * self.is_kv + remote_start) + self.remote_nums[engine_id][p_remote_rank] = ( + remote_end - remote_start if not self.ds_v32 + else (remote_end - remote_start) * 2) + else: + if self.pp_world_size > 1: + assert self.pp_split == nixl_agent_meta.pp_split, ( + f"pp_split must match: {self.pp_split} != " + f"{nixl_agent_meta.pp_split}") + + kv_block_len = self.get_backend_aware_kv_block_len(layer_idx=0) + read = True + local_rank_offset = 0 + local_block_len = kv_block_len + remote_rank_offset = 0 + default_local_xfer_dlist_key = (0, kv_block_len) + current_local_xfer_dlist = self.src_registered_xfer_side_headles[ + default_local_xfer_dlist_key] + remote_block_len = nixl_agent_meta.block_lens[0] + block_len_ratio = 1 + + if self.use_mla or (local_replicates_kv_cache + and remote_replicates_kv_cache): + assert self.block_len_per_layer[0] == nixl_agent_meta.block_lens[0], ( + "KV cache sizes must match between P and D when replicated" + ) + if self.tp_world_size < remote_tp_size: + read = bool(remote_tp_rank % tp_ratio == 0) + else: + if self.tp_world_size >= remote_tp_size: + assert not remote_replicates_kv_cache + if not local_replicates_kv_cache: + remote_rank_offset = (self.tp_rank % tp_ratio + * local_block_len) + assert remote_block_len == ( + self.block_len_per_layer[0] * tp_ratio) + else: + local_repeat_times = (self._tp_size[self.engine_id] + // total_num_kv_heads) + remote_rank_offset = ( + (self.tp_rank // local_repeat_times) + % (tp_ratio // local_repeat_times) + * local_block_len) + assert remote_block_len == ( + self.block_len_per_layer[0] + * (tp_ratio // local_repeat_times)) + block_len_ratio = max(remote_block_len // local_block_len, 1) + else: + assert not local_replicates_kv_cache + if not remote_replicates_kv_cache: + assert remote_block_len == local_block_len // tp_ratio + local_rank_offset = (remote_block_len + * (remote_tp_rank % tp_ratio)) + assert remote_block_len == ( + self.block_len_per_layer[0] // tp_ratio) + else: + remote_repeat_times = (remote_tp_size + // total_num_kv_heads) + assert remote_block_len == ( + local_block_len + // (tp_ratio // remote_repeat_times)) + read = bool( + remote_tp_rank % remote_repeat_times == 0) + local_rank_offset = ( + remote_block_len + * ((remote_tp_rank // remote_repeat_times) + % (tp_ratio // remote_repeat_times))) + assert remote_block_len == ( + self.block_len_per_layer[0] + // (tp_ratio // remote_repeat_times)) + local_block_len = remote_block_len + + if not read: + local_rank_offset = 0 + + def register_new_local_xfer_dlist(offset, block_len): + bdata = [] + for base_addr in self.kv_caches_base_addr[self.engine_id]: + for block_id in range(self.num_blocks): + boffset = (block_id + * self.get_backend_aware_kv_block_len(0)) + addr = base_addr + boffset + offset + bdata.append((addr, block_len, self.global_rank)) + descs = self.nixl_wrapper.get_xfer_descs( + bdata, self.nixl_memory_type) + return self.nixl_wrapper.prep_xfer_dlist( + "NIXL_INIT_AGENT", descs) + + if (local_rank_offset, local_block_len) != default_local_xfer_dlist_key: + key = (local_rank_offset, local_block_len) + if key in self.src_registered_xfer_side_headles: + current_local_xfer_dlist = ( + self.src_registered_xfer_side_headles[key]) + else: + current_local_xfer_dlist = register_new_local_xfer_dlist( + local_rank_offset, local_block_len) + self.src_registered_xfer_side_headles[key] = ( + current_local_xfer_dlist) + + self.kv_caches_base_addr[engine_id] = ( + nixl_agent_meta.kv_caches_base_addr[remote_start:remote_end]) + + for i, base_addr in enumerate( + nixl_agent_meta.kv_caches_base_addr): + if i < remote_start or i >= remote_end: + if self.ds_v32: + half = len(nixl_agent_meta.kv_caches_base_addr) // 2 + if (i < remote_start + half + or i >= remote_end + half): + continue + else: + continue for block_id in range(nixl_agent_meta.num_blocks): - block_offset = block_id * nixl_agent_meta.block_lens[i] - # For each block, grab the heads chunk belonging to rank_i - # of size remote_nheads // tp_ratio, which correspond to - # self.block_len == remote_block_len//tp_ratio bytes. - addr = base_addr + block_offset + rank_offset - # (addr, len, device id) - blocks_data.append((addr, local_block_len, nixl_agent_meta.device_id)) + cur_layer_block_len = nixl_agent_meta.block_lens[i] + block_offset = block_id * cur_layer_block_len + addr = base_addr + block_offset + remote_rank_offset + blocks_data.append(( + addr, + cur_layer_block_len // block_len_ratio, + nixl_agent_meta.device_id, + )) if self.kv_topo.is_kv_layout_blocks_first: - # With FlashInfer index V separately to allow head splitting. - for block_id in range(nixl_agent_meta.num_blocks): - block_offset = block_id * nixl_agent_meta.block_lens[i] - addr = base_addr + block_offset + rank_offset - v_addr = addr + nixl_agent_meta.block_lens[i] // 2 - blocks_data.append( - (v_addr, local_block_len, nixl_agent_meta.device_id) - ) + raise NotImplementedError( + "FlashInfer not supported with PP yet") logger.debug( - "Created %s blocks for dst engine %s with remote rank %s and local rank %s", - len(blocks_data), - engine_id, - remote_tp_rank, - self.tp_rank, + "Created %s blocks for dst engine %s with remote rank %s " + "and local rank %s", + len(blocks_data), engine_id, p_remote_rank, self.global_rank, ) - # Register with NIXL. - descs = self.nixl_wrapper.get_xfer_descs(blocks_data, self.nixl_memory_type) - self.dst_xfer_side_handles[engine_id][remote_tp_rank] = ( - self.nixl_wrapper.prep_xfer_dlist(remote_agent_name, descs) + descs = self.nixl_wrapper.get_xfer_descs( + blocks_data, self.nixl_memory_type) + if engine_id not in self.dst_xfer_side_handles: + self.dst_xfer_side_handles[engine_id] = { + p_remote_rank: self.nixl_wrapper.prep_xfer_dlist( + remote_agent_name, descs)} + self.src_xfer_side_handles[engine_id] = { + p_remote_rank: current_local_xfer_dlist} + else: + self.dst_xfer_side_handles[engine_id][p_remote_rank] = ( + self.nixl_wrapper.prep_xfer_dlist(remote_agent_name, descs)) + self.src_xfer_side_handles[engine_id][p_remote_rank] = ( + current_local_xfer_dlist) + + logger.info( + "create remote info: " + "local: tp_rank %s, tp_size %s, pp_rank %s, pp_size %s, " + "remote: tp_rank %s, tp_size %s, pp_rank %s, pp_size %s, " + "local_rank_offset: %s, local_block_len: %s, " + "local start: %s, local end: %s, " + "remote_rank_offset %s, remote_block_len %s, " + "remote start: %s, remote end: %s, read: %s", + self.tp_rank, self.tp_world_size, self.pp_rank, + self.pp_world_size, + remote_tp_rank, remote_tp_size, remote_pp_rank, + remote_pp_size, + local_rank_offset, local_block_len, + self.local_start_index[engine_id][p_remote_rank], + self.local_end_index[engine_id][p_remote_rank], + remote_rank_offset, remote_block_len, + remote_start, remote_end, read, ) - if block_size_ratio > 1: - # when prefill with smaller block_size, we need to init a - # new handler with same block_len to match - self.src_xfer_handles_by_block_size[nixl_agent_meta.block_size] = ( - self.register_local_xfer_handler(nixl_agent_meta.block_size)[0] - ) - - return remote_agent_name + return (remote_agent_name, read) def _validate_remote_agent_handshake( - self, nixl_agent_meta: NixlAgentMetadata, remote_tp_size: int + self, nixl_agent_meta: NixlAgentMetadata, + remote_tp_size: int, remote_pp_size: int ): - """ - Validate the remote agent handshake metadata ensuring the - invariants hold true. - """ + """Validate the remote agent handshake metadata.""" remote_engine_id = nixl_agent_meta.engine_id assert self._tp_size[remote_engine_id] == remote_tp_size @@ -1761,8 +1934,6 @@ class NixlConnectorWorker: block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( remote_engine_id ) - # Num kv_heads > tp_size and P TP > D TP case, not supported - assert not (tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id)) kv_cache_layout = ( self.kv_cache_layout @@ -1786,46 +1957,51 @@ class NixlConnectorWorker: "setting 'enable_permute_local_kv'=True in --kv-transfer-config." ) - # Block len can only vary across layers when using MLA. remote_block_len = nixl_agent_meta.block_lens[0] if self.use_mla or self.kv_topo.is_kv_replicated(remote_engine_id): - # With replicated KV cache, only the number of blocks can differ. - for i in range(len(self.block_len_per_layer)): - assert ( - self.block_len_per_layer[i] // block_size_ratio - == nixl_agent_meta.block_lens[i] - ), "KV cache sizes must match between P and D when replicated" + local_pp_split = [int(i) for i in self.pp_split.split(",")] + local_pp_layer_start = sum(local_pp_split[:self.pp_rank]) + local_pp_layers = local_pp_split[self.pp_rank] + local_pp_layer_end = local_pp_layer_start + local_pp_layers + + remote_pp_split_vals = [ + int(i) for i in nixl_agent_meta.pp_split.split(",")] + rpp_rank = nixl_agent_meta.pp_rank + remote_pp_layer_start = sum(remote_pp_split_vals[:rpp_rank]) + remote_pp_layers = remote_pp_split_vals[rpp_rank] + remote_pp_layer_end = remote_pp_layer_start + remote_pp_layers + + for i in range(local_pp_layer_start, local_pp_layer_end): + if i >= remote_pp_layer_start and i < remote_pp_layer_end: + assert ( + self.block_len_per_layer[i - local_pp_layer_start] + // block_size_ratio + == nixl_agent_meta.block_lens[ + i - remote_pp_layer_start] + ), "KV cache sizes must match between P and D" + if self.ds_v32: + assert ( + self.block_len_per_layer[ + local_pp_layers + i - local_pp_layer_start] + == nixl_agent_meta.block_lens[ + remote_pp_layers + i - remote_pp_layer_start] + ), "indexer_key_cache sizes must match" else: - # When MLA is not used, this is a list of the same block length for block_len in nixl_agent_meta.block_lens: assert block_len == remote_block_len, ( "All remote layers must have the same block size" ) + assert ( + remote_block_len + == (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio + ), ( + "Remote P worker KV layer cache shape mismatch." + ) - if tp_ratio > 0: - # Remote tp is smaller: remote block_len size is bigger - assert ( - remote_block_len - == (self.block_len_per_layer[0] * tp_ratio) // block_size_ratio - ), ( - "Remote P worker KV layer cache must be of shape [2, N, " - "local_kv_heads*tp_ratio, page_size, head_dim] and same dtype." - ) # noqa: E501 - else: - assert block_size_ratio == 1, ( - "Different local/remote block sizes are not supported when" - " P TP > D TP." - ) - # Remote tp is bigger: remote block_len size is smaller - assert remote_block_len == self.block_len_per_layer[0] // (-tp_ratio), ( - "Remote P worker KV layer cache must be of shape [2, N, " - "local_kv_heads/tp_ratio, page_size, head_dim] and same dtype." - ) # noqa: E501 - - # TP workers that handhshake with same remote have same #blocks. assert self.dst_num_blocks[remote_engine_id] == nixl_agent_meta.num_blocks - # Same number of regions/~layers. - assert len(nixl_agent_meta.kv_caches_base_addr) == len(self.block_len_per_layer) + if self.pp_world_size == remote_pp_size: + assert (len(nixl_agent_meta.kv_caches_base_addr) + == len(self.block_len_per_layer)) def sync_recved_kv_to_device(self, req_id: str, meta: ReqMeta): """copy recved kv from host buffer to device.""" @@ -1959,13 +2135,12 @@ class NixlConnectorWorker: # clean up metadata for completed requests meta = self._recving_metadata.pop(req_id, None) assert meta is not None, f"{req_id} not found in recving_metadata list" - assert meta.remote is not None if self.use_host_buffer: self.sync_recved_kv_to_device(req_id, meta) # post processing for heteroblocksize block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( - meta.remote.engine_id + meta.remote_engine_id ) if not self.use_mla and ( block_size_ratio > 1 or self.enable_permute_local_kv @@ -2004,14 +2179,28 @@ class NixlConnectorWorker: def _get_new_notifs(self) -> set[str]: """ Get req_ids which got a remote xfer message. When multiple consumers - are reading from the same producer (heterogeneous TP scenario), wait - for all consumers to be done pulling. + are reading from the same producer (heterogeneous TP/PP scenario), + wait for all consumers to be done pulling. """ assert self.kv_topo is not None notified_req_ids: set[str] = set() for notifs in self.nixl_wrapper.get_new_notifs().values(): for notif in notifs: - req_id, tp_size = notif.decode("utf-8").rsplit(":", 1) + parts = notif.decode("utf-8").rsplit(":", 2) + if len(parts) == 3: + req_id, tp_ratio_str, pp_ratio_str = parts + consumers_per_producer = ( + int(tp_ratio_str) * int(pp_ratio_str)) + elif len(parts) == 2: + req_id, tp_size_str = parts + n_consumers = int(tp_size_str) + tp_ratio = self.kv_topo.tp_ratio(n_consumers) + consumers_per_producer = ( + -tp_ratio if n_consumers > self.world_size else 1) + else: + logger.warning("Unexpected notif format: %s", notif) + continue + if ( req_id not in self._reqs_to_send and req_id not in self._reqs_to_process @@ -2024,18 +2213,7 @@ class NixlConnectorWorker: ) continue - # NOTE: `tp_ratio` is the opposite when swapping local<>remote - n_consumers = int(tp_size) - tp_ratio = self.kv_topo.tp_ratio(n_consumers) - - # Number of reads *per producer* to wait for. - # When remote D TP > local P TP we expect `tp_ratio` reads. - consumers_per_producer = ( - -tp_ratio if n_consumers > self.world_size else 1 - ) - self.consumer_notification_counts_by_req[req_id] += 1 - # Wait all consumers (D) to be done reading before freeing. if ( self.consumer_notification_counts_by_req[req_id] == consumers_per_producer @@ -2046,27 +2224,28 @@ class NixlConnectorWorker: self._reqs_to_send.pop(req_id, None) return notified_req_ids - def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]: + def _pop_done_transfers( + self, transfers: dict[str, list[tuple[int, float]]] + ) -> set[str]: """ Pop completed xfers by checking for DONE state. Args: - transfers: dict of req_id -> list[running_xfer] + transfers: dict of req_id -> list[(handle, start_time)] Returns: set of req_ids that have all done xfers """ done_req_ids: set[str] = set() - for req_id, handles in list(transfers.items()): + for req_id, handle_list in list(transfers.items()): in_progress = [] - for handle in handles: + for handle, start_time in handle_list: try: xfer_state = self.nixl_wrapper.check_xfer_state(handle) if xfer_state == "DONE": - # Get telemetry from NIXL res = self.nixl_wrapper.get_xfer_telemetry(handle) self.xfer_stats.record_transfer(res) self.nixl_wrapper.release_xfer_handle(handle) elif xfer_state == "PROC": - in_progress.append(handle) + in_progress.append((handle, start_time)) continue else: self._log_failure( @@ -2167,63 +2346,16 @@ class NixlConnectorWorker: self._reqs_to_send[req_id] = expiration_time def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): - assert meta.remote is not None and self.kv_topo is not None - remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id( - meta.remote.engine_id + logger.debug( + "Remote agent %s available, calling _read_blocks for req %s", + meta.remote_engine_id, req_id, + ) + self._read_blocks( + request_id=req_id, + dst_engine_id=meta.remote_engine_id, + local_block_ids=meta.local_physical_block_ids, + remote_block_ids=meta.remote_block_ids, ) - tp_ratio = self.kv_topo.tp_ratio_from_engine_id(meta.remote.engine_id) - # D may have to perform multiple reads from different remote ranks. - for i, remote_rank in enumerate(remote_ranks): - if self.use_mla and tp_ratio < 0 and i > 0: - # MLA opt: when P TP > D TP, only a single read is executed for - # the first remote rank (cache is duplicated).. - break - - remote_block_size = self.kv_topo.remote_block_size[meta.remote.engine_id] - logger.debug( - "Remote agent %s available, calling _read_blocks" - " on remote rank %s with remote block size %s for req %s", - meta.remote.engine_id, - remote_rank, - remote_block_size, - req_id, - ) - # Get side handles. - if tp_ratio < 0 and not self.use_mla: - assert remote_block_size == self.block_size - # Remote tp_size > local tp_size: we must perform multiple - # reads. Get the memory chunk onto which we will write to. - local_xfer_side_handle = self.src_xfer_handles_by_tp_ratio[tp_ratio][i] - else: - # Single read from remote, we write to the whole memory region. - # Also handle remote block size different from local block size. - local_xfer_side_handle = self.src_xfer_handles_by_block_size[ - remote_block_size - ] - - # Destination handle: remote_engine_id -> remote_rank -> handle. - remote_xfer_side_handle = self.dst_xfer_side_handles[meta.remote.engine_id][ - remote_rank - ] - self._read_blocks( - request_id=req_id, - dst_engine_id=meta.remote.engine_id, - remote_request_id=meta.remote.request_id, - local_block_ids=meta.local_physical_block_ids, - remote_block_ids=meta.remote.block_ids, - remote_rank=remote_rank, - local_xfer_side_handle=local_xfer_side_handle, - remote_xfer_side_handle=remote_xfer_side_handle, - ) - - if self.use_mla and tp_ratio < 0: - # ..but we still need to notify the other remote ranks that we - # have the blocks we need so they can update the request state. - notif_id = f"{req_id}:{self.world_size}".encode() - remote_agents = self._remote_agents[meta.remote.engine_id] - for rank_to_notify, agent in remote_agents.items(): - if rank_to_notify != remote_rank: - self.nixl_wrapper.send_notif(agent, notif_msg=notif_id) def _read_blocks( self, @@ -2231,165 +2363,147 @@ class NixlConnectorWorker: remote_block_ids: list[int], dst_engine_id: str, request_id: str, - remote_request_id: str, - remote_rank: int, - local_xfer_side_handle: int, - remote_xfer_side_handle: int, ): """ - Post a READ point-to-point xfer request from a single local worker to - a single remote worker. + Post READ xfer requests from local worker to remote workers. + With PP, may read from multiple remote PP ranks. """ assert self.kv_topo is not None - block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id) + block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( + dst_engine_id) if block_size_ratio > 1: local_block_ids = self.get_mapped_blocks( np.asarray(local_block_ids), block_size_ratio ) if len(local_block_ids) > len(remote_block_ids): - # NOTE: - # get_mapped_blocks will always expand block_ids for n times. - # ex: - # prefill block_ids with block_size as 4: - # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - # Local decode block_ids with block_size as 16: [1, 2, 3] - # expland ecode block_ids with get_mapped_blocks from [1, 2, 3] to - # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] - # Then we clip local to align with prefill - # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] to - # [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - local_block_ids = local_block_ids[: len(remote_block_ids)] - # NOTE(rob): having the staging blocks be on the READER side is - # not going to work well (since we will have to call rearrange tensors). - # after we detect the txn is complete (which means we cannot make the - # read trxn async easily). If we want to make "READ" happen cleanly, - # then we will need to have the staging blocks on the remote side. + local_block_ids = local_block_ids[:len(remote_block_ids)] - # NOTE(rob): according to nvidia the staging blocks are used to - # saturate IB with heterogeneous TP sizes. We should remove the staging - # blocks until we are ready. - - # Number of D TP workers that will read from dst P. Propagate info - # on notification so that dst worker can wait before freeing blocks. - notif_id = f"{remote_request_id}:{self.world_size}".encode() - - # Full prefix cache hit: do not need to read remote blocks, - # just notify P worker that we have the blocks we need. - num_local_blocks = len(local_block_ids) - if num_local_blocks == 0: - agent_name = self._remote_agents[dst_engine_id][remote_rank] - try: - self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id) - except Exception as e: - self._log_failure( - failure_type="notification_failed", - msg="P worker blocks will be freed after timeout. " - "This may indicate network issues.", - req_id=request_id, - error=e, - dst_engine_id=dst_engine_id, - remote_rank=remote_rank, - remote_agent_name=agent_name, - ) - self.xfer_stats.record_failed_notification() - return - - # Partial prefix cache hit: just read uncomputed blocks. - num_remote_blocks = len(remote_block_ids) - assert num_local_blocks <= num_remote_blocks - if num_local_blocks < num_remote_blocks: - remote_block_ids = remote_block_ids[-num_local_blocks:] - - # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from - # corresponding rank. With heterogeneous TP, fixing D>P, the D tp - # workers will issue xfers to parts of the P worker remote kv caches. - - # Get descs ids. - local_block_descs_ids: np.ndarray - remote_block_descs_ids: np.ndarray - - if not self.block_window_per_layer: - # Default case: assume global attention - remote_block_descs_ids = self._get_block_descs_ids( - dst_engine_id, - remote_block_ids, - ) - local_block_descs_ids = self._get_block_descs_ids( - self.engine_id, - local_block_ids, - block_size_ratio=block_size_ratio, - ) + if self._tp_size[self.engine_id] > self._tp_size[dst_engine_id]: + tp_ratio = (self._tp_size[self.engine_id] + // self._tp_size[dst_engine_id]) else: - # TODO(mgoin): remove this once we have hybrid memory allocator - # Optimization for models with local attention (Llama 4) - local_descs_list = [] - remote_descs_list = [] - for layer_idx, block_window in enumerate(self.block_window_per_layer): - # For each layer: - if block_window is None: - # If not chunked, we just use the - # full block lists (global attention) - layer_local_block_ids = local_block_ids - layer_remote_block_ids = remote_block_ids + tp_ratio = 1 + + if self._pp_size[self.engine_id] > self._pp_size[dst_engine_id]: + pp_ratio = (self._pp_size[self.engine_id] + // self._pp_size[dst_engine_id]) + else: + pp_ratio = 1 + + notif_id = f"{request_id}:{tp_ratio}:{pp_ratio}".encode() + + remote_infos = self._remote_agents[dst_engine_id] + local_side_handles = self.src_xfer_side_handles[dst_engine_id] + remote_side_handles = self.dst_xfer_side_handles[dst_engine_id] + + for remote_index, (remote_agent_name, read) in remote_infos.items(): + num_local_blocks = len(local_block_ids) + if num_local_blocks == 0: + try: + self.nixl_wrapper.send_notif( + remote_agent_name, notif_msg=notif_id) + except Exception as e: + self._log_failure( + failure_type="notification_failed", + msg="P worker blocks will be freed after timeout.", + req_id=request_id, + error=e, + dst_engine_id=dst_engine_id, + ) + self.xfer_stats.record_failed_notification() + continue + + if read: + num_remote_blocks = len(remote_block_ids) + assert num_local_blocks <= num_remote_blocks + if num_local_blocks < num_remote_blocks: + remote_block_ids_used = remote_block_ids[ + -num_local_blocks:] else: - # If chunked, get the last block_window blocks - layer_local_block_ids = local_block_ids[-block_window:] - layer_remote_block_ids = remote_block_ids[-block_window:] + remote_block_ids_used = remote_block_ids - # Get descs ids for the layer. - layer_local_desc_ids = self._get_block_descs_ids( - self.engine_id, - layer_local_block_ids, - layer_idx, - block_size_ratio=block_size_ratio, - ) - layer_remote_desc_ids = self._get_block_descs_ids( - dst_engine_id, - layer_remote_block_ids, - layer_idx, - ) + local_xfer_side_handle = local_side_handles[remote_index] + remote_xfer_side_handle = remote_side_handles[remote_index] - local_descs_list.append(layer_local_desc_ids) - remote_descs_list.append(layer_remote_desc_ids) + local_block_descs_ids: np.ndarray + remote_block_descs_ids: np.ndarray - local_block_descs_ids = np.concatenate(local_descs_list) - remote_block_descs_ids = np.concatenate(remote_descs_list) + if not self.block_window_per_layer: + remote_block_descs_ids = self._get_block_descs_ids( + dst_engine_id, remote_block_ids_used, + is_local=False, + remote_engine_id=dst_engine_id, + remote_index=remote_index, + ) + local_block_descs_ids = self._get_block_descs_ids( + self.engine_id, local_block_ids, + is_local=True, + remote_engine_id=dst_engine_id, + remote_index=remote_index, + block_size_ratio=block_size_ratio, + ) + else: + local_descs_list = [] + remote_descs_list = [] + for layer_idx, block_window in enumerate( + self.block_window_per_layer): + if block_window is None: + layer_local_block_ids = local_block_ids + layer_remote_block_ids = remote_block_ids_used + else: + layer_local_block_ids = local_block_ids[ + -block_window:] + layer_remote_block_ids = remote_block_ids_used[ + -block_window:] - assert len(local_block_descs_ids) == len(remote_block_descs_ids) + layer_local_desc_ids = self._get_block_descs_ids( + self.engine_id, layer_local_block_ids, layer_idx, + block_size_ratio=block_size_ratio, + ) + layer_remote_desc_ids = self._get_block_descs_ids( + dst_engine_id, layer_remote_block_ids, layer_idx, + ) + local_descs_list.append(layer_local_desc_ids) + remote_descs_list.append(layer_remote_desc_ids) - # Prepare transfer with Nixl. - handle = None - try: - handle = self.nixl_wrapper.make_prepped_xfer( - "READ", - local_xfer_side_handle, - local_block_descs_ids, - remote_xfer_side_handle, - remote_block_descs_ids, - notif_msg=notif_id, - ) + local_block_descs_ids = np.concatenate(local_descs_list) + remote_block_descs_ids = np.concatenate( + remote_descs_list) - # Begin async xfer. - self.nixl_wrapper.transfer(handle) + assert len(local_block_descs_ids) == len( + remote_block_descs_ids) - # Use handle to check completion in future step(). - self._recving_transfers[request_id].append(handle) - except Exception as e: - # mark all (logical) blocks for this request as invalid - self._log_failure( - failure_type="transfer_setup_failed", - req_id=request_id, - msg="Marking blocks as invalid", - error=e, - dst_engine_id=dst_engine_id, - remote_rank=remote_rank, - ) - if meta := self._recving_metadata.get(request_id): - self._invalid_block_ids.update(meta.local_block_ids) - self.xfer_stats.record_failed_transfer() - if handle is not None: - self.nixl_wrapper.release_xfer_handle(handle) - self._failed_recv_reqs.add(request_id) + handle = None + try: + handle = self.nixl_wrapper.make_prepped_xfer( + "READ", + local_xfer_side_handle, + local_block_descs_ids, + remote_xfer_side_handle, + remote_block_descs_ids, + notif_msg=notif_id, + ) + self.nixl_wrapper.transfer(handle) + + self._recving_transfers[request_id].append( + (handle, time.perf_counter())) + except Exception as e: + self._log_failure( + failure_type="transfer_setup_failed", + req_id=request_id, + msg="Marking blocks as invalid", + error=e, + dst_engine_id=dst_engine_id, + ) + if meta := self._recving_metadata.get(request_id): + self._invalid_block_ids.update(meta.local_block_ids) + self.xfer_stats.record_failed_transfer() + if handle is not None: + self.nixl_wrapper.release_xfer_handle(handle) + self._failed_recv_reqs.add(request_id) + else: + self.nixl_wrapper.send_notif( + remote_agent_name, notif_msg=notif_id) def get_mapped_blocks(self, block_ids, block_size_ratio): """ @@ -2415,25 +2529,52 @@ class NixlConnectorWorker: engine_id: str, block_ids: list[int], layer_idx: int | None = None, + is_local: bool = True, + remote_engine_id: str | None = None, + remote_index: int | None = None, block_size_ratio: float | None = None, ) -> np.ndarray: """ Get the descs ids for a set of block ids. - If layer_idx is provided, we use the region_ids for the given layer. - Otherwise, we use all regions. + With PP support, uses local_start_index/local_end_index for region + selection when reading from specific remote PP ranks. """ if layer_idx is None: - region_ids = np.arange(self.num_regions) + if is_local: + if (remote_engine_id is not None + and remote_index is not None + and remote_engine_id in self.local_start_index + and remote_index in self.local_start_index[ + remote_engine_id]): + si = self.local_start_index[remote_engine_id][ + remote_index] + ei = self.local_end_index[remote_engine_id][ + remote_index] + if isinstance(si, tuple): + region_ids = np.arange(0) + for start, end in zip(si, ei): + region_ids = np.concatenate( + [region_ids, np.arange(start, end)]) + else: + region_ids = np.arange(si, ei) + else: + region_ids = np.arange(self.num_regions) + else: + if (remote_engine_id is not None + and remote_index is not None + and remote_engine_id in self.remote_nums + and remote_index in self.remote_nums[ + remote_engine_id]): + region_ids = np.arange( + self.remote_nums[remote_engine_id][remote_index]) + else: + region_ids = np.arange(self.num_regions) else: assert layer_idx < self.num_layers if self.num_layers < self.num_regions: - # If we have more regions than layers, we assume that - # the regions are organized as [K0, V0, K1, V1, ...] - # and we select K_i and V_i assert 2 * self.num_layers == self.num_regions region_ids = np.arange(2 * layer_idx, 2 * layer_idx + 2) else: - # Otherwise, we assume we have MLA and select i-th layer assert self.num_layers == self.num_regions region_ids = np.arange(layer_idx, layer_idx + 1) @@ -2441,10 +2582,9 @@ class NixlConnectorWorker: if block_size_ratio is not None: num_blocks = int(num_blocks * block_size_ratio) - # Compute the desc ids for each block. region_ids = region_ids[:, None] - block_ids = np.array(block_ids)[None, :] - descs_ids = region_ids * num_blocks + block_ids + block_ids_arr = np.array(block_ids)[None, :] + descs_ids = region_ids * num_blocks + block_ids_arr return descs_ids.flatten() def _logical_to_kernel_block_ids(self, block_ids: list[int]) -> list[int]: @@ -2506,21 +2646,24 @@ class NixlConnectorWorker: def shutdown(self): """Shutdown the connector worker.""" + if not hasattr(self, "_handshake_initiation_executor"): + # error happens during init, no need to shutdown + return self._handshake_initiation_executor.shutdown(wait=False) - for handles in self._recving_transfers.values(): - for handle in handles: + for handle_list in self._recving_transfers.values(): + for handle, _ in handle_list: self.nixl_wrapper.release_xfer_handle(handle) self._recving_transfers.clear() - for handle in self.src_xfer_handles_by_block_size.values(): + for handle in self.src_registered_xfer_side_headles.values(): self.nixl_wrapper.release_dlist_handle(handle) - self.src_xfer_handles_by_block_size.clear() - for handles in self.src_xfer_handles_by_tp_ratio.values(): - for handle in handles: + self.src_registered_xfer_side_headles.clear() + for src_handles in self.src_xfer_side_handles.values(): + for handle in src_handles.values(): + self.nixl_wrapper.release_dlist_handle(handle) + self.src_xfer_side_handles.clear() + for dst_handles in self.dst_xfer_side_handles.values(): + for handle in dst_handles.values(): self.nixl_wrapper.release_dlist_handle(handle) - self.src_xfer_handles_by_tp_ratio.clear() - for dst_xfer_side_handles in self.dst_xfer_side_handles.values(): - for dst_xfer_side_handle in dst_xfer_side_handles.values(): - self.nixl_wrapper.release_dlist_handle(dst_xfer_side_handle) self.dst_xfer_side_handles.clear() for remote_agents in self._remote_agents.values(): for agent_name in remote_agents.values(): diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 61646dd..ee170d4 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -25,6 +25,7 @@ If you only need to use the distributed environment without model/pipeline import contextlib import gc +import os import pickle import weakref from collections import namedtuple @@ -33,7 +34,7 @@ from contextlib import contextmanager, nullcontext from dataclasses import dataclass from datetime import timedelta from multiprocessing import shared_memory -from typing import Any, Protocol +from typing import TYPE_CHECKING, Any, Protocol from unittest.mock import patch import torch @@ -54,6 +55,10 @@ from vllm.utils.system_utils import suppress_stdout from vllm.utils.torch_utils import ( direct_register_custom_op, ) + +if TYPE_CHECKING: + from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator + import ixformer.distributed as ixfd import vllm._custom_ops as ops @@ -327,6 +332,8 @@ class GroupCoordinator: self.rank = torch.distributed.get_rank() self.local_rank = local_rank + use_vllm_comm = os.environ.get("VLLM_FORCE_NCCL_COMM", None) not in {"1", "Y", "y"} + self_device_group = None self_cpu_group = None @@ -339,7 +346,7 @@ class GroupCoordinator: with suppress_stdout(): cpu_group = torch.distributed.new_group(ranks, backend="gloo") if self.rank in ranks: - self.ixfd_group = ixfd.init_comm_with_store(device_group) + self.ixfd_group = ixfd.init_comm_with_store(device_group) if use_vllm_comm else None self.ranks = ranks self.world_size = len(ranks) self.rank_in_group = ranks.index(self.rank) @@ -372,8 +379,7 @@ class GroupCoordinator: self.device_communicator = device_comm_cls( cpu_group=self.cpu_group, device=self.device, - # device_group=self.device_group, - device_group=self.ixfd_group if envs.VLLM_FORCE_NCCL_COMM else self.device_group, + device_group=self.ixfd_group if use_vllm_comm else self.device_group, unique_name=self.unique_name, ) @@ -385,11 +391,6 @@ class GroupCoordinator: self.cpu_group, 1 << 22, 6 ) - from vllm.platforms import current_platform - - # self.use_custom_op_call = ( - # current_platform.is_cuda_alike() or current_platform.is_tpu() - # ) self.use_custom_op_call = False self.use_cpu_custom_send_recv = current_platform.is_cpu() and hasattr( @@ -468,14 +469,12 @@ class GroupCoordinator: # only cuda uses this function, # so we don't abstract it into the base class maybe_ca_context = nullcontext() - # from vllm.distributed.device_communicators.cuda_communicator import ( - # CudaCommunicator, - # ) - from vllm.distributed.device_communicators.base_device_communicator import DeviceCommunicatorBase + from vllm.distributed.device_communicators.cuda_communicator import ( + CudaCommunicator, + ) if self.device_communicator is not None: - # assert isinstance(self.device_communicator, CudaCommunicator) - assert isinstance(self.device_communicator, DeviceCommunicatorBase) + assert isinstance(self.device_communicator, CudaCommunicator) ca_comm = self.device_communicator.ca_comm if ca_comm is not None: maybe_ca_context = ca_comm.capture() # type: ignore @@ -608,9 +607,9 @@ class GroupCoordinator: src=self.ranks[src], group=self.device_group) else: - torch.distributed.broadcast(input_, - src=self.ranks[src], - group=self.device_group) + torch.distributed.broadcast( + input_, src=self.ranks[src], group=self.device_group + ) return input_ def broadcast_object(self, obj: Any | None = None, src: int = 0): @@ -764,10 +763,9 @@ class GroupCoordinator: group=group, async_op=True) else: - handle = torch.distributed.broadcast(tensor, - src=self.ranks[src], - group=group, - async_op=True) + handle = torch.distributed.broadcast( + tensor, src=self.ranks[src], group=group, async_op=True + ) async_handles.append(handle) for async_handle in async_handles: async_handle.wait() @@ -802,10 +800,8 @@ class GroupCoordinator: async_op=True) else: handle = torch.distributed.broadcast( - tensor, - src=self.ranks[src], - group=group, - async_op=True) + tensor, src=self.ranks[src], group=group, async_op=True + ) async_handles.append(handle) tensor_dict[key] = tensor else: @@ -876,6 +872,10 @@ class GroupCoordinator: if self.world_size <= 1: return [] + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + assert dst < self.world_size, f"Invalid dst rank ({dst})" + if self.use_cpu_custom_send_recv: if self.device_communicator is None: raise ValueError("No device communicator found") @@ -893,10 +893,6 @@ class GroupCoordinator: group = self.device_group metadata_group = self.cpu_group - if dst is None: - dst = (self.rank_in_group + 1) % self.world_size - assert dst < self.world_size, f"Invalid dst rank ({dst})" - metadata_list, tensor_list = _split_tensor_dict(tensor_dict) self.send_object(metadata_list, dst=dst) @@ -917,6 +913,7 @@ class GroupCoordinator: handle = torch.distributed.isend( tensor, dst=self.ranks[dst], group=comm_group ) + if tensor.is_cuda: tensor.record_stream(torch.cuda.current_stream(tensor.device)) handles.append(handle) @@ -973,6 +970,11 @@ class GroupCoordinator: ]: if not torch.distributed.is_initialized() or self.world_size == 1: return None, [], [] + + if src is None: + src = (self.rank_in_group - 1) % self.world_size + assert src < self.world_size, f"Invalid src rank ({src})" + if self.use_cpu_custom_send_recv: if self.device_communicator is None: raise ValueError("No device communicator found") @@ -990,10 +992,6 @@ class GroupCoordinator: group = self.device_group metadata_group = self.cpu_group - if src is None: - src = (self.rank_in_group - 1) % self.world_size - assert src < self.world_size, f"Invalid src rank ({src})" - recv_metadata_list = self.recv_object(src=src) tensor_dict: dict[str, Any] = {} handles: list[Handle] = [] @@ -1072,14 +1070,13 @@ class GroupCoordinator: return self.device_communicator.recv(size, dtype, src) def destroy(self): - if hasattr(self, "device_group"): - # torch.distributed.destroy_process_group(self.device_group) + if self.device_group is not None: if self.device_communicator and self.device_communicator.use_vllm_comm: ixfd.destroy_process_group(self.device_group) else: torch.distributed.destroy_process_group(self.device_group) - del self.device_group - if hasattr(self, "cpu_group"): + self.device_group = None + if self.cpu_group is not None: torch.distributed.destroy_process_group(self.cpu_group) del self.cpu_group if self.device_communicator is not None: @@ -1094,7 +1091,6 @@ class GroupCoordinator: def dispatch_router_logits( self, hidden_states: torch.Tensor, - extra_residual:torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, extra_tensors: list[torch.Tensor] | None = None, @@ -1105,13 +1101,12 @@ class GroupCoordinator: if self.device_communicator is not None: return self.device_communicator.dispatch_router_logits( hidden_states, - extra_residual, router_logits, is_sequence_parallel, extra_tensors, ) else: - return hidden_states, extra_residual, router_logits + return hidden_states, router_logits def dispatch( self, @@ -1189,6 +1184,55 @@ def init_model_parallel_group( ) +def _init_stateless_group( + group_ranks: list[list[int]], + group_name: str, + group_ports: list[list[int]], + host: str, + backend: str, + use_device_communicator: bool = True, +) -> "StatelessGroupCoordinator": + """Create a StatelessGroupCoordinator with the given parameters.""" + from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator + + world = get_world_group() + return StatelessGroupCoordinator( + group_ranks=group_ranks, + local_rank=world.local_rank, + torch_distributed_backend=backend, + use_device_communicator=use_device_communicator, + group_name=group_name, + host=host, + group_ports=group_ports, + global_rank=world.rank, + global_world_size=world.world_size, + ) + + +def _replace_active_groups( + *, + world: GroupCoordinator | None, + dp: GroupCoordinator | None, + ep: GroupCoordinator | None, + eplb: GroupCoordinator | None, + node_count: int | None, +) -> None: + """Destroy the current DP/EP/WORLD/EPLB groups and replace them. + + Destruction is collective — all ranks in the old groups must call this + function together. Pass all-``None`` to tear down without replacement. + """ + global _WORLD, _DP, _EP, _EPLB, _NODE_COUNT + for group in (_DP, _EP, _WORLD, _EPLB): + if group is not None: + group.destroy() + _WORLD = world + _DP = dp + _EP = ep + _EPLB = eplb + _NODE_COUNT = node_count + + _TP: GroupCoordinator | None = None @@ -1286,6 +1330,39 @@ def set_custom_all_reduce(enable: bool): _ENABLE_CUSTOM_ALL_REDUCE = enable +def _init_elastic_ep_world( + config, local_rank: int, backend: str, rank: int, world_size: int +) -> None: + from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator + + global _WORLD, _NODE_COUNT + assert _WORLD is None, "world group already initialized" + parallel_config = config.parallel_config + global_rank = parallel_config.data_parallel_rank * world_size + rank + global_world_size = parallel_config.world_size_across_dp + all_ranks = list(range(global_world_size)) + group_ranks = [all_ranks[i : i + 1] for i in range(global_world_size)] + if global_rank in all_ranks: + group_ranks = [all_ranks] + group_ports = [parallel_config.get_next_stateless_world_group_port()] + world = StatelessGroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + use_device_communicator=False, + group_name="world", + host=parallel_config.data_parallel_master_ip, + group_ports=group_ports, + global_rank=global_rank, + global_world_size=global_world_size, + ) + assert parallel_config.nnodes_within_dp == 1, ( + "Elastic EP is not supported with multi-node TP/PP" + ) + _NODE_COUNT = _node_count(world.tcp_store_group) + _WORLD = world + + def init_distributed_environment( world_size: int = -1, rank: int = -1, @@ -1305,6 +1382,7 @@ def init_distributed_environment( from vllm.config import get_current_vllm_config_or_none config = get_current_vllm_config_or_none() + enable_elastic_ep = config is not None and config.parallel_config.enable_elastic_ep if ( config is not None and config.parallel_config.distributed_executor_backend != "external_launcher" @@ -1312,6 +1390,7 @@ def init_distributed_environment( config.parallel_config.nnodes > 1 or config.parallel_config.data_parallel_size > 1 ) + and not enable_elastic_ep ): parallel_config = config.parallel_config # adjust to take into account data parallelism @@ -1365,6 +1444,18 @@ def init_distributed_environment( rank=rank, timeout=timeout, ) + if enable_elastic_ep: + tp_pp_cpu_group = torch.distributed.new_group( + backend="gloo", timeout=timeout + ) + if _node_count(tp_pp_cpu_group) > 1: + # NOTE(yongji): StatelessGroupCoordinator uses data_parallel_master_ip + # to initialize all DP/EP groups, hence all ranks within TP/PP group + # must reside on the same node + raise RuntimeError( + "Elastic EP is not yet supported with multi-node TP/PP" + ) + # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 @@ -1373,6 +1464,9 @@ def init_distributed_environment( # setting, where we can use rank as local rank local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank global _WORLD, _NODE_COUNT, _INNER_DP_WORLD + if enable_elastic_ep: + _init_elastic_ep_world(config, local_rank, backend, rank, world_size) + return if _WORLD is None: ranks = list(range(torch.distributed.get_world_size())) _WORLD = init_world_group(ranks, local_rank, backend) @@ -1436,16 +1530,33 @@ def initialize_model_parallel( """ # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() - world_size: int = torch.distributed.get_world_size() - rank = torch.distributed.get_rank() - backend = backend or torch.distributed.get_backend(get_world_group().device_group) - data_parallel_size = 1 - from vllm.config import get_current_vllm_config_or_none + from vllm.config import get_current_vllm_config - config = get_current_vllm_config_or_none() - if config is not None: - data_parallel_size = config.parallel_config.data_parallel_size + config = get_current_vllm_config() + data_parallel_size = config.parallel_config.data_parallel_size + enable_elastic_ep = config.parallel_config.enable_elastic_ep + if enable_elastic_ep: + # Use stateless world group for global information + world_size = get_world_group().world_size + rank = get_world_group().rank + backend = backend or "nccl" + tp_pp_pcp_size = ( + tensor_model_parallel_size + * pipeline_model_parallel_size + * prefill_context_model_parallel_size + ) + local_all_ranks = torch.arange(tp_pp_pcp_size).reshape( + pipeline_model_parallel_size, + prefill_context_model_parallel_size, + tensor_model_parallel_size, + ) + else: + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + backend = backend or torch.distributed.get_backend( + get_world_group().device_group + ) # the layout order is: ExternalDP x DP x PP x TP # ExternalDP is the data parallel group that is not part of the model, @@ -1469,7 +1580,9 @@ def initialize_model_parallel( assert _TP is None, "tensor model parallel group is already initialized" group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] - + if enable_elastic_ep: + group_ranks = local_all_ranks.view(-1, tensor_model_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] # message queue broadcaster is only used in tensor model parallel group _TP = init_model_parallel_group( group_ranks, @@ -1488,6 +1601,11 @@ def initialize_model_parallel( # TP group into tp_size//dcp_size DCP groups. group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] + if enable_elastic_ep: + group_ranks = local_all_ranks.reshape( + -1, decode_context_model_parallel_size + ).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] _DCP = init_model_parallel_group( group_ranks, get_world_group().local_rank, @@ -1504,6 +1622,13 @@ def initialize_model_parallel( .unbind(0) ) group_ranks = [x.tolist() for x in group_ranks] + if enable_elastic_ep: + group_ranks = ( + local_all_ranks.transpose(1, 2) + .reshape(-1, prefill_context_model_parallel_size) + .unbind(0) + ) + group_ranks = [x.tolist() for x in group_ranks] _PCP = init_model_parallel_group( group_ranks, get_world_group().local_rank, backend, group_name="pcp" ) @@ -1515,6 +1640,13 @@ def initialize_model_parallel( all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0) ) group_ranks = [x.tolist() for x in group_ranks] + if enable_elastic_ep: + group_ranks = ( + local_all_ranks.transpose(0, 2) + .reshape(-1, pipeline_model_parallel_size) + .unbind(0) + ) + group_ranks = [x.tolist() for x in group_ranks] _PP = init_model_parallel_group( group_ranks, get_world_group().local_rank, backend, group_name="pp" ) @@ -1523,14 +1655,27 @@ def initialize_model_parallel( assert _DP is None, "data parallel group is already initialized" group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] - _DP = init_model_parallel_group( - group_ranks, get_world_group().local_rank, backend, group_name="dp" - ) + if enable_elastic_ep: + parallel_config = config.parallel_config + dp_ports = [ + parallel_config.get_next_stateless_dp_group_port() for _ in group_ranks + ] + _DP = _init_stateless_group( + group_ranks, + "dp", + dp_ports, + parallel_config.data_parallel_master_ip, + backend, + ) + else: + _DP = init_model_parallel_group( + group_ranks, get_world_group().local_rank, backend, group_name="dp" + ) global _EP assert _EP is None, "expert parallel group is already initialized" # Don't create EP group for dense models. - if config is None or config.model_config is None or config.model_config.is_moe: + if config.model_config is None or config.model_config.is_moe: group_ranks = ( all_ranks.transpose(1, 2) .reshape( @@ -1542,9 +1687,22 @@ def initialize_model_parallel( .unbind(0) ) group_ranks = [x.tolist() for x in group_ranks] - _EP = init_model_parallel_group( - group_ranks, get_world_group().local_rank, backend, group_name="ep" - ) + if enable_elastic_ep: + parallel_config = config.parallel_config + ep_ports = [ + parallel_config.get_next_stateless_ep_group_port() for _ in group_ranks + ] + _EP = _init_stateless_group( + group_ranks, + "ep", + ep_ports, + parallel_config.data_parallel_master_ip, + backend, + ) + else: + _EP = init_model_parallel_group( + group_ranks, get_world_group().local_rank, backend, group_name="ep" + ) # Create EPLB group with the same ranks as EP if EPLB is enabled. # This is a separate process group to isolate EPLB communications @@ -1557,10 +1715,25 @@ def initialize_model_parallel( and config.parallel_config is not None and config.parallel_config.enable_eplb ): - # Reuse the same group_ranks from EP - _EPLB = init_model_parallel_group( - group_ranks, get_world_group().local_rank, backend, group_name="eplb" - ) + if enable_elastic_ep: + eplb_ports = [ + parallel_config.get_next_stateless_eplb_group_port() + for _ in group_ranks + ] + _EPLB = _init_stateless_group( + group_ranks, + "eplb", + eplb_ports, + parallel_config.data_parallel_master_ip, + backend, + ) + else: + _EPLB = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + group_name="eplb", + ) # If no EP group needed, _EP remains None # If no EPLB group needed, _EPLB remains None @@ -1590,7 +1763,11 @@ def ensure_model_parallel_initialized( or ensure tensor-parallel and pipeline-parallel sizes are equal to expected values if the model parallel groups are initialized. """ - backend = backend or torch.distributed.get_backend(get_world_group().device_group) + world_group = get_world_group() + if hasattr(world_group, "backend"): + backend = backend or world_group.backend + else: + backend = backend or torch.distributed.get_backend(world_group.device_group) if not model_parallel_is_initialized(): initialize_model_parallel( tensor_model_parallel_size, diff --git a/vllm/distributed/stateless_coordinator.py b/vllm/distributed/stateless_coordinator.py new file mode 100644 index 0000000..f2126fd --- /dev/null +++ b/vllm/distributed/stateless_coordinator.py @@ -0,0 +1,322 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any, Optional + +import torch +from torch.distributed import Backend, ProcessGroup + +from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator +from vllm.distributed.parallel_state import ( + GroupCoordinator, + TensorMetadata, + _get_unique_name, + _register_group, + _split_tensor_dict, +) +from vllm.distributed.utils import ( + StatelessProcessGroup, + stateless_destroy_torch_distributed_process_group, + stateless_init_torch_distributed_process_group, +) +from vllm.logger import init_logger +from vllm.utils.import_utils import resolve_obj_by_qualname + +logger = init_logger(__name__) + + +class StatelessGroupCoordinator(GroupCoordinator): + """ + A stateless version of the GroupCoordinator class in parallel_state, + It will create CPU, device and TCPStore based communication groups + that are independent of PyTorch's WORLD group. Hence, + communication groups with a different set of participants GPUs + can be created without destroying the existing ones. + """ + + def __init__( + self, + group_ranks: list[list[int]], + local_rank: int, + torch_distributed_backend: str | Backend, + use_device_communicator: bool, + use_message_queue_broadcaster: bool = False, + group_name: str | None = None, + host: str = "127.0.0.1", + group_ports: list[list[int]] | None = None, + global_rank: int = 0, + global_world_size: int = 1, + ): + group_name = group_name or "anonymous" + self.unique_name = _get_unique_name(group_name) + _register_group(self) + + self.rank = global_rank + self.local_rank = local_rank + + self_device_group = None + self_cpu_group = None + self_tcp_store_group = None + + from vllm.platforms import current_platform + + backend = str(torch_distributed_backend) + self.backend = backend + assert group_ports is not None, "group_ports is not provided" + for idx, ranks in enumerate(group_ranks): + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + + ports = group_ports[idx] + device_port = ports[0] + cpu_port = ports[1] + tcp_store_port = ports[2] + + device_group = stateless_init_torch_distributed_process_group( + host=host, + port=device_port, + rank=self.rank_in_group, + world_size=self.world_size, + backend=backend, + group_name=f"{self.unique_name}_device", + ) + cpu_group = stateless_init_torch_distributed_process_group( + host=host, + port=cpu_port, + rank=self.rank_in_group, + world_size=self.world_size, + backend="gloo", + group_name=f"{self.unique_name}_cpu", + ) + tcp_store_group = StatelessProcessGroup.create( + host=host, + port=tcp_store_port, + rank=self.rank_in_group, + world_size=self.world_size, + ) + + self_device_group = device_group + self_cpu_group = cpu_group + self_tcp_store_group = tcp_store_group + + assert self_cpu_group is not None + assert self_device_group is not None + assert self_tcp_store_group is not None + + self.cpu_group = self_cpu_group + self.device_group = self_device_group + self.tcp_store_group = self_tcp_store_group + + if current_platform.is_cuda_alike(): + self.device = torch.device(f"cuda:{local_rank}") + elif current_platform.is_xpu(): + self.device = torch.device(f"xpu:{local_rank}") + elif current_platform.is_out_of_tree(): + self.device = torch.device(f"{current_platform.device_name}:{local_rank}") + else: + self.device = torch.device("cpu") + + self.use_device_communicator = use_device_communicator + self.device_communicator = None + if use_device_communicator and self.world_size > 1: + device_comm_cls = resolve_obj_by_qualname( + current_platform.get_device_communicator_cls() + ) + assert device_comm_cls == CudaCommunicator + self.device_communicator = CudaCommunicator( + cpu_group=self.cpu_group, + device=self.device, + device_group=self.device_group, + unique_name=self.unique_name, + global_ranks=self.ranks, + global_world_size=global_world_size, + tcp_store_group=self.tcp_store_group, + ) + + self.mq_broadcaster = None + + self.use_custom_op_call = ( + current_platform.is_cuda_alike() or current_platform.is_tpu() + ) + self.use_cpu_custom_send_recv = False + + def destroy(self): + if self.device_communicator: + self.device_communicator.destroy() + if self.device_group: + stateless_destroy_torch_distributed_process_group(self.device_group) + if self.cpu_group: + stateless_destroy_torch_distributed_process_group(self.cpu_group) + + def size(self) -> int: + """Return the world size of this group.""" + return self.world_size + + def broadcast(self, input_: torch.Tensor, src: int = 0): + if self.world_size == 1: + return input_ + + if self.device_communicator and input_.is_cuda: + return self.device_communicator.broadcast(input_, src) + else: + return self.tcp_store_group.broadcast(input_, src) + + def broadcast_object(self, obj=None, src: int = 0): + if self.world_size == 1: + return obj + return self.tcp_store_group.broadcast_obj(obj, src) + + def broadcast_object_list( + self, obj_list: list[Any], src: int = 0, group: ProcessGroup | None = None + ): + assert src < self.world_size + + if self.world_size == 1: + return obj_list + + if self.rank_in_group == src: + for obj in obj_list: + self.tcp_store_group.broadcast_obj(obj, src) + else: + for i in range(len(obj_list)): + obj_list[i] = self.tcp_store_group.broadcast_obj(None, src) + + return obj_list + + def broadcast_tensor_dict( + self, + tensor_dict: dict[str, torch.Tensor | Any] | None = None, + src: int = 0, + group: ProcessGroup | None = None, + metadata_group: ProcessGroup | None = None, + ) -> dict[str, torch.Tensor | Any] | None: + if self.world_size == 1: + return tensor_dict + + if self.rank_in_group == src: + assert isinstance(tensor_dict, dict), ( + f"Expecting a dictionary, got {type(tensor_dict)}" + ) + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + else: + metadata_list = None + tensor_list = [] + + recv_metadata_list: list[tuple[str, Any]] = self.tcp_store_group.broadcast_obj( + metadata_list, src + ) + + if self.rank_in_group != src: + tensor_dict = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty( + value.size, dtype=value.dtype, device=value.device + ) + tensor_list.append(tensor) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + + for tensor in tensor_list: + if tensor.numel() == 0: + continue + if self.device_communicator and tensor.is_cuda: + tensor.copy_(self.device_communicator.broadcast(tensor, src)) + else: + tensor.copy_(self.tcp_store_group.broadcast(tensor, src)) + + return tensor_dict + + def send_object(self, obj, dst: int) -> None: + assert dst < self.world_size + assert dst != self.rank_in_group + self.tcp_store_group.send_obj(obj, dst) + + def recv_object(self, src: int): + assert src < self.world_size + assert src != self.rank_in_group + return self.tcp_store_group.recv_obj(src) + + def send_tensor_dict( + self, + tensor_dict: dict[str, torch.Tensor | Any], + dst: int | None = None, + all_gather_group: Optional["GroupCoordinator"] = None, + all_gather_tensors: dict[str, bool] | None = None, + ) -> dict[str, torch.Tensor | Any] | None: + if self.world_size == 1: + return tensor_dict + + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + assert dst < self.world_size + + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + self.tcp_store_group.send_obj(metadata_list, dst) + + for tensor in tensor_list: + if tensor.numel() == 0: + continue + if self.device_communicator and tensor.is_cuda: + self.device_communicator.send(tensor, dst) + else: + self.tcp_store_group.send(tensor, dst) + + return None + + def recv_tensor_dict( + self, + src: int | None = None, + all_gather_group: Optional["GroupCoordinator"] = None, + all_gather_tensors: dict[str, bool] | None = None, + ) -> dict[str, torch.Tensor | Any] | None: + if self.world_size == 1: + return None + + if src is None: + src = (self.rank_in_group - 1) % self.world_size + assert src < self.world_size + + recv_metadata_list = self.tcp_store_group.recv_obj(src) + tensor_dict = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, dtype=value.dtype, device=value.device) + if tensor.numel() > 0: + if self.device_communicator and tensor.is_cuda: + tensor = self.device_communicator.recv( + tensor.size(), tensor.dtype, src + ) + else: + tensor = self.tcp_store_group.recv(tensor, src) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + return tensor_dict + + def barrier(self): + self.tcp_store_group.barrier() + + def gather( + self, input_: torch.Tensor, dst: int = 0, dim: int = -1 + ) -> torch.Tensor | None: + if self.world_size == 1: + return input_ + + if self.device_communicator is None: + raise ValueError("No device communicator found") + + if self.rank_in_group == dst: + gathered_list = [torch.empty_like(input_) for _ in range(self.world_size)] + gathered_list[self.rank_in_group] = input_ + for src_rank in range(self.world_size): + if src_rank != self.rank_in_group: + gathered_list[src_rank] = self.device_communicator.recv( + input_.size(), input_.dtype, src_rank + ) + return torch.cat(gathered_list, dim=dim) + else: + self.device_communicator.send(input_, dst) + return None diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 1737525..102f2f7 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -18,7 +18,7 @@ from datetime import timedelta from typing import Any import torch -from torch.distributed import ProcessGroup, TCPStore +from torch.distributed import ProcessGroup, Store, TCPStore from torch.distributed.distributed_c10d import ( Backend, PrefixStore, @@ -228,6 +228,55 @@ class StatelessProcessGroup: gathered_objs.append(recv_obj) return gathered_objs + def broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor: + """Broadcast a tensor from source rank to all other ranks.""" + if self.rank == src: + tensor_bytes = pickle.dumps(tensor) + self.expire_data() + key = f"broadcast_tensor/{src}/{self.broadcast_send_counter}" + self.store.set(key, tensor_bytes) + self.broadcast_send_counter += 1 + self.entries.append((key, time.time())) + return tensor + else: + key = f"broadcast_tensor/{src}/{self.broadcast_recv_src_counter[src]}" + tensor = pickle.loads(self.store.get(key)) + self.broadcast_recv_src_counter[src] += 1 + return tensor + + def send(self, tensor: torch.Tensor, dst: int): + """Send a tensor to a destination rank.""" + self.expire_data() + key = f"send_tensor/{dst}/{self.send_dst_counter[dst]}" + self.store.set(key, pickle.dumps(tensor)) + self.send_dst_counter[dst] += 1 + self.entries.append((key, time.time())) + + def recv(self, tensor: torch.Tensor, src: int) -> torch.Tensor: + """Receive a tensor from a source rank.""" + key = f"send_tensor/{self.rank}/{self.recv_src_counter[src]}" + received = pickle.loads(self.store.get(key)) + self.recv_src_counter[src] += 1 + tensor.copy_(received) + return tensor + + def all_reduce( + self, tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM + ) -> torch.Tensor: + """All-reduce a tensor across all ranks.""" + tensors = self.all_gather_obj(tensor) + result = tensors[0].clone() + for t in tensors[1:]: + if op == torch.distributed.ReduceOp.SUM: + result.add_(t) + elif op == torch.distributed.ReduceOp.PRODUCT: + result.mul_(t) + elif op == torch.distributed.ReduceOp.MAX: + result = torch.maximum(result, t) + elif op == torch.distributed.ReduceOp.MIN: + result = torch.minimum(result, t) + return result + def barrier(self, timeout: float = 30.0): """A robust barrier to synchronize all ranks. @@ -448,8 +497,14 @@ def init_gloo_process_group( def stateless_init_torch_distributed_process_group( - host: str, port: int, rank: int, world_size: int, backend: str -) -> ProcessGroup: + host: str, + port: int, + rank: int, + world_size: int, + backend: str, + group_name: str | None = None, + return_store: bool = False, +) -> ProcessGroup | tuple[ProcessGroup, Store]: """ A replacement for `torch.distributed.init_process_group` that does not pollute the global state. The created ProcessGroup object can be used for @@ -496,25 +551,35 @@ def stateless_init_torch_distributed_process_group( # Use a PrefixStore to avoid accidental overrides of keys used by # different systems (e.g. RPC) in case the store is multi-tenant. prefix_store = PrefixStore(init_method, store) - try: + + if backend == "gloo": + pg = init_gloo_process_group( + prefix_store=prefix_store, + group_rank=group_rank, + group_size=group_size, + timeout=timeout, + ) + else: from vllm.platforms import current_platform - return current_platform.stateless_init_device_torch_dist_pg( + pg = current_platform.stateless_init_device_torch_dist_pg( backend=backend, prefix_store=prefix_store, group_rank=group_rank, group_size=group_size, timeout=timeout, ) - except NotImplementedError: - # If platform doesn't implement stateless_init_device_torch_dist_pg, it - # will raise a NotImplementedError. In this case, we fall back to gloo. - return init_gloo_process_group( - prefix_store=prefix_store, - group_rank=group_rank, - group_size=group_size, - timeout=timeout, - ) + + if group_name is not None: + from torch._C._distributed_c10d import _register_process_group + + pg._set_group_name(group_name) + _register_process_group(group_name, pg) + + if return_store: + return pg, store + else: + return pg def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None: diff --git a/vllm/distributed/weight_transfer/base.py b/vllm/distributed/weight_transfer/base.py index b87f190..788dcef 100644 --- a/vllm/distributed/weight_transfer/base.py +++ b/vllm/distributed/weight_transfer/base.py @@ -3,7 +3,7 @@ """Base class for weight transfer engines.""" from abc import ABC, abstractmethod -from collections.abc import Callable +from collections.abc import Callable, Iterator from dataclasses import KW_ONLY, dataclass, field from typing import Any, Generic, TypeVar @@ -156,3 +156,30 @@ class WeightTransferEngine(ABC, Generic[TInitInfo, TUpdateInfo]): This should be called when the worker is shutting down. """ raise NotImplementedError + + @staticmethod + @abstractmethod + def trainer_send_weights( + iterator: Iterator[tuple[str, torch.Tensor]], + trainer_args: dict[str, Any] | Any, + ) -> None: + """ + Send weights from trainer to inference workers. + + This is a static method that can be called from the trainer process + to send weights to all inference workers. + + Args: + iterator: Iterator of model parameters. Returns (name, tensor) tuples. + The tensors should be on the appropriate device for the backend. + trainer_args: Dictionary containing backend-specific arguments needed + to send weights. The structure depends on the backend: + - NCCL: Contains 'group', 'src', 'packed', etc. + - IPC: Contains 'mode' ('http' or 'ray'), + 'llm_handle' (for Ray), 'url' (for HTTP), etc. + + Example: + >>> param_iter = ((n, p) for n, p in model.named_parameters()) + >>> engine.trainer_send_weights(param_iter, trainer_args) + """ + raise NotImplementedError diff --git a/vllm/distributed/weight_transfer/factory.py b/vllm/distributed/weight_transfer/factory.py index 7235e30..f8e9c86 100644 --- a/vllm/distributed/weight_transfer/factory.py +++ b/vllm/distributed/weight_transfer/factory.py @@ -114,3 +114,9 @@ WeightTransferEngineFactory.register_engine( "vllm.distributed.weight_transfer.nccl_engine", "NCCLWeightTransferEngine", ) + +WeightTransferEngineFactory.register_engine( + "ipc", + "vllm.distributed.weight_transfer.ipc_engine", + "IPCWeightTransferEngine", +) diff --git a/vllm/distributed/weight_transfer/ipc_engine.py b/vllm/distributed/weight_transfer/ipc_engine.py new file mode 100644 index 0000000..2edbec6 --- /dev/null +++ b/vllm/distributed/weight_transfer/ipc_engine.py @@ -0,0 +1,291 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""IPC-based weight transfer engine using CUDA IPC for communication.""" + +import base64 +import pickle +from collections.abc import Callable, Iterator +from dataclasses import asdict, dataclass +from typing import Any + +import requests +import torch +from torch.multiprocessing.reductions import reduce_tensor + +from vllm.config.parallel import ParallelConfig +from vllm.config.weight_transfer import WeightTransferConfig +from vllm.distributed.weight_transfer.base import ( + WeightTransferEngine, + WeightTransferInitInfo, + WeightTransferUpdateInfo, +) + + +@dataclass +class IPCTrainerSendWeightsArgs: + """Arguments for IPC trainer_send_weights method.""" + + mode: str + """Transport mode: 'http' or 'ray'.""" + llm_handle: Any = None + """Ray ObjectRef to LLM handle (required for 'ray' mode).""" + url: str | None = None + """Base URL for HTTP endpoint (required for 'http' mode).""" + + def __post_init__(self): + """Validate that required arguments are provided for the selected mode.""" + if self.mode == "ray" and self.llm_handle is None: + raise ValueError("llm_handle is required for 'ray' mode") + if self.mode == "http" and self.url is None: + raise ValueError("url is required for 'http' mode") + if self.mode not in ("ray", "http"): + raise ValueError(f"mode must be 'ray' or 'http', got {self.mode}") + + +@dataclass +class IPCWeightTransferInitInfo(WeightTransferInitInfo): + """Initialization info for IPC weight transfer backend. No init needed for IPC.""" + + pass + + +@dataclass +class IPCWeightTransferUpdateInfo(WeightTransferUpdateInfo): + """Update info for IPC weight transfer backend. + + Accepts IPC handles either directly via ``ipc_handles`` (Ray transport) + or as a base64-encoded pickle via ``ipc_handles_pickled`` (HTTP transport). + Exactly one of the two must be provided; if ``ipc_handles_pickled`` is set + it is unpickled into ``ipc_handles`` during ``__post_init__``. + """ + + names: list[str] + dtype_names: list[str] + shapes: list[list[int]] + ipc_handles: list[dict[str, tuple[Callable, tuple]]] | None = None + """IPC handles mapping physical GPU UUID to (func, args) tuple. + Each handle is a dictionary mapping GPU UUID strings to IPC handle tuples.""" + ipc_handles_pickled: str | None = None + """Base64-encoded pickled IPC handles, used for HTTP transport.""" + + def __post_init__(self): + if self.ipc_handles_pickled is not None: + if self.ipc_handles is not None: + raise ValueError( + "Cannot specify both `ipc_handles` and `ipc_handles_pickled`" + ) + self.ipc_handles = pickle.loads(base64.b64decode(self.ipc_handles_pickled)) + self.ipc_handles_pickled = None + + if self.ipc_handles is None: + raise ValueError( + "Either `ipc_handles` or `ipc_handles_pickled` must be provided" + ) + + num_params = len(self.names) + if len(self.dtype_names) != num_params: + raise ValueError( + f"`dtype_names` should be of the same size as `names`: " + f"got {len(self.dtype_names)} and {len(self.names)}" + ) + if len(self.shapes) != num_params: + raise ValueError( + f"`shapes` should be of the same size as `names`: " + f"got {len(self.shapes)} and {len(self.names)}" + ) + if len(self.ipc_handles) != num_params: + raise ValueError( + f"`ipc_handles` should be of the same size as `names`: " + f"got {len(self.ipc_handles)} and {len(self.names)}" + ) + + +class IPCWeightTransferEngine( + WeightTransferEngine[IPCWeightTransferInitInfo, IPCWeightTransferUpdateInfo] +): + """ + Weight transfer engine using CUDA IPC for communication between trainer and workers. + + This implementation uses CUDA IPC to transfer weights from the trainer (rank 0) + to all inference workers in a process group. IPC handles are used to share + memory between processes on the same node. + """ + + # Define backend-specific dataclass types + init_info_cls = IPCWeightTransferInitInfo + update_info_cls = IPCWeightTransferUpdateInfo + + def __init__( + self, config: WeightTransferConfig, parallel_config: ParallelConfig + ) -> None: + """ + Initialize the IPC weight transfer engine. + + Args: + config: The configuration for the weight transfer engine + parallel_config: The configuration for the parallel setup + """ + super().__init__(config, parallel_config) + + def init_transfer_engine(self, init_info: IPCWeightTransferInitInfo) -> None: + """ + Initialize the weight transfer mechanism. + This is called once at the beginning of training. + No initialization needed for IPC backend. + + Args: + init_info: IPC initialization info (empty) + """ + pass + + def receive_weights( + self, + update_info: IPCWeightTransferUpdateInfo, + load_weights: Callable[[list[tuple[str, torch.Tensor]]], None], + ) -> None: + """ + Receive weights from the trainer via CUDA IPC handles. + + Args: + update_info: IPC update info containing parameter names, dtypes, shapes, + and IPC handles. Each IPC handle is a mapping between physical + GPU UUID and the IPC handle tuple (func, args). + load_weights: Callable that loads weights into the model. Called + incrementally for each weight to avoid OOM. + """ + assert update_info.ipc_handles is not None + weights = [] + for name, _dtype_name, _shape, ipc_handle in zip( + update_info.names, + update_info.dtype_names, + update_info.shapes, + update_info.ipc_handles, + ): + device_index = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device_index) + physical_gpu_id = str(props.uuid) + + if physical_gpu_id not in ipc_handle: + raise ValueError( + f"IPC handle not found for GPU UUID {physical_gpu_id}. " + f"Available UUIDs: {list(ipc_handle.keys())}" + ) + + handle = ipc_handle[physical_gpu_id] + + func, args = handle + list_args = list(args) # type: ignore + # Index 6 is the device_index parameter in torch's + # IPC handle tuple (rebuild_cuda_tensor). Update it + # to the current device since the logical index can + # differ between sender and receiver. + list_args[6] = device_index + weight = func(*list_args) # type: ignore + weights.append((name, weight)) + + load_weights(weights) + + def shutdown(self) -> None: + """ + Shutdown the weight transfer engine. + """ + pass + + @staticmethod + def trainer_send_weights( + iterator: Iterator[tuple[str, torch.Tensor]], + trainer_args: dict[str, Any] | IPCTrainerSendWeightsArgs, + ) -> None: + """ + Send weights from trainer to inference workers via CUDA IPC. + + Supports two modes: + - 'ray': Sends weights via Ray RPC to a Ray-based LLM handle + - 'http': Sends weights via HTTP POST to a vLLM HTTP server + + Args: + iterator: Iterator of model parameters. Returns (name, tensor) tuples. + Tensors should be on the same GPU as the inference workers. + trainer_args: Dictionary containing IPC-specific arguments. + Should contain keys from IPCTrainerSendWeightsArgs: + - mode: 'ray' or 'http' + - llm_handle: Ray ObjectRef (for 'ray' mode) + - url: Base URL string (for 'http' mode) + + Example (Ray mode): + >>> from vllm.distributed.weight_transfer.ipc_engine import ( + ... IPCWeightTransferEngine, + ... IPCTrainerSendWeightsArgs, + ... ) + >>> param_iter = ((n, p) for n, p in model.named_parameters()) + >>> args = IPCTrainerSendWeightsArgs(mode="ray", llm_handle=llm_handle) + >>> IPCWeightTransferEngine.trainer_send_weights(param_iter, asdict(args)) + + Example (HTTP mode): + >>> args = IPCTrainerSendWeightsArgs( + ... mode="http", url="http://localhost:8000" + ... ) + >>> IPCWeightTransferEngine.trainer_send_weights(param_iter, asdict(args)) + """ + # Parse trainer args - accept either dict or dataclass instance + if isinstance(trainer_args, dict): + args = IPCTrainerSendWeightsArgs(**trainer_args) + else: + args = trainer_args + + # Get physical GPU UUID + device_index = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device_index) + gpu_uuid = str(props.uuid) + + # Collect weight metadata and create IPC handles + names = [] + dtype_names = [] + shapes = [] + ipc_handles = [] + + for name, tensor in iterator: + names.append(name) + dtype_names.append(str(tensor.dtype).split(".")[-1]) + shapes.append(list(tensor.shape)) + + # Create IPC handle for this weight tensor + # The tensor must remain in memory for IPC to work + weight = tensor.detach().contiguous() + ipc_handle = reduce_tensor(weight) + ipc_handles.append({gpu_uuid: ipc_handle}) + + # Send weights based on mode + if args.mode == "ray": + # Ray mode: send via Ray RPC + import ray + + update_info = asdict( + IPCWeightTransferUpdateInfo( + names=names, + dtype_names=dtype_names, + shapes=shapes, + ipc_handles=ipc_handles, + ) + ) + ray.get( + args.llm_handle.update_weights.remote(dict(update_info=update_info)) + ) + elif args.mode == "http": + # HTTP mode: send via HTTP POST with pickled handles + # Pickle and base64 encode IPC handles for HTTP transmission + pickled_handles = base64.b64encode(pickle.dumps(ipc_handles)).decode( + "utf-8" + ) + + url = f"{args.url}/update_weights" + payload = { + "update_info": { + "names": names, + "dtype_names": dtype_names, + "shapes": shapes, + "ipc_handles_pickled": pickled_handles, + } + } + response = requests.post(url, json=payload, timeout=300) + response.raise_for_status() diff --git a/vllm/distributed/weight_transfer/nccl_engine.py b/vllm/distributed/weight_transfer/nccl_engine.py index 5c90198..e8a1091 100644 --- a/vllm/distributed/weight_transfer/nccl_engine.py +++ b/vllm/distributed/weight_transfer/nccl_engine.py @@ -35,6 +35,32 @@ class NCCLWeightTransferInitInfo(WeightTransferInitInfo): world_size: int +@dataclass +class NCCLTrainerSendWeightsArgs: + """Arguments for NCCL trainer_send_weights method.""" + + group: Any + """Process group (PyNcclCommunicator) for NCCL communication.""" + src: int = 0 + """Source rank (default 0, trainer is typically rank 0).""" + post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor] | None = None + """Optional function to apply to each (name, tensor) pair before broadcasting. + If None, extracts just the tensor.""" + packed: bool = False + """Whether to use packed tensor broadcasting for efficiency. + When True, multiple tensors are batched together before broadcasting + to reduce NCCL communication overhead.""" + stream: torch.cuda.Stream | None = None + """CUDA stream to use for broadcasting if packed is False. + If packed is True, new streams will be created for each buffer.""" + packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES + """Size in bytes for each packed tensor buffer. + Must match the value used in NCCLWeightTransferUpdateInfo.""" + packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS + """Number of buffers for double/triple buffering during packed transfer. + Must match the value used in NCCLWeightTransferUpdateInfo.""" + + @dataclass class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo): """Update info for NCCL weight transfer backend.""" @@ -47,7 +73,7 @@ class NCCLWeightTransferUpdateInfo(WeightTransferUpdateInfo): When True, multiple tensors are batched together before broadcasting to reduce NCCL communication overhead.""" packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES - """Size in bytes for each packed tensor buffer. Default is 1GB. + """Size in bytes for each packed tensor buffer. Both producer and consumer must use the same value.""" packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS """Number of buffers for double/triple buffering during packed transfer. @@ -186,47 +212,38 @@ class NCCLWeightTransferEngine( @staticmethod def trainer_send_weights( iterator: Iterator[tuple[str, torch.Tensor]], - group: Any, - src: int = 0, - post_iter_func: Callable[[tuple[str, torch.Tensor]], torch.Tensor] - | None = None, - packed: bool = False, - stream: torch.cuda.Stream | None = None, - packed_buffer_size_bytes: int = DEFAULT_PACKED_BUFFER_SIZE_BYTES, - packed_num_buffers: int = DEFAULT_PACKED_NUM_BUFFERS, + trainer_args: dict[str, Any] | NCCLTrainerSendWeightsArgs, ) -> None: """Broadcast weights from trainer to vLLM workers. Args: iterator: Iterator of model parameters. Returns (name, tensor) tuples - group: Process group (PyNcclCommunicator) - src: Source rank (default 0, trainer is typically rank 0) - post_iter_func: Optional function to apply to each (name, tensor) pair - before broadcasting. If None, extracts just the tensor. - packed: Whether to use packed tensor broadcasting for efficiency. - When True, multiple tensors are batched together before - broadcasting to reduce NCCL communication overhead. - stream: CUDA stream to use for broadcasting if packed is False. - If packed is True, new streams will be created for each buffer. - packed_buffer_size_bytes: Size in bytes for each packed tensor buffer. - Must match the value used in NCCLWeightTransferUpdateInfo. - packed_num_buffers: Number of buffers for double/triple buffering. - Must match the value used in NCCLWeightTransferUpdateInfo. + trainer_args: Dictionary or NCCLTrainerSendWeightsArgs instance containing + NCCL-specific arguments. If a dict, should contain keys from + NCCLTrainerSendWeightsArgs. Example: >>> from vllm.distributed.weight_transfer.nccl_engine import ( ... NCCLWeightTransferEngine, + ... NCCLTrainerSendWeightsArgs, ... ) >>> param_iter = ((n, p) for n, p in model.named_parameters()) - >>> NCCLWeightTransferEngine.trainer_send_weights( - ... param_iter, group, packed=True - ... ) + >>> args = NCCLTrainerSendWeightsArgs(group=group, packed=True) + >>> NCCLWeightTransferEngine.trainer_send_weights(param_iter, args) """ - if post_iter_func is None: + # Parse trainer args - accept either dict or dataclass instance + if isinstance(trainer_args, dict): + args = NCCLTrainerSendWeightsArgs(**trainer_args) + else: + args = trainer_args + + if args.post_iter_func is None: # Default: extract just the tensor from (name, tensor) tuple post_iter_func = lambda x: x[1] + else: + post_iter_func = args.post_iter_func - if packed: + if args.packed: # Use packed tensor broadcasting for efficiency from vllm.distributed.weight_transfer.packed_tensor import ( packed_broadcast_producer, @@ -234,18 +251,20 @@ class NCCLWeightTransferEngine( packed_broadcast_producer( iterator=iterator, - group=group, - src=src, + group=args.group, + src=args.src, post_iter_func=post_iter_func, - buffer_size_bytes=packed_buffer_size_bytes, - num_buffers=packed_num_buffers, + buffer_size_bytes=args.packed_buffer_size_bytes, + num_buffers=args.packed_num_buffers, ) else: # Use simple one-by-one broadcasting for item in iterator: tensor = post_iter_func(item) - group.broadcast( - tensor, src=src, stream=stream or torch.cuda.current_stream() + args.group.broadcast( + tensor, + src=args.src, + stream=args.stream or torch.cuda.current_stream(), ) @staticmethod diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8c48661..c4d3c03 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -419,6 +419,7 @@ class EngineArgs: enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel moe_backend: MoEBackend = KernelConfig.moe_backend all2all_backend: All2AllBackend = ParallelConfig.all2all_backend + enable_elastic_ep: bool = ParallelConfig.enable_elastic_ep enable_dbo: bool = ParallelConfig.enable_dbo ubatch_size: int = ParallelConfig.ubatch_size dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold @@ -896,6 +897,9 @@ class EngineArgs: "--ubatch-size", **parallel_kwargs["ubatch_size"], ) + parallel_group.add_argument( + "--enable-elastic-ep", **parallel_kwargs["enable_elastic_ep"] + ) parallel_group.add_argument( "--dbo-decode-token-threshold", **parallel_kwargs["dbo_decode_token_threshold"], @@ -1321,6 +1325,7 @@ class EngineArgs: "launched vLLM.", self.seed, ) + return ModelConfig( model=self.model, model_weights=self.model_weights, @@ -1697,6 +1702,7 @@ class EngineArgs: is_moe_model=model_config.is_moe, enable_expert_parallel=self.enable_expert_parallel, all2all_backend=self.all2all_backend, + enable_elastic_ep=self.enable_elastic_ep, enable_dbo=self.enable_dbo, ubatch_size=self.ubatch_size, dbo_decode_token_threshold=self.dbo_decode_token_threshold, @@ -1905,6 +1911,7 @@ class EngineArgs: performance_mode=self.performance_mode, weight_transfer_config=self.weight_transfer_config, ) + return config def _check_feature_supported(self): @@ -2074,20 +2081,19 @@ class EngineArgs: ) # Disable chunked prefill and prefix caching for: - # POWER (ppc64le)/RISCV CPUs in V1 + # RISCV CPUs in V1 if current_platform.is_cpu() and current_platform.get_cpu_architecture() in ( - CpuArchEnum.POWERPC, CpuArchEnum.RISCV, ): logger.info( - "Chunked prefill is not supported for POWER, " - "and RISC-V CPUs; " + "Chunked prefill is not supported for" + "RISC-V CPUs; " "disabling it for V1 backend." ) self.enable_chunked_prefill = False logger.info( - "Prefix caching is not supported for POWER, " - "and RISC-V CPUs; " + "Prefix caching is not supported for " + "RISC-V CPUs; " "disabling it for V1 backend." ) self.enable_prefix_caching = False @@ -2181,14 +2187,10 @@ class AsyncEngineArgs(EngineArgs): "--enable-log-requests", action=argparse.BooleanOptionalAction, default=AsyncEngineArgs.enable_log_requests, - help="Enable logging requests.", - ) - parser.add_argument( - "--disable-log-requests", - action=argparse.BooleanOptionalAction, - default=not AsyncEngineArgs.enable_log_requests, - help="[DEPRECATED] Disable logging requests.", - deprecated=True, + help="Enable logging request information, dependant on log level:\n" + "- INFO: Request ID, parameters and LoRA request.\n" + "- DEBUG: Prompt inputs (e.g: text, token IDs).\n" + "You can set the minimum log level via `VLLM_LOGGING_LEVEL`.", ) current_platform.pre_register_and_update(parser) return parser diff --git a/vllm/entrypoints/anthropic/api_router.py b/vllm/entrypoints/anthropic/api_router.py index 1494dd7..2b65fff 100644 --- a/vllm/entrypoints/anthropic/api_router.py +++ b/vllm/entrypoints/anthropic/api_router.py @@ -8,6 +8,8 @@ from fastapi import APIRouter, Depends, FastAPI, Request from fastapi.responses import JSONResponse, StreamingResponse from vllm.entrypoints.anthropic.protocol import ( + AnthropicCountTokensRequest, + AnthropicCountTokensResponse, AnthropicError, AnthropicErrorResponse, AnthropicMessagesRequest, @@ -31,6 +33,18 @@ def messages(request: Request) -> AnthropicServingMessages: return request.app.state.anthropic_serving_messages +def translate_error_response(response: ErrorResponse) -> JSONResponse: + anthropic_error = AnthropicErrorResponse( + error=AnthropicError( + type=response.error.type, + message=response.error.message, + ) + ) + return JSONResponse( + status_code=response.error.code, content=anthropic_error.model_dump() + ) + + @router.post( "/v1/messages", dependencies=[Depends(validate_json_request)], @@ -44,17 +58,6 @@ def messages(request: Request) -> AnthropicServingMessages: @with_cancellation @load_aware_call async def create_messages(request: AnthropicMessagesRequest, raw_request: Request): - def translate_error_response(response: ErrorResponse) -> JSONResponse: - anthropic_error = AnthropicErrorResponse( - error=AnthropicError( - type=response.error.type, - message=response.error.message, - ) - ) - return JSONResponse( - status_code=response.error.code, content=anthropic_error.model_dump() - ) - handler = messages(raw_request) if handler is None: base_server = raw_request.app.state.openai_serving_tokenization @@ -88,5 +91,46 @@ async def create_messages(request: AnthropicMessagesRequest, raw_request: Reques return StreamingResponse(content=generator, media_type="text/event-stream") +@router.post( + "/v1/messages/count_tokens", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"model": AnthropicCountTokensResponse}, + HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse}, + }, +) +@load_aware_call +@with_cancellation +async def count_tokens(request: AnthropicCountTokensRequest, raw_request: Request): + handler = messages(raw_request) + if handler is None: + base_server = raw_request.app.state.openai_serving_tokenization + error = base_server.create_error_response( + message="The model does not support Messages API" + ) + return translate_error_response(error) + + try: + response = await handler.count_tokens(request, raw_request) + except Exception as e: + logger.exception("Error in count_tokens: %s", e) + return JSONResponse( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + content=AnthropicErrorResponse( + error=AnthropicError( + type="internal_error", + message=str(e), + ) + ).model_dump(), + ) + + if isinstance(response, ErrorResponse): + return translate_error_response(response) + + return JSONResponse(content=response.model_dump(exclude_none=True)) + + def attach_router(app: FastAPI): app.include_router(router) diff --git a/vllm/entrypoints/anthropic/protocol.py b/vllm/entrypoints/anthropic/protocol.py index af9430e..c541db5 100644 --- a/vllm/entrypoints/anthropic/protocol.py +++ b/vllm/entrypoints/anthropic/protocol.py @@ -34,7 +34,7 @@ class AnthropicUsage(BaseModel): class AnthropicContentBlock(BaseModel): """Content block in message""" - type: Literal["text", "image", "tool_use", "tool_result"] + type: Literal["text", "image", "tool_use", "tool_result", "thinking"] text: str | None = None # For image content source: dict[str, Any] | None = None @@ -45,6 +45,9 @@ class AnthropicContentBlock(BaseModel): input: dict[str, Any] | None = None content: str | list[dict[str, Any]] | None = None is_error: bool | None = None + # For thinking content + thinking: str | None = None + signature: str | None = None class AnthropicMessage(BaseModel): @@ -74,7 +77,7 @@ class AnthropicTool(BaseModel): class AnthropicToolChoice(BaseModel): """Tool Choice definition""" - type: Literal["auto", "any", "tool"] + type: Literal["auto", "any", "tool", "none"] name: str | None = None @model_validator(mode="after") @@ -118,9 +121,14 @@ class AnthropicMessagesRequest(BaseModel): class AnthropicDelta(BaseModel): """Delta for streaming responses""" - type: Literal["text_delta", "input_json_delta"] | None = None + type: ( + Literal["text_delta", "input_json_delta", "thinking_delta", "signature_delta"] + | None + ) = None text: str | None = None + thinking: str | None = None partial_json: str | None = None + signature: str | None = None # Message delta stop_reason: ( @@ -167,3 +175,33 @@ class AnthropicMessagesResponse(BaseModel): def model_post_init(self, __context): if not self.id: self.id = f"msg_{int(time.time() * 1000)}" + + +class AnthropicContextManagement(BaseModel): + """Context management information for token counting.""" + + original_input_tokens: int + + +class AnthropicCountTokensRequest(BaseModel): + """Anthropic messages.count_tokens request""" + + model: str + messages: list[AnthropicMessage] + system: str | list[AnthropicContentBlock] | None = None + tool_choice: AnthropicToolChoice | None = None + tools: list[AnthropicTool] | None = None + + @field_validator("model") + @classmethod + def validate_model(cls, v): + if not v: + raise ValueError("Model is required") + return v + + +class AnthropicCountTokensResponse(BaseModel): + """Anthropic messages.count_tokens response""" + + input_tokens: int + context_management: AnthropicContextManagement | None = None diff --git a/vllm/entrypoints/anthropic/serving.py b/vllm/entrypoints/anthropic/serving.py index dc03731..85232e9 100644 --- a/vllm/entrypoints/anthropic/serving.py +++ b/vllm/entrypoints/anthropic/serving.py @@ -8,6 +8,7 @@ import json import logging import time +import uuid from collections.abc import AsyncGenerator from typing import Any @@ -16,6 +17,9 @@ from fastapi import Request from vllm.engine.protocol import EngineClient from vllm.entrypoints.anthropic.protocol import ( AnthropicContentBlock, + AnthropicContextManagement, + AnthropicCountTokensRequest, + AnthropicCountTokensResponse, AnthropicDelta, AnthropicError, AnthropicMessagesRequest, @@ -85,94 +89,225 @@ class AnthropicServingMessages(OpenAIServingChat): "tool_calls": "tool_use", } + @staticmethod + def _convert_image_source_to_url(source: dict[str, Any]) -> str: + """Convert an Anthropic image source to an OpenAI-compatible URL. + + Anthropic supports two image source types: + - base64: {"type": "base64", "media_type": "image/jpeg", "data": "..."} + - url: {"type": "url", "url": "https://..."} + + For base64 sources, this constructs a proper data URI that + downstream processors (e.g. vLLM's media connector) can handle. + """ + source_type = source.get("type") + if source_type == "url": + return source.get("url", "") + # Default to base64 processing if type is "base64" + # or missing, ensuring a proper data URI is always + # constructed for non-URL sources. + media_type = source.get("media_type", "image/jpeg") + data = source.get("data", "") + return f"data:{media_type};base64,{data}" + + @classmethod def _convert_anthropic_to_openai_request( - self, anthropic_request: AnthropicMessagesRequest + cls, anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest ) -> ChatCompletionRequest: """Convert Anthropic message format to OpenAI format""" - openai_messages = [] + openai_messages: list[dict[str, Any]] = [] - # Add system message if provided - if anthropic_request.system: - if isinstance(anthropic_request.system, str): - openai_messages.append( - {"role": "system", "content": anthropic_request.system} - ) - else: - system_prompt = "" - for block in anthropic_request.system: - if block.type == "text" and block.text: - system_prompt += block.text - openai_messages.append({"role": "system", "content": system_prompt}) + cls._convert_system_message(anthropic_request, openai_messages) + cls._convert_messages(anthropic_request.messages, openai_messages) + req = cls._build_base_request(anthropic_request, openai_messages) + cls._handle_streaming_options(req, anthropic_request) + cls._convert_tool_choice(anthropic_request, req) + cls._convert_tools(anthropic_request, req) + return req - for msg in anthropic_request.messages: + @classmethod + def _convert_system_message( + cls, + anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest, + openai_messages: list[dict[str, Any]], + ) -> None: + """Convert Anthropic system message to OpenAI format""" + if not anthropic_request.system: + return + + if isinstance(anthropic_request.system, str): + openai_messages.append( + {"role": "system", "content": anthropic_request.system} + ) + else: + system_prompt = "" + for block in anthropic_request.system: + if block.type == "text" and block.text: + system_prompt += block.text + openai_messages.append({"role": "system", "content": system_prompt}) + + @classmethod + def _convert_messages( + cls, messages: list, openai_messages: list[dict[str, Any]] + ) -> None: + """Convert Anthropic messages to OpenAI format""" + for msg in messages: openai_msg: dict[str, Any] = {"role": msg.role} # type: ignore + if isinstance(msg.content, str): openai_msg["content"] = msg.content else: - # Handle complex content blocks - content_parts: list[dict[str, Any]] = [] - tool_calls: list[dict[str, Any]] = [] - - for block in msg.content: - if block.type == "text" and block.text: - content_parts.append({"type": "text", "text": block.text}) - elif block.type == "image" and block.source: - content_parts.append( - { - "type": "image_url", - "image_url": {"url": block.source.get("data", "")}, - } - ) - elif block.type == "tool_use": - # Convert tool use to function call format - tool_call = { - "id": block.id or f"call_{int(time.time())}", - "type": "function", - "function": { - "name": block.name or "", - "arguments": json.dumps(block.input or {}), - }, - } - tool_calls.append(tool_call) - elif block.type == "tool_result": - if msg.role == "user": - openai_messages.append( - { - "role": "tool", - "tool_call_id": block.tool_use_id or "", - "content": str(block.content) - if block.content - else "", - } - ) - else: - # Assistant tool result becomes regular text - tool_result_text = ( - str(block.content) if block.content else "" - ) - content_parts.append( - { - "type": "text", - "text": f"Tool result: {tool_result_text}", - } - ) - - # Add tool calls to the message if any - if tool_calls: - openai_msg["tool_calls"] = tool_calls # type: ignore - - # Add content parts if any - if content_parts: - if len(content_parts) == 1 and content_parts[0]["type"] == "text": - openai_msg["content"] = content_parts[0]["text"] - else: - openai_msg["content"] = content_parts # type: ignore - elif not tool_calls: - continue + cls._convert_message_content(msg, openai_msg, openai_messages) openai_messages.append(openai_msg) - req = ChatCompletionRequest( + @classmethod + def _convert_message_content( + cls, + msg, + openai_msg: dict[str, Any], + openai_messages: list[dict[str, Any]], + ) -> None: + """Convert complex message content blocks""" + content_parts: list[dict[str, Any]] = [] + tool_calls: list[dict[str, Any]] = [] + reasoning_parts: list[str] = [] + + for block in msg.content: + cls._convert_block( + block, + msg.role, + content_parts, + tool_calls, + reasoning_parts, + openai_messages, + ) + + if reasoning_parts: + openai_msg["reasoning"] = "".join(reasoning_parts) + + if tool_calls: + openai_msg["tool_calls"] = tool_calls # type: ignore + + if content_parts: + if len(content_parts) == 1 and content_parts[0]["type"] == "text": + openai_msg["content"] = content_parts[0]["text"] + else: + openai_msg["content"] = content_parts # type: ignore + elif not tool_calls and not reasoning_parts: + return + + @classmethod + def _convert_block( + cls, + block, + role: str, + content_parts: list[dict[str, Any]], + tool_calls: list[dict[str, Any]], + reasoning_parts: list[str], + openai_messages: list[dict[str, Any]], + ) -> None: + """Convert individual content block""" + if block.type == "text" and block.text: + content_parts.append({"type": "text", "text": block.text}) + elif block.type == "image" and block.source: + image_url = cls._convert_image_source_to_url(block.source) + content_parts.append({"type": "image_url", "image_url": {"url": image_url}}) + elif block.type == "thinking" and block.thinking is not None: + reasoning_parts.append(block.thinking) + elif block.type == "tool_use": + cls._convert_tool_use_block(block, tool_calls) + elif block.type == "tool_result": + cls._convert_tool_result_block(block, role, openai_messages, content_parts) + + @classmethod + def _convert_tool_use_block(cls, block, tool_calls: list[dict[str, Any]]) -> None: + """Convert tool_use block to OpenAI function call format""" + tool_call = { + "id": block.id or f"call_{int(time.time())}", + "type": "function", + "function": { + "name": block.name or "", + "arguments": json.dumps(block.input or {}), + }, + } + tool_calls.append(tool_call) + + @classmethod + def _convert_tool_result_block( + cls, + block, + role: str, + openai_messages: list[dict[str, Any]], + content_parts: list[dict[str, Any]], + ) -> None: + """Convert tool_result block to OpenAI format""" + if role == "user": + cls._convert_user_tool_result(block, openai_messages) + else: + tool_result_text = str(block.content) if block.content else "" + content_parts.append( + {"type": "text", "text": f"Tool result: {tool_result_text}"} + ) + + @classmethod + def _convert_user_tool_result( + cls, block, openai_messages: list[dict[str, Any]] + ) -> None: + """Convert user tool_result with text and image support""" + tool_text = "" + tool_image_urls: list[str] = [] + + if isinstance(block.content, str): + tool_text = block.content + elif isinstance(block.content, list): + text_parts: list[str] = [] + for item in block.content: + if not isinstance(item, dict): + continue + item_type = item.get("type") + if item_type == "text": + text_parts.append(item.get("text", "")) + elif item_type == "image": + source = item.get("source", {}) + url = cls._convert_image_source_to_url(source) + if url: + tool_image_urls.append(url) + tool_text = "\n".join(text_parts) + + openai_messages.append( + { + "role": "tool", + "tool_call_id": block.tool_use_id or "", + "content": tool_text or "", + } + ) + + if tool_image_urls: + openai_messages.append( + { + "role": "user", + "content": [ # type: ignore[dict-item] + {"type": "image_url", "image_url": {"url": img}} + for img in tool_image_urls + ], + } + ) + + @classmethod + def _build_base_request( + cls, + anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest, + openai_messages: list[dict[str, Any]], + ) -> ChatCompletionRequest: + """Build base ChatCompletionRequest""" + if isinstance(anthropic_request, AnthropicCountTokensRequest): + return ChatCompletionRequest( + model=anthropic_request.model, + messages=openai_messages, + ) + + return ChatCompletionRequest( model=anthropic_request.model, messages=openai_messages, max_tokens=anthropic_request.max_tokens, @@ -183,19 +318,40 @@ class AnthropicServingMessages(OpenAIServingChat): top_k=anthropic_request.top_k, ) + @classmethod + def _handle_streaming_options( + cls, + req: ChatCompletionRequest, + anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest, + ) -> None: + """Handle streaming configuration""" + if isinstance(anthropic_request, AnthropicCountTokensRequest): + return if anthropic_request.stream: req.stream = anthropic_request.stream - req.stream_options = StreamOptions.validate( + req.stream_options = StreamOptions.model_validate( {"include_usage": True, "continuous_usage_stats": True} ) + @classmethod + def _convert_tool_choice( + cls, + anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest, + req: ChatCompletionRequest, + ) -> None: + """Convert Anthropic tool_choice to OpenAI format""" if anthropic_request.tool_choice is None: req.tool_choice = None - elif anthropic_request.tool_choice.type == "auto": + return + + tool_choice_type = anthropic_request.tool_choice.type + if tool_choice_type == "auto": req.tool_choice = "auto" - elif anthropic_request.tool_choice.type == "any": + elif tool_choice_type == "any": req.tool_choice = "required" - elif anthropic_request.tool_choice.type == "tool": + elif tool_choice_type == "none": + req.tool_choice = "none" + elif tool_choice_type == "tool": req.tool_choice = ChatCompletionNamedToolChoiceParam.model_validate( { "type": "function", @@ -203,9 +359,17 @@ class AnthropicServingMessages(OpenAIServingChat): } ) - tools = [] + @classmethod + def _convert_tools( + cls, + anthropic_request: AnthropicMessagesRequest | AnthropicCountTokensRequest, + req: ChatCompletionRequest, + ) -> None: + """Convert Anthropic tools to OpenAI format""" if anthropic_request.tools is None: - return req + return + + tools = [] for tool in anthropic_request.tools: tools.append( ChatCompletionToolsParam.model_validate( @@ -219,10 +383,10 @@ class AnthropicServingMessages(OpenAIServingChat): } ) ) + if req.tool_choice is None: req.tool_choice = "auto" req.tools = tools - return req async def create_messages( self, @@ -263,23 +427,32 @@ class AnthropicServingMessages(OpenAIServingChat): output_tokens=generator.usage.completion_tokens, ), ) - if generator.choices[0].finish_reason == "stop": + choice = generator.choices[0] + if choice.finish_reason == "stop": result.stop_reason = "end_turn" - elif generator.choices[0].finish_reason == "length": + elif choice.finish_reason == "length": result.stop_reason = "max_tokens" - elif generator.choices[0].finish_reason == "tool_calls": + elif choice.finish_reason == "tool_calls": result.stop_reason = "tool_use" - content: list[AnthropicContentBlock] = [ - AnthropicContentBlock( - type="text", - text=generator.choices[0].message.content - if generator.choices[0].message.content - else "", + content: list[AnthropicContentBlock] = [] + if choice.message.reasoning: + content.append( + AnthropicContentBlock( + type="thinking", + thinking=choice.message.reasoning, + signature=uuid.uuid4().hex, + ) + ) + if choice.message.content: + content.append( + AnthropicContentBlock( + type="text", + text=choice.message.content, + ) ) - ] - for tool_call in generator.choices[0].message.tool_calls: + for tool_call in choice.message.tool_calls: anthropic_tool_call = AnthropicContentBlock( type="tool_use", id=tool_call.id, @@ -297,10 +470,85 @@ class AnthropicServingMessages(OpenAIServingChat): generator: AsyncGenerator[str, None], ) -> AsyncGenerator[str, None]: try: + + class _ActiveBlockState: + def __init__(self) -> None: + self.content_block_index = 0 + self.block_type: str | None = None + self.block_index: int | None = None + self.block_signature: str | None = None + self.signature_emitted: bool = False + self.tool_use_id: str | None = None + + def reset(self) -> None: + self.block_type = None + self.block_index = None + self.block_signature = None + self.signature_emitted = False + self.tool_use_id = None + + def start(self, block: AnthropicContentBlock) -> None: + self.block_type = block.type + self.block_index = self.content_block_index + if block.type == "thinking": + self.block_signature = uuid.uuid4().hex + self.signature_emitted = False + self.tool_use_id = None + elif block.type == "tool_use": + self.block_signature = None + self.signature_emitted = True + self.tool_use_id = block.id + else: + self.block_signature = None + self.signature_emitted = True + self.tool_use_id = None + first_item = True finish_reason = None - content_block_index = 0 - content_block_started = False + state = _ActiveBlockState() + # Map from tool call index to tool_use_id + tool_index_to_id: dict[int, str] = {} + + def stop_active_block(): + events: list[str] = [] + if state.block_type is None: + return events + if ( + state.block_type == "thinking" + and state.block_signature is not None + and not state.signature_emitted + ): + chunk = AnthropicStreamEvent( + index=state.block_index, + type="content_block_delta", + delta=AnthropicDelta( + type="signature_delta", + signature=state.block_signature, + ), + ) + data = chunk.model_dump_json(exclude_unset=True) + events.append(wrap_data_with_event(data, "content_block_delta")) + state.signature_emitted = True + stop_chunk = AnthropicStreamEvent( + index=state.block_index, + type="content_block_stop", + ) + data = stop_chunk.model_dump_json(exclude_unset=True) + events.append(wrap_data_with_event(data, "content_block_stop")) + state.reset() + state.content_block_index += 1 + return events + + def start_block(block: AnthropicContentBlock): + chunk = AnthropicStreamEvent( + index=state.content_block_index, + type="content_block_start", + content_block=block, + ) + data = chunk.model_dump_json(exclude_unset=True) + event = wrap_data_with_event(data, "content_block_start") + state.start(block) + return event async for item in generator: if item.startswith("data:"): @@ -326,6 +574,8 @@ class AnthropicServingMessages(OpenAIServingChat): id=origin_chunk.id, content=[], model=origin_chunk.model, + stop_reason=None, + stop_sequence=None, usage=AnthropicUsage( input_tokens=origin_chunk.usage.prompt_tokens if origin_chunk.usage @@ -341,13 +591,8 @@ class AnthropicServingMessages(OpenAIServingChat): # last chunk including usage info if len(origin_chunk.choices) == 0: - if content_block_started: - stop_chunk = AnthropicStreamEvent( - index=content_block_index, - type="content_block_stop", - ) - data = stop_chunk.model_dump_json(exclude_unset=True) - yield wrap_data_with_event(data, "content_block_stop") + for event in stop_active_block(): + yield event stop_reason = self.stop_reason_map.get( finish_reason or "stop" ) @@ -369,96 +614,139 @@ class AnthropicServingMessages(OpenAIServingChat): if origin_chunk.choices[0].finish_reason is not None: finish_reason = origin_chunk.choices[0].finish_reason - continue - - # content - if origin_chunk.choices[0].delta.content is not None: - if not content_block_started: - chunk = AnthropicStreamEvent( - index=content_block_index, - type="content_block_start", - content_block=AnthropicContentBlock( - type="text", text="" - ), - ) - data = chunk.model_dump_json(exclude_unset=True) - yield wrap_data_with_event(data, "content_block_start") - content_block_started = True - - if origin_chunk.choices[0].delta.content == "": - continue - chunk = AnthropicStreamEvent( - index=content_block_index, - type="content_block_delta", - delta=AnthropicDelta( - type="text_delta", - text=origin_chunk.choices[0].delta.content, - ), - ) - data = chunk.model_dump_json(exclude_unset=True) - yield wrap_data_with_event(data, "content_block_delta") - continue - - # tool calls - elif len(origin_chunk.choices[0].delta.tool_calls) > 0: - tool_call = origin_chunk.choices[0].delta.tool_calls[0] - if tool_call.id is not None: - if content_block_started: - stop_chunk = AnthropicStreamEvent( - index=content_block_index, - type="content_block_stop", - ) - data = stop_chunk.model_dump_json( - exclude_unset=True - ) - yield wrap_data_with_event( - data, "content_block_stop" - ) - content_block_started = False - content_block_index += 1 - - chunk = AnthropicStreamEvent( - index=content_block_index, - type="content_block_start", - content_block=AnthropicContentBlock( - type="tool_use", - id=tool_call.id, - name=tool_call.function.name - if tool_call.function - else None, - input={}, - ), - ) - data = chunk.model_dump_json(exclude_unset=True) - yield wrap_data_with_event(data, "content_block_start") - content_block_started = True - if tool_call.function and tool_call.function.arguments: - chunk = AnthropicStreamEvent( - index=content_block_index, - type="content_block_delta", - delta=AnthropicDelta( - type="input_json_delta", - partial_json=tool_call.function.arguments, - ), - ) - data = chunk.model_dump_json(exclude_unset=True) - yield wrap_data_with_event( - data, "content_block_delta" - ) + # continue + # thinking / text content + reasoning_delta = origin_chunk.choices[0].delta.reasoning + if reasoning_delta is not None: + if reasoning_delta == "": + pass else: + if state.block_type != "thinking": + for event in stop_active_block(): + yield event + start_event = start_block( + AnthropicContentBlock( + type="thinking", thinking="" + ) + ) + yield start_event chunk = AnthropicStreamEvent( - index=content_block_index, + index=( + state.block_index + if state.block_index is not None + else state.content_block_index + ), type="content_block_delta", delta=AnthropicDelta( - type="input_json_delta", - partial_json=tool_call.function.arguments - if tool_call.function - else None, + type="thinking_delta", + thinking=reasoning_delta, ), ) data = chunk.model_dump_json(exclude_unset=True) yield wrap_data_with_event(data, "content_block_delta") + + if origin_chunk.choices[0].delta.content is not None: + if origin_chunk.choices[0].delta.content == "": + pass + else: + if state.block_type != "text": + for event in stop_active_block(): + yield event + start_event = start_block( + AnthropicContentBlock(type="text", text="") + ) + yield start_event + chunk = AnthropicStreamEvent( + index=( + state.block_index + if state.block_index is not None + else state.content_block_index + ), + type="content_block_delta", + delta=AnthropicDelta( + type="text_delta", + text=origin_chunk.choices[0].delta.content, + ), + ) + data = chunk.model_dump_json(exclude_unset=True) + yield wrap_data_with_event(data, "content_block_delta") + + # tool calls - process all tool calls in the delta + if len(origin_chunk.choices[0].delta.tool_calls) > 0: + for tool_call in origin_chunk.choices[0].delta.tool_calls: + if tool_call.id is not None: + # Update mapping for incremental updates + tool_index_to_id[tool_call.index] = tool_call.id + # Only create new block if different tool call + # AND has a name + tool_name = ( + tool_call.function.name + if tool_call.function + else None + ) + if ( + state.tool_use_id != tool_call.id + and tool_name is not None + ): + for event in stop_active_block(): + yield event + start_event = start_block( + AnthropicContentBlock( + type="tool_use", + id=tool_call.id, + name=tool_name, + input={}, + ) + ) + yield start_event + # Handle initial arguments if present + if ( + tool_call.function + and tool_call.function.arguments + and state.tool_use_id == tool_call.id + ): + chunk = AnthropicStreamEvent( + index=( + state.block_index + if state.block_index is not None + else state.content_block_index + ), + type="content_block_delta", + delta=AnthropicDelta( + type="input_json_delta", + partial_json=tool_call.function.arguments, + ), + ) + data = chunk.model_dump_json(exclude_unset=True) + yield wrap_data_with_event( + data, "content_block_delta" + ) + else: + # Incremental update - use index to find tool_use_id + tool_use_id = tool_index_to_id.get(tool_call.index) + if ( + tool_use_id is not None + and tool_call.function + and tool_call.function.arguments + and state.tool_use_id == tool_use_id + ): + chunk = AnthropicStreamEvent( + index=( + state.block_index + if state.block_index is not None + else state.content_block_index + ), + type="content_block_delta", + delta=AnthropicDelta( + type="input_json_delta", + partial_json=tool_call.function.arguments, + ), + ) + data = chunk.model_dump_json(exclude_unset=True) + yield wrap_data_with_event( + data, "content_block_delta" + ) continue else: error_response = AnthropicStreamEvent( @@ -481,3 +769,31 @@ class AnthropicServingMessages(OpenAIServingChat): data = error_response.model_dump_json(exclude_unset=True) yield wrap_data_with_event(data, "error") yield "data: [DONE]\n\n" + + async def count_tokens( + self, + request: AnthropicCountTokensRequest, + raw_request: Request | None = None, + ) -> AnthropicCountTokensResponse | ErrorResponse: + """Implements Anthropic's messages.count_tokens endpoint.""" + chat_req = self._convert_anthropic_to_openai_request(request) + result = await self.render_chat_request(chat_req) + if isinstance(result, ErrorResponse): + return result + + _, engine_prompts = result + + input_tokens = sum( # type: ignore + len(prompt["prompt_token_ids"]) # type: ignore[typeddict-item, misc] + for prompt in engine_prompts + if "prompt_token_ids" in prompt + ) + + response = AnthropicCountTokensResponse( + input_tokens=input_tokens, + context_management=AnthropicContextManagement( + original_input_tokens=input_tokens + ), + ) + + return response diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index c48d7be..1d10aa6 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -7,6 +7,7 @@ import warnings from abc import ABC, abstractmethod from collections import Counter, defaultdict from collections.abc import Awaitable, Callable, Iterable +from dataclasses import dataclass from functools import cached_property, lru_cache, partial from itertools import accumulate from pathlib import Path @@ -1024,6 +1025,13 @@ class AsyncMultiModalContentParser(BaseMultiModalContentParser): self._add_placeholder("video", placeholder) +@dataclass +class ChatTemplateConfig: + chat_template: str | None = None + chat_template_content_format: ChatTemplateContentFormatOption = "auto" + trust_request_chat_template: bool = False + + def validate_chat_template(chat_template: Path | str | None): """Raises if the provided chat template appears invalid.""" if chat_template is None: diff --git a/vllm/entrypoints/cli/__init__.py b/vllm/entrypoints/cli/__init__.py index ff5a928..704d94d 100644 --- a/vllm/entrypoints/cli/__init__.py +++ b/vllm/entrypoints/cli/__init__.py @@ -1,12 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.entrypoints.cli.benchmark.latency import BenchmarkLatencySubcommand +from vllm.entrypoints.cli.benchmark.mm_processor import ( + BenchmarkMMProcessorSubcommand, +) +from vllm.entrypoints.cli.benchmark.serve import BenchmarkServingSubcommand +from vllm.entrypoints.cli.benchmark.startup import BenchmarkStartupSubcommand +from vllm.entrypoints.cli.benchmark.sweep import BenchmarkSweepSubcommand +from vllm.entrypoints.cli.benchmark.throughput import BenchmarkThroughputSubcommand -# Keep this package init import-free. -# -# The `vllm` console script imports `vllm.entrypoints.cli.main`, which causes -# Python to import this package before loading the `main` submodule. -# Eagerly importing benchmark subcommands here makes every `vllm serve ...` -# startup depend on optional benchmark-only modules. -# -# Benchmark subcommands are loaded on demand in -# `vllm.entrypoints.cli.benchmark.main`. +__all__: list[str] = [ + "BenchmarkLatencySubcommand", + "BenchmarkMMProcessorSubcommand", + "BenchmarkServingSubcommand", + "BenchmarkStartupSubcommand", + "BenchmarkSweepSubcommand", + "BenchmarkThroughputSubcommand", +] diff --git a/vllm/entrypoints/cli/benchmark/main.py b/vllm/entrypoints/cli/benchmark/main.py index ae490ba..48f34fc 100644 --- a/vllm/entrypoints/cli/benchmark/main.py +++ b/vllm/entrypoints/cli/benchmark/main.py @@ -2,8 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse -import importlib -import logging import typing from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase @@ -15,30 +13,6 @@ if typing.TYPE_CHECKING: else: FlexibleArgumentParser = argparse.ArgumentParser -logger = logging.getLogger(__name__) - - -def _load_benchmark_subcommands() -> None: - modules = [ - "vllm.entrypoints.cli.benchmark.latency", - "vllm.entrypoints.cli.benchmark.mm_processor", - "vllm.entrypoints.cli.benchmark.serve", - "vllm.entrypoints.cli.benchmark.startup", - "vllm.entrypoints.cli.benchmark.sweep", - "vllm.entrypoints.cli.benchmark.throughput", - ] - - for module_name in modules: - try: - importlib.import_module(module_name) - except ModuleNotFoundError as e: - logger.warning( - "Skipping benchmark subcommand module %s because an optional " - "dependency could not be imported: %r", - module_name, - e, - ) - class BenchmarkSubcommand(CLISubcommand): """The `bench` subcommand for the vLLM CLI.""" @@ -64,8 +38,6 @@ class BenchmarkSubcommand(CLISubcommand): ) bench_subparsers = bench_parser.add_subparsers(required=True, dest="bench_type") - _load_benchmark_subcommands() - for cmd_cls in BenchmarkSubcommandBase.__subclasses__(): cmd_subparser = bench_subparsers.add_parser( cmd_cls.name, diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index c12cc7f..944fb88 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -220,6 +220,12 @@ def run_multi_api_server(args: argparse.Namespace): num_api_servers: int = args.api_server_count assert num_api_servers > 0 + if num_api_servers > 1 and getattr(args, "use_gpu_for_pooling_score", False): + # TODO(wentao): remove this once well tested + raise ValueError( + "--use-gpu-for-pooling-score cannot be used with api_server_count > 1 now" + ) + if num_api_servers > 1: setup_multiprocess_prometheus() @@ -246,8 +252,12 @@ def run_multi_api_server(args: argparse.Namespace): api_server_manager: APIServerProcessManager | None = None + from vllm.v1.engine.utils import get_engine_zmq_addresses + + addresses = get_engine_zmq_addresses(vllm_config, num_api_servers) + with launch_core_engines( - vllm_config, executor_class, log_stats, num_api_servers + vllm_config, executor_class, log_stats, addresses, num_api_servers ) as (local_engine_manager, coordinator, addresses): # Construct common args for the APIServerProcessManager up-front. api_server_manager_kwargs = dict( diff --git a/vllm/entrypoints/grpc_server.py b/vllm/entrypoints/grpc_server.py index 1fc3354..ec8f480 100644 --- a/vllm/entrypoints/grpc_server.py +++ b/vllm/entrypoints/grpc_server.py @@ -101,11 +101,15 @@ class VllmEngineServicer(vllm_engine_pb2_grpc.VllmEngineServicer): sampling_params = self._sampling_params_from_proto( request.sampling_params, stream=request.stream ) + tokenization_kwargs = self._tokenization_kwargs_from_proto( + request.sampling_params + ) async for output in self.async_llm.generate( prompt=prompt, sampling_params=sampling_params, request_id=request_id, + tokenization_kwargs=tokenization_kwargs, ): # Convert vLLM output to protobuf # For streaming, always send chunks @@ -308,9 +312,6 @@ class VllmEngineServicer(vllm_engine_pb2_grpc.VllmEngineServicer): seed=params.seed if params.HasField("seed") else None, include_stop_str_in_output=params.include_stop_str_in_output, logit_bias=dict(params.logit_bias) if params.logit_bias else None, - truncate_prompt_tokens=params.truncate_prompt_tokens - if params.HasField("truncate_prompt_tokens") - else None, structured_outputs=structured_outputs, # detokenize must be True if stop strings are used detokenize=bool(stop), @@ -319,6 +320,14 @@ class VllmEngineServicer(vllm_engine_pb2_grpc.VllmEngineServicer): else RequestOutputKind.FINAL_ONLY, ) + @staticmethod + def _tokenization_kwargs_from_proto( + params: vllm_engine_pb2.SamplingParams, + ) -> dict[str, int] | None: + if params.HasField("truncate_prompt_tokens"): + return {"truncate_prompt_tokens": params.truncate_prompt_tokens} + return None + @staticmethod def _chunk_response(output: RequestOutput) -> vllm_engine_pb2.GenerateResponse: """ diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index ee78d4d..d5a51a6 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -2,8 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools -import warnings from collections.abc import Callable, Iterable, Sequence +from pathlib import Path from typing import TYPE_CHECKING, Any import cloudpickle @@ -41,8 +41,11 @@ from vllm.distributed.weight_transfer.base import ( from vllm.engine.arg_utils import EngineArgs from vllm.entrypoints.chat_utils import ( ChatCompletionMessageParam, + ChatTemplateConfig, ChatTemplateContentFormatOption, + load_chat_template, ) +from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors from vllm.entrypoints.pooling.score.utils import ( ScoreData, ScoreMultiModalParam, @@ -146,6 +149,7 @@ class LLM: a tag name, or a commit id. tokenizer_revision: The specific tokenizer version to use. It can be a branch name, a tag name, or a commit id. + chat_template: The chat template to apply. seed: The seed to initialize the random number generator for sampling. gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache. Higher @@ -233,6 +237,7 @@ class LLM: quantization: QuantizationMethods | None = None, revision: str | None = None, tokenizer_revision: str | None = None, + chat_template: Path | str | None = None, seed: int = 0, gpu_memory_utilization: float = 0.9, swap_space: float = 4, @@ -385,9 +390,16 @@ class LLM: self.model_config = self.llm_engine.model_config self.renderer = self.llm_engine.renderer + self.chat_template = load_chat_template(chat_template) self.io_processor = self.llm_engine.io_processor self.input_processor = self.llm_engine.input_processor - + self.chat_template_config = ChatTemplateConfig(chat_template=self.chat_template) + self.init_pooling_io_processors = init_pooling_io_processors( + supported_tasks=supported_tasks, + model_config=self.model_config, + renderer=self.renderer, + chat_template_config=self.chat_template_config, + ) # Cache for __repr__ to avoid repeated collective_rpc calls self._cached_repr: str | None = None @@ -1030,7 +1042,6 @@ class LLM: prompts: PromptType | Sequence[PromptType] | DataPrompt, pooling_params: PoolingParams | Sequence[PoolingParams] | None = None, *, - truncate_prompt_tokens: int | None = None, use_tqdm: bool | Callable[..., tqdm] = True, lora_request: list[LoRARequest] | LoRARequest | None = None, pooling_task: PoolingTask | None = None, @@ -1088,21 +1099,7 @@ class LLM: "pooling model." ) - if truncate_prompt_tokens is not None: - warnings.warn( - "The `truncate_prompt_tokens` parameter in `LLM.encode()` " - "is deprecated and will be removed in v0.16. " - "Please pass it via `tokenization_kwargs` instead.", - DeprecationWarning, - stacklevel=2, - ) - - tokenization_kwargs = merge_kwargs( - tokenization_kwargs, - dict(truncate_prompt_tokens=truncate_prompt_tokens), - ) - - if use_io_processor := (isinstance(prompts, dict) and "data" in prompts): + if isinstance(prompts, dict) and "data" in prompts: if self.io_processor is None: raise ValueError( "No IOProcessor plugin installed. Please refer " @@ -1136,6 +1133,31 @@ class LLM: for p in params_seq: if p.task is None: p.task = "plugin" + + outputs = self._run_completion( + prompts=prompts_seq, + params=params_seq, + output_type=PoolingRequestOutput, + use_tqdm=use_tqdm, + lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, + ) + + # get the post-processed model outputs + assert self.io_processor is not None + processed_outputs = self.io_processor.post_process(outputs) + + return [ + PoolingRequestOutput[Any]( + request_id="", + outputs=processed_outputs, + num_cached_tokens=getattr( + processed_outputs, "num_cached_tokens", 0 + ), + prompt_token_ids=[], + finished=True, + ) + ] else: if pooling_params is None: # Use default pooling params. @@ -1153,39 +1175,42 @@ class LLM: ) raise ValueError(msg) - outputs = self._run_completion( - prompts=prompts_seq, - params=params_seq, - output_type=PoolingRequestOutput, - use_tqdm=use_tqdm, - lora_request=lora_request, - tokenization_kwargs=tokenization_kwargs, - ) - - if use_io_processor: - # get the post-processed model outputs - assert self.io_processor is not None - processed_outputs = self.io_processor.post_process(outputs) - - return [ - PoolingRequestOutput[Any]( - request_id="", - outputs=processed_outputs, - num_cached_tokens=getattr( - processed_outputs, "num_cached_tokens", 0 - ), - prompt_token_ids=[], - finished=True, + if pooling_task in self.init_pooling_io_processors: + io_processor = self.init_pooling_io_processors[pooling_task] + processor_inputs = io_processor.pre_process_offline( + prompts_seq, tokenization_kwargs ) - ] + seq_lora_requests = self._lora_request_to_seq( + lora_request, len(prompts_seq) + ) + seq_priority = self._priority_to_seq(None, len(prompts)) + self._render_and_add_requests( + prompts=processor_inputs, + params=params_seq, + lora_requests=seq_lora_requests, + priorities=seq_priority, + ) + + outputs = self._run_engine( + use_tqdm=use_tqdm, output_type=PoolingRequestOutput + ) + outputs = io_processor.post_process(outputs) + else: + outputs = self._run_completion( + prompts=prompts_seq, + params=params_seq, + output_type=PoolingRequestOutput, + use_tqdm=use_tqdm, + lora_request=lora_request, + tokenization_kwargs=tokenization_kwargs, + ) return outputs def embed( self, prompts: PromptType | Sequence[PromptType], *, - truncate_prompt_tokens: int | None = None, use_tqdm: bool | Callable[..., tqdm] = True, pooling_params: PoolingParams | Sequence[PoolingParams] | None = None, lora_request: list[LoRARequest] | LoRARequest | None = None, @@ -1221,12 +1246,6 @@ class LLM: "Try converting the model using `--convert embed`." ) - if truncate_prompt_tokens is not None: - tokenization_kwargs = merge_kwargs( - tokenization_kwargs, - dict(truncate_prompt_tokens=truncate_prompt_tokens), - ) - items = self.encode( prompts, use_tqdm=use_tqdm, @@ -1294,7 +1313,6 @@ class LLM: /, *, pooling_params: PoolingParams | Sequence[PoolingParams] | None = None, - truncate_prompt_tokens: int | None = None, use_tqdm: bool | Callable[..., tqdm] = True, lora_request: list[LoRARequest] | LoRARequest | None = None, tokenization_kwargs: dict[str, Any] | None = None, @@ -1319,13 +1337,11 @@ class LLM: A list of `PoolingRequestOutput` objects containing the pooled hidden states in the same order as the input prompts. """ - return self.encode( prompts, use_tqdm=use_tqdm, lora_request=lora_request, pooling_params=pooling_params, - truncate_prompt_tokens=truncate_prompt_tokens, pooling_task="token_classify", tokenization_kwargs=tokenization_kwargs, ) @@ -1771,23 +1787,15 @@ class LLM: seq_prompts = prompt_to_seq(prompts) seq_params = self._params_to_seq(params, len(seq_prompts)) seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts)) - seq_tok_kwargs = [ - merge_kwargs( - tokenization_kwargs, - dict(truncate_prompt_tokens=param.truncate_prompt_tokens), - ) - for param in seq_params - ] seq_priority = self._priority_to_seq(priority, len(prompts)) return self._render_and_add_requests( prompts=( - self._preprocess_cmpl_one(prompt, tok_kwargs) - for prompt, tok_kwargs in zip( - maybe_tqdm( - seq_prompts, use_tqdm=use_tqdm, desc="Rendering prompts" - ), - seq_tok_kwargs, + self._preprocess_cmpl_one(prompt, tokenization_kwargs) + for prompt in maybe_tqdm( + seq_prompts, + use_tqdm=use_tqdm, + desc="Rendering prompts", ) ), params=seq_params, @@ -1841,13 +1849,6 @@ class LLM: seq_convs = conversation_to_seq(messages) seq_params = self._params_to_seq(params, len(seq_convs)) seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_convs)) - seq_tok_kwargs = [ - merge_kwargs( - tokenization_kwargs, - dict(truncate_prompt_tokens=param.truncate_prompt_tokens), - ) - for param in seq_params - ] return self._render_and_run_requests( prompts=( @@ -1859,16 +1860,13 @@ class LLM: add_generation_prompt=add_generation_prompt, continue_final_message=continue_final_message, tools=tools, - tokenization_kwargs=tok_kwargs, + tokenization_kwargs=tokenization_kwargs, mm_processor_kwargs=mm_processor_kwargs, ) - for conversation, tok_kwargs in zip( - maybe_tqdm( - seq_convs, - use_tqdm=use_tqdm, - desc="Rendering conversations", - ), - seq_tok_kwargs, + for conversation in maybe_tqdm( + seq_convs, + use_tqdm=use_tqdm, + desc="Rendering conversations", ) ), params=seq_params, diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py index c9e8093..c2a77fb 100644 --- a/vllm/entrypoints/logger.py +++ b/vllm/entrypoints/logger.py @@ -18,6 +18,20 @@ class RequestLogger: def __init__(self, *, max_log_len: int | None) -> None: self.max_log_len = max_log_len + if not logger.isEnabledFor(logging.INFO): + logger.warning_once( + "`--enable-log-requests` is set but " + "the minimum log level is higher than INFO. " + "No request information will be logged." + ) + elif not logger.isEnabledFor(logging.DEBUG): + logger.info_once( + "`--enable-log-requests` is set but " + "the minimum log level is higher than DEBUG. " + "Only limited information will be logged to minimize overhead. " + "To view more details, set `VLLM_LOGGING_LEVEL=DEBUG`." + ) + def log_inputs( self, request_id: str, diff --git a/vllm/entrypoints/openai/chat_completion/protocol.py b/vllm/entrypoints/openai/chat_completion/protocol.py index 1bf0de5..0abe85a 100644 --- a/vllm/entrypoints/openai/chat_completion/protocol.py +++ b/vllm/entrypoints/openai/chat_completion/protocol.py @@ -38,6 +38,7 @@ from vllm.logprobs import Logprob from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs from vllm.sampling_params import ( BeamSearchParams, + RepetitionDetectionParams, RequestOutputKind, SamplingParams, StructuredOutputsParams, @@ -336,6 +337,16 @@ class ChatCompletionRequest(OpenAIBaseModel): ), ) + repetition_detection: RepetitionDetectionParams | None = Field( + default=None, + description="Parameters for detecting repetitive N-gram patterns " + "in output tokens. If such repetition is detected, generation will " + "be ended early. LLMs can sometimes generate repetitive, unhelpful " + "token patterns, stopping only when they hit the maximum output length " + "(e.g. 'abcdabcdabcd...' or '\emoji \emoji \emoji ...'). This feature " + "can detect such behavior and terminate early, saving time and tokens.", + ) + # --8<-- [end:chat-completion-extra-params] def build_chat_params( @@ -490,7 +501,6 @@ class ChatCompletionRequest(OpenAIBaseModel): skip_special_tokens=self.skip_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens, include_stop_str_in_output=self.include_stop_str_in_output, - truncate_prompt_tokens=self.truncate_prompt_tokens, output_kind=RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY, @@ -500,8 +510,37 @@ class ChatCompletionRequest(OpenAIBaseModel): allowed_token_ids=self.allowed_token_ids, extra_args=extra_args or None, skip_clone=True, # Created fresh per request, safe to skip clone + repetition_detection=self.repetition_detection, ) + @model_validator(mode="before") + @classmethod + def validate_response_format(cls, data): + response_format = data.get("response_format") + if response_format is None: + return data + + rf_type = ( + response_format.get("type") + if isinstance(response_format, dict) + else getattr(response_format, "type", None) + ) + + if rf_type == "json_schema": + json_schema = ( + response_format.get("json_schema") + if isinstance(response_format, dict) + else getattr(response_format, "json_schema", None) + ) + if json_schema is None: + raise VLLMValidationError( + "When response_format type is 'json_schema', the " + "'json_schema' field must be provided.", + parameter="response_format", + ) + + return data + @model_validator(mode="before") @classmethod def validate_stream_options(cls, data): diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 39f8635..06b16cd 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -1249,13 +1249,23 @@ class OpenAIServingChat(OpenAIServing): ) # get the expected call based on partial JSON - # parsing which "autocompletes" the JSON - expected_call = json.dumps( - tool_parser.prev_tool_call_arr[index].get( - "arguments", {} - ), - ensure_ascii=False, + # parsing which "autocompletes" the JSON. + # Tool parsers (e.g. Qwen3Coder) store + # arguments as a JSON string in + # prev_tool_call_arr. Calling json.dumps() + # on an already-serialized string would + # double-serialize it (e.g. '{"k":1}' becomes + # '"{\\"k\\":1}"'), which then causes the + # replace() below to fail and append the + # entire double-serialized string as a + # spurious final delta. + args = tool_parser.prev_tool_call_arr[index].get( + "arguments", {} ) + if isinstance(args, str): + expected_call = args + else: + expected_call = json.dumps(args, ensure_ascii=False) # get what we've streamed so far for arguments # for the current tool diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index eac581e..d3a66c1 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -143,7 +143,8 @@ class BaseFrontendArgs: templates and other tokenizer configuration.""" enable_log_outputs: bool = False """If set to True, log model outputs (generations). - Requires --enable-log-requests.""" + Requires `--enable-log-requests`. As with `--enable-log-requests`, + information is only logged at INFO level at maximum.""" enable_log_deltas: bool = True """If set to False, output deltas will not be logged. Relevant only if --enable-log-outputs is set. @@ -277,6 +278,10 @@ class FrontendArgs(BaseFrontendArgs): Enable offline FastAPI documentation for air-gapped environments. Uses vendored static assets bundled with vLLM. """ + use_gpu_for_pooling_score: bool = False + """If set, run pooling score MaxSim on GPU in the API server process. + Can significantly improve late-interaction scoring performance. + https://github.com/vllm-project/vllm/pull/35330""" @classmethod def _customize_cli_kwargs( diff --git a/vllm/entrypoints/openai/completion/protocol.py b/vllm/entrypoints/openai/completion/protocol.py index 226dd6c..af13204 100644 --- a/vllm/entrypoints/openai/completion/protocol.py +++ b/vllm/entrypoints/openai/completion/protocol.py @@ -26,6 +26,7 @@ from vllm.logprobs import Logprob from vllm.renderers import TokenizeParams from vllm.sampling_params import ( BeamSearchParams, + RepetitionDetectionParams, RequestOutputKind, SamplingParams, StructuredOutputsParams, @@ -166,6 +167,16 @@ class CompletionRequest(OpenAIBaseModel): ), ) + repetition_detection: RepetitionDetectionParams | None = Field( + default=None, + description="Parameters for detecting repetitive N-gram patterns " + "in output tokens. If such repetition is detected, generation will " + "be ended early. LLMs can sometimes generate repetitive, unhelpful " + "token patterns, stopping only when they hit the maximum output length " + "(e.g. 'abcdabcdabcd...' or '\emoji \emoji \emoji ...'). This feature " + "can detect such behavior and terminate early, saving time and tokens.", + ) + # --8<-- [end:completion-extra-params] def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: @@ -259,7 +270,7 @@ class CompletionRequest(OpenAIBaseModel): structured_outputs_kwargs["json"] = json_schema.json_schema elif response_format.type == "structural_tag": structural_tag = response_format - assert structural_tag is not None and isinstance( + assert isinstance( structural_tag, ( LegacyStructuralTagResponseFormat, @@ -302,7 +313,6 @@ class CompletionRequest(OpenAIBaseModel): skip_special_tokens=self.skip_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens, include_stop_str_in_output=self.include_stop_str_in_output, - truncate_prompt_tokens=self.truncate_prompt_tokens, output_kind=RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY, @@ -311,8 +321,37 @@ class CompletionRequest(OpenAIBaseModel): allowed_token_ids=self.allowed_token_ids, extra_args=extra_args or None, skip_clone=True, # Created fresh per request, safe to skip clone + repetition_detection=self.repetition_detection, ) + @model_validator(mode="before") + @classmethod + def validate_response_format(cls, data): + response_format = data.get("response_format") + if response_format is None: + return data + + rf_type = ( + response_format.get("type") + if isinstance(response_format, dict) + else getattr(response_format, "type", None) + ) + + if rf_type == "json_schema": + json_schema = ( + response_format.get("json_schema") + if isinstance(response_format, dict) + else getattr(response_format, "json_schema", None) + ) + if json_schema is None: + raise VLLMValidationError( + "When response_format type is 'json_schema', the " + "'json_schema' field must be provided.", + parameter="response_format", + ) + + return data + @model_validator(mode="before") @classmethod def check_structured_outputs_count(cls, data): diff --git a/vllm/entrypoints/openai/engine/serving.py b/vllm/entrypoints/openai/engine/serving.py index 3e376ba..e864f56 100644 --- a/vllm/entrypoints/openai/engine/serving.py +++ b/vllm/entrypoints/openai/engine/serving.py @@ -62,11 +62,6 @@ from vllm.entrypoints.openai.speech_to_text.protocol import ( TranscriptionResponse, TranslationRequest, ) -from vllm.entrypoints.pooling.classify.protocol import ( - ClassificationChatRequest, - ClassificationCompletionRequest, - ClassificationResponse, -) from vllm.entrypoints.pooling.embed.protocol import ( EmbeddingBytesResponse, EmbeddingChatRequest, @@ -161,7 +156,6 @@ CompletionLikeRequest: TypeAlias = ( | TokenizeCompletionRequest | DetokenizeRequest | EmbeddingCompletionRequest - | ClassificationCompletionRequest | RerankRequest | ScoreRequest | PoolingCompletionRequest @@ -171,7 +165,6 @@ ChatLikeRequest: TypeAlias = ( ChatCompletionRequest | TokenizeChatRequest | EmbeddingChatRequest - | ClassificationChatRequest | PoolingChatRequest ) @@ -194,12 +187,10 @@ AnyResponse: TypeAlias = ( | TranscriptionResponse | TokenizeResponse | PoolingResponse - | ClassificationResponse | ScoreResponse | GenerateResponse ) - RequestT = TypeVar("RequestT", bound=AnyRequest) @@ -223,8 +214,8 @@ class ServeContext(Generic[RequestT]): class OpenAIServing: request_id_prefix: ClassVar[str] = """ - A short string prepended to every request’s ID (e.g. "embd", "classify") - so you can easily tell “this ID came from Embedding vs Classification.” + A short string prepended to every request’s ID (e.g. "embd") + so you can easily tell “this ID came from Embedding.” """ def __init__( @@ -456,7 +447,7 @@ class OpenAIServing: ) -> ErrorResponse | None: """ Default preprocessing hook. Subclasses may override - to prepare `ctx` (classification, embedding, etc.). + to prepare `ctx` (embedding, etc.). """ return None @@ -817,7 +808,7 @@ class OpenAIServing: token_num = len(input_ids) max_model_len = self.model_config.max_model_len - # Note: EmbeddingRequest, ClassificationRequest, + # Note: EmbeddingRequest, # and ScoreRequest doesn't have max_tokens if isinstance( request, @@ -828,8 +819,6 @@ class OpenAIServing: ScoreTextRequest, ScoreQueriesDocumentsRequest, RerankRequest, - ClassificationCompletionRequest, - ClassificationChatRequest, ), ): # Note: input length can be up to the entire model context length @@ -839,8 +828,6 @@ class OpenAIServing: ScoreDataRequest: "score", ScoreTextRequest: "score", ScoreQueriesDocumentsRequest: "score", - ClassificationCompletionRequest: "classification", - ClassificationChatRequest: "classification", } operation = operations.get(type(request), "embedding generation") raise VLLMValidationError( diff --git a/vllm/entrypoints/openai/responses/protocol.py b/vllm/entrypoints/openai/responses/protocol.py index b0ffd03..1ec88cc 100644 --- a/vllm/entrypoints/openai/responses/protocol.py +++ b/vllm/entrypoints/openai/responses/protocol.py @@ -328,8 +328,9 @@ class ResponsesRequest(OpenAIBaseModel): # Also check text.format for OpenAI-style json_schema if self.text is not None and self.text.format is not None: if structured_outputs is not None: - raise ValueError( - "Cannot specify both structured_outputs and text.format" + raise VLLMValidationError( + "Cannot specify both structured_outputs and text.format", + parameter="structured_outputs", ) response_format = self.text.format if ( @@ -378,14 +379,19 @@ class ResponsesRequest(OpenAIBaseModel): ) @model_validator(mode="before") + @classmethod def validate_background(cls, data): if not data.get("background"): return data if not data.get("store", True): - raise ValueError("background can only be used when `store` is true") + raise VLLMValidationError( + "background can only be used when `store` is true", + parameter="background", + ) return data @model_validator(mode="before") + @classmethod def validate_prompt(cls, data): if data.get("prompt") is not None: raise VLLMValidationError( @@ -394,16 +400,19 @@ class ResponsesRequest(OpenAIBaseModel): return data @model_validator(mode="before") + @classmethod def check_cache_salt_support(cls, data): if data.get("cache_salt") is not None and ( not isinstance(data["cache_salt"], str) or not data["cache_salt"] ): - raise ValueError( - "Parameter 'cache_salt' must be a non-empty string if provided." + raise VLLMValidationError( + "Parameter 'cache_salt' must be a non-empty string if provided.", + parameter="cache_salt", ) return data @model_validator(mode="before") + @classmethod def function_call_parsing(cls, data): """Parse function_call dictionaries into ResponseFunctionToolCall objects. This ensures Pydantic can properly resolve union types in the input field. diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index b9d526e..3cfb6ff 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -85,6 +85,8 @@ from vllm.entrypoints.openai.responses.protocol import ( ResponseCreatedEvent, ResponseInProgressEvent, ResponseInputOutputMessage, + ResponseReasoningPartAddedEvent, + ResponseReasoningPartDoneEvent, ResponsesRequest, ResponsesResponse, ResponseUsage, @@ -1339,6 +1341,19 @@ class OpenAIServingResponses(OpenAIServing): ), ) ) + yield _increment_sequence_number_and_return( + ResponseReasoningPartAddedEvent( + type="response.reasoning_part.added", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + content_index=current_content_index, + part=ResponseReasoningTextContent( + text="", + type="reasoning_text", + ), + ) + ) else: yield _increment_sequence_number_and_return( ResponseOutputItemAddedEvent( @@ -1354,22 +1369,21 @@ class OpenAIServingResponses(OpenAIServing): ), ) ) - yield _increment_sequence_number_and_return( - ResponseContentPartAddedEvent( - type="response.content_part.added", - sequence_number=-1, - output_index=current_output_index, - item_id=current_item_id, - content_index=current_content_index, - part=ResponseOutputText( - type="output_text", - text="", - annotations=[], - logprobs=[], - ), + yield _increment_sequence_number_and_return( + ResponseContentPartAddedEvent( + type="response.content_part.added", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + content_index=current_content_index, + part=ResponseOutputText( + type="output_text", + text="", + annotations=[], + logprobs=[], + ), + ) ) - ) - current_content_index += 1 first_delta_sent = True # todo(kebe7jun) tool call support @@ -1397,6 +1411,19 @@ class OpenAIServingResponses(OpenAIServing): text=reason_content, ) ) + yield _increment_sequence_number_and_return( + ResponseReasoningPartDoneEvent( + type="response.reasoning_part.done", + sequence_number=-1, + item_id=current_item_id, + output_index=current_output_index, + content_index=current_content_index, + part=ResponseReasoningTextContent( + text=reason_content, + type="reasoning_text", + ), + ) + ) current_content_index = 0 reasoning_item = ResponseReasoningItem( type="reasoning", @@ -1418,6 +1445,8 @@ class OpenAIServingResponses(OpenAIServing): item=reasoning_item, ) ) + current_output_index += 1 + current_item_id = str(uuid.uuid4()) yield _increment_sequence_number_and_return( ResponseOutputItemAddedEvent( type="response.output_item.added", @@ -1432,8 +1461,6 @@ class OpenAIServingResponses(OpenAIServing): ), ) ) - current_output_index += 1 - current_item_id = str(uuid.uuid4()) yield _increment_sequence_number_and_return( ResponseContentPartAddedEvent( type="response.content_part.added", @@ -1449,7 +1476,6 @@ class OpenAIServingResponses(OpenAIServing): ), ) ) - current_content_index += 1 # reset previous delta messages previous_delta_messages = [] @@ -1485,7 +1511,6 @@ class OpenAIServingResponses(OpenAIServing): ), ) ) - current_content_index += 1 previous_delta_messages.append(delta_message) if previous_delta_messages: @@ -1505,7 +1530,19 @@ class OpenAIServingResponses(OpenAIServing): text=reason_content, ) ) - current_content_index += 1 + yield _increment_sequence_number_and_return( + ResponseReasoningPartDoneEvent( + type="response.reasoning_part.done", + sequence_number=-1, + item_id=current_item_id, + output_index=current_output_index, + content_index=current_content_index, + part=ResponseReasoningTextContent( + text=reason_content, + type="reasoning_text", + ), + ) + ) reasoning_item = ResponseReasoningItem( type="reasoning", content=[ @@ -1543,7 +1580,6 @@ class OpenAIServingResponses(OpenAIServing): item_id=current_item_id, ) ) - current_content_index += 1 part = ResponseOutputText( text=final_content, type="output_text", @@ -1559,7 +1595,6 @@ class OpenAIServingResponses(OpenAIServing): part=part, ) ) - current_content_index += 1 item = ResponseOutputMessage( type="message", role="assistant", diff --git a/vllm/entrypoints/openai/speech_to_text/speech_to_text.py b/vllm/entrypoints/openai/speech_to_text/speech_to_text.py index 966e6d4..1c56f09 100644 --- a/vllm/entrypoints/openai/speech_to_text/speech_to_text.py +++ b/vllm/entrypoints/openai/speech_to_text/speech_to_text.py @@ -11,6 +11,7 @@ from typing import Final, Literal, TypeAlias, TypeVar, cast import numpy as np from fastapi import Request +from soundfile import LibsndfileError from transformers import PreTrainedTokenizerBase import vllm.envs as envs @@ -57,6 +58,14 @@ try: except ImportError: librosa = PlaceholderModule("librosa") # type: ignore[assignment] +# Public libsndfile error codes exposed via `soundfile.LibsndfileError.code`, soundfile +# being librosa's main backend. Used to validate if an audio loading error is due to a +# server error vs a client error (invalid audio file). +# 1 = unrecognised format (file is not a supported audio container) +# 3 = malformed file (corrupt or structurally invalid audio) +# 4 = unsupported encoding (codec not supported by this libsndfile build) +_BAD_SF_CODES = {1, 3, 4} + SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse SpeechToTextResponseVerbose: TypeAlias = ( TranscriptionResponseVerbose | TranslationResponseVerbose @@ -315,9 +324,15 @@ class OpenAISpeechToText(OpenAIServing): ) with io.BytesIO(audio_data) as bytes_: - # NOTE resample to model SR here for efficiency. This is also a - # pre-requisite for chunking, as it assumes Whisper SR. - y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate) + try: + # NOTE resample to model SR here for efficiency. This is also a + # pre-requisite for chunking, as it assumes Whisper SR. + y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate) + except LibsndfileError as exc: + # Distinguish client errors (invalid audio) from server errors + if exc.code in _BAD_SF_CODES: + raise ValueError("Invalid or unsupported audio file.") from exc + raise duration = librosa.get_duration(y=y, sr=sr) do_split_audio = ( diff --git a/vllm/entrypoints/openai/translations/__init__.py b/vllm/entrypoints/openai/translations/__init__.py deleted file mode 100644 index cf210d5..0000000 --- a/vllm/entrypoints/openai/translations/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import warnings - -warnings.warn( - "The 'vllm.entrypoints.openai.translations' module has been renamed to " - "'vllm.entrypoints.openai.speech_to_text'. Please update your imports. " - "This backward-compatible alias will be removed in version 0.17+.", - DeprecationWarning, - stacklevel=2, -) diff --git a/vllm/entrypoints/openai/translations/api_router.py b/vllm/entrypoints/openai/translations/api_router.py deleted file mode 100644 index 4a43bf8..0000000 --- a/vllm/entrypoints/openai/translations/api_router.py +++ /dev/null @@ -1,14 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import warnings - -warnings.warn( - "'vllm.entrypoints.openai.translations.api_router' has been moved to " - "'vllm.entrypoints.openai.speech_to_text.api_router'. Please update your " - "imports. This backward-compatible alias will be removed in version 0.17+.", - DeprecationWarning, - stacklevel=2, -) - -from vllm.entrypoints.openai.speech_to_text.api_router import * # noqa: F401,F403,E402 diff --git a/vllm/entrypoints/openai/translations/protocol.py b/vllm/entrypoints/openai/translations/protocol.py deleted file mode 100644 index c8ec156..0000000 --- a/vllm/entrypoints/openai/translations/protocol.py +++ /dev/null @@ -1,14 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import warnings - -warnings.warn( - "'vllm.entrypoints.openai.translations.protocol' has been moved to " - "'vllm.entrypoints.openai.speech_to_text.protocol'. Please update your " - "imports. This backward-compatible alias will be removed in version 0.17+.", - DeprecationWarning, - stacklevel=2, -) - -from vllm.entrypoints.openai.speech_to_text.protocol import * # noqa: F401,F403,E402 diff --git a/vllm/entrypoints/openai/translations/serving.py b/vllm/entrypoints/openai/translations/serving.py deleted file mode 100644 index 1749d61..0000000 --- a/vllm/entrypoints/openai/translations/serving.py +++ /dev/null @@ -1,14 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import warnings - -warnings.warn( - "'vllm.entrypoints.openai.translations.serving' has been moved to " - "'vllm.entrypoints.openai.speech_to_text.serving'. Please update your " - "imports. This backward-compatible alias will be removed in version 0.17+.", - DeprecationWarning, - stacklevel=2, -) - -from vllm.entrypoints.openai.speech_to_text.serving import * # noqa: F401,F403,E402 diff --git a/vllm/entrypoints/openai/translations/speech_to_text.py b/vllm/entrypoints/openai/translations/speech_to_text.py deleted file mode 100644 index eb26c6a..0000000 --- a/vllm/entrypoints/openai/translations/speech_to_text.py +++ /dev/null @@ -1,15 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import warnings - -warnings.warn( - "'vllm.entrypoints.openai.translations.speech_to_text' has been moved to " - "'vllm.entrypoints.openai.speech_to_text.speech_to_text'. Please update " - "your imports. This backward-compatible alias will be removed in version " - "0.17+.", - DeprecationWarning, - stacklevel=2, -) - -from vllm.entrypoints.openai.speech_to_text.speech_to_text import * # noqa: F401,F403,E402 diff --git a/vllm/entrypoints/pooling/__init__.py b/vllm/entrypoints/pooling/__init__.py index 1108be1..3ba131d 100644 --- a/vllm/entrypoints/pooling/__init__.py +++ b/vllm/entrypoints/pooling/__init__.py @@ -115,6 +115,7 @@ def init_pooling_state( request_logger=request_logger, score_template=resolved_chat_template, log_error_stack=args.log_error_stack, + use_gpu_for_pooling_score=getattr(args, "use_gpu_for_pooling_score", False), ) if any(t in supported_tasks for t in ("embed", "score", "token_embed")) else None diff --git a/vllm/entrypoints/pooling/base/io_processor.py b/vllm/entrypoints/pooling/base/io_processor.py new file mode 100644 index 0000000..254c3d6 --- /dev/null +++ b/vllm/entrypoints/pooling/base/io_processor.py @@ -0,0 +1,189 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable, Sequence +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Final + +from vllm import PoolingRequestOutput, PromptType +from vllm.config import ModelConfig +from vllm.entrypoints.chat_utils import ( + ChatCompletionMessageParam, + ChatTemplateConfig, + ChatTemplateContentFormatOption, + ConversationMessage, +) +from vllm.entrypoints.openai.engine.serving import RendererChatRequest, RendererRequest +from vllm.inputs import ProcessorInputs, SingletonPrompt +from vllm.renderers import BaseRenderer, merge_kwargs +from vllm.renderers.inputs import TokPrompt +from vllm.renderers.inputs.preprocess import parse_model_prompt, prompt_to_seq +from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers import ToolParser +from vllm.utils.mistral import is_mistral_tokenizer + + +class PoolingIOProcessor: + def __init__( + self, + model_config: ModelConfig, + renderer: BaseRenderer, + chat_template_config: ChatTemplateConfig, + ): + self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) + + self.model_config = model_config + self.renderer = renderer + + self.chat_template = chat_template_config.chat_template + self.chat_template_content_format: Final = ( + chat_template_config.chat_template_content_format + ) + self.trust_request_chat_template = ( + chat_template_config.trust_request_chat_template + ) + + def pre_process_online(self, *args, **kwargs): + raise NotImplementedError + + async def pre_process_online_async(self, *args, **kwargs): + return self.pre_process_online(*args, **kwargs) + + def pre_process_offline(self, *args, **kwargs): + raise NotImplementedError + + async def pre_process_offline_async(self, *args, **kwargs): + return self.pre_process_offline(*args, **kwargs) + + def post_process( + self, outputs: list[PoolingRequestOutput] + ) -> list[PoolingRequestOutput]: + return outputs + + async def post_process_async( + self, outputs: list[PoolingRequestOutput] + ) -> list[PoolingRequestOutput]: + return self.post_process(outputs) + + def create_pooling_params(self, request): + return request.to_pooling_params() + + def _preprocess_completion_online( + self, + request: RendererRequest, + prompt_input: str | list[str] | list[int] | list[list[int]] | None, + prompt_embeds: bytes | list[bytes] | None, + ) -> list[TokPrompt]: + renderer = self.renderer + model_config = self.model_config + + prompts = list[SingletonPrompt | bytes]() + if prompt_embeds is not None: # embeds take higher priority + prompts.extend(prompt_to_seq(prompt_embeds)) + if prompt_input is not None: + prompts.extend(prompt_to_seq(prompt_input)) + + parsed_prompts = [ + ( + prompt + if isinstance(prompt, bytes) + else parse_model_prompt(model_config, prompt) + ) + for prompt in prompts + ] + tok_params = request.build_tok_params(model_config) + + return renderer.render_cmpl( + parsed_prompts, + tok_params, + prompt_extras={ + k: v + for k in ("mm_processor_kwargs", "cache_salt") + if (v := getattr(request, k, None)) is not None + }, + ) + + def _preprocess_chat_online( + self, + request: RendererChatRequest, + messages: list[ChatCompletionMessageParam], + default_template: str | None, + default_template_content_format: ChatTemplateContentFormatOption, + default_template_kwargs: dict[str, Any] | None, + tool_dicts: list[dict[str, Any]] | None = None, + tool_parser: Callable[[TokenizerLike], ToolParser] | None = None, + ) -> tuple[list[ConversationMessage], list[TokPrompt]]: + renderer = self.renderer + + default_template_kwargs = merge_kwargs( + default_template_kwargs, + dict( + tools=tool_dicts, + tokenize=is_mistral_tokenizer(renderer.tokenizer), + ), + ) + + tok_params = request.build_tok_params(self.model_config) + chat_params = request.build_chat_params( + default_template, default_template_content_format + ).with_defaults(default_template_kwargs) + + (conversation,), (engine_prompt,) = renderer.render_chat( + [messages], + chat_params, + tok_params, + prompt_extras={ + k: v + for k in ("mm_processor_kwargs", "cache_salt") + if (v := getattr(request, k, None)) is not None + }, + ) + + return conversation, [engine_prompt] + + def _preprocess_completion_offline( + self, + prompts: PromptType | Sequence[PromptType], + tokenization_kwargs: dict[str, Any] | None = None, + ) -> Sequence[ProcessorInputs]: + renderer = self.renderer + model_config = self.model_config + + prompts = prompt_to_seq(prompts) + + parsed_prompts = [ + ( + prompt + if isinstance(prompt, bytes) + else parse_model_prompt(model_config, prompt) + ) + for prompt in prompts + ] + tok_params = renderer.default_cmpl_tok_params.with_kwargs( + **(tokenization_kwargs or {}) + ) + + return renderer.render_cmpl( + parsed_prompts, + tok_params, + ) + + def _validate_chat_template( + self, + request_chat_template: str | None, + chat_template_kwargs: dict[str, Any] | None, + trust_request_chat_template: bool, + ): + if not trust_request_chat_template and ( + request_chat_template is not None + or ( + chat_template_kwargs + and chat_template_kwargs.get("chat_template") is not None + ) + ): + raise ValueError( + "Chat template is passed with request, but " + "--trust-request-chat-template is not set. " + "Refused request with untrusted chat template." + ) + return None diff --git a/vllm/entrypoints/pooling/base/protocol.py b/vllm/entrypoints/pooling/base/protocol.py index 86dc12c..5394510 100644 --- a/vllm/entrypoints/pooling/base/protocol.py +++ b/vllm/entrypoints/pooling/base/protocol.py @@ -190,10 +190,6 @@ class EmbedRequestMixin(EncodingRequestMixin): description="Whether to use activation for the pooler outputs. " "`None` uses the pooler's default, which is `True` in most cases.", ) - normalize: bool | None = Field( - default=None, - description="Deprecated; please pass `use_activation` instead", - ) # --8<-- [end:embed-extra-params] diff --git a/vllm/entrypoints/pooling/base/serving.py b/vllm/entrypoints/pooling/base/serving.py new file mode 100644 index 0000000..813282d --- /dev/null +++ b/vllm/entrypoints/pooling/base/serving.py @@ -0,0 +1,378 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import time +from collections.abc import AsyncGenerator, Mapping +from dataclasses import dataclass, field +from http import HTTPStatus +from typing import ClassVar, Generic, TypeVar + +from fastapi import Request +from pydantic import ConfigDict +from starlette.datastructures import Headers +from starlette.responses import JSONResponse + +from vllm import ( + PoolingParams, + PoolingRequestOutput, + PromptType, + SamplingParams, + envs, +) +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import ( + ChatTemplateConfig, + ChatTemplateContentFormatOption, +) +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.engine.protocol import ErrorResponse +from vllm.entrypoints.openai.models.serving import OpenAIServingModels +from vllm.entrypoints.pooling.typing import AnyPoolingRequest, AnyPoolingResponse +from vllm.inputs import ProcessorInputs +from vllm.lora.request import LoRARequest +from vllm.renderers import BaseRenderer +from vllm.renderers.inputs.preprocess import extract_prompt_components +from vllm.sampling_params import BeamSearchParams +from vllm.tracing import ( + contains_trace_headers, + extract_trace_headers, + log_tracing_disabled_warning, +) +from vllm.utils import random_uuid +from vllm.utils.async_utils import merge_async_iterators + +from ...utils import create_error_response +from .io_processor import PoolingIOProcessor + +PoolingRequestT = TypeVar("PoolingRequestT", bound=AnyPoolingRequest) + + +@dataclass(kw_only=True) +class PoolingServeContext(Generic[PoolingRequestT]): + request: PoolingRequestT + raw_request: Request | None = None + model_name: str + request_id: str + created_time: int = field(default_factory=lambda: int(time.time())) + lora_request: LoRARequest | None = None + engine_prompts: list[ProcessorInputs] | None = None + + result_generator: AsyncGenerator[tuple[int, PoolingRequestOutput], None] | None = ( + None + ) + final_res_batch: list[PoolingRequestOutput] = field(default_factory=list) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class PoolingServing: + request_id_prefix: ClassVar[str] + + def __init__( + self, + engine_client: EngineClient, + models: OpenAIServingModels, + *, + request_logger: RequestLogger | None, + chat_template: str | None = None, + chat_template_content_format: ChatTemplateContentFormatOption = "auto", + trust_request_chat_template: bool = False, + return_tokens_as_token_ids: bool = False, + log_error_stack: bool = False, + ): + super().__init__() + self.engine_client = engine_client + self.models = models + self.model_config = models.model_config + self.max_model_len = self.model_config.max_model_len + self.request_logger = request_logger + self.return_tokens_as_token_ids = return_tokens_as_token_ids + self.log_error_stack = log_error_stack + self.chat_template_config = ChatTemplateConfig( + chat_template=chat_template, + chat_template_content_format=chat_template_content_format, + trust_request_chat_template=trust_request_chat_template, + ) + self.io_processor = self.init_io_processor( + model_config=models.model_config, + renderer=models.renderer, + chat_template_config=self.chat_template_config, + ) + + def init_io_processor( + self, + model_config: ModelConfig, + renderer: BaseRenderer, + chat_template_config: ChatTemplateConfig, + ) -> PoolingIOProcessor: + raise NotImplementedError + + async def __call__( + self, + request: AnyPoolingRequest, + raw_request: Request, + ) -> JSONResponse: + try: + model_name = self.models.model_name() + request_id = ( + f"{self.request_id_prefix}-{self._base_request_id(raw_request)}" + ) + + await self._check_model(request) + + ctx = PoolingServeContext( + request=request, + raw_request=raw_request, + model_name=model_name, + request_id=request_id, + ) + + self._validate_request(ctx) + self._maybe_get_adapters(ctx) + await self._preprocess(ctx) + await self._prepare_generators(ctx) + await self._collect_batch(ctx) + response = await self._build_response(ctx) + return JSONResponse(content=response.model_dump()) + except Exception as e: + error_response = create_error_response(e) + return JSONResponse( + content=error_response.model_dump(), + status_code=error_response.error.code, + ) + + async def _preprocess( + self, + ctx: PoolingServeContext, + ): + ctx.engine_prompts = await self.io_processor.pre_process_online_async( + ctx.request + ) + + async def _prepare_generators( + self, + ctx: PoolingServeContext, + ): + if ctx.engine_prompts is None: + raise ValueError("Engine prompts not available") + + generators: list[AsyncGenerator[PoolingRequestOutput, None]] = [] + + trace_headers = ( + None + if ctx.raw_request is None + else await self._get_trace_headers(ctx.raw_request.headers) + ) + + pooling_params = self.io_processor.create_pooling_params(ctx.request) + + for i, engine_prompt in enumerate(ctx.engine_prompts): + request_id_item = f"{ctx.request_id}-{i}" + + self._log_inputs( + request_id_item, + engine_prompt, + params=pooling_params, + lora_request=ctx.lora_request, + ) + + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=ctx.lora_request, + trace_headers=trace_headers, + priority=getattr(ctx.request, "priority", 0), + ) + + generators.append(generator) + + ctx.result_generator = merge_async_iterators(*generators) + + async def _collect_batch( + self, + ctx: PoolingServeContext, + ): + if ctx.engine_prompts is None: + raise ValueError("Engine prompts not available") + + if ctx.result_generator is None: + raise ValueError("Result generator not available") + + num_prompts = len(ctx.engine_prompts) + final_res_batch: list[PoolingRequestOutput | None] + final_res_batch = [None] * num_prompts + + async for i, res in ctx.result_generator: + final_res_batch[i] = res + + if None in final_res_batch: + raise ValueError("Failed to generate results for all prompts") + + ctx.final_res_batch = [res for res in final_res_batch if res is not None] + + async def _build_response( + self, + ctx: PoolingServeContext, + ) -> AnyPoolingResponse: + raise NotImplementedError + + @staticmethod + def _base_request_id( + raw_request: Request | None, default: str | None = None + ) -> str | None: + """Pulls the request id to use from a header, if provided""" + if raw_request is not None and ( + (req_id := raw_request.headers.get("X-Request-Id")) is not None + ): + return req_id + + return random_uuid() if default is None else default + + def _is_model_supported(self, model_name: str | None) -> bool: + if not model_name: + return True + return self.models.is_base_model(model_name) + + async def _check_model( + self, + request: AnyPoolingRequest, + ) -> ErrorResponse | None: + if self._is_model_supported(request.model): + return None + if request.model in self.models.lora_requests: + return None + if ( + envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING + and request.model + and (load_result := await self.models.resolve_lora(request.model)) + ): + if isinstance(load_result, LoRARequest): + return None + if ( + isinstance(load_result, ErrorResponse) + and load_result.error.code == HTTPStatus.BAD_REQUEST.value + ): + raise ValueError(load_result.error.message) + return None + + def _validate_request(self, ctx: PoolingServeContext) -> None: + truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens", None) + + if ( + truncate_prompt_tokens is not None + and truncate_prompt_tokens > self.max_model_len + ): + raise ValueError( + "truncate_prompt_tokens value is " + "greater than max_model_len." + " Please, select a smaller truncation size." + ) + return None + + async def _get_trace_headers( + self, + headers: Headers, + ) -> Mapping[str, str] | None: + is_tracing_enabled = await self.engine_client.is_tracing_enabled() + + if is_tracing_enabled: + return extract_trace_headers(headers) + + if contains_trace_headers(headers): + log_tracing_disabled_warning() + + return None + + def _maybe_get_adapters( + self, + ctx: PoolingServeContext, + supports_default_mm_loras: bool = False, + ): + request = ctx.request + if request.model in self.models.lora_requests: + ctx.lora_request = self.models.lora_requests[request.model] + + # Currently only support default modality specific loras + # if we have exactly one lora matched on the request. + if supports_default_mm_loras: + default_mm_lora = self._get_active_default_mm_loras(request) + if default_mm_lora is not None: + ctx.lora_request = default_mm_lora + + if self._is_model_supported(request.model): + return None + + # if _check_model has been called earlier, this will be unreachable + raise ValueError(f"The model `{request.model}` does not exist.") + + def _get_active_default_mm_loras( + self, request: AnyPoolingRequest + ) -> LoRARequest | None: + """Determine if there are any active default multimodal loras.""" + # TODO: Currently this is only enabled for chat completions + # to be better aligned with only being enabled for .generate + # when run offline. It would be nice to support additional + # tasks types in the future. + message_types = self._get_message_types(request) + default_mm_loras = set() + + for lora in self.models.lora_requests.values(): + # Best effort match for default multimodal lora adapters; + # There is probably a better way to do this, but currently + # this matches against the set of 'types' in any content lists + # up until '_', e.g., to match audio_url -> audio + if lora.lora_name in message_types: + default_mm_loras.add(lora) + + # Currently only support default modality specific loras if + # we have exactly one lora matched on the request. + if len(default_mm_loras) == 1: + return default_mm_loras.pop() + return None + + def _get_message_types(self, request: AnyPoolingRequest) -> set[str]: + """Retrieve the set of types from message content dicts up + until `_`; we use this to match potential multimodal data + with default per modality loras. + """ + message_types: set[str] = set() + + if not hasattr(request, "messages"): + return message_types + + messages = request.messages + if messages is None or isinstance(messages, (str, bytes)): + return message_types + + for message in messages: + if ( + isinstance(message, dict) + and "content" in message + and isinstance(message["content"], list) + ): + for content_dict in message["content"]: + if "type" in content_dict: + message_types.add(content_dict["type"].split("_")[0]) + return message_types + + def _log_inputs( + self, + request_id: str, + inputs: PromptType | ProcessorInputs, + params: SamplingParams | PoolingParams | BeamSearchParams | None, + lora_request: LoRARequest | None, + ) -> None: + if self.request_logger is None: + return + + components = extract_prompt_components(self.model_config, inputs) + + self.request_logger.log_inputs( + request_id, + components.text, + components.token_ids, + components.embeds, + params=params, + lora_request=lora_request, + ) diff --git a/vllm/entrypoints/pooling/classify/api_router.py b/vllm/entrypoints/pooling/classify/api_router.py index 8a1513e..0e99a86 100644 --- a/vllm/entrypoints/pooling/classify/api_router.py +++ b/vllm/entrypoints/pooling/classify/api_router.py @@ -3,16 +3,17 @@ from fastapi import APIRouter, Depends, Request from starlette.responses import JSONResponse -from typing_extensions import assert_never -from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.utils import validate_json_request from vllm.entrypoints.pooling.classify.protocol import ( ClassificationRequest, - ClassificationResponse, ) from vllm.entrypoints.pooling.classify.serving import ServingClassification -from vllm.entrypoints.utils import load_aware_call, with_cancellation +from vllm.entrypoints.utils import ( + create_error_response, + load_aware_call, + with_cancellation, +) router = APIRouter() @@ -24,25 +25,17 @@ def classify(request: Request) -> ServingClassification | None: @router.post("/classify", dependencies=[Depends(validate_json_request)]) @with_cancellation @load_aware_call -async def create_classify(request: ClassificationRequest, raw_request: Request): +async def create_classify( + request: ClassificationRequest, raw_request: Request +) -> JSONResponse: handler = classify(raw_request) if handler is None: - base_server = raw_request.app.state.openai_serving_tokenization - return base_server.create_error_response( + error_response = create_error_response( message="The model does not support Classification API" ) - - try: - generator = await handler.create_classify(request, raw_request) - except Exception as e: - generator = handler.create_error_response(e) - - if isinstance(generator, ErrorResponse): return JSONResponse( - content=generator.model_dump(), status_code=generator.error.code + content=error_response.model_dump(), + status_code=error_response.error.code, ) - elif isinstance(generator, ClassificationResponse): - return JSONResponse(content=generator.model_dump()) - - assert_never(generator) + return await handler(request, raw_request) diff --git a/vllm/entrypoints/pooling/classify/io_processor.py b/vllm/entrypoints/pooling/classify/io_processor.py new file mode 100644 index 0000000..90d5b0e --- /dev/null +++ b/vllm/entrypoints/pooling/classify/io_processor.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Sequence +from typing import Any + +from vllm import PromptType +from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor +from vllm.entrypoints.pooling.classify.protocol import ( + ClassificationChatRequest, + ClassificationCompletionRequest, +) +from vllm.inputs import ProcessorInputs +from vllm.renderers.inputs import TokPrompt + + +class ClassifyIOProcessor(PoolingIOProcessor): + def pre_process_online( + self, request: ClassificationCompletionRequest | ClassificationChatRequest + ) -> list[TokPrompt] | None: + if isinstance(request, ClassificationChatRequest): + self._validate_chat_template( + request_chat_template=request.chat_template, + chat_template_kwargs=request.chat_template_kwargs, + trust_request_chat_template=self.trust_request_chat_template, + ) + _, engine_prompts = self._preprocess_chat_online( + request, + request.messages, + default_template=self.chat_template, + default_template_content_format=self.chat_template_content_format, + default_template_kwargs=None, + ) + elif isinstance(request, ClassificationCompletionRequest): + engine_prompts = self._preprocess_completion_online( + request, + prompt_input=request.input, + prompt_embeds=None, + ) + else: + raise ValueError("Invalid classification request type") + return engine_prompts + + def pre_process_offline( + self, + prompts: PromptType | Sequence[PromptType], + tokenization_kwargs: dict[str, Any] | None = None, + ) -> Sequence[ProcessorInputs]: + return self._preprocess_completion_offline( + prompts=prompts, tokenization_kwargs=tokenization_kwargs + ) diff --git a/vllm/entrypoints/pooling/classify/protocol.py b/vllm/entrypoints/pooling/classify/protocol.py index 3c4bbd8..bfc38eb 100644 --- a/vllm/entrypoints/pooling/classify/protocol.py +++ b/vllm/entrypoints/pooling/classify/protocol.py @@ -40,7 +40,6 @@ class ClassificationCompletionRequest( def to_pooling_params(self): return PoolingParams( task="classify", - truncate_prompt_tokens=self.truncate_prompt_tokens, use_activation=self.use_activation, ) @@ -63,7 +62,6 @@ class ClassificationChatRequest( def to_pooling_params(self): return PoolingParams( task="classify", - truncate_prompt_tokens=self.truncate_prompt_tokens, use_activation=self.use_activation, ) diff --git a/vllm/entrypoints/pooling/classify/serving.py b/vllm/entrypoints/pooling/classify/serving.py index 8cdbbde..efd4be7 100644 --- a/vllm/entrypoints/pooling/classify/serving.py +++ b/vllm/entrypoints/pooling/classify/serving.py @@ -1,116 +1,57 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Final, TypeAlias +from typing import TypeAlias -import jinja2 import numpy as np -from fastapi import Request -from vllm.engine.protocol import EngineClient -from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption -from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.engine.protocol import ErrorResponse, UsageInfo -from vllm.entrypoints.openai.engine.serving import OpenAIServing, ServeContext -from vllm.entrypoints.openai.models.serving import OpenAIServingModels -from vllm.entrypoints.pooling.classify.protocol import ( - ClassificationChatRequest, - ClassificationCompletionRequest, +from vllm import ClassificationOutput +from vllm.config import ModelConfig +from vllm.entrypoints.chat_utils import ChatTemplateConfig +from vllm.entrypoints.openai.engine.protocol import UsageInfo +from vllm.entrypoints.pooling.base.serving import PoolingServeContext, PoolingServing +from vllm.logger import init_logger +from vllm.renderers import BaseRenderer + +from .io_processor import ClassifyIOProcessor +from .protocol import ( ClassificationData, ClassificationRequest, ClassificationResponse, ) -from vllm.logger import init_logger -from vllm.outputs import ClassificationOutput logger = init_logger(__name__) -ClassificationServeContext: TypeAlias = ServeContext[ClassificationRequest] +ClassificationServeContext: TypeAlias = PoolingServeContext[ClassificationRequest] -class ServingClassification(OpenAIServing): +class ServingClassification(PoolingServing): request_id_prefix = "classify" - def __init__( + def init_io_processor( self, - engine_client: EngineClient, - models: OpenAIServingModels, - *, - request_logger: RequestLogger | None, - chat_template: str | None = None, - chat_template_content_format: ChatTemplateContentFormatOption = "auto", - trust_request_chat_template: bool = False, - log_error_stack: bool = False, - ) -> None: - super().__init__( - engine_client=engine_client, - models=models, - request_logger=request_logger, - log_error_stack=log_error_stack, + model_config: ModelConfig, + renderer: BaseRenderer, + chat_template_config: ChatTemplateConfig, + ) -> ClassifyIOProcessor: + return ClassifyIOProcessor( + model_config=model_config, + renderer=renderer, + chat_template_config=chat_template_config, ) - self.chat_template = chat_template - self.chat_template_content_format: Final = chat_template_content_format - self.trust_request_chat_template = trust_request_chat_template - - async def _preprocess( + async def _build_response( self, ctx: ClassificationServeContext, - ) -> ErrorResponse | None: - """ - Process classification inputs: tokenize text, resolve adapters, - and prepare model-specific inputs. - """ - try: - ctx.lora_request = self._maybe_get_adapters(ctx.request) + ) -> ClassificationResponse: + final_res_batch_checked = await self.io_processor.post_process_async( + ctx.final_res_batch + ) - if isinstance(ctx.request, ClassificationChatRequest): - error_check_ret = self._validate_chat_template( - request_chat_template=ctx.request.chat_template, - chat_template_kwargs=ctx.request.chat_template_kwargs, - trust_request_chat_template=self.trust_request_chat_template, - ) - if error_check_ret: - return error_check_ret - - _, ctx.engine_prompts = await self._preprocess_chat( - ctx.request, - ctx.request.messages, - default_template=self.chat_template, - default_template_content_format=self.chat_template_content_format, - default_template_kwargs=None, - ) - elif isinstance(ctx.request, ClassificationCompletionRequest): - ctx.engine_prompts = await self._preprocess_completion( - ctx.request, - prompt_input=ctx.request.input, - prompt_embeds=None, - ) - else: - return self.create_error_response("Invalid classification request type") - - return None - - except (ValueError, TypeError, jinja2.TemplateError) as e: - logger.exception("Error in preprocessing prompt inputs") - return self.create_error_response(str(e)) - - def _build_response( - self, - ctx: ClassificationServeContext, - ) -> ClassificationResponse | ErrorResponse: - """ - Convert model outputs to a formatted classification response - with probabilities and labels. - """ id2label = getattr(self.model_config.hf_config, "id2label", {}) - - items: list[ClassificationData] = [] num_prompt_tokens = 0 - - final_res_batch_checked = ctx.final_res_batch - + items: list[ClassificationData] = [] for idx, final_res in enumerate(final_res_batch_checked): classify_res = ClassificationOutput.from_base(final_res.outputs) @@ -141,20 +82,3 @@ class ServingClassification(OpenAIServing): data=items, usage=usage, ) - - async def create_classify( - self, - request: ClassificationRequest, - raw_request: Request, - ) -> ClassificationResponse | ErrorResponse: - model_name = self.models.model_name() - request_id = f"{self.request_id_prefix}-{self._base_request_id(raw_request)}" - - ctx = ClassificationServeContext( - request=request, - raw_request=raw_request, - model_name=model_name, - request_id=request_id, - ) - - return await self.handle(ctx) # type: ignore[return-value] diff --git a/vllm/entrypoints/pooling/embed/protocol.py b/vllm/entrypoints/pooling/embed/protocol.py index 4f83105..4b47c65 100644 --- a/vllm/entrypoints/pooling/embed/protocol.py +++ b/vllm/entrypoints/pooling/embed/protocol.py @@ -14,12 +14,9 @@ from vllm.entrypoints.pooling.base.protocol import ( EmbedRequestMixin, PoolingBasicRequestMixin, ) -from vllm.logger import init_logger from vllm.renderers import TokenizeParams from vllm.utils import random_uuid -logger = init_logger(__name__) - def _get_max_total_output_tokens( model_config: ModelConfig, @@ -60,18 +57,10 @@ class EmbeddingCompletionRequest( ) def to_pooling_params(self): - if self.normalize is not None: - logger.warning_once( - "`normalize` is deprecated and will be removed in v0.17. " - "Please pass `use_activation` instead." - ) - self.use_activation = self.normalize - return PoolingParams( task="embed", dimensions=self.dimensions, use_activation=self.use_activation, - truncate_prompt_tokens=self.truncate_prompt_tokens, ) @@ -97,18 +86,10 @@ class EmbeddingChatRequest( ) def to_pooling_params(self): - if self.normalize is not None: - logger.warning_once( - "`normalize` is deprecated and will be removed in v0.17. " - "Please pass `use_activation` instead." - ) - self.use_activation = self.normalize - return PoolingParams( task="embed", dimensions=self.dimensions, use_activation=self.use_activation, - truncate_prompt_tokens=self.truncate_prompt_tokens, ) diff --git a/vllm/entrypoints/pooling/io_processor_factories.py b/vllm/entrypoints/pooling/io_processor_factories.py new file mode 100644 index 0000000..9747676 --- /dev/null +++ b/vllm/entrypoints/pooling/io_processor_factories.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from vllm.config import ModelConfig +from vllm.entrypoints.chat_utils import ChatTemplateConfig +from vllm.entrypoints.pooling.base.io_processor import PoolingIOProcessor +from vllm.renderers import BaseRenderer +from vllm.tasks import SupportedTask + + +def init_pooling_io_processors( + supported_tasks: tuple[SupportedTask, ...], + model_config: ModelConfig, + renderer: BaseRenderer, + chat_template_config: ChatTemplateConfig, +) -> dict[str, PoolingIOProcessor]: + pooling_io_processors: dict[str, PoolingIOProcessor] = {} + + if "classify" in supported_tasks: + from vllm.entrypoints.pooling.classify.io_processor import ( + ClassifyIOProcessor, + ) + + pooling_io_processors["classify"] = ClassifyIOProcessor( + model_config=model_config, + renderer=renderer, + chat_template_config=chat_template_config, + ) + + return pooling_io_processors diff --git a/vllm/entrypoints/pooling/pooling/protocol.py b/vllm/entrypoints/pooling/pooling/protocol.py index a8c1c59..b99f989 100644 --- a/vllm/entrypoints/pooling/pooling/protocol.py +++ b/vllm/entrypoints/pooling/pooling/protocol.py @@ -16,13 +16,10 @@ from vllm.entrypoints.pooling.base.protocol import ( EncodingRequestMixin, PoolingBasicRequestMixin, ) -from vllm.logger import init_logger from vllm.renderers import TokenizeParams from vllm.tasks import PoolingTask from vllm.utils import random_uuid -logger = init_logger(__name__) - class PoolingCompletionRequest( PoolingBasicRequestMixin, @@ -45,16 +42,8 @@ class PoolingCompletionRequest( ) def to_pooling_params(self): - if self.normalize is not None: - logger.warning_once( - "`normalize` is deprecated and will be removed in v0.17. " - "Please pass `use_activation` instead." - ) - self.use_activation = self.normalize - return PoolingParams( task=self.task, - truncate_prompt_tokens=self.truncate_prompt_tokens, use_activation=self.use_activation, dimensions=self.dimensions, ) @@ -78,16 +67,8 @@ class PoolingChatRequest( ) def to_pooling_params(self): - if self.normalize is not None: - logger.warning_once( - "`normalize` is deprecated and will be removed in v0.17. " - "Please pass `use_activation` instead." - ) - self.use_activation = self.normalize - return PoolingParams( task=self.task, - truncate_prompt_tokens=self.truncate_prompt_tokens, use_activation=self.use_activation, dimensions=self.dimensions, ) diff --git a/vllm/entrypoints/pooling/score/protocol.py b/vllm/entrypoints/pooling/score/protocol.py index a85ed5d..643eeed 100644 --- a/vllm/entrypoints/pooling/score/protocol.py +++ b/vllm/entrypoints/pooling/score/protocol.py @@ -37,7 +37,6 @@ class ScoreRequestMixin(PoolingBasicRequestMixin, ClassifyRequestMixin): def to_pooling_params(self, task: PoolingTask = "score"): return PoolingParams( task=task, - truncate_prompt_tokens=self.truncate_prompt_tokens, use_activation=self.use_activation, ) @@ -113,7 +112,6 @@ class RerankRequest(PoolingBasicRequestMixin, ClassifyRequestMixin): def to_pooling_params(self, task: PoolingTask = "score"): return PoolingParams( task=task, - truncate_prompt_tokens=self.truncate_prompt_tokens, use_activation=self.use_activation, ) diff --git a/vllm/entrypoints/pooling/score/serving.py b/vllm/entrypoints/pooling/score/serving.py index 3fe18ca..60d6db6 100644 --- a/vllm/entrypoints/pooling/score/serving.py +++ b/vllm/entrypoints/pooling/score/serving.py @@ -31,7 +31,7 @@ from vllm.entrypoints.pooling.score.utils import ( ScoreInputs, _cosine_similarity, compress_token_type_ids, - compute_maxsim_score, + compute_maxsim_scores, get_score_prompt, parse_score_data_single, validate_score_input, @@ -56,6 +56,7 @@ class ServingScores(OpenAIServing): request_logger: RequestLogger | None, score_template: str | None = None, log_error_stack: bool = False, + use_gpu_for_pooling_score: bool = False, ) -> None: super().__init__( engine_client=engine_client, @@ -64,6 +65,7 @@ class ServingScores(OpenAIServing): log_error_stack=log_error_stack, ) self.score_template = score_template + self.use_gpu_for_pooling_score = use_gpu_for_pooling_score self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) @@ -311,19 +313,18 @@ class ServingScores(OpenAIServing): # Compute MaxSim scores from vllm.outputs import PoolingOutput + maxsim_scores = compute_maxsim_scores( + [emb.outputs.data for emb in emb_data_1], + [emb.outputs.data for emb in emb_data_2], + use_gpu_for_pooling_score=self.use_gpu_for_pooling_score, + ) + scores: list[PoolingRequestOutput] = [] padding: list[int] = [] if (pad_token_id := tokenizer.pad_token_id) is not None: padding = [pad_token_id] - for emb_1, emb_2 in zip(emb_data_1, emb_data_2): - # emb_1.outputs.data: [query_len, dim] - # emb_2.outputs.data: [doc_len, dim] - q_emb = emb_1.outputs.data - d_emb = emb_2.outputs.data - - maxsim_score = compute_maxsim_score(q_emb, d_emb) - + for emb_1, emb_2, maxsim_score in zip(emb_data_1, emb_data_2, maxsim_scores): tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids scores.append( diff --git a/vllm/entrypoints/pooling/score/utils.py b/vllm/entrypoints/pooling/score/utils.py index 60e71ff..65611dc 100644 --- a/vllm/entrypoints/pooling/score/utils.py +++ b/vllm/entrypoints/pooling/score/utils.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from typing import Any, TypeAlias, cast import torch @@ -25,6 +25,7 @@ from vllm.inputs.data import PromptType, TextPrompt from vllm.model_executor.models.interfaces import supports_score_template from vllm.multimodal.inputs import MultiModalDataDict, MultiModalUUIDDict from vllm.outputs import PoolingRequestOutput +from vllm.platforms import current_platform from vllm.renderers.hf import safe_apply_chat_template from vllm.tokenizers import TokenizerLike @@ -53,6 +54,91 @@ def compute_maxsim_score(q_emb: torch.Tensor, d_emb: torch.Tensor) -> torch.Tens return token_scores.amax(dim=-1).sum() +def _should_use_gpu_for_maxsim(use_gpu_for_pooling_score: bool) -> bool: + return use_gpu_for_pooling_score and not current_platform.is_cpu() + + +def compute_maxsim_scores( + q_embs: Sequence[torch.Tensor], + d_embs: Sequence[torch.Tensor], + max_batch_size: int = 16, + max_score_matrix_elements: int = 16_000_000, + use_gpu_for_pooling_score: bool = False, +) -> list[torch.Tensor]: + """Compute ColBERT MaxSim scores in padded mini-batches.""" + if len(q_embs) != len(d_embs): + raise ValueError("q_embs and d_embs must have the same length") + + num_pairs = len(q_embs) + if num_pairs == 0: + return [] + + for q_emb, d_emb in zip(q_embs, d_embs): + if q_emb.ndim != 2 or d_emb.ndim != 2: + raise ValueError("Each embedding tensor must be 2-D") + if q_emb.shape[1] != d_emb.shape[1]: + raise ValueError("Query and document embeddings must have same dim") + + compute_device = torch.device( + current_platform.device_type + if _should_use_gpu_for_maxsim(use_gpu_for_pooling_score) + else "cpu" + ) + scores: list[torch.Tensor] = [] + start = 0 + while start < num_pairs: + end = min(start + max_batch_size, num_pairs) + max_q = max(int(x.shape[0]) for x in q_embs[start:end]) + max_d = max(int(x.shape[0]) for x in d_embs[start:end]) + + # keep score matrix bounded to avoid oversized allocations. + while ( + end - start > 1 + and (end - start) * max_q * max_d > max_score_matrix_elements + ): + end -= 1 + max_q = max(int(x.shape[0]) for x in q_embs[start:end]) + max_d = max(int(x.shape[0]) for x in d_embs[start:end]) + + batch_q = q_embs[start:end] + batch_d = d_embs[start:end] + batch_size = end - start + dim = int(batch_q[0].shape[1]) + dtype = batch_q[0].dtype + + q_batch = torch.zeros( + (batch_size, max_q, dim), dtype=dtype, device=compute_device + ) + d_batch = torch.zeros( + (batch_size, max_d, dim), dtype=dtype, device=compute_device + ) + q_mask = torch.zeros( + (batch_size, max_q), dtype=torch.bool, device=compute_device + ) + d_mask = torch.zeros( + (batch_size, max_d), dtype=torch.bool, device=compute_device + ) + + # copy to padded tensors + for i, (q_emb, d_emb) in enumerate(zip(batch_q, batch_d)): + q_len = int(q_emb.shape[0]) + d_len = int(d_emb.shape[0]) + q_batch[i, :q_len] = q_emb.to(device=compute_device, dtype=dtype) + d_batch[i, :d_len] = d_emb.to(device=compute_device, dtype=dtype) + q_mask[i, :q_len] = True + d_mask[i, :d_len] = True + + token_scores = torch.bmm(q_batch, d_batch.transpose(1, 2)) + token_scores.masked_fill_(~d_mask.unsqueeze(1), float("-inf")) + max_per_query = token_scores.amax(dim=-1) + max_per_query.masked_fill_(~q_mask, 0) + batch_scores = max_per_query.sum(dim=-1).to("cpu") + scores.extend(batch_scores.unbind(0)) + start = end + + return [cast(torch.Tensor, score) for score in scores] + + class ScoreMultiModalParam(TypedDict, total=False): """ A specialized parameter type for scoring multimodal content diff --git a/vllm/entrypoints/pooling/typing.py b/vllm/entrypoints/pooling/typing.py new file mode 100644 index 0000000..87d6487 --- /dev/null +++ b/vllm/entrypoints/pooling/typing.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import TypeAlias + +from vllm.entrypoints.pooling.classify.protocol import ( + ClassificationChatRequest, + ClassificationCompletionRequest, + ClassificationResponse, +) +from vllm.entrypoints.pooling.embed.protocol import ( + EmbeddingBytesResponse, + EmbeddingChatRequest, + EmbeddingCompletionRequest, + EmbeddingResponse, +) +from vllm.entrypoints.pooling.pooling.protocol import ( + IOProcessorRequest, + PoolingChatRequest, + PoolingCompletionRequest, + PoolingResponse, +) +from vllm.entrypoints.pooling.score.protocol import ( + RerankRequest, + ScoreRequest, + ScoreResponse, +) + +PoolingCompletionLikeRequest: TypeAlias = ( + EmbeddingCompletionRequest + | ClassificationCompletionRequest + | RerankRequest + | ScoreRequest + | PoolingCompletionRequest +) + +PoolingChatLikeRequest: TypeAlias = ( + EmbeddingChatRequest | ClassificationChatRequest | PoolingChatRequest +) + +AnyPoolingRequest: TypeAlias = ( + PoolingCompletionLikeRequest | PoolingChatLikeRequest | IOProcessorRequest +) + +AnyPoolingResponse: TypeAlias = ( + ClassificationResponse + | EmbeddingResponse + | EmbeddingBytesResponse + | PoolingResponse + | ScoreResponse +) diff --git a/vllm/entrypoints/sagemaker/api_router.py b/vllm/entrypoints/sagemaker/api_router.py index 1138225..32faaa0 100644 --- a/vllm/entrypoints/sagemaker/api_router.py +++ b/vllm/entrypoints/sagemaker/api_router.py @@ -13,6 +13,7 @@ from fastapi.responses import JSONResponse, Response from vllm.entrypoints.openai.engine.protocol import ErrorResponse from vllm.entrypoints.openai.engine.serving import OpenAIServing from vllm.entrypoints.openai.utils import validate_json_request +from vllm.entrypoints.pooling.base.serving import PoolingServing from vllm.entrypoints.serve.instrumentator.basic import base from vllm.entrypoints.serve.instrumentator.health import health from vllm.tasks import POOLING_TASKS, SupportedTask @@ -20,7 +21,7 @@ from vllm.tasks import POOLING_TASKS, SupportedTask # TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers # (requires typing_extensions >= 4.13) RequestType = Any -GetHandlerFn = Callable[[Request], OpenAIServing | None] +GetHandlerFn = Callable[[Request], OpenAIServing | PoolingServing | None] EndpointFn = Callable[[RequestType, Request], Awaitable[Any]] diff --git a/vllm/entrypoints/utils.py b/vllm/entrypoints/utils.py index 34df85f..6390a72 100644 --- a/vllm/entrypoints/utils.py +++ b/vllm/entrypoints/utils.py @@ -5,7 +5,10 @@ import asyncio import dataclasses import functools import os +import sys +import traceback from argparse import Namespace +from http import HTTPStatus from logging import Logger from string import Template from typing import TYPE_CHECKING @@ -17,17 +20,23 @@ from starlette.background import BackgroundTask, BackgroundTasks from vllm import envs from vllm.engine.arg_utils import EngineArgs +from vllm.exceptions import VLLMValidationError from vllm.logger import current_formatter_type, init_logger from vllm.platforms import current_platform from vllm.utils.argparse_utils import FlexibleArgumentParser if TYPE_CHECKING: - from vllm.entrypoints.openai.engine.protocol import StreamOptions + from vllm.entrypoints.openai.engine.protocol import ( + ErrorInfo, + ErrorResponse, + StreamOptions, + ) from vllm.entrypoints.openai.models.protocol import LoRAModulePath else: - StreamOptions = object + ErrorResponse = object + ErrorInfo = object LoRAModulePath = object - + StreamOptions = object logger = init_logger(__name__) @@ -291,3 +300,59 @@ def log_version_and_model(lgr: Logger, version: str, model_name: str) -> None: message = logo_template.substitute(colors) lgr.info(message, version, model_name) + + +def create_error_response( + message: str | Exception, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST, + param: str | None = None, + log_error_stack: bool = False, +) -> "ErrorResponse": + exc: Exception | None = None + + from vllm.entrypoints.openai.engine.protocol import ErrorInfo, ErrorResponse + + if isinstance(message, Exception): + exc = message + + if isinstance(exc, VLLMValidationError): + err_type = "BadRequestError" + status_code = HTTPStatus.BAD_REQUEST + param = exc.parameter + elif isinstance(exc, (ValueError, TypeError, RuntimeError, OverflowError)): + # Common validation errors from user input + err_type = "BadRequestError" + status_code = HTTPStatus.BAD_REQUEST + param = None + elif isinstance(exc, NotImplementedError): + err_type = "NotImplementedError" + status_code = HTTPStatus.NOT_IMPLEMENTED + param = None + elif exc.__class__.__name__ == "TemplateError": + # jinja2.TemplateError (avoid importing jinja2) + err_type = "BadRequestError" + status_code = HTTPStatus.BAD_REQUEST + param = None + else: + err_type = "InternalServerError" + status_code = HTTPStatus.INTERNAL_SERVER_ERROR + param = None + + message = str(exc) + + if log_error_stack: + exc_type, _, _ = sys.exc_info() + if exc_type is not None: + traceback.print_exc() + else: + traceback.print_stack() + + return ErrorResponse( + error=ErrorInfo( + message=sanitize_message(message), + type=err_type, + code=status_code.value, + param=param, + ) + ) diff --git a/vllm/env_override.py b/vllm/env_override.py index 181d000..2799221 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -482,3 +482,44 @@ if is_torch_equal("2.9.0"): PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched GraphLowering._update_scheduler = _update_scheduler_patched + +# =================================================== +# torch 2.11 Inductor constrain_to_fx_strides monkeypatch +# =================================================== +# Patch the inductor's `constrain_to_fx_strides` to handle opaque +# (non-tensor) arguments. The original calls `.stride()` on every FX +# arg's meta value, which crashes on FakeScriptObject (the compile-time +# proxy for hoisted opaque types). The patched version skips args +# whose meta value is not a torch.Tensor. +# Upstream issue: https://github.com/pytorch/pytorch/issues/175973 + +from vllm.utils.torch_utils import is_torch_equal_or_newer + +if is_torch_equal_or_newer("2.11.0.dev"): + import torch._inductor.ir as _ir + import torch._inductor.lowering as _lowering + from torch._inductor.virtualized import V as _V + + _orig_constrain = _lowering.constrain_to_fx_strides + + def _patched_constrain_to_fx_strides(fx_node, *args, **kwargs): + def apply_constraint(arg, fx_arg): + if isinstance(arg, _ir.IRNode): + meta_val = fx_arg.meta.get("val") + if isinstance(meta_val, torch.Tensor): + stride_order = _ir.get_stride_order( + meta_val.stride(), _V.graph.sizevars.shape_env + ) + return _ir.ExternKernel.require_stride_order(arg, stride_order) + return arg + if isinstance(arg, dict): + return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg} + return arg + + args = tuple( + apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args) + ) + kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} + return args, kwargs + + _lowering.constrain_to_fx_strides = _patched_constrain_to_fx_strides diff --git a/vllm/envs.py b/vllm/envs.py index be3a568..9f1f5d8 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -26,6 +26,9 @@ if TYPE_CHECKING: VLLM_ENGINE_READY_TIMEOUT_S: int = 600 VLLM_API_KEY: str | None = None VLLM_DEBUG_LOG_API_SERVER_RESPONSE: bool = False + VLLM_ENABLE_PP_MIX_ILU_SCHEDULING: bool = False + VLLM_ENABLE_PP_ILU_OPT: bool = False + VLLM_PP_ILU_OPT_BATCH_QUEUE_SIZE: int = 0 S3_ACCESS_KEY_ID: str | None = None S3_SECRET_ACCESS_KEY: str | None = None S3_ENDPOINT_URL: str | None = None @@ -35,7 +38,7 @@ if TYPE_CHECKING: VLLM_USAGE_STATS_SERVER: str = "https://stats.vllm.ai" VLLM_NO_USAGE_STATS: bool = False VLLM_DO_NOT_TRACK: bool = False - VLLM_USAGE_SOURCE: str = "" + VLLM_USAGE_SOURCE: str = "production" VLLM_CONFIGURE_LOGGING: bool = True VLLM_LOGGING_LEVEL: str = "INFO" VLLM_LOGGING_PREFIX: str = "" @@ -48,7 +51,7 @@ if TYPE_CHECKING: VLLM_USE_FLASHINFER_SAMPLER: bool | None = None VLLM_PP_LAYER_PARTITION: str | None = None VLLM_CPU_KVCACHE_SPACE: int | None = 0 - VLLM_CPU_OMP_THREADS_BIND: str = "" + VLLM_CPU_OMP_THREADS_BIND: str = "auto" VLLM_CPU_NUM_OF_RESERVED_CPU: int | None = None VLLM_CPU_SGL_KERNEL: bool = False VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") @@ -89,13 +92,14 @@ if TYPE_CHECKING: VLLM_LORA_RESOLVER_CACHE_DIR: str | None = None VLLM_LORA_RESOLVER_HF_REPO_LIST: str | None = None VLLM_USE_AOT_COMPILE: bool = False - VLLM_USE_BYTECODE_HOOK: bool = False + VLLM_USE_BYTECODE_HOOK: bool = True VLLM_FORCE_AOT_LOAD: bool = False VLLM_USE_MEGA_AOT_ARTIFACT: bool = False VLLM_USE_TRITON_AWQ: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_SKIP_P2P_CHECK: bool = False VLLM_DISABLED_KERNELS: list[str] = [] + VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE: bool = True VLLM_DISABLE_PYNCCL: bool = False VLLM_USE_OINK_OPS: bool = False VLLM_ROCM_USE_AITER: bool = False @@ -106,7 +110,7 @@ if TYPE_CHECKING: VLLM_ROCM_USE_AITER_MLA: bool = True VLLM_ROCM_USE_AITER_MHA: bool = True VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False - VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = True + VLLM_ROCM_USE_AITER_TRITON_ROPE: bool = False VLLM_ROCM_USE_AITER_FP8BMM: bool = True VLLM_ROCM_USE_AITER_FP4BMM: bool = True VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION: bool = False @@ -168,7 +172,7 @@ if TYPE_CHECKING: VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "masked_gemm"] = ( "latency" ) - VLLM_FLASHINFER_ALLREDUCE_BACKEND: Literal["auto", "trtllm", "mnnvl"] = "auto" + VLLM_FLASHINFER_ALLREDUCE_BACKEND: Literal["auto", "trtllm", "mnnvl"] = "trtllm" VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024 VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 @@ -231,7 +235,7 @@ if TYPE_CHECKING: VLLM_USE_FBGEMM: bool = False VLLM_GC_DEBUG: str = "" VLLM_DEBUG_WORKSPACE: bool = False - VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = False + VLLM_DISABLE_SHARED_EXPERTS_STREAM: bool = True VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_USE_V2_MODEL_RUNNER: bool = False @@ -243,17 +247,33 @@ if TYPE_CHECKING: VLLM_LORA_DISABLE_PDL: bool = False VLLM_ENABLE_CUDA_COMPATIBILITY: bool = False VLLM_CUDA_COMPATIBILITY_PATH: str | None = None + VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: bool = False + VLLM_ELASTIC_EP_DRAIN_REQUESTS: bool = False + VLLM_FLAT_LOGPROBS: bool = False # optional envs we add. VLLM_W8A8_MOE_USE_W4A8: bool = False + VLLM_WNA16_MOE_USE_W4A8: bool = False + VLLM_W8A8_FORMAT: str = "TN" VLLM_W4A8_FORMAT: str = "TN" VLLM_W4A8_VERSION: int = 2 VLLM_MIX_QUANTIZATION_TYPE: str = "" VLLM_MLA_CUSTOMIZE: bool = True VLLM_USE_INT8_MLA: bool = False - VLLM_W8A8_LINEAR_USE_W4A8: bool = False - VLLM_FORCE_NCCL_COMM: bool =False - VLLM_KV_DISABLE_CROSS_GROUP_SHARE: bool = False + # support Iluvatar IxServer + VLLM_ATTN_OPT_LEVEL: int = 0 + VLLM_MOE_OPT_LEVEL: int = 0 + VLLM_LINEAR_OPT_LEVEL: int = 0 + VLLM_OPT_EXCLUDE_LAYERS: str = "" + VLLM_LINEAR_SPECIFIED_LAYERS: str = "" + VLLM_LINEAR_SPECIFIED_KEYS: str = "" + VLLM_LINEAR_SPECIFIED_OPT_LEVEL: int = 0 + + VLLM_USE_LORA_FUSION: bool = False + + VLLM_USE_SILU_QUANT_FUSION: bool = False + # static quant for attention + VLLM_ATTN_STATIC_QUANT_SCALE_FILE_PATH: str = "" def get_default_cache_root(): return os.getenv( @@ -478,6 +498,8 @@ def get_env_or_set_default( logger = logging.getLogger(__name__) +IGNORED_UNKNOWN_VARS = {"VLLM_ENFORCE_CUDA_GRAPH"} + environment_variables: dict[str, Callable[[], Any]] = { # ================== Installation Time Env Vars ================== # Target device of vLLM, supporting [cuda (by default), @@ -648,6 +670,22 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False" ).lower() == "true", + # When set to 1, scheduler.schedule() delegates to schedule_opt() + # for PP mix ILU scheduling. + "VLLM_ENABLE_PP_MIX_ILU_SCHEDULING": lambda: os.environ.get( + "VLLM_ENABLE_PP_MIX_ILU_SCHEDULING", "0" + ) == "1", + # When set to 1, use step_with_batch_queue_ilu_opt (async batch queue with + # background thread). Batch queue size can be controlled via + # VLLM_PP_ILU_OPT_BATCH_QUEUE_SIZE. + "VLLM_ENABLE_PP_ILU_OPT": lambda: os.environ.get( + "VLLM_ENABLE_PP_ILU_OPT", "0" + ) == "1", + # Batch queue size used when VLLM_ENABLE_PP_ILU_OPT=1. + # If <= 0, EngineCore falls back to base_batch_queue_size * 2. + "VLLM_PP_ILU_OPT_BATCH_QUEUE_SIZE": lambda: int( + os.environ.get("VLLM_PP_ILU_OPT_BATCH_QUEUE_SIZE", "0") + ), # S3 access information, used for tensorizer to load model from S3 "S3_ACCESS_KEY_ID": lambda: os.environ.get("S3_ACCESS_KEY_ID", None), "S3_SECRET_ACCESS_KEY": lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None), @@ -862,7 +900,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ), # Time in ms for the zmq client to wait for a response from the backend # server for simple data operations - "VLLM_RPC_TIMEOUT": lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "10000")), + "VLLM_RPC_TIMEOUT": lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "1000000")), # Timeout in seconds for keeping HTTP connections alive in API server "VLLM_HTTP_TIMEOUT_KEEP_ALIVE": lambda: int( os.environ.get("VLLM_HTTP_TIMEOUT_KEEP_ALIVE", "5") @@ -907,6 +945,9 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_DISABLED_KERNELS": lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ["VLLM_DISABLED_KERNELS"].split(","), + "VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE": lambda: bool( + int(os.getenv("VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE", "1")) + ), # Disable pynccl (using torch.distributed instead) "VLLM_DISABLE_PYNCCL": lambda: ( os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1") @@ -957,9 +998,9 @@ environment_variables: dict[str, Callable[[], Any]] = { os.getenv("VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", "False").lower() in ("true", "1") ), # Whether to use aiter rope. - # By default is enabled. + # By default is disabled. "VLLM_ROCM_USE_AITER_TRITON_ROPE": lambda: ( - os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "True").lower() in ("true", "1") + os.getenv("VLLM_ROCM_USE_AITER_TRITON_ROPE", "False").lower() in ("true", "1") ), # Whether to use aiter triton fp8 bmm kernel # By default is enabled. @@ -1305,9 +1346,12 @@ environment_variables: dict[str, Callable[[], Any]] = { # Flashinfer fused allreduce backend. # "auto" will default to "mnnvl", which performs mostly same/better than "trtllm". # But "mnnvl" backend does not support fuse with quantization. + # TODO: Default is "trtllm" right now because "mnnvl" has issues with cudagraph: + # https://github.com/vllm-project/vllm/issues/35772 + # Should switch back to "auto" if the issue is resolved. "VLLM_FLASHINFER_ALLREDUCE_BACKEND": env_with_choices( "VLLM_FLASHINFER_ALLREDUCE_BACKEND", - "auto", + "trtllm", ["auto", "trtllm", "mnnvl"], ), # Control the workspace buffer size for the FlashInfer backend. @@ -1478,41 +1522,6 @@ environment_variables: dict[str, Callable[[], Any]] = { ), # Allows vllm to find tuned config under customized folder "VLLM_TUNED_CONFIG_FOLDER": lambda: os.getenv("VLLM_TUNED_CONFIG_FOLDER", None), - # vLLM do not support W4A8 and W8A8, we add it. For MOE, we default use W8A8, If set to true, we use W4A8. - - "VLLM_W8A8_LINEAR_USE_W4A8": - lambda: os.environ.get("VLLM_W8A8_LINEAR_USE_W4A8", "0").lower() in - ("1", "true"), - - "VLLM_W8A8_MOE_USE_W4A8": - lambda: os.environ.get("VLLM_W8A8_MOE_USE_W4A8", "0").lower() in - ("1", "true"), - - "VLLM_FORCE_NCCL_COMM": - lambda: os.environ.get("VLLM_FORCE_NCCL_COMM", "0").lower() in - ("1", "true"), - # If set to true, we use int8 mla attention for decode stage. - - "VLLM_USE_INT8_MLA": - lambda: os.environ.get("VLLM_USE_INT8_MLA", "0").lower() in - ("1", "true"), - - # For W4A8 MOE, we default use TN gemm format, choices: [TN, NN]. - "VLLM_W4A8_FORMAT": - lambda: os.environ.get("VLLM_W4A8_FORMAT", "TN").upper(), - - "VLLM_W4A8_VERSION": - # For W4A8 MOE, we default use version 2, choices: [1, 2]. - lambda: int(os.environ.get("VLLM_W4A8_VERSION", "2")), - - # temp param to support compressed-tensor's multi-quantization - "VLLM_MIX_QUANTIZATION_TYPE": - lambda: os.environ.get("VLLM_MIX_QUANTIZATION_TYPE", "").upper(), - - # Use Customize mlp impl for faster speed and less gpu memory usage. - "VLLM_MLA_CUSTOMIZE": - lambda: os.environ.get("VLLM_MLA_CUSTOMIZE", "1").lower() in - ("1", "true"), # Valid values are container,code_interpreter,web_search_preview # ex VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS=container,code_interpreter # If the server_label of your mcp tool is not in this list it will @@ -1607,7 +1616,7 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_DEBUG_WORKSPACE": lambda: bool(int(os.getenv("VLLM_DEBUG_WORKSPACE", "0"))), # Disables parallel execution of shared_experts via separate cuda stream "VLLM_DISABLE_SHARED_EXPERTS_STREAM": lambda: bool( - int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "0")) + int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "1")) ), # Limits when we run shared_experts in a separate stream. # We found out that for large batch sizes, the separate stream @@ -1662,10 +1671,93 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_CUDA_COMPATIBILITY_PATH": lambda: os.environ.get( "VLLM_CUDA_COMPATIBILITY_PATH", None ), - # control kv cache share cross group - "VLLM_KV_DISABLE_CROSS_GROUP_SHARE": - lambda: os.environ.get("VLLM_KV_DISABLE_CROSS_GROUP_SHARE", "0").lower() in - ("1", "true") + # Whether it is a scale up launch engine for elastic EP, + # Should only be set by EngineCoreClient. + "VLLM_ELASTIC_EP_SCALE_UP_LAUNCH": lambda: bool( + int(os.getenv("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH", "0")) + ), + # Whether to wait for all requests to drain before sending the + # scaling command in elastic EP. + "VLLM_ELASTIC_EP_DRAIN_REQUESTS": lambda: bool( + int(os.getenv("VLLM_ELASTIC_EP_DRAIN_REQUESTS", "0")) + ), + # Flag to enable FlatLogprobs whose GC overhead is significantly smaller than + # the original list[dict[int, Logprob]] approach. + # After enabled, PromptLogprobs and SampleLogprobs would populated as + # FlatLogprobs. + "VLLM_FLAT_LOGPROBS": lambda: bool(int(os.getenv("VLLM_FLAT_LOGPROBS", "0"))), + + # vLLM do not support W4A8 and W8A8, we add it. For MOE, we default use W8A8, If set to true, we use W4A8. + + "VLLM_W8A8_MOE_USE_W4A8": + lambda: os.environ.get("VLLM_W8A8_MOE_USE_W4A8", "0").lower() in + ("1", "true"), + + "VLLM_WNA16_MOE_USE_W4A8": + lambda: os.environ.get("VLLM_WNA16_MOE_USE_W4A8", "0").lower() in + ("1", "true"), + + # If set to true, we use int8 mla attention for decode stage. + + "VLLM_USE_INT8_MLA": + lambda: os.environ.get("VLLM_USE_INT8_MLA", "0").lower() in + ("1", "true"), + + # For attn, 0 for f16qkv, 1 for i8qkv, 2 for i8qkf16v + "VLLM_ATTN_OPT_LEVEL": + lambda: int(os.environ.get("VLLM_ATTN_OPT_LEVEL", "0")), + + # For W8A8 MOE, we default use TN gemm format, choices: [TN, NN].However, GEMV(Cuda Graph) only supports NN. + "VLLM_W8A8_FORMAT": + lambda: os.environ.get("VLLM_W8A8_FORMAT", "TN").upper(), + + # For W4A8 MOE, we default use TN gemm format, choices: [TN, NN]. + "VLLM_W4A8_FORMAT": + lambda: os.environ.get("VLLM_W4A8_FORMAT", "TN").upper(), + + "VLLM_W4A8_VERSION": + # For W4A8 MOE, we default use version 2, choices: [1, 2]. + lambda: int(os.environ.get("VLLM_W4A8_VERSION", "2")), + + # temp param to support compressed-tensor's multi-quantization + "VLLM_MIX_QUANTIZATION_TYPE": + lambda: os.environ.get("VLLM_MIX_QUANTIZATION_TYPE", "").upper(), + + # Use Customize mlp impl for faster speed and less gpu memory usage. + "VLLM_MLA_CUSTOMIZE": + lambda: os.environ.get("VLLM_MLA_CUSTOMIZE", "1").lower() in + ("1", "true"), + + # support Iluvatar IxServer + # Does vLLM support Iluvatar IxServer which is a distributed inference framework. + "VLLM_MOE_OPT_LEVEL": + lambda: int(os.getenv("VLLM_MOE_OPT_LEVEL", "0")), + + "VLLM_LINEAR_OPT_LEVEL": + lambda: int(os.getenv("VLLM_LINEAR_OPT_LEVEL", "0")), + + "VLLM_OPT_EXCLUDE_LAYERS": + lambda: os.environ.get("VLLM_OPT_EXCLUDE_LAYERS", "").upper(), + + "VLLM_LINEAR_SPECIFIED_LAYERS": + lambda: os.environ.get("VLLM_LINEAR_SPECIFIED_LAYERS", "").upper(), + + "VLLM_LINEAR_SPECIFIED_KEYS": + lambda: os.environ.get("VLLM_LINEAR_SPECIFIED_KEYS", "").lower(), + + "VLLM_LINEAR_SPECIFIED_OPT_LEVEL": + lambda: int(os.getenv("VLLM_LINEAR_SPECIFIED_OPT_LEVEL", "0")), + + "VLLM_USE_LORA_FUSION": + lambda: os.environ.get("VLLM_USE_LORA_FUSION", "0").lower() in + ("1", "true"), + + "VLLM_USE_SILU_QUANT_FUSION": + lambda: os.environ.get("VLLM_USE_SILU_QUANT_FUSION", "0").lower() in + ("1", "true"), + + "VLLM_ATTN_STATIC_QUANT_SCALE_FILE_PATH": + lambda: os.environ.get("VLLM_ATTN_STATIC_QUANT_SCALE_FILE_PATH", ""), } @@ -1738,6 +1830,8 @@ def is_set(name: str): def validate_environ(hard_fail: bool) -> None: for env in os.environ: if env.startswith("VLLM_") and env not in environment_variables: + if env in IGNORED_UNKNOWN_VARS: + continue if hard_fail: raise ValueError(f"Unknown vLLM environment variable detected: {env}") else: @@ -1801,10 +1895,7 @@ def compile_factors() -> dict[str, object]: "VLLM_ENABLE_V1_MULTIPROCESSING", "VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", "VLLM_CPU_KVCACHE_SPACE", - "VLLM_CPU_OMP_THREADS_BIND", - "VLLM_CPU_NUM_OF_RESERVED_CPU", "VLLM_CPU_MOE_PREPACK", - "VLLM_CPU_SGL_KERNEL", "VLLM_TEST_FORCE_LOAD_FORMAT", "VLLM_ENABLE_CUDA_COMPATIBILITY", "VLLM_CUDA_COMPATIBILITY_PATH", diff --git a/vllm/forward_context.py b/vllm/forward_context.py index a0753b1..15e3263 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -241,7 +241,7 @@ class ForwardContext: additional_kwargs: dict[str, Any] = field(default_factory=dict) def __post_init__(self): - assert self.cudagraph_runtime_mode.valid_runtime_modes(), ( + assert self.cudagraph_runtime_mode.is_valid_runtime_mode(), ( f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}" ) @@ -347,7 +347,6 @@ def set_forward_context( num_tokens_unpadded=num_tokens, parallel_config=vllm_config.parallel_config, allow_microbatching=False, - allow_dp_padding=False, ) assert num_tokens_across_dp is not None dp_metadata = DPMetadata.make( diff --git a/vllm/kernels/helion/register.py b/vllm/kernels/helion/register.py index 3114631..cd0ef83 100644 --- a/vllm/kernels/helion/register.py +++ b/vllm/kernels/helion/register.py @@ -31,8 +31,8 @@ by key matches the config returned by the autotuner. Key Classes ----------- -- HelionKernelWrapper: Wraps raw kernel + config_picker, creates configured ops -- ConfiguredHelionKernel: Platform-specific kernel registered as PyTorch custom op +- HelionKernelWrapper: Wraps raw kernel + config_picker, creates configured kernels +- ConfiguredHelionKernel: Platform-specific kernel with pre-tuned configs - PresetConfigSearch: Custom autotuner that returns pre-tuned configs """ @@ -53,10 +53,27 @@ if not has_helion(): ) import helion +from helion._compat import requires_torch_version from helion.autotuner.base_search import BaseAutotuner from helion.runtime.config import Config from helion.runtime.settings import default_autotuner_fn +# TODO(gmagogsfm): Remove CustomOp fallback path (_get_or_register_custom_op, +# vllm_helion_lib, direct_register_custom_op) once vLLM requires PyTorch >= 2.11. +_HOP_AVAILABLE = requires_torch_version("2.11") + +if _HOP_AVAILABLE: + import torch.utils._pytree as pytree + from helion._compiler._dynamo.higher_order_ops import ( + helion_kernel_side_table, + helion_kernel_wrapper_mutation, + ) + from helion._compiler._dynamo.variables import infer_output_spec + from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + get_proxy_mode, + ) + logger = init_logger(__name__) vllm_helion_lib = Library("vllm_helion", "FRAGMENT") # noqa @@ -233,7 +250,7 @@ class ConfiguredHelionKernel: class HelionKernelWrapper: - """Wrapper for Helion kernels that creates config-specific PyTorch custom ops.""" + """Wrapper for Helion kernels with pre-tuned config selection and HOP support.""" def __init__( self, @@ -252,11 +269,86 @@ class HelionKernelWrapper: self._config_picker: ( Callable[[tuple[Any, ...], list[str]], str | None] | None ) = None + self._configured_kernel: ConfiguredHelionKernel | None = None self._input_generator: Callable[[], dict[str, tuple[Any, ...]]] | None = None def __call__(self, *args, **kwargs): - configured_op = self.get_configured_op() - return configured_op(*args, **kwargs) + # CustomOp fallback: register as torch custom op for torch.compile + # compatibility on older PyTorch lacking HOP/EffectType support + if not _HOP_AVAILABLE: + custom_op = self._get_or_register_custom_op() + return custom_op(*args, **kwargs) + # HOP tracing: record HigherOrderOp in the FX graph + if get_proxy_mode() is not None: + return self._call_via_hop(args, kwargs) + # Eager: run the configured kernel directly + return self.get_configured_op()(*args, **kwargs) + + def _call_via_hop( + self, + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> Any: + kernel = self.get_configured_op()._decorated_kernel + kernel_idx = helion_kernel_side_table.add_kernel(kernel) + + constant_args, tensor_args = self._partition_args(kernel, args, kwargs) + + all_named = {**constant_args, **tensor_args} + full_args = tuple( + all_named.get(n, p.default) + for n, p in kernel.signature.parameters.items() # type: ignore[attr-defined] + if n in all_named or p.default is not p.empty + ) + + with disable_proxy_modes_tracing(): + output_spec = infer_output_spec(kernel, full_args) + + hop_result = helion_kernel_wrapper_mutation( + kernel_idx=kernel_idx, + constant_args=constant_args, + tensor_args=tensor_args, + output_spec=output_spec, + ) + + tree_spec_str = output_spec.get("tree_spec_str") + if tree_spec_str is None: + return None + tree_spec = pytree.treespec_loads(tree_spec_str) + + hop_iter = iter(hop_result) + reconstructed = [] + for spec in output_spec["leaf_specs"]: + is_constant_scalar = spec["type"] == "scalar" and not isinstance( + spec.get("scalar_value"), torch.SymInt + ) + if is_constant_scalar: + reconstructed.append(spec["scalar_value"]) + else: + reconstructed.append(next(hop_iter)) + return pytree.tree_unflatten(reconstructed, tree_spec) + + @staticmethod + def _partition_args( + kernel: Any, + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> tuple[dict[str, Any], dict[str, Any]]: + constant_args: dict[str, Any] = {} + tensor_args: dict[str, Any] = {} + params = list(kernel.signature.parameters.keys()) + for i, val in enumerate(args): + name = params[i] + if isinstance(val, torch.Tensor): + tensor_args[name] = val + else: + constant_args[name] = val + for name, val in kwargs.items(): + if isinstance(val, torch.Tensor): + tensor_args[name] = val + else: + constant_args[name] = val + return constant_args, tensor_args def register_config_picker( self, picker_func: Callable[[tuple[Any, ...], list[str]], str | None] @@ -309,29 +401,32 @@ class HelionKernelWrapper: ) return autotune_kernel.autotune(inputs) - def get_configured_op(self) -> Any: + def get_configured_op(self) -> ConfiguredHelionKernel: assert self._config_picker is not None, ( f"No config picker registered for kernel '{self.op_name}'. " f"Use @{self.op_name}.register_config_picker to register one." ) + if self._configured_kernel is None: + self._configured_kernel = ConfiguredHelionKernel( + op_name=self.op_name, + config_picker=self._config_picker, + raw_kernel_func=self.raw_kernel_func, + helion_settings=self.helion_settings, + ) + + return self._configured_kernel + + def _get_or_register_custom_op(self) -> Any: if hasattr(torch.ops.vllm_helion, self.op_name): - logger.debug("Op vllm_helion::%s already registered", self.op_name) return getattr(torch.ops.vllm_helion, self.op_name) - configured_kernel = ConfiguredHelionKernel( - op_name=self.op_name, - config_picker=self._config_picker, - raw_kernel_func=self.raw_kernel_func, - helion_settings=self.helion_settings, - ) + configured_kernel = self.get_configured_op() logger.info("Registering op: vllm_helion::%s", self.op_name) direct_register_custom_op( op_name=self.op_name, - op_func=configured_kernel._decorated_kernel, # Register decorated kernel - # TODO(gmagogsfm): Implement automatic mutation/aliasing detection - # for Helion kernels. + op_func=configured_kernel._decorated_kernel, mutates_args=None, fake_impl=self._fake_impl, target_lib=vllm_helion_lib, diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index e08dcc8..eff05b5 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -32,10 +32,10 @@ from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( UnfusedOAITritonExperts, ) from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel, + FusedMoEKernel, ) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, + MoEPrepareAndFinalizeNoDPEPModular, ) from .utils import _get_lora_device, try_get_optimal_moe_lora_config @@ -83,7 +83,11 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): ): if envs.VLLM_TUNED_CONFIG_FOLDER: hidden_size = layer.hidden_size - intermediate_size = layer.intermediate_size_per_partition + intermediate_size = ( + self.w2_lora_a_stacked[0].shape[-1] + if op_prefix == "w2" + else self.w13_lora_b_stacked[0].shape[-2] + ) shrink_config = get_lora_op_configs( op_type=f"fused_moe_lora_{op_prefix}_shrink", max_loras=num_loras, @@ -132,7 +136,7 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): if getattr(self.base_layer.quant_method, "supports_internal_mk", False): # Use the existing modular kernel from the quant method - m_fused_moe_fn = self.base_layer.quant_method.moe_mk + m_fused_moe_fn = self.base_layer.quant_method.moe_kernel # Don't let the kernel own shared experts so the runner can # overlap them with routed experts via a separate CUDA stream. m_fused_moe_fn.shared_experts = None @@ -140,8 +144,8 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): # Create a new modular kernel via select_gemm_impl. # Don't pass shared_experts to the kernel so the runner can # overlap them with routed experts via a separate CUDA stream. - prepare_finalize = MoEPrepareAndFinalizeNoEP() - m_fused_moe_fn = FusedMoEModularKernel( + prepare_finalize = MoEPrepareAndFinalizeNoDPEPModular() + m_fused_moe_fn = FusedMoEKernel( prepare_finalize, self.base_layer.quant_method.select_gemm_impl( prepare_finalize, self.base_layer @@ -150,10 +154,11 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): if quant_config.use_mxfp4_w4a16: assert isinstance( - m_fused_moe_fn.fused_experts, (MarlinExperts, UnfusedOAITritonExperts) + m_fused_moe_fn.impl.fused_experts, + (MarlinExperts, UnfusedOAITritonExperts), ) else: - assert isinstance(m_fused_moe_fn.fused_experts, TritonExperts) + assert isinstance(m_fused_moe_fn.impl.fused_experts, TritonExperts) def fwd_decorator(layer, func): def wrapper(*args, **kwargs): @@ -333,9 +338,9 @@ class FusedMoEWithLoRA(BaseLayerWithLoRA): return wrapper - fused_experts = m_fused_moe_fn.fused_experts + fused_experts = m_fused_moe_fn.impl.fused_experts - m_fused_moe_fn.forward = fwd_decorator(self.base_layer, m_fused_moe_fn.forward) + m_fused_moe_fn.apply = fwd_decorator(self.base_layer, m_fused_moe_fn.apply) fused_experts.activation = act_decorator( self.base_layer, fused_experts.activation ) diff --git a/vllm/lora/layers/logits_processor.py b/vllm/lora/layers/logits_processor.py index 217c46f..237a61e 100644 --- a/vllm/lora/layers/logits_processor.py +++ b/vllm/lora/layers/logits_processor.py @@ -88,10 +88,8 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): model_config: PretrainedConfig | None = None, ) -> None: # TODO: Verify if this condition can be further relaxed - if self.base_layer.vocab_size <= 32000 or self.base_layer.vocab_size > 258048: - raise ValueError( - "When using LoRA, vocab size must be > 32000 and <= 258048" - ) + if self.base_layer.vocab_size > 258048: + raise ValueError("When using LoRA, vocab size must be <= 258048") self.lora_a_stacked = torch.zeros( ( max_loras, diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index c9c85c1..7fc49d8 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -8,9 +8,10 @@ from vllm.distributed import ( tensor_model_parallel_all_reduce, ) from vllm.triton_utils import tl, triton +from vllm.triton_utils.allocation import set_triton_allocator from vllm.utils.torch_utils import direct_register_custom_op -from .utils import supports_pdl +from .utils import supports_pdl, supports_tma @triton.jit @@ -70,6 +71,37 @@ def _get_token_offs( ) +@triton.jit +def _get_c_ptrs( + cur_c_ptr, + lora_id, + pid_m, + offs, + offs_token, + offs_cn, + stride_cm, + stride_cn, + EM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + sort_c: tl.constexpr, +): + # When sort_c is true, store the output in c_ptr using token order defined + # in sorted_token_ids_ptr; otherwise, use the original token order from the prompt + if sort_c: + offs_token_id = pid_m * BLOCK_SIZE_M + offs + c_ptrs = ( + cur_c_ptr + + lora_id * EM * stride_cm + + stride_cm * offs_token_id[:, None] + + stride_cn * offs_cn[None, :] + ) + else: + c_ptrs = ( + cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + ) + return c_ptrs + + _LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {} @@ -95,7 +127,7 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device): def _adjust_kernel_inputs( - num_active_loras: int, + num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs sorted_token_ids: torch.Tensor | None, expert_ids: torch.Tensor, ): @@ -109,7 +141,7 @@ def _adjust_kernel_inputs( else: stride_tl = sorted_token_ids.stride(0) stride_el = expert_ids.stride(0) - grid_lora_dim = num_active_loras + grid_lora_dim = num_active_loras.item() return grid_lora_dim, stride_tl, stride_el @@ -125,7 +157,9 @@ def _adjust_kernel_inputs( ) def _fused_moe_lora_kernel( a_ptr, + a_desc, b_ptr, + b_desc, c_ptr, topk_weights_ptr, sorted_token_ids_ptr, @@ -177,6 +211,18 @@ def _fused_moe_lora_kernel( USE_GDC: tl.constexpr, launch_pdl: tl.constexpr, IS_PRIMARY: tl.constexpr, + USE_TMA: tl.constexpr, + # sort_c determines whether tokens are stored in C in the order determined + # by sorted_token_ids to enable later TMA loads from this tensor. + # + # When USE_TMA is enabled, the parameter combinations are: + # a_desc | b_desc | sort_c | Use Case + # --------|---------|--------|----------------------------- + # yes | yes | False | expand kernel (num_slices=1) + # no | yes | True | shrink kernel (num_slices=1) + # yes | no | False | expand kernel (num_slices>1) + # no | no | True | shrink kernel (num_slices>1) + sort_c: tl.constexpr, ): pid = tl.program_id(axis=0) slice_id = tl.program_id(axis=1) @@ -250,58 +296,90 @@ def _fused_moe_lora_kernel( cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty)) cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size - # remove modulo wrap-around - offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int32) offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) token_mask = offs_token < num_valid_tokens - # get a_ptrs,b_ptrs - a_ptrs = cur_a_ptr + ( - offs_token[:, None] // token_mapping_factor * stride_am - + offs_k[None, :] * stride_ak - ) + if USE_TMA and a_desc is not None: + # Expand path - with TMA enabled, load from A using TMA descriptor + offs_am = ( + slice_id * max_loras * EM + + lora_id * EM + + pid_m * BLOCK_SIZE_M // token_mapping_factor + ) + offs_ak = pid_sk * BLOCK_SIZE_K + else: + # Shrink path - load hidden states based on order defined in + # 'sorted_token_ids_ptr' then store them in c_ptr in this same sorted order + tl.static_assert(a_desc is None, "a_desc must be none") + a_ptrs = cur_a_ptr + ( + offs_token[:, None] // token_mapping_factor * stride_am + + offs_k[None, :] * stride_ak + ) - b_ptrs = ( - cur_b_ptr - + lora_id * stride_bl - + expert_id * stride_be - + offs_k[:, None] * stride_bk - + offs_bn[None, :] * stride_bn - ) + if USE_TMA: + offs_bn = pid_n * BLOCK_SIZE_N + offs_bk = pid_sk * BLOCK_SIZE_K + if b_desc is None: + # Note(@gnovack) - Allocation of TMA descriptors on-device + # can cause conflicts when running in parallel via PDL + if USE_GDC and not IS_PRIMARY: + tl.extra.cuda.gdc_wait() + + b_desc = tl.make_tensor_descriptor( + cur_b_ptr, + shape=[max_loras, num_experts, N, K], + strides=[stride_bl, stride_be, stride_bn, stride_bk], + block_shape=[1, 1, BLOCK_SIZE_N, BLOCK_SIZE_K], + ) + else: + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int32) + b_ptrs = ( + cur_b_ptr + + lora_id * stride_bl + + expert_id * stride_be + + offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn + ) if USE_GDC and IS_PRIMARY: # GDC launch dependents hints the runtime system to launch dependent kernels. tl.extra.cuda.gdc_launch_dependents() - # accumulator accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) if USE_GDC and not IS_PRIMARY: tl.extra.cuda.gdc_wait() for k in range(0, grid_k): - k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K) - # GDC wait waits for ALL programs in the prior kernel to complete - # before continuing. + cur_k_offset = k * (BLOCK_SIZE_K * SPLIT_K) + k_remaining = K - cur_k_offset # pre-fetch lora weight - # add (offs_bn < N) mask; optional .ca for B - b_mask = (offs_k[:, None] < k_remaining) & (offs_bn[None, :] < N) - if USE_B_L2_CACHE: - b = tl.load(b_ptrs, mask=b_mask, other=0.0, cache_modifier=".ca") + if b_desc is not None: + b = ( + b_desc.load([lora_id, expert_id, offs_bn, offs_bk + cur_k_offset]) + .reshape(BLOCK_SIZE_N, BLOCK_SIZE_K) + .T + ) else: - b = tl.load(b_ptrs, mask=b_mask, other=0.0) + # add (offs_bn < N) mask; optional .ca for B + b_mask = (offs_k[:, None] < k_remaining) & (offs_bn[None, :] < N) + if USE_B_L2_CACHE: + b = tl.load(b_ptrs, mask=b_mask, other=0.0, cache_modifier=".ca") + else: + b = tl.load(b_ptrs, mask=b_mask, other=0.0) + b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk + + if a_desc is not None: + a = a_desc.load([offs_am, offs_ak + cur_k_offset]) + else: + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < k_remaining), + other=0.0, + ) + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak - if USE_GDC and not IS_PRIMARY: - tl.extra.cuda.gdc_wait() - a = tl.load( - a_ptrs, - mask=token_mask[:, None] & (offs_k[None, :] < k_remaining), - other=0.0, - ) accumulator += tl.dot(a, b) - # Advance the ptrs to the next K block. - a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak - b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk if MUL_ROUTED_WEIGHT: moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0) @@ -309,7 +387,19 @@ def _fused_moe_lora_kernel( accumulator = accumulator.to(c_ptr.dtype.element_ty) # Write back the block of the output offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_ptrs = _get_c_ptrs( + cur_c_ptr, + lora_id, + pid_m, + offs, + offs_token, + offs_cn, + stride_cm, + stride_cn, + EM, + BLOCK_SIZE_M, + sort_c, + ) c_mask = token_mask[:, None] & (offs_cn[None, :] < N) if SPLIT_K == 1: @@ -354,9 +444,10 @@ def _fused_moe_lora_shrink( num_warps: int, num_stages: int, split_k: int, - num_active_loras: int, + num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs mul_routed_weight: bool = False, use_gdc: bool = False, + use_tma: bool = False, ) -> None: w1_lora_a_stacked = lora_a_stacked[0] shrink_config = { @@ -369,6 +460,7 @@ def _fused_moe_lora_shrink( "SPLIT_K": split_k, "USE_GDC": use_gdc, "launch_pdl": use_gdc, # triton kernel metadata + "USE_TMA": use_tma, } b_ptr = _get_ptr(lora_a_stacked, device) @@ -383,9 +475,20 @@ def _fused_moe_lora_shrink( len(lora_a_stacked), grid_lora_dim, ) + + a_desc = None + b_desc = None + if use_tma and num_slices == 1: + b_desc = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor( + lora_a_stacked[0], + [1, 1, shrink_config["BLOCK_SIZE_N"], shrink_config["BLOCK_SIZE_K"]], + ) + _fused_moe_lora_kernel[grid]( qcurr_hidden_states, + a_desc, b_ptr, + b_desc, a_intermediate_cache1, topk_weights, sorted_token_ids, @@ -407,8 +510,8 @@ def _fused_moe_lora_shrink( w1_lora_a_stacked.stride(1), w1_lora_a_stacked.stride(3), w1_lora_a_stacked.stride(2), - a_intermediate_cache1.stride(2), - a_intermediate_cache1.stride(3), + a_intermediate_cache1.stride(-2), + a_intermediate_cache1.stride(-1), stride_tl, stride_el, slice_a_size=qcurr_hidden_states.numel(), @@ -419,7 +522,8 @@ def _fused_moe_lora_shrink( naive_block_assignment=sorted_token_ids is None, MUL_ROUTED_WEIGHT=False, ADD_INPUTS=False, - USE_B_L2_CACHE=True, # new + USE_B_L2_CACHE=True, + sort_c=use_tma and sorted_token_ids is not None, IS_PRIMARY=True, **shrink_config, ) @@ -458,10 +562,11 @@ def _fused_moe_lora_expand( num_warps: int, num_stages: int, split_k: int, - num_active_loras: int, + num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs mul_routed_weight: bool = False, offset: int = 0, use_gdc: bool = False, + use_tma: bool = False, ) -> None: b_ptr = _get_ptr(lora_b_stacked, device) K = max_lora_rank @@ -470,7 +575,7 @@ def _fused_moe_lora_expand( w1_lora_b_stacked = lora_b_stacked[0] a_intermediate_cache1 = a_intermediate_cache1.view( - -1, a_intermediate_cache1.shape[3] + -1, a_intermediate_cache1.shape[-1] ) expand_config = { @@ -483,6 +588,7 @@ def _fused_moe_lora_expand( "SPLIT_K": 1, # Set split_k = 1 for expand calls "USE_GDC": use_gdc, "launch_pdl": use_gdc, # triton kernel metadata + "USE_TMA": use_tma, } grid_lora_dim, stride_tl, stride_el = _adjust_kernel_inputs( @@ -498,10 +604,27 @@ def _fused_moe_lora_expand( # Fast path: directly accumulate into the corresponding slice interval of output. out_view = output[:, :, offset : offset + num_slices * N] slice_c_size = N * out_view.stride(2) + a_desc = None + b_desc = None + if use_tma: + if sorted_token_ids is not None: + a_desc = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor( + a_intermediate_cache1, + [expand_config["BLOCK_SIZE_M"], expand_config["BLOCK_SIZE_K"]], + ) + if num_slices == 1: + b_desc = triton.tools.tensor_descriptor.TensorDescriptor.from_tensor( + lora_b_stacked[0], + [1, 1, expand_config["BLOCK_SIZE_N"], expand_config["BLOCK_SIZE_K"]], + ) + else: + b_desc = None _fused_moe_lora_kernel[grid]( a_intermediate_cache1, + a_desc, b_ptr, + b_desc, out_view, topk_weights, sorted_token_ids, @@ -535,7 +658,8 @@ def _fused_moe_lora_expand( naive_block_assignment=sorted_token_ids is None, MUL_ROUTED_WEIGHT=mul_routed_weight, ADD_INPUTS=True, - USE_B_L2_CACHE=True, # new + USE_B_L2_CACHE=True, + sort_c=False, IS_PRIMARY=False, **expand_config, ) @@ -559,7 +683,7 @@ def _fused_moe_lora( max_lora_rank: int, top_k_num: int, lora_ids: torch.Tensor, - num_active_loras: int, + num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs adapter_enabled: torch.Tensor, shrink_block_size_m: int, shrink_block_size_n: int, @@ -616,8 +740,34 @@ def _fused_moe_lora( else num_tokens * shrink_block_size_m ) + # TMA is not currently compatiple with fully_sharded due to the non-determinism + # of token id sorting across ranks. + use_tma = supports_tma(device) and not fully_sharded + + intermediate_cache_shape = ( + num_slices, + M, + top_k_num, + max_lora_rank, + ) + if use_tma: + if num_slices > 1: + # if num_slices > 1, we construct TMA descriptors for LoRA + # weights within the kernel, which requires us to first set an allocator + set_triton_allocator(device) + + # When storing intermediate data in sorted order for TMA, we + # need an extra 'num_active_loras' dim in the cache to avoid conflicts + if sorted_token_ids is not None: + intermediate_cache_shape = ( + num_slices, + sorted_token_ids.shape[0], + EM, + max_lora_rank, + ) + a_intermediate_cache1 = torch.zeros( - (num_slices, M, top_k_num, max_lora_rank), + intermediate_cache_shape, dtype=output.dtype, device=device, ) @@ -654,6 +804,7 @@ def _fused_moe_lora( num_active_loras, mul_routed_weight, use_gdc=use_gdc, + use_tma=use_tma, ) if fully_sharded: @@ -703,6 +854,7 @@ def _fused_moe_lora( mul_routed_weight, offset, use_gdc=use_gdc, + use_tma=use_tma, ) @@ -719,7 +871,7 @@ def _fused_moe_lora_fake( max_lora_rank: int, top_k_num: int, lora_ids: torch.Tensor, - num_active_loras: int, + num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs adapter_enabled: torch.Tensor, shrink_block_size_m: int, shrink_block_size_n: int, @@ -769,9 +921,10 @@ def _fused_moe_lora_shrink_fake( num_warps: int, num_stages: int, split_k: int, - num_active_loras: int, + num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs mul_routed_weight: bool = False, use_gdc: bool = False, + use_tma: bool = False, ) -> None: return @@ -805,10 +958,11 @@ def _fused_moe_lora_expand_fake( num_warps: int, num_stages: int, split_k: int, - num_active_loras: int, + num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs mul_routed_weight: bool = False, offset: int = 0, use_gdc: bool = False, + use_tma: bool = False, ) -> None: return diff --git a/vllm/lora/ops/triton_ops/lora_expand_op.py b/vllm/lora/ops/triton_ops/lora_expand_op.py index 1557d37..343e0c8 100644 --- a/vllm/lora/ops/triton_ops/lora_expand_op.py +++ b/vllm/lora/ops/triton_ops/lora_expand_op.py @@ -138,7 +138,7 @@ def _lora_expand( lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] lora_ids: torch.Tensor, # shape [max-loras + 1] no_lora_flag_cpu: torch.Tensor, # shape [1] - num_active_loras: int, # number of active LoRAs (unused here, for API compat) + num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs offset_start: int = 0, add_inputs: bool = False, ) -> None: @@ -235,7 +235,7 @@ def _lora_expand( grid = ( triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N), NUM_SLICES, - num_active_loras, + num_active_loras.item(), ) # We disable PDL temporarily because LoRA kernels are not launching back-to-back, # making PDL invalid and affecting the kernel performance. @@ -289,7 +289,7 @@ def _lora_expand_fake( lora_token_start_loc: torch.Tensor, lora_ids: torch.Tensor, no_lora_flag_cpu: torch.Tensor, - num_active_loras: int, + num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs offset_start: int = 0, add_inputs: bool = False, ) -> None: diff --git a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py index 1fec1d5..dd7c2c7 100644 --- a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py +++ b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py @@ -29,9 +29,16 @@ class LoRAKernelMeta: # to early exit from inside the lora_expand / lora_shrink torch operation. no_lora_flag_cpu: torch.Tensor - # Number of active LoRAs (unique non-(-1) values in token_lora_mapping) - # Stored as a Python int to avoid GPU->CPU sync during forward pass - num_active_loras: int = 0 + # Number of active LoRAs (unique non-(-1) values in token_lora_mapping). + # Stored as a CPU tensor (not a Python int) so that torch.compile treats + # it as a dynamic value rather than baking it as a constant at trace time. + # This follows the same pattern as no_lora_flag_cpu above. + num_active_loras_cpu: torch.Tensor + + # Default num_active_loras value (max_loras + 1) as a CPU tensor, + # used when specialize_active_lora is False to avoid allocating a + # new tensor on every meta_args() call. + default_num_active_loras_cpu: torch.Tensor # Captured LoRA counts for cudagraph specialization (sorted list). # When specialize_active_lora is enabled, num_active_loras is rounded up @@ -73,6 +80,11 @@ class LoRAKernelMeta: no_lora_flag_cpu = torch.tensor([False], dtype=torch.bool, device="cpu") + num_active_loras_cpu = torch.tensor([0], dtype=torch.int32, device="cpu") + default_num_active_loras_cpu = torch.tensor( + [max_loras + 1], dtype=torch.int32, device="cpu" + ) + return LoRAKernelMeta( token_lora_mapping=token_lora_mapping, token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids, @@ -80,6 +92,8 @@ class LoRAKernelMeta: num_tokens_per_lora=num_tokens_per_lora, lora_token_start_loc=lora_token_start_loc, no_lora_flag_cpu=no_lora_flag_cpu, + num_active_loras_cpu=num_active_loras_cpu, + default_num_active_loras_cpu=default_num_active_loras_cpu, captured_lora_counts=sorted(captured_lora_counts) if captured_lora_counts else [], @@ -90,8 +104,7 @@ class LoRAKernelMeta: self.num_tokens_per_lora.fill_(0) self.lora_token_start_loc.fill_(0) self.no_lora_flag_cpu.fill_(False) - self.num_active_loras = 0 - self.captured_lora_counts = [] + self.num_active_loras_cpu.fill_(0) def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: """ @@ -137,14 +150,16 @@ class LoRAKernelMeta: num_tokens_per_lora, non_blocking=True ) - self.num_active_loras = lora_ids.size(0) + num_active_loras = lora_ids.size(0) # Round up num_active_loras to match cudagraph capture keys. # This ensures the kernel grid dimension matches the captured graph. - if self.captured_lora_counts and self.num_active_loras > 0: - idx = bisect.bisect_left(self.captured_lora_counts, self.num_active_loras) + if self.captured_lora_counts and num_active_loras > 0: + idx = bisect.bisect_left(self.captured_lora_counts, num_active_loras) if idx < len(self.captured_lora_counts): - self.num_active_loras = self.captured_lora_counts[idx] + num_active_loras = self.captured_lora_counts[idx] + + self.num_active_loras_cpu[0] = num_active_loras # lora_token_start_loc lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0) @@ -163,7 +178,7 @@ class LoRAKernelMeta: torch.Tensor, torch.Tensor, torch.Tensor, - int, + torch.Tensor, ]: """ This function returns the kernel metadata required for the current @@ -175,7 +190,10 @@ class LoRAKernelMeta: token_nums (int): Number of input tokens in the current forward pass of the kernel. """ - max_loras = self.active_lora_ids.size(0) - 1 + if specialize_active_lora: + num_active_loras = self.num_active_loras_cpu + else: + num_active_loras = self.default_num_active_loras_cpu return ( self.token_lora_mapping[:token_nums], self.token_indices_sorted_by_lora_ids[:token_nums], @@ -183,5 +201,5 @@ class LoRAKernelMeta: self.lora_token_start_loc, self.active_lora_ids, self.no_lora_flag_cpu, - self.num_active_loras if specialize_active_lora else max_loras + 1, + num_active_loras, ) diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py index 8dbd988..ea850ba 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink_op.py +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -134,7 +134,7 @@ def _lora_shrink( lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] lora_ids: torch.Tensor, # shape [max-loras + 1] no_lora_flag_cpu: torch.Tensor, # shape [1] - num_active_loras: int, # number of active LoRAs (unused here, for API compat) + num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs scaling: float, ) -> None: """ @@ -157,6 +157,9 @@ def _lora_shrink( lora_ids (torch.Tensor): LoRA ids to process. no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates if there are any requests that require LoRA. + num_active_loras (torch.Tensor): A CPU tensor of size 1, containing the + number of active LoRAs. Stored as a tensor (not int) so + torch.compile treats it as dynamic rather than a constant. scaling (float): Scaling factor. """ @@ -215,7 +218,7 @@ def _lora_shrink( grid = ( SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), NUM_SLICES, - num_active_loras, + num_active_loras.item(), ) # We disable PDL temporarily because LoRA kernels are not launching back-to-back, # making PDL invalid and affecting the kernel performance. @@ -267,7 +270,7 @@ def _lora_shrink_fake( lora_token_start_loc: torch.Tensor, lora_ids: torch.Tensor, no_lora_flag_cpu: torch.Tensor, - num_active_loras: int, + num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs scaling: float, ) -> None: return diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index c7ac591..a863b97 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -316,3 +316,9 @@ def supports_pdl(device: torch.device | None = None) -> bool: and current_platform.has_device_capability(90) and not envs.VLLM_LORA_DISABLE_PDL ) + + +@lru_cache +def supports_tma(device: torch.device | None = None) -> bool: + # TMA requires compute capability SM90 or above + return current_platform.is_cuda() and current_platform.has_device_capability(90) diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index b75d297..2964a0d 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -233,6 +233,17 @@ class PunicaWrapperGPU(PunicaWrapperBase): assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + import vllm.envs as env + if env.VLLM_USE_LORA_FUSION: + import ixformer.inference.functions as ops + + num_token, m = x.size(0), x.size(-1) + k, n = lora_b_stacked[0].size(-1), y.size(-1) + if len(lora_a_stacked) == 1 and ops.lora_gemv_optim_condition(num_token, m, k, n): + ops.add_lora_linear(y, x, lora_a_stacked, lora_b_stacked, + lora_bias_stacked = None, scale = 1.0, output_slices = (1,)) + return + assert buffer is None, ( "To minimize overhead, the buffer should be created by " ".add_lora_linear() instead of being passed in." @@ -351,6 +362,8 @@ class PunicaWrapperGPU(PunicaWrapperBase): max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) if pad_sorted_ids: max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + if topk_ids.numel() < num_experts: + max_num_tokens_padded = topk_ids.numel() * block_size sorted_ids = torch.empty( (max_loras * max_num_tokens_padded,), dtype=torch.int32, diff --git a/vllm/model_executor/kernels/linear/__init__.py b/vllm/model_executor/kernels/linear/__init__.py index 1b4b7dc..cbdc53d 100644 --- a/vllm/model_executor/kernels/linear/__init__.py +++ b/vllm/model_executor/kernels/linear/__init__.py @@ -127,10 +127,10 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = # in priority/performance order (when available) _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = { PlatformEnum.CUDA: [ + MarlinLinearKernel, CutlassW4A8LinearKernel, MacheteLinearKernel, AllSparkLinearKernel, - MarlinLinearKernel, ConchLinearKernel, ExllamaLinearKernel, ], diff --git a/vllm/model_executor/kernels/linear/mixed_precision/machete.py b/vllm/model_executor/kernels/linear/mixed_precision/machete.py index 7953ed5..b756c8a 100644 --- a/vllm/model_executor/kernels/linear/mixed_precision/machete.py +++ b/vllm/model_executor/kernels/linear/mixed_precision/machete.py @@ -69,7 +69,6 @@ class MacheteLinearKernel(MPLinearKernel): # `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1} def process_weights_after_loading(self, layer: torch.nn.Module): c = self.config - if c.has_g_idx: assert self.w_gidx_name is not None perm = torch.argsort(getattr(layer, self.w_gidx_name)).to(torch.int) @@ -86,19 +85,17 @@ class MacheteLinearKernel(MPLinearKernel): assert isinstance(x, BasevLLMParameter) permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) if c.has_g_idx: - x_unpacked = unpack_quantized_values_into_int32( - x.data, c.weight_type, packed_dim=0 - ) + x_unpacked = unpack_quantized_values_into_int32(x.data, + c.weight_type, + packed_dim=0) x_perm = x_unpacked[perm, :] - x.data = pack_quantized_values_into_int32( - x_perm, c.weight_type, packed_dim=0 - ) - x.data = ops.machete_prepack_B( - x.data.t().contiguous().t(), - a_type=c.act_type, - b_type=c.weight_type, - group_scales_type=c.act_type, - ) + x.data = pack_quantized_values_into_int32(x_perm, + c.weight_type, + packed_dim=0) + x.data = ops.machete_prepack_B(x.data.t().contiguous().t(), + a_type=c.act_type, + b_type=c.weight_type, + group_scales_type=c.act_type) return x def transform_w_s(x): @@ -144,16 +141,14 @@ class MacheteLinearKernel(MPLinearKernel): else: w_zp = None - output = ops.machete_mm( - a=x_2d, - b_q=w_q, - b_type=c.weight_type, - b_group_zeros=w_zp, - b_group_scales=w_s, - b_group_size=c.group_size, - ) + output = ops.machete_mm(a=x_2d, + b_q=w_q, + b_type=c.weight_type, + b_group_zeros=w_zp, + b_group_scales=w_s, + b_group_size=c.group_size) if bias is not None: output.add_(bias) # In-place add - return output.reshape(out_shape) + return output.reshape(out_shape) \ No newline at end of file diff --git a/vllm/model_executor/kernels/linear/mixed_precision/marlin.py b/vllm/model_executor/kernels/linear/mixed_precision/marlin.py index eb14f9e..b955324 100644 --- a/vllm/model_executor/kernels/linear/mixed_precision/marlin.py +++ b/vllm/model_executor/kernels/linear/mixed_precision/marlin.py @@ -23,9 +23,95 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_ from vllm.platforms import current_platform from vllm.scalar_type import scalar_types +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_quantized_values_into_int32, unpack_quantized_values_into_int32) from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig +from vllm.scalar_type import ScalarType, scalar_types +import ixformer.inference.functions as ixf_ops +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.logger import init_logger +logger = init_logger(__name__) + + +def unpack_rows(packed_w: torch.Tensor, num_bits: int) -> torch.Tensor: + """ + Efficient vectorized unpacking. + Converts [K // pack_factor, N] int32 tensor → [K, N] int8 tensor. + + Args: + packed_w: torch.int32 tensor of shape [K // pack_factor, N]. + num_bits: Number of bits per packed element (e.g., 4). + + Returns: + unpacked: torch.int8 tensor of shape [K, N]. + """ + pack_factor = 32 // num_bits + k_packed, n = packed_w.shape + k = k_packed * pack_factor + + mask = (1 << num_bits) - 1 + + # [pack_factor, 1, 1] + shifts = (num_bits * torch.arange(pack_factor, device=packed_w.device)).view(-1, 1, 1) + + # [pack_factor, k_packed, n] + packed_expanded = packed_w.unsqueeze(0) + + # Extract each group of num_bits using bitwise ops + unpacked_groups = ((packed_expanded >> shifts) & mask).to(torch.int8) + # [pack_factor, k_packed, n] → [k, n] + unpacked = unpacked_groups.permute(1, 0, 2).reshape(k, n) + + return unpacked + + +def pack_cols(x: torch.Tensor, pack_num: int = 8, order_map=None) -> torch.Tensor: + """ + Efficient vectorized version: pack int4 values (0–15) into int32. + Each int32 element contains `pack_num` 4-bit values. + + Args: + x: Tensor of shape [rows, cols * pack_num], dtype=int32. + Represents unpacked int4 values. + pack_num: Number of 4-bit elements to pack into each int32. + order_map: Index mapping defining the order of 4-bit packing, + must match the unpack order used in `unpack_tensor`. + + Returns: + Tensor of shape [rows, cols], dtype=int32 — packed result. + """ + # Default sequential order if none provided + if order_map is None: + order_map = list(range(pack_num)) + order_map = torch.tensor(order_map, device=x.device) + + # Number of bits per packed element (e.g., 32 / 8 = 4 bits) + unit = 32 // pack_num + rows, cols_pack = x.shape + assert cols_pack % pack_num == 0, "Number of columns must be a multiple of pack_num" + cols = cols_pack // pack_num + + # Reshape input into groups of `pack_num` int4 values + # Shape: [rows, cols, pack_num] + x_reshape = x.view(rows, cols, pack_num) + + # Reorder elements according to order_map + # order_map is broadcasted to match shape [rows, cols, pack_num] + x_reorder = torch.gather(x_reshape, 2, order_map.view(1, 1, -1).expand(rows, cols, -1)) + + # Keep only the lower 4 bits of each value + x_reorder = x_reorder & 0xF + + # Compute bit shifts for each position (e.g., [0, 4, 8, 12, 16, 20, 24, 28]) + shifts = (unit * torch.arange(pack_num, device=x.device)).view(1, 1, -1) + + # Shift and combine (bitwise OR) along the last dimension + # Using sum() is safe since bits don't overlap between 4-bit slots + res = (x_reorder << shifts).sum(dim=-1).to(torch.int32) + + return res class MarlinLinearKernel(MPLinearKernel): @classmethod @@ -79,96 +165,133 @@ class MarlinLinearKernel(MPLinearKernel): getattr(layer, self.w_s_name).data = ( getattr(layer, self.w_s_name).data * 512 ) + assert (c.weight_type.size_bits == 4) , f"MarlinLinearKernel now only support uint4, uint4b8, \ + now quant weight_type {c.weight_typ}" + + # device = getattr(layer, self.w_q_name).device + - row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0] - self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel) + # row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0] + # self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel) # Allocate marlin workspace. - self.workspace = marlin_make_workspace_new(device) + # self.workspace = marlin_make_workspace_new(device) # Default names since marlin requires empty parameters for these, # TODO: remove this requirement from marlin (allow optional tensors) - if self.w_gidx_name is None: - self.w_gidx_name = "g_idx" - if self.w_zp_name is None: - self.w_zp_name = "w_zp" + # if self.w_gidx_name is None: + # self.w_gidx_name = "g_idx" + # if self.w_zp_name is None: + # self.w_zp_name = "w_zp" + if c.has_g_idx: + assert self.w_gidx_name is not None + perm = torch.argsort(getattr(layer, self.w_gidx_name)).to(torch.int) + + self.act_perm = lambda x: x[:, perm] def transform_w_q(x): - assert isinstance(x, BasevLLMParameter) - permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) - x.data = ops.gptq_marlin_repack( - x.data.contiguous(), - perm=layer.g_idx_sort_indices, - size_k=c.partition_weight_shape[0], - size_n=c.partition_weight_shape[1], - num_bits=c.weight_type.size_bits, - is_a_8bit=is_a_8bit, - ) + # assert isinstance(x, BasevLLMParameter) + # permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + # x.data = ops.gptq_marlin_repack( + # x.data.contiguous(), + # perm=layer.g_idx_sort_indices, + # size_k=c.partition_weight_shape[0], + # size_n=c.partition_weight_shape[1], + # num_bits=c.weight_type.size_bits, + # is_a_8bit=is_a_8bit, + # ) + assert x.data.ndim == 2 + if x._packed_dim == 1: #CompressedTensorsWNA16 + #[oc, ic // 8] - > [oc, ic] + x_unpacked = unpack_quantized_values_into_int32(x.data, + c.weight_type, + packed_dim=1) + if c.has_g_idx: + x_unpacked = x_unpacked[:,perm] + #[oc, ic] -> [ic, oc] + x_unpacked = x_unpacked.t().contiguous() + + elif x._packed_dim == 0: #GPTQMarlinLinearMethod + + #[ic // 8, oc] -> [ic , oc] + x_unpacked = unpack_rows(x.data,c.weight_type.size_bits) + if c.has_g_idx: + x_unpacked = x_unpacked[perm:] + raise NotImplementedError(f"GPTQMarlinLinearMethod has_g_idx not test, \ + Please check whether the model's inference results are correct, and annotate/modify the statement. ") + else: + raise NotImplementedError(f"transform_w_q pack_dim {x._packed_dim} not implement") + + #[ic, oc]-> [ic, oc//8] + x_packed = pack_cols(x_unpacked, order_map=[0, 2, 4, 6, 1, 3, 5, 7]) + x.data = x_packed.contiguous() + x._packed_dim = 1 return x def transform_w_s(x): assert isinstance(x, BasevLLMParameter) permute_param_layout_(x, input_dim=0, output_dim=1) - x.data = marlin_permute_scales( - x.data.contiguous(), - size_k=c.partition_weight_shape[0], - size_n=c.partition_weight_shape[1], - group_size=c.group_size, - is_a_8bit=is_a_8bit, - ) + x.data = x.data.contiguous() + return x.to(dtype=c.act_type) - if c.group_size == -1: - num_groups = 1 - else: - num_groups = c.partition_weight_shape[0] // c.group_size - - if c.act_type == torch.int8 and num_groups > 1: - x.data, input_global_scale = marlin_act_int8_process_scales(x.data) - layer.register_parameter( - "input_global_scale", - torch.nn.Parameter(input_global_scale, requires_grad=False), - ) - else: - layer.input_global_scale = None + # if c.has_g_idx: + # g_idx, g_idx_sort_indices = marlin_sort_g_idx( + # getattr(layer, self.w_gidx_name) + # ) + # self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) + # layer.g_idx_sort_indices = g_idx_sort_indices + # else: + # setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device)) + # layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + def transform_w_zp(x): + grouped_k = (c.partition_weight_shape[0] // + c.group_size if c.group_size != -1 else 1) + x_unpacked = unpack_cols(x.clone().t(), c.weight_type.size_bits, grouped_k, c.partition_weight_shape[1]) + x_packed = pack_cols(x_unpacked, order_map=[0, 2, 4, 6, 1, 3, 5, 7]) + x.data = x_packed.contiguous() return x - - if c.has_g_idx: - g_idx, g_idx_sort_indices = marlin_sort_g_idx( - getattr(layer, self.w_gidx_name) - ) - self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) - layer.g_idx_sort_indices = g_idx_sort_indices - else: - setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device)) - layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + if c.zero_points: - grouped_k = ( - c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1 - ) - self._transform_param( - layer, - self.w_zp_name, - lambda x: marlin_zero_points( - unpack_cols( - x.t(), - c.weight_type.size_bits, - grouped_k, - c.partition_weight_shape[1], - ), - size_k=grouped_k, - size_n=c.partition_weight_shape[1], - num_bits=c.weight_type.size_bits, - is_a_8bit=is_a_8bit, - ), - ) + # grouped_k = ( + # c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1 + # ) + # self._transform_param( + # layer, + # self.w_zp_name, + # lambda x: marlin_zero_points( + # unpack_cols( + # x.t(), + # c.weight_type.size_bits, + # grouped_k, + # c.partition_weight_shape[1], + # ), + # size_k=grouped_k, + # size_n=c.partition_weight_shape[1], + # num_bits=c.weight_type.size_bits, + # ), + # ) + self._transform_param(layer, self.w_zp_name, transform_w_zp) else: - setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) + # setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) + #weight_type = uint4b8, using c.weight_type.bias as zero point,according quant method. + #[ic, oc]-> [ic, oc//8] + w_zp = torch.full_like(getattr(layer, self.w_s_name), c.weight_type.bias, dtype=torch.int32) + w_zp_pack = pack_cols(w_zp, order_map=[0, 2, 4, 6, 1, 3, 5, 7]).contiguous() + weight_zero_point = torch.nn.Parameter( + w_zp_pack, + requires_grad=False) + + if hasattr(layer, self.w_zp_name): + replace_parameter(layer, self.w_zp_name, weight_zero_point) #GPTQMarlinLinearMethod + else: + layer.register_parameter("weight_zero_point", weight_zero_point) #CompressedTensorsWNA16 + self._transform_param(layer, self.w_q_name, transform_w_q) self._transform_param(layer, self.w_s_name, transform_w_s) - if hasattr(layer, "bias") and layer.bias is not None: - layer.bias.data = marlin_permute_bias(layer.bias) + # if hasattr(layer, "bias") and layer.bias is not None: + # layer.bias.data = marlin_permute_bias(layer.bias) def apply_weights( self, @@ -179,22 +302,39 @@ class MarlinLinearKernel(MPLinearKernel): c = self.config w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer) - # `process_weights_after_loading` will ensure w_zp and w_gidx are not - # None for marlin + pack_factor = 32 // c.weight_type.size_bits + + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + x_2d = x.reshape(-1, x.shape[-1]) + + if c.has_g_idx: + x_2d = self.act_perm(x_2d) + + out = ops.custom_gptq_marlin_gemm(input = x_2d, + qweight = w_q, + scales = w_s, + qzeros = w_zp, + pack_factor = pack_factor, + group_size = c.group_size, + bias = bias) + out = out.reshape(out_shape) + # if bias is not None: + # out.add_(bias) + return out + - return apply_gptq_marlin_linear( - input=x, - weight=w_q, - weight_scale=w_s, - weight_zp=w_zp, # type: ignore - g_idx=w_gidx, # type: ignore - g_idx_sort_indices=layer.g_idx_sort_indices, - workspace=self.workspace, - wtype=c.weight_type, - input_size_per_partition=c.partition_weight_shape[0], - output_size_per_partition=c.partition_weight_shape[1], - is_k_full=self.is_k_full, - input_global_scale=getattr(layer, "input_global_scale", None), - bias=bias, - input_dtype=c.act_type, - ) + # # `process_weights_after_loading` will ensure w_zp and w_gidx are not + # # None for marlin + # return apply_gptq_marlin_linear( + # input=x, + # weight=w_q, + # weight_scale=w_s, + # weight_zp=w_zp, # type: ignore + # g_idx=w_gidx, # type: ignore + # g_idx_sort_indices=layer.g_idx_sort_indices, + # workspace=self.workspace, + # wtype=c.weight_type, + # input_size_per_partition=c.partition_weight_shape[0], + # output_size_per_partition=c.partition_weight_shape[1], + # is_k_full=self.is_k_full, + # bias=bias) diff --git a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py index e4b14ea..46dd4c6 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py @@ -18,7 +18,6 @@ from .ScaledMMLinearKernel import ( Int8ScaledMMLinearLayerConfig, ) -import vllm.envs as envs class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel): @classmethod @@ -38,13 +37,28 @@ class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel): config = self.config # WEIGHT # Cutlass kernels need transposed weight. - weight = getattr(layer, w_q_name) - replace_parameter( - layer, - w_q_name, - # torch.nn.Parameter(weight.t().data, requires_grad=False), - torch.nn.Parameter(weight.data if envs.VLLM_W8A8_LINEAR_USE_W4A8 else weight.t().data, requires_grad=False), - ) + weight = getattr(layer, w_q_name) + if layer.scheme.is_w4a8_linear: + self.format = "NN" + replace_parameter(layer, w_q_name, torch.nn.Parameter(weight.data.contiguous(), requires_grad=False)) + else: + self.format = "TN" #默认weight都是按T排布 + m, k = weight.shape + if(m % 64 == 0 and k % 64 == 0): + self.format= "NN" + replace_parameter( + layer, w_q_name, + torch.nn.Parameter(weight.t().data.contiguous(), requires_grad=False))#原始排布是T[m,k] 处理完后是N[k, m] + else: + if k % 64 != 0: + pad_k = (k // 64 + 1) * 64 + weight_pad = torch.empty((m, pad_k), dtype=weight.dtype, device=weight.device) + _weight = weight_pad[:, :k] + _weight.copy_(weight) + weight = _weight + replace_parameter( + layer, w_q_name, + torch.nn.Parameter(weight.t(), requires_grad=False)) # WEIGHT SCALE # Cutlass kernels support only per-tensor and per-channel. @@ -114,6 +128,7 @@ class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel): layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, + is_w4a8_linear: bool = False, ) -> torch.Tensor: w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer) @@ -121,9 +136,15 @@ class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel): # * dynamic, i_s is None and x_s computed from x. # * static, i_s is scalar and x_s is i_s. symmetric = azp_adj is None - x_q, x_s, x_zp = ops.scaled_int8_quant( - x.contiguous(), i_s, i_zp, symmetric=symmetric - ) + if isinstance(x, tuple): + x_q, x_s, out_dtype = x + x_zp = None + else: + out_dtype = x.dtype + x_q, x_s, x_zp = ops.scaled_int8_quant(x.contiguous(), + i_s, + i_zp, + symmetric=symmetric) if x_zp is not None: # Currently, static is always per-tensor and dynamic is per-token @@ -134,14 +155,21 @@ class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel): w_q, scale_a=x_s, scale_b=w_s, - out_dtype=x.dtype, + out_dtype=out_dtype, azp_adj=azp_adj, azp=azp, bias=bias, ) + if self.format == "NN" and x_q.shape[-1] != w_q.shape[0]: + padding = w_q.shape[0] - x_q.shape[-1] + x_align = torch.nn.functional.pad(x_q, (0, padding), mode='constant', value=0) + elif self.format == "TN" and x_q.shape[-1] != w_q.shape[-1]: + padding = w_q.shape[-1] - x_q.shape[-1] + x_align = torch.nn.functional.pad(x_q, (0, padding), mode='constant', value=0) + else: + x_align = x_q return ops.cutlass_scaled_mm( - # x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias - x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias, format="NN" if envs.VLLM_W8A8_LINEAR_USE_W4A8 else "TN" + x_align, w_q, scale_a=x_s, scale_b=w_s, out_dtype=out_dtype, bias=bias, format=self.format, is_w4a8_linear=is_w4a8_linear ) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 11cdc73..aee1621 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -8,6 +8,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from vllm import envs from vllm.distributed import ( divide, get_tensor_model_parallel_rank, @@ -130,13 +131,12 @@ class SiluAndMul(CustomOp): def __init__(self, *, compile_native: bool = True): super().__init__(compile_native=compile_native) - if current_platform.is_cuda_alike(): + if current_platform.is_cuda_alike() or current_platform.is_xpu(): from vllm import _custom_ops as ops - self.op = ops.silu_and_mul - elif current_platform.is_xpu(): - from vllm._ipex_ops import ipex_ops - - self.op = ipex_ops.silu_and_mul + if envs.VLLM_USE_SILU_QUANT_FUSION: + self.op = ops.silu_and_mul_quant + else: + self.op = ops.silu_and_mul elif current_platform.is_cpu(): self._forward_method = self.forward_native @@ -146,11 +146,15 @@ class SiluAndMul(CustomOp): d = x.shape[-1] // 2 return F.silu(x[..., :d]) * x[..., d:] - def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + def forward_cuda(self, x: torch.Tensor, out_dim: int = 0) -> torch.Tensor: d = x.shape[-1] // 2 output_shape = x.shape[:-1] + (d,) - out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - self.op(out, x) + if envs.VLLM_USE_SILU_QUANT_FUSION: + quant_out, out_scales = self.op(x, out_dim) + out = (quant_out, out_scales, x.dtype) + else: + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + self.op(out, x) return out def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: @@ -174,7 +178,6 @@ class MulAndSilu(CustomOp): def __init__(self): super().__init__() if current_platform.is_cuda_alike() or current_platform.is_xpu(): - # self.op = torch.ops._C.mul_and_silu from vllm import _custom_ops as ops self.op = ops.mul_and_silu elif current_platform.is_cpu(): @@ -397,7 +400,6 @@ class NewGELU(CustomOp): or current_platform.is_cpu() or current_platform.is_xpu() ): - # self.op = torch.ops._C.gelu_new from vllm import _custom_ops as ops self.op = ops.gelu_new @@ -427,7 +429,8 @@ class FastGELU(CustomOp): or current_platform.is_cpu() or current_platform.is_xpu() ): - self.op = torch.ops._C.gelu_fast + from vllm import _custom_ops as ops + self.op = ops.gelu_fast def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" @@ -455,7 +458,6 @@ class QuickGELU(CustomOp): or current_platform.is_cpu() or current_platform.is_xpu() ): - # self.op = torch.ops._C.gelu_quick from vllm import _custom_ops as ops self.op = ops.gelu_quick diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 38f1099..2948d46 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -12,7 +12,7 @@ from vllm.config.vllm import VllmConfig from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention.kv_transfer_utils import ( - maybe_transfer_kv_layer, + maybe_transfer_kv_layer ) from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant @@ -40,6 +40,9 @@ from vllm.v1.kv_cache_interface import ( KVCacheSpec, SlidingWindowSpec, ) +from .extra_cache import StaticQuantManager +from ixformer.core import config +_USE_TORCH_OPS = config.IXFORMER_USE_TORCH_OPS if TYPE_CHECKING: from vllm.model_executor.layers.attention import MLAAttention @@ -202,6 +205,7 @@ class Attention(nn.Module, AttentionLayerBase): kv_sharing_target_layer_name: str | None = None, attn_backend: type[AttentionBackend] | None = None, head_size_v: int | None = None, + extra_cache_para: dict = None, **extra_impl_args, ) -> None: """ @@ -258,6 +262,7 @@ class Attention(nn.Module, AttentionLayerBase): self.num_heads = num_heads self.head_size = head_size + self.hidden_size = head_size * num_heads self.head_size_v = self.head_size if head_size_v is None else head_size_v self.num_kv_heads = num_kv_heads self.sliding_window = sliding_window @@ -326,6 +331,15 @@ class Attention(nn.Module, AttentionLayerBase): kv_sharing_target_layer_name, **extra_impl_args, ) + if extra_cache_para is not None: + self.quant_manager = StaticQuantManager( + layer_id=extra_cache_para.get("layer_id", None), + shape=(self.num_kv_heads, self.head_size_v), + dtype=torch.float32, + total_layer_num=extra_cache_para.get("total_layer_num", None) + ) + else: + self.quant_manager = None self.backend = AttentionBackendEnum[self.attn_backend.get_name()] self.dtype = dtype @@ -333,7 +347,10 @@ class Attention(nn.Module, AttentionLayerBase): # torch.compile works by registering the attention as one giant # opaque custom op. For other platforms, we directly call them # and let torch.compile handle them. - self.use_direct_call = not current_platform.opaque_attention_op() + if _USE_TORCH_OPS: + self.use_direct_call = False + else: + self.use_direct_call = True self.use_output = self.attn_backend.accept_output_buffer compilation_config = vllm_config.compilation_config @@ -349,14 +366,26 @@ class Attention(nn.Module, AttentionLayerBase): compilation_config.static_forward_context, ) self.kv_sharing_target_layer_name = kv_sharing_target_layer_name - - # use a placeholder kv cache tensor during init, which will be replaced - # by bind_kv_cache - # this variable will not be accessed if use_direct_call is True self.kv_cache = [ torch.tensor([]) for _ in range(vllm_config.parallel_config.pipeline_parallel_size) ] + self.is_i8qi8ki8v = envs.VLLM_ATTN_OPT_LEVEL == 1 + self.is_i8qi8kf16v = envs.VLLM_ATTN_OPT_LEVEL == 2 + if self.is_i8qi8kf16v: + self.kv_cache_scale = [ + torch.tensor([]) for _ in range(get_current_vllm_config( + ).parallel_config.pipeline_parallel_size) + ] + elif self.is_i8qi8ki8v: + self.kv_cache_scale = [ + [torch.tensor([]), torch.tensor([])] for _ in range(get_current_vllm_config( + ).parallel_config.pipeline_parallel_size) + ] + + # use a placeholder kv cache tensor during init, which will be replaced + # by bind_kv_cache + # this variable will not be accessed if use_direct_call is True # Initialize KV cache quantization attributes _init_kv_cache_quant(self, quant_config, prefix) @@ -396,6 +425,7 @@ class Attention(nn.Module, AttentionLayerBase): context using `vllm.forward_context.get_forward_context().attn_metadata`. """ + optional_args = {} if self.calculate_kv_scales: torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name) output_dtype = query.dtype @@ -412,15 +442,8 @@ class Attention(nn.Module, AttentionLayerBase): query, _ = self.query_quant(query, self._q_scale) if self.use_output: - if output_shape is None: - # Handle both 2D [num_tokens, hidden] and - # 3D [num_tokens, heads, head_dim] query - num_tokens = query.shape[0] - output_shape = torch.Size( - (num_tokens, self.num_heads * self.head_size_v) - ) + output_shape = output_shape if output_shape is not None else query.shape output = torch.empty(output_shape, dtype=output_dtype, device=query.device) - hidden_size = output_shape[-1] # Reshape the query, key, and value tensors. # NOTE(woosuk): We do this outside the custom op to minimize the # CPU overheads from the non-CUDA-graph regions. @@ -430,46 +453,50 @@ class Attention(nn.Module, AttentionLayerBase): key = key.view(-1, self.num_kv_heads, self.head_size) if value is not None: value = value.view(-1, self.num_kv_heads, self.head_size_v) - kv_cache_dummy_dep = None if self.use_direct_call: - # Skip this if sharing KV cache with an earlier attention layer. - if ( - not self.attn_backend.forward_includes_kv_cache_update - and self.kv_sharing_target_layer_name is None - and key is not None - and value is not None - ): - kv_cache_dummy_dep = unified_kv_cache_update( - key, value, self.layer_name + def direct_forward(layer_name: str, output: torch.Tensor): + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + # Skip this if sharing KV cache with an earlier attention layer. + if self.is_i8qi8ki8v or self.is_i8qi8kf16v: + optional_args["kv_cache_scale"] = self.kv_cache_scale[forward_context.virtual_engine] + output = self.impl.forward( + self, + query, + key, + value, + self_kv_cache, + attn_metadata, + output=output, + **optional_args ) - unified_attention_with_output( - query, - key, - value, - output, - self.layer_name, - kv_cache_dummy_dep=kv_cache_dummy_dep, - ) + return output + return maybe_transfer_kv_layer(direct_forward)(self.layer_name, output) else: - # Skip this if sharing KV cache with an earlier attention layer. - if ( - not self.attn_backend.forward_includes_kv_cache_update - and self.kv_sharing_target_layer_name is None - and key is not None - and value is not None - ): - kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update( - key, value, self.layer_name - ) + if self.is_i8qi8ki8v: + forward_context: ForwardContext = get_forward_context() + kv_cache_scale = self.kv_cache_scale[forward_context.virtual_engine][0] + v_cache_scale = self.kv_cache_scale[forward_context.virtual_engine][1] + elif self.is_i8qi8kf16v: + forward_context: ForwardContext = get_forward_context() + kv_cache_scale = self.kv_cache_scale[forward_context.virtual_engine] + v_cache_scale = None + else: + kv_cache_scale = None + v_cache_scale = None torch.ops.vllm.unified_attention_with_output( query, key, value, output, self.layer_name, - kv_cache_dummy_dep=kv_cache_dummy_dep, + kv_cache_scale, + v_cache_scale ) - return output.view(-1, hidden_size) + return output.view(-1, self.hidden_size) else: assert self.attn_backend.forward_includes_kv_cache_update, ( "Split KV cache update not supported when output tensor not provided." @@ -521,6 +548,7 @@ class Attention(nn.Module, AttentionLayerBase): block_size = vllm_config.cache_config.block_size # Should not be called for enc-dec or encoder-only attention. assert self.attn_type == AttentionType.DECODER + # TODO : kernel unsupport kvcache for sliding_window, use FullAttentionSpec replace if self.sliding_window is not None: assert not vllm_config.model_config.use_mla, ( "MLA is not supported for slidingwindow" @@ -689,6 +717,8 @@ def unified_attention_with_output( value: torch.Tensor, output: torch.Tensor, layer_name: str, + kv_cache_scale: torch.Tensor | None = None, + v_cache_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, kv_cache_dummy_dep: torch.Tensor | None = None, @@ -696,9 +726,7 @@ def unified_attention_with_output( # kv_cache_dummy_dep is not used but accepting it creates a data dependency # that ensures torch.compile preserves ordering between KV cache update and # attention forward. - del kv_cache_dummy_dep attn_metadata, self, kv_cache, _ = get_attention_context(layer_name) - self.impl.forward( self, query, @@ -707,6 +735,7 @@ def unified_attention_with_output( kv_cache, attn_metadata, output=output, + kv_cache_scale = [kv_cache_scale, v_cache_scale] if envs.VLLM_ATTN_OPT_LEVEL==1 else kv_cache_scale, output_scale=output_scale, output_block_scale=output_block_scale, ) @@ -718,6 +747,8 @@ def unified_attention_with_output_fake( value: torch.Tensor, output: torch.Tensor, layer_name: str, + kv_cache_scale: torch.Tensor | None = None, + v_cache_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, kv_cache_dummy_dep: torch.Tensor | None = None, diff --git a/vllm/model_executor/layers/attention/extra_cache.py b/vllm/model_executor/layers/attention/extra_cache.py new file mode 100644 index 0000000..36d2c1f --- /dev/null +++ b/vllm/model_executor/layers/attention/extra_cache.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os + +import torch +from filelock import FileLock + +import vllm.envs as envs +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class StaticQuantManager: + def __init__( + self, + layer_id: int, + shape: tuple, + dtype: torch.dtype, + total_layer_num: int, + device: str = None, + tp_size: int = None, + tp_rank: int = None, + file_save_path: str = None, + save_step: int = 100, + info_step: int = 100, + ): + # update parament + if tp_size is None: + tp_size = get_tensor_model_parallel_world_size() + if tp_rank is None: + tp_rank = get_tensor_model_parallel_rank() + if file_save_path is None: + file_save_path = envs.VLLM_ATTN_STATIC_QUANT_SCALE_FILE_PATH + if device is None: + device = "cuda" + + # check parament + if file_save_path in [None, ""]: + self.disable = True + return + + para_dir = os.path.dirname(file_save_path) + assert os.path.exists(para_dir), ( + f"StaticQuantManager workdir {para_dir} not exist!" + ) + self.disable = os.path.exists(file_save_path) + if self.disable: + return + + assert layer_id is not None + assert total_layer_num is not None + + world_rank = torch.distributed.get_rank() + work_dir = os.path.join(para_dir, "StaticQuantManagerWorkdir") + self.operator = world_rank == 0 and layer_id == 0 + if not os.path.exists(work_dir): + if self.operator: + logger.debug(f"StaticQuantManager Creat {work_dir}!") + os.mkdir(work_dir) + self.file_save_path = file_save_path + self.work_dir = work_dir + + self.tp_size = tp_size + self.tp_rank = tp_rank + self.world_rank = world_rank + self.layer_id = layer_id + self.total_layer_num = total_layer_num + self.save_step = save_step + self.info_step = info_step + self.update_count = 0 + self.save_flag = False + self.scales = torch.zeros(shape, dtype=dtype, device=device) + logger.debug( + f"StaticQuantManager info: world_rank:{self.world_rank} tp_rank:{self.tp_rank} layer_id:{self.layer_id} scale shape:{shape} self.scales:{self.scales.device}" + ) + + def check_enable(self): + return not self.disable + + def update_data(self, data): + if self.disable: + return + + self.scales = torch.max(data, self.scales) + + # save file + self.update_count += 1 + if self.update_count % self.info_step == 0 and self.operator: + logger.info(f"StaticQuantManager run update_data {self.update_count} step") + + if self.update_count % self.save_step == 0: + # step1: save to disk + save_file_path = os.path.join( + self.work_dir, f"{self.layer_id}_{self.tp_rank}.pt" + ) + lock_file_path = os.path.join( + self.work_dir, f"{self.layer_id}_{self.tp_rank}.lock" + ) + lock = FileLock(lock_file_path) + cpu_data = self.scales.cpu() + with lock: + torch.save(cpu_data, save_file_path) + + # step2: merge and save + if self.save_flag and self.operator: + save_dict = {} + for idx in range(self.total_layer_num): + tp_datas = [] + for tp_rank in range(self.tp_size): + load_file = os.path.join(self.work_dir, f"{idx}_{tp_rank}.pt") + lock_file_path = os.path.join( + self.work_dir, f"{idx}_{tp_rank}.lock" + ) + lock = FileLock(lock_file_path) + with lock: + cur_data = torch.load(load_file) + tp_datas.append(cur_data) + + layer_data = torch.concat(tp_datas) + save_dict[f"layer_{idx}"] = layer_data + + torch.save(save_dict, self.file_save_path) + logger.info( + f"StaticQuantManager save to {self.file_save_path} with {self.update_count} step" + ) + self.save_flag = True diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index f9cfcd2..2b74e56 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -191,7 +191,7 @@ import functools from abc import abstractmethod from dataclasses import dataclass, field from enum import Enum -from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar, cast +from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar, cast, Any if TYPE_CHECKING: from flashinfer import BatchPrefillWithRaggedKVCacheWrapper @@ -221,8 +221,6 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.model_executor.layers.linear import ( ColumnParallelLinear, - UnquantizedLinearMethod, - LinearBase, ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 @@ -230,6 +228,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, get_and_maybe_dequant_weights, ) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + LinearBase, + UnquantizedLinearMethod, +) from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer, has_nvidia_artifactory from vllm.utils.math_utils import cdiv, round_down @@ -245,14 +248,16 @@ from vllm.v1.attention.backend import ( AttentionType, CommonAttentionMetadata, MLAAttentionImpl, + SparseMLAAttentionImpl, ) from vllm.v1.attention.backends.fa_utils import get_flash_attn_version from vllm.v1.attention.backends.utils import ( get_dcp_local_seq_lens, get_per_layer_parameters, infer_global_hyperparameters, - split_decodes_and_prefills, + split_decodes_and_prefills ) +from vllm.v1.attention.backend import AttentionCGSupport from vllm.v1.attention.ops.common import cp_lse_ag_out_rs from vllm.v1.attention.ops.merge_attn_states import merge_attn_states from vllm.v1.attention.selector import get_attn_backend @@ -262,14 +267,15 @@ from vllm.v1.kv_cache_interface import ( MLAAttentionSpec, ) -from vllm.v1.attention.backend import AttentionCGSupport - logger = init_logger(__name__) class MLAAttention(nn.Module, AttentionLayerBase): """Multi-Head Latent Attention layer. + NOTE: Please read the comment at the top of the file before trying to + understand this class + This class takes query, and compressed key/value tensors as input. The class does the following: @@ -293,6 +299,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): prefix: str = "", use_sparse: bool = False, indexer: object | None = None, + rotary_emb: object | None = None, **extra_impl_args, ): super().__init__() @@ -303,8 +310,20 @@ class MLAAttention(nn.Module, AttentionLayerBase): self.v_head_dim = v_head_dim self.q_lora_rank = q_lora_rank self.kv_lora_rank = kv_lora_rank + self.kv_b_proj = kv_b_proj self.head_size = kv_lora_rank + qk_rope_head_dim self.layer_name = prefix + self.indexer = indexer + self.rotary_emb = rotary_emb + speculative_config = get_current_vllm_config().speculative_config + self.use_spec_decode = ( + speculative_config is not None + and speculative_config.num_speculative_tokens is not None + and speculative_config.num_speculative_tokens > 0 + ) + + self.num_kv_heads = 1 + self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype @@ -329,6 +348,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): block_size, use_mla=True, use_sparse=use_sparse, + num_heads=self.num_heads, ) if ( @@ -370,8 +390,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): indexer=indexer, **extra_impl_args, ) - - # self.use_direct_call = not current_platform.opaque_attention_op() + self.q_pad_num_heads = getattr(self.impl, "q_pad_num_heads", None) self.use_direct_call = True compilation_config = get_current_vllm_config().compilation_config @@ -385,6 +404,12 @@ class MLAAttention(nn.Module, AttentionLayerBase): get_current_vllm_config().parallel_config.pipeline_parallel_size ) ] + if envs.VLLM_USE_INT8_MLA: + self.kv_cache_scale = [ + torch.tensor([]) for _ in range(get_current_vllm_config( + ).parallel_config.pipeline_parallel_size) + ] + self.is_int8_mla = envs.VLLM_USE_INT8_MLA self.use_sparse = use_sparse @@ -393,85 +418,76 @@ class MLAAttention(nn.Module, AttentionLayerBase): self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) - def forward_corex( - self, - q: torch.Tensor, - kv_c_normed: torch.Tensor, - k_pe: torch.Tensor, - output_shape: torch.Size | None = None, - ) -> torch.Tensor: - if self.calculate_kv_scales: - torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name) + self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled() - optional_args = {} - output_shape = output_shape if output_shape is not None else q.shape - output = torch.empty(output_shape, dtype=q.dtype, device=q.device) + # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported + self.is_aiter_triton_fp4_bmm_enabled = ( + rocm_aiter_ops.is_fp4bmm_enabled() + and self.kv_b_proj.weight.dtype == torch.bfloat16 + ) - if self.use_direct_call: - forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[self.layer_name] - self_kv_cache = self.kv_cache[forward_context.virtual_engine] - if self.attn_backend.accept_output_buffer: - output = torch.empty(output_shape, dtype=q.dtype, device=q.device) - output = self.impl.forward( - self, - q, - kv_c_normed, - k_pe, - self_kv_cache, - attn_metadata, - output=output, - **optional_args, - ) - else: - output = self.impl.forward( - self, - q, - kv_c_normed, - k_pe, - self_kv_cache, - attn_metadata, - **optional_args, - ) - return output + # Attributes for forward_impl method + self.chunked_prefill_workspace_size = ( + MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size( + get_current_vllm_config() + ) + ) + self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8( + static=True, + group_shape=GroupShape.PER_TENSOR, + compile_native=True, + ) def forward( self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, + position: torch.Tensor, output_shape: torch.Size | None = None, ) -> torch.Tensor: - return self.forward_corex(q, kv_c_normed, k_pe, output_shape) + optional_args = {} if self.calculate_kv_scales: torch.ops.vllm.maybe_calc_kv_scales(q, kv_c_normed, k_pe, self.layer_name) if self.use_direct_call: - forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[self.layer_name] - self_kv_cache = self.kv_cache[forward_context.virtual_engine] - - if self.attn_backend.accept_output_buffer: - output = torch.empty(output_shape, dtype=q.dtype, device=q.device) - self.impl.forward( - self, - q, - kv_c_normed, - k_pe, - self_kv_cache, - attn_metadata, - output=output, - ) - return output - else: - return self.impl.forward( - self, q, kv_c_normed, k_pe, self_kv_cache, attn_metadata - ) + def direct_forward(layer_name: str, output_shape: torch.Size | None): + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[layer_name] + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + if self.is_int8_mla: + optional_args["kv_cache_scale"] = self.kv_cache_scale[forward_context.virtual_engine] + if self.attn_backend.accept_output_buffer: + output_shape = (output_shape if output_shape is not None else q.shape) + output = torch.zeros(output_shape, + dtype=q.dtype, + device=q.device) + output = self.forward_impl( + q, + kv_c_normed, + k_pe, + self_kv_cache, + attn_metadata, + output=output, + positions=position, + **optional_args + ) + return output + else: + return self.forward_impl( + q, kv_c_normed, k_pe, self_kv_cache, attn_metadata, position=position + ) + return maybe_transfer_kv_layer(direct_forward)(self.layer_name, output_shape) else: + kv_cache_dummy_dep = torch.ops.vllm.unified_mla_kv_cache_update( + kv_c_normed, + k_pe, + self.layer_name, + self.kv_cache_dtype, + self._k_scale, + ) if self.attn_backend.accept_output_buffer: output = torch.empty(output_shape, dtype=q.dtype, device=q.device) torch.ops.vllm.unified_mla_attention_with_output( @@ -480,6 +496,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): k_pe, output, self.layer_name, + kv_cache_dummy_dep=kv_cache_dummy_dep, ) return output else: @@ -488,9 +505,210 @@ class MLAAttention(nn.Module, AttentionLayerBase): kv_c_normed, k_pe, self.layer_name, + kv_cache_dummy_dep=kv_cache_dummy_dep, ) + def forward_impl( + self, + q: torch.Tensor, + k_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: torch.Tensor, + attn_metadata: "MLACommonMetadata", + output: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + kv_cache_scale: torch.Tensor | None = None, + output_scale: torch.Tensor | None = None, + output_block_scale: torch.Tensor | None = None, + ) -> torch.Tensor: + assert output is not None, "Output tensor must be provided." + assert positions is not None, "positions tensor must be provided." + + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported for MLA" + ) + + if attn_metadata is None: + # During the profile run try to simulate to worse case output size + # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` + # since this can be large + _ = torch.empty( + ( + self.chunked_prefill_workspace_size, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ), + device=k_c_normed.device, + dtype=k_c_normed.dtype, + ) + + # The zero fill is required when used with DP + EP + # to ensure all ranks within a DP group compute the + # same expert outputs. + output = torch.empty(output.shape[0], self.v_head_dim * self.num_heads, device=q.device, + dtype=q.dtype) + return output + + if self.impl.dcp_world_size == -1: + self.impl.dcp_world_size = get_dcp_group().world_size + + fp8_attention = self.kv_cache_dtype.startswith("fp8") + + is_sparse_impl = isinstance(self.impl, SparseMLAAttentionImpl) + + if not is_sparse_impl: + assert ( + attn_metadata.num_decodes is not None + and attn_metadata.num_prefills is not None + and attn_metadata.num_decode_tokens is not None + ) + + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + + decode_q = q[:num_decode_tokens] + k_pe = k_pe.unsqueeze(1) + + prefill_q = q[num_decode_tokens:] + prefill_k_pe = k_pe[num_decode_tokens:] + prefill_k_c_normed = k_c_normed[num_decode_tokens:] + prefill_k = torch.empty_like(prefill_q) + + # write the latent and rope to kv cache + write_kv_cache = (None, None) + if kv_cache.numel() > 0: + if has_decode: + decode_q_pe, decode_k_pe = self.rotary_emb(positions[:num_decode_tokens], decode_q[..., self.qk_nope_head_dim:], k_pe[:num_decode_tokens],) + if envs.VLLM_USE_INT8_MLA: + k_c_normed_int8, k_c_normed_scale,_ = ops.scaled_int8_quant(k_c_normed[:num_decode_tokens]) + decode_k_pe_int8, decode_k_pe_scale,_ = ops.scaled_int8_quant(decode_k_pe.contiguous()) + ops.concat_and_cache_mla_int8( + kv_c_int8 = k_c_normed_int8, + kv_c_scale = k_c_normed_scale[...,0], + k_pe_int8 = decode_k_pe_int8, + k_pe_scale = decode_k_pe_scale[...,0].view(-1,decode_k_pe_int8.shape[-2]), + kv_cache = kv_cache, + kv_cache_scale = kv_cache_scale, + slot_mapping = attn_metadata.slot_mapping.flatten()[:num_decode_tokens], + kv_cache_dtype=self.kv_cache_dtype, + scale=self._k_scale, + ) + else: + if self.impl.dcp_world_size > 1 or self.use_spec_decode: + ops.concat_and_cache_mla( + k_c_normed[:num_decode_tokens], + decode_k_pe, + kv_cache, + attn_metadata.slot_mapping.flatten()[:num_decode_tokens], + kv_cache_dtype=self.kv_cache_dtype, + scale=self._k_scale, + ) + else: + write_kv_cache = (k_c_normed[:num_decode_tokens], decode_k_pe) + if has_prefill: + ixf_ops.mla_rope(positions[num_decode_tokens:], prefill_q[..., self.qk_nope_head_dim:], prefill_k_pe.squeeze(1), prefill_k[...,self.qk_nope_head_dim:], self.rotary_emb.cos_sin_cache) + if envs.VLLM_USE_INT8_MLA: + prefill_k_c_normed_int8, prefill_k_c_normed_scale,_ = ops.scaled_int8_quant(prefill_k_c_normed) + prefill_k_pe_int8, prefill_k_pe_scale,_ = ops.scaled_int8_quant(prefill_k[...,self.qk_nope_head_dim:].contiguous()) + ops.concat_and_cache_mla_int8( + prefill_k_c_normed_int8, + prefill_k_c_normed_scale[...,0], + prefill_k_pe_int8, + prefill_k_pe_scale[...,0].view(-1,prefill_k_pe_int8.shape[-2]), + kv_cache, + kv_cache_scale, + attn_metadata.slot_mapping.flatten()[num_decode_tokens:], + self.kv_cache_dtype, + self._k_scale, + ) + else: + ops.concat_and_cache_mla( + prefill_k_c_normed, + prefill_k[...,self.qk_nope_head_dim:], + kv_cache, + attn_metadata.slot_mapping.flatten()[num_decode_tokens:], + kv_cache_dtype=self.kv_cache_dtype, + scale=self._k_scale, + ) + output = torch.empty(output.shape[0], + self.num_heads, self.v_head_dim, + device=q.device, + dtype=q.dtype) + + if fp8_attention and self.kv_cache_dtype != "fp8_ds_mla": + kv_cache = kv_cache.view(current_platform.fp8_dtype()) + + # Sparse MLA impls only support forward_mqa (decode-style attention) + + if has_prefill: + output[num_decode_tokens:] = self.impl.forward_mha( + prefill_q, + prefill_k_c_normed, + prefill_k, + kv_cache, + attn_metadata, + kv_c_and_k_pe_cache_scale=kv_cache_scale + ) + + if has_decode: + # For sparse impl, we always use forward_mqa for all tokens + # For non-sparse impl, we only use forward_mqa for decode tokens + assert attn_metadata.decode is not None + mqa_q = decode_q + + attn_out, lse = self.impl.forward_mqa( + mqa_q[..., :self.qk_nope_head_dim], decode_q_pe, kv_cache, attn_metadata, + *write_kv_cache, kv_c_and_k_pe_cache_scale=kv_cache_scale) + if self.impl.dcp_world_size > 1: + assert lse is not None + attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group()) + output[:num_decode_tokens] = self.impl._v_up_proj(attn_out) + else: + assert lse is None + output[:num_decode_tokens] = attn_out + return output.view(output.shape[0], self.v_head_dim * self.num_heads) + else: + num_actual_toks = attn_metadata.num_actual_tokens + # Inputs and outputs may be padded for CUDA graphs + k_pe = k_pe.unsqueeze(1) + q = q[:num_actual_toks, ...] + k_c_normed = k_c_normed[:num_actual_toks, ...] + k_pe = k_pe[:num_actual_toks, ...] + positions = positions[:num_actual_toks, ...] + num_mqa_tokens = q.size(0) + if num_mqa_tokens > 0: + q = q[:num_mqa_tokens] + q_nope, q_pe = q.split( + [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + q_pe, k_pe = self.rotary_emb(positions[:num_actual_toks], q_pe, k_pe) + + q_nope = self.impl._k_up_proj(q_nope) + q_nope = q_nope.view(-1, self.num_heads, self.kv_lora_rank) + q = torch.cat([q_nope, q_pe], dim=-1) + if kv_cache.numel() > 0: + ops.concat_and_cache_mla( + k_c_normed, + k_pe, + kv_cache, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=self._k_scale, + ) + attn_out= self.impl.forward_mqa(q, kv_cache, attn_metadata, self) + output = torch.empty(output.shape[0], + self.num_heads, self.v_head_dim, + device=q.device, + dtype=q.dtype) + output[:num_actual_toks] = self.impl._v_up_proj(attn_out) + + return output.view(output.shape[0], self.v_head_dim * self.num_heads) + + def process_weights_after_loading(self, act_dtype: torch.dtype): + if hasattr(self.impl, "process_weights_after_loading"): self.impl.process_weights_after_loading(act_dtype) @@ -543,15 +761,21 @@ class MLAAttention(nn.Module, AttentionLayerBase): ) + @maybe_transfer_kv_layer def unified_mla_attention( q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, layer_name: str, + kv_cache_dummy_dep: torch.Tensor | None = None, ) -> torch.Tensor: - attn_metadata, self, kv_cache = get_attention_context(layer_name) - output = self.impl.forward(self, q, kv_c_normed, k_pe, kv_cache, attn_metadata) + # kv_cache_dummy_dep is not used but accepting it creates a data dependency + # that ensures torch.compile preserves ordering between KV cache update and + # attention forward. + del kv_cache_dummy_dep + attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name) + output = layer.forward_impl(q, kv_c_normed, k_pe, kv_cache, attn_metadata) return output @@ -561,6 +785,7 @@ def unified_mla_attention_fake( kv_c_normed: torch.Tensor, k_pe: torch.Tensor, layer_name: str, + kv_cache_dummy_dep: torch.Tensor | None = None, ) -> torch.Tensor: return torch.empty_like(q).contiguous() @@ -574,6 +799,56 @@ direct_register_custom_op( ) +def unified_mla_kv_cache_update( + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + layer_name: str, + kv_cache_dtype: str, + k_scale: torch.Tensor, +) -> torch.Tensor: + """ + Returns a dummy that is passed to unified_attention to signal a side effect and + the data dependency between them to ensure torch.compile preserves ordering. + """ + forward_context = get_forward_context() + attn_layer = forward_context.no_compile_layers[layer_name] + kv_cache = attn_layer.kv_cache[forward_context.virtual_engine] + + slot_mapping = forward_context.slot_mapping + assert isinstance(slot_mapping, dict), ( + f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. " + ) + layer_slot_mapping = slot_mapping.get(layer_name) + if layer_slot_mapping is not None: + attn_layer.impl.do_kv_cache_update( + kv_c_normed, + k_pe, + kv_cache, + layer_slot_mapping, + kv_cache_dtype, + k_scale, + ) + + return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype) + + +def unified_mla_kv_cache_update_fake( + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + layer_name: str, + kv_cache_dtype: str, + k_scale: torch.Tensor, +) -> torch.Tensor: + return torch.empty(0, device=kv_c_normed.device, dtype=kv_c_normed.dtype) + + +direct_register_custom_op( + op_name="unified_mla_kv_cache_update", + op_func=unified_mla_kv_cache_update, + fake_impl=unified_mla_kv_cache_update_fake, +) + + @maybe_transfer_kv_layer def unified_mla_attention_with_output( q: torch.Tensor, @@ -583,10 +858,14 @@ def unified_mla_attention_with_output( layer_name: str, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, + kv_cache_dummy_dep: torch.Tensor | None = None, ) -> None: - attn_metadata, self, kv_cache = get_attention_context(layer_name) - self.impl.forward( - self, + # kv_cache_dummy_dep is not used but accepting it creates a data dependency + # that ensures torch.compile preserves ordering between KV cache update and + # attention forward. + del kv_cache_dummy_dep + attn_metadata, layer, kv_cache, _ = get_attention_context(layer_name) + layer.forward_impl( q, kv_c_normed, k_pe, @@ -606,6 +885,7 @@ def unified_mla_attention_with_output_fake( layer_name: str, output_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None, + kv_cache_dummy_dep: torch.Tensor | None = None, ) -> None: return @@ -637,20 +917,23 @@ class QueryLenSupport(Enum): try: - # from vllm.vllm_flash_attn import ( # type: ignore[attr-defined] - # flash_attn_varlen_func, - # ) - from ixformer.contrib.vllm_flash_attn import ( - flash_attn_varlen_func, - merge_attn_states, - ) - + from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func,merge_attn_states is_vllm_fa = True except ImportError: - # For rocm use upstream flash attention - if current_platform.is_rocm(): - from flash_attn import flash_attn_varlen_func # type: ignore[no-redef] is_vllm_fa = False + flash_attn_varlen_func = None # type: ignore[assignment] + # On ROCm, vllm_flash_attn is not available, try upstream flash_attn instead. + # On CUDA, vllm_flash_attn should always be available (built with vLLM), + # so we don't attempt the fallback there. + if current_platform.is_rocm(): + try: + from flash_attn import flash_attn_varlen_func # type: ignore[no-redef] + except ImportError: + logger.debug( + "flash_attn not available on ROCm; " + "MLA models using TRITON_MLA will require flash_attn. " + "AITER_MLA backends use aiter kernels instead." + ) def dynamic_per_batched_tensor_quant( @@ -667,7 +950,10 @@ def dynamic_per_batched_tensor_quant( logger = init_logger(__name__) -@CustomOp.register("mla_decode_concat_quant_fp8") +@CustomOp.register( + "mla_decode_concat_quant_fp8", + dynamic_arg_dims={"decode_ql_nope": 0, "decode_q_pe": 0}, +) class _DecodeConcatQuantFP8(QuantFP8): """ QuantFP8 variant that concatenates decode_ql_nope and decode_q_pe before @@ -704,7 +990,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import scaled_deq import ixformer.inference.functions as ixf_ops import numpy as np - class MLACommonBackend(AttentionBackend): accept_output_buffer: bool = True @@ -775,6 +1060,7 @@ class MLACommonPrefillMetadata: query_seq_lens: torch.Tensor | None = None workspace_buffer: torch.Tensor | None = None q_data_type: torch.dtype | None = None + output_dtype: torch.dtype | None = None @dataclass @@ -870,6 +1156,7 @@ def is_deepseek_r1_mla_compatible(vllm_config: VllmConfig) -> bool: return qk_nope_head_dim == 128 and qk_rope_head_dim == 64 and v_head_dim == 128 +@functools.cache def use_flashinfer_prefill() -> bool: # For blackwell default to flashinfer prefill if it's available since # it is faster than FA2. @@ -887,6 +1174,7 @@ def use_flashinfer_prefill() -> bool: return is_deepseek_r1_mla_compatible(vllm_config) +@functools.cache def use_cudnn_prefill() -> bool: from vllm.config import get_current_vllm_config @@ -899,6 +1187,7 @@ def use_cudnn_prefill() -> bool: ) +@functools.cache def use_trtllm_ragged_deepseek_prefill() -> bool: """Check if TRT-LLM ragged DeepSeek prefill should be used.""" from vllm.config import get_current_vllm_config @@ -935,6 +1224,27 @@ def get_mla_dims(model_config: ModelConfig) -> MLADims: ) +@functools.cache +def backend_supports_prefill_query_quantization() -> bool: + """Check if the selected MLA backend supports prefill query quantization. + + Currently supported backends: + - FlashInfer prefill + - TRT-LLM ragged DeepSeek prefill + + Not supported: + - cuDNN Prefill + - FlashAttention + - Non-GB200 devices (FP8 prefill requires device capability 100) + """ + # FP8 prefill query quantization requires GB200 (device capability 100) + # for the necessary FP8 kernels at the moment. + if not current_platform.is_device_capability_family(100): + return False + + return use_flashinfer_prefill() or use_trtllm_ragged_deepseek_prefill() + + class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): """ NOTE: Please read the comment at the top of the file before trying to @@ -947,7 +1257,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): # - VARLEN: Supports variable-length queries (spec decode with mixed lengths) # If set to UNIFORM or VARLEN, this will increase `reorder_batch_threshold` when # speculative decoding is enabled. - query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.SINGLE_ONLY + query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM # The threshold for reordering the batch into decode and prefill requests. # If > 1, the batch will be reordered such that requests with @@ -956,7 +1266,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): # when speculative decoding is enabled. reorder_batch_threshold: int = 1 _cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + AttentionCGSupport.UNIFORM_BATCH @staticmethod def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int: @@ -989,6 +1299,40 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): return chunked_prefill_workspace_size + @staticmethod + def determine_prefill_query_data_type( + vllm_config: VllmConfig, + model_dtype: torch.dtype, + ) -> torch.dtype: + """ + Determine the query data type for prefill queries. + Return FP8 dtype if cache is FP8 and prefill query quantization + is enabled, else model dtype. + """ + use_fp8 = ( + vllm_config.cache_config.cache_dtype.startswith("fp8") + and vllm_config.attention_config.use_prefill_query_quantization + and backend_supports_prefill_query_quantization() + ) + + if use_fp8: + fp8_dtype = current_platform.fp8_dtype() + logger.info_once( + "FP8 prefill attention enabled: query data type is FP8", scope="local" + ) + return fp8_dtype + elif vllm_config.attention_config.use_prefill_query_quantization: + logger.info_once( + "Unable to perform FP8 prefill attention when" + " use_prefill_query_quantization is enabled. Please" + " ensure that --kv-cache-dtype is set to fp8 and your prefill" + " backend is compatible with FP8 attention.", + scope="local", + ) + return model_dtype + + return model_dtype + def __init__( self, kv_cache_spec: AttentionSpec, @@ -1006,12 +1350,24 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): self.model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config self.compilation_config = vllm_config.compilation_config + self.decode_use_graph = ( + vllm_config.compilation_config.cudagraph_mode.decode_use_graph() + ) + self.use_full_cuda_graph = ( + vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs() + ) self.vllm_config = vllm_config self.device = device self.num_heads = self.model_config.get_num_attention_heads(parallel_config) self.mla_dims = get_mla_dims(self.model_config) self.aot_schedule = current_platform.is_cuda() + + self.kv_cache_spec = kv_cache_spec + self.q_data_type = self.determine_prefill_query_data_type( + vllm_config, self.model_config.dtype + ) + try: self.dcp_world_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group @@ -1052,7 +1408,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): self.chunked_prefill_workspace_size, self.model_config.get_head_size(), ), - dtype=self.model_config.dtype, + dtype=self.q_data_type, device=device, ) @@ -1106,6 +1462,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): f"reorder_batch_threshold must be 1 when query_len_support is " f"SINGLE_ONLY, got {self.reorder_batch_threshold}" ) + + # Spec-decode expands block_table / seq_lens via index_select; CUDA graph + # replay expects stable tensor storage. Copy expanded rows into these + # buffers when decode or full CUDA graph is enabled. + self._spec_decode_expand_block_table_buf: torch.Tensor | None = None + self._spec_decode_expand_seq_lens_buf: torch.Tensor | None = None def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): qo_indptr = prefill.query_start_loc @@ -1162,7 +1524,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): sm_scale=self._global_hyperparameters.sm_scale, window_left=self._global_hyperparameters.window_left, logits_soft_cap=self._global_hyperparameters.logits_soft_cap, - q_data_type=self.model_config.dtype, + q_data_type=self.q_data_type, + o_data_type=prefill.output_dtype, ) # Prepare context prefills @@ -1181,7 +1544,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): sm_scale=self._global_hyperparameters.sm_scale, window_left=self._global_hyperparameters.window_left, logits_soft_cap=self._global_hyperparameters.logits_soft_cap, - q_data_type=self.model_config.dtype, + q_data_type=self.q_data_type, + o_data_type=prefill.output_dtype, ) prefill.prefill_main = self._fi_prefill_main @@ -1195,16 +1559,17 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): query_start_loc_cpu: torch.Tensor, query_start_loc_device: torch.Tensor, num_decode_tokens: int, - max_decode_seq_len: int, - use_cuda_graph: bool, dcp_tot_seq_lens_device: torch.Tensor | None, + max_decode_seq_len: int, + use_cuda_graph: bool ) -> MLACommonDecodeMetadata: return MLACommonDecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens_device, + dcp_tot_seq_lens=dcp_tot_seq_lens_device, max_decode_seq_len=max_decode_seq_len, use_cuda_graph=use_cuda_graph, - dcp_tot_seq_lens=dcp_tot_seq_lens_device, + ) def build_for_cudagraph_capture( @@ -1246,7 +1611,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens - + seq_lens_np = common_attn_metadata.seq_lens_np num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( common_attn_metadata, @@ -1440,6 +1805,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): query_start_loc=prefill_query_start_loc, max_query_len=max_query_len, chunked_context=chunked_context_metadata, + output_dtype=self.model_config.dtype, + q_data_type=self.q_data_type, ) if self._use_cudnn_prefill: @@ -1471,17 +1838,135 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): max_seq_len = ( (max_seq_len + num_partitions - 1) // num_partitions ) * self.cp_kv_cache_interleave_size + + decode_block_table = block_table_tensor[:num_decodes, ...] + decode_seq_lens_np = seq_lens_np[:num_decodes] + decode_seq_lens_device = seq_lens[:num_decodes] + decode_query_start_loc_cpu = query_start_loc_cpu[: num_decodes + 1] + decode_query_start_loc = query_start_loc[: num_decodes + 1] + max_decode_seq_len = np.max(decode_seq_lens_np).item() + + speculative_config = self.vllm_config.speculative_config + use_spec_decode = ( + speculative_config is not None + and speculative_config.num_speculative_tokens is not None + and speculative_config.num_speculative_tokens > 0 + ) + # For non-DCP speculative decode, expand decode metadata from + # request-level rows to token-level rows so block_table/seq_lens + # align with num_decode_tokens. + if ( + use_spec_decode + and self.dcp_world_size == 1 + and num_decode_tokens > num_decodes + ): + decode_lens_cpu = ( + decode_query_start_loc_cpu[1:] - decode_query_start_loc_cpu[:-1] + ) + decode_lens_device_long = decode_lens_cpu.to( + device=decode_block_table.device, dtype=torch.long + ) + req_ids = torch.repeat_interleave( + torch.arange( + num_decodes, device=decode_block_table.device, dtype=torch.long + ), + decode_lens_device_long, + ) + if req_ids.numel() != num_decode_tokens: + pass + else: + expanded_block_table = decode_block_table.index_select( + 0, req_ids + ) + + decode_query_start_loc_long = decode_query_start_loc.to( + device=expanded_block_table.device, dtype=torch.long + ) + token_starts = torch.repeat_interleave( + decode_query_start_loc_long[:-1], decode_lens_device_long + ) + token_pos_in_req = ( + torch.arange( + num_decode_tokens, + device=expanded_block_table.device, + dtype=torch.long, + ) + - token_starts + ) + + decode_seq_lens_device_long = decode_seq_lens_device.to( + device=decode_block_table.device, dtype=torch.long + ) + context_lens_long = ( + decode_seq_lens_device_long - decode_lens_device_long + ) + decode_seq_lens_device_long = ( + context_lens_long.index_select(0, req_ids) + + token_pos_in_req + + 1 + ) + + decode_seq_lens_device = decode_seq_lens_device_long.to( + dtype=seq_lens.dtype + ) + + if self.decode_use_graph or self.use_full_cuda_graph: + n = expanded_block_table.shape[0] + ncols = expanded_block_table.shape[1] + max_rows = ( + self.vllm_config.scheduler_config.max_num_batched_tokens + ) + assert n <= max_rows, ( + f"spec-decode expanded decode rows {n} > " + f"max_num_batched_tokens {max_rows}" + ) + if ( + self._spec_decode_expand_block_table_buf is None + or self._spec_decode_expand_block_table_buf.shape[0] + < max_rows + or self._spec_decode_expand_block_table_buf.shape[1] + != ncols + ): + self._spec_decode_expand_block_table_buf = torch.zeros( + max_rows, + ncols, + dtype=block_table_tensor.dtype, + device=self.device, + ) + self._spec_decode_expand_seq_lens_buf = torch.zeros( + max_rows, + dtype=seq_lens.dtype, + device=self.device, + ) + self._spec_decode_expand_block_table_buf[:n].copy_( + expanded_block_table + ) + self._spec_decode_expand_seq_lens_buf[:n].copy_( + decode_seq_lens_device + ) + decode_block_table = self._spec_decode_expand_block_table_buf[ + :n + ] + decode_seq_lens_device = ( + self._spec_decode_expand_seq_lens_buf[:n] + ) + else: + decode_block_table = expanded_block_table + decode_seq_lens_cpu = decode_seq_lens_device.to( + device="cpu", dtype=seq_lens.dtype + ) + max_decode_seq_len = torch.max(decode_seq_lens_cpu).item() decode_metadata = self._build_decode( - block_table_tensor=block_table_tensor[:num_decodes, ...], - seq_lens_device=seq_lens[:num_decodes], + block_table_tensor=decode_block_table, + seq_lens_device=decode_seq_lens_device, max_seq_len=max_seq_len, - query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1], - query_start_loc_device=query_start_loc[: num_decodes + 1], + query_start_loc_cpu=decode_query_start_loc_cpu, + query_start_loc_device=decode_query_start_loc, num_decode_tokens=num_decode_tokens, - max_decode_seq_len=torch.max(seq_lens[:num_decodes]).item(), - use_cuda_graph=False, dcp_tot_seq_lens_device=dcp_tot_seq_lens_device, + max_decode_seq_len=max_decode_seq_len, + use_cuda_graph=num_prefills == 0 and self.decode_use_graph, ) attn_metadata = self.metadata_cls( @@ -1580,9 +2065,7 @@ def reorg_kvcache( return reorganized_kv_c_normed, reorganized_k_pe -# TODO(Lucas): rename MLACommonBaseImpl -> MLACommonImpl, -# and MLACommonImpl -> MLACommonDenseImpl or somthing like that -class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): +class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): """ NOTE: Please read the comment at the top of the file before trying to understand this class @@ -1608,9 +2091,8 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): qk_head_dim: int, v_head_dim: int, kv_b_proj: ColumnParallelLinear, - indexer=None, + indexer: object | None = None, q_pad_num_heads: int | None = None, - rotary_emb = None, ) -> None: if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported for MLA") @@ -1630,23 +2112,321 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): self.kv_b_proj = kv_b_proj self.indexer = indexer self.q_pad_num_heads = q_pad_num_heads - self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled() - self.rotary_emb = rotary_emb + self.supports_quant_query_input = True - # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported - self.is_aiter_triton_fp4_bmm_enabled = ( - rocm_aiter_ops.is_fp4bmm_enabled() - and self.kv_b_proj.weight.dtype == torch.bfloat16 + # Use flashinfer's optimized concat_mla_k kernel when available. + # The kernel is optimized for DeepSeek V3 dimensions: + # num_heads=128, nope_dim=128, rope_dim=64 + self._use_flashinfer_concat_mla_k = ( + has_flashinfer() + and (self.num_heads == 128) + and (self.qk_nope_head_dim == 128) + and (self.qk_rope_head_dim == 64) ) - def process_weights_after_loading(self, act_dtype: torch.dtype): - # we currently do not have quantized bmm's which are needed for - # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform - # the bmm's in 16-bit, the extra memory overhead of this is fairly low - # kv_b_proj_weight = get_and_maybe_dequant_weights( - # self.kv_b_proj, out_dtype=act_dtype - # ).T + if use_trtllm_ragged_deepseek_prefill(): + logger.info_once( + "Using TRT-LLM ragged DeepSeek prefill for MLA", scope="local" + ) + self._run_prefill_context_chunk = ( + self._run_prefill_context_chunk_trtllm_ragged + ) + self._run_prefill_new_tokens = self._run_prefill_new_tokens_trtllm_ragged + self._pad_v = False + elif use_flashinfer_prefill(): + logger.info_once("Using FlashInfer prefill for MLA", scope="local") + self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi + self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi + self._pad_v = False + elif use_cudnn_prefill(): + logger.info_once("Using CUDNN prefill for MLA", scope="local") + self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn + self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn + self._pad_v = False + else: # Use FlashAttention + if flash_attn_varlen_func is None: + raise RuntimeError( + "MLA attention requires FlashAttention but it is not " + "available. Please install flash_attn or use " + "--attention-backend ROCM_AITER_MLA." + ) + logger.info_once("Using FlashAttention prefill for MLA", scope="local") + self.positions = None + self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa + self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa + # Handle the differences between the flash_attn_varlen from + # flash_attn and the one from vllm_flash_attn. The former is used on + # RoCM and the latter has an additional parameter to control + # FA2 vs FA3 + self.flash_attn_varlen_func = flash_attn_varlen_func + self.vllm_flash_attn_version = get_flash_attn_version() + # if self.vllm_flash_attn_version is not None: + # self.flash_attn_varlen_func = functools.partial( + # flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version + # ) + + # For MLA the v head dim is smaller than qk head dim so we pad out + # v with 0s to match the qk head dim for attention backends that do + # not support different headdims + # We don't need to pad V if we are on a hopper system with FA3 + # self._pad_v = self.vllm_flash_attn_version is None or not ( + # self.vllm_flash_attn_version == 3 + # and current_platform.get_device_capability()[0] == 9 + # ) + self._pad_v = False + + self.dcp_world_size: int = -1 + + self.cp_kv_cache_interleave_size: int = ( + get_current_vllm_config().parallel_config.cp_kv_cache_interleave_size + ) + + def _flash_attn_varlen_diff_headdims( + self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs + ): + maybe_padded_v = v + if self._pad_v: + maybe_padded_v = torch.nn.functional.pad( + v, [0, q.shape[-1] - v.shape[-1]], value=0 + ) + + if is_vllm_fa: + kwargs["return_softmax_lse"] = return_softmax_lse + else: + # ROCm leverages the upstream flash_attn, which takes a parameter + # called "return_attn_probs" instead of return_softmax_lse + kwargs["return_attn_probs"] = return_softmax_lse + if vllm_is_batch_invariant(): + kwargs["num_splits"] = 1 + + attn_out = self.flash_attn_varlen_func( + q=q, + k=k, + v=maybe_padded_v, + softmax_scale=softmax_scale, + **kwargs, + ) + + # Unpack the output if there is multiple results + lse = None + if isinstance(attn_out, tuple): + attn_out, lse = attn_out[0], attn_out[1] + + # Remain consistent with old `flash_attn_varlen_func` where there + # is only one output tensor if `return_softmax_lse` is False. + if return_softmax_lse: + return attn_out, lse + return attn_out + + def _run_prefill_new_tokens_fa( + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse, out + ): + return self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill.query_start_loc, + cu_seqlens_k=prefill.query_start_loc, + max_seqlen_q=prefill.max_query_len, + max_seqlen_k=prefill.max_query_len, + softmax_scale=self.scale, + causal=True, + return_softmax_lse=return_softmax_lse, + out=out, + ) + + def _run_prefill_new_tokens_fi( + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse + ): + assert isinstance(prefill, FlashInferPrefillMetadata) + assert prefill.prefill_main is not None + + ret = prefill.prefill_main.run( + q=q, + k=k, + v=v, + return_lse=return_softmax_lse, + ) + + if isinstance(ret, tuple): + return ret[0], ret[1].transpose(0, 1).contiguous() + return ret + + def _run_prefill_new_tokens_cudnn( + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse + ): + assert isinstance(prefill, CudnnPrefillMetadata) + assert prefill.query_seq_lens is not None + from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache + + output, lse = cudnn_batch_prefill_with_kv_cache( + q=q, + k_cache=k, + v_cache=v, + scale=self.scale, + workspace_buffer=prefill.cudnn_workspace, + max_token_per_sequence=prefill.max_query_len, + max_sequence_kv=prefill.max_query_len, + actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1), + actual_seq_lens_kv=prefill.query_seq_lens.view(-1, 1, 1, 1), + causal=True, + # Do not support False for now + return_lse=True, + # Indicates actual_seq_lens are on GPU or CPU. + is_cuda_graph_compatible=True, + ) + if return_softmax_lse: + return output, lse + return output + + def _run_prefill_context_chunk_fa( + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v, out + ): + assert prefill.chunked_context is not None + return self._flash_attn_varlen_diff_headdims( + q=q, + k=k, + v=v, + cu_seqlens_q=prefill.query_start_loc, + cu_seqlens_k=prefill.chunked_context.cu_seq_lens[chunk_idx], + max_seqlen_q=prefill.max_query_len, + max_seqlen_k=prefill.chunked_context.max_seq_lens[chunk_idx], + softmax_scale=self.scale, + causal=False, # Context is unmasked + return_softmax_lse=True, + out=out, + ) + + def _run_prefill_context_chunk_fi( + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v + ): + assert isinstance(prefill, FlashInferPrefillMetadata) + + attn_out, lse = prefill.prefill_chunks[chunk_idx].run( + q=q, + k=k, + v=v, + return_lse=True, + ) + + # Convert from (q_len, num_heads) to (num_heads, q_len) + return attn_out, lse.transpose(0, 1).contiguous() + + def _run_prefill_context_chunk_cudnn( + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v + ): + assert isinstance(prefill, CudnnPrefillMetadata) + assert prefill.chunked_context is not None + assert prefill.chunked_context.seq_lens[chunk_idx] is not None + assert prefill.query_seq_lens is not None + from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache + + return cudnn_batch_prefill_with_kv_cache( + q=q, + k_cache=k, + v_cache=v, + scale=self.scale, + workspace_buffer=prefill.cudnn_workspace, + max_token_per_sequence=prefill.max_query_len, + max_sequence_kv=prefill.chunked_context.max_seq_lens[chunk_idx], + actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1), + actual_seq_lens_kv=prefill.chunked_context.seq_lens[chunk_idx].view( + -1, 1, 1, 1 + ), + causal=False, + return_lse=True, + # Indicates actual_seq_lens are on GPU or CPU. + is_cuda_graph_compatible=True, + ) + + def _run_prefill_new_tokens_trtllm_ragged( + self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse + ): + """TRT-LLM ragged attention for new tokens (causal).""" + from flashinfer.prefill import trtllm_ragged_attention_deepseek + + assert prefill.query_seq_lens is not None + assert prefill.workspace_buffer is not None + # allocate BF16 / FP16 output tensor for TRT-LLM ragged attention + out = torch.empty( + q.shape[0], + q.shape[1], + v.shape[2], + device=q.device, + dtype=prefill.output_dtype, + ) + + ret = trtllm_ragged_attention_deepseek( + query=q, + key=k, + value=v, + workspace_buffer=prefill.workspace_buffer, + seq_lens=prefill.query_seq_lens, + max_q_len=prefill.max_query_len, + max_kv_len=prefill.max_query_len, + bmm1_scale=self.scale, + bmm2_scale=1.0, + o_sf_scale=1.0, + batch_size=prefill.query_seq_lens.shape[0], + window_left=-1, + cum_seq_lens_q=prefill.query_start_loc, + cum_seq_lens_kv=prefill.query_start_loc, + enable_pdl=False, + is_causal=True, + return_lse=return_softmax_lse, + out=out, + ) + + if isinstance(ret, tuple): + # Convert from (q_len, num_heads) to (num_heads, q_len) + return ret[0], ret[1].transpose(0, 1).contiguous() + return ret + + def _run_prefill_context_chunk_trtllm_ragged( + self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v + ): + """TRT-LLM ragged attention for context chunks (non-causal).""" + from flashinfer.prefill import trtllm_ragged_attention_deepseek + + assert prefill.chunked_context is not None + assert prefill.chunked_context.seq_lens[chunk_idx] is not None + assert prefill.workspace_buffer is not None + + out = torch.zeros( + q.shape[0], + q.shape[1], + v.shape[2], + device=q.device, + dtype=prefill.output_dtype, + ) + prefill.workspace_buffer.fill_(0) + + attn_out, lse = trtllm_ragged_attention_deepseek( + query=q, + key=k, + value=v, + workspace_buffer=prefill.workspace_buffer, + seq_lens=prefill.chunked_context.seq_lens[chunk_idx], + max_q_len=prefill.max_query_len, + max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx], + bmm1_scale=self.scale, + bmm2_scale=1.0, + o_sf_scale=1.0, + batch_size=prefill.chunked_context.seq_lens[chunk_idx].shape[0], + window_left=-1, + cum_seq_lens_q=prefill.query_start_loc, + cum_seq_lens_kv=prefill.chunked_context.cu_seq_lens[chunk_idx], + enable_pdl=False, + is_causal=False, + return_lse=True, + out=out, + ) + + # Convert from (q_len, num_heads) to (num_heads, q_len) + return attn_out, lse.transpose(0, 1).contiguous() + + def process_weights_after_loading(self, act_dtype: torch.dtype): def get_layer_weight(layer): WEIGHT_NAMES = ("weight", "qweight", "weight_packed") for attr in WEIGHT_NAMES: @@ -1657,15 +2437,17 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): ) def get_and_maybe_dequant_weights(layer: LinearBase): - if not isinstance(layer.quant_method, UnquantizedLinearMethod): + if layer.quant_method is not None and not isinstance( + layer.quant_method, UnquantizedLinearMethod + ): # we already prepare a dequant weight for GGUF, skip it. if layer.quant_method.__class__.__name__ == "GGUFLinearMethod": weight = layer.weight.clone() del layer.weight return weight if layer.quant_method.__class__.__name__ == "AWQMarlinLinearMethod": - from ixformer.inference.functions import ref_wui4a16 - return ref_wui4a16(None, layer.qweight, layer.scales, layer.qzeros, None, layer.quant_method.quant_config.group_size, only_return_weight=True) + from ixformer.inference.functions import wui4a16 + return wui4a16(None, layer.qweight, layer.scales, layer.qzeros, None, layer.quant_method.quant_config.group_size, only_return_weight=True) # for W8A8, we directly dequantize it here to avoiding quantization errors if hasattr(layer, "scheme") and layer.scheme.__class__.__name__ == "CompressedTensorsW8A8Int8" and not layer.scheme.is_static_input_scheme: quant_weight = layer.weight.T # output, input @@ -1682,27 +2464,12 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): # standardize to (output, input) return dequant_weights.T return layer.weight - - # assert kv_b_proj_weight.shape == ( - # self.kv_lora_rank, - # self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), - # ), ( - # f"{kv_b_proj_weight.shape=}, " - # f"{self.kv_lora_rank=}, " - # f"{self.num_heads=}, " - # f"{self.qk_nope_head_dim=}, " - # f"{self.v_head_dim=}" - # ) - # kv_b_proj_weight = kv_b_proj_weight.view( - # self.kv_lora_rank, - # self.num_heads, - # self.qk_nope_head_dim + self.v_head_dim, - # ) - - # W_UK, W_UV = kv_b_proj_weight.split( - # [self.qk_nope_head_dim, self.v_head_dim], dim=-1 - # ) + back_to_vllm = False + weight_dtype = get_layer_weight(self.kv_b_proj).dtype + + # when use customize forward, we do not care the specific data types and shape, just reset the + # _v_up_proj_and_o_proj and _q_proj_and_k_up_proj funs to get the correct results. if envs.VLLM_MLA_CUSTOMIZE: layer = self.kv_b_proj quant_method = layer.quant_method.__class__.__name__ @@ -1866,19 +2633,18 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): back_to_vllm = True if back_to_vllm: - # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform # the bmm's in 16-bit, the extra memory overhead of this is fairly low kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), - ), ( - f"{kv_b_proj_weight.shape=}, " - f"{self.kv_lora_rank=}, " - f"{self.num_heads=}, " - f"{self.qk_nope_head_dim=}, " - f"{self.v_head_dim=}" - ) + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") kv_b_proj_weight = kv_b_proj_weight.view( self.kv_lora_rank, self.num_heads, @@ -1886,478 +2652,59 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): ) W_UK, W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1 - ) + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - if self.is_aiter_triton_fp8_bmm_enabled: - W_K = W_UK.transpose(0, 1) # 16 512 128 - W_V = W_UV.permute(1, 2, 0) # 16 128 512 - self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( - W_K, dtype=current_platform.fp8_dtype() - ) - self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( - W_V, dtype=current_platform.fp8_dtype() - ) - - # The kernel operates on non-padded inputs. Hence, pre-compiling - # triton kernel to avoid runtime compilation for unseen batch sizes - # Pre-compile for batch sizes 1 to 1024 to cover most use-cases. - # On DS-R1, this step adds roughly 50s to the model loading time. - max_batch_size = 1024 # [ToDo] Find the optimal upper limit - pre_compilation_list = list(range(1, max_batch_size + 1)) - if is_global_first_rank(): - pre_compilation_list = tqdm( - pre_compilation_list, - desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", - total=max_batch_size, - ) - - for m in pre_compilation_list: - x = torch.empty( - (self.W_K.shape[0], m, self.W_K.shape[2]), - dtype=torch.bfloat16, - device=self.W_K.device, - ) - rocm_aiter_ops.triton_fp8_bmm( - x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True - ) - - x = torch.empty( - (self.W_V.shape[0], m, self.W_V.shape[2]), - dtype=torch.bfloat16, - device=self.W_V.device, - ) - rocm_aiter_ops.triton_fp8_bmm( - x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True - ) - else: - # # Convert from (L, N, V) to (N, L, V) - # self.W_UV = W_UV.transpose(0, 1) - # # Convert from (L, N, P) to (N, P, L) - # self.W_UK_T = W_UK.permute(1, 2, 0) - self.W_UV = W_UV - self.W_UK = W_UK - - # # If kv_b_proj_weight is unquantized, quantize it to mxfp4 if supported + # # Convert from (L, N, V) to (N, L, V) + # self.W_UV = W_UV.transpose(0, 1) + # # Convert from (L, N, P) to (N, P, L) + # self.W_UK_T = W_UK.permute(1, 2, 0) + self.W_UV = W_UV + self.W_UK = W_UK + + def _v_up_proj(self, x: torch.Tensor): + # Convert from (B, N, L) to (N, B, L) + # x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + # out = out.view(-1, self.num_heads, self.v_head_dim) # if self.is_aiter_triton_fp4_bmm_enabled: - # from vllm.model_executor.layers.quantization.quark.utils import ( - # quark_quantize_weight_to_mxfp4, - # ) - - # self.W_K, self.W_K_scale = quark_quantize_weight_to_mxfp4(W_UK) - # # Convert from (L, N, P) to (N, L, P) - # self.W_K = self.W_K.transpose(0, 1) - # self.W_K_scale = self.W_K_scale.transpose(0, 1) - - # self.W_V, self.W_V_scale = quark_quantize_weight_to_mxfp4( - # W_UV.permute(1, 2, 0) + # out = rocm_aiter_ops.batched_gemm_a16wfp4( + # x, + # self.W_V, + # self.W_V_scale, + # out, + # transpose_bm=True, + # prequant=True, + # y_scale=None, # ) + # x = out.view(-1, self.num_heads * self.v_head_dim) # elif self.is_aiter_triton_fp8_bmm_enabled: - # W_K = W_UK.transpose(0, 1) # 16 512 128 - # W_V = W_UV.permute(1, 2, 0) # 16 128 512 - # self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( - # W_K, dtype=current_platform.fp8_dtype() + # # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) + # x = rocm_aiter_ops.triton_fp8_bmm( + # x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out # ) - # self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( - # W_V, dtype=current_platform.fp8_dtype() - # ) - - # # The kernel operates on non-padded inputs. Hence, pre-compiling - # # triton kernel to avoid runtime compilation for unseen batch sizes - # # Pre-compile for batch sizes 1 to 1024 to cover most use-cases. - # # On DS-R1, this step adds roughly 50s to the model loading time. - # max_batch_size = 1024 # [ToDo] Find the optimal upper limit - # pre_compilation_list = list(range(1, max_batch_size + 1)) - # if is_global_first_rank(): - # pre_compilation_list = tqdm( - # pre_compilation_list, - # desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", - # total=max_batch_size, - # ) - - # for m in pre_compilation_list: - # x = torch.empty( - # (self.W_K.shape[0], m, self.W_K.shape[2]), - # dtype=torch.bfloat16, - # device=self.W_K.device, - # ) - # rocm_aiter_ops.triton_fp8_bmm( - # x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True - # ) - - # x = torch.empty( - # (self.W_V.shape[0], m, self.W_V.shape[2]), - # dtype=torch.bfloat16, - # device=self.W_V.device, - # ) - # rocm_aiter_ops.triton_fp8_bmm( - # x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True - # ) # else: - # # Convert from (L, N, V) to (N, L, V) - # self.W_UV = W_UV.transpose(0, 1) - # # Convert from (L, N, P) to (N, P, L) - # self.W_UK_T = W_UK.permute(1, 2, 0) + # # Convert from (B, N * V) to (N, B, V) + # out = out.transpose(0, 1) - def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor): - # Convert from (B, N, L) to (N, B, L) - x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - out = out.view(-1, self.num_heads, self.v_head_dim) - if self.is_aiter_triton_fp4_bmm_enabled: - out = rocm_aiter_ops.batched_gemm_a16wfp4( - x, - self.W_V, - self.W_V_scale, - out, - transpose_bm=True, - prequant=True, - y_scale=None, - ) - x = out.view(-1, self.num_heads * self.v_head_dim) - elif self.is_aiter_triton_fp8_bmm_enabled: - # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) - x = rocm_aiter_ops.triton_fp8_bmm( - x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out - ) - else: - # Convert from (B, N * V) to (N, B, V) - out = out.transpose(0, 1) + # # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + # torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot" - # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot" + # # Convert from (N, B, V) to (B, N * V) + # out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) - # Convert from (N, B, V) to (B, N * V) - out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + # # Adjust output buffer shape back to the original (B, N * V) + # N, B, V = out.shape + # out.resize_((B, N * V)) + # out.copy_(out_new) # Copy result + return torch.einsum("bnl,lnv->bnv", x, self.W_UV) + def _k_up_proj(self, q_nope): + # # Convert from (B, N, P) to (N, B, P) + # q_nope = q_nope.transpose(0, 1) + # # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + # ql_nope = torch.bmm(q_nope, self.W_UK_T) + # # Convert from (N, B, L) to (B, N, L) + # return ql_nope.transpose(0, 1), q_pe + return torch.einsum("bnp,lnp->bnl", q_nope, self.W_UK).view(-1, self.num_heads, self.kv_lora_rank) - # Adjust output buffer shape back to the original (B, N * V) - N, B, V = out.shape - out.resize_((B, N * V)) - out.copy_(out_new) # Copy result - - -class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): - """ - NOTE: Please read the comment at the top of the file before trying to - understand this class - """ - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - - if use_trtllm_ragged_deepseek_prefill(): - logger.info_once( - "Using TRT-LLM ragged DeepSeek prefill for MLA", scope="local" - ) - self._run_prefill_context_chunk = ( - self._run_prefill_context_chunk_trtllm_ragged - ) - self._run_prefill_new_tokens = self._run_prefill_new_tokens_trtllm_ragged - self._pad_v = False - elif use_flashinfer_prefill(): - logger.info_once("Using FlashInfer prefill for MLA", scope="local") - self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi - self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi - self._pad_v = False - elif use_cudnn_prefill(): - logger.info_once("Using CUDNN prefill for MLA", scope="local") - self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn - self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn - self._pad_v = False - else: # Use FlashAttention - logger.info_once("Using FlashAttention prefill for MLA", scope="local") - self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa - self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa - - # Handle the differences between the flash_attn_varlen from - # flash_attn and the one from vllm_flash_attn. The former is used on - # RoCM and the latter has an additional parameter to control - # FA2 vs FA3 - self.flash_attn_varlen_func = flash_attn_varlen_func - self.vllm_flash_attn_version = get_flash_attn_version() - # if self.vllm_flash_attn_version is not None: - # self.flash_attn_varlen_func = functools.partial( - # flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version - # ) - - # For MLA the v head dim is smaller than qk head dim so we pad out - # v with 0s to match the qk head dim for attention backends that do - # not support different headdims - # We don't need to pad V if we are on a hopper system with FA3 - # device_capability = current_platform.get_device_capability() - # self._pad_v = self.vllm_flash_attn_version is None or not ( - # self.vllm_flash_attn_version == 3 - # and device_capability is not None - # and device_capability[0] == 9 - # ) - self._pad_v = False - - self.dcp_world_size: int = -1 - - self.chunked_prefill_workspace_size = ( - MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size( - get_current_vllm_config() - ) - ) - self.cp_kv_cache_interleave_size: int = ( - get_current_vllm_config().parallel_config.cp_kv_cache_interleave_size - ) - self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8( - static=True, - group_shape=GroupShape.PER_TENSOR, - compile_native=True, - ) - - def _flash_attn_varlen_diff_headdims( - self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs - ): - maybe_padded_v = v - if self._pad_v: - maybe_padded_v = torch.nn.functional.pad( - v, [0, q.shape[-1] - v.shape[-1]], value=0 - ) - - if is_vllm_fa: - kwargs["return_softmax_lse"] = return_softmax_lse - else: - # ROCm leverages the upstream flash_attn, which takes a parameter - # called "return_attn_probs" instead of return_softmax_lse - kwargs["return_attn_probs"] = return_softmax_lse - if vllm_is_batch_invariant(): - kwargs["num_splits"] = 1 - - attn_out = self.flash_attn_varlen_func( - q=q, - k=k, - v=maybe_padded_v, - softmax_scale=softmax_scale, - **kwargs, - ) - - # Unpack the output if there is multiple results - lse = None - if isinstance(attn_out, tuple): - attn_out, lse = attn_out[0], attn_out[1] - - # Remain consistent with old `flash_attn_varlen_func` where there - # is only one output tensor if `return_softmax_lse` is False. - if return_softmax_lse: - return attn_out, lse - return attn_out - - def _run_prefill_new_tokens_fa( - self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse, out - ): - return self._flash_attn_varlen_diff_headdims( - q=q, - k=k, - v=v, - cu_seqlens_q=prefill.query_start_loc, - cu_seqlens_k=prefill.query_start_loc, - max_seqlen_q=prefill.max_query_len, - max_seqlen_k=prefill.max_query_len, - softmax_scale=self.scale, - causal=True, - return_softmax_lse=return_softmax_lse, - out=out - ) - - def _run_prefill_new_tokens_fi( - self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse - ): - assert isinstance(prefill, FlashInferPrefillMetadata) - assert prefill.prefill_main is not None - - ret = prefill.prefill_main.run( - q=q, - k=k, - v=v, - return_lse=return_softmax_lse, - ) - - if isinstance(ret, tuple): - return ret[0], ret[1].transpose(0, 1).contiguous() - return ret - - def _run_prefill_new_tokens_cudnn( - self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse - ): - assert isinstance(prefill, CudnnPrefillMetadata) - assert prefill.query_seq_lens is not None - from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache - - output, lse = cudnn_batch_prefill_with_kv_cache( - q=q, - k_cache=k, - v_cache=v, - scale=self.scale, - workspace_buffer=prefill.cudnn_workspace, - max_token_per_sequence=prefill.max_query_len, - max_sequence_kv=prefill.max_query_len, - actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1), - actual_seq_lens_kv=prefill.query_seq_lens.view(-1, 1, 1, 1), - causal=True, - # Do not support False for now - return_lse=True, - # Indicates actual_seq_lens are on GPU or CPU. - is_cuda_graph_compatible=True, - ) - if return_softmax_lse: - return output, lse - return output - - def _run_prefill_context_chunk_fa( - self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v, out - ): - assert prefill.chunked_context is not None - return self._flash_attn_varlen_diff_headdims( - q=q, - k=k, - v=v, - cu_seqlens_q=prefill.query_start_loc, - cu_seqlens_k=prefill.chunked_context.cu_seq_lens[chunk_idx], - max_seqlen_q=prefill.max_query_len, - max_seqlen_k=prefill.chunked_context.max_seq_lens[chunk_idx], - softmax_scale=self.scale, - causal=False, # Context is unmasked - return_softmax_lse=True, - out=out - ) - - def _run_prefill_context_chunk_fi( - self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v - ): - assert isinstance(prefill, FlashInferPrefillMetadata) - - attn_out, lse = prefill.prefill_chunks[chunk_idx].run( - q=q, - k=k, - v=v, - return_lse=True, - ) - - # Convert from (q_len, num_heads) to (num_heads, q_len) - return attn_out, lse.transpose(0, 1).contiguous() - - def _run_prefill_context_chunk_cudnn( - self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v - ): - assert isinstance(prefill, CudnnPrefillMetadata) - assert prefill.chunked_context is not None - assert prefill.chunked_context.seq_lens[chunk_idx] is not None - assert prefill.query_seq_lens is not None - from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache - - return cudnn_batch_prefill_with_kv_cache( - q=q, - k_cache=k, - v_cache=v, - scale=self.scale, - workspace_buffer=prefill.cudnn_workspace, - max_token_per_sequence=prefill.max_query_len, - max_sequence_kv=prefill.chunked_context.max_seq_lens[chunk_idx], - actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1), - actual_seq_lens_kv=prefill.chunked_context.seq_lens[chunk_idx].view( - -1, 1, 1, 1 - ), - causal=False, - return_lse=True, - # Indicates actual_seq_lens are on GPU or CPU. - is_cuda_graph_compatible=True, - ) - - def _v_up_proj(self, x): - # Convert from (B, N, L) to (N, B, L) - x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - if is_rocm_aiter_fp8bmm_enabled(): - # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) - x = aiter_triton_fp8_bmm( - x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True - ) - # Convert from (B, N, V) to (B, N * V) - x = x.reshape(-1, self.num_heads * self.v_head_dim) - else: - # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - x = torch.bmm(x, self.W_UV) - # Convert from (N, B, V) to (B, N * V) - x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) - return x - - def _run_prefill_new_tokens_trtllm_ragged( - self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse - ): - """TRT-LLM ragged attention for new tokens (causal).""" - from flashinfer.prefill import trtllm_ragged_attention_deepseek - - assert prefill.query_seq_lens is not None - assert prefill.workspace_buffer is not None - - ret = trtllm_ragged_attention_deepseek( - query=q, - key=k, - value=v, - workspace_buffer=prefill.workspace_buffer, - seq_lens=prefill.query_seq_lens, - max_q_len=prefill.max_query_len, - max_kv_len=prefill.max_query_len, - bmm1_scale=self.scale, - bmm2_scale=1.0, - o_sf_scale=1.0, - batch_size=prefill.query_seq_lens.shape[0], - window_left=-1, - cum_seq_lens_q=prefill.query_start_loc, - cum_seq_lens_kv=prefill.query_start_loc, - enable_pdl=False, - is_causal=True, - return_lse=return_softmax_lse, - ) - - if isinstance(ret, tuple): - # Convert from (q_len, num_heads) to (num_heads, q_len) - return ret[0], ret[1].transpose(0, 1).contiguous() - return ret - - def _run_prefill_context_chunk_trtllm_ragged( - self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v - ): - """TRT-LLM ragged attention for context chunks (non-causal).""" - from flashinfer.prefill import trtllm_ragged_attention_deepseek - - assert prefill.chunked_context is not None - assert prefill.chunked_context.seq_lens[chunk_idx] is not None - assert prefill.workspace_buffer is not None - - out = torch.zeros( - q.shape[0], - q.shape[1], - v.shape[2], - device=q.device, - dtype=q.dtype, - ) - prefill.workspace_buffer.fill_(0) - - attn_out, lse = trtllm_ragged_attention_deepseek( - query=q, - key=k, - value=v, - workspace_buffer=prefill.workspace_buffer, - seq_lens=prefill.chunked_context.seq_lens[chunk_idx], - max_q_len=prefill.max_query_len, - max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx], - bmm1_scale=self.scale, - bmm2_scale=1.0, - o_sf_scale=1.0, - batch_size=prefill.chunked_context.seq_lens[chunk_idx].shape[0], - window_left=-1, - cum_seq_lens_q=prefill.query_start_loc, - cum_seq_lens_kv=prefill.chunked_context.cu_seq_lens[chunk_idx], - enable_pdl=False, - is_causal=False, - return_lse=True, - out=out, - ) - - # Convert from (q_len, num_heads) to (num_heads, q_len) - return attn_out, lse.transpose(0, 1).contiguous() def _concat_k_nope_k_pe( self, k_nope: torch.Tensor, k_pe: torch.Tensor @@ -2381,9 +2728,13 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): dtype=k_nope.dtype, device=k_nope.device, ) - # Direct copies with efficient broadcasting - k[..., : k_nope.shape[-1]] = k_nope - k[..., k_nope.shape[-1] :] = k_pe + + if self._use_flashinfer_concat_mla_k: + torch.ops.vllm.flashinfer_concat_mla_k(k, k_nope, k_pe) + else: + # Fallback: Direct copies with efficient broadcasting + k[..., : k_nope.shape[-1]] = k_nope + k[..., k_nope.shape[-1] :] = k_pe return k def _compute_prefill_context( @@ -2392,34 +2743,27 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache_scale: torch.Tensor, attn_metadata: MLACommonMetadata, - k_scale: torch.Tensor | None = None, ): assert attn_metadata.prefill is not None prefill_metadata = attn_metadata.prefill assert prefill_metadata.chunked_context is not None + use_fp8_prefill = prefill_metadata.q_data_type == current_platform.fp8_dtype() + output = None iters = len(prefill_metadata.chunked_context.seq_tot) workspace = prefill_metadata.chunked_context.workspace + + if use_fp8_prefill: + q = q.to(prefill_metadata.q_data_type) + for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] - # ops.gather_and_maybe_dequant_cache( - # src_cache=kv_c_and_k_pe_cache, - # dst=workspace, - # block_table=prefill_metadata.block_table, - # cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i], - # token_to_seq=prefill_metadata.chunked_context.token_to_seq[i], - # num_tokens=prefill_metadata.chunked_context.chunk_total_token[i], - # kv_cache_dtype=self.kv_cache_dtype, - # scale=k_scale, - # seq_starts=prefill_metadata.chunked_context.starts[i], - # ) - if envs.VLLM_USE_INT8_MLA: ops.gather_cache_int8( src_cache=kv_c_and_k_pe_cache, src_cache_scale=kv_c_and_k_pe_cache_scale, - kv_lora_rank=self.kv_lora_rank, + kv_lora_rank = self.kv_lora_rank, dst=workspace, block_table=prefill_metadata.block_table, cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i], @@ -2439,20 +2783,20 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): kv_c_normed = workspace[:toks][..., : self.kv_lora_rank].contiguous() k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1) + k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1) kv_nope = self.kv_b_proj(kv_c_normed)[0].view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim ) + + # To Do: Use epilogue of kv_b_proj to generate fp8 kv_nope. + if use_fp8_prefill: + kv_nope = kv_nope.to(prefill_metadata.q_data_type) + k_pe = k_pe.to(prefill_metadata.q_data_type) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - # k = self._concat_k_nope_k_pe(k_nope, k_pe) k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) - attn_output = torch.empty( - q.shape[0], - self.num_heads, - self.v_head_dim, - dtype=q.dtype, - device=q.device, - ) + + attn_output = torch.empty(q.shape[0], self.num_heads, self.v_head_dim, dtype=q.dtype, device=q.device) attn_output, attn_softmax_lse = self._run_prefill_context_chunk( prefill=prefill_metadata, @@ -2467,24 +2811,12 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): output = attn_output output_lse = attn_softmax_lse else: - # output_tmp = torch.empty_like(output) - # output_lse_tmp = torch.empty_like(output_lse) - # merge_attn_states( - # output=output_tmp, - # output_lse=output_lse_tmp, - # prefix_output=output, - # prefix_lse=output_lse, - # suffix_output=attn_output, - # suffix_lse=attn_softmax_lse, - # ) - # output = output_tmp - # output_lse = output_lse_tmp - merge_attn_states( + output,output_lse = merge_attn_states( prefix_output=output, prefix_lse=output_lse, suffix_output=attn_output, suffix_lse=attn_softmax_lse, - return_lse=True + return_lse=True, ) return output, output_lse @@ -2564,8 +2896,8 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim ) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - # k = self._concat_k_nope_k_pe(k_nope, k_pe) - k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) + k = self._concat_k_nope_k_pe(k_nope, k_pe) + attn_output = torch.empty(q.shape[0], self.num_heads, self.v_head_dim, dtype=q.dtype, device=q.device) attn_output, attn_softmax_lse = self._run_prefill_context_chunk( prefill=prefill_metadata, @@ -2573,67 +2905,54 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): q=q, k=k, v=v, + out=attn_output, ) if output is None: output = attn_output output_lse = attn_softmax_lse else: - output_tmp = torch.empty_like(output) - output_lse_tmp = torch.empty_like(output_lse) merge_attn_states( - output=output_tmp, - output_lse=output_lse_tmp, prefix_output=output, prefix_lse=output_lse, suffix_output=attn_output, suffix_lse=attn_softmax_lse, + return_lse=True, ) - output = output_tmp - output_lse = output_lse_tmp - return output, output_lse def forward_mha( self, q: torch.Tensor, kv_c_normed: torch.Tensor, - # k_pe: torch.Tensor, k: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, - # k_scale: torch.Tensor, - # output: torch.Tensor, - kv_c_and_k_pe_cache_scale: torch.Tensor, - output: torch.Tensor | None = None, - ) -> None: + kv_c_and_k_pe_cache_scale: torch.Tensor |None = None, + ) -> torch.Tensor: # TODO (zyongye): Prefill function here assert attn_metadata.prefill is not None assert self.dcp_world_size != -1 - has_context = attn_metadata.prefill.chunked_context is not None + prefill_metadata = attn_metadata.prefill + use_fp8_prefill = prefill_metadata.q_data_type == current_platform.fp8_dtype() + + # Convert q to FP8 if FP8 prefill attention is enabled + if use_fp8_prefill: + q = q.to(prefill_metadata.q_data_type) + + has_context = prefill_metadata.chunked_context is not None + kv_nope = self.kv_b_proj(kv_c_normed)[0].view( -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim ) - # k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - # k = self._concat_k_nope_k_pe(k_nope, k_pe) - - # output_prefill = self._run_prefill_new_tokens( - # prefill=attn_metadata.prefill, - # q=q, - # k=k, - # v=v, - # return_softmax_lse=has_context, - # ) k_nope, v_nope = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + v = v_nope - k[..., : self.qk_nope_head_dim] = k_nope - - attn_output = torch.empty( - q.shape[0], self.num_heads, self.v_head_dim, dtype=q.dtype, device=q.device - ) - + k[...,:self.qk_nope_head_dim] = k_nope + + attn_output = torch.empty(q.shape[0], self.num_heads, self.v_head_dim, dtype=q.dtype, device=q.device) + output = self._run_prefill_new_tokens( prefill=attn_metadata.prefill, q=q, @@ -2644,7 +2963,6 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ) if has_context: - # suffix_output, suffix_lse = output_prefill suffix_output, suffix_lse = output if self.dcp_world_size > 1: context_output, context_lse = ( @@ -2657,26 +2975,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ) ) else: - # context_output, context_lse = self._compute_prefill_context( - # q, kv_c_and_k_pe_cache, attn_metadata, k_scale - # ) - context_output, context_lse = self._compute_prefill_context( - q, kv_c_and_k_pe_cache, kv_c_and_k_pe_cache_scale, attn_metadata - ) + context_output, context_lse = self._compute_prefill_context( \ + q, kv_c_and_k_pe_cache, kv_c_and_k_pe_cache_scale, attn_metadata) - # unpad if necessary - if self._pad_v: - context_output = context_output[..., : v.shape[-1]] - suffix_output = suffix_output[..., : v.shape[-1]] - - # output = output.view(-1, self.num_heads, self.v_head_dim) - # merge_attn_states( - # output=output, - # prefix_output=context_output, - # prefix_lse=context_lse, - # suffix_output=suffix_output, - # suffix_lse=suffix_lse, - # ) output = torch.empty_like(suffix_output) output = merge_attn_states( output=output, @@ -2685,322 +2986,22 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): suffix_output=suffix_output, suffix_lse=suffix_lse, ) - # else: - # output_prefill = output_prefill[..., : v.shape[-1]].flatten(start_dim=-2) - # output.copy_(output_prefill) + + # unpad if necessary + if self._pad_v: + output = output[..., : v.shape[-1]] + return output @abstractmethod def forward_mqa( self, - q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + ql_nope: torch.Tensor, + q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: M, - layer: AttentionLayer, + k_c_normed: torch.Tensor | None, + k_pe: torch.Tensor | None, + kv_c_and_k_pe_cache_scale: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor | None]: raise NotImplementedError - - def forward_prepare( - self, - positions: torch.Tensor, - ) -> None: - self.positions = positions - - def forward( - self, - layer: AttentionLayer, - q: torch.Tensor, - k_c_normed: torch.Tensor, # key in unified attn - k_pe: torch.Tensor, # value in unified attn - kv_cache: torch.Tensor, - attn_metadata: M, - output: torch.Tensor | None = None, - kv_cache_scale: torch.Tensor | None = None, - output_scale: torch.Tensor | None = None, - output_block_scale: torch.Tensor | None = None, - ) -> torch.Tensor: - assert output is not None, "Output tensor must be provided." - - if output_scale is not None or output_block_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported for MLACommonImpl" - ) - - if attn_metadata is None: - # During the profile run try to simulate to worse case output size - # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` - # since this can be large - _ = torch.empty( - ( - self.chunked_prefill_workspace_size, - self.num_heads, - self.qk_nope_head_dim + self.v_head_dim, - ), - device=k_c_normed.device, - dtype=k_c_normed.dtype, - ) - - # The zero fill is required when used with DP + EP - # to ensure all ranks within a DP group compute the - # same expert outputs. - # return output.fill_(0) - output = torch.empty( - output.shape[0], - self.v_head_dim * self.num_heads, - device=q.device, - dtype=q.dtype, - ) - return output - - if self.dcp_world_size == -1: - self.dcp_world_size = get_dcp_group().world_size - - fp8_attention = self.kv_cache_dtype.startswith("fp8") - - # num_actual_toks = attn_metadata.num_actual_tokens - - # Inputs and outputs may be padded for CUDA graphs - # output_padded = output - # output = output[:num_actual_toks, ...] - # q = q[:num_actual_toks, ...] - # k_c_normed = k_c_normed[:num_actual_toks, ...] - # k_pe = k_pe[:num_actual_toks, ...] - - assert ( - attn_metadata.num_decodes is not None - and attn_metadata.num_prefills is not None - and attn_metadata.num_decode_tokens is not None - ) - - has_decode = attn_metadata.num_decodes > 0 - has_prefill = attn_metadata.num_prefills > 0 - num_decode_tokens = attn_metadata.num_decode_tokens - - decode_q = q[:num_decode_tokens] - - k_pe = k_pe.unsqueeze(1) - prefill_q = q[num_decode_tokens:] - prefill_k_pe = k_pe[num_decode_tokens:] - prefill_k_c_normed = k_c_normed[num_decode_tokens:] - prefill_k = torch.empty_like(prefill_q) - - # write the latent and rope to kv cache - write_kv_cache = (None, None) - if kv_cache.numel() > 0: - # ops.concat_and_cache_mla( - # k_c_normed, - # k_pe.squeeze(1), - # kv_cache, - # attn_metadata.slot_mapping.flatten(), - # kv_cache_dtype=self.kv_cache_dtype, - # scale=layer._k_scale, - # ) - if has_decode: - decode_q_pe, decode_k_pe = self.rotary_emb( - self.positions[:num_decode_tokens], - decode_q[..., self.qk_nope_head_dim :], - k_pe[:num_decode_tokens], - ) - if envs.VLLM_USE_INT8_MLA: - write_kv_cache = (None, None) - k_c_normed_int8, k_c_normed_scale, _ = ops.scaled_int8_quant( - k_c_normed[:num_decode_tokens] - ) - decode_k_pe_int8, decode_k_pe_scale, _ = ops.scaled_int8_quant( - decode_k_pe - ) - ops.concat_and_cache_mla_int8( - kv_c_int8=k_c_normed_int8, - kv_c_scale=k_c_normed_scale[..., 0], - k_pe_int8=decode_k_pe_int8, - k_pe_scale=decode_k_pe_scale[..., 0].view( - -1, decode_k_pe_int8.shape[-2] - ), - kv_cache=kv_cache, - kv_cache_scale=kv_cache_scale, - slot_mapping=attn_metadata.slot_mapping.flatten()[ - :num_decode_tokens - ], - kv_cache_dtype=self.kv_cache_dtype, - scale=layer._k_scale, - ) - else: - # write_kv_cache = (k_c_normed[:num_decode_tokens], decode_k_pe) - if self.dcp_world_size > 1: - ops.concat_and_cache_mla( - k_c_normed[:num_decode_tokens], - decode_k_pe, - kv_cache, - attn_metadata.slot_mapping.flatten()[:num_decode_tokens], - kv_cache_dtype=self.kv_cache_dtype, - scale=layer._k_scale, - ) - else: - write_kv_cache = (k_c_normed[:num_decode_tokens], decode_k_pe) - if has_prefill: - ixf_ops.mla_rope( - self.positions[num_decode_tokens:], - prefill_q[..., self.qk_nope_head_dim :], - prefill_k_pe.squeeze(1), - prefill_k[..., self.qk_nope_head_dim :], - self.rotary_emb.cos_sin_cache, - ) - if envs.VLLM_USE_INT8_MLA: - prefill_k_c_normed_int8, prefill_k_c_normed_scale, _ = ( - ops.scaled_int8_quant(prefill_k_c_normed) - ) - prefill_k_pe_int8, prefill_k_pe_scale, _ = ops.scaled_int8_quant( - prefill_k[..., self.qk_nope_head_dim :].contiguous() - ) - - ops.concat_and_cache_mla_int8( - prefill_k_c_normed_int8, - prefill_k_c_normed_scale[..., 0], - prefill_k_pe_int8, - prefill_k_pe_scale[..., 0].view( - -1, prefill_k_pe_int8.shape[-2] - ), - kv_cache, - kv_cache_scale, - attn_metadata.slot_mapping.flatten()[num_decode_tokens:], - kv_cache_dtype=self.kv_cache_dtype, - scale=layer._k_scale, - ) - else: - ops.concat_and_cache_mla( - prefill_k_c_normed, - prefill_k[..., self.qk_nope_head_dim :], - kv_cache, - attn_metadata.slot_mapping.flatten()[num_decode_tokens:], - kv_cache_dtype=self.kv_cache_dtype, - scale=layer._k_scale, - ) - output = torch.empty( - output.shape[0], - self.num_heads, - self.v_head_dim, - device=q.device, - dtype=q.dtype, - ) - - if fp8_attention: - kv_cache = kv_cache.view(current_platform.fp8_dtype()) - - if has_prefill: - # self.forward_mha( - # prefill_q, - # prefill_k_c_normed, - # prefill_k_pe, - # kv_cache, - # attn_metadata, - # layer._k_scale, - # output=output[num_decode_tokens:], - # ) - output[num_decode_tokens:] = self.forward_mha( - prefill_q, - prefill_k_c_normed, - prefill_k, - kv_cache, - attn_metadata, - kv_c_and_k_pe_cache_scale=kv_cache_scale, - ) - - if has_decode: - assert attn_metadata.decode is not None - - # decode_q_nope, decode_q_pe = decode_q.split( - # [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 - # ) - - # # Convert from (B, N, P) to (N, B, P) - # decode_q_nope = decode_q_nope.transpose(0, 1) - - # if self.q_pad_num_heads is not None: - # B, N, L = decode_q_pe.shape - # decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L)) - # decode_pe_padded.resize_((B, N, L)) - # decode_pe_padded.copy_(decode_q_pe) - # decode_q_pe = decode_pe_padded - - # if self.is_aiter_triton_fp4_bmm_enabled: - # from aiter.ops.triton.batched_gemm_a16wfp4 import batched_gemm_a16wfp4 - - # decode_ql_nope = batched_gemm_a16wfp4( - # decode_q_nope, - # self.W_K, - # self.W_K_scale, - # transpose_bm=True, - # prequant=True, - # y_scale=layer._q_scale if fp8_attention else None, - # ) - # elif self.is_aiter_triton_fp8_bmm_enabled: - # # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) - # decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm( - # decode_q_nope, - # self.W_K, - # self.W_K_scale, - # group_size=128, - # transpose_bm=True, - # ) - # else: - # # Pads the head_dim if necessary (for the underlying kernel) - # N, B, P = decode_q_nope.shape - # _, _, L = self.W_UK_T.shape - - # if self.q_pad_num_heads is not None: - # decode_ql_nope = decode_q_nope.new_empty( - # (self.q_pad_num_heads, B, L) - # ) - # decode_ql_nope.resize_((N, B, L)) - # else: - # decode_ql_nope = decode_q_nope.new_empty((N, B, L)) - - # # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - # torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope) - - # # Convert from (N, B, L) to (B, N, L) - # decode_ql_nope = decode_ql_nope.transpose(0, 1) - - # if fp8_attention: - # assert decode_ql_nope.shape[0] == decode_q_pe.shape[0] - # assert decode_ql_nope.shape[1] == decode_q_pe.shape[1] - # decode_q = self._decode_concat_quant_fp8_op( - # decode_ql_nope, decode_q_pe, layer._q_scale - # ) - # else: - # decode_q = (decode_ql_nope, decode_q_pe) - # if self.dcp_world_size > 1: - # assert not fp8_attention, "DCP not support fp8 kvcache now." - # # concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P) - # decode_q = torch.cat(decode_q, dim=-1) - # # decode_q do allgather in head dim. - # decode_q = get_dcp_group().all_gather(decode_q, dim=1) - - # call decode attn - # decode_q, kv_cache, attn_metadata, layer - - attn_out, lse = self.forward_mqa( - decode_q[..., :self.qk_nope_head_dim], decode_q_pe, kv_cache, attn_metadata, - *write_kv_cache, kv_c_and_k_pe_cache_scale=kv_cache_scale - ) - - # correct dcp attn_out with lse. - if self.dcp_world_size > 1: - assert lse is not None - attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group()) - output[:num_decode_tokens] = self._v_up_proj(attn_out) - else: - assert lse is None - output[:num_decode_tokens] = attn_out - - return output.view(output.shape[0], self.v_head_dim * self.num_heads) - # attn_out = cp_lse_ag_out_rs( - # attn_out, - # lse, - # get_dcp_group(), - # is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False), - # ) - - # # v_up projection - # self._v_up_proj(attn_out, out=output[:num_decode_tokens]) - # return output_padded diff --git a/vllm/model_executor/layers/attention/mm_encoder_attention.py b/vllm/model_executor/layers/attention/mm_encoder_attention.py index e59806a..64370a6 100644 --- a/vllm/model_executor/layers/attention/mm_encoder_attention.py +++ b/vllm/model_executor/layers/attention/mm_encoder_attention.py @@ -2,21 +2,94 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import numpy as np import torch from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.models.vision import get_vit_attn_backend +from vllm.utils.math_utils import round_up from vllm.v1.attention.backends.fa_utils import get_flash_attn_version from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.ops.vit_attn_wrappers import ( vit_flash_attn_wrapper, + vit_flashinfer_wrapper, vit_torch_sdpa_wrapper, vit_triton_attn_wrapper, ) +import ixformer.contrib.vllm_flash_attn as ops logger = init_logger(__name__) +# Batch buckets for cuDNN graph caching. +# Graphs use batch size and max sequence length as cache key. +# This avoids creating a new graph for each unique set of +# batch size and max sequence length at runtime. +# From the cuDNN team's performance measurements, there +# is no significant kernel performance difference between padding +# to a smaller batch size/seq length and padding to larger +# ones. The bucketing here is solely used to avoid memory +# operation overhead, which won't be needed if we have CUDA +# graph support in the future. +# TODO: Remove buckets after issue #34763 +# (cuda graph support) is addressed. +FLASHINFER_BATCH_BUCKETS = [8, 16, 32, 64] +FLASHINFER_MAX_SEQLEN_BUCKETS = [ + 1 * 1024, + 2 * 1024, + 4 * 1024, + 8 * 1024, + 16 * 1024, + 32 * 1024, + 64 * 1024, + 128 * 1024, +] + +# Workspace buffer for FlashInfer CuDNN backend +FLASHINFER_CUDNN_WORKSPACE_SIZE_BYTES = 128 * 1024 * 1024 +_flashinfer_workspace_buffer: torch.Tensor | None = None + + +def _get_flashinfer_workspace_buffer() -> torch.Tensor: + global _flashinfer_workspace_buffer + if _flashinfer_workspace_buffer is None: + _flashinfer_workspace_buffer = torch.zeros( + FLASHINFER_CUDNN_WORKSPACE_SIZE_BYTES, + dtype=torch.uint8, + device="cuda", + ) + return _flashinfer_workspace_buffer + + +def add_padding_to_seqlens( + seq: np.ndarray, + batch_size: int, + padding_value: int, +) -> np.ndarray: + batch_size_padded = next( + (b for b in FLASHINFER_BATCH_BUCKETS if b >= batch_size), + round_up(batch_size, FLASHINFER_BATCH_BUCKETS[0]), + ) + if batch_size_padded == batch_size: + return seq + return np.concatenate( + [ + seq, + np.full((batch_size_padded - batch_size,), padding_value, dtype=seq.dtype), + ] + ) + + +def bucket_flashinfer_max_seqlen( + real_max_seqlen: int, +) -> int: + if real_max_seqlen <= 0: + return FLASHINFER_MAX_SEQLEN_BUCKETS[0] + return next( + (s for s in FLASHINFER_MAX_SEQLEN_BUCKETS if s >= real_max_seqlen), + round_up(real_max_seqlen, FLASHINFER_MAX_SEQLEN_BUCKETS[-1]), + ) + # --8<-- [start:mm_encoder_attn] @CustomOp.register("mm_encoder_attn") @@ -24,6 +97,67 @@ class MMEncoderAttention(CustomOp): """Multi-headed attention without any cache, used for multimodal encoder.""" # --8<-- [end:mm_encoder_attn] + @classmethod + def compute_max_seqlen( + cls, + attn_backend: AttentionBackendEnum, + cu_seqlens: np.ndarray, + ) -> int: + max_seqlen = 0 + if ( + attn_backend + in ( + AttentionBackendEnum.FLASH_ATTN, + AttentionBackendEnum.ROCM_AITER_FA, + AttentionBackendEnum.TRITON_ATTN, + AttentionBackendEnum.FLASHINFER, + ) + and len(cu_seqlens) >= 2 + ): + max_seqlen = int((cu_seqlens[1:] - cu_seqlens[:-1]).max()) + if attn_backend == AttentionBackendEnum.FLASHINFER: + max_seqlen = bucket_flashinfer_max_seqlen(max_seqlen) + return max_seqlen + + @classmethod + def maybe_compute_sequence_lengths( + cls, + attn_backend: AttentionBackendEnum, + cu_seqlens: np.ndarray, + ) -> np.ndarray | None: + if attn_backend != AttentionBackendEnum.FLASHINFER: + return None + sequence_lengths = cu_seqlens[1:] - cu_seqlens[:-1] + sequence_lengths = add_padding_to_seqlens( + sequence_lengths, len(sequence_lengths), 0 + ) + return sequence_lengths + + @classmethod + def maybe_recompute_cu_seqlens( + cls, + attn_backend: AttentionBackendEnum, + cu_seqlens: np.ndarray, + hidden_size: int, + tp_size: int, + ) -> np.ndarray: + if attn_backend != AttentionBackendEnum.FLASHINFER: + return cu_seqlens + + batch_size = len(cu_seqlens) - 1 + scale = hidden_size // tp_size + cu_seqlens = cu_seqlens * scale + + cu_seqlens_qko = cu_seqlens + cu_seqlens_v = cu_seqlens * 3 + + cu_seqlens_qko = add_padding_to_seqlens( + cu_seqlens_qko, batch_size, cu_seqlens_qko[-1] + ) + cu_seqlens_v = add_padding_to_seqlens( + cu_seqlens_v, batch_size, cu_seqlens_v[-1] + ) + return np.concatenate([cu_seqlens_qko, cu_seqlens_v]) def __init__( self, @@ -46,10 +180,9 @@ class MMEncoderAttention(CustomOp): self.num_heads = num_heads self.head_size = head_size - self.scale = scale + self.scale = 1.0 / (head_size**0.5) if scale is None else scale self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.layer_name = prefix - assert self.num_heads % self.num_kv_heads == 0, ( f"num_heads ({self.num_heads}) is not " f"divisible by num_kv_heads ({self.num_kv_heads})" @@ -72,9 +205,14 @@ class MMEncoderAttention(CustomOp): } self._fa_version = ( - get_flash_attn_version() if self.is_flash_attn_backend else None + get_flash_attn_version(head_size=head_size) + if self.is_flash_attn_backend + else None ) + if self.attn_backend == AttentionBackendEnum.FLASHINFER: + _get_flashinfer_workspace_buffer() + logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.") @classmethod @@ -148,23 +286,27 @@ class MMEncoderAttention(CustomOp): bsz, q_len = query.size()[:2] kv_len = key.size(1) is_reshaped = query.dim() != 4 + query = query.view(bsz * q_len, self.num_heads, self.head_size) + key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size) + value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size) - query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len) - - output = vit_flash_attn_wrapper( - q=query, - k=key, - v=value, - batch_size=bsz, - is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA), - fa_version=self._fa_version, - scale=self.scale, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, + cu_q = torch.tensor([0,] + [q_len for _ in range(bsz)], device=query.device, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32) + cu_kv = torch.tensor([0,] + [kv_len for _ in range(bsz)], device=query.device, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32) + out = ops.flash_attn_varlen_func( + query, + key, + value, + cu_q, + cu_kv, + q_len, + kv_len, + softmax_scale=self.scale, + causal=False, ) + out = out.view(bsz, q_len, self.num_heads, self.head_size) if is_reshaped: - output = output.reshape(bsz, q_len, -1) - return output + out = out.reshape(bsz, q_len, -1) + return out def _forward_triton( self, @@ -201,6 +343,27 @@ class MMEncoderAttention(CustomOp): output = output.reshape(bsz, q_len, -1) return output + def _forward_flashinfer( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + max_seqlen: torch.Tensor | None = None, + sequence_lengths: torch.Tensor + | None = None, # Only used for FlashInfer CuDNN backend + ) -> torch.Tensor: + return vit_flashinfer_wrapper( + q=query, + k=key, + v=value, + scale=self.scale, + workspace_buffer=_get_flashinfer_workspace_buffer(), + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + sequence_lengths=sequence_lengths, + ) + def forward_native( self, query: torch.Tensor, @@ -208,6 +371,8 @@ class MMEncoderAttention(CustomOp): value: torch.Tensor, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + sequence_lengths: torch.Tensor + | None = None, # Only used for FlashInfer CuDNN backend ) -> torch.Tensor: return self._forward_sdpa(query, key, value, cu_seqlens) @@ -218,11 +383,17 @@ class MMEncoderAttention(CustomOp): value: torch.Tensor, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + sequence_lengths: torch.Tensor + | None = None, # Only used for FlashInfer CuDNN backend ) -> torch.Tensor: if self.is_flash_attn_backend: return self._forward_fa(query, key, value, cu_seqlens, max_seqlen) elif self.attn_backend == AttentionBackendEnum.TRITON_ATTN: return self._forward_triton(query, key, value, cu_seqlens, max_seqlen) + elif self.attn_backend == AttentionBackendEnum.FLASHINFER: + return self._forward_flashinfer( + query, key, value, cu_seqlens, max_seqlen, sequence_lengths + ) elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA: return self._forward_sdpa(query, key, value, cu_seqlens) else: @@ -238,6 +409,8 @@ class MMEncoderAttention(CustomOp): value: torch.Tensor, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + sequence_lengths: torch.Tensor + | None = None, # Only used for FlashInfer CuDNN backend ) -> torch.Tensor: return self._forward_sdpa(query, key, value, cu_seqlens) @@ -248,6 +421,8 @@ class MMEncoderAttention(CustomOp): value: torch.Tensor, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention + sequence_lengths: torch.Tensor + | None = None, # Only used for FlashInfer CuDNN backend ) -> torch.Tensor: if self.attn_backend == AttentionBackendEnum.FLASH_ATTN: return self._forward_fa(query, key, value, cu_seqlens, max_seqlen) diff --git a/vllm/model_executor/layers/fla/ops/__init__.py b/vllm/model_executor/layers/fla/ops/__init__.py index c19cc14..e52387a 100644 --- a/vllm/model_executor/layers/fla/ops/__init__.py +++ b/vllm/model_executor/layers/fla/ops/__init__.py @@ -7,11 +7,17 @@ # the following copyright notice: # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang from .chunk import chunk_gated_delta_rule -from .fused_recurrent import fused_recurrent_gated_delta_rule +from .fused_recurrent import ( + fused_recurrent_gated_delta_rule, + fused_recurrent_gated_delta_rule_packed_decode, +) +from .fused_sigmoid_gating import fused_sigmoid_gating_delta_rule_update from .layernorm_guard import RMSNormGated __all__ = [ "RMSNormGated", "chunk_gated_delta_rule", "fused_recurrent_gated_delta_rule", + "fused_recurrent_gated_delta_rule_packed_decode", + "fused_sigmoid_gating_delta_rule_update", ] diff --git a/vllm/model_executor/layers/fla/ops/chunk.py b/vllm/model_executor/layers/fla/ops/chunk.py index 40f8c3c..9261885 100644 --- a/vllm/model_executor/layers/fla/ops/chunk.py +++ b/vllm/model_executor/layers/fla/ops/chunk.py @@ -30,7 +30,7 @@ def chunk_gated_delta_rule_fwd( scale: float, initial_state: torch.Tensor, output_final_state: bool, - cu_seqlens: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, ): g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) # obtain WY representation. u is actually the new v. @@ -84,7 +84,7 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function): scale: float, initial_state: torch.Tensor, output_final_state: bool, - cu_seqlens: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, use_qk_l2norm_in_kernel: bool = False, ): if use_qk_l2norm_in_kernel: @@ -117,7 +117,7 @@ def chunk_gated_delta_rule( scale: float = None, initial_state: torch.Tensor = None, output_final_state: bool = False, - cu_seqlens: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, use_qk_l2norm_in_kernel: bool = False, ): r""" @@ -141,7 +141,7 @@ def chunk_gated_delta_rule( Default: `None`. output_final_state (Optional[bool]): Whether to output the final state of shape `[N, H, V, K]`. Default: `False`. - cu_seqlens (torch.LongTensor): + cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape `[N+1]` used for variable-length training, consistent with the FlashAttention API. Returns: @@ -171,7 +171,7 @@ def chunk_gated_delta_rule( # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected - >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.int32) >>> o_var, ht_var = chunk_gated_delta_rule( q, k, v, g, beta, initial_state=h0, diff --git a/vllm/model_executor/layers/fla/ops/chunk_delta_h.py b/vllm/model_executor/layers/fla/ops/chunk_delta_h.py index 98a3d61..ce60ca4 100644 --- a/vllm/model_executor/layers/fla/ops/chunk_delta_h.py +++ b/vllm/model_executor/layers/fla/ops/chunk_delta_h.py @@ -288,7 +288,7 @@ def chunk_gated_delta_rule_fwd_h( output_final_state: bool = False, chunk_size: int = 64, # SY: remove this argument and force chunk size 64? save_new_value: bool = True, - cu_seqlens: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: # This kernel is slightly different from fla to support Q/K with different head numbers. # In fla, Q/K always have the same head number, so Hg is always equal to H. diff --git a/vllm/model_executor/layers/fla/ops/chunk_o.py b/vllm/model_executor/layers/fla/ops/chunk_o.py index 16890af..1307812 100644 --- a/vllm/model_executor/layers/fla/ops/chunk_o.py +++ b/vllm/model_executor/layers/fla/ops/chunk_o.py @@ -89,7 +89,7 @@ def chunk_fwd_kernel_o( b_o = tl.zeros([BT, BV], dtype=tl.float32) b_A = tl.zeros([BT, BT], dtype=tl.float32) - + for i_k in range(tl.cdiv(K, BK)): p_q = tl.make_block_ptr( q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) @@ -145,7 +145,7 @@ def chunk_fwd_o( h: torch.Tensor, g: torch.Tensor | None = None, # cumsum of log decay scale: float | None = None, - cu_seqlens: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, chunk_size: int = 64, ) -> torch.Tensor: B, T, Hg, K, V = *q.shape, v.shape[-1] diff --git a/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py b/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py index 7724fa5..31bd489 100644 --- a/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py +++ b/vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py @@ -102,7 +102,7 @@ def chunk_scaled_dot_kkt_fwd( k: torch.Tensor, g: torch.Tensor | None = None, beta: torch.Tensor | None = None, - cu_seqlens: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, chunk_size: int = 64, output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: @@ -116,7 +116,7 @@ def chunk_scaled_dot_kkt_fwd( The beta tensor of shape `[B, T, H]`. g (torch.Tensor): The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`. - cu_seqlens (torch.LongTensor): + cu_seqlens (torch.Tensor): The cumulative sequence lengths of the input tensor. Default: None chunk_size (int): diff --git a/vllm/model_executor/layers/fla/ops/fused_recurrent.py b/vllm/model_executor/layers/fla/ops/fused_recurrent.py index 67d77e8..920efa4 100644 --- a/vllm/model_executor/layers/fla/ops/fused_recurrent.py +++ b/vllm/model_executor/layers/fla/ops/fused_recurrent.py @@ -106,12 +106,12 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 else: i_t = 0 - # Load state index and check for PAD_SLOT_ID (-1) + # Load state index and check for invalid entries state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to( tl.int64 ) - # Skip if state index is invalid (PAD_SLOT_ID = -1) - if state_idx < 0: + # Skip if state index is invalid (NULL_BLOCK_ID=0) + if state_idx <= 0: return p_h0 = h0 + state_idx * stride_init_state_token else: @@ -150,12 +150,12 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( # keep the states for multi-query tokens if INPLACE_FINAL_STATE: - # Load state index and check for PAD_SLOT_ID (-1) + # Load state index and check for invalid entries final_state_idx = tl.load( ssm_state_indices + i_n * stride_indices_seq + i_t ).to(tl.int64) - # Only store if state index is valid (not PAD_SLOT_ID) - if final_state_idx >= 0: + # Only store if state index is valid (not NULL_BLOCK_ID=0) + if final_state_idx > 0: p_ht = ht + final_state_idx * stride_final_state_token p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :] tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) @@ -184,7 +184,7 @@ def fused_recurrent_gated_delta_rule_fwd( scale: float, initial_state: torch.Tensor, inplace_final_state: bool = True, - cu_seqlens: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, ssm_state_indices: torch.Tensor | None = None, num_accepted_tokens: torch.Tensor | None = None, use_qk_l2norm_in_kernel: bool = False, @@ -252,6 +252,232 @@ def fused_recurrent_gated_delta_rule_fwd( return o, final_state +@triton.jit +def fused_recurrent_gated_delta_rule_packed_decode_kernel( + mixed_qkv, + a, + b, + A_log, + dt_bias, + o, + h0, + ht, + ssm_state_indices, + scale, + stride_mixed_qkv_tok: tl.constexpr, + stride_a_tok: tl.constexpr, + stride_b_tok: tl.constexpr, + stride_init_state_token: tl.constexpr, + stride_final_state_token: tl.constexpr, + stride_indices_seq: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + SOFTPLUS_THRESHOLD: tl.constexpr, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + + o_k = tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_v[:, None] & mask_k[None, :] + + state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq).to(tl.int64) + p_o = o + (i_n * HV + i_hv) * V + o_v + + # Skip if state index is invalid (NULL_BLOCK_ID=0) + if state_idx <= 0: + zero = tl.zeros([BV], dtype=tl.float32).to(p_o.dtype.element_ty) + tl.store(p_o, zero, mask=mask_v) + return + + p_h0 = h0 + state_idx * stride_init_state_token + p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :] + b_h = tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + p_mixed = mixed_qkv + i_n * stride_mixed_qkv_tok + q_off = i_h * K + o_k + k_off = (H * K) + i_h * K + o_k + v_off = (2 * H * K) + i_hv * V + o_v + b_q = tl.load(p_mixed + q_off, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_mixed + k_off, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_mixed + v_off, mask=mask_v, other=0).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) + b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) + b_q = b_q * scale + + a_val = tl.load(a + i_n * stride_a_tok + i_hv).to(tl.float32) + b_val = tl.load(b + i_n * stride_b_tok + i_hv).to(tl.float32) + A_log_val = tl.load(A_log + i_hv).to(tl.float32) + dt_bias_val = tl.load(dt_bias + i_hv).to(tl.float32) + x = a_val + dt_bias_val + softplus_x = tl.where(x <= SOFTPLUS_THRESHOLD, tl.log(1.0 + tl.exp(x)), x) + g_val = -tl.exp(A_log_val) * softplus_x + beta_val = tl.sigmoid(b_val).to(b.dtype.element_ty).to(tl.float32) + + b_h *= exp(g_val) + b_v -= tl.sum(b_h * b_k[None, :], 1) + b_v *= beta_val + b_h += b_v[:, None] * b_k[None, :] + b_o = tl.sum(b_h * b_q[None, :], 1) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + p_ht = ht + state_idx * stride_final_state_token + p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + +def fused_recurrent_gated_delta_rule_packed_decode( + mixed_qkv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + out: torch.Tensor, + ssm_state_indices: torch.Tensor, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + if mixed_qkv.ndim != 2: + raise ValueError( + f"`mixed_qkv` must be a 2D tensor (got ndim={mixed_qkv.ndim})." + ) + if mixed_qkv.stride(-1) != 1: + raise ValueError("`mixed_qkv` must be contiguous in the last dim.") + if a.ndim != 2 or b.ndim != 2: + raise ValueError( + f"`a` and `b` must be 2D tensors (got a.ndim={a.ndim}, b.ndim={b.ndim})." + ) + if a.stride(-1) != 1 or b.stride(-1) != 1: + raise ValueError("`a`/`b` must be contiguous in the last dim.") + if A_log.ndim != 1 or dt_bias.ndim != 1: + raise ValueError("`A_log`/`dt_bias` must be 1D tensors.") + if A_log.stride(0) != 1 or dt_bias.stride(0) != 1: + raise ValueError("`A_log`/`dt_bias` must be contiguous.") + if ssm_state_indices.ndim != 1: + raise ValueError( + f"`ssm_state_indices` must be 1D for packed decode (got ndim={ssm_state_indices.ndim})." + ) + if not out.is_contiguous(): + raise ValueError("`out` must be contiguous.") + + dev = mixed_qkv.device + if ( + a.device != dev + or b.device != dev + or A_log.device != dev + or dt_bias.device != dev + or initial_state.device != dev + or out.device != dev + or ssm_state_indices.device != dev + ): + raise ValueError("All inputs must be on the same device.") + + B = mixed_qkv.shape[0] + if a.shape[0] != B or b.shape[0] != B: + raise ValueError( + "Mismatched batch sizes: " + f"mixed_qkv.shape[0]={B}, a.shape[0]={a.shape[0]}, b.shape[0]={b.shape[0]}." + ) + if ssm_state_indices.shape[0] != B: + raise ValueError( + f"`ssm_state_indices` must have shape [B] (got {tuple(ssm_state_indices.shape)}; expected ({B},))." + ) + + if initial_state.ndim != 4: + raise ValueError( + f"`initial_state` must be a 4D tensor (got ndim={initial_state.ndim})." + ) + if initial_state.stride(-1) != 1: + raise ValueError("`initial_state` must be contiguous in the last dim.") + HV, V, K = initial_state.shape[-3:] + if a.shape[1] != HV or b.shape[1] != HV: + raise ValueError( + f"`a`/`b` must have shape [B, HV] with HV={HV} (got a.shape={tuple(a.shape)}, b.shape={tuple(b.shape)})." + ) + if A_log.numel() != HV or dt_bias.numel() != HV: + raise ValueError( + f"`A_log` and `dt_bias` must have {HV} elements (got A_log.numel()={A_log.numel()}, dt_bias.numel()={dt_bias.numel()})." + ) + if out.shape != (B, 1, HV, V): + raise ValueError( + f"`out` must have shape {(B, 1, HV, V)} (got out.shape={tuple(out.shape)})." + ) + + qkv_dim = mixed_qkv.shape[1] + qk_dim = qkv_dim - HV * V + if qk_dim <= 0 or qk_dim % 2 != 0: + raise ValueError( + f"Invalid packed `mixed_qkv` last dim={qkv_dim} for HV={HV}, V={V}." + ) + q_dim = qk_dim // 2 + if q_dim % K != 0: + raise ValueError(f"Invalid packed Q size {q_dim}: must be divisible by K={K}.") + H = q_dim // K + if H <= 0 or HV % H != 0: + raise ValueError( + f"Invalid head config inferred from mixed_qkv: H={H}, HV={HV}." + ) + + BK = triton.next_power_of_2(K) + if triton.cdiv(K, BK) != 1: + raise ValueError( + f"Packed decode kernel only supports NK=1 (got K={K}, BK={BK})." + ) + BV = min(triton.next_power_of_2(V), 32) + num_stages = 3 + num_warps = 1 + + stride_mixed_qkv_tok = mixed_qkv.stride(0) + stride_a_tok = a.stride(0) + stride_b_tok = b.stride(0) + stride_init_state_token = initial_state.stride(0) + stride_final_state_token = initial_state.stride(0) + stride_indices_seq = ssm_state_indices.stride(0) + + NV = triton.cdiv(V, BV) + grid = (NV, B * HV) + fused_recurrent_gated_delta_rule_packed_decode_kernel[grid]( + mixed_qkv=mixed_qkv, + a=a, + b=b, + A_log=A_log, + dt_bias=dt_bias, + o=out, + h0=initial_state, + ht=initial_state, + ssm_state_indices=ssm_state_indices, + scale=scale, + stride_mixed_qkv_tok=stride_mixed_qkv_tok, + stride_a_tok=stride_a_tok, + stride_b_tok=stride_b_tok, + stride_init_state_token=stride_init_state_token, + stride_final_state_token=stride_final_state_token, + stride_indices_seq=stride_indices_seq, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + SOFTPLUS_THRESHOLD=20.0, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + num_warps=num_warps, + num_stages=num_stages, + ) + return out, initial_state + + class FusedRecurrentFunction(torch.autograd.Function): @staticmethod def forward( @@ -264,7 +490,7 @@ class FusedRecurrentFunction(torch.autograd.Function): scale: float, initial_state: torch.Tensor, inplace_final_state: bool = True, - cu_seqlens: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, ssm_state_indices: torch.Tensor | None = None, num_accepted_tokens: torch.Tensor | None = None, use_qk_l2norm_in_kernel: bool = False, @@ -296,7 +522,7 @@ def fused_recurrent_gated_delta_rule( scale: float = None, initial_state: torch.Tensor = None, inplace_final_state: bool = True, - cu_seqlens: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, ssm_state_indices: torch.Tensor | None = None, num_accepted_tokens: torch.Tensor | None = None, use_qk_l2norm_in_kernel: bool = False, @@ -324,7 +550,7 @@ def fused_recurrent_gated_delta_rule( inplace_final_state: bool: Whether to store the final state in-place to save memory. Default: `True`. - cu_seqlens (torch.LongTensor): + cu_seqlens (torch.Tensor): Cumulative sequence lengths of shape `[N+1]` used for variable-length training, consistent with the FlashAttention API. ssm_state_indices (Optional[torch.Tensor]): @@ -358,7 +584,7 @@ def fused_recurrent_gated_delta_rule( # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected - >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.int32) >>> o_var, ht_var = fused_gated_recurrent_delta_rule( q, k, v, g, beta, initial_state=h0, diff --git a/vllm/model_executor/layers/fla/ops/fused_sigmoid_gating.py b/vllm/model_executor/layers/fla/ops/fused_sigmoid_gating.py new file mode 100644 index 0000000..7e0c7e0 --- /dev/null +++ b/vllm/model_executor/layers/fla/ops/fused_sigmoid_gating.py @@ -0,0 +1,279 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch + +from vllm.triton_utils import tl, triton + + +@triton.heuristics( + { + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None, + "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, + } +) +@triton.jit(do_not_specialize=["N", "T"]) +def fused_sigmoid_gating_delta_rule_update_kernel( + A_log, + a, + b, + dt_bias, + beta, + threshold, + q, + k, + v, + o, + h0, + ht, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + scale, + N: tl.int64, # num of sequences + T: tl.int64, # num of tokens + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + stride_init_state_token: tl.constexpr, + stride_final_state_token: tl.constexpr, + stride_indices_seq: tl.constexpr, + stride_indices_tok: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + IS_KDA: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int64), + tl.load(cu_seqlens + i_n + 1).to(tl.int64), + ) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + if T == 0: + # no tokens to process for this sequence + return + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v + + p_A_log = A_log + i_hv + if not IS_KDA: + p_a = a + bos * HV + i_hv + p_dt_bias = dt_bias + i_hv + else: + p_a = a + (bos * HV + i_hv) * K + o_k + p_dt_bias = dt_bias + i_hv * K + o_k + + p_b = b + bos * HV + i_hv + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_v[:, None] & mask_k[None, :] + + b_h = tl.zeros([BV, BK], dtype=tl.float32) + if USE_INITIAL_STATE: + if IS_CONTINUOUS_BATCHING: + if IS_SPEC_DECODING: + i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 + else: + i_t = 0 + # Load state index and check for invalid entries + state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to( + tl.int64 + ) + # Skip if state index is invalid (NULL_BLOCK_ID=0) + if state_idx <= 0: + return + p_h0 = h0 + state_idx * stride_init_state_token + else: + p_h0 = h0 + bos * HV * V * K + p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for i_t in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_b = tl.load(p_b).to(tl.float32) + + # If the model is loaded in fp16, without the .float() here, A might be -inf + x = tl.load(p_a).to(tl.float32) + tl.load(p_dt_bias).to(tl.float32) + softplus_x = tl.where( + beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x + ) + b_g = -tl.exp(tl.load(p_A_log).to(tl.float32)) * softplus_x + + # compute beta_output = sigmoid(b) + b_beta = tl.sigmoid(b_b.to(tl.float32)) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q * (tl.rsqrt(tl.sum(b_q * b_q) + 1e-6)) + b_k = b_k * (tl.rsqrt(tl.sum(b_k * b_k) + 1e-6)) + b_q = b_q * scale + # [BV, BK] + if not IS_KDA: + b_h *= tl.exp(b_g) + else: + b_h *= tl.exp(b_g[None, :]) + # [BV] + b_v -= tl.sum(b_h * b_k[None, :], 1) + b_v *= b_beta + # [BV, BK] + b_h += b_v[:, None] * b_k[None, :] + # [BV] + b_o = tl.sum(b_h * b_q[None, :], 1) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + # keep the states for multi-query tokens + if INPLACE_FINAL_STATE: + # Load state index and check for invalid entries + final_state_idx = tl.load( + ssm_state_indices + i_n * stride_indices_seq + i_t + ).to(tl.int64) + # Only store if state index is valid (not NULL_BLOCK_ID=0) + if final_state_idx > 0: + p_ht = ht + final_state_idx * stride_final_state_token + p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + else: + p_ht = ht + (bos + i_t) * stride_final_state_token + p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + # Update pointers for next timestep + p_q += H * K + p_k += H * K + p_o += HV * V + p_v += HV * V + p_b += HV + p_a += HV + + +def fused_sigmoid_gating_delta_rule_update( + A_log: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + dt_bias: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, + scale: float = None, + initial_state: torch.Tensor = None, + inplace_final_state: bool = True, + cu_seqlens: torch.Tensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, + use_qk_l2norm_in_kernel: bool = False, + is_kda: bool = False, +): + """ + Fused triton implementation of sigmoid gating delta rule update. + This function uses a single fused kernel that combines both sigmoid gating + computation and the recurrent delta rule update for better performance. + """ + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 3 + num_warps = 4 + + if cu_seqlens is not None and q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]}" + f" when using `cu_seqlens`. Please flatten variable-length" + f" inputs before processing." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + + o = q.new_empty(NK, *v.shape) + if inplace_final_state: + final_state = initial_state + else: + final_state = q.new_empty(T, HV, V, K, dtype=initial_state.dtype) + + stride_init_state_token = initial_state.stride(0) + stride_final_state_token = final_state.stride(0) + + if ssm_state_indices is None: + stride_indices_seq, stride_indices_tok = 1, 1 + elif ssm_state_indices.ndim == 1: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 + else: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() + + grid = (NK, NV, N * HV) + fused_sigmoid_gating_delta_rule_update_kernel[grid]( + A_log=A_log, + a=a.contiguous(), + b=b.contiguous(), + dt_bias=dt_bias, + beta=beta, + threshold=threshold, + q=q.contiguous(), + k=k.contiguous(), + v=v.contiguous(), + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + scale=scale, + N=N, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + stride_init_state_token=stride_init_state_token, + stride_final_state_token=stride_final_state_token, + stride_indices_seq=stride_indices_seq, + stride_indices_tok=stride_indices_tok, + INPLACE_FINAL_STATE=inplace_final_state, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + IS_KDA=is_kda, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, final_state diff --git a/vllm/model_executor/layers/fla/ops/index.py b/vllm/model_executor/layers/fla/ops/index.py index f023e13..810d32c 100644 --- a/vllm/model_executor/layers/fla/ops/index.py +++ b/vllm/model_executor/layers/fla/ops/index.py @@ -15,14 +15,12 @@ from .utils import tensor_cache @tensor_cache -def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: +def prepare_lens(cu_seqlens: torch.Tensor) -> torch.Tensor: return cu_seqlens[1:] - cu_seqlens[:-1] @tensor_cache -def prepare_chunk_indices( - cu_seqlens: torch.LongTensor, chunk_size: int -) -> torch.LongTensor: +def prepare_chunk_indices(cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor: indices = torch.cat( [ torch.arange(n) @@ -33,9 +31,7 @@ def prepare_chunk_indices( @tensor_cache -def prepare_chunk_offsets( - cu_seqlens: torch.LongTensor, chunk_size: int -) -> torch.LongTensor: +def prepare_chunk_offsets(cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor: return torch.cat( [cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)] ).cumsum(-1) diff --git a/vllm/model_executor/layers/fla/ops/kda.py b/vllm/model_executor/layers/fla/ops/kda.py index 7145933..e686471 100644 --- a/vllm/model_executor/layers/fla/ops/kda.py +++ b/vllm/model_executor/layers/fla/ops/kda.py @@ -37,7 +37,7 @@ def fused_recurrent_kda_fwd( scale: float, initial_state: torch.Tensor, inplace_final_state: bool = True, - cu_seqlens: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, ssm_state_indices: torch.Tensor | None = None, num_accepted_tokens: torch.Tensor | None = None, use_qk_l2norm_in_kernel: bool = False, @@ -115,7 +115,7 @@ def fused_recurrent_kda( initial_state: torch.Tensor = None, inplace_final_state: bool = True, use_qk_l2norm_in_kernel: bool = True, - cu_seqlens: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, ssm_state_indices: torch.LongTensor | None = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -692,7 +692,7 @@ def chunk_kda_scaled_dot_kkt_fwd( gk: torch.Tensor | None = None, beta: torch.Tensor | None = None, scale: float | None = None, - cu_seqlens: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, chunk_size: int = 64, output_dtype: torch.dtype = torch.float32, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -706,7 +706,7 @@ def chunk_kda_scaled_dot_kkt_fwd( The beta tensor of shape `[B, T, H]`. gk (torch.Tensor): The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`. - cu_seqlens (torch.LongTensor): + cu_seqlens (torch.Tensor): The cumulative sequence lengths of the input tensor. Default: None chunk_size (int): @@ -936,7 +936,7 @@ def recompute_w_u_fwd( A: torch.Tensor, q: torch.Tensor | None = None, gk: torch.Tensor | None = None, - cu_seqlens: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: B, T, H, K, V = *k.shape, v.shape[-1] BT = A.shape[-1] @@ -1104,7 +1104,7 @@ def chunk_gla_fwd_o_gk( h: torch.Tensor, o: torch.Tensor, scale: float, - cu_seqlens: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, chunk_size: int = 64, ): B, T, H, K, V = *q.shape, v.shape[-1] @@ -1148,7 +1148,7 @@ def chunk_kda_fwd( scale: float, initial_state: torch.Tensor, output_final_state: bool, - cu_seqlens: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, ): chunk_size = 64 g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens) @@ -1208,7 +1208,7 @@ def chunk_kda( initial_state: torch.Tensor = None, output_final_state: bool = False, use_qk_l2norm_in_kernel: bool = False, - cu_seqlens: torch.LongTensor | None = None, + cu_seqlens: torch.Tensor | None = None, **kwargs, ): if scale is None: diff --git a/vllm/model_executor/layers/fla/ops/layernorm_guard.py b/vllm/model_executor/layers/fla/ops/layernorm_guard.py index 74c08e0..3abfbff 100644 --- a/vllm/model_executor/layers/fla/ops/layernorm_guard.py +++ b/vllm/model_executor/layers/fla/ops/layernorm_guard.py @@ -84,6 +84,7 @@ def layer_norm_fwd_kernel( HAS_Z: tl.constexpr, NORM_BEFORE_GATE: tl.constexpr, IS_RMS_NORM: tl.constexpr, + ACTIVATION: tl.constexpr, ): # Map the program id to the starting row of X and Y it should compute. row_start = tl.program_id(0) * ROWS_PER_BLOCK @@ -112,7 +113,10 @@ def layer_norm_fwd_kernel( if HAS_Z and not NORM_BEFORE_GATE: Z_base = Z + rows[:, None] * stride_z_row + col_offsets z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32) - x *= z * tl.sigmoid(z) + if ACTIVATION == "swish" or ACTIVATION == "silu": + x *= z * tl.sigmoid(z) + elif ACTIVATION == "sigmoid": + x *= tl.sigmoid(z) # Compute mean and variance per row (reduce along axis 1) if not IS_RMS_NORM: @@ -155,7 +159,10 @@ def layer_norm_fwd_kernel( if HAS_Z and NORM_BEFORE_GATE: Z_base = Z + rows[:, None] * stride_z_row + col_offsets z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32) - y *= z * tl.sigmoid(z) + if ACTIVATION == "swish" or ACTIVATION == "silu": + y *= z * tl.sigmoid(z) + elif ACTIVATION == "sigmoid": + y *= tl.sigmoid(z) # Write output tl.store(Y_base, y, mask=mask) @@ -178,6 +185,7 @@ def layer_norm_fwd( group_size: int = None, norm_before_gate: bool = True, is_rms_norm: bool = False, + activation: str = "swish", ): M, N = x.shape if group_size is None: @@ -232,9 +240,12 @@ def layer_norm_fwd( eps, BLOCK_N=BLOCK_N, ROWS_PER_BLOCK=rows_per_block, + HAS_BIAS=bias is not None, + HAS_Z=z is not None, NORM_BEFORE_GATE=norm_before_gate, IS_RMS_NORM=is_rms_norm, num_warps=num_warps, + ACTIVATION=activation, ) return out, mean, rstd @@ -252,6 +263,7 @@ class LayerNormFn(torch.autograd.Function): group_size=None, norm_before_gate=True, is_rms_norm=False, + activation: str = "swish", ): """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" @@ -277,6 +289,7 @@ class LayerNormFn(torch.autograd.Function): group_size=group_size, norm_before_gate=norm_before_gate, is_rms_norm=is_rms_norm, + activation=activation, ) ctx.save_for_backward(x, weight, bias, mean, rstd, z) ctx.x_shape_og = x_shape_og @@ -284,6 +297,7 @@ class LayerNormFn(torch.autograd.Function): ctx.group_size = group_size ctx.norm_before_gate = norm_before_gate ctx.is_rms_norm = is_rms_norm + ctx.activation = activation return y.reshape(x_shape_og) @@ -296,17 +310,25 @@ def layernorm_fn( group_size=None, norm_before_gate=True, is_rms_norm=False, + activation: str = "swish", ): return LayerNormFn.apply( - x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm + x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm, activation ) def rmsnorm_fn( - x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True + x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + activation: str = "swish", ): return LayerNormFn.apply( - x, weight, bias, z, eps, group_size, norm_before_gate, True + x, weight, bias, z, eps, group_size, norm_before_gate, True, activation ) @@ -359,6 +381,7 @@ class RMSNormGated(nn.Module): norm_before_gate: bool = False, device: torch.device | None = None, dtype: torch.dtype | None = None, + activation: str = "swish", ): """If group_size is not None, we do GroupNorm with each group having group_size elements. group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). @@ -366,6 +389,7 @@ class RMSNormGated(nn.Module): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.eps = eps + self.activation = activation self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.register_parameter("bias", None) self.group_size = group_size @@ -385,4 +409,5 @@ class RMSNormGated(nn.Module): eps=self.eps, group_size=self.group_size, norm_before_gate=self.norm_before_gate, + activation=self.activation, ) diff --git a/vllm/model_executor/layers/fla/ops/wy_fast.py b/vllm/model_executor/layers/fla/ops/wy_fast.py index a66ec1d..6baa08a 100644 --- a/vllm/model_executor/layers/fla/ops/wy_fast.py +++ b/vllm/model_executor/layers/fla/ops/wy_fast.py @@ -122,7 +122,7 @@ def recompute_w_u_fwd( beta: torch.Tensor, g_cumsum: torch.Tensor, A: torch.Tensor, - cu_seqlens: torch.LongTensor | None, + cu_seqlens: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: B, T, Hg, K, V = *k.shape, v.shape[-1] H = v.shape[-2] diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index c6cb31b..d26ac6e 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -22,12 +22,13 @@ from vllm.model_executor.layers.fused_moe.layer import ( ) from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEActivationFormat, - FusedMoEPermuteExpertsUnpermute, - FusedMoEPrepareAndFinalize, + FusedMoEExpertsModular, + FusedMoEPrepareAndFinalizeModular, ) from vllm.model_executor.layers.fused_moe.router.fused_moe_router import ( FusedMoERouter, ) +from vllm.model_executor.layers.fused_moe.router.gate_linear import GateLinear from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( UnquantizedFusedMoEMethod, @@ -61,9 +62,10 @@ __all__ = [ "MoEActivation", "UnquantizedFusedMoEMethod", "FusedMoeWeightScaleSupported", - "FusedMoEPermuteExpertsUnpermute", + "FusedMoEExpertsModular", "FusedMoEActivationFormat", - "FusedMoEPrepareAndFinalize", + "FusedMoEPrepareAndFinalizeModular", + "GateLinear", "RoutingMethodType", "SharedFusedMoE", "ZeroExpertFusedMoE", @@ -137,4 +139,4 @@ else: raise NotImplementedError(f"{method} is not implemented as lack of triton.") fused_topk = lambda *args, **kwargs: _raise_exception("fused_topk") - fused_experts = lambda *args, **kwargs: _raise_exception("fused_experts") + fused_experts = lambda *args, **kwargs: _raise_exception("fused_experts") \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/activation.py b/vllm/model_executor/layers/fused_moe/activation.py index c11f16d..021b066 100644 --- a/vllm/model_executor/layers/fused_moe/activation.py +++ b/vllm/model_executor/layers/fused_moe/activation.py @@ -6,8 +6,7 @@ from enum import Enum import torch import torch.nn.functional as F - -from vllm._custom_ops import silu_and_mul, gelu_and_mul, swigluoai_and_mul +from vllm import _custom_ops as ops class MoEActivation(Enum): @@ -114,14 +113,11 @@ def apply_moe_activation( # Activations with gated multiplication (gate × activation(up)) if activation == MoEActivation.SILU: - # torch.ops._C.silu_and_mul(output, input) - silu_and_mul(output, input) + ops.silu_and_mul(output, input) elif activation == MoEActivation.GELU: - # torch.ops._C.gelu_and_mul(output, input) - gelu_and_mul(output, input) + ops.gelu_and_mul(output, input) elif activation == MoEActivation.SWIGLUOAI: - # torch.ops._C.swigluoai_and_mul(output, input) - swigluoai_and_mul(output, input) + ops.swigluoai_and_mul(output, input) elif activation == MoEActivation.SWIGLUSTEP: from vllm.model_executor.layers.activation import swiglustep_and_mul_triton diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index bf8ec2d..47ca95e 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any import torch @@ -20,20 +21,15 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEPrepareAndFinalize, ) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNaiveEP, - MoEPrepareAndFinalizeNoEP, + make_moe_prepare_and_finalize_naive_dp_ep, + make_moe_prepare_and_finalize_no_dp_ep, ) from vllm.platforms import current_platform -from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx +from vllm.utils.import_utils import has_deep_ep, has_mori logger = init_logger(__name__) if current_platform.is_cuda_alike(): - if has_pplx(): - from .pplx_prepare_finalize import ( - PplxPrepareAndFinalize, - pplx_hidden_dim_scale_bytes, - ) if has_deep_ep(): from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize from .deepep_ll_prepare_finalize import ( @@ -81,6 +77,7 @@ def maybe_make_prepare_finalize( quant_config: FusedMoEQuantConfig | None, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, allow_new_interface: bool = False, + use_monolithic: bool = False, ) -> FusedMoEPrepareAndFinalize | None: # NOTE(rob): we are migrating each quant_method to hold the MK # in all cases. The allow_new_interface=False flag allow us to fall @@ -106,65 +103,25 @@ def maybe_make_prepare_finalize( "Detected DP deployment with no --enable-expert-parallel. " "Falling back to AllGather+ReduceScatter dispatch/combine." ) - return MoEPrepareAndFinalizeNaiveEP( + return make_moe_prepare_and_finalize_naive_dp_ep( is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel, num_dispatchers=( get_ep_group().device_communicator.all2all_manager.world_size ), + use_monolithic=use_monolithic, ) else: - return MoEPrepareAndFinalizeNoEP() + return make_moe_prepare_and_finalize_no_dp_ep(use_monolithic) all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None prepare_finalize: FusedMoEPrepareAndFinalize | None = None - if moe.use_pplx_kernels: - assert quant_config is not None - - hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes( - moe.max_num_tokens, - moe.hidden_dim, - moe.in_dtype, - quant_config.quant_dtype, - per_act_token_quant=quant_config.per_act_token_quant, - block_shape=quant_config.block_shape, - ) - - all_to_all_args = dict( - max_num_tokens=moe.max_num_tokens, - num_experts=moe.num_experts, - experts_per_token=moe.experts_per_token, # topk - rank=all2all_manager.rank, - world_size=all2all_manager.world_size, - # dp_size actually means tp_size, bug in pplx kernels - dp_size=all2all_manager.tp_group.world_size, - hidden_dim=moe.hidden_dim, - hidden_dim_bytes=hidden_dim_bytes, - hidden_dim_scale_bytes=hidden_scale_bytes, - ) - - num_dispatchers = ( - all2all_manager.world_size // all2all_manager.tp_group.world_size - ) - - # Intranode pplx a2a takes a group name while internode does not. - if not all2all_manager.internode: - all_to_all_args["group_name"] = all2all_manager.cpu_group.group_name - - handle = all2all_manager.get_handle(all_to_all_args) - - prepare_finalize = PplxPrepareAndFinalize( - handle, - max_num_tokens=moe.max_num_tokens, - num_local_experts=moe.num_local_experts, - num_dispatchers=num_dispatchers, - ) - elif moe.use_deepep_ht_kernels: + if moe.use_deepep_ht_kernels: assert moe.dp_size == all2all_manager.dp_world_size - all_to_all_args = dict() + all_to_all_args: dict[str, Any] = dict() handle = all2all_manager.get_handle(all_to_all_args) prepare_finalize = DeepEPHTPrepareAndFinalize( handle, @@ -246,8 +203,9 @@ def maybe_make_prepare_finalize( ) elif moe.use_naive_all2all_kernels and allow_new_interface: - prepare_finalize = MoEPrepareAndFinalizeNaiveEP( - is_sequence_parallel=(moe.moe_parallel_config.is_sequence_parallel), + prepare_finalize = make_moe_prepare_and_finalize_naive_dp_ep( + use_monolithic=use_monolithic, + is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel, num_dispatchers=all2all_manager.world_size, ) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 405965c..5397125 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -261,7 +261,7 @@ def persistent_masked_m_silu_mul_quant( return y_q, y_s -class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): +class BatchedDeepGemmExperts(mk.FusedMoEExpertsModular): def __init__( self, moe_config: FusedMoEConfig, diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 87e1e24..e0ed913 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -228,6 +228,7 @@ class FusedMoEQuantConfig: _a2: FusedMoEQuantDesc _w1: FusedMoEQuantDesc _w2: FusedMoEQuantDesc + is_nvfp4_scale_swizzled: bool = True def __post_init__(self): assert not self.per_act_token_quant or self.block_shape is None, ( @@ -475,6 +476,7 @@ class FusedMoEQuantConfig: w1_zp: torch.Tensor | None = None, w2_zp: torch.Tensor | None = None, weight_dtype: torch.dtype | str | None = None, + is_nvfp4_scale_swizzled: bool = True, ) -> "FusedMoEQuantConfig": """ General builder function for a FusedMoEQuantConfig. @@ -504,6 +506,7 @@ class FusedMoEQuantConfig: - w2_bias: Optional biases for w1 (GPT OSS Triton). - w1_zp: Optional w1 zero points for int4/int8 quantization. - w2_zp: Optional w2 zero points for int4/int8 quantization. + - is_nvfp4_scale_swizzled: Whether to swizzle the nvfp4 scale swizzling. """ assert not isinstance(quant_dtype, str) or quant_dtype in { "nvfp4", @@ -536,6 +539,7 @@ class FusedMoEQuantConfig: _w2=FusedMoEQuantDesc( weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias ), + is_nvfp4_scale_swizzled=is_nvfp4_scale_swizzled, ) assert quant_config.per_act_token_quant == per_act_token_quant assert quant_config.per_out_ch_quant == per_out_ch_quant @@ -737,6 +741,7 @@ def nvfp4_moe_quant_config( w2_scale: torch.Tensor, w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, + is_nvfp4_scale_swizzled: bool = True, ) -> FusedMoEQuantConfig: """ Construct a quant config for mxfp4 activations and nvp4 weights. @@ -754,6 +759,7 @@ def nvfp4_moe_quant_config( per_act_token_quant=False, per_out_ch_quant=False, block_shape=None, + is_nvfp4_scale_swizzled=is_nvfp4_scale_swizzled, ) @@ -939,10 +945,6 @@ class FusedMoEParallelConfig: def use_all2all_kernels(self): return self.dp_size > 1 and self.use_ep - @property - def use_pplx_kernels(self): - return self.use_all2all_kernels and self.all2all_backend == "pplx" - @property def use_deepep_ht_kernels(self): return ( @@ -962,7 +964,7 @@ class FusedMoEParallelConfig: @property def use_batched_activation_format(self): - return self.use_deepep_ll_kernels or self.use_pplx_kernels + return self.use_deepep_ll_kernels @property def use_naive_all2all_kernels(self): @@ -1221,10 +1223,6 @@ class FusedMoEConfig: def use_ep(self): return self.moe_parallel_config.use_ep - @property - def use_pplx_kernels(self): - return self.moe_parallel_config.use_pplx_kernels - @property def use_deepep_ht_kernels(self): return self.moe_parallel_config.use_deepep_ht_kernels diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 0000000..620fe93 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H200.json b/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H200.json new file mode 100644 index 0000000..fc7dda8 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H200.json @@ -0,0 +1,147 @@ +{ + "triton_version": "3.6.0", + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index ae9430d..64848bf 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -21,7 +21,7 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( moe_unpermute, ) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, + MoEPrepareAndFinalizeNoDPEPModular, ) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, @@ -166,7 +166,7 @@ def run_cutlass_moe_fp8( problem_sizes1 = torch.empty((local_E, 3), dtype=torch.int32, device=device) problem_sizes2 = torch.empty((local_E, 3), dtype=torch.int32, device=device) - ops.get_cutlass_pplx_moe_mm_data( + ops.get_cutlass_batched_moe_mm_data( expert_offsets, problem_sizes1, problem_sizes2, @@ -262,7 +262,7 @@ def run_cutlass_moe_fp8( ) -class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute): +class CutlassExpertsFp8Base(mk.FusedMoEExpertsModular): def __init__( self, moe_config: FusedMoEConfig, @@ -661,7 +661,7 @@ def run_cutlass_moe_fp4( return -class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): +class CutlassExpertsFp4(mk.FusedMoEExpertsModular): """CUTLASS FP4 fused MoE expert implementation.""" @property @@ -928,7 +928,7 @@ def run_cutlass_moe_w4a8_fp8( ) -class CutlassExpertsW4A8Fp8(mk.FusedMoEPermuteExpertsUnpermute): +class CutlassExpertsW4A8Fp8(mk.FusedMoEExpertsModular): def __init__( self, out_dtype: torch.dtype | None, @@ -1170,8 +1170,8 @@ def cutlass_moe_w4a8_fp8( num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(0) - fn = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), + fn = mk.FusedMoEKernel( + MoEPrepareAndFinalizeNoDPEPModular(), CutlassExpertsW4A8Fp8( out_dtype=a.dtype, a_strides1=a_strides1, @@ -1186,10 +1186,9 @@ def cutlass_moe_w4a8_fp8( quant_config=quant_config, group_size=group_size, ), - inplace=False, ) - return fn( + return fn.apply( a, w1_q, w2_q, diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 69ca7c9..8af439a 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -113,7 +113,7 @@ def _valid_deep_gemm( return True -class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): +class DeepGemmExperts(mk.FusedMoEExpertsModular): """DeepGemm-based fused MoE expert implementation.""" def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig): diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 514aa20..1f9e79c 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -25,7 +25,7 @@ from vllm.v1.worker.ubatching import ( ) -class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): +class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): """ Prepare/Finalize using DeepEP High-Throughput kernels. """ @@ -123,7 +123,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): is_token_in_rank, event, ) = self.buffer.get_dispatch_layout( - topk_idx=rank_topk_ids, + topk_idx=rank_topk_ids.long(), num_experts=num_experts, previous_event=previous_event, async_finish=False, @@ -148,7 +148,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, is_token_in_rank=is_token_in_rank, num_tokens_per_expert=dispatch_expert_num_tokens, - topk_idx=rank_topk_ids, + topk_idx=rank_topk_ids.long(), topk_weights=rank_topk_weights, # expert_alignment rounds the number of tokens per expert # to this value. @@ -169,7 +169,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): event, has_scales, token_data, - expert_topk_ids, + expert_topk_ids.int(), num_experts, expert_num_tokens_per_expert_list, expert_topk_weights, @@ -239,6 +239,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): quant_dtype=quant_config.quant_dtype, per_act_token_quant=False, block_shape=quant_config.block_shape, + is_fp4_scale_swizzled=quant_config.is_nvfp4_scale_swizzled, ) return ( diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index a4cee76..d95e481 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -49,7 +49,7 @@ def dequant_fp8( return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.size()) -class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): +class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): """ Prepare/Finalize using DeepEP low-latency kernels. """ @@ -119,7 +119,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): # time. This setting is handled by post_init_setup. self.use_ue8m0_dispatch = False - def post_init_setup(self, fused_experts: mk.FusedMoEPermuteExpertsUnpermute): + def post_init_setup(self, fused_experts: mk.FusedMoEExperts): if not fused_experts.supports_packed_ue8m0_act_scales(): # Early exit. return @@ -297,12 +297,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): dispatch_topk_ids = self._map_global_to_physical_ids(topk_ids) expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch( a1, - dispatch_topk_ids, + dispatch_topk_ids.long(), self.max_tokens_per_rank, num_experts, use_fp8=self.use_fp8_dispatch, - round_scale=self.use_ue8m0_dispatch, - use_ue8m0=self.use_ue8m0_dispatch, + # round_scale=self.use_ue8m0_dispatch, + # use_ue8m0=self.use_ue8m0_dispatch, **(dict(use_nvfp4=True) if use_nvfp4 else dict()), **( dict(x_global_scale=qc_a1_gscale_or_scale) @@ -398,7 +398,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): dbo_maybe_run_recv_hook() _, _, recv_hook = self.buffer.low_latency_combine( fused_expert_output, - combine_topk_ids, + combine_topk_ids.long(), combine_topk_weights, handle, async_finish=False, diff --git a/vllm/model_executor/layers/fused_moe/experts/__init__.py b/vllm/model_executor/layers/fused_moe/experts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py new file mode 100644 index 0000000..febb3b2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py @@ -0,0 +1,335 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.activation import MoEActivation +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, + RoutingMethodType, +) +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + activation_to_flashinfer_int, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kFp8Dynamic128Sym, + kFp8Static128BlockSym, + kFp8StaticTensorSym, +) +from vllm.platforms import current_platform + + +class TrtLlmFp8Experts(mk.FusedMoEExpertsMonolithic): + """ + Fp8 TRTLLM-Gen MoE kernels. Supports monolithic interface. + """ + + def __init__( + self, + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + ): + super().__init__(moe_config, quant_config) + + if moe_config.moe_parallel_config.use_ep and quant_config.is_per_tensor: + raise NotImplementedError( + "EP parallelism is not supported with TRTLLM" + "per-tensor FP8 quantization." + ) + + self.routing_method_type = moe_config.routing_method + self.topk = moe_config.experts_per_token + self.intermediate_size_per_partition = ( + moe_config.intermediate_size_per_partition + ) + self.hidden_dim = moe_config.hidden_dim + self.local_num_experts = moe_config.num_local_experts + self.ep_rank = moe_config.moe_parallel_config.ep_rank + + # Make additional scales for per-tensor interface. + if self.quant_config.is_per_tensor: + w1_scale = self.quant_config.w1_scale + assert w1_scale is not None + a1_scale = self.quant_config.a1_scale + assert a1_scale is not None + w2_scale = self.quant_config.w2_scale + assert w2_scale is not None + a2_scale = self.quant_config.a2_scale + assert a2_scale is not None + + self._g1_alphas = (w1_scale * a1_scale).squeeze() + self._g2_alphas = (w2_scale * a2_scale).squeeze() + self._g1_scale_c = ( + self._g1_alphas / self.quant_config.a2_scale + if moe_config.is_act_and_mul + else torch.ones_like(self._g1_alphas) / self.quant_config.a2_scale + ) + + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + @staticmethod + def _supports_current_device() -> bool: + """Supports only Blackwell-family GPUs.""" + p = current_platform + # Add check flashinfer trtllm is available + return p.is_cuda() and p.is_device_capability_family(100) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + """Does not support non-gated MoE (i.e. Nanotron-3-Nano).""" + return True + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + """Supports Fp8 per-tensor and Fp8 block.""" + SUPPORTED_W_A = [ + (kFp8Static128BlockSym, kFp8Dynamic128Sym), + (kFp8StaticTensorSym, kFp8StaticTensorSym), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + + @staticmethod + def _supports_activation(activation: MoEActivation) -> bool: + """Supports only SiLU and RELU^2 non-gated activation.""" + return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] + + @staticmethod + def _supports_routing_method( + routing_method: RoutingMethodType, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + """Monolithic kernels need to express router support.""" + # NOTE(dbari): TopK routing could also be enabled, but need to validate models + # NOTE(dbari): Default is not implemented and should not be enabled until it is + if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym): + # NOTE(rob): potentially allow others here. This is a conservative list. + return routing_method in [ + RoutingMethodType.DeepSeekV3, + RoutingMethodType.Renormalize, + RoutingMethodType.RenormalizeNaive, + ] + elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym): + # NOTE(dbari): as above, potentially allow others here. + return routing_method in [ + RoutingMethodType.DeepSeekV3, + RoutingMethodType.Llama4, + RoutingMethodType.Renormalize, + RoutingMethodType.RenormalizeNaive, + ] + else: + raise ValueError("Unsupported quantization scheme.") + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + """Monolithic kernel so only use with naive DP/EP and TP.""" + return ( + not moe_parallel_config.use_all2all_kernels + or moe_parallel_config.use_naive_all2all_kernels + ) and not moe_parallel_config.enable_eplb + + @staticmethod + def _supports_router_logits_dtype( + router_logits_dtype: torch.dtype | None, + routing_method: RoutingMethodType, + ) -> bool: + """ + The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default. + Only DeepSeekV3 routing supports float32 router_logits (which is converted + internally in the kernel). + """ + if router_logits_dtype == torch.float32: + # Only DeepSeekV3 routing handles float32 logits + # https://github.com/flashinfer-ai/flashinfer/issues/2469 + return routing_method == RoutingMethodType.DeepSeekV3 + return True + + def supports_chunking(self) -> bool: + return False + + def supports_expert_map(self) -> bool: + return False + + def _apply_per_block( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + apply_router_weight_on_input: bool, + # grouped topk + fused topk bias parameters + num_expert_group: int | None = None, + e_score_correction_bias: torch.Tensor | None = None, + routed_scaling_factor: float | None = None, + topk_group: int | None = None, + ) -> torch.Tensor: + # Delay import for non-CUDA. + import flashinfer + + assert not apply_router_weight_on_input + assert activation == MoEActivation.SILU + + if e_score_correction_bias is not None: + e_score_correction_bias = e_score_correction_bias.to(hidden_states.dtype) + + if self.routing_method_type == RoutingMethodType.DeepSeekV3: + router_logits = router_logits.to(torch.float32) + + assert self.topk <= global_num_experts + assert self.topk <= 10 + assert global_num_experts % 4 == 0 + assert self.quant_config.block_shape == [128, 128] + # Routing kernel expects #experts <= #threads 512 + assert global_num_experts <= 512 + + # Kernel requires transposed hidden state scales + # TODO: fuse into the quant kernel. + assert a1q_scale is not None + a1q_scale_t = a1q_scale.t().contiguous() + + return flashinfer.fused_moe.trtllm_fp8_block_scale_moe( + routing_logits=router_logits, + routing_bias=e_score_correction_bias, + hidden_states=hidden_states, + hidden_states_scale=a1q_scale_t, + gemm1_weights=w1, + gemm1_weights_scale=self.quant_config.w1_scale, + gemm2_weights=w2, + gemm2_weights_scale=self.quant_config.w2_scale, + num_experts=global_num_experts, + top_k=self.topk, + n_group=(num_expert_group or 0), + topk_group=(topk_group or 0), + intermediate_size=self.intermediate_size_per_partition, + local_expert_offset=self.ep_rank * self.local_num_experts, + local_num_experts=self.local_num_experts, + routed_scaling_factor=routed_scaling_factor, + routing_method_type=self.routing_method_type, + use_shuffled_weight=False, + ) + + def _apply_per_tensor( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + apply_router_weight_on_input: bool, + # grouped topk + fused topk bias parameters + num_expert_group: int | None = None, + e_score_correction_bias: torch.Tensor | None = None, + routed_scaling_factor: float | None = None, + topk_group: int | None = None, + ) -> torch.Tensor: + # Delay import for non-CUDA. + import flashinfer + from flashinfer.fused_moe.core import ActivationType + + # Confirm supported activation function. + assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] + + activation_type = ActivationType(activation_to_flashinfer_int(activation)) + + # Confirm Llama-4 routing is proper. + if self.routing_method_type == RoutingMethodType.Llama4: + assert apply_router_weight_on_input + else: + assert not apply_router_weight_on_input + + # The DeepSeekV3 routing method requires float32 router logits. + if self.routing_method_type == RoutingMethodType.DeepSeekV3: + router_logits = router_logits.to(torch.float32) + + out = flashinfer.fused_moe.trtllm_fp8_per_tensor_scale_moe( + routing_logits=router_logits, + routing_bias=e_score_correction_bias, + hidden_states=hidden_states, + gemm1_weights=w1, + output1_scales_scalar=self._g1_scale_c, + output1_scales_gate_scalar=self._g1_alphas, + gemm2_weights=w2, + output2_scales_scalar=self._g2_alphas, + num_experts=global_num_experts, + top_k=self.topk, + n_group=num_expert_group or 0, + topk_group=topk_group or 0, + intermediate_size=self.intermediate_size_per_partition, + local_expert_offset=self.ep_rank * self.local_num_experts, + local_num_experts=self.local_num_experts, + routed_scaling_factor=routed_scaling_factor, + use_routing_scales_on_input=apply_router_weight_on_input, + routing_method_type=self.routing_method_type, + activation_type=activation_type, + ) + return out + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + apply_router_weight_on_input: bool, + # grouped topk + fused topk bias parameters + num_expert_group: int | None = None, + e_score_correction_bias: torch.Tensor | None = None, + routed_scaling_factor: float | None = None, + topk_group: int | None = None, + ) -> torch.Tensor: + if self.quant_config.block_shape is not None: + return self._apply_per_block( + hidden_states, + w1, + w2, + router_logits, + activation, + global_num_experts, + expert_map, + a1q_scale, + apply_router_weight_on_input, + num_expert_group=num_expert_group, + e_score_correction_bias=e_score_correction_bias, + routed_scaling_factor=routed_scaling_factor, + topk_group=topk_group, + ) + elif self.quant_config.is_per_tensor: + return self._apply_per_tensor( + hidden_states, + w1, + w2, + router_logits, + activation, + global_num_experts, + expert_map, + a1q_scale, + apply_router_weight_on_input, + num_expert_group=num_expert_group, + e_score_correction_bias=e_score_correction_bias, + routed_scaling_factor=routed_scaling_factor, + ) + else: + raise NotImplementedError( + "Only per-block and per-tensor quantization are supported in " + f"{self.__class__.__name__}." + ) diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py new file mode 100644 index 0000000..5026717 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py @@ -0,0 +1,326 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import flashinfer +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.activation import MoEActivation +from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEConfig, + FusedMoEParallelConfig, + FusedMoEQuantConfig, + RoutingMethodType, +) +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceNoOP, +) +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + activation_to_flashinfer_int, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kNvfp4Dynamic, + kNvfp4Static, +) +from vllm.platforms import current_platform + + +class TrtLlmNvFp4ExpertsBase: + """ + NvFp4 TRTLLM-Gen MoE kernels. Supports modular and monolithic interface. + """ + + def __init__( + self, + moe_config: FusedMoEConfig, + quant_config: FusedMoEQuantConfig, + ): + self.moe_config = moe_config + self.quant_config = quant_config + + self.routing_method_type = self.moe_config.routing_method + self.topk = moe_config.experts_per_token + self.intermediate_size_per_partition = ( + moe_config.intermediate_size_per_partition + ) + self.hidden_dim = moe_config.hidden_dim + self.local_num_experts = moe_config.num_local_experts + self.ep_rank = moe_config.moe_parallel_config.ep_rank + + assert self.quant_config.g1_alphas is not None + assert self.quant_config.a2_gscale is not None + if moe_config.is_act_and_mul: + # g1_alpha_s = a13_scale * w13_scale_2 + # a2_gscale = (1 / a2_scale) + # g1_scale_c = a13_scale * w13_scale_2 / a2_scale + self.g1_scale_c = self.quant_config.g1_alphas * self.quant_config.a2_gscale + else: + self.g1_scale_c = ( + torch.ones_like(self.quant_config.a1_gscale) + * self.quant_config.a2_gscale + ) + + @staticmethod + def _supports_current_device() -> bool: + """Supports only Blackwell-family GPUs.""" + p = current_platform + return p.is_cuda() and p.is_device_capability_family(100) + + @staticmethod + def _supports_no_act_and_mul() -> bool: + """Supports non-gated MoE (i.e. Nemotron-Nano).""" + return True + + @staticmethod + def _supports_quant_scheme( + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + """Supports Nvfp4 quantization.""" + SUPPORTED_W_A = [ + (kNvfp4Static, kNvfp4Dynamic), + ] + return (weight_key, activation_key) in SUPPORTED_W_A + + @staticmethod + def _supports_activation(activation: MoEActivation) -> bool: + """Supports only SiLU and RELU^2 non-gated activation.""" + return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] + + @staticmethod + def _supports_shape(hidden_dim: int) -> bool: + """Requires hidden dim to be multiple of 512.""" + return hidden_dim % 512 == 0 + + @staticmethod + def activation_format() -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + def supports_chunking(self) -> bool: + return False + + def supports_expert_map(self) -> bool: + return False + + +class TrtLlmNvFp4ExpertsModular(TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsModular): + """ + Modular version of the implementation (just the experts). + """ + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + """The modular implementation supports all parallel configs.""" + return True + + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + activation: MoEActivation, + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + # The workspaces for this implementation are managed by flashinfer. + workspace1 = (0,) + workspace2 = (0,) + + # Hidden states are Nvfp4, packed into int8 dtype, so we + # need to multiply K by 2 to get the output shape right. + assert self.hidden_dim == K * 2 + output = (M, self.hidden_dim) + + return (workspace1, workspace2, output) + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + return TopKWeightAndReduceNoOP() + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + apply_router_weight_on_input: bool, + ): + assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] + assert a1q_scale is not None + assert self.quant_config.w1_scale is not None + assert self.quant_config.w2_scale is not None + + # Pack topk ids and weights into format expected by the kernel. + packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( + torch.bfloat16 + ).view(torch.int16) + + # trtllm_fp4_block_scale_routed_moe does not support autotuning + # so skip this kernel during dummy run for autotuning. + import vllm.utils.flashinfer as fi_utils + + if fi_utils._is_fi_autotuning: + return hidden_states + + # Invoke kernel. + flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe( + topk_ids=packed_tensor, + routing_bias=None, + hidden_states=hidden_states, + hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape( + *hidden_states.shape[:-1], -1 + ), + gemm1_weights=w1, + gemm1_weights_scale=self.quant_config.w1_scale.view(torch.float8_e4m3fn), + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=w2, + gemm2_weights_scale=self.quant_config.w2_scale.view(torch.float8_e4m3fn), + gemm2_bias=None, + output1_scale_scalar=self.g1_scale_c, + output1_scale_gate_scalar=self.quant_config.g1_alphas, + output2_scale_scalar=self.quant_config.g2_alphas, + num_experts=global_num_experts, + top_k=self.topk, + n_group=0, + topk_group=0, + intermediate_size=self.intermediate_size_per_partition, + local_expert_offset=self.ep_rank * self.local_num_experts, + local_num_experts=self.local_num_experts, + routed_scaling_factor=None, + routing_method_type=1, + do_finalize=True, + activation_type=activation_to_flashinfer_int(activation), + output=output, + ) + + +class TrtLlmNvFp4ExpertsMonolithic( + TrtLlmNvFp4ExpertsBase, mk.FusedMoEExpertsMonolithic +): + """ + Monolithic version of the kernel (router + experts). + """ + + @staticmethod + def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: + """The modular implementation should be used for the Dp/Ep or EPLB case.""" + return ( + not moe_parallel_config.use_all2all_kernels + and not moe_parallel_config.enable_eplb + ) + + @staticmethod + def _supports_routing_method( + routing_method_type: RoutingMethodType, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + # NOTE(rob): this is a conservative list. + return routing_method_type in [ + RoutingMethodType.DeepSeekV3, + RoutingMethodType.Renormalize, + RoutingMethodType.RenormalizeNaive, + RoutingMethodType.Llama4, + ] + + @staticmethod + def _supports_router_logits_dtype( + router_logits_dtype: torch.dtype | None, + routing_method: RoutingMethodType, + ) -> bool: + """ + The FlashInfer TRTLLM NvFp4 kernel expects bfloat16 router_logits by default. + Only DeepSeekV3 routing supports float32 router_logits (which is converted + internally in the kernel). + """ + if router_logits_dtype == torch.float32: + # Only DeepSeekV3 routing handles float32 logits + # https://github.com/flashinfer-ai/flashinfer/issues/2469 + return routing_method == RoutingMethodType.DeepSeekV3 + return True + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + apply_router_weight_on_input: bool, + # grouped topk + fused topk bias parameters + num_expert_group: int | None = None, + e_score_correction_bias: torch.Tensor | None = None, + routed_scaling_factor: float | None = None, + topk_group: int | None = None, + ) -> torch.Tensor: + assert activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] + assert a1q_scale is not None + assert self.quant_config.w1_scale is not None + assert self.quant_config.w2_scale is not None + assert ( + apply_router_weight_on_input + and self.routing_method_type == RoutingMethodType.Llama4 + ) or ( + not apply_router_weight_on_input + and self.routing_method_type != RoutingMethodType.Llama4 + ) + + # Prepare routing bias into kernel format. + routing_bias = e_score_correction_bias + if routing_bias is not None: + routing_bias = routing_bias.to(torch.bfloat16) + router_logits = ( + router_logits.to(torch.float32) + if self.routing_method_type == RoutingMethodType.DeepSeekV3 + else router_logits + ) + + # Invoke kernel. + return flashinfer.fused_moe.trtllm_fp4_block_scale_moe( + routing_logits=router_logits, + routing_bias=routing_bias, + hidden_states=hidden_states, + hidden_states_scale=a1q_scale.view(torch.float8_e4m3fn).reshape( + *hidden_states.shape[:-1], -1 + ), + gemm1_weights=w1, + gemm1_weights_scale=self.quant_config.w1_scale.view(torch.float8_e4m3fn), + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=w2, + gemm2_weights_scale=self.quant_config.w2_scale.view(torch.float8_e4m3fn), + gemm2_bias=None, + output1_scale_scalar=self.g1_scale_c, + output1_scale_gate_scalar=self.quant_config.g1_alphas, + output2_scale_scalar=self.quant_config.g2_alphas, + num_experts=global_num_experts, + top_k=self.topk, + n_group=(num_expert_group or 0), + topk_group=(topk_group or 0), + intermediate_size=self.intermediate_size_per_partition, + local_expert_offset=self.ep_rank * self.local_num_experts, + local_num_experts=self.local_num_experts, + routed_scaling_factor=routed_scaling_factor, + routing_method_type=self.routing_method_type, + do_finalize=True, + )[0] diff --git a/vllm/model_executor/layers/fused_moe/fallback.py b/vllm/model_executor/layers/fused_moe/fallback.py index 4b6458e..403a71e 100644 --- a/vllm/model_executor/layers/fused_moe/fallback.py +++ b/vllm/model_executor/layers/fused_moe/fallback.py @@ -11,13 +11,13 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey -class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC): +class FallbackExperts(mk.FusedMoEExpertsModular, ABC): """Base class for runtime dispatching of expert implementations.""" def __init__( self, - experts: mk.FusedMoEPermuteExpertsUnpermute, - fallback_experts: mk.FusedMoEPermuteExpertsUnpermute, + experts: mk.FusedMoEExpertsModular, + fallback_experts: mk.FusedMoEExpertsModular, ): super().__init__( moe_config=experts.moe_config, quant_config=experts.quant_config @@ -27,8 +27,8 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC): @staticmethod def get_clses() -> tuple[ - type[mk.FusedMoEPermuteExpertsUnpermute], - type[mk.FusedMoEPermuteExpertsUnpermute], + type[mk.FusedMoEExpertsModular], + type[mk.FusedMoEExpertsModular], ]: """ Get the cls for the experts and fallback experts. @@ -149,7 +149,7 @@ class FallbackExperts(mk.FusedMoEPermuteExpertsUnpermute, ABC): hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - ) -> mk.FusedMoEPermuteExpertsUnpermute: + ) -> mk.FusedMoEExpertsModular: raise NotImplementedError def apply( diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py index 39b3738..465d0ae 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_a2a_prepare_finalize.py @@ -18,7 +18,7 @@ def get_local_sizes(): return get_forward_context().dp_metadata.get_chunk_sizes_across_dp_rank() -class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): +class FlashInferA2APrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): """Base class for FlashInfer MoE prepare and finalize operations.""" def __init__( @@ -185,8 +185,8 @@ def flashinfer_alltoall_dispatch( ep_size, ) - # Swizzle after the A2A if nvfp4. - if quant_config.quant_dtype == "nvfp4": + # Swizzle after the A2A if MoE kernel expects swizzled scales. + if quant_config.quant_dtype == "nvfp4" and quant_config.is_nvfp4_scale_swizzled: if x_sf.element_size() == 1: x_sf = x_sf.view(torch.uint8) x_sf = nvfp4_block_scale_interleave(x_sf) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py index d0cf753..730dc0c 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py @@ -30,7 +30,7 @@ from vllm.utils.flashinfer import ( logger = init_logger(__name__) -class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute): +class FlashInferCuteDSLExperts(mk.FusedMoEExpertsModular): def __init__( self, moe_config: FusedMoEConfig, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index b9566a3..02c31fd 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -60,7 +60,7 @@ def is_valid_flashinfer_cutlass_fused_moe( return True -class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): +class FlashInferExperts(mk.FusedMoEExpertsModular): def __init__( self, moe_config: mk.FusedMoEConfig, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py index 732ab8e..6765e36 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py @@ -10,16 +10,6 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEParallelConfig, RoutingMethodType, ) -from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input -from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8, -) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, - kFp8Dynamic128Sym, - kFp8Static128BlockSym, - kFp8StaticTensorSym, -) from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op @@ -39,49 +29,10 @@ def _supports_no_act_and_mul() -> bool: return True -def _supports_quant_scheme( - weight_key: QuantKey | None, - activation_key: QuantKey | None, -) -> bool: - """Supports Fp8 per-tensor and Fp8 block.""" - SUPPORTED_W_A = [ - (kFp8Static128BlockSym, kFp8Dynamic128Sym), - (kFp8StaticTensorSym, kFp8StaticTensorSym), - ] - return (weight_key, activation_key) in SUPPORTED_W_A - - def _supports_activation(activation: MoEActivation) -> bool: return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] -def _supports_routing_method( - weight_key: QuantKey | None, - activation_key: QuantKey | None, - routing_method: RoutingMethodType, -) -> bool: - """Monolithic kernels need to express router support.""" - # NOTE(dbari): TopK routing could also be enabled, but need to validate models - # NOTE(dbari): Default is not implemented and should not be enabled until it is - if (weight_key, activation_key) == (kFp8Static128BlockSym, kFp8Dynamic128Sym): - # NOTE(rob): potentially allow others here. This is a conservative list. - return routing_method in [ - RoutingMethodType.DeepSeekV3, - RoutingMethodType.Renormalize, - RoutingMethodType.RenormalizeNaive, - ] - elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym): - # NOTE(dbari): as above, potentially allow others here. - return routing_method in [ - RoutingMethodType.DeepSeekV3, - RoutingMethodType.Llama4, - RoutingMethodType.Renormalize, - RoutingMethodType.RenormalizeNaive, - ] - else: - raise ValueError("Unsupported quantization scheme.") - - def _supports_routing_method_bf16( routing_method: RoutingMethodType, ) -> bool: @@ -99,62 +50,6 @@ def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bo return not moe_parallel_config.enable_eplb -def _supports_router_logits_dtype( - router_logits_dtype: torch.dtype | None, - routing_method: RoutingMethodType, -) -> bool: - """ - The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default. - Only DeepSeekV3 routing supports float32 router_logits (which is converted - internally in the kernel). - """ - if router_logits_dtype == torch.float32: - # Only DeepSeekV3 routing handles float32 logits - # https://github.com/flashinfer-ai/flashinfer/issues/2469 - return routing_method == RoutingMethodType.DeepSeekV3 - return True - - -def is_supported_config_trtllm_fp8( - moe_config: FusedMoEConfig, - weight_key: QuantKey | None, - activation_key: QuantKey | None, - activation_format: mk.FusedMoEActivationFormat, -) -> tuple[bool, str | None]: - """ - This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config - """ - - def _make_reason(reason: str) -> str: - return f"kernel does not support {reason}" - - if not _supports_current_device(): - return False, _make_reason(f"current device {current_platform.device_name}") - elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()): - return False, _make_reason("no act_and_mul MLP layer") - elif not _supports_activation(moe_config.activation): - return False, _make_reason(f"{moe_config.activation} activation") - elif not _supports_quant_scheme(weight_key, activation_key): - return False, _make_reason(f"quantization scheme {weight_key}x{activation_key}") - elif not _supports_parallel_config(moe_config.moe_parallel_config): - return False, _make_reason(f"parallel config {moe_config.moe_parallel_config}") - elif not _supports_routing_method( - weight_key, activation_key, moe_config.routing_method - ): - return False, _make_reason(f"routing method {moe_config.routing_method}") - elif activation_format != mk.FusedMoEActivationFormat.Standard: - return False, _make_reason(f"activation format {activation_format}") - elif not _supports_router_logits_dtype( - moe_config.router_logits_dtype, moe_config.routing_method - ): - return False, _make_reason( - "float32 router_logits with non-DeepSeekV3 routing " - f"{moe_config.router_logits_dtype}x{moe_config.routing_method}" - ) - - return True, None - - def is_supported_config_trtllm_bf16( moe_config: FusedMoEConfig, activation_format: mk.FusedMoEActivationFormat, @@ -183,199 +78,6 @@ def is_supported_config_trtllm_bf16( return True, None -def flashinfer_fused_moe_blockscale_fp8( - routing_logits: torch.Tensor, - routing_bias: torch.Tensor | None, - x: torch.Tensor, - w13_weight: torch.Tensor, - w13_weight_scale_inv: torch.Tensor, - w2_weight: torch.Tensor, - w2_weight_scale_inv: torch.Tensor, - global_num_experts: int, - top_k: int, - num_expert_group: int | None, - topk_group: int | None, - intermediate_size: int, - expert_offset: int, - local_num_experts: int, - block_shape: list[int], - routing_method_type: int, - routed_scaling: float | None = 1.0, -) -> torch.Tensor: - from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe - - num_expert_group = num_expert_group if num_expert_group is not None else 0 - topk_group = topk_group if topk_group is not None else 0 - assert top_k <= global_num_experts - assert top_k <= 10 - assert global_num_experts % 4 == 0 - assert block_shape == [128, 128] - # Routing kernel expects #experts <= #threads 512 - assert global_num_experts <= 512 - - # The DeepSeekV3 routing method requires float32 router logits. - if routing_method_type == RoutingMethodType.DeepSeekV3: - routing_logits = routing_logits.to(torch.float32) - - if routing_bias is not None: - routing_bias = routing_bias.to(x.dtype) - - a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1]) - # NOTE: scales of hidden states have to be transposed! - a_sf_t = a_sf.t().contiguous() - return flashinfer_trtllm_fp8_block_scale_moe( - routing_logits=routing_logits, - routing_bias=routing_bias, - hidden_states=a_q, - hidden_states_scale=a_sf_t, - gemm1_weights=w13_weight, - gemm1_weights_scale=w13_weight_scale_inv, - gemm2_weights=w2_weight, - gemm2_weights_scale=w2_weight_scale_inv, - num_experts=global_num_experts, - top_k=top_k, - n_group=num_expert_group, - topk_group=topk_group, - intermediate_size=intermediate_size, - local_expert_offset=expert_offset, - local_num_experts=local_num_experts, - routed_scaling_factor=routed_scaling, - routing_method_type=routing_method_type, - use_shuffled_weight=False, - ) - - -def flashinfer_fused_moe_blockscale_fp8_fake( - routing_logits: torch.Tensor, - routing_bias: torch.Tensor | None, - x: torch.Tensor, - w13_weight: torch.Tensor, - w13_weight_scale_inv: torch.Tensor, - w2_weight: torch.Tensor, - w2_weight_scale_inv: torch.Tensor, - global_num_experts: int, - top_k: int, - num_expert_group: int, - topk_group: int, - intermediate_size: int, - expert_offset: int, - local_num_experts: int, - block_shape: list[int], - routing_method_type: int, - routed_scaling: float = 1.0, -) -> torch.Tensor: - return torch.empty_like(x) - - -# TODO(bnell): Does this really need to be a torch.op? -direct_register_custom_op( - op_name="flashinfer_fused_moe_blockscale_fp8", - op_func=flashinfer_fused_moe_blockscale_fp8, - fake_impl=flashinfer_fused_moe_blockscale_fp8_fake, - tags=(torch.Tag.needs_fixed_stride_order,), -) - - -def fi_trtllm_fp8_per_tensor_moe( - routing_logits: torch.Tensor, - routing_bias: torch.Tensor | None, - hidden_states: torch.Tensor, - input_scale: torch.Tensor, - gemm1_weights: torch.Tensor, - gemm2_weights: torch.Tensor, - output1_scales_scalar: torch.Tensor, - output1_scales_gate_scalar: torch.Tensor, - output2_scales_scalar: torch.Tensor, - num_experts: int, - top_k: int, - num_expert_group: int | None, - topk_group: int | None, - intermediate_size: int, - local_expert_offset: int, - local_num_experts: int, - use_routing_scales_on_input: bool, - routing_method_type: int, - activation_type: int, - routed_scaling_factor: float = 1.0, -) -> torch.Tensor: - num_expert_group = num_expert_group if num_expert_group is not None else 0 - topk_group = topk_group if topk_group is not None else 0 - - quant_hidden_states, _ = moe_kernel_quantize_input( - hidden_states, - input_scale, - quant_dtype=torch.float8_e4m3fn, - per_act_token_quant=False, - ) - - from flashinfer.fused_moe.core import ActivationType - - from vllm.utils.flashinfer import flashinfer_trtllm_fp8_per_tensor_scale_moe - - # The DeepSeekV3 routing method requires float32 router logits. - if routing_method_type == RoutingMethodType.DeepSeekV3: - routing_logits = routing_logits.to(torch.float32) - - return flashinfer_trtllm_fp8_per_tensor_scale_moe( - routing_logits=routing_logits, - routing_bias=routing_bias, - hidden_states=quant_hidden_states, - gemm1_weights=gemm1_weights, - output1_scales_scalar=output1_scales_scalar, - output1_scales_gate_scalar=output1_scales_gate_scalar, - gemm2_weights=gemm2_weights, - output2_scales_scalar=output2_scales_scalar, - num_experts=num_experts, - top_k=top_k, - n_group=num_expert_group, - topk_group=topk_group, - intermediate_size=intermediate_size, - local_expert_offset=local_expert_offset, - local_num_experts=local_num_experts, - routed_scaling_factor=routed_scaling_factor, - use_routing_scales_on_input=use_routing_scales_on_input, - routing_method_type=routing_method_type, - # TODO: enum type Required for flashinfer==0.6.3, remove with update - # https://github.com/flashinfer-ai/flashinfer/pull/2508 - activation_type=ActivationType(activation_type), - ) - - -def fi_trtllm_fp8_per_tensor_moe_fake( - routing_logits: torch.Tensor, - routing_bias: torch.Tensor | None, - hidden_states: torch.Tensor, - input_scale: torch.Tensor, - gemm1_weights: torch.Tensor, - gemm2_weights: torch.Tensor, - output1_scales_scalar: torch.Tensor, - output1_scales_gate_scalar: torch.Tensor, - output2_scales_scalar: torch.Tensor, - num_experts: int, - top_k: int, - num_expert_group: int | None, - topk_group: int | None, - intermediate_size: int, - local_expert_offset: int, - local_num_experts: int, - use_routing_scales_on_input: bool, - routing_method_type: int, - activation_type: int, - routed_scaling_factor: float = 1.0, -) -> torch.Tensor: - return torch.empty_like(hidden_states) - - -# TODO(bnell): Does this really need to be a torch.op? -direct_register_custom_op( - op_name="fi_trtllm_fp8_per_tensor_moe", - op_func=fi_trtllm_fp8_per_tensor_moe, - mutates_args=["hidden_states"], - fake_impl=fi_trtllm_fp8_per_tensor_moe_fake, - tags=(torch.Tag.needs_fixed_stride_order,), -) - - def flashinfer_fused_moe_bf16( routing_logits: torch.Tensor, routing_bias: torch.Tensor | None, diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index fbd47f8..68393f7 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -489,11 +489,11 @@ def invoke_moe_batched_triton_kernel( ) -class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): +class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): """ A reference prepare/finalize class that reorganizes the tokens into expert batched format, i.e. E x max_num_tokens x K. This is the format - that the PPLX dispatch/combine kernels use. + that the batched dispatch/combine kernels use. """ def __init__( @@ -645,10 +645,10 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ) -class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): +class NaiveBatchedExperts(mk.FusedMoEExpertsModular): """ A reference MoE expert class that operates on expert batched format, - i.e. E x max_num_tokens x K. This is the format that the pplx + i.e. E x max_num_tokens x K. This is the format that the batched dispatch/combine kernels use. """ @@ -877,10 +877,10 @@ def batched_moe_kernel_quantize_input( return A_q, A_q_scale -class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): +class BatchedTritonExperts(mk.FusedMoEExpertsModular): """ A Triton based MoE expert class that operates on expert batched format, - i.e. E x max_num_tokens x K. This is the format that the pplx + i.e. E x max_num_tokens x K. This is the format that the batched dispatch/combine kernels use. """ diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 4a8f312..280d090 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -526,7 +526,7 @@ def batched_fused_marlin_moe( return output -class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute): +class MarlinExpertsBase(mk.FusedMoEExpertsModular): def __init__( self, moe_config: FusedMoEConfig, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 64fa9c9..27dcd40 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -53,7 +53,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.torch_utils import direct_register_custom_op - +import vllm._custom_ops as ops +import ixformer.inference.functions as ixfops +from vllm.forward_context import ForwardContext, get_forward_context +from vllm.distributed import get_ep_group logger = init_logger(__name__) @@ -575,56 +578,6 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) -def invoke_fused_moe_kernel( - A: torch.Tensor, - B: torch.Tensor, - C: torch.Tensor, - A_scale: torch.Tensor | None, - B_scale: torch.Tensor | None, - B_zp: torch.Tensor | None, - topk_weights: torch.Tensor | None, - topk_ids: torch.Tensor, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, - mul_routed_weight: bool, - top_k: int, - config: dict[str, Any], - compute_type: tl.dtype, - use_fp8_w8a8: bool, - use_int8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - per_channel_quant: bool, - block_shape: list[int] | None = None, - B_bias: torch.Tensor | None = None, -) -> None: - assert topk_weights is not None or not mul_routed_weight - assert topk_weights is None or topk_weights.stride(1) == 1 - assert sorted_token_ids.stride(0) == 1 - ops.invoke_fused_moe_kernel( - A, - B, - C, - A_scale, - B_scale, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - mul_routed_weight, - top_k, - config, - compute_type, - use_fp8_w8a8, - use_int8_w8a16, - block_shape, - ) - # ops.invoke_fused_moe_kernel(A,B,C,A_scale,B_scale,topk_weights,topk_ids,sorted_token_ids,expert_ids,num_tokens_post_padded,mul_routed_weight,top_k,config,compute_type,use_fp8_w8a8,use_int8_w8a16,block_shape,B_bias) - return - - # NOTE(zyongye): we can remove all the wna16 kernel # once we drop off sm75 support def invoke_fused_moe_wna16_cuda_kernel( @@ -782,6 +735,7 @@ def invoke_fused_moe_triton_kernel( A_scale: torch.Tensor | None, B_scale: torch.Tensor | None, topk_weights: torch.Tensor | None, + topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor | None, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, @@ -799,7 +753,9 @@ def invoke_fused_moe_triton_kernel( ): assert topk_weights is not None or not mul_routed_weight assert topk_weights is None or topk_weights.stride(1) == 1 - assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1 + assert sorted_token_ids.stride(0) == 1 + ops.invoke_fused_moe_kernel(A,B,C,A_scale,B_scale,topk_weights,topk_ids,sorted_token_ids,expert_ids,num_tokens_post_padded,mul_routed_weight,top_k,config,compute_type,use_fp8_w8a8,use_int8_w8a16,block_shape,B_bias) + return if use_fp8_w8a8 or use_int8_w8a8: assert B_scale is not None @@ -910,32 +866,6 @@ def dispatch_fused_moe_kernel( block_shape: list[int] | None = None, B_bias: torch.Tensor | None = None, ) -> None: - invoke_fused_moe_kernel( - A, - B, - C, - A_scale, - B_scale, - B_zp, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - mul_routed_weight, - top_k, - config, - compute_type, - use_fp8_w8a8, - use_int8_w8a8, - use_int8_w8a16, - use_int4_w4a16, - per_channel_quant, - block_shape, - B_bias - ) - return - assert topk_weights is not None or not mul_routed_weight assert topk_weights is None or topk_weights.stride(1) == 1 assert sorted_token_ids is None or sorted_token_ids.stride(0) == 1 @@ -999,6 +929,7 @@ def dispatch_fused_moe_kernel( A_scale, B_scale, topk_weights, + topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, @@ -1397,14 +1328,13 @@ def get_default_config( "num_warps": num_warps, "num_stages": num_stages, } - # TODO numel = M * topk if numel <= 64: - config["BLOCK_SIZE_M"] = 32 + config['BLOCK_SIZE_M'] = 32 elif numel <= 1024: - config["BLOCK_SIZE_M"] = 64 + config['BLOCK_SIZE_M'] = 64 else: - config["BLOCK_SIZE_M"] = 256 + config['BLOCK_SIZE_M'] = 256 return config @@ -1424,14 +1354,12 @@ def try_get_optimal_moe_config( else: # First try to load optimal config from the file E, _, N = w2_shape - if dtype == "int4_w4a16": - N = N * 2 - block_n = block_shape[0] if block_shape else 0 - block_k = block_shape[1] if block_shape else 0 - configs = get_moe_configs(E, N, dtype, block_n, block_k) + # block_n = block_shape[0] if block_shape else 0 + # block_k = block_shape[1] if block_shape else 0 + # configs = get_moe_configs(E, N, dtype, block_n, block_k) configs = None - + if configs: # If an optimal configuration map has been found, look up the # optimal config @@ -1560,13 +1488,12 @@ def outplace_fused_experts( w1_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, ) -> torch.Tensor: - return fused_experts_impl( + return fused_experts_impl_opt( hidden_states, w1, w2, topk_weights, topk_ids, - False, activation, apply_router_weight_on_input, use_fp8_w8a8, @@ -1626,14 +1553,12 @@ direct_register_custom_op( def torch_vllm_inplace_fused_experts(**kwargs) -> torch.Tensor: - # torch.ops.vllm.inplace_fused_experts(**kwargs) inplace_fused_experts(**kwargs) - hidden_states = kwargs["hidden_states"] + hidden_states = kwargs['hidden_states'] return hidden_states def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor: - # return torch.ops.vllm.outplace_fused_experts(**kwargs) return outplace_fused_experts(**kwargs) @@ -1661,7 +1586,6 @@ def fused_experts( """Run fused MoE expert computation using Triton kernels.""" if quant_config is None: quant_config = FUSED_MOE_UNQUANTIZED_CONFIG - assert not inplace or not disable_inplace() return dispatch_fused_experts_func(inplace)( @@ -1691,6 +1615,245 @@ def fused_experts( w2_bias=quant_config.w2_bias, ) +def fused_experts_impl_opt( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, + use_int8_w8a16: bool = False, + use_int4_w4a16: bool = False, + ocp_mx_scheme: str | None = None, + per_channel_quant: bool = False, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + w1_zp: torch.Tensor | None = None, + w2_zp: torch.Tensor | None = None, + a1_scale: torch.Tensor | None = None, + a2_scale: torch.Tensor | None = None, + block_shape: torch.Tensor | None = None, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, + output: torch.Tensor | None = None +) -> torch.Tensor: + # check constraints + if use_fp8_w8a8 or use_int8_w8a8 or use_int8_w8a16 or use_int4_w4a16 or w1_scale or \ + w2_scale or w1_zp or w2_zp or a1_scale or a2_scale: + raise ValueError("Quantized MoE is not supported") + + attn_metadata = get_forward_context().attn_metadata + use_ep = expert_map is not None + + # unsupported ep now + if attn_metadata: + only_decode = (use_ep == False and all(t.num_decodes > 0 and t.num_prefills ==0 for t in list(attn_metadata.values()))) + else: + only_decode = False + + assert topk_weights.size() == topk_ids.size(), "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.stride(-1) == 1, "Stride of last dimension must be 1" + assert w2.stride(-1) == 1, "Stride of last dimension must be 1" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + + num_tokens = hidden_states.size(0) + num_experts = w1.size(0) + top_k = topk_weights.size(1) + + if use_ep: + local_num_experts = w1.size(0) + start_eid = get_ep_group().device_group.rank() * local_num_experts + end_eid = min((get_ep_group().device_group.rank() + 1) * local_num_experts, global_num_experts) + hidden_size = hidden_states.shape[1] + ( + src_to_dst, + sorted_token_ids, + expert_sizes_gpu, + expert_sizes_cpu, + expand_tokens, + ) = ixfops.moe_compute_token_index_ep( + topk_ids=topk_ids, + num_experts=global_num_experts, + start_expert_id=start_eid, + end_expert_id=end_eid, + ) + if expert_sizes_cpu.sum() == 0: + return torch.zeros( + (num_tokens, hidden_size), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + else: + expand_tokens = num_tokens * top_k + ( + src_to_dst, + sorted_token_ids, + expert_sizes_gpu, + expert_sizes_cpu, + ) = ixfops.moe_compute_token_index( + topk_ids=topk_ids, + num_experts=num_experts, + ) + + if only_decode: + # expand + reorder + hidden_states = ixfops.moe_expand_input( + hidden_states=hidden_states, + dst_to_src=sorted_token_ids, + dst_tokens=expand_tokens, + topk=top_k, + src_to_dst=src_to_dst, + ) + + # group gemm 1 + pt_output_1 = ixfops.moe_w16a16_group_gemv( + input=hidden_states, + weight=w1, + output_dtype=hidden_states.dtype, + tokens_per_experts_gpu=expert_sizes_gpu, + dst_to_src=None, + bias=w1_bias, + format="TN", + ) + + # act + if activation == "silu": + pt_output_2 = ixfops.silu_and_mul(pt_output_1) + elif activation == "gelu": + pt_output_2 = ixfops.gelu_and_mul(pt_output_1) + elif activation == "swigluoai": + pt_output_2 = ixfops.swigluoai_and_mul(pt_output_1) + elif activation == "swiglustep": + from vllm.model_executor.layers.activation import swiglustep_and_mul_triton + output_dim = pt_output_1.shape[1] + pt_output_2 = torch.empty( + (num_tokens * top_k, output_dim//2), + device=pt_output_1.device, + dtype=pt_output_1.dtype, + ) + swiglustep_and_mul_triton(pt_output_2, pt_output_1) + else: + raise ValueError(f"Unsupported activation: {activation}") + + # group gemm 2 + reorder + pt_output_3 = ixfops.moe_w16a16_group_gemv( + input=pt_output_2, + weight=w2, + output_dtype=hidden_states.dtype, + tokens_per_experts_gpu=expert_sizes_gpu, + dst_to_src=sorted_token_ids, + bias=w2_bias, + format="TN", + ) + + # mul + reduce_sum + final_hidden_states = ixfops.moe_output_reduce_sum( + input=pt_output_3.view(num_tokens, top_k, -1), + topk_weight=topk_weights, + ) + + else: + expert_sizes_cpu = expert_sizes_gpu.cpu() + # expand + reorder + hidden_states = ixfops.moe_expand_input( + hidden_states=hidden_states, + dst_to_src=sorted_token_ids, + dst_tokens=expand_tokens, + topk=top_k, + src_to_dst=src_to_dst, + ) + # group gemm 1 + pt_output_1 = ixfops.moe_w16a16_group_gemm( + input=hidden_states, + weight=w1, + output_dtype=hidden_states.dtype, + tokens_per_experts=expert_sizes_cpu, + dst_to_src=None, + bias=w1_bias, + format="TN", + ) + + # act + if activation == "silu": + pt_output_2 = ixfops.silu_and_mul(pt_output_1) + elif activation == "gelu": + pt_output_2 = ixfops.gelu_and_mul(pt_output_1) + elif activation == "swigluoai": + pt_output_2 = ixfops.swigluoai_and_mul(pt_output_1) + elif activation == "swiglustep": + from vllm.model_executor.layers.activation import swiglustep_and_mul_triton + output_dim = pt_output_1.shape[1] + pt_output_2 = torch.empty( + (num_tokens * top_k, output_dim//2), + device=pt_output_1.device, + dtype=pt_output_1.dtype, + ) + swiglustep_and_mul_triton(pt_output_2, pt_output_1) + else: + raise ValueError(f"Unsupported activation: {activation}") + + if use_ep: + pt_output_3 = torch.empty( + (num_tokens * top_k, hidden_size), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + # group gemm 2 + reorder + pt_output_3 = ixfops.moe_w16a16_group_gemm( + input=pt_output_2, + weight=w2, + output_dtype=hidden_states.dtype, + tokens_per_experts=expert_sizes_cpu, + dst_to_src=sorted_token_ids, + format="TN", + bias=w2_bias, + output=pt_output_3, + ) + + # mul + reduce_sum + reduce_mask = src_to_dst == -1 + if output != None: + ixfops.moe_output_reduce_sum( + input=pt_output_3.view(num_tokens, top_k, -1), + topk_weight=topk_weights, + output=output, + mask=reduce_mask, + ) + else: + final_hidden_states = ixfops.moe_output_reduce_sum( + input=pt_output_3.view(num_tokens, top_k, -1), + topk_weight=topk_weights, + mask=reduce_mask, + ) + else: + # group gemm 2 + reorder + pt_output_3 = ixfops.moe_w16a16_group_gemm( + input=pt_output_2, + weight=w2, + output_dtype=hidden_states.dtype, + tokens_per_experts=expert_sizes_cpu, + dst_to_src=sorted_token_ids, + bias=w2_bias, + format="TN", + ) + + # mul + reduce_sum + final_hidden_states = ixfops.moe_output_reduce_sum( + input=pt_output_3.view(num_tokens, top_k, -1), + topk_weight=topk_weights, + ) + + if output == None: + return final_hidden_states + def _get_config_quant_dtype( use_fp8_w8a8: bool, @@ -1825,7 +1988,7 @@ def fused_experts_impl( intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K) # This needs separate memory since it's used concurrently with cache1 - activation_out_dim = mk.FusedMoEPermuteExpertsUnpermute.adjust_N_for_activation( + activation_out_dim = mk.FusedMoEExpertsModular.adjust_N_for_activation( N, activation_enum ) intermediate_cache2 = torch.empty( @@ -1910,28 +2073,28 @@ def fused_experts_impl( ocp_mx_scheme=ocp_mx_scheme, ) - # SPARSITY_FACTOR is a heuristic margin ensuring tokens_in_chunk * top_k - # activates only a small fraction of total experts - SPARSITY_FACTOR = 4 - # block quantized code path is not implemented yet. - naive_block_assignment = ( - expert_map is None - and tokens_in_chunk * top_k_num * SPARSITY_FACTOR <= global_num_experts - and not ( - (use_int8_w8a16 or use_int4_w4a16) - and block_shape is not None - and block_shape[1] > 0 - ) - ) + # # SPARSITY_FACTOR is a heuristic margin ensuring tokens_in_chunk * top_k + # # activates only a small fraction of total experts + # SPARSITY_FACTOR = 4 + # # block quantized code path is not implemented yet. + # naive_block_assignment = ( + # expert_map is None + # and tokens_in_chunk * top_k_num * SPARSITY_FACTOR <= global_num_experts + # and not ( + # (use_int8_w8a16 or use_int4_w4a16) + # and block_shape is not None + # and block_shape[1] > 0 + # ) + # ) # if not naive_block_assignment: - # sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - # curr_topk_ids, - # config["BLOCK_SIZE_M"], - # global_num_experts, - # expert_map, - # ignore_invalid_experts=True, - # ) + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + curr_topk_ids, + config["BLOCK_SIZE_M"], + global_num_experts, + expert_map, + ignore_invalid_experts=True, + ) # else: # max_num_tokens_padded = topk_ids.numel() * config["BLOCK_SIZE_M"] # expert_ids = curr_topk_ids.view(-1) @@ -1941,14 +2104,6 @@ def fused_experts_impl( # num_tokens_post_padded.fill_(max_num_tokens_padded) # sorted_token_ids = None - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - curr_topk_ids, - config["BLOCK_SIZE_M"], - global_num_experts, - expert_map, - ignore_invalid_experts=True, - ) - dispatch_fused_moe_kernel( qcurr_hidden_states, w1, @@ -2015,20 +2170,14 @@ def fused_experts_impl( B_bias=w2_bias, ) - # ops.moe_sum( - # intermediate_cache3.view(*intermediate_cache3.size()), - # out_hidden_states[begin_chunk_idx:end_chunk_idx], - # ) - torch.sum( - intermediate_cache3.view(*intermediate_cache3.shape), - dim=1, - out=out_hidden_states[begin_chunk_idx:end_chunk_idx], - ) + torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=out_hidden_states[begin_chunk_idx:end_chunk_idx]) return out_hidden_states -class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): +class TritonExperts(mk.FusedMoEExpertsModular): """Triton-based fused MoE expert implementation.""" def __init__( @@ -2091,8 +2240,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): @staticmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: - # return not moe_parallel_config.use_fi_all2allv_kernels - return True + return not moe_parallel_config.use_fi_all2allv_kernels def supports_chunking(self) -> bool: return True @@ -2138,157 +2286,31 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): - # Check constraints. - if self.quant_config.use_int4_w4a16: - assert hidden_states.size(-1) // 2 == w1.size(2), "Hidden size mismatch" - else: - assert hidden_states.size(-1) == w1.size(2), ( - f"Hidden size mismatch {hidden_states.size(-1)} != {w1.size(2)}" - ) - - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert hidden_states.dim() == 2 - assert w1.stride(-1) == 1, "Stride of last dimension must be 1" - assert w2.stride(-1) == 1, "Stride of last dimension must be 1" - assert hidden_states.dtype in [ - torch.float32, - torch.float16, - torch.bfloat16, - torch.float8_e4m3fn, - torch.float8_e4m3fnuz, - ] - - E, num_tokens, N, K, top_k_num = self.moe_problem_size( - hidden_states, w1, w2, topk_ids - ) - - if global_num_experts == -1: - global_num_experts = E - - config = try_get_optimal_moe_config( - w1.size(), - w2.size(), - top_k_num, - self.quant_config.config_name(hidden_states.dtype), - num_tokens, - block_shape=self.block_shape, - ) - - if hidden_states.dtype == torch.bfloat16: - compute_type = tl.bfloat16 - elif hidden_states.dtype == torch.float16: - compute_type = tl.float16 - elif hidden_states.dtype == torch.float32: - compute_type = tl.float32 - elif ( - hidden_states.dtype == torch.float8_e4m3fn - or hidden_states.dtype == torch.float8_e4m3fnuz - ): - compute_type = tl.bfloat16 - else: - raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") - - # Note that the output tensor might be in workspace1 - intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N)) - cache2_dim = self.adjust_N_for_activation(N, activation) - intermediate_cache2 = _resize_cache( - workspace13, (num_tokens * top_k_num, cache2_dim) - ) - intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K)) - - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map - ) - - invoke_fused_moe_triton_kernel( - hidden_states, - w1, - intermediate_cache1, - a1q_scale, - self.w1_scale, - None, # topk_weights - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, # mul_routed_weights - top_k_num, - config, - compute_type=compute_type, - use_fp8_w8a8=self.quant_config.use_fp8_w8a8, - use_int8_w8a8=self.quant_config.use_int8_w8a8, - use_int8_w8a16=self.quant_config.use_int8_w8a16, - use_int4_w4a16=self.quant_config.use_int4_w4a16, - per_channel_quant=self.per_act_token_quant, - block_shape=self.block_shape, - B_bias=self.w1_bias, - ) - - self.activation( - activation, intermediate_cache2, intermediate_cache1.view(-1, N) - ) - - a2q_scale: torch.Tensor | None = None - - qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( - intermediate_cache2, - a2_scale, - self.quant_dtype, - self.per_act_token_quant, - self.block_shape, - ) - - # invoke_fused_moe_triton_kernel( - # qintermediate_cache2, - # w2, - # intermediate_cache3, - # a2q_scale, - # self.w2_scale, - # topk_weights, - # sorted_token_ids, - # expert_ids, - # num_tokens_post_padded, - # not apply_router_weight_on_input, - # 1, - # config, - # compute_type=compute_type, - # use_fp8_w8a8=self.quant_config.use_fp8_w8a8, - # use_int8_w8a8=self.quant_config.use_int8_w8a8, - # use_int8_w8a16=self.quant_config.use_int8_w8a16, - # use_int4_w4a16=self.quant_config.use_int4_w4a16, - # per_channel_quant=self.per_act_token_quant, - # block_shape=self.block_shape, - # B_bias=self.w2_bias, - # ) - - invoke_fused_moe_kernel( - qintermediate_cache2, - w2, - intermediate_cache3, - a2q_scale, - self.w2_scale, - self.w2_zp, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - not apply_router_weight_on_input, - 1, - config, - compute_type=compute_type, - use_fp8_w8a8=self.quant_config.use_fp8_w8a8, - use_int8_w8a8=self.quant_config.use_int8_w8a8, - use_int8_w8a16=self.quant_config.use_int8_w8a16, - use_int4_w4a16=self.quant_config.use_int4_w4a16, - per_channel_quant=self.per_act_token_quant, - block_shape=self.block_shape, - B_bias=self.w2_bias, - ) - - # separate function is required for MoE + LoRA - self.moe_sum(intermediate_cache3, output) - - def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None: - ops.moe_sum(input, output) + fused_experts_impl_opt(hidden_states, + w1, + w2, + topk_weights, + topk_ids, + activation, + apply_router_weight_on_input, + self.quant_config.use_fp8_w8a8, + self.quant_config.use_int8_w8a8, + self.quant_config.use_int8_w8a16, + self.quant_config.use_int4_w4a16, + self.quant_config.ocp_mx_scheme, + self.quant_config.per_act_token_quant, + global_num_experts, + expert_map, + self.quant_config.w1_scale, + self.quant_config.w2_scale, + self.quant_config.w1_zp, + self.quant_config.w2_zp, + self.quant_config.a1_scale, + self.quant_config.a2_scale, + self.quant_config.block_shape, + self.quant_config.w1_bias, + self.quant_config.w2_bias, + output) class TritonWNA16Experts(TritonExperts): diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index ac7c71e..88cd173 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -12,8 +12,8 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEPermuteExpertsUnpermute, - FusedMoEPrepareAndFinalize, + FusedMoEExpertsModular, + FusedMoEPrepareAndFinalizeModular, ) from vllm.model_executor.layers.quantization.base_config import ( QuantizeMethodBase, @@ -27,19 +27,21 @@ class FusedMoEMethodBase(QuantizeMethodBase): super().__init__() self.moe: FusedMoEConfig = moe self.moe_quant_config: FusedMoEQuantConfig | None = None - self.moe_mk: mk.FusedMoEModularKernel | None = None + self.moe_kernel: mk.FusedMoEKernel | None = None @property def supports_internal_mk(self) -> bool: # NOTE(rob): temporary attribute to indicate support for # completed migration to the new internal MK interface. - return self.moe_mk is not None + return self.moe_kernel is not None @property def mk_owns_shared_expert(self) -> bool: # NOTE(rob): temporary attribute to indicate support for # completed migration to the new internal MK interface. - return self.moe_mk is not None and self.moe_mk.shared_experts is not None + return ( + self.moe_kernel is not None and self.moe_kernel.shared_experts is not None + ) @abstractmethod def create_weights( @@ -66,35 +68,25 @@ class FusedMoEMethodBase(QuantizeMethodBase): def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, - ) -> FusedMoEPrepareAndFinalize | None: + ) -> FusedMoEPrepareAndFinalizeModular | None: from .all2all_utils import maybe_make_prepare_finalize - return maybe_make_prepare_finalize( + pf = maybe_make_prepare_finalize( self.moe, self.moe_quant_config, routing_tables ) + assert pf is None or isinstance(pf, FusedMoEPrepareAndFinalizeModular) + return pf def select_gemm_impl( self, - prepare_finalize: FusedMoEPrepareAndFinalize, + prepare_finalize: FusedMoEPrepareAndFinalizeModular, layer: torch.nn.Module, - ) -> FusedMoEPermuteExpertsUnpermute: + ) -> FusedMoEExpertsModular: # based on the all2all implementation, select the appropriate # gemm implementation - raise NotImplementedError( - f"{self.__class__.__name__} must select appropriate gemm " - "implementation based on the prepare_finalize" - ) - - def prepare_dp_allgather_tensor( - self, - layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - ) -> tuple[torch.Tensor, list[torch.Tensor]]: - """Hook to prepare tensors and extra tensors for DP allgather + EP dispatch.""" - raise NotImplementedError( - "Method 'prepare_dp_allgather_tensor' is not implemented in " - f"{self.__class__.__name__}." + raise ValueError( + f"{self.__class__.__name__} uses the new modular kernel initialization " + "logic. This function should not be called." ) @abstractmethod @@ -105,8 +97,8 @@ class FusedMoEMethodBase(QuantizeMethodBase): @property def topk_indices_dtype(self) -> torch.dtype | None: - if self.moe_mk is not None: - return self.moe_mk.prepare_finalize.topk_indices_dtype() + if self.moe_kernel is not None: + return self.moe_kernel.prepare_finalize.topk_indices_dtype() return None @property @@ -119,7 +111,12 @@ class FusedMoEMethodBase(QuantizeMethodBase): @property def is_monolithic(self) -> bool: - return False + if self.moe_kernel is None: + if hasattr(self, "experts_cls"): + return self.experts_cls.is_monolithic() + else: + return False + return self.moe_kernel.is_monolithic def apply( self, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py index 187464c..0065c11 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py @@ -13,8 +13,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( FusedMoEMethodBase, ) from vllm.model_executor.layers.fused_moe.modular_kernel import ( - FusedMoEModularKernel, - FusedMoEPrepareAndFinalize, + FusedMoEKernel, + FusedMoEPrepareAndFinalizeModular, ) logger = init_logger(__name__) @@ -26,15 +26,15 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): # --8<-- [end:modular_fused_moe] def __init__( - self, old_quant_method: FusedMoEMethodBase, experts: FusedMoEModularKernel + self, old_quant_method: FusedMoEMethodBase, moe_kernel: FusedMoEKernel ): super().__init__(old_quant_method.moe) self.moe_quant_config = old_quant_method.moe_quant_config - self.moe_mk = experts + self.moe_kernel = moe_kernel self.disable_expert_map = getattr( old_quant_method, "disable_expert_map", - not self.moe_mk.supports_expert_map(), + not self.moe_kernel.supports_expert_map(), ) self.old_quant_method = old_quant_method logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__) @@ -43,13 +43,13 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): def make( moe_layer: torch.nn.Module, old_quant_method: FusedMoEMethodBase, - prepare_finalize: FusedMoEPrepareAndFinalize, + prepare_finalize: FusedMoEPrepareAndFinalizeModular, shared_experts: torch.nn.Module | None, inplace: bool = False, ) -> "FusedMoEModularMethod": return FusedMoEModularMethod( old_quant_method, - FusedMoEModularKernel( + FusedMoEKernel( prepare_finalize, old_quant_method.select_gemm_impl(prepare_finalize, moe_layer), shared_experts, @@ -90,8 +90,8 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.moe_mk is not None - return self.moe_mk( + assert self.moe_kernel is not None + return self.moe_kernel.apply( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 5617156..8d6f716 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -6,6 +6,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops +from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( @@ -178,7 +179,40 @@ def triton_kernel_moe_forward( apply_router_weight_on_input: bool = False, global_num_experts: int = -1, expert_map: torch.Tensor | None = None, + unpadded_N_w1=None, + unpadded_K_w1=None, + unpadded_N_w2=None, + unpadded_K_w2=None, ) -> torch.Tensor: + if ( + quant_config is not None + and quant_config.use_mxfp4_w4a8 + and rocm_aiter_ops.is_enabled() + ): + from aiter.ops.triton.moe_routing.routing import routing as aiter_routing + + routing_data, gather_idx, scatter_idx = aiter_routing( + gating_output, topk, sm_first=not renormalize + ) + return triton_kernel_fused_mxfp4_w4a8_experts( + None, + hidden_states, + w1, + w2, + routing_data, + gather_idx, + scatter_idx, + activation=activation.value, + quant_config=quant_config, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + unpadded_N_w1=unpadded_N_w1, + unpadded_K_w1=unpadded_K_w1, + unpadded_N_w2=unpadded_N_w2, + unpadded_K_w2=unpadded_K_w2, + ) + if expert_map is not None: # With expert parallelism, legacy_routing produces routing data # using global expert IDs which don't correspond to local weight @@ -210,6 +244,9 @@ def triton_kernel_moe_forward( effective_global_num_experts = global_num_experts output = torch.empty_like(hidden_states) + effective_quant_config = ( + quant_config if quant_config is not None else FUSED_MOE_UNQUANTIZED_CONFIG + ) return triton_kernel_fused_experts( output, @@ -221,7 +258,7 @@ def triton_kernel_moe_forward( scatter_idx, topk=topk, activation=activation, - quant_config=quant_config, + quant_config=effective_quant_config, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=effective_global_num_experts, expert_map=effective_expert_map, @@ -252,8 +289,7 @@ def triton_kernel_fused_experts( assert activation == MoEActivation.SWIGLUOAI, ( "Only SWIGLUOAI activation is supported" ) - if quant_config is None: - quant_config = FUSED_MOE_UNQUANTIZED_CONFIG + assert quant_config is not None # type check, uint8 means mxfp4 assert hidden_states.dtype == torch.bfloat16 @@ -330,6 +366,98 @@ def triton_kernel_fused_experts( return output_tensor +# This is a triton implementation of the fused_experts function +def triton_kernel_fused_mxfp4_w4a8_experts( + output_tensor: torch.Tensor, + hidden_states: torch.Tensor, + w1, # Tensor or triton_kernels.Tensor + w2, # Tensor or triton_kernels.Tensor + routing_data, # RoutingData + gather_indx, # GatherIndx + scatter_indx, # ScatterIndx + activation: str = "silu", + quant_config: FusedMoEQuantConfig | None = None, + swiglu_alpha: float = 1.702, + swiglu_limit: float = 7.0, + apply_router_weight_on_input: bool = False, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + a1q_scale: torch.Tensor | None = None, + unpadded_N_w1=None, + unpadded_K_w1=None, + unpadded_N_w2=None, + unpadded_K_w2=None, +) -> torch.Tensor: + assert quant_config is not None + # type check, uint8 means mxfp4 + assert hidden_states.dtype == torch.bfloat16 + assert quant_config.w1_bias is None or quant_config.w1_bias.dtype == torch.float32 + assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32 + + # Shape check, only check non-mxfp4 + assert hidden_states.shape[-1] == w1.shape[-2] + assert w2.shape[-1] == w1.shape[1] + + E, _, N = w1.shape + + if global_num_experts == -1: + global_num_experts = E + + gammas = routing_data.gate_scal if routing_data else None + + from aiter.ops.triton.moe_op_gemm_a8w4 import moe_gemm_a8w4 + from aiter.ops.triton.quant_moe import downcast_to_static_fp8 + + assert quant_config.w1_precision is not None, ( + "w1_precision in quant config can't be None" + ) + assert quant_config.w2_precision is not None, ( + "w2_precision in quant config can't be None" + ) + + hidden_states = downcast_to_static_fp8( + hidden_states, quant_config.w1_precision.flex_ctx.lhs_data.scale + ) + + intermediate_cache1 = moe_gemm_a8w4( + hidden_states, + w1.storage.data, + None, + quant_config.w1_precision.weight_scale.storage.data, + quant_config.w1_precision.flex_ctx.lhs_data.scale, + quant_config.w2_precision.flex_ctx.lhs_data.scale, + quant_config.w1_bias, + routing_data, + gather_indx=gather_indx, + gammas=gammas if apply_router_weight_on_input else None, + swizzle_mx_scale="CDNA4_SCALE", + out_dtype=torch.float8_e4m3fn, + apply_swiglu=True, + alpha=swiglu_alpha, + limit=swiglu_limit, + unpadded_N=unpadded_N_w1, + unpadded_K=unpadded_K_w1, + ) + + intermediate_cache3 = moe_gemm_a8w4( + intermediate_cache1, + w2.storage.data, + None, + quant_config.w2_precision.weight_scale.storage.data, + quant_config.w2_precision.flex_ctx.lhs_data.scale, + None, + quant_config.w2_bias, + routing_data, + scatter_indx=scatter_indx, + gammas=None if apply_router_weight_on_input else gammas, + swizzle_mx_scale="CDNA4_SCALE", + unpadded_N=unpadded_N_w2, + unpadded_K=unpadded_K_w2, + ) + + return intermediate_cache3 + + def make_routing_data( topk_ids: torch.Tensor, topk_weights: torch.Tensor, @@ -383,7 +511,7 @@ def make_routing_data( return routing_data, gather_indx, scatter_indx -class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute): +class BaseOAITritonExperts(mk.FusedMoEExpertsModular): @staticmethod def _supports_current_device() -> bool: raise NotImplementedError( @@ -520,6 +648,9 @@ class OAITritonExperts(BaseOAITritonExperts): expert_tokens_meta: mk.ExpertTokensMetadata | None, apply_router_weight_on_input: bool, ): + if self.quant_config is None: + self.quant_config: FusedMoEQuantConfig = FUSED_MOE_UNQUANTIZED_CONFIG + if expert_map is not None: topk_ids = expert_map[topk_ids] diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 6ac1c19..69c9c17 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -5,8 +5,8 @@ from collections.abc import Callable, Iterable from enum import Enum from typing import Literal, cast, get_args, overload +import ast, re import torch -import torch.nn.functional as F from torch.nn.parameter import UninitializedParameter import vllm.envs as envs @@ -54,10 +54,14 @@ from vllm.model_executor.layers.quantization.base_config import ( ) from vllm.platforms import current_platform from vllm.utils.math_utils import round_up +from vllm.model_executor.layers.utils import ( + parse_opt_exclude_layers, + weight_quant_l1, + weight_quant_l2, +) logger = init_logger(__name__) - class FusedMoeWeightScaleSupported(Enum): TENSOR = "tensor" CHANNEL = "channel" @@ -333,6 +337,7 @@ class FusedMoE(CustomOp): gate: torch.nn.Module | None = None, shared_experts: torch.nn.Module | None = None, routed_input_transform: torch.nn.Module | None = None, + fused_shared_output: bool = False, ): super().__init__() @@ -483,6 +488,8 @@ class FusedMoE(CustomOp): (expert_mask == 0) | (expert_mask == 1) ), "Aiter Fused MoE kernel only supports expert_map with 0 and 1s." + self.hidden_size = hidden_size + self.num_experts = num_experts assert intermediate_size % self.tp_size == 0 self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results @@ -526,16 +533,18 @@ class FusedMoE(CustomOp): # Round up hidden size before creating moe_config. # This way moe_config is created with the correct hidden_size from the start. + unpadded_hidden_size = hidden_size + self.model_type = ( + self.vllm_config.model_config.hf_config.model_type + if self.vllm_config.model_config is not None + else None + ) hidden_size = maybe_roundup_hidden_size( hidden_size=hidden_size, act_dtype=moe_in_dtype, moe_parallel_config=self.moe_parallel_config, is_lora_enabled=vllm_config.lora_config is not None, - model_type=( - self.vllm_config.model_config.hf_config.model_type - if self.vllm_config.model_config is not None - else None - ), + model_type=self.model_type, is_mxfp4_quant=( quant_config is not None and quant_config.is_mxfp4_quant(prefix, self) ), @@ -581,14 +590,27 @@ class FusedMoE(CustomOp): """ quant_method = None if self.quant_config is not None: + self.opt_level = 0 quant_method = self.quant_config.get_quant_method(self, prefix) if quant_method is None: - quant_method = UnquantizedFusedMoEMethod(self.moe_config) + from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( + CompressedTensorsL1OptMoEMethod, CompressedTensorsL2OptMoEMethod) + if self.opt_level == 1: + quant_method = CompressedTensorsL1OptMoEMethod(self.moe_config) + elif self.opt_level == 2: + quant_method = CompressedTensorsL2OptMoEMethod(self.moe_config) + else: + quant_method = UnquantizedFusedMoEMethod(self.moe_config) assert isinstance(quant_method, FusedMoEMethodBase) return quant_method # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. + self.opt_level = envs.VLLM_MOE_OPT_LEVEL + if parse_opt_exclude_layers(envs.VLLM_OPT_EXCLUDE_LAYERS, prefix): + self.opt_flag = False + logger.info(f"Excluding layer {prefix} from optimization") + self.quant_method: FusedMoEMethodBase = _get_quant_method() if not self.moe_config.is_act_and_mul and not current_platform.is_cuda_alike(): @@ -611,6 +633,7 @@ class FusedMoE(CustomOp): moe_quant_params = { "num_experts": self.local_num_experts, "hidden_size": hidden_size, + "unpadded_hidden_size": unpadded_hidden_size, "intermediate_size_per_partition": self.intermediate_size_per_partition, "params_dtype": params_dtype, "weight_loader": self.weight_loader, @@ -625,6 +648,7 @@ class FusedMoE(CustomOp): moe_quant_params["intermediate_size_full"] = intermediate_size self.quant_method.create_weights(layer=self, **moe_quant_params) + self.base_quant_method = self.quant_method # Disable shared expert overlap if: # - we are using eplb with non-default backend, because of correctness issues @@ -638,7 +662,10 @@ class FusedMoE(CustomOp): ) and self._shared_experts is not None ) - + if fused_shared_output: + assert self.use_ep == False, "Fused shared output is only supported when EP is disabled." + assert shared_experts is not None, "Shared experts must be provided when fused_shared_output is True." + self.fused_shared_output = fused_shared_output self.runner = self._init_runner() def _init_runner(self): @@ -655,6 +682,7 @@ class FusedMoE(CustomOp): quant_method=self.quant_method, reduce_results=self.reduce_results, enable_dbo=self.vllm_config.parallel_config.enable_dbo, + fused_shared_output=self.fused_shared_output, ) # TODO(bnell): This method is provided as a hook so vllm/lora/layers/fused_moe.py @@ -681,7 +709,7 @@ class FusedMoE(CustomOp): # routing_tables only needed for round-robin expert placement with # DeepEP all2all backend. routing_tables = self._maybe_init_expert_routing_tables() - prepare_finalize = self.quant_method.maybe_make_prepare_finalize( + prepare_finalize = self.base_quant_method.maybe_make_prepare_finalize( routing_tables=routing_tables ) if prepare_finalize is not None: @@ -691,7 +719,7 @@ class FusedMoE(CustomOp): self._replace_quant_method( FusedMoEModularMethod.make( self, - self.quant_method, + self.base_quant_method, prepare_finalize, self.shared_experts, inplace=not self.moe_config.disable_inplace, @@ -959,11 +987,7 @@ class FusedMoE(CustomOp): else: assert shard_id == "w3" expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) - try: - expert_data.copy_(loaded_weight) - except Exception as e: - print(expert_data.shape, expert_data.dtype, loaded_weight.shape, loaded_weight.dtype) - raise e + expert_data.copy_(loaded_weight) def _load_w2( self, @@ -976,7 +1000,7 @@ class FusedMoE(CustomOp): # Index the loaded weight for tp sharding. # down_proj: "RowParallel" so tp sharding on input_dim # Narrow parameter and load. - shard_size = expert_data.shape[shard_dim] + shard_size = loaded_weight.shape[shard_dim] // self.tp_size # Only narrow if the loaded_weight is not a scalar (0-dim tensor) # and we're not loading the full weight if not load_full and loaded_weight.ndim > 0: @@ -984,7 +1008,55 @@ class FusedMoE(CustomOp): shard_dim, shard_size * tp_rank, shard_size ) # w2, down_proj: Load into only logical weight of w2. - expert_data.copy_(loaded_weight) + expert_data.narrow(shard_dim, 0, shard_size).copy_(loaded_weight) + + def _load_model_opt_weight_or_group_weight_scale(self, + shard_dim: int, + shard_dim_scale: int, + expert_data: torch.Tensor, + scale_data: torch.Tensor, + shard_id: str, + loaded_weight: torch.Tensor, + tp_rank: int, + opt_level: int, + load_full_w2: bool = False): + """ + Load grouped weight scales for group quantization or model weights + :param shard_dim: dimension to shard + :param expert_data: parameter for a particular expert + :param shard_id: either w1, w2, or w3 + :param loaded_weight: checkpoint weight to load into the param + :param tp_rank: tensor parallel rank + :param load_full_w2: whether or not the w2 loaded should be sharded. + """ + + assert opt_level in [1, 2] + if opt_level == 1: + weight, scale = weight_quant_l1(loaded_weight) + else: + weight, scale = weight_quant_l2(loaded_weight) + scale = scale.view(1, -1) + + if shard_id == "w2": + # In the case where we have actorder/g_idx, we do not partition the + # w2 scales, as indicated by `load_full` argument, for all tp cases + self._load_w2(shard_dim=shard_dim, + loaded_weight=weight, + expert_data=expert_data, + tp_rank=tp_rank, + load_full=load_full_w2) + scale_data.copy_(scale) + elif shard_id in ("w1", "w3"): + self._load_w13(shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=weight, + expert_data=expert_data, + tp_rank=tp_rank) + self._load_w13(shard_id=shard_id, + shard_dim=shard_dim_scale, + loaded_weight=scale, + expert_data=scale_data, + tp_rank=tp_rank) def _load_single_value( self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int @@ -1147,7 +1219,6 @@ class FusedMoE(CustomOp): shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] if is_transposed: shard_dim = int(not shard_dim) - shard_dim_force = getattr(param, "shard_dim", None) shard_dim = shard_dim_force if shard_dim_force is not None else shard_dim @@ -1309,13 +1380,28 @@ class FusedMoE(CustomOp): # Case model weights if "weight" in weight_name: - self._load_model_weight_or_group_weight_scale( - shard_id=shard_id, - shard_dim=shard_dim, - loaded_weight=loaded_weight, - expert_data=expert_data, - tp_rank=self.tp_rank, - ) + if self.opt_level != 0: + scale_name = weight_name.split('.')[-1] + "_scale" + params_dict = dict(self.named_parameters()) + scale_param = params_dict[scale_name] + shard_dim_scale = getattr(scale_param, "shard_dim", None) + scale_expert_data = scale_param.data if full_load else scale_param.data[expert_id] + self._load_model_opt_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + shard_dim_scale=shard_dim_scale, + loaded_weight=loaded_weight, + expert_data=expert_data, + scale_data=scale_expert_data, + opt_level=self.opt_level, + tp_rank=self.tp_rank) + else: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=self.tp_rank) return True if return_success else None return False if return_success else None diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 9a5f4a2..7b49282 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -20,6 +20,7 @@ from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig, + RoutingMethodType, ) from vllm.model_executor.layers.fused_moe.utils import ( _resize_cache, @@ -56,25 +57,25 @@ logger = init_logger(__name__) # MoE kernel implementations. # # The following main classes are defined: -# * FusedMoEPrepareAndFinalize - an abstract base class for preparation of MoE +# * FusedMoEPrepareAndFinalizeModular - an abstract base class for preparation of MoE # inputs (e.g. quantization, distribution) and finalization of Moe outputs. # The prepare method must take care of any needed quantization and the -# finalize method, informed by the FusedMoEPermuteExpertsUnpermute method, +# finalize method, informed by the FusedMoEExpertsModular method, # may apply weights and/or do the final reduction of the output. -# * FusedMoEPermuteExpertsUnpermute - an abstract base class for the main fused +# * FusedMoEExpertsModular - an abstract base class for the main fused # MoE operation, i.e matmul + act_mul + optionally quant + matmul. -# Some FusedMoEPermuteExpertsUnpermute implementations may choose to do +# Some FusedMoEExpertsModular implementations may choose to do # the weight application and/or reduction. The class communicates this # to [Finalize] via a TopKWeightAndReduce object. # * FusedMoEModularKernel - an interface class that combines a -# FusedMoEPrepareAndFinalize and a FusedMoEPermuteExpertsUnpermute to +# FusedMoEPrepareAndFinalizeModular and a FusedMoEExpertsModular to # provide the standard fused MoE kernel interface. # * TopKWeightAndReduce - A TopKWeightAndReduce implementation chosen -# by the FusedMoEPermuteExpertsUnpermute implementation that is passed +# by the FusedMoEExpertsModular implementation that is passed # on to [Finalize]. # # [Quantize-Prepare] and [Finalize] functionality are bundled into a single -# class `FusedMoEPrepareAndFinalize` since they could use collective +# class `FusedMoEPrepareAndFinalizeModular` since they could use collective # communication mechanisms that need to be consistent. # @@ -155,25 +156,96 @@ PrepareResultType = tuple[ torch.Tensor | None, ] +# +# PrepareResultType is a tuple of: +# - quantized + dispatched a. +# - quantized + dispatched a1_scales. +# - dispatched router logits. +# +# See `prepare_monolithic` method below. +# +PrepareMonolithicResultType = tuple[ + torch.Tensor, + torch.Tensor | None, + torch.Tensor, +] + ReceiverType = Callable[[], PrepareResultType] +################################################################################ +# Prepare/Finalize +################################################################################ + -# TODO: pass FusedMoEParallelConfig in as ctor parameter? class FusedMoEPrepareAndFinalize(ABC): """ An abstract base class for the [Quantize-Prepare] and [Finalize] steps described above. + + There are two variants of this class: + * FusedMoEPrepareAndFinalizeModular - this operates on topk ids and weights + * FusedMoEPrepareAndFinalizeMonolithic - the operates on router_logits """ - def post_init_setup(self, fused_experts: "FusedMoEPermuteExpertsUnpermute"): + def post_init_setup(self, fused_experts: "FusedMoEExperts"): """ - Initialize FusedMoEPrepareAndFinalize settings that depend on - FusedMoEPermuteExpertsUnpermute experts object. - The FusedMoEPrepareAndFinalize implementations that have such + Initialize FusedMoEPrepareAndFinalizeModular settings that depend on + FusedMoEExpertsModular experts object. + The FusedMoEPrepareAndFinalizeModular implementations that have such dependencies may choose to override this function. """ return + @property + @abstractmethod + def activation_format(self) -> FusedMoEActivationFormat: + """ + A property indicating the output format of the activations for the + 'prepare' method. + """ + raise NotImplementedError + + @abstractmethod + def topk_indices_dtype(self) -> torch.dtype | None: + """ + The PrepareFinalize All2All implementations generally constrain the + dtype of the topk_ids they support. This function returns the + required topk indices dtype so it can be respected. + Return None if there are no such restrictions. + """ + raise NotImplementedError + + @abstractmethod + def max_num_tokens_per_rank(self) -> int | None: + """ + Some PrepareFinalize All2All implementations are batched. Meaning, + they can process only as set of tokens at a time. This + function returns the batch size i.e the maximum number of tokens + the implementation can process at a time. + Return None if there are no such restrictions. + """ + raise NotImplementedError + + @abstractmethod + def num_dispatchers(self) -> int: + raise NotImplementedError + + @abstractmethod + def output_is_reduced(self) -> bool: + """ + Indicates whether or not the output of finalize is reduced across all + ranks. + """ + raise NotImplementedError + + +# TODO: pass FusedMoEParallelConfig in as ctor parameter? +class FusedMoEPrepareAndFinalizeModular(FusedMoEPrepareAndFinalize): + """ + An abstract base class for the [Quantize-Prepare] and [Finalize] steps + described above for the Modular case. + """ + @abstractmethod def prepare( self, @@ -198,7 +270,7 @@ class FusedMoEPrepareAndFinalize(ABC): activations, before quantization + dispatching. - quant_config: Quantization info provided by the fused experts. - defer_input_quant: Runtime parameter indicating whether or not to - defer input quantization to the FusedMoEPermuteExpertsUnpermute + defer input quantization to the FusedMoEExpertsModular in cases where the compute kernel expects unquantized inputs Returns a tuple of: @@ -245,7 +317,7 @@ class FusedMoEPrepareAndFinalize(ABC): - apply_router_weight_on_input: When True, apply the weights to the activations, before quantization + dispatching. - defer_input_quant: Runtime parameter indicating whether or not to - defer input quantization to the FusedMoEPermuteExpertsUnpermute + defer input quantization to the FusedMoEExpertsModular in cases where the compute kernel expects unquantized inputs Returns a callback or a hook callback pair that when invoked waits for @@ -338,56 +410,58 @@ class FusedMoEPrepareAndFinalize(ABC): """ raise NotImplementedError - @property + +class FusedMoEPrepareAndFinalizeMonolithic(FusedMoEPrepareAndFinalize): + """ + An abstract base class for the [Quantize-Prepare] and [Finalize] steps + described above for the monolithic case. + """ + @abstractmethod - def activation_format(self) -> FusedMoEActivationFormat: + def prepare( + self, + a1: torch.Tensor, + router_logits: torch.Tensor, + quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, + ) -> PrepareMonolithicResultType: """ - A property indicating the output format of the activations for the - 'prepare' method. + Optional method for subclasses compatible with monolithic + FusedMoEExpertsModular kernels. + + Perform any quantization (and/or) dispatching needed for this kernel. + - a1: The (unquantized) input to the MoE layer. + - quant_config: Quantization info provided by the fused experts. + - defer_input_quant: Runtime parameter indicating whether or not to + defer input quantization to the FusedMoEExpertsModular + + Returns a tuple of: + - quantized + dispatched a. + - Optional quantized + dispatched a1_scales. """ raise NotImplementedError @abstractmethod - def topk_indices_dtype(self) -> torch.dtype | None: + def finalize(self, fused_expert_output: torch.Tensor) -> torch.Tensor: """ - The PrepareFinalize All2All implementations generally constrain the - dtype of the topk_ids they support. This function returns the - required topk indices dtype so it can be respected. - Return None if there are no such restrictions. + Optional method for subclasses compatible with monolithic + FusedMoEExpertsModular kernels. + + Perform any combine plus apply weights and perform a reduction on the + fused experts output. + - fused_expert_output: The unweighted, unreduced output of the fused + experts, it will have (M, topk, K) shape. """ raise NotImplementedError - @abstractmethod - def max_num_tokens_per_rank(self) -> int | None: - """ - Some PrepareFinalize All2All implementations are batched. Meaning, - they can process only as set of tokens at a time. This - function returns the batch size i.e the maximum number of tokens - the implementation can process at a time. - Return None if there are no such restrictions. - """ - raise NotImplementedError - @abstractmethod - def num_dispatchers(self) -> int: - raise NotImplementedError - - @abstractmethod - def output_is_reduced(self) -> bool: - """ - Indicates whether or not the output of finalize is reduced across all - ranks. - """ - raise NotImplementedError +################################################################################ +# Experts +################################################################################ # TODO: add supported activations method (return string) -class FusedMoEPermuteExpertsUnpermute(ABC): - """ - An abstract base class for the [Permute-Experts-Unpermute] step described - above. - """ - +class FusedMoEExperts(ABC): def __init__( self, moe_config: FusedMoEConfig, @@ -419,6 +493,10 @@ class FusedMoEPermuteExpertsUnpermute(ABC): self.max_num_tokens = max_num_tokens self.num_dispatchers = num_dispatchers + @staticmethod + def is_monolithic() -> bool: + raise NotImplementedError("Implemented by subclasses.") + @property def expects_unquantized_inputs(self) -> bool: """ @@ -439,49 +517,6 @@ class FusedMoEPermuteExpertsUnpermute(ABC): """ raise NotImplementedError - def moe_problem_size( - self, - a1: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - ) -> tuple[int, int, int, int, int]: - """ - Extract the MoE problem size from the given tensor arguments: - - a: The hidden states, input to the MoE layer. - - w1: The first set of expert weights. - - w2: The second set of expert weights. - - topk_ids: The topk ids. - - Note: extracting the problem shape from the weight and activation - tensors is not obvious. It needs to be done this way specifically - due to subtle issues with particular kernels, e.g. the int4 kernels - divide the trailing dimension by two, so it's not "correct" to - extract N or K from the trailing dimension of w1 or w2. Similarly, - some kernels transpose the weights, so this needs to be kept in mind. - - Note: This implementation covers most cases. However, if experts - require a specialized implementation, like MarlinExperts, they are free - to override this function. - """ - assert w1.dim() == 3 and w2.dim() == 3 - E, N, _ = w1.size() - K = a1.size(-1) - - if a1.dim() == 2: - # Make sure we are using the correct a1 (pre-permute). - assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}" - M = a1.size(0) - else: - assert a1.dim() == 3 - assert a1.size(0) == E, f"{a1.size(0)} == {E}" - M = a1.size(1) # This is max_num_tokens - - assert topk_ids.dim() == 2 - topk = topk_ids.size(1) - - return E, M, N, K, topk - # # Various helpers for registering support for various features. # Used by the oracle to select a particular kernel for a deployment. @@ -489,7 +524,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): @staticmethod def is_supported_config( - cls: type["FusedMoEPermuteExpertsUnpermute"], + cls: type["FusedMoEExperts"], moe_config: FusedMoEConfig, weight_key: QuantKey | None, activation_key: QuantKey | None, @@ -512,6 +547,21 @@ class FusedMoEPermuteExpertsUnpermute(ABC): return False, _make_reason( f"parallel config {moe_config.moe_parallel_config}" ) + elif not cls._supports_routing_method( + moe_config.routing_method, weight_key, activation_key + ): + return False, _make_reason(f"routing method {moe_config.routing_method}") + elif not cls._supports_router_logits_dtype( + moe_config.router_logits_dtype, + moe_config.routing_method, + ): + return False, _make_reason( + f"router logits dtype {moe_config.router_logits_dtype}" + ) + elif not cls._supports_shape(moe_config.hidden_dim): + return False, _make_reason( + f"{moe_config.hidden_dim} hidden dim is not supported" + ) elif activation_format != cls.activation_format(): return False, _make_reason(f"{activation_format.value} activation format") return True, None @@ -554,10 +604,48 @@ class FusedMoEPermuteExpertsUnpermute(ABC): @abstractmethod def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: """ - Whether the kernel supports deployment in expert parallel. + Whether the kernel supports deployment in particular parallel config. + + Can be overriden if a kernel does not support EP, SP or some other + configuration. """ raise NotImplementedError + @staticmethod + def _supports_routing_method( + routing_method: RoutingMethodType, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + """ + Whether the kernel supports a routing method (e.g. GroupedTopK). + + Can be overriden by monolithic kernels that execute the router + in addition to the experts if certain routers are not supported. + """ + return True + + @staticmethod + def _supports_router_logits_dtype( + router_logits_dtype: torch.dtype | None, + routing_method: RoutingMethodType, + ) -> bool: + """ + Whether a kernel supports a particular dtype for router logits input. + + Can be overriden by monolithic kernels that execute the router + in addition to the experts if certain dtypes are not supported. + """ + return True + + @staticmethod + def _supports_shape(hidden_dim: int) -> bool: + """ + Whether a kernel supports a particular shape. Can be overridden if a kernel + has specific shape requirements. + """ + return True + # # Various helpers for accessing quantization parameters from the # quant_config. @@ -654,6 +742,65 @@ class FusedMoEPermuteExpertsUnpermute(ABC): """ return False + def enable_chunking(self): + return ( + envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and self.supports_chunking() + ) + + +class FusedMoEExpertsModular(FusedMoEExperts): + """ + An abstract base class for the [Permute-Experts-Unpermute] step described + above. + """ + + @staticmethod + def is_monolithic() -> bool: + return False + + def moe_problem_size( + self, + a1: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + ) -> tuple[int, int, int, int, int]: + """ + Extract the MoE problem size from the given tensor arguments: + - a: The hidden states, input to the MoE layer. + - w1: The first set of expert weights. + - w2: The second set of expert weights. + - topk_ids: The topk ids. + + Note: extracting the problem shape from the weight and activation + tensors is not obvious. It needs to be done this way specifically + due to subtle issues with particular kernels, e.g. the int4 kernels + divide the trailing dimension by two, so it's not "correct" to + extract N or K from the trailing dimension of w1 or w2. Similarly, + some kernels transpose the weights, so this needs to be kept in mind. + + Note: This implementation covers most cases. However, if experts + require a specialized implementation, like MarlinExperts, they are free + to override this function. + """ + assert w1.dim() == 3 and w2.dim() == 3 + E, N, _ = w1.size() + K = a1.size(-1) + + if a1.dim() == 2: + # Make sure we are using the correct a1 (pre-permute). + assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}" + M = a1.size(0) + else: + assert a1.dim() == 3 + assert a1.size(0) == E, f"{a1.size(0)} == {E}" + M = a1.size(1) # This is max_num_tokens + + assert topk_ids.dim() == 2 + topk = topk_ids.size(1) + + return E, M, N, K, topk + def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: """ Workspace type: The dtype to use for the workspace tensors. @@ -726,11 +873,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ) -> None: apply_moe_activation(activation, output, input) - def enable_chunking(self): - return ( - envs.VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING and self.supports_chunking() - ) - + @abstractmethod def finalize_weight_and_reduce_impl(self) -> TopKWeightAndReduce: raise NotImplementedError @@ -791,6 +934,67 @@ class FusedMoEPermuteExpertsUnpermute(ABC): raise NotImplementedError +class FusedMoEExpertsMonolithic(FusedMoEExperts): + """ + An abstract base class for the [Permute-Experts-Unpermute] step described + above, but with the monolithic interface (accepts router logits + rather than topk ids and weights). + """ + + @staticmethod + def _supports_routing_method( + routing_method: RoutingMethodType, + weight_key: QuantKey | None, + activation_key: QuantKey | None, + ) -> bool: + """ + Whether the kernel supports a routing method (e.g. GroupedTopK). + + Monolithic kernels should explicitly opt-in to support. + """ + raise NotImplementedError + + @staticmethod + def _supports_router_logits_dtype( + router_logits_dtype: torch.dtype | None, + routing_method: RoutingMethodType, + ) -> bool: + """ + Whether the kernel supports a dtype for router logits. + + Modular kernels should opt-in to support. + """ + raise NotImplementedError + + @staticmethod + def is_monolithic() -> bool: + return True + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + apply_router_weight_on_input: bool, + # grouped topk + fused topk bias parameters + num_expert_group: int | None = None, + e_score_correction_bias: torch.Tensor | None = None, + routed_scaling_factor: float | None = None, + topk_group: int | None = None, + ) -> torch.Tensor: + """ + Same as apply(), except uses router_logits as opposed + to the topk_ids and topk_weights. This is useful for kernels + with fused router and fused_experts (e.g. FLASHINFER_TRTLLM). + """ + raise NotImplementedError + + def _slice_scales( scales: torch.Tensor | None, start: int, end: int ) -> torch.Tensor | None: @@ -802,75 +1006,32 @@ def _slice_scales( return None +################################################################################ +# Kernel +################################################################################ + + @final -class FusedMoEModularKernel(torch.nn.Module): - """ - This class combines a FusedMoEPrepareAndFinalize instance and - a FusedMoEPermuteExpertsUnpermute to provide an interface that - is compatible with the `fused_experts` function in fused_moe.py. - - It takes care of managing any required scratch space. - - Note: Instances of this class should only be used for a single model - layer due to any layer specific state that may be used by the component - objects. - """ - +class FusedMoEKernelModularImpl: def __init__( self, - prepare_finalize: FusedMoEPrepareAndFinalize, - fused_experts: FusedMoEPermuteExpertsUnpermute, + prepare_finalize: FusedMoEPrepareAndFinalizeModular, + fused_experts: FusedMoEExpertsModular, shared_experts: torch.nn.Module | None = None, moe_parallel_config: FusedMoEParallelConfig | None = None, inplace: bool = False, ): - super().__init__() self.prepare_finalize = prepare_finalize self.fused_experts = fused_experts self.shared_experts = shared_experts + self.moe_parallel_config = moe_parallel_config self.inplace = inplace - - # prefer an explicit FusedMoEParallelConfig when available (from - # FusedMoE layers / tests). - # if not provided, assume this kernel is - # running in a non-DP+EP context - self.moe_parallel_config: FusedMoEParallelConfig | None = moe_parallel_config self.is_dp_ep = ( moe_parallel_config is not None and moe_parallel_config.dp_size > 1 and moe_parallel_config.use_ep ) - self._post_init_setup() - assert ( - prepare_finalize.activation_format == fused_experts.activation_format() - ), ( - f"{prepare_finalize.__class__.__name__}." - f"{prepare_finalize.activation_format} == " - f"{fused_experts.__class__.__name__}." - f"{fused_experts.activation_format()}" - ) - - def _post_init_setup(self): - """ - Resolve any leftover setup dependencies between self.prepare_finalize - and self.fused_experts here. - """ - self.prepare_finalize.post_init_setup(self.fused_experts) - - def supports_expert_map(self) -> bool: - """ - A flag indicating whether or not this class supports expert maps. - """ - return self.fused_experts.supports_expert_map() - - def output_is_reduced(self) -> bool: - """ - Indicates whether or not the output of fused MoE kernel - is reduced across all ranks. - """ - return self.prepare_finalize.output_is_reduced() - def _chunk_info(self, M: int) -> tuple[int, int]: """ Compute number of chunks and chunk size for given M. @@ -919,7 +1080,7 @@ class FusedMoEModularKernel(torch.nn.Module): workspace_dtype = self.fused_experts.workspace_dtype(out_dtype) # Force worst-case allocation in profiling run for - # "mk.FusedMoEModularKernel.Standard" formats where this is only bounded + # "mk.FusedMoEKernel.Standard" formats where this is only bounded # by `VLLM_FUSED_MOE_CHUNK_SIZE` and may not be seen during profiling with # DP+EP due to the random token routing. is_profile_run = ( @@ -1172,9 +1333,9 @@ class FusedMoEModularKernel(torch.nn.Module): # This happens when none of the tokens from the all2all reach this # EP rank. Also, note that this is only relevant for CUDAGraph # incompatible all2all kernels like the DeepEP high-throughput - # kernels. CUDAGraph compatible all2all kernels like the pplx - # kernels and the DeepEP low-latency kernels are always batched - # and can never run into the tensor.numel() == 0 case. + # kernels. CUDAGraph compatible all2all kernels like the DeepEP + # low-latency kernels are always batched and can never run into + # the tensor.numel() == 0 case. if M_full == 0: assert num_chunks == 0 workspace13 = None @@ -1313,19 +1474,18 @@ class FusedMoEModularKernel(torch.nn.Module): assert shared_output is not None return shared_output, output - def forward( + def apply( self, hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, + topk_weights: torch.Tensor, activation: MoEActivation = MoEActivation.SILU, global_num_experts: int = -1, expert_map: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, shared_experts_input: torch.Tensor | None = None, - **kwargs ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ This function computes a Mixture of Experts (MoE) layer using two sets @@ -1335,8 +1495,7 @@ class FusedMoEModularKernel(torch.nn.Module): - hidden_states: (torch.Tensor): The input tensor to the MoE layer. - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. - - topk_weights (torch.Tensor): The topk weights applied at the end of - the layer. + - topk_weights (torch.Tensor): The topk weights applied at the end of the layer. - topk_ids (torch.Tensor): A map of row to expert id. - activation (MoEActivation): The activation function to apply after the first MoE layer. @@ -1355,23 +1514,6 @@ class FusedMoEModularKernel(torch.nn.Module): Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ - from .fused_moe import fused_experts as fused_experts_kernel - - result = fused_experts_kernel( - hidden_states=hidden_states, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - quant_config=kwargs.get("quant_config", None), - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, - ) - - return result if self.inplace: assert self.shared_experts is None assert not disable_inplace() @@ -1417,3 +1559,206 @@ class FusedMoEModularKernel(torch.nn.Module): apply_router_weight_on_input, shared_experts_input=shared_experts_input, ) + + +@final +class FusedMoEKernelMonolithicImpl: + def __init__( + self, + prepare_finalize: FusedMoEPrepareAndFinalizeMonolithic, + fused_experts: FusedMoEExpertsMonolithic, + ): + self.prepare_finalize = prepare_finalize + self.fused_experts = fused_experts + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + apply_router_weight_on_input: bool, + # grouped topk + fused topk bias parameters + num_expert_group: int | None = None, + e_score_correction_bias: torch.Tensor | None = None, + routed_scaling_factor: float | None = None, + topk_group: int | None = None, + ) -> torch.Tensor: + """ + Same as forward(), except uses router_logits as opposed + to the topk_ids and topk_weights. This is used for kernels + that have fused router + experts (e.g. FLASHINFER_TRTLLM). + """ + + # TODO(rob): add inplace support. + a1q, a1q_scale, router_logits = self.prepare_finalize.prepare( + hidden_states, + router_logits=router_logits, + quant_config=self.fused_experts.quant_config, + defer_input_quant=self.fused_experts.expects_unquantized_inputs, + ) + + fused_out = self.fused_experts.apply( + hidden_states=a1q, + w1=w1, + w2=w2, + router_logits=router_logits, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + a1q_scale=a1q_scale, + # grouped topk + fused topk bias parameters + num_expert_group=num_expert_group, + e_score_correction_bias=e_score_correction_bias, + routed_scaling_factor=routed_scaling_factor, + topk_group=topk_group, + ) + + output = self.prepare_finalize.finalize(fused_out) + + return output + + +@final +class FusedMoEKernel: + def __init__( + self, + prepare_finalize: FusedMoEPrepareAndFinalize, + fused_experts: FusedMoEExperts, + shared_experts: torch.nn.Module | None = None, + moe_parallel_config: FusedMoEParallelConfig | None = None, + inplace: bool = False, + ): + super().__init__() + self.shared_experts = shared_experts # NOTE: check if we can remove + + # Initialize the implementation (monolithic or modular). + self.impl: FusedMoEKernelModularImpl | FusedMoEKernelMonolithicImpl + if isinstance( + prepare_finalize, FusedMoEPrepareAndFinalizeModular + ) and isinstance(fused_experts, FusedMoEExpertsModular): + self.impl = FusedMoEKernelModularImpl( + prepare_finalize, + fused_experts, + shared_experts, + moe_parallel_config, + inplace, + ) + + elif isinstance( + prepare_finalize, FusedMoEPrepareAndFinalizeMonolithic + ) and isinstance(fused_experts, FusedMoEExpertsMonolithic): + assert shared_experts is None + assert not inplace + self.impl = FusedMoEKernelMonolithicImpl( + prepare_finalize, + fused_experts, + ) + + else: + raise ValueError( + "prepare_finalize and fused_experts must both be either monolithic " + f"or non-monolithic but got {prepare_finalize.__class__.__name__} " + f"and {fused_experts.__class__.__name__}" + ) + + self._post_init_setup() + + @property + def is_monolithic(self) -> bool: + return isinstance(self.impl, FusedMoEKernelMonolithicImpl) + + @property + def prepare_finalize(self) -> FusedMoEPrepareAndFinalize: + return self.impl.prepare_finalize + + @property + def fused_experts(self) -> FusedMoEExperts: + return self.impl.fused_experts + + def _post_init_setup(self): + """ + Resolve any leftover setup dependencies between self.prepare_finalize + and self.fused_experts here. + """ + self.prepare_finalize.post_init_setup(self.impl.fused_experts) + assert ( + self.prepare_finalize.activation_format + == self.fused_experts.activation_format() + ) + + def supports_expert_map(self) -> bool: + """ + A flag indicating whether or not this class supports expert maps. + """ + return self.fused_experts.supports_expert_map() + + def output_is_reduced(self) -> bool: + """ + Indicates whether or not the output of fused MoE kernel + is reduced across all ranks. + """ + return self.prepare_finalize.output_is_reduced() + + def apply_monolithic( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + apply_router_weight_on_input: bool, + # grouped topk + fused topk bias parameters + num_expert_group: int | None = None, + e_score_correction_bias: torch.Tensor | None = None, + routed_scaling_factor: float | None = None, + topk_group: int | None = None, + ) -> torch.Tensor: + assert isinstance(self.impl, FusedMoEKernelMonolithicImpl) + return self.impl.apply( + hidden_states=hidden_states, + w1=w1, + w2=w2, + router_logits=router_logits, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + num_expert_group=num_expert_group, + e_score_correction_bias=e_score_correction_bias, + routed_scaling_factor=routed_scaling_factor, + topk_group=topk_group, + ) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: MoEActivation, + global_num_experts: int, + expert_map: torch.Tensor | None, + apply_router_weight_on_input: bool, + shared_experts_input: torch.Tensor | None = None, + ) -> torch.Tensor: + assert isinstance(self.impl, FusedMoEKernelModularImpl) + return self.impl.apply( + hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + shared_experts_input=shared_experts_input, + ) diff --git a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py index dc0f32d..164605d 100644 --- a/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/mori_prepare_finalize.py @@ -12,7 +12,7 @@ from vllm.platforms import current_platform logger = init_logger(__name__) -class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): +class MoriPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): """ Prepare/Finalize using MoRI kernels. """ diff --git a/vllm/model_executor/layers/fused_moe/oracle/fp8.py b/vllm/model_executor/layers/fused_moe/oracle/fp8.py index 6f961df..0ed159b 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/fp8.py +++ b/vllm/model_executor/layers/fused_moe/oracle/fp8.py @@ -18,13 +18,9 @@ from vllm.model_executor.layers.fused_moe.config import ( fp8_w8a8_moe_quant_config, fp8_w8a16_moe_quant_config, ) -from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import ( - is_supported_config_trtllm_fp8, -) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, get_flashinfer_moe_backend, - make_fp8_moe_alpha_scales_for_fi, prepare_fp8_moe_layer_for_fi, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( @@ -103,9 +99,13 @@ def _get_priority_backends( def backend_to_kernel_cls( backend: Fp8MoeBackend, -) -> type[mk.FusedMoEPermuteExpertsUnpermute]: +) -> type[mk.FusedMoEExperts]: if backend == Fp8MoeBackend.FLASHINFER_TRTLLM: - raise NotImplementedError + from vllm.model_executor.layers.fused_moe.experts.trtllm_fp8_moe import ( # noqa: E501 + TrtLlmFp8Experts, + ) + + return TrtLlmFp8Experts elif backend == Fp8MoeBackend.FLASHINFER_CUTLASS: from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( @@ -205,13 +205,11 @@ def select_fp8_moe_backend( weight_key: QuantKey | None, activation_key: QuantKey | None, allow_vllm_cutlass: bool = False, -) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]: +) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts] | None]: """ Select the primary FP8 MoE backend Note: Shape-specific fallbacks may still occur at runtime. """ - k_cls: type[mk.FusedMoEPermuteExpertsUnpermute] | None = None - if config.is_lora_enabled: return Fp8MoeBackend.TRITON, backend_to_kernel_cls(Fp8MoeBackend.TRITON) @@ -252,7 +250,7 @@ def select_fp8_moe_backend( weight_key: QuantKey | None, activation_key: QuantKey | None, activation_format: mk.FusedMoEActivationFormat, - ) -> tuple[Fp8MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]: + ) -> tuple[Fp8MoeBackend, type[mk.FusedMoEExperts]]: k_cls = backend_to_kernel_cls(backend) supported, reason = k_cls.is_supported_config( k_cls, config, weight_key, activation_key, activation_format @@ -287,16 +285,6 @@ def select_fp8_moe_backend( "vLLM CUTLASS FP8 MoE backend is disabled for this configuration." ) - # Handle FLASHINFER_TRTLLM specially (no kernel class). - if requested_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: - supported, reason = is_supported_config_trtllm_fp8( - config, weight_key, activation_key, activation_format - ) - if supported: - logger.info_once(_make_log_backend(requested_backend)) - return requested_backend, None - raise ValueError(_make_log_unsupported(requested_backend, reason)) - return _return_or_raise( requested_backend, config, weight_key, activation_key, activation_format ) @@ -311,51 +299,32 @@ def select_fp8_moe_backend( elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"): # If user is explicit about backend, validate it. fi_backend = get_flashinfer_moe_backend() - - if fi_backend == FlashinferMoeBackend.TENSORRT_LLM: - backend = Fp8MoeBackend.FLASHINFER_TRTLLM - supported, reason = is_supported_config_trtllm_fp8( - config, weight_key, activation_key, activation_format - ) - if supported: - logger.info_once(_make_log_backend(backend)) - return backend, None - else: - raise ValueError(_make_log_unsupported(backend, reason)) - - elif fi_backend == FlashinferMoeBackend.CUTLASS: + if fi_backend == FlashinferMoeBackend.CUTLASS: backend = Fp8MoeBackend.FLASHINFER_CUTLASS - return _return_or_raise( - backend, config, weight_key, activation_key, activation_format - ) - + elif fi_backend == FlashinferMoeBackend.TENSORRT_LLM: + backend = Fp8MoeBackend.FLASHINFER_TRTLLM else: - assert fi_backend == FlashinferMoeBackend.CUTEDSL - raise ValueError("FlashInfer MaskedGEMM not supported for FP8") - + raise ValueError( + f"FlashInfer MOE backend {fi_backend} does not support FP8 MoE." + ) + k_cls = backend_to_kernel_cls(backend) + return _return_or_raise( + backend, config, weight_key, activation_key, activation_format + ) else: # If the user is not explicit about the backend, try both. for backend in [ Fp8MoeBackend.FLASHINFER_TRTLLM, Fp8MoeBackend.FLASHINFER_CUTLASS, ]: - if backend == Fp8MoeBackend.FLASHINFER_TRTLLM: - k_cls = None - supported, reason = is_supported_config_trtllm_fp8( - config, - weight_key, - activation_key, - activation_format, - ) - else: - k_cls = backend_to_kernel_cls(backend) - supported, reason = k_cls.is_supported_config( - k_cls, - config, - weight_key, - activation_key, - activation_format, - ) + k_cls = backend_to_kernel_cls(backend) + supported, reason = k_cls.is_supported_config( + k_cls, + config, + weight_key, + activation_key, + activation_format, + ) if supported: logger.info_once(_make_log_backend(backend), scope="local") @@ -408,23 +377,14 @@ def select_fp8_moe_backend( # Select kernels in order of backend. for backend in AVAILABLE_BACKENDS: - if backend == Fp8MoeBackend.FLASHINFER_TRTLLM: - k_cls = None - supported, reason = is_supported_config_trtllm_fp8( - config, - weight_key, - activation_key, - activation_format, - ) - else: - k_cls = backend_to_kernel_cls(backend) - supported, reason = k_cls.is_supported_config( - k_cls, - config, - weight_key, - activation_key, - activation_format, - ) + k_cls = backend_to_kernel_cls(backend) + supported, reason = k_cls.is_supported_config( + k_cls, + config, + weight_key, + activation_key, + activation_format, + ) if supported: logger.info_once(_make_log_backend(backend), scope="local") @@ -510,7 +470,7 @@ def make_fp8_moe_quant_config( block_shape: list[int] | None = None, per_act_token_quant: bool = False, per_out_ch_quant: bool = False, -) -> FusedMoEQuantConfig | None: +) -> FusedMoEQuantConfig: """ Create FusedMoEQuantConfig for the specified FP8 Backend. The FusedMoEQuantConfig holds the scales that are used @@ -523,9 +483,6 @@ def make_fp8_moe_quant_config( In a future PR, we will have this function should be a method of the modular kernel itself. """ - # TRTLLM does not use Modular Kernel abstraction yet. - if fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: - return None # MARLIN is mixed precision W8A16 config. if fp8_backend == Fp8MoeBackend.MARLIN: @@ -539,12 +496,6 @@ def make_fp8_moe_quant_config( # (alpha = w_scale * a_scale) and inverse a2 scale. if fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS and block_shape is None: assert a1_scale is not None and a2_scale is not None - g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi( - w1_scale, - a1_scale, - w2_scale, - a2_scale, - ) return fp8_w8a8_moe_quant_config( w1_scale=w1_scale, w2_scale=w2_scale, @@ -552,8 +503,8 @@ def make_fp8_moe_quant_config( a2_scale=a2_scale, a1_gscale=(1.0 / a1_scale), a2_gscale=(1.0 / a2_scale), - g1_alphas=g1_alphas, - g2_alphas=g2_alphas, + g1_alphas=(w1_scale * a1_scale).squeeze(), + g2_alphas=(w2_scale * a2_scale).squeeze(), ) # All other backends use normal config. return fp8_w8a8_moe_quant_config( @@ -570,17 +521,18 @@ def make_fp8_moe_quant_config( def make_fp8_moe_kernel( moe_quant_config: FusedMoEQuantConfig, moe_config: FusedMoEConfig, - experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], + experts_cls: type[mk.FusedMoEExperts], fp8_backend: Fp8MoeBackend, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, shared_experts: torch.nn.Module | None = None, -) -> mk.FusedMoEModularKernel: +) -> mk.FusedMoEKernel: # Create Prepare/Finalize. prepare_finalize = maybe_make_prepare_finalize( moe=moe_config, quant_config=moe_quant_config, routing_tables=routing_tables, allow_new_interface=True, + use_monolithic=issubclass(experts_cls, mk.FusedMoEExpertsMonolithic), ) assert prepare_finalize is not None @@ -603,9 +555,9 @@ def make_fp8_moe_kernel( ) # NOTE(rob): we only want the mk to control the shared_expert - # if using all2all (for SBO). bnell is making this explict in + # if using all2all (for SBO). bnell is making this explicit in # the new MoE runner class. - kernel = mk.FusedMoEModularKernel( + kernel = mk.FusedMoEKernel( prepare_finalize, experts, shared_experts=( diff --git a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py index ee7db88..dd1a24d 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/nvfp4.py @@ -19,7 +19,6 @@ from vllm.model_executor.layers.fused_moe.config import ( nvfp4_w4a16_moe_quant_config, ) from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - is_supported_config_trtllm, prepare_nvfp4_moe_layer_for_fi_or_cutlass, ) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( @@ -67,39 +66,46 @@ def is_global_sf_supported_for_nvfp4_backend(backend: NvFp4MoeBackend) -> bool: def backend_to_kernel_cls( backend: NvFp4MoeBackend, -) -> type[mk.FusedMoEPermuteExpertsUnpermute]: +) -> list[type[mk.FusedMoEExperts]]: if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: - raise NotImplementedError( - "FLASHINFER_TRTLLM doesn't support Modular Kernel Interface" + from vllm.model_executor.layers.fused_moe.experts.trtllm_nvfp4_moe import ( + TrtLlmNvFp4ExpertsModular, + TrtLlmNvFp4ExpertsMonolithic, ) + # NOTE: prefer Monolthic > Modular, so return Monolithic first. + return [ + TrtLlmNvFp4ExpertsMonolithic, + TrtLlmNvFp4ExpertsModular, + ] + elif backend == NvFp4MoeBackend.FLASHINFER_CUTLASS: from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts, ) - return FlashInferExperts + return [FlashInferExperts] elif backend == NvFp4MoeBackend.FLASHINFER_CUTEDSL: from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( FlashInferCuteDSLExperts, ) - return FlashInferCuteDSLExperts + return [FlashInferCuteDSLExperts] elif backend == NvFp4MoeBackend.VLLM_CUTLASS: from vllm.model_executor.layers.fused_moe.cutlass_moe import ( CutlassExpertsFp4, ) - return CutlassExpertsFp4 + return [CutlassExpertsFp4] elif backend == NvFp4MoeBackend.MARLIN: from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( MarlinExperts, ) - return MarlinExperts + return [MarlinExperts] else: raise ValueError(f"Unknown NvFP4 MoE backend: {backend.value}") @@ -125,7 +131,7 @@ def select_nvfp4_moe_backend( config: FusedMoEConfig, weight_key: QuantKey | None, activation_key: QuantKey | None, -) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]: +) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]: """ Select the primary NvFP4 MoE backend Note: Shape-specific fallbacks may still occur at runtime. @@ -143,10 +149,7 @@ def select_nvfp4_moe_backend( # NOTE(rob): this is kind of a hack. We need to peak into # the prepare-finalize selection to determine if we are using # the batched or standard expert format. - use_batched = ( - config.moe_parallel_config.use_deepep_ll_kernels - or config.moe_parallel_config.use_pplx_kernels - ) + use_batched = config.moe_parallel_config.use_deepep_ll_kernels activation_format = ( mk.FusedMoEActivationFormat.BatchedExperts if use_batched @@ -178,29 +181,21 @@ def select_nvfp4_moe_backend( weight_key: QuantKey | None, activation_key: QuantKey | None, activation_format: mk.FusedMoEActivationFormat, - ) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]: - k_cls = backend_to_kernel_cls(backend) - supported, reason = k_cls.is_supported_config( - k_cls, config, weight_key, activation_key, activation_format - ) - if supported: - logger.info_once(_make_log_backend(backend)) - return backend, k_cls + ) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEExperts]]: + for k_cls in backend_to_kernel_cls(backend): + supported, reason = k_cls.is_supported_config( + k_cls, config, weight_key, activation_key, activation_format + ) + if supported: + logger.info_once(_make_log_backend(backend)) + return backend, k_cls + raise ValueError(_make_log_unsupported(backend, reason)) # Handle explicit moe_backend from user. runner_backend = config.moe_backend if runner_backend != "auto": requested_backend = map_nvfp4_backend(runner_backend) - if requested_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: - supported, reason = is_supported_config_trtllm( - config, weight_key, activation_key, activation_format - ) - if supported: - logger.info_once(_make_log_backend(requested_backend)) - return requested_backend, None - raise ValueError(_make_log_unsupported(requested_backend, reason)) - return _return_or_raise( requested_backend, config, weight_key, activation_key, activation_format ) @@ -213,36 +208,14 @@ def select_nvfp4_moe_backend( elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"): # If user is explicit about backend, validate it. - fi_backend = get_flashinfer_moe_backend() - - if fi_backend == FlashinferMoeBackend.TENSORRT_LLM: - backend = NvFp4MoeBackend.FLASHINFER_TRTLLM - supported, reason = is_supported_config_trtllm( - config, weight_key, activation_key, activation_format - ) - if supported: - logger.info_once(_make_log_backend(backend)) - return backend, None - else: - raise ValueError(_make_log_unsupported(backend, reason)) - else: - backend = fi_2_vllm_backend_map[fi_backend] - return _return_or_raise( - backend, config, weight_key, activation_key, activation_format - ) + backend = fi_2_vllm_backend_map[get_flashinfer_moe_backend()] + return _return_or_raise( + backend, config, weight_key, activation_key, activation_format + ) else: # If the user is not explicit about the backend, try each. for backend in FLASHINFER_NVFP4_MOE_BACKENDS: - if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: - k_cls = None - supported, reason = is_supported_config_trtllm( - config, - weight_key, - activation_key, - activation_format, - ) - else: - k_cls = backend_to_kernel_cls(backend) + for k_cls in backend_to_kernel_cls(backend): supported, reason = k_cls.is_supported_config( k_cls, config, @@ -250,13 +223,13 @@ def select_nvfp4_moe_backend( activation_key, activation_format, ) - if supported: - logger.info_once(_make_log_backend(backend), scope="local") - return backend, None - else: - logger.debug_once( - _make_log_unsupported(backend, reason), scope="local" - ) + if supported: + logger.info_once(_make_log_backend(backend), scope="local") + return backend, k_cls + else: + logger.debug_once( + _make_log_unsupported(backend, reason), scope="local" + ) raise NotImplementedError( "Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no " @@ -271,16 +244,7 @@ def select_nvfp4_moe_backend( # Select kernels in order of backend. for backend in AVAILABLE_BACKENDS: - if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: - k_cls = None # type: ignore[assignment] - supported, reason = is_supported_config_trtllm( - config, - weight_key, - activation_key, - activation_format, - ) - else: - k_cls = backend_to_kernel_cls(backend) + for k_cls in backend_to_kernel_cls(backend): supported, reason = k_cls.is_supported_config( k_cls, config, @@ -289,11 +253,11 @@ def select_nvfp4_moe_backend( activation_format, ) - if supported: - logger.info_once(_make_log_backend(backend), scope="local") - return backend, k_cls - else: - logger.debug_once(_make_log_unsupported(backend, reason), scope="local") + if supported: + logger.info_once(_make_log_backend(backend), scope="local") + return backend, k_cls + else: + logger.debug_once(_make_log_unsupported(backend, reason), scope="local") raise NotImplementedError( "No NvFp4 MoE backend supports the deployment configuration." @@ -401,12 +365,8 @@ def make_nvfp4_moe_quant_config( w2_scale_2: torch.Tensor, a13_scale: torch.Tensor, a2_scale: torch.Tensor, -) -> FusedMoEQuantConfig | None: - UNSUPPORTED = [NvFp4MoeBackend.FLASHINFER_TRTLLM] - if backend in UNSUPPORTED: - return None - - elif backend == NvFp4MoeBackend.MARLIN: +) -> FusedMoEQuantConfig: + if backend == NvFp4MoeBackend.MARLIN: return nvfp4_w4a16_moe_quant_config( g1_alphas=w13_scale_2, g2_alphas=w2_scale_2, @@ -423,22 +383,27 @@ def make_nvfp4_moe_quant_config( a2_gscale=(1.0 / a2_scale), w1_scale=w13_scale, w2_scale=w2_scale, + # NOTE(rob): this is a hack until the MoE kernels + # create their own quant configs. TRTLLM kernel + # does not accept swizzled input quant scales. + is_nvfp4_scale_swizzled=(backend != NvFp4MoeBackend.FLASHINFER_TRTLLM), ) def make_nvfp4_moe_kernel( moe_quant_config: FusedMoEQuantConfig, moe_config: FusedMoEConfig, - experts_cls: type[mk.FusedMoEPermuteExpertsUnpermute], + experts_cls: type[mk.FusedMoEExperts], routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, shared_experts: torch.nn.Module | None = None, -) -> mk.FusedMoEModularKernel: +) -> mk.FusedMoEKernel: # Create Prepare/Finalize. prepare_finalize = maybe_make_prepare_finalize( moe=moe_config, quant_config=moe_quant_config, routing_tables=routing_tables, allow_new_interface=True, + use_monolithic=issubclass(experts_cls, mk.FusedMoEExpertsMonolithic), ) assert prepare_finalize is not None @@ -461,9 +426,9 @@ def make_nvfp4_moe_kernel( ) # NOTE(rob): we only want the mk to control the shared_expert - # if using all2all (for SBO). bnell is making this explict in + # if using all2all (for SBO). bnell is making this explicit in # the new MoE runner class. - kernel = mk.FusedMoEModularKernel( + kernel = mk.FusedMoEKernel( prepare_finalize, experts, shared_experts=( diff --git a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py index 1c582bc..3d8d56e 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py +++ b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py @@ -19,7 +19,7 @@ from vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe import ( is_supported_config_trtllm_bf16, ) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( - MoEPrepareAndFinalizeNoEP, + MoEPrepareAndFinalizeNoDPEPModular, ) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( swap_w13_to_w31, @@ -209,7 +209,7 @@ def make_unquantized_moe_kernel( backend: UnquantizedMoeBackend, quant_config: FusedMoEQuantConfig, moe_config: FusedMoEConfig, -) -> mk.FusedMoEModularKernel | None: +) -> mk.FusedMoEKernel | None: if backend in UNSUPPORTED_BACKEND: return None @@ -218,8 +218,8 @@ def make_unquantized_moe_kernel( FlashInferExperts, ) - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), + kernel = mk.FusedMoEKernel( + MoEPrepareAndFinalizeNoDPEPModular(), FlashInferExperts( moe_config=moe_config, quant_config=quant_config, @@ -232,8 +232,8 @@ def make_unquantized_moe_kernel( AiterExperts, ) - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), + kernel = mk.FusedMoEKernel( + MoEPrepareAndFinalizeNoDPEPModular(), AiterExperts( moe_config=moe_config, quant_config=quant_config, @@ -241,25 +241,6 @@ def make_unquantized_moe_kernel( inplace=not moe_config.disable_inplace, ) elif backend == UnquantizedMoeBackend.TRITON: - from vllm.model_executor.layers.fused_moe import TritonExperts - - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - TritonExperts( - moe_config=moe_config, - quant_config=quant_config, - ), - inplace=not moe_config.disable_inplace, - ) - elif backend == UnquantizedMoeBackend.XPU: - from vllm.model_executor.layers.fused_moe import XPUExperts - - kernel = mk.FusedMoEModularKernel( - MoEPrepareAndFinalizeNoEP(), - XPUExperts( - moe_config=moe_config, - quant_config=quant_config, - ), - inplace=not moe_config.disable_inplace, - ) + from vllm.model_executor.layers.fused_moe import fused_experts + kernel = fused_experts return kernel diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py deleted file mode 100644 index 289ac0d..0000000 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ /dev/null @@ -1,373 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable - -import pplx_kernels as pplx -import torch - -import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceDelegate, -) -from vllm.model_executor.layers.fused_moe.utils import ( - _validate_scale_shape, - moe_kernel_quantize_input, -) -from vllm.utils.math_utils import cdiv, round_up - -logger = init_logger(__name__) - - -def pplx_hidden_dim_scale_bytes( - max_num_tokens: int, - hidden_dim: int, - in_dtype: torch.dtype, - quant_dtype: torch.dtype | str | None, - per_act_token_quant: bool, - block_shape: list[int] | None, -): - # All pplx byte sizes must be 16-byte aligned. - align = 16 - - # For blocked per token: set to - # cdiv(hidden_dim, block_size) * sizeof(float32) - # For per-token: set to 4 * sizeof(float32) (x4 for alignment) - if quant_dtype is not None: - assert isinstance(quant_dtype, torch.dtype) - assert quant_dtype.itemsize == 1 - hidden_dim_bytes = hidden_dim * quant_dtype.itemsize - elem_size = torch.float32.itemsize - - if per_act_token_quant: - # per-token (M x 1) - assert block_shape is None - hidden_scale_bytes = elem_size - elif block_shape is not None: - # per-group (M x K_tiles) - block_size = block_shape[1] - num_blocks = cdiv(hidden_dim, block_size) - hidden_scale_bytes = num_blocks * elem_size - else: - # per-tensor (1 x 1) - hidden_scale_bytes = elem_size - else: - hidden_dim_bytes = hidden_dim * in_dtype.itemsize - hidden_scale_bytes = 0 - - return ( - round_up(hidden_dim_bytes, align), - round_up(hidden_scale_bytes, align), - ) - - -class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): - """PPLX-based prepare and finalize for expert parallelism.""" - - def __init__( - self, - a2a: pplx.AllToAll, - max_num_tokens: int, - num_local_experts: int, - num_dispatchers: int, - ): - super().__init__() - assert max_num_tokens > 0 - assert num_local_experts > 0 - self.a2a = a2a - self.max_num_tokens = max_num_tokens - self.num_local_experts = num_local_experts - self.num_dispatchers_ = num_dispatchers - - @property - def activation_format(self) -> mk.FusedMoEActivationFormat: - return mk.FusedMoEActivationFormat.BatchedExperts - - def max_num_tokens_per_rank(self) -> int | None: - return self.max_num_tokens - - def topk_indices_dtype(self) -> torch.dtype | None: - return torch.uint32 - - def num_dispatchers(self) -> int: - return self.num_dispatchers_ - - def output_is_reduced(self) -> bool: - return True - - def supports_async(self) -> bool: - return True - - def prepare_async( - self, - a1: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: torch.Tensor | None, - apply_router_weight_on_input: bool, - quant_config: FusedMoEQuantConfig, - defer_input_quant: bool = False, - ) -> tuple[Callable, mk.ReceiverType]: - if defer_input_quant: - raise NotImplementedError( - f"{self.__class__.__name__} does not support defer_input_quant=True. " - "Please select an MoE kernel that accepts quantized inputs." - ) - - num_tokens = a1.size(0) # M - hidden_dim = a1.size(-1) # K - - assert topk_ids.size(0) == num_tokens - # expert_map should be None because with expert map, -1 id is used for - # non-local token; this causes error when casting ids to the - # topk_indices_dtype() int32 - # - if expert_map is not None: - logger.warning_once( - "The PPLX backend does not support expert mapping. " - "The provided `expert_map` will be ignored." - ) - expert_map = None # noqa: F841 - - # Is this always going to be a1.device? - device = a1.device - - if apply_router_weight_on_input: - topk = topk_ids.size(1) - # TODO: this only works for topK=1, will need to update for topK>1 - assert topk == 1, ( - "apply_router_weight_on_input is only implemented for topk=1" - ) - a1 = a1 * topk_weights.to(a1.dtype) - - repeat_cols = 4 - repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0) - # TODO(bnell): always pass quant_config.a1_scale? - a1q, a1q_scale = moe_kernel_quantize_input( - a1, - (None if quant_config.per_act_token_quant else quant_config.a1_scale), - quant_dtype=quant_config.quant_dtype, - per_act_token_quant=quant_config.per_act_token_quant, - block_shape=quant_config.block_shape, - ) - - _validate_scale_shape( - a1q, a1q_scale, quant_config.per_act_token_quant, quant_config.block_shape - ) - - orig_a_scale_block_shape: int | None = None - - if a1q_scale is not None: - scalar_scales = a1q_scale.numel() == 1 - - # pplx requires 2-d scales even for scalar scales - if a1q_scale.dim() <= 1: - assert scalar_scales - a1q_scale = a1q_scale.view(1, 1) - - orig_a_scale_block_shape = a1q_scale.shape[-1] - - if not quant_config.is_block_quantized: - # TODO (bnell): use group_broadcast instead? - a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols) - - assert a1q_scale is None or a1q_scale.ndim == 2, ( - f"{0 if a1q_scale is None else (a1q_scale.ndim, a1q_scale.shape)}" - ) - - expert_num_tokens = torch.empty( - self.num_local_experts, - dtype=torch.int32, - device=device, - ) - - expert_x = torch.empty( - ( - self.num_local_experts, - self.max_num_tokens * self.num_dispatchers(), - hidden_dim, - ), - dtype=a1q.dtype, - device=device, - ) - - expert_x_scale: torch.Tensor | None = None - if a1q.dtype.itemsize == 1: - if quant_config.is_per_act_token: - # (M x 1) -> (E x M x K) - final_dim = expert_x.size(2) - elif quant_config.is_per_tensor: - # (1 x 1) -> (E x 1 x 1) - final_dim = 1 - else: - # (M x K_tiles) -> (E x M x K_tiles) - assert quant_config.block_shape is not None - num_blocks = cdiv(expert_x.size(2), quant_config.block_shape[1]) - final_dim = num_blocks - - expert_x_scale_shape = ( - self.num_local_experts, - expert_x.size(1), - round_up(final_dim, 4), # round up for alignment - ) - - expert_x_scale = torch.empty( - expert_x_scale_shape, - dtype=torch.float32, - device=expert_x.device, - ) - - # This argument is optional, defaults to indices.size(0) - # There's not much point setting this unless it is != indices.size(0) - bound_m: torch.Tensor | None = None - - self.a2a.dispatch( - out_expert_num_tokens=expert_num_tokens, - out_expert_x=expert_x, - out_expert_x_scale=expert_x_scale, - dp_x=a1q, - dp_x_scale=a1q_scale, - indices=topk_ids, - bound_m=bound_m, - do_send=True, - do_recv=False, - ) - - hook = lambda: self.a2a.dispatch( - out_expert_num_tokens=expert_num_tokens, - out_expert_x=expert_x, - out_expert_x_scale=expert_x_scale, - dp_x=a1q, - dp_x_scale=a1q_scale, - indices=topk_ids, - bound_m=bound_m, - do_send=False, - do_recv=True, - ) - - return ( - hook, - lambda: self._receiver( - expert_num_tokens, - expert_x, - expert_x_scale, - orig_a_scale_block_shape, - ), - ) - - def _receiver( - self, - expert_num_tokens: torch.Tensor, - expert_x: torch.Tensor, - expert_x_scale: torch.Tensor | None, - orig_a_scale_block_shape: int | None, - ) -> mk.PrepareResultType: - if expert_x_scale is not None: - expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape] - assert expert_x_scale.ndim == 3 - - expert_tokens_meta = mk.ExpertTokensMetadata( - expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None - ) - - return expert_x, expert_x_scale, expert_tokens_meta, None, None - - def prepare( - self, - a1: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: torch.Tensor | None, - apply_router_weight_on_input: bool, - quant_config: FusedMoEQuantConfig, - defer_input_quant: bool = False, - ) -> mk.PrepareResultType: - hook, receiver = self.prepare_async( - a1, - topk_weights, - topk_ids, - num_experts, - expert_map, - apply_router_weight_on_input, - quant_config, - defer_input_quant=defer_input_quant, - ) - hook() - return receiver() - - def finalize_async( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> Callable: - assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate), ( - "Weight application and reduction happens in the combine kernel." - ) - - # This argument is optional - # There's not much point setting this unless it is != topk_ids.size(0) - bound_m: torch.Tensor | None = None - - # TODO (bnell): fails in test_pplx_moe.py, figure out what's going on - # num_tokens = output.size(0) # M - # assert topk_ids.size(0) == num_tokens, ( - # f"{topk_ids.size(0)} == {num_tokens}") - assert topk_ids.size() == topk_weights.size(), ( - f"{topk_ids.size()} == {topk_weights.size()}" - ) - assert output.size(0) <= self.max_num_tokens, ( - f"{output.size(0)} <= {self.max_num_tokens}" - ) - assert output.size(1) == fused_expert_output.size(-1) - - # Set weights to 1 if we did them in dispatch. This is hacky. - if apply_router_weight_on_input: - topk_weights = torch.ones_like(topk_weights) - - topk_ids_u32 = topk_ids.view(dtype=torch.uint32) - - self.a2a.combine( - out_tokens=output, - indices=topk_ids_u32, - weights=topk_weights, - expert_y=fused_expert_output, - bound_m=bound_m, - do_send=True, - do_recv=False, - ) - - return lambda: self.a2a.combine( - out_tokens=output, - indices=topk_ids_u32, - weights=topk_weights, - expert_y=fused_expert_output, - bound_m=bound_m, - do_send=False, - do_recv=True, - ) - - def finalize( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> None: - receiver = self.finalize_async( - output, - fused_expert_output, - topk_weights, - topk_ids, - apply_router_weight_on_input, - weight_and_reduce_impl, - ) - receiver() diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py deleted file mode 100644 index 7b8dd3b..0000000 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ /dev/null @@ -1,209 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.distributed import get_ep_group -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( - TopKWeightAndReduceContiguous, - TopKWeightAndReduceDelegate, -) -from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input -from vllm.utils.flashinfer import nvfp4_block_scale_interleave - - -class MoEPrepareAndFinalizeNaiveEP(mk.FusedMoEPrepareAndFinalize): - def __init__( - self, - is_sequence_parallel: bool = False, - num_dispatchers: int = 1, - ) -> None: - super().__init__() - self.is_sequence_parallel = is_sequence_parallel - self._num_dispatchers = num_dispatchers - - @property - def activation_format(self) -> mk.FusedMoEActivationFormat: - return mk.FusedMoEActivationFormat.Standard - - def max_num_tokens_per_rank(self) -> int | None: - return None - - def topk_indices_dtype(self) -> torch.dtype | None: - return None - - def num_dispatchers(self) -> int: - return self._num_dispatchers - - def output_is_reduced(self) -> bool: - return False - - def prepare( - self, - a1: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: torch.Tensor | None, - apply_router_weight_on_input: bool, - quant_config: FusedMoEQuantConfig, - defer_input_quant: bool = False, - ) -> mk.PrepareResultType: - if apply_router_weight_on_input: - topk = topk_ids.size(1) - assert topk == 1, ( - "apply_router_weight_on_input is only implemented for topk=1" - ) - # Note: do not use inplace for shared experts overlap - a1 = a1 * topk_weights.to(a1.dtype) - - # Defer input quantization to the MoE kernel. - use_nvfp4 = quant_config.use_nvfp4_w4a4 - if defer_input_quant: - a1q = a1 - a1q_scale = None - else: - a1q, a1q_scale = moe_kernel_quantize_input( - a1, - quant_config.a1_gscale if use_nvfp4 else quant_config.a1_scale, - quant_config.quant_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape, - # NOTE: swizzling pads the scales to multiple of 128 - # which makes the scales tensor different shape than - # the hidden states, breaking the A2A kernel. So, we - # delay the swizzling until after the A2A. - is_fp4_scale_swizzled=False, - ) - - # Skip gathering scales if we have static quantization - # (the scale is a scalar, replicated on all ranks) or - # if quantization is deferred. - skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0 - scales = None if skip_gather_scales else [a1q_scale] - - res = get_ep_group().dispatch( - a1q, - topk_weights, - topk_ids, - is_sequence_parallel=self.is_sequence_parallel, - extra_tensors=scales, - ) - if skip_gather_scales: - a1q, topk_weights, topk_ids = res - else: - a1q, topk_weights, topk_ids, scales = res - assert scales is not None and len(scales) == 1 - a1q_scale = scales[0] - if quant_config.quant_dtype == "nvfp4": - assert a1q_scale is not None - if a1q_scale.element_size() == 1: - a1q_scale = a1q_scale.view(torch.uint8) - a1q_scale = nvfp4_block_scale_interleave(a1q_scale) - - return a1q, a1q_scale, None, topk_ids, topk_weights - - def finalize( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> None: - if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): - weight_and_reduce_impl = TopKWeightAndReduceContiguous() - - out = weight_and_reduce_impl.apply( - output=None, - fused_expert_output=fused_expert_output, - topk_weights=topk_weights, - topk_ids=topk_ids, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - - output.copy_( - get_ep_group().combine(out, is_sequence_parallel=self.is_sequence_parallel) - ) - - -class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): - """MoE prepare and finalize without expert parallelism.""" - - @property - def activation_format(self) -> mk.FusedMoEActivationFormat: - return mk.FusedMoEActivationFormat.Standard - - def max_num_tokens_per_rank(self) -> int | None: - return None - - def topk_indices_dtype(self) -> torch.dtype | None: - return None - - def num_dispatchers(self) -> int: - return 1 - - def output_is_reduced(self) -> bool: - return False - - def prepare( - self, - a1: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - num_experts: int, - expert_map: torch.Tensor | None, - apply_router_weight_on_input: bool, - quant_config: FusedMoEQuantConfig, - defer_input_quant: bool = False, - ) -> mk.PrepareResultType: - if apply_router_weight_on_input: - topk = topk_ids.size(1) - # TODO: this only works for topK=1, will need to update for topK>1 - assert topk == 1, ( - "apply_router_weight_on_input is only implemented for topk=1" - ) - # Note: do not use inplace for shared experts overlap - a1 = a1 * topk_weights.to(a1.dtype) - - # Defer input quant to moe kernel for backends (e.g. AITER, FI) - # which use a single kernel call for quant + experts. - if defer_input_quant: - return a1, None, None, None, None - - input_sf = ( - quant_config.a1_gscale - if quant_config.use_nvfp4_w4a4 - else quant_config.a1_scale - ) - a1q, a1q_scale = moe_kernel_quantize_input( - a1, - input_sf, - quant_config.quant_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape, - ) - - return a1q, a1q_scale, None, None, None - - def finalize( - self, - output: torch.Tensor, - fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce, - ) -> None: - if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): - weight_and_reduce_impl = TopKWeightAndReduceContiguous() - weight_and_reduce_impl.apply( - output=output, - fused_expert_output=fused_expert_output, - topk_weights=topk_weights, - topk_ids=topk_ids, - apply_router_weight_on_input=apply_router_weight_on_input, - ) diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/__init__.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/__init__.py new file mode 100644 index 0000000..03fea7c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/__init__.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.model_executor.layers.fused_moe.prepare_finalize.naive_dp_ep import ( + MoEPrepareAndFinalizeNaiveDPEPModular, + MoEPrepareAndFinalizeNaiveDPEPMonolithic, + make_moe_prepare_and_finalize_naive_dp_ep, +) +from vllm.model_executor.layers.fused_moe.prepare_finalize.no_dp_ep import ( + MoEPrepareAndFinalizeNoDPEPModular, + MoEPrepareAndFinalizeNoDPEPMonolithic, + make_moe_prepare_and_finalize_no_dp_ep, +) + +__all__ = [ + "MoEPrepareAndFinalizeNaiveDPEPMonolithic", + "MoEPrepareAndFinalizeNaiveDPEPModular", + "make_moe_prepare_and_finalize_naive_dp_ep", + "MoEPrepareAndFinalizeNoDPEPMonolithic", + "MoEPrepareAndFinalizeNoDPEPModular", + "make_moe_prepare_and_finalize_no_dp_ep", +] diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py new file mode 100644 index 0000000..6dc9f69 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/naive_dp_ep.py @@ -0,0 +1,253 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.distributed import get_ep_group +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceContiguous, + TopKWeightAndReduceDelegate, +) +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input +from vllm.utils.flashinfer import nvfp4_block_scale_interleave + + +def _quantize_and_setup_dispatch( + a1: torch.Tensor, + quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, +) -> tuple[torch.Tensor, list[torch.Tensor] | None]: + # Defer input quantization to the MoE kernel. + if defer_input_quant: + a1q = a1 + a1q_scale = None + else: + input_sf = ( + quant_config.a1_gscale + if quant_config.use_nvfp4_w4a4 + else quant_config.a1_scale + ) + + # NOTE: swizzling pads the scales to multiple of 128 + # which makes the scales tensor different shape than + # the hidden states, breaking the A2A kernel. So, we + # delay the swizzling until after the A2A. + a1q, a1q_scale = a1q, a1q_scale = moe_kernel_quantize_input( + a1, + input_sf, + quant_dtype=quant_config.quant_dtype, + per_act_token_quant=quant_config.per_act_token_quant, + block_shape=quant_config.block_shape, + is_fp4_scale_swizzled=False, + ) + + # Skip gathering scales if we have static quantization + # (the scale is a scalar, replicated on all ranks) or + # if quantization is deferred. + skip_gather_scales = a1q_scale is None or a1q_scale.ndim == 0 + scales = None if skip_gather_scales else [a1q_scale] + + return a1q, scales + + +def _unwrap_scale_and_prepare_for_moe( + scales: list[torch.Tensor] | None, + quant_config: FusedMoEQuantConfig, +) -> torch.Tensor: + assert scales is not None and len(scales) == 1 + a1q_scale = scales[0] + # Apply swizzling after a2a if the MoE kernel needs it. + if quant_config.quant_dtype == "nvfp4" and quant_config.is_nvfp4_scale_swizzled: + assert a1q_scale is not None + if a1q_scale.element_size() == 1: + a1q_scale = a1q_scale.view(torch.uint8) + a1q_scale = nvfp4_block_scale_interleave(a1q_scale) + + return a1q_scale + + +class MoEPrepareAndFinalizeNaiveDPEPModular(mk.FusedMoEPrepareAndFinalizeModular): + """ + Naive Prepare/Finalize for Dp/Ep case for Modular Kernels. + + Uses Torch AR/RS or AR for dispatch/combine operations, applied + to the topk weights and ids. + """ + + def __init__( + self, + is_sequence_parallel: bool = False, + num_dispatchers: int = 1, + ) -> None: + super().__init__() + self.is_sequence_parallel = is_sequence_parallel + self._num_dispatchers = num_dispatchers + + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + def max_num_tokens_per_rank(self) -> int | None: + return None + + def topk_indices_dtype(self) -> torch.dtype | None: + return None + + def num_dispatchers(self) -> int: + return self._num_dispatchers + + def output_is_reduced(self) -> bool: + return False + + def prepare( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: torch.Tensor | None, + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, + ) -> mk.PrepareResultType: + """Quantize and Dispatch Topk Weights and Topk Ids.""" + + if apply_router_weight_on_input: + topk = topk_ids.size(1) + assert topk == 1, ( + "apply_router_weight_on_input is only implemented for topk=1" + ) + # Note: do not use inplace for shared experts overlap + a1 = a1 * topk_weights.to(a1.dtype) + + a1q, scales = _quantize_and_setup_dispatch(a1, quant_config, defer_input_quant) + + res = get_ep_group().dispatch( + a1q, + topk_weights, + topk_ids, + is_sequence_parallel=self.is_sequence_parallel, + extra_tensors=scales, + ) + + if scales is None: + a1q, topk_weights, topk_ids = res + a1q_scale = None + else: + a1q, topk_weights, topk_ids, scales = res + a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config) + + return a1q, a1q_scale, None, topk_ids, topk_weights + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): + weight_and_reduce_impl = TopKWeightAndReduceContiguous() + + out = weight_and_reduce_impl.apply( + output=None, + fused_expert_output=fused_expert_output, + topk_weights=topk_weights, + topk_ids=topk_ids, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + + output.copy_( + get_ep_group().combine(out, is_sequence_parallel=self.is_sequence_parallel) + ) + + +class MoEPrepareAndFinalizeNaiveDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMonolithic): + """ + Naive Prepare/Finalize for Dp/Ep case for Modular Kernels. + + Uses Torch AR/RS or AR for dispatch/combine operations, applied + to the router logits (the MoE kernel runs the router internally). + """ + + def __init__( + self, + is_sequence_parallel: bool = False, + num_dispatchers: int = 1, + ) -> None: + super().__init__() + self.is_sequence_parallel = is_sequence_parallel + self._num_dispatchers = num_dispatchers + + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + def max_num_tokens_per_rank(self) -> int | None: + return None + + def topk_indices_dtype(self) -> torch.dtype | None: + return None + + def num_dispatchers(self) -> int: + return self._num_dispatchers + + def output_is_reduced(self) -> bool: + return False + + def prepare( + self, + a1: torch.Tensor, + router_logits: torch.Tensor, + quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, + ) -> mk.PrepareMonolithicResultType: + """Quantize and Dispatch Router Logits.""" + + a1q, scales = _quantize_and_setup_dispatch(a1, quant_config, defer_input_quant) + + res = get_ep_group().dispatch_router_logits( + a1q, + router_logits, + is_sequence_parallel=self.is_sequence_parallel, + extra_tensors=scales, + ) + + if scales is None: + a1q, router_logits = res + a1q_scale = None + else: + a1q, router_logits, scales = res + a1q_scale = _unwrap_scale_and_prepare_for_moe(scales, quant_config) + + return a1q, a1q_scale, router_logits + + def finalize( + self, + fused_expert_output: torch.Tensor, + ) -> torch.Tensor: + out = get_ep_group().combine( + fused_expert_output, is_sequence_parallel=self.is_sequence_parallel + ) + return out + + +def make_moe_prepare_and_finalize_naive_dp_ep( + use_monolithic: bool, + is_sequence_parallel: bool = False, + num_dispatchers: int = 1, +) -> MoEPrepareAndFinalizeNaiveDPEPModular | MoEPrepareAndFinalizeNaiveDPEPMonolithic: + return ( + MoEPrepareAndFinalizeNaiveDPEPMonolithic( + is_sequence_parallel=is_sequence_parallel, + num_dispatchers=num_dispatchers, + ) + if use_monolithic + else MoEPrepareAndFinalizeNaiveDPEPModular( + is_sequence_parallel=is_sequence_parallel, + num_dispatchers=num_dispatchers, + ) + ) diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py b/vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py new file mode 100644 index 0000000..b9d57da --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize/no_dp_ep.py @@ -0,0 +1,141 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceContiguous, + TopKWeightAndReduceDelegate, +) +from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input + + +def _quantize_input( + a1: torch.Tensor, + quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, +) -> tuple[torch.Tensor, torch.Tensor | None]: + # Defer input quant to moe kernel for backends (e.g. AITER, FI) + # which use a single kernel call for quant + experts. + if defer_input_quant: + return a1, None + + input_sf = ( + quant_config.a1_gscale if quant_config.use_nvfp4_w4a4 else quant_config.a1_scale + ) + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + input_sf, + quant_dtype=quant_config.quant_dtype, + per_act_token_quant=quant_config.per_act_token_quant, + block_shape=quant_config.block_shape, + is_fp4_scale_swizzled=quant_config.is_nvfp4_scale_swizzled, + ) + + return a1q, a1q_scale + + +class MoEPrepareAndFinalizeNoDPEPModular(mk.FusedMoEPrepareAndFinalizeModular): + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + def max_num_tokens_per_rank(self) -> int | None: + return None + + def topk_indices_dtype(self) -> torch.dtype | None: + return None + + def num_dispatchers(self) -> int: + return 1 + + def output_is_reduced(self) -> bool: + return False + + def prepare( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: torch.Tensor | None, + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, + ) -> mk.PrepareResultType: + if apply_router_weight_on_input: + topk = topk_ids.size(1) + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, ( + "apply_router_weight_on_input is only implemented for topk=1" + ) + # Note: do not use inplace for shared experts overlap + a1 = a1 * topk_weights.to(a1.dtype) + + a1q, a1q_scale = _quantize_input(a1, quant_config, defer_input_quant) + + return a1q, a1q_scale, None, None, None + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate): + weight_and_reduce_impl = TopKWeightAndReduceContiguous() + weight_and_reduce_impl.apply( + output=output, + fused_expert_output=fused_expert_output, + topk_weights=topk_weights, + topk_ids=topk_ids, + apply_router_weight_on_input=apply_router_weight_on_input, + ) + + +class MoEPrepareAndFinalizeNoDPEPMonolithic(mk.FusedMoEPrepareAndFinalizeMonolithic): + @property + def activation_format(self) -> mk.FusedMoEActivationFormat: + return mk.FusedMoEActivationFormat.Standard + + def max_num_tokens_per_rank(self) -> int | None: + return None + + def topk_indices_dtype(self) -> torch.dtype | None: + return None + + def num_dispatchers(self) -> int: + return 1 + + def output_is_reduced(self) -> bool: + return False + + def prepare( + self, + a1: torch.Tensor, + router_logits: torch.Tensor, + quant_config: FusedMoEQuantConfig, + defer_input_quant: bool = False, + ) -> mk.PrepareMonolithicResultType: + a1q, a1q_scale = _quantize_input(a1, quant_config, defer_input_quant) + return a1q, a1q_scale, router_logits + + def finalize( + self, + fused_expert_output: torch.Tensor, + ) -> torch.Tensor: + return fused_expert_output + + +def make_moe_prepare_and_finalize_no_dp_ep( + use_monolithic: bool, +) -> MoEPrepareAndFinalizeNoDPEPModular | MoEPrepareAndFinalizeNoDPEPMonolithic: + return ( + MoEPrepareAndFinalizeNoDPEPMonolithic() + if use_monolithic + else MoEPrepareAndFinalizeNoDPEPModular() + ) diff --git a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py index 8c8439d..c550cad 100644 --- a/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py @@ -292,7 +292,7 @@ def rocm_aiter_fused_experts( ) -class AiterExperts(mk.FusedMoEPermuteExpertsUnpermute): +class AiterExperts(mk.FusedMoEExpertsModular): @property def expects_unquantized_inputs(self) -> bool: return True diff --git a/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py b/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py index 7608e06..b061b3d 100644 --- a/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py +++ b/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py @@ -20,6 +20,7 @@ import torch from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_rank from vllm.forward_context import get_forward_context +from vllm.platforms import current_platform logger = logging.getLogger(__name__) @@ -132,7 +133,7 @@ class RoutedExpertsCapturer: self._device_buffer = torch.zeros( (max_num_batched_tokens, num_layers, num_experts_per_tok), dtype=torch.int32, - device="cuda", + device=current_platform.device_type, ) self.dp_rank = vllm_config.parallel_config.data_parallel_rank diff --git a/vllm/model_executor/layers/fused_moe/router/base_router.py b/vllm/model_executor/layers/fused_moe/router/base_router.py index 52005d4..d5c83ce 100644 --- a/vllm/model_executor/layers/fused_moe/router/base_router.py +++ b/vllm/model_executor/layers/fused_moe/router/base_router.py @@ -64,7 +64,7 @@ if current_platform.is_cuda_alike(): # TODO(bowen): When using `FusedMoEModularKernel`, this # can be done in a more unified way, since - # `FusedMoEPrepareAndFinalize` will return the expert + # `FusedMoEPrepareAndFinalizeModular` will return the expert # token count, in some cases directly from the kernel. # However, now there are many code paths not using # the modular kernel, e.g. calling `fused_experts`, @@ -175,6 +175,7 @@ class BaseRouter(FusedMoERouter): topk_ids = topk_ids.to(dtype=indices_type) assert topk_ids.dtype == indices_type or indices_type is None + topk_ids = topk_ids.to(torch.int32) return topk_ids @abstractmethod diff --git a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py index 584e044..3ceb5a8 100644 --- a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py +++ b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py @@ -31,7 +31,7 @@ def vllm_topk_softmax( token_expert_indices, gating_output, renormalize, - e_score_correction_bias, + e_score_correction_bias ) return topk_weights, topk_indices @@ -85,13 +85,14 @@ def fused_topk_bias( token_expert_indices = torch.empty( M, topk, dtype=torch.int32, device=hidden_states.device ) + gating_output_float = gating_output.float() # TODO(woosuk): Optimize this. if scoring_func == "softmax": topk_weights, topk_ids = vllm_topk_softmax( topk_weights, topk_ids, token_expert_indices, - gating_output, + gating_output_float, renormalize, e_score_correction_bias, ) @@ -186,7 +187,7 @@ class FusedTopKBiasRouter(BaseRouter): indices_type=indices_type, ) - if self.routed_scaling_factor != 1.0: - topk_weights *= self.routed_scaling_factor + # if self.routed_scaling_factor != 1.0: + # topk_weights *= self.routed_scaling_factor return topk_weights, topk_ids diff --git a/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py b/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py index 0ae867a..d0ad5c5 100644 --- a/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py +++ b/vllm/model_executor/layers/fused_moe/router/fused_topk_router.py @@ -26,8 +26,9 @@ def vllm_topk_softmax( topk_indices, token_expert_indices, gating_output, - renormalize, ) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_indices @@ -90,13 +91,14 @@ def fused_topk( token_expert_indices = torch.empty( M, topk, dtype=torch.int32, device=hidden_states.device ) + gating_output_float = gating_output.float() if scoring_func == "softmax": topk_func = dispatch_topk_softmax_func( use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled() ) topk_weights, topk_ids = topk_func( - topk_weights, topk_ids, token_expert_indices, gating_output.float(), renormalize + topk_weights, topk_ids, token_expert_indices, gating_output_float, renormalize ) return topk_weights, topk_ids, token_expert_indices @@ -105,7 +107,7 @@ def fused_topk( use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled() ) topk_weights, topk_ids = topk_func( - topk_weights, topk_ids, token_expert_indices, gating_output.float(), renormalize + topk_weights, topk_ids, token_expert_indices, gating_output_float, renormalize ) return topk_weights, topk_ids, token_expert_indices diff --git a/vllm/model_executor/layers/fused_moe/router/gate_linear.py b/vllm/model_executor/layers/fused_moe/router/gate_linear.py new file mode 100644 index 0000000..26fedd5 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/router/gate_linear.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +from torch.nn.parameter import Parameter + +from vllm.model_executor.custom_op import PluggableLayer +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.platforms import current_platform + + +@PluggableLayer.register("gate_linear") +class GateLinear(ReplicatedLinear): + """MoE gate linear layer with three-tier GEMM dispatch: + + 1. DSV3 specialized kernel (SM90+, batch<=16, supported dims) + 2. cuBLAS bf16×bf16→fp32 (SM90+ + bf16 + fp32 out_dtype) + 3. F.linear via ReplicatedLinear (ultimate fallback) + + The ``out_dtype`` attribute is mutable and can be set after init + (e.g. when the required dtype depends on the expert quantization + method which is only known later). + """ + + # Dimensions supported by the DSV3 specialized kernel + DSV3_SUPPORTED_NUM_EXPERTS = [256, 384] + DSV3_SUPPORTED_HIDDEN_SIZES = [7168] + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = False, + out_dtype: torch.dtype | None = None, + params_dtype: torch.dtype | None = None, + force_fp32_compute: bool = False, + prefix: str = "", + ): + is_hopper_or_blackwell = current_platform.is_device_capability( + (9, 0) + ) or current_platform.is_device_capability_family(100) + can_use_specialized_kernels = False + + # If fp32 compute is required and no specialized kernel is available, + # store weights in fp32 so Tier 3 computes in fp32 natively. + if force_fp32_compute and not can_use_specialized_kernels: + params_dtype = torch.float32 + + super().__init__( + input_size, + output_size, + bias=bias, + params_dtype=params_dtype, + quant_config=None, + prefix=prefix, + ) + self.out_dtype = out_dtype + + # DSV3 specialized kernel eligibility (SM90+, exact dims) + self.allow_specialized_router_gemm = can_use_specialized_kernels + self.allow_dsv3_router_gemm = ( + self.allow_specialized_router_gemm + and output_size in self.DSV3_SUPPORTED_NUM_EXPERTS + and input_size in self.DSV3_SUPPORTED_HIDDEN_SIZES + ) + + # cuBLAS bf16→fp32 eligibility + self.allow_cublas_router_gemm = ( + self.allow_specialized_router_gemm + and self.weight.dtype == torch.bfloat16 + and self.out_dtype == torch.float32 + ) + + def set_out_dtype(self, out_dtype: torch.dtype) -> None: + """Set output dtype for the router logits after init. + + Useful when the required dtype depends on the expert quantization + method which is only known after the gate is constructed. + """ + if self.out_dtype is not None: + raise ValueError("out_dtype has already been set") + self.out_dtype = out_dtype + + if ( + not self.allow_cublas_router_gemm + and self.allow_specialized_router_gemm + and out_dtype == torch.float32 + ): + self.allow_cublas_router_gemm = self.weight.dtype == torch.bfloat16 + + def forward( + self, x: torch.Tensor + ) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]: + import vllm._custom_ops as ops + + # Tier 1: DSV3 specialized kernel + if self.allow_dsv3_router_gemm and x.shape[0] <= 16: + output = ops.dsv3_router_gemm( + hidden_states=x, + router_weight=self.weight, + output_dtype=self.out_dtype, + ) + return output, None + + # Tier 2: cuBLAS bf16→fp32 + if self.allow_cublas_router_gemm and x.dtype == torch.bfloat16: + output = ops.router_gemm_bf16_fp32(x, self.weight) + return output, None + + # Tier 3: F.linear (ReplicatedLinear) + if self.out_dtype is not None and x.dtype != self.weight.dtype: + x = x.to(self.weight.dtype) + output, output_bias = super().forward(x) + if self.out_dtype is not None and output.dtype != self.out_dtype: + output = output.to(self.out_dtype) + return output, output_bias diff --git a/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py b/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py index 486610f..99d8891 100644 --- a/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py +++ b/vllm/model_executor/layers/fused_moe/router/grouped_topk_router.py @@ -92,77 +92,9 @@ def grouped_topk( routed_scaling_factor: float = 1.0, e_score_correction_bias: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - if ( - envs.VLLM_USE_FUSED_MOE_GROUPED_TOPK - and current_platform.is_cuda() - and num_expert_group <= 32 - and topk <= 32 - and e_score_correction_bias is not None - ): - return fused_grouped_topk( - hidden_states=hidden_states, - gating_output=gating_output, - topk=topk, - renormalize=renormalize, - e_score_correction_bias=e_score_correction_bias, - num_expert_group=num_expert_group, - topk_group=topk_group, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - ) - - assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch" - - if scoring_func == "softmax": - scores = torch.softmax(gating_output, dim=-1) - elif scoring_func == "sigmoid": - scores = gating_output.sigmoid() - else: - raise ValueError(f"Unsupported scoring function: {scoring_func}") - - num_token = scores.size(0) - if e_score_correction_bias is not None: - # Store original scores before applying correction bias. We use biased - # scores for expert selection but original scores for routing weights - original_scores = scores - scores = scores + e_score_correction_bias.unsqueeze(0) - group_scores = ( - scores.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1) - ) - else: - group_scores = ( - scores.view(num_token, num_expert_group, -1).max(dim=-1).values - ) # [n, n_group] - - # For batch invariance, use sorted=True to ensure deterministic expert selection - use_sorted = vllm_is_batch_invariant() - group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[ - 1 - ] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = ( - group_mask.unsqueeze(-1) - .expand(num_token, num_expert_group, scores.size(-1) // num_expert_group) - .reshape(num_token, -1) - ) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) # [n, e] - - if e_score_correction_bias is not None: - topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=use_sorted)[1] - # Use original unbiased scores for the routing weights - topk_weights = original_scores.gather(1, topk_ids) - else: - topk_weights, topk_ids = torch.topk( - tmp_scores, k=topk, dim=-1, sorted=use_sorted - ) - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - - if routed_scaling_factor != 1.0: - topk_weights = topk_weights * routed_scaling_factor - return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + from ixformer.inference.functions import moe_grouped_topk as grouped_topk + topk_weights, topk_ids = grouped_topk(gating_output, topk, num_expert_group, topk_group, scoring_func, e_score_correction_bias,renormalize = renormalize) + return topk_weights, topk_ids # --8<-- [start:grouped_topk] @@ -246,7 +178,6 @@ class GroupedTopk(CustomOp): hidden_states, gating_output, e_score_correction_bias ) -from ixformer.inference.functions import moe_grouped_topk as grouped_topk class GroupedTopKRouter(BaseRouter): """Router using grouped top-k routing (e.g., DeepSeekV2/V3).""" @@ -316,8 +247,8 @@ class GroupedTopKRouter(BaseRouter): topk=self.top_k, renormalize=self.renormalize, ) - if self.routed_scaling_factor != 1.0: - topk_weights *= self.routed_scaling_factor + # if self.routed_scaling_factor != 1.0: + # topk_weights *= self.routed_scaling_factor else: topk_weights, topk_ids, token_expert_indices = fused_topk( hidden_states=hidden_states, @@ -340,14 +271,14 @@ class GroupedTopKRouter(BaseRouter): grouped_topk_impl = grouped_topk topk_weights, topk_ids = grouped_topk_impl( - # hidden_states=hidden_states, + hidden_states=hidden_states, gating_output=router_logits, topk=self.top_k, renormalize=self.renormalize, num_expert_group=self.num_expert_group, topk_group=self.topk_group, scoring_func=self.scoring_func, - # routed_scaling_factor=self.routed_scaling_factor, + routed_scaling_factor=self.routed_scaling_factor, e_score_correction_bias=self.e_score_correction_bias, ) diff --git a/vllm/model_executor/layers/fused_moe/router/router_factory.py b/vllm/model_executor/layers/fused_moe/router/router_factory.py index a0733ba..11027e8 100644 --- a/vllm/model_executor/layers/fused_moe/router/router_factory.py +++ b/vllm/model_executor/layers/fused_moe/router/router_factory.py @@ -44,7 +44,7 @@ def create_fused_moe_router( # grouped topk + fused topk bias parameters routed_scaling_factor: float = 1.0, e_score_correction_bias: torch.Tensor | None = None, - # custom routing paramaters + # custom routing parameters custom_routing_function: Callable | None = None, # eplb parameters enable_eplb: bool = False, diff --git a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py index 087b703..b69744d 100644 --- a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py +++ b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import nullcontext +from typing import TYPE_CHECKING import torch import torch.nn.functional as F @@ -30,6 +31,8 @@ from vllm.model_executor.layers.fused_moe.runner.moe_runner import MoERunner from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import ( + HAS_OPAQUE_TYPE, + ModuleName, aux_stream, current_stream, direct_register_custom_op, @@ -56,13 +59,27 @@ def get_layer_from_name(layer_name: str) -> torch.nn.Module: return forward_context.no_compile_layers[layer_name] +# On torch >= 2.11, layer_name is a hoisted ModuleName opaque object; +# on older versions it remains a plain str. +if TYPE_CHECKING: + from typing import TypeAlias + + _layer_name_type: TypeAlias = str | ModuleName +else: + _layer_name_type = ModuleName if HAS_OPAQUE_TYPE else str + + +def _resolve_layer_name(layer_name: str | ModuleName) -> str: + return layer_name.value if isinstance(layer_name, ModuleName) else layer_name + + def _moe_forward( hidden_states: torch.Tensor, router_logits: torch.Tensor, shared_experts_input: torch.Tensor | None, - layer_name: str, + layer_name: _layer_name_type, ) -> torch.Tensor: - layer = get_layer_from_name(layer_name) + layer = get_layer_from_name(_resolve_layer_name(layer_name)) # TODO(bnell): this can be removed after MK migration is complete. layer.ensure_moe_quant_config_init() return layer.runner.forward_impl( @@ -74,7 +91,7 @@ def _moe_forward_fake( hidden_states: torch.Tensor, router_logits: torch.Tensor, shared_experts_input: torch.Tensor | None, - layer_name: str, + layer_name: _layer_name_type, ) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -83,9 +100,9 @@ def _moe_forward_shared( hidden_states: torch.Tensor, router_logits: torch.Tensor, shared_experts_input: torch.Tensor | None, - layer_name: str, + layer_name: _layer_name_type, ) -> tuple[torch.Tensor, torch.Tensor]: - layer = get_layer_from_name(layer_name) + layer = get_layer_from_name(_resolve_layer_name(layer_name)) # TODO(bnell): this can be removed after MK migration is complete. layer.ensure_moe_quant_config_init() return layer.runner.forward_impl( @@ -97,7 +114,7 @@ def _moe_forward_shared_fake( hidden_states: torch.Tensor, router_logits: torch.Tensor, shared_experts_input: torch.Tensor | None, - layer_name: str, + layer_name: _layer_name_type, ) -> tuple[torch.Tensor, torch.Tensor]: # Output shapes: # - fused_out: same as hidden_states (routed experts use transformed size) @@ -105,12 +122,10 @@ def _moe_forward_shared_fake( # hidden_states # (For latent MoE: shared experts use original hidden_size, not latent size) fused_out = torch.empty_like(hidden_states) - if shared_experts_input is not None: shared_out = torch.empty_like(shared_experts_input) else: shared_out = torch.empty_like(hidden_states) - return shared_out, fused_out @@ -165,6 +180,7 @@ class DefaultMoERunner(MoERunner): quant_method: FusedMoEMethodBase, reduce_results: bool, enable_dbo: bool, + fused_shared_output: bool = False, ): super().__init__() self.moe_config = moe_config @@ -175,6 +191,9 @@ class DefaultMoERunner(MoERunner): self.quant_method = quant_method self.reduce_results = reduce_results self.enable_dbo = enable_dbo + self.fused_shared_output = fused_shared_output + if self.fused_shared_output: + assert self.shared_experts is not None, "Shared experts must be provided when fused_shared_output is True." # Allow disabling of the separate shared experts stream for # debug purposes. @@ -195,19 +214,19 @@ class DefaultMoERunner(MoERunner): # Needed for string -> FusedMoE layer lookup in custom ops. self.layer_name = layer.layer_name - if current_platform.is_tpu() or current_platform.is_cpu(): + # if current_platform.is_tpu() or current_platform.is_cpu(): # TODO: Once the OOM issue for the TPU backend is resolved, we # will switch to using the moe_forward custom op. # Note: CPU doesn't require wrapped forward_impl. - if self.shared_experts is None: - self.moe_forward = _moe_forward - else: - self.moe_forward = _moe_forward_shared + if self.shared_experts is None: + self.moe_forward = _moe_forward else: - if self.shared_experts is None: - self.moe_forward = torch.ops.vllm.moe_forward - else: - self.moe_forward = torch.ops.vllm.moe_forward_shared + self.moe_forward = _moe_forward_shared + # else: + # if self.shared_experts is None: + # self.moe_forward = torch.ops.vllm.moe_forward + # else: + # self.moe_forward = torch.ops.vllm.moe_forward_shared # Chunked all2all staging tensor self.batched_hidden_states: torch.Tensor | None = None @@ -216,8 +235,7 @@ class DefaultMoERunner(MoERunner): @property def use_dp_chunking(self) -> bool: return ( - self.moe_config.moe_parallel_config.use_pplx_kernels - or self.moe_config.moe_parallel_config.use_deepep_ll_kernels + self.moe_config.moe_parallel_config.use_deepep_ll_kernels or self.moe_config.moe_parallel_config.use_mori_kernels or self.moe_config.moe_parallel_config.use_fi_all2allv_kernels ) and envs.VLLM_ENABLE_MOE_DP_CHUNK @@ -306,8 +324,8 @@ class DefaultMoERunner(MoERunner): """ assert self.quant_method is not None return ( - self.quant_method.moe_mk is not None - and self.quant_method.moe_mk.output_is_reduced() + self.quant_method.moe_kernel is not None + and self.quant_method.moe_kernel.output_is_reduced() ) def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor): @@ -362,13 +380,15 @@ class DefaultMoERunner(MoERunner): if isinstance(states, tuple): return tuple( - [func(s, trunc_size) for s, trunc_size in zip(states, trunc_sizes)] + [None if s is None else func(s, trunc_size) for s, trunc_size in zip(states, trunc_sizes)] ) else: assert len(trunc_sizes) == 1 return func(states, trunc_sizes[0]) - def _encode_layer_name(self) -> str: + def _encode_layer_name(self) -> str | ModuleName: + if HAS_OPAQUE_TYPE: + return ModuleName(self.layer_name) # Can be unavailable or None in unittests if ( is_forward_context_available() @@ -624,53 +644,27 @@ class DefaultMoERunner(MoERunner): ) with sp_ctx: - extra_tensors = None - if do_naive_dispatch_combine: - post_quant_allgather = ( - self.quant_method is not None - and self.moe_config.dp_size > 1 - and self.moe_config.use_ep - and getattr(self.quant_method, "do_post_quant_allgather", False) - ) - if post_quant_allgather: - hidden_states_to_dispatch, extra_tensors = ( - self.quant_method.prepare_dp_allgather_tensor( - layer, hidden_states, router_logits - ) - ) - else: - hidden_states_to_dispatch = hidden_states - - dispatch_res = get_ep_group().dispatch_router_logits( - hidden_states_to_dispatch, - router_logits, - self.moe_config.is_sequence_parallel, - extra_tensors=extra_tensors, - ) - if extra_tensors is not None: - ( - orig_hidden_states, - router_logits, - extra_tensors_combined, - ) = dispatch_res - hidden_states_combined = ( - orig_hidden_states, - extra_tensors_combined[0], - ) - else: - hidden_states_combined, router_logits = dispatch_res - orig_hidden_states = hidden_states_combined - else: - orig_hidden_states = hidden_states - # Run shared experts before matrix multiply. # because matrix multiply maybe modify the hidden_states. - if has_separate_shared_experts and not use_shared_experts_stream: + if has_separate_shared_experts: # and not use_shared_experts_stream: assert self.shared_experts is not None shared_input = ( shared_input if shared_input is not None else hidden_states ) shared_output = self.shared_experts(shared_input) + else: + assert self.fused_shared_output == False, "fused_shared_output is only supported when has_separate_shared_experts is True" + shared_output = None + # For naive dispatch/combine Dp/Ep, dispatch the hidden states and + # router logits to all experts. + # NOTE: this will be removed once all kernels are migrated into the + # MoEKernel framework. + if do_naive_dispatch_combine: + hidden_states, router_logits = get_ep_group().dispatch_router_logits( + hidden_states, + router_logits, + self.moe_config.is_sequence_parallel, + ) # NOTE: Similar with DP, PCP also needs dispatch and combine. For # simplicity, AgRsAll2All was added separately for PCP here. Maybe @@ -685,42 +679,33 @@ class DefaultMoERunner(MoERunner): dim=0, ) - # TODO(bnell): deal with fp4 flashinfer tuple hidden states hack (#30014). - # Figure out nicer way to do this. - if do_naive_dispatch_combine: - x = hidden_states_combined - x_orig = orig_hidden_states - else: - x = hidden_states - x_orig = hidden_states - # Matrix multiply. if self.quant_method.is_monolithic: final_hidden_states = self.quant_method.apply_monolithic( layer=layer, - x=x, + x=hidden_states, router_logits=router_logits, ) else: topk_weights, topk_ids = self.router.select_experts( - hidden_states=x_orig, + hidden_states=hidden_states, router_logits=router_logits, ) final_hidden_states = self.quant_method.apply( layer=layer, - x=x, # The type signture of this is wrong due to the hack. + x=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, - shared_experts_input=shared_input, - router_logits=router_logits, - top_k=topk_ids.shape[-1] + # Assign the value of shared_experts_output to variable shared_experts_input for fusion + shared_experts_input=shared_output if self.fused_shared_output else None, ) if has_separate_shared_experts: assert self.shared_experts is not None if use_shared_experts_stream: + assert use_shared_experts_stream == False, "Running shared experts in parallel with the main MoE execution is currently not supported!" # Run shared experts in parallel on a separate stream # NOTE: We start the separate stream here and mark the # sync end point immediately after it is done. This is @@ -733,7 +718,7 @@ class DefaultMoERunner(MoERunner): current_stream().wait_stream(self.shared_experts_stream) final_hidden_states = ( - shared_output, + None if self.fused_shared_output else shared_output, final_hidden_states, ) diff --git a/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py b/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py index 99d4038..4cebe60 100644 --- a/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py +++ b/vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py @@ -10,14 +10,15 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk class TopKWeightAndReduceDelegate(mk.TopKWeightAndReduce): """ - Useful in the case when some FusedMoEPermuteExpertsUnpermute + Useful in the case when some FusedMoEExpertsModular implementation does not perform weight application and reduction but cannot address the needs of all the compatible PrepareAndFinalize implementations. - For example, BatchedTritonExperts is compatible with both - PplxPrepareAndFinalize and BatchedPrepareAndFinalize. PplxPrepareAndFinalize - does the weight-application + reduction as part of the pplx combine kernel. - But the BatchedPrepareAndFinalize needs an implementation. To facilitate + For example, BatchedTritonExperts is compatible with both batched + PrepareAndFinalize implementations like DeepEPLLPrepareAndFinalize and + BatchedPrepareAndFinalize. Some PrepareAndFinalize implementations do + the weight-application + reduction as part of the combine kernel, while + BatchedPrepareAndFinalize needs an explicit implementation. To facilitate this case, the BatchedTritonExperts could use TopKWeightAndReduceDelegate so the PrepareAndFinalize implementations could choose how to weight + reduce. @@ -61,7 +62,7 @@ class TopKWeightAndReduceNoOP(mk.TopKWeightAndReduce): if output is None: return fused_expert_output - # MoEPrepareAndFinalizeNoEP needs the output to be in the `output` + # MoEPrepareAndFinalizeNoDPEPModular needs the output to be in the `output` # tensor. assert output.size() == fused_expert_output.size(), ( "output shape is expected to match the fused_expert_output shape. " diff --git a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py index 21a3d05..4aa396d 100644 --- a/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_cutlass_moe.py @@ -32,8 +32,8 @@ class TritonOrCutlassExperts(FallbackExperts): @staticmethod def get_clses() -> tuple[ - type[mk.FusedMoEPermuteExpertsUnpermute], - type[mk.FusedMoEPermuteExpertsUnpermute], + type[mk.FusedMoEExpertsModular], + type[mk.FusedMoEExpertsModular], ]: return (CutlassExpertsFp8, TritonExperts) @@ -77,7 +77,7 @@ class TritonOrCutlassExperts(FallbackExperts): hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - ) -> mk.FusedMoEPermuteExpertsUnpermute: + ) -> mk.FusedMoEExpertsModular: # Small batch fallback for sm100. if self.is_sm100 and hidden_states.shape[0] <= 8: return self.fallback_experts diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index a3f2f59..b601806 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -32,8 +32,8 @@ class TritonOrDeepGemmExperts(FallbackExperts): @staticmethod def get_clses() -> tuple[ - type[mk.FusedMoEPermuteExpertsUnpermute], - type[mk.FusedMoEPermuteExpertsUnpermute], + type[mk.FusedMoEExpertsModular], + type[mk.FusedMoEExpertsModular], ]: return (DeepGemmExperts, TritonExperts) @@ -79,7 +79,7 @@ class TritonOrDeepGemmExperts(FallbackExperts): hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - ) -> mk.FusedMoEPermuteExpertsUnpermute: + ) -> mk.FusedMoEExpertsModular: if is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2): return self.experts else: diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py index 2bd4cd7..5160840 100644 --- a/vllm/model_executor/layers/fused_moe/trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py @@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ) -class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): +class TrtLlmGenExperts(mk.FusedMoEExpertsModular): """TensorRT-LLM-based fused MoE expert implementation.""" def __init__( diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index a293145..00d4124 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -24,8 +24,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( ) from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEActivationFormat, - FusedMoEPermuteExpertsUnpermute, - FusedMoEPrepareAndFinalize, + FusedMoEExpertsModular, + FusedMoEPrepareAndFinalizeModular, ) from vllm.model_executor.layers.fused_moe.oracle.unquantized import ( UnquantizedMoeBackend, @@ -42,9 +42,9 @@ from vllm.platforms.interface import CpuArchEnum if current_platform.is_cuda_alike() or current_platform.is_xpu(): from .fused_batched_moe import BatchedTritonExperts - from .fused_moe import TritonExperts else: TritonExperts = None # type: ignore + fused_experts = None logger = init_logger(__name__) @@ -70,7 +70,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): self.rocm_aiter_moe_enabled = ( rocm_aiter_ops.is_fused_moe_enabled() and moe.is_act_and_mul ) - self.kernel: mk.FusedMoEModularKernel | None = None + self.kernel: mk.FusedMoEKernel | None = None self._is_monolithic = ( current_platform.is_cpu() or self.unquantized_backend == UnquantizedMoeBackend.FLASHINFER_TRTLLM @@ -107,7 +107,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, - ) -> FusedMoEPrepareAndFinalize | None: + ) -> FusedMoEPrepareAndFinalizeModular | None: if self.unquantized_backend == UnquantizedMoeBackend.AITER: return None else: @@ -115,9 +115,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def select_gemm_impl( self, - prepare_finalize: FusedMoEPrepareAndFinalize, + prepare_finalize: FusedMoEPrepareAndFinalizeModular, layer: torch.nn.Module, - ) -> FusedMoEPermuteExpertsUnpermute: + ) -> FusedMoEExpertsModular: assert self.moe_quant_config is not None if ( prepare_finalize.activation_format @@ -296,16 +296,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): x: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + # Assign the value of shared_experts_output to variable shared_experts_input for fusion shared_experts_input: torch.Tensor | None, - **kwargs ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - return self.forward( + result = self.forward( layer=layer, x=x, topk_weights=topk_weights, topk_ids=topk_ids, + # not used shared_experts_input=shared_experts_input, - ) + ) * layer.routed_scaling_factor + if shared_experts_input is not None: + result += shared_experts_input + return result def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: if self.moe.has_bias: @@ -333,10 +337,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): topk_weights=topk_weights, topk_ids=topk_ids, activation=layer.activation, + quant_config=self.moe_quant_config, apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, - expert_map=layer.expert_map, - shared_experts_input=shared_experts_input, + expert_map=layer.expert_map ) def forward_monolithic_cuda( diff --git a/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py index e6f8b8e..0693a25 100644 --- a/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/xpu_fused_moe.py @@ -23,7 +23,7 @@ if current_platform.is_xpu(): from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe -class XPUExperts(mk.FusedMoEPermuteExpertsUnpermute): +class XPUExperts(mk.FusedMoEExpertsModular): def __init__( self, moe_config: FusedMoEConfig, diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index e55cd39..04e35d4 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -82,11 +82,12 @@ def fused_add_rms_norm( return rms_norm_batch_invariant( x + residual, weight, variance_epsilon ), x + residual - ops.fused_add_rms_norm( + x, residual = ops.fused_add_rms_norm( x, residual, weight, variance_epsilon, + residual_alpha, ) return x, residual @@ -125,7 +126,7 @@ def dispatch_rocm_rmsnorm_func( return fused_add_rms_norm return rms_norm - + def rms_norm_qk( input_q: torch.Tensor, input_k: torch.Tensor, @@ -140,11 +141,7 @@ def rms_norm_qk( output_q, output_k, input_q, input_k, weight_q, weight_k, epsilon ) return output_q, output_k - - -def dispatch_cuda_rmsnorm_qk_func() -> callable: - return rms_norm_qk - + @CustomOp.register("rms_norm_qk") class RMSNormQK(CustomOp): @@ -226,8 +223,7 @@ class RMSNormQK(CustomOp): f"[RMSNormQK] Expected input_q and input_k to have same dtype, " f"but got {input_q.dtype} vs {input_k.dtype}" ) - norm_func = dispatch_cuda_rmsnorm_qk_func() - return norm_func( + return rms_norm_qk( input_q, input_k, weight_q, @@ -264,7 +260,7 @@ class RMSNormQK(CustomOp): f"eps={self.variance_epsilon}, " ) - +# --8<-- [start:rms_norm] @CustomOp.register("rms_norm") class RMSNorm(CustomOp): """Root mean square normalization. @@ -375,7 +371,7 @@ class RMSNorm(CustomOp): # otherwise Inductor eliminates the casts to and from f16, # increasing memory usage (and complicating pattern matching) x = x + residual - residual = x.to(orig_dtype).contiguous() + residual = x.to(orig_dtype) if x.shape[-1] != hidden_size: raise ValueError( @@ -425,6 +421,7 @@ class RMSNorm(CustomOp): self, x: torch.Tensor, residual: torch.Tensor | None = None, + residual_alpha: float = 1.0, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if self.variance_size_override is not None: return self.forward_native(x, residual) @@ -499,7 +496,7 @@ class RMSNorm(CustomOp): add_residual = residual is not None if add_residual: return fused_add_rms_norm( - x, residual, self.weight.data, self.variance_epsilon + x, residual, self.weight.data, self.variance_epsilon,residual_alpha ) else: return rms_norm(x, self.weight.data, self.variance_epsilon) @@ -649,6 +646,7 @@ class RMSNormGated(CustomOp): norm_before_gate: bool = False, device: torch.device | None = None, dtype: torch.dtype | None = None, + activation: str = "swish", ): """Initialize RMSNormGated. @@ -663,10 +661,12 @@ class RMSNormGated(CustomOp): If False and z is provided: out = norm(x * silu(z)) device: Device to create parameters on dtype: Data type for parameters + activation: Activation function name for gating """ factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.eps = eps + self.activation = activation self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) self.register_parameter("bias", None) self.group_size = group_size @@ -693,6 +693,11 @@ class RMSNormGated(CustomOp): - norm_before_gate=True: out = norm(x) * silu(z) - norm_before_gate=False: out = norm(x * silu(z)) """ + orig_dtype = x.dtype + x = x.float() + weight = self.weight.float() + z = z.float() if z is not None else None + # Apply gating before normalization if needed if z is not None and not self.norm_before_gate: x = x * F.silu(z) @@ -702,7 +707,7 @@ class RMSNormGated(CustomOp): # Standard RMS norm across the last dimension variance = x.pow(2).mean(dim=-1, keepdim=True) x_normed = x * torch.rsqrt(variance + self.eps) - out = x_normed * self.weight + out = x_normed * weight else: # Group RMS norm from einops import rearrange @@ -710,13 +715,13 @@ class RMSNormGated(CustomOp): x_group = rearrange(x, "... (g d) -> ... g d", d=self.group_size) variance = x_group.pow(2).mean(dim=-1, keepdim=True) x_normed = x_group * torch.rsqrt(variance + self.eps) - out = rearrange(x_normed, "... g d -> ... (g d)") * self.weight + out = rearrange(x_normed, "... g d -> ... (g d)") * weight # Apply gating after normalization if needed if z is not None and self.norm_before_gate: out = out * F.silu(z) - return out + return out.to(orig_dtype) def forward_cuda( self, x: torch.Tensor, z: torch.Tensor | None = None @@ -731,6 +736,7 @@ class RMSNormGated(CustomOp): eps=self.eps, group_size=self.group_size, norm_before_gate=self.norm_before_gate, + activation=self.activation, ) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index a3dcdf8..5d0bfe0 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -2,8 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools +import ast, re from abc import abstractmethod -from typing import Any import torch from torch.nn.parameter import Parameter, UninitializedParameter @@ -16,6 +16,7 @@ from vllm.distributed import ( tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, ) +import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.batch_invariant import ( @@ -28,7 +29,9 @@ from vllm.model_executor.layers.quantization.base_config import ( ) from vllm.model_executor.layers.utils import ( dispatch_unquantized_gemm, - is_layer_moe_router_gate, + parse_opt_exclude_layers, + weight_quant_l1, + weight_quant_l2, ) from vllm.model_executor.parameter import ( BasevLLMParameter, @@ -41,12 +44,11 @@ from vllm.model_executor.parameter import ( ) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -import vllm.envs as envs +from compressed_tensors.quantization import QuantizationStrategy logger = init_logger(__name__) WEIGHT_LOADER_V2_SUPPORTED = [ - "UnquantizedLinearMethod", "CompressedTensorsLinearMethod", "CompressedTensorsLinearTransformMethod", "AWQMarlinLinearMethod", @@ -66,6 +68,14 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "PetitNvFp4LinearMethod", ] +LINEAR_OPT_SUPPORTED = [ + "ColumnParallelLinear", + "ReplicatedLinear", + "RowParallelLinear", + "QKVParallelLinear", + "MergedColumnParallelLinear", +] + def adjust_marlin_shard( param: Parameter, @@ -135,44 +145,6 @@ def adjust_scalar_to_fused_array( return param_data[shard_id], loaded_weight -# TODO(Isotr0py): We might need a more flexible structure to handle -# bitsandbytes shard offsets. -def left_shift_bitsandbytes_4bit_shard( - bnb_weight_attrs: dict[str, Any], -) -> tuple[dict[str, Any], dict[str, Any]]: - """ - Separate the BitsAndBytes 4-bit shard. - - For example, given bnb weight attributes as below: - { - 'bnb_shard_offsets': array([0, 4, 8, 16]), - 'bnb_quant_state': {0: ..., 1: ..., 2: ...}, - } - - The function will return: - { - 'bnb_shard_offsets': array([0, 4]), - 'bnb_quant_state': {0: ...}, - } - and - { - 'bnb_shard_offsets': array([0, 4, 12]), - 'bnb_quant_state': {0: ..., 1: ...}, - } - """ - shard_offsets = bnb_weight_attrs["bnb_shard_offsets"] - offset_l = shard_offsets[:2] - offset_r = shard_offsets[1:] - shard_offsets[1] - quant_state_l = {0: bnb_weight_attrs["bnb_quant_state"][0]} - quant_state_r = { - i - 1: bnb_weight_attrs["bnb_quant_state"][i] - for i in range(1, len(shard_offsets) - 1) - } - left = dict(bnb_shard_offsets=offset_l, bnb_quant_state=quant_state_l) - right = dict(bnb_shard_offsets=offset_r, bnb_quant_state=quant_state_r) - return left, right - - class LinearMethodBase(QuantizeMethodBase): """Base class for different (maybe quantized) linear methods.""" @@ -231,17 +203,11 @@ class UnquantizedLinearMethod(LinearMethodBase): # The weights are not quantized, and they are not sharded. # The amount of memory allocated for the weights is # sum(output_partition_sizes) * input_size_per_partition. - weight_loader = extra_weight_attrs.pop("weight_loader") - weight = ModelWeightParameter( - data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) @@ -258,11 +224,7 @@ class UnquantizedLinearMethod(LinearMethodBase): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - if ( - vllm_is_batch_invariant() - and current_platform.is_cuda_alike() - and is_layer_moe_router_gate(getattr(layer, "prefix", "")) - ): + if vllm_is_batch_invariant() and current_platform.is_cuda_alike(): return linear_batch_invariant(x, layer.weight, bias) return dispatch_unquantized_gemm()(layer, x, layer.weight, bias) @@ -305,15 +267,31 @@ class LinearBase(PluggableLayer): self.quant_config = quant_config self.prefix = prefix self.allow_fp8_block_shape_mismatch = False - if quant_config is None: + self.opt_level = envs.VLLM_LINEAR_OPT_LEVEL + if parse_opt_exclude_layers(envs.VLLM_LINEAR_SPECIFIED_LAYERS, self.prefix) or \ + (envs.VLLM_LINEAR_SPECIFIED_KEYS != "" and envs.VLLM_LINEAR_SPECIFIED_KEYS in self.prefix): + self.opt_level = envs.VLLM_LINEAR_SPECIFIED_OPT_LEVEL + self.opt_flag = quant_config is None and self.opt_level != 0 and \ + self.__class__.__name__ in LINEAR_OPT_SUPPORTED + + if parse_opt_exclude_layers(envs.VLLM_OPT_EXCLUDE_LAYERS, self.prefix): + self.opt_flag = False + logger.info(f"Excluding layer {self.prefix} from optimization") + + if self.opt_flag: + from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsLinearMethod + from vllm.model_executor.layers.quantization.compressed_tensors.schemes import CompressedTensorsW8A8Int8 + self.quant_method: QuantizeMethodBase | None = CompressedTensorsLinearMethod(None) + self.scheme = CompressedTensorsW8A8Int8(QuantizationStrategy.CHANNEL, False, True, is_w4a8_linear=True if self.opt_level == 2 else False) + elif quant_config is None: self.quant_method: QuantizeMethodBase | None = UnquantizedLinearMethod() else: self.quant_method = quant_config.get_quant_method(self, prefix=prefix) self.return_bias = return_bias + self.output_padding_size = 0 self.disable_tp = disable_tp self.tp_rank = get_tensor_model_parallel_rank() if not disable_tp else 0 self.tp_size = get_tensor_model_parallel_world_size() if not disable_tp else 1 - self.output_padding_size = 0 def update_param_tp_status(self): for param in self.parameters(): @@ -402,7 +380,7 @@ class ReplicatedLinear(LinearBase): # If the weight on disk does not have a shape, give it one # (such scales for AutoFp8). # Special case for GGUF - + is_gguf_weight = getattr(param, "is_gguf_weight", False) is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) if is_gguf_weight_type: @@ -419,7 +397,17 @@ class ReplicatedLinear(LinearBase): f"Tried to load weights of size {loaded_weight.size()}" f"to a parameter of size {param.size()}" ) + if self.opt_flag: + if self.opt_level == 1: + loaded_weight, scale = weight_quant_l1(loaded_weight) + elif self.opt_level == 2: + loaded_weight, scale = weight_quant_l2(loaded_weight, format="NN") + param.data.copy_(loaded_weight) + if self.opt_flag: + params_dict = dict(self.named_parameters()) + scale_param = params_dict["weight_scale"] + scale_param.data.copy_(scale) def forward( self, @@ -609,7 +597,18 @@ class ColumnParallelLinear(LinearBase): if len(loaded_weight.shape) == 0: assert loaded_weight.numel() == 1 loaded_weight = loaded_weight.reshape(1) + + if self.opt_flag: + if self.opt_level == 1: + loaded_weight, scale = weight_quant_l1(loaded_weight) + elif self.opt_level == 2: + loaded_weight, scale = weight_quant_l2(loaded_weight, format="NN") + param.load_column_parallel_weight(loaded_weight=loaded_weight) + if self.opt_flag: + params_dict = dict(self.named_parameters()) + scale_param = params_dict["weight_scale"] + scale_param.load_column_parallel_weight(loaded_weight=scale) def forward( self, @@ -733,16 +732,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear): loaded_shard_id: tuple[int, ...] | int | None = None, ): self.validate_shard_id(loaded_shard_id) - # FIXME(Isotr0py): Enable tuple shard_id for BNB quantization. - if isinstance(loaded_shard_id, tuple): - raise NotImplementedError( - "Shard id with multiple indices is not supported in weight_loader, " - "please use weight_loader_v2 instead." - ) # Special case for GGUF # initialize GGUF param after we know the quantize type is_gguf_weight = getattr(param, "is_gguf_weight", False) is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if isinstance(loaded_shard_id, tuple) and ( + is_gguf_weight or is_gguf_weight_type + ): + raise NotImplementedError( + "Shard id with multiple indices is not supported for GGUF." + ) if is_gguf_weight_type: if loaded_shard_id is not None: param.data[loaded_shard_id].copy_(loaded_weight) @@ -770,7 +769,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): # Special case for per-tensor scale to load scalar into fused array. needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) - if loaded_shard_id is None: + if loaded_shard_id is None or isinstance(loaded_shard_id, tuple): # Loaded weight is already fused on disk (mlp). # (e.g., Phi-3's gate_up_proj). if output_dim is None: @@ -782,10 +781,25 @@ class MergedColumnParallelLinear(ColumnParallelLinear): assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) return + + output_sizes = ( + self.output_sizes[loaded_shard_id[0] : loaded_shard_id[-1] + 1] + if loaded_shard_id is not None + else self.output_sizes + ) current_shard_offset = 0 use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) + if ( + use_bitsandbytes_4bit + and isinstance(loaded_shard_id, tuple) + and self.tp_size > 1 + ): + raise NotImplementedError( + "Shard id with multiple indices is not supported " + "for BNB quantization with TP yet." + ) shard_offsets: list[tuple[int, int, int]] = [] - for i, output_size in enumerate(self.output_sizes): + for i, output_size in enumerate(output_sizes): shard_offsets.append((i, current_shard_offset, output_size)) current_shard_offset += output_size packed_dim = getattr(param, "packed_dim", None) @@ -850,9 +864,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear): is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit if use_bitsandbytes_4bit: - shard_size = loaded_weight.shape[output_dim] - shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id - + index = list(itertools.accumulate([0] + self.output_sizes)) + orig_offsets = { + str(i): (index[i], size) for i, size in enumerate(self.output_sizes) + } + orig_offsets["total"] = (self.output_size, 0) + shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( + param, orig_offsets, str(loaded_shard_id) + ) param_data = param_data.narrow(output_dim, shard_offset, shard_size) start_idx = self.tp_rank * shard_size if not is_sharded_weight: @@ -921,12 +940,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear): loaded_weight: torch.Tensor, loaded_shard_id: tuple[int, ...] | int | None = None, ): + if self.opt_flag: + if self.opt_level == 1: + loaded_weight, scale = weight_quant_l1(loaded_weight) + elif self.opt_level == 2: + loaded_weight, scale = weight_quant_l2(loaded_weight, format="NN") self.validate_shard_id(loaded_shard_id) - dtype = loaded_weight.dtype - if envs.VLLM_W8A8_LINEAR_USE_W4A8 and not (param.shape[0] == 1 or param.shape[1] == 1) and dtype == torch.int8: - load_sizes = [self.output_sizes[i] // 2 for i in range(len(self.output_sizes))] - else: - load_sizes = self.output_sizes if loaded_shard_id is None or isinstance(loaded_shard_id, tuple): if isinstance(param, PerTensorScaleParameter): param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0) @@ -953,19 +972,21 @@ class MergedColumnParallelLinear(ColumnParallelLinear): assert loaded_shard_id < len(self.output_sizes) - # shard_offset = sum(self.output_sizes[:loaded_shard_id]) - # shard_size = self.output_sizes[loaded_shard_id] - shard_offset = sum(load_sizes[:loaded_shard_id]) - shard_size = load_sizes[loaded_shard_id] + shard_offset = sum(self.output_sizes[:loaded_shard_id]) + shard_size = self.output_sizes[loaded_shard_id] shard_offset //= self.tp_size shard_size //= self.tp_size + scale_shard_offset = shard_offset + scale_shard_size = shard_size + if self.opt_flag and self.opt_level == 2: + shard_offset = shard_offset // 2 + shard_size = shard_size // 2 if isinstance(param, BlockQuantScaleParameter): weight_block_size = getattr(self, "weight_block_size", None) shard_size, shard_offset = adjust_block_scale_shard( weight_block_size, shard_size, shard_offset ) - param.load_merged_column_weight( loaded_weight=loaded_weight, shard_id=loaded_shard_id, @@ -973,6 +994,16 @@ class MergedColumnParallelLinear(ColumnParallelLinear): shard_size=shard_size, tp_rank=self.tp_rank, ) + if self.opt_flag: + params_dict = dict(self.named_parameters()) + scale_param = params_dict["weight_scale"] + scale_param.load_merged_column_weight( + loaded_weight=scale, + shard_id=loaded_shard_id, + shard_offset=scale_shard_offset, + shard_size=scale_shard_size, + tp_rank=self.tp_rank, + ) class QKVParallelLinear(ColumnParallelLinear): @@ -1128,12 +1159,24 @@ class QKVParallelLinear(ColumnParallelLinear): shard_size, shard_offset = param.adjust_shard_indexes_for_packing( shard_size=shard_size, shard_offset=shard_offset ) - - loaded_weight_shard = loaded_weight.narrow( - param.output_dim, shard_offset, shard_size - ) + if self.opt_level == 2: + loaded_weight_shard = loaded_weight.narrow( + 0, shard_offset, shard_size + ) + else: + loaded_weight_shard = loaded_weight.narrow( + param.output_dim, shard_offset, shard_size + ) self.weight_loader_v2(param, loaded_weight_shard, shard_id) + def quant(self, loaded_weight: torch.Tensor): + if self.opt_flag: + if self.opt_level == 1: + return weight_quant_l1(loaded_weight) + elif self.opt_level == 2: + return weight_quant_l2(loaded_weight, format="NN") + return loaded_weight, None + def weight_loader_v2( self, param: BasevLLMParameter, @@ -1141,15 +1184,27 @@ class QKVParallelLinear(ColumnParallelLinear): loaded_shard_id: str | None = None, ): self.validate_shard_id(loaded_shard_id) + params_dict = dict(self.named_parameters()) if loaded_shard_id is None: # special case for certain models if isinstance(param, PerTensorScaleParameter): + loaded_weight, scale = self.quant(loaded_weight) param.load_qkv_weight( loaded_weight=loaded_weight, shard_id=0, tp_rank=self.tp_rank ) + if self.opt_flag: + scale_param = params_dict["weight_scale"] + scale_param.load_qkv_weight( + loaded_weight=scale, shard_id=0, tp_rank=self.tp_rank + ) return elif type(param) in (RowvLLMParameter, BasevLLMParameter): + loaded_weight, scale = self.quant(loaded_weight) param.load_qkv_weight(loaded_weight=loaded_weight, tp_rank=self.tp_rank) + if self.opt_flag: + scale_param = params_dict["weight_scale"] + scale_param.load_qkv_weight(loaded_weight=scale, tp_rank=self.tp_rank) return + # TODO: @dsikka - move to parameter.py self._load_fused_module_from_checkpoint(param, loaded_weight) return @@ -1158,11 +1213,15 @@ class QKVParallelLinear(ColumnParallelLinear): shard_offset = self._get_shard_offset_mapping(loaded_shard_id) shard_size = self._get_shard_size_mapping(loaded_shard_id) - dtype = loaded_weight.dtype - # w4a8 gemm需要除2,scale 不需要 - if envs.VLLM_W8A8_LINEAR_USE_W4A8 and not (param.shape[0] == 1 or param.shape[1] == 1) and dtype == torch.int8: - shard_offset //= 2 - shard_size //= 2 + + scale_shard_offset = shard_offset + scale_shard_size = shard_size + + loaded_weight, scale = self.quant(loaded_weight) + + if self.opt_flag and self.opt_level == 2: + shard_offset = shard_offset // 2 + shard_size = shard_size // 2 if isinstance(param, BlockQuantScaleParameter): weight_block_size = getattr(self, "weight_block_size", None) @@ -1179,6 +1238,15 @@ class QKVParallelLinear(ColumnParallelLinear): tp_rank=self.tp_rank, ) + if self.opt_flag: + scale_param = params_dict["weight_scale"] + scale_param.load_qkv_weight(loaded_weight=scale, + num_heads=self.num_kv_head_replicas, + shard_id=loaded_shard_id, + shard_offset=scale_shard_offset, + shard_size=scale_shard_size, + tp_rank=self.tp_rank) + def weight_loader( self, param: Parameter, @@ -1525,7 +1593,17 @@ class RowParallelLinear(LinearBase): assert loaded_weight.numel() == 1 loaded_weight = loaded_weight.reshape(1) + if self.opt_flag: + if self.opt_level == 1: + loaded_weight, scale = weight_quant_l1(loaded_weight) + elif self.opt_level == 2: + loaded_weight, scale = weight_quant_l2(loaded_weight, format="NN") + param.load_row_parallel_weight(loaded_weight=loaded_weight) + if self.opt_flag: + params_dict = dict(self.named_parameters()) + scale_param = params_dict["weight_scale"] + scale_param.load_row_parallel_weight(loaded_weight=scale) def forward( self, diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index dd2a61b..2e6fd19 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -61,7 +61,10 @@ class LogitsProcessor(CustomOp): logits = hidden_states else: # Get the logits for the next tokens. - logits = self._get_logits(hidden_states, lm_head, embedding_bias) + if hidden_states.shape[0] > 0: + logits = self._get_logits(hidden_states, lm_head, embedding_bias) + else: + logits = torch.empty([0, lm_head.weight.shape[0]], device=hidden_states.device, dtype=hidden_states.dtype) if logits is not None: if self.soft_cap is not None: logits = logits / self.soft_cap diff --git a/vllm/model_executor/layers/mamba/linear_attn.py b/vllm/model_executor/layers/mamba/linear_attn.py index 8b5f80f..8021418 100644 --- a/vllm/model_executor/layers/mamba/linear_attn.py +++ b/vllm/model_executor/layers/mamba/linear_attn.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math +from collections.abc import Callable import torch import torch.nn.functional as F @@ -43,7 +44,6 @@ class MiniMaxText01RMSNormTP(CustomOp): self.weight.weight_loader = self.weight_loader self.variance_epsilon = eps - return @staticmethod def weight_loader( @@ -56,7 +56,6 @@ class MiniMaxText01RMSNormTP(CustomOp): shard_size = loaded_weight.shape[0] // tp_world shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) param.data.copy_(loaded_weight[shard]) - return def _forward( self, @@ -102,6 +101,101 @@ class MiniMaxText01RMSNormTP(CustomOp): return q, k +def clear_linear_attention_cache_for_new_sequences( + kv_cache: torch.Tensor, + state_indices_tensor: torch.Tensor, + attn_metadata: LinearAttentionMetadata, +) -> None: + num_prefills = getattr(attn_metadata, "num_prefills", 0) + if num_prefills <= 0: + return + + num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", 0) + for prefill_idx in range(num_prefills): + q_start = attn_metadata.query_start_loc[num_decode_tokens + prefill_idx] + q_end = attn_metadata.query_start_loc[num_decode_tokens + prefill_idx + 1] + query_len = q_end - q_start + context_len = ( + attn_metadata.seq_lens[num_decode_tokens + prefill_idx] - query_len + ) + if context_len == 0: + block_to_clear = state_indices_tensor[num_decode_tokens + prefill_idx] + kv_cache[block_to_clear, ...] = 0 + + +def linear_attention_decode( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kv_cache: torch.Tensor, + slope_rate: torch.Tensor, + state_indices_tensor: torch.Tensor, + q_start: int = 0, + q_end: int | None = None, + slot_start: int = 0, + slot_end: int | None = None, + block_size: int = 32, +) -> torch.Tensor: + q = q[q_start:q_end].unsqueeze(2).contiguous() + k = k[q_start:q_end].unsqueeze(2).contiguous() + v = v[q_start:q_end].unsqueeze(2).contiguous() + slot_id = state_indices_tensor[slot_start:slot_end] + return linear_decode_forward_triton( + q, k, v, kv_cache, slope_rate, slot_id, block_size + ) + + +def linear_attention_prefill_and_mix( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kv_cache: torch.Tensor, + state_indices_tensor: torch.Tensor, + attn_metadata: LinearAttentionMetadata, + slope_rate: torch.Tensor, + block_size: int, + decode_fn: Callable[..., torch.Tensor], + prefix_fn: Callable[..., torch.Tensor], + layer_idx: int | None = None, +) -> torch.Tensor: + hidden = [] + for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)): + if _prefill_idx >= len(attn_metadata.query_start_loc): + break + if _prefill_idx >= len(state_indices_tensor): + break + offset = attn_metadata.num_decode_tokens + _start = attn_metadata.query_start_loc[offset + _prefill_idx] + _end = attn_metadata.query_start_loc[offset + _prefill_idx + 1] + slot_id = state_indices_tensor[offset + _prefill_idx] + qs = q[_start:_end].transpose(0, 1).contiguous() + ks = k[_start:_end].transpose(0, 1).contiguous() + vs = v[_start:_end].transpose(0, 1).contiguous() + slice_layer_cache = kv_cache[slot_id, ...] + out_slice = prefix_fn( + qs, + ks, + vs, + slice_layer_cache, + slope_rate, + block_size, + layer_idx=layer_idx, + ) + hidden.append(out_slice.contiguous()) + + if attn_metadata.num_decode_tokens > 0: + hidden_decode = decode_fn( + q, k, v, kv_cache, state_indices_tensor, attn_metadata + ) + hidden.insert(0, hidden_decode) + + if not hidden: + return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype) + + hidden = torch.concat(hidden, dim=0).contiguous() + return hidden + + class MiniMaxText01LinearKernel: @staticmethod def jit_linear_forward_prefix( @@ -258,50 +352,33 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): def _prefill_and_mix_infer( self, q, k, v, kv_cache, state_indices_tensor, attn_metadata ): - hidden = [] - for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)): - if _prefill_idx >= len(attn_metadata.query_start_loc): - break - if _prefill_idx >= len(state_indices_tensor): - break - offset = attn_metadata.num_decode_tokens - _start = attn_metadata.query_start_loc[offset + _prefill_idx] - _end = attn_metadata.query_start_loc[offset + _prefill_idx + 1] - slot_id = state_indices_tensor[offset + _prefill_idx] - qs = q[_start:_end].transpose(0, 1).contiguous() - ks = k[_start:_end].transpose(0, 1).contiguous() - vs = v[_start:_end].transpose(0, 1).contiguous() - slice_layer_cache = kv_cache[slot_id, ...] - - out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix( - qs, - ks, - vs, - slice_layer_cache, - self.tp_slope, - self.BLOCK, - layer_idx=self.layer_idx, - ) - hidden.append(out_slice.contiguous()) - if attn_metadata.num_decode_tokens > 0: - hidden_decode = self._decode_infer( - q, k, v, kv_cache, state_indices_tensor, attn_metadata - ) - hidden.insert(0, hidden_decode) - - if not hidden: - return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype) - - hidden = torch.concat(hidden, dim=0).contiguous() - return hidden + return linear_attention_prefill_and_mix( + q=q, + k=k, + v=v, + kv_cache=kv_cache, + state_indices_tensor=state_indices_tensor, + attn_metadata=attn_metadata, + slope_rate=self.tp_slope, + block_size=self.BLOCK, + decode_fn=self._decode_infer, + prefix_fn=MiniMaxText01LinearKernel.jit_linear_forward_prefix, + layer_idx=self.layer_idx, + ) def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata): - q = q[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() - k = k[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() - v = v[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous() - slot_id = state_indices_tensor[: attn_metadata.num_decodes] - hidden = linear_decode_forward_triton( - q, k, v, kv_cache, self.tp_slope, slot_id, 32 + hidden = linear_attention_decode( + q, + k, + v, + kv_cache, + self.tp_slope, + state_indices_tensor, + q_start=0, + q_end=attn_metadata.num_decode_tokens, + slot_start=0, + slot_end=attn_metadata.num_decodes, + block_size=32, ) return hidden @@ -338,27 +415,9 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase): if attn_metadata is not None: kv_cache = self.kv_cache[forward_context.virtual_engine][0] state_indices_tensor = attn_metadata.state_indices_tensor - - num_prefills = getattr(attn_metadata, "num_prefills", 0) - if num_prefills > 0: - num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", 0) - for prefill_idx in range(num_prefills): - q_start = attn_metadata.query_start_loc[ - num_decode_tokens + prefill_idx - ] - q_end = attn_metadata.query_start_loc[ - num_decode_tokens + prefill_idx + 1 - ] - query_len = q_end - q_start - context_len = ( - attn_metadata.seq_lens[num_decode_tokens + prefill_idx] - - query_len - ) - if context_len == 0: - block_to_clear = state_indices_tensor[ - num_decode_tokens + prefill_idx - ] - kv_cache[block_to_clear, ...] = 0 + clear_linear_attention_cache_for_new_sequences( + kv_cache, state_indices_tensor, attn_metadata + ) decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 if attn_metadata is None: diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 24e189a..6a33fc7 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -271,6 +271,8 @@ class MambaMixer(MambaBase, PluggableLayer): conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] has_initial_states_p = attn_metadata.has_initial_states_p + cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p + last_chunk_indices_p = attn_metadata.last_chunk_indices_p # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) @@ -376,6 +378,8 @@ class MambaMixer(MambaBase, PluggableLayer): block_idx_first_scheduled_token=block_idx_first_scheduled_token_p, block_idx_last_scheduled_token=block_idx_last_scheduled_token_p, initial_state_idx=block_idx_last_computed_token_p, + cu_chunk_seqlen=cu_chunk_seqlen_p, + last_chunk_indices=last_chunk_indices_p, ) ssm_outputs.append(scan_out_p) diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py index fc8912f..1f6751f 100644 --- a/vllm/model_executor/layers/mamba/mamba_utils.py +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -289,9 +289,6 @@ def get_temporal_copy_spec( ) -get_full_copy_spec = get_temporal_copy_spec - - class MambaStateCopyFuncCalculator: @classmethod def linear_attention_state_copy_func(cls): diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index 507a0a9..b0c1ffb 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -1159,7 +1159,7 @@ def causal_conv1d_update( f"ERROR: conv_state_indices should have shape ({batch},*) but got {conv_state_indices.shape}" ) - # assert num_cache_lines >= batch + assert num_cache_lines >= batch assert weight.stride(1) == 1 # Need this # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index a0df65f..44e73dd 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -497,6 +497,8 @@ def selective_scan_fn( block_idx_first_scheduled_token=None, block_idx_last_scheduled_token=None, initial_state_idx=None, + cu_chunk_seqlen=None, + last_chunk_indices=None, ) -> torch.Tensor: """ u: (dim, total_length) for varlen or (batch, dim, seqlen) @@ -588,6 +590,8 @@ def selective_scan_fn( block_idx_first_scheduled_token, block_idx_last_scheduled_token, initial_state_idx, + cu_chunk_seqlen, + last_chunk_indices, ) if z is None: diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index e289727..ca09b23 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -9,7 +9,6 @@ from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.attention import MLAAttention from vllm.model_executor.layers.quantization import QuantizationConfig - @dataclass class MLAModules: """Modules used in MLA.""" @@ -18,7 +17,7 @@ class MLAModules: kv_b_proj: torch.nn.Module rotary_emb: torch.nn.Module o_proj: torch.nn.Module - fused_qkv_a_proj: torch.nn.Module | None + q_a_proj: torch.nn.Module | None kv_a_proj_with_mqa: torch.nn.Module | None q_a_layernorm: torch.nn.Module | None q_b_proj: torch.nn.Module | None @@ -74,7 +73,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): self.q_lora_rank = q_lora_rank self.kv_lora_rank = kv_lora_rank self.num_heads = num_heads - self.fused_qkv_a_proj = mla_modules.fused_qkv_a_proj + self.q_a_proj = mla_modules.q_a_proj self.kv_a_proj_with_mqa = mla_modules.kv_a_proj_with_mqa self.q_a_layernorm = mla_modules.q_a_layernorm self.q_b_proj = mla_modules.q_b_proj @@ -106,7 +105,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): kv_b_proj=self.kv_b_proj, use_sparse=self.is_sparse, indexer=self.indexer, - rotary_emb=self.rotary_emb + rotary_emb=self.rotary_emb, ) self.prefix = prefix @@ -119,60 +118,47 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ) -> torch.Tensor: q_c = None kv_lora = None - if self.q_lora_rank is not None: - assert self.fused_qkv_a_proj is not None, ( - "fused_qkv_a_proj is required when q_lora_rank is not None" - ) - assert self.q_a_layernorm is not None, ( - "q_a_layernorm is required when q_lora_rank is not None" - ) - assert self.q_b_proj is not None, ( - "q_b_proj is required when q_lora_rank is not None" - ) - qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] - q_c, kv_lora = qkv_lora.split( - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], - dim=-1, - ) - q_c = self.q_a_layernorm(q_c) - q = self.q_b_proj(q_c)[0] + q = self.q_a_proj(hidden_states)[0] + kv_a, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1) + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_heads, self.qk_head_dim) + kv_a = self.kv_a_layernorm(kv_a) else: - assert self.kv_a_proj_with_mqa is not None, ( - "kv_a_proj_with_mqa is required when q_lora_rank is None" - ) - assert self.q_proj is not None, ( - "q_proj is required when q_lora_rank is None" - ) - kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0] - q = self.q_proj(hidden_states)[0] - - kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - kv_c_normed = self.kv_a_layernorm(kv_c) - - q = q.view(-1, self.num_heads, self.qk_head_dim) - # Add head dim of 1 to k_pe - # k_pe = k_pe.unsqueeze(1) - - # if self.rotary_emb is not None: - # q[..., self.qk_nope_head_dim :], k_pe = self.rotary_emb( - # positions, q[..., self.qk_nope_head_dim :], k_pe - # ) - - if self.indexer and self.is_sparse: - _topk_indices = self.indexer( - hidden_states, q_c, positions, self.indexer_rope_emb - ) - + q = self.q_proj(hidden_states)[0].view(-1, self.num_heads, self.qk_head_dim) + latent_kpe = self.kv_a_proj_with_mqa(hidden_states)[0] + kv_a, k_pe = latent_kpe.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1) + kv_a = self.kv_a_layernorm(kv_a) + + # NOTE attention data do not have position, pass it here if llama_4_scaling is not None: q *= llama_4_scaling - - self.mla_attn.impl.forward_prepare(positions) - attn_out = self.mla_attn( - q, - kv_c_normed, - k_pe, - output_shape=(hidden_states.shape[0], self.num_heads * self.v_head_dim), - ) - + attn_out = self.mla_attn(q, kv_a, k_pe, positions) return self.o_proj(attn_out)[0] + + def forward_opt( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + llama_4_scaling: torch.Tensor | None = None): + if self.q_lora_rank is not None: + q_latent_kpe = self.q_a_proj(hidden_states)[0] + q, kv_a, k_pe, _ = q_latent_kpe.split([self.q_lora_rank, self.kv_lora_rank, self.qk_rope_head_dim, self.q_a_proj.output_padding_size], dim=1) + q_c = self.q_a_layernorm(q) + q = self.q_b_proj(q_c)[0].view(-1, self.num_heads, self.qk_head_dim) + kv_a = self.kv_a_layernorm(kv_a) + else: + q = self.q_proj(hidden_states)[0].view(-1, self.num_heads, self.qk_head_dim) + latent_kpe = self.kv_a_proj_with_mqa(hidden_states)[0] + kv_a, k_pe = latent_kpe.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1) + kv_a = self.kv_a_layernorm(kv_a) + if self.indexer and self.is_sparse: + _topk_indices = self.indexer(hidden_states, q_c, positions, + self.rotary_emb) + + # NOTE attention data do not have position, pass it here + if llama_4_scaling is not None: + q *= llama_4_scaling + attn_out = self.mla_attn(q, kv_a, k_pe, positions) + return self.o_proj(attn_out)[0] + diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 09e67f5..8794647 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -18,6 +18,7 @@ QuantizationMethods = Literal[ "modelopt", "modelopt_fp4", "modelopt_mxfp8", + "modelopt_mixed", "gguf", "gptq_marlin", "awq_marlin", @@ -32,6 +33,7 @@ QuantizationMethods = Literal[ "mxfp4", "petit_nvfp4", "cpu_awq", + "w8a16" ] QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) @@ -120,12 +122,18 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .gptq import GPTQConfig from .gptq_marlin import GPTQMarlinConfig from .inc import INCConfig - from .modelopt import ModelOptFp8Config, ModelOptMxFp8Config, ModelOptNvFp4Config + from .modelopt import ( + ModelOptFp8Config, + ModelOptMixedPrecisionConfig, + ModelOptMxFp8Config, + ModelOptNvFp4Config, + ) from .moe_wna16 import MoeWNA16Config from .mxfp4 import Mxfp4Config from .petit import PetitNvFp4Config from .ptpc_fp8 import PTPCFp8Config from .torchao import TorchAOConfig + from .w8a16 import W8a16Config method_to_config: dict[str, type[QuantizationConfig]] = { "awq": AWQConfig, @@ -135,6 +143,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "modelopt": ModelOptFp8Config, "modelopt_fp4": ModelOptNvFp4Config, "modelopt_mxfp8": ModelOptMxFp8Config, + "modelopt_mixed": ModelOptMixedPrecisionConfig, "gguf": GGUFConfig, "gptq_marlin": GPTQMarlinConfig, "awq_marlin": AWQMarlinConfig, @@ -151,6 +160,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: "mxfp4": Mxfp4Config, "petit_nvfp4": PetitNvFp4Config, "cpu_awq": CPUAWQConfig, + "w8a16": W8a16Config, } # Update the `method_to_config` with customized quantization methods. method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index d80a9ba..791ad0c 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any import torch from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE @@ -9,6 +9,7 @@ from torch.nn import Parameter import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops +from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, @@ -60,6 +61,7 @@ from vllm.transformers_utils.config import get_safetensors_params_metadata if TYPE_CHECKING: from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.models.utils import WeightsMapper +import ixformer.inference.functions as ixfops logger = init_logger(__name__) @@ -197,7 +199,7 @@ class AWQMarlinConfig(QuantizationConfig): quant_method.input_dtype = get_marlin_input_dtype(prefix) return quant_method elif isinstance(layer, FusedMoE): - from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config + # from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config # if is_layer_skipped( # prefix, @@ -213,9 +215,10 @@ class AWQMarlinConfig(QuantizationConfig): # return MoeWNA16Config.from_config(self.full_config).get_quant_method( # layer, prefix # ) - moe_quant_method = AWQMarlinMoEMethod(self, layer.moe_config) - moe_quant_method.input_dtype = get_marlin_input_dtype(prefix) - return moe_quant_method + # moe_quant_method = AWQMarlinMoEMethod(self, layer.moe_config) + # moe_quant_method.input_dtype = get_marlin_input_dtype(prefix) + # return moe_quant_method + return AWQMarlinMoEMethod(self, layer.moe_config) return None @classmethod @@ -389,13 +392,13 @@ class AWQMarlinLinearMethod(LinearMethodBase): replace_parameter(layer, "qweight", pad_qweight) replace_parameter(layer, "qzeros", pad_qzeros) replace_parameter(layer, "scales", pad_scales) - return + # TODO(gyf) Marlin format is not support for now.. device = layer.qweight.device layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False) layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False) layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False) - + return # Allocate marlin workspace layer.workspace = marlin_make_workspace_new(device) @@ -811,49 +814,33 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): self, layer: FusedMoE, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + # Assign the value of shared_experts_output to variable shared_experts_input for fusion + shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - # return fused_marlin_moe( - # x, - # layer.w13_qweight, - # layer.w2_qweight, - # getattr(layer, "w13_bias", None), - # getattr(layer, "w2_bias", None), - # layer.w13_scales, - # layer.w2_scales, - # topk_weights, - # topk_ids, - # input_global_scale1=getattr(layer, "w13_input_global_scale", None), - # input_global_scale2=getattr(layer, "w2_input_global_scale", None), - # quant_type_id=self.quant_type.id, - # apply_router_weight_on_input=layer.apply_router_weight_on_input, - # global_num_experts=layer.global_num_experts, - # expert_map=layer.expert_map, - # w1_zeros=layer.w13_qzeros, - # w2_zeros=layer.w2_qzeros, - # workspace=layer.workspace, - # input_dtype=self.input_dtype, - # inplace=not self.moe.disable_inplace, - # ) - num_tokens, num_experts = router_logits.shape + assert layer.activation.value == "silu", "Only SiLU activation is supported." + use_ep = layer.expert_map is not None + attn_metadata = get_forward_context().attn_metadata + if attn_metadata: + if isinstance(attn_metadata, dict): + only_decode = (use_ep == False and all(t.num_decodes > 0 and t.num_prefills ==0 for t in list(attn_metadata.values()))) + else: + only_decode = use_ep == False and attn_metadata.num_decodes > 0 and attn_metadata.num_prefills == 0 + else: + only_decode = False + + if use_ep: + start_eid = layer.ep_rank * layer.local_num_experts + end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, layer.global_num_experts) + if layer.apply_router_weight_on_input: + raise NotImplementedError( + "Apply router weight on input is not supported for" + "fused Marlin MoE method.") + + num_tokens = topk_ids.shape[0] + num_experts = layer.global_num_experts if use_ep: hidden_size = x.shape[1] ( @@ -875,7 +862,7 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): dtype=x.dtype, ) else: - expand_tokens = num_tokens * top_k + expand_tokens = num_tokens * layer.top_k ( src_to_dst, sorted_token_ids, @@ -885,7 +872,6 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): topk_ids=topk_ids, num_experts=num_experts, ) - expert_sizes_cpu = expert_sizes_gpu.cpu() # expand + reorder # TODO use kernel @@ -893,76 +879,130 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): hidden_states=x, dst_to_src=sorted_token_ids, dst_tokens=expand_tokens, - topk=top_k, + topk=layer.top_k, src_to_dst=src_to_dst, ) # w4a16 group gemm 1 # pt_output_1: (expand_tokens, 2n) dtype - pt_output_1 = ixfops.moe_w4a16_group_gemm( - input=expand_hidden_states, - weight=layer.w13_qweight, - w_scales=layer.w13_scales, - quant_type="awq", - tokens_per_experts=expert_sizes_cpu, - w_zeros=layer.w13_qzeros, - group_size=self.quant_config.group_size, - dst_to_src=None, - format="NN", - tokens_per_experts_gpu=expert_sizes_gpu, - ) - - # act - pt_output_2 = ixfops.silu_and_mul(pt_output_1) - - # w4a16 group gemm 2 + reorder - # pt_output_3: (expand_tokens, k) dtype - if use_ep: - pt_output_3 = torch.empty( - (num_tokens * top_k, hidden_size), - device=x.device, - dtype=x.dtype, - ) - - ixfops.moe_w4a16_group_gemm( - input=pt_output_2, - weight=layer.w2_qweight, - w_scales=layer.w2_scales, + if only_decode: + pt_output_1 = ixfops.moe_w4a16_group_gemv( + input=expand_hidden_states, + weight=layer.w13_qweight, + w_scales=layer.w13_scales, quant_type="awq", - tokens_per_experts=expert_sizes_cpu, - w_zeros=layer.w2_qzeros, + w_zeros=layer.w13_qzeros, group_size=self.quant_config.group_size, - dst_to_src=sorted_token_ids, - format="NN", - output=pt_output_3, - ) - - reduce_mask = src_to_dst == -1 - final_hidden_states = ixfops.moe_output_reduce_sum( - input=pt_output_3.view(num_tokens, top_k, -1), - topk_weight=topk_weights, - scaling_factor=routed_scaling_factor, - mask=reduce_mask, - ) - else: - pt_output_3 = ixfops.moe_w4a16_group_gemm( - input=pt_output_2, - weight=layer.w2_qweight, - w_scales=layer.w2_scales, - quant_type="awq", - tokens_per_experts=expert_sizes_cpu, - w_zeros=layer.w2_qzeros, - group_size=self.quant_config.group_size, - dst_to_src=sorted_token_ids, + dst_to_src=None, format="NN", tokens_per_experts_gpu=expert_sizes_gpu, ) - # mul + reduce_sum - # final_hidden_states: (num_tokens, k) + # act + pt_output_2 = ixfops.silu_and_mul(pt_output_1) + + pt_output_3 = ixfops.moe_w4a16_group_gemv( + input=pt_output_2, + weight=layer.w2_qweight, + w_scales=layer.w2_scales, + quant_type="awq", + w_zeros=layer.w2_qzeros, + group_size=self.quant_config.group_size, + dst_to_src=sorted_token_ids, + format="NN", + tokens_per_experts_gpu=expert_sizes_gpu, + ) + + # mul + reduce_sum + # final_hidden_states: (num_tokens, k) final_hidden_states = ixfops.moe_output_reduce_sum( - input=pt_output_3.view(num_tokens, top_k, -1), + input=pt_output_3.view(num_tokens, layer.top_k, -1), topk_weight=topk_weights, - scaling_factor=routed_scaling_factor + scaling_factor=layer.routed_scaling_factor, + extra_residual=shared_experts_input, ) + + else: + expert_sizes_cpu = expert_sizes_gpu.cpu() + pt_output_1 = ixfops.moe_w4a16_group_gemm( + input=expand_hidden_states, + weight=layer.w13_qweight, + w_scales=layer.w13_scales, + quant_type="awq", + tokens_per_experts=expert_sizes_cpu, + w_zeros=layer.w13_qzeros, + group_size=self.quant_config.group_size, + dst_to_src=None, + format="NN", + tokens_per_experts_gpu=expert_sizes_gpu, + ) + + # act + pt_output_2 = ixfops.silu_and_mul(pt_output_1) + + # w4a16 group gemm 2 + reorder + # pt_output_3: (expand_tokens, k) dtype + if use_ep: + pt_output_3 = torch.empty( + (num_tokens * layer.top_k, hidden_size), + device=x.device, + dtype=x.dtype, + ) + + ixfops.moe_w4a16_group_gemm( + input=pt_output_2, + weight=layer.w2_qweight, + w_scales=layer.w2_scales, + quant_type="awq", + tokens_per_experts=expert_sizes_cpu, + w_zeros=layer.w2_qzeros, + group_size=self.quant_config.group_size, + dst_to_src=sorted_token_ids, + format="NN", + output=pt_output_3, + tokens_per_experts_gpu=expert_sizes_gpu, + ) + + reduce_mask = src_to_dst == -1 + final_hidden_states = ixfops.moe_output_reduce_sum( + input=pt_output_3.view(num_tokens, layer.top_k, -1), + topk_weight=topk_weights, + scaling_factor=layer.routed_scaling_factor, + mask=reduce_mask, + ) + else: + pt_output_3 = ixfops.moe_w4a16_group_gemm( + input=pt_output_2, + weight=layer.w2_qweight, + w_scales=layer.w2_scales, + quant_type="awq", + tokens_per_experts=expert_sizes_cpu, + w_zeros=layer.w2_qzeros, + group_size=self.quant_config.group_size, + dst_to_src=sorted_token_ids, + format="NN", + tokens_per_experts_gpu=expert_sizes_gpu, + ) + + # mul + reduce_sum + # final_hidden_states: (num_tokens, k) + final_hidden_states = ixfops.moe_output_reduce_sum( + input=pt_output_3.view(num_tokens, layer.top_k, -1), + topk_weight=topk_weights, + scaling_factor=layer.routed_scaling_factor, + extra_residual=shared_experts_input, + ) return final_hidden_states + # return torch.ops.vllm.fused_marlin_moe( + # x, + # layer.w13_qweight, + # layer.w2_qweight, + # layer.w13_scales, + # layer.w2_scales, + # router_logits, + # topk_weights, + # topk_ids, + # w1_zeros=layer.w13_qzeros, + # w2_zeros=layer.w2_qzeros, + # num_bits=self.quant_config.weight_bits, + # ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 0654e0c..4a02c91 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -18,7 +18,6 @@ from compressed_tensors.quantization import ( ) from compressed_tensors.transform import TransformConfig -import vllm.envs as envs from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -52,7 +51,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, CompressedTensorsWNA16, - CompressedTensorsW4A8Int8 ) from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501 CompressedTensorsLinearTransformMethod, @@ -401,8 +399,8 @@ class CompressedTensorsConfig(QuantizationConfig): ) -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 weight_strategy = ( - weight_quant.strategy == QuantizationStrategy.TENSOR.value - or weight_quant.strategy == QuantizationStrategy.CHANNEL.value + weight_quant.strategy == QuantizationStrategy.CHANNEL.value + or weight_quant.strategy == QuantizationStrategy.GROUP.value ) is_tensor = ( weight_strategy @@ -420,8 +418,8 @@ class CompressedTensorsConfig(QuantizationConfig): ) -> bool: is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 weight_strategy = ( - weight_quant.strategy == QuantizationStrategy.TENSOR.value - or weight_quant.strategy == QuantizationStrategy.CHANNEL.value + weight_quant.strategy == QuantizationStrategy.CHANNEL.value + or weight_quant.strategy == QuantizationStrategy.GROUP.value ) is_token = ( weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value @@ -663,12 +661,6 @@ class CompressedTensorsConfig(QuantizationConfig): ) if self._is_dynamic_token_w8a8(weight_quant, input_quant): - if envs.VLLM_W8A8_LINEAR_USE_W4A8: - return CompressedTensorsW4A8Int8( - strategy=weight_quant.strategy, - is_static_input_scheme=False, - input_symmetric=input_quant.symmetric, - ) return CompressedTensorsW8A8Int8( strategy=weight_quant.strategy, is_static_input_scheme=False, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index d83b5c8..4c8b865 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -3,6 +3,7 @@ import enum from enum import Enum +from typing import List, Optional, Tuple import torch from compressed_tensors import CompressionFormat @@ -12,16 +13,16 @@ from compressed_tensors.quantization import ( QuantizationStrategy, ) -import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm import _custom_ops as ops from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEActivationFormat, + FusedMoEExpertsModular, FusedMoEMethodBase, - FusedMoEPermuteExpertsUnpermute, FusedMoeWeightScaleSupported, UnquantizedFusedMoEMethod, ) @@ -41,7 +42,6 @@ from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( fused_marlin_moe, ) from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( - Fp8MoeBackend, convert_to_fp8_moe_kernel_format, make_fp8_moe_kernel, make_fp8_moe_quant_config, @@ -60,18 +60,11 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compress WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP, ) -from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - flashinfer_trtllm_fp4_moe, - flashinfer_trtllm_fp4_routed_moe, -) from vllm.model_executor.layers.quantization.utils.flashinfer_mxint4_moe import ( flashinfer_trtllm_mxint4_moe, is_flashinfer_mxint4_moe_available, prepare_static_weights_for_trtllm_mxint4_moe, ) -from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - apply_fi_trtllm_fp8_per_tensor_moe, -) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( process_fp8_input_tensor_strategy_moe, process_fp8_weight_tensor_strategy_moe, @@ -106,8 +99,6 @@ from vllm.platforms import CpuArchEnum, current_platform logger = init_logger(__name__) import ixformer.inference.functions as ixfops import vllm.envs as envs -from vllm.forward_context import ForwardContext, get_forward_context - class GPTQMarlinState(Enum): @@ -126,6 +117,42 @@ __all__ = [ ] +def moe_w4a8_group_gemm_pad_k( + input: torch.Tensor, + weight: torch.Tensor, + i_scales: torch.Tensor, + w_scales: torch.Tensor, + output_dtype: torch.dtype, + tokens_per_experts: torch.Tensor, + w_i8scales: torch.Tensor = None, + w_i8zeros: torch.Tensor = None, + dst_to_src: torch.Tensor = None, + bias: Optional[torch.Tensor] = None, + format: int = 0, + version: int = 2, + group_size: int = -1, + persistent: int = 0, + gdr_buffer_ptr: int = None, + output: torch.Tensor = None, +): + # Pad k dimension if needed for kernel constraint (format 0/1: k%64==0 && k>=256; format 2: k%128==0 && k>=256) + if format in (0, 1): + k = weight.size(1) + k_padded = max(256, (k + 63) & ~63) # bitwise align to 64, avoids div/mul + else: + k = weight.size(2) << 1 # * 2 + k_padded = max(256, (k + 127) & ~127) # bitwise align to 128 + pad_k = k_padded - k + if pad_k > 0: + if format in (0, 1): + input = torch.nn.functional.pad(input, (0, pad_k), mode='constant', value=0) + weight = torch.nn.functional.pad(weight, (0, 0, 0, pad_k), mode='constant', value=0) + else: + input = torch.nn.functional.pad(input, (0, pad_k), mode='constant', value=0) + weight = torch.nn.functional.pad(weight, (0, pad_k >> 1, 0, 0), mode='constant', value=0) + + return ixfops.moe_w4a8_group_gemm(input, weight, i_scales, w_scales, output_dtype, tokens_per_experts, w_i8scales, w_i8zeros, dst_to_src, bias, format, version, group_size, persistent, gdr_buffer_ptr, output) + class CompressedTensorsMoEMethod(FusedMoEMethodBase): @staticmethod def get_moe_method( @@ -182,9 +209,13 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): f"but got format: {CompressionFormat.pack_quantized.value} " f" and bits: {weight_quant.num_bits}", ) + if envs.VLLM_WNA16_MOE_USE_W4A8: + assert weight_quant.num_bits == 4 and group_size == -1 + logger.info_once("Using CompressedTensorsW4A8MoEMethod") + return CompressedTensorsW4A16MoEMethod(weight_quant, input_quant, layer.moe_config) # Prefer to use the MarlinMoE kernel when it is supported. - if ( + elif ( not check_moe_marlin_supports_layer(layer, group_size) or current_platform.is_rocm() ): @@ -225,10 +256,11 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase): return CompressedTensorsW8A8Fp8MoEMethod( weight_quant, input_quant, layer.moe_config ) - elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): + elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant) or quant_config._is_static_tensor_w8a8(weight_quant, input_quant): if envs.VLLM_W8A8_MOE_USE_W4A8: - return CompressedTensorsW4A8MoEMethod(quant_config, layer.moe_config) - return CompressedTensorsW8A8Int8MoEMethod( + return CompressedTensorsW4A8MoEMethod(weight_quant, input_quant, layer.moe_config) + else: + return CompressedTensorsW8A8Int8MoEMethod( weight_quant, input_quant, layer.moe_config ) elif quant_config._is_fp8_w4a8_sm90(weight_quant, input_quant): @@ -343,7 +375,7 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config is not None: - self.moe_mk = make_nvfp4_moe_kernel( + self.moe_kernel = make_nvfp4_moe_kernel( moe_quant_config=self.moe_quant_config, moe_config=self.moe, experts_cls=self.experts_cls, @@ -358,10 +390,9 @@ class CompressedTensorsW4A4Mxfp4MoEMethod(CompressedTensorsMoEMethod): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - **kwargs, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.moe_mk is not None - return self.moe_mk( + assert self.moe_kernel is not None + return self.moe_kernel.apply( x, layer.w13_weight, layer.w2_weight, @@ -570,43 +601,27 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): layer.w13_input_scale = a13_scale layer.w2_input_scale = a2_scale - # Setup modular kernel for TP case and naive DP/EP case. - # In non-naive DP/EP case, we will create a ModularKernelMethod. - # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel - # in both cases. + # Setup modular kernel. self.moe_quant_config = self.get_fused_moe_quant_config(layer) - if self.moe_quant_config: - assert self.experts_cls is not None - self.moe_mk = make_nvfp4_moe_kernel( - moe_quant_config=self.moe_quant_config, - moe_config=self.moe, - experts_cls=self.experts_cls, - shared_experts=layer.shared_experts, - routing_tables=layer._maybe_init_expert_routing_tables(), - ) + assert self.experts_cls is not None + self.moe_kernel = make_nvfp4_moe_kernel( + moe_quant_config=self.moe_quant_config, + moe_config=self.moe, + experts_cls=self.experts_cls, + shared_experts=layer.shared_experts, + routing_tables=layer._maybe_init_expert_routing_tables(), + ) def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, - ) -> mk.FusedMoEPrepareAndFinalize | None: + ) -> mk.FusedMoEPrepareAndFinalizeModular | None: raise ValueError( f"{self.__class__.__name__} uses the new modular kernel initialization " "logic. This function should not be called." ) - def select_gemm_impl( - self, - prepare_finalize: mk.FusedMoEPrepareAndFinalize, - layer: torch.nn.Module, - ) -> mk.FusedMoEPermuteExpertsUnpermute: - raise ValueError( - f"{self.__class__.__name__} uses the new modular kernel initialization " - "logic. This function should not be called." - ) - - def get_fused_moe_quant_config( - self, layer: torch.nn.Module - ) -> FusedMoEQuantConfig | None: + def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: return make_nvfp4_moe_quant_config( backend=self.nvfp4_backend, w13_scale=layer.w13_weight_scale, @@ -617,13 +632,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): a2_scale=layer.w2_input_scale, ) - @property - def is_monolithic(self) -> bool: - return ( - self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM - and not self.moe.moe_parallel_config.enable_eplb - ) - def apply_monolithic( self, layer: FusedMoE, @@ -631,24 +639,20 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.is_monolithic - assert layer.activation == MoEActivation.SILU, ( - f"Only SiLU activation is supported, not {layer.activation}." - ) - assert ( - self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM - and not layer.enable_eplb - ) - return flashinfer_trtllm_fp4_moe( - layer=layer, - x=x, - router_logits=router_logits, - top_k=layer.top_k, + assert self.moe_kernel is not None + return self.moe_kernel.apply_monolithic( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, activation=layer.activation, global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, num_expert_group=layer.num_expert_group, topk_group=layer.topk_group, - custom_routing_function=layer.custom_routing_function, e_score_correction_bias=layer.e_score_correction_bias, + routed_scaling_factor=layer.routed_scaling_factor, ) def apply( @@ -658,36 +662,20 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - **kwargs, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert not self.is_monolithic - - # EPLB path - if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: - assert layer.enable_eplb - return flashinfer_trtllm_fp4_routed_moe( - layer=layer, - x=x, - topk_ids=topk_ids, - topk_weights=topk_weights, - top_k=layer.top_k, - activation=layer.activation, - global_num_experts=layer.global_num_experts, - ) - else: - assert self.moe_mk is not None - return self.moe_mk( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - activation=layer.activation, - global_num_experts=layer.global_num_experts, - expert_map=layer.expert_map, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - shared_experts_input=shared_experts_input, - ) + assert self.moe_kernel is not None + return self.moe_kernel.apply( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + shared_experts_input=shared_experts_input, + ) class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): @@ -946,7 +934,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): w13, w13_scale, shard_size=layer.intermediate_size_per_partition, - num_experts=layer.num_local_experts, + num_experts=layer.local_num_experts, is_act_and_mul=self.moe.is_act_and_mul, ) @@ -975,7 +963,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config: assert self.experts_cls is not None - self.moe_mk = make_fp8_moe_kernel( + self.moe_kernel = make_fp8_moe_kernel( moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, @@ -987,94 +975,47 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, - ) -> mk.FusedMoEPrepareAndFinalize | None: + ) -> mk.FusedMoEPrepareAndFinalizeModular | None: raise ValueError( f"{self.__class__.__name__} uses the new modular kernel initialization " "logic. This function should not be called." ) - def select_gemm_impl( - self, - prepare_finalize: mk.FusedMoEPrepareAndFinalize, - layer: torch.nn.Module, - ) -> mk.FusedMoEPermuteExpertsUnpermute: - raise ValueError( - f"{self.__class__.__name__} uses the new modular kernel initialization " - "logic. This function should not be called." - ) - - def get_fused_moe_quant_config( - self, layer: torch.nn.Module - ) -> FusedMoEQuantConfig | None: - w1_scale = layer.w13_weight_scale - w2_scale = layer.w2_weight_scale - a1_scale = layer.w13_input_scale - a2_scale = layer.w2_input_scale - + def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: + is_per_token = self.input_quant.strategy == QuantizationStrategy.TOKEN return make_fp8_moe_quant_config( fp8_backend=self.fp8_backend, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, - per_act_token_quant=( - self.input_quant.strategy == QuantizationStrategy.TOKEN - ), - per_out_ch_quant=(self.input_quant.strategy == QuantizationStrategy.TOKEN), + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + per_act_token_quant=is_per_token, + per_out_ch_quant=is_per_token, block_shape=self.weight_block_size, ) - @property - def is_monolithic(self) -> bool: - return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM - def apply_monolithic( self, layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.is_monolithic - assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM - assert layer.activation == MoEActivation.SILU, ( - f"Only SiLU activation is supported, not {layer.activation}." + assert self.moe_kernel is not None + return self.moe_kernel.apply_monolithic( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + e_score_correction_bias=layer.e_score_correction_bias, + routed_scaling_factor=layer.routed_scaling_factor, ) - if self.block_quant: - import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 - - return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( - routing_logits=router_logits, - routing_bias=layer.e_score_correction_bias, - x=x, - w13_weight=layer.w13_weight, - w13_weight_scale_inv=layer.w13_weight_scale, - w2_weight=layer.w2_weight, - w2_weight_scale_inv=layer.w2_weight_scale, - global_num_experts=layer.global_num_experts, - top_k=layer.top_k, - num_expert_group=layer.num_expert_group, - topk_group=layer.topk_group, - intermediate_size=layer.intermediate_size_per_partition, - expert_offset=layer.ep_rank * layer.local_num_experts, - local_num_experts=layer.local_num_experts, - block_shape=self.weight_block_size, - routing_method_type=layer.routing_method_type, - routed_scaling=layer.routed_scaling_factor, - ) - else: - return apply_fi_trtllm_fp8_per_tensor_moe( - layer=layer, - hidden_states=x, - router_logits=router_logits, - routing_bias=layer.e_score_correction_bias, - global_num_experts=layer.global_num_experts, - top_k=layer.top_k, - num_expert_group=layer.num_expert_group, - topk_group=layer.topk_group, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - ) - def apply( self, layer: FusedMoE, @@ -1082,11 +1023,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - **kwargs ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert not self.is_monolithic - assert self.moe_mk is not None - return self.moe_mk( + assert self.moe_kernel is not None + return self.moe_kernel.apply( x, layer.w13_weight, layer.w2_weight, @@ -1115,8 +1055,13 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): layer_name: str | None = None, ): super().__init__(moe) + self.has_bias = self.moe.has_bias self.weight_quant = weight_quant self.input_quant = input_quant + self.group_size = -1 if self.weight_quant.group_size is None else self.weight_quant.group_size + self.gemm_format = envs.VLLM_W8A8_FORMAT + assert self.gemm_format in ["TN", "NN"], f"W8A8 INT8 MoE only supports TN or NN format, but got {self.gemm_format}" + self.padding_dim = -1 if self.gemm_format == "TN" else -2 per_channel = ( self.weight_quant.strategy == QuantizationStrategy.CHANNEL @@ -1130,11 +1075,11 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ) self.static_input_scales = not self.input_quant.dynamic - if self.static_input_scales: - raise ValueError( - "For INT8 Fused MoE layers, we require channelwise, " - "dynamic per token quantization. Found static input scales." - ) + # if self.static_input_scales: + # raise ValueError( + # "For INT8 Fused MoE layers, we require channelwise, " + # "dynamic per token quantization. Found static input scales.") + def create_weights( self, @@ -1146,50 +1091,88 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): **extra_weight_attrs, ): params_dtype = torch.int8 + w13_num_shards = 2 if self.moe.is_act_and_mul else 1 w13_remainder = hidden_size % 64 w2_remainder = intermediate_size_per_partition % 64 + if w13_remainder != 0: + hidden_size_padded = hidden_size + (64 - w13_remainder) + else: + hidden_size_padded = hidden_size + if w2_remainder != 0: + intermediate_size_per_partition_padded = intermediate_size_per_partition + (64 - w2_remainder) + else: + intermediate_size_per_partition_padded = intermediate_size_per_partition - hidden_size_padded = hidden_size if w13_remainder == 0 else hidden_size + (64 - w13_remainder) - - intermediate_size_per_partition_padded = intermediate_size_per_partition if w2_remainder == 0 else intermediate_size_per_partition + (64 - w2_remainder) - + if self.gemm_format == "NN": + w13_shape = (num_experts, hidden_size_padded, w13_num_shards * intermediate_size_per_partition_padded) + w2_shape = (num_experts, intermediate_size_per_partition_padded, hidden_size_padded) + else: + w13_shape = (num_experts, w13_num_shards * intermediate_size_per_partition, hidden_size_padded) + w2_shape = (num_experts, hidden_size, intermediate_size_per_partition_padded) + # WEIGHTS - w13_weight = torch.nn.Parameter( - torch.empty( - num_experts, - 2 * intermediate_size_per_partition, - hidden_size_padded, - dtype=params_dtype, - ), - requires_grad=False, - ) + w13_weight = torch.nn.Parameter(torch.zeros( + w13_shape, + dtype=params_dtype), requires_grad=False) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - w2_weight = torch.nn.Parameter( - torch.empty( - num_experts, - hidden_size, - intermediate_size_per_partition_padded, - dtype=params_dtype, - ), - requires_grad=False, - ) + w2_weight = torch.nn.Parameter(torch.zeros( + w2_shape, + dtype=params_dtype), requires_grad=False) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) + if self.gemm_format == "TN": + w13_bias_shape = (num_experts, 2 * intermediate_size_per_partition) + w2_bias_shape = (num_experts, hidden_size) + elif self.gemm_format == "NN": + w13_bias_shape = (num_experts, 2 * intermediate_size_per_partition_padded) + w2_bias_shape = (num_experts, hidden_size_padded) + + if self.has_bias: + w13_bias = torch.nn.Parameter( + torch.zeros( + w13_bias_shape, + dtype=torch.bfloat16, + ), + requires_grad=False, + ) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) + else: + setattr(layer, "w13_bias", None) + + if self.has_bias: + w2_bias = torch.nn.Parameter( + torch.zeros( + w2_bias_shape, + dtype=torch.bfloat16, + ), + requires_grad=False, + ) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) + else: + setattr(layer, "w2_bias", None) + # WEIGHT_SCALES assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL + if self.gemm_format == "NN": + w13_weight_scale_shape = (num_experts, w13_num_shards * intermediate_size_per_partition_padded, 1) + w2_weight_scale_shape = (num_experts, hidden_size_padded, 1) + else: + w13_weight_scale_shape = (num_experts, w13_num_shards * intermediate_size_per_partition, 1) + w2_weight_scale_shape = (num_experts, hidden_size, 1) + w13_weight_scale = torch.nn.Parameter( - torch.ones( - num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 - ), + torch.ones(w13_weight_scale_shape, dtype=torch.float32), requires_grad=False, ) layer.register_parameter("w13_weight_scale", w13_weight_scale) w2_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), + torch.ones(w2_weight_scale_shape, dtype=torch.float32), requires_grad=False, ) layer.register_parameter("w2_weight_scale", w2_weight_scale) @@ -1201,9 +1184,23 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES - assert not self.static_input_scales - layer.w13_input_scale = None - layer.w2_input_scale = None + if self.static_input_scales: + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + w13_input_scale = torch.nn.Parameter(torch.ones( + num_experts, hidden_size, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter(torch.ones( + num_experts, intermediate_size_per_partition, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: + layer.w13_input_scale = None + layer.w2_input_scale = None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: pass @@ -1225,24 +1222,40 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): x: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - router_logits: torch.Tensor, - expert_map: torch.Tensor | None = None, - enable_eplb: bool = False, - extra_residual: torch.Tensor = None, - routed_scaling_factor: float = 1.0, - *args, - **kwargs + # Assign the value of shared_experts_output to variable shared_experts_input for fusion + shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - from vllm.model_executor.layers.fused_moe import fused_experts - use_ep = expert_map is not None + + attn_metadata = get_forward_context().attn_metadata + use_ep = layer.expert_map is not None + # unsupported ep now + if attn_metadata: + deepseek_instance = None + for value in attn_metadata.values(): + if hasattr(value, 'num_prefills') and hasattr(value, 'num_decodes'): + deepseek_instance = value + break + value_types = {type(value).__name__ for value in attn_metadata.values()} + is_same_class = len(value_types) == 1 + if is_same_class: + assert deepseek_instance + only_decode = (use_ep == False and all(t.num_decodes > 0 and t.num_prefills ==0 for t in list(attn_metadata.values()))) + else: + if deepseek_instance: + only_decode = (use_ep == False and deepseek_instance.num_decodes > 0 and deepseek_instance.num_prefills ==0) + else: + only_decode = False + else: + only_decode = False + if use_ep: start_eid = layer.ep_rank * layer.local_num_experts - end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, global_num_experts) - - - top_k = topk_ids.shape[-1] + end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, layer.global_num_experts) + dtype = x.dtype - num_tokens, num_experts = router_logits.shape + num_tokens = topk_ids.shape[0] + num_experts = layer.global_num_experts + if use_ep: hidden_size = x.shape[1] ( @@ -1264,7 +1277,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): dtype=x.dtype, ) else: - expand_tokens = num_tokens * top_k + expand_tokens = num_tokens * layer.top_k ( src_to_dst, sorted_token_ids, @@ -1281,108 +1294,1170 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): hidden_states=x, dst_to_src=sorted_token_ids, dst_tokens=expand_tokens, - topk=top_k, + topk=layer.top_k, src_to_dst=src_to_dst, topk_ids=None, smooth_scales=layer.w13_input_scale, ) - i8_hidden_states_align = i8_hidden_states - if i8_hidden_states.shape[-1] != layer.w13_weight.shape[-1]: padding = layer.w13_weight.shape[-1] - i8_hidden_states.shape[-1] i8_hidden_states_align = torch.nn.functional.pad(i8_hidden_states, (0, padding), mode='constant', value=0) + else: + i8_hidden_states_align = i8_hidden_states - # w8a8 group gemm 1 - pt_output_1 = ixfops.moe_w8a8_group_gemm( - input=i8_hidden_states_align, - weight=layer.w13_weight, - i_scales=a_scale, - w_scales=layer.w13_weight_scale, - output_dtype=dtype, - tokens_per_experts=expert_sizes_cpu, - dst_to_src=None, - format="TN", - ) + if only_decode and self.gemm_format == "NN": - # act + quant - pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8( - input=pt_output_1, - bias=None, - smooth_scales=layer.w2_input_scale, - dst_to_src=sorted_token_ids, - topk_ids=None, - act_type="swiglu", - ) - - pt_output_2_align = pt_output_2 - - if pt_output_2.shape[-1] != layer.w2_weight.shape[-1]: - padding = layer.w2_weight.shape[-1] - pt_output_2.shape[-1] - pt_output_2_align = torch.nn.functional.pad(pt_output_2, (0, padding), mode='constant', value=0) - - # w8a8 group gemm 2 + reorder - if use_ep: - pt_output_3 = torch.empty( - (num_tokens * top_k, hidden_size), - device=x.device, - dtype=x.dtype, + # expand + reorder + quant + i8_hidden_states, a_scale = ixfops.moe_expand_input_dynamic_scaled_int8( + hidden_states=x, + dst_to_src=sorted_token_ids, + dst_tokens=expand_tokens, + topk=layer.top_k, + src_to_dst=src_to_dst, + topk_ids=None, + smooth_scales=layer.w13_input_scale, ) - ixfops.moe_w8a8_group_gemm( + if i8_hidden_states.shape[-1] != layer.w13_weight.shape[-2]: + padding = layer.w13_weight.shape[-2] - i8_hidden_states.shape[-1] + i8_hidden_states_align = torch.nn.functional.pad(i8_hidden_states, (0, padding), mode='constant', value=0) + else: + i8_hidden_states_align = i8_hidden_states + + # w8a8 group gemm 1 + pt_output_1 = ixfops.moe_w8a8_group_gemv( + input=i8_hidden_states_align, + weight=layer.w13_weight, + i_scales=a_scale, + w_scales=layer.w13_weight_scale, + output_dtype=dtype, + tokens_per_experts=expert_sizes_gpu, + dst_to_src=None, + bias=layer.w13_bias, + format=0, + group_size=self.group_size, + ) + + # act + quant + if layer.activation.value == "swigluoai": + pt_output_2, a2_scale = ixfops.activation_swigluoai_dynamic_scaled_int8( + input=pt_output_1, + bias=None, + smooth_scales=layer.w2_input_scale, + dst_to_src=sorted_token_ids, + topk_ids=None, + ) + else: + pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8( + input=pt_output_1, + bias=None, + smooth_scales=layer.w2_input_scale, + dst_to_src=sorted_token_ids, + topk_ids=None, + act_type="swiglu", + ) + + + if pt_output_2.shape[-1] != layer.w2_weight.shape[-2]: + padding = layer.w2_weight.shape[-2] - pt_output_2.shape[-1] + pt_output_2_align = torch.nn.functional.pad(pt_output_2, (0, padding), mode='constant', value=0) + else: + pt_output_2_align = pt_output_2 + + # w8a8 group gemm 2 + reorder + pt_output_3 = ixfops.moe_w8a8_group_gemv( input=pt_output_2_align, weight=layer.w2_weight, i_scales=a2_scale, w_scales=layer.w2_weight_scale, output_dtype=dtype, - tokens_per_experts=expert_sizes_cpu, + tokens_per_experts=expert_sizes_gpu, dst_to_src=sorted_token_ids, - format="TN", - output=pt_output_3, + bias=layer.w2_bias, + format=0, + group_size=self.group_size, ) - - reduce_mask = src_to_dst == -1 + final_hidden_states = ixfops.moe_output_reduce_sum( - input=pt_output_3.view(num_tokens, top_k, -1), + input=pt_output_3.view(num_tokens, layer.top_k, -1), topk_weight=topk_weights, - extra_residual=extra_residual, - scaling_factor=routed_scaling_factor, - mask=reduce_mask, + scaling_factor=layer.routed_scaling_factor, + extra_residual=shared_experts_input, ) else: - pt_output_3 = ixfops.moe_w8a8_group_gemm( + expert_sizes_cpu = expert_sizes_gpu.cpu() + # expand + reorder + quant + i8_hidden_states, a_scale = ixfops.moe_expand_input_dynamic_scaled_int8( + hidden_states=x, + dst_to_src=sorted_token_ids, + dst_tokens=expand_tokens, + topk=layer.top_k, + src_to_dst=src_to_dst, + topk_ids=None, + smooth_scales=layer.w13_input_scale, + ) + + if i8_hidden_states.shape[-1] != layer.w13_weight.shape[self.padding_dim]: + padding = layer.w13_weight.shape[self.padding_dim] - i8_hidden_states.shape[-1] + i8_hidden_states_align = torch.nn.functional.pad(i8_hidden_states, (0, padding), mode='constant', value=0) + else: + i8_hidden_states_align = i8_hidden_states + + # w8a8 group gemm 1 + pt_output_1 = ixfops.moe_w8a8_group_gemm( + input=i8_hidden_states_align, + weight=layer.w13_weight, + i_scales=a_scale, + w_scales=layer.w13_weight_scale, + output_dtype=dtype, + tokens_per_experts=expert_sizes_cpu, + dst_to_src=None, + bias=layer.w13_bias, + format=self.gemm_format, + ) + + # act + quant + if layer.activation.value == "swigluoai": + pt_output_2, a2_scale = ixfops.activation_swigluoai_dynamic_scaled_int8( + input=pt_output_1, + bias=None, + smooth_scales=layer.w2_input_scale, + dst_to_src=sorted_token_ids, + topk_ids=None, + ) + else: + pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8( + input=pt_output_1, + bias=None, + smooth_scales=layer.w2_input_scale, + dst_to_src=sorted_token_ids, + topk_ids=None, + act_type="swiglu", + ) + + if pt_output_2.shape[-1] != layer.w2_weight.shape[self.padding_dim]: + padding = layer.w2_weight.shape[self.padding_dim] - pt_output_2.shape[-1] + pt_output_2_align = torch.nn.functional.pad(pt_output_2, (0, padding), mode='constant', value=0) + else: + pt_output_2_align = pt_output_2 + + + # w8a8 group gemm 2 + reorder + if use_ep: + pt_output_3 = torch.empty( + (num_tokens * layer.top_k, hidden_size), + device=x.device, + dtype=x.dtype, + ) + ixfops.moe_w8a8_group_gemm( + input=pt_output_2_align, + weight=layer.w2_weight, + i_scales=a2_scale, + w_scales=layer.w2_weight_scale, + output_dtype=dtype, + tokens_per_experts=expert_sizes_cpu, + dst_to_src=sorted_token_ids, + bias=layer.w2_bias, + format=self.gemm_format, + output=pt_output_3, + ) + + reduce_mask = src_to_dst == -1 + final_hidden_states = ixfops.moe_output_reduce_sum( + input=pt_output_3.view(num_tokens, layer.top_k, -1), + topk_weight=topk_weights, + scaling_factor=layer.routed_scaling_factor, + mask=reduce_mask, + ) + else: + pt_output_3 = ixfops.moe_w8a8_group_gemm( + input=pt_output_2_align, + weight=layer.w2_weight, + i_scales=a2_scale, + w_scales=layer.w2_weight_scale, + output_dtype=dtype, + tokens_per_experts=expert_sizes_cpu, + dst_to_src=sorted_token_ids, + bias=layer.w2_bias, + format=self.gemm_format, + ) + + # mul + reduce_sum + final_hidden_states = ixfops.moe_output_reduce_sum( + input=pt_output_3.view(num_tokens, layer.top_k, -1), + topk_weight=topk_weights, + scaling_factor=layer.routed_scaling_factor, + extra_residual=shared_experts_input, + ) + + return final_hidden_states + + +class CompressedTensorsW4A16MoEMethod(CompressedTensorsMoEMethod): + + def __init__( + self, + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs, + moe: FusedMoEConfig, + layer_name: str | None = None, + ): + super().__init__(moe) + self.has_bias = self.moe.has_bias + self.weight_quant = weight_quant + self.input_quant = input_quant + + self.pack_factor = 2 + self.group_size = -1 if self.weight_quant.group_size is None else self.weight_quant.group_size + self.weight_symmetric = self.weight_quant.symmetric + self.gemm_format = envs.VLLM_W4A8_FORMAT + self.format_mapping = {"NN":0,"NT":1,"TN":2} + self.version = envs.VLLM_W4A8_VERSION + assert self.gemm_format in ["TN","NN"] + self.static_input_scales = False + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + params_dtype = torch.int8 + w13_remainder = (hidden_size // self.pack_factor) % 64 + w2_remainder = (intermediate_size_per_partition // self.pack_factor) % 64 + if self.gemm_format == "TN": + if w13_remainder != 0: + w13_shape = (num_experts, 2 * intermediate_size_per_partition, (hidden_size // self.pack_factor) + 64 - w13_remainder) + else: + w13_shape = (num_experts, 2 * intermediate_size_per_partition, hidden_size // self.pack_factor) + + if w2_remainder != 0: + w2_shape = (num_experts, hidden_size, (intermediate_size_per_partition // self.pack_factor) + 64 - w2_remainder) + else: + w2_shape = (num_experts, hidden_size, intermediate_size_per_partition // self.pack_factor) + else: + w13_shape = (num_experts, hidden_size, 2 * intermediate_size_per_partition // self.pack_factor) + w2_shape = (num_experts, intermediate_size_per_partition, hidden_size // self.pack_factor) + + # WEIGHTS + # use process_weights_after_loading to get get right layout if gemm_format is NN + w13_weight = torch.nn.Parameter(torch.empty(w13_shape, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + if self.gemm_format == "NN": + setattr(w13_weight, "shard_dim", 1) + + if self.has_bias: + w13_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + dtype=torch.bfloat16, + ), + requires_grad=False, + ) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) + else: + setattr(layer, "w13_bias", None) + + if w2_remainder != 0: + w2_weight = torch.nn.Parameter(torch.zeros(w2_shape, + dtype=params_dtype), + requires_grad=False) + else: + w2_weight = torch.nn.Parameter(torch.empty(w2_shape, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + if self.gemm_format == "NN": + setattr(w2_weight, "shard_dim", 0) + + if self.has_bias: + w2_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + dtype=torch.bfloat16, + ), + requires_grad=False, + ) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) + else: + setattr(layer, "w2_bias", None) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + # The following scale or zero will use permute(0,2,1) to get right layout, init here to avoid rewrite data_loader + w13_weight_scale = torch.nn.Parameter(torch.empty(num_experts, + 1 if self.version == 2 else 1 if self.group_size == -1 else hidden_size // self.group_size, + 2 * intermediate_size_per_partition, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + setattr(w13_weight_scale, "shard_dim", 1) + + w2_weight_scale = torch.nn.Parameter(torch.empty(num_experts, + 1 if self.version == 2 else 1 if self.group_size == -1 else intermediate_size_per_partition // self.group_size, + hidden_size, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + setattr(w2_weight_scale, "shard_dim", 0) + # setattr(w2_weight_scale, "load_full_w2", True) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value if self.version == 2 or self.group_size == -1 else FusedMoeWeightScaleSupported.GROUP.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + if self.version == 2: + # INT8 -> INT4 weight scales/zeros + if self.group_size != -1: + w13_i8_weight_scale = torch.nn.Parameter(torch.empty(num_experts, + hidden_size // self.group_size, + 2 * intermediate_size_per_partition, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_i8_weight_scale", w13_i8_weight_scale) + setattr(w13_i8_weight_scale, "shard_dim", 1) + if not self.weight_symmetric: + w13_i8_weight_zero = torch.nn.Parameter(torch.empty(num_experts, + 1 if self.group_size == -1 else hidden_size // self.group_size, + 2 * intermediate_size_per_partition, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_i8_weight_zero", w13_i8_weight_zero) + setattr(w13_i8_weight_zero, "shard_dim", 1) + + if self.group_size != -1: + w2_i8_weight_scale = torch.nn.Parameter(torch.empty(num_experts, + intermediate_size_per_partition // self.group_size, + hidden_size, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_i8_weight_scale", w2_i8_weight_scale) + setattr(w2_i8_weight_scale, "shard_dim", 0) + if not self.weight_symmetric: + w2_i8_weight_zero = torch.nn.Parameter(torch.empty(num_experts, + 1 if self.group_size == -1 else intermediate_size_per_partition // self.group_size, + hidden_size, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_i8_weight_zero", w2_i8_weight_zero) + setattr(w2_i8_weight_zero, "shard_dim", 0) + + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value if self.group_size == -1 else FusedMoeWeightScaleSupported.GROUP.value}) + + if self.version == 2 and self.group_size != -1: + set_weight_attrs(w13_i8_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_i8_weight_scale, extra_weight_attrs) + else: + setattr(layer, "w13_i8_weight_scale", None) + setattr(layer, "w2_i8_weight_scale", None) + if self.version == 2 and not self.weight_symmetric: + set_weight_attrs(w13_i8_weight_zero, extra_weight_attrs) + set_weight_attrs(w2_i8_weight_zero, extra_weight_attrs) + else: + setattr(layer, "w13_i8_weight_zero", None) + setattr(layer, "w2_i8_weight_zero", None) + + # DO NOT SUPPORT INPUT_SCALES + if self.static_input_scales: + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + w13_input_scale = torch.nn.Parameter(torch.ones( + num_experts, hidden_size, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter(torch.ones( + num_experts, intermediate_size_per_partition, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + self.gemm_format = self.format_mapping[self.gemm_format] + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None: + return None + + def apply( + self, + layer: FusedMoE, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + # Assign the value of shared_experts_output to variable shared_experts_input for fusion + shared_experts_input: torch.Tensor | None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + attn_metadata = get_forward_context().attn_metadata + use_ep = layer.expert_map is not None + # unsupported ep now + if attn_metadata: + deepseek_instance = None + for value in attn_metadata.values(): + if hasattr(value, 'num_prefills') and hasattr(value, 'num_decodes'): + deepseek_instance = value + break + value_types = {type(value).__name__ for value in attn_metadata.values()} + is_same_class = len(value_types) == 1 + if is_same_class: + assert deepseek_instance + only_decode = (use_ep == False and all(t.num_decodes > 0 and t.num_prefills ==0 for t in list(attn_metadata.values()))) + else: + if deepseek_instance: + only_decode = (use_ep == False and deepseek_instance.num_decodes > 0 and deepseek_instance.num_prefills ==0) + else: + only_decode = False + else: + only_decode = False + + if use_ep: + start_eid = layer.ep_rank * layer.local_num_experts + end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, layer.global_num_experts) + + dtype = x.dtype + num_tokens = topk_ids.shape[0] + num_experts = layer.global_num_experts + if use_ep: + hidden_size = x.shape[1] + ( + src_to_dst, + sorted_token_ids, + expert_sizes_gpu, + expert_sizes_cpu, + expand_tokens, + ) = ixfops.moe_compute_token_index_ep( + topk_ids=topk_ids, + num_experts=num_experts, + start_expert_id=start_eid, + end_expert_id=end_eid, + ) + if expert_sizes_cpu.sum() == 0: + return torch.zeros( + (num_tokens, hidden_size), + device=x.device, + dtype=x.dtype, + ) + else: + expand_tokens = num_tokens * layer.top_k + ( + src_to_dst, + sorted_token_ids, + expert_sizes_gpu, + expert_sizes_cpu, + ) = ixfops.moe_compute_token_index( + topk_ids=topk_ids, + num_experts=num_experts, + ) + + if only_decode and self.gemm_format == 2: + i8_hidden_states, a_scale = ixfops.moe_expand_input_dynamic_scaled_int8( + hidden_states=x, + dst_to_src=sorted_token_ids, + dst_tokens=expand_tokens, + topk=layer.top_k, + src_to_dst=src_to_dst, + topk_ids=None, + smooth_scales=layer.w13_input_scale, + output_format = 1, + ) + + if i8_hidden_states.shape[-1] != layer.w13_weight.shape[-1] * 2: + padding = layer.w13_weight.shape[-1] * 2 - i8_hidden_states.shape[-1] + i8_hidden_states_align = torch.nn.functional.pad(i8_hidden_states, (0, padding), mode='constant', value=0) + else: + i8_hidden_states_align = i8_hidden_states + + pt_output_1 = ixfops.moe_w4a8_group_gemv( + input=i8_hidden_states_align, + weight=layer.w13_weight, + i_scales=a_scale, + w_scales=layer.w13_weight_scale, + output_dtype=dtype, + tokens_per_experts=expert_sizes_gpu, + w_i8scales=layer.w13_i8_weight_scale, + w_i8zeros=layer.w13_i8_weight_zero, + dst_to_src=None, + bias=layer.w13_bias, + format=self.gemm_format, + group_size=self.group_size, + ) + + if layer.activation.value == "swigluoai": + pt_output_2, a2_scale = ixfops.activation_swigluoai_dynamic_scaled_int8( + input=pt_output_1, + bias=None, + smooth_scales=layer.w2_input_scale, + dst_to_src=sorted_token_ids, + topk_ids=None, + output_format = 1, + ) + else: + pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8( + input=pt_output_1, + bias=None, + smooth_scales=layer.w2_input_scale, + dst_to_src=sorted_token_ids, + topk_ids=None, + act_type="swiglu", + output_format = 1, + ) + if pt_output_2.shape[-1] != layer.w2_weight.shape[-1] * 2: + padding = layer.w2_weight.shape[-1] * 2 - pt_output_2.shape[-1] + pt_output_2_align = torch.nn.functional.pad(pt_output_2, (0, padding), mode='constant', value=0) + else: + pt_output_2_align = pt_output_2 + + pt_output_3 = ixfops.moe_w4a8_group_gemv( input=pt_output_2_align, weight=layer.w2_weight, i_scales=a2_scale, w_scales=layer.w2_weight_scale, output_dtype=dtype, - tokens_per_experts=expert_sizes_cpu, + tokens_per_experts=expert_sizes_gpu, + w_i8scales=layer.w2_i8_weight_scale, + w_i8zeros=layer.w2_i8_weight_zero, dst_to_src=sorted_token_ids, - format="TN", + bias=layer.w2_bias, + format=self.gemm_format, + group_size=self.group_size, ) - - # mul + reduce_sum + # mul + reduce_sum final_hidden_states = ixfops.moe_output_reduce_sum( - input=pt_output_3.view(num_tokens, top_k, -1), + input=pt_output_3.view(num_tokens, layer.top_k, -1), topk_weight=topk_weights, - extra_residual=extra_residual, - scaling_factor=routed_scaling_factor + scaling_factor=layer.routed_scaling_factor, + extra_residual=shared_experts_input, + ) + else: + expert_sizes_cpu = expert_sizes_gpu.cpu() + # expand + reorder + quant + i8_hidden_states, a_scale = ixfops.moe_expand_input_dynamic_scaled_int8( + hidden_states=x, + dst_to_src=sorted_token_ids, + dst_tokens=expand_tokens, + topk=layer.top_k, + src_to_dst=src_to_dst, + topk_ids=None, + smooth_scales=layer.w13_input_scale, ) - return final_hidden_states - # return fused_experts( - # hidden_states=x, - # w1=layer.w13_weight, - # w2=layer.w2_weight, - # topk_weights=topk_weights, - # topk_ids=topk_ids, - # inplace=not self.moe.disable_inplace, - # activation=layer.activation, - # apply_router_weight_on_input=layer.apply_router_weight_on_input, - # global_num_experts=layer.global_num_experts, - # expert_map=layer.expert_map, - # quant_config=self.moe_quant_config, - # ) + if i8_hidden_states.shape[-1] != layer.w13_weight.shape[-1] * 2: + padding = layer.w13_weight.shape[-1] * 2 - i8_hidden_states.shape[-1] + i8_hidden_states_align = torch.nn.functional.pad(i8_hidden_states, (0, padding), mode='constant', value=0) + else: + i8_hidden_states_align = i8_hidden_states + + # w4a8 group gemm 1 + pt_output_1 = ixfops.moe_w4a8_group_gemm( + input=i8_hidden_states_align, + weight=layer.w13_weight, + i_scales=a_scale, + w_scales=layer.w13_weight_scale, + output_dtype=dtype, + tokens_per_experts=expert_sizes_cpu, + w_i8scales=layer.w13_i8_weight_scale, + w_i8zeros=layer.w13_i8_weight_zero, + dst_to_src=None, + bias=layer.w13_bias, + format=self.gemm_format, + group_size=self.group_size, + version=self.version + ) + + # act + quant + if layer.activation.value == "swigluoai": + pt_output_2, a2_scale = ixfops.activation_swigluoai_dynamic_scaled_int8( + input=pt_output_1, + bias=None, + smooth_scales=layer.w2_input_scale, + dst_to_src=sorted_token_ids, + topk_ids=None, + ) + else: + pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8( + input=pt_output_1, + bias=None, + smooth_scales=layer.w2_input_scale, + dst_to_src=sorted_token_ids, + topk_ids=None, + act_type="swiglu", + ) + + if pt_output_2.shape[-1] != layer.w2_weight.shape[-1] * 2 and self.gemm_format == 2: + padding = layer.w2_weight.shape[-1] * 2 - pt_output_2.shape[-1] + pt_output_2_align = torch.nn.functional.pad(pt_output_2, (0, padding), mode='constant', value=0) + else: + pt_output_2_align = pt_output_2 + + # w4a8 group gemm 2 + reorder + if use_ep: + pt_output_3 = torch.empty( + (num_tokens * layer.top_k, hidden_size), + device=x.device, + dtype=x.dtype, + ) + + ixfops.moe_w4a8_group_gemm( + input=pt_output_2_align, + weight=layer.w2_weight, + i_scales=a2_scale, + w_scales=layer.w2_weight_scale, + output_dtype=dtype, + tokens_per_experts=expert_sizes_cpu, + w_i8scales=layer.w2_i8_weight_scale, + w_i8zeros=layer.w2_i8_weight_zero, + dst_to_src=sorted_token_ids, + bias=layer.w2_bias, + format=self.gemm_format, + group_size=self.group_size, + version=self.version, + output=pt_output_3, + ) + + reduce_mask = src_to_dst == -1 + final_hidden_states = ixfops.moe_output_reduce_sum( + input=pt_output_3.view(num_tokens, layer.top_k, -1), + topk_weight=topk_weights, + scaling_factor=layer.routed_scaling_factor, + mask=reduce_mask, + ) + else: + pt_output_3 = ixfops.moe_w4a8_group_gemm( + input=pt_output_2_align, + weight=layer.w2_weight, + i_scales=a2_scale, + w_scales=layer.w2_weight_scale, + output_dtype=dtype, + tokens_per_experts=expert_sizes_cpu, + w_i8scales=layer.w2_i8_weight_scale, + w_i8zeros=layer.w2_i8_weight_zero, + dst_to_src=sorted_token_ids, + bias=layer.w2_bias, + format=self.gemm_format, + group_size=self.group_size, + version=self.version + ) + + # mul + reduce_sum + final_hidden_states = ixfops.moe_output_reduce_sum( + input=pt_output_3.view(num_tokens, layer.top_k, -1), + topk_weight=topk_weights, + scaling_factor=layer.routed_scaling_factor, + extra_residual=shared_experts_input, + ) + return final_hidden_states + +class CompressedTensorsW4A8MoEMethod(CompressedTensorsMoEMethod): + + def __init__( + self, + weight_quant: QuantizationArgs, + input_quant: QuantizationArgs, + moe: FusedMoEConfig, + layer_name: str | None = None, + ): + super().__init__(moe) + self.has_bias = self.moe.has_bias + self.weight_quant = weight_quant + self.input_quant = input_quant + + self.pack_factor = 2 + self.group_size = -1 if self.weight_quant.group_size is None else self.weight_quant.group_size + self.weight_symmetric = self.weight_quant.symmetric + self.gemm_format = envs.VLLM_W4A8_FORMAT + self.format_mapping = {"NN":0,"NT":1,"TN":2} + self.version = envs.VLLM_W4A8_VERSION + assert self.gemm_format in ["TN","NN"] + + if not ((self.weight_quant.strategy == QuantizationStrategy.CHANNEL + or self.weight_quant.strategy == QuantizationStrategy.GROUP) + and self.input_quant.strategy == QuantizationStrategy.TOKEN): + raise ValueError( + "For INT4 pack2 Fused MoE layers, only per-channel or group scales" + "for weights and per-token scales for activations are supported. Found " + f"{self.weight_quant}, {self.input_quant}") + + self.static_input_scales = not self.input_quant.dynamic + + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + params_dtype = torch.int8 + w13_remainder = (hidden_size // self.pack_factor) % 64 + w2_remainder = (intermediate_size_per_partition // self.pack_factor) % 64 + if self.gemm_format == "TN": + if w13_remainder != 0: + w13_shape = (num_experts, 2 * intermediate_size_per_partition, (hidden_size // self.pack_factor) + 64 - w13_remainder) + else: + w13_shape = (num_experts, 2 * intermediate_size_per_partition, hidden_size // self.pack_factor) + + if w2_remainder != 0: + w2_shape = (num_experts, hidden_size, (intermediate_size_per_partition // self.pack_factor) + 64 - w2_remainder) + else: + w2_shape = (num_experts, hidden_size, intermediate_size_per_partition // self.pack_factor) + else: + w13_shape = (num_experts, hidden_size, 2 * intermediate_size_per_partition // self.pack_factor) + w2_shape = (num_experts, intermediate_size_per_partition, hidden_size // self.pack_factor) + + # WEIGHTS + # use process_weights_after_loading to get get right layout if gemm_format is NN + w13_weight = torch.nn.Parameter(torch.empty(w13_shape, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + if self.gemm_format == "NN": + setattr(w13_weight, "shard_dim", 1) + + if self.has_bias: + w13_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + dtype=torch.bfloat16, + ), + requires_grad=False, + ) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) + else: + setattr(layer, "w13_bias", None) + + if w2_remainder != 0: + w2_weight = torch.nn.Parameter(torch.zeros(w2_shape, + dtype=params_dtype), + requires_grad=False) + else: + w2_weight = torch.nn.Parameter(torch.empty(w2_shape, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + if self.gemm_format == "NN": + setattr(w2_weight, "shard_dim", 0) + + if self.has_bias: + w2_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + dtype=torch.bfloat16, + ), + requires_grad=False, + ) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) + else: + setattr(layer, "w2_bias", None) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + # The following scale or zero will use permute(0,2,1) to get right layout, init here to avoid rewrite data_loader + w13_weight_scale = torch.nn.Parameter(torch.empty(num_experts, + 1 if self.version == 2 else 1 if self.group_size == -1 else hidden_size // self.group_size, + 2 * intermediate_size_per_partition, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + setattr(w13_weight_scale, "shard_dim", 1) + + w2_weight_scale = torch.nn.Parameter(torch.empty(num_experts, + 1 if self.version == 2 else 1 if self.group_size == -1 else intermediate_size_per_partition // self.group_size, + hidden_size, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + setattr(w2_weight_scale, "shard_dim", 0) + # setattr(w2_weight_scale, "load_full_w2", True) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value if self.version == 2 or self.group_size == -1 else FusedMoeWeightScaleSupported.GROUP.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + if self.version == 2: + # INT8 -> INT4 weight scales/zeros + if self.group_size != -1: + w13_i8_weight_scale = torch.nn.Parameter(torch.empty(num_experts, + hidden_size // self.group_size, + 2 * intermediate_size_per_partition, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_i8_weight_scale", w13_i8_weight_scale) + setattr(w13_i8_weight_scale, "shard_dim", 1) + if not self.weight_symmetric: + w13_i8_weight_zero = torch.nn.Parameter(torch.empty(num_experts, + 1 if self.group_size == -1 else hidden_size // self.group_size, + 2 * intermediate_size_per_partition, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_i8_weight_zero", w13_i8_weight_zero) + setattr(w13_i8_weight_zero, "shard_dim", 1) + + if self.group_size != -1: + w2_i8_weight_scale = torch.nn.Parameter(torch.empty(num_experts, + intermediate_size_per_partition // self.group_size, + hidden_size, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_i8_weight_scale", w2_i8_weight_scale) + setattr(w2_i8_weight_scale, "shard_dim", 0) + if not self.weight_symmetric: + w2_i8_weight_zero = torch.nn.Parameter(torch.empty(num_experts, + 1 if self.group_size == -1 else intermediate_size_per_partition // self.group_size, + hidden_size, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_i8_weight_zero", w2_i8_weight_zero) + setattr(w2_i8_weight_zero, "shard_dim", 0) + + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value if self.group_size == -1 else FusedMoeWeightScaleSupported.GROUP.value}) + + if self.version == 2 and self.group_size != -1: + set_weight_attrs(w13_i8_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_i8_weight_scale, extra_weight_attrs) + else: + setattr(layer, "w13_i8_weight_scale", None) + setattr(layer, "w2_i8_weight_scale", None) + if self.version == 2 and not self.weight_symmetric: + set_weight_attrs(w13_i8_weight_zero, extra_weight_attrs) + set_weight_attrs(w2_i8_weight_zero, extra_weight_attrs) + else: + setattr(layer, "w13_i8_weight_zero", None) + setattr(layer, "w2_i8_weight_zero", None) + + # DO NOT SUPPORT INPUT_SCALES + if self.static_input_scales: + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + w13_input_scale = torch.nn.Parameter(torch.ones( + num_experts, hidden_size, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter(torch.ones( + num_experts, intermediate_size_per_partition, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + self.gemm_format = self.format_mapping[self.gemm_format] + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None: + return None + + def apply( + self, + layer: FusedMoE, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + # Assign the value of shared_experts_output to variable shared_experts_input for fusion + shared_experts_input: torch.Tensor | None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + attn_metadata = get_forward_context().attn_metadata + use_ep = layer.expert_map is not None + # unsupported ep now + if attn_metadata: + deepseek_instance = None + for value in attn_metadata.values(): + if hasattr(value, 'num_prefills') and hasattr(value, 'num_decodes'): + deepseek_instance = value + break + value_types = {type(value).__name__ for value in attn_metadata.values()} + is_same_class = len(value_types) == 1 + if is_same_class: + assert deepseek_instance + only_decode = (use_ep == False and all(t.num_decodes > 0 and t.num_prefills ==0 for t in list(attn_metadata.values()))) + else: + if deepseek_instance: + only_decode = (use_ep == False and deepseek_instance.num_decodes > 0 and deepseek_instance.num_prefills ==0) + else: + only_decode = False + else: + only_decode = False + + if use_ep: + start_eid = layer.ep_rank * layer.local_num_experts + end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, layer.global_num_experts) + + dtype = x.dtype + num_tokens = topk_ids.shape[0] + num_experts = layer.global_num_experts + if use_ep: + hidden_size = x.shape[1] + ( + src_to_dst, + sorted_token_ids, + expert_sizes_gpu, + expert_sizes_cpu, + expand_tokens, + ) = ixfops.moe_compute_token_index_ep( + topk_ids=topk_ids, + num_experts=num_experts, + start_expert_id=start_eid, + end_expert_id=end_eid, + ) + if expert_sizes_cpu.sum() == 0: + return torch.zeros( + (num_tokens, hidden_size), + device=x.device, + dtype=x.dtype, + ) + else: + expand_tokens = num_tokens * layer.top_k + ( + src_to_dst, + sorted_token_ids, + expert_sizes_gpu, + expert_sizes_cpu, + ) = ixfops.moe_compute_token_index( + topk_ids=topk_ids, + num_experts=num_experts, + ) + + if only_decode and self.gemm_format == 2: + i8_hidden_states, a_scale = ixfops.moe_expand_input_dynamic_scaled_int8( + hidden_states=x, + dst_to_src=sorted_token_ids, + dst_tokens=expand_tokens, + topk=layer.top_k, + src_to_dst=src_to_dst, + topk_ids=None, + smooth_scales=layer.w13_input_scale, + output_format = 1, + ) + + if i8_hidden_states.shape[-1] != layer.w13_weight.shape[-1] * 2: + padding = layer.w13_weight.shape[-1] * 2 - i8_hidden_states.shape[-1] + i8_hidden_states_align = torch.nn.functional.pad(i8_hidden_states, (0, padding), mode='constant', value=0) + else: + i8_hidden_states_align = i8_hidden_states + + pt_output_1 = ixfops.moe_w4a8_group_gemv( + input=i8_hidden_states_align, + weight=layer.w13_weight, + i_scales=a_scale, + w_scales=layer.w13_weight_scale, + output_dtype=dtype, + tokens_per_experts=expert_sizes_gpu, + w_i8scales=layer.w13_i8_weight_scale, + w_i8zeros=layer.w13_i8_weight_zero, + dst_to_src=None, + bias=layer.w13_bias, + format=self.gemm_format, + group_size=self.group_size, + ) + + if layer.activation.value == "swigluoai": + pt_output_2, a2_scale = ixfops.activation_swigluoai_dynamic_scaled_int8( + input=pt_output_1, + bias=None, + smooth_scales=layer.w2_input_scale, + dst_to_src=sorted_token_ids, + topk_ids=None, + output_format = 1, + ) + elif layer.activation.value == "swiglustep": + pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8( + input=pt_output_1, + bias=None, + smooth_scales=layer.w2_input_scale, + dst_to_src=sorted_token_ids, + topk_ids=None, + act_type="swiglustep", + output_format = 1, + ) + else: + pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8( + input=pt_output_1, + bias=None, + smooth_scales=layer.w2_input_scale, + dst_to_src=sorted_token_ids, + topk_ids=None, + act_type="swiglu", + output_format = 1, + ) + if pt_output_2.shape[-1] != layer.w2_weight.shape[-1] * 2: + padding = layer.w2_weight.shape[-1] * 2 - pt_output_2.shape[-1] + pt_output_2_align = torch.nn.functional.pad(pt_output_2, (0, padding), mode='constant', value=0) + else: + pt_output_2_align = pt_output_2 + + pt_output_3 = ixfops.moe_w4a8_group_gemv( + input=pt_output_2_align, + weight=layer.w2_weight, + i_scales=a2_scale, + w_scales=layer.w2_weight_scale, + output_dtype=dtype, + tokens_per_experts=expert_sizes_gpu, + w_i8scales=layer.w2_i8_weight_scale, + w_i8zeros=layer.w2_i8_weight_zero, + dst_to_src=sorted_token_ids, + bias=layer.w2_bias, + format=self.gemm_format, + group_size=self.group_size, + ) + # mul + reduce_sum + final_hidden_states = ixfops.moe_output_reduce_sum( + input=pt_output_3.view(num_tokens, layer.top_k, -1), + topk_weight=topk_weights, + scaling_factor=layer.routed_scaling_factor, + extra_residual=shared_experts_input, + ) + else: + expert_sizes_cpu = expert_sizes_gpu.cpu() + # expand + reorder + quant + i8_hidden_states, a_scale = ixfops.moe_expand_input_dynamic_scaled_int8( + hidden_states=x, + dst_to_src=sorted_token_ids, + dst_tokens=expand_tokens, + topk=layer.top_k, + src_to_dst=src_to_dst, + topk_ids=None, + smooth_scales=layer.w13_input_scale, + ) + + if i8_hidden_states.shape[-1] != layer.w13_weight.shape[-1] * 2: + padding = layer.w13_weight.shape[-1] * 2 - i8_hidden_states.shape[-1] + i8_hidden_states_align = torch.nn.functional.pad(i8_hidden_states, (0, padding), mode='constant', value=0) + else: + i8_hidden_states_align = i8_hidden_states + + # w4a8 group gemm 1 + pt_output_1 = moe_w4a8_group_gemm_pad_k( + input=i8_hidden_states_align, + weight=layer.w13_weight, + i_scales=a_scale, + w_scales=layer.w13_weight_scale, + output_dtype=dtype, + tokens_per_experts=expert_sizes_cpu, + w_i8scales=layer.w13_i8_weight_scale, + w_i8zeros=layer.w13_i8_weight_zero, + dst_to_src=None, + bias=layer.w13_bias, + format=self.gemm_format, + group_size=self.group_size, + version=self.version + ) + + # act + quant + if layer.activation.value == "swigluoai": + pt_output_2, a2_scale = ixfops.activation_swigluoai_dynamic_scaled_int8( + input=pt_output_1, + bias=None, + smooth_scales=layer.w2_input_scale, + dst_to_src=sorted_token_ids, + topk_ids=None, + ) + elif layer.activation.value == "swiglustep": + pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8( + input=pt_output_1, + bias=None, + smooth_scales=layer.w2_input_scale, + dst_to_src=sorted_token_ids, + topk_ids=None, + act_type="swiglustep", + ) + else: + pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8( + input=pt_output_1, + bias=None, + smooth_scales=layer.w2_input_scale, + dst_to_src=sorted_token_ids, + topk_ids=None, + act_type="swiglu", + ) + + if pt_output_2.shape[-1] != layer.w2_weight.shape[-1] * 2 and self.gemm_format == 2: + padding = layer.w2_weight.shape[-1] * 2 - pt_output_2.shape[-1] + pt_output_2_align = torch.nn.functional.pad(pt_output_2, (0, padding), mode='constant', value=0) + else: + pt_output_2_align = pt_output_2 + + # w4a8 group gemm 2 + reorder + if use_ep: + pt_output_3 = torch.empty( + (num_tokens * layer.top_k, hidden_size), + device=x.device, + dtype=x.dtype, + ) + + moe_w4a8_group_gemm_pad_k( + input=pt_output_2_align, + weight=layer.w2_weight, + i_scales=a2_scale, + w_scales=layer.w2_weight_scale, + output_dtype=dtype, + tokens_per_experts=expert_sizes_cpu, + w_i8scales=layer.w2_i8_weight_scale, + w_i8zeros=layer.w2_i8_weight_zero, + dst_to_src=sorted_token_ids, + bias=layer.w2_bias, + format=self.gemm_format, + group_size=self.group_size, + version=self.version, + output=pt_output_3, + ) + + reduce_mask = src_to_dst == -1 + final_hidden_states = ixfops.moe_output_reduce_sum( + input=pt_output_3.view(num_tokens, layer.top_k, -1), + topk_weight=topk_weights, + scaling_factor=layer.routed_scaling_factor, + mask=reduce_mask, + ) + else: + pt_output_3 = moe_w4a8_group_gemm_pad_k( + input=pt_output_2_align, + weight=layer.w2_weight, + i_scales=a2_scale, + w_scales=layer.w2_weight_scale, + output_dtype=dtype, + tokens_per_experts=expert_sizes_cpu, + w_i8scales=layer.w2_i8_weight_scale, + w_i8zeros=layer.w2_i8_weight_zero, + dst_to_src=sorted_token_ids, + bias=layer.w2_bias, + format=self.gemm_format, + group_size=self.group_size, + version=self.version + ) + + # mul + reduce_sum + final_hidden_states = ixfops.moe_output_reduce_sum( + input=pt_output_3.view(num_tokens, layer.top_k, -1), + topk_weight=topk_weights, + scaling_factor=layer.routed_scaling_factor, + extra_residual=shared_experts_input, + ) + return final_hidden_states class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): @@ -1806,9 +2881,9 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): def select_gemm_impl( self, - prepare_finalize: mk.FusedMoEPrepareAndFinalize, + prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular, layer: torch.nn.Module, - ) -> mk.FusedMoEPermuteExpertsUnpermute: + ) -> mk.FusedMoEExpertsModular: assert self.num_bits == 4, "only supporting w4" layer.w13_weight = layer.w13_weight_packed layer.w2_weight = layer.w2_weight_packed @@ -1878,7 +2953,6 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - **kwargs ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.kernel_backend == "Marlin" return fused_marlin_moe( @@ -2098,9 +3172,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): def select_gemm_impl( self, - prepare_finalize: mk.FusedMoEPrepareAndFinalize, + prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular, layer: torch.nn.Module, - ) -> mk.FusedMoEPermuteExpertsUnpermute: + ) -> mk.FusedMoEExpertsModular: if self.moe.is_lora_enabled: assert self.moe_quant_config is not None from vllm.triton_utils import HAS_TRITON @@ -2128,7 +3202,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - **kwargs ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts @@ -2683,7 +3756,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, - ) -> mk.FusedMoEPrepareAndFinalize | None: + ) -> mk.FusedMoEPrepareAndFinalizeModular | None: return super().maybe_make_prepare_finalize(routing_tables) def get_fused_moe_quant_config( @@ -2704,9 +3777,9 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): def select_gemm_impl( self, - prepare_finalize: mk.FusedMoEPrepareAndFinalize, + prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular, layer: torch.nn.Module, - ) -> mk.FusedMoEPermuteExpertsUnpermute: + ) -> mk.FusedMoEExpertsModular: assert self.moe_quant_config is not None assert ( prepare_finalize.activation_format == FusedMoEActivationFormat.Standard @@ -2714,7 +3787,7 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): from vllm.model_executor.layers.fused_moe import CutlassExpertsW4A8Fp8 - experts: FusedMoEPermuteExpertsUnpermute + experts: FusedMoEExpertsModular logger.debug("CutlassExpertsW4A8Fp8(%s)", self.__class__.__name__) experts = CutlassExpertsW4A8Fp8( @@ -2746,7 +3819,6 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, - **kwargs ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if layer.enable_eplb: raise NotImplementedError( @@ -2785,183 +3857,118 @@ class CompressedTensorsW4A8Fp8MoEMethod(CompressedTensorsMoEMethod): def supports_eplb(self) -> bool: return False - -# for corex -class CompressedTensorsW4A8MoEMethod(CompressedTensorsMoEMethod): +class CompressedTensorsL1OptMoEMethod(CompressedTensorsMoEMethod): def __init__( self, - quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 moe: FusedMoEConfig, ): super().__init__(moe) - self.quant_config = quant_config - self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( - "weights") - self.input_quant = self.quant_config.target_scheme_map["Linear"].get( - "input_activations") - self.pack_factor = 2 - self.group_size = -1 if self.weight_quant.group_size is None else self.weight_quant.group_size - self.weight_symmetric = self.weight_quant.symmetric - self.gemm_format = envs.VLLM_W4A8_FORMAT - self.format_mapping = {"NN":0,"NT":1,"TN":2} - self.version = envs.VLLM_W4A8_VERSION - assert self.gemm_format in ["TN","NN"] + self.has_bias = self.moe.has_bias - if not ((self.weight_quant.strategy == QuantizationStrategy.CHANNEL - or self.weight_quant.strategy == QuantizationStrategy.GROUP) - and self.input_quant.strategy == QuantizationStrategy.TOKEN): - raise ValueError( - "For INT4 pack2 Fused MoE layers, only per-channel or group scales" - "for weights and per-token scales for activations are supported. Found " - f"{self.weight_quant}, {self.input_quant}") - - self.static_input_scales = not self.input_quant.dynamic - def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): params_dtype = torch.int8 - remainder = (intermediate_size_per_partition // self.pack_factor) % 64 - if self.gemm_format == "TN": - w13_shape = (num_experts, 2 * intermediate_size_per_partition, hidden_size // self.pack_factor) - if remainder != 0: - w2_shape = (num_experts, hidden_size, (intermediate_size_per_partition // self.pack_factor) + 64 - remainder) - else: - w2_shape = (num_experts, hidden_size, intermediate_size_per_partition // self.pack_factor) + + w13_remainder = hidden_size % 64 + w2_remainder = intermediate_size_per_partition % 64 + if w13_remainder != 0: + hidden_size_padded = hidden_size + (64 - w13_remainder) else: - w13_shape = (num_experts, hidden_size, 2 * intermediate_size_per_partition // self.pack_factor) - w2_shape = (num_experts, intermediate_size_per_partition, hidden_size // self.pack_factor) - + hidden_size_padded = hidden_size + if w2_remainder != 0: + intermediate_size_per_partition_padded = intermediate_size_per_partition + (64 - w2_remainder) + else: + intermediate_size_per_partition_padded = intermediate_size_per_partition + # WEIGHTS - # use process_weights_after_loading to get get right layout if gemm_format is NN - w13_weight = torch.nn.Parameter(torch.empty(w13_shape, - dtype=params_dtype), + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size_padded, + dtype=params_dtype), requires_grad=False) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) - if self.gemm_format == "NN": - setattr(w13_weight, "shard_dim", 1) - if remainder != 0: - w2_weight = torch.nn.Parameter(torch.zeros(w2_shape, - dtype=params_dtype), - requires_grad=False) + if self.has_bias: + w13_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + dtype=torch.bfloat16, + ), + requires_grad=False, + ) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) else: - w2_weight = torch.nn.Parameter(torch.empty(w2_shape, - dtype=params_dtype), - requires_grad=False) + setattr(layer, "w13_bias", None) + + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition_padded, + dtype=params_dtype), + requires_grad=False) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) - if self.gemm_format == "NN": - setattr(w2_weight, "shard_dim", 0) + + if self.has_bias: + w2_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + dtype=torch.bfloat16, + ), + requires_grad=False, + ) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) + else: + setattr(layer, "w2_bias", None) # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - # The following scale or zero will use permute(0,2,1) to get right layout, init here to avoid rewrite data_loader - w13_weight_scale = torch.nn.Parameter(torch.empty(num_experts, - 1 if self.version == 2 else 1 if self.group_size == -1 else hidden_size // self.group_size, - 2 * intermediate_size_per_partition, - dtype=torch.float32), + w13_weight_scale = torch.nn.Parameter(torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32), requires_grad=False) layer.register_parameter("w13_weight_scale", w13_weight_scale) - setattr(w13_weight_scale, "shard_dim", 1) - - w2_weight_scale = torch.nn.Parameter(torch.empty(num_experts, - 1 if self.version == 2 else 1 if self.group_size == -1 else intermediate_size_per_partition // self.group_size, - hidden_size, - dtype=torch.float32), + setattr(w13_weight_scale, "shard_dim", 0) + w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + hidden_size, + 1, + dtype=torch.float32), requires_grad=False) layer.register_parameter("w2_weight_scale", w2_weight_scale) - setattr(w2_weight_scale, "shard_dim", 0) - # setattr(w2_weight_scale, "load_full_w2", True) - + setattr(w2_weight_scale, "shard_dim", 1) + # Add PER-CHANNEL quantization for FusedMoE.weight_loader. extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value if self.version == 2 or self.group_size == -1 else FusedMoeWeightScaleSupported.GROUP.value}) + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) - - if self.version == 2: - # INT8 -> INT4 weight scales/zeros - if self.group_size != -1: - w13_i8_weight_scale = torch.nn.Parameter(torch.empty(num_experts, - hidden_size // self.group_size, - 2 * intermediate_size_per_partition, - dtype=torch.int32), - requires_grad=False) - layer.register_parameter("w13_i8_weight_scale", w13_i8_weight_scale) - setattr(w13_i8_weight_scale, "shard_dim", 1) - if not self.weight_symmetric: - w13_i8_weight_zero = torch.nn.Parameter(torch.empty(num_experts, - 1 if self.group_size == -1 else hidden_size // self.group_size, - 2 * intermediate_size_per_partition, - dtype=torch.int32), - requires_grad=False) - layer.register_parameter("w13_i8_weight_zero", w13_i8_weight_zero) - setattr(w13_i8_weight_zero, "shard_dim", 1) - - if self.group_size != -1: - w2_i8_weight_scale = torch.nn.Parameter(torch.empty(num_experts, - intermediate_size_per_partition // self.group_size, - hidden_size, - dtype=torch.int32), - requires_grad=False) - layer.register_parameter("w2_i8_weight_scale", w2_i8_weight_scale) - setattr(w2_i8_weight_scale, "shard_dim", 0) - if not self.weight_symmetric: - w2_i8_weight_zero = torch.nn.Parameter(torch.empty(num_experts, - 1 if self.group_size == -1 else intermediate_size_per_partition // self.group_size, - hidden_size, - dtype=torch.int32), - requires_grad=False) - layer.register_parameter("w2_i8_weight_zero", w2_i8_weight_zero) - setattr(w2_i8_weight_zero, "shard_dim", 0) - - # Add the quantization method used (per tensor/grouped/channel) - # to ensure the weight scales are loaded in properly - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value if self.group_size == -1 else FusedMoeWeightScaleSupported.GROUP.value}) - if self.version == 2 and self.group_size != -1: - set_weight_attrs(w13_i8_weight_scale, extra_weight_attrs) - set_weight_attrs(w2_i8_weight_scale, extra_weight_attrs) - else: - setattr(layer, "w13_i8_weight_scale", None) - setattr(layer, "w2_i8_weight_scale", None) - if self.version == 2 and not self.weight_symmetric: - set_weight_attrs(w13_i8_weight_zero, extra_weight_attrs) - set_weight_attrs(w2_i8_weight_zero, extra_weight_attrs) - else: - setattr(layer, "w13_i8_weight_zero", None) - setattr(layer, "w2_i8_weight_zero", None) - - # DO NOT SUPPORT INPUT_SCALES - if self.static_input_scales: - extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) - w13_input_scale = torch.nn.Parameter(torch.ones( - num_experts, hidden_size, dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w13_input_scale", w13_input_scale) - set_weight_attrs(w13_input_scale, extra_weight_attrs) + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + pass - w2_input_scale = torch.nn.Parameter(torch.ones( - num_experts, intermediate_size_per_partition, dtype=torch.float32), - requires_grad=False) - layer.register_parameter("w2_input_scale", w2_input_scale) - set_weight_attrs(w2_input_scale, extra_weight_attrs) - else: - layer.w13_input_scale = None - layer.w2_input_scale = None - - self.gemm_format = self.format_mapping[self.gemm_format] - def get_fused_moe_quant_config( - self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None: - return None + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + return int8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + per_act_token_quant=True, + ) def apply( self, @@ -2969,45 +3976,23 @@ class CompressedTensorsW4A8MoEMethod(CompressedTensorsMoEMethod): x: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - expert_map: torch.Tensor | None = None, - enable_eplb: bool = False, - extra_residual: torch.Tensor = None, - routed_scaling_factor: float = 1.0, - **kwargs - ) -> torch.Tensor: - attn_metadata = get_forward_context().attn_metadata - use_ep = expert_map is not None - # unsupported ep now - if attn_metadata: - if isinstance(attn_metadata, dict): - deepseek_instance = None - for value in attn_metadata.values(): - if hasattr(value, 'num_prefills') and hasattr(value, 'num_decodes'): - deepseek_instance = value - break - value_types = {type(value).__name__ for value in attn_metadata.values()} - is_same_class = len(value_types) == 1 - if is_same_class: - assert deepseek_instance - only_decode = (use_ep == False and all(t.num_decodes > 0 and t.num_prefills ==0 for t in list(attn_metadata.values()))) - else: - if deepseek_instance: - only_decode = (use_ep == False and deepseek_instance.num_decodes > 0 and deepseek_instance.num_prefills ==0) - else: - only_decode = False - else: - only_decode = False + # Assign the value of shared_experts_output to variable shared_experts_input for fusion + shared_experts_input: torch.Tensor | None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if layer.enable_eplb: + raise NotImplementedError( + "EPLB not supported for `CompressedTensorsW8A8Int8MoEMethod` yet." + ) - if enable_eplb: - raise NotImplementedError("EPLB not supported for " - "`CompressedTensorsW4A8MoEMethod` yet.") + use_ep = layer.expert_map is not None if use_ep: start_eid = layer.ep_rank * layer.local_num_experts - end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, global_num_experts) + end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, layer.global_num_experts) + dtype = x.dtype - num_tokens, num_experts = router_logits.shape + num_tokens = topk_ids.shape[0] + num_experts = layer.global_num_experts + if use_ep: hidden_size = x.shape[1] ( @@ -3029,13 +4014,316 @@ class CompressedTensorsW4A8MoEMethod(CompressedTensorsMoEMethod): dtype=x.dtype, ) else: - expand_tokens = num_tokens * top_k + expand_tokens = num_tokens * layer.top_k ( src_to_dst, sorted_token_ids, expert_sizes_gpu, expert_sizes_cpu, - ) = torch.ops.ixf_ops.moe_compute_token_index( + ) = ixfops.moe_compute_token_index( + topk_ids=topk_ids, + num_experts=num_experts, + ) + expert_sizes_cpu = expert_sizes_gpu.cpu() + + # expand + reorder + quant + i8_hidden_states, a_scale = ixfops.moe_expand_input_dynamic_scaled_int8( + hidden_states=x, + dst_to_src=sorted_token_ids, + dst_tokens=expand_tokens, + topk=layer.top_k, + src_to_dst=src_to_dst, + topk_ids=None, + smooth_scales=layer.w13_input_scale, + ) + + if i8_hidden_states.shape[-1] != layer.w13_weight.shape[-1]: + padding = layer.w13_weight.shape[-1] - i8_hidden_states.shape[-1] + i8_hidden_states_align = torch.nn.functional.pad(i8_hidden_states, (0, padding), mode='constant', value=0) + else: + i8_hidden_states_align = i8_hidden_states + + # w8a8 group gemm 1 + pt_output_1 = ixfops.moe_w8a8_group_gemm( + input=i8_hidden_states_align, + weight=layer.w13_weight, + i_scales=a_scale, + w_scales=layer.w13_weight_scale, + output_dtype=dtype, + tokens_per_experts=expert_sizes_cpu, + dst_to_src=None, + bias=layer.w13_bias, + format="TN", + ) + + # act + quant + if layer.activation.value == "swigluoai": + pt_output_2, a2_scale = ixfops.activation_swigluoai_dynamic_scaled_int8( + input=pt_output_1, + bias=None, + smooth_scales=layer.w2_input_scale, + dst_to_src=sorted_token_ids, + topk_ids=None, + ) + else: + pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8( + input=pt_output_1, + bias=None, + smooth_scales=layer.w2_input_scale, + dst_to_src=sorted_token_ids, + topk_ids=None, + act_type="swiglu", + ) + + if pt_output_2.shape[-1] != layer.w2_weight.shape[-1]: + padding = layer.w2_weight.shape[-1] - pt_output_2.shape[-1] + pt_output_2_align = torch.nn.functional.pad(pt_output_2, (0, padding), mode='constant', value=0) + else: + pt_output_2_align = pt_output_2 + + # w8a8 group gemm 2 + reorder + if use_ep: + pt_output_3 = torch.empty( + (num_tokens * layer.top_k, hidden_size), + device=x.device, + dtype=x.dtype, + ) + + ixfops.moe_w8a8_group_gemm( + input=pt_output_2_align, + weight=layer.w2_weight, + i_scales=a2_scale, + w_scales=layer.w2_weight_scale, + output_dtype=dtype, + tokens_per_experts=expert_sizes_cpu, + dst_to_src=sorted_token_ids, + bias=layer.w2_bias, + format="TN", + output=pt_output_3, + ) + + reduce_mask = src_to_dst == -1 + final_hidden_states = ixfops.moe_output_reduce_sum( + input=pt_output_3.view(num_tokens, layer.top_k, -1), + topk_weight=topk_weights, + scaling_factor=layer.routed_scaling_factor, + mask=reduce_mask, + ) + else: + pt_output_3 = ixfops.moe_w8a8_group_gemm( + input=pt_output_2_align, + weight=layer.w2_weight, + i_scales=a2_scale, + w_scales=layer.w2_weight_scale, + output_dtype=dtype, + tokens_per_experts=expert_sizes_cpu, + dst_to_src=sorted_token_ids, + bias=layer.w2_bias, + format="TN", + ) + + # mul + reduce_sum + final_hidden_states = ixfops.moe_output_reduce_sum( + input=pt_output_3.view(num_tokens, layer.top_k, -1), + topk_weight=topk_weights, + scaling_factor=layer.routed_scaling_factor, + extra_residual=shared_experts_input, + ) + return final_hidden_states + +class CompressedTensorsL2OptMoEMethod(CompressedTensorsMoEMethod): + + def __init__( + self, + moe: FusedMoEConfig, + ): + super().__init__(moe) + self.has_bias = self.moe.has_bias + self.pack_factor = 2 + self.group_size = -1 + self.version = 2 + self.gemm_format = envs.VLLM_W4A8_FORMAT + self.format_mapping = {"NN":0,"NT":1,"TN":2} + assert self.gemm_format in ["TN","NN"] + + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + params_dtype = torch.int8 + w13_remainder = (hidden_size // self.pack_factor) % 64 + w2_remainder = (intermediate_size_per_partition // self.pack_factor) % 64 + if self.gemm_format == "TN": + if w13_remainder != 0: + w13_shape = (num_experts, 2 * intermediate_size_per_partition, (hidden_size // self.pack_factor) + 64 - w13_remainder) + else: + w13_shape = (num_experts, 2 * intermediate_size_per_partition, hidden_size // self.pack_factor) + + if w2_remainder != 0: + w2_shape = (num_experts, hidden_size, (intermediate_size_per_partition // self.pack_factor) + 64 - w2_remainder) + else: + w2_shape = (num_experts, hidden_size, intermediate_size_per_partition // self.pack_factor) + else: + w13_shape = (num_experts, hidden_size, 2 * intermediate_size_per_partition // self.pack_factor) + w2_shape = (num_experts, intermediate_size_per_partition, hidden_size // self.pack_factor) + + # WEIGHTS + # use process_weights_after_loading to get get right layout if gemm_format is NN + w13_weight = torch.nn.Parameter(torch.empty(w13_shape, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + if self.gemm_format == "NN": + setattr(w13_weight, "shard_dim", 1) + + if self.has_bias: + w13_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + dtype=torch.bfloat16, + ), + requires_grad=False, + ) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) + else: + setattr(layer, "w13_bias", None) + + if w2_remainder != 0: + w2_weight = torch.nn.Parameter(torch.zeros(w2_shape, + dtype=params_dtype), + requires_grad=False) + else: + w2_weight = torch.nn.Parameter(torch.empty(w2_shape, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + if self.gemm_format == "NN": + setattr(w2_weight, "shard_dim", 0) + + if self.has_bias: + w2_bias = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + dtype=torch.bfloat16, + ), + requires_grad=False, + ) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) + else: + setattr(layer, "w2_bias", None) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + # The following scale or zero will use permute(0,2,1) to get right layout, init here to avoid rewrite data_loader + w13_weight_scale = torch.nn.Parameter(torch.empty(num_experts, + 1, + 2 * intermediate_size_per_partition, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + setattr(w13_weight_scale, "shard_dim", 1) + + w2_weight_scale = torch.nn.Parameter(torch.empty(num_experts, + 1, + hidden_size, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + setattr(w2_weight_scale, "shard_dim", 0) + + extra_weight_attrs.update({"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + setattr(layer, "w13_i8_weight_scale", None) + setattr(layer, "w2_i8_weight_scale", None) + setattr(layer, "w13_i8_weight_zero", None) + setattr(layer, "w2_i8_weight_zero", None) + + layer.w13_input_scale = None + layer.w2_input_scale = None + + self.gemm_format = self.format_mapping[self.gemm_format] + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None: + return None + + def apply( + self, + layer: FusedMoE, + x: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + # Assign the value of shared_experts_output to variable shared_experts_input for fusion + shared_experts_input: torch.Tensor | None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + attn_metadata = get_forward_context().attn_metadata + use_ep = layer.expert_map is not None + # unsupported ep now + if attn_metadata: + deepseek_instance = None + for value in attn_metadata.values(): + if hasattr(value, 'num_prefills') and hasattr(value, 'num_decodes'): + deepseek_instance = value + break + value_types = {type(value).__name__ for value in attn_metadata.values()} + is_same_class = len(value_types) == 1 + if is_same_class: + assert deepseek_instance + only_decode = (use_ep == False and all(t.num_decodes > 0 and t.num_prefills ==0 for t in list(attn_metadata.values()))) + else: + if deepseek_instance: + only_decode = (use_ep == False and deepseek_instance.num_decodes > 0 and deepseek_instance.num_prefills ==0) + else: + only_decode = False + else: + only_decode = False + if layer.enable_eplb: + raise NotImplementedError("EPLB not supported for `CompressedTensorsW4A8MoEMethod` yet.") + if use_ep: + start_eid = layer.ep_rank * layer.local_num_experts + end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, layer.global_num_experts) + + dtype = x.dtype + num_tokens = topk_ids.shape[0] + num_experts = layer.global_num_experts + if use_ep: + hidden_size = x.shape[1] + ( + src_to_dst, + sorted_token_ids, + expert_sizes_gpu, + expert_sizes_cpu, + expand_tokens, + ) = ixfops.moe_compute_token_index_ep( + topk_ids=topk_ids, + num_experts=num_experts, + start_expert_id=start_eid, + end_expert_id=end_eid, + ) + if expert_sizes_cpu.sum() == 0: + return torch.zeros( + (num_tokens, hidden_size), + device=x.device, + dtype=x.dtype, + ) + else: + expand_tokens = num_tokens * layer.top_k + ( + src_to_dst, + sorted_token_ids, + expert_sizes_gpu, + expert_sizes_cpu, + ) = ixfops.moe_compute_token_index( topk_ids=topk_ids, num_experts=num_experts, ) @@ -3045,15 +4333,21 @@ class CompressedTensorsW4A8MoEMethod(CompressedTensorsMoEMethod): hidden_states=x, dst_to_src=sorted_token_ids, dst_tokens=expand_tokens, - topk=top_k, + topk=layer.top_k, src_to_dst=src_to_dst, topk_ids=None, smooth_scales=layer.w13_input_scale, output_format = 1, ) + if i8_hidden_states.shape[-1] != layer.w13_weight.shape[-1] * 2: + padding = layer.w13_weight.shape[-1] * 2 - i8_hidden_states.shape[-1] + i8_hidden_states_align = torch.nn.functional.pad(i8_hidden_states, (0, padding), mode='constant', value=0) + else: + i8_hidden_states_align = i8_hidden_states + pt_output_1 = ixfops.moe_w4a8_group_gemv( - input=i8_hidden_states, + input=i8_hidden_states_align, weight=layer.w13_weight, i_scales=a_scale, w_scales=layer.w13_weight_scale, @@ -3062,19 +4356,30 @@ class CompressedTensorsW4A8MoEMethod(CompressedTensorsMoEMethod): w_i8scales=layer.w13_i8_weight_scale, w_i8zeros=layer.w13_i8_weight_zero, dst_to_src=None, + bias=layer.w13_bias, format=self.gemm_format, group_size=self.group_size, ) - pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8( - input=pt_output_1, - bias=None, - smooth_scales=layer.w2_input_scale, - dst_to_src=sorted_token_ids, - topk_ids=None, - act_type="swiglu", - output_format = 1, - ) + if layer.activation.value == "swigluoai": + pt_output_2, a2_scale = ixfops.activation_swigluoai_dynamic_scaled_int8( + input=pt_output_1, + bias=None, + smooth_scales=layer.w2_input_scale, + dst_to_src=sorted_token_ids, + topk_ids=None, + output_format = 1, + ) + else: + pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8( + input=pt_output_1, + bias=None, + smooth_scales=layer.w2_input_scale, + dst_to_src=sorted_token_ids, + topk_ids=None, + act_type="swiglu", + output_format = 1, + ) if pt_output_2.shape[-1] != layer.w2_weight.shape[-1] * 2: padding = layer.w2_weight.shape[-1] * 2 - pt_output_2.shape[-1] @@ -3092,32 +4397,39 @@ class CompressedTensorsW4A8MoEMethod(CompressedTensorsMoEMethod): w_i8scales=layer.w2_i8_weight_scale, w_i8zeros=layer.w2_i8_weight_zero, dst_to_src=sorted_token_ids, + bias=layer.w2_bias, format=self.gemm_format, group_size=self.group_size, ) # mul + reduce_sum final_hidden_states = ixfops.moe_output_reduce_sum( - input=pt_output_3.view(num_tokens, top_k, -1), + input=pt_output_3.view(num_tokens, layer.top_k, -1), topk_weight=topk_weights, - extra_residual=extra_residual, - scaling_factor=routed_scaling_factor - ) + scaling_factor=layer.routed_scaling_factor, + extra_residual=shared_experts_input, + ) else: expert_sizes_cpu = expert_sizes_gpu.cpu() # expand + reorder + quant - i8_hidden_states, a_scale = torch.ops.ixf_ops.moe_expand_input_dynamic_scaled_int8( + i8_hidden_states, a_scale = ixfops.moe_expand_input_dynamic_scaled_int8( hidden_states=x, dst_to_src=sorted_token_ids, dst_tokens=expand_tokens, - topk=top_k, + topk=layer.top_k, src_to_dst=src_to_dst, topk_ids=None, smooth_scales=layer.w13_input_scale, ) - # w4a8 group gemm 1 - pt_output_1 = torch.ops.ixf_ops.moe_w4a8_group_gemm( - input=i8_hidden_states, + if i8_hidden_states.shape[-1] != layer.w13_weight.shape[-1] * 2: + padding = layer.w13_weight.shape[-1] * 2 - i8_hidden_states.shape[-1] + i8_hidden_states_align = torch.nn.functional.pad(i8_hidden_states, (0, padding), mode='constant', value=0) + else: + i8_hidden_states_align = i8_hidden_states + + + pt_output_1 = moe_w4a8_group_gemm_pad_k( + input=i8_hidden_states_align, weight=layer.w13_weight, i_scales=a_scale, w_scales=layer.w13_weight_scale, @@ -3126,22 +4438,32 @@ class CompressedTensorsW4A8MoEMethod(CompressedTensorsMoEMethod): w_i8scales=layer.w13_i8_weight_scale, w_i8zeros=layer.w13_i8_weight_zero, dst_to_src=None, + bias=layer.w13_bias, format=self.gemm_format, group_size=self.group_size, version=self.version ) # act + quant - pt_output_2, a2_scale = torch.ops.ixf_ops.activation_dynamic_scaled_int8( - input=pt_output_1, - bias=None, - smooth_scales=layer.w2_input_scale, - dst_to_src=sorted_token_ids, - topk_ids=None, - act_type="swiglu", - ) + if layer.activation.value == "swigluoai": + pt_output_2, a2_scale = ixfops.activation_swigluoai_dynamic_scaled_int8( + input=pt_output_1, + bias=None, + smooth_scales=layer.w2_input_scale, + dst_to_src=sorted_token_ids, + topk_ids=None, + ) + else: + pt_output_2, a2_scale = ixfops.activation_dynamic_scaled_int8( + input=pt_output_1, + bias=None, + smooth_scales=layer.w2_input_scale, + dst_to_src=sorted_token_ids, + topk_ids=None, + act_type="swiglu", + ) - if pt_output_2.shape[-1] != layer.w2_weight.shape[-1] * 2: + if pt_output_2.shape[-1] != layer.w2_weight.shape[-1] * 2 and self.gemm_format == 2: padding = layer.w2_weight.shape[-1] * 2 - pt_output_2.shape[-1] pt_output_2_align = torch.nn.functional.pad(pt_output_2, (0, padding), mode='constant', value=0) else: @@ -3150,12 +4472,12 @@ class CompressedTensorsW4A8MoEMethod(CompressedTensorsMoEMethod): # w4a8 group gemm 2 + reorder if use_ep: pt_output_3 = torch.empty( - (num_tokens * top_k, hidden_size), + (num_tokens * layer.top_k, hidden_size), device=x.device, dtype=x.dtype, ) - torch.ops.ixf_ops.moe_w4a8_group_gemm( + moe_w4a8_group_gemm_pad_k( input=pt_output_2_align, weight=layer.w2_weight, i_scales=a2_scale, @@ -3165,6 +4487,7 @@ class CompressedTensorsW4A8MoEMethod(CompressedTensorsMoEMethod): w_i8scales=layer.w2_i8_weight_scale, w_i8zeros=layer.w2_i8_weight_zero, dst_to_src=sorted_token_ids, + bias=layer.w2_bias, format=self.gemm_format, group_size=self.group_size, version=self.version, @@ -3172,15 +4495,14 @@ class CompressedTensorsW4A8MoEMethod(CompressedTensorsMoEMethod): ) reduce_mask = src_to_dst == -1 - final_hidden_states = torch.ops.ixf_ops.moe_output_reduce_sum( - input=pt_output_3.view(num_tokens, top_k, -1), - topk_weight=topk_weight, - extra_residual=extra_residual, - scaling_factor=routed_scaling_factor, + final_hidden_states = ixfops.moe_output_reduce_sum( + input=pt_output_3.view(num_tokens, layer.top_k, -1), + topk_weight=topk_weights, + scaling_factor=layer.routed_scaling_factor, mask=reduce_mask, ) else: - pt_output_3 = torch.ops.ixf_ops.moe_w4a8_group_gemm( + pt_output_3 = moe_w4a8_group_gemm_pad_k( input=pt_output_2_align, weight=layer.w2_weight, i_scales=a2_scale, @@ -3190,16 +4512,17 @@ class CompressedTensorsW4A8MoEMethod(CompressedTensorsMoEMethod): w_i8scales=layer.w2_i8_weight_scale, w_i8zeros=layer.w2_i8_weight_zero, dst_to_src=sorted_token_ids, + bias=layer.w2_bias, format=self.gemm_format, group_size=self.group_size, version=self.version ) # mul + reduce_sum - final_hidden_states = torch.ops.ixf_ops.moe_output_reduce_sum( - input=pt_output_3.view(num_tokens, top_k, -1), + final_hidden_states = ixfops.moe_output_reduce_sum( + input=pt_output_3.view(num_tokens, layer.top_k, -1), topk_weight=topk_weights, - extra_residual=extra_residual, - scaling_factor=routed_scaling_factor - ) - return final_hidden_states \ No newline at end of file + scaling_factor=layer.routed_scaling_factor, + extra_residual=shared_experts_input, + ) + return final_hidden_states diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index 8157182..c9dd98d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -8,7 +8,7 @@ from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int from .compressed_tensors_w4a16_mxfp4 import CompressedTensorsW4A16Mxfp4 from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4 from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 -from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8, CompressedTensorsW4A8Int8 +from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8 from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8 from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS, CompressedTensorsWNA16 @@ -28,5 +28,4 @@ __all__ = [ "CompressedTensorsW4A4Fp4", "CompressedTensorsW4A8Int", "CompressedTensorsW4A8Fp8", - "CompressedTensorsW4A8Int8" ] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py index e660fc4..ed919a7 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -25,11 +25,18 @@ logger = init_logger(__name__) class CompressedTensorsW8A8Int8(CompressedTensorsScheme): def __init__( - self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool + self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool, is_w4a8_linear: bool = False ): - self.strategy = strategy + import vllm.envs as env + if env.VLLM_MIX_QUANTIZATION_TYPE == "TENSOR": + self.strategy = QuantizationStrategy.TENSOR + elif env.VLLM_MIX_QUANTIZATION_TYPE == "CHANNEL": + self.strategy = QuantizationStrategy.CHANNEL + else: + self.strategy = strategy self.is_static_input_scheme = is_static_input_scheme self.input_symmetric = input_symmetric + self.is_w4a8_linear = is_w4a8_linear @classmethod def get_min_capability(cls) -> int: @@ -53,16 +60,32 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): input_symmetric=self.input_symmetric, module_name=self.__class__.__name__, ) + remainder = input_size_per_partition % 64 + if remainder != 0: + input_size_per_partition_padded = input_size_per_partition + (64 - remainder) + else: + input_size_per_partition_padded = input_size_per_partition # WEIGHT - weight = ModelWeightParameter( - data=torch.empty( - sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8 - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) + if self.is_w4a8_linear: + # only "NN" is supported + weight = ModelWeightParameter(data=torch.empty( + input_size_per_partition_padded, + sum(output_partition_sizes) // 2, + dtype=torch.int8), + input_dim=0, + output_dim=1, + weight_loader=weight_loader, + ) + else: + weight = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition_padded, + dtype=torch.int8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) layer.register_parameter("weight", weight) @@ -109,104 +132,4 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): def apply_weights( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None ) -> torch.Tensor: - return self.kernel.apply_weights(layer, x, bias) - - - -class CompressedTensorsW4A8Int8(CompressedTensorsScheme): - def __init__( - self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool - ): - self.strategy = strategy - self.is_static_input_scheme = is_static_input_scheme - self.input_symmetric = input_symmetric - - @classmethod - def get_min_capability(cls) -> int: - # turing and up - return 75 - - def create_weights( - self, - layer: torch.nn.Module, - output_partition_sizes: list[int], - input_size_per_partition: int, - params_dtype: torch.dtype, - weight_loader: Callable, - **kwargs, - ): - layer.logical_widths = output_partition_sizes - - self.kernel = init_int8_linear_kernel( - is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL), - is_static_input_scheme=self.is_static_input_scheme, - input_symmetric=self.input_symmetric, - module_name=self.__class__.__name__, - ) - - # WEIGHT - # weight = ModelWeightParameter( - # data=torch.empty( - # sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8 - # ), - # input_dim=1, - # output_dim=0, - # weight_loader=weight_loader, - # ) - weight = ModelWeightParameter( - data=torch.empty( - input_size_per_partition, - sum(output_partition_sizes) // 2, - dtype=torch.int8 - ), - input_dim=0, - output_dim=1, - weight_loader=weight_loader, - ) - - layer.register_parameter("weight", weight) - - # WEIGHT SCALE - if self.strategy == QuantizationStrategy.CHANNEL: - weight_scale = ChannelQuantScaleParameter( - data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), - output_dim=0, - weight_loader=weight_loader, - ) - else: - assert self.strategy == QuantizationStrategy.TENSOR - weight_scale = PerTensorScaleParameter( - data=torch.empty(len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader, - ) - layer.register_parameter("weight_scale", weight_scale) - - # INPUT SCALE - input_zero_point = None - input_scale = None - if self.is_static_input_scheme: - input_scale = BasevLLMParameter( - data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader - ) - if not self.input_symmetric: - # Note: compressed-tensors stores the zp using the same dtype - # as the weights - # AZP loaded as int8 but used as int32 - input_zero_point = BasevLLMParameter( - data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader - ) - - layer.register_parameter("input_zero_point", input_zero_point) - layer.register_parameter("input_scale", input_scale) - if not hasattr(layer, "azp_adj"): - layer.register_parameter("azp_adj", None) - - # Checkpoints are serialized in compressed-tensors format, which is - # different from the format the kernel may want. Handle repacking here. - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - self.kernel.process_weights_after_loading(layer) - - def apply_weights( - self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None - ) -> torch.Tensor: - return self.kernel.apply_weights(layer, x, bias) \ No newline at end of file + return self.kernel.apply_weights(layer, x, bias, self.is_w4a8_linear) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index e3174ba..5101347 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -23,17 +23,13 @@ from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEMethodBase, - FusedMoEPermuteExpertsUnpermute, - FusedMoEPrepareAndFinalize, FusedMoeWeightScaleSupported, - MoEActivation, ) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( - Fp8MoeBackend, convert_to_fp8_moe_kernel_format, make_fp8_moe_kernel, make_fp8_moe_quant_config, @@ -50,9 +46,6 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizeMethodBase, ) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod -from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - apply_fi_trtllm_fp8_per_tensor_moe, -) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, create_fp8_input_scale, @@ -860,14 +853,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): replace_parameter(layer, f"w13_{self.weight_scale_name}", w13_scale) replace_parameter(layer, f"w2_{self.weight_scale_name}", w2_scale) - # Setup modular kernel for TP case and naive DP/EP case. - # In non-naive DP/EP case, we will create a ModularKernelMethod. - # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel - # in both cases. self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config: assert self.experts_cls is not None - self.moe_mk = make_fp8_moe_kernel( + self.moe_kernel = make_fp8_moe_kernel( moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, @@ -930,29 +919,13 @@ class Fp8MoEMethod(FusedMoEMethodBase): def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, - ) -> mk.FusedMoEPrepareAndFinalize | None: + ) -> mk.FusedMoEPrepareAndFinalizeModular | None: raise ValueError( f"{self.__class__.__name__} uses the new modular kernel initialization " "logic. This function should not be called." ) - def select_gemm_impl( - self, - prepare_finalize: FusedMoEPrepareAndFinalize, - layer: torch.nn.Module, - ) -> FusedMoEPermuteExpertsUnpermute: - raise ValueError( - f"{self.__class__.__name__} uses the new modular kernel initialization " - "logic. This function should not be called." - ) - - def get_fused_moe_quant_config( - self, layer: torch.nn.Module - ) -> FusedMoEQuantConfig | None: - # TRTLLM does not use Modular Kernel. - if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: - return None - + def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: w1_scale = getattr(layer, f"w13_{self.weight_scale_name}") w2_scale = getattr(layer, f"w2_{self.weight_scale_name}") a1_scale = layer.w13_input_scale @@ -983,10 +956,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): def supports_eplb(self) -> bool: return True - @property - def is_monolithic(self) -> bool: - return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM - def apply_monolithic( self, layer: FusedMoE, @@ -994,50 +963,22 @@ class Fp8MoEMethod(FusedMoEMethodBase): router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.is_monolithic - assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM - - # TODO(rob): convert this to MK. - if layer.enable_eplb: - raise NotImplementedError("EPLB not supported for `Fp8MoEMethod` yet.") - assert layer.activation == MoEActivation.SILU, ( - f"Expected 'silu' activation but got {layer.activation}" + assert self.moe_kernel is not None + return self.moe_kernel.apply_monolithic( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + num_expert_group=layer.num_expert_group, + topk_group=layer.topk_group, + e_score_correction_bias=layer.e_score_correction_bias, + routed_scaling_factor=layer.routed_scaling_factor, ) - if self.block_quant: - import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 - - return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8( - routing_logits=router_logits, - routing_bias=layer.e_score_correction_bias, - x=x, - w13_weight=layer.w13_weight, - w13_weight_scale_inv=layer.w13_weight_scale_inv, - w2_weight=layer.w2_weight, - w2_weight_scale_inv=layer.w2_weight_scale_inv, - global_num_experts=layer.global_num_experts, - top_k=layer.top_k, - num_expert_group=layer.num_expert_group, - topk_group=layer.topk_group, - intermediate_size=layer.intermediate_size_per_partition, - expert_offset=layer.ep_rank * layer.local_num_experts, - local_num_experts=layer.local_num_experts, - block_shape=self.weight_block_size, - routing_method_type=layer.routing_method_type, - routed_scaling=layer.routed_scaling_factor, - ) - else: - return apply_fi_trtllm_fp8_per_tensor_moe( - layer=layer, - hidden_states=x, - router_logits=router_logits, - routing_bias=layer.e_score_correction_bias, - global_num_experts=layer.global_num_experts, - top_k=layer.top_k, - num_expert_group=layer.num_expert_group, - topk_group=layer.topk_group, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - ) - def apply( self, layer: FusedMoE, @@ -1046,9 +987,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.moe_mk is not None assert not self.is_monolithic - return self.moe_mk( + assert self.moe_kernel is not None + return self.moe_kernel.apply( x, layer.w13_weight, layer.w2_weight, diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 8802334..c7dd67d 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -7,6 +7,7 @@ from typing import Any import gguf import torch +import torch.nn.functional as F from gguf import GGMLQuantizationType as WeightType from torch.nn.parameter import Parameter, UninitializedParameter @@ -234,7 +235,7 @@ try: op_func=_fused_mul_mat_gguf, fake_impl=_fused_mul_mat_gguf_fake, ) - fused_mul_mat_gguf = torch.ops.vllm._fused_mul_mat_gguf + fused_mul_mat_gguf = _fused_mul_mat_gguf except AttributeError as error: raise error @@ -365,7 +366,7 @@ try: op_func=_fused_moe_gguf, fake_impl=_fused_moe_gguf_fake, ) - fused_moe_gguf = torch.ops.vllm._fused_moe_gguf + fused_moe_gguf = _fused_moe_gguf except AttributeError as error: raise error @@ -410,7 +411,7 @@ try: op_func=_apply_gguf_embedding, fake_impl=_apply_gguf_embedding_fake, ) - apply_gguf_embedding = torch.ops.vllm._apply_gguf_embedding + apply_gguf_embedding = _apply_gguf_embedding except AttributeError as error: raise error @@ -451,6 +452,9 @@ class GGUFLinearMethod(LinearMethodBase): "data_container": [], "shard_id": [], "shard_id_map": {}, + "params_dtype": params_dtype, + "input_size_per_partition" :input_size_per_partition, # restore shape for qkv and merge + "output_partition_sizes" :output_partition_sizes, }, ) set_weight_attrs(qweight, extra_weight_attrs) @@ -664,6 +668,10 @@ class GGUFEmbeddingMethod(GGUFLinearMethod): """ def embedding(self, layer: torch.nn.Module, x: torch.Tensor) -> torch.Tensor: + weight = layer.weight + return F.embedding(x, weight) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: qweight = layer.qweight qweight_type = layer.qweight_type.weight_type hidden_size = qweight.tensor_shape[1] diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 2cc86a9..dc91f17 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -128,7 +128,7 @@ class GPTQConfig(QuantizationConfig): @classmethod def get_supported_act_dtypes(cls) -> list[torch.dtype]: - return [torch.half, torch.bfloat16] + return [torch.bfloat16, torch.half] @classmethod # Need to figure it out diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index d7b2a36..5dd352c 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -59,9 +59,164 @@ from vllm.platforms import current_platform from vllm.scalar_type import scalar_types from vllm.transformers_utils.config import get_safetensors_params_metadata from vllm.utils.collection_utils import is_list_of +import ixformer.inference.functions as ixfops logger = init_logger(__name__) +#[B,K//8,N] ->[B,K,N] +# less memmory +def unpack_k_batch_opt(packed_w: torch.Tensor, num_bits: int = 4, chunk_size: int = 2) -> torch.Tensor: + """ + Memory-efficient unpacking for 3D tensor. + Converts [B, K // pack_factor, N] int32 tensor → [B, K, N] int8 tensor, + without broadcasting huge intermediate tensors (avoids OOM). + + Args: + packed_w: torch.int32 tensor of shape [B, K // pack_factor, N]. + num_bits: Number of bits per packed element (e.g., 4 or 2). + chunk_size: How many bit groups to unpack at once (tradeoff between speed and memory). + + Returns: + unpacked: torch.int8 tensor of shape [B, K, N]. + """ + B, k_packed, N = packed_w.shape + pack_factor = 32 // num_bits + K = k_packed * pack_factor + mask = (1 << num_bits) - 1 + + # Allocate output tensor once + unpacked = torch.empty((B, K, N), dtype=torch.int8, device=packed_w.device) + + # Process bit chunks iteratively to save memory + for i in range(0, pack_factor, chunk_size): + # Precompute shifts for this chunk + shift_vals = num_bits * torch.arange(i, min(i + chunk_size, pack_factor), device=packed_w.device) + # [chunk_size, 1, 1, 1] + shifts = shift_vals.view(-1, 1, 1, 1) + # Compute small chunk only + chunk = ((packed_w.unsqueeze(0) >> shifts) & mask).to(torch.int8) + + # chunk: [chunk_size, B, k_packed, N] + # write into output + for j in range(chunk.shape[0]): + unpacked[:, (i + j)::pack_factor, :] = chunk[j] + + del chunk # release memory early + + return unpacked + +# more memmory +def unpack_k_batch(packed_w: torch.Tensor, num_bits: int = 4) -> torch.Tensor: + """ + Efficient vectorized unpacking for 3D tensor (batch version). + Converts [B, K // pack_factor, N] int32 tensor → [B, K, N] int8 tensor. + + Args: + packed_w: torch.int32 tensor of shape [B, K // pack_factor, N]. + num_bits: Number of bits per packed element (e.g., 4). + + Returns: + unpacked: torch.int8 tensor of shape [B, K, N]. + """ + B, k_packed, n = packed_w.shape + pack_factor = 32 // num_bits + k = k_packed * pack_factor + + mask = (1 << num_bits) - 1 + + # [pack_factor, 1, 1, 1] + shifts = (num_bits * torch.arange(pack_factor, device=packed_w.device)).view(-1, 1, 1, 1) + + # [1, B, k_packed, N] + packed_expanded = packed_w.unsqueeze(0) + + # Extract each group of num_bits using bitwise ops + unpacked_groups = ((packed_expanded >> shifts) & mask).to(torch.int8) + + # [pack_factor, B, k_packed, N] → [B, K, N] + unpacked = unpacked_groups.permute(1, 2, 0, 3).reshape(B, k, n) + + return unpacked + + +#[B,K,N] ->[B,K,N//8] +# less memmory +def pack_n_batch_opt(x: torch.Tensor, pack_num: int = 8, order_map=None, chunk_size: int = 2) -> torch.Tensor: + """ + Memory-efficient batch packing with correct bit order. + [B, K, N] int4 -> [B, K, N//pack_num] int32. + """ + B, K, N = x.shape + assert N % pack_num == 0, "N must be divisible by pack_num" + cols = N // pack_num + unit = 32 // pack_num + + if order_map is None: + order_map = list(range(pack_num)) + order_map = torch.tensor(order_map, device=x.device) + + shifts = unit * torch.arange(pack_num, device=x.device) # always 0..unit*(pack_num-1) + packed = torch.zeros((B, K, cols), dtype=torch.int32, device=x.device) + x_reshape = x.view(B, K, cols, pack_num) & 0xF + + # process in chunks for memory efficiency + for start in range(0, pack_num, chunk_size): + end = min(start + chunk_size, pack_num) + idx_chunk = order_map[start:end] + shift_chunk = shifts[start:end] + + vals = torch.gather(x_reshape, 3, idx_chunk.view(1,1,1,-1).expand(B,K,cols,-1)).to(torch.int32) + for j in range(vals.shape[-1]): + packed.add_(vals[..., j] << shift_chunk[j]) + + return packed + +## more memmory +def pack_n_batch(x: torch.Tensor, pack_num: int = 8, order_map=None) -> torch.Tensor: + """ + Efficient vectorized batch packing: [B, K, N] int4 -> [B, K, N//pack_num] int32. + + Args: + x: torch.int32 tensor of shape [B, K, N], each element 0-15 (int4). + pack_num: Number of 4-bit elements per packed int32 (default=8). + order_map: Optional order of elements within each packed int32. + + Returns: + torch.int32 tensor of shape [B, K, N//pack_num]. + """ + + B, K, N = x.shape + assert N % pack_num == 0, "N must be divisible by pack_num" + cols = N // pack_num + + if order_map is None: + order_map = list(range(pack_num)) + order_map = torch.tensor(order_map, device=x.device) + + unit = 32 // pack_num # number of bits per element + + # reshape to [B, K, cols, pack_num] + pack_num_int = int(pack_num) + + x_reshape = x.view(B, K, cols, pack_num_int) + + # reorder according to order_map + x_reorder = torch.gather( + x_reshape, 3, order_map.view(1, 1, 1, -1).expand(B, K, cols, -1) + ) + + # mask low 4 bits + x_reorder = x_reorder & 0xF + + # bit shifts [pack_num] -> [1,1,1,pack_num] broadcastable + shifts = (unit * torch.arange(pack_num_int, device=x.device)).view(1, 1, 1, -1) + + # shift and sum along last dimension to combine bits + packed = (x_reorder << shifts).sum(dim=-1).to(torch.int32) + + return packed + + def get_moe_quant_method( config: "GPTQMarlinConfig", @@ -495,8 +650,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): self.quant_config = quant_config if self.quant_config.quant_type.size_bits == 4: self.quant_type = scalar_types.uint4b8 - elif self.quant_config.quant_type.size_bits == 8: - self.quant_type = scalar_types.uint8b128 + # elif self.quant_config.quant_type.size_bits == 8: + # self.quant_type = scalar_types.uint8b128 else: raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.") self.input_dtype = None @@ -594,7 +749,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): num_experts, scales_size13, 2 * intermediate_size_per_partition // self.quant_config.pack_factor, - dtype=params_dtype, + dtype=torch.int32, ), requires_grad=False, ) @@ -606,7 +761,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): num_experts, scales_size2, hidden_size // self.quant_config.pack_factor, - dtype=params_dtype, + dtype=torch.int32, ), requires_grad=False, ) @@ -656,7 +811,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) device = layer.w13_qweight.device - layer.workspace = marlin_make_workspace_new(device, 4) + # layer.workspace = marlin_make_workspace_new(device, 4) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1 @@ -673,119 +828,111 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): layer.w2_scales.data = layer.w2_scales.data * 512 # Process act_order - if self.quant_config.desc_act: + # if self.quant_config.desc_act: # Get sorting based on g_idx - num_experts = layer.w13_g_idx.shape[0] - w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx) - w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx) - w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) - w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) - for e in range(num_experts): - w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to( - torch.int32 - ) - w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to( - torch.int32 - ) - w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]] - w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]] - replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx) - replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx) - replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices) - replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices) - else: - # Reset g_idx related tensors - num_experts = layer.w13_g_idx.shape[0] - device = layer.w13_g_idx.device - layer.w13_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w2_g_idx = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w13_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - layer.w2_g_idx_sort_indices = torch.nn.Parameter( - torch.empty((num_experts, 0), dtype=torch.int32, device=device), - requires_grad=False, - ) - # Repack weights - marlin_w13_qweight = ops.gptq_marlin_moe_repack( - layer.w13_qweight, - layer.w13_g_idx_sort_indices, - layer.w13_qweight.shape[1] * self.quant_config.pack_factor, - layer.w13_qweight.shape[2], - self.quant_config.quant_type.size_bits, - is_a_8bit=is_a_8bit, - ) - replace_parameter(layer, "w13_qweight", marlin_w13_qweight) - marlin_w2_qweight = ops.gptq_marlin_moe_repack( - layer.w2_qweight, - layer.w2_g_idx_sort_indices, - layer.w2_qweight.shape[1] * self.quant_config.pack_factor, - layer.w2_qweight.shape[2], - self.quant_config.quant_type.size_bits, - is_a_8bit=is_a_8bit, - ) - replace_parameter(layer, "w2_qweight", marlin_w2_qweight) + # num_experts = layer.w13_g_idx.shape[0] + # w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx) + # w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx) + # w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) + # w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) + # for e in range(num_experts): + # w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to( + # torch.int32 + # ) + # w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to( + # torch.int32 + # ) + # w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]] + # w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]] + # replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx) + # replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx) + # replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices) + # replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices) + # else: + # # Reset g_idx related tensors + # num_experts = layer.w13_g_idx.shape[0] + # device = layer.w13_g_idx.device + # layer.w13_g_idx = torch.nn.Parameter( + # torch.empty((num_experts, 0), dtype=torch.int32, device=device), + # requires_grad=False, + # ) + # layer.w2_g_idx = torch.nn.Parameter( + # torch.empty((num_experts, 0), dtype=torch.int32, device=device), + # requires_grad=False, + # ) + # layer.w13_g_idx_sort_indices = torch.nn.Parameter( + # torch.empty((num_experts, 0), dtype=torch.int32, device=device), + # requires_grad=False, + # ) + # layer.w2_g_idx_sort_indices = torch.nn.Parameter( + # torch.empty((num_experts, 0), dtype=torch.int32, device=device), + # requires_grad=False, + # ) + # # Repack weights + # marlin_w13_qweight = ops.gptq_marlin_moe_repack( + # layer.w13_qweight, + # layer.w13_g_idx_sort_indices, + # layer.w13_qweight.shape[1] * self.quant_config.pack_factor, + # layer.w13_qweight.shape[2], + # self.quant_config.quant_type.size_bits, + # ) + # replace_parameter(layer, "w13_qweight", marlin_w13_qweight) + # marlin_w2_qweight = ops.gptq_marlin_moe_repack( + # layer.w2_qweight, + # layer.w2_g_idx_sort_indices, + # layer.w2_qweight.shape[1] * self.quant_config.pack_factor, + # layer.w2_qweight.shape[2], + # self.quant_config.quant_type.size_bits, + # ) + # replace_parameter(layer, "w2_qweight", marlin_w2_qweight) + # # Repack scales + # marlin_w13_scales = marlin_moe_permute_scales( + # s=layer.w13_scales, + # size_k=layer.intermediate_size_per_partition, + # size_n=layer.w13_scales.shape[2], + # group_size=self.quant_config.group_size, + # ) + # replace_parameter(layer, "w13_scales", marlin_w13_scales) + # marlin_w2_scales = marlin_moe_permute_scales( + # s=layer.w2_scales, + # size_k=layer.w2_scales.shape[1] + # * ( + # self.quant_config.group_size + # if self.quant_config.group_size != -1 + # else self.quant_config.pack_factor + # ), + # size_n=layer.w2_scales.shape[2], + # group_size=self.quant_config.group_size, + # ) + # replace_parameter(layer, "w2_scales", marlin_w2_scales) - # The modular kernel expects w13_weight and w2_weight, - # but GPTQ uses w13_qweight and w2_qweight - # Alias for modular kernel - layer.w13_weight = layer.w13_qweight - # Alias for modular kernel - layer.w2_weight = layer.w2_qweight + # if hasattr(layer, "w13_bias") and layer.w13_bias is not None: + # layer.w13_bias.data = marlin_permute_bias(layer.w13_bias) - # Repack scales - marlin_w13_scales = marlin_moe_permute_scales( - s=layer.w13_scales, - size_k=layer.intermediate_size_per_partition, - size_n=layer.w13_scales.shape[2], - group_size=self.quant_config.group_size, - is_a_8bit=is_a_8bit, - ) - if self.input_dtype == torch.int8 and layer.num_groups_w13 > 1: - marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales( - marlin_w13_scales - ) - layer.register_parameter( - "w13_input_global_scale", - torch.nn.Parameter(w13_input_global_scale, requires_grad=False), - ) - - replace_parameter(layer, "w13_scales", marlin_w13_scales) - marlin_w2_scales = marlin_moe_permute_scales( - s=layer.w2_scales, - size_k=layer.w2_scales.shape[1] - * ( - self.quant_config.group_size - if self.quant_config.group_size != -1 - else self.quant_config.pack_factor - ), - size_n=layer.w2_scales.shape[2], - group_size=self.quant_config.group_size, - is_a_8bit=is_a_8bit, - ) - if self.input_dtype == torch.int8 and layer.num_groups_w2 > 1: - marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales( - marlin_w2_scales - ) - layer.register_parameter( - "w2_input_global_scale", - torch.nn.Parameter(w2_input_global_scale, requires_grad=False), - ) - - replace_parameter(layer, "w2_scales", marlin_w2_scales) - - if hasattr(layer, "w13_bias") and layer.w13_bias is not None: - layer.w13_bias.data = marlin_permute_bias(layer.w13_bias) - - if hasattr(layer, "w2_bias") and layer.w2_bias is not None: - layer.w2_bias.data = marlin_permute_bias(layer.w2_bias) + # if hasattr(layer, "w2_bias") and layer.w2_bias is not None: + # layer.w2_bias.data = marlin_permute_bias(layer.w2_bias) + if self.quant_config.desc_act: + raise NotImplementedError( + "GPTQMarlinMoEMethod now not support desc_act. please fix it") + w13_qweight_unpacked = unpack_k_batch(layer.w13_qweight) + w13_qweight_repacked = pack_n_batch(w13_qweight_unpacked,self.quant_config.pack_factor,order_map=[0, 2, 4, 6, 1, 3, 5, 7]) + replace_parameter(layer, "w13_qweight", w13_qweight_repacked) + + # quant vllm/model_executor/layers/quantization/utils/quant_utils.py#quantize_weights + # if quant_type.has_bias(): + # w_q += quant_type.bias + # use quant_type.bias as zp,(ixformer support) + w13_zp = torch.full_like(layer.w13_scales, self.quant_type.bias, dtype=torch.int32) + w13_zp_pack = pack_n_batch(w13_zp, self.quant_config.pack_factor, order_map=[0, 2, 4, 6, 1, 3, 5, 7]).contiguous() + replace_parameter(layer, "w13_qzeros", w13_zp_pack) + + w2_qweight_unpacked = unpack_k_batch(layer.w2_qweight) + w2_qweight_repacked = pack_n_batch(w2_qweight_unpacked,self.quant_config.pack_factor,order_map=[0, 2, 4, 6, 1, 3, 5, 7]) + replace_parameter(layer, "w2_qweight", w2_qweight_repacked) + + w2_zp = torch.full_like(layer.w2_scales, self.quant_type.bias, dtype=torch.int32) + w2_zp_pack = pack_n_batch(w2_zp, self.quant_config.pack_factor, order_map=[0, 2, 4, 6, 1, 3, 5, 7]).contiguous() + replace_parameter(layer, "w2_qzeros", w2_zp_pack) def get_fused_moe_quant_config( self, layer: torch.nn.Module @@ -900,30 +1047,165 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): x: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + # Assign the value of shared_experts_output to variable shared_experts_input for fusion shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - return fused_marlin_moe( - x, - layer.w13_qweight, - layer.w2_qweight, - getattr(layer, "w13_bias", None), - getattr(layer, "w2_bias", None), - layer.w13_scales, - layer.w2_scales, - topk_weights, - topk_ids, - input_global_scale1=getattr(layer, "w13_input_global_scale", None), - input_global_scale2=getattr(layer, "w2_input_global_scale", None), - quant_type_id=self.quant_type.id, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - global_num_experts=layer.global_num_experts, - expert_map=layer.expert_map, - g_idx1=layer.w13_g_idx, - g_idx2=layer.w2_g_idx, - sort_indices1=layer.w13_g_idx_sort_indices, - sort_indices2=layer.w2_g_idx_sort_indices, - workspace=layer.workspace, - is_k_full=self.is_k_full, - input_dtype=self.input_dtype, - inplace=not self.moe.disable_inplace, + assert layer.activation.value == "silu", "Only SiLU activation is supported." + use_ep = layer.expert_map is not None + + if use_ep: + start_eid = layer.ep_rank * layer.local_num_experts + end_eid = min((layer.ep_rank + 1) * layer.local_num_experts, layer.global_num_experts) + + if layer.apply_router_weight_on_input: + raise NotImplementedError( + "GPTQMarlinMoEMethod Apply router weight on input is not supported for" + "fused Marlin MoE method.") + + if (hasattr(layer, "w13_bias") and layer.w13_bias is not None) or (hasattr(layer, "w2_bias") and layer.w2_bias is not None): + raise NotImplementedError( + "GPTQMarlinMoEMethod moe_w4a16_group_gemm not supported bias, please fix this") + + num_tokens = topk_ids.shape[0] + num_experts = layer.global_num_experts + + if use_ep: + hidden_size = x.shape[1] + ( + src_to_dst, + sorted_token_ids, + expert_sizes_gpu, + expert_sizes_cpu, + expand_tokens, + ) = ixfops.moe_compute_token_index_ep( + topk_ids=topk_ids, + num_experts=num_experts, + start_expert_id=start_eid, + end_expert_id=end_eid, + ) + if expert_sizes_cpu.sum() == 0: + return torch.zeros( + (num_tokens, hidden_size), + device=x.device, + dtype=x.dtype, + ) + else: + expand_tokens = num_tokens * layer.top_k + ( + src_to_dst, + sorted_token_ids, + expert_sizes_gpu, + expert_sizes_cpu, + ) = ixfops.moe_compute_token_index( + topk_ids=topk_ids, + num_experts=num_experts, + ) + expert_sizes_cpu = expert_sizes_gpu.cpu() + + # expand + reorder + # TODO use kernel + expand_hidden_states = ixfops.moe_expand_input( + hidden_states=x, + dst_to_src=sorted_token_ids, + dst_tokens=expand_tokens, + topk=layer.top_k, + src_to_dst=src_to_dst, ) + + # w4a16 group gemm 1 + # pt_output_1: (expand_tokens, 2n) dtype + pt_output_1 = ixfops.moe_w4a16_group_gemm( + input=expand_hidden_states, + weight=layer.w13_qweight, + w_scales=layer.w13_scales, + quant_type="awq", + tokens_per_experts=expert_sizes_cpu, + w_zeros=layer.w13_qzeros, + group_size=self.quant_config.group_size, + dst_to_src=None, + format="NN", + tokens_per_experts_gpu=expert_sizes_gpu, + ) + + # act + pt_output_2 = ixfops.silu_and_mul(pt_output_1) + + # w4a16 group gemm 2 + reorder + # pt_output_3: (expand_tokens, k) dtype + if use_ep: + pt_output_3 = torch.empty( + (num_tokens * layer.top_k, hidden_size), + device=x.device, + dtype=x.dtype, + ) + + ixfops.moe_w4a16_group_gemm( + input=pt_output_2, + weight=layer.w2_qweight, + w_scales=layer.w2_scales, + quant_type="awq", + tokens_per_experts=expert_sizes_cpu, + w_zeros=layer.w2_qzeros, + group_size=self.quant_config.group_size, + dst_to_src=sorted_token_ids, + format="NN", + output=pt_output_3, + tokens_per_experts_gpu=expert_sizes_gpu, + ) + + reduce_mask = src_to_dst == -1 + final_hidden_states = ixfops.moe_output_reduce_sum( + input=pt_output_3.view(num_tokens, layer.top_k, -1), + topk_weight=topk_weights, + scaling_factor=layer.routed_scaling_factor, + mask=reduce_mask, + ) + else: + pt_output_3 = ixfops.moe_w4a16_group_gemm( + input=pt_output_2, + weight=layer.w2_qweight, + w_scales=layer.w2_scales, + quant_type="awq", + tokens_per_experts=expert_sizes_cpu, + w_zeros=layer.w2_qzeros, + group_size=self.quant_config.group_size, + dst_to_src=sorted_token_ids, + format="NN", + tokens_per_experts_gpu=expert_sizes_gpu, + ) + + # mul + reduce_sum + # final_hidden_states: (num_tokens, k) + final_hidden_states = ixfops.moe_output_reduce_sum( + input=pt_output_3.view(num_tokens, layer.top_k, -1), + topk_weight=topk_weights, + scaling_factor=layer.routed_scaling_factor, + extra_residual=shared_experts_input, + ) + return final_hidden_states + + + + + + # return torch.ops.vllm.fused_marlin_moe( + # x, + # layer.w13_qweight, + # layer.w2_qweight, + # getattr(layer, "w13_bias", None), + # getattr(layer, "w2_bias", None), + # layer.w13_scales, + # layer.w2_scales, + # router_logits, + # topk_weights, + # topk_ids, + # quant_type_id=self.quant_type.id, + # apply_router_weight_on_input=apply_router_weight_on_input, + # global_num_experts=global_num_experts, + # expert_map=expert_map, + # g_idx1=layer.w13_g_idx, + # g_idx2=layer.w2_g_idx, + # sort_indices1=layer.w13_g_idx_sort_indices, + # sort_indices2=layer.w2_g_idx_sort_indices, + # workspace=layer.workspace, + # is_k_full=self.is_k_full) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 4c059da..f167e21 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -12,8 +12,7 @@ from vllm.logger import init_logger from vllm.model_executor.kernels.linear import ( init_fp8_linear_kernel, ) -from vllm.model_executor.layers.attention import Attention -from vllm.model_executor.layers.fused_moe.activation import MoEActivation +from vllm.model_executor.layers.attention import Attention, MLAAttention from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, @@ -24,14 +23,12 @@ from vllm.model_executor.layers.fused_moe.layer import ( FusedMoeWeightScaleSupported, ) from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( - Fp8MoeBackend, convert_to_fp8_moe_kernel_format, make_fp8_moe_kernel, make_fp8_moe_quant_config, select_fp8_moe_backend, ) from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( - NvFp4MoeBackend, convert_to_nvfp4_moe_kernel_format, is_global_sf_supported_for_nvfp4_backend, make_nvfp4_moe_kernel, @@ -49,13 +46,6 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizeMethodBase, ) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod -from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( - flashinfer_trtllm_fp4_moe, - flashinfer_trtllm_fp4_routed_moe, -) -from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - apply_fi_trtllm_fp8_per_tensor_moe, -) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, process_fp8_input_tensor_strategy_moe, @@ -114,6 +104,8 @@ QUANT_ALGOS = [ "NVFP4", # MXFP8 "MXFP8", + # MIXED_PRECISION, + "MIXED_PRECISION", ] KV_CACHE_QUANT_ALGOS = ["FP8"] @@ -181,7 +173,7 @@ class ModelOptQuantConfigBase(QuantizationConfig): self, layer: torch.nn.Module, prefix: str ) -> "QuantizeMethodBase | None": # handle kv-cache first so we can focus only on weight quantization thereafter - if isinstance(layer, Attention): + if isinstance(layer, (Attention, MLAAttention)): return self.KVCacheMethodCls(self) # handle exclusion @@ -235,6 +227,26 @@ class ModelOptQuantConfigBase(QuantizationConfig): self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules) + @staticmethod + def _extract_modelopt_quant_algo( + hf_quant_cfg: dict[str, Any] | None, + ) -> str | None: + """Extract upper-cased quant_algo from a modelopt config. + + Returns the quant_algo string (upper-cased), or None if the config + is not a modelopt config. + """ + if hf_quant_cfg is None: + return None + if hf_quant_cfg.get("quant_method", "").lower() != "modelopt": + return None + if "quantization" in hf_quant_cfg: + quant_config = hf_quant_cfg["quantization"] + if isinstance(quant_config, dict): + return str(quant_config.get("quant_algo", "")).upper() + return None + return str(hf_quant_cfg.get("quant_algo", "")).upper() + @staticmethod def get_config_filenames() -> list[str]: return ["hf_quant_config.json"] @@ -272,10 +284,20 @@ class ModelOptQuantConfigBase(QuantizationConfig): # "exclude_modules" is the key in the legacy hf_quant_config.json exclude_modules = quant_config.get("exclude_modules", []) else: - # Compressed-tensors style format: + # Compressed-tensors style format (config.json quantization_config): # {"quant_algo": "...", "quant_method": "modelopt"} quant_method = config.get("quant_algo") - kv_cache_quant_method = config.get("kv_cache_quant_algo") + + # "kv_cache_scheme" (a dict) instead of "kv_cache_quant_algo" (a string). + kv_cache_scheme = config.get("kv_cache_scheme") + if isinstance(kv_cache_scheme, dict) and ( + kv_cache_scheme.get("type") == "float" + and kv_cache_scheme.get("num_bits") == 8 + ): + kv_cache_quant_method = "FP8" + else: + kv_cache_quant_method = None + # "ignore" is the key in config.json exclude_modules = config.get("ignore", []) group_size_raw = config.get("group_size") @@ -379,32 +401,9 @@ class ModelOptFp8Config(ModelOptQuantConfigBase): def override_quantization_method( cls, hf_quant_cfg, user_quant ) -> QuantizationMethods | None: - """Detect if this ModelOpt config should be used based on - quantization config.""" - - if hf_quant_cfg is None: - return None - - # Use the community standard 'quant_method' - quant_method = hf_quant_cfg.get("quant_method", "").lower() - - # Only proceed if the method is explicitly "modelopt" - if quant_method != "modelopt": - return None - - # Look for ModelOpt-specific config structure - if "quantization" in hf_quant_cfg: - quant_config = hf_quant_cfg["quantization"] - if isinstance(quant_config, dict): - quant_algo = str(quant_config.get("quant_algo", "")) - if quant_algo.upper() == "FP8": - return "modelopt" - else: - # Check for compressed-tensors style config with specific quant_algo - quant_algo = str(hf_quant_cfg.get("quant_algo", "")) - if quant_algo.upper() == "FP8": - return "modelopt" - + algo = cls._extract_modelopt_quant_algo(hf_quant_cfg) + if algo is not None and algo == "FP8": + return "modelopt" return None @classmethod @@ -737,7 +736,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, - ) -> mk.FusedMoEPrepareAndFinalize | None: + ) -> mk.FusedMoEPrepareAndFinalizeModular | None: raise ValueError( f"{self.__class__.__name__} uses the new modular kernel initialization " "logic. This function should not be called." @@ -745,9 +744,9 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): def select_gemm_impl( self, - prepare_finalize: mk.FusedMoEPrepareAndFinalize, + prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular, layer: torch.nn.Module, - ) -> mk.FusedMoEPermuteExpertsUnpermute: + ) -> mk.FusedMoEExpertsModular: raise ValueError( f"{self.__class__.__name__} uses the new modular kernel initialization " "logic. This function should not be called." @@ -862,16 +861,15 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): # Setup modular kernel. self.moe_quant_config = self.get_fused_moe_quant_config(layer) - if self.moe_quant_config: - assert self.experts_cls is not None - self.moe_mk = make_fp8_moe_kernel( - moe_quant_config=self.moe_quant_config, - moe_config=self.moe, - fp8_backend=self.fp8_backend, - experts_cls=self.experts_cls, - routing_tables=layer._maybe_init_expert_routing_tables(), - shared_experts=layer.shared_experts, - ) + assert self.experts_cls is not None + self.moe_kernel = make_fp8_moe_kernel( + moe_quant_config=self.moe_quant_config, + moe_config=self.moe, + fp8_backend=self.fp8_backend, + experts_cls=self.experts_cls, + routing_tables=layer._maybe_init_expert_routing_tables(), + shared_experts=layer.shared_experts, + ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: w13 = layer.w13_weight @@ -904,9 +902,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale ) - def get_fused_moe_quant_config( - self, layer: torch.nn.Module - ) -> FusedMoEQuantConfig | None: + def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: w1_scale = layer.w13_weight_scale w2_scale = layer.w2_weight_scale a1_scale = layer.w13_input_scale @@ -920,10 +916,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): a2_scale=a2_scale, ) - @property - def is_monolithic(self) -> bool: - return self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM - def apply_monolithic( self, layer: FusedMoE, @@ -931,28 +923,20 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.is_monolithic - assert self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM - if layer.enable_eplb: - raise NotImplementedError( - "EPLB not supported for FlashInfer TRTLLM FP8 MoE Backend." - ) - # TODO(rob): this validation should happen at kernel selection - # time in the oracle rather than here. - SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] - assert layer.activation in SUPPORTED_ACTIVATIONS, ( - f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer " - f"TRTLLM FP4 MoE, {layer.activation} found instead." - ) - return apply_fi_trtllm_fp8_per_tensor_moe( - layer=layer, - hidden_states=x, - router_logits=router_logits, - routing_bias=layer.e_score_correction_bias, + assert self.moe_kernel is not None + return self.moe_kernel.apply_monolithic( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, + activation=layer.activation, global_num_experts=layer.global_num_experts, - top_k=layer.top_k, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, num_expert_group=layer.num_expert_group, topk_group=layer.topk_group, - apply_router_weight_on_input=layer.apply_router_weight_on_input, + e_score_correction_bias=layer.e_score_correction_bias, + routed_scaling_factor=layer.routed_scaling_factor, ) def apply( @@ -964,25 +948,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert not self.is_monolithic - - # TODO(rob): this validation should happen at kernel selection - # time in the oracle rather than here. - if self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: - assert layer.activation in ( - MoEActivation.SILU, - MoEActivation.RELU2_NO_MUL, - ), ( - "Expected activation to be in ('silu', 'relu2_no_mul')," - f"but got {layer.activation}" - ) - - assert self.moe_mk is not None - return self.moe_mk( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, + assert self.moe_kernel is not None + return self.moe_kernel.apply( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, activation=layer.activation, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, @@ -1031,32 +1003,9 @@ class ModelOptNvFp4Config(ModelOptQuantConfigBase): def override_quantization_method( cls, hf_quant_cfg, user_quant ) -> QuantizationMethods | None: - """Detect if this ModelOpt FP4 config should be used based on - quantization config.""" - if hf_quant_cfg is None: - return None - - # Use the community standard 'quant_method' - quant_method = hf_quant_cfg.get("quant_method", "").lower() - - # Only proceed if the method is explicitly "modelopt" - if quant_method != "modelopt": - return None - - # Look for ModelOpt-specific config structure - if "quantization" in hf_quant_cfg: - quant_config = hf_quant_cfg["quantization"] - if isinstance(quant_config, dict): - quant_algo = quant_config.get("quant_algo", "") - if "NVFP4" in quant_algo: - return "modelopt_fp4" - else: - # Check for compressed-tensors style config with specific - # quant_algo field - quant_algo = hf_quant_cfg.get("quant_algo", "") - if isinstance(quant_algo, str) and "FP4" in quant_algo.upper(): - return "modelopt_fp4" - + algo = cls._extract_modelopt_quant_algo(hf_quant_cfg) + if algo is not None and ("NVFP4" in algo or "FP4" in algo): + return "modelopt_fp4" return None @classmethod @@ -1249,17 +1198,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, - ) -> mk.FusedMoEPrepareAndFinalize | None: - raise ValueError( - f"{self.__class__.__name__} uses the new modular kernel initialization " - "logic. This function should not be called." - ) - - def select_gemm_impl( - self, - prepare_finalize: mk.FusedMoEPrepareAndFinalize, - layer: torch.nn.Module, - ) -> mk.FusedMoEPermuteExpertsUnpermute: + ) -> mk.FusedMoEPrepareAndFinalizeModular | None: raise ValueError( f"{self.__class__.__name__} uses the new modular kernel initialization " "logic. This function should not be called." @@ -1434,51 +1373,18 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): replace_parameter(layer, "w2_weight_scale_2", w2_scale_2) replace_parameter(layer, "w2_input_scale", a2_scale) - # Setup modular kernel for TP case and naive DP/EP case. - # In non-naive DP/EP case, we will create a ModularKernelMethod. - # TODO(rob): unify these so FP8MoEMethod owns the ModularKernel - # in both cases. + # Setup modular kernel. self.moe_quant_config = self.get_fused_moe_quant_config(layer) - if self.moe_quant_config: - assert self.experts_cls is not None - self.moe_mk = make_nvfp4_moe_kernel( - moe_quant_config=self.moe_quant_config, - moe_config=self.moe, - experts_cls=self.experts_cls, - shared_experts=layer.shared_experts, - routing_tables=layer._maybe_init_expert_routing_tables(), - ) - - @property - def do_post_quant_allgather(self): - return self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM - - def prepare_dp_allgather_tensor( - self, - layer: FusedMoE, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - ) -> tuple[torch.Tensor, list[torch.Tensor]]: - """Optionally prepare extra tensors to carry through DP allgather/EP.""" - if self.nvfp4_backend != NvFp4MoeBackend.FLASHINFER_TRTLLM: - raise RuntimeError( - "prepare_dp_allgather_tensor is only supported for " - "FlashInfer TRTLLM NVFP4 MoE backend." - ) - - import flashinfer - - hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize( - hidden_states, - layer.a1_gscale, - is_sf_swizzled_layout=False, + assert self.experts_cls is not None + self.moe_kernel = make_nvfp4_moe_kernel( + moe_quant_config=self.moe_quant_config, + moe_config=self.moe, + experts_cls=self.experts_cls, + shared_experts=layer.shared_experts, + routing_tables=layer._maybe_init_expert_routing_tables(), ) - extra_tensors: list[torch.Tensor] = [hidden_states_sf] - return hidden_states_fp4, extra_tensors - def get_fused_moe_quant_config( - self, layer: torch.nn.Module - ) -> FusedMoEQuantConfig | None: + def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: return make_nvfp4_moe_quant_config( backend=self.nvfp4_backend, w13_scale=layer.w13_weight_scale, @@ -1493,13 +1399,6 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): def supports_eplb(self) -> bool: return True - @property - def is_monolithic(self) -> bool: - return ( - self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM - and not self.moe.moe_parallel_config.enable_eplb - ) - def apply_monolithic( self, layer: FusedMoE, @@ -1507,22 +1406,20 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.is_monolithic - assert ( - self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM - and not layer.enable_eplb - ) - - return flashinfer_trtllm_fp4_moe( - layer=layer, - x=x, - router_logits=router_logits, - top_k=layer.top_k, + assert self.moe_kernel is not None + return self.moe_kernel.apply_monolithic( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, activation=layer.activation, global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, num_expert_group=layer.num_expert_group, topk_group=layer.topk_group, - custom_routing_function=layer.custom_routing_function, e_score_correction_bias=layer.e_score_correction_bias, + routed_scaling_factor=layer.routed_scaling_factor, ) def apply( @@ -1534,33 +1431,19 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert not self.is_monolithic - - # EPLB path - if self.nvfp4_backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: - assert layer.enable_eplb - return flashinfer_trtllm_fp4_routed_moe( - layer=layer, - x=x, - topk_ids=topk_ids, - topk_weights=topk_weights, - top_k=layer.top_k, - activation=layer.activation, - global_num_experts=layer.global_num_experts, - ) - else: - assert self.moe_mk is not None - return self.moe_mk( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=layer.activation, - global_num_experts=layer.global_num_experts, - expert_map=layer.expert_map, - apply_router_weight_on_input=layer.apply_router_weight_on_input, - shared_experts_input=shared_experts_input, - ) + assert self.moe_kernel is not None + return self.moe_kernel.apply( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=layer.activation, + global_num_experts=layer.global_num_experts, + expert_map=layer.expert_map, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + shared_experts_input=shared_experts_input, + ) ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod @@ -1619,31 +1502,9 @@ class ModelOptMxFp8Config(ModelOptQuantConfigBase): def override_quantization_method( cls, hf_quant_cfg, user_quant ) -> QuantizationMethods | None: - """Detect if this ModelOpt MXFP8 config should be used based on - quantization config.""" - if hf_quant_cfg is None: - return None - - # Use the community standard 'quant_method' - quant_method = hf_quant_cfg.get("quant_method", "").lower() - - # Only proceed if the method is explicitly "modelopt" - if quant_method != "modelopt": - return None - - # Look for ModelOpt-specific config structure - if "quantization" in hf_quant_cfg: - quant_config = hf_quant_cfg["quantization"] - if isinstance(quant_config, dict): - quant_algo = str(quant_config.get("quant_algo", "")).upper() - if "MXFP8" in quant_algo: - return "modelopt_mxfp8" - else: - # Check for compressed-tensors style config with specific quant_algo - quant_algo = str(hf_quant_cfg.get("quant_algo", "")).upper() - if "MXFP8" in quant_algo: - return "modelopt_mxfp8" - + algo = cls._extract_modelopt_quant_algo(hf_quant_cfg) + if algo is not None and "MXFP8" in algo: + return "modelopt_mxfp8" return None @classmethod @@ -1841,3 +1702,188 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase): # Register the method classes for ModelOptMxFp8Config ModelOptMxFp8Config.LinearMethodCls = ModelOptMxFp8LinearMethod ModelOptMxFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod + + +class ModelOptMixedPrecisionConfig(ModelOptQuantConfigBase): + """Config class for ModelOpt MIXED_PRECISION. + + Supports checkpoints where different layers use different quantization + algorithms (e.g., FP8 for dense layers and NVFP4 for MoE experts). + The per-layer algorithm is specified in the ``quantized_layers`` dict + inside ``config.json``'s ``quantization_config`` (preferred) or the + legacy ``hf_quant_config.json``. + """ + + def __init__( + self, + kv_cache_quant_method: str | None, + exclude_modules: list[str], + quantized_layers: dict[str, dict[str, Any]], + fp8_config: ModelOptFp8Config, + nvfp4_config: ModelOptNvFp4Config, + ) -> None: + super().__init__(exclude_modules) + self.kv_cache_quant_method = kv_cache_quant_method + self.quantized_layers = quantized_layers + self.fp8_config = fp8_config + self.nvfp4_config = nvfp4_config + + def get_name(self) -> QuantizationMethods: + return "modelopt_mixed" + + def get_supported_act_dtypes(self) -> list[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 89 + + @classmethod + def override_quantization_method( + cls, hf_quant_cfg, user_quant + ) -> QuantizationMethods | None: + algo = cls._extract_modelopt_quant_algo(hf_quant_cfg) + if algo is not None and algo == "MIXED_PRECISION": + return "modelopt_mixed" + return None + + @classmethod + def _from_config( + cls, + *, + quant_method: str, + kv_cache_quant_method: str | None, + exclude_modules: list[str], + original_config: dict[str, Any], + group_size: int | None, + **kwargs: Any, + ) -> "ModelOptMixedPrecisionConfig": + if "quantization" in original_config: + quantized_layers = original_config["quantization"].get( + "quantized_layers", {} + ) + else: + quantized_layers = original_config.get("quantized_layers", {}) + + if not quantized_layers: + raise ValueError( + "MIXED_PRECISION quant_algo requires a non-empty " + "'quantized_layers' mapping in the quantization config." + ) + + # Determine group_size from the first NVFP4 entry if not provided. + if group_size is None: + for layer_info in quantized_layers.values(): + if layer_info.get("quant_algo", "").upper() == "NVFP4": + group_size = layer_info.get("group_size", 16) + break + if group_size is None: + group_size = 16 + + fp8_config = ModelOptFp8Config( + quant_method="FP8", + is_checkpoint_fp8_serialized=True, + kv_cache_quant_method=kv_cache_quant_method, + exclude_modules=[], + ) + nvfp4_config = ModelOptNvFp4Config( + is_checkpoint_nvfp4_serialized=True, + kv_cache_quant_algo=kv_cache_quant_method, + exclude_modules=[], + group_size=group_size, + ) + + return cls( + kv_cache_quant_method=kv_cache_quant_method, + exclude_modules=exclude_modules, + quantized_layers=quantized_layers, + fp8_config=fp8_config, + nvfp4_config=nvfp4_config, + ) + + def _resolve_quant_algo(self, prefix: str) -> str | None: + """Look up the quant_algo for a vLLM-side layer prefix. + + Tries three strategies in order: + 1. Direct lookup in ``quantized_layers``. + 2. Packed/fused-layer lookup (unfuse via ``packed_modules_mapping``). + 3. Prefix-based lookup for FusedMoE (any child key starts with + ``prefix + "."``). + + Returns the upper-cased quant_algo string, or *None* if the prefix + is not found. + """ + # 1. Direct lookup + if prefix in self.quantized_layers: + return self.quantized_layers[prefix]["quant_algo"].upper() + + # 2. Packed / fused layer lookup + proj_name = prefix.rsplit(".", 1)[-1] + if self.packed_modules_mapping and proj_name in self.packed_modules_mapping: + algos: set[str] = set() + base = prefix.rsplit(".", 1)[0] + for shard_name in self.packed_modules_mapping[proj_name]: + shard_prefix = f"{base}.{shard_name}" + if shard_prefix in self.quantized_layers: + algos.add(self.quantized_layers[shard_prefix]["quant_algo"].upper()) + if len(algos) == 1: + return algos.pop() + if len(algos) > 1: + raise ValueError( + f"Mixed quant_algo within fused layer {prefix}: " + f"{algos}. All shards must use the same quantization." + ) + + # 3. Prefix-based lookup (for FusedMoE / parent modules) + prefix_dot = prefix + "." + for key, info in self.quantized_layers.items(): + if key.startswith(prefix_dot): + return info["quant_algo"].upper() + + return None + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> "QuantizeMethodBase | None": + """Return quantize-method based on layer.""" + # KV-cache quantization + if isinstance(layer, Attention): + if self.kv_cache_quant_method: + return ModelOptFp8KVCacheMethod(self) + return None + + # Excluded layers + if self.is_layer_excluded(prefix): + if isinstance(layer, LinearBase): + return UnquantizedLinearMethod() + return None + + quant_algo = self._resolve_quant_algo(prefix) + + if isinstance(layer, LinearBase): + if quant_algo == "FP8": + return ModelOptFp8LinearMethod(self.fp8_config) + if quant_algo == "NVFP4": + return ModelOptNvFp4LinearMethod(self.nvfp4_config) + # Layer not in quantized_layers — leave unquantized + return UnquantizedLinearMethod() + + if isinstance(layer, FusedMoE): + if quant_algo == "FP8": + return ModelOptFp8MoEMethod( + quant_config=self.fp8_config, + moe_config=layer.moe_config, + ) + if quant_algo == "NVFP4": + return ModelOptNvFp4FusedMoE( + quant_config=self.nvfp4_config, + moe_config=layer.moe_config, + ) + return None + + return None + + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + super().apply_vllm_mapper(hf_to_vllm_mapper) + if self.quantized_layers: + self.quantized_layers = hf_to_vllm_mapper.apply_dict(self.quantized_layers) diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index d81f0f8..97d6017 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -6,6 +6,7 @@ import torch from torch.nn.parameter import Parameter from vllm import envs +from vllm._aiter_ops import rocm_aiter_ops from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention @@ -77,6 +78,8 @@ class Mxfp4Backend(Enum): # Triton Backend TRITON = 6 + CK = 7 + def get_mxfp4_backend_with_lora() -> Mxfp4Backend: """ @@ -167,9 +170,15 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend: elif current_platform.is_xpu(): logger.info_once("Using xpu backend on XPU") return Mxfp4Backend.MARLIN - elif current_platform.is_rocm() and has_triton_kernels(): - logger.info_once("Using Triton backend") - return Mxfp4Backend.TRITON + elif current_platform.is_rocm(): + from vllm.platforms.rocm import on_gfx950 + + if rocm_aiter_ops.is_enabled() and on_gfx950(): + logger.info_once("Using CK MXFP4 MoE backend (Aiter ROCm)") + return Mxfp4Backend.CK + elif has_triton_kernels(): + logger.info_once("Using Triton backend") + return Mxfp4Backend.TRITON return Mxfp4Backend.NONE @@ -257,7 +266,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ) self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {} # Initialized in process_weights_after_loading for CUTLASS/SM90 backends - self.moe_mk: mk.FusedMoEModularKernel | None = None + self.moe_kernel: mk.FusedMoEKernel | None = None def create_weights( self, @@ -338,6 +347,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): self.intermediate_size = intermediate_size_per_partition_after_pad self.hidden_size = hidden_size + self.hidden_pad = extra_weight_attrs.get("hidden_pad", 0) + self.intermediate_pad = ( + intermediate_size_per_partition_after_pad - intermediate_size_per_partition + ) # Fused gate_up_proj (column parallel) w13_weight = torch.nn.Parameter( torch.zeros( @@ -427,7 +440,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ) assert prepare_finalize is not None - self.moe_mk = mk.FusedMoEModularKernel( + self.moe_kernel = mk.FusedMoEKernel( prepare_finalize, MarlinExperts( self.moe, @@ -776,7 +789,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ) assert prepare_finalize is not None - self.moe_mk = mk.FusedMoEModularKernel( + self.moe_kernel = mk.FusedMoEKernel( prepare_finalize, FlashInferExperts( moe_config=self.moe, @@ -784,6 +797,66 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ), shared_experts=None, ) + elif self.mxfp4_backend == Mxfp4Backend.CK: + if layer.w13_bias is not None: + layer.w13_bias.data = layer.w13_bias.data.to(torch.float32) + if layer.w2_bias.data is not None: + layer.w2_bias.data = layer.w2_bias.data.to(torch.float32) + + e, n, k = layer.w13_weight.shape + layer.w13_weight.view(torch.uint8).copy_( + layer.w13_weight.data.view(torch.uint8) + .view(e, n // 2, 2, k) + .permute(0, 2, 1, 3) + .contiguous() + .view(e, n, k) + ) + layer.w13_weight_scale.data = ( + layer.w13_weight_scale.data.view(e, n // 2, 2, -1) + .permute(0, 2, 1, 3) + .contiguous() + .view(e, n, -1) + ) + layer.w13_weight.data = layer.w13_weight.data.view(torch.float4_e2m1fn_x2) + layer.w2_weight.data = layer.w2_weight.data.view(torch.float4_e2m1fn_x2) + + layer.w13_weight.data = rocm_aiter_ops.shuffle_weight_a16w4( + layer.w13_weight, 16, True + ) + shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4( + layer.w13_weight_scale.view(-1, layer.w13_weight_scale.shape[-1]), + self.num_experts, + True, + ) + + layer.w2_weight.data = rocm_aiter_ops.shuffle_weight_a16w4( + layer.w2_weight, 16, False + ) + shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4( + layer.w2_weight_scale.view(-1, layer.w2_weight_scale.shape[-1]), + self.num_experts, + False, + ) + + layer.w13_bias.data = ( + layer.w13_bias.data.view(-1, n // 2, 2) + .permute(0, 2, 1) + .contiguous() + .view(-1, n) + ) + + layer.w13_weight_scale = torch.nn.Parameter( + shuffled_w13_scale, requires_grad=False + ) + layer.w2_weight_scale = torch.nn.Parameter( + shuffled_w2_scale, requires_grad=False + ) + # replace_parameter(layer, "w13_bias", w13_bias) + # replace_parameter(layer, "w13_weight_scale", w13_weight_scale) + # replace_parameter(layer, "w2_weight_scale", w2_weight_scale) + # replace_parameter(layer, "w13_weight", w13_weight) + # replace_parameter(layer, "w2_weight", w2_weight) + elif self.mxfp4_backend == Mxfp4Backend.TRITON: from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig @@ -792,18 +865,16 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): layer.w13_bias = Parameter(w13_bias, requires_grad=False) layer.w2_bias = Parameter(w2_bias, requires_grad=False) - # Ideally we'd use FusedMoEModularKernel.prepare_finalize object # (stored in self.fused_experts) to determine if the MoE has a # batched activation format. As self.fused_experts is not # initialized at this point, we resort to checking the MoE config # directly. - is_batched_moe = self.moe.use_pplx_kernels or self.moe.use_deepep_ll_kernels + is_batched_moe = self.moe.use_deepep_ll_kernels if is_batched_moe: num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8 else: num_warps = 8 - w13_weight, w13_flex, w13_scale = _swizzle_mxfp4( layer.w13_weight, layer.w13_weight_scale, num_warps ) @@ -817,13 +888,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): self.w2_precision_config = PrecisionConfig( weight_scale=w2_scale, flex_ctx=FlexCtx(rhs_data=w2_flex) ) - self.w13_weight = w13_weight self.w2_weight = w2_weight del layer.w13_weight del layer.w2_weight layer.w13_weight = w13_weight layer.w2_weight = w2_weight + else: raise ValueError( f"Unsupported mxfp4_backend: {self.mxfp4_backend}: " @@ -862,6 +933,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): elif self.mxfp4_backend in [ Mxfp4Backend.SM100_FI_MXFP4_BF16, Mxfp4Backend.SM90_FI_MXFP4_BF16, + Mxfp4Backend.CK, ]: return mxfp4_w4a16_moe_quant_config( w1_bias=layer.w13_bias, @@ -882,9 +954,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): def select_gemm_impl( self, - prepare_finalize: mk.FusedMoEPrepareAndFinalize, + prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular, layer: torch.nn.Module, - ) -> mk.FusedMoEPermuteExpertsUnpermute: + ) -> mk.FusedMoEExpertsModular: if ( prepare_finalize.activation_format == mk.FusedMoEActivationFormat.BatchedExperts @@ -929,10 +1001,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): @property def is_monolithic(self) -> bool: + if self.moe.is_lora_enabled: + return False return ( self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16 or self.mxfp4_backend == Mxfp4Backend.TRITON + or self.mxfp4_backend == Mxfp4Backend.CK ) def apply( @@ -968,8 +1043,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): or self.mxfp4_backend == Mxfp4Backend.MARLIN ) - assert self.moe_mk is not None - return self.moe_mk( + assert self.moe_kernel is not None + return self.moe_kernel.apply( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -1054,6 +1129,27 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): tune_max_num_tokens=max(self.max_capture_size, 1), )[0] return trtllm_gen_output + elif self.mxfp4_backend == Mxfp4Backend.CK: + topk_weights, topk_ids = rocm_aiter_ops.fused_topk( + x, router_logits, layer.top_k, True + ) + output = rocm_aiter_ops.fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation_method=rocm_aiter_ops.get_aiter_activation_type("swiglu"), + quant_method=rocm_aiter_ops.get_aiter_quant_type("per_1x32"), + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + doweight_stage1=False, + hidden_pad=self.hidden_pad // 128 * 128, + intermediate_pad=self.intermediate_pad // 64 * 64 * 2, + bias1=layer.w13_bias, + bias2=layer.w2_bias, + ) + return output elif self.mxfp4_backend == Mxfp4Backend.TRITON: from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501 triton_kernel_moe_forward, @@ -1162,7 +1258,7 @@ class XpuMxfp4MoEMethod(Mxfp4MoEMethod): topk_weights=routing_weights, topk_ids=selected_experts, n_experts_per_token=layer.top_k, - activation=layer.activation, + activation=layer.activation.value, num_experts=layer.local_num_experts, is_mxfp4=True, ) diff --git a/vllm/model_executor/layers/quantization/ptpc_fp8.py b/vllm/model_executor/layers/quantization/ptpc_fp8.py index 76410f2..5d7b7b5 100644 --- a/vllm/model_executor/layers/quantization/ptpc_fp8.py +++ b/vllm/model_executor/layers/quantization/ptpc_fp8.py @@ -7,7 +7,6 @@ import torch from torch.nn.parameter import Parameter from vllm import _custom_ops as ops -from vllm.logger import init_logger from vllm.model_executor.kernels.linear import ( init_fp8_linear_kernel, ) @@ -26,10 +25,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ) from vllm.platforms import current_platform -ACTIVATION_SCHEMES = ["static", "dynamic"] - -logger = init_logger(__name__) - class PTPCFp8Config(Fp8Config): """Config class for Per-Token-Per-Channel Dynamic Quantization Fp8.""" diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index 36f20c8..dedc7db 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -35,6 +35,7 @@ from vllm.model_executor.layers.quantization.quark.utils import ( ) from vllm.model_executor.models.utils import WeightsMapper from vllm.platforms import current_platform +from vllm.transformers_utils.config import get_config if TYPE_CHECKING: from vllm.model_executor.models.utils import WeightsMapper @@ -59,6 +60,22 @@ class QuarkConfig(QuantizationConfig): self.kv_cache_group = kv_cache_group self.kv_cache_config = kv_cache_config self.pack_method = pack_method + self.dynamic_mxfp4_quant = False + + def maybe_update_config(self, model_name: str, revision: str | None = None): + self.hf_config = get_config( + model=model_name, + trust_remote_code=False, # or get from model_config if available + revision=revision, + config_format="auto", + ) + + quant_config = getattr(self.hf_config, "quantization_config", None) + if quant_config is not None: + quant_dtype = quant_config["global_quant_config"]["weight"]["dtype"] + model_type = self.hf_config.model_type + if quant_dtype == "fp4" and model_type == "deepseek_v3": + self.dynamic_mxfp4_quant = True def get_linear_method(self) -> "QuarkLinearMethod": return QuarkLinearMethod(self) @@ -108,7 +125,20 @@ class QuarkConfig(QuantizationConfig): if should_ignore_layer( prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping ): - return UnquantizedLinearMethod() + if ( + "self_attn" not in prefix # only quantize attention projections + or not getattr(self, "dynamic_mxfp4_quant", False) + or not isinstance(layer, LinearBase) # Ignore other methods + ): + return UnquantizedLinearMethod() + + scheme = self.get_scheme( + layer=layer, + layer_name=prefix, + dynamic_mxfp4_quant=True, + ) + layer.scheme = scheme + return QuarkLinearMethod(self) if isinstance(layer, LinearBase): scheme = self.get_scheme(layer=layer, layer_name=prefix) layer.scheme = scheme @@ -450,7 +480,9 @@ class QuarkConfig(QuantizationConfig): ) return global_quant_config - def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme": + def _get_scheme_from_config( + self, config: dict[str, Any], dynamic_mxfp4_quant: bool = False + ) -> "QuarkScheme": if config.get("output_tensors") or config.get("bias"): raise NotImplementedError( "Currently, Quark models with output_tensors " @@ -473,7 +505,9 @@ class QuarkConfig(QuantizationConfig): input_symmetric=input_config.get("symmetric"), ) elif self._is_w_ocp_mx_a_x(weight_config, input_config): - return QuarkOCP_MX(weight_config, input_config) + return QuarkOCP_MX( + weight_config, input_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant + ) raise NotImplementedError( "No quark compatible scheme was found. " @@ -481,11 +515,15 @@ class QuarkConfig(QuantizationConfig): f"Input config: {input_config}" ) - def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme": + def get_scheme( + self, layer: torch.nn.Module, layer_name: str, dynamic_mxfp4_quant: bool = False + ) -> "QuarkScheme": layer_quant_config = self._find_matched_config(layer_name, layer) # Find the quant_scheme - scheme = self._get_scheme_from_config(layer_quant_config) + scheme = self._get_scheme_from_config( + layer_quant_config, dynamic_mxfp4_quant=dynamic_mxfp4_quant + ) # Raise error if device does not support the scheme # (e.g. fp8 needs ada lovelace) self._check_scheme_supported(scheme.get_min_capability()) diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 8394857..b2abbce 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -5,8 +5,8 @@ from typing import Any import torch -import vllm.envs as envs from vllm import _custom_ops as ops +from vllm import envs from vllm._aiter_ops import rocm_aiter_ops from vllm.config import get_current_vllm_config from vllm.logger import init_logger @@ -32,6 +32,7 @@ from vllm.model_executor.layers.quantization.mxfp4 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_fp8_moe_layer_for_marlin, ) +from vllm.model_executor.layers.quantization.utils.mxfp4_utils import _swizzle_mxfp4 from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( OCP_MX_BLOCK_SIZE, OCP_MX_Scheme, @@ -49,7 +50,11 @@ from vllm.utils.math_utils import round_up logger = init_logger(__name__) -__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkOCP_MX_MoEMethod"] +__all__ = [ + "QuarkMoEMethod", + "QuarkOCP_MX_MoEMethod", + "QuarkOCP_MX_MoEMethod_OSS", +] class QuarkMoEMethod(FusedMoEMethodBase): @@ -71,14 +76,30 @@ class QuarkMoEMethod(FusedMoEMethodBase): "output_tensors and bias " "quantized are not supported" ) + weight_config = layer_quant_config.get("weight") input_config = layer_quant_config.get("input_tensors") + if quant_config._is_fp8_w4a8(weight_config, input_config): return QuarkW4A8Fp8MoEMethod(weight_config, input_config, module.moe_config) elif quant_config._is_fp8_w8a8(weight_config, input_config): return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config) elif quant_config._is_w_ocp_mx_a_x(weight_config, input_config): - return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config) + emulate = not current_platform.supports_mx() or not ( + rocm_aiter_ops.is_fused_moe_enabled() + ) + if ( + input_config.get("dtype") == "fp8_e4m3" + and not input_config.get("is_dynamic") + and not emulate + ): + return QuarkOCP_MX_MoEMethod_OSS( + weight_config, input_config, module.moe_config + ) + else: + return QuarkOCP_MX_MoEMethod( + weight_config, input_config, module.moe_config + ) else: raise RuntimeError("Unsupported FusedMoe scheme") @@ -706,13 +727,11 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): get_current_vllm_config().model_config.hf_config, "model_type", None ) - self._emulate = ( + self.emulate = ( not current_platform.supports_mx() or not self.ocp_mx_scheme.startswith("w_mxfp4") ) and (self.mxfp4_backend is None or not self.use_rocm_aiter_moe) - self.emulate = True if self.model_type == "gpt_oss" else self._emulate - if self.emulate: logger.warning_once( f"The current mode (supports_mx={current_platform.supports_mx()}, " @@ -753,6 +772,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ) params_dtype = torch.uint8 + self.intermediate_size_per_partition = intermediate_size_per_partition if self.model_type == "gpt_oss": if current_platform.is_rocm(): intermediate_size_per_partition_after_pad = round_up( @@ -765,6 +785,10 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): else: intermediate_size_per_partition_after_pad = intermediate_size_per_partition + self.unpadded_hidden_size = extra_weight_attrs.get( + "unpadded_hidden_size", hidden_size + ) + # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( @@ -991,30 +1015,20 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if not self.emulate: - if ( - self.model_type == "gpt_oss" - and self.mxfp4_backend == Mxfp4Backend.TRITON - ): - raise NotImplementedError( - "Triton kernel implemented fused MoE for GPT_OSS model " - "in Quark(MoE) format is not integrated or provided yet." - ) + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + rocm_aiter_fused_experts, + ) - else: - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - rocm_aiter_fused_experts, - ) - - return rocm_aiter_fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=layer.activation, - quant_config=self.moe_quant_config, - expert_map=layer.expert_map, - ) + return rocm_aiter_fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=layer.activation, + quant_config=self.moe_quant_config, + expert_map=layer.expert_map, + ) else: from vllm.model_executor.layers.fused_moe import fused_experts @@ -1031,3 +1045,133 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): expert_map=layer.expert_map, quant_config=self.moe_quant_config, ) + + +class QuarkOCP_MX_MoEMethod_OSS(QuarkOCP_MX_MoEMethod): + def __init__( + self, + weight_config: dict[str, Any], + input_config: dict[str, Any], + moe: FusedMoEConfig, + ): + super().__init__(weight_config, input_config, moe) + + def process_weights_after_loading(self, layer): + from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig + + w13_bias = layer.w13_bias.to(torch.float32) + w2_bias = layer.w2_bias.to(torch.float32) + + layer.w13_bias = torch.nn.Parameter(w13_bias, requires_grad=False) + layer.w2_bias = torch.nn.Parameter(w2_bias, requires_grad=False) + + # FIXME warp need to be adjusted based on batch size + # only apply to batched mode + if self.moe.use_ep: + num_warps = 4 if envs.VLLM_MOE_DP_CHUNK_SIZE <= 512 else 8 + else: + num_warps = 8 + + w13_weight, w13_flex, w13_scale = _swizzle_mxfp4( + layer.w13_weight, layer.w13_weight_scale, num_warps + ) + w2_weight, w2_flex, w2_scale = _swizzle_mxfp4( + layer.w2_weight, layer.w2_weight_scale, num_warps + ) + + self.w13_weight_triton_tensor = w13_weight + self.w2_weight_triton_tensor = w2_weight + + # need to delete the original weights to save memory on single GPU + del layer.w13_weight + del layer.w2_weight + layer.w13_weight = None + layer.w2_weight = None + torch.cuda.empty_cache() + + if self.static_input_scales: + if layer.w13_input_scale is None or layer.w2_input_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None." + ) + if not all_close_1d(layer.w13_input_scale) or not all_close_1d( + layer.w2_input_scale + ): + logger.warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer." + ) + + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max().to(torch.float32), requires_grad=False + ) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max().to(torch.float32), requires_grad=False + ) + + from triton_kernels.numerics import InFlexData + + lhs_data13 = InFlexData(scale=layer.w13_input_scale) + lhs_data2 = InFlexData(scale=layer.w2_input_scale) + + self.w13_precision_config = PrecisionConfig( + weight_scale=w13_scale, + flex_ctx=FlexCtx(rhs_data=w13_flex, lhs_data=lhs_data13), + ) + + self.w2_precision_config = PrecisionConfig( + weight_scale=w2_scale, + flex_ctx=FlexCtx(rhs_data=w2_flex, lhs_data=lhs_data2), + ) + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + return mxfp4_w4a8_moe_quant_config( + w1_scale=self.w13_precision_config, + w2_scale=self.w2_precision_config, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + w1_bias=layer.w13_bias, + w2_bias=layer.w2_bias, + block_shape=None, + ) + + @property + def is_monolithic(self) -> bool: + return True + + def apply_monolithic( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + expert_map: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if layer.enable_eplb: + raise NotImplementedError( + "EPLB not supported for `QuarkW4MXFp4MoEMethod_OSS` yet." + ) + + from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( # noqa: E501 + triton_kernel_moe_forward, + ) + + return triton_kernel_moe_forward( + hidden_states=x, + w1=self.w13_weight_triton_tensor, + w2=self.w2_weight_triton_tensor, + gating_output=router_logits, + topk=layer.top_k, + renormalize=layer.renormalize, + global_num_experts=layer.global_num_experts, + expert_map=expert_map, + quant_config=self.moe_quant_config, + apply_router_weight_on_input=layer.apply_router_weight_on_input, + unpadded_N_w1=self.intermediate_size_per_partition * 2, + unpadded_K_w1=self.unpadded_hidden_size, + unpadded_N_w2=self.unpadded_hidden_size, + unpadded_K_w2=self.intermediate_size_per_partition, + ) diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py index c5f5012..6917bb6 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py @@ -24,7 +24,12 @@ from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( OCP_MX_BLOCK_SIZE, OCP_MX_Scheme, ) -from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter +from vllm.model_executor.parameter import ( + GroupQuantScaleParameter, + ModelWeightParameter, + PackedvLLMParameter, +) +from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from .quark_scheme import QuarkScheme @@ -169,13 +174,16 @@ except (ImportError, AttributeError, RuntimeError): class QuarkOCP_MX(QuarkScheme): def __init__( - self, weight_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any] + self, + weight_quant_spec: dict[str, Any], + input_quant_spec: dict[str, Any], + dynamic_mxfp4_quant: bool = False, ): self.out_dtype = torch.get_default_dtype() self.qscheme = "per_group" self.weight_quant_spec = weight_quant_spec self.input_quant_spec = input_quant_spec - + self.dynamic_mxfp4_quant = dynamic_mxfp4_quant self.weight_dtype = weight_quant_spec["dtype"].replace("fp", "mxfp") self.input_dtype = input_quant_spec["dtype"].replace("fp", "mxfp") @@ -269,7 +277,13 @@ class QuarkOCP_MX(QuarkScheme): layer.weight_scale.data, requires_grad=False ) else: - if self.rocm_use_aiter_fp4_asm_gemm: + if self.dynamic_mxfp4_quant: + w_q, w_s = dynamic_mxfp4_quant(layer.weight) + layer.weight_scale = torch.nn.Parameter( + w_s.T.contiguous(), requires_grad=False + ) + layer.weight = torch.nn.Parameter(w_q, requires_grad=False) + elif self.rocm_use_aiter_fp4_asm_gemm: # shuffle weight scale weight_scale_shuffle = layer.weight_scale.data sm, sn = weight_scale_shuffle.shape @@ -302,36 +316,51 @@ class QuarkOCP_MX(QuarkScheme): weight_loader: Callable, **kwargs, ): - output_size_per_partition = sum(output_partition_sizes) - layer.logical_widths = output_partition_sizes + if self.dynamic_mxfp4_quant: + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) - # WEIGHT - weight = PackedvLLMParameter( - data=torch.empty( - output_size_per_partition, - self.get_packed_dim(input_size_per_partition, self.weight_dtype), - dtype=torch.uint8, - ), - input_dim=1, - output_dim=0, - packed_dim=1, - packed_factor=self.packed_factor, - weight_loader=weight_loader, - ) - layer.register_parameter("weight", weight) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, kwargs) + else: + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes - # WEIGHT SCALE - weight_scale = GroupQuantScaleParameter( - data=torch.empty( - output_size_per_partition, - input_size_per_partition // OCP_MX_BLOCK_SIZE, - dtype=torch.uint8, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) - layer.register_parameter("weight_scale", weight_scale) + # WEIGHT + weight = PackedvLLMParameter( + data=torch.empty( + output_size_per_partition, + self.get_packed_dim(input_size_per_partition, self.weight_dtype), + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + packed_dim=1, + packed_factor=self.packed_factor, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + weight_scale = GroupQuantScaleParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // OCP_MX_BLOCK_SIZE, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) def apply_weights( self, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index fadf56b..42677a5 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -6,28 +6,18 @@ from typing import TYPE_CHECKING import torch -import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm import _custom_ops as ops +import vllm.envs as envs from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.activation import MoEActivation -from vllm.model_executor.layers.fused_moe.config import ( - FusedMoEConfig, - FusedMoEParallelConfig, - RoutingMethodType, -) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( - activation_to_flashinfer_int, align_fp4_moe_weights_for_fi, ) from vllm.model_executor.layers.quantization.utils.nvfp4_utils import ( swizzle_blockscale, ) -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, - kNvfp4Dynamic, - kNvfp4Static, -) from vllm.platforms import current_platform +from vllm.utils.flashinfer import ( + has_flashinfer_cutlass_fused_moe, +) if TYPE_CHECKING: from vllm.model_executor.layers.fused_moe.layer import FusedMoE @@ -42,92 +32,15 @@ __all__ = [ "reorder_w1w3_to_w3w1", ] -# -# Methods used by the oracle for kernel selection. -# - -def _supports_current_device() -> bool: - """Supports only Blackwell-family GPUs.""" - p = current_platform - return p.is_cuda() and p.is_device_capability_family(100) - - -def _supports_no_act_and_mul() -> bool: - """Supports non-gated MoE.""" - return True - - -def _supports_quant_scheme( - weight_key: QuantKey | None, - activation_key: QuantKey | None, -) -> bool: - """Supports Nvfp4 quantization.""" - SUPPORTED_W_A = [ - (kNvfp4Static, kNvfp4Dynamic), - ] - return (weight_key, activation_key) in SUPPORTED_W_A - - -def _supports_activation(activation: MoEActivation) -> bool: - return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] - - -def _supports_routing_method( - routing_method: RoutingMethodType, -) -> bool: - """Monolithic kernels need to express router support.""" - # NOTE(rob): potentially allow others here. This is a conservative list. - return routing_method in [ - RoutingMethodType.DeepSeekV3, - RoutingMethodType.Renormalize, - RoutingMethodType.RenormalizeNaive, - RoutingMethodType.Llama4, - ] - - -def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: - """ - TRTLLM is a monolithic kernel that requires dispatch_router_logits() for - the naive dispatch/combine path. DeepEP HT only implements dispatch() for - the modular kernel path, so TRTLLM is incompatible with DeepEP HT. - """ - return not moe_parallel_config.use_deepep_ht_kernels - - -def is_supported_config_trtllm( - moe_config: FusedMoEConfig, - weight_key: QuantKey | None, - activation_key: QuantKey | None, - activation_format: mk.FusedMoEActivationFormat, -) -> tuple[bool, str | None]: - """ - This method mirrors mk.FusedMoEPermuteExpertsUnpermute.is_supported_config - """ - - def _make_reason(reason: str) -> str: - return f"kernel does not support {reason}" - - if not _supports_current_device(): - return False, _make_reason(f"current device {current_platform.device_name}") - elif not (moe_config.is_act_and_mul or _supports_no_act_and_mul()): - return False, _make_reason("no act_and_mul MLP layer") - elif not _supports_activation(moe_config.activation): - return False, _make_reason(f"{moe_config.activation} activation") - elif not _supports_quant_scheme(weight_key, activation_key): - return False, _make_reason(f"quantization scheme {weight_key}x{activation_key}") - elif not _supports_parallel_config(moe_config.moe_parallel_config): - return False, _make_reason(f"parallel config {moe_config.moe_parallel_config}") - elif not _supports_routing_method(moe_config.routing_method): - return False, _make_reason(f"routing method {moe_config.routing_method}") - elif activation_format != mk.FusedMoEActivationFormat.Standard: - return False, _make_reason(f"activation format {activation_format}") - elif moe_config.hidden_dim % 512 != 0: - return False, _make_reason( - f"hidden_dim must be divisible by 512, found {moe_config.hidden_dim}" - ) - - return True, None +def is_flashinfer_fp4_cutlass_moe_available() -> bool: + """Return `True` when FlashInfer CUTLASS NV-FP4 kernels can be used.""" + return ( + envs.VLLM_USE_FLASHINFER_MOE_FP4 + and has_flashinfer_cutlass_fused_moe() + and current_platform.is_cuda() + and current_platform.has_device_capability(100) + ) def reorder_w1w3_to_w3w1( @@ -276,190 +189,6 @@ def prepare_static_weights_for_trtllm_fp4_moe( ) -def flashinfer_trtllm_fp4_moe( - layer: torch.nn.Module, - x: torch.Tensor | tuple[torch.Tensor, torch.Tensor], - router_logits: torch.Tensor, - top_k: int, - activation: MoEActivation, - global_num_experts: int, - num_expert_group: int | None, - topk_group: int | None, - custom_routing_function: object | None, - e_score_correction_bias: torch.Tensor | None, -) -> torch.Tensor: - """ - Apply FlashInfer TensorRT-LLM FP4 MoE kernel. - - Args: - layer: The MoE layer with weights and scales - x: Input tensor - router_logits: Router logits for expert selection - top_k: Number of experts to select per token - activation: Activation function to use - global_num_experts: Total number of experts across all ranks - num_expert_group: Number of expert groups (for grouped routing) - topk_group: Top-k within each group - custom_routing_function: Custom routing function (e.g., Llama4) - e_score_correction_bias: Optional routing bias correction - - Returns: - Output tensor from the MoE layer - """ - import flashinfer - - from vllm.model_executor.models.llama4 import Llama4MoE - - SUPPORTED_ACTIVATIONS = [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] - assert activation in SUPPORTED_ACTIVATIONS, ( - f"Only {SUPPORTED_ACTIVATIONS} activations are supported for FlashInfer " - f"TRTLLM FP4 MoE, {activation} found instead." - ) - - # Quantize input to FP4 - if isinstance(x, tuple): - hidden_states_fp4, hidden_states_scale_linear_fp4 = x - else: - # hidden_states is the already quantized - (hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant( - x, layer.a1_gscale, is_sf_swizzled_layout=False - ) - - # Determine routing method type - use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function - routing_method_type = layer.routing_method_type - if use_llama4_routing: - routing_method_type = flashinfer.RoutingMethodType.Llama4 - - # Cast to Fp32 (required by kernel). - router_logits = ( - router_logits.to(torch.float32) - if routing_method_type == RoutingMethodType.DeepSeekV3 - else router_logits - ) - - # Determine activation type - activation_type = activation_to_flashinfer_int(layer.activation) - - # Call TRT-LLM FP4 block-scale MoE kernel - out = flashinfer.fused_moe.trtllm_fp4_block_scale_moe( - routing_logits=router_logits, - routing_bias=e_score_correction_bias, - hidden_states=hidden_states_fp4, - hidden_states_scale=hidden_states_scale_linear_fp4.view( - torch.float8_e4m3fn - ).reshape(*hidden_states_fp4.shape[:-1], -1), - gemm1_weights=layer.w13_weight.data, - gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn), - gemm1_bias=None, - gemm1_alpha=None, - gemm1_beta=None, - gemm1_clamp_limit=None, - gemm2_weights=layer.w2_weight.data, - gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn), - gemm2_bias=None, - output1_scale_scalar=layer.g1_scale_c.data, - output1_scale_gate_scalar=layer.g1_alphas.data, - output2_scale_scalar=layer.g2_alphas.data, - num_experts=global_num_experts, - top_k=top_k, - n_group=num_expert_group if num_expert_group is not None else 0, - topk_group=topk_group if topk_group is not None else 0, - intermediate_size=layer.intermediate_size_per_partition, - local_expert_offset=layer.ep_rank * layer.local_num_experts, - local_num_experts=layer.local_num_experts, - routed_scaling_factor=None, - routing_method_type=routing_method_type, - do_finalize=True, - activation_type=activation_type, - )[0] - - return out - - -def flashinfer_trtllm_fp4_routed_moe( - layer: torch.nn.Module, - x: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - top_k: int, - activation: MoEActivation, - global_num_experts: int, -) -> torch.Tensor: - """ - Apply FlashInfer TensorRT-LLM FP4 MoE kernel. Uses packed - input top k expert indices and scores rather than computing - top k expert indices from scores. - - Args: - layer: The MoE layer with weights and scales - x: Input tensor - topk_ids: Ids of selected experts - top_k: Number of experts to select per token - activation: Activation function to use - global_num_experts: Total number of experts across all ranks - - Returns: - Output tensor from the MoE layer - """ - import flashinfer - - # https://github.com/flashinfer-ai/flashinfer/blob/f0277fd1bff90e309e5c19cab36c5dae056d685d/flashinfer/fused_moe/core.py#L2535 - assert activation == MoEActivation.SILU, ( - "Only SiLU activation is supported for FlashInfer TRTLLM FP4 Routed MoE. " - f"{activation} found instead." - ) - - # Pack top k ids and expert weights into a single int32 tensor, as - # required by TRT-LLM - packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to( - torch.bfloat16 - ).view(torch.int16) - - if isinstance(x, tuple): - # Hidden_states is the already quantized - hidden_states_fp4, hidden_states_scale_linear_fp4 = x - else: - # Quantize input to FP4 - (hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant( - x, layer.a1_gscale, is_sf_swizzled_layout=False - ) - - # Call TRT-LLM FP4 block-scale MoE kernel - out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe( - topk_ids=packed_tensor, - routing_bias=None, - hidden_states=hidden_states_fp4, - hidden_states_scale=hidden_states_scale_linear_fp4.view( - torch.float8_e4m3fn - ).reshape(*hidden_states_fp4.shape[:-1], -1), - gemm1_weights=layer.w13_weight.data, - gemm1_weights_scale=layer.w13_weight_scale.data.view(torch.float8_e4m3fn), - gemm1_bias=None, - gemm1_alpha=None, - gemm1_beta=None, - gemm1_clamp_limit=None, - gemm2_weights=layer.w2_weight.data, - gemm2_weights_scale=layer.w2_weight_scale.data.view(torch.float8_e4m3fn), - gemm2_bias=None, - output1_scale_scalar=layer.g1_scale_c.data, - output1_scale_gate_scalar=layer.g1_alphas.data, - output2_scale_scalar=layer.g2_alphas.data, - num_experts=global_num_experts, - top_k=top_k, - n_group=0, - topk_group=0, - intermediate_size=layer.intermediate_size_per_partition, - local_expert_offset=layer.ep_rank * layer.local_num_experts, - local_num_experts=layer.local_num_experts, - routed_scaling_factor=None, - routing_method_type=1, - do_finalize=True, - )[0] - - return out - - def prepare_nvfp4_moe_layer_for_fi_or_cutlass( backend: "NvFp4MoeBackend", layer: "FusedMoE", @@ -526,6 +255,7 @@ def prepare_nvfp4_moe_layer_for_fi_or_cutlass( ) ) layer.intermediate_size_per_partition = padded_intermediate + layer.moe_config.intermediate_size_per_partition = padded_intermediate w13, w13_scale, w2, w2_scale = prepare_static_weights_for_trtllm_fp4_moe( w13, diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 3d7d8e6..a8be1d6 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from enum import Enum +from typing import TYPE_CHECKING import torch @@ -10,6 +11,9 @@ from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.platforms import current_platform from vllm.utils.math_utils import round_up +if TYPE_CHECKING: + from flashinfer.fused_moe.core import ActivationType + logger = init_logger(__name__) @@ -20,6 +24,10 @@ class FlashinferMoeBackend(Enum): def activation_to_flashinfer_int(activation: MoEActivation) -> int: + return activation_to_flashinfer_type(activation).value + + +def activation_to_flashinfer_type(activation: MoEActivation) -> "ActivationType": from flashinfer.fused_moe.core import ActivationType # silu and gelu are mapped to their gated versions SwiGLU and GeGLU respectively @@ -30,7 +38,7 @@ def activation_to_flashinfer_int(activation: MoEActivation) -> int: MoEActivation.GELU: ActivationType.Geglu, MoEActivation.RELU2_NO_MUL: ActivationType.Relu2, } - return ACTIVATION_TO_FI_ACTIVATION[activation].value + return ACTIVATION_TO_FI_ACTIVATION[activation] def swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor: @@ -87,104 +95,6 @@ def rotate_weights_for_fi_trtllm_fp8_per_tensor_moe( ) -def register_scales_for_trtllm_fp8_per_tensor_moe( - layer: torch.nn.Module, - w13_scale: torch.Tensor, - w13_input_scale: torch.Tensor, - w2_scale: torch.Tensor, - w2_input_scale: torch.Tensor, -) -> None: - """Register necessary scales for FlashInfer TRTLLM FP8 MoE kernel""" - g1_alphas, g2_alphas = make_fp8_moe_alpha_scales_for_fi( - w13_scale=w13_scale, - w13_input_scale=w13_input_scale, - w2_scale=w2_scale, - w2_input_scale=w2_input_scale, - ) - layer.w2_input_scale_inv = 1.0 / w2_input_scale - layer.output1_scales_gate_scalar = g1_alphas - - if layer.activation.is_gated: - layer.output1_scales_scalar = g1_alphas * layer.w2_input_scale_inv - else: - layer.output1_scales_scalar = ( - torch.ones_like(g1_alphas) * layer.w2_input_scale_inv - ) - layer.output2_scales_scalar = g2_alphas - - -def apply_fi_trtllm_fp8_per_tensor_moe( - layer: torch.nn.Module, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - routing_bias: torch.Tensor | None, - top_k: int, - num_expert_group: int | None, - topk_group: int | None, - global_num_experts: int, - apply_router_weight_on_input: bool, -) -> torch.Tensor: - from flashinfer.fused_moe import RoutingMethodType - - import vllm.model_executor.layers.fused_moe.flashinfer_trtllm_moe # noqa: E501, F401 - from vllm.model_executor.models.llama4 import Llama4MoE - - # Added to the layer by: register_scales_for_trtllm_fp8_per_tensor_moe - assert ( - hasattr(layer, "output1_scales_scalar") - and hasattr(layer, "output1_scales_gate_scalar") - and hasattr(layer, "output2_scales_scalar") - ) - - if layer.routing_method_type == RoutingMethodType.Llama4: - assert ( - not layer.renormalize - and layer.custom_routing_function == Llama4MoE.custom_routing_function - ), ( - "FusedMoE flashinfer kernels with Llama4 routing method are only " - "supported for Llama4" - ) - else: - assert layer.custom_routing_function is None, ( - "Custom routing function is only supported for Llama4" - ) - activation_type = activation_to_flashinfer_int(layer.activation) - - return torch.ops.vllm.fi_trtllm_fp8_per_tensor_moe( - routing_logits=router_logits, - routing_bias=routing_bias, - hidden_states=hidden_states, - input_scale=layer.w13_input_scale, - gemm1_weights=layer.w13_weight, - gemm2_weights=layer.w2_weight, - output1_scales_scalar=layer.output1_scales_scalar, - output1_scales_gate_scalar=layer.output1_scales_gate_scalar, - output2_scales_scalar=layer.output2_scales_scalar, - num_experts=global_num_experts, - top_k=top_k, - num_expert_group=num_expert_group, - topk_group=topk_group, - intermediate_size=layer.intermediate_size_per_partition, - local_expert_offset=layer.ep_rank * layer.local_num_experts, - local_num_experts=layer.local_num_experts, - use_routing_scales_on_input=apply_router_weight_on_input, - routing_method_type=layer.routing_method_type, - activation_type=activation_type, - ) - - -def make_fp8_moe_alpha_scales_for_fi( - w13_scale: torch.Tensor, - w13_input_scale: torch.Tensor, - w2_scale: torch.Tensor, - w2_input_scale: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - g1_alphas = (w13_scale * w13_input_scale).squeeze() - g2_alphas = (w2_scale * w2_input_scale).squeeze() - - return g1_alphas, g2_alphas - - def get_flashinfer_moe_backend() -> FlashinferMoeBackend: backend_map = { "throughput": FlashinferMoeBackend.CUTLASS, @@ -432,6 +342,7 @@ def prepare_fp8_moe_layer_for_fi( min_alignment, ) layer.intermediate_size_per_partition = new_intermediate + layer.moe_config.intermediate_size_per_partition = new_intermediate # FI kernels require W31 layout rather than W13. if layer.moe_config.is_act_and_mul: @@ -440,20 +351,12 @@ def prepare_fp8_moe_layer_for_fi( w13_scale = swap_w13_to_w31(w13_scale) # FI TRT-LLM FP8 per-tensor MoE kernel requires weight shuffle - # and registration of alpha scales. Note that we do not register - # as nn.Parameters since they are not needed for weight-reloading. + # and registration of alpha scales. if is_trtllm and not block_quant: assert w13_input_scale is not None assert w2_input_scale is not None rotate_weights_for_fi_trtllm_fp8_per_tensor_moe(w13, w2, is_gated) - register_scales_for_trtllm_fp8_per_tensor_moe( - layer, - w13_scale=w13_scale, - w13_input_scale=w13_input_scale, - w2_scale=w2_scale, - w2_input_scale=w2_input_scale, - ) # Clamp block scales to avoid NaN from the FlashInfer CUTLASS kernel. # Some FP8 models have near-zero block scales (~1e-23) for dead/unused diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index ee3f2ce..ccd5cce 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -53,7 +53,10 @@ logger = init_logger(__name__) def is_fp8(x: torch.dtype | torch.Tensor) -> bool: if isinstance(x, torch.Tensor): x = x.dtype - return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz + try: + return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz + except: + return False # We need to pass in the is_hopper flag as argument because the function diff --git a/vllm/model_executor/layers/quantization/utils/gguf_utils.py b/vllm/model_executor/layers/quantization/utils/gguf_utils.py new file mode 100644 index 0000000..79b34e2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/gguf_utils.py @@ -0,0 +1,373 @@ +import torch +import numpy as np +from gguf.constants import GGMLQuantizationType + +def get_awq_format(w, group_size=128, w_bit=4): + org_w_shape = w.shape + ori_w_dtype = torch.get_default_dtype() + assert w_bit == 4 + assert w.shape[1] % group_size == 0 + + in_features = org_w_shape[1] + w = w.reshape(-1, group_size) + assert torch.isnan(w).sum() == 0 + + max_val = w.amax(dim=1, keepdim=True) + min_val = w.amin(dim=1, keepdim=True) + max_int = 2**w_bit - 1 + min_int = 0 + scales = (max_val - min_val).clamp(min=1e-5) / max_int + zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) + w = ( + torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros + ) * scales + zeros = zeros.view(org_w_shape[0], -1) + scales = scales.view(org_w_shape[0], -1) + w = w.reshape(org_w_shape) + assert torch.isnan(scales).sum() == 0 + assert torch.isnan(w).sum() == 0 + + scales = scales.t().contiguous() # input // group, o + zeros = zeros.t().contiguous() # input // group, o + + # from auto awq + scale_zeros = zeros * scales + scales = scales.clone().to(ori_w_dtype) + + pack_num = 32 // w_bit + intweight = [] + for idx in range(in_features): + intweight.append( + torch.round( + (w[:, idx] + scale_zeros[idx // group_size]) + / scales[idx // group_size] + ).to(torch.int)[:, None] + ) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.to(dtype=torch.int32) + + qweight = torch.zeros( + (intweight.shape[0], intweight.shape[1] // 32 * w_bit), + dtype=torch.int32, + device=intweight.device, + ) + + for col in range(intweight.shape[1] // pack_num): + order_map = [0, 2, w_bit, 6, 1, 3, 5, 7] + for i in range(pack_num): + qweight_col = intweight[:, col * pack_num + order_map[i]] + qweight[:, col] |= qweight_col << (i * w_bit) + + zeros = zeros.to(dtype=torch.int32, device=qweight.device) + + qzeros = torch.zeros( + (zeros.shape[0], zeros.shape[1] // 32 * w_bit), + dtype=torch.int32, + device=zeros.device, + ) + + for col in range(zeros.shape[1] // pack_num): + order_map = [0, 2, w_bit, 6, 1, 3, 5, 7] + for i in range(pack_num): + qzero_col = zeros[:, col * pack_num + order_map[i]] + qzeros[:, col] |= qzero_col << (i * w_bit) + + return qweight, qzeros, scales + +GGML_BLOCK_SIZES = { + "F32": 4, + "F16": 2, + "Q4_0": 2 + 16, + "Q5_0": 2 + 4 + 16, + "Q8_0": 2 + 32, + "Q2_K": 256 // 16 + 256 // 4 + 2 + 2, + "Q3_K": 256 // 8 + 256 // 4 + 12 + 2, + "Q4_K": 2 + 2 + 12 + 256 // 2, + "Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2, + "Q6_K": 256 // 2 + 256 // 4 + 256 // 16 + 2, + "IQ4_XS": 2 + 2 + 256 // 2 + 256 // 64, +} + +def dequantize_f32(data): + return np.frombuffer(data, dtype=np.float32) + +def dequantize_f16(data): + return np.frombuffer(data, dtype=np.float16) + +def dequantize_q4_0(data): + num_blocks = len(data) // GGML_BLOCK_SIZES["Q4_0"] + + scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 8)[:, :1].astype(np.float32) + qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 16)[:, 2:] + + return np.concatenate([ + scales * ((qs & 0xf).astype(np.int8) - 8), + scales * ((qs >> 4).astype(np.int8) - 8), + ], axis=1) + +def dequantize_q5_0(data): + num_blocks = len(data) // GGML_BLOCK_SIZES["Q5_0"] + + scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 2 + 8)[:, :1].astype(np.float32) + qh = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2:2 + 4] + qs = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, 2 + 4 + 16)[:, 2 + 4:] + + bits = np.unpackbits(qh, axis=-1, bitorder="little") + + x0 = ((qs & 0xf).astype(np.int8) | (bits[:, :16] << 4)) - 16 + x1 = ((qs >> 4).astype(np.int8) | (bits[:, 16:] << 4)) - 16 + + return np.concatenate([ + scales * x0, + scales * x1, + ], axis=1) + +def dequantize_q8_0(data): + num_blocks = len(data) // GGML_BLOCK_SIZES["Q8_0"] + + scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 16)[:, :1].astype(np.float32) + qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:] + return scales * qs + +def dequantize_q2_k(data): + block_size = GGML_BLOCK_SIZES["Q2_K"] + num_blocks = len(data) // block_size + + data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2) + data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size) + + dmin = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32) + d = data_f16[:, -2].reshape(num_blocks, 1, 1).astype(np.float32) + scales = data_u8[:, :16].reshape(num_blocks, 16, 1) + qs = data_u8[:, 16:80].reshape(num_blocks, 64) + + tmp = np.stack([ + qs[:, 00:16] >> 0, + qs[:, 16:32] >> 0, + qs[:, 00:16] >> 2, + qs[:, 16:32] >> 2, + qs[:, 00:16] >> 4, + qs[:, 16:32] >> 4, + qs[:, 00:16] >> 6, + qs[:, 16:32] >> 6, + qs[:, 32:48] >> 0, + qs[:, 48:64] >> 0, + qs[:, 32:48] >> 2, + qs[:, 48:64] >> 2, + qs[:, 32:48] >> 4, + qs[:, 48:64] >> 4, + qs[:, 32:48] >> 6, + qs[:, 48:64] >> 6, + ], axis=1) + + return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4) + + +def dequantize_q3_k(data): + block_size = GGML_BLOCK_SIZES["Q3_K"] + num_blocks = len(data) // block_size + + data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2) + data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size) + + d = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32) + bits = np.unpackbits(data_u8[:, :32].reshape(num_blocks, 32, 1), axis=-1, bitorder="little") + bits = 4 ^ (bits << 2) + qs = data_u8[:, 32:32 + 64].astype(np.int16) + a, b, c = data_u8[:, 96: 96 + 12].reshape(num_blocks, 3, 4).transpose(1, 0, 2) + scales = np.zeros((num_blocks, 4, 4), dtype=np.uint8) + scales[:, 0] = (a & 15) | ((c & 3) << 4) + scales[:, 1] = (b & 15) | (((c >> 2) & 3) << 4) + scales[:, 2] = (a >> 4) | (((c >> 4) & 3) << 4) + scales[:, 3] = (b >> 4) | ((c >> 6) << 4) + scales = scales.reshape(num_blocks, 16, 1).astype(np.int16) + + return d * (scales - 32) * np.stack([ + (((qs[:, 00:16] >> 0) & 3) - bits[:, :16, 0]), + (((qs[:, 16:32] >> 0) & 3) - bits[:, 16:, 0]), + (((qs[:, 00:16] >> 2) & 3) - bits[:, :16, 1]), + (((qs[:, 16:32] >> 2) & 3) - bits[:, 16:, 1]), + (((qs[:, 00:16] >> 4) & 3) - bits[:, :16, 2]), + (((qs[:, 16:32] >> 4) & 3) - bits[:, 16:, 2]), + (((qs[:, 00:16] >> 6) & 3) - bits[:, :16, 3]), + (((qs[:, 16:32] >> 6) & 3) - bits[:, 16:, 3]), + (((qs[:, 32:48] >> 0) & 3) - bits[:, :16, 4]), + (((qs[:, 48:64] >> 0) & 3) - bits[:, 16:, 4]), + (((qs[:, 32:48] >> 2) & 3) - bits[:, :16, 5]), + (((qs[:, 48:64] >> 2) & 3) - bits[:, 16:, 5]), + (((qs[:, 32:48] >> 4) & 3) - bits[:, :16, 6]), + (((qs[:, 48:64] >> 4) & 3) - bits[:, 16:, 6]), + (((qs[:, 32:48] >> 6) & 3) - bits[:, :16, 7]), + (((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7]) + ], axis=1) + +def dequantize_q4_k(data, device=None): + block_size = GGML_BLOCK_SIZES["Q4_K"] + num_blocks = len(data) // block_size + data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2) + data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size) + # Casting to float32 because float16 is very slow on CPU + scale_factors = data_f16[:, 0].reshape(num_blocks, 1, 1).astype(np.float32) + scale_offsets = data_f16[:, 1].reshape(num_blocks, 1, 1).astype(np.float32) + qs1 = data_u8[:, 4:16].reshape(num_blocks, 12, 1) + qs2 = data_u8[:, 16:].reshape(num_blocks, 4, 32) + # Dequantize scales and offsets (6 bits and 4 + 2 bits) + factors = scale_factors * np.concatenate([qs1[:, 0:4] & 0b111111, (qs1[:, 8:] & 15) | ((qs1[:, 0:4] >> 6) << 4)], axis=1) + offsets = scale_offsets * np.concatenate([qs1[:, 4:8] & 0b111111, (qs1[:, 8:] >> 4) | ((qs1[:, 4:8] >> 6) << 4)], axis=1) + # Interleave low and high quantized bits + qs2 = np.stack([qs2 & 0xf, qs2 >> 4], axis=2).reshape(num_blocks, 8, 32) + # Dequantize final weights using scales and offsets + weight = factors * qs2 - offsets + if device is None: + return weight + return torch.from_numpy(weight).to(device=device) + +def dequantize_q5_k(data): + block_size = GGML_BLOCK_SIZES["Q5_K"] + num_blocks = len(data) // block_size + + data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2) + data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size) + + d = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32) + dmin = data_f16[:, 1].reshape(num_blocks, 1).astype(np.float32) + scales = data_u8[:, 4:16].reshape(num_blocks, 12, 1) + qh = data_u8[:, 16: 16 + 32].reshape(num_blocks, 32, 1) + qs = data_u8[:, 48: 48 + 128].reshape(num_blocks, 4, 32) + + bits = np.unpackbits(qh, axis=-1, bitorder="little") + + qs_hi_4 = qs >> 4 + qs_lo_4 = qs & 15 + + scales_lo_6 = scales[:, :8] & 63 + scales_hi_6 = scales[:, :8] >> 6 + scales_lo_4 = scales[:, 8:] & 15 + scales_hi_4 = scales[:, 8:] >> 4 + + m1 = dmin * scales_lo_6[:, 4] + m2 = dmin * scales_lo_6[:, 5] + m3 = dmin * scales_lo_6[:, 6] + m4 = dmin * scales_lo_6[:, 7] + m5 = dmin * (scales_hi_4[:, 0] | (scales_hi_6[:, 4] << 4)) + m6 = dmin * (scales_hi_4[:, 1] | (scales_hi_6[:, 5] << 4)) + m7 = dmin * (scales_hi_4[:, 2] | (scales_hi_6[:, 6] << 4)) + m8 = dmin * (scales_hi_4[:, 3] | (scales_hi_6[:, 7] << 4)) + + d1 = d * scales_lo_6[:, 0] + d2 = d * scales_lo_6[:, 1] + d3 = d * scales_lo_6[:, 2] + d4 = d * scales_lo_6[:, 3] + d5 = d * (scales_lo_4[:, 0] | (scales_hi_6[:, 0] << 4)) + d6 = d * (scales_lo_4[:, 1] | (scales_hi_6[:, 1] << 4)) + d7 = d * (scales_lo_4[:, 2] | (scales_hi_6[:, 2] << 4)) + d8 = d * (scales_lo_4[:, 3] | (scales_hi_6[:, 3] << 4)) + + return np.concatenate([ + d1 * (qs_lo_4[:, 0] + (bits[:, :, 0] << 4)) - m1, + d2 * (qs_hi_4[:, 0] + (bits[:, :, 1] << 4)) - m2, + d3 * (qs_lo_4[:, 1] + (bits[:, :, 2] << 4)) - m3, + d4 * (qs_hi_4[:, 1] + (bits[:, :, 3] << 4)) - m4, + d5 * (qs_lo_4[:, 2] + (bits[:, :, 4] << 4)) - m5, + d6 * (qs_hi_4[:, 2] + (bits[:, :, 5] << 4)) - m6, + d7 * (qs_lo_4[:, 3] + (bits[:, :, 6] << 4)) - m7, + d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8, + ], axis=1) + +def dequantize_q6_k(data, device = None): + block_size = GGML_BLOCK_SIZES["Q6_K"] + num_blocks = len(data) // block_size + + data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2) + data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size) + data_i8 = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, block_size) + + scales = data_f16[:, -1].reshape(num_blocks, 1).astype(np.float32) + # TODO use uint8 and cast later? + ql = data_u8[:, :128].astype(np.int16) + qh = data_u8[:, 128:192].astype(np.int16) + sc = data_i8[:, 192:208, np.newaxis].astype(np.float32) + + # Unpack bits, subtraction requires signed data type + q1 = (ql[:, :32 ] & 0xF) | (((qh[:, :32] >> 0) & 3) << 4) - 32 + q2 = (ql[:, 32:64 ] & 0xF) | (((qh[:, :32] >> 2) & 3) << 4) - 32 + q3 = (ql[:, :32 ] >> 4) | (((qh[:, :32] >> 4) & 3) << 4) - 32 + q4 = (ql[:, 32:64 ] >> 4) | (((qh[:, :32] >> 6) & 3) << 4) - 32 + q5 = (ql[:, 64:96 ] & 0xF) | (((qh[:, 32:] >> 0) & 3) << 4) - 32 + q6 = (ql[:, 96:128] & 0xF) | (((qh[:, 32:] >> 2) & 3) << 4) - 32 + q7 = (ql[:, 64:96 ] >> 4) | (((qh[:, 32:] >> 4) & 3) << 4) - 32 + q8 = (ql[:, 96:128] >> 4) | (((qh[:, 32:] >> 6) & 3) << 4) - 32 + + # Dequantize + weight = scales * np.concatenate([ + sc[:, 0] * q1[:, :16], + sc[:, 1] * q1[:, 16:], + sc[:, 2] * q2[:, :16], + sc[:, 3] * q2[:, 16:], + sc[:, 4] * q3[:, :16], + sc[:, 5] * q3[:, 16:], + sc[:, 6] * q4[:, :16], + sc[:, 7] * q4[:, 16:], + sc[:, 8] * q5[:, :16], + sc[:, 9] * q5[:, 16:], + sc[:, 10] * q6[:, :16], + sc[:, 11] * q6[:, 16:], + sc[:, 12] * q7[:, :16], + sc[:, 13] * q7[:, 16:], + sc[:, 14] * q8[:, :16], + sc[:, 15] * q8[:, 16:], + ], axis=1) + + if device is None: + return weight + return torch.from_numpy(weight).to(device=device) + +QK_K = 256 +kvalues_iq4nl = np.array([-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113], dtype=np.int8) + +def dequantize_iq4_xs(data): + block_size = GGML_BLOCK_SIZES["IQ4_XS"] + num_blocks = len(data) // block_size + + d = np.frombuffer(data, dtype=np.float16)[0::block_size//2].astype(np.float32).reshape(num_blocks, 1) + scales_h = np.frombuffer(data, dtype=np.uint16)[1::block_size//2].reshape(num_blocks, 1) + data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)[:, 4:] + scales_l = data_u8[:, :4].reshape(num_blocks, 4) + qs = data_u8[:, 4:].reshape(num_blocks, block_size - 8) + + ls = np.zeros((num_blocks, QK_K // 32), dtype=np.int8) + for ib in range(QK_K // 32): + ls[:, ib] = ((scales_l[:, ib // 2] >> 4 * (ib % 2)) & 0xf) | (((scales_h[:, 0] >> 2 * ib) & 3) << 4) + + dl = (d * (ls - 32)).reshape(num_blocks, -1, 1) + + qs_lo_4 = qs[:, :QK_K // 2].reshape(num_blocks, -1, 16) & 0xf + qs_hi_4 = qs[:, :QK_K // 2].reshape(num_blocks, -1, 16) >> 4 + + y = np.zeros((num_blocks, QK_K), dtype=np.float32) + for ib in range(QK_K // 32): + y[:, ib*32:(ib*32)+16] = dl[:, ib] * kvalues_iq4nl[qs_lo_4[:, ib]] + y[:, (ib*32)+16:(ib*32)+32] = dl[:, ib] * kvalues_iq4nl[qs_hi_4[:, ib]] + + return y.flatten() + +GGML_DEQUANTIZE = { + int(GGMLQuantizationType.F32): dequantize_f32, + int(GGMLQuantizationType.F16): dequantize_f16, + int(GGMLQuantizationType.Q4_0): dequantize_q4_0, + int(GGMLQuantizationType.Q5_0): dequantize_q5_0, + int(GGMLQuantizationType.Q8_0): dequantize_q8_0, + int(GGMLQuantizationType.Q2_K): dequantize_q2_k, + int(GGMLQuantizationType.Q3_K): dequantize_q3_k, + int(GGMLQuantizationType.Q4_K): dequantize_q4_k, + int(GGMLQuantizationType.Q5_K): dequantize_q5_k, + int(GGMLQuantizationType.Q6_K): dequantize_q6_k, + int(GGMLQuantizationType.IQ4_XS): dequantize_iq4_xs, +} + + +def dequant_gguf(data, type, shape): + values = GGML_DEQUANTIZE[type](data) + values = torch.from_numpy(values).view(shape) + return values \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index c114772..23ccfc5 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -255,18 +255,6 @@ def marlin_moe_intermediate_size(w1_packed: torch.Tensor, w2_packed: torch.Tenso return w2_packed.size(1) * marlin_tile_size -def marlin_make_workspace( - output_size_per_partition: int, device: torch.device -) -> torch.Tensor: - max_workspace_size = ( - output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N - ) * GPTQ_MARLIN_MAX_PARALLEL - - return torch.zeros( - max_workspace_size, dtype=torch.int, device=device, requires_grad=False - ) - - def marlin_make_workspace_new( device: torch.device, max_blocks_per_sm: int = 1 ) -> torch.Tensor: @@ -297,12 +285,6 @@ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: ) -def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) - - def marlin_sort_g_idx(g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) return g_idx[g_idx_sort_indices], g_idx_sort_indices diff --git a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py index 9dbfc6e..1bd2d75 100644 --- a/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/mxfp4_utils.py @@ -175,7 +175,7 @@ try: op_func=_dequant_mxfp4, fake_impl=_dequant_mxfp4_fake, ) - dequant_mxfp4 = torch.ops.vllm.dequant_mxfp4 + dequant_mxfp4 = None except AttributeError as error: raise error @@ -185,6 +185,6 @@ try: op_func=_quant_dequant_mxfp4, fake_impl=_quant_dequant_mxfp4_fake, ) - quant_dequant_mxfp4 = torch.ops.vllm.quant_dequant_mxfp4 + quant_dequant_mxfp4 = None except AttributeError as error: raise error diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 12a1799..5709bfb 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -271,12 +271,12 @@ def scaled_quantize( If None, uses input dtype. Use torch.float32 for higher precision. """ group_shape = _normalize_quant_group_shape(x, group_shape) - assert quant_dtype.is_floating_point, ( - "currently `scaled_quantize` only supports floating point dtypes " - "but could be extended to support other dtypes" - ) + # assert quant_dtype.is_floating_point, ( + # "currently `scaled_quantize` only supports floating point dtypes " + # "but could be extended to support other dtypes" + # ) - finfo = torch.finfo(quant_dtype) + finfo = torch.finfo(quant_dtype) if quant_dtype.is_floating_point else torch.iinfo(quant_dtype) # Convert to compute dtype if specified x_compute = x if compute_dtype is None else x.to(compute_dtype) diff --git a/vllm/model_executor/layers/quantization/w8a16.py b/vllm/model_executor/layers/quantization/w8a16.py new file mode 100644 index 0000000..6c42ce7 --- /dev/null +++ b/vllm/model_executor/layers/quantization/w8a16.py @@ -0,0 +1,114 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.parameter import (GroupQuantScaleParameter, + PackedvLLMParameter) +from vllm.model_executor.utils import set_weight_attrs + + +class W8a16Config(QuantizationConfig): + """Config class for W8a16. + + """ + + def __init__( + self, + ) -> None: + pass + + def __repr__(self) -> str: + return ("W8a16Config") + + def get_name(self) -> str: + return "w8a16" + + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + def get_min_capability(self) -> int: + return 75 + + @staticmethod + def get_config_filenames(): + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "W8a16Config": + return cls() + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["W8a16LinearMethod"]: + if isinstance(layer, LinearBase): + return W8a16LinearMethod(self) + return None + + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class W8a16LinearMethod(LinearMethodBase): + """Linear method for w8a16. + + """ + + def __init__(self, quant_config: W8a16Config): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + output_size_per_partition = sum(output_partition_sizes) + weight = Parameter( + torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.int8, + ), + requires_grad=False, + ) + set_weight_attrs( + weight, { + "input_dim": 1, + "output_dim": 0, + }) + + scales = Parameter( + torch.empty( + 1, + output_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(scales, { + "input_dim": None, + "output_dim": 1, + }) + + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + layer.register_parameter("scales", scales) + set_weight_attrs(scales, extra_weight_attrs) + + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + qweight = layer.weight + scales = layer.scales + out_shape = (x.shape[:-1] + (qweight.shape[-2],)) + reshaped_x = x.reshape(-1, x.shape[-1]) + out = ops.linear_w8a16(reshaped_x, qweight, scales, format="TN") + if bias is not None: + out = out + bias + return out.reshape(out_shape) \ No newline at end of file diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 1374334..0a8fd19 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -227,6 +227,7 @@ class RotaryEmbedding(RotaryEmbeddingBase): self.head_size, cos_sin_cache, self.is_neox_style, + self.rotary_dim, ) return query, key diff --git a/vllm/model_executor/layers/rotary_embedding/common.py b/vllm/model_executor/layers/rotary_embedding/common.py index 54c6684..2668a70 100644 --- a/vllm/model_executor/layers/rotary_embedding/common.py +++ b/vllm/model_executor/layers/rotary_embedding/common.py @@ -229,22 +229,7 @@ class ApplyRotaryEmb(CustomOp): cos: torch.Tensor, sin: torch.Tensor, ) -> torch.Tensor: - # from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb return self.forward_native(x, cos, sin) - x, cos, sin, origin_shape, origin_dtype = self._pre_process(x, cos, sin) - - """ - Arguments of apply_rotary_emb() in vllm_flash_attn: - x: [batch_size, seq_len, nheads, headdim] - cos, sin: [seqlen_rotary, rotary_dim / 2] - interleaved: defalut as False (Neox-style). - ... - """ - interleaved = not self.is_neox_style - output = apply_rotary_emb(x, cos, sin, interleaved) - - output = self._post_process(output, origin_shape, origin_dtype) - return output def forward_hip( self, diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py index c3abdc1..c263843 100644 --- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py @@ -8,7 +8,7 @@ import torch from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer -from .base import RotaryEmbeddingBase +from .base import RotaryEmbedding from .common import ( rotate_gptj, rotate_neox, @@ -23,7 +23,7 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: return 0.1 * mscale * math.log(scale) + 1.0 -class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase): +class DeepseekScalingRotaryEmbedding(RotaryEmbedding): """RotaryEmbedding extended with YaRN method. Credits to Peng et al. github.com/jquesnelle/yarn @@ -110,73 +110,3 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase): sin = freqs.sin() * self.mscale cache = torch.cat((cos, sin), dim=-1) return cache - - def forward_native( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor | None = None, - offsets: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - """PyTorch-native implementation equivalent to forward().""" - assert key is not None - cos_sin_cache = self._match_cos_sin_cache_dtype(query) - query_rot = query[..., : self.rotary_dim] - key_rot = key[..., : self.rotary_dim] - if self.rotary_dim < self.head_size: - query_pass = query[..., self.rotary_dim :] - key_pass = key[..., self.rotary_dim :] - - cos_sin = cos_sin_cache[ - torch.add(positions, offsets) if offsets is not None else positions - ] - cos, sin = cos_sin.chunk(2, dim=-1) - if self.is_neox_style: - # NOTE(woosuk): Here we assume that the positions tensor has the - # shape [batch_size, seq_len]. - cos = cos.repeat(1, 1, 2).unsqueeze(-2) - sin = sin.repeat(1, 1, 2).unsqueeze(-2) - else: - cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) - sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) - - rotate_fn = rotate_neox if self.is_neox_style else rotate_gptj - query_rot = query_rot * cos + rotate_fn(query_rot) * sin - key_rot = key_rot * cos + rotate_fn(key_rot) * sin - - if self.rotary_dim < self.head_size: - query = torch.cat((query_rot, query_pass), dim=-1) - key = torch.cat((key_rot, key_pass), dim=-1) - else: - query = query_rot - key = key_rot - return query, key - - def forward_hip( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor | None = None, - offsets: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - return self.forward_native(positions, query, key, offsets) - - def forward_cuda( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor | None = None, - offsets: torch.Tensor | None = None, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - if self.use_flashinfer: - torch.ops.vllm.flashinfer_rotary_embedding( - torch.add(positions, offsets) if offsets is not None else positions, - query, - key, - self.head_size, - self.cos_sin_cache, - self.is_neox_style, - ) - return query, key - else: - return self.forward_native(positions, query, key, offsets) diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 3c946dd..27635dd 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -330,48 +330,46 @@ class MRotaryEmbedding(RotaryEmbeddingBase): ) -> tuple[torch.Tensor, torch.Tensor | None]: assert positions.ndim == 1 or positions.ndim == 2 assert key is not None + from vllm import _custom_ops as ops - cos_sin_cache = self._match_cos_sin_cache_dtype(query) - num_tokens = positions.shape[-1] - cos_sin = cos_sin_cache[positions] - cos, sin = cos_sin.chunk(2, dim=-1) - query_shape = query.shape - key_shape = key.shape - if positions.ndim == 2: - assert self.mrope_section + self._match_cos_sin_cache_dtype(query) + + if self.mrope_interleaved: + num_tokens = positions.shape[-1] + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + query_shape = query.shape + key_shape = key.shape + if positions.ndim == 2: + assert self.mrope_section + q, k = triton_mrope( + query, + key, + cos, + sin, + self.mrope_section, + self.head_size, + self.rotary_dim, + self.mrope_interleaved, + ) - q, k = triton_mrope( - query, - key, - cos, - sin, - self.mrope_section, - self.head_size, - self.rotary_dim, - self.mrope_interleaved, - ) - - return q.reshape(query_shape), k.reshape(key_shape) - - query = query.view(num_tokens, -1, self.head_size) - query_rot = query[..., : self.rotary_dim] - query_pass = query[..., self.rotary_dim :] - query_rot = self.apply_rotary_emb( - query_rot, - cos, - sin, - ) - query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) - - key = key.view(num_tokens, -1, self.head_size) - key_rot = key[..., : self.rotary_dim] - key_pass = key[..., self.rotary_dim :] - key_rot = self.apply_rotary_emb( - key_rot, - cos, - sin, - ) - key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return q.reshape(query_shape), k.reshape(key_shape) + + if positions.ndim == 1: + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) + else: + if self.is_neox_style: + ops.m_rotary_embedding(positions.contiguous(), query, key, self.head_size, + self.cos_sin_cache, + torch.tensor(self.mrope_section, dtype=torch.int), + self.is_neox_style) + else: + query, key = self.forward_native( + positions, query, key + ) + + return query, key def forward_cpu( diff --git a/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py b/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py index e58c978..5e519cf 100644 --- a/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py @@ -12,6 +12,7 @@ from .common import rotate_neox logger = init_logger(__name__) +import ixformer.inference.functions as ixops class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): """Phi3 family of models scaled rotary embedding. @@ -133,27 +134,18 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): query = query.view(*query.shape[:-1], -1, self.head_size) key = key.view(*key.shape[:-1], -1, self.head_size) - if self.use_long_rope: - k = self.original_max_position_embeddings - long_prompt_offset = torch.full_like(positions, k).long() - idx = torch.add(positions, long_prompt_offset) - else: - idx = positions - idx = torch.add(idx, offsets) if offsets is not None else idx - cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx) + k = self.original_max_position_embeddings + long_prompt_offset = torch.any(positions > k) + + ixops.vllm_rotary_embedding_phi( + positions, + query, + key, + self.head_size, + self.long_short_cos_sin_cache, + long_prompt_offset, + k, + offsets + ) - cos, sin = cos_sin.chunk(2, dim=-1) - cos = cos.repeat(1, 2).unsqueeze(-2) - sin = sin.repeat(1, 2).unsqueeze(-2) - - query_rot = query[..., : self.rotary_dim] - query_pass = query[..., self.rotary_dim :] - query_rot = query_rot * cos + rotate_neox(query_rot) * sin - query = torch.cat((query_rot, query_pass), dim=-1) - - key_rot = key[..., : self.rotary_dim] - key_pass = key[..., self.rotary_dim :] - key_rot = key_rot * cos + rotate_neox(key_rot) * sin - key = torch.cat((key_rot, key_pass), dim=-1) - - return query.flatten(-2), key.flatten(-2) + return query, key diff --git a/vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py new file mode 100644 index 0000000..e1e3d18 --- /dev/null +++ b/vllm/model_executor/layers/shared_fused_moe/shared_fused_moe.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +from vllm.distributed import tensor_model_parallel_all_reduce +from vllm.model_executor.layers.fused_moe.layer import FusedMoE + + +# TODO(bnell): Add shared + fused combo function? e.g. + +class SharedFusedMoE(FusedMoE): + """ + A FusedMoE operation that also computes the results of shared experts. + If an all2all communicator is being used the shared expert computation + can be interleaved with the fused all2all dispatch communication step. + """ + + def __init__( + self, + shared_experts: torch.nn.Module, + use_overlapped: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self._shared_experts = shared_experts + self.use_overlapped = use_overlapped + + @property + def shared_experts(self) -> Optional[torch.nn.Module]: + return self._shared_experts if self.use_overlapped else None + + def forward( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + if not self.use_overlapped: + shared_out = self._shared_experts(hidden_states) + + # Reduce outputs if necessary, since the MLP should + # have been created with reduce_results=False. + if (self.reduce_results and self.tp_size > 1 + and self.must_reduce_shared_expert_outputs()): + shared_out = tensor_model_parallel_all_reduce(shared_out) + + fused_out = super().forward( + hidden_states=hidden_states, + router_logits=router_logits, + ) + else: + shared_out, fused_out = super().forward( + hidden_states=hidden_states, + router_logits=router_logits, + ) + return shared_out, fused_out diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index 826caa5..26e79f9 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -9,58 +9,108 @@ from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import CustomOp from vllm.platforms import current_platform -from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits -from vllm.utils.import_utils import has_deep_gemm +from vllm.utils.deep_gemm import ( + fp8_mqa_logits, + fp8_mqa_logits_torch, + fp8_paged_mqa_logits, + fp8_paged_mqa_logits_torch, + is_deep_gemm_supported, +) from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backends.mla.indexer import ( DeepseekV32IndexerMetadata, ) from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.v1.worker.workspace import current_workspace_manager - +from vllm.utils.math_utils import cdiv if current_platform.is_cuda_alike(): from vllm import _custom_ops as ops elif current_platform.is_xpu(): from vllm._xpu_ops import xpu_ops as ops +import ixformer.inference.functions as ixfops + logger = init_logger(__name__) +@torch.inference_mode() +def cp_gather_indexer_k_quant_cache( + kv_cache, # [num_blocks, block_size, head_dim] + dst_value, # [cu_seq_lens[-1], head_dim] + block_table, # [batch_size, num_blocks] + cu_seq_lens, # [batch_size + 1, ] + batch_size, +): + num_blocks, block_size, _ = kv_cache.shape + head_dim = dst_value.shape[-1] + kv_cache = kv_cache.view(num_blocks, -1) + + expected_value = [] + # expected_scale = [] + for b in range(batch_size): + s = cu_seq_lens[b + 1] - cu_seq_lens[b] + if s == 0: + continue + tot = cdiv(s, block_size) + blocks = block_table[b, :tot] + + value = [] + scale = [] + full_block = torch.arange(tot - 1, + device=kv_cache.device, + dtype=torch.int32) + non_remaining_value = kv_cache[blocks[full_block], :block_size * + head_dim].view(-1, head_dim) + # non_remaining_scale = kv_cache[blocks[full_block], + # block_size * head_dim:].view(-1, 4) + + remaining = s - (tot - 1) * block_size + + value = torch.cat([ + non_remaining_value, + kv_cache[blocks[-1], :remaining * head_dim].view(-1, head_dim) + ], + dim=0) + # scale = torch.cat([ + # non_remaining_scale, + # kv_cache[blocks[-1], block_size * head_dim:block_size * head_dim + + # remaining * 4].view(-1, 4) + # ], + # dim=0) + + expected_value.append(value) + # expected_scale.append(scale) + + gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim) + # gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4) + gather_value = gather_value.view(torch.bfloat16) + # gather_scale = gather_scale.view(torch.float32) + dst_value.copy_(gather_value) + # dst_scale.copy_(gather_scale) def sparse_attn_indexer( hidden_states: torch.Tensor, k_cache_prefix: str, kv_cache: torch.Tensor, - q_fp8: torch.Tensor, + q: torch.Tensor, k: torch.Tensor, weights: torch.Tensor, - quant_block_size: int, - scale_fmt: str | None, topk_tokens: int, head_dim: int, max_model_len: int, total_seq_lens: int, - topk_indices_buffer: torch.Tensor, + topk_indices_buffer: torch.Tensor | None, ) -> torch.Tensor: # careful! this will be None in dummy run attn_metadata = get_forward_context().attn_metadata - fp8_dtype = current_platform.fp8_dtype() - # assert isinstance(attn_metadata, dict) if not isinstance(attn_metadata, dict): - # Reserve workspace for indexer during profiling run - current_workspace_manager().get_simultaneous( - ((total_seq_lens, head_dim), torch.float8_e4m3fn), - ((total_seq_lens, 4), torch.uint8), - ) return sparse_attn_indexer_fake( hidden_states, k_cache_prefix, kv_cache, - q_fp8, + q, k, weights, - quant_block_size, - scale_fmt, topk_tokens, head_dim, max_model_len, @@ -74,12 +124,118 @@ def sparse_attn_indexer( has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens - ops.indexer_k_quant_and_cache( + ops.indexer_k_cache( + k, + kv_cache, + slot_mapping + ) + + # topk_indices_buffer[: hidden_states.shape[0]] = -1 + if has_prefill: + prefill_metadata = attn_metadata.prefill + for chunk in prefill_metadata.chunks: + logits = ixfops.dsa_indexer_mqa_logits_with_blocks( + q[chunk.token_start:chunk.token_end], + chunk.cu_seqlens_q, + chunk.cu_seq_lens, + kv_cache, + chunk.block_table, + weights[chunk.token_start : chunk.token_end], + max_q_len=chunk.max_q_len, + max_kv_len=chunk.max_kv_len, + max_context_len=chunk.max_context_len + ) + ixfops.dsa_update_topk_indices( + logits, chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, topk_tokens, + topk_indices_buffer[chunk.token_start:chunk.token_end] + ) + + if has_decode: + decode_metadata = attn_metadata.decode + # TODO: support speculative decode + if decode_metadata.requires_padding: + raise NotImplementedError( + "Sparse attention indexer does not support requires_padding" + ) + + # Use dsa_indexer_mqa_logits_with_blocks similar to prefill + logits = ixfops.dsa_indexer_mqa_logits_with_blocks( + q[:num_decode_tokens], + decode_metadata.cu_seqlens_q, + decode_metadata.cu_seqlens_kv, + kv_cache, + decode_metadata.block_table, + weights[:num_decode_tokens], + max_q_len=decode_metadata.max_q_len, + max_kv_len=decode_metadata.max_kv_len, + max_context_len=decode_metadata.max_context_len, + ) + + ixfops.dsa_update_topk_indices( + logits, + decode_metadata.cu_seqlen_ks, + decode_metadata.cu_seqlen_ke, + topk_tokens, + topk_indices_buffer[:num_decode_tokens], + ) + + return topk_indices_buffer + + +def sparse_attn_indexer_original( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor, +) -> torch.Tensor: + # careful! this will be None in dummy run + attn_metadata = get_forward_context().attn_metadata + # fp8_dtype = current_platform.fp8_dtype() + + # assert isinstance(attn_metadata, dict) + if not isinstance(attn_metadata, dict): + # Reserve workspace for indexer during profiling run + current_workspace_manager().get_simultaneous( + ((total_seq_lens, head_dim), torch.float8_e4m3fn), + ((total_seq_lens, 4), torch.uint8), + ) + return sparse_attn_indexer_fake( + hidden_states, + k_cache_prefix, + kv_cache, + q, + k, + weights, + topk_tokens, + head_dim, + max_model_len, + total_seq_lens, + topk_indices_buffer, + ) + attn_metadata = attn_metadata[k_cache_prefix] + assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) + slot_mapping = attn_metadata.slot_mapping + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + + # During speculative decoding, k may be padded to the CUDA graph batch + # size while slot_mapping only covers actual tokens. Truncate k to avoid + # out-of-bounds reads in the kernel. + num_tokens = slot_mapping.shape[0] + k = k[:num_tokens] + + ops.indexer_k_cache( k, kv_cache, slot_mapping, - quant_block_size, - scale_fmt, ) topk_indices_buffer[: hidden_states.shape[0]] = -1 @@ -88,44 +244,42 @@ def sparse_attn_indexer( # Get the full shared workspace buffers once (will allocate on first use) workspace_manager = current_workspace_manager() - k_fp8_full, k_scale_full = workspace_manager.get_simultaneous( - ((total_seq_lens, head_dim), fp8_dtype), - ((total_seq_lens, 4), torch.uint8), - ) + k_full = workspace_manager.get_simultaneous( + ((total_seq_lens, head_dim), torch.bfloat16), + )[0] for chunk in prefill_metadata.chunks: - k_fp8 = k_fp8_full[: chunk.total_seq_lens] - k_scale = k_scale_full[: chunk.total_seq_lens] - ops.cp_gather_indexer_k_quant_cache( + k = k_full[: chunk.total_seq_lens] + # k_scale = k_scale_full[: chunk.total_seq_lens] + cp_gather_indexer_k_quant_cache( kv_cache, - k_fp8, - k_scale, + k, chunk.block_table, chunk.cu_seq_lens, + chunk.num_reqs, ) - logits = fp8_mqa_logits( - q_fp8[chunk.token_start : chunk.token_end], - (k_fp8, k_scale.view(torch.float32).flatten()), + logits = ops.ref_mqa_logits( + q[chunk.token_start:chunk.token_end], + k, weights[chunk.token_start : chunk.token_end], chunk.cu_seqlen_ks, chunk.cu_seqlen_ke, - clean_logits=False, - ) - num_rows = logits.shape[0] - - topk_indices = topk_indices_buffer[ - chunk.token_start : chunk.token_end, :topk_tokens - ] - torch.ops._C.top_k_per_row_prefill( - logits, - chunk.cu_seqlen_ks, - chunk.cu_seqlen_ke, - topk_indices, - num_rows, - logits.stride(0), - logits.stride(1), - topk_tokens, ) + topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]), + dim=-1)[1] + topk_indices -= chunk.cu_seqlen_ks[:, None] + mask_lo = topk_indices >= 0 + mask_hi = topk_indices - (chunk.cu_seqlen_ke - + chunk.cu_seqlen_ks)[:, None] < 0 + mask = torch.full_like(topk_indices, + False, + dtype=torch.bool, + device=topk_indices.device) + mask = mask_lo & mask_hi + topk_indices = topk_indices.masked_fill(~mask, -1) + topk_indices_buffer[ + chunk.token_start:chunk.token_end, :topk_indices. + shape[-1]] = topk_indices.to(dtype=torch.int32) # Compute lengths from row spans # lengths = (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks).to(torch.int32) @@ -147,63 +301,50 @@ def sparse_attn_indexer( # decode_threshold since we unstrictly split # prefill and decode by decode_threshold # (currently set to 1 + speculative tokens) - padded_q_fp8_decode_tokens = pack_seq_triton( - q_fp8[:num_decode_tokens], decode_lens - ) + padded_q_decode_tokens = pack_seq_triton( + q[:num_decode_tokens], decode_lens) else: - padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape( - decode_lens.shape[0], -1, *q_fp8.shape[1:] - ) + padded_q_decode_tokens = q[:num_decode_tokens].reshape( + decode_lens.shape[0], -1, *q.shape[1:]) # TODO: move and optimize below logic with triton kernels - batch_size = padded_q_fp8_decode_tokens.shape[0] - next_n = padded_q_fp8_decode_tokens.shape[1] + batch_size = padded_q_decode_tokens.shape[0] + next_n = padded_q_decode_tokens.shape[1] assert batch_size == decode_metadata.seq_lens.shape[0] num_padded_tokens = batch_size * next_n - logits = fp8_paged_mqa_logits( - padded_q_fp8_decode_tokens, + logits = ops.ref_paged_mqa_logits( + padded_q_decode_tokens, kv_cache, weights[:num_padded_tokens], decode_metadata.seq_lens, decode_metadata.block_table, - decode_metadata.schedule_metadata, max_model_len=max_model_len, clean_logits=False, ) - num_rows = logits.shape[0] - topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] - - if decode_metadata.use_large_context_topk: - if next_n == 1: - lengths = decode_metadata.seq_lens - else: - # (bs,) -> (bs, 1) + (next_n,) -> (bs, next_n) -> (bs * next_n,) - lengths = ( - decode_metadata.seq_lens.unsqueeze(1) - - next_n - + 1 - + decode_metadata.offsets - ).flatten() - - torch.ops._C.large_context_topk( - logits, - topk_indices, - lengths, - None, - ) - else: - torch.ops._C.top_k_per_row_decode( - logits, - next_n, - decode_metadata.seq_lens, - topk_indices, - num_rows, - logits.stride(0), - logits.stride(1), - topk_tokens, - ) - + # padded query len + current_device = padded_q_decode_tokens.device + padded_num_tokens = batch_size * next_n + positions = torch.arange(max_model_len, + device=current_device).unsqueeze(0).expand( + batch_size * next_n, -1) + row_indices = torch.arange(padded_num_tokens, + device=current_device) // next_n + next_n_offset = torch.arange( + padded_num_tokens, + device=padded_q_decode_tokens.device) % next_n + index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n + + next_n_offset).unsqueeze(1) + # index_end_pos: [B * N, 1] + mask = positions <= index_end_pos + # mask: [B * N, L] + logits = logits.masked_fill(~mask, float('-inf')) + topk_indices = logits.topk(topk_tokens, + dim=-1)[1].to(torch.int32) # [B * N, K] + # ensure we don't set indices for the top k + # that is out of range(masked already) + # this will happen if context length is shorter than K + topk_indices[topk_indices > index_end_pos] = -1 if decode_metadata.requires_padding: # if padded, we need to unpack # the topk indices removing padded tokens @@ -211,9 +352,8 @@ def sparse_attn_indexer( topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), decode_lens, ) - topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = ( - topk_indices - ) + topk_indices_buffer[:num_decode_tokens, :topk_indices. + shape[-1]] = topk_indices.to(dtype=torch.int32) return topk_indices_buffer @@ -222,11 +362,9 @@ def sparse_attn_indexer_fake( hidden_states: torch.Tensor, k_cache_prefix: str, kv_cache: torch.Tensor, - q_fp8: torch.Tensor, + q: torch.Tensor, k: torch.Tensor, weights: torch.Tensor, - quant_block_size: int, - scale_fmt: str | None, topk_tokens: int, head_dim: int, max_model_len: int, @@ -278,9 +416,12 @@ class SparseAttnIndexer(CustomOp): self.max_model_len = max_model_len self.max_total_seq_len = max_total_seq_len self.topk_indices_buffer = topk_indices_buffer - if current_platform.is_cuda() and not has_deep_gemm(): - raise RuntimeError( - "Sparse Attention Indexer CUDA op requires DeepGEMM to be installed." + if current_platform.is_cuda() and not is_deep_gemm_supported(): + logger.warning_once( + "DeepGEMM is not supported or available. SparseAttnIndexer will use a " + "less efficient PyTorch implementation. " + "Please make sure you have the required hardware and software setup " + "for DeepGEMM to achieve optimal performance." ) def forward_native( @@ -303,7 +444,7 @@ class SparseAttnIndexer(CustomOp): def forward_cuda( self, hidden_states: torch.Tensor, - q_fp8: torch.Tensor, + q: torch.Tensor, k: torch.Tensor, weights: torch.Tensor, ): @@ -311,11 +452,9 @@ class SparseAttnIndexer(CustomOp): hidden_states, self.k_cache.prefix, self.k_cache.kv_cache[0], - q_fp8, + q, k, weights, - self.quant_block_size, - self.scale_fmt, self.topk_tokens, self.head_dim, self.max_model_len, diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index bc51b0e..11a4ff0 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -3,6 +3,8 @@ """Utility methods for model layers.""" from collections.abc import Callable +import ast +import re import torch @@ -13,6 +15,7 @@ from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform from vllm.utils.platform_utils import num_compute_units from vllm.utils.torch_utils import direct_register_custom_op +import ixformer.inference.functions as IXF logger = init_logger(__name__) @@ -31,27 +34,6 @@ def is_layer_moe_router_gate(prefix: str) -> bool: return prefix.rsplit(".", 1)[-1] in MOE_LAYER_ROUTER_GATE_SUFFIXES -def shuffle_weight(w: torch.Tensor) -> torch.Tensor: - # Shuffle weight along the last dimension so that - # we folded the weights to adjance location - # Example: - # input: - # [[1, 2, 3, 4, 5, 6], - # [7, 8, 9, 10, 11, 12]] - # output: - # [[1, 4, 2, 5, 3, 6], - # [7, 10, 8, 11, 9, 12]] - # This will be used together with triton swiglu kernel - shape = w.shape - N = shape[-1] - first = w[..., : N // 2] - second = w[..., N // 2 :] - - stacked = torch.stack((first, second), dim=-1) - w_shuffled = stacked.reshape(shape) - return w_shuffled - - def get_token_bin_counts_and_mask( tokens: torch.Tensor, vocab_size: int, @@ -116,7 +98,11 @@ def default_unquantized_gemm( weight: torch.Tensor, bias: torch.Tensor | None = None, ): - return torch.nn.functional.linear(x, weight, bias) + if bias is None and x.dtype in [torch.half, torch.bfloat16] and weight.dtype == torch.float32: + return IXF.mixed_type_linear(input=x, weight=layer.weight) + if x.dtype == torch.float32: + return torch.nn.functional.linear(x, weight, bias) + return IXF.linear(x, weight, bias) def use_aiter_triton_gemm(n, m, k, dtype): @@ -191,7 +177,6 @@ def rocm_unquantized_gemm_impl( and on_gfx9() and x.dtype in [torch.float16, torch.bfloat16] and k % 8 == 0 - and x.is_contiguous() ) if use_skinny is not True: @@ -302,3 +287,72 @@ def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]: return cpu_unquantized_gemm else: return default_unquantized_gemm + +def weight_quant_l1(input: torch.Tensor): + qmax = 127.0 + input = input.to(device="cuda") + abs_max = torch.abs(input).max(dim=1, keepdim=True)[0] + scale = abs_max / qmax + assert scale.shape == (input.shape[0], 1) + quantized = torch.round(input / scale) + quantized = torch.clamp(quantized, -qmax, qmax) + return quantized.to(torch.int8), scale.to(torch.float32) + +def weight_quant_l2(input: torch.Tensor, format: str = "TN"): + qmax = 127.0 + input = input.to(device="cuda") + abs_max = torch.abs(input).max(dim=1, keepdim=True)[0] # [rows, 1] + scale = abs_max / qmax # [rows, 1] + assert scale.shape == (input.shape[0], 1) + quantized = torch.round(input / scale) + quantized = torch.clamp(quantized, -qmax, qmax) + + i4_weights, i8scales, i8zeros = IXF.quant_repack_int4(quantized.to(torch.int8).unsqueeze_(0), -1, 2, format, False) + return i4_weights.squeeze(0), scale.to(torch.float32) + + +def parse_opt_exclude_layers( + opt_exclude_layers_str: str, + prefix: str, +) -> bool: + """ + Parses the VLLM_OPT_EXCLUDE_LAYERS environment variable to determine if + the current layer should be excluded from optimization. + + Args: + opt_exclude_layers_str: The string value from the + VLLM_OPT_EXCLUDE_LAYERS environment variable. + prefix: The prefix of the current layer (e.g., + "model.layers.12.qkv_proj"). + + Returns: + A boolean indicating whether the layer should be excluded. + """ + if not opt_exclude_layers_str: + return False + + try: + # Safely evaluate the string to a Python object + excluded_layers = ast.literal_eval(opt_exclude_layers_str) + + # If a single integer is provided, convert it to a set + if isinstance(excluded_layers, int): + excluded_layers = {excluded_layers} + elif not isinstance(excluded_layers, (set, tuple, list)): + raise TypeError + + excluded_layers: set[int] = set(excluded_layers) + + # Extract layer number from the prefix string + layer_match = re.search(r"\.(\d+)", prefix) + if layer_match and int(layer_match.group(1)) in excluded_layers: + return True # Exclude this layer + except (ValueError, SyntaxError, TypeError): + logger.warning( + "Failed to parse VLLM_OPT_EXCLUDE_LAYERS: %s. " + "Expected a string representation of an integer or a " + "tuple/list/set of integers.", + opt_exclude_layers_str, + ) + + return False # Do not exclude this layer \ No newline at end of file diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py index ff95d5b..30ac389 100644 --- a/vllm/model_executor/model_loader/__init__.py +++ b/vllm/model_executor/model_loader/__init__.py @@ -23,6 +23,10 @@ from vllm.model_executor.model_loader.utils import ( get_model_architecture, get_model_cls, ) +from vllm.model_executor.model_loader.weight_utils import ( + padding_weight_loader +) + logger = init_logger(__name__) diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index ebd3a15..f75d252 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -13,6 +13,7 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from vllm.config import ModelConfig from vllm.config.load import LoadConfig +from vllm import envs from vllm.logger import init_logger from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least from vllm.model_executor.model_loader.base_loader import BaseModelLoader @@ -32,6 +33,8 @@ from vllm.model_executor.model_loader.weight_utils import ( ) from vllm.tracing import instrument from vllm.transformers_utils.repo_utils import list_filtered_repo_files +from vllm import envs + logger = init_logger(__name__) @@ -287,8 +290,7 @@ class DefaultModelLoader(BaseModelLoader): self.load_config.safetensors_load_strategy = "torchao" weights_to_load = {name for name, _ in model.named_parameters()} - all_weights = self.get_all_weights(model_config, model) - loaded_weights = model.load_weights(all_weights) + loaded_weights = model.load_weights(self.get_all_weights(model_config, model)) self.counter_after_loading_weights = time.perf_counter() logger.info_once( @@ -298,7 +300,8 @@ class DefaultModelLoader(BaseModelLoader): ) # We only enable strict check for non-quantized models # that have loaded weights tracking currently. - if model_config.quantization is None and loaded_weights is not None: + opt_flag = envs.VLLM_MOE_OPT_LEVEL != 0 or envs.VLLM_LINEAR_OPT_LEVEL != 0 + if model_config.quantization is None and loaded_weights is not None and not opt_flag: weights_not_loaded = weights_to_load - loaded_weights if weights_not_loaded: raise ValueError( diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index f3d0b03..0bcf1cf 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -39,8 +39,6 @@ from vllm.platforms import current_platform from vllm.tracing import instrument from vllm.utils.import_utils import PlaceholderModule -import ixformer.inference.functions as ixfop - try: from runai_model_streamer import SafetensorsStreamer except ImportError: @@ -289,7 +287,17 @@ def get_quant_config( ) if hf_quant_config is not None: - return quant_cls.from_config(hf_quant_config) + # For modelopt_mixed, config.json's quantization_config may or may + # not contain the per-layer quantized_layers map. Newer checkpoints + # embed it directly; older ones keep it only in hf_quant_config.json. + # If it is missing, fall through to the file-based loading path. + if ( + model_config.quantization == "modelopt_mixed" + and "quantized_layers" not in hf_quant_config + ): + pass # fall through to file-based loading below + else: + return quant_cls.from_config(hf_quant_config) # if hf_quant_config is None, we will try to get config from # hf_overrides @@ -367,8 +375,8 @@ def get_quant_config( if model_config.quantization == "bitsandbytes": config["adapter_name_or_path"] = model_config.model - elif model_config.quantization == "modelopt": - if config["producer"]["name"] == "modelopt": + elif model_config.quantization in ("modelopt", "modelopt_mixed"): + if config.get("producer", {}).get("name") == "modelopt": return quant_cls.from_config(config) else: raise ValueError( @@ -697,13 +705,6 @@ def np_cache_weights_iterator( yield name, torch.from_numpy(param) -FP_TYPES = { - "torch.bfloat16", - "torch.float16", - "torch.float32", - "torch.half", -} - def safetensors_weights_iterator( hf_weights_files: list[str], use_tqdm_on_load: bool, @@ -714,12 +715,6 @@ def safetensors_weights_iterator( if safetensors_load_strategy == "eager": loading_desc += " (eager)" - CUSTOM_QUANT_CONFIG = os.environ.get("CUSTOM_QUANT_CONFIG", None) - try: - with open(f"{CUSTOM_QUANT_CONFIG}/quant_map.json", "r") as f: - quant_map = json.load(f) - except Exception as e: - quant_map = None leftover_state_dict: dict[str, torch.Tensor] = {} for st_file in tqdm( sorted(hf_weights_files, key=_natural_sort_key), @@ -763,160 +758,9 @@ def safetensors_weights_iterator( with safe_open(st_file, framework="pt") as f: for name in f.keys(): # noqa: SIM118 param = f.get_tensor(name) - if not quant_map: - yield name, param - continue - quant_type = quant_map.get(name) - - if quant_type is None or quant_type in FP_TYPES: - yield name, param - continue - - qtype, qformat = quant_type.split("-") - - qname = name - qscale_name = f"{name}_scale" - is_expert = ("expert" in name and "shared" not in name) - - # INT8 - if qtype == "int8": - if param.ndim == 2: - param.unsqueeze_(0) - - qweight, qscale = weight_quant_bf16_to_int8(param) - - A = qweight.shape[0] - - if is_expert: - qscale = qscale.view(A, 1, -1).transpose(-2, -1).contiguous() - - if A == 1: - qweight = qweight.squeeze(0) - qscale = qscale.squeeze(0) - - yield qname, qweight - yield qscale_name, qscale - continue - - # INT4 - if qtype == "int4": - i8scales, i8zeros = None, None - - if param.ndim == 2: - param.unsqueeze_(0) - - qweight, qscale, i8scales, i8zeros = weight_quant_bf16_to_int4pack8( - param, - format=qformat, - symmetric=True, - ) - - A = qweight.shape[0] - - if is_expert: - qscale = qscale.view(A, 1, -1).contiguous() - - if A == 1: - qweight = qweight.squeeze(0) - qscale = qscale.squeeze(0) - - yield qname, qweight - yield qscale_name, qscale - - if i8scales is not None: - yield f"{name}_i8_weight_scale", i8scales.squeeze_(0) - - if i8zeros is not None: - yield f"{name}_i8_weight_zero", i8zeros.squeeze_(0) - - continue - yield name, param -def weight_quant_bf16_to_int8(inputs: torch.Tensor): - device = current_platform.current_device() - - assert inputs.dim() == 3, f"inputs shape is [batch, output_dim, input_dim], but got {inputs.dim()}" - - ori_device = inputs.device - if inputs.device != device: - inputs = inputs.to(device) - - qmax = 127.0 - abs_max = torch.abs(inputs).max(dim=2, keepdim=True)[0] - scale = abs_max / qmax - - assert scale.shape == (*inputs.shape[:2], 1) - - quantized = torch.round(inputs / scale) - quantized = torch.clamp(quantized, -qmax, qmax) - return quantized.to(torch.int8).to(ori_device), scale.to(torch.float32).to(ori_device) - - -def weight_quant_bf16_to_int4pack8( - v: torch.Tensor, # [B, R, C] - block_size: int = 128, - group_size: int = -1, - format: str = "TN", - symmetric: bool = True, - version: int = 2, -): - """ - Batch 版本 INT4 量化 + 打包。 - - Args: - v: [batch, rows, cols], float Tensor - - Returns: - i4_weights: [batch, rows, packed_cols] - scale: [batch, rows, 1] - i8scales: 来自 ixfop.quant_repack_int4 - i8zeros: 来自 ixfop.quant_repack_int4 - """ - device = current_platform.current_device() - ori_device = inputs.device - if inputs.device != device: - inputs = inputs.to(device) - assert v.dim() == 3, f"expected [batch, rows, cols], got {v.shape}" - - B, R, C = v.shape - - qmax = 127.0 - - # abs_max: [B, R, 1] - abs_max = torch.abs(v).amax(dim=2, keepdim=True) - scale = abs_max / qmax # [B, R, 1] - - # quantized: [B, R, C] - quantized = torch.round(v / scale) - quantized = torch.clamp(quantized, -qmax, qmax).to(torch.int8) - - # ixfop.quant_repack_int4 需要 [batch, rows, cols] - # 它本来就是 batch-first,可直接送进去 - # 返回形状一般是: - # i4_weights: [B, R, packed_C] - # i8scales: [B, R, groups] - # i8zeros: [B, R, groups] - i4_weights, i8scales, i8zeros = ixfop.quant_repack_int4( - quantized, # 不需要 unsqueeze,因为本来就是 [B, R, C] - group_size, - version, - format, - symmetric, - ) - if i8scales is not None: - i8scales = i8scales.to(ori_device) - i8zeros = i8zeros.to(ori_device) - return ( - i4_weights.to(ori_device), # [B, R, packed_C] - scale.to(torch.float32).to(ori_device), # [B, R, 1] - i8scales, # 来自 repack - i8zeros - ) - - - def multi_thread_safetensors_weights_iterator( hf_weights_files: list[str], use_tqdm_on_load: bool, @@ -1426,3 +1270,50 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None: # If there were no matches, return the untouched param name return name + + +def padding_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + """Weight loader that allows padding on the last (and optionally middle) dims. + + If shapes match: copy directly. + If shapes differ: copy the overlapping slice (min along each dimension). + Special-cases MoE weights that have expert dim in front (2D/3D). + """ + # Fast path: exact match + if param.shape == loaded_weight.shape: + param.data.copy_(loaded_weight) + return + + # Basic sanity checks + if param.ndim != loaded_weight.ndim: + raise ValueError( + f"Cannot load weight with different ndim: param.ndim={param.ndim}, " + f"loaded.ndim={loaded_weight.ndim}, param.shape={tuple(param.shape)}, " + f"loaded.shape={tuple(loaded_weight.shape)}" + ) + + dims = param.ndim + if dims not in (2, 3): + raise ValueError( + f"padding_weight_loader only supports 2D/3D tensors, got {dims}D. " + f"param.shape={tuple(param.shape)}, loaded.shape={tuple(loaded_weight.shape)}" + ) + + # For MoE tensors, dim0 is num_experts and must match. + if param.shape[0] != loaded_weight.shape[0]: + raise AssertionError( + f"Mismatch in number of experts: param={param.shape[0]}, loaded={loaded_weight.shape[0]}" + ) + + # Copy the overlapping region: [:, :min(dim1), :min(dim2)] + if dims == 2: + copy_d1 = min(param.shape[1], loaded_weight.shape[1]) + param.data[:, :copy_d1].copy_(loaded_weight[:, :copy_d1]) + return + + # dims == 3 + copy_d1 = min(param.shape[1], loaded_weight.shape[1]) + copy_d2 = min(param.shape[2], loaded_weight.shape[2]) + param.data[:, :copy_d1, :copy_d2].copy_( + loaded_weight[:, :copy_d1, :copy_d2] + ) diff --git a/vllm/model_executor/models/AXK1.py b/vllm/model_executor/models/AXK1.py new file mode 100644 index 0000000..f5ed440 --- /dev/null +++ b/vllm/model_executor/models/AXK1.py @@ -0,0 +1,1168 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only A.X K1 model.""" + +import typing +from collections.abc import Callable, Iterable +from itertools import islice + +import torch +from torch import nn + +from vllm._aiter_ops import rocm_aiter_ops +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ParallelConfig, VllmConfig +from vllm.distributed import ( + get_ep_group, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.model_executor.models.deepseek_v2 import ( + DeepseekAttention, + DeepseekV2MLP, + yarn_get_mscale, +) +from vllm.model_executor.models.utils import sequence_parallel_chunk +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.AXK1 import AXK1Config + +from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP +from .utils import ( + PPMissingLayer, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, + make_layers, + maybe_prefix, +) + +logger = init_logger(__name__) + + +class AXK1MLP(DeepseekV2MLP): + pass + + +class AXK1MoE(nn.Module): + def __init__( + self, + config: AXK1Config, + parallel_config: ParallelConfig, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + self.routed_scaling_factor = config.routed_scaling_factor + + self.ep_group = get_ep_group().device_group + self.ep_rank = get_ep_group().rank_in_group + self.ep_size = self.ep_group.size() + self.n_routed_experts: int = config.n_routed_experts + self.n_shared_experts: int = config.n_shared_experts + + self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe + + if config.hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) + + self.gate = ReplicatedLinear( + config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + if config.topk_method == "noaux_tc": + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts, dtype=torch.float32) + ) + else: + self.gate.e_score_correction_bias = None + + # Load balancing settings. + eplb_config = parallel_config.eplb_config + self.enable_eplb = parallel_config.enable_eplb + + self.n_redundant_experts = eplb_config.num_redundant_experts + self.n_logical_experts = self.n_routed_experts + self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts + self.n_local_physical_experts = self.n_physical_experts // self.ep_size + + self.physical_expert_start = self.ep_rank * self.n_local_physical_experts + self.physical_expert_end = ( + self.physical_expert_start + self.n_local_physical_experts + ) + + self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() + self.is_fusion_moe_shared_experts_enabled = ( + rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() + ) + if config.n_shared_experts is None or self.is_fusion_moe_shared_experts_enabled: + self.shared_experts = None + else: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + + self.shared_experts = AXK1MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + is_sequence_parallel=self.is_sequence_parallel, + reduce_results=False, + prefix=f"{prefix}.shared_experts", + ) + + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, + gate=self.gate, + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + # we do scaling outside, set factor to 1.0 to avoid double mul + # aiter applies routed_scaling_factor internally + routed_scaling_factor=1.0 + if not self.is_rocm_aiter_moe_enabled + else self.routed_scaling_factor, + e_score_correction_bias=self.gate.e_score_correction_bias, + enable_eplb=self.enable_eplb, + num_redundant_experts=self.n_redundant_experts, + is_sequence_parallel=self.is_sequence_parallel, + n_shared_experts=config.n_shared_experts + if self.is_fusion_moe_shared_experts_enabled + else None, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + # Chunk the hidden states so they aren't replicated across TP ranks. + # This avoids duplicate computation in self.experts. + # TODO: We can replace the all_reduce at the end of attn with a + # reduce_scatter instead of chunking here. + if self.is_sequence_parallel: + hidden_states = sequence_parallel_chunk(hidden_states) + + if self.experts.is_internal_router: + # In this case, the gate/router runs inside the FusedMoE class + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=hidden_states + ) + else: + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + fused_moe_out = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + + shared_output, final_hidden_states = fused_moe_out + if self.shared_experts is None: + assert shared_output is None + + # Fix FP16 overflow + # See AXK1DecoderLayer for more details. + if hidden_states.dtype != torch.float16: + if not self.is_rocm_aiter_moe_enabled: + final_hidden_states *= self.routed_scaling_factor + elif self.shared_experts is not None: + assert shared_output is not None + shared_output *= 1.0 / self.routed_scaling_factor + + if self.shared_experts is not None: + assert shared_output is not None + final_hidden_states += shared_output + + if self.is_sequence_parallel: + final_hidden_states = tensor_model_parallel_all_gather( + final_hidden_states, 0 + ) + final_hidden_states = final_hidden_states[:num_tokens] + elif self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states + ) + + return final_hidden_states.view(num_tokens, hidden_dim) + + +def _get_llama_4_scaling( + original_max_position_embeddings: int, scaling_beta: float, positions: torch.Tensor +) -> torch.Tensor: + scaling = 1 + scaling_beta * torch.log( + 1 + torch.floor(positions / original_max_position_embeddings) + ) + # Broadcast over num_heads and head_dim + return scaling[..., None, None] + + +class AXK1Attention(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + config: AXK1Config, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int, + kv_lora_rank: int, + max_position_embeddings: int = 8192, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + topk_indices_buffer: torch.Tensor | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + self.scaling = self.qk_head_dim**-0.5 + self.max_position_embeddings = max_position_embeddings + assert topk_indices_buffer is None, ( + "topk_indices_buffer is not \ + supported for AXK1Attention" + ) + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj", + ) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj", + ) + else: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa", + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj", + ) + # O projection. + self.o_proj = RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + if config.rope_parameters["rope_type"] != "default": + config.rope_parameters["rope_type"] = ( + "deepseek_yarn" + if config.rope_parameters.get("apply_yarn_scaling", True) + else "deepseek_llama_scaling" + ) + + self.rotary_emb = get_rope( + qk_rope_head_dim, + max_position=max_position_embeddings, + rope_parameters=config.rope_parameters, + is_neox_style=False, + ) + + if config.rope_parameters["rope_type"] == "deepseek_yarn": + mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False) + scaling_factor = config.rope_parameters["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + self.attn = Attention( + self.num_local_heads, + self.qk_head_dim, + self.scaling, + num_kv_heads=self.num_local_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + llama_4_scaling: torch.Tensor | None, + ) -> torch.Tensor: + if self.q_lora_rank is not None: + q = self.q_a_proj(hidden_states)[0] + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) + else: + q = self.q_proj(hidden_states)[0].view( + -1, self.num_local_heads, self.qk_head_dim + ) + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + latent_cache = latent_cache.unsqueeze(1) + kv_a = self.kv_a_layernorm(kv_a) + kv = self.kv_b_proj(kv_a)[0] + kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_pe = latent_cache[:, :, self.kv_lora_rank :] + + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + + q[..., self.qk_nope_head_dim :] = q_pe + k = torch.empty_like(q) + k[..., : self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim :] = k_pe + + # Apply llama 4 scaling if provided + if llama_4_scaling is not None: + q *= llama_4_scaling + + # padding value to qk_head_dim for alignment + v = torch.nn.functional.pad( + v, [0, self.qk_head_dim - self.v_head_dim], value=0 + ).view(-1, self.num_local_heads * self.qk_head_dim) + attn_output = self.attn(q, k, v) + attn_output = attn_output.view(-1, self.num_local_heads, self.qk_head_dim)[ + ..., : self.v_head_dim + ].reshape(-1, self.num_local_heads * self.v_head_dim) + output, _ = self.o_proj(attn_output) + return output + + +class AXK1MLAAttention(nn.Module): + """ + Main reference: DeepseekV2 paper, and FlashInfer Implementation + (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). + + For more info see MLACommonImpl in: + vllm/v1/attention/backends/mla/utils.py + """ + + def __init__( + self, + vllm_config: VllmConfig, + config: AXK1Config, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int | None, + kv_lora_rank: int, + max_position_embeddings: int = 8192, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + topk_indices_buffer: torch.Tensor | None = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + + self.scaling = self.qk_head_dim**-0.5 + self.max_position_embeddings = max_position_embeddings + + if self.q_lora_rank is not None: + self.fused_qkv_a_proj = MergedColumnParallelLinear( + self.hidden_size, + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.fused_qkv_a_proj", + disable_tp=True, + ) + else: + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa", + ) + + if self.q_lora_rank is not None: + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + self.q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj", + ) + else: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj", + ) + self.o_proj = RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + if config.rope_parameters["rope_type"] != "default": + config.rope_parameters["rope_type"] = ( + "deepseek_yarn" + if config.rope_parameters.get("apply_yarn_scaling", True) + else "deepseek_llama_scaling" + ) + + self.rotary_emb = get_rope( + qk_rope_head_dim, + max_position=max_position_embeddings, + rope_parameters=config.rope_parameters, + is_neox_style=False, + ) + + if config.rope_parameters["rope_type"] == "deepseek_yarn": + mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False) + scaling_factor = config.rope_parameters["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + mla_modules = MLAModules( + kv_a_layernorm=self.kv_a_layernorm, + kv_b_proj=self.kv_b_proj, + rotary_emb=self.rotary_emb, + o_proj=self.o_proj, + fused_qkv_a_proj=self.fused_qkv_a_proj + if self.q_lora_rank is not None + else None, + kv_a_proj_with_mqa=self.kv_a_proj_with_mqa + if self.q_lora_rank is None + else None, + q_a_layernorm=self.q_a_layernorm if self.q_lora_rank is not None else None, + q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None, + q_proj=self.q_proj if self.q_lora_rank is None else None, + indexer=None, + indexer_rotary_emb=None, + is_sparse=False, + topk_indices_buffer=topk_indices_buffer, + ) + + self.mla_attn = MultiHeadLatentAttentionWrapper( + self.hidden_size, + self.num_local_heads, + self.scaling, + self.qk_nope_head_dim, + self.qk_rope_head_dim, + self.v_head_dim, + self.q_lora_rank, + self.kv_lora_rank, + mla_modules, + cache_config, + quant_config, + prefix, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + llama_4_scaling: torch.Tensor | None, + ) -> torch.Tensor: + return self.mla_attn(positions, hidden_states, llama_4_scaling) + + +class AXK1DecoderLayer(nn.Module): + def __init__( + self, + vllm_config: VllmConfig, + prefix: str, + config: AXK1Config | None = None, + ) -> None: + super().__init__() + + if config is None: + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + self.config = config + + self.hidden_size = config.hidden_size + max_position_embeddings = config.max_position_embeddings + # DecoderLayers are created with `make_layers` which passes the prefix + # with the layer's index. + layer_idx = int(prefix.split(sep=".")[-1]) + self.layer_idx = layer_idx + + # verify MLA attention specific fields + qk_nope_head_dim = config.qk_nope_head_dim + qk_rope_head_dim = config.qk_rope_head_dim + v_head_dim = config.v_head_dim + kv_lora_rank = config.kv_lora_rank + use_mha = all(dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)) + self.use_mha = use_mha + + if use_mha: + attn_cls = DeepseekAttention + elif model_config.use_mla: + attn_cls = AXK1MLAAttention + else: + attn_cls = AXK1Attention + self.self_attn = attn_cls( + vllm_config=vllm_config, + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + v_head_dim=v_head_dim, + q_lora_rank=config.q_lora_rank, + kv_lora_rank=kv_lora_rank, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + topk_indices_buffer=None, + ) + + self.is_layer_sparse = self._is_layer_sparse() + if self.is_layer_sparse: + self.mlp = AXK1MoE( + config=config, + parallel_config=parallel_config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = AXK1MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_mlp_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.routed_scaling_factor = config.routed_scaling_factor + + def _is_layer_sparse(self) -> bool: + return ( + self.config.n_routed_experts is not None + and self.layer_idx >= self.config.first_k_dense_replace + and self.layer_idx % self.config.moe_layer_freq == 0 + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + llama_4_scaling: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states.clone() + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + attn_kwargs = { + "positions": positions, + "hidden_states": hidden_states, + } + if not self.use_mha: + attn_kwargs["llama_4_scaling"] = llama_4_scaling + hidden_states = self.self_attn(**attn_kwargs) + + if ( + not isinstance(self.self_attn, DeepseekAttention) + and hidden_states.dtype == torch.float16 + ): + # Fix FP16 overflow + # We scale both hidden_states and residual before + # rmsnorm, and rmsnorm result would not affect by scale. + hidden_states *= 1.0 / self.routed_scaling_factor + if self.layer_idx == 0: + # The residual is shared by all layers, we only scale it on + # first layer. + residual *= 1.0 / self.routed_scaling_factor + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + if self.is_layer_sparse: + hidden_states = self.post_mlp_layernorm(hidden_states) + + if isinstance(self.mlp, AXK1MLP) and hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # Scaling the AXK1MLP output, it is the input of + # input_layernorm of next decoder layer. + # The scaling of AXK1MOE output would be done in the forward + # of AXK1MOE + hidden_states *= 1.0 / self.routed_scaling_factor + + return hidden_states, residual + + +@support_torch_compile +class AXK1Model(nn.Module): + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config: AXK1Config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.device = current_platform.device_type + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: AXK1DecoderLayer(vllm_config, prefix), + prefix=f"{prefix}.layers", + ) + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_input_ids(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + # Compute llama 4 scaling once per forward pass if enabled + llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None) + llama_4_scaling: torch.Tensor | None + if llama_4_scaling_config is not None: + llama_4_scaling = _get_llama_4_scaling( + original_max_position_embeddings=llama_4_scaling_config[ + "original_max_position_embeddings" + ], + scaling_beta=llama_4_scaling_config["beta"], + positions=positions, + ) + else: + llama_4_scaling = None + + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer( + positions, hidden_states, residual, llama_4_scaling + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class AXK1MixtureOfExperts(MixtureOfExperts): + moe_mlp_layers: list[AXK1MoE] + """ + List of MoE MLP layers in the model. + """ + + def extract_moe_parameters(self, example_moe: AXK1MoE | None): + if example_moe is None: + self.num_moe_layers = 0 + self.num_expert_groups = 0 + self.num_logical_experts = 0 + self.num_physical_experts = 0 + self.num_local_physical_experts = 0 + self.num_routed_experts = 0 + self.num_shared_experts = 0 + self.num_redundant_experts = 0 + logger.warning("AXK1: No AXK1MoE layer found in model.layers.") + else: + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ) -> None: + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for moe in self.moe_mlp_layers: + moe.n_local_physical_experts = num_local_physical_experts + moe.n_physical_experts = num_physical_experts + moe.n_redundant_experts = self.num_redundant_experts + moe.experts.update_expert_map() + + +class AXK1ForCausalLM( + nn.Module, SupportsPP, AXK1MixtureOfExperts, SupportsLoRA, SupportsEagle +): + packed_modules_mapping = { + "gate_up_proj": ["gate_proj", "up_proj"], + } + model_cls = AXK1Model + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: AXK1Config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + qk_nope_head_dim = config.qk_nope_head_dim + qk_rope_head_dim = config.qk_rope_head_dim + self.use_mha = all(dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)) + + if self.use_mha: + self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"] + + # `packed_modules_mapping` needs to be modified before + # initializing AXK1Model, as it is passed inplace to + # quantization config init and may be used to select the + # quant_method for relevant layers during initialization. + self.fuse_qkv_a_proj = config.q_lora_rank is not None + if self.fuse_qkv_a_proj: + self.packed_modules_mapping["fused_qkv_a_proj"] = [ + "q_a_proj", + "kv_a_proj_with_mqa", + ] + + self.model = self.model_cls( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + else: + self.lm_head = PPMissingLayer() + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + # Set MoE hyperparameters + self.num_moe_layers = ( + self.config.num_hidden_layers - self.config.first_k_dense_replace + ) + self.set_moe_parameters() + + def set_moe_parameters(self): + self.expert_weights = [] + + self.num_expert_groups = getattr(self.config, "n_group", 1) + + self.moe_layers = [] + self.moe_mlp_layers = [] + example_moe = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, AXK1DecoderLayer) + if isinstance(layer.mlp, AXK1MoE): + # Pick last one layer since the first ones may be dense layers. + example_moe = layer.mlp + self.moe_mlp_layers.append(layer.mlp) + self.moe_layers.append(layer.mlp.experts) + + self.extract_moe_parameters(example_moe) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return SharedFusedMoE.make_expert_params_mapping( + self, + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts, + num_redundant_experts=0, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + rocm_aiter_moe_shared_expert_enabled = ( + rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() + ) + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + mla_params_mapping = [ + ("fused_qkv_a_proj", "q_a_proj", 0), + ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), + ] + mha_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + if self.use_mha: + stacked_params_mapping.extend(mha_params_mapping) + else: + stacked_params_mapping.extend(mla_params_mapping) + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = SharedFusedMoE.make_expert_params_mapping( + self, + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts + + ( + self.config.n_shared_experts + if rocm_aiter_moe_shared_expert_enabled + else 0 + ), + num_redundant_experts=self.num_redundant_experts, + ) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model + + is_fusion_moe_shared_experts_layer = ( + rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name) + ) + + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + if is_fusion_moe_shared_experts_layer: + continue + name_mapped = name.replace(weight_name, param_name) + + # QKV fusion is optional, fall back to normal + # weight loading if it's not enabled + # if go with fusion option, then update name + if ( + param_name == "fused_qkv_a_proj" + ) and name_mapped not in params_dict: + continue + else: + name = name_mapped + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + is_expert_weight = False + + # Special handling: when AITER fusion_shared_experts is enabled, + # checkpoints may provide a single widened shared_experts tensor + # without explicit expert indices + # (e.g. ...mlp.shared_experts.gate_proj.weight). + # For models with multiple shared experts, split that tensor + # evenly into per-shared-expert slices and load them into + # appended expert slots mlp.experts.{n_routed_experts + j}.* + # accordingly. + num_chunks = 1 + if is_fusion_moe_shared_experts_layer: + num_chunks = getattr(self.config, "n_shared_experts", 1) or 1 + # Determine split axis based on op type + # gate/up: ColumnParallel → split along dim 0 + # down: RowParallel → split along dim 1 + split_dim = ( + 1 + if ("down_proj.weight" in name and loaded_weight.ndim > 1) + else 0 + ) + total = loaded_weight.shape[split_dim] + assert total % num_chunks == 0, ( + f"Shared expert weight dim {total} " + f"not divisible by num_chunks {num_chunks}" + ) + chunk_size = total // num_chunks + + for j in range(num_chunks): + chunk_name = name + weight_to_load = loaded_weight + + if is_fusion_moe_shared_experts_layer: + chunk_slice = slice(j * chunk_size, (j + 1) * chunk_size) + if loaded_weight.ndim == 1: + weight_to_load = loaded_weight[chunk_slice] + elif split_dim == 0: + weight_to_load = loaded_weight[chunk_slice, :] + else: + weight_to_load = loaded_weight[:, chunk_slice] + # Synthesize an expert-style name so expert mapping + # can route it + chunk_name = name.replace( + "mlp.shared_experts", + f"mlp.experts.{self.config.n_routed_experts + j}", + ) + + # Use expert_params_mapping to locate the destination + # param and delegate to its expert-aware weight_loader + # with expert_id. + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in chunk_name: + continue + + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = chunk_name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): + continue + + param = params_dict[name_mapped] + # We should ask the weight loader to return success or + # not here since otherwise we may skip experts with + # other available replicas. + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + weight_to_load, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + if not is_fusion_moe_shared_experts_layer: + name = name_mapped + else: + loaded_params.add(name_mapped) + break + else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + if not is_fusion_moe_shared_experts_layer: + loaded_params.add(name) + + return loaded_params + + +def get_spec_layer_idx_from_weight_name( + config: AXK1Config, weight_name: str +) -> int | None: + if config.num_nextn_predict_layers and config.num_nextn_predict_layers > 0: + layer_idx = config.num_hidden_layers + for i in range(config.num_nextn_predict_layers): + if weight_name.startswith(f"model.layers.{layer_idx + i}."): + return layer_idx + i + return None diff --git a/vllm/model_executor/models/bailing_moe_linear.py b/vllm/model_executor/models/bailing_moe_linear.py new file mode 100644 index 0000000..9b54ec6 --- /dev/null +++ b/vllm/model_executor/models/bailing_moe_linear.py @@ -0,0 +1,1246 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy +from collections.abc import Iterable + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.configuration_utils import PretrainedConfig + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.layers.fla.ops.layernorm_guard import ( + RMSNormGated, + layernorm_fn, +) +from vllm.model_executor.layers.fused_moe import FusedMoE, SharedFusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.abstract import MambaBase +from vllm.model_executor.layers.mamba.linear_attn import ( + MiniMaxText01LinearAttention, + MiniMaxText01LinearKernel, + MiniMaxText01RMSNormTP, + clear_linear_attention_cache_for_new_sequences, + linear_attention_decode, + linear_attention_prefill_and_mix, +) +from vllm.model_executor.layers.mamba.mamba_utils import ( + MambaStateCopyFuncCalculator, + MambaStateDtypeCalculator, + MambaStateShapeCalculator, +) +from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.model_executor.models.bailing_moe import BailingMLP +from vllm.sequence import IntermediateTensors +from vllm.v1.attention.backend import AttentionMetadata +from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata + +from .interfaces import HasInnerState, IsHybrid, SupportsPP +from .utils import ( + AutoWeightsLoader, + PPMissingLayer, + is_pp_missing_parameter, + make_layers, + maybe_prefix, +) + +logger = init_logger(__name__) + + +def is_linear_layer(layer_idx, layer_group_size): + if layer_idx is None: + return False + if layer_group_size > 0: + return (layer_idx + 1) % layer_group_size != 0 + else: + return False + + +def _build_rope_parameters(config: PretrainedConfig) -> dict | None: + rope_parameters = copy.deepcopy(getattr(config, "rope_parameters", None)) or {} + if "rope_theta" not in rope_parameters and hasattr(config, "rope_theta"): + rope_parameters["rope_theta"] = config.rope_theta + if "partial_rotary_factor" not in rope_parameters and hasattr( + config, "partial_rotary_factor" + ): + rope_parameters["partial_rotary_factor"] = config.partial_rotary_factor + + rope_scaling = getattr(config, "rope_scaling", None) + if isinstance(rope_scaling, dict): + rope_scaling = copy.deepcopy(rope_scaling) + if "type" in rope_scaling and "rope_type" not in rope_scaling: + rope_scaling["rope_type"] = rope_scaling.pop("type") + rope_parameters.update(rope_scaling) + + return rope_parameters or None + + +class BailingMoeV25MLAAttention(nn.Module): + """ + MLA Attention for BailingMoeV2.5 full attention layers. + """ + + def __init__( + self, + config: PretrainedConfig, + quant_config: QuantizationConfig | None = None, + layer_id: int = 0, + prefix: str = "attention", + cache_config: CacheConfig | None = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.layer_id = layer_id + self.prefix = prefix + + # MLA dimensions + self.qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 128) + self.qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 64) + self.qk_head_dim = self.qk_nope_head_dim + self.qk_rope_head_dim + self.v_head_dim = getattr(config, "v_head_dim", 128) + + # LoRA ranks + self.q_lora_rank = getattr(config, "q_lora_rank", None) + self.kv_lora_rank = getattr(config, "kv_lora_rank", 512) + + tp_size = get_tensor_model_parallel_world_size() + assert self.num_heads % tp_size == 0 + self.num_local_heads = self.num_heads // tp_size + + self.scaling = self.qk_head_dim**-0.5 + + # KV projections + self.kv_a_layernorm = RMSNorm( + self.kv_lora_rank, + eps=config.rms_norm_eps, + ) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj", + ) + + # Output projection + self.o_proj = RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + if self.q_lora_rank is not None: + # Use fused_qkv_a_proj when q_lora_rank is set + self.fused_qkv_a_proj = MergedColumnParallelLinear( + self.hidden_size, + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.fused_qkv_a_proj", + disable_tp=True, + ) + self.q_a_layernorm = RMSNorm( + self.q_lora_rank, + eps=config.rms_norm_eps, + ) + self.q_b_proj = ColumnParallelLinear( + self.q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj", + ) + self.q_proj = None + self.kv_a_proj_with_mqa = None + else: + # Direct projections when no q_lora_rank + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa", + ) + self.fused_qkv_a_proj = None + self.q_a_layernorm = None + self.q_b_proj = None + + rope_parameters = _build_rope_parameters(config) + max_position = getattr(config, "max_position_embeddings", 8192) + self.rotary_emb = get_rope( + head_size=self.qk_rope_head_dim, + max_position=max_position, + is_neox_style=False, + rope_parameters=rope_parameters or None, + dtype=torch.float32, + ) + + # Build MLAModules for MultiHeadLatentAttentionWrapper + mla_modules = MLAModules( + kv_a_layernorm=self.kv_a_layernorm, + kv_b_proj=self.kv_b_proj, + rotary_emb=self.rotary_emb, + o_proj=self.o_proj, + fused_qkv_a_proj=self.fused_qkv_a_proj, + kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, + q_a_layernorm=self.q_a_layernorm, + q_b_proj=self.q_b_proj, + q_proj=self.q_proj, + indexer=None, + is_sparse=False, + topk_indices_buffer=None, + ) + + self.mla_attn = MultiHeadLatentAttentionWrapper( + self.hidden_size, + self.num_local_heads, + self.scaling, + self.qk_nope_head_dim, + self.qk_rope_head_dim, + self.v_head_dim, + self.q_lora_rank, + self.kv_lora_rank, + mla_modules, + cache_config, + quant_config, + prefix, + ) + + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + ) -> torch.Tensor: + """Forward pass for MLA attention.""" + return self.mla_attn(positions, hidden_states) + + +class BailingMoEGate(nn.Module): + def __init__( + self, + config: PretrainedConfig, + params_dtype: torch.dtype | None = None, + prefix: str = "", + ): + super().__init__() + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + self.weight = nn.Parameter( + torch.empty( + (config.num_experts, config.hidden_size), + dtype=self.params_dtype, + ), + ) + if getattr(config, "moe_router_enable_expert_bias", False): + self.expert_bias = nn.Parameter( + torch.empty((config.num_experts,), dtype=torch.float32), + ) + else: + self.expert_bias = None + + def forward(self, hidden_states): + logits = F.linear(hidden_states.to(self.weight.dtype), self.weight, None).to( + hidden_states.dtype + ) + return logits + + +class BailingMoeV25(nn.Module): + """Bailing MoE v2.5 - standalone implementation for linear attention model.""" + + def __init__( + self, + config: PretrainedConfig, + quant_config: QuantizationConfig | None = None, + layer_id: int = 0, + prefix: str = "", + ): + super().__init__() + + self.layer_id = layer_id + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + norm_topk_prob = getattr(config, "norm_topk_prob", None) + # Ring-2.5 reference implementations normalize routing weights by default. + self.norm_expert_prob = True if norm_topk_prob is None else bool(norm_topk_prob) + self.hidden_size = config.hidden_size + self.quant_config = quant_config + self.num_shared_experts = config.num_shared_experts + self.score_function = getattr(config, "score_function", None) + self.n_group = getattr(config, "n_group", None) + self.topk_group = getattr(config, "topk_group", None) + self.use_grouped_topk = self.n_group is not None and self.topk_group is not None + self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0) + + router_dtype = getattr(config, "router_dtype", None) + if router_dtype is None or router_dtype == "fp32": + self.router_dtype = torch.float32 + else: + self.router_dtype = torch.bfloat16 + + # Gate for routing + self.gate = BailingMoEGate( + config=config, + params_dtype=self.router_dtype, + prefix=f"{prefix}.gate", + ) + correction_bias = ( + self.gate.expert_bias if self.gate.expert_bias is not None else None + ) + if self.score_function is not None: + assert (self.score_function == "softmax" and correction_bias is None) or ( + self.score_function == "sigmoid" and correction_bias is not None + ), ( + "score_function and correction_bias should be " + "(softmax, None) or (sigmoid, not None)" + ) + + # Shared experts (using BailingMLP) + if self.num_shared_experts > 0: + if hasattr(config, "moe_shared_expert_intermediate_size"): + intermediate_size = config.moe_shared_expert_intermediate_size + else: + intermediate_size = config.moe_intermediate_size + intermediate_size *= config.num_shared_experts + self.shared_experts = BailingMLP( + intermediate_size=intermediate_size, + config=config, + quant_config=quant_config, + reduce_results=False, + prefix=f"{prefix}.shared_experts", + ) + else: + self.shared_experts = None + + # Routed experts using SharedFusedMoE + self.experts = SharedFusedMoE( + shared_experts=self.shared_experts, + num_experts=self.num_experts, + top_k=self.top_k, + hidden_size=self.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=self.norm_expert_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + scoring_func=self.score_function, + e_score_correction_bias=correction_bias, + num_expert_group=self.n_group, + topk_group=self.topk_group, + use_grouped_topk=self.use_grouped_topk, + router_logits_dtype=self.router_dtype, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + # Ensure contiguous token-major layout before router/projections. + hidden_states = hidden_states.contiguous().view(-1, hidden_size) + + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(hidden_states.to(self.router_dtype)) + router_logits = router_logits.to(hidden_states.dtype) + + final_hidden_states = self.experts( + hidden_states=hidden_states, router_logits=router_logits + ) + + # Handle tuple return from SharedFusedMoE + if self.shared_experts is not None: + shared_output, final_hidden_states = final_hidden_states + else: + shared_output = None + + final_hidden_states *= self.routed_scaling_factor + + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + + if self.tp_size > 1: + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( + final_hidden_states + ) + + return final_hidden_states.view(num_tokens, hidden_size) + + +BailingRMSNormTP = MiniMaxText01RMSNormTP + + +class BailingGroupRMSNormGate(RMSNormGated): + def __init__( + self, + hidden_size, + eps=1e-5, + group_size=None, + norm_before_gate=True, + device=None, + dtype=None, + ): + super().__init__( + hidden_size, + eps=eps, + group_size=group_size, + norm_before_gate=norm_before_gate, + device=device, + dtype=dtype, + activation="sigmoid", + ) + # Add custom weight loader for TP sharding + self.weight.weight_loader = self._weight_loader + + @staticmethod + def _weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor) -> None: + """Load weight with TP sharding.""" + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + shard_size = loaded_weight.shape[0] // tp_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + param.data.copy_(loaded_weight[shard].contiguous()) + + +class BailingMoELinearAttention(nn.Module, MambaBase): + """ + Bailing MoE Linear Attention implementation using minimax backend. + + This implements the linear attention mechanism from sglang, adapted for vLLM's + v1 engine with MambaBase interface support. + """ + + @property + def mamba_type(self) -> str: + return "linear_attention" + + def get_state_shape(self) -> tuple[tuple[int, ...], ...]: + """Return state shape for linear attention cache. + + Must match the calculation in get_mamba_state_shape_from_config. + """ + return MambaStateShapeCalculator.linear_attention_state_shape( + num_heads=self.total_num_heads, + tp_size=self.tp_size, + head_dim=self.head_dim, + ) + + def get_state_dtype(self) -> tuple[torch.dtype, ...]: + """Return state dtype for linear attention cache. + + Must match the calculation in get_mamba_state_dtype_from_config. + """ + return MambaStateDtypeCalculator.linear_attention_state_dtype( + self.model_config.dtype, + self.cache_config.mamba_cache_dtype, + ) + + def __init__( + self, + config: PretrainedConfig, + quant_config: QuantizationConfig | None = None, + layer_id: int = 0, + prefix: str = "linear_attn", + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + ): + super().__init__() + + self.layer_id = layer_id + self.hidden_size = config.hidden_size + self.total_num_heads = config.num_attention_heads + self.total_kv_heads = config.num_attention_heads # MHA + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.model_config = model_config + self.cache_config = cache_config + self.prefix = prefix + + self.head_dim = ( + config.head_dim + if hasattr(config, "head_dim") + else config.hidden_size // self.total_num_heads + ) + + self.hidden_inner_size = self.head_dim * self.total_num_heads + self.scaling = self.head_dim**-0.5 + + assert self.total_num_heads % self.tp_size == 0 + self.tp_heads = self.total_num_heads // self.tp_size + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = getattr(config, "rope_theta", 600000) + + self.tp_kv_heads = self.total_kv_heads // self.tp_size + self.q_size_per_rank = self.head_dim * self.tp_heads + self.kv_size_per_rank = self.head_dim * self.tp_kv_heads + + self.use_qk_norm = getattr(config, "use_qk_norm", False) + self.linear_backend = "minimax" + self.linear_scale = self.linear_backend == "minimax" + self.linear_rope = getattr(config, "linear_rope", True) + if hasattr(config, "use_linear_silu"): + self.linear_silu = config.use_linear_silu + elif hasattr(config, "linear_silu"): + self.linear_silu = config.linear_silu + else: + self.linear_silu = False + + # Block size for lightning attention + self.BLOCK = getattr(config, "block", 256) + + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_heads, # MHA: kv_heads = num_heads + bias=(config.use_bias or config.use_qkv_bias), + quant_config=quant_config, + prefix=f"{prefix}.query_key_value", + ) + + if self.use_qk_norm: + self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + self.g_proj = ColumnParallelLinear( + self.hidden_size, + self.hidden_inner_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.g_proj", + ) + self.dense = RowParallelLinear( + self.hidden_inner_size, + self.hidden_size, + bias=config.use_bias, + quant_config=quant_config, + prefix=f"{prefix}.dense", + reduce_results=True, + ) + + self.group_norm_size = getattr(config, "group_norm_size", 1) + self.rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-5)) + assert self.tp_size <= self.group_norm_size, ( + "tp_size must be <= group_norm_size for local rms norm" + ) + assert self.group_norm_size % self.tp_size == 0, ( + "group_norm_size must be divisible by tp_size" + ) + + # When group_norm_size == 1, group_size equals hidden_size // tp_size + self.g_norm = BailingGroupRMSNormGate( + hidden_size=self.hidden_inner_size // self.tp_size, + eps=self.rms_norm_eps, + group_size=( + self.hidden_inner_size // self.group_norm_size + if self.group_norm_size > 1 + else self.hidden_inner_size // self.tp_size + ), + ) + + # use fp32 rotary embedding + rope_parameters = _build_rope_parameters(config) + + self.rotary_emb = get_rope( + self.head_dim, + max_position=self.max_position_embeddings, + is_neox_style=True, + dtype=torch.float32, + rope_parameters=rope_parameters or None, + ) + + # Build slope tensor for linear attention decay + num_hidden_layers = config.num_hidden_layers + slope_rate = MiniMaxText01LinearAttention._build_slope_tensor( + self.total_num_heads + ) + if num_hidden_layers <= 1: + self.slope_rate = slope_rate * (1 + 1e-5) + else: + self.slope_rate = slope_rate * ( + 1 - layer_id / (num_hidden_layers - 1) + 1e-5 + ) + self.tp_slope = self.slope_rate[ + self.tp_rank * self.tp_heads : (self.tp_rank + 1) * self.tp_heads + ].contiguous() + + # Register for compilation + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + @staticmethod + def weight_direct_load(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + """Load weight for linear attention layers. + + For FP8 quantized parameters, we need to use the weight_loader if available, + as it handles special cases like tensor parallelism sharding. + """ + # Check if param has a weight_loader (for vLLM ModelWeightParameter) + weight_loader = getattr(param, "weight_loader", None) + if weight_loader is not None: + # Use the weight_loader which handles TP sharding and quantization + weight_loader(param, loaded_weight) + else: + # Fall back to direct copy for standard tensors + assert param.size() == loaded_weight.size(), ( + f"Shape mismatch: {param.shape} vs {loaded_weight.shape}" + ) + param.data.copy_(loaded_weight) + + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor, + ) -> None: + """Forward method called by torch.ops.vllm.linear_attention""" + torch.ops.vllm.linear_attention( + hidden_states, + output, + positions, + self.prefix, + ) + + def _forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + positions: torch.Tensor, + ) -> None: + """Actual forward implementation.""" + forward_context = get_forward_context() + attn_metadata: AttentionMetadata = forward_context.attn_metadata + if attn_metadata is not None: + assert isinstance(attn_metadata, dict) + attn_metadata = attn_metadata[self.prefix] + assert isinstance(attn_metadata, LinearAttentionMetadata) + num_actual_tokens = ( + attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens + ) + else: + num_actual_tokens = hidden_states.shape[0] + + # QKV projection + qkv, _ = self.query_key_value(hidden_states[:num_actual_tokens]) + + # use rotary_emb support fp32 + qkv = qkv.to(torch.float32) + if self.linear_silu: + qkv = F.silu(qkv) + + # Split q, k, v + q, k, v = torch.split( + qkv, + [self.q_size_per_rank, self.kv_size_per_rank, self.kv_size_per_rank], + dim=-1, + ) + + # Apply QK norm if needed + if self.use_qk_norm: + q = q.reshape(-1, self.tp_heads, self.head_dim) + k = k.reshape(-1, self.tp_kv_heads, self.head_dim) + q = layernorm_fn( + q, + self.query_layernorm.weight.data, + bias=None, + eps=self.rms_norm_eps, + is_rms_norm=True, + ) + k = layernorm_fn( + k, + self.key_layernorm.weight.data, + bias=None, + eps=self.rms_norm_eps, + is_rms_norm=True, + ) + q = q.reshape(-1, self.q_size_per_rank) + k = k.reshape(-1, self.kv_size_per_rank) + + # Apply rotary embeddings + if self.linear_rope: + q, k = self.rotary_emb(positions[:num_actual_tokens], q, k) + + # Reshape to [batch, heads, seq_len, head_dim] + q = q.view((qkv.shape[0], self.tp_heads, self.head_dim)) + k = k.view((qkv.shape[0], self.tp_kv_heads, self.head_dim)) + v = v.view((qkv.shape[0], self.tp_kv_heads, self.head_dim)) + + # Apply scaling if using minimax backend + if self.linear_scale: + q = q * self.scaling + + # Get KV cache and state indices + if attn_metadata is not None: + kv_cache = self.kv_cache[forward_context.virtual_engine][0] + state_indices_tensor = attn_metadata.state_indices_tensor + clear_linear_attention_cache_for_new_sequences( + kv_cache, state_indices_tensor, attn_metadata + ) + + # Compute attention + decode_only = getattr(attn_metadata, "num_prefills", 0) == 0 + if attn_metadata is None: + hidden = torch.empty( + (q.shape[0], q.shape[1] * q.shape[2]), device=q.device, dtype=q.dtype + ) + else: + if not decode_only: + hidden = self._prefill_and_mix_infer( + q, k, v, kv_cache, state_indices_tensor, attn_metadata + ) + else: + hidden = self._decode_infer( + q, k, v, kv_cache, state_indices_tensor, attn_metadata + ) + + # Apply group norm and gate (matching SGLang behavior) + gate, _ = self.g_proj(hidden_states[:num_actual_tokens]) + + if self.group_norm_size > 1: + hidden = self.g_norm(hidden, gate) + else: + hidden = self.g_norm(hidden) + hidden = F.sigmoid(gate) * hidden + + hidden = hidden.to(hidden_states.dtype) + + # Output projection + dense_out, _ = self.dense(hidden) + output[:num_actual_tokens] = dense_out + + def _prefill_and_mix_infer( + self, q, k, v, kv_cache, state_indices_tensor, attn_metadata + ): + """Handle prefill (mixed with decode if any).""" + return linear_attention_prefill_and_mix( + q=q, + k=k, + v=v, + kv_cache=kv_cache, + state_indices_tensor=state_indices_tensor, + attn_metadata=attn_metadata, + slope_rate=self.tp_slope, + block_size=self.BLOCK, + decode_fn=self._decode_infer, + prefix_fn=MiniMaxText01LinearKernel.jit_linear_forward_prefix, + layer_idx=self.layer_id, + ) + + def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata): + """Handle decode (single token per sequence).""" + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_prefills = attn_metadata.num_prefills + hidden = linear_attention_decode( + q, + k, + v, + kv_cache, + self.tp_slope, + state_indices_tensor, + q_start=num_prefill_tokens, + q_end=None, + slot_start=num_prefills, + slot_end=None, + block_size=32, + ) + return hidden + + +class BailingMoeV25DecoderLayer(nn.Module): + """Decoder layer supporting both linear and full attention.""" + + def __init__( + self, + config: PretrainedConfig, + quant_config: QuantizationConfig | None = None, + layer_id: int = 0, + prefix: str = "layer", + model_config: ModelConfig | None = None, + cache_config: CacheConfig | None = None, + ) -> None: + super().__init__() + self.layer_id = layer_id + self.hidden_size = config.hidden_size + + # Determine attention type (0 = linear, 1 = full) + self.attention_type = getattr(config, "attention_type", 1) + + if self.attention_type == 0: # Linear attention + self.self_attn = BailingMoELinearAttention( + config, + quant_config=quant_config, + layer_id=layer_id, + prefix=f"{prefix}.self_attn", + model_config=model_config, + cache_config=cache_config, + ) + else: # Full attention + self.self_attn = BailingMoeV25MLAAttention( + config, + quant_config=quant_config, + layer_id=layer_id, + prefix=f"{prefix}.self_attn", + cache_config=cache_config, + ) + + # MLP/MoE + is_moe_layer = config.num_experts > 1 and layer_id >= getattr( + config, "first_k_dense_replace", 0 + ) + + if is_moe_layer: + self.mlp = BailingMoeV25( + config, + quant_config=quant_config, + layer_id=layer_id, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = BailingMLP( + intermediate_size=config.intermediate_size, + config=config, + quant_config=quant_config, + reduce_results=True, + prefix=f"{prefix}.mlp", + ) + + # Layer norms + rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-5)) + self.input_layernorm = RMSNorm(self.hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = RMSNorm(self.hidden_size, eps=rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + positions: torch.Tensor, + attn_metadata: AttentionMetadata | None = None, + residual: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + # Input layernorm + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + # Self attention + if self.attention_type == 0: + # Linear attention uses output tensor + self_attention_output = torch.zeros_like(hidden_states) + self.self_attn( + hidden_states=hidden_states, + output=self_attention_output, + positions=positions, + ) + else: + # Full attention + self_attention_output = self.self_attn(hidden_states, positions) + + hidden_states, residual = self.post_attention_layernorm( + self_attention_output, residual + ) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + } +) +class BailingMoeV25Model(nn.Module): + """Bailing MoE v2.5 Model with hybrid attention support.""" + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + quant_config = vllm_config.quant_config + cache_config = vllm_config.cache_config + + self.config = config + self.vocab_size = config.vocab_size + self.embed_dim = config.hidden_size + + # Determine layer types based on layer_group_size + self.layer_group_size = getattr(config, "layer_group_size", 1) + self.num_layers = config.num_hidden_layers + + # decoder_attention_types: 0 = linear, 1 = full + self.decoder_attention_types = [ + 0 if is_linear_layer(i, self.layer_group_size) else 1 + for i in range(self.num_layers) + ] + + # Embeddings + if get_pp_group().is_first_rank: + self.word_embeddings = VocabParallelEmbedding( + self.vocab_size, + self.embed_dim, + org_num_embeddings=self.vocab_size, + ) + else: + from vllm.model_executor.models.utils import PPMissingLayer + + self.word_embeddings = PPMissingLayer() + + # Layers + def layer_fn(prefix): + layer_idx = int(prefix.split(".")[-1]) + layer_config = copy.deepcopy(config) + layer_config.attention_type = self.decoder_attention_types[layer_idx] + + return BailingMoeV25DecoderLayer( + config=layer_config, + quant_config=quant_config, + layer_id=layer_idx, + prefix=prefix, + model_config=model_config, + cache_config=cache_config, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + self.num_layers, layer_fn, prefix=f"{prefix}.layers" + ) + + # Final norm + norm_kwargs = {} + if hasattr(config, "rms_norm_eps"): + norm_kwargs["eps"] = config.rms_norm_eps + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, **norm_kwargs) + else: + from vllm.model_executor.models.utils import PPMissingLayer + + self.norm = PPMissingLayer() + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.word_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + forward_context = get_forward_context() + attn_metadata = forward_context.attn_metadata + + if get_pp_group().is_first_rank: + if inputs_embeds is None: + hidden_states = self.word_embeddings(input_ids) + else: + hidden_states = inputs_embeds + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in self.layers[self.start_layer : self.end_layer]: + hidden_states, residual = layer( + hidden_states=hidden_states, + positions=positions, + attn_metadata=attn_metadata, + residual=residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + else: + if residual is not None: + hidden_states, _ = self.norm(hidden_states, residual) + else: + hidden_states = self.norm(hidden_states) + return hidden_states + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + """Get expert parameter mapping for MoE layers.""" + return FusedMoE.make_expert_params_mapping( + self, + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + num_redundant_experts=0, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load checkpoint weights with simplified mapping.""" + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + # Stacked parameter mappings (fused projections) + stacked_mappings = [ + (".fused_qkv_a_proj", ".q_a_proj", 0), + (".fused_qkv_a_proj", ".kv_a_proj_with_mqa", 1), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + + # Expert parameter mappings from FusedMoE + expert_mappings = list(self.get_expert_mapping()) + + def load_param(name: str, tensor: torch.Tensor, shard_id=None) -> bool: + """Load a single parameter.""" + if name not in params_dict or is_pp_missing_parameter(name, self): + return False + if name.endswith(".bias") and name not in params_dict: + return False + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + + if shard_id is None: + weight_loader(param, tensor) + elif isinstance(shard_id, int): + weight_loader(param, tensor, shard_id) + else: + # Expert param: (expert_id, shard_id) + weight_loader( + param, tensor, name, expert_id=shard_id[0], shard_id=shard_id[1] + ) + + loaded_params.add(name) + return True + + def normalize_name(name: str) -> str | None: + """Normalize checkpoint name to model parameter name.""" + # Skip special weights + if name.startswith("model.mtp"): + return None + # Remove 'model.' prefix if present + # (e.g., 'model.layers.0...' -> 'layers.0...') + name = name.removeprefix("model.") + # Map attention.dense based on layer type + if "attention.dense" in name: + layer_idx = ( + int(name.split("layers.")[1].split(".")[0]) + if "layers." in name + else 0 + ) + attn_name = ( + "self_attn.dense" + if is_linear_layer(layer_idx, self.config.layer_group_size) + else "self_attn.o_proj" + ) + name = name.replace("attention.dense", attn_name) + + # Standard mappings + name = name.replace("attention.", "self_attn.") + name = name.replace( + "mlp.gate.e_score_correction_bias", "mlp.gate.expert_bias" + ) + + return maybe_remap_kv_scale_name(name, params_dict) + + for orig_name, weight in weights: + norm_name = normalize_name(orig_name) + if norm_name is None: + continue + + # Try stacked mappings + loaded = False + for param_suf, weight_suf, shard_id in stacked_mappings: + if weight_suf not in norm_name: + continue + mapped = norm_name.replace(weight_suf, param_suf).replace( + "attention.", "self_attn." + ) + if load_param(mapped, weight, shard_id): + loaded = True + break + if loaded: + continue + + # Handle expert weights + if "mlp.experts" in norm_name: + # Expert bias + if ( + "mlp.experts.e_score_correction_bias" in norm_name + or "mlp.experts.expert_bias" in norm_name + ): + alt = norm_name.replace( + "mlp.experts.e_score_correction_bias", "mlp.gate.expert_bias" + ).replace("mlp.experts.expert_bias", "mlp.gate.expert_bias") + if load_param(alt, weight) or load_param(norm_name, weight): + continue + + # Routed experts + for param_name, weight_name, expert_id, shard_id in expert_mappings: + if weight_name not in norm_name: + continue + mapped = norm_name.replace(weight_name, param_name) + if load_param(mapped, weight, (expert_id, shard_id)): + break + continue + + # General parameters + load_param(norm_name, weight) + + return loaded_params + + +class BailingMoeV25ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsPP): + """Bailing MoE v2.5 For CausalLM.""" + + packed_modules_mapping = { + "gate_up_proj": ["gate_proj", "up_proj"], + } + + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.config = config + self.quant_config = quant_config + + self.model = BailingMoeV25Model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + else: + self.lm_head = PPMissingLayer() + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + return self.logits_processor(self.lm_head, hidden_states) + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, device: torch.device + ) -> IntermediateTensors: + return IntermediateTensors( + { + "hidden_states": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + "residual": torch.zeros( + (batch_size, self.config.hidden_size), dtype=dtype, device=device + ), + } + ) + + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: VllmConfig, + ) -> tuple[tuple[int, ...], ...]: + """Calculate shape for linear attention cache.""" + config = vllm_config.model_config.hf_config + tp_size = vllm_config.parallel_config.tensor_parallel_size + + head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + + # Return base state shape from linear attention (no padding) + return MambaStateShapeCalculator.linear_attention_state_shape( + num_heads=config.num_attention_heads, + tp_size=tp_size, + head_dim=head_dim, + ) + + @classmethod + def get_mamba_state_dtype_from_config( + cls, + vllm_config: VllmConfig, + ) -> tuple[torch.dtype, ...]: + return MambaStateDtypeCalculator.linear_attention_state_dtype( + vllm_config.model_config.dtype, + vllm_config.cache_config.mamba_cache_dtype, + ) + + @classmethod + def get_mamba_state_copy_func(cls) -> tuple: + return MambaStateCopyFuncCalculator.linear_attention_state_copy_func() + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights) + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + return self.model.get_expert_mapping() diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 27cf3a7..1f326ce 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -112,6 +112,42 @@ class LlamaBidirectionalConfig(VerifyAndUpdateConfig): model_config.pooler_config.seq_pooling_type = pooling_type +class LlamaNemotronVLConfig(VerifyAndUpdateConfig): + """Config handler for LlamaNemotronVL embedding models.""" + + @staticmethod + def verify_and_update_model_config(model_config: "ModelConfig") -> None: + from vllm.config.pooler import SequencePoolingType + + hf_config = model_config.hf_config + + # Set bidirectional attention on the language model config + hf_config.is_causal = False + if hasattr(hf_config, "llm_config"): + hf_config.llm_config.is_causal = False + + if hasattr(hf_config, "vision_config"): + hf_config.patch_size = hf_config.vision_config.patch_size + + # Set up pooling type + pooling_type_map: dict[str, SequencePoolingType] = { + "avg": "MEAN", + "cls": "CLS", + "last": "LAST", + } + + # Get pooling type from config (check both top-level and llm_config) + pooling = getattr(hf_config, "pooling", None) + if pooling is None and hasattr(hf_config, "llm_config"): + pooling = getattr(hf_config.llm_config, "pooling", "avg") + + pooling_type = pooling_type_map.get(pooling) + if pooling_type is None: + raise ValueError(f"pool_type {pooling!r} not supported") + + model_config.pooler_config.seq_pooling_type = pooling_type + + class NomicBertModelConfig(VerifyAndUpdateConfig): @staticmethod def verify_and_update_model_config(model_config: "ModelConfig") -> None: @@ -177,7 +213,7 @@ class NomicBertModelConfig(VerifyAndUpdateConfig): "Nomic context extension is disabled. " "Changing max_model_len from %s to %s. " "To enable context extension, see: " - "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html", + "https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.py", max_model_len_before, model_config.max_model_len, ) @@ -293,6 +329,14 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig): } +class Ernie4_5_VLMoeForConditionalGenerationConfig(VerifyAndUpdateConfig): + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + # Ernie4.5-VL conditionally executes text/vision MoE branches, so + # fast_moe_cold_start can silently produce incorrect execution order. + vllm_config.compilation_config.fast_moe_cold_start = False + + class GptOssForCausalLMConfig(VerifyAndUpdateConfig): @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: @@ -553,7 +597,7 @@ class DeepseekV32ForCausalLM(VerifyAndUpdateConfig): if cache_config.cache_dtype.startswith("fp8"): cache_config.cache_dtype = "fp8_ds_mla" logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2") - if cache_config.cache_dtype == "bfloat16": + if cache_config.cache_dtype == "auto" or cache_config.cache_dtype == "bfloat16": cache_config.cache_dtype = "auto" logger.info("Using bfloat16 kv-cache for DeepSeekV3.2") @@ -619,11 +663,14 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "Gemma3TextModel": Gemma3TextModelConfig, "LlamaBidirectionalForSequenceClassification": LlamaBidirectionalConfig, "LlamaBidirectionalModel": LlamaBidirectionalConfig, + "LlamaNemotronVLModel": LlamaNemotronVLConfig, + "LlamaNemotronVLForSequenceClassification": LlamaNemotronVLConfig, "NomicBertModel": NomicBertModelConfig, "Qwen2ForProcessRewardModel": Qwen2ForProcessRewardModelConfig, "Qwen2ForRewardModel": Qwen2ForRewardModelConfig, "Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig, "Qwen3VLForSequenceClassification": Qwen3VLForSequenceClassificationConfig, + "Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLMoeForConditionalGenerationConfig, # noqa: E501 "XLMRobertaModel": JinaRobertaModelConfig, "ColBERTJinaRobertaModel": JinaRobertaModelConfig, "JinaVLForRanking": JinaVLForSequenceClassificationConfig, diff --git a/vllm/model_executor/models/deepencoder.py b/vllm/model_executor/models/deepencoder.py index f7ae426..68c1014 100644 --- a/vllm/model_executor/models/deepencoder.py +++ b/vllm/model_executor/models/deepencoder.py @@ -18,6 +18,7 @@ import torch.nn as nn import torch.nn.functional as F from transformers import CLIPVisionConfig +from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.attention import MMEncoderAttention from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.quantization import QuantizationConfig @@ -263,9 +264,13 @@ class Block(nn.Module): return x -class RelPosAttention(nn.Module): +# --8<-- [start:rel_pos_attention] +@PluggableLayer.register("rel_pos_attention") +class RelPosAttention(PluggableLayer): """Multi-head Attention block with relative position embeddings.""" + # --8<-- [end:rel_pos_attention] + def __init__( self, dim: int, diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 182828c..8dfb679 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -32,6 +32,7 @@ from .deepseek_v2 import ( DeepseekV2MoE, get_spec_layer_idx_from_weight_name, ) +from .interfaces import SupportsPP from .utils import maybe_prefix logger = init_logger(__name__) @@ -180,7 +181,7 @@ class DeepSeekMultiTokenPredictor(nn.Module): @support_torch_compile -class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts): +class DeepSeekMTP(nn.Module, SupportsPP, DeepseekV2MixtureOfExperts): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() self.config = vllm_config.model_config.hf_config @@ -415,8 +416,141 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts): weight_loader(param, loaded_weight) if not is_fusion_moe_shared_experts_layer: loaded_params.add(name) + + # Validate that weights were loaded for each expected MTP layer. + loaded_layers: set[int] = set() + for param_name in loaded_params: + spec_layer = get_spec_layer_idx_from_weight_name(self.config, param_name) + if spec_layer is not None: + loaded_layers.add(spec_layer) + for layer_idx in range( + self.model.mtp_start_layer_idx, + self.model.mtp_start_layer_idx + self.model.num_mtp_layers, + ): + if layer_idx not in loaded_layers: + raise ValueError( + f"MTP speculative decoding layer {layer_idx} weights " + f"missing from checkpoint. The checkpoint may have " + f"been quantized without including the MTP layers. " + f"Use a checkpoint that includes MTP layer weights, " + f"or disable speculative decoding." + ) + + # Post-load optimization: fuse q_a_proj and kv_a_proj_with_mqa + # into a single GEMM, then monkey-patch forward to forward_opt. + # Same logic as DeepseekV2ForCausalLM.load_weights. + opt_support_quant_method = [ + "GGUFLinearMethod", "UnquantizedLinearMethod", + "CompressedTensorsW8A8Int8", "AWQMarlinLinearMethod", + ] + + def inject_layer(layer, quant_method, is_mla): + logger.info( + "DeepSeekMTP optimization: fused q_a_proj and kv_a_proj_with_mqa for layer '%s' (quant_method=%s, is_mla=%s). Forward replaced with forward_opt.", + layer.__class__.__name__, quant_method, is_mla) + q_lora_rank = getattr(layer, "q_lora_rank", None) + if quant_method in ["UnquantizedLinearMethod", + "CompressedTensorsW8A8Int8"]: + if q_lora_rank is not None: + layer.q_a_proj.weight.data = torch.cat( + [layer.q_a_proj.weight, + layer.kv_a_proj_with_mqa.weight], dim=0) + if hasattr(layer.q_a_proj, "weight_scale"): + layer.q_a_proj.weight_scale.data = torch.cat( + [layer.q_a_proj.weight_scale, + layer.kv_a_proj_with_mqa.weight_scale], dim=0) + del layer.kv_a_proj_with_mqa.weight_scale + elif not is_mla: + layer.q_proj.weight.data = torch.cat( + [layer.q_proj.weight, + layer.kv_a_proj_with_mqa.weight], dim=0) + if hasattr(layer.q_proj, "weight_scale"): + layer.q_proj.weight_scale.data = torch.cat( + [layer.q_proj.weight_scale, + layer.kv_a_proj_with_mqa.weight_scale], dim=0) + del layer.kv_a_proj_with_mqa.weight_scale + else: + return + del layer.kv_a_proj_with_mqa.weight + del layer.kv_a_proj_with_mqa + if is_mla: + layer.mla_attn.forward = layer.mla_attn.forward_opt + else: + layer.forward = layer.forward_opt + elif quant_method == "GGUFLinearMethod": + pass + elif quant_method == "AWQMarlinLinearMethod": + dtype = layer.kv_a_proj_with_mqa.qweight.dtype + assert dtype == torch.int32 + if q_lora_rank is not None: + layer.q_a_proj.qweight.data = torch.cat( + [layer.q_a_proj.qweight, + layer.kv_a_proj_with_mqa.qweight], dim=1) + layer.q_a_proj.scales.data = torch.cat( + [layer.q_a_proj.scales, + layer.kv_a_proj_with_mqa.scales], dim=1) + del layer.kv_a_proj_with_mqa.scales + layer.q_a_proj.qzeros.data = torch.cat( + [layer.q_a_proj.qzeros, + layer.kv_a_proj_with_mqa.qzeros], dim=1) + del layer.kv_a_proj_with_mqa.qzeros + elif not is_mla: + layer.q_proj.weight.data = torch.cat( + [layer.q_proj.weight, + layer.kv_a_proj_with_mqa.weight], dim=1) + layer.q_proj.scales.data = torch.cat( + [layer.q_proj.scales, + layer.kv_a_proj_with_mqa.scales], dim=1) + del layer.kv_a_proj_with_mqa.scales + layer.q_proj.qzeros.data = torch.cat( + [layer.q_proj.qzeros, + layer.kv_a_proj_with_mqa.qzeros], dim=1) + del layer.kv_a_proj_with_mqa.qzeros + else: + return + del layer.kv_a_proj_with_mqa.qweight + del layer.kv_a_proj_with_mqa + if is_mla: + layer.mla_attn.forward = layer.mla_attn.forward_opt + else: + layer.forward = layer.forward_opt + + for _, layer in self.model.named_modules(): + if layer.__class__.__name__ in [ + "DeepseekV2Attention", "DeepseekV2MLAAttention" + ]: + if hasattr(layer.kv_a_proj_with_mqa, "scheme"): + quant_method = ( + layer.kv_a_proj_with_mqa.scheme.__class__.__name__) + else: + quant_method = ( + layer.kv_a_proj_with_mqa + .quant_method.__class__.__name__) + if quant_method not in opt_support_quant_method: + break + inject_layer( + layer, quant_method, + is_mla=(layer.__class__.__name__ + == "DeepseekV2MLAAttention")) + + # Check if all parameters have been loaded + all_params = set(params_dict.keys()) + not_loaded = all_params - loaded_params + if not_loaded: + logger.warning( + "DeepSeekMTP weight loading: %d parameters were NOT loaded.\n%s", + len(not_loaded), + "\n".join(sorted(not_loaded)), + ) + else: + logger.info( + "DeepSeekMTP weight loading: All %d parameters loaded successfully.", + len(all_params), + ) + return loaded_params + def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str: """ Rewrite the weight name to match the format of the original model. diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index cb2cf19..5a6547f 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -47,7 +47,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.layers.fused_moe import SharedFusedMoE +from vllm.model_executor.layers.fused_moe import GateLinear, SharedFusedMoE from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -75,7 +75,9 @@ from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backend import AttentionBackend +from vllm.utils.math_utils import cdiv from vllm.v1.attention.backends.mla.indexer import ( DeepseekV32IndexerBackend, ) @@ -89,6 +91,7 @@ from .utils import ( make_layers, maybe_prefix, ) +import ixformer.inference.functions as ixfops logger = init_logger(__name__) @@ -221,73 +224,6 @@ class DeepseekV2MLP(nn.Module): return x -class DeepSeekV2Gate(ReplicatedLinear): - def __init__( - self, - hidden_size: int, - n_experts: int, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - ): - assert quant_config is None - super().__init__( - hidden_size, - n_experts, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate", - ) - - # Unquantized only, will be called "weight". - assert hasattr(self, "weight") - is_hopper_or_blackwell = current_platform.is_device_capability( - (9, 0) - ) or current_platform.is_device_capability_family(100) - SUPPORTED_NUM_EXPERTS = [256, 384] - SUPPORTED_HIDDEN_SIZES = [7168] - - self.allow_dsv3_router_gemm = ( - current_platform.is_cuda() - and is_hopper_or_blackwell - and n_experts in SUPPORTED_NUM_EXPERTS - and hidden_size in SUPPORTED_HIDDEN_SIZES - ) - - self._out_dtype: torch.dtype | None = None - - def set_out_dtype(self, out_dtype: torch.dtype) -> None: - """ - Set out dtype for the router logits. This is needed after - __init__, b/c we need to check if the trtllm kernel is - selected before we decide between bf16 and fp32. - """ - - if self._out_dtype is not None: - raise ValueError("out_dtype has already been set") - else: - self._out_dtype = out_dtype - - @property - def out_dtype(self) -> torch.dtype: - if self._out_dtype is None: - raise ValueError("out_dtype has not been set yet") - return self._out_dtype - - def forward( - self, - x: torch.Tensor, - ) -> tuple[torch.Tensor, None]: - """ - Use specialized GEMM for low batch size for DSV3 and KIMI. - """ - if self.allow_dsv3_router_gemm and x.shape[0] <= 16: - return ops.dsv3_router_gemm( - hidden_states=x, router_weight=self.weight, output_dtype=self.out_dtype - ), None - else: - return super().forward(x) - - class DeepseekV2MoE(nn.Module): def __init__( self, @@ -316,23 +252,12 @@ class DeepseekV2MoE(nn.Module): "Only silu is supported for now." ) - # self.gate = DeepSeekV2Gate( - # config.hidden_size, - # config.n_routed_experts, - # quant_config=None, - # prefix=f"{prefix}.gate", - # ) - self.gate = ReplicatedLinear( + self.gate = GateLinear( config.hidden_size, config.n_routed_experts, - bias=False, - quant_config=None, prefix=f"{prefix}.gate", ) if getattr(config, "topk_method", None) == "noaux_tc": - # self.gate.e_score_correction_bias = nn.Parameter( - # torch.empty(config.n_routed_experts, dtype=torch.float32) - # ) self.gate.e_score_correction_bias = nn.Parameter( torch.empty(config.n_routed_experts) ) @@ -401,12 +326,12 @@ class DeepseekV2MoE(nn.Module): else None, ) - # # NOTE(rob): this is a hack until we finish off the PR for - # # merging TRTLLM kernels into the MK framework. Then we can - # # query the MonolithicMK for the expected router logits. - # self.gate.set_out_dtype( - # torch.float32 if self.experts.quant_method.is_monolithic else torch.bfloat16 - # ) + # NOTE(rob): this is a hack until we finish off the PR for + # merging TRTLLM kernels into the MK framework. Then we can + # query the MonolithicMK for the expected router logits. + self.gate.set_out_dtype( + torch.float32 if self.experts.quant_method.is_monolithic else torch.bfloat16 + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape @@ -443,11 +368,12 @@ class DeepseekV2MoE(nn.Module): elif self.shared_experts is not None: assert shared_output is not None shared_output *= 1.0 / self.routed_scaling_factor - + if self.shared_experts is not None: assert shared_output is not None final_hidden_states += shared_output + if self.is_sequence_parallel: final_hidden_states = tensor_model_parallel_all_gather( final_hidden_states, 0 @@ -596,7 +522,7 @@ class DeepseekV2Attention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.attn", ) - + def forward( self, positions: torch.Tensor, @@ -605,23 +531,20 @@ class DeepseekV2Attention(nn.Module): ) -> torch.Tensor: if self.q_lora_rank is not None: q = self.q_a_proj(hidden_states)[0] + kv_a, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1) q = self.q_a_layernorm(q) q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) else: - q = self.q_proj(hidden_states)[0].view( - -1, self.num_local_heads, self.qk_head_dim - ) - q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] - kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - latent_cache = latent_cache.unsqueeze(1) + q = self.q_proj(hidden_states)[0] + kv_a, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1) + q = q.view(-1, self.num_local_heads, self.qk_head_dim) + + _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + kv_a = self.kv_a_layernorm(kv_a) kv = self.kv_b_proj(kv_a)[0] kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k_pe = latent_cache[:, :, self.kv_lora_rank :] - - q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + k_nope, v_nope = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) q[..., self.qk_nope_head_dim :] = q_pe k = torch.empty_like(q) @@ -671,7 +594,7 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase): def get_attn_backend(self) -> AttentionBackend: return DeepseekV32IndexerBackend - + class Indexer(nn.Module): def __init__( @@ -727,8 +650,8 @@ class Indexer(nn.Module): # where we store value in fp8 and scale in fp32 # per self.quant_block_size element self.k_cache = DeepseekV32IndexerCache( - head_dim=self.head_dim + self.head_dim // self.quant_block_size * 4, - dtype=torch.uint8, + head_dim=self.head_dim, + dtype=torch.bfloat16, prefix=f"{prefix}.k_cache", cache_config=cache_config, ) @@ -776,23 +699,61 @@ class Indexer(nn.Module): k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1) # we only quant q here since k quant is fused with cache insertion - q = q.view(-1, self.head_dim) - q_fp8, q_scale = per_token_group_quant_fp8( - q, - self.quant_block_size, - column_major_scales=False, - use_ue8m0=self.scale_fmt is not None, - ) - q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim) - q_scale = q_scale.view(-1, self.n_head, 1) + # q = q.view(-1, self.head_dim) + # q_fp8, q_scale = per_token_group_quant_fp8( + # q, + # self.quant_block_size, + # column_major_scales=False, + # use_ue8m0=self.scale_fmt is not None, + # ) + # q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim) + # q_scale = q_scale.view(-1, self.n_head, 1) weights, _ = self.weights_proj(hidden_states) weights = ( - weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5 + weights.unsqueeze(-1) * self.softmax_scale * self.n_head**-0.5 ) weights = weights.squeeze(-1) - return self.indexer_op(hidden_states, q_fp8, k, weights) + return self.indexer_op(hidden_states, q, k, weights) + + +def _min_latency_fused_qkv_a_proj_impl( + input_: torch.Tensor, + weight: torch.Tensor, +) -> torch.Tensor: + """ + Dynamically run min-latency gemm if num_tokens <= 16. + This must be wrapped in a custom op because our torch.compile integration + does not support runtime dispatching on num_tokens. + """ + num_tokens = input_.shape[0] + if 0 < num_tokens <= 16: + output = torch.empty( + num_tokens, + weight.shape[0], + dtype=torch.bfloat16, + device=input_.device, + ) + ops.dsv3_fused_a_gemm(output, input_, weight.T) + return output + else: + return torch.nn.functional.linear(input_, weight) + + +def _min_latency_fused_qkv_a_proj_fake( + input_: torch.Tensor, + weight: torch.Tensor, +) -> torch.Tensor: + return input_.new_empty(input_.shape[0], weight.shape[0]) + + +direct_register_custom_op( + op_name="min_latency_fused_qkv_a_proj", + op_func=_min_latency_fused_qkv_a_proj_impl, + mutates_args=[], + fake_impl=_min_latency_fused_qkv_a_proj_fake, +) class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear): @@ -830,19 +791,8 @@ class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear): self, input_, ) -> torch.Tensor | tuple[torch.Tensor, torch.nn.Parameter | None]: - num_tokens = input_.shape[0] - if self._use_min_latency_gemm and (0 < num_tokens <= 16): - output = torch.empty( - num_tokens, - 2112, - dtype=torch.bfloat16, - device=input_.device, - ) - ops.dsv3_fused_a_gemm( - output, - input_, - self.weight.T, - ) + if self._use_min_latency_gemm: + output = torch.ops.vllm.min_latency_fused_qkv_a_proj(input_, self.weight) if not self.return_bias: return output output_bias = self.bias if self.skip_bias_add else None @@ -898,47 +848,35 @@ class DeepseekV2MLAAttention(nn.Module): self.max_position_embeddings = max_position_embeddings if self.q_lora_rank is not None: - # self.fused_qkv_a_proj = DeepSeekV2FusedQkvAProj( - # self.hidden_size, - # [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], - # quant_config=quant_config, - # prefix=f"{prefix}.fused_qkv_a_proj", - # ) - self.fused_qkv_a_proj = MergedColumnParallelLinear( - self.hidden_size, - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.fused_qkv_a_proj", - disable_tp=True, - ) + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj") + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(self.q_lora_rank, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj") else: - self.kv_a_proj_with_mqa = ReplicatedLinear( - self.hidden_size, - self.kv_lora_rank + self.qk_rope_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.kv_a_proj_with_mqa", - ) + self.q_proj = ColumnParallelLinear(self.hidden_size, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") - if self.q_lora_rank is not None: - self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear( - self.q_lora_rank, - self.num_heads * self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_b_proj", - ) - else: - self.q_proj = ColumnParallelLinear( - self.hidden_size, - self.num_heads * self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_proj", - ) - self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) self.kv_b_proj = ColumnParallelLinear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), @@ -1005,9 +943,7 @@ class DeepseekV2MLAAttention(nn.Module): kv_b_proj=self.kv_b_proj, rotary_emb=self.rotary_emb, o_proj=self.o_proj, - fused_qkv_a_proj=self.fused_qkv_a_proj - if self.q_lora_rank is not None - else None, + q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None, kv_a_proj_with_mqa=self.kv_a_proj_with_mqa if self.q_lora_rank is None else None, @@ -1346,14 +1282,14 @@ class DeepseekV2ForCausalLM( # initializing DeepseekV2Model, as it is passed inplace to # quantization config init and may be used to select the # quant_method for relevant layers during initialization. - self.fuse_qkv_a_proj = ( - hasattr(config, "q_lora_rank") and config.q_lora_rank is not None - ) - if self.fuse_qkv_a_proj: - self.packed_modules_mapping["fused_qkv_a_proj"] = [ - "q_a_proj", - "kv_a_proj_with_mqa", - ] + # self.fuse_qkv_a_proj = ( + # hasattr(config, "q_lora_rank") and config.q_lora_rank is not None + # ) + # if self.fuse_qkv_a_proj: + # self.packed_modules_mapping["fused_qkv_a_proj"] = [ + # "q_a_proj", + # "kv_a_proj_with_mqa", + # ] self.model = self.model_cls( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") @@ -1385,19 +1321,19 @@ class DeepseekV2ForCausalLM( self.moe_layers = [] self.moe_mlp_layers = [] example_moe = None + for layer in self.model.layers: if isinstance(layer, PPMissingLayer): continue - assert isinstance(layer, DeepseekV2DecoderLayer) if isinstance(layer.mlp, DeepseekV2MoE): # Pick last one layer since the first ones may be dense layers. example_moe = layer.mlp self.moe_mlp_layers.append(layer.mlp) self.moe_layers.append(layer.mlp.experts) - self.extract_moe_parameters(example_moe) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) @@ -1441,10 +1377,10 @@ class DeepseekV2ForCausalLM( ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] - mla_params_mapping = [ - ("fused_qkv_a_proj", "q_a_proj", 0), - ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), - ] + # mla_params_mapping = [ + # ("fused_qkv_a_proj", "q_a_proj", 0), + # ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), + # ] mha_params_mapping = [ ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), @@ -1452,8 +1388,8 @@ class DeepseekV2ForCausalLM( ] if self.use_mha: stacked_params_mapping.extend(mha_params_mapping) - else: - stacked_params_mapping.extend(mla_params_mapping) + # else: + # stacked_params_mapping.extend(mla_params_mapping) # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) @@ -1474,168 +1410,232 @@ class DeepseekV2ForCausalLM( params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - - spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) - if spec_layer is not None: - continue # skip spec decode layers for main model - - is_fusion_moe_shared_experts_layer = ( - rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name) - ) - - for param_name, weight_name, shard_id in stacked_params_mapping: - # Skip non-stacked layers and experts (experts handled below). - if weight_name not in name: - continue - # We have mlp.experts[0].gate_proj in the checkpoint. - # Since we handle the experts below in expert_params_mapping, - # we need to skip here BEFORE we update the name, otherwise - # name will be updated to mlp.experts[0].gate_up_proj, which - # will then be updated below in expert_params_mapping - # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if ("mlp.experts." in name) and name not in params_dict: - continue - if is_fusion_moe_shared_experts_layer: - continue - name_mapped = name.replace(weight_name, param_name) - - # QKV fusion is optional, fall back to normal - # weight loading if it's not enabled - # if go with fusion option, then update name - if ( - param_name == "fused_qkv_a_proj" - ) and name_mapped not in params_dict: - continue - else: - name = name_mapped - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + try: + if "rotary_emb.inv_freq" in name: continue - if is_pp_missing_parameter(name, self): - continue + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - is_expert_weight = False - - # Special handling: when AITER fusion_shared_experts is enabled, - # checkpoints may provide a single widened shared_experts tensor - # without explicit expert indices - # (e.g. ...mlp.shared_experts.gate_proj.weight). - # For models with multiple shared experts, split that tensor - # evenly into per-shared-expert slices and load them into - # appended expert slots mlp.experts.{n_routed_experts + j}.* - # accordingly. - num_chunks = 1 - if is_fusion_moe_shared_experts_layer: - num_chunks = getattr(self.config, "n_shared_experts", 1) or 1 - # Determine split axis based on op type - # gate/up: ColumnParallel → split along dim 0 - # down: RowParallel → split along dim 1 - split_dim = ( - 1 - if ("down_proj.weight" in name and loaded_weight.ndim > 1) - else 0 - ) - total = loaded_weight.shape[split_dim] - assert total % num_chunks == 0, ( - f"Shared expert weight dim {total} " - f"not divisible by num_chunks {num_chunks}" - ) - chunk_size = total // num_chunks - - for j in range(num_chunks): - chunk_name = name - weight_to_load = loaded_weight + is_fusion_moe_shared_experts_layer = ( + rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name) + ) + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue if is_fusion_moe_shared_experts_layer: - chunk_slice = slice(j * chunk_size, (j + 1) * chunk_size) - if loaded_weight.ndim == 1: - weight_to_load = loaded_weight[chunk_slice] - elif split_dim == 0: - weight_to_load = loaded_weight[chunk_slice, :] - else: - weight_to_load = loaded_weight[:, chunk_slice] - # Synthesize an expert-style name so expert mapping - # can route it - chunk_name = name.replace( - "mlp.shared_experts", - f"mlp.experts.{self.config.n_routed_experts + j}", - ) + continue + name_mapped = name.replace(weight_name, param_name) - # Use expert_params_mapping to locate the destination - # param and delegate to its expert-aware weight_loader - # with expert_id. - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in chunk_name: - continue - - # Anyway, this is an expert weight and should not be - # attempted to load as other weights later - is_expert_weight = True - - # Do not modify `name` since the loop may continue here - # Instead, create a new variable - name_mapped = chunk_name.replace(weight_name, param_name) - - if is_pp_missing_parameter(name_mapped, self): - continue - - param = params_dict[name_mapped] - # We should ask the weight loader to return success or - # not here since otherwise we may skip experts with - # other available replicas. - weight_loader = typing.cast( - Callable[..., bool], param.weight_loader - ) - success = weight_loader( - param, - weight_to_load, - name_mapped, - shard_id=shard_id, - expert_id=expert_id, - return_success=True, - ) - if success: - if not is_fusion_moe_shared_experts_layer: - name = name_mapped - else: - loaded_params.add(name_mapped) - break + # QKV fusion is optional, fall back to normal + # weight loading if it's not enabled + # if go with fusion option, then update name + if ( + param_name == "fused_qkv_a_proj" + ) and name_mapped not in params_dict: + continue else: - if is_expert_weight: - # We've checked that this is an expert weight - # However it's not mapped locally to this rank - # So we simply skip it - continue + name = name_mapped + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue + if is_pp_missing_parameter(name, self): + continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + is_expert_weight = False - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader + # Special handling: when AITER fusion_shared_experts is enabled, + # checkpoints may provide a single widened shared_experts tensor + # without explicit expert indices + # (e.g. ...mlp.shared_experts.gate_proj.weight). + # For models with multiple shared experts, split that tensor + # evenly into per-shared-expert slices and load them into + # appended expert slots mlp.experts.{n_routed_experts + j}.* + # accordingly. + num_chunks = 1 + if is_fusion_moe_shared_experts_layer: + num_chunks = getattr(self.config, "n_shared_experts", 1) or 1 + # Determine split axis based on op type + # gate/up: ColumnParallel → split along dim 0 + # down: RowParallel → split along dim 1 + split_dim = ( + 1 + if ("down_proj.weight" in name and loaded_weight.ndim > 1) + else 0 ) - weight_loader(param, loaded_weight) - if name is not None and not is_fusion_moe_shared_experts_layer: - loaded_params.add(name) + total = loaded_weight.shape[split_dim] + assert total % num_chunks == 0, ( + f"Shared expert weight dim {total} " + f"not divisible by num_chunks {num_chunks}" + ) + chunk_size = total // num_chunks + for j in range(num_chunks): + chunk_name = name + weight_to_load = loaded_weight + + if is_fusion_moe_shared_experts_layer: + chunk_slice = slice(j * chunk_size, (j + 1) * chunk_size) + if loaded_weight.ndim == 1: + weight_to_load = loaded_weight[chunk_slice] + elif split_dim == 0: + weight_to_load = loaded_weight[chunk_slice, :] + else: + weight_to_load = loaded_weight[:, chunk_slice] + # Synthesize an expert-style name so expert mapping + # can route it + chunk_name = name.replace( + "mlp.shared_experts", + f"mlp.experts.{self.config.n_routed_experts + j}", + ) + + # Use expert_params_mapping to locate the destination + # param and delegate to its expert-aware weight_loader + # with expert_id. + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in chunk_name: + continue + # Anyway, this is an expert weight and should not be + # attempted to load as other weights later + is_expert_weight = True + + # Do not modify `name` since the loop may continue here + # Instead, create a new variable + name_mapped = chunk_name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name_mapped, self): + continue + + param = params_dict[name_mapped] + # We should ask the weight loader to return success or + # not here since otherwise we may skip experts with + # other available replicas. + weight_loader = typing.cast( + Callable[..., bool], param.weight_loader + ) + success = weight_loader( + param, + weight_to_load, + name_mapped, + shard_id=shard_id, + expert_id=expert_id, + return_success=True, + ) + if success: + if not is_fusion_moe_shared_experts_layer: + name = name_mapped + else: + loaded_params.add(name_mapped) + break + else: + if is_expert_weight: + # We've checked that this is an expert weight + # However it's not mapped locally to this rank + # So we simply skip it + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + if name is not None and not is_fusion_moe_shared_experts_layer: + loaded_params.add(name) + except: + pass + + opt_support_quant_method = ["GGUFLinearMethod", "UnquantizedLinearMethod", "CompressedTensorsW8A8Int8", "AWQMarlinLinearMethod"] + # add your opt here.. + def inject_layer(layer, quant_method, is_mla): + q_lora_rank = getattr(layer, "q_lora_rank", None) + if quant_method in ["UnquantizedLinearMethod", "CompressedTensorsW8A8Int8"]: + if q_lora_rank is not None: + layer.q_a_proj.weight.data = torch.cat([layer.q_a_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=0) + if hasattr(layer.q_a_proj, "weight_scale"): + layer.q_a_proj.weight_scale.data = torch.cat([layer.q_a_proj.weight_scale, layer.kv_a_proj_with_mqa.weight_scale], dim=0) + del layer.kv_a_proj_with_mqa.weight_scale + elif not is_mla: + layer.q_proj.weight.data = torch.cat([layer.q_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=0) + if hasattr(layer.q_proj, "weight_scale"): + layer.q_proj.weight_scale.data = torch.cat([layer.q_proj.weight_scale, layer.kv_a_proj_with_mqa.weight_scale], dim=0) + del layer.kv_a_proj_with_mqa.weight_scale + else: + return + del layer.kv_a_proj_with_mqa.weight + del layer.kv_a_proj_with_mqa + if is_mla: + layer.mla_attn.forward = layer.mla_attn.forward_opt + else: + layer.forward = layer.forward_opt + elif quant_method == "GGUFLinearMethod": + pass + elif quant_method == "AWQMarlinLinearMethod": + dtype = layer.kv_a_proj_with_mqa.qweight.dtype + assert dtype == torch.int32 + if layer.q_lora_rank is not None: + layer.q_a_proj.qweight.data = torch.cat([layer.q_a_proj.qweight, layer.kv_a_proj_with_mqa.qweight], dim=1) + layer.q_a_proj.scales.data = torch.cat([layer.q_a_proj.scales, layer.kv_a_proj_with_mqa.scales], dim=1) + del layer.kv_a_proj_with_mqa.scales + layer.q_a_proj.qzeros.data = torch.cat([layer.q_a_proj.qzeros, layer.kv_a_proj_with_mqa.qzeros], dim=1) + del layer.kv_a_proj_with_mqa.qzeros + elif not is_mla: + layer.q_proj.weight.data = torch.cat([layer.q_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=1) + layer.q_proj.scales.data = torch.cat([layer.q_proj.scales, layer.kv_a_proj_with_mqa.scales], dim=1) + del layer.kv_a_proj_with_mqa.scales + layer.q_proj.qzeros.data = torch.cat([layer.q_proj.qzeros, layer.kv_a_proj_with_mqa.qzeros], dim=1) + del layer.kv_a_proj_with_mqa.qzeros + else: + return + + del layer.kv_a_proj_with_mqa.qweight + del layer.kv_a_proj_with_mqa + if is_mla: + layer.mla_attn.forward = layer.mla_attn.forward_opt + else: + layer.forward = layer.forward_opt + else: + pass + + for _, layer in self.model.named_modules(): + if layer.__class__.__name__ in ["DeepseekV2Attention","DeepseekV2MLAAttention"]: + if hasattr(layer.kv_a_proj_with_mqa, "scheme"): + quant_method = layer.kv_a_proj_with_mqa.scheme.__class__.__name__ + else: + quant_method = layer.kv_a_proj_with_mqa.quant_method.__class__.__name__ + if quant_method not in opt_support_quant_method: + break + + inject_layer(layer, quant_method, is_mla = layer.__class__.__name__ == "DeepseekV2MLAAttention") + return loaded_params diff --git a/vllm/model_executor/models/ernie45_moe.py b/vllm/model_executor/models/ernie45_moe.py index f038cfb..3cfdb4d 100644 --- a/vllm/model_executor/models/ernie45_moe.py +++ b/vllm/model_executor/models/ernie45_moe.py @@ -164,7 +164,7 @@ class Ernie4_5_MoeMoE(nn.Module): config.hidden_size, config.moe_num_experts, bias=False, - params_dtype=torch.float32, + # params_dtype=torch.float32, quant_config=None, prefix=f"{prefix}.gate", ) @@ -209,7 +209,7 @@ class Ernie4_5_MoeMoE(nn.Module): hidden_dim = hidden_states.shape[-1] hidden_states = hidden_states.view(-1, hidden_dim) - router_logits, _ = self.gate(hidden_states.to(dtype=torch.float32)) + router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits @@ -429,7 +429,8 @@ class Ernie4_5_MoeModel(nn.Module): self.num_redundant_experts = eplb_config.num_redundant_experts - if get_pp_group().is_first_rank: + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, @@ -653,11 +654,11 @@ class Ernie4_5_MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, MixtureOfExpe quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight else: self.lm_head = PPMissingLayer() - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 1df4adf..edf4c2c 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -829,16 +829,31 @@ class Ernie4_5_VLProcessingInfo(BaseProcessingInfo): spatial_conv_size = hf_config.spatial_conv_size temporal_conv_size = hf_config.temporal_conv_size + if self.ctx.model_config.trust_remote_code: + # Defined in HF Hub repo + min_pixels_key = "min_pixels" + max_pixels_key = "max_pixels" + else: + # Defined in Transformers library (requires v5.0 or above) + min_pixels_key = "shortest_edge" + max_pixels_key = "longest_edge" + mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs) - size = mm_kwargs.get("size", image_processor.size) + size = image_processor.size + if override_size := mm_kwargs.get("size"): + size = size | override_size + if (override_min_pixels := mm_kwargs.get("min_pixels")) is not None: + size = size | {min_pixels_key: override_min_pixels} + if (override_max_pixels := mm_kwargs.get("max_pixels")) is not None: + size = size | {max_pixels_key: override_max_pixels} if do_resize: resized_height, resized_width = smart_resize( height=image_height, width=image_width, factor=patch_size * spatial_conv_size, - min_pixels=size["min_pixels"], - max_pixels=size["max_pixels"], + min_pixels=size[min_pixels_key], + max_pixels=size[max_pixels_key], ) preprocessed_size = ImageSize(width=resized_width, height=resized_height) else: diff --git a/vllm/model_executor/models/extract_hidden_states.py b/vllm/model_executor/models/extract_hidden_states.py new file mode 100644 index 0000000..ae9bdb5 --- /dev/null +++ b/vllm/model_executor/models/extract_hidden_states.py @@ -0,0 +1,394 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Hidden States Extractor Model. + +This model extracts and caches hidden states from the target model +without performing actual token generation. It's used with the +extract_hidden_states speculative decoding method. +""" + +from collections.abc import Iterable +from typing import ClassVar + +import torch +import torch.nn as nn + +from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config +from vllm.config.cache import CacheDType +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.attention.attention import set_default_quant_scales +from vllm.model_executor.layers.attention.kv_transfer_utils import ( + maybe_transfer_kv_layer, +) +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.model_executor.models.utils import maybe_prefix +from vllm.utils.torch_utils import kv_cache_dtype_str_to_dtype +from vllm.v1.attention.backend import ( + AttentionBackend, + AttentionImpl, + AttentionMetadataBuilder, + AttentionType, + CommonAttentionMetadata, + is_quantized_kv_cache, +) +from vllm.v1.kv_cache_interface import ( + AttentionSpec, + KVCacheSpec, + MLAAttentionSpec, +) + +########## Custom Ops ######## + + +def unified_kv_cache_update( + to_cache: torch.Tensor, + layer_name: str, +) -> torch.Tensor: + """ + Returns a dummy that is passed to unified_attention to signal a side effect and + the data dependency between them to ensure torch.compile preserves ordering. + """ + forward_context = get_forward_context() + attn_layer = forward_context.no_compile_layers[layer_name] + kv_cache = attn_layer.kv_cache[forward_context.virtual_engine] + + slot_mapping = forward_context.slot_mapping + assert isinstance(slot_mapping, dict), ( + f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. " + ) + layer_slot_mapping = slot_mapping.get(layer_name) + if layer_slot_mapping is not None: + assert hasattr(attn_layer.impl, "do_kv_cache_update"), ( + f"{attn_layer.impl.__class__.__name__} does not support kv cache update" + ) + attn_layer.impl.do_kv_cache_update( + attn_layer, + to_cache, + kv_cache, + layer_slot_mapping, + ) + + return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype) + + +@maybe_transfer_kv_layer +def dummy_attention(layer_name, _placeholder): + # Note: layer_name arg required by @maybe_transfer_kv_layer + return _placeholder + + +def basic_cache( + to_cache: torch.Tensor, # shape: [num_blocks, block_size, num_heads, head_size] + kv_cache: torch.Tensor, # shape: [seq_len, num_heads, head_size] + slot_mapping: torch.Tensor, # shape: [seq_len] +): + num_blocks, block_size, num_heads, head_size = kv_cache.shape + token_kv_cache = kv_cache.view(num_blocks * block_size, num_heads, head_size) + token_kv_cache[slot_mapping] = to_cache + + +######### CacheOnlyAttentionBackend ######## + + +class CacheOnlyAttentionBackend(AttentionBackend): + """Attention backend that only caches KV without computing attention.""" + + accept_output_buffer: bool = False + supported_dtypes: ClassVar[list[torch.dtype]] = [ + torch.float16, + torch.bfloat16, + torch.float32, + ] + supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ + "auto", + "bfloat16", + ] + forward_includes_kv_cache_update: bool = False + + @staticmethod + def get_name() -> str: + return "CACHE_ONLY_ATTN" + + @classmethod + def supports_attn_type(cls, attn_type: str) -> bool: + return attn_type == AttentionType.DECODER + + @classmethod + def supports_mm_prefix(cls) -> bool: + return True + + @staticmethod + def get_impl_cls() -> type["CacheOnlyAttentionImpl"]: + return CacheOnlyAttentionImpl + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + # We set `num_kv_heads = num_hidden_layers` and `head_size = hidden_size` + # We also don't use a k/v (2) dim + return (num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_builder_cls() -> type["CacheOnlyAttentionMetadataBuilder"]: + return CacheOnlyAttentionMetadataBuilder + + @staticmethod + def use_cascade_attention(*args, **kwargs) -> bool: + return False + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [] + + +class CacheOnlyAttentionMetadata: + def __init__(self, slot_mapping: torch.Tensor): + self.slot_mapping = slot_mapping + + +class CacheOnlyAttentionMetadataBuilder( + AttentionMetadataBuilder[CacheOnlyAttentionMetadata] +): + def __init__( + self, + kv_cache_spec: AttentionSpec, + layer_names: list[str], + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(kv_cache_spec, layer_names, vllm_config, device) + + def build( + self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False, + ) -> CacheOnlyAttentionMetadata: + use_cascade = common_prefix_len > 0 + if use_cascade: + raise NotImplementedError( + "Cascade attention not supported by CacheOnlyAttention" + ) + causal = common_attn_metadata.causal + if not causal: + raise NotImplementedError( + "Non-causal attention not supported by CacheOnlyAttention" + ) + + return CacheOnlyAttentionMetadata( + slot_mapping=common_attn_metadata.slot_mapping, + ) + + +class CacheOnlyAttentionImpl(AttentionImpl): + """Attention implementation that only caches KV states.""" + + def __init__( + self, + num_heads: int, + head_size: int, + kv_cache_dtype: str, + kv_cache_torch_dtype: torch.dtype, + attn_type: AttentionType = AttentionType.DECODER, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.kv_cache_dtype = kv_cache_dtype + self.kv_cache_torch_dtype = kv_cache_torch_dtype + + if attn_type != AttentionType.DECODER: + raise NotImplementedError(f"Unsupported attention type: {attn_type}") + if is_quantized_kv_cache(kv_cache_dtype): + raise NotImplementedError("Quantized KV cache not supported") + + self.num_queries_per_kv = 1 + + def do_kv_cache_update( + self, + layer, + to_cache, + kv_cache, + slot_mapping, + ): + assert to_cache.dtype == self.kv_cache_torch_dtype, ( + f"Data to cache must be {self.kv_cache_torch_dtype}, got {to_cache.dtype}" + ) + assert kv_cache.dtype == self.kv_cache_torch_dtype, ( + f"KV cache must be {self.kv_cache_torch_dtype}, got {kv_cache.dtype}" + ) + + basic_cache(to_cache, kv_cache, slot_mapping) + + def forward(self, *args, **kwargs): + # Empty implementation of abstract method + pass + + +############## CacheOnlyAttentionLayer (replaces Attention) ############ + + +class CacheOnlyAttentionLayer(nn.Module, AttentionLayerBase): + """Attention layer that only caches key/value states without computing attention.""" + + def __init__( + self, + num_heads: int, + head_size: int, + cache_config: CacheConfig | None = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER, + ): + super().__init__() + + self.num_heads = num_heads + self.head_size = head_size + self.layer_name = prefix + + vllm_config = get_current_vllm_config() + + # KV cache configuration + cache_config = cache_config or vllm_config.cache_config + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + self.block_size = cache_config.block_size + else: + kv_cache_dtype = "auto" + self.block_size = 16 + + assert kv_cache_dtype in ["auto", "bfloat16", "float16"], ( + "CacheOnlyAttentionLayer doesn't currently support quantized kv cache but" + f"kv cache dtype was set to {kv_cache_dtype}" + ) + self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype( + kv_cache_dtype, vllm_config.model_config + ) + + # Initialize KV cache quantization attributes + set_default_quant_scales(self, register_buffer=True) + + # Attention backend + self.attn_backend = CacheOnlyAttentionBackend + impl_cls = self.attn_backend.get_impl_cls() + self.impl = impl_cls( + num_heads, + head_size, + kv_cache_dtype, + self.kv_cache_torch_dtype, + attn_type, + ) + + assert not self.attn_backend.forward_includes_kv_cache_update, ( + "KV cache update should be independent of forward" + ) + + # Placeholder KV cache (replaced by bind_kv_cache) + self.kv_cache = [ + torch.tensor([]) + for _ in range(vllm_config.parallel_config.pipeline_parallel_size) + ] + + # Register in compilation context + compilation_config = vllm_config.compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + def forward(self, to_cache: torch.Tensor) -> torch.Tensor: + """Cache hidden states as KV pairs without computing attention. + + Args: + to_cache: The tensor to insert into the kv cache. + shape [num_tokens, num_heads, head_size] + + Returns: + Dummy output tensor (not used) + """ + # Note: we set num_heads to num_hidden_layers and + # head_size to hidden_size for hidden states storage + output = torch.empty(0, device=to_cache.device, dtype=to_cache.dtype) + + # Note: dummy_out is used to force torch.compile to preserve ordering between + # cache update and attention op (which triggers kv_connector transfer) + dummy_out = unified_kv_cache_update(to_cache, self.layer_name) + + # Triggers kv_connector transfer via decorator + _ = dummy_attention(self.layer_name, dummy_out) + + return output + + def get_attn_backend(self) -> type[AttentionBackend]: + return self.attn_backend + + def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec: + # Note: we use MLAAttentionSpec here to because it will + # produce page sizes of (block_size * num_kv_heads * head_size * dtype_size) + # whereas FullAttentionSpec will add an additional factor of 2 + return MLAAttentionSpec( + block_size=self.block_size, + num_kv_heads=self.num_heads, + head_size=self.head_size, + dtype=self.kv_cache_torch_dtype, + ) + + +############ ExtractHiddenStatesModel definition ########## + + +class ExtractHiddenStatesModel(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + self.vllm_config = vllm_config + self.hf_config = vllm_config.speculative_config.draft_model_config.hf_config + self.hidden_size = vllm_config.model_config.get_hidden_size() + self.target_num_hidden_layers = ( + vllm_config.model_config.get_total_num_hidden_layers() + ) + self.num_hidden_states = len( + getattr(self.hf_config, "eagle_aux_hidden_state_layer_ids", []) + ) + + cache_config = vllm_config.cache_config + + # Create a single cache-only attention layer + # Note: We set num_heads <- self.num_hidden_states + # and head_size <- hidden_size so that we can insert + # the hidden states directly into the cache without + # reshaping + self.cache_only_layers = nn.ModuleDict( + { + str(self.target_num_hidden_layers): CacheOnlyAttentionLayer( + num_heads=self.num_hidden_states, + head_size=self.hidden_size, + cache_config=cache_config, + prefix=maybe_prefix( + prefix, f"cache_only_layers.{self.target_num_hidden_layers}" + ), + ) + } + ) + + def forward(self, hidden_states: torch.Tensor) -> None: + """Process and cache hidden states. + + Args: + hidden_states: Hidden states from target model + shape: [num_tokens, num_hidden_states, hidden_size] + + Returns: + Tuple of (dummy_output, dummy_output) - both unused + """ + + # Call dummy attention layer to cache hidden states + # Output is ignored - we only care about the KV cache side effects + _ = self.cache_only_layers[str(self.target_num_hidden_layers)](hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """No weights to load for this dummy model.""" + return set() diff --git a/vllm/model_executor/models/fireredasr2.py b/vllm/model_executor/models/fireredasr2.py new file mode 100644 index 0000000..f0d3e12 --- /dev/null +++ b/vllm/model_executor/models/fireredasr2.py @@ -0,0 +1,829 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import Annotated, Literal, cast + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from transformers import ( + BatchFeature, + Qwen2Config, +) + +from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs.data import PromptType +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY +from vllm.model_executor.layers.linear import ( + ReplicatedLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.models.whisper_utils import ( + ISO639_1_SUPPORTED_LANGS, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser +from vllm.multimodal.processing import ( + BaseDummyInputsBuilder, + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) +from vllm.transformers_utils.processor import cached_processor_from_config +from vllm.transformers_utils.processors.fireredasr2_processor import ( + FireRedASR2FeatureExtractor, +) +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from .interfaces import ( + MultiModalEmbeddings, + SupportsMultiModal, + SupportsTranscription, + _require_is_multimodal, +) +from .qwen2 import Qwen2ForCausalLM +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + _merge_multimodal_embeddings, + maybe_prefix, +) + +logger = init_logger(__name__) + + +class FireRedASR2AudioInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - nmb: Number of mel bins + - t: Time frames (M) + """ + + input_features: Annotated[ + list[torch.Tensor] | None, + TensorShape("b", "nmb", "t"), + ] + speech_lengths: Annotated[ + list[torch.Tensor] | None, + TensorShape("b"), + ] + fake_token_lengths: Annotated[ + list[torch.Tensor] | None, + TensorShape("b"), + ] + + +class Swish(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.sigmoid(x) + + +class Conv2dSubsampling(nn.Module): + def __init__(self, idim: int, d_model: int, out_channels: int = 32): + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(1, out_channels, 3, 2), + nn.ReLU(), + nn.Conv2d(out_channels, out_channels, 3, 2), + nn.ReLU(), + ) + subsample_idim = ((idim - 1) // 2 - 1) // 2 + self.out = ReplicatedLinear( + input_size=out_channels * subsample_idim, + output_size=d_model, + bias=True, + ) + + self.subsampling = 4 + left_context = right_context = 3 # both exclude currect frame + self.context = left_context + 1 + right_context # 7 + + def forward( + self, x: torch.Tensor, x_mask: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + x = x.unsqueeze(1) + x = self.conv(x) + N, C, T, D = x.size() + x, _ = self.out(x.transpose(1, 2).contiguous().view(N, T, C * D)) + mask = x_mask[:, :, :-2:2][:, :, :-2:2] + input_lengths = mask[:, -1, :].sum(dim=-1) + return x, input_lengths, mask + + +class RelPositionalEncoding(nn.Module): + def __init__(self, d_model: int, max_len: int = 5000): + super().__init__() + pe_positive = torch.zeros(max_len, d_model, requires_grad=False) + pe_negative = torch.zeros(max_len, d_model, requires_grad=False) + position = torch.arange(0, max_len).unsqueeze(1).float() + div_term = torch.exp( + torch.arange(0, d_model, 2).float() + * -(torch.log(torch.tensor(10000.0)).item() / d_model) + ) + pe_positive[:, 0::2] = torch.sin(position * div_term) + pe_positive[:, 1::2] = torch.cos(position * div_term) + pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) + pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) + + pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) + pe_negative = pe_negative[1:].unsqueeze(0) + self.pe = torch.cat([pe_positive, pe_negative], dim=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Tmax = 2 * max_len - 1 + Tmax, T = self.pe.size(1), x.size(1) + pos_emb = self.pe[:, Tmax // 2 - T + 1 : Tmax // 2 + T].clone().detach() + return pos_emb + + +class ConformerFeedForward(nn.Module): + def __init__(self, d_model: int): + super().__init__() + self.pre_layer_norm = nn.LayerNorm(d_model) + self.linear_expand = ReplicatedLinear( + input_size=d_model, + output_size=d_model * 4, + bias=True, + ) + self.nonlinear = Swish() + self.linear_project = ReplicatedLinear( + input_size=d_model * 4, + output_size=d_model, + bias=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.pre_layer_norm(x) + x, _ = self.linear_expand(x) + x = self.nonlinear(x) + x, _ = self.linear_project(x) + output = x + residual + return output + + +class EncoderMultiHeadAttention(nn.Module): + def __init__(self, n_head: int, d_model: int): + super().__init__() + assert d_model % n_head == 0 + self.n_head = n_head + self.d_k = d_model // n_head + self.d_v = self.d_k + + self.w_qs = ReplicatedLinear( + input_size=d_model, output_size=n_head * self.d_k, bias=False + ) + self.w_ks = ReplicatedLinear( + input_size=d_model, output_size=n_head * self.d_k, bias=False + ) + self.w_vs = ReplicatedLinear( + input_size=d_model, output_size=n_head * self.d_v, bias=False + ) + + self.layer_norm_q = nn.LayerNorm(d_model) + self.layer_norm_k = nn.LayerNorm(d_model) + self.layer_norm_v = nn.LayerNorm(d_model) + + self.fc = ReplicatedLinear( + input_size=n_head * self.d_v, output_size=d_model, bias=False + ) + + def forward_qkv( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + d_k, d_v, n_head = self.d_k, self.d_v, self.n_head + sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) + + q = self.layer_norm_q(q) + k = self.layer_norm_k(k) + v = self.layer_norm_v(v) + + q = self.w_qs(q)[0].view(sz_b, len_q, n_head, d_k) + k = self.w_ks(k)[0].view(sz_b, len_k, n_head, d_k) + v = self.w_vs(v)[0].view(sz_b, len_v, n_head, d_v) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + return q, k, v + + def forward_output( + self, output: torch.Tensor, residual: torch.Tensor, sz_b: int, len_q: int + ) -> torch.Tensor: + output = output.transpose(1, 2).contiguous().view(sz_b, len_q, -1) + fc_out, _ = self.fc(output) + output = fc_out + output = output + residual + return output + + def forward_attention( + self, attn: torch.Tensor, v: torch.Tensor, mask: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + if mask is not None: + mask = mask.unsqueeze(1) + mask = mask.eq(0) + attn = attn.masked_fill(mask, -float("inf")) + attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0) + else: + attn = torch.softmax(attn, dim=-1) + + d_attn = attn + output = torch.matmul(d_attn, v) + + return output, attn + + +class RelPosMultiHeadAttention(EncoderMultiHeadAttention): + def __init__(self, n_head: int, d_model: int): + super().__init__(n_head, d_model) + d_k = d_model // n_head + self.scale = 1.0 / (d_k**0.5) + self.linear_pos = ReplicatedLinear( + input_size=d_model, output_size=n_head * d_k, bias=False + ) + self.pos_bias_u = nn.Parameter(torch.empty([n_head, d_k])) + self.pos_bias_v = nn.Parameter(torch.empty([n_head, d_k])) + + def _rel_shift(self, x): + N, H, T1, T2 = x.size() + zero_pad = torch.zeros((N, H, T1, 1), device=x.device, dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(N, H, T2 + 1, T1) + x = x_padded[:, :, 1:].view_as(x) + x = x[:, :, :, : x.size(-1) // 2 + 1] + return x + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + pos_emb: torch.Tensor, + mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + sz_b, len_q = q.size(0), q.size(1) + + residual = q + q, k, v = self.forward_qkv(q, k, v) + + q = q.transpose(1, 2) + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb)[0].view(n_batch_pos, -1, self.n_head, self.d_k) + p = p.transpose(1, 2) + + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + matrix_bd = self._rel_shift(matrix_bd) + + attn_scores = matrix_ac + matrix_bd + attn_scores.mul_(self.scale) + + output, attn = self.forward_attention(attn_scores, v, mask=mask) + + output = self.forward_output(output, residual, sz_b, len_q) + return output, attn + + +class ConformerConvolution(nn.Module): + def __init__(self, d_model: int, kernel_size: int = 33): + super().__init__() + assert kernel_size % 2 == 1 + self.pre_layer_norm = nn.LayerNorm(d_model) + self.pointwise_conv1 = nn.Conv1d( + d_model, d_model * 4, kernel_size=1, bias=False + ) + self.padding = (kernel_size - 1) // 2 + self.depthwise_conv = nn.Conv1d( + d_model * 2, + d_model * 2, + kernel_size, + stride=1, + padding=self.padding, + groups=d_model * 2, + bias=False, + ) + self.batch_norm = nn.LayerNorm(d_model * 2) + self.swish = Swish() + self.pointwise_conv2 = nn.Conv1d( + d_model * 2, d_model, kernel_size=1, bias=False + ) + + def forward( + self, x: torch.Tensor, mask: torch.Tensor | None = None + ) -> torch.Tensor: + residual = x + out = self.pre_layer_norm(x) + out = out.transpose(1, 2) + if mask is not None: + out.masked_fill_(mask.ne(1), 0.0) + out = self.pointwise_conv1(out) + out = F.glu(out, dim=1) + out = self.depthwise_conv(out) + + out = out.transpose(1, 2) + out = self.swish(self.batch_norm(out)) + out = out.transpose(1, 2) + + out = self.pointwise_conv2(out) + if mask is not None: + out.masked_fill_(mask.ne(1), 0.0) + out = out.transpose(1, 2) + return out + residual + + +class RelPosEmbConformerBlock(nn.Module): + def __init__(self, d_model, n_head, kernel_size=33): + super().__init__() + self.ffn1 = ConformerFeedForward(d_model) + self.mhsa = RelPosMultiHeadAttention(n_head, d_model) + self.conv = ConformerConvolution(d_model, kernel_size) + self.ffn2 = ConformerFeedForward(d_model) + self.layer_norm = nn.LayerNorm(d_model) + + def forward( + self, + x: torch.Tensor, + pos_emb: torch.Tensor, + slf_attn_mask: torch.Tensor | None = None, + pad_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + out = 0.5 * x + 0.5 * self.ffn1(x) + out = self.mhsa(out, out, out, pos_emb, mask=slf_attn_mask)[0] + out = self.conv(out, pad_mask) + out = 0.5 * out + 0.5 * self.ffn2(out) + out = self.layer_norm(out) + return out + + +class ConformerEncoder(nn.Module): + def __init__( + self, + idim: int, + n_layers_enc: int, + n_head: int, + d_model: int, + kernel_size: int = 33, + pe_maxlen: int = 5000, + ): + super().__init__() + self.odim = d_model + + self.input_preprocessor = Conv2dSubsampling(idim, d_model) + self.positional_encoding = RelPositionalEncoding(d_model) + + self.layer_stack = nn.ModuleList() + for _ in range(n_layers_enc): + block = RelPosEmbConformerBlock(d_model, n_head, kernel_size) + self.layer_stack.append(block) + + def forward( + self, padded_input: torch.Tensor, input_lengths: torch.Tensor, pad: bool = True + ): + if pad: + padded_input = F.pad( + padded_input, + (0, 0, 0, self.input_preprocessor.context - 1), + "constant", + 0.0, + ) + src_mask = self.padding_position_is_0(padded_input, input_lengths) + + embed_output, input_lengths, src_mask = self.input_preprocessor( + padded_input, src_mask + ) + enc_output = embed_output + + pos_emb = self.positional_encoding(embed_output) + + enc_outputs = [] + for enc_layer in self.layer_stack: + enc_output = enc_layer( + enc_output, pos_emb, slf_attn_mask=src_mask, pad_mask=src_mask + ) + enc_outputs.append(enc_output) + + return enc_output, input_lengths, src_mask + + def padding_position_is_0( + self, padded_input: torch.Tensor, input_lengths: torch.Tensor + ) -> torch.Tensor: + N, T = padded_input.size()[:2] + mask = torch.ones((N, T)).to(padded_input.device) + for i in range(N): + mask[i, input_lengths[i] :] = 0 + mask = mask.unsqueeze(dim=1) + return mask.to(torch.uint8) + + +class FireRedASR2Adapter(nn.Module): + def __init__(self, encoder_dim: int, llm_dim: int, downsample_rate: int = 2): + super().__init__() + self.ds = downsample_rate + self.linear1 = ReplicatedLinear( + input_size=encoder_dim * downsample_rate, + output_size=llm_dim, + bias=True, + ) + self.relu = _ACTIVATION_REGISTRY["relu"] + self.linear2 = ReplicatedLinear( + input_size=llm_dim, + output_size=llm_dim, + bias=True, + ) + + def forward(self, x, x_lens): + batch_size, seq_len, feat_dim = x.size() + num_frames_to_discard = seq_len % self.ds + if num_frames_to_discard > 0: + x = x[:, :-num_frames_to_discard, :] + seq_len = x.size(1) + + x = x.contiguous() + x = x.view(batch_size, seq_len // self.ds, feat_dim * self.ds) + + x, _ = self.linear1(x) + x = self.relu(x) + x, _ = self.linear2(x) + + new_x_lens = torch.clamp(x_lens, max=seq_len) // self.ds + return x, new_x_lens + + +class FireRedASR2Encoder(nn.Module): + def __init__( + self, + *, + vllm_config: VllmConfig, + ): + super().__init__() + self.audio_encoder = ConformerEncoder( + **vllm_config.model_config.hf_config.audio_encoder_conf + ) + + +class FireRedASR2Model(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.encoder = FireRedASR2Encoder( + vllm_config=vllm_config, + ) + encoder_dim = self.encoder.audio_encoder.odim + llm_dim = vllm_config.model_config.hf_config.hidden_size + self.encoder_projector = FireRedASR2Adapter( + encoder_dim, + llm_dim, + vllm_config.model_config.hf_config.encoder_downsample_rate, + ) + + self.decoder = Qwen2ForCausalLM( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "decoder") + ) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + decoder_outputs = self.decoder( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + ) + return decoder_outputs + + def get_encoder_outputs( + self, + speech: torch.Tensor | list[torch.Tensor] | None, + speech_lengths: torch.Tensor | list[torch.Tensor] | None, + ) -> torch.Tensor | None: + encoder_outs, enc_lengths, enc_mask = self.encoder.audio_encoder( + speech, speech_lengths + ) + speech_features, speech_lens = self.encoder_projector(encoder_outs, enc_lengths) + return speech_features + + +class FireRedASR2ProcessingInfo(BaseProcessingInfo): + def get_hf_config(self) -> Qwen2Config: + return self.ctx.get_hf_config(Qwen2Config) + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"audio": 1} + + def get_feature_extractor(self, **kwargs: object) -> FireRedASR2FeatureExtractor: + hf_processor = self.get_hf_processor(**kwargs) + feature_extractor = hf_processor.feature_extractor # type: ignore + assert isinstance(feature_extractor, FireRedASR2FeatureExtractor) + return feature_extractor + + def get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.get_feature_extractor() + return MultiModalDataParser( + target_sr=feature_extractor.sampling_rate, + target_channels=self.get_target_channels(), + ) + + def get_target_channels(self) -> int: + return 1 + + +class FireRedASR2DummyInputsBuilder(BaseDummyInputsBuilder[FireRedASR2ProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + + return "<|AUDIO|>" * num_audios + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions], + ) -> MultiModalDataDict: + feature_extractor = self.info.get_feature_extractor() + + sampling_rate = feature_extractor.sampling_rate + audio_len = feature_extractor.chunk_length * sampling_rate + num_audios = mm_counts.get("audio", 0) + + audio_overrides = mm_options.get("audio") + + ret = { + "audio": self._get_dummy_audios( + length=audio_len, num_audios=num_audios, overrides=audio_overrides + ) + } + return ret + + +class FireRedASR2MultiModalProcessor( + BaseMultiModalProcessor[FireRedASR2ProcessingInfo] +): + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + if mm_data: + feature_extractor = self.info.get_feature_extractor(**mm_kwargs) + mm_data = dict(audio=mm_data.pop("audios")) + mm_kwargs = dict( + **mm_kwargs, + sampling_rate=feature_extractor.sampling_rate, + ) + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + if "labels" in processed_outputs: + processed_outputs["input_ids"] = processed_outputs.pop("labels") + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + input_features=MultiModalFieldConfig.batched("audio"), + speech_lengths=MultiModalFieldConfig.batched("audio"), + fake_token_lengths=MultiModalFieldConfig.batched("audio"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + audio_token = getattr(processor, "audio_token", "<|AUDIO|>") + + audio_token_id = vocab[audio_token] + + out_mm_data = out_mm_kwargs.get_data() + + fake_token_lengths = out_mm_data.get("fake_token_lengths") + + if fake_token_lengths is None: + audio_output_lengths = [] + else: + assert isinstance(fake_token_lengths, torch.Tensor) + + audio_output_lengths = fake_token_lengths.tolist() + + def get_replacement_fireredasr2_audio(item_idx: int): + num_features = audio_output_lengths[item_idx] + + audio_tokens = [audio_token_id] * int(num_features) + + return PromptUpdateDetails.select_token_id( + audio_tokens, + embed_token_id=audio_token_id, + ) + + return [ + PromptReplacement( + modality="audio", + target=[audio_token_id], + replacement=get_replacement_fireredasr2_audio, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor( + FireRedASR2MultiModalProcessor, + info=FireRedASR2ProcessingInfo, + dummy_inputs=FireRedASR2DummyInputsBuilder, +) +class FireRedASR2ForConditionalGeneration( + nn.Module, SupportsTranscription, SupportsMultiModal +): + packed_modules_mapping = { + "self_attn.qkv_proj": [ + "self_attn.q_proj", + "self_attn.k_proj", + "self_attn.v_proj", + ], + "encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"], + } + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + "llm.": "model.decoder.", + "encoder.": "model.encoder.audio_encoder.", + "encoder_projector.": "model.encoder_projector.", + "net.0": "pre_layer_norm", + "net.1": "linear_expand", + "net.4": "linear_project", + } + ) + + supports_transcription_only = True + supports_segment_timestamp = True + supported_languages = ISO639_1_SUPPORTED_LANGS + + @classmethod + def validate_language(cls, language: str | None) -> str | None: + if language is None: + # TODO language should be optional and can be guessed. + # For now we default to en. See + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520 + logger.warning( + "Defaulting to language='en'. If you wish to transcribe " + "audio in a different language, pass the `language` field " + "in the TranscriptionRequest." + ) + language = "en" + return super().validate_language(language) + + @classmethod + def get_generation_prompt( + cls, + audio: np.ndarray, + model_config: ModelConfig, # not needed here + stt_config: SpeechToTextConfig, + language: str | None, + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: str | None, + ) -> PromptType: + if language is None: + raise ValueError( + "Language must be specified when creating the fireredasr2 prompt" + ) + + prompt_str = "<|im_start|>user\n<|AUDIO|>请转写音频为文字<|im_end|>\n<|im_start|>assistant\n" # noqa: E501 + prompt = { + "prompt": prompt_str, + "multi_modal_data": { + "audio": (audio, stt_config.sample_rate), + }, + } + return cast(PromptType, prompt) + + @classmethod + def get_speech_to_text_config( + cls, model_config: ModelConfig, task_type: str + ) -> SpeechToTextConfig: + processor = cached_processor_from_config(model_config) + + return SpeechToTextConfig( + max_audio_clip_s=processor.feature_extractor.chunk_length, + sample_rate=processor.feature_extractor.sampling_rate, + ) + + @classmethod + def get_num_audio_tokens( + cls, + audio_duration_s: float, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + ) -> int | None: + processor = cached_processor_from_config(model_config) + hop_length = processor.feature_extractor.hop_length + assert hop_length is not None + return math.ceil(audio_duration_s * stt_config.sample_rate / hop_length) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.config = config + self.dtype = vllm_config.model_config.dtype + + self.model = FireRedASR2Model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + logit_scale = getattr(config, "logit_scale", 1.0) + + self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + decoder_outputs = self.model( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + ) + return decoder_outputs + + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: + audio_input = self._parse_and_validate_audio_input(**kwargs) + + speech = audio_input["input_features"] + speech_lengths = audio_input["speech_lengths"].to(torch.int32) + enc_output = self.model.get_encoder_outputs( + speech=speech, speech_lengths=speech_lengths + ) + + return enc_output + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings | None = None, + *, + is_multimodal: torch.Tensor | None = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + inputs_embeds = self.model.decoder.embed_input_ids(input_ids) + + ret = _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=_require_is_multimodal(is_multimodal), + ) + return ret + + def _parse_and_validate_audio_input( + self, **kwargs: object + ) -> FireRedASR2AudioInputs: + input_features = kwargs.pop("input_features", None) + speech_lengths = kwargs.pop("speech_lengths", None) + fake_token_lengths = kwargs.pop("fake_token_lengths", None) + + return FireRedASR2AudioInputs( + input_features=input_features, + speech_lengths=speech_lengths, + fake_token_lengths=fake_token_lengths, + ) + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.model.decoder.lm_head, hidden_states) + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, skip_prefixes=["model.encoder.audio_encoder.positional_encoding.pe"] + ) + + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/funaudiochat.py b/vllm/model_executor/models/funaudiochat.py index 5bcb49e..2265d04 100644 --- a/vllm/model_executor/models/funaudiochat.py +++ b/vllm/model_executor/models/funaudiochat.py @@ -13,7 +13,6 @@ positions via `inputs_embeds`, while `position_ids` (RoPE) remains standard 1D. from __future__ import annotations -import os from collections.abc import Iterable, Mapping, Sequence from functools import cached_property from typing import Any @@ -924,53 +923,6 @@ class FunAudioChatForConditionalGeneration(nn.Module, SupportsMultiModal, Suppor f"sequence of Tensors (got {type(speech_attention_mask)})" ) - debug = os.getenv("VLLM_FUN_AUDIOCHAT_DEBUG", "") == "1" - if debug: - print( - f"[FunAudioChat] embed_multimodal speech_ids={tuple(speech_ids.shape)} " - f"speech_attention_mask={tuple(speech_attention_mask.shape)}", - flush=True, - ) - attn_impl = getattr( - self.continuous_audio_tower.config, "_attn_implementation", None - ) - print( - f"[FunAudioChat] audio_attn_impl={attn_impl}", - flush=True, - ) - if hasattr(self.continuous_audio_tower, "conv1"): - conv1_w = self.continuous_audio_tower.conv1.weight - print( - f"[FunAudioChat] conv1_w_norm={float(conv1_w.norm().item()):.6g}", - flush=True, - ) - try: - attn0 = self.continuous_audio_tower.layers[0].self_attn - q_norm = float(attn0.q_proj.weight.norm().item()) - k_norm = float(attn0.k_proj.weight.norm().item()) - v_norm = float(attn0.v_proj.weight.norm().item()) - o_norm = float(attn0.out_proj.weight.norm().item()) - print( - f"[FunAudioChat] attn0_q_norm={q_norm:.6g} " - f"k_norm={k_norm:.6g} " - f"v_norm={v_norm:.6g} " - f"o_norm={o_norm:.6g}", - flush=True, - ) - except Exception: - pass - if isinstance(input_features, torch.Tensor): - print( - f"[FunAudioChat] input_features={tuple(input_features.shape)}", - flush=True, - ) - if isinstance(feature_attention_mask, torch.Tensor): - print( - "[FunAudioChat] feature_attention_mask=" - f"{tuple(feature_attention_mask.shape)}", - flush=True, - ) - group_size = int(self.audio_tower.group_size) speech_maxlen = int(speech_ids.shape[-1]) @@ -1019,38 +971,6 @@ class FunAudioChatForConditionalGeneration(nn.Module, SupportsMultiModal, Suppor embeds = tuple( audio_features[i, : int(length)] for i, length in enumerate(lengths) ) - if debug: - embed_lens = [int(t.shape[0]) for t in embeds] - print(f"[FunAudioChat] embed_multimodal out_lens={embed_lens}", flush=True) - if embeds: - t0 = embeds[0] - print( - f"[FunAudioChat] embed0 dtype={t0.dtype} device={t0.device} " - f"nan={bool(torch.isnan(t0).any())} " - f"norm={float(t0.norm().item()):.6g}", - flush=True, - ) - dump_path = os.getenv("VLLM_FUN_AUDIOCHAT_DUMP_PATH", "") - if ( - dump_path - and speech_ids.shape[0] == 1 - and len(embeds) == 1 - and embed_lens[0] > 10 - ): - if not os.path.exists(dump_path): - np.save(dump_path, embeds[0].detach().float().cpu().numpy()) - print(f"[FunAudioChat] dumped embeds to {dump_path}", flush=True) - cont_path = dump_path.replace(".npy", "_cont.npy") - if continuous_audio_features is not None and not os.path.exists( - cont_path - ): - np.save( - cont_path, - continuous_audio_features.detach().float().cpu().numpy(), - ) - print( - f"[FunAudioChat] dumped continuous to {cont_path}", flush=True - ) return embeds def forward( diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index 770424b..5c84187 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -409,7 +409,7 @@ class Gemma3nAttention(nn.Module): ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - + q, k, v = q.contiguous(), k.contiguous(), v.contiguous() q = q.unflatten(-1, (self.num_heads, self.head_dim)) q = self.q_norm(q) q = q.flatten(-2, -1) diff --git a/vllm/model_executor/models/glm4_moe.py b/vllm/model_executor/models/glm4_moe.py index 9705bdf..c8fdd51 100644 --- a/vllm/model_executor/models/glm4_moe.py +++ b/vllm/model_executor/models/glm4_moe.py @@ -110,7 +110,12 @@ class Glm4MoeMLP(nn.Module): def forward(self, x): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) - x, _ = self.down_proj(x) + if self.down_proj.quant_method.__class__.__name__ != "UnquantizedLinearMethod" and x.shape[-1] != self.down_proj.weight.shape[0]: + padding = self.down_proj.weight.shape[0] - x.shape[-1] + x_align = torch.nn.functional.pad(x, (0, padding), mode='constant', value=0) + else: + x_align = x + x, _ = self.down_proj(x_align) return x @@ -144,11 +149,10 @@ class Glm4MoE(nn.Module): config.hidden_size, config.n_routed_experts, bias=False, - # dtype=torch.float32, + dtype=torch.bfloat16, ) self.gate.e_score_correction_bias = nn.Parameter( - torch.empty(config.n_routed_experts) - ) + torch.empty(config.n_routed_experts, dtype=torch.bfloat16)) # Load balancing settings. vllm_config = get_current_vllm_config() @@ -205,8 +209,7 @@ class Glm4MoE(nn.Module): hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (num_tokens, n_experts) - # router_logits = self.gate(hidden_states.to(dtype=torch.float32)) - router_logits = self.gate(hidden_states) + router_logits = self.gate(hidden_states.to(dtype=torch.bfloat16)) fused_moe_out = self.experts( hidden_states=hidden_states, router_logits=router_logits @@ -312,6 +315,9 @@ class Glm4MoeAttention(nn.Module): ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() if self.use_qk_norm: q = self.q_norm(q.reshape(-1, self.num_heads, self.head_dim)).reshape( q.shape diff --git a/vllm/model_executor/models/glm4_moe_lite.py b/vllm/model_executor/models/glm4_moe_lite.py index 9139457..6529986 100644 --- a/vllm/model_executor/models/glm4_moe_lite.py +++ b/vllm/model_executor/models/glm4_moe_lite.py @@ -127,12 +127,10 @@ class Glm4MoeLiteDecoderLayer(nn.Module): v_head_dim = getattr(config, "v_head_dim", 0) kv_lora_rank = getattr(config, "kv_lora_rank", 0) - # if model_config.use_mla: - # attn_cls = Glm4MoeLiteMLAAttention - # else: - # attn_cls = Glm4MoeLiteAttention - - attn_cls = Glm4MoeLiteAttention + if model_config.use_mla: + attn_cls = Glm4MoeLiteMLAAttention + else: + attn_cls = Glm4MoeLiteAttention self.self_attn = attn_cls( vllm_config=vllm_config, @@ -306,7 +304,7 @@ class Glm4MoeLiteModel(nn.Module): ), } ) - + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) @@ -318,6 +316,120 @@ class Glm4MoeLiteModel(nn.Module): num_experts=self.config.n_routed_experts, ) + +class Glm4MoeLiteForCausalLM( + nn.Module, SupportsPP, SupportsLoRA, Glm4LiteMixtureOfExperts +): + packed_modules_mapping = { + "gate_up_proj": ["gate_proj", "up_proj"], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0) + qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0) + self.use_mha = config.model_type == "deepseek" or all( + dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim) + ) + + if self.use_mha: + self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"] + + # `packed_modules_mapping` needs to be modified before + # initializing DeepseekV2Model, as it is passed inplace to + # quantization config init and may be used to select the + # quant_method for relevant layers during initialization. + self.fuse_qkv_a_proj = ( + hasattr(config, "q_lora_rank") and config.q_lora_rank is not None + ) + if self.fuse_qkv_a_proj: + self.packed_modules_mapping["fused_qkv_a_proj"] = [ + "q_a_proj", + "kv_a_proj_with_mqa", + ] + + self.model = Glm4MoeLiteModel( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + ) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + else: + self.lm_head = PPMissingLayer() + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + # Set MoE hyperparameters + self.num_moe_layers = ( + self.config.num_hidden_layers - self.config.first_k_dense_replace + ) + self.set_moe_parameters() + + def set_moe_parameters(self): + self.expert_weights = [] + + self.num_expert_groups = getattr(self.config, "n_group", 1) + + self.moe_layers = [] + self.moe_mlp_layers = [] + example_moe = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, Glm4MoeLiteDecoderLayer) + if isinstance(layer.mlp, Glm4MoeLite): + # Pick last one layer since the first ones may be dense layers. + example_moe = layer.mlp + self.moe_mlp_layers.append(layer.mlp) + self.moe_layers.append(layer.mlp.experts) + + self.extract_moe_parameters(example_moe) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + return SharedFusedMoE.make_expert_params_mapping( + self, + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts, + num_redundant_experts=0, + ) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: rocm_aiter_moe_shared_expert_enabled = ( rocm_aiter_ops.is_fusion_moe_shared_experts_enabled() @@ -327,12 +439,13 @@ class Glm4MoeLiteModel(nn.Module): ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] - mla_params_mapping = [ - ("fused_qkv_a_proj", "q_a_proj", 0), - ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), + mha_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), ] - - stacked_params_mapping.extend(mla_params_mapping) + if self.use_mha: + stacked_params_mapping.extend(mha_params_mapping) # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) @@ -510,128 +623,71 @@ class Glm4MoeLiteModel(nn.Module): weight_loader(param, loaded_weight) if not is_fusion_moe_shared_experts_layer: loaded_params.add(name) + opt_support_quant_method = ["GGUFLinearMethod", "UnquantizedLinearMethod", "CompressedTensorsW8A8Int8", "AWQMarlinLinearMethod"] + # add your opt here.. + def inject_layer(layer, quant_method, is_mla): + q_lora_rank = getattr(layer, "q_lora_rank", None) + if quant_method in ["UnquantizedLinearMethod", "CompressedTensorsW8A8Int8"]: + if q_lora_rank is not None: + layer.q_a_proj.weight.data = torch.cat([layer.q_a_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=0) + if hasattr(layer.q_a_proj, "weight_scale"): + layer.q_a_proj.weight_scale.data = torch.cat([layer.q_a_proj.weight_scale, layer.kv_a_proj_with_mqa.weight_scale], dim=0) + del layer.kv_a_proj_with_mqa.weight_scale + elif not is_mla: + layer.q_proj.weight.data = torch.cat([layer.q_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=0) + if hasattr(layer.q_proj, "weight_scale"): + layer.q_proj.weight_scale.data = torch.cat([layer.q_proj.weight_scale, layer.kv_a_proj_with_mqa.weight_scale], dim=0) + del layer.kv_a_proj_with_mqa.weight_scale + else: + return + del layer.kv_a_proj_with_mqa.weight + del layer.kv_a_proj_with_mqa + if is_mla: + layer.mla_attn.forward = layer.mla_attn.forward_opt + else: + layer.forward = layer.forward_opt + elif quant_method == "GGUFLinearMethod": + pass + elif quant_method == "AWQMarlinLinearMethod": + dtype = layer.kv_a_proj_with_mqa.qweight.dtype + assert dtype == torch.int32 + if layer.q_lora_rank is not None: + layer.q_a_proj.qweight.data = torch.cat([layer.q_a_proj.qweight, layer.kv_a_proj_with_mqa.qweight], dim=1) + layer.q_a_proj.scales.data = torch.cat([layer.q_a_proj.scales, layer.kv_a_proj_with_mqa.scales], dim=1) + del layer.kv_a_proj_with_mqa.scales + layer.q_a_proj.qzeros.data = torch.cat([layer.q_a_proj.qzeros, layer.kv_a_proj_with_mqa.qzeros], dim=1) + del layer.kv_a_proj_with_mqa.qzeros + elif not is_mla: + layer.q_proj.weight.data = torch.cat([layer.q_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=1) + layer.q_proj.scales.data = torch.cat([layer.q_proj.scales, layer.kv_a_proj_with_mqa.scales], dim=1) + del layer.kv_a_proj_with_mqa.scales + layer.q_proj.qzeros.data = torch.cat([layer.q_proj.qzeros, layer.kv_a_proj_with_mqa.qzeros], dim=1) + del layer.kv_a_proj_with_mqa.qzeros + else: + return + del layer.kv_a_proj_with_mqa.qweight + del layer.kv_a_proj_with_mqa + if is_mla: + layer.mla_attn.forward = layer.mla_attn.forward_opt + else: + layer.forward = layer.forward_opt + else: + pass + + for _, layer in self.model.named_modules(): + if layer.__class__.__name__ in ["Glm4MoeLiteAttention","Glm4MoeLiteMLAAttention"]: + if hasattr(layer.kv_a_proj_with_mqa, "scheme"): + quant_method = layer.kv_a_proj_with_mqa.scheme.__class__.__name__ + else: + quant_method = layer.kv_a_proj_with_mqa.quant_method.__class__.__name__ + if quant_method not in opt_support_quant_method: + break + + inject_layer(layer, quant_method, is_mla = layer.__class__.__name__ == "Glm4MoeLiteMLAAttention") return loaded_params -class Glm4MoeLiteForCausalLM( - nn.Module, SupportsPP, SupportsLoRA, Glm4LiteMixtureOfExperts -): - packed_modules_mapping = { - "gate_up_proj": ["gate_proj", "up_proj"], - } - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__() - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - - qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0) - qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0) - self.use_mha = config.model_type == "deepseek" or all( - dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim) - ) - - if self.use_mha: - self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"] - - # `packed_modules_mapping` needs to be modified before - # initializing DeepseekV2Model, as it is passed inplace to - # quantization config init and may be used to select the - # quant_method for relevant layers during initialization. - self.fuse_qkv_a_proj = ( - hasattr(config, "q_lora_rank") and config.q_lora_rank is not None - ) - if self.fuse_qkv_a_proj: - self.packed_modules_mapping["fused_qkv_a_proj"] = [ - "q_a_proj", - "kv_a_proj_with_mqa", - ] - - self.model = Glm4MoeLiteModel( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") - ) - if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "lm_head"), - ) - else: - self.lm_head = PPMissingLayer() - self.logits_processor = LogitsProcessor(config.vocab_size) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors - ) - # Set MoE hyperparameters - self.num_moe_layers = ( - self.config.num_hidden_layers - self.config.first_k_dense_replace - ) - self.set_moe_parameters() - - def set_moe_parameters(self): - self.expert_weights = [] - - self.num_expert_groups = getattr(self.config, "n_group", 1) - - self.moe_layers = [] - self.moe_mlp_layers = [] - example_moe = None - for layer in self.model.layers: - if isinstance(layer, PPMissingLayer): - continue - - assert isinstance(layer, Glm4MoeLiteDecoderLayer) - if isinstance(layer.mlp, Glm4MoeLite): - # Pick last one layer since the first ones may be dense layers. - example_moe = layer.mlp - self.moe_mlp_layers.append(layer.mlp) - self.moe_layers.append(layer.mlp.experts) - - self.extract_moe_parameters(example_moe) - - def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.model.embed_input_ids(input_ids) - - def forward( - self, - input_ids: torch.Tensor | None, - positions: torch.Tensor, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - ) -> torch.Tensor | IntermediateTensors: - hidden_states = self.model( - input_ids, positions, intermediate_tensors, inputs_embeds - ) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - ) -> torch.Tensor | None: - logits = self.logits_processor(self.lm_head, hidden_states) - return logits - - def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - return SharedFusedMoE.make_expert_params_mapping( - self, - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts, - num_redundant_experts=0, - ) - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - loader = AutoWeightsLoader(self) - return loader.load_weights(weights) - - def get_spec_layer_idx_from_weight_name( config: "Glm4MoeLiteConfig", weight_name: str ) -> int | None: diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index fd70508..61f41c9 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -7,7 +7,7 @@ import torch import torch.distributed as dist from torch import nn from transformers import GptOssConfig - +import vllm.envs as envs from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import ( @@ -23,7 +23,11 @@ from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear +from vllm.model_executor.layers.linear import ( + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_BLOCK_SIZE @@ -42,6 +46,7 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils.math_utils import cdiv from vllm.v1.attention.backend import AttentionType +from vllm.model_executor.model_loader import padding_weight_loader from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP from .utils import ( @@ -107,7 +112,6 @@ class OAIAttention(nn.Module): quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) - self.o_proj = RowParallelLinear( input_size=self.num_attention_heads * self.head_dim, output_size=self.hidden_size, @@ -165,7 +169,14 @@ class MLPBlock(torch.nn.Module): self.hidden_size = config.hidden_size self.experts_per_token = config.num_experts_per_tok self.world_size = dist.get_world_size() if dist.is_initialized() else 1 - self.router = torch.nn.Linear(config.hidden_size, config.num_local_experts) + self.router = ReplicatedLinear( + config.hidden_size, + config.num_local_experts, + bias=True, + quant_config=None, + prefix=f"{prefix}.router", + return_bias=False, + ) assert config.intermediate_size % self.world_size == 0 self.experts = FusedMoE( num_experts=config.num_local_experts, @@ -969,8 +980,18 @@ class GptOssModel(nn.Module): weights: Iterable[tuple[str, torch.Tensor]], stacked_params_mapping: list[tuple[str, ...]], ) -> set[str]: - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() + + def handle_weight(name, weight, param_name, permute_dims=None, slice_dims=None, contiguous=True): + """Helper function to handle weight loading with optional slicing and permutation.""" + param = params_dict[param_name] + if slice_dims: + weight = weight[slice_dims] + if permute_dims: + weight = weight.permute(*permute_dims) + if contiguous: + weight = weight.contiguous() + padding_weight_loader(param, weight) + loaded_params.add(param_name) use_ep = self.parallel_config.enable_expert_parallel @@ -986,91 +1007,71 @@ class GptOssModel(nn.Module): intermediate_size = self.config.intermediate_size per_rank_intermediate_size = cdiv(intermediate_size, tp_size) - # Calculate common slicing bounds for current rank tp_rank_start = tp_rank * per_rank_intermediate_size tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size, intermediate_size) + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + pack_factor = 2 if envs.VLLM_W8A8_MOE_USE_W4A8 else 1 + w4a8_flag = envs.VLLM_W8A8_MOE_USE_W4A8 + gemm_format = envs.VLLM_W8A8_FORMAT + for name, weight in weights: - # Skip layers on other devices. + # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue - - if ".w13_weight" in name: - # Handle MLP gate and up projection weights - # Extract gate and up projection parts - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[:, :, 2 * tp_rank_start : 2 * tp_rank_end] - - narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() + if ".experts.w13_weight" in name and "scale" not in name and "bias" not in name: + slice_dims = (slice(ep_rank_start, ep_rank_end), ...) if use_ep else (slice(None), slice(None), slice(2 * tp_rank_start, 2 * tp_rank_end)) + permute_dims = None if gemm_format == "NN" else (0, 2, 1) + handle_weight(name, weight, name, permute_dims=permute_dims, slice_dims=slice_dims) + elif ".experts.w2_weight" in name and "scale" not in name and "bias" not in name: + slice_dims = (slice(ep_rank_start, ep_rank_end), ...) if use_ep else (slice(None), slice(tp_rank_start // pack_factor, tp_rank_end // pack_factor), slice(None)) + permute_dims = None if gemm_format == "NN" else (0, 2, 1) + handle_weight(name, weight, name, permute_dims=permute_dims, slice_dims=slice_dims) + elif ".experts.gate_up_proj_scale" in name: + new_name = name.replace("gate_up_proj_scale", "w13_weight_scale") + slice_dims = (slice(ep_rank_start, ep_rank_end), ...) if use_ep else (slice(None), slice(None), slice(2 * tp_rank_start, 2 * tp_rank_end)) + permute_dims = None if w4a8_flag else (0, 2, 1) + handle_weight(name, weight, new_name, permute_dims=permute_dims, slice_dims=slice_dims, contiguous=w4a8_flag) + elif ".experts.down_proj_scale" in name: + new_name = name.replace("down_proj_scale", "w2_weight_scale") + slice_dims = (slice(ep_rank_start, ep_rank_end), ...) if use_ep else None + permute_dims = None if w4a8_flag else (0, 2, 1) + handle_weight(name, weight, new_name, permute_dims=permute_dims, slice_dims=slice_dims, contiguous=w4a8_flag) + elif ".experts.w13_bias" in name: + slice_dims = (slice(ep_rank_start, ep_rank_end), ...) if use_ep else (slice(None), slice(2 * tp_rank_start, 2 * tp_rank_end)) + handle_weight(name, weight, name, slice_dims=slice_dims, contiguous=False) + elif ".experts.w2_bias" in name: param = params_dict[name] - - param.copy_(narrow_weight) - loaded_params.add(name) - continue - elif ".w2_weight" in name: - # Handle MLP down projection weights - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[:, tp_rank_start:tp_rank_end, :] - narrow_weight = narrow_weight.permute(0, 2, 1).contiguous() - param = params_dict[name] - - param.copy_(narrow_weight) - loaded_params.add(name) - continue - elif ".w13_bias" in name: - # Handle MLP gate and up projection biases - # Extract gate and up projection bias parts - if use_ep: - narrow_weight = weight[ep_rank_start:ep_rank_end, ...] - else: - narrow_weight = weight[:, 2 * tp_rank_start : 2 * tp_rank_end] - - param = params_dict[name] - param.copy_(narrow_weight) - loaded_params.add(name) - continue - elif ".w2_bias" in name: - # Handle MLP down projection bias if use_ep: weight = weight[ep_rank_start:ep_rank_end, ...] - else: - # (only load on rank 0 to avoid duplication) - if tp_rank != 0: - weight.zero_() - param = params_dict[name] - param.copy_(weight) + elif tp_rank != 0: + weight.zero_() + param.data.copy_(weight) loaded_params.add(name) - continue elif "sinks" in name: - # Handle attention sinks (distributed across ranks) + name = name.replace("self_attn", "attn") param = params_dict[name] narrow_weight = weight.narrow(0, head_start, heads_per_rank) param.data.copy_(narrow_weight) loaded_params.add(name) - continue - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - if weight_loader == default_weight_loader: - weight_loader(param, weight) - else: - weight_loader(param, weight, shard_id) - break + elif ("q_proj" in name or "k_proj" in name or "v_proj" in name): + shard_id = ("q" if "q_proj" in name else "k" if "k_proj" in name else "v") + name = name.replace("self_attn", "attn") + param_name = name.replace(f"{shard_id}_proj", "qkv_proj") + param = params_dict[param_name] + weight_loader = param.weight_loader + weight_loader(param, weight, loaded_shard_id=shard_id) + loaded_params.add(param_name) else: - # Handle all other weights with potential renaming if name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, weight) - loaded_params.add(name) + loaded_params.add(name) + return loaded_params def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: diff --git a/vllm/model_executor/models/hunyuan_vision.py b/vllm/model_executor/models/hunyuan_vision.py index 3f2d0e7..b6fda25 100644 --- a/vllm/model_executor/models/hunyuan_vision.py +++ b/vllm/model_executor/models/hunyuan_vision.py @@ -636,7 +636,13 @@ class HunYuanVLProcessingInfo(BaseProcessingInfo): spatial_merge_size = vision_config.spatial_merge_size mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs) - size = mm_kwargs.get("size", image_processor.size) + size = image_processor.size + if override_size := mm_kwargs.get("size"): + size = size | override_size + if (override_min_pixels := mm_kwargs.get("min_pixels")) is not None: + size = size | {"shortest_edge": override_min_pixels} + if (override_max_pixels := mm_kwargs.get("max_pixels")) is not None: + size = size | {"longest_edge": override_max_pixels} if do_resize: resized_height, resized_width = smart_resize( diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index 1fb0d5e..5b0dfe4 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -49,7 +49,6 @@ from .utils import ( ) from .vision import get_vision_encoder_info -EOT = "<|endofturn|>" IMAGE_TOKEN: str = "<|dummy3|>" VIDEO_TOKEN: str = "<|_unuse_missing_100270|>" diff --git a/vllm/model_executor/models/isaac.py b/vllm/model_executor/models/isaac.py index f4f7ce4..6d8b45a 100644 --- a/vllm/model_executor/models/isaac.py +++ b/vllm/model_executor/models/isaac.py @@ -551,7 +551,7 @@ def process_vision_for_patches( `(num_images, height, width, channels)` for a batch. Channels are expected to be RGB. patch_size (`int`): - Edge length of square patches; implictly controls resize grid granularity. + Edge length of square patches; implicitly controls resize grid granularity. max_num_patches (`int`): Maximum number of patches allowed after resizing. min_num_patches (`int`, *optional*): diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index 2cb7dc4..4c43e41 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -1021,7 +1021,13 @@ class KeyeProcessingInfo(BaseProcessingInfo): temporal_patch_size = 1 mm_kwargs = self.ctx.get_merged_mm_kwargs(mm_kwargs) - size = mm_kwargs.get("size", image_processor.size) + size = image_processor.size + if override_size := mm_kwargs.get("size"): + size = size | override_size + if (override_min_pixels := mm_kwargs.get("min_pixels")) is not None: + size = size | {"min_pixels": override_min_pixels} + if (override_max_pixels := mm_kwargs.get("max_pixels")) is not None: + size = size | {"max_pixels": override_max_pixels} if do_resize: resized_height, resized_width = smart_resize( diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 4492b57..529233c 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -654,4 +654,4 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): self, skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None), ) - return loader.load_weights(weights) + return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py index e61e9d0..522f4a2 100644 --- a/vllm/model_executor/models/minicpm3.py +++ b/vllm/model_executor/models/minicpm3.py @@ -230,4 +230,4 @@ class MiniCPM3ForCausalLM(MiniCPMForCausalLM): } def _init_model(self, *, vllm_config: VllmConfig, prefix: str = ""): - return MiniCPM3Model(vllm_config=vllm_config, prefix=prefix) + return MiniCPM3Model(vllm_config=vllm_config, prefix=prefix) \ No newline at end of file diff --git a/vllm/model_executor/models/minimax_m2.py b/vllm/model_executor/models/minimax_m2.py index c4591ac..0761393 100644 --- a/vllm/model_executor/models/minimax_m2.py +++ b/vllm/model_executor/models/minimax_m2.py @@ -88,7 +88,7 @@ class MiniMaxM2MoE(nn.Module): self.use_routing_bias = getattr(config, "use_routing_bias", False) if self.use_routing_bias: self.e_score_correction_bias = nn.Parameter( - torch.empty(config.num_local_experts, dtype=torch.float32) + torch.empty(config.num_local_experts, dtype=torch.get_default_dtype()) ) self.e_score_correction_bias.weight_loader = ( MiniMaxM2MoE.ebias_weight_loader @@ -107,13 +107,14 @@ class MiniMaxM2MoE(nn.Module): renormalize=True, quant_config=quant_config, prefix=f"{prefix}.experts", + router_logits_dtype=torch.float32, ) self.gate = ReplicatedLinear( config.hidden_size, config.num_local_experts, bias=False, - # params_dtype=torch.float32, + params_dtype=torch.float32, quant_config=None, prefix=f"{prefix}.gate", ) @@ -121,7 +122,6 @@ class MiniMaxM2MoE(nn.Module): @staticmethod def ebias_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor) -> None: assert param.size() == loaded_weight.size() - # param.data.copy_(loaded_weight.to(torch.float32)) param.data.copy_(loaded_weight) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -129,10 +129,9 @@ class MiniMaxM2MoE(nn.Module): hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (num_tokens, n_experts) - # router_logits, _ = self.gate(hidden_states.to(torch.float32)) - router_logits, _ = self.gate(hidden_states) + router_logits, _ = self.gate(hidden_states.to(torch.float32)) final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits + hidden_states=hidden_states, router_logits=router_logits.to(hidden_states.dtype) ) final_hidden_states = final_hidden_states if self.tp_size > 1: diff --git a/vllm/model_executor/models/nano_nemotron_vl.py b/vllm/model_executor/models/nano_nemotron_vl.py index 46cf7fe..82422e8 100644 --- a/vllm/model_executor/models/nano_nemotron_vl.py +++ b/vllm/model_executor/models/nano_nemotron_vl.py @@ -44,6 +44,7 @@ from vllm.model_executor.models.internvl import ( ) from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM +from vllm.model_executor.models.parakeet import ParakeetExtractor, ProjectedParakeet from vllm.model_executor.models.radio import RadioModel, calc_seq_lens from vllm.model_executor.models.utils import ( init_vllm_registered_model, @@ -55,12 +56,14 @@ from vllm.multimodal.evs import ( compute_retention_mask, ) from vllm.multimodal.inputs import ( + AudioItem, MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems, VideoItem, ) from vllm.multimodal.parse import ( + AudioProcessorItems, ImageEmbeddingItems, ImageProcessorItems, ImageSize, @@ -91,9 +94,29 @@ Image.MAX_IMAGE_PIXELS = None # Disable the limit entirely # Alternative: Set a specific higher limit # Image.MAX_IMAGE_PIXELS = 300000000 # ~300M pixels + +class NanoNemotronVLAudioFeatureInputs(TensorSchema): + """ + Dimensions: + - b: Number of audio clips + - t: Audio feature length + - f: Feature size (mel bins) + """ + + type: Literal["audio_features"] = "audio_features" + input_audio_features: Annotated[torch.Tensor, TensorShape("b", "t", "f")] + feature_attention_mask: Annotated[torch.Tensor, TensorShape("b", "t")] + audio_feature_lengths: Annotated[torch.Tensor, TensorShape("b")] + + +MAX_AUDIO_LEN_S = 10 * 60 # 10 minutes + IMG_START = "" IMG_END = "" IMG_CONTEXT = "" +AUDIO_START = "" +AUDIO_END = "" +AUDIO_CONTEXT = "" # Profiling # MAX_FRAMES = 16 @@ -820,6 +843,11 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): self.video_token = video_token self.video_pruning_rate = video_pruning_rate + self.audio_extractor: ParakeetExtractor | None = None + raw_sound_config = getattr(config, "sound_config", None) + if raw_sound_config is not None: + self.audio_extractor = ParakeetExtractor(raw_sound_config) + # Pre-tokenize special tokens for video processing # to avoid repeated tokenization self._img_start_token_ids = tokenizer.encode( @@ -952,11 +980,53 @@ class NanoNemotronVLProcessor(BaseNanoNemotronVLProcessor): text = [t.replace("